From b1409609d0327166a424380c3fcd9e5c7c34e227 Mon Sep 17 00:00:00 2001 From: retoor Date: Sat, 29 Nov 2025 01:49:14 +0100 Subject: [PATCH] Refactored using AI --- .gitignore | 6 + Makefile | 115 +++++- src/buffer.c | 75 ++++ src/buffer.h | 14 + src/config.c | 159 +++++++++ src/config.h | 13 + src/connection.c | 769 +++++++++++++++++++++++++++++++++++++++++ src/connection.h | 28 ++ src/dashboard.c | 505 +++++++++++++++++++++++++++ src/dashboard.h | 9 + src/http.c | 152 ++++++++ src/http.h | 10 + src/logging.c | 44 +++ src/logging.h | 10 + src/main.c | 115 ++++++ src/monitor.c | 441 +++++++++++++++++++++++ src/monitor.h | 25 ++ src/ssl_handler.c | 76 ++++ src/ssl_handler.h | 14 + src/types.h | 172 +++++++++ tests/test_buffer.c | 159 +++++++++ tests/test_config.c | 250 ++++++++++++++ tests/test_framework.h | 55 +++ tests/test_http.c | 196 +++++++++++ tests/test_main.c | 41 +++ tests/test_routing.c | 258 ++++++++++++++ 26 files changed, 3705 insertions(+), 6 deletions(-) create mode 100644 src/buffer.c create mode 100644 src/buffer.h create mode 100644 src/config.c create mode 100644 src/config.h create mode 100644 src/connection.c create mode 100644 src/connection.h create mode 100644 src/dashboard.c create mode 100644 src/dashboard.h create mode 100644 src/http.c create mode 100644 src/http.h create mode 100644 src/logging.c create mode 100644 src/logging.h create mode 100644 src/main.c create mode 100644 src/monitor.c create mode 100644 src/monitor.h create mode 100644 src/ssl_handler.c create mode 100644 src/ssl_handler.h create mode 100644 src/types.h create mode 100644 tests/test_buffer.c create mode 100644 tests/test_config.c create mode 100644 tests/test_framework.h create mode 100644 tests/test_http.c create mode 100644 tests/test_main.c create mode 100644 tests/test_routing.c diff --git a/.gitignore b/.gitignore index 853f8d9..7f5d3e0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,10 @@ *.bak2 *.db *.txt +*.log +*_test +*.py +*.sh +build +__pycache__/ rproxy diff --git a/Makefile b/Makefile index 4f0cd30..d32e503 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,114 @@ +CC = gcc +CFLAGS = -Wall -Wextra -O2 -g -D_GNU_SOURCE +LDFLAGS = -lssl -lcrypto -lsqlite3 -lm -all: build run +SRC_DIR = src +BUILD_DIR = build +TESTS_DIR = tests -build: - gcc -Wall -Wextra -O2 -g -pthread cJSON.c rproxy.c -o rproxy -lssl -lcrypto -lsqlite3 -lm +SOURCES = $(SRC_DIR)/main.c \ + $(SRC_DIR)/buffer.c \ + $(SRC_DIR)/logging.c \ + $(SRC_DIR)/config.c \ + $(SRC_DIR)/monitor.c \ + $(SRC_DIR)/http.c \ + $(SRC_DIR)/ssl_handler.c \ + $(SRC_DIR)/connection.c \ + $(SRC_DIR)/dashboard.c \ + cJSON.c -run: - ./rproxy +OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(SOURCES))) + +TARGET = rproxy + +TEST_SOURCES = $(TESTS_DIR)/test_main.c \ + $(TESTS_DIR)/test_http.c \ + $(TESTS_DIR)/test_buffer.c \ + $(TESTS_DIR)/test_config.c \ + $(TESTS_DIR)/test_routing.c + +TEST_OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(TEST_SOURCES))) + +TEST_LIB_SOURCES = $(SRC_DIR)/buffer.c \ + $(SRC_DIR)/logging.c \ + $(SRC_DIR)/config.c \ + $(SRC_DIR)/monitor.c \ + $(SRC_DIR)/http.c \ + $(SRC_DIR)/ssl_handler.c \ + $(SRC_DIR)/connection.c \ + $(SRC_DIR)/dashboard.c \ + cJSON.c + +TEST_LIB_OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(TEST_LIB_SOURCES))) + +TEST_TARGET = rproxy_test + +.PHONY: all clean test legacy run + +all: $(BUILD_DIR) $(TARGET) + +$(BUILD_DIR): + mkdir -p $(BUILD_DIR) + +$(TARGET): $(OBJECTS) + $(CC) $(OBJECTS) -o $@ $(LDFLAGS) + +$(BUILD_DIR)/main.o: $(SRC_DIR)/main.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/buffer.o: $(SRC_DIR)/buffer.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/logging.o: $(SRC_DIR)/logging.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/config.o: $(SRC_DIR)/config.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/monitor.o: $(SRC_DIR)/monitor.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/http.o: $(SRC_DIR)/http.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/ssl_handler.o: $(SRC_DIR)/ssl_handler.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/connection.o: $(SRC_DIR)/connection.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/dashboard.o: $(SRC_DIR)/dashboard.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/cJSON.o: cJSON.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/test_main.o: $(TESTS_DIR)/test_main.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_http.o: $(TESTS_DIR)/test_http.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_buffer.o: $(TESTS_DIR)/test_buffer.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_config.o: $(TESTS_DIR)/test_config.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(BUILD_DIR)/test_routing.o: $(TESTS_DIR)/test_routing.c + $(CC) $(CFLAGS) -I$(SRC_DIR) -c $< -o $@ + +$(TEST_TARGET): $(BUILD_DIR) $(TEST_OBJECTS) $(TEST_LIB_OBJECTS) + $(CC) $(TEST_OBJECTS) $(TEST_LIB_OBJECTS) -o $@ $(LDFLAGS) + +test: $(TEST_TARGET) + ./$(TEST_TARGET) + +legacy: rproxy.c cJSON.c cJSON.h + $(CC) $(CFLAGS) rproxy.c cJSON.c -o rproxy_legacy $(LDFLAGS) + +run: $(TARGET) + ./$(TARGET) clean: - rm -f rproxy + rm -rf $(BUILD_DIR) $(TARGET) $(TEST_TARGET) rproxy_legacy diff --git a/src/buffer.c b/src/buffer.c new file mode 100644 index 0000000..d0fc88a --- /dev/null +++ b/src/buffer.c @@ -0,0 +1,75 @@ +#include "buffer.h" +#include "logging.h" +#include +#include +#include + +int buffer_init(buffer_t *buf, size_t capacity) { + buf->data = malloc(capacity); + if (!buf->data) { + log_error("Failed to allocate buffer"); + return -1; + } + buf->capacity = capacity; + buf->head = 0; + buf->tail = 0; + return 0; +} + +void buffer_free(buffer_t *buf) { + if (buf->data) { + free(buf->data); + buf->data = NULL; + } + buf->capacity = 0; + buf->head = 0; + buf->tail = 0; +} + +size_t buffer_available_read(buffer_t *buf) { + return buf->tail - buf->head; +} + +size_t buffer_available_write(buffer_t *buf) { + return buf->capacity - buf->tail; +} + +int buffer_ensure_capacity(buffer_t *buf, size_t required) { + if (buf->capacity >= required) return 0; + + size_t new_capacity = buf->capacity; + while (new_capacity < required) { + new_capacity *= 2; + if (new_capacity > SIZE_MAX / 2) { + log_error("Buffer size limit exceeded"); + return -1; + } + } + + char *new_data = realloc(buf->data, new_capacity); + if (!new_data) { + log_error("Failed to reallocate buffer"); + return -1; + } + buf->data = new_data; + buf->capacity = new_capacity; + return 0; +} + +void buffer_compact(buffer_t *buf) { + if (buf->head == 0) return; + size_t len = buf->tail - buf->head; + if (len > 0) { + memmove(buf->data, buf->data + buf->head, len); + } + buf->head = 0; + buf->tail = len; +} + +void buffer_consume(buffer_t *buf, size_t bytes) { + buf->head += bytes; + if (buf->head >= buf->tail) { + buf->head = 0; + buf->tail = 0; + } +} diff --git a/src/buffer.h b/src/buffer.h new file mode 100644 index 0000000..bba2741 --- /dev/null +++ b/src/buffer.h @@ -0,0 +1,14 @@ +#ifndef RPROXY_BUFFER_H +#define RPROXY_BUFFER_H + +#include "types.h" + +int buffer_init(buffer_t *buf, size_t capacity); +void buffer_free(buffer_t *buf); +size_t buffer_available_read(buffer_t *buf); +size_t buffer_available_write(buffer_t *buf); +int buffer_ensure_capacity(buffer_t *buf, size_t required); +void buffer_compact(buffer_t *buf); +void buffer_consume(buffer_t *buf, size_t bytes); + +#endif diff --git a/src/config.c b/src/config.c new file mode 100644 index 0000000..0041235 --- /dev/null +++ b/src/config.c @@ -0,0 +1,159 @@ +#include "config.h" +#include "logging.h" +#include "../cJSON.h" +#include +#include +#include + +app_config_t config; + +static char* read_file_to_string(const char *filename) { + FILE *f = fopen(filename, "rb"); + if (!f) return NULL; + + fseek(f, 0, SEEK_END); + long length = ftell(f); + if (length < 0 || length > 1024*1024) { + fclose(f); + return NULL; + } + + fseek(f, 0, SEEK_SET); + char *buffer = malloc(length + 1); + if (buffer) { + size_t read_len = fread(buffer, 1, length, f); + buffer[read_len] = '\0'; + } + fclose(f); + return buffer; +} + +int config_load(const char *filename) { + log_info("Loading configuration from %s", filename); + char *json_string = read_file_to_string(filename); + if (!json_string) { + log_error("Could not read config file"); + return 0; + } + + cJSON *root = cJSON_Parse(json_string); + free(json_string); + if (!root) { + fprintf(stderr, "JSON parse error: %s\n", cJSON_GetErrorPtr()); + return 0; + } + + cJSON *port_item = cJSON_GetObjectItem(root, "port"); + config.port = cJSON_IsNumber(port_item) ? port_item->valueint : 8080; + + if (config.port < 1 || config.port > 65535) { + fprintf(stderr, "Invalid port number: %d\n", config.port); + cJSON_Delete(root); + return 0; + } + + cJSON *proxy_array = cJSON_GetObjectItem(root, "reverse_proxy"); + if (cJSON_IsArray(proxy_array)) { + config.route_count = cJSON_GetArraySize(proxy_array); + if (config.route_count <= 0) { + cJSON_Delete(root); + return 0; + } + + config.routes = calloc(config.route_count, sizeof(route_config_t)); + if (!config.routes) { + log_error("Failed to allocate memory for routes"); + cJSON_Delete(root); + return 0; + } + + int i = 0; + cJSON *route_item; + cJSON_ArrayForEach(route_item, proxy_array) { + route_config_t *route = &config.routes[i]; + + cJSON *hostname = cJSON_GetObjectItem(route_item, "hostname"); + cJSON *upstream_host = cJSON_GetObjectItem(route_item, "upstream_host"); + cJSON *upstream_port = cJSON_GetObjectItem(route_item, "upstream_port"); + + if (!cJSON_IsString(hostname) || !cJSON_IsString(upstream_host) || !cJSON_IsNumber(upstream_port)) { + fprintf(stderr, "Invalid route configuration at index %d\n", i); + continue; + } + + strncpy(route->hostname, hostname->valuestring, sizeof(route->hostname) - 1); + strncpy(route->upstream_host, upstream_host->valuestring, sizeof(route->upstream_host) - 1); + route->upstream_port = upstream_port->valueint; + + if (route->upstream_port < 1 || route->upstream_port > 65535) { + fprintf(stderr, "Invalid upstream port for %s: %d\n", route->hostname, route->upstream_port); + continue; + } + + route->use_ssl = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "use_ssl")); + route->rewrite_host = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "rewrite_host")); + + log_info("Route configured: %s -> %s:%d (SSL: %s, Rewrite Host: %s)", + route->hostname, route->upstream_host, route->upstream_port, + route->use_ssl ? "yes" : "no", route->rewrite_host ? "yes" : "no"); + i++; + } + } + cJSON_Delete(root); + log_info("Loaded %d routes from %s", config.route_count, filename); + return 1; +} + +void config_free(void) { + if (config.routes) { + free(config.routes); + config.routes = NULL; + } + config.route_count = 0; +} + +void config_create_default(const char *filename) { + FILE *f = fopen(filename, "r"); + if (f) { + fclose(f); + return; + } + + f = fopen(filename, "w"); + if (!f) { + log_error("Cannot create default config file"); + return; + } + + fprintf(f, "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"localhost\",\n" + " \"upstream_host\": \"127.0.0.1\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": true\n" + " },\n" + " {\n" + " \"hostname\": \"example.com\",\n" + " \"upstream_host\": \"127.0.0.1\",\n" + " \"upstream_port\": 5000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"); + fclose(f); + log_info("Created default config file: %s", filename); +} + +route_config_t *config_find_route(const char *hostname) { + if (!hostname) return NULL; + for (int i = 0; i < config.route_count; i++) { + if (strcasecmp(hostname, config.routes[i].hostname) == 0) { + return &config.routes[i]; + } + } + return NULL; +} diff --git a/src/config.h b/src/config.h new file mode 100644 index 0000000..3c17c7c --- /dev/null +++ b/src/config.h @@ -0,0 +1,13 @@ +#ifndef RPROXY_CONFIG_H +#define RPROXY_CONFIG_H + +#include "types.h" + +extern app_config_t config; + +int config_load(const char *filename); +void config_free(void); +void config_create_default(const char *filename); +route_config_t *config_find_route(const char *hostname); + +#endif diff --git a/src/connection.c b/src/connection.c new file mode 100644 index 0000000..5264981 --- /dev/null +++ b/src/connection.c @@ -0,0 +1,769 @@ +#include "connection.h" +#include "buffer.h" +#include "logging.h" +#include "config.h" +#include "monitor.h" +#include "http.h" +#include "ssl_handler.h" +#include "dashboard.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +connection_t connections[MAX_FDS]; +int epoll_fd = -1; + +void connection_init_all(void) { + for (int i = 0; i < MAX_FDS; i++) { + connections[i].type = CONN_TYPE_UNUSED; + connections[i].fd = -1; + } +} + +void connection_set_non_blocking(int fd) { + int flags = fcntl(fd, F_GETFL, 0); + if (flags >= 0) { + fcntl(fd, F_SETFL, flags | O_NONBLOCK); + } +} + +void connection_set_tcp_keepalive(int fd) { + int yes = 1; + setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &yes, sizeof(yes)); + int idle = 60; + setsockopt(fd, IPPROTO_TCP, TCP_KEEPIDLE, &idle, sizeof(idle)); + int interval = 10; + setsockopt(fd, IPPROTO_TCP, TCP_KEEPINTVL, &interval, sizeof(interval)); + int maxpkt = 6; + setsockopt(fd, IPPROTO_TCP, TCP_KEEPCNT, &maxpkt, sizeof(maxpkt)); +} + +void connection_add_to_epoll(int fd, uint32_t events) { + struct epoll_event event = { .data.fd = fd, .events = events }; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &event) == -1) { + log_error("epoll_ctl_add failed"); + close(fd); + } +} + +void connection_modify_epoll(int fd, uint32_t events) { + struct epoll_event event = { .data.fd = fd, .events = events }; + if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, fd, &event) == -1) { + if(errno != EBADF && errno != ENOENT) { + log_debug("epoll_ctl_mod failed for fd %d: %s", fd, strerror(errno)); + } + } +} + +void connection_setup_listener(int port) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + log_error("socket failed"); + exit(EXIT_FAILURE); + } + + int reuse = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) < 0) { + log_error("setsockopt SO_REUSEADDR failed"); + close(fd); + exit(EXIT_FAILURE); + } + + struct sockaddr_in addr = { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr.s_addr = htonl(INADDR_ANY) + }; + + if (bind(fd, (struct sockaddr*)&addr, sizeof(addr)) == -1) { + log_error("bind failed"); + close(fd); + exit(EXIT_FAILURE); + } + + if (listen(fd, SOMAXCONN) == -1) { + log_error("listen failed"); + close(fd); + exit(EXIT_FAILURE); + } + + connection_set_non_blocking(fd); + connection_add_to_epoll(fd, EPOLLIN); + + connections[fd].type = CONN_TYPE_LISTENER; + connections[fd].fd = fd; + + log_info("Listening on port %d (fd=%d)", port, fd); +} + +void connection_accept(int listener_fd) { + while (1) { + struct sockaddr_in client_addr; + socklen_t client_len = sizeof(client_addr); + + int client_fd = accept(listener_fd, (struct sockaddr*)&client_addr, &client_len); + if (client_fd == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + log_error("accept failed"); + } + break; + } + + if (client_fd >= MAX_FDS) { + log_error("Connection fd too high, closing"); + close(client_fd); + continue; + } + + connection_set_non_blocking(client_fd); + connection_set_tcp_keepalive(client_fd); + connection_add_to_epoll(client_fd, EPOLLIN); + + connection_t *conn = &connections[client_fd]; + memset(conn, 0, sizeof(connection_t)); + conn->type = CONN_TYPE_CLIENT; + conn->state = CLIENT_STATE_READING_HEADERS; + conn->fd = client_fd; + conn->last_activity = time(NULL); + + if (buffer_init(&conn->read_buf, CHUNK_SIZE) < 0 || + buffer_init(&conn->write_buf, CHUNK_SIZE) < 0) { + connection_close(client_fd); + continue; + } + + __sync_fetch_and_add(&monitor.active_connections, 1); + log_debug("New connection on fd %d from %s, total: %d", + client_fd, inet_ntoa(client_addr.sin_addr), monitor.active_connections); + } +} + +void connection_close(int fd) { + if (fd < 0 || fd >= MAX_FDS) return; + + connection_t *conn = &connections[fd]; + if (conn->type == CONN_TYPE_UNUSED || conn->fd == -1) return; + + connection_t *pair = conn->pair; + + if (pair) { + if (conn->type == CONN_TYPE_UPSTREAM && pair->type == CONN_TYPE_CLIENT) { + log_debug("Upstream fd %d is closing. Resetting client fd %d to READING_HEADERS.", fd, pair->fd); + pair->state = CLIENT_STATE_READING_HEADERS; + pair->pair = NULL; + } else if (conn->type == CONN_TYPE_CLIENT && pair->type == CONN_TYPE_UPSTREAM) { + log_debug("Client fd %d is closing. Closing orphaned upstream pair fd %d.", fd, pair->fd); + pair->pair = NULL; + connection_close(pair->fd); + } + conn->pair = NULL; + } + + if (conn->vhost_stats && conn->request_start_time > 0) { + monitor_record_request_end(conn->vhost_stats, conn->request_start_time); + conn->request_start_time = 0; + } + + if (conn->type == CONN_TYPE_CLIENT) { + __sync_fetch_and_sub(&monitor.active_connections, 1); + } + + log_debug("Closing and cleaning up fd %d", fd); + + epoll_ctl(epoll_fd, EPOLL_CTL_DEL, fd, NULL); + + if (conn->ssl) { + SSL_shutdown(conn->ssl); + SSL_free(conn->ssl); + } + + close(fd); + + buffer_free(&conn->read_buf); + buffer_free(&conn->write_buf); + + memset(conn, 0, sizeof(connection_t)); + conn->type = CONN_TYPE_UNUSED; + conn->fd = -1; +} + +int connection_do_read(connection_t *conn) { + if (!conn) return -1; + + buffer_t *buf = &conn->read_buf; + buffer_compact(buf); + + size_t available = buffer_available_write(buf); + if (available == 0) { + if (buffer_ensure_capacity(buf, buf->capacity * 2) < 0) { + return -1; + } + available = buffer_available_write(buf); + } + + int bytes_read; + if (conn->ssl && conn->ssl_handshake_done) { + bytes_read = SSL_read(conn->ssl, buf->data + buf->tail, available); + if (bytes_read <= 0) { + int ssl_error = SSL_get_error(conn->ssl, bytes_read); + if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { + errno = EAGAIN; + return 0; + } + return bytes_read; + } + } else { + bytes_read = read(conn->fd, buf->data + buf->tail, available); + } + + if (bytes_read > 0) { + buf->tail += bytes_read; + conn->last_activity = time(NULL); + if (conn->vhost_stats) { + monitor_record_bytes(conn->vhost_stats, 0, bytes_read); + } + } + return bytes_read; +} + +int connection_do_write(connection_t *conn) { + if (!conn) return -1; + + buffer_t *buf = &conn->write_buf; + size_t available = buffer_available_read(buf); + if (available == 0) return 0; + + int written; + if (conn->ssl && conn->ssl_handshake_done) { + written = SSL_write(conn->ssl, buf->data + buf->head, available); + if (written <= 0) { + int ssl_error = SSL_get_error(conn->ssl, written); + if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { + errno = EAGAIN; + return 0; + } + return written; + } + } else { + written = write(conn->fd, buf->data + buf->head, available); + } + + if (written > 0) { + buffer_consume(buf, written); + conn->last_activity = time(NULL); + if (conn->vhost_stats) { + monitor_record_bytes(conn->vhost_stats, written, 0); + } + } + return written; +} + +void connection_send_error_response(connection_t *conn, int code, const char* status, const char* body) { + if (!conn || !status || !body) return; + + char response[2048]; + int len = snprintf(response, sizeof(response), + "HTTP/1.1 %d %s\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Content-Length: %zu\r\n" + "Connection: close\r\n" + "Server: ReverseProxy/4.0\r\n" + "\r\n" + "%s", + code, status, strlen(body), body); + + if (len > 0 && (size_t)len < sizeof(response)) { + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) { + memcpy(conn->write_buf.data + conn->write_buf.tail, response, len); + conn->write_buf.tail += len; + + struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); + } + } + + conn->state = CLIENT_STATE_ERROR; + conn->request.keep_alive = 0; +} + +void connection_connect_to_upstream(connection_t *client, const char *data, size_t data_len) { + if (!client || !data) return; + + route_config_t *route = config_find_route(client->request.host); + if (!route) { + connection_send_error_response(client, 502, "Bad Gateway", "No route configured for this host"); + return; + } + + int up_fd = socket(AF_INET, SOCK_STREAM, 0); + if (up_fd < 0) { + connection_send_error_response(client, 502, "Bad Gateway", "Failed to create upstream socket"); + return; + } + + if (up_fd >= MAX_FDS) { + close(up_fd); + connection_send_error_response(client, 502, "Bad Gateway", "Connection limit exceeded"); + return; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(route->upstream_port); + + if (inet_pton(AF_INET, route->upstream_host, &addr.sin_addr) <= 0) { + struct hostent *he = gethostbyname(route->upstream_host); + if (!he) { + close(up_fd); + connection_send_error_response(client, 502, "Bad Gateway", "Cannot resolve upstream hostname"); + return; + } + memcpy(&addr.sin_addr, he->h_addr_list[0], he->h_length); + } + + connection_set_non_blocking(up_fd); + connection_set_tcp_keepalive(up_fd); + + int connect_result = connect(up_fd, (struct sockaddr*)&addr, sizeof(addr)); + if (connect_result < 0 && errno != EINPROGRESS) { + close(up_fd); + connection_send_error_response(client, 502, "Bad Gateway", "Failed to connect to upstream"); + return; + } + + connection_add_to_epoll(up_fd, EPOLLIN | EPOLLOUT); + + connection_t *up = &connections[up_fd]; + memset(up, 0, sizeof(connection_t)); + up->type = CONN_TYPE_UPSTREAM; + up->fd = up_fd; + up->last_activity = time(NULL); + + client->pair = up; + up->pair = client; + up->vhost_stats = client->vhost_stats; + + if (buffer_init(&up->read_buf, CHUNK_SIZE) < 0 || + buffer_init(&up->write_buf, CHUNK_SIZE) < 0) { + connection_close(client->fd); + return; + } + + char *data_to_send = (char*)data; + size_t len_to_send = data_len; + char *modified_request = NULL; + + if (route->rewrite_host) { + char new_host_header[512]; + const char *old_host_header_start = NULL; + const char *old_host_header_end = NULL; + + const char *current = data; + const char *end = data + data_len; + while(current < end) { + const char* line_end = memchr(current, '\n', end - current); + if (!line_end) break; + if (strncasecmp(current, "Host:", 5) == 0) { + old_host_header_start = current; + old_host_header_end = line_end + 1; + break; + } + current = line_end + 1; + } + + if (old_host_header_start) { + if (route->upstream_port == 80 || route->upstream_port == 443) { + snprintf(new_host_header, sizeof(new_host_header), "Host: %s\r\n", route->upstream_host); + } else { + snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", route->upstream_host, route->upstream_port); + } + size_t new_host_len = strlen(new_host_header); + size_t old_host_len = old_host_header_end - old_host_header_start; + + len_to_send = data_len - old_host_len + new_host_len; + modified_request = malloc(len_to_send + 1); + if (modified_request) { + char* p = modified_request; + size_t prefix_len = old_host_header_start - data; + memcpy(p, data, prefix_len); + p += prefix_len; + memcpy(p, new_host_header, new_host_len); + p += new_host_len; + size_t suffix_len = data_len - (old_host_header_end - data); + memcpy(p, old_host_header_end, suffix_len); + data_to_send = modified_request; + } + } + } + + if (buffer_ensure_capacity(&up->write_buf, len_to_send) == 0) { + memcpy(up->write_buf.data, data_to_send, len_to_send); + up->write_buf.tail = len_to_send; + } + + if (modified_request) { + free(modified_request); + } + + if (route->use_ssl) { + up->ssl = SSL_new(ssl_ctx); + if (!up->ssl) { + connection_close(client->fd); + return; + } + + const char *sni_hostname = route->rewrite_host ? route->upstream_host : client->request.host; + SSL_set_tlsext_host_name(up->ssl, sni_hostname); + SSL_set_fd(up->ssl, up_fd); + SSL_set_connect_state(up->ssl); + + up->ssl_handshake_done = 0; + + log_debug("Setting SNI to: %s for upstream %s:%d", + sni_hostname, route->upstream_host, route->upstream_port); + } + + up->state = CLIENT_STATE_FORWARDING; + + log_debug("Connecting to upstream %s:%d on fd %d (SSL: %s)", + route->upstream_host, route->upstream_port, up_fd, route->use_ssl ? "yes" : "no"); +} + +static void handle_client_read(connection_t *conn) { + int bytes_read = connection_do_read(conn); + if (bytes_read < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + log_debug("[ROUTING] Closing connection fd=%d due to read error", conn->fd); + connection_close(conn->fd); + return; + } + + buffer_t *buf = &conn->read_buf; + + if (conn->state == CLIENT_STATE_FORWARDING && conn->pair == NULL) { + conn->state = CLIENT_STATE_READING_HEADERS; + } + + if (conn->state == CLIENT_STATE_FORWARDING && conn->pair != NULL) { + char *data_start = buf->data + buf->head; + size_t data_len = buffer_available_read(buf); + + if (data_len >= 1) { + int looks_like_new_request = http_is_request_start(data_start, data_len); + + if (!looks_like_new_request) { + return; + } + + log_debug("Pipelined request detected on fd %d, closing upstream fd %d", + conn->fd, conn->pair->fd); + connection_close(conn->pair->fd); + conn->pair = NULL; + conn->state = CLIENT_STATE_READING_HEADERS; + } else { + return; + } + } + + while (buffer_available_read(buf) > 0 && conn->state == CLIENT_STATE_READING_HEADERS) { + char *data_start = buf->data + buf->head; + size_t data_len = buffer_available_read(buf); + + char *headers_end = memmem(data_start, data_len, "\r\n\r\n", 4); + + if (!headers_end) { + if (data_len >= MAX_HEADER_SIZE) { + connection_send_error_response(conn, 413, "Request Header Too Large", "Header is too large."); + return; + } + log_debug("fd %d: Incomplete headers, waiting for more data.", conn->fd); + break; + } + + size_t headers_len = (headers_end - data_start) + 4; + int parse_result = http_parse_request(data_start, headers_len, &conn->request); + + if (parse_result == 0) { + connection_send_error_response(conn, 400, "Bad Request", "Malformed HTTP request."); + return; + } + if (parse_result < 0) { + break; + } + + long long body_len = (conn->request.content_length > 0) ? conn->request.content_length : 0; + size_t total_request_len = headers_len + body_len; + + if (!conn->request.is_chunked && data_len < total_request_len) { + log_debug("fd %d: Incomplete body, waiting for more data.", conn->fd); + break; + } + + size_t len_to_forward = (conn->request.is_chunked) ? headers_len : total_request_len; + + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + conn->request_start_time = ts.tv_sec + ts.tv_nsec / 1e9; + + if (strncmp(conn->request.uri, "/rproxy/dashboard", 17) == 0 || + strncmp(conn->request.uri, "/rproxy/api/stats", 17) == 0) { + + log_info("[ROUTING-INTERNAL] Serving internal route %s for fd=%d", + conn->request.uri, conn->fd); + + if (conn->pair) { + connection_close(conn->pair->fd); + conn->pair = NULL; + } + + conn->state = CLIENT_STATE_SERVING_INTERNAL; + if (strncmp(conn->request.uri, "/rproxy/dashboard", 17) == 0) { + dashboard_serve(conn); + } else { + dashboard_serve_stats_api(conn); + } + + buffer_consume(buf, total_request_len); + + if (!conn->request.keep_alive) { + conn->state = CLIENT_STATE_CLOSING; + return; + } + + memset(&conn->request, 0, sizeof(http_request_t)); + conn->request.keep_alive = 1; + conn->state = CLIENT_STATE_READING_HEADERS; + continue; + } + + log_info("[ROUTING-FORWARD] Forwarding request for fd=%d: %s %s", + conn->fd, conn->request.method, conn->request.uri); + + conn->vhost_stats = monitor_get_or_create_vhost_stats(conn->request.host); + monitor_record_request_start(conn->vhost_stats, conn->request.is_websocket); + + conn->state = CLIENT_STATE_FORWARDING; + connection_connect_to_upstream(conn, data_start, len_to_forward); + buffer_consume(buf, len_to_forward); + + return; + } +} + +static void handle_forwarding(connection_t *conn) { + connection_t *pair = conn->pair; + + if (!pair || pair->fd == -1) { + if (conn->type == CONN_TYPE_CLIENT) { + log_info("ROUTING-ORPHAN: Client fd=%d lost upstream, resetting to READING_HEADERS", conn->fd); + conn->state = CLIENT_STATE_READING_HEADERS; + conn->pair = NULL; + if (buffer_available_read(&conn->read_buf) > 0) { + handle_client_read(conn); + } + } else { + connection_close(conn->fd); + } + return; + } + + int bytes_read = connection_do_read(conn); + + if (bytes_read == 0) { + log_debug("EOF on fd %d, performing half-close on pair fd %d", conn->fd, pair->fd); + conn->half_closed = 1; + connection_modify_epoll(conn->fd, buffer_available_read(&conn->write_buf) ? EPOLLOUT : 0); + + if (pair->fd != -1 && !pair->write_shutdown) { + if (shutdown(pair->fd, SHUT_WR) == -1 && errno != ENOTCONN) { + log_debug("shutdown(SHUT_WR) failed for fd %d: %s", pair->fd, strerror(errno)); + } + pair->write_shutdown = 1; + } + + if (pair->half_closed) { + connection_close(conn->fd); + } + return; + } + + if (bytes_read < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + connection_close(conn->fd); + return; + } + + size_t data_to_forward = buffer_available_read(&conn->read_buf); + if (data_to_forward > 0) { + if (buffer_available_read(&pair->write_buf) > CHUNK_SIZE / 2) { + connection_do_write(pair); + } + + size_t space_needed = pair->write_buf.tail + data_to_forward; + if (buffer_ensure_capacity(&pair->write_buf, space_needed) < 0) { + log_debug("Failed to buffer data for fd %d, closing connection", conn->fd); + connection_close(conn->fd); + return; + } + + memcpy(pair->write_buf.data + pair->write_buf.tail, + conn->read_buf.data + conn->read_buf.head, + data_to_forward); + pair->write_buf.tail += data_to_forward; + buffer_consume(&conn->read_buf, data_to_forward); + + connection_do_write(pair); + + connection_modify_epoll(pair->fd, EPOLLIN | EPOLLOUT); + } +} + +static void handle_ssl_handshake(connection_t *conn) { + if (!conn->ssl || conn->ssl_handshake_done) return; + + int ret = SSL_do_handshake(conn->ssl); + if (ret == 1) { + conn->ssl_handshake_done = 1; + log_debug("SSL handshake completed for fd %d", conn->fd); + + if (conn->type == CONN_TYPE_UPSTREAM && conn->pair) { + connection_t *client = conn->pair; + + if (buffer_available_read(&client->read_buf) > 0) { + size_t data_len = buffer_available_read(&client->read_buf); + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + data_len) == 0) { + memcpy(conn->write_buf.data + conn->write_buf.tail, + client->read_buf.data + client->read_buf.head, + data_len); + conn->write_buf.tail += data_len; + buffer_consume(&client->read_buf, data_len); + log_debug("Forwarding %zu bytes of buffered request data after SSL handshake", data_len); + } + } + } + + if (buffer_available_read(&conn->write_buf) > 0) { + connection_modify_epoll(conn->fd, EPOLLIN | EPOLLOUT); + } else { + connection_modify_epoll(conn->fd, EPOLLIN); + } + } else { + int ssl_error = SSL_get_error(conn->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + connection_modify_epoll(conn->fd, EPOLLIN); + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + connection_modify_epoll(conn->fd, EPOLLOUT); + } else { + log_debug("SSL handshake failed for fd %d: %d", conn->fd, ssl_error); + if (conn->pair) { + connection_send_error_response(conn->pair, 502, "Bad Gateway", "SSL handshake failed"); + } else { + connection_close(conn->fd); + } + } + } +} + +static void handle_write_event(connection_t *conn) { + conn->last_activity = time(NULL); + + if (conn->type == CONN_TYPE_UPSTREAM && conn->ssl && !conn->ssl_handshake_done) { + handle_ssl_handshake(conn); + if (!conn->ssl_handshake_done) { + return; + } + } + + int written = connection_do_write(conn); + + if (buffer_available_read(&conn->write_buf) == 0) { + if (conn->write_shutdown) { + if (conn->half_closed) { + connection_close(conn->fd); + return; + } + connection_modify_epoll(conn->fd, EPOLLIN); + } else { + if (conn->state == CLIENT_STATE_ERROR || + (conn->state == CLIENT_STATE_SERVING_INTERNAL && !conn->request.keep_alive)) { + connection_close(conn->fd); + } else if (conn->state == CLIENT_STATE_SERVING_INTERNAL && conn->request.keep_alive) { + memset(&conn->request, 0, sizeof(http_request_t)); + conn->request.keep_alive = 1; + conn->state = CLIENT_STATE_READING_HEADERS; + connection_modify_epoll(conn->fd, EPOLLIN); + } else { + connection_modify_epoll(conn->fd, EPOLLIN); + } + } + } else if (written < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + connection_close(conn->fd); + } +} + +void connection_handle_event(struct epoll_event *event) { + int fd = event->data.fd; + if (fd < 0 || fd >= MAX_FDS) return; + + connection_t *conn = &connections[fd]; + if (conn->type == CONN_TYPE_UNUSED || conn->fd == -1) return; + + if (event->events & (EPOLLERR | EPOLLHUP)) { + if (event->events & EPOLLERR) { + log_debug("EPOLLERR on fd %d", fd); + } + connection_close(fd); + return; + } + + if (conn->type == CONN_TYPE_LISTENER) { + if (event->events & EPOLLIN) { + connection_accept(fd); + } + } else { + if (conn->type == CONN_TYPE_UPSTREAM && conn->ssl && !conn->ssl_handshake_done) { + handle_ssl_handshake(conn); + if (!conn->ssl_handshake_done) { + return; + } + } + + if (event->events & EPOLLOUT) { + handle_write_event(conn); + } + + if (connections[fd].type != CONN_TYPE_UNUSED && (event->events & EPOLLIN)) { + if (conn->type == CONN_TYPE_CLIENT && + (conn->state == CLIENT_STATE_READING_HEADERS || + conn->state == CLIENT_STATE_SERVING_INTERNAL)) { + handle_client_read(conn); + } else if (conn->state == CLIENT_STATE_FORWARDING) { + handle_forwarding(conn); + } + } + } +} + +void connection_cleanup_idle(void) { + time_t current_time = time(NULL); + + for (int i = 0; i < MAX_FDS; i++) { + connection_t *conn = &connections[i]; + if (conn->type != CONN_TYPE_UNUSED && + conn->type != CONN_TYPE_LISTENER && + conn->fd != -1) { + if (current_time - conn->last_activity > CONNECTION_TIMEOUT) { + log_debug("Closing idle connection fd=%d", i); + connection_close(i); + } + } + } +} diff --git a/src/connection.h b/src/connection.h new file mode 100644 index 0000000..2b99669 --- /dev/null +++ b/src/connection.h @@ -0,0 +1,28 @@ +#ifndef RPROXY_CONNECTION_H +#define RPROXY_CONNECTION_H + +#include "types.h" +#include + +extern connection_t connections[MAX_FDS]; +extern int epoll_fd; + +void connection_init_all(void); +void connection_setup_listener(int port); +void connection_accept(int listener_fd); +void connection_close(int fd); +void connection_handle_event(struct epoll_event *event); +void connection_cleanup_idle(void); + +void connection_set_non_blocking(int fd); +void connection_set_tcp_keepalive(int fd); +void connection_add_to_epoll(int fd, uint32_t events); +void connection_modify_epoll(int fd, uint32_t events); + +int connection_do_read(connection_t *conn); +int connection_do_write(connection_t *conn); + +void connection_send_error_response(connection_t *conn, int code, const char* status, const char* body); +void connection_connect_to_upstream(connection_t *client, const char *data, size_t data_len); + +#endif diff --git a/src/dashboard.c b/src/dashboard.c new file mode 100644 index 0000000..cdf8284 --- /dev/null +++ b/src/dashboard.c @@ -0,0 +1,505 @@ +#include "dashboard.h" +#include "buffer.h" +#include "monitor.h" +#include "connection.h" +#include "../cJSON.h" +#include +#include +#include +#include + +static const char *DASHBOARD_HTML = +"\n" +"\n" +"\n" +" Reverse Proxy Monitor\n" +" \n" +" \n" +"\n" +"\n" +"
\n" +"
\n" +"
0
\n" +"
Connections
\n" +"
\n" +"
\n" +"
0
\n" +"
Memory
\n" +"
\n" +"
\n" +"
0
\n" +"
CPU %
\n" +"
\n" +"
\n" +"
0.00
\n" +"
Load 1m
\n" +"
\n" +"
\n" +"
0.00
\n" +"
Load 5m
\n" +"
\n" +"
\n" +"
0.00
\n" +"
Load 15m
\n" +"
\n" +"
\n" +"\n" +"
\n" +"
CPU Usage
\n" +" \n" +"
\n" +"\n" +"
\n" +"
Memory Usage
\n" +" \n" +"
\n" +"\n" +"
\n" +"
Network I/O
\n" +"
\n" +"
RX
\n" +"
TX
\n" +"
\n" +" \n" +"
\n" +"\n" +"
\n" +"
Disk I/O
\n" +"
\n" +"
Read
\n" +"
Write
\n" +"
\n" +" \n" +"
\n" +"\n" +"
\n" +"
Load Average
\n" +"
\n" +"
1 min
\n" +"
5 min
\n" +"
15 min
\n" +"
\n" +" \n" +"
\n" +"\n" +"
\n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +" \n" +"
Virtual HostHTTP ReqWS ReqTotal ReqAvg Resp (ms)SentReceived
\n" +"
\n" +"\n" +" \n" +"\n" +"\n"; + +void dashboard_serve(connection_t *conn) { + if (!conn) return; + + size_t content_len = strlen(DASHBOARD_HTML); + char header[512]; + int len = snprintf(header, sizeof(header), + "HTTP/1.1 200 OK\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "Content-Length: %zu\r\n" + "Connection: %s\r\n" + "Cache-Control: no-cache\r\n" + "\r\n", + content_len, + conn->request.keep_alive ? "keep-alive" : "close"); + + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len + content_len) < 0) { + connection_send_error_response(conn, 500, "Internal Server Error", "Memory allocation failed"); + return; + } + + memcpy(conn->write_buf.data + conn->write_buf.tail, header, len); + conn->write_buf.tail += len; + memcpy(conn->write_buf.data + conn->write_buf.tail, DASHBOARD_HTML, content_len); + conn->write_buf.tail += content_len; + + struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); +} + +static cJSON* format_history(history_deque_t *dq, int window_seconds) { + cJSON *arr = cJSON_CreateArray(); + if (!arr || !dq || !dq->points || dq->count == 0) return arr; + + double current_time = time(NULL); + int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity; + + for (int i = 0; i < dq->count; ++i) { + int current_index = (start_index + i) % dq->capacity; + history_point_t *p = &dq->points[current_index]; + if ((current_time - p->time) <= window_seconds) { + cJSON *pt = cJSON_CreateObject(); + if (pt) { + cJSON_AddNumberToObject(pt, "x", (long)((p->time - current_time) * 1000)); + cJSON_AddNumberToObject(pt, "y", p->value); + cJSON_AddItemToArray(arr, pt); + } + } + } + return arr; +} + +static cJSON* format_network_history(network_history_deque_t *dq, int window_seconds, const char *key) { + cJSON *arr = cJSON_CreateArray(); + if (!arr || !dq || !dq->points || !key || dq->count == 0) return arr; + + double current_time = time(NULL); + int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity; + + for (int i = 0; i < dq->count; ++i) { + int current_index = (start_index + i) % dq->capacity; + network_history_point_t *p = &dq->points[current_index]; + if ((current_time - p->time) <= window_seconds) { + cJSON *pt = cJSON_CreateObject(); + if (pt) { + cJSON_AddNumberToObject(pt, "x", (long)((p->time - current_time) * 1000)); + cJSON_AddNumberToObject(pt, "y", strcmp(key, "rx_kbps") == 0 ? p->rx_kbps : p->tx_kbps); + cJSON_AddItemToArray(arr, pt); + } + } + } + return arr; +} + +static cJSON* format_disk_history(disk_history_deque_t *dq, int window_seconds, const char *key) { + cJSON *arr = cJSON_CreateArray(); + if (!arr || !dq || !dq->points || !key || dq->count == 0) return arr; + + double current_time = time(NULL); + int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity; + + for (int i = 0; i < dq->count; ++i) { + int current_index = (start_index + i) % dq->capacity; + disk_history_point_t *p = &dq->points[current_index]; + if ((current_time - p->time) <= window_seconds) { + cJSON *pt = cJSON_CreateObject(); + if (pt) { + cJSON_AddNumberToObject(pt, "x", (long)((p->time - current_time) * 1000)); + cJSON_AddNumberToObject(pt, "y", strcmp(key, "read_mbps") == 0 ? p->read_mbps : p->write_mbps); + cJSON_AddItemToArray(arr, pt); + } + } + } + return arr; +} + +void dashboard_serve_stats_api(connection_t *conn) { + if (!conn) return; + + cJSON *root = cJSON_CreateObject(); + if (!root) { + connection_send_error_response(conn, 500, "Internal Server Error", "JSON creation failed"); + return; + } + + cJSON *current = cJSON_CreateObject(); + if (!current) { + cJSON_Delete(root); + connection_send_error_response(conn, 500, "Internal Server Error", "JSON creation failed"); + return; + } + cJSON_AddItemToObject(root, "current", current); + + char buffer[64]; + double last_cpu = 0, last_mem = 0; + double load1 = 0, load5 = 0, load15 = 0; + + if (monitor.cpu_history.count > 0) { + int last_idx = (monitor.cpu_history.head - 1 + monitor.cpu_history.capacity) % monitor.cpu_history.capacity; + last_cpu = monitor.cpu_history.points[last_idx].value; + } + + if (monitor.memory_history.count > 0) { + int last_idx = (monitor.memory_history.head - 1 + monitor.memory_history.capacity) % monitor.memory_history.capacity; + last_mem = monitor.memory_history.points[last_idx].value; + } + + if (monitor.load1_history.count > 0) { + int idx = (monitor.load1_history.head - 1 + monitor.load1_history.capacity) % monitor.load1_history.capacity; + load1 = monitor.load1_history.points[idx].value; + } + if (monitor.load5_history.count > 0) { + int idx = (monitor.load5_history.head - 1 + monitor.load5_history.capacity) % monitor.load5_history.capacity; + load5 = monitor.load5_history.points[idx].value; + } + if (monitor.load15_history.count > 0) { + int idx = (monitor.load15_history.head - 1 + monitor.load15_history.capacity) % monitor.load15_history.capacity; + load15 = monitor.load15_history.points[idx].value; + } + + snprintf(buffer, sizeof(buffer), "%.2f", last_cpu); + cJSON_AddStringToObject(current, "cpu_percent", buffer); + snprintf(buffer, sizeof(buffer), "%.2f", last_mem); + cJSON_AddStringToObject(current, "memory_gb", buffer); + cJSON_AddNumberToObject(current, "active_connections", monitor.active_connections); + cJSON_AddNumberToObject(current, "load_1m", load1); + cJSON_AddNumberToObject(current, "load_5m", load5); + cJSON_AddNumberToObject(current, "load_15m", load15); + + cJSON_AddItemToObject(root, "cpu_history", format_history(&monitor.cpu_history, HISTORY_SECONDS)); + cJSON_AddItemToObject(root, "memory_history", format_history(&monitor.memory_history, HISTORY_SECONDS)); + cJSON_AddItemToObject(root, "network_rx_history", format_network_history(&monitor.network_history, HISTORY_SECONDS, "rx_kbps")); + cJSON_AddItemToObject(root, "network_tx_history", format_network_history(&monitor.network_history, HISTORY_SECONDS, "tx_kbps")); + cJSON_AddItemToObject(root, "disk_read_history", format_disk_history(&monitor.disk_history, HISTORY_SECONDS, "read_mbps")); + cJSON_AddItemToObject(root, "disk_write_history", format_disk_history(&monitor.disk_history, HISTORY_SECONDS, "write_mbps")); + cJSON_AddItemToObject(root, "throughput_history", format_history(&monitor.throughput_history, HISTORY_SECONDS)); + cJSON_AddItemToObject(root, "load1_history", format_history(&monitor.load1_history, HISTORY_SECONDS)); + cJSON_AddItemToObject(root, "load5_history", format_history(&monitor.load5_history, HISTORY_SECONDS)); + cJSON_AddItemToObject(root, "load15_history", format_history(&monitor.load15_history, HISTORY_SECONDS)); + + cJSON *processes = cJSON_CreateArray(); + if (processes) { + cJSON_AddItemToObject(root, "processes", processes); + for (vhost_stats_t *s = monitor.vhost_stats_head; s; s = s->next) { + cJSON *p = cJSON_CreateObject(); + if (p) { + cJSON_AddStringToObject(p, "name", s->vhost_name); + cJSON_AddNumberToObject(p, "http_requests", s->http_requests); + cJSON_AddNumberToObject(p, "websocket_requests", s->websocket_requests); + cJSON_AddNumberToObject(p, "total_requests", s->total_requests); + cJSON_AddNumberToObject(p, "avg_request_time_ms", s->avg_request_time_ms); + cJSON_AddNumberToObject(p, "bytes_sent", s->bytes_sent); + cJSON_AddNumberToObject(p, "bytes_recv", s->bytes_recv); + cJSON_AddItemToObject(p, "throughput_history", format_history(&s->throughput_history, 60)); + cJSON_AddItemToArray(processes, p); + } + } + } + + char *json_string = cJSON_PrintUnformatted(root); + if (!json_string) { + cJSON_Delete(root); + connection_send_error_response(conn, 500, "Internal Server Error", "JSON serialization failed"); + return; + } + + char header[512]; + int hlen = snprintf(header, sizeof(header), + "HTTP/1.1 200 OK\r\n" + "Content-Type: application/json; charset=utf-8\r\n" + "Content-Length: %zu\r\n" + "Connection: %s\r\n" + "Cache-Control: no-cache\r\n" + "\r\n", + strlen(json_string), + conn->request.keep_alive ? "keep-alive" : "close"); + + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + hlen + strlen(json_string)) < 0) { + free(json_string); + cJSON_Delete(root); + connection_send_error_response(conn, 500, "Internal Server Error", "Memory allocation failed"); + return; + } + + memcpy(conn->write_buf.data + conn->write_buf.tail, header, hlen); + conn->write_buf.tail += hlen; + memcpy(conn->write_buf.data + conn->write_buf.tail, json_string, strlen(json_string)); + conn->write_buf.tail += strlen(json_string); + + struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); + + cJSON_Delete(root); + free(json_string); +} diff --git a/src/dashboard.h b/src/dashboard.h new file mode 100644 index 0000000..bf06406 --- /dev/null +++ b/src/dashboard.h @@ -0,0 +1,9 @@ +#ifndef RPROXY_DASHBOARD_H +#define RPROXY_DASHBOARD_H + +#include "types.h" + +void dashboard_serve(connection_t *conn); +void dashboard_serve_stats_api(connection_t *conn); + +#endif diff --git a/src/http.c b/src/http.c new file mode 100644 index 0000000..3e7f25f --- /dev/null +++ b/src/http.c @@ -0,0 +1,152 @@ +#include "http.h" +#include +#include +#include + +int http_find_header_value(const char* data, size_t len, const char* name, char* value, size_t value_size) { + size_t name_len = strlen(name); + const char* end = data + len; + + for (const char* p = data; p < end; ) { + const char* line_end = memchr(p, '\n', end - p); + if (!line_end) break; + + if (line_end > p && *(line_end - 1) == '\r') { + line_end--; + } + + if ((size_t)(line_end - p) > name_len && strncasecmp(p, name, name_len) == 0 && p[name_len] == ':') { + const char* v_start = p + name_len + 1; + while (v_start < line_end && (*v_start == ' ' || *v_start == '\t')) v_start++; + + const char* v_end = line_end; + while (v_end > v_start && (*(v_end - 1) == ' ' || *(v_end - 1) == '\t')) v_end--; + + size_t copy_len = v_end - v_start; + if (copy_len >= value_size) copy_len = value_size - 1; + + memcpy(value, v_start, copy_len); + value[copy_len] = '\0'; + return 1; + } + + p = line_end + ( (line_end < end && *(line_end) == '\r') ? 2 : 1 ); + } + return 0; +} + +static int is_valid_method_char(char c) { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); +} + +static int is_valid_http_method(const char *method, size_t len) { + if (len == 0 || len > 31) return 0; + for (size_t i = 0; i < len; i++) { + if (!is_valid_method_char(method[i])) return 0; + } + return 1; +} + +int http_is_request_start(const char *data, size_t len) { + if (len < 4) return 0; + + const char *space = memchr(data, ' ', len > 32 ? 32 : len); + if (!space) return 0; + + size_t method_len = space - data; + if (method_len == 0 || method_len > 31) return 0; + + for (size_t i = 0; i < method_len; i++) { + if (!is_valid_method_char(data[i])) return 0; + } + + return 1; +} + +int http_parse_request(const char *data, size_t len, http_request_t *req) { + memset(req, 0, sizeof(http_request_t)); + req->content_length = -1; + req->keep_alive = 1; + + const char *line_end = memchr(data, '\n', len); + if (!line_end) return -1; + size_t line_len = line_end - data; + if (line_len > 0 && data[line_len - 1] == '\r') { + line_len--; + } + + const char *method_end = memchr(data, ' ', line_len); + if (!method_end) return 0; + size_t method_len = method_end - data; + if (method_len >= sizeof(req->method)) return 0; + memcpy(req->method, data, method_len); + req->method[method_len] = '\0'; + + if (!is_valid_http_method(req->method, method_len)) return 0; + + const char *uri_start = method_end + 1; + while (uri_start < data + line_len && *uri_start == ' ') uri_start++; + const char *uri_end = data + line_len; + + const char *version_start = NULL; + for (const char *p = uri_end - 1; p > uri_start; p--) { + if (*p == ' ') { + version_start = p + 1; + uri_end = p; + break; + } + } + + if (!version_start || version_start == uri_start) return 0; + + size_t uri_len = uri_end - uri_start; + if (uri_len >= sizeof(req->uri)) return 0; + memcpy(req->uri, uri_start, uri_len); + req->uri[uri_len] = '\0'; + + while (version_start < data + line_len && *version_start == ' ') version_start++; + const char *actual_line_end = data + line_len; + size_t version_len = actual_line_end - version_start; + if (version_len >= sizeof(req->version)) return 0; + memcpy(req->version, version_start, version_len); + req->version[version_len] = '\0'; + + if (strncmp(req->version, "HTTP/1.0", 8) == 0) { + req->keep_alive = 0; + } + + const char *headers_start = line_end + 1; + char value[1024]; + + if (http_find_header_value(headers_start, len - (headers_start - data), "Host", req->host, sizeof(req->host))) { + char *port_colon = strchr(req->host, ':'); + if (port_colon) *port_colon = '\0'; + } + + if (http_find_header_value(headers_start, len - (headers_start - data), "Content-Length", value, sizeof(value))) { + req->content_length = atol(value); + } + + if (http_find_header_value(headers_start, len - (headers_start - data), "Transfer-Encoding", value, sizeof(value))) { + if (strcasecmp(value, "chunked") == 0) { + req->is_chunked = 1; + } + } + + if (http_find_header_value(headers_start, len - (headers_start - data), "Connection", value, sizeof(value))) { + if (strcasecmp(value, "close") == 0) { + req->keep_alive = 0; + req->connection_close = 1; + } else if (strcasecmp(value, "keep-alive") == 0) { + req->keep_alive = 1; + } else if (strcasecmp(value, "upgrade") == 0) { + req->is_websocket = 1; + } + } + + if (http_find_header_value(headers_start, len - (headers_start - data), "Upgrade", value, sizeof(value))) { + if (strcasecmp(value, "websocket") == 0) req->is_websocket = 1; + } + + return 1; +} diff --git a/src/http.h b/src/http.h new file mode 100644 index 0000000..29d71c2 --- /dev/null +++ b/src/http.h @@ -0,0 +1,10 @@ +#ifndef RPROXY_HTTP_H +#define RPROXY_HTTP_H + +#include "types.h" + +int http_parse_request(const char *data, size_t len, http_request_t *req); +int http_find_header_value(const char* data, size_t len, const char* name, char* value, size_t value_size); +int http_is_request_start(const char *data, size_t len); + +#endif diff --git a/src/logging.c b/src/logging.c new file mode 100644 index 0000000..7efc5fc --- /dev/null +++ b/src/logging.c @@ -0,0 +1,44 @@ +#include "logging.h" +#include +#include +#include + +static int g_debug_mode = 0; + +void logging_set_debug(int enabled) { + g_debug_mode = enabled; +} + +int logging_get_debug(void) { + return g_debug_mode; +} + +void log_error(const char *msg) { + perror(msg); +} + +static void log_message(const char *level, const char *format, va_list args) { + time_t now; + time(&now); + char buf[32]; + strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", localtime(&now)); + printf("%s - %-5s - ", buf, level); + vprintf(format, args); + printf("\n"); + fflush(stdout); +} + +void log_info(const char *format, ...) { + va_list args; + va_start(args, format); + log_message("INFO", format, args); + va_end(args); +} + +void log_debug(const char *format, ...) { + if (!g_debug_mode) return; + va_list args; + va_start(args, format); + log_message("DEBUG", format, args); + va_end(args); +} diff --git a/src/logging.h b/src/logging.h new file mode 100644 index 0000000..2847dea --- /dev/null +++ b/src/logging.h @@ -0,0 +1,10 @@ +#ifndef RPROXY_LOGGING_H +#define RPROXY_LOGGING_H + +void log_error(const char *msg); +void log_info(const char *format, ...); +void log_debug(const char *format, ...); +void logging_set_debug(int enabled); +int logging_get_debug(void); + +#endif diff --git a/src/main.c b/src/main.c new file mode 100644 index 0000000..e3dd134 --- /dev/null +++ b/src/main.c @@ -0,0 +1,115 @@ +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include +#include +#include +#include +#include + +#include "types.h" +#include "logging.h" +#include "config.h" +#include "monitor.h" +#include "ssl_handler.h" +#include "connection.h" + +static volatile int g_shutdown = 0; + +static void signal_handler(int sig) { + if (sig == SIGINT || sig == SIGTERM) { + log_info("Received signal %d, shutting down...", sig); + g_shutdown = 1; + } +} + +static void cleanup(void) { + log_info("Cleaning up resources..."); + + for (int i = 0; i < MAX_FDS; i++) { + if (connections[i].type != CONN_TYPE_UNUSED && connections[i].fd != -1) { + connection_close(i); + } + } + + config_free(); + monitor_cleanup(); + + if (epoll_fd >= 0) { + close(epoll_fd); + epoll_fd = -1; + } + + ssl_cleanup(); + + log_info("Cleanup complete"); +} + +int main(int argc, char *argv[]) { + signal(SIGPIPE, SIG_IGN); + signal(SIGINT, signal_handler); + signal(SIGTERM, signal_handler); + + if (getenv("DEBUG")) { + logging_set_debug(1); + log_info("Debug mode enabled"); + } + + const char *config_file = (argc > 1) ? argv[1] : "proxy_config.json"; + config_create_default(config_file); + + if (!config_load(config_file)) { + fprintf(stderr, "Failed to load configuration\n"); + return 1; + } + + ssl_init(); + monitor_init("proxy_stats.db"); + + epoll_fd = epoll_create1(EPOLL_CLOEXEC); + if (epoll_fd == -1) { + log_error("epoll_create1 failed"); + return 1; + } + + connection_init_all(); + connection_setup_listener(config.port); + + log_info("Port %d", config.port); + log_info("Dashboard: http://localhost:%d/rproxy/dashboard", config.port); + log_info("Stats: http://localhost:%d/rproxy/api/stats", config.port); + + atexit(cleanup); + + struct epoll_event events[MAX_EVENTS]; + time_t last_monitor_update = 0; + time_t last_cleanup = 0; + + while (!g_shutdown) { + int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 1000); + if (n == -1) { + if (errno == EINTR) continue; + log_error("epoll_wait failed"); + break; + } + + for (int i = 0; i < n; i++) { + connection_handle_event(&events[i]); + } + + time_t current_time = time(NULL); + + if (current_time > last_monitor_update) { + monitor_update(); + last_monitor_update = current_time; + } + + if (current_time - last_cleanup >= 60) { + connection_cleanup_idle(); + last_cleanup = current_time; + } + } + + log_info("Shutdown complete"); + return 0; +} diff --git a/src/monitor.c b/src/monitor.c new file mode 100644 index 0000000..62b7d08 --- /dev/null +++ b/src/monitor.c @@ -0,0 +1,441 @@ +#include "monitor.h" +#include "logging.h" +#include +#include +#include +#include +#include + +system_monitor_t monitor; + +void history_deque_init(history_deque_t *dq, int capacity) { + dq->points = calloc(capacity, sizeof(history_point_t)); + dq->capacity = capacity; + dq->head = 0; + dq->count = 0; +} + +void history_deque_push(history_deque_t *dq, double time, double value) { + if (!dq || !dq->points) return; + dq->points[dq->head] = (history_point_t){ .time = time, .value = value }; + dq->head = (dq->head + 1) % dq->capacity; + if (dq->count < dq->capacity) dq->count++; +} + +void network_history_deque_init(network_history_deque_t *dq, int capacity) { + dq->points = calloc(capacity, sizeof(network_history_point_t)); + dq->capacity = capacity; + dq->head = 0; + dq->count = 0; +} + +void network_history_deque_push(network_history_deque_t *dq, double time, double rx, double tx) { + if (!dq || !dq->points) return; + dq->points[dq->head] = (network_history_point_t){ .time = time, .rx_kbps = rx, .tx_kbps = tx }; + dq->head = (dq->head + 1) % dq->capacity; + if (dq->count < dq->capacity) dq->count++; +} + +void disk_history_deque_init(disk_history_deque_t *dq, int capacity) { + dq->points = calloc(capacity, sizeof(disk_history_point_t)); + dq->capacity = capacity; + dq->head = 0; + dq->count = 0; +} + +void disk_history_deque_push(disk_history_deque_t *dq, double time, double read_mbps, double write_mbps) { + if (!dq || !dq->points) return; + dq->points[dq->head] = (disk_history_point_t){ .time = time, .read_mbps = read_mbps, .write_mbps = write_mbps }; + dq->head = (dq->head + 1) % dq->capacity; + if (dq->count < dq->capacity) dq->count++; +} + +void request_time_deque_init(request_time_deque_t *dq, int capacity) { + dq->times = calloc(capacity, sizeof(double)); + dq->capacity = capacity; + dq->head = 0; + dq->count = 0; +} + +void request_time_deque_push(request_time_deque_t *dq, double time_ms) { + if (!dq || !dq->times) return; + dq->times[dq->head] = time_ms; + dq->head = (dq->head + 1) % dq->capacity; + if (dq->count < dq->capacity) dq->count++; +} + +static void init_db(void) { + if (!monitor.db) return; + + char *err_msg = 0; + const char *sql_create_table = + "CREATE TABLE IF NOT EXISTS vhost_stats (" + " id INTEGER PRIMARY KEY AUTOINCREMENT," + " vhost TEXT NOT NULL," + " timestamp REAL NOT NULL," + " http_requests INTEGER DEFAULT 0," + " websocket_requests INTEGER DEFAULT 0," + " total_requests INTEGER DEFAULT 0," + " bytes_sent INTEGER DEFAULT 0," + " bytes_recv INTEGER DEFAULT 0," + " avg_request_time_ms REAL DEFAULT 0," + " UNIQUE(vhost, timestamp)" + ");"; + const char *sql_create_index = + "CREATE INDEX IF NOT EXISTS idx_vhost_timestamp ON vhost_stats(vhost, timestamp);"; + + if (sqlite3_exec(monitor.db, sql_create_table, 0, 0, &err_msg) != SQLITE_OK || + sqlite3_exec(monitor.db, sql_create_index, 0, 0, &err_msg) != SQLITE_OK) { + fprintf(stderr, "SQL error: %s\n", err_msg); + sqlite3_free(err_msg); + } +} + +static void load_stats_from_db(void) { + if (!monitor.db) return; + + sqlite3_stmt *res; + const char *sql = + "SELECT vhost, http_requests, websocket_requests, total_requests, " + "bytes_sent, bytes_recv, avg_request_time_ms " + "FROM vhost_stats v1 WHERE timestamp = (" + " SELECT MAX(timestamp) FROM vhost_stats v2 WHERE v2.vhost = v1.vhost" + ")"; + + if (sqlite3_prepare_v2(monitor.db, sql, -1, &res, 0) != SQLITE_OK) { + fprintf(stderr, "Failed to execute statement: %s\n", sqlite3_errmsg(monitor.db)); + return; + } + + int vhost_count = 0; + while (sqlite3_step(res) == SQLITE_ROW) { + vhost_stats_t *stats = monitor_get_or_create_vhost_stats((const char*)sqlite3_column_text(res, 0)); + if (stats) { + stats->http_requests = sqlite3_column_int64(res, 1); + stats->websocket_requests = sqlite3_column_int64(res, 2); + stats->total_requests = sqlite3_column_int64(res, 3); + stats->bytes_sent = sqlite3_column_int64(res, 4); + stats->bytes_recv = sqlite3_column_int64(res, 5); + stats->avg_request_time_ms = sqlite3_column_double(res, 6); + vhost_count++; + } + } + sqlite3_finalize(res); + log_info("Loaded statistics for %d vhosts from database", vhost_count); +} + +void monitor_init(const char *db_file) { + memset(&monitor, 0, sizeof(system_monitor_t)); + monitor.start_time = time(NULL); + + history_deque_init(&monitor.cpu_history, HISTORY_SECONDS); + history_deque_init(&monitor.memory_history, HISTORY_SECONDS); + network_history_deque_init(&monitor.network_history, HISTORY_SECONDS); + disk_history_deque_init(&monitor.disk_history, HISTORY_SECONDS); + history_deque_init(&monitor.throughput_history, HISTORY_SECONDS); + history_deque_init(&monitor.load1_history, HISTORY_SECONDS); + history_deque_init(&monitor.load5_history, HISTORY_SECONDS); + history_deque_init(&monitor.load15_history, HISTORY_SECONDS); + + if (sqlite3_open(db_file, &monitor.db) != SQLITE_OK) { + fprintf(stderr, "Can't open database: %s\n", sqlite3_errmsg(monitor.db)); + if (monitor.db) { + sqlite3_close(monitor.db); + monitor.db = NULL; + } + } else { + init_db(); + load_stats_from_db(); + } + monitor_update(); +} + +void monitor_cleanup(void) { + if (monitor.db) { + sqlite3_close(monitor.db); + monitor.db = NULL; + } + + vhost_stats_t *current = monitor.vhost_stats_head; + while (current) { + vhost_stats_t *next = current->next; + if (current->throughput_history.points) free(current->throughput_history.points); + if (current->request_times.times) free(current->request_times.times); + free(current); + current = next; + } + monitor.vhost_stats_head = NULL; + + if (monitor.cpu_history.points) free(monitor.cpu_history.points); + if (monitor.memory_history.points) free(monitor.memory_history.points); + if (monitor.network_history.points) free(monitor.network_history.points); + if (monitor.disk_history.points) free(monitor.disk_history.points); + if (monitor.throughput_history.points) free(monitor.throughput_history.points); + if (monitor.load1_history.points) free(monitor.load1_history.points); + if (monitor.load5_history.points) free(monitor.load5_history.points); + if (monitor.load15_history.points) free(monitor.load15_history.points); +} + +static void save_stats_to_db(void) { + if (!monitor.db) return; + + sqlite3_stmt *stmt; + const char *sql = + "INSERT OR REPLACE INTO vhost_stats " + "(vhost, timestamp, http_requests, websocket_requests, total_requests, " + "bytes_sent, bytes_recv, avg_request_time_ms) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?);"; + + if (sqlite3_prepare_v2(monitor.db, sql, -1, &stmt, NULL) != SQLITE_OK) return; + + double current_time = (double)time(NULL); + for (vhost_stats_t *s = monitor.vhost_stats_head; s != NULL; s = s->next) { + if (s->request_times.count > 0) { + double total_time = 0; + for(int i = 0; i < s->request_times.count; i++) { + total_time += s->request_times.times[i]; + } + s->avg_request_time_ms = total_time / s->request_times.count; + } + + sqlite3_bind_text(stmt, 1, s->vhost_name, -1, SQLITE_STATIC); + sqlite3_bind_double(stmt, 2, current_time); + sqlite3_bind_int64(stmt, 3, s->http_requests); + sqlite3_bind_int64(stmt, 4, s->websocket_requests); + sqlite3_bind_int64(stmt, 5, s->total_requests); + sqlite3_bind_int64(stmt, 6, s->bytes_sent); + sqlite3_bind_int64(stmt, 7, s->bytes_recv); + sqlite3_bind_double(stmt, 8, s->avg_request_time_ms); + sqlite3_step(stmt); + sqlite3_reset(stmt); + } + sqlite3_finalize(stmt); +} + +static double get_cpu_usage(void) { + static long long prev_user = 0, prev_nice = 0, prev_system = 0, prev_idle = 0; + long long user, nice, system, idle, iowait, irq, softirq; + + FILE *f = fopen("/proc/stat", "r"); + if (!f) return 0.0; + + if (fscanf(f, "cpu %lld %lld %lld %lld %lld %lld %lld", + &user, &nice, &system, &idle, &iowait, &irq, &softirq) != 7) { + fclose(f); + return 0.0; + } + fclose(f); + + long long prev_total = prev_user + prev_nice + prev_system + prev_idle; + long long total = user + nice + system + idle; + long long totald = total - prev_total; + long long idled = idle - prev_idle; + + prev_user = user; prev_nice = nice; prev_system = system; prev_idle = idle; + return totald == 0 ? 0.0 : (double)(totald - idled) * 100.0 / totald; +} + +static void get_memory_usage(double *used_gb) { + struct sysinfo info; + if (sysinfo(&info) != 0) { + *used_gb = 0; + return; + } + *used_gb = (double)(info.totalram - info.freeram - info.bufferram) * info.mem_unit / (1024.0 * 1024.0 * 1024.0); +} + +static void get_network_stats(long long *bytes_sent, long long *bytes_recv) { + FILE *f = fopen("/proc/net/dev", "r"); + if (!f) { + *bytes_sent = 0; + *bytes_recv = 0; + return; + } + + char line[256]; + if (!fgets(line, sizeof(line), f) || !fgets(line, sizeof(line), f)) { + fclose(f); + *bytes_sent = 0; + *bytes_recv = 0; + return; + } + + long long total_recv = 0, total_sent = 0; + while (fgets(line, sizeof(line), f)) { + char iface[32]; + long long r, t; + if (sscanf(line, "%31[^:]: %lld %*d %*d %*d %*d %*d %*d %*d %lld", iface, &r, &t) == 3) { + char *trimmed = iface; + while (*trimmed == ' ') trimmed++; + if (strcmp(trimmed, "lo") != 0) { + total_recv += r; + total_sent += t; + } + } + } + fclose(f); + *bytes_sent = total_sent; + *bytes_recv = total_recv; +} + +static void get_disk_stats(long long *sectors_read, long long *sectors_written) { + FILE *f = fopen("/proc/diskstats", "r"); + if (!f) { + *sectors_read = 0; + *sectors_written = 0; + return; + } + + char line[2048]; + long long total_read = 0, total_written = 0; + while (fgets(line, sizeof(line), f)) { + char device[64]; + long long sectors_r = 0, sectors_w = 0; + int nfields = 0; + + char major[16], minor[16], dev[64]; + char rc[32], rm[32], sr[32], rtm[32], rtm2[32], wc[32], wm[32], sw[32]; + nfields = sscanf(line, "%15s %15s %63s %31s %31s %31s %31s %31s %31s %31s %31s %31s", + major, minor, dev, rc, rm, sr, rtm, rtm2, wc, wm, sw, sw); + if (nfields >= 11) { + strncpy(device, dev, sizeof(device)-1); + device[sizeof(device)-1] = '\0'; + sectors_r = atoll(sr); + sectors_w = atoll(sw); + + if (strncmp(device, "loop", 4) != 0 && strncmp(device, "ram", 3) != 0) { + int len = strlen(device); + if ((strncmp(device, "sd", 2) == 0 && len == 3) || + (strncmp(device, "nvme", 4) == 0 && strstr(device, "n1p") == NULL) || + (strncmp(device, "vd", 2) == 0 && len == 3) || + (strncmp(device, "hd", 2) == 0 && len == 3)) { + total_read += sectors_r; + total_written += sectors_w; + } + } + } + } + fclose(f); + *sectors_read = total_read; + *sectors_written = total_written; +} + +static void get_load_averages(double *load1, double *load5, double *load15) { + FILE *f = fopen("/proc/loadavg", "r"); + if (!f) { + *load1 = *load5 = *load15 = 0.0; + return; + } + + if (fscanf(f, "%lf %lf %lf", load1, load5, load15) != 3) { + *load1 = *load5 = *load15 = 0.0; + } + fclose(f); +} + +void monitor_update(void) { + double current_time = time(NULL); + + history_deque_push(&monitor.cpu_history, current_time, get_cpu_usage()); + + double mem_used_gb; + get_memory_usage(&mem_used_gb); + history_deque_push(&monitor.memory_history, current_time, mem_used_gb); + + long long net_sent, net_recv; + get_network_stats(&net_sent, &net_recv); + double time_delta = current_time - monitor.last_net_update_time; + if (time_delta > 0 && monitor.last_net_update_time > 0) { + double rx = (net_recv - monitor.last_net_recv) / time_delta / 1024.0; + double tx = (net_sent - monitor.last_net_sent) / time_delta / 1024.0; + network_history_deque_push(&monitor.network_history, current_time, fmax(0, rx), fmax(0, tx)); + history_deque_push(&monitor.throughput_history, current_time, fmax(0, rx + tx)); + } + monitor.last_net_sent = net_sent; + monitor.last_net_recv = net_recv; + monitor.last_net_update_time = current_time; + + long long disk_read, disk_write; + get_disk_stats(&disk_read, &disk_write); + double disk_time_delta = current_time - monitor.last_disk_update_time; + if (disk_time_delta > 0 && monitor.last_disk_update_time > 0) { + double read_mbps = (disk_read - monitor.last_disk_read) * 512.0 / disk_time_delta / (1024.0 * 1024.0); + double write_mbps = (disk_write - monitor.last_disk_write) * 512.0 / disk_time_delta / (1024.0 * 1024.0); + disk_history_deque_push(&monitor.disk_history, current_time, fmax(0, read_mbps), fmax(0, write_mbps)); + } + monitor.last_disk_read = disk_read; + monitor.last_disk_write = disk_write; + monitor.last_disk_update_time = current_time; + + double load1, load5, load15; + get_load_averages(&load1, &load5, &load15); + history_deque_push(&monitor.load1_history, current_time, load1); + history_deque_push(&monitor.load5_history, current_time, load5); + history_deque_push(&monitor.load15_history, current_time, load15); + + for (vhost_stats_t *s = monitor.vhost_stats_head; s != NULL; s = s->next) { + double vhost_delta = current_time - s->last_update; + if (vhost_delta >= 1.0) { + double kbps = 0; + if (s->last_update > 0) { + long long bytes_diff = (s->bytes_sent - s->last_bytes_sent) + (s->bytes_recv - s->last_bytes_recv); + kbps = bytes_diff / vhost_delta / 1024.0; + } + history_deque_push(&s->throughput_history, current_time, fmax(0, kbps)); + s->last_bytes_sent = s->bytes_sent; + s->last_bytes_recv = s->bytes_recv; + s->last_update = current_time; + } + } + + static time_t last_db_save = 0; + if (current_time - last_db_save >= 10) { + save_stats_to_db(); + last_db_save = current_time; + } +} + +vhost_stats_t* monitor_get_or_create_vhost_stats(const char *vhost_name) { + if (!vhost_name || strlen(vhost_name) == 0) return NULL; + + for (vhost_stats_t *curr = monitor.vhost_stats_head; curr; curr = curr->next) { + if (strcmp(curr->vhost_name, vhost_name) == 0) return curr; + } + + vhost_stats_t *new_stats = calloc(1, sizeof(vhost_stats_t)); + if (!new_stats) return NULL; + + strncpy(new_stats->vhost_name, vhost_name, sizeof(new_stats->vhost_name) - 1); + new_stats->last_update = time(NULL); + history_deque_init(&new_stats->throughput_history, 60); + request_time_deque_init(&new_stats->request_times, 100); + new_stats->next = monitor.vhost_stats_head; + monitor.vhost_stats_head = new_stats; + return new_stats; +} + +void monitor_record_request_start(vhost_stats_t *stats, int is_websocket) { + if (!stats) return; + if (is_websocket) { + __sync_fetch_and_add(&stats->websocket_requests, 1); + } else { + __sync_fetch_and_add(&stats->http_requests, 1); + } + __sync_fetch_and_add(&stats->total_requests, 1); +} + +void monitor_record_request_end(vhost_stats_t *stats, double start_time) { + if (!stats || start_time <= 0) return; + struct timespec end_time; + clock_gettime(CLOCK_MONOTONIC, &end_time); + double duration_ms = ((end_time.tv_sec + end_time.tv_nsec / 1e9) - start_time) * 1000.0; + if (duration_ms >= 0 && duration_ms < 60000) { + request_time_deque_push(&stats->request_times, duration_ms); + } +} + +void monitor_record_bytes(vhost_stats_t *stats, long long sent, long long recv) { + if (!stats) return; + __sync_fetch_and_add(&stats->bytes_sent, sent); + __sync_fetch_and_add(&stats->bytes_recv, recv); +} diff --git a/src/monitor.h b/src/monitor.h new file mode 100644 index 0000000..4a4aa0b --- /dev/null +++ b/src/monitor.h @@ -0,0 +1,25 @@ +#ifndef RPROXY_MONITOR_H +#define RPROXY_MONITOR_H + +#include "types.h" + +extern system_monitor_t monitor; + +void monitor_init(const char *db_file); +void monitor_cleanup(void); +void monitor_update(void); +vhost_stats_t* monitor_get_or_create_vhost_stats(const char *vhost_name); +void monitor_record_request_start(vhost_stats_t *stats, int is_websocket); +void monitor_record_request_end(vhost_stats_t *stats, double start_time); +void monitor_record_bytes(vhost_stats_t *stats, long long sent, long long recv); + +void history_deque_init(history_deque_t *dq, int capacity); +void history_deque_push(history_deque_t *dq, double time, double value); +void network_history_deque_init(network_history_deque_t *dq, int capacity); +void network_history_deque_push(network_history_deque_t *dq, double time, double rx, double tx); +void disk_history_deque_init(disk_history_deque_t *dq, int capacity); +void disk_history_deque_push(disk_history_deque_t *dq, double time, double read_mbps, double write_mbps); +void request_time_deque_init(request_time_deque_t *dq, int capacity); +void request_time_deque_push(request_time_deque_t *dq, double time_ms); + +#endif diff --git a/src/ssl_handler.c b/src/ssl_handler.c new file mode 100644 index 0000000..9cacc6f --- /dev/null +++ b/src/ssl_handler.c @@ -0,0 +1,76 @@ +#include "ssl_handler.h" +#include "logging.h" +#include +#include + +SSL_CTX *ssl_ctx = NULL; + +void ssl_init(void) { + SSL_load_error_strings(); + OpenSSL_add_ssl_algorithms(); + ssl_ctx = SSL_CTX_new(TLS_client_method()); + if (!ssl_ctx) { + log_error("Failed to create SSL context"); + exit(EXIT_FAILURE); + } + + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, NULL); + SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + SSL_CTX_set_verify_depth(ssl_ctx, 0); +} + +void ssl_cleanup(void) { + if (ssl_ctx) { + SSL_CTX_free(ssl_ctx); + ssl_ctx = NULL; + } + EVP_cleanup(); +} + +int ssl_do_handshake(connection_t *conn) { + if (!conn->ssl || conn->ssl_handshake_done) return 1; + + int ret = SSL_do_handshake(conn->ssl); + if (ret == 1) { + conn->ssl_handshake_done = 1; + log_debug("SSL handshake completed for fd %d", conn->fd); + return 1; + } + + int ssl_error = SSL_get_error(conn->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { + return 0; + } + + log_debug("SSL handshake failed for fd %d: %d", conn->fd, ssl_error); + return -1; +} + +int ssl_read(connection_t *conn, char *buf, size_t len) { + if (!conn->ssl || !conn->ssl_handshake_done) return -1; + + int bytes_read = SSL_read(conn->ssl, buf, len); + if (bytes_read <= 0) { + int ssl_error = SSL_get_error(conn->ssl, bytes_read); + if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { + return 0; + } + return -1; + } + return bytes_read; +} + +int ssl_write(connection_t *conn, const char *buf, size_t len) { + if (!conn->ssl || !conn->ssl_handshake_done) return -1; + + int written = SSL_write(conn->ssl, buf, len); + if (written <= 0) { + int ssl_error = SSL_get_error(conn->ssl, written); + if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { + return 0; + } + return -1; + } + return written; +} diff --git a/src/ssl_handler.h b/src/ssl_handler.h new file mode 100644 index 0000000..1129b8c --- /dev/null +++ b/src/ssl_handler.h @@ -0,0 +1,14 @@ +#ifndef RPROXY_SSL_HANDLER_H +#define RPROXY_SSL_HANDLER_H + +#include "types.h" + +extern SSL_CTX *ssl_ctx; + +void ssl_init(void); +void ssl_cleanup(void); +int ssl_do_handshake(connection_t *conn); +int ssl_read(connection_t *conn, char *buf, size_t len); +int ssl_write(connection_t *conn, const char *buf, size_t len); + +#endif diff --git a/src/types.h b/src/types.h new file mode 100644 index 0000000..cdb75ea --- /dev/null +++ b/src/types.h @@ -0,0 +1,172 @@ +#ifndef RPROXY_TYPES_H +#define RPROXY_TYPES_H + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include +#include +#include +#include +#include + +#define MAX_EVENTS 4096 +#define MAX_FDS 65536 +#define CHUNK_SIZE 65536 +#define HISTORY_SECONDS 300 +#define MAX_HEADER_SIZE 8192 +#define MAX_REQUEST_LINE_SIZE 4096 +#define MAX_URI_SIZE 2048 +#define CONNECTION_TIMEOUT 300 + +typedef enum { + CONN_TYPE_UNUSED, + CONN_TYPE_LISTENER, + CONN_TYPE_CLIENT, + CONN_TYPE_UPSTREAM +} conn_type_t; + +typedef enum { + CLIENT_STATE_READING_HEADERS, + CLIENT_STATE_FORWARDING, + CLIENT_STATE_SERVING_INTERNAL, + CLIENT_STATE_ERROR, + CLIENT_STATE_CLOSING +} client_state_t; + +typedef struct { + char *data; + size_t capacity; + size_t head; + size_t tail; +} buffer_t; + +typedef struct { + char method[32]; + char uri[MAX_URI_SIZE]; + char version[16]; + char host[256]; + long content_length; + int is_websocket; + int keep_alive; + int connection_close; + bool is_chunked; +} http_request_t; + +struct connection_s; +struct vhost_stats_s; + +typedef struct connection_s { + conn_type_t type; + client_state_t state; + int fd; + struct connection_s *pair; + struct vhost_stats_s *vhost_stats; + buffer_t read_buf; + buffer_t write_buf; + SSL *ssl; + int ssl_handshake_done; + http_request_t request; + double request_start_time; + time_t last_activity; + int half_closed; + int write_shutdown; +} connection_t; + +typedef struct { + char hostname[256]; + char upstream_host[256]; + int upstream_port; + int use_ssl; + int rewrite_host; +} route_config_t; + +typedef struct { + int port; + route_config_t *routes; + int route_count; +} app_config_t; + +typedef struct { + double time; + double value; +} history_point_t; + +typedef struct { + history_point_t *points; + int capacity; + int head; + int count; +} history_deque_t; + +typedef struct { + double time; + double rx_kbps; + double tx_kbps; +} network_history_point_t; + +typedef struct { + network_history_point_t *points; + int capacity; + int head; + int count; +} network_history_deque_t; + +typedef struct { + double time; + double read_mbps; + double write_mbps; +} disk_history_point_t; + +typedef struct { + disk_history_point_t *points; + int capacity; + int head; + int count; +} disk_history_deque_t; + +typedef struct { + double *times; + int capacity; + int head; + int count; +} request_time_deque_t; + +typedef struct vhost_stats_s { + char vhost_name[256]; + long long http_requests; + long long websocket_requests; + long long total_requests; + long long bytes_sent; + long long bytes_recv; + double avg_request_time_ms; + long long last_bytes_sent; + long long last_bytes_recv; + double last_update; + history_deque_t throughput_history; + request_time_deque_t request_times; + struct vhost_stats_s *next; +} vhost_stats_t; + +typedef struct { + time_t start_time; + int active_connections; + history_deque_t cpu_history; + history_deque_t memory_history; + network_history_deque_t network_history; + disk_history_deque_t disk_history; + history_deque_t throughput_history; + history_deque_t load1_history; + history_deque_t load5_history; + history_deque_t load15_history; + long long last_net_sent; + long long last_net_recv; + long long last_disk_read; + long long last_disk_write; + double last_net_update_time; + double last_disk_update_time; + vhost_stats_t *vhost_stats_head; + sqlite3 *db; +} system_monitor_t; + +#endif diff --git a/tests/test_buffer.c b/tests/test_buffer.c new file mode 100644 index 0000000..b3d93f3 --- /dev/null +++ b/tests/test_buffer.c @@ -0,0 +1,159 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/buffer.h" + +void test_buffer_init(void) { + TEST_SUITE_BEGIN("Buffer Initialization"); + + buffer_t buf; + int result = buffer_init(&buf, 1024); + + TEST_ASSERT_EQ(0, result, "Buffer init returns 0"); + TEST_ASSERT(buf.data != NULL, "Buffer data is not NULL"); + TEST_ASSERT_EQ(1024, buf.capacity, "Buffer capacity is 1024"); + TEST_ASSERT_EQ(0, buf.head, "Buffer head is 0"); + TEST_ASSERT_EQ(0, buf.tail, "Buffer tail is 0"); + + buffer_free(&buf); + TEST_ASSERT(buf.data == NULL, "Buffer data is NULL after free"); + + TEST_SUITE_END(); +} + +void test_buffer_read_write(void) { + TEST_SUITE_BEGIN("Buffer Read/Write Operations"); + + buffer_t buf; + buffer_init(&buf, 1024); + + TEST_ASSERT_EQ(1024, buffer_available_write(&buf), "Initial write capacity is 1024"); + TEST_ASSERT_EQ(0, buffer_available_read(&buf), "Initial read capacity is 0"); + + const char *test_data = "Hello, World!"; + size_t len = strlen(test_data); + memcpy(buf.data + buf.tail, test_data, len); + buf.tail += len; + + TEST_ASSERT_EQ(len, buffer_available_read(&buf), "Read capacity equals written data"); + TEST_ASSERT_EQ(1024 - len, buffer_available_write(&buf), "Write capacity reduced"); + + buffer_free(&buf); + + TEST_SUITE_END(); +} + +void test_buffer_consume(void) { + TEST_SUITE_BEGIN("Buffer Consume Operations"); + + buffer_t buf; + buffer_init(&buf, 1024); + + const char *test_data = "0123456789"; + size_t len = strlen(test_data); + memcpy(buf.data + buf.tail, test_data, len); + buf.tail += len; + + buffer_consume(&buf, 5); + TEST_ASSERT_EQ(5, buffer_available_read(&buf), "5 bytes remaining after consume"); + TEST_ASSERT(memcmp(buf.data + buf.head, "56789", 5) == 0, "Correct data after consume"); + + buffer_consume(&buf, 5); + TEST_ASSERT_EQ(0, buffer_available_read(&buf), "Buffer empty after full consume"); + TEST_ASSERT_EQ(0, buf.head, "Head reset to 0"); + TEST_ASSERT_EQ(0, buf.tail, "Tail reset to 0"); + + buffer_free(&buf); + + TEST_SUITE_END(); +} + +void test_buffer_compact(void) { + TEST_SUITE_BEGIN("Buffer Compact Operations"); + + buffer_t buf; + buffer_init(&buf, 1024); + + const char *test_data = "ABCDEFGHIJ"; + size_t len = strlen(test_data); + memcpy(buf.data + buf.tail, test_data, len); + buf.tail += len; + + buffer_consume(&buf, 5); + TEST_ASSERT_EQ(5, buf.head, "Head moved to 5"); + + buffer_compact(&buf); + TEST_ASSERT_EQ(0, buf.head, "Head is 0 after compact"); + TEST_ASSERT_EQ(5, buf.tail, "Tail is 5 after compact"); + TEST_ASSERT(memcmp(buf.data, "FGHIJ", 5) == 0, "Data moved to beginning"); + + buffer_free(&buf); + + TEST_SUITE_END(); +} + +void test_buffer_ensure_capacity(void) { + TEST_SUITE_BEGIN("Buffer Ensure Capacity"); + + buffer_t buf; + buffer_init(&buf, 64); + + int result = buffer_ensure_capacity(&buf, 128); + TEST_ASSERT_EQ(0, result, "Capacity increase successful"); + TEST_ASSERT(buf.capacity >= 128, "Capacity at least 128"); + + result = buffer_ensure_capacity(&buf, 512); + TEST_ASSERT_EQ(0, result, "Second capacity increase successful"); + TEST_ASSERT(buf.capacity >= 512, "Capacity at least 512"); + + result = buffer_ensure_capacity(&buf, 64); + TEST_ASSERT_EQ(0, result, "No change when capacity already sufficient"); + + buffer_free(&buf); + + TEST_SUITE_END(); +} + +void test_buffer_multiple_operations(void) { + TEST_SUITE_BEGIN("Buffer Multiple Operations"); + + buffer_t buf; + buffer_init(&buf, 128); + + for (int i = 0; i < 10; i++) { + char data[10]; + snprintf(data, sizeof(data), "MSG%d", i); + size_t len = strlen(data); + + if (buffer_available_write(&buf) < len) { + buffer_compact(&buf); + } + + memcpy(buf.data + buf.tail, data, len); + buf.tail += len; + } + + TEST_ASSERT(buffer_available_read(&buf) > 0, "Buffer has data after multiple writes"); + + size_t total_read = 0; + while (buffer_available_read(&buf) > 0) { + size_t to_consume = buffer_available_read(&buf) > 5 ? 5 : buffer_available_read(&buf); + buffer_consume(&buf, to_consume); + total_read += to_consume; + } + + TEST_ASSERT(total_read > 0, "All data was consumed"); + TEST_ASSERT_EQ(0, buffer_available_read(&buf), "Buffer is empty"); + + buffer_free(&buf); + + TEST_SUITE_END(); +} + +void run_buffer_tests(void) { + test_buffer_init(); + test_buffer_read_write(); + test_buffer_consume(); + test_buffer_compact(); + test_buffer_ensure_capacity(); + test_buffer_multiple_operations(); +} diff --git a/tests/test_config.c b/tests/test_config.c new file mode 100644 index 0000000..d663848 --- /dev/null +++ b/tests/test_config.c @@ -0,0 +1,250 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/config.h" +#include +#include + +static const char *TEST_CONFIG_FILE = "/tmp/test_proxy_config.json"; + +static void create_test_config(const char *content) { + FILE *f = fopen(TEST_CONFIG_FILE, "w"); + if (f) { + fprintf(f, "%s", content); + fclose(f); + } +} + +static void cleanup_test_config(void) { + unlink(TEST_CONFIG_FILE); +} + +void test_config_load_valid(void) { + TEST_SUITE_BEGIN("Config Load Valid Configuration"); + + const char *valid_config = + "{\n" + " \"port\": 9090,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.example.com\",\n" + " \"upstream_host\": \"127.0.0.1\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": true\n" + " },\n" + " {\n" + " \"hostname\": \"api.example.com\",\n" + " \"upstream_host\": \"192.168.1.100\",\n" + " \"upstream_port\": 443,\n" + " \"use_ssl\": true,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(valid_config); + + int result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(1, result, "Config loaded successfully"); + TEST_ASSERT_EQ(9090, config.port, "Port is 9090"); + TEST_ASSERT_EQ(2, config.route_count, "Two routes configured"); + + route_config_t *route1 = config_find_route("test.example.com"); + TEST_ASSERT(route1 != NULL, "Route for test.example.com found"); + if (route1) { + TEST_ASSERT_STR_EQ("127.0.0.1", route1->upstream_host, "First route upstream host"); + TEST_ASSERT_EQ(3000, route1->upstream_port, "First route upstream port"); + TEST_ASSERT_EQ(0, route1->use_ssl, "First route SSL disabled"); + TEST_ASSERT_EQ(1, route1->rewrite_host, "First route host rewrite enabled"); + } + + route_config_t *route2 = config_find_route("api.example.com"); + TEST_ASSERT(route2 != NULL, "Route for api.example.com found"); + if (route2) { + TEST_ASSERT_STR_EQ("192.168.1.100", route2->upstream_host, "Second route upstream host"); + TEST_ASSERT_EQ(443, route2->upstream_port, "Second route upstream port"); + TEST_ASSERT_EQ(1, route2->use_ssl, "Second route SSL enabled"); + TEST_ASSERT_EQ(0, route2->rewrite_host, "Second route host rewrite disabled"); + } + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_find_route_case_insensitive(void) { + TEST_SUITE_BEGIN("Config Find Route Case Insensitive"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"Test.Example.COM\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + config_load(TEST_CONFIG_FILE); + + route_config_t *route1 = config_find_route("test.example.com"); + TEST_ASSERT(route1 != NULL, "Lowercase hostname matches"); + + route_config_t *route2 = config_find_route("TEST.EXAMPLE.COM"); + TEST_ASSERT(route2 != NULL, "Uppercase hostname matches"); + + route_config_t *route3 = config_find_route("TeSt.ExAmPlE.cOm"); + TEST_ASSERT(route3 != NULL, "Mixed case hostname matches"); + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_find_route_nonexistent(void) { + TEST_SUITE_BEGIN("Config Find Nonexistent Route"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"existing.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + config_load(TEST_CONFIG_FILE); + + route_config_t *route = config_find_route("nonexistent.com"); + TEST_ASSERT(route == NULL, "Nonexistent route returns NULL"); + + route = config_find_route(NULL); + TEST_ASSERT(route == NULL, "NULL hostname returns NULL"); + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_default_port(void) { + TEST_SUITE_BEGIN("Config Default Port"); + + const char *config_content = + "{\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"test.com\",\n" + " \"upstream_host\": \"localhost\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + config_load(TEST_CONFIG_FILE); + + TEST_ASSERT_EQ(8080, config.port, "Default port is 8080 when not specified"); + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_invalid_json(void) { + TEST_SUITE_BEGIN("Config Invalid JSON"); + + const char *invalid_config = "{ invalid json }"; + + create_test_config(invalid_config); + + int result = config_load(TEST_CONFIG_FILE); + TEST_ASSERT_EQ(0, result, "Invalid JSON returns 0"); + + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void test_config_missing_file(void) { + TEST_SUITE_BEGIN("Config Missing File"); + + unlink("/tmp/nonexistent_config.json"); + int result = config_load("/tmp/nonexistent_config.json"); + TEST_ASSERT_EQ(0, result, "Missing file returns 0"); + + TEST_SUITE_END(); +} + +void test_config_ssl_rewrite_host_options(void) { + TEST_SUITE_BEGIN("Config SSL and Rewrite Host Options"); + + const char *config_content = + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"https-rewrite.com\",\n" + " \"upstream_host\": \"secure.example.com\",\n" + " \"upstream_port\": 443,\n" + " \"use_ssl\": true,\n" + " \"rewrite_host\": true\n" + " },\n" + " {\n" + " \"hostname\": \"http-norewrite.com\",\n" + " \"upstream_host\": \"plain.example.com\",\n" + " \"upstream_port\": 80,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"; + + create_test_config(config_content); + config_load(TEST_CONFIG_FILE); + + route_config_t *ssl_route = config_find_route("https-rewrite.com"); + TEST_ASSERT(ssl_route != NULL, "SSL route found"); + if (ssl_route) { + TEST_ASSERT_EQ(1, ssl_route->use_ssl, "SSL enabled for https route"); + TEST_ASSERT_EQ(1, ssl_route->rewrite_host, "Host rewrite enabled for https route"); + } + + route_config_t *plain_route = config_find_route("http-norewrite.com"); + TEST_ASSERT(plain_route != NULL, "Plain route found"); + if (plain_route) { + TEST_ASSERT_EQ(0, plain_route->use_ssl, "SSL disabled for http route"); + TEST_ASSERT_EQ(0, plain_route->rewrite_host, "Host rewrite disabled for http route"); + } + + config_free(); + cleanup_test_config(); + + TEST_SUITE_END(); +} + +void run_config_tests(void) { + test_config_load_valid(); + test_config_find_route_case_insensitive(); + test_config_find_route_nonexistent(); + test_config_default_port(); + test_config_invalid_json(); + test_config_missing_file(); + test_config_ssl_rewrite_host_options(); +} diff --git a/tests/test_framework.h b/tests/test_framework.h new file mode 100644 index 0000000..df2b5c9 --- /dev/null +++ b/tests/test_framework.h @@ -0,0 +1,55 @@ +#ifndef TEST_FRAMEWORK_H +#define TEST_FRAMEWORK_H + +#include +#include +#include + +extern int tests_run; +extern int tests_passed; +extern int tests_failed; + +#define TEST_ASSERT(condition, message) do { \ + tests_run++; \ + if (condition) { \ + tests_passed++; \ + printf(" [PASS] %s\n", message); \ + } else { \ + tests_failed++; \ + printf(" [FAIL] %s (line %d)\n", message, __LINE__); \ + } \ +} while(0) + +#define TEST_ASSERT_EQ(expected, actual, message) do { \ + tests_run++; \ + if ((expected) == (actual)) { \ + tests_passed++; \ + printf(" [PASS] %s\n", message); \ + } else { \ + tests_failed++; \ + printf(" [FAIL] %s (expected %d, got %d, line %d)\n", message, (int)(expected), (int)(actual), __LINE__); \ + } \ +} while(0) + +#define TEST_ASSERT_STR_EQ(expected, actual, message) do { \ + tests_run++; \ + if (strcmp(expected, actual) == 0) { \ + tests_passed++; \ + printf(" [PASS] %s\n", message); \ + } else { \ + tests_failed++; \ + printf(" [FAIL] %s (expected '%s', got '%s', line %d)\n", message, expected, actual, __LINE__); \ + } \ +} while(0) + +#define TEST_SUITE_BEGIN(name) do { \ + printf("\n=== Test Suite: %s ===\n", name); \ +} while(0) + +#define TEST_SUITE_END() do { \ + printf("\n"); \ +} while(0) + +void test_summary(void); + +#endif diff --git a/tests/test_http.c b/tests/test_http.c new file mode 100644 index 0000000..2cee81e --- /dev/null +++ b/tests/test_http.c @@ -0,0 +1,196 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/http.h" + +void test_http_parse_get_request(void) { + TEST_SUITE_BEGIN("HTTP GET Request Parsing"); + + http_request_t req; + const char *request = "GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n"; + int result = http_parse_request(request, strlen(request), &req); + + TEST_ASSERT_EQ(1, result, "Parse GET request"); + TEST_ASSERT_STR_EQ("GET", req.method, "Method is GET"); + TEST_ASSERT_STR_EQ("/path/to/resource", req.uri, "URI parsed correctly"); + TEST_ASSERT_STR_EQ("HTTP/1.1", req.version, "Version is HTTP/1.1"); + TEST_ASSERT_STR_EQ("example.com", req.host, "Host header parsed"); + TEST_ASSERT_EQ(1, req.keep_alive, "Keep-alive is enabled"); + + TEST_SUITE_END(); +} + +void test_http_parse_post_request(void) { + TEST_SUITE_BEGIN("HTTP POST Request Parsing"); + + http_request_t req; + const char *request = "POST /api/data HTTP/1.1\r\nHost: api.example.com\r\nContent-Length: 100\r\nConnection: close\r\n\r\n"; + int result = http_parse_request(request, strlen(request), &req); + + TEST_ASSERT_EQ(1, result, "Parse POST request"); + TEST_ASSERT_STR_EQ("POST", req.method, "Method is POST"); + TEST_ASSERT_STR_EQ("/api/data", req.uri, "URI parsed correctly"); + TEST_ASSERT_EQ(100, req.content_length, "Content-Length parsed"); + TEST_ASSERT_EQ(0, req.keep_alive, "Keep-alive is disabled"); + TEST_ASSERT_EQ(1, req.connection_close, "Connection close flag set"); + + TEST_SUITE_END(); +} + +void test_http_parse_webdav_methods(void) { + TEST_SUITE_BEGIN("HTTP WebDAV Methods Parsing"); + + http_request_t req; + + const char *propfind = "PROPFIND /folder HTTP/1.1\r\nHost: webdav.example.com\r\n\r\n"; + int result = http_parse_request(propfind, strlen(propfind), &req); + TEST_ASSERT_EQ(1, result, "Parse PROPFIND request"); + TEST_ASSERT_STR_EQ("PROPFIND", req.method, "Method is PROPFIND"); + + const char *mkcol = "MKCOL /newfolder HTTP/1.1\r\nHost: webdav.example.com\r\n\r\n"; + result = http_parse_request(mkcol, strlen(mkcol), &req); + TEST_ASSERT_EQ(1, result, "Parse MKCOL request"); + TEST_ASSERT_STR_EQ("MKCOL", req.method, "Method is MKCOL"); + + const char *move = "MOVE /source HTTP/1.1\r\nHost: webdav.example.com\r\n\r\n"; + result = http_parse_request(move, strlen(move), &req); + TEST_ASSERT_EQ(1, result, "Parse MOVE request"); + TEST_ASSERT_STR_EQ("MOVE", req.method, "Method is MOVE"); + + const char *copy = "COPY /source HTTP/1.1\r\nHost: webdav.example.com\r\n\r\n"; + result = http_parse_request(copy, strlen(copy), &req); + TEST_ASSERT_EQ(1, result, "Parse COPY request"); + TEST_ASSERT_STR_EQ("COPY", req.method, "Method is COPY"); + + const char *lock = "LOCK /resource HTTP/1.1\r\nHost: webdav.example.com\r\n\r\n"; + result = http_parse_request(lock, strlen(lock), &req); + TEST_ASSERT_EQ(1, result, "Parse LOCK request"); + TEST_ASSERT_STR_EQ("LOCK", req.method, "Method is LOCK"); + + const char *unlock = "UNLOCK /resource HTTP/1.1\r\nHost: webdav.example.com\r\n\r\n"; + result = http_parse_request(unlock, strlen(unlock), &req); + TEST_ASSERT_EQ(1, result, "Parse UNLOCK request"); + TEST_ASSERT_STR_EQ("UNLOCK", req.method, "Method is UNLOCK"); + + TEST_SUITE_END(); +} + +void test_http_parse_custom_methods(void) { + TEST_SUITE_BEGIN("HTTP Custom Methods Parsing"); + + http_request_t req; + + const char *custom1 = "MYMETHOD /path HTTP/1.1\r\nHost: example.com\r\n\r\n"; + int result = http_parse_request(custom1, strlen(custom1), &req); + TEST_ASSERT_EQ(1, result, "Parse custom MYMETHOD request"); + TEST_ASSERT_STR_EQ("MYMETHOD", req.method, "Method is MYMETHOD"); + + const char *custom2 = "FOOBAR /path HTTP/1.1\r\nHost: example.com\r\n\r\n"; + result = http_parse_request(custom2, strlen(custom2), &req); + TEST_ASSERT_EQ(1, result, "Parse custom FOOBAR request"); + TEST_ASSERT_STR_EQ("FOOBAR", req.method, "Method is FOOBAR"); + + TEST_SUITE_END(); +} + +void test_http_parse_websocket_upgrade(void) { + TEST_SUITE_BEGIN("HTTP WebSocket Upgrade Parsing"); + + http_request_t req; + const char *request = "GET /ws HTTP/1.1\r\n" + "Host: example.com\r\n" + "Upgrade: websocket\r\n" + "Connection: upgrade\r\n" + "\r\n"; + int result = http_parse_request(request, strlen(request), &req); + + TEST_ASSERT_EQ(1, result, "Parse WebSocket upgrade request"); + TEST_ASSERT_STR_EQ("GET", req.method, "Method is GET"); + TEST_ASSERT_EQ(1, req.is_websocket, "WebSocket flag is set"); + + TEST_SUITE_END(); +} + +void test_http_parse_chunked_encoding(void) { + TEST_SUITE_BEGIN("HTTP Chunked Encoding Parsing"); + + http_request_t req; + const char *request = "POST /api/upload HTTP/1.1\r\n" + "Host: example.com\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n"; + int result = http_parse_request(request, strlen(request), &req); + + TEST_ASSERT_EQ(1, result, "Parse chunked request"); + TEST_ASSERT_EQ(1, req.is_chunked, "Chunked flag is set"); + + TEST_SUITE_END(); +} + +void test_http_parse_http10(void) { + TEST_SUITE_BEGIN("HTTP/1.0 Request Parsing"); + + http_request_t req; + const char *request = "GET /path HTTP/1.0\r\nHost: example.com\r\n\r\n"; + int result = http_parse_request(request, strlen(request), &req); + + TEST_ASSERT_EQ(1, result, "Parse HTTP/1.0 request"); + TEST_ASSERT_STR_EQ("HTTP/1.0", req.version, "Version is HTTP/1.0"); + TEST_ASSERT_EQ(0, req.keep_alive, "Keep-alive is disabled by default for HTTP/1.0"); + + TEST_SUITE_END(); +} + +void test_http_parse_host_with_port(void) { + TEST_SUITE_BEGIN("HTTP Host with Port Parsing"); + + http_request_t req; + const char *request = "GET / HTTP/1.1\r\nHost: example.com:8080\r\n\r\n"; + int result = http_parse_request(request, strlen(request), &req); + + TEST_ASSERT_EQ(1, result, "Parse request with host:port"); + TEST_ASSERT_STR_EQ("example.com", req.host, "Port stripped from host"); + + TEST_SUITE_END(); +} + +void test_http_is_request_start(void) { + TEST_SUITE_BEGIN("HTTP Request Start Detection"); + + TEST_ASSERT_EQ(1, http_is_request_start("GET / HTTP/1.1", 14), "GET is valid request start"); + TEST_ASSERT_EQ(1, http_is_request_start("POST /api HTTP/1.1", 18), "POST is valid request start"); + TEST_ASSERT_EQ(1, http_is_request_start("PROPFIND / HTTP/1.1", 19), "PROPFIND is valid request start"); + TEST_ASSERT_EQ(1, http_is_request_start("CUSTOMMETHOD / HTTP/1.1", 23), "Custom method is valid request start"); + TEST_ASSERT_EQ(0, http_is_request_start("123 invalid", 11), "Numbers at start are invalid"); + TEST_ASSERT_EQ(0, http_is_request_start("abc", 3), "Too short is invalid"); + + TEST_SUITE_END(); +} + +void test_http_malformed_requests(void) { + TEST_SUITE_BEGIN("HTTP Malformed Request Handling"); + + http_request_t req; + + const char *no_space = "GETHTTP/1.1\r\nHost: example.com\r\n\r\n"; + int result = http_parse_request(no_space, strlen(no_space), &req); + TEST_ASSERT_EQ(0, result, "Reject request without space after method"); + + const char *no_version = "GET /path\r\nHost: example.com\r\n\r\n"; + result = http_parse_request(no_version, strlen(no_version), &req); + TEST_ASSERT_EQ(0, result, "Reject request without version"); + + TEST_SUITE_END(); +} + +void run_http_tests(void) { + test_http_parse_get_request(); + test_http_parse_post_request(); + test_http_parse_webdav_methods(); + test_http_parse_custom_methods(); + test_http_parse_websocket_upgrade(); + test_http_parse_chunked_encoding(); + test_http_parse_http10(); + test_http_parse_host_with_port(); + test_http_is_request_start(); + test_http_malformed_requests(); +} diff --git a/tests/test_main.c b/tests/test_main.c new file mode 100644 index 0000000..6832b5b --- /dev/null +++ b/tests/test_main.c @@ -0,0 +1,41 @@ +#include "test_framework.h" +#include + +int tests_run = 0; +int tests_passed = 0; +int tests_failed = 0; + +void test_summary(void) { + printf("\n=========================================\n"); + printf("Test Results: %d/%d passed\n", tests_passed, tests_run); + if (tests_failed > 0) { + printf("FAILED: %d tests failed\n", tests_failed); + } else { + printf("SUCCESS: All tests passed\n"); + } + printf("=========================================\n"); +} + +extern void run_http_tests(void); +extern void run_buffer_tests(void); +extern void run_config_tests(void); +extern void run_routing_tests(void); + +int main(int argc, char *argv[]) { + (void)argc; + (void)argv; + + printf("\n"); + printf("=========================================\n"); + printf(" RProxy Enterprise Unit Tests\n"); + printf("=========================================\n"); + + run_buffer_tests(); + run_http_tests(); + run_config_tests(); + run_routing_tests(); + + test_summary(); + + return tests_failed > 0 ? 1 : 0; +} diff --git a/tests/test_routing.c b/tests/test_routing.c new file mode 100644 index 0000000..e1d05e7 --- /dev/null +++ b/tests/test_routing.c @@ -0,0 +1,258 @@ +#include "test_framework.h" +#include "../src/types.h" +#include "../src/http.h" +#include "../src/config.h" +#include +#include + +static const char *TEST_CONFIG_FILE = "/tmp/test_routing_config.json"; + +static void create_routing_config(void) { + FILE *f = fopen(TEST_CONFIG_FILE, "w"); + if (f) { + fprintf(f, + "{\n" + " \"port\": 8080,\n" + " \"reverse_proxy\": [\n" + " {\n" + " \"hostname\": \"api.local\",\n" + " \"upstream_host\": \"backend-api\",\n" + " \"upstream_port\": 3000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": true\n" + " },\n" + " {\n" + " \"hostname\": \"secure.local\",\n" + " \"upstream_host\": \"backend-secure\",\n" + " \"upstream_port\": 443,\n" + " \"use_ssl\": true,\n" + " \"rewrite_host\": true\n" + " },\n" + " {\n" + " \"hostname\": \"passthrough.local\",\n" + " \"upstream_host\": \"backend-pass\",\n" + " \"upstream_port\": 8000,\n" + " \"use_ssl\": false,\n" + " \"rewrite_host\": false\n" + " }\n" + " ]\n" + "}\n"); + fclose(f); + } +} + +static void cleanup_routing_config(void) { + unlink(TEST_CONFIG_FILE); +} + +void test_routing_host_rewrite(void) { + TEST_SUITE_BEGIN("Routing Host Rewrite Logic"); + + create_routing_config(); + config_load(TEST_CONFIG_FILE); + + route_config_t *route = config_find_route("api.local"); + TEST_ASSERT(route != NULL, "Route for api.local found"); + if (route) { + TEST_ASSERT_EQ(1, route->rewrite_host, "Host rewrite enabled for api.local"); + TEST_ASSERT_STR_EQ("backend-api", route->upstream_host, "Upstream host is backend-api"); + TEST_ASSERT_EQ(3000, route->upstream_port, "Upstream port is 3000"); + } + + route = config_find_route("passthrough.local"); + TEST_ASSERT(route != NULL, "Route for passthrough.local found"); + if (route) { + TEST_ASSERT_EQ(0, route->rewrite_host, "Host rewrite disabled for passthrough.local"); + } + + config_free(); + cleanup_routing_config(); + + TEST_SUITE_END(); +} + +void test_routing_ssl_upstream(void) { + TEST_SUITE_BEGIN("Routing SSL Upstream"); + + create_routing_config(); + config_load(TEST_CONFIG_FILE); + + route_config_t *route = config_find_route("secure.local"); + TEST_ASSERT(route != NULL, "Route for secure.local found"); + if (route) { + TEST_ASSERT_EQ(1, route->use_ssl, "SSL enabled for secure.local"); + TEST_ASSERT_EQ(443, route->upstream_port, "SSL port is 443"); + } + + route = config_find_route("api.local"); + TEST_ASSERT(route != NULL, "Route for api.local found"); + if (route) { + TEST_ASSERT_EQ(0, route->use_ssl, "SSL disabled for api.local"); + } + + config_free(); + cleanup_routing_config(); + + TEST_SUITE_END(); +} + +void test_routing_dashboard_detection(void) { + TEST_SUITE_BEGIN("Routing Dashboard Detection"); + + http_request_t req; + + const char *dashboard_req = "GET /rproxy/dashboard HTTP/1.1\r\nHost: anyhost.com\r\n\r\n"; + http_parse_request(dashboard_req, strlen(dashboard_req), &req); + int is_dashboard = (strncmp(req.uri, "/rproxy/dashboard", 17) == 0); + TEST_ASSERT_EQ(1, is_dashboard, "Dashboard URI detected"); + + const char *api_req = "GET /rproxy/api/stats HTTP/1.1\r\nHost: anyhost.com\r\n\r\n"; + http_parse_request(api_req, strlen(api_req), &req); + int is_api = (strncmp(req.uri, "/rproxy/api/stats", 17) == 0); + TEST_ASSERT_EQ(1, is_api, "API stats URI detected"); + + const char *regular_req = "GET /some/other/path HTTP/1.1\r\nHost: api.local\r\n\r\n"; + http_parse_request(regular_req, strlen(regular_req), &req); + int is_regular = (strncmp(req.uri, "/rproxy/", 8) != 0); + TEST_ASSERT_EQ(1, is_regular, "Regular path not matched as internal"); + + TEST_SUITE_END(); +} + +void test_routing_keep_alive_handling(void) { + TEST_SUITE_BEGIN("Routing Keep-Alive Handling"); + + http_request_t req; + + const char *ka_req = "GET /api HTTP/1.1\r\nHost: api.local\r\nConnection: keep-alive\r\n\r\n"; + http_parse_request(ka_req, strlen(ka_req), &req); + TEST_ASSERT_EQ(1, req.keep_alive, "Keep-alive flag set when Connection: keep-alive"); + + const char *close_req = "GET /api HTTP/1.1\r\nHost: api.local\r\nConnection: close\r\n\r\n"; + http_parse_request(close_req, strlen(close_req), &req); + TEST_ASSERT_EQ(0, req.keep_alive, "Keep-alive flag not set when Connection: close"); + + const char *default_req = "GET /api HTTP/1.1\r\nHost: api.local\r\n\r\n"; + http_parse_request(default_req, strlen(default_req), &req); + TEST_ASSERT_EQ(1, req.keep_alive, "Keep-alive default for HTTP/1.1"); + + const char *http10_req = "GET /api HTTP/1.0\r\nHost: api.local\r\n\r\n"; + http_parse_request(http10_req, strlen(http10_req), &req); + TEST_ASSERT_EQ(0, req.keep_alive, "Keep-alive default off for HTTP/1.0"); + + TEST_SUITE_END(); +} + +void test_routing_all_methods_accepted(void) { + TEST_SUITE_BEGIN("Routing All HTTP Methods Accepted"); + + http_request_t req; + int result; + + const char *methods[] = { + "GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", "TRACE", "CONNECT", + "PROPFIND", "PROPPATCH", "MKCOL", "MOVE", "COPY", "LOCK", "UNLOCK", + "SEARCH", "REPORT", "MKACTIVITY", "CHECKOUT", "MERGE", + "NOTIFY", "SUBSCRIBE", "UNSUBSCRIBE", + "CUSTOMMETHOD", "FOOBAR", "MYPROTO" + }; + + int num_methods = sizeof(methods) / sizeof(methods[0]); + + for (int i = 0; i < num_methods; i++) { + char request[256]; + snprintf(request, sizeof(request), "%s /path HTTP/1.1\r\nHost: test.com\r\n\r\n", methods[i]); + result = http_parse_request(request, strlen(request), &req); + + char msg[128]; + snprintf(msg, sizeof(msg), "Method %s accepted", methods[i]); + TEST_ASSERT_EQ(1, result, msg); + TEST_ASSERT_STR_EQ(methods[i], req.method, msg); + } + + TEST_SUITE_END(); +} + +void test_routing_pipelined_requests(void) { + TEST_SUITE_BEGIN("Routing Pipelined Request Detection"); + + http_request_t req1, req2; + + const char *first_req = "GET /first HTTP/1.1\r\nHost: api.local\r\nConnection: keep-alive\r\n\r\n"; + int result1 = http_parse_request(first_req, strlen(first_req), &req1); + TEST_ASSERT_EQ(1, result1, "First request parsed"); + TEST_ASSERT_STR_EQ("/first", req1.uri, "First request URI"); + TEST_ASSERT_EQ(1, req1.keep_alive, "First request keep-alive"); + + const char *second_req = "GET /second HTTP/1.1\r\nHost: api.local\r\n\r\n"; + int result2 = http_parse_request(second_req, strlen(second_req), &req2); + TEST_ASSERT_EQ(1, result2, "Second request parsed"); + TEST_ASSERT_STR_EQ("/second", req2.uri, "Second request URI"); + + int is_new_request = http_is_request_start(second_req, strlen(second_req)); + TEST_ASSERT_EQ(1, is_new_request, "New request detected for pipelining"); + + TEST_SUITE_END(); +} + +void test_routing_websocket_upgrade(void) { + TEST_SUITE_BEGIN("Routing WebSocket Upgrade"); + + http_request_t req; + + const char *ws_req = + "GET /ws/chat HTTP/1.1\r\n" + "Host: api.local\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + + int result = http_parse_request(ws_req, strlen(ws_req), &req); + TEST_ASSERT_EQ(1, result, "WebSocket upgrade request parsed"); + TEST_ASSERT_EQ(1, req.is_websocket, "WebSocket flag set"); + TEST_ASSERT_STR_EQ("/ws/chat", req.uri, "WebSocket URI parsed"); + + TEST_SUITE_END(); +} + +void test_routing_chunked_transfer(void) { + TEST_SUITE_BEGIN("Routing Chunked Transfer Encoding"); + + http_request_t req; + + const char *chunked_req = + "POST /upload HTTP/1.1\r\n" + "Host: api.local\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n"; + + int result = http_parse_request(chunked_req, strlen(chunked_req), &req); + TEST_ASSERT_EQ(1, result, "Chunked request parsed"); + TEST_ASSERT_EQ(1, req.is_chunked, "Chunked flag set"); + + const char *non_chunked_req = + "POST /upload HTTP/1.1\r\n" + "Host: api.local\r\n" + "Content-Length: 100\r\n" + "\r\n"; + + result = http_parse_request(non_chunked_req, strlen(non_chunked_req), &req); + TEST_ASSERT_EQ(1, result, "Non-chunked request parsed"); + TEST_ASSERT_EQ(0, req.is_chunked, "Chunked flag not set"); + TEST_ASSERT_EQ(100, req.content_length, "Content-Length parsed"); + + TEST_SUITE_END(); +} + +void run_routing_tests(void) { + test_routing_host_rewrite(); + test_routing_ssl_upstream(); + test_routing_dashboard_detection(); + test_routing_keep_alive_handling(); + test_routing_all_methods_accepted(); + test_routing_pipelined_requests(); + test_routing_websocket_upgrade(); + test_routing_chunked_transfer(); +}