diff --git a/rproxy.c b/rproxy.c index 55c8e73..dde70bb 100644 --- a/rproxy.c +++ b/rproxy.c @@ -6,11 +6,11 @@ * Description: High-performance Reverse Proxy with Real-time Monitoring Dashboard. * Single-file C implementation using epoll - PRODUCTION READY VERSION * - * Version: 1.5 (Compiler Warnings Fixed) + * Version: 2.1 (Forwarding and State Machine Corrected) * Created: 2024 * Compiler: gcc * - * Author: Fixed Version + * Author: Corrected Version * * To Compile: * gcc -Wall -Wextra -O2 -g -pthread cJSON.c proxy.c -o rproxy -lssl -lcrypto -lsqlite3 -lm @@ -64,13 +64,9 @@ typedef enum { } conn_type_t; typedef enum { - CLIENT_STATE_READING_REQUEST_LINE, // Initial state, reading headers - CLIENT_STATE_READING_HEADERS, // Legacy, handled by above - CLIENT_STATE_READING_BODY, // New state for handling request bodies - CLIENT_STATE_CONNECTING_UPSTREAM, - CLIENT_STATE_TUNNELING, // Full-duplex forwarding state - CLIENT_STATE_DONE, - CLIENT_STATE_ERROR + CLIENT_STATE_READING_HEADERS, // Reading the initial HTTP request + CLIENT_STATE_FORWARDING, // Actively forwarding data in both directions + CLIENT_STATE_ERROR // An error occurred, connection will be closed } client_state_t; struct connection_s; @@ -90,10 +86,12 @@ typedef struct { int route_count; } app_config_t; +// High-performance buffer using head/tail pointers to avoid memmove typedef struct { char *data; - size_t size; size_t capacity; + size_t head; // Read position + size_t tail; // Write position } buffer_t; typedef struct { @@ -101,10 +99,9 @@ typedef struct { char uri[MAX_URI_SIZE]; char version[16]; char host[256]; - int content_length; + long content_length; int is_websocket; int keep_alive; - int connection_upgrade; } http_request_t; typedef struct connection_s { @@ -118,11 +115,6 @@ typedef struct connection_s { SSL *ssl; int ssl_handshake_done; http_request_t request; - char *http_body_start; - long http_content_length; - long http_body_read; - int is_websocket; - int keep_alive; double request_start_time; time_t last_activity; } connection_t; @@ -202,20 +194,15 @@ static int g_debug_mode = 0; void log_error(const char *msg); void log_info(const char *format, ...); void log_debug(const char *format, ...); -int load_config(const char *filename); -void free_config(); -route_config_t *find_route(const char *hostname); void setup_listener_socket(int port); void accept_new_connection(int listener_fd); void close_connection(int fd); void handle_connection_event(struct epoll_event *event); -void connect_to_upstream(connection_t *client_conn, route_config_t *route); -void handle_client_data(connection_t *conn); -void handle_client_body(connection_t *conn); -int parse_http_request_line(const char *line, http_request_t *request); -int parse_http_headers(const char *headers, http_request_t *request); -int find_header_value(const char *headers, const char *header_name, char *value, size_t value_size); +void connect_to_upstream(connection_t *client_conn, const char *data, size_t data_len); +void handle_client_read(connection_t *conn); +int parse_http_request(const char *data, size_t len, http_request_t *req); void monitor_init(const char *db_file); +void free_config(); void monitor_update(); void monitor_cleanup(); void monitor_record_request_start(vhost_stats_t *stats, int is_websocket); @@ -234,7 +221,8 @@ static inline int buffer_init(buffer_t *buf, size_t capacity) { return -1; } buf->capacity = capacity; - buf->size = 0; + buf->head = 0; + buf->tail = 0; return 0; } @@ -244,7 +232,8 @@ static inline void buffer_free(buffer_t *buf) { buf->data = NULL; } buf->capacity = 0; - buf->size = 0; + buf->head = 0; + buf->tail = 0; } static inline int buffer_ensure_capacity(buffer_t *buf, size_t required) { @@ -265,36 +254,18 @@ static inline int buffer_ensure_capacity(buffer_t *buf, size_t required) { return 0; } +// Compacts the buffer by moving unread data to the beginning +static inline void buffer_compact(buffer_t *buf) { + if (buf->head == 0) return; + size_t len = buf->tail - buf->head; + if (len > 0) { + memmove(buf->data, buf->data + buf->head, len); + } + buf->head = 0; + buf->tail = len; +} + // --- String Utilities --- -static char* safe_strdup(const char *str, size_t max_len) { - if (!str) return NULL; - size_t len = strnlen(str, max_len); - char *dup = malloc(len + 1); - if (!dup) return NULL; - memcpy(dup, str, len); - dup[len] = '\0'; - return dup; -} - -static void trim_whitespace(char *str) { - if (!str) return; - - // Trim leading whitespace - char *start = str; - while (isspace((unsigned char)*start)) start++; - - if (start != str) { - memmove(str, start, strlen(start) + 1); - } - - // Trim trailing whitespace - char *end = str + strlen(str) - 1; - while (end >= str && isspace((unsigned char)*end)) { - *end = '\0'; - end--; - } -} - static int is_valid_http_method(const char *method) { const char *valid_methods[] = {"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH", "TRACE", "CONNECT", NULL}; @@ -304,13 +275,9 @@ static int is_valid_http_method(const char *method) { return 0; } -static int is_valid_http_version(const char *version) { - return (strcmp(version, "HTTP/1.0") == 0 || strcmp(version, "HTTP/1.1") == 0); -} - // ================================================================================================= // -// Logging Implementation +// Logging Implementation // // ================================================================================================= @@ -347,7 +314,7 @@ void log_debug(const char *format, ...) { // ================================================================================================= // -// Configuration Implementation +// Configuration Implementation // // ================================================================================================= @@ -449,14 +416,6 @@ int load_config(const char *filename) { return 1; } -void free_config() { - if (config.routes) { - free(config.routes); - config.routes = NULL; - } - config.route_count = 0; -} - route_config_t *find_route(const char *hostname) { if (!hostname) return NULL; for (int i = 0; i < config.route_count; i++) { @@ -469,7 +428,7 @@ route_config_t *find_route(const char *hostname) { // ================================================================================================= // -// Monitor Implementation +// Monitor Implementation // // ================================================================================================= static void history_deque_init(history_deque_t *dq, int capacity) { @@ -800,7 +759,7 @@ void monitor_record_request_end(vhost_stats_t *stats, double start_time) { // ================================================================================================= // -// Dashboard Implementation +// Dashboard Implementation // // ================================================================================================= const char *DASHBOARD_HTML = @@ -987,10 +946,10 @@ const char *DASHBOARD_HTML = " if (!canvas) return;\n" " if (!window.vhostCharts[chartId]) {\n" " const colors = [\n" -" { border: '#3498db', bg: 'rgba(52, 152, 219, 0.1)' }, { border: '#2ecc71', bg: 'rgba(46, 204, 113, 0.1)' },\n" -" { border: '#f39c12', bg: 'rgba(243, 156, 18, 0.1)' }, { border: '#e74c3c', bg: 'rgba(231, 76, 60, 0.1)' },\n" -" { border: '#9b59b6', bg: 'rgba(155, 89, 182, 0.1)' }, { border: '#1abc9c', bg: 'rgba(26, 188, 156, 0.1)' }\n" -" ];\n" +" { border: '#3498db', bg: 'rgba(52, 152, 219, 0.1)' }, { border: '#2ecc71', bg: 'rgba(46, 204, 113, 0.1)' },\n" +" { border: '#f39c12', bg: 'rgba(243, 156, 18, 0.1)' }, { border: '#e74c3c', bg: 'rgba(231, 76, 60, 0.1)' },\n" +" { border: '#9b59b6', bg: 'rgba(155, 89, 182, 0.1)' }, { border: '#1abc9c', bg: 'rgba(26, 188, 156, 0.1)' }\n" +" ];\n" " const color = colors[index % colors.length];\n" " const options = { ...createBaseChartOptions(60), scales: { ...createBaseChartOptions(60).scales, y: { ...createBaseChartOptions(60).scales.y, ticks: { ...createBaseChartOptions(60).scales.y.ticks, callback: formatSizeTick } } } };\n" " window.vhostCharts[chartId] = new Chart(canvas.getContext('2d'), {\n" @@ -1026,20 +985,18 @@ void serve_dashboard(connection_t *conn) { "Cache-Control: no-cache\r\n" "\r\n", strlen(DASHBOARD_HTML)); - if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.size + len + strlen(DASHBOARD_HTML)) < 0) { + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len + strlen(DASHBOARD_HTML)) < 0) { send_error_response(conn, 500, "Internal Server Error", "Memory allocation failed"); return; } - memcpy(conn->write_buf.data + conn->write_buf.size, header, len); - conn->write_buf.size += len; - memcpy(conn->write_buf.data + conn->write_buf.size, DASHBOARD_HTML, strlen(DASHBOARD_HTML)); - conn->write_buf.size += strlen(DASHBOARD_HTML); + memcpy(conn->write_buf.data + conn->write_buf.tail, header, len); + conn->write_buf.tail += len; + memcpy(conn->write_buf.data + conn->write_buf.tail, DASHBOARD_HTML, strlen(DASHBOARD_HTML)); + conn->write_buf.tail += strlen(DASHBOARD_HTML); struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); - conn->read_buf.size = 0; - conn->keep_alive = 1; } static cJSON* format_history(history_deque_t *dq, int window_seconds) { @@ -1163,165 +1120,136 @@ void serve_stats_api(connection_t *conn) { "Cache-Control: no-cache\r\n" "\r\n", strlen(json_string)); - if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.size + hlen + strlen(json_string)) < 0) { + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + hlen + strlen(json_string)) < 0) { free(json_string); cJSON_Delete(root); send_error_response(conn, 500, "Internal Server Error", "Memory allocation failed"); return; } - memcpy(conn->write_buf.data + conn->write_buf.size, header, hlen); - conn->write_buf.size += hlen; - memcpy(conn->write_buf.data + conn->write_buf.size, json_string, strlen(json_string)); - conn->write_buf.size += strlen(json_string); + memcpy(conn->write_buf.data + conn->write_buf.tail, header, hlen); + conn->write_buf.tail += hlen; + memcpy(conn->write_buf.data + conn->write_buf.tail, json_string, strlen(json_string)); + conn->write_buf.tail += strlen(json_string); struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); cJSON_Delete(root); free(json_string); - conn->read_buf.size = 0; - conn->keep_alive = 1; } // ================================================================================================= // -// HTTP Implementation +// HTTP Implementation // // ================================================================================================= -int find_header_value(const char *headers, const char *header_name, char *value, size_t value_size) { - if (!headers || !header_name || !value || value_size == 0) return 0; +// Helper to find a header value without modifying the source string +static int find_header_value(const char* data, size_t len, const char* name, char* value, size_t value_size) { + size_t name_len = strlen(name); + const char* end = data + len; - size_t header_name_len = strlen(header_name); - const char *line_start = headers; - - while (*line_start) { - const char *line_end = strstr(line_start, "\r\n"); - if (!line_end) { - line_end = line_start + strlen(line_start); + for (const char* p = data; p < end; ) { + const char* line_end = memchr(p, '\n', end - p); + if (!line_end) break; + + if (line_end > p && *(line_end - 1) == '\r') { + line_end--; } - size_t line_len = line_end - line_start; - if (line_len > header_name_len + 1) { - if (strncasecmp(line_start, header_name, header_name_len) == 0 && - line_start[header_name_len] == ':') { + if ((size_t)(line_end - p) > name_len && strncasecmp(p, name, name_len) == 0 && p[name_len] == ':') { + const char* v_start = p + name_len + 1; + while (v_start < line_end && (*v_start == ' ' || *v_start == '\t')) v_start++; - const char *value_start = line_start + header_name_len + 1; - while (value_start < line_end && isspace((unsigned char)*value_start)) { - value_start++; - } + const char* v_end = line_end; + while (v_end > v_start && (*(v_end - 1) == ' ' || *(v_end - 1) == '\t')) v_end--; - size_t value_len = line_end - value_start; - if (value_len >= value_size) { - value_len = value_size - 1; - } + size_t copy_len = v_end - v_start; + if (copy_len >= value_size) copy_len = value_size - 1; - memcpy(value, value_start, value_len); - value[value_len] = '\0'; - trim_whitespace(value); - return 1; - } + memcpy(value, v_start, copy_len); + value[copy_len] = '\0'; + return 1; } - if (*line_end == '\r') { - line_start = line_end + 2; - } else if (*line_end == '\0') { - break; - } else { - line_start = line_end + 1; - } + p = line_end + ( (line_end < end && *(line_end) == '\r') ? 2 : 1 ); } - return 0; } -int parse_http_request_line(const char *line, http_request_t *request) { - if (!line || !request) return 0; - memset(request, 0, sizeof(http_request_t)); +// A robust, non-destructive HTTP request parser +int parse_http_request(const char *data, size_t len, http_request_t *req) { + memset(req, 0, sizeof(http_request_t)); + req->content_length = -1; - char *line_copy = safe_strdup(line, MAX_REQUEST_LINE_SIZE); - if (!line_copy) return 0; - - trim_whitespace(line_copy); - - char *method = strtok(line_copy, " "); - char *uri = strtok(NULL, " "); - char *version = strtok(NULL, " "); - - if (!method || !uri || !version) { - free(line_copy); - return 0; + // Find the end of the request line + const char *line_end = memchr(data, '\n', len); + if (!line_end) return -1; // Incomplete + size_t line_len = line_end - data; + if (line_len > 0 && data[line_len - 1] == '\r') { + line_len--; } - if (!is_valid_http_method(method) || !is_valid_http_version(version)) { - free(line_copy); - return 0; + // Parse Method + const char *method_end = memchr(data, ' ', line_len); + if (!method_end) return 0; // Malformed + size_t method_len = method_end - data; + if (method_len >= sizeof(req->method)) return 0; + memcpy(req->method, data, method_len); + req->method[method_len] = '\0'; + if (!is_valid_http_method(req->method)) return 0; + + // Parse URI + const char *uri_start = method_end + 1; + while (uri_start < data + line_len && *uri_start == ' ') uri_start++; + const char *uri_end = data + line_len; + // Look for version part from the right + const char *version_start = memrchr(uri_start, ' ', uri_end - uri_start); + if (!version_start || version_start == uri_start) return 0; + + // Set URI + size_t uri_len = version_start - uri_start; + if (uri_len >= sizeof(req->uri)) return 0; + memcpy(req->uri, uri_start, uri_len); + req->uri[uri_len] = '\0'; + + // Set Version + version_start++; + while (version_start < uri_end && *version_start == ' ') version_start++; + size_t version_len = uri_end - version_start; + if (version_len >= sizeof(req->version)) return 0; + memcpy(req->version, version_start, version_len); + req->version[version_len] = '\0'; + + // Parse headers + const char *headers_start = line_end + 1; + char value[1024]; + + if (find_header_value(headers_start, len - (headers_start - data), "Host", req->host, sizeof(req->host))) { + char *port_colon = strchr(req->host, ':'); + if (port_colon) *port_colon = '\0'; } - if (strlen(uri) >= sizeof(request->uri)) { - free(line_copy); - return 0; + if (find_header_value(headers_start, len - (headers_start - data), "Content-Length", value, sizeof(value))) { + req->content_length = atol(value); + } + + if (find_header_value(headers_start, len - (headers_start - data), "Connection", value, sizeof(value))) { + if (strcasecmp(value, "keep-alive") == 0) req->keep_alive = 1; + if (strcasecmp(value, "upgrade") == 0) req->is_websocket = 1; } - strncpy(request->method, method, sizeof(request->method) - 1); - request->method[sizeof(request->method) - 1] = '\0'; // Fix: Ensure null termination - - strncpy(request->uri, uri, sizeof(request->uri) - 1); - request->uri[sizeof(request->uri) - 1] = '\0'; // Fix: Ensure null termination - - strncpy(request->version, version, sizeof(request->version) - 1); - request->version[sizeof(request->version) - 1] = '\0'; // Fix: Ensure null termination - - free(line_copy); - return 1; -} - -int parse_http_headers(const char *headers, http_request_t *request) { - if (!headers || !request) return 0; - - char value[256]; - - // Parse Host header - if (find_header_value(headers, "Host", value, sizeof(value))) { - char *colon = strchr(value, ':'); - if (colon) { - *colon = '\0'; // Remove port from hostname - } - strncpy(request->host, value, sizeof(request->host) - 1); - request->host[sizeof(request->host) - 1] = '\0'; // Fix: [-Wstringop-truncation] Ensure null termination - } - - // Parse Content-Length - request->content_length = -1; // Default to unknown - if (find_header_value(headers, "Content-Length", value, sizeof(value))) { - request->content_length = atoi(value); - if (request->content_length < 0) { - request->content_length = 0; - } - } - - // Parse Connection header - if (find_header_value(headers, "Connection", value, sizeof(value))) { - if (strcasecmp(value, "keep-alive") == 0) { - request->keep_alive = 1; - } else if (strcasecmp(value, "upgrade") == 0) { - request->connection_upgrade = 1; - } - } - - // Parse Upgrade header for WebSocket - if (find_header_value(headers, "Upgrade", value, sizeof(value))) { - if (strcasecmp(value, "websocket") == 0) { - request->is_websocket = 1; - } + if (find_header_value(headers_start, len - (headers_start - data), "Upgrade", value, sizeof(value))) { + if (strcasecmp(value, "websocket") == 0) req->is_websocket = 1; } return 1; } + void send_error_response(connection_t *conn, int code, const char* status, const char* body) { if (!conn || !status || !body) return; @@ -1331,166 +1259,24 @@ void send_error_response(connection_t *conn, int code, const char* status, const "Content-Type: text/plain; charset=utf-8\r\n" "Content-Length: %zu\r\n" "Connection: close\r\n" - "Server: ReverseProxy/1.1\r\n" + "Server: ReverseProxy/2.1\r\n" "\r\n" "%s", code, status, strlen(body), body); - if (len > 0 && len < (int)sizeof(response)) { // Fix: [-Wsign-compare] Cast sizeof to int - if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.size + len) == 0) { - memcpy(conn->write_buf.data + conn->write_buf.size, response, len); - conn->write_buf.size += len; + if (len > 0 && (size_t)len < sizeof(response)) { + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) { + memcpy(conn->write_buf.data + conn->write_buf.tail, response, len); + conn->write_buf.tail += len; struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); } } - conn->keep_alive = 0; conn->state = CLIENT_STATE_ERROR; } -void handle_client_data(connection_t *conn) { - if (!conn) return; - - // Update last activity time - conn->last_activity = time(NULL); - - int bytes_read = read(conn->fd, conn->read_buf.data + conn->read_buf.size, - conn->read_buf.capacity - conn->read_buf.size - 1); - - if (bytes_read <= 0) { - if (bytes_read < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) { - return; - } - close_connection(conn->fd); - return; - } - - conn->read_buf.size += bytes_read; - conn->read_buf.data[conn->read_buf.size] = '\0'; - - // Check for complete headers - char *headers_end = strstr(conn->read_buf.data, "\r\n\r\n"); - if (!headers_end) { - // Check if we've exceeded max header size - if (conn->read_buf.size >= MAX_HEADER_SIZE) { - send_error_response(conn, 413, "Request Header Too Large", "Request headers too large"); - } - return; - } - - // Mark request start time - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - conn->request_start_time = ts.tv_sec + ts.tv_nsec / 1e9; - - // Extract request line - char *first_line_end = strstr(conn->read_buf.data, "\r\n"); - if (!first_line_end) { - send_error_response(conn, 400, "Bad Request", "Invalid request line"); - return; - } - - *first_line_end = '\0'; - - // Parse request line - if (!parse_http_request_line(conn->read_buf.data, &conn->request)) { - send_error_response(conn, 400, "Bad Request", "Invalid request line format"); - *first_line_end = '\r'; // restore - return; - } - - *first_line_end = '\r'; // restore - - // Parse headers - char *headers_start = first_line_end + 2; - *headers_end = '\0'; - - if (!parse_http_headers(headers_start, &conn->request)) { - send_error_response(conn, 400, "Bad Request", "Invalid headers"); - *headers_end = '\r'; // restore - return; - } - - *headers_end = '\r'; // restore - - // Handle special dashboard/API routes - if (strcmp(conn->request.method, "GET") == 0) { - if (strncmp(conn->request.uri, "/dashboard", 10) == 0) { - serve_dashboard(conn); - return; - } - if (strncmp(conn->request.uri, "/api/stats", 10) == 0) { - serve_stats_api(conn); - return; - } - } - - // Validate Host header - if (strlen(conn->request.host) == 0) { - send_error_response(conn, 400, "Bad Request", "Host header is required"); - return; - } - - // Find route configuration - route_config_t *route = find_route(conn->request.host); - if (!route) { - char body[512]; - snprintf(body, sizeof(body), "Host not configured: %s", conn->request.host); - send_error_response(conn, 404, "Not Found", body); - return; - } - - // Get or create vhost stats - conn->vhost_stats = monitor_get_or_create_vhost_stats(route->hostname); - conn->is_websocket = conn->request.is_websocket; - conn->keep_alive = conn->request.keep_alive; - - // **FIX**: Store content length and how much body we've already read - conn->http_content_length = conn->request.content_length; - conn->http_body_read = conn->read_buf.size - ((headers_end + 4) - conn->read_buf.data); - if (conn->http_content_length < 0) conn->http_content_length = 0; // Treat unknown length as 0 for this stage - - // Record request start - monitor_record_request_start(conn->vhost_stats, conn->is_websocket); - - // Rewrite Host header if needed - if (route->rewrite_host) { - char new_host_header[320]; - snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", - route->upstream_host, route->upstream_port); - - char host_search[] = "Host:"; - char *host_start = strcasestr(conn->read_buf.data, host_search); - if (host_start) { - char *host_line_end = strstr(host_start, "\r\n"); - if (host_line_end) { - size_t old_len = (host_line_end + 2) - host_start; - size_t new_len = strlen(new_host_header); - size_t tail_len = conn->read_buf.size - (host_start - conn->read_buf.data) - old_len; - - if (new_len != old_len) { - if (buffer_ensure_capacity(&conn->read_buf, conn->read_buf.size - old_len + new_len + 1) < 0) { - send_error_response(conn, 500, "Internal Server Error", "Memory allocation failed"); - return; - } - // Recalculate pointer after potential realloc - host_start = strcasestr(conn->read_buf.data, host_search); - memmove(host_start + new_len, host_start + old_len, tail_len); - conn->read_buf.size = conn->read_buf.size - old_len + new_len; - } - memcpy(host_start, new_host_header, new_len); - conn->read_buf.data[conn->read_buf.size] = '\0'; - } - } - } - - // Transition to connecting upstream - conn->state = CLIENT_STATE_CONNECTING_UPSTREAM; - connect_to_upstream(conn, route); -} - // ================================================================================================= // // Proxy Implementation @@ -1515,7 +1301,10 @@ static void add_to_epoll(int fd, uint32_t events) { static void modify_epoll(int fd, uint32_t events) { struct epoll_event event = { .data.fd = fd, .events = events }; if (epoll_ctl(epoll_fd, EPOLL_CTL_MOD, fd, &event) == -1) { - close_connection(fd); + // This can happen if fd was closed, log it but don't close again + if(errno != EBADF && errno != ENOENT) { + log_debug("epoll_ctl_mod failed for fd %d", fd); + } } } @@ -1584,7 +1373,7 @@ void accept_new_connection(int listener_fd) { connection_t *conn = &connections[client_fd]; memset(conn, 0, sizeof(connection_t)); conn->type = CONN_TYPE_CLIENT; - conn->state = CLIENT_STATE_READING_REQUEST_LINE; + conn->state = CLIENT_STATE_READING_HEADERS; conn->fd = client_fd; conn->last_activity = time(NULL); @@ -1604,27 +1393,37 @@ void close_connection(int fd) { connection_t *conn = &connections[fd]; - // Close paired connection if exists + // Prevent double-closing + if (conn->fd == -1) return; + + int pair_fd = -1; if (conn->pair) { + pair_fd = conn->pair->fd; conn->pair->pair = NULL; - close_connection(conn->pair->fd); } // Record request end if needed if (conn->vhost_stats && conn->request_start_time > 0) { monitor_record_request_end(conn->vhost_stats, conn->request_start_time); + conn->request_start_time = 0; // Prevent double recording + } + + // Update connection count before we invalidate the pointer + if (conn->type == CONN_TYPE_CLIENT) { + if(monitor.active_connections > 0) monitor.active_connections--; } - // Remove from epoll + log_debug("Closing connection on fd %d, pair %d, remaining: %d", fd, pair_fd, monitor.active_connections); + + // Remove from epoll before closing epoll_ctl(epoll_fd, EPOLL_CTL_DEL, fd, NULL); // Clean up SSL if (conn->ssl) { SSL_shutdown(conn->ssl); SSL_free(conn->ssl); - conn->ssl = NULL; } - + // Close socket close(fd); @@ -1632,20 +1431,24 @@ void close_connection(int fd) { buffer_free(&conn->read_buf); buffer_free(&conn->write_buf); - // Update connection count - if (conn->type == CONN_TYPE_CLIENT) { - monitor.active_connections--; - } - - log_debug("Closed connection on fd %d, remaining: %d", fd, monitor.active_connections); - - // Reset connection structure - memset(conn, 0, sizeof(connection_t)); + // Mark as unused conn->type = CONN_TYPE_UNUSED; + conn->fd = -1; // Mark as closed + + // Close the paired connection + if (pair_fd != -1) { + close_connection(pair_fd); + } } -void connect_to_upstream(connection_t *client, route_config_t *route) { - if (!client || !route) return; +void connect_to_upstream(connection_t *client, const char *data, size_t data_len) { + if (!client || !data) return; + + route_config_t *route = find_route(client->request.host); + if (!route) { + send_error_response(client, 502, "Bad Gateway", "Route not found after parsing"); + return; + } int up_fd = socket(AF_INET, SOCK_STREAM, 0); if (up_fd < 0) { @@ -1698,36 +1501,93 @@ void connect_to_upstream(connection_t *client, route_config_t *route) { if (buffer_init(&up->read_buf, CHUNK_SIZE) < 0 || buffer_init(&up->write_buf, CHUNK_SIZE) < 0) { - close_connection(up_fd); - send_error_response(client, 502, "Bad Gateway", "Memory allocation failed"); + close_connection(client->fd); // Close the whole pair return; } + char *data_to_send = (char*)data; + size_t len_to_send = data_len; + char *modified_request = NULL; + + if (route->rewrite_host) { + char new_host_header[512]; + const char *old_host_header_start = NULL; + const char *old_host_header_end = NULL; + + // Find the old host header to replace it accurately + const char *current = data; + const char *end = data + data_len; + while(current < end) { + const char* line_end = memchr(current, '\n', end - current); + if (!line_end) break; + if (strncasecmp(current, "Host:", 5) == 0) { + old_host_header_start = current; + old_host_header_end = line_end + 1; + break; + } + current = line_end + 1; + } + + if (old_host_header_start) { + snprintf(new_host_header, sizeof(new_host_header), "Host: %s:%d\r\n", route->upstream_host, route->upstream_port); + size_t new_host_len = strlen(new_host_header); + size_t old_host_len = old_host_header_end - old_host_header_start; + + len_to_send = data_len - old_host_len + new_host_len; + modified_request = malloc(len_to_send); + if (modified_request) { + char* p = modified_request; + // Copy part before Host header + memcpy(p, data, old_host_header_start - data); + p += old_host_header_start - data; + // Copy new Host header + memcpy(p, new_host_header, new_host_len); + p += new_host_len; + // Copy part after Host header + memcpy(p, old_host_header_end, data_len - (old_host_header_end - data)); + data_to_send = modified_request; + } + } + } + + if (buffer_ensure_capacity(&up->write_buf, len_to_send) == 0) { + memcpy(up->write_buf.data, data_to_send, len_to_send); + up->write_buf.tail = len_to_send; + } + + if (modified_request) { + free(modified_request); + } + if (route->use_ssl) { up->ssl = SSL_new(ssl_ctx); if (!up->ssl) { - close_connection(up_fd); + close_connection(client->fd); send_error_response(client, 502, "Bad Gateway", "SSL initialization failed"); return; } + + SSL_set_tlsext_host_name(up->ssl, route->upstream_host); // SNI SSL_set_fd(up->ssl, up_fd); SSL_set_connect_state(up->ssl); } - + log_debug("Connecting to upstream %s:%d on fd %d", route->upstream_host, route->upstream_port, up_fd); } + static int do_read(connection_t *conn) { if (!conn) return -1; - - if (buffer_ensure_capacity(&conn->read_buf, conn->read_buf.size + CHUNK_SIZE) < 0) { - return -1; - } + + buffer_t *buf = &conn->read_buf; + if (buf->tail == buf->capacity) buffer_compact(buf); + size_t len = buf->capacity - buf->tail; + if (len == 0) return 0; // Buffer full int bytes_read; if (conn->ssl && conn->ssl_handshake_done) { - bytes_read = SSL_read(conn->ssl, conn->read_buf.data + conn->read_buf.size, CHUNK_SIZE); + bytes_read = SSL_read(conn->ssl, buf->data + buf->tail, len); if (bytes_read <= 0) { int ssl_error = SSL_get_error(conn->ssl, bytes_read); if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { @@ -1735,18 +1595,23 @@ static int do_read(connection_t *conn) { } } } else { - bytes_read = read(conn->fd, conn->read_buf.data + conn->read_buf.size, CHUNK_SIZE); + bytes_read = read(conn->fd, buf->data + buf->tail, len); } - + + if (bytes_read > 0) buf->tail += bytes_read; return bytes_read; } static int do_write(connection_t *conn) { - if (!conn || conn->write_buf.size == 0) return 0; + if (!conn) return -1; + + buffer_t *buf = &conn->write_buf; + size_t len = buf->tail - buf->head; + if (len == 0) return 0; int written; if (conn->ssl && conn->ssl_handshake_done) { - written = SSL_write(conn->ssl, conn->write_buf.data, conn->write_buf.size); + written = SSL_write(conn->ssl, buf->data + buf->head, len); if (written <= 0) { int ssl_error = SSL_get_error(conn->ssl, written); if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { @@ -1755,121 +1620,135 @@ static int do_write(connection_t *conn) { } } } else { - written = write(conn->fd, conn->write_buf.data, conn->write_buf.size); + written = write(conn->fd, buf->data + buf->head, len); } if (written > 0) { - memmove(conn->write_buf.data, conn->write_buf.data + written, - conn->write_buf.size - written); - conn->write_buf.size -= written; + buf->head += written; + if (buf->head == buf->tail) { + buf->head = 0; + buf->tail = 0; + } } - return written; } -void handle_client_body(connection_t *conn) { - if (!conn || !conn->pair) { +void handle_client_read(connection_t *conn) { + if (do_read(conn) <= 0 && (errno != EAGAIN && errno != EWOULDBLOCK)) { close_connection(conn->fd); return; } + + buffer_t *buf = &conn->read_buf; - connection_t *upstream = conn->pair; - conn->last_activity = time(NULL); + // Loop to handle pipelined requests for internal routes + while (buf->tail > buf->head) { + char *data_start = buf->data + buf->head; + size_t data_len = buf->tail - buf->head; - long remaining_to_read_total = conn->http_content_length - conn->http_body_read; - if (remaining_to_read_total <= 0) { - conn->state = CLIENT_STATE_TUNNELING; - return; - } + char *headers_end_marker = memmem(data_start, data_len, "\r\n\r\n", 4); + if (!headers_end_marker) { + if (data_len >= MAX_HEADER_SIZE) { + send_error_response(conn, 413, "Request Header Too Large", "Header too large"); + } + break; // Incomplete request, wait for more data + } + size_t headers_len = (headers_end_marker - data_start) + 4; - size_t can_read_now = conn->read_buf.capacity - conn->read_buf.size; - if (can_read_now > (size_t)remaining_to_read_total) { // Fix: [-Wsign-compare] Cast to size_t - can_read_now = remaining_to_read_total; - } - - if(can_read_now == 0) return; - - int bytes = read(conn->fd, conn->read_buf.data, can_read_now); - - if (bytes > 0) { - conn->http_body_read += bytes; - - if (buffer_ensure_capacity(&upstream->write_buf, upstream->write_buf.size + bytes) < 0) { - close_connection(conn->fd); + if (parse_http_request(data_start, headers_len, &conn->request) != 1) { + send_error_response(conn, 400, "Bad Request", "Malformed HTTP request"); return; } - memcpy(upstream->write_buf.data + upstream->write_buf.size, conn->read_buf.data, bytes); - upstream->write_buf.size += bytes; - modify_epoll(upstream->fd, EPOLLIN | EPOLLOUT); + long long body_len = (conn->request.content_length > 0) ? conn->request.content_length : 0; + size_t total_request_len = headers_len + body_len; - if (conn->http_body_read >= conn->http_content_length) { - log_debug("Finished reading body for fd %d, transitioning to TUNNELING", conn->fd); - conn->state = CLIENT_STATE_TUNNELING; - } - } else if (bytes == 0) { - close_connection(conn->fd); - } else if (errno != EAGAIN && errno != EWOULDBLOCK) { - close_connection(conn->fd); - } -} - -static void handle_tunneling(connection_t *conn) { - if (!conn || !conn->pair) { - close_connection(conn->fd); - return; - } - - conn->last_activity = time(NULL); - - int bytes = do_read(conn); - if (bytes > 0) { - conn->read_buf.size += bytes; - connection_t *pair = conn->pair; - - if (buffer_ensure_capacity(&pair->write_buf, - pair->write_buf.size + conn->read_buf.size) < 0) { - close_connection(conn->fd); - return; + if (data_len < total_request_len) { + break; // Incomplete body, wait for more data } - memcpy(pair->write_buf.data + pair->write_buf.size, - conn->read_buf.data, conn->read_buf.size); - pair->write_buf.size += conn->read_buf.size; - - // Update stats - if (conn->vhost_stats) { - if (conn->type == CONN_TYPE_CLIENT) { - conn->vhost_stats->bytes_recv += bytes; - } else { - conn->vhost_stats->bytes_sent += bytes; + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + conn->request_start_time = ts.tv_sec + ts.tv_nsec / 1e9; + + // Handle internal routes + if (strcmp(conn->request.method, "GET") == 0) { + if (strncmp(conn->request.uri, "/dashboard", 10) == 0) { + serve_dashboard(conn); + buf->head += total_request_len; // Consume request + monitor_record_request_end(conn->vhost_stats, conn->request_start_time); + conn->request_start_time = 0; + continue; // Check for another pipelined request + } + if (strncmp(conn->request.uri, "/api/stats", 10) == 0) { + serve_stats_api(conn); + buf->head += total_request_len; // Consume request + monitor_record_request_end(conn->vhost_stats, conn->request_start_time); + conn->request_start_time = 0; + continue; } } + + // This is a request to be forwarded + conn->vhost_stats = monitor_get_or_create_vhost_stats(conn->request.host); + monitor_record_request_start(conn->vhost_stats, conn->request.is_websocket); - conn->read_buf.size = 0; - modify_epoll(pair->fd, EPOLLIN | EPOLLOUT); - } else if (bytes == 0) { - close_connection(conn->fd); - } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + conn->state = CLIENT_STATE_FORWARDING; // <-- CRITICAL: Change state now + + connect_to_upstream(conn, data_start, total_request_len); + + // The data has been handed off; clear the read buffer and stop parsing. + buf->head = 0; + buf->tail = 0; + return; + } + + buffer_compact(buf); +} + +static void handle_forwarding(connection_t *conn) { + conn->last_activity = time(NULL); + connection_t *pair = conn->pair; + if (!pair) { close_connection(conn->fd); return; } + + int bytes_read = do_read(conn); + size_t data_to_forward = conn->read_buf.tail - conn->read_buf.head; + + if (data_to_forward > 0) { + if (buffer_ensure_capacity(&pair->write_buf, pair->write_buf.tail + data_to_forward) == 0) { + memcpy(pair->write_buf.data + pair->write_buf.tail, conn->read_buf.data + conn->read_buf.head, data_to_forward); + pair->write_buf.tail += data_to_forward; + + conn->read_buf.head = 0; + conn->read_buf.tail = 0; + modify_epoll(pair->fd, EPOLLIN | EPOLLOUT); + } else { + log_error("Failed to ensure write buffer capacity during forward."); + close_connection(conn->fd); + return; + } + } + + if (bytes_read == 0) { + shutdown(pair->fd, SHUT_WR); + } else if (bytes_read < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { close_connection(conn->fd); } } + static void handle_write_event(connection_t *conn) { if (!conn) return; - conn->last_activity = time(NULL); if (conn->type == CONN_TYPE_UPSTREAM && !conn->ssl_handshake_done) { - // Check if connection is established int err = 0; socklen_t len = sizeof(err); if (getsockopt(conn->fd, SOL_SOCKET, SO_ERROR, &err, &len) != 0 || err != 0) { - close_connection(conn->fd); + send_error_response(conn->pair, 502, "Bad Gateway", strerror(err)); return; } - // Handle SSL handshake if needed if (conn->ssl) { int ret = SSL_do_handshake(conn->ssl); if (ret == 1) { @@ -1877,67 +1756,24 @@ static void handle_write_event(connection_t *conn) { } else { int ssl_error = SSL_get_error(conn->ssl, ret); if (ssl_error != SSL_ERROR_WANT_READ && ssl_error != SSL_ERROR_WANT_WRITE) { - close_connection(conn->fd); + send_error_response(conn->pair, 502, "Bad Gateway", "SSL handshake failed"); return; } - // Continue handshake later - return; + return; // Handshake in progress } } else { - conn->ssl_handshake_done = 1; // No SSL, consider "handshake" done - } - - // Connection established, forward client request - if (conn->ssl_handshake_done && conn->pair) { - connection_t *client = conn->pair; - if (buffer_ensure_capacity(&conn->write_buf, - conn->write_buf.size + client->read_buf.size) < 0) { - close_connection(conn->fd); - return; - } - - memcpy(conn->write_buf.data + conn->write_buf.size, - client->read_buf.data, client->read_buf.size); - conn->write_buf.size += client->read_buf.size; - client->read_buf.size = 0; - - // **FIX**: Transition to the correct state based on Content-Length - if (client->http_body_read >= client->http_content_length) { - client->state = CLIENT_STATE_TUNNELING; - log_debug("Request has no body or was already read, tunneling fd %d", client->fd); - } else { - client->state = CLIENT_STATE_READING_BODY; - log_debug("Request has body, reading body for fd %d", client->fd); - } - - // Record WebSocket request end immediately after connection - if (client->is_websocket && client->vhost_stats) { - monitor_record_request_end(client->vhost_stats, client->request_start_time); - } - } - } - - // Write data if available - int written = do_write(conn); - if (written > 0) { - // Update stats - if(conn->vhost_stats) { - if(conn->type == CONN_TYPE_CLIENT) { - conn->vhost_stats->bytes_sent += written; - } else { - conn->vhost_stats->bytes_recv += written; // This seems reversed, but from perspective of vhost, upstream sends to it. - } + conn->ssl_handshake_done = 1; // No SSL } } - if (written < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { - close_connection(conn->fd); - return; - } + do_write(conn); - // Adjust epoll events - if (conn->write_buf.size == 0) { - modify_epoll(conn->fd, EPOLLIN); + if (conn->write_buf.tail - conn->write_buf.head == 0) { + if (conn->type == CONN_TYPE_CLIENT && conn->state == CLIENT_STATE_ERROR) { + close_connection(conn->fd); + } else { + modify_epoll(conn->fd, EPOLLIN); + } } } @@ -1949,48 +1785,31 @@ void handle_connection_event(struct epoll_event *event) { return; } - if (fd < 0 || fd >= MAX_FDS) { - return; - } + if (fd < 0 || fd >= MAX_FDS) return; connection_t *conn = &connections[fd]; + if (conn->type == CONN_TYPE_UNUSED) return; if (conn->type == CONN_TYPE_LISTENER) { if (event->events & EPOLLIN) { accept_new_connection(fd); } - } else { + } else { // Client or Upstream if (event->events & EPOLLOUT) { handle_write_event(conn); } - - if ((event->events & EPOLLIN) && conn->type != CONN_TYPE_UNUSED) { - if (conn->type == CONN_TYPE_CLIENT) { - // **FIX**: Dispatch based on the more detailed state machine - switch (conn->state) { - case CLIENT_STATE_READING_REQUEST_LINE: - handle_client_data(conn); - break; - case CLIENT_STATE_READING_BODY: - handle_client_body(conn); - break; - case CLIENT_STATE_TUNNELING: - handle_tunneling(conn); - break; - default: - // Do nothing for other states like CONNECTING, DONE, ERROR - break; - } - } else { // It's an upstream connection - handle_tunneling(conn); + if (connections[fd].type != CONN_TYPE_UNUSED && (event->events & EPOLLIN)) { + if (conn->type == CONN_TYPE_CLIENT && conn->state == CLIENT_STATE_READING_HEADERS) { + handle_client_read(conn); + } else { + handle_forwarding(conn); } } } } - // ================================================================================================= // -// Main Implementation +// Main Implementation // // ================================================================================================= @@ -2039,7 +1858,6 @@ void init_ssl() { exit(EXIT_FAILURE); } - // Set verification options SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, NULL); SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); } @@ -2047,7 +1865,6 @@ void init_ssl() { void cleanup() { log_info("Shutting down proxy..."); - // Close all connections for (int i = 0; i < MAX_FDS; i++) { if (connections[i].type != CONN_TYPE_UNUSED) { close_connection(i); @@ -2072,7 +1889,6 @@ void cleanup_idle_connections() { static time_t last_cleanup = 0; time_t current_time = time(NULL); - // Run cleanup every 60 seconds if (current_time - last_cleanup < 60) { return; } @@ -2091,6 +1907,17 @@ void cleanup_idle_connections() { } } + +void free_config() { + if (config.routes) { + free(config.routes); + config.routes = NULL; + } + config.route_count = 0; +} + + + int main(int argc, char *argv[]) { signal(SIGPIPE, SIG_IGN); @@ -2116,7 +1943,6 @@ int main(int argc, char *argv[]) { return 1; } - // Initialize connections array for (int i = 0; i < MAX_FDS; i++) { connections[i].type = CONN_TYPE_UNUSED; }