#include "test_framework.h"
#include "../src/types.h"
#include "../src/http.h"
#include "../src/config.h"
#include "../src/buffer.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
static const char *TEST_CONFIG_FILE = "/tmp/test_host_rewrite_config.json";
static void create_ssl_rewrite_config(void) {
FILE *f = fopen(TEST_CONFIG_FILE, "w");
if (f) {
fprintf(f,
"{\n"
" \"port\": 9999,\n"
" \"reverse_proxy\": [\n"
" {\n"
" \"hostname\": \"mysite.example.com\",\n"
" \"upstream_host\": \"upstream.internal\",\n"
" \"upstream_port\": 443,\n"
" \"use_ssl\": true,\n"
" \"rewrite_host\": true\n"
" }\n"
" ]\n"
"}\n");
fclose(f);
}
}
static void cleanup_config(void) {
unlink(TEST_CONFIG_FILE);
}
void test_host_header_line_bounds_first_request(void) {
TEST_SUITE_BEGIN("Host Header Line Bounds - First Request");
const char *request =
"GET /page1 HTTP/1.1\r\n"
"Host: mysite.example.com\r\n"
"Connection: keep-alive\r\n"
"\r\n";
size_t len = strlen(request);
const char *line_start = NULL;
const char *line_end = NULL;
int found = http_find_header_line_bounds(request, len, "Host", &line_start, &line_end);
TEST_ASSERT_EQ(1, found, "Host header found in first request");
TEST_ASSERT(line_start != NULL, "Line start is not NULL");
TEST_ASSERT(line_end != NULL, "Line end is not NULL");
if (line_start && line_end) {
size_t host_line_len = line_end - line_start;
char host_line[256];
if (host_line_len < sizeof(host_line)) {
memcpy(host_line, line_start, host_line_len);
host_line[host_line_len] = '\0';
TEST_ASSERT_STR_EQ("Host: mysite.example.com\r\n", host_line, "Host header line extracted correctly");
}
}
TEST_SUITE_END();
}
void test_host_header_line_bounds_second_request(void) {
TEST_SUITE_BEGIN("Host Header Line Bounds - Second Request");
const char *request =
"GET /page2 HTTP/1.1\r\n"
"Host: mysite.example.com\r\n"
"Connection: keep-alive\r\n"
"\r\n";
size_t len = strlen(request);
const char *line_start = NULL;
const char *line_end = NULL;
int found = http_find_header_line_bounds(request, len, "Host", &line_start, &line_end);
TEST_ASSERT_EQ(1, found, "Host header found in second request");
TEST_ASSERT(line_start != NULL, "Line start is not NULL");
TEST_ASSERT(line_end != NULL, "Line end is not NULL");
if (line_start && line_end) {
size_t host_line_len = line_end - line_start;
char host_line[256];
if (host_line_len < sizeof(host_line)) {
memcpy(host_line, line_start, host_line_len);
host_line[host_line_len] = '\0';
TEST_ASSERT_STR_EQ("Host: mysite.example.com\r\n", host_line, "Host header line extracted correctly");
}
}
TEST_SUITE_END();
}
void test_host_rewrite_simulation(void) {
TEST_SUITE_BEGIN("Host Rewrite Simulation - Multiple Requests");
create_ssl_rewrite_config();
config_load(TEST_CONFIG_FILE);
route_config_t *route = config_find_route("mysite.example.com");
TEST_ASSERT(route != NULL, "Route found for mysite.example.com");
if (!route) {
config_free();
cleanup_config();
TEST_SUITE_END();
return;
}
TEST_ASSERT_EQ(1, route->rewrite_host, "Rewrite host is enabled");
TEST_ASSERT_EQ(1, route->use_ssl, "SSL is enabled");
TEST_ASSERT_STR_EQ("upstream.internal", route->upstream_host, "Upstream host is upstream.internal");
const char *request1 =
"GET /page1 HTTP/1.1\r\n"
"Host: mysite.example.com\r\n"
"Connection: keep-alive\r\n"
"\r\n";
const char *request2 =
"GET /page2 HTTP/1.1\r\n"
"Host: mysite.example.com\r\n"
"Connection: keep-alive\r\n"
"\r\n";
for (int req_num = 1; req_num <= 2; req_num++) {
const char *request = (req_num == 1) ? request1 : request2;
size_t request_len = strlen(request);
const char *old_host_start = NULL;
const char *old_host_end = NULL;
int found = http_find_header_line_bounds(request, request_len, "Host", &old_host_start, &old_host_end);
char msg[128];
snprintf(msg, sizeof(msg), "Request %d: Host header found", req_num);
TEST_ASSERT_EQ(1, found, msg);
if (found && old_host_start && old_host_end) {
char new_host_header[512];
int is_default_port = (route->use_ssl && route->upstream_port == 443) ||
(!route->use_ssl && route->upstream_port == 80);
if (is_default_port) {
snprintf(new_host_header, sizeof(new_host_header), "Host: %s\r\n", route->upstream_host);
} else {
snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n",
route->upstream_host, route->upstream_port);
}
size_t new_host_len = strlen(new_host_header);
size_t old_host_len = old_host_end - old_host_start;
size_t new_request_len = request_len - old_host_len + new_host_len;
char *modified_request = malloc(new_request_len + 1);
TEST_ASSERT(modified_request != NULL, "Memory allocated for modified request");
if (modified_request) {
char *p = modified_request;
size_t prefix_len = old_host_start - request;
memcpy(p, request, prefix_len);
p += prefix_len;
memcpy(p, new_host_header, new_host_len);
p += new_host_len;
size_t suffix_len = request_len - (old_host_end - request);
memcpy(p, old_host_end, suffix_len);
modified_request[new_request_len] = '\0';
int has_new_host = (strstr(modified_request, "Host: upstream.internal\r\n") != NULL);
snprintf(msg, sizeof(msg), "Request %d: Host rewritten to upstream.internal", req_num);
TEST_ASSERT_EQ(1, has_new_host, msg);
int no_old_host = (strstr(modified_request, "Host: mysite.example.com") == NULL);
snprintf(msg, sizeof(msg), "Request %d: Original host removed", req_num);
TEST_ASSERT_EQ(1, no_old_host, msg);
free(modified_request);
}
}
}
config_free();
cleanup_config();
TEST_SUITE_END();
}
void test_pipelined_request_detection(void) {
TEST_SUITE_BEGIN("Pipelined Request Detection for Host Rewrite");
const char *first_request =
"GET /first HTTP/1.1\r\n"
"Host: mysite.example.com\r\n"
"\r\n";
const char *second_request =
"GET /second HTTP/1.1\r\n"
"Host: mysite.example.com\r\n"
"\r\n";
int is_first_req_start = http_is_request_start(first_request, strlen(first_request));
TEST_ASSERT_EQ(1, is_first_req_start, "First request detected as HTTP request start");
int is_second_req_start = http_is_request_start(second_request, strlen(second_request));
TEST_ASSERT_EQ(1, is_second_req_start, "Second request detected as HTTP request start");
const char *response_data = "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello";
int is_response_req = http_is_request_start(response_data, strlen(response_data));
TEST_ASSERT_EQ(0, is_response_req, "Response data NOT detected as request start");
const char *binary_data = "\x00\x01\x02\x03";
int is_binary_req = http_is_request_start(binary_data, 4);
TEST_ASSERT_EQ(0, is_binary_req, "Binary data NOT detected as request start");
TEST_SUITE_END();
}
void test_ssl_sni_with_host_rewrite(void) {
TEST_SUITE_BEGIN("SSL SNI Hostname with Host Rewrite");
create_ssl_rewrite_config();
config_load(TEST_CONFIG_FILE);
route_config_t *route = config_find_route("mysite.example.com");
TEST_ASSERT(route != NULL, "Route found");
if (route) {
const char *sni_hostname = route->rewrite_host ? route->upstream_host : "mysite.example.com";
TEST_ASSERT_STR_EQ("upstream.internal", sni_hostname, "SNI should be upstream.internal when rewrite_host is true");
}
config_free();
cleanup_config();
TEST_SUITE_END();
}
void test_consecutive_requests_same_connection(void) {
TEST_SUITE_BEGIN("Consecutive Requests on Same Connection");
create_ssl_rewrite_config();
config_load(TEST_CONFIG_FILE);
route_config_t *route = config_find_route("mysite.example.com");
TEST_ASSERT(route != NULL, "Route found");
if (!route) {
config_free();
cleanup_config();
TEST_SUITE_END();
return;
}
const char *requests[] = {
"GET /page1 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n",
"GET /page2 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n",
"GET /page3 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n",
"POST /api HTTP/1.1\r\nHost: mysite.example.com\r\nContent-Length: 0\r\n\r\n",
"GET /page4 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: close\r\n\r\n"
};
int num_requests = sizeof(requests) / sizeof(requests[0]);
for (int i = 0; i < num_requests; i++) {
const char *request = requests[i];
size_t request_len = strlen(request);
http_request_t parsed_req;
int parse_result = http_parse_request(request, request_len, &parsed_req);
char msg[128];
snprintf(msg, sizeof(msg), "Request %d: Parsed successfully", i + 1);
TEST_ASSERT_EQ(1, parse_result, msg);
route_config_t *found_route = config_find_route(parsed_req.host);
snprintf(msg, sizeof(msg), "Request %d: Route found", i + 1);
TEST_ASSERT(found_route != NULL, msg);
if (found_route) {
snprintf(msg, sizeof(msg), "Request %d: Rewrite host enabled", i + 1);
TEST_ASSERT_EQ(1, found_route->rewrite_host, msg);
const char *old_host_start = NULL;
const char *old_host_end = NULL;
int found = http_find_header_line_bounds(request, request_len, "Host",
&old_host_start, &old_host_end);
snprintf(msg, sizeof(msg), "Request %d: Host header found for rewrite", i + 1);
TEST_ASSERT_EQ(1, found, msg);
}
}
config_free();
cleanup_config();
TEST_SUITE_END();
}
void run_host_rewrite_tests(void) {
test_host_header_line_bounds_first_request();
test_host_header_line_bounds_second_request();
test_host_rewrite_simulation();
test_pipelined_request_detection();
test_ssl_sni_with_host_rewrite();
test_consecutive_requests_same_connection();
}