#include "test_framework.h" #include "../src/types.h" #include "../src/http.h" #include "../src/config.h" #include "../src/buffer.h" #include #include #include #include static const char *TEST_CONFIG_FILE = "/tmp/test_host_rewrite_config.json"; static void create_ssl_rewrite_config(void) { FILE *f = fopen(TEST_CONFIG_FILE, "w"); if (f) { fprintf(f, "{\n" " \"port\": 9999,\n" " \"reverse_proxy\": [\n" " {\n" " \"hostname\": \"mysite.example.com\",\n" " \"upstream_host\": \"upstream.internal\",\n" " \"upstream_port\": 443,\n" " \"use_ssl\": true,\n" " \"rewrite_host\": true\n" " }\n" " ]\n" "}\n"); fclose(f); } } static void cleanup_config(void) { unlink(TEST_CONFIG_FILE); } void test_host_header_line_bounds_first_request(void) { TEST_SUITE_BEGIN("Host Header Line Bounds - First Request"); const char *request = "GET /page1 HTTP/1.1\r\n" "Host: mysite.example.com\r\n" "Connection: keep-alive\r\n" "\r\n"; size_t len = strlen(request); const char *line_start = NULL; const char *line_end = NULL; int found = http_find_header_line_bounds(request, len, "Host", &line_start, &line_end); TEST_ASSERT_EQ(1, found, "Host header found in first request"); TEST_ASSERT(line_start != NULL, "Line start is not NULL"); TEST_ASSERT(line_end != NULL, "Line end is not NULL"); if (line_start && line_end) { size_t host_line_len = line_end - line_start; char host_line[256]; if (host_line_len < sizeof(host_line)) { memcpy(host_line, line_start, host_line_len); host_line[host_line_len] = '\0'; TEST_ASSERT_STR_EQ("Host: mysite.example.com\r\n", host_line, "Host header line extracted correctly"); } } TEST_SUITE_END(); } void test_host_header_line_bounds_second_request(void) { TEST_SUITE_BEGIN("Host Header Line Bounds - Second Request"); const char *request = "GET /page2 HTTP/1.1\r\n" "Host: mysite.example.com\r\n" "Connection: keep-alive\r\n" "\r\n"; size_t len = strlen(request); const char *line_start = NULL; const char *line_end = NULL; int found = http_find_header_line_bounds(request, len, "Host", &line_start, &line_end); TEST_ASSERT_EQ(1, found, "Host header found in second request"); TEST_ASSERT(line_start != NULL, "Line start is not NULL"); TEST_ASSERT(line_end != NULL, "Line end is not NULL"); if (line_start && line_end) { size_t host_line_len = line_end - line_start; char host_line[256]; if (host_line_len < sizeof(host_line)) { memcpy(host_line, line_start, host_line_len); host_line[host_line_len] = '\0'; TEST_ASSERT_STR_EQ("Host: mysite.example.com\r\n", host_line, "Host header line extracted correctly"); } } TEST_SUITE_END(); } void test_host_rewrite_simulation(void) { TEST_SUITE_BEGIN("Host Rewrite Simulation - Multiple Requests"); create_ssl_rewrite_config(); config_load(TEST_CONFIG_FILE); route_config_t *route = config_find_route("mysite.example.com"); TEST_ASSERT(route != NULL, "Route found for mysite.example.com"); if (!route) { config_free(); cleanup_config(); TEST_SUITE_END(); return; } TEST_ASSERT_EQ(1, route->rewrite_host, "Rewrite host is enabled"); TEST_ASSERT_EQ(1, route->use_ssl, "SSL is enabled"); TEST_ASSERT_STR_EQ("upstream.internal", route->upstream_host, "Upstream host is upstream.internal"); const char *request1 = "GET /page1 HTTP/1.1\r\n" "Host: mysite.example.com\r\n" "Connection: keep-alive\r\n" "\r\n"; const char *request2 = "GET /page2 HTTP/1.1\r\n" "Host: mysite.example.com\r\n" "Connection: keep-alive\r\n" "\r\n"; for (int req_num = 1; req_num <= 2; req_num++) { const char *request = (req_num == 1) ? request1 : request2; size_t request_len = strlen(request); const char *old_host_start = NULL; const char *old_host_end = NULL; int found = http_find_header_line_bounds(request, request_len, "Host", &old_host_start, &old_host_end); char msg[128]; snprintf(msg, sizeof(msg), "Request %d: Host header found", req_num); TEST_ASSERT_EQ(1, found, msg); if (found && old_host_start && old_host_end) { char new_host_header[512]; int is_default_port = (route->use_ssl && route->upstream_port == 443) || (!route->use_ssl && route->upstream_port == 80); if (is_default_port) { snprintf(new_host_header, sizeof(new_host_header), "Host: %s\r\n", route->upstream_host); } else { snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", route->upstream_host, route->upstream_port); } size_t new_host_len = strlen(new_host_header); size_t old_host_len = old_host_end - old_host_start; size_t new_request_len = request_len - old_host_len + new_host_len; char *modified_request = malloc(new_request_len + 1); TEST_ASSERT(modified_request != NULL, "Memory allocated for modified request"); if (modified_request) { char *p = modified_request; size_t prefix_len = old_host_start - request; memcpy(p, request, prefix_len); p += prefix_len; memcpy(p, new_host_header, new_host_len); p += new_host_len; size_t suffix_len = request_len - (old_host_end - request); memcpy(p, old_host_end, suffix_len); modified_request[new_request_len] = '\0'; int has_new_host = (strstr(modified_request, "Host: upstream.internal\r\n") != NULL); snprintf(msg, sizeof(msg), "Request %d: Host rewritten to upstream.internal", req_num); TEST_ASSERT_EQ(1, has_new_host, msg); int no_old_host = (strstr(modified_request, "Host: mysite.example.com") == NULL); snprintf(msg, sizeof(msg), "Request %d: Original host removed", req_num); TEST_ASSERT_EQ(1, no_old_host, msg); free(modified_request); } } } config_free(); cleanup_config(); TEST_SUITE_END(); } void test_pipelined_request_detection(void) { TEST_SUITE_BEGIN("Pipelined Request Detection for Host Rewrite"); const char *first_request = "GET /first HTTP/1.1\r\n" "Host: mysite.example.com\r\n" "\r\n"; const char *second_request = "GET /second HTTP/1.1\r\n" "Host: mysite.example.com\r\n" "\r\n"; int is_first_req_start = http_is_request_start(first_request, strlen(first_request)); TEST_ASSERT_EQ(1, is_first_req_start, "First request detected as HTTP request start"); int is_second_req_start = http_is_request_start(second_request, strlen(second_request)); TEST_ASSERT_EQ(1, is_second_req_start, "Second request detected as HTTP request start"); const char *response_data = "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; int is_response_req = http_is_request_start(response_data, strlen(response_data)); TEST_ASSERT_EQ(0, is_response_req, "Response data NOT detected as request start"); const char *binary_data = "\x00\x01\x02\x03"; int is_binary_req = http_is_request_start(binary_data, 4); TEST_ASSERT_EQ(0, is_binary_req, "Binary data NOT detected as request start"); TEST_SUITE_END(); } void test_ssl_sni_with_host_rewrite(void) { TEST_SUITE_BEGIN("SSL SNI Hostname with Host Rewrite"); create_ssl_rewrite_config(); config_load(TEST_CONFIG_FILE); route_config_t *route = config_find_route("mysite.example.com"); TEST_ASSERT(route != NULL, "Route found"); if (route) { const char *sni_hostname = route->rewrite_host ? route->upstream_host : "mysite.example.com"; TEST_ASSERT_STR_EQ("upstream.internal", sni_hostname, "SNI should be upstream.internal when rewrite_host is true"); } config_free(); cleanup_config(); TEST_SUITE_END(); } void test_consecutive_requests_same_connection(void) { TEST_SUITE_BEGIN("Consecutive Requests on Same Connection"); create_ssl_rewrite_config(); config_load(TEST_CONFIG_FILE); route_config_t *route = config_find_route("mysite.example.com"); TEST_ASSERT(route != NULL, "Route found"); if (!route) { config_free(); cleanup_config(); TEST_SUITE_END(); return; } const char *requests[] = { "GET /page1 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n", "GET /page2 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n", "GET /page3 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n", "POST /api HTTP/1.1\r\nHost: mysite.example.com\r\nContent-Length: 0\r\n\r\n", "GET /page4 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: close\r\n\r\n" }; int num_requests = sizeof(requests) / sizeof(requests[0]); for (int i = 0; i < num_requests; i++) { const char *request = requests[i]; size_t request_len = strlen(request); http_request_t parsed_req; int parse_result = http_parse_request(request, request_len, &parsed_req); char msg[128]; snprintf(msg, sizeof(msg), "Request %d: Parsed successfully", i + 1); TEST_ASSERT_EQ(1, parse_result, msg); route_config_t *found_route = config_find_route(parsed_req.host); snprintf(msg, sizeof(msg), "Request %d: Route found", i + 1); TEST_ASSERT(found_route != NULL, msg); if (found_route) { snprintf(msg, sizeof(msg), "Request %d: Rewrite host enabled", i + 1); TEST_ASSERT_EQ(1, found_route->rewrite_host, msg); const char *old_host_start = NULL; const char *old_host_end = NULL; int found = http_find_header_line_bounds(request, request_len, "Host", &old_host_start, &old_host_end); snprintf(msg, sizeof(msg), "Request %d: Host header found for rewrite", i + 1); TEST_ASSERT_EQ(1, found, msg); } } config_free(); cleanup_config(); TEST_SUITE_END(); } void run_host_rewrite_tests(void) { test_host_header_line_bounds_first_request(); test_host_header_line_bounds_second_request(); test_host_rewrite_simulation(); test_pipelined_request_detection(); test_ssl_sni_with_host_rewrite(); test_consecutive_requests_same_connection(); }