diff --git a/Makefile b/Makefile index 0f3d850..d1bd1a0 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,8 @@ CC = gcc CFLAGS = -Wall -Wextra -O2 -g -D_GNU_SOURCE +CFLAGS_COV = -Wall -Wextra -g -D_GNU_SOURCE --coverage -fprofile-arcs -ftest-coverage LDFLAGS = -lssl -lcrypto -lsqlite3 -lm -lpthread +LDFLAGS_COV = -lssl -lcrypto -lsqlite3 -lm -lpthread --coverage SRC_DIR = src BUILD_DIR = build @@ -29,7 +31,12 @@ TEST_SOURCES = $(TESTS_DIR)/test_main.c \ $(TESTS_DIR)/test_http.c \ $(TESTS_DIR)/test_buffer.c \ $(TESTS_DIR)/test_config.c \ - $(TESTS_DIR)/test_routing.c + $(TESTS_DIR)/test_routing.c \ + $(TESTS_DIR)/test_host_rewrite.c \ + $(TESTS_DIR)/test_http_helpers.c \ + $(TESTS_DIR)/test_patch.c \ + $(TESTS_DIR)/test_auth.c \ + $(TESTS_DIR)/test_rate_limit.c TEST_OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(TEST_SOURCES))) @@ -51,7 +58,10 @@ TEST_LIB_OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(TEST_LIB_SOURCES)) TEST_TARGET = rproxy_test -.PHONY: all clean test legacy run +MIN_COVERAGE = 70 +COVERAGE_MODULES = auth.c buffer.c config.c http.c logging.c patch.c rate_limit.c + +.PHONY: all clean test legacy run coverage coverage-html all: $(BUILD_DIR) $(TARGET) @@ -118,6 +128,21 @@ $(BUILD_DIR)/test_config.o: $(TESTS_DIR)/test_config.c $(BUILD_DIR)/test_routing.o: $(TESTS_DIR)/test_routing.c $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ +$(BUILD_DIR)/test_host_rewrite.o: $(TESTS_DIR)/test_host_rewrite.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_http_helpers.o: $(TESTS_DIR)/test_http_helpers.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_patch.o: $(TESTS_DIR)/test_patch.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_auth.o: $(TESTS_DIR)/test_auth.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_rate_limit.o: $(TESTS_DIR)/test_rate_limit.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + $(TEST_TARGET): $(BUILD_DIR) $(TEST_OBJECTS) $(TEST_LIB_OBJECTS) $(CC) $(TEST_OBJECTS) $(TEST_LIB_OBJECTS) -o $@ $(LDFLAGS) @@ -130,5 +155,83 @@ legacy: rproxy.c cJSON.c cJSON.h run: $(TARGET) ./$(TARGET) +coverage: clean + mkdir -p $(BUILD_DIR) + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_main.c -o build/test_main.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_http.c -o build/test_http.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_buffer.c -o build/test_buffer.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_config.c -o build/test_config.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_routing.c -o build/test_routing.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_host_rewrite.c -o build/test_host_rewrite.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_http_helpers.c -o build/test_http_helpers.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_patch.c -o build/test_patch.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_auth.c -o build/test_auth.o + $(CC) $(CFLAGS_COV) -Isrc -c tests/test_rate_limit.c -o build/test_rate_limit.o + $(CC) $(CFLAGS_COV) -c src/buffer.c -o build/buffer.o + $(CC) $(CFLAGS_COV) -c src/logging.c -o build/logging.o + $(CC) $(CFLAGS_COV) -c src/config.c -o build/config.o + $(CC) $(CFLAGS_COV) -c src/monitor.c -o build/monitor.o + $(CC) $(CFLAGS_COV) -c src/http.c -o build/http.o + $(CC) $(CFLAGS_COV) -c src/ssl_handler.c -o build/ssl_handler.o + $(CC) $(CFLAGS_COV) -c src/connection.c -o build/connection.o + $(CC) $(CFLAGS_COV) -c src/dashboard.c -o build/dashboard.o + $(CC) $(CFLAGS_COV) -c src/rate_limit.c -o build/rate_limit.o + $(CC) $(CFLAGS_COV) -c src/auth.c -o build/auth.o + $(CC) $(CFLAGS_COV) -c src/health_check.c -o build/health_check.o + $(CC) $(CFLAGS_COV) -c src/patch.c -o build/patch.o + $(CC) $(CFLAGS_COV) -c cJSON.c -o build/cJSON.o + $(CC) $(TEST_OBJECTS) $(TEST_LIB_OBJECTS) -o $(TEST_TARGET) $(LDFLAGS_COV) + ./$(TEST_TARGET) + @echo "" + @echo "=== Coverage Report ===" + @gcov -n build/*.gcno 2>/dev/null | grep -A 1 "^File.*src/" | grep -v "^--$$" || true + @echo "" + @echo "=== Coverage Check (minimum $(MIN_COVERAGE)%) ===" + @echo "Checking modules: $(COVERAGE_MODULES)" + @total_lines=0; covered_lines=0; \ + for gcno in build/*.gcno; do \ + output=$$(gcov -n "$$gcno" 2>/dev/null); \ + file=$$(echo "$$output" | grep "^File.*src/" | head -1 | sed "s/File '\\(.*\\)'/\\1/"); \ + if [ -n "$$file" ]; then \ + basename=$$(basename "$$file"); \ + is_tracked=0; \ + for mod in $(COVERAGE_MODULES); do \ + if [ "$$basename" = "$$mod" ]; then is_tracked=1; break; fi; \ + done; \ + if [ $$is_tracked -eq 1 ]; then \ + lines=$$(echo "$$output" | grep "Lines executed:" | head -1 | sed 's/.*of \([0-9]*\)/\1/'); \ + pct=$$(echo "$$output" | grep "Lines executed:" | head -1 | sed 's/Lines executed:\([0-9.]*\)%.*/\1/'); \ + if [ -n "$$lines" ] && [ -n "$$pct" ]; then \ + covered=$$(echo "$$pct * $$lines / 100" | bc -l | cut -d. -f1); \ + covered=$${covered:-0}; \ + total_lines=$$((total_lines + lines)); \ + covered_lines=$$((covered_lines + covered)); \ + echo " $$basename: $$pct% ($$covered/$$lines lines)"; \ + fi; \ + fi; \ + fi; \ + done; \ + if [ $$total_lines -gt 0 ]; then \ + avg=$$(echo "scale=2; $$covered_lines * 100 / $$total_lines" | bc -l); \ + echo ""; \ + echo "Total: $$covered_lines/$$total_lines lines covered ($$avg%)"; \ + avg_int=$$(echo "$$avg" | cut -d. -f1); \ + if [ $$avg_int -lt $(MIN_COVERAGE) ]; then \ + echo "FAILED: Coverage $$avg% is below minimum $(MIN_COVERAGE)%"; \ + exit 1; \ + else \ + echo "PASSED: Coverage $$avg% meets minimum $(MIN_COVERAGE)%"; \ + fi; \ + else \ + echo "FAILED: No coverage data found"; \ + exit 1; \ + fi + +coverage-html: coverage + mkdir -p coverage_report + lcov --capture --directory build --output-file coverage_report/coverage.info --ignore-errors source 2>/dev/null || true + genhtml coverage_report/coverage.info --output-directory coverage_report 2>/dev/null || echo "Install lcov for HTML reports: sudo apt install lcov" + @echo "Coverage report generated in coverage_report/index.html" + clean: - rm -rf $(BUILD_DIR) $(TARGET) $(TEST_TARGET) rproxy_legacy + rm -rf $(BUILD_DIR) $(TARGET) $(TEST_TARGET) rproxy_legacy *.gcov coverage_report diff --git a/src/config.c b/src/config.c index 2986942..51f1092 100644 --- a/src/config.c +++ b/src/config.c @@ -244,7 +244,7 @@ int config_load(const char *filename) { route->patches.rule_count++; } if (route->patches.rule_count > 0) { - log_info("Loaded %d patch rules for %s", route->hostname); + log_info("Loaded %d patch rules for %s", route->patches.rule_count, route->hostname); } } diff --git a/src/connection.c b/src/connection.c index 7cac766..0ed9d80 100644 --- a/src/connection.c +++ b/src/connection.c @@ -168,6 +168,15 @@ void connection_close(int fd) { log_debug("Upstream fd %d is closing. Resetting client fd %d to READING_HEADERS.", fd, pair->fd); pair->state = CLIENT_STATE_READING_HEADERS; pair->pair = NULL; + pair->route = NULL; + pair->content_type_checked = 0; + pair->is_textual_content = 0; + pair->response_headers_parsed = 0; + pair->original_content_length = 0; + pair->content_length_delta = 0; + pair->patch_blocked = 0; + pair->half_closed = 0; + pair->write_shutdown = 0; } else if (conn->type == CONN_TYPE_CLIENT && pair->type == CONN_TYPE_UPSTREAM) { log_debug("Client fd %d is closing. Closing orphaned upstream pair fd %d.", fd, pair->fd); pair->pair = NULL; @@ -698,6 +707,31 @@ static void handle_forwarding(connection_t *conn) { int bytes_read = connection_do_read(conn); + if (conn->type == CONN_TYPE_CLIENT && bytes_read > 0) { + char *data_start = conn->read_buf.data + conn->read_buf.head; + size_t data_len = buffer_available_read(&conn->read_buf); + + if (data_len >= 4 && http_is_request_start(data_start, data_len)) { + log_debug("Pipelined request detected in handle_forwarding on fd %d, closing upstream fd %d", + conn->fd, pair->fd); + connection_close(pair->fd); + conn->pair = NULL; + conn->state = CLIENT_STATE_READING_HEADERS; + + conn->content_type_checked = 0; + conn->is_textual_content = 0; + conn->response_headers_parsed = 0; + conn->original_content_length = 0; + conn->content_length_delta = 0; + conn->patch_blocked = 0; + conn->half_closed = 0; + conn->write_shutdown = 0; + + handle_client_read(conn); + return; + } + } + if (bytes_read == 0) { log_debug("EOF on fd %d, performing half-close on pair fd %d", conn->fd, pair->fd); conn->half_closed = 1; @@ -841,14 +875,20 @@ static void handle_ssl_handshake(connection_t *conn) { connection_t *client = conn->pair; if (buffer_available_read(&client->read_buf) > 0) { + char *data_start = client->read_buf.data + client->read_buf.head; size_t data_len = buffer_available_read(&client->read_buf); - if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + data_len) == 0) { - memcpy(conn->write_buf.data + conn->write_buf.tail, - client->read_buf.data + client->read_buf.head, - data_len); - conn->write_buf.tail += data_len; - buffer_consume(&client->read_buf, data_len); - log_debug("Forwarding %zu bytes of buffered request data after SSL handshake", data_len); + + if (data_len >= 4 && http_is_request_start(data_start, data_len)) { + log_debug("New HTTP request detected in client buffer after SSL handshake, not forwarding raw"); + } else { + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + data_len) == 0) { + memcpy(conn->write_buf.data + conn->write_buf.tail, + client->read_buf.data + client->read_buf.head, + data_len); + conn->write_buf.tail += data_len; + buffer_consume(&client->read_buf, data_len); + log_debug("Forwarding %zu bytes of buffered request data after SSL handshake", data_len); + } } } } diff --git a/src/health_check.c b/src/health_check.c index 92983e7..96f3e60 100644 --- a/src/health_check.c +++ b/src/health_check.c @@ -34,7 +34,7 @@ void health_check_init(void) { free(health_states); } - health_state_count = config.route_count; + health_state_count = config->route_count; if (health_state_count <= 0) { health_states = NULL; pthread_mutex_unlock(&health_mutex); @@ -49,9 +49,9 @@ void health_check_init(void) { } for (int i = 0; i < health_state_count; i++) { - strncpy(health_states[i].hostname, config.routes[i].hostname, sizeof(health_states[i].hostname) - 1); - strncpy(health_states[i].upstream_host, config.routes[i].upstream_host, sizeof(health_states[i].upstream_host) - 1); - health_states[i].upstream_port = config.routes[i].upstream_port; + strncpy(health_states[i].hostname, config->routes[i].hostname, sizeof(health_states[i].hostname) - 1); + strncpy(health_states[i].upstream_host, config->routes[i].upstream_host, sizeof(health_states[i].upstream_host) - 1); + health_states[i].upstream_port = config->routes[i].upstream_port; health_states[i].healthy = 1; health_states[i].consecutive_failures = 0; health_states[i].last_check = 0; diff --git a/tests/test_auth.c b/tests/test_auth.c new file mode 100644 index 0000000..a1376ce --- /dev/null +++ b/tests/test_auth.c @@ -0,0 +1,107 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/auth.h" +#include + +void test_auth_init_and_enabled(void) { + TEST_SUITE_BEGIN("Auth Init and Enabled"); + + auth_init(NULL, NULL); + TEST_ASSERT_EQ(0, auth_is_enabled(), "Auth disabled with NULL credentials"); + + auth_init("", ""); + TEST_ASSERT_EQ(0, auth_is_enabled(), "Auth disabled with empty credentials"); + + auth_init("admin", "secret"); + TEST_ASSERT_EQ(1, auth_is_enabled(), "Auth enabled with valid credentials"); + + TEST_SUITE_END(); +} + +void test_auth_check_credentials(void) { + TEST_SUITE_BEGIN("Auth Check Credentials"); + + auth_init("testuser", "testpass"); + + TEST_ASSERT_EQ(1, auth_check_credentials("testuser", "testpass"), "Correct credentials accepted"); + TEST_ASSERT_EQ(0, auth_check_credentials("testuser", "wrongpass"), "Wrong password rejected"); + TEST_ASSERT_EQ(0, auth_check_credentials("wronguser", "testpass"), "Wrong username rejected"); + TEST_ASSERT_EQ(0, auth_check_credentials("wronguser", "wrongpass"), "Wrong both rejected"); + TEST_ASSERT_EQ(0, auth_check_credentials(NULL, "testpass"), "NULL username rejected"); + TEST_ASSERT_EQ(0, auth_check_credentials("testuser", NULL), "NULL password rejected"); + + TEST_SUITE_END(); +} + +void test_auth_check_basic_auth(void) { + TEST_SUITE_BEGIN("Auth Check Basic Auth Header"); + + auth_init("admin", "password123"); + + char error[256]; + + int result = auth_check_basic_auth("Basic YWRtaW46cGFzc3dvcmQxMjM=", error, sizeof(error)); + TEST_ASSERT_EQ(1, result, "Valid Basic auth header accepted"); + + result = auth_check_basic_auth("Basic aW52YWxpZDppbnZhbGlk", error, sizeof(error)); + TEST_ASSERT_EQ(0, result, "Invalid credentials rejected"); + + result = auth_check_basic_auth(NULL, error, sizeof(error)); + TEST_ASSERT_EQ(0, result, "NULL auth header rejected"); + + result = auth_check_basic_auth("Bearer token123", error, sizeof(error)); + TEST_ASSERT_EQ(0, result, "Bearer auth rejected"); + + result = auth_check_basic_auth("Basic !!invalid!!", error, sizeof(error)); + TEST_ASSERT_EQ(0, result, "Invalid base64 rejected"); + + result = auth_check_basic_auth("Basic bm9jb2xvbg==", error, sizeof(error)); + TEST_ASSERT_EQ(0, result, "Base64 without colon rejected"); + + TEST_SUITE_END(); +} + +void test_auth_route_basic_auth(void) { + TEST_SUITE_BEGIN("Auth Route Basic Auth"); + + route_config_t route; + memset(&route, 0, sizeof(route)); + route.use_auth = 0; + + char error[256]; + int result = auth_check_route_basic_auth(&route, NULL, error, sizeof(error)); + TEST_ASSERT_EQ(1, result, "Route without auth passes"); + + route.use_auth = 1; + strcpy(route.username, "routeuser"); + strcpy(route.password_hash, "5e884898da28047d9166d19d3de276ee6e8b0cb37d63e0c4c7ed6e1b4e5c7fd2"); + + result = auth_check_route_basic_auth(&route, NULL, error, sizeof(error)); + TEST_ASSERT_EQ(0, result, "Route with auth requires header"); + + result = auth_check_route_basic_auth(NULL, NULL, error, sizeof(error)); + TEST_ASSERT_EQ(1, result, "NULL route passes"); + + TEST_SUITE_END(); +} + +void test_auth_disabled_passthrough(void) { + TEST_SUITE_BEGIN("Auth Disabled Passthrough"); + + auth_init(NULL, NULL); + + TEST_ASSERT_EQ(1, auth_check_credentials("anyone", "anything"), "Disabled auth passes any credentials"); + + char error[256]; + TEST_ASSERT_EQ(1, auth_check_basic_auth(NULL, error, sizeof(error)), "Disabled auth passes NULL header"); + + TEST_SUITE_END(); +} + +void run_auth_tests(void) { + test_auth_init_and_enabled(); + test_auth_check_credentials(); + test_auth_check_basic_auth(); + test_auth_route_basic_auth(); + test_auth_disabled_passthrough(); +} diff --git a/tests/test_config.c b/tests/test_config.c index d663848..9f54d93 100644 --- a/tests/test_config.c +++ b/tests/test_config.c @@ -46,8 +46,8 @@ void test_config_load_valid(void) { int result = config_load(TEST_CONFIG_FILE); TEST_ASSERT_EQ(1, result, "Config loaded successfully"); - TEST_ASSERT_EQ(9090, config.port, "Port is 9090"); - TEST_ASSERT_EQ(2, config.route_count, "Two routes configured"); + TEST_ASSERT_EQ(9090, config->port, "Port is 9090"); + TEST_ASSERT_EQ(2, config->route_count, "Two routes configured"); route_config_t *route1 = config_find_route("test.example.com"); TEST_ASSERT(route1 != NULL, "Route for test.example.com found"); @@ -159,7 +159,7 @@ void test_config_default_port(void) { create_test_config(config_content); config_load(TEST_CONFIG_FILE); - TEST_ASSERT_EQ(8080, config.port, "Default port is 8080 when not specified"); + TEST_ASSERT_EQ(8080, config->port, "Default port is 8080 when not specified"); config_free(); cleanup_test_config(); @@ -239,6 +239,437 @@ void test_config_ssl_rewrite_host_options(void) { TEST_SUITE_END(); } +void test_config_invalid_port(void) { + TEST_SUITE_BEGIN("Config Invalid Port"); + + const char *config_content = + "{\n" + " \"port\": 99999,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + int result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, result, "Invalid port (99999) rejected"); + + cleanup_test_config(); + + const char *config_zero = + "{\n" + " \"port\": 0,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_zero); + result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, result, "Invalid port (0) rejected"); + + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_empty_routes(void) { + TEST_SUITE_BEGIN("Config Empty Routes"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": []\n" + "}\n"; + + create_test_config(config_content); + int result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, result, "Empty routes array rejected"); + + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_with_auth(void) { + TEST_SUITE_BEGIN("Config With Authentication"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"secure.example.com\",\n" + " \"upstream_host\": \"backend.local\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false,\n" + " \"use_auth\": true,\n" + " \"username\": \"admin\",\n" + " \"password\": \"secret123\"\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + int result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(1, result, "Config with auth loaded"); + + route_config_t *route = config_find_route("secure.example.com"); + TEST_ASSERT(route != NULL, "Auth route found"); + if (route) { + TEST_ASSERT_EQ(1, route->use_auth, "Auth is enabled"); + TEST_ASSERT_STR_EQ("admin", route->username, "Username is admin"); + TEST_ASSERT(strlen(route->password_hash) > 0, "Password hash is set"); + } + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_with_patches(void) { + TEST_SUITE_BEGIN("Config With Patch Rules"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"patched.example.com\",\n" + " \"upstream_host\": \"backend.local\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false,\n" + " \"patch\": {\n" + " \"old-text\": \"new-text\",\n" + " \"remove-this\": null\n" + " }\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + int result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(1, result, "Config with patches loaded"); + + route_config_t *route = config_find_route("patched.example.com"); + TEST_ASSERT(route != NULL, "Patched route found"); + if (route) { + TEST_ASSERT_EQ(2, route->patches.rule_count, "Two patch rules loaded"); + } + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_create_default(void) { + TEST_SUITE_BEGIN("Config Create Default"); + + const char *default_file = "/tmp/test_default_config.json"; + unlink(default_file); + + config_create_default(default_file); + + FILE *f = fopen(default_file, "r"); + TEST_ASSERT(f != NULL, "Default config file created"); + if (f) { + fclose(f); + } + + config_create_default(default_file); + + int result = config_load(default_file); + TEST_ASSERT_EQ(1, result, "Default config is valid"); + + config_free(); + unlink(default_file); + + TEST_SUITE_END(); +} + +void test_config_check_file_changed(void) { + TEST_SUITE_BEGIN("Config Check File Changed"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + + int first_check = config_check_file_changed(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, first_check, "First check returns 0 (initializes mtime)"); + + int second_check = config_check_file_changed(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, second_check, "Second check returns 0 (unchanged)"); + + int missing_check = config_check_file_changed("/tmp/nonexistent_file.json"); + TEST_ASSERT_EQ(0, missing_check, "Missing file returns 0"); + + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_hot_reload(void) { + TEST_SUITE_BEGIN("Config Hot Reload"); + + const char *initial_config = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"initial.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(initial_config); + config_load(TEST_CONFIG_FILE); + + route_config_t *route1 = config_find_route("initial.com"); + TEST_ASSERT(route1 != NULL, "Initial route found"); + + const char *updated_config = + "{\n" + " \"port\": 9090,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"updated.com\",\n" + " \"upstream_host\": \"newhost\",\n" + " \"upstream_port\": 4000,\n" + " \"use_ssl\": true,\n" + " \"rewrite_host\": true\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(updated_config); + int result = config_hot_reload(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(1, result, "Hot reload succeeded"); + + route_config_t *route2 = config_find_route("updated.com"); + TEST_ASSERT(route2 != NULL, "Updated route found after hot reload"); + if (route2) { + TEST_ASSERT_STR_EQ("newhost", route2->upstream_host, "New upstream host"); + TEST_ASSERT_EQ(4000, route2->upstream_port, "New upstream port"); + TEST_ASSERT_EQ(1, route2->use_ssl, "SSL enabled after reload"); + } + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_hot_reload_invalid(void) { + TEST_SUITE_BEGIN("Config Hot Reload Invalid"); + + const char *valid_config = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(valid_config); + config_load(TEST_CONFIG_FILE); + + create_test_config("{ invalid json }"); + int result = config_hot_reload(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, result, "Hot reload with invalid JSON fails"); + + route_config_t *route = config_find_route("test.com"); + TEST_ASSERT(route != NULL, "Original config preserved after failed reload"); + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_hot_reload_invalid_port(void) { + TEST_SUITE_BEGIN("Config Hot Reload Invalid Port"); + + const char *valid_config = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(valid_config); + config_load(TEST_CONFIG_FILE); + + const char *invalid_port_config = + "{\n" + " \"port\": 99999,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"new.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(invalid_port_config); + int result = config_hot_reload(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, result, "Hot reload with invalid port fails"); + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_ref_counting(void) { + TEST_SUITE_BEGIN("Config Reference Counting"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + config_load(TEST_CONFIG_FILE); + + TEST_ASSERT(config != NULL, "Config loaded"); + TEST_ASSERT_EQ(1, config->ref_count, "Initial ref count is 1"); + + config_ref_inc(config); + TEST_ASSERT_EQ(2, config->ref_count, "Ref count incremented to 2"); + + config_ref_inc(config); + TEST_ASSERT_EQ(3, config->ref_count, "Ref count incremented to 3"); + + config_ref_dec(config); + TEST_ASSERT_EQ(2, config->ref_count, "Ref count decremented to 2"); + + config_ref_dec(config); + TEST_ASSERT_EQ(1, config->ref_count, "Ref count decremented to 1"); + + config_ref_inc(NULL); + config_ref_dec(NULL); + TEST_ASSERT(1, "NULL ref operations don't crash"); + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_hot_reload_with_auth_and_patches(void) { + TEST_SUITE_BEGIN("Config Hot Reload With Auth And Patches"); + + const char *initial_config = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(initial_config); + config_load(TEST_CONFIG_FILE); + + const char *updated_config = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"secure.com\",\n" + " \"upstream_host\": \"backend\",\n" + " \"upstream_port\": 443,\n" + " \"use_ssl\": true,\n" + " \"rewrite_host\": true,\n" + " \"use_auth\": true,\n" + " \"username\": \"user\",\n" + " \"password\": \"pass\",\n" + " \"patch\": {\n" + " \"find\": \"replace\"\n" + " }\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(updated_config); + int result = config_hot_reload(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(1, result, "Hot reload with auth and patches succeeded"); + + route_config_t *route = config_find_route("secure.com"); + TEST_ASSERT(route != NULL, "Route with auth found"); + if (route) { + TEST_ASSERT_EQ(1, route->use_auth, "Auth enabled after reload"); + TEST_ASSERT_EQ(1, route->patches.rule_count, "Patch rule loaded after reload"); + } + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + void run_config_tests(void) { test_config_load_valid(); test_config_find_route_case_insensitive(); @@ -247,4 +678,15 @@ void run_config_tests(void) { test_config_invalid_json(); test_config_missing_file(); test_config_ssl_rewrite_host_options(); + test_config_invalid_port(); + test_config_empty_routes(); + test_config_with_auth(); + test_config_with_patches(); + test_config_create_default(); + test_config_check_file_changed(); + test_config_hot_reload(); + test_config_hot_reload_invalid(); + test_config_hot_reload_invalid_port(); + test_config_ref_counting(); + test_config_hot_reload_with_auth_and_patches(); } diff --git a/tests/test_host_rewrite.c b/tests/test_host_rewrite.c new file mode 100644 index 0000000..e235950 --- /dev/null +++ b/tests/test_host_rewrite.c @@ -0,0 +1,307 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/http.h" +#include "../src/config.h" +#include "../src/buffer.h" +#include +#include +#include +#include + +static const char *TEST_CONFIG_FILE = "/tmp/test_host_rewrite_config.json"; + +static void create_ssl_rewrite_config(void) { + FILE *f = fopen(TEST_CONFIG_FILE, "w"); + if (f) { + fprintf(f, + "{\n" + " \"port\": 9999,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"mysite.example.com\",\n" + " \"upstream_host\": \"upstream.internal\",\n" + " \"upstream_port\": 443,\n" + " \"use_ssl\": true,\n" + " \"rewrite_host\": true\n" + " }\n" + " ]\n" + "}\n"); + fclose(f); + } +} + +static void cleanup_config(void) { + unlink(TEST_CONFIG_FILE); +} + +void test_host_header_line_bounds_first_request(void) { + TEST_SUITE_BEGIN("Host Header Line Bounds - First Request"); + + const char *request = + "GET /page1 HTTP/1.1\r\n" + "Host: mysite.example.com\r\n" + "Connection: keep-alive\r\n" + "\r\n"; + size_t len = strlen(request); + + const char *line_start = NULL; + const char *line_end = NULL; + int found = http_find_header_line_bounds(request, len, "Host", &line_start, &line_end); + + TEST_ASSERT_EQ(1, found, "Host header found in first request"); + TEST_ASSERT(line_start != NULL, "Line start is not NULL"); + TEST_ASSERT(line_end != NULL, "Line end is not NULL"); + + if (line_start && line_end) { + size_t host_line_len = line_end - line_start; + char host_line[256]; + if (host_line_len < sizeof(host_line)) { + memcpy(host_line, line_start, host_line_len); + host_line[host_line_len] = '\0'; + TEST_ASSERT_STR_EQ("Host: mysite.example.com\r\n", host_line, "Host header line extracted correctly"); + } + } + + TEST_SUITE_END(); +} + +void test_host_header_line_bounds_second_request(void) { + TEST_SUITE_BEGIN("Host Header Line Bounds - Second Request"); + + const char *request = + "GET /page2 HTTP/1.1\r\n" + "Host: mysite.example.com\r\n" + "Connection: keep-alive\r\n" + "\r\n"; + size_t len = strlen(request); + + const char *line_start = NULL; + const char *line_end = NULL; + int found = http_find_header_line_bounds(request, len, "Host", &line_start, &line_end); + + TEST_ASSERT_EQ(1, found, "Host header found in second request"); + TEST_ASSERT(line_start != NULL, "Line start is not NULL"); + TEST_ASSERT(line_end != NULL, "Line end is not NULL"); + + if (line_start && line_end) { + size_t host_line_len = line_end - line_start; + char host_line[256]; + if (host_line_len < sizeof(host_line)) { + memcpy(host_line, line_start, host_line_len); + host_line[host_line_len] = '\0'; + TEST_ASSERT_STR_EQ("Host: mysite.example.com\r\n", host_line, "Host header line extracted correctly"); + } + } + + TEST_SUITE_END(); +} + +void test_host_rewrite_simulation(void) { + TEST_SUITE_BEGIN("Host Rewrite Simulation - Multiple Requests"); + + create_ssl_rewrite_config(); + config_load(TEST_CONFIG_FILE); + + route_config_t *route = config_find_route("mysite.example.com"); + TEST_ASSERT(route != NULL, "Route found for mysite.example.com"); + if (!route) { + config_free(); + cleanup_config(); + TEST_SUITE_END(); + return; + } + + TEST_ASSERT_EQ(1, route->rewrite_host, "Rewrite host is enabled"); + TEST_ASSERT_EQ(1, route->use_ssl, "SSL is enabled"); + TEST_ASSERT_STR_EQ("upstream.internal", route->upstream_host, "Upstream host is upstream.internal"); + + const char *request1 = + "GET /page1 HTTP/1.1\r\n" + "Host: mysite.example.com\r\n" + "Connection: keep-alive\r\n" + "\r\n"; + + const char *request2 = + "GET /page2 HTTP/1.1\r\n" + "Host: mysite.example.com\r\n" + "Connection: keep-alive\r\n" + "\r\n"; + + for (int req_num = 1; req_num <= 2; req_num++) { + const char *request = (req_num == 1) ? request1 : request2; + size_t request_len = strlen(request); + + const char *old_host_start = NULL; + const char *old_host_end = NULL; + int found = http_find_header_line_bounds(request, request_len, "Host", &old_host_start, &old_host_end); + + char msg[128]; + snprintf(msg, sizeof(msg), "Request %d: Host header found", req_num); + TEST_ASSERT_EQ(1, found, msg); + + if (found && old_host_start && old_host_end) { + char new_host_header[512]; + int is_default_port = (route->use_ssl && route->upstream_port == 443) || + (!route->use_ssl && route->upstream_port == 80); + if (is_default_port) { + snprintf(new_host_header, sizeof(new_host_header), "Host: %s\r\n", route->upstream_host); + } else { + snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", + route->upstream_host, route->upstream_port); + } + + size_t new_host_len = strlen(new_host_header); + size_t old_host_len = old_host_end - old_host_start; + size_t new_request_len = request_len - old_host_len + new_host_len; + + char *modified_request = malloc(new_request_len + 1); + TEST_ASSERT(modified_request != NULL, "Memory allocated for modified request"); + + if (modified_request) { + char *p = modified_request; + size_t prefix_len = old_host_start - request; + memcpy(p, request, prefix_len); + p += prefix_len; + memcpy(p, new_host_header, new_host_len); + p += new_host_len; + size_t suffix_len = request_len - (old_host_end - request); + memcpy(p, old_host_end, suffix_len); + modified_request[new_request_len] = '\0'; + + int has_new_host = (strstr(modified_request, "Host: upstream.internal\r\n") != NULL); + snprintf(msg, sizeof(msg), "Request %d: Host rewritten to upstream.internal", req_num); + TEST_ASSERT_EQ(1, has_new_host, msg); + + int no_old_host = (strstr(modified_request, "Host: mysite.example.com") == NULL); + snprintf(msg, sizeof(msg), "Request %d: Original host removed", req_num); + TEST_ASSERT_EQ(1, no_old_host, msg); + + free(modified_request); + } + } + } + + config_free(); + cleanup_config(); + + TEST_SUITE_END(); +} + +void test_pipelined_request_detection(void) { + TEST_SUITE_BEGIN("Pipelined Request Detection for Host Rewrite"); + + const char *first_request = + "GET /first HTTP/1.1\r\n" + "Host: mysite.example.com\r\n" + "\r\n"; + + const char *second_request = + "GET /second HTTP/1.1\r\n" + "Host: mysite.example.com\r\n" + "\r\n"; + + int is_first_req_start = http_is_request_start(first_request, strlen(first_request)); + TEST_ASSERT_EQ(1, is_first_req_start, "First request detected as HTTP request start"); + + int is_second_req_start = http_is_request_start(second_request, strlen(second_request)); + TEST_ASSERT_EQ(1, is_second_req_start, "Second request detected as HTTP request start"); + + const char *response_data = "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"; + int is_response_req = http_is_request_start(response_data, strlen(response_data)); + TEST_ASSERT_EQ(0, is_response_req, "Response data NOT detected as request start"); + + const char *binary_data = "\x00\x01\x02\x03"; + int is_binary_req = http_is_request_start(binary_data, 4); + TEST_ASSERT_EQ(0, is_binary_req, "Binary data NOT detected as request start"); + + TEST_SUITE_END(); +} + +void test_ssl_sni_with_host_rewrite(void) { + TEST_SUITE_BEGIN("SSL SNI Hostname with Host Rewrite"); + + create_ssl_rewrite_config(); + config_load(TEST_CONFIG_FILE); + + route_config_t *route = config_find_route("mysite.example.com"); + TEST_ASSERT(route != NULL, "Route found"); + + if (route) { + const char *sni_hostname = route->rewrite_host ? route->upstream_host : "mysite.example.com"; + TEST_ASSERT_STR_EQ("upstream.internal", sni_hostname, "SNI should be upstream.internal when rewrite_host is true"); + } + + config_free(); + cleanup_config(); + + TEST_SUITE_END(); +} + +void test_consecutive_requests_same_connection(void) { + TEST_SUITE_BEGIN("Consecutive Requests on Same Connection"); + + create_ssl_rewrite_config(); + config_load(TEST_CONFIG_FILE); + + route_config_t *route = config_find_route("mysite.example.com"); + TEST_ASSERT(route != NULL, "Route found"); + + if (!route) { + config_free(); + cleanup_config(); + TEST_SUITE_END(); + return; + } + + const char *requests[] = { + "GET /page1 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n", + "GET /page2 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n", + "GET /page3 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: keep-alive\r\n\r\n", + "POST /api HTTP/1.1\r\nHost: mysite.example.com\r\nContent-Length: 0\r\n\r\n", + "GET /page4 HTTP/1.1\r\nHost: mysite.example.com\r\nConnection: close\r\n\r\n" + }; + + int num_requests = sizeof(requests) / sizeof(requests[0]); + + for (int i = 0; i < num_requests; i++) { + const char *request = requests[i]; + size_t request_len = strlen(request); + + http_request_t parsed_req; + int parse_result = http_parse_request(request, request_len, &parsed_req); + + char msg[128]; + snprintf(msg, sizeof(msg), "Request %d: Parsed successfully", i + 1); + TEST_ASSERT_EQ(1, parse_result, msg); + + route_config_t *found_route = config_find_route(parsed_req.host); + snprintf(msg, sizeof(msg), "Request %d: Route found", i + 1); + TEST_ASSERT(found_route != NULL, msg); + + if (found_route) { + snprintf(msg, sizeof(msg), "Request %d: Rewrite host enabled", i + 1); + TEST_ASSERT_EQ(1, found_route->rewrite_host, msg); + + const char *old_host_start = NULL; + const char *old_host_end = NULL; + int found = http_find_header_line_bounds(request, request_len, "Host", + &old_host_start, &old_host_end); + snprintf(msg, sizeof(msg), "Request %d: Host header found for rewrite", i + 1); + TEST_ASSERT_EQ(1, found, msg); + } + } + + config_free(); + cleanup_config(); + + TEST_SUITE_END(); +} + +void run_host_rewrite_tests(void) { + test_host_header_line_bounds_first_request(); + test_host_header_line_bounds_second_request(); + test_host_rewrite_simulation(); + test_pipelined_request_detection(); + test_ssl_sni_with_host_rewrite(); + test_consecutive_requests_same_connection(); +} diff --git a/tests/test_http_helpers.c b/tests/test_http_helpers.c new file mode 100644 index 0000000..a24c00f --- /dev/null +++ b/tests/test_http_helpers.c @@ -0,0 +1,200 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/http.h" +#include + +void test_http_find_header_value(void) { + TEST_SUITE_BEGIN("HTTP Find Header Value"); + + const char *headers = + "Host: example.com\r\n" + "Content-Type: application/json\r\n" + "Content-Length: 42\r\n" + "X-Custom-Header: custom-value\r\n" + "Connection: keep-alive\r\n" + "\r\n"; + size_t len = strlen(headers); + char value[256]; + + int found = http_find_header_value(headers, len, "Host", value, sizeof(value)); + TEST_ASSERT_EQ(1, found, "Host header found"); + TEST_ASSERT_STR_EQ("example.com", value, "Host value correct"); + + found = http_find_header_value(headers, len, "Content-Type", value, sizeof(value)); + TEST_ASSERT_EQ(1, found, "Content-Type header found"); + TEST_ASSERT_STR_EQ("application/json", value, "Content-Type value correct"); + + found = http_find_header_value(headers, len, "Content-Length", value, sizeof(value)); + TEST_ASSERT_EQ(1, found, "Content-Length header found"); + TEST_ASSERT_STR_EQ("42", value, "Content-Length value correct"); + + found = http_find_header_value(headers, len, "X-Custom-Header", value, sizeof(value)); + TEST_ASSERT_EQ(1, found, "Custom header found"); + TEST_ASSERT_STR_EQ("custom-value", value, "Custom header value correct"); + + found = http_find_header_value(headers, len, "Non-Existent", value, sizeof(value)); + TEST_ASSERT_EQ(0, found, "Non-existent header not found"); + + found = http_find_header_value(headers, len, "host", value, sizeof(value)); + TEST_ASSERT_EQ(1, found, "Case-insensitive header search works"); + + TEST_SUITE_END(); +} + +void test_http_is_textual_content_type(void) { + TEST_SUITE_BEGIN("HTTP Textual Content Type Detection"); + + TEST_ASSERT_EQ(1, http_is_textual_content_type("text/html"), "text/html is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("text/plain"), "text/plain is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("text/css"), "text/css is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("text/javascript"), "text/javascript is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("application/json"), "application/json is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("application/javascript"), "application/javascript is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("application/xml"), "application/xml is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("application/xhtml+xml"), "application/xhtml+xml is textual"); + + TEST_ASSERT_EQ(0, http_is_textual_content_type("image/png"), "image/png is not textual"); + TEST_ASSERT_EQ(0, http_is_textual_content_type("image/jpeg"), "image/jpeg is not textual"); + TEST_ASSERT_EQ(0, http_is_textual_content_type("application/octet-stream"), "octet-stream is not textual"); + TEST_ASSERT_EQ(0, http_is_textual_content_type("video/mp4"), "video/mp4 is not textual"); + TEST_ASSERT_EQ(0, http_is_textual_content_type("audio/mpeg"), "audio/mpeg is not textual"); + + TEST_ASSERT_EQ(1, http_is_textual_content_type("text/html; charset=utf-8"), "text/html with charset is textual"); + TEST_ASSERT_EQ(1, http_is_textual_content_type("application/json; charset=utf-8"), "json with charset is textual"); + + TEST_ASSERT_EQ(0, http_is_textual_content_type(NULL), "NULL is not textual"); + + TEST_SUITE_END(); +} + +void test_http_detect_binary_content(void) { + TEST_SUITE_BEGIN("HTTP Binary Content Detection"); + + const char *text_data = "Hello, World! This is plain text content."; + TEST_ASSERT_EQ(0, http_detect_binary_content(text_data, strlen(text_data)), "Plain text is not binary"); + + const char *html_data = "

Hello

"; + TEST_ASSERT_EQ(0, http_detect_binary_content(html_data, strlen(html_data)), "HTML is not binary"); + + const char *json_data = "{\"key\": \"value\", \"number\": 123}"; + TEST_ASSERT_EQ(0, http_detect_binary_content(json_data, strlen(json_data)), "JSON is not binary"); + + const unsigned char binary_data[] = {0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}; + TEST_ASSERT_EQ(1, http_detect_binary_content((const char*)binary_data, sizeof(binary_data)), "PNG header is binary"); + + const unsigned char null_bytes[] = {'H', 'e', 'l', 'l', 'o', 0x00, 'W', 'o', 'r', 'l', 'd'}; + TEST_ASSERT_EQ(1, http_detect_binary_content((const char*)null_bytes, sizeof(null_bytes)), "Data with null bytes is binary"); + + TEST_ASSERT_EQ(0, http_detect_binary_content(NULL, 0), "NULL data returns 0"); + TEST_ASSERT_EQ(0, http_detect_binary_content("", 0), "Empty data returns 0"); + + TEST_SUITE_END(); +} + +void test_http_get_content_length(void) { + TEST_SUITE_BEGIN("HTTP Get Content-Length"); + + const char *headers1 = "HTTP/1.1 200 OK\r\nContent-Length: 1234\r\nContent-Type: text/html\r\n\r\n"; + TEST_ASSERT_EQ(1234, http_get_content_length(headers1, strlen(headers1)), "Content-Length 1234 parsed"); + + const char *headers2 = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: 0\r\n\r\n"; + TEST_ASSERT_EQ(0, http_get_content_length(headers2, strlen(headers2)), "Content-Length 0 parsed"); + + const char *headers3 = "HTTP/1.1 200 OK\r\nContent-Length: 9999999\r\n\r\n"; + TEST_ASSERT_EQ(9999999, http_get_content_length(headers3, strlen(headers3)), "Large Content-Length parsed"); + + const char *headers4 = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; + TEST_ASSERT_EQ(-1, http_get_content_length(headers4, strlen(headers4)), "Missing Content-Length returns -1"); + + const char *headers5 = "HTTP/1.1 200 OK\r\ncontent-length: 500\r\n\r\n"; + long len5 = http_get_content_length(headers5, strlen(headers5)); + TEST_ASSERT(len5 == 500 || len5 == -1, "Lowercase content-length handled"); + + TEST_SUITE_END(); +} + +void test_http_find_headers_end(void) { + TEST_SUITE_BEGIN("HTTP Find Headers End"); + + const char *request1 = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\nBody content here"; + size_t end1; + int found1 = http_find_headers_end(request1, strlen(request1), &end1); + TEST_ASSERT_EQ(1, found1, "Headers end found in complete request"); + TEST_ASSERT(end1 > 0 && end1 < strlen(request1), "Headers end position is valid"); + + const char *request2 = "GET / HTTP/1.1\r\nHost: example.com\r\n"; + size_t end2; + int found2 = http_find_headers_end(request2, strlen(request2), &end2); + TEST_ASSERT_EQ(0, found2, "Incomplete headers not found"); + + const char *request3 = "GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"; + size_t end3; + int found3 = http_find_headers_end(request3, strlen(request3), &end3); + TEST_ASSERT_EQ(1, found3, "Headers end found with no body"); + TEST_ASSERT_EQ(strlen(request3), end3, "Headers end at string end for no body"); + + TEST_SUITE_END(); +} + +void test_http_rewrite_content_length(void) { + TEST_SUITE_BEGIN("HTTP Rewrite Content-Length"); + + char headers1[512] = "HTTP/1.1 200 OK\r\nContent-Length: 100\r\nContent-Type: text/html\r\n\r\n"; + size_t len1 = strlen(headers1); + int result1 = http_rewrite_content_length(headers1, &len1, sizeof(headers1), 200); + TEST_ASSERT_EQ(1, result1, "Content-Length rewrite succeeded"); + TEST_ASSERT(strstr(headers1, "200") != NULL, "New length present in headers"); + + char headers2[512] = "HTTP/1.1 200 OK\r\nContent-Length: 99999\r\n\r\n"; + size_t len2 = strlen(headers2); + int result2 = http_rewrite_content_length(headers2, &len2, sizeof(headers2), 50); + TEST_ASSERT_EQ(1, result2, "Content-Length rewrite to smaller succeeded"); + + char headers3[512] = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n"; + size_t len3 = strlen(headers3); + int result3 = http_rewrite_content_length(headers3, &len3, sizeof(headers3), 100); + TEST_ASSERT_EQ(0, result3, "Rewrite fails when no Content-Length present"); + + TEST_SUITE_END(); +} + +void test_http_find_header_line_bounds_edge_cases(void) { + TEST_SUITE_BEGIN("HTTP Header Line Bounds Edge Cases"); + + const char *request1 = "GET / HTTP/1.1\r\nHost: example.com\r\nX-Test:no-space\r\n\r\n"; + const char *start1, *end1; + int found1 = http_find_header_line_bounds(request1, strlen(request1), "X-Test", &start1, &end1); + TEST_ASSERT_EQ(1, found1, "Header without space after colon found"); + + const char *request2 = "GET / HTTP/1.1\r\nHost: lots-of-spaces \r\n\r\n"; + const char *start2, *end2; + int found2 = http_find_header_line_bounds(request2, strlen(request2), "Host", &start2, &end2); + TEST_ASSERT_EQ(1, found2, "Header with extra spaces found"); + + const char *request3 = "GET / HTTP/1.1\r\nhost: lowercase.com\r\n\r\n"; + const char *start3, *end3; + int found3 = http_find_header_line_bounds(request3, strlen(request3), "Host", &start3, &end3); + TEST_ASSERT_EQ(1, found3, "Lowercase header name found"); + + const char *request4 = "GET / HTTP/1.1\r\nHOST: UPPERCASE.COM\r\n\r\n"; + const char *start4, *end4; + int found4 = http_find_header_line_bounds(request4, strlen(request4), "host", &start4, &end4); + TEST_ASSERT_EQ(1, found4, "Uppercase header with lowercase search found"); + + const char *request5 = "GET / HTTP/1.1\r\n\r\n"; + const char *start5, *end5; + int found5 = http_find_header_line_bounds(request5, strlen(request5), "Host", &start5, &end5); + TEST_ASSERT_EQ(0, found5, "Header not found in request with no headers"); + + TEST_SUITE_END(); +} + +void run_http_helper_tests(void) { + test_http_find_header_value(); + test_http_is_textual_content_type(); + test_http_detect_binary_content(); + test_http_get_content_length(); + test_http_find_headers_end(); + test_http_rewrite_content_length(); + test_http_find_header_line_bounds_edge_cases(); +} diff --git a/tests/test_main.c b/tests/test_main.c index 6832b5b..e2b98a0 100644 --- a/tests/test_main.c +++ b/tests/test_main.c @@ -20,6 +20,11 @@ extern void run_http_tests(void); extern void run_buffer_tests(void); extern void run_config_tests(void); extern void run_routing_tests(void); +extern void run_host_rewrite_tests(void); +extern void run_http_helper_tests(void); +extern void run_patch_tests(void); +extern void run_auth_tests(void); +extern void run_rate_limit_tests(void); int main(int argc, char *argv[]) { (void)argc; @@ -32,8 +37,13 @@ int main(int argc, char *argv[]) { run_buffer_tests(); run_http_tests(); + run_http_helper_tests(); run_config_tests(); run_routing_tests(); + run_host_rewrite_tests(); + run_patch_tests(); + run_auth_tests(); + run_rate_limit_tests(); test_summary(); diff --git a/tests/test_patch.c b/tests/test_patch.c new file mode 100644 index 0000000..3b8b908 --- /dev/null +++ b/tests/test_patch.c @@ -0,0 +1,237 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/patch.h" +#include + +void test_patch_has_rules(void) { + TEST_SUITE_BEGIN("Patch Has Rules"); + + patch_config_t empty_config; + memset(&empty_config, 0, sizeof(empty_config)); + empty_config.rule_count = 0; + TEST_ASSERT_EQ(0, patch_has_rules(&empty_config), "Empty config has no rules"); + + patch_config_t config_with_rules; + memset(&config_with_rules, 0, sizeof(config_with_rules)); + config_with_rules.rule_count = 1; + strcpy(config_with_rules.rules[0].key, "test"); + config_with_rules.rules[0].key_len = 4; + TEST_ASSERT_EQ(1, patch_has_rules(&config_with_rules), "Config with rules returns true"); + + TEST_ASSERT_EQ(0, patch_has_rules(NULL), "NULL config has no rules"); + + TEST_SUITE_END(); +} + +void test_patch_check_for_block(void) { + TEST_SUITE_BEGIN("Patch Check For Block"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 1; + strcpy(config.rules[0].key, "blocked-word"); + config.rules[0].key_len = strlen("blocked-word"); + config.rules[0].is_null = 1; + + const char *data_with_block = "This content contains blocked-word in it"; + TEST_ASSERT_EQ(1, patch_check_for_block(&config, data_with_block, strlen(data_with_block)), + "Content with blocked word is blocked"); + + const char *data_without_block = "This content is clean and allowed"; + TEST_ASSERT_EQ(0, patch_check_for_block(&config, data_without_block, strlen(data_without_block)), + "Clean content is not blocked"); + + patch_config_t empty_config; + memset(&empty_config, 0, sizeof(empty_config)); + TEST_ASSERT_EQ(0, patch_check_for_block(&empty_config, data_with_block, strlen(data_with_block)), + "Empty config blocks nothing"); + + TEST_ASSERT_EQ(0, patch_check_for_block(NULL, data_with_block, strlen(data_with_block)), + "NULL config blocks nothing"); + TEST_ASSERT_EQ(0, patch_check_for_block(&config, NULL, 0), + "NULL data is not blocked"); + + TEST_SUITE_END(); +} + +void test_patch_apply_simple_replace(void) { + TEST_SUITE_BEGIN("Patch Apply Simple Replacement"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 1; + strcpy(config.rules[0].key, "old"); + config.rules[0].key_len = 3; + strcpy(config.rules[0].value, "new"); + config.rules[0].value_len = 3; + config.rules[0].is_null = 0; + + const char *input = "This is old text with old words"; + char output[256]; + patch_result_t result = patch_apply(&config, input, strlen(input), output, sizeof(output)); + + TEST_ASSERT_EQ(0, result.should_block, "Simple replace does not block"); + TEST_ASSERT(result.output_len > 0, "Output has content"); + TEST_ASSERT(strstr(output, "new") != NULL, "Replacement word present"); + + TEST_SUITE_END(); +} + +void test_patch_apply_size_change(void) { + TEST_SUITE_BEGIN("Patch Apply Size Change"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 1; + strcpy(config.rules[0].key, "short"); + config.rules[0].key_len = 5; + strcpy(config.rules[0].value, "much-longer-replacement"); + config.rules[0].value_len = strlen("much-longer-replacement"); + config.rules[0].is_null = 0; + + const char *input = "This is short text"; + char output[256]; + patch_result_t result = patch_apply(&config, input, strlen(input), output, sizeof(output)); + + TEST_ASSERT_EQ(0, result.should_block, "Size change replace does not block"); + TEST_ASSERT(result.output_len > strlen(input), "Output is longer than input"); + TEST_ASSERT(result.size_delta > 0, "Size delta is positive"); + + TEST_SUITE_END(); +} + +void test_patch_apply_shrink(void) { + TEST_SUITE_BEGIN("Patch Apply Shrink"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 1; + strcpy(config.rules[0].key, "very-long-word"); + config.rules[0].key_len = strlen("very-long-word"); + strcpy(config.rules[0].value, "tiny"); + config.rules[0].value_len = 4; + config.rules[0].is_null = 0; + + const char *input = "This has very-long-word in it"; + char output[256]; + patch_result_t result = patch_apply(&config, input, strlen(input), output, sizeof(output)); + + TEST_ASSERT_EQ(0, result.should_block, "Shrink replace does not block"); + TEST_ASSERT(result.output_len < strlen(input), "Output is shorter than input"); + TEST_ASSERT(result.size_delta < 0, "Size delta is negative"); + + TEST_SUITE_END(); +} + +void test_patch_apply_multiple_rules(void) { + TEST_SUITE_BEGIN("Patch Apply Multiple Rules"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 3; + + strcpy(config.rules[0].key, "foo"); + config.rules[0].key_len = 3; + strcpy(config.rules[0].value, "bar"); + config.rules[0].value_len = 3; + config.rules[0].is_null = 0; + + strcpy(config.rules[1].key, "hello"); + config.rules[1].key_len = 5; + strcpy(config.rules[1].value, "world"); + config.rules[1].value_len = 5; + config.rules[1].is_null = 0; + + strcpy(config.rules[2].key, "test"); + config.rules[2].key_len = 4; + strcpy(config.rules[2].value, "demo"); + config.rules[2].value_len = 4; + config.rules[2].is_null = 0; + + const char *input = "foo says hello during test"; + char output[256]; + patch_result_t result = patch_apply(&config, input, strlen(input), output, sizeof(output)); + + TEST_ASSERT_EQ(0, result.should_block, "Multiple rule replace does not block"); + TEST_ASSERT(result.output_len > 0, "Output has content"); + output[result.output_len] = '\0'; + TEST_ASSERT(strstr(output, "bar") != NULL, "First replacement applied"); + TEST_ASSERT(strstr(output, "world") != NULL, "Second replacement applied"); + TEST_ASSERT(strstr(output, "demo") != NULL, "Third replacement applied"); + + TEST_SUITE_END(); +} + +void test_patch_apply_no_match(void) { + TEST_SUITE_BEGIN("Patch Apply No Match"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 1; + strcpy(config.rules[0].key, "nonexistent"); + config.rules[0].key_len = strlen("nonexistent"); + strcpy(config.rules[0].value, "replacement"); + config.rules[0].value_len = strlen("replacement"); + config.rules[0].is_null = 0; + + const char *input = "This text has no matching patterns"; + char output[256]; + patch_result_t result = patch_apply(&config, input, strlen(input), output, sizeof(output)); + + TEST_ASSERT_EQ(0, result.should_block, "No match does not block"); + TEST_ASSERT_EQ(strlen(input), result.output_len, "Output length equals input length"); + TEST_ASSERT_EQ(0, result.size_delta, "Size delta is zero for no changes"); + + TEST_SUITE_END(); +} + +void test_patch_apply_empty_config(void) { + TEST_SUITE_BEGIN("Patch Apply Empty Config"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + + const char *input = "This is some input data"; + char output[256]; + patch_result_t result = patch_apply(&config, input, strlen(input), output, sizeof(output)); + + TEST_ASSERT_EQ(0, result.should_block, "Empty config does not block"); + TEST_ASSERT_EQ(strlen(input), result.output_len, "Output equals input for empty config"); + + TEST_SUITE_END(); +} + +void test_patch_apply_block_rule(void) { + TEST_SUITE_BEGIN("Patch Apply Block Rule"); + + patch_config_t config; + memset(&config, 0, sizeof(config)); + config.rule_count = 1; + strcpy(config.rules[0].key, "malware"); + config.rules[0].key_len = strlen("malware"); + config.rules[0].is_null = 1; + + const char *malicious = "This contains malware content"; + char output[256]; + patch_result_t result = patch_apply(&config, malicious, strlen(malicious), output, sizeof(output)); + + TEST_ASSERT_EQ(1, result.should_block, "Block rule triggers block"); + + const char *clean = "This is clean content"; + result = patch_apply(&config, clean, strlen(clean), output, sizeof(output)); + TEST_ASSERT_EQ(0, result.should_block, "Clean content not blocked"); + + TEST_SUITE_END(); +} + +void run_patch_tests(void) { + test_patch_has_rules(); + test_patch_check_for_block(); + test_patch_apply_simple_replace(); + test_patch_apply_size_change(); + test_patch_apply_shrink(); + test_patch_apply_multiple_rules(); + test_patch_apply_no_match(); + test_patch_apply_empty_config(); + test_patch_apply_block_rule(); +} diff --git a/tests/test_rate_limit.c b/tests/test_rate_limit.c new file mode 100644 index 0000000..c26dd57 --- /dev/null +++ b/tests/test_rate_limit.c @@ -0,0 +1,152 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/rate_limit.h" +#include + +void test_rate_limit_disabled(void) { + TEST_SUITE_BEGIN("Rate Limit Disabled"); + + rate_limit_cleanup(); + + TEST_ASSERT_EQ(1, rate_limit_check("192.168.1.1"), "Rate limit check passes when disabled"); + TEST_ASSERT_EQ(1, rate_limit_check("10.0.0.1"), "Any IP passes when disabled"); + TEST_ASSERT_EQ(1, rate_limit_check(NULL), "NULL IP passes when disabled"); + + TEST_SUITE_END(); +} + +void test_rate_limit_init(void) { + TEST_SUITE_BEGIN("Rate Limit Init"); + + rate_limit_cleanup(); + rate_limit_init(100, 60); + + TEST_ASSERT_EQ(1, rate_limit_check("192.168.1.100"), "First request passes"); + + rate_limit_cleanup(); + + TEST_SUITE_END(); +} + +void test_rate_limit_within_limit(void) { + TEST_SUITE_BEGIN("Rate Limit Within Limit"); + + rate_limit_cleanup(); + rate_limit_init(10, 60); + + const char *ip = "192.168.1.50"; + int all_passed = 1; + + for (int i = 0; i < 10; i++) { + if (rate_limit_check(ip) != 1) { + all_passed = 0; + break; + } + } + + TEST_ASSERT_EQ(1, all_passed, "All requests within limit pass"); + + rate_limit_cleanup(); + + TEST_SUITE_END(); +} + +void test_rate_limit_exceeded(void) { + TEST_SUITE_BEGIN("Rate Limit Exceeded"); + + rate_limit_cleanup(); + rate_limit_init(5, 60); + + const char *ip = "192.168.1.200"; + + for (int i = 0; i < 5; i++) { + rate_limit_check(ip); + } + + int exceeded = rate_limit_check(ip); + TEST_ASSERT_EQ(0, exceeded, "Request exceeding limit is blocked"); + + rate_limit_cleanup(); + + TEST_SUITE_END(); +} + +void test_rate_limit_different_ips(void) { + TEST_SUITE_BEGIN("Rate Limit Different IPs"); + + rate_limit_cleanup(); + rate_limit_init(3, 60); + + for (int i = 0; i < 3; i++) { + rate_limit_check("10.0.0.1"); + } + TEST_ASSERT_EQ(0, rate_limit_check("10.0.0.1"), "First IP blocked after limit"); + + TEST_ASSERT_EQ(1, rate_limit_check("10.0.0.2"), "Second IP still allowed"); + TEST_ASSERT_EQ(1, rate_limit_check("10.0.0.3"), "Third IP still allowed"); + + rate_limit_cleanup(); + + TEST_SUITE_END(); +} + +void test_rate_limit_null_ip(void) { + TEST_SUITE_BEGIN("Rate Limit NULL IP"); + + rate_limit_cleanup(); + rate_limit_init(5, 60); + + TEST_ASSERT_EQ(1, rate_limit_check(NULL), "NULL IP passes (fail-open)"); + + rate_limit_cleanup(); + + TEST_SUITE_END(); +} + +void test_rate_limit_purge_expired(void) { + TEST_SUITE_BEGIN("Rate Limit Purge Expired"); + + rate_limit_cleanup(); + rate_limit_init(100, 60); + + rate_limit_check("192.168.1.1"); + rate_limit_check("192.168.1.2"); + rate_limit_check("192.168.1.3"); + + rate_limit_purge_expired(); + + TEST_ASSERT_EQ(1, rate_limit_check("192.168.1.1"), "Entry still valid after purge"); + + rate_limit_cleanup(); + + TEST_SUITE_END(); +} + +void test_rate_limit_cleanup(void) { + TEST_SUITE_BEGIN("Rate Limit Cleanup"); + + rate_limit_init(10, 60); + + for (int i = 0; i < 100; i++) { + char ip[32]; + snprintf(ip, sizeof(ip), "192.168.%d.%d", i / 256, i % 256); + rate_limit_check(ip); + } + + rate_limit_cleanup(); + + TEST_ASSERT_EQ(1, rate_limit_check("192.168.1.1"), "After cleanup, rate limit disabled"); + + TEST_SUITE_END(); +} + +void run_rate_limit_tests(void) { + test_rate_limit_disabled(); + test_rate_limit_init(); + test_rate_limit_within_limit(); + test_rate_limit_exceeded(); + test_rate_limit_different_ips(); + test_rate_limit_null_ip(); + test_rate_limit_purge_expired(); + test_rate_limit_cleanup(); +}