From 833c78debc851c6970606a8749bab757f0f984d1 Mon Sep 17 00:00:00 2001 From: retoor Date: Sat, 29 Nov 2025 04:58:34 +0100 Subject: [PATCH] Update. --- Makefile | 17 +++- README.md | 59 ++++++++++- src/auth.c | 196 +++++++++++++++++++++++++++++++++++ src/auth.h | 13 +++ src/buffer.c | 31 +++++- src/config.c | 247 ++++++++++++++++++++++++++++++++++++++++++++- src/config.h | 4 + src/connection.c | 173 +++++++++++++++++++++++++------ src/connection.h | 3 +- src/dashboard.c | 70 ++++++++++++- src/dashboard.h | 4 +- src/health_check.c | 188 ++++++++++++++++++++++++++++++++++ src/health_check.h | 11 ++ src/http.c | 6 +- src/logging.c | 77 ++++++++++++-- src/logging.h | 4 +- src/main.c | 141 +++++++++++++++++++++++++- src/monitor.c | 23 ++++- src/rate_limit.c | 126 +++++++++++++++++++++++ src/rate_limit.h | 11 ++ src/ssl_handler.c | 56 ++++++++-- src/ssl_handler.h | 3 + src/types.h | 13 +++ 23 files changed, 1404 insertions(+), 72 deletions(-) create mode 100644 src/auth.c create mode 100644 src/auth.h create mode 100644 src/health_check.c create mode 100644 src/health_check.h create mode 100644 src/rate_limit.c create mode 100644 src/rate_limit.h diff --git a/Makefile b/Makefile index d32e503..876e689 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ CC = gcc CFLAGS = -Wall -Wextra -O2 -g -D_GNU_SOURCE -LDFLAGS = -lssl -lcrypto -lsqlite3 -lm +LDFLAGS = -lssl -lcrypto -lsqlite3 -lm -lpthread SRC_DIR = src BUILD_DIR = build @@ -15,6 +15,9 @@ SOURCES = $(SRC_DIR)/main.c \ $(SRC_DIR)/ssl_handler.c \ $(SRC_DIR)/connection.c \ $(SRC_DIR)/dashboard.c \ + $(SRC_DIR)/rate_limit.c \ + $(SRC_DIR)/auth.c \ + $(SRC_DIR)/health_check.c \ cJSON.c OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(SOURCES))) @@ -37,6 +40,9 @@ TEST_LIB_SOURCES = $(SRC_DIR)/buffer.c \ $(SRC_DIR)/ssl_handler.c \ $(SRC_DIR)/connection.c \ $(SRC_DIR)/dashboard.c \ + $(SRC_DIR)/rate_limit.c \ + $(SRC_DIR)/auth.c \ + $(SRC_DIR)/health_check.c \ cJSON.c TEST_LIB_OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(TEST_LIB_SOURCES))) @@ -80,6 +86,15 @@ $(BUILD_DIR)/connection.o: $(SRC_DIR)/connection.c $(BUILD_DIR)/dashboard.o: $(SRC_DIR)/dashboard.c $(CC) $(CFLAGS) -c $< -o $@ +$(BUILD_DIR)/rate_limit.o: $(SRC_DIR)/rate_limit.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/auth.o: $(SRC_DIR)/auth.c + $(CC) $(CFLAGS) -c $< -o $@ + +$(BUILD_DIR)/health_check.o: $(SRC_DIR)/health_check.c + $(CC) $(CFLAGS) -c $< -o $@ + $(BUILD_DIR)/cJSON.o: cJSON.c $(CC) $(CFLAGS) -c $< -o $@ diff --git a/README.md b/README.md index 8f25858..434ab7b 100644 --- a/README.md +++ b/README.md @@ -5,19 +5,27 @@ rproxy is a high-performance reverse proxy server written in C. It routes HTTP a ## Features - Reverse proxy routing by hostname -- SSL/TLS support for upstream connections +- SSL/TLS support for upstream connections with certificate verification - WebSocket proxying - Connection pooling and idle timeout management - Real-time monitoring and statistics - Web-based dashboard for metrics visualization - SQLite-based persistent statistics storage - Epoll-based event handling for high concurrency +- Graceful shutdown with connection draining +- Live configuration reload via SIGHUP +- Dashboard authentication (HTTP Basic Auth) +- Rate limiting per client IP +- Health checks for upstream servers +- Automatic upstream connection retries +- File logging support ## Dependencies - GCC - OpenSSL (libssl, libcrypto) - SQLite3 +- pthreads - cJSON library ## Build @@ -55,6 +63,19 @@ Configuration is defined in `proxy_config.json`: - `use_ssl`: Enable SSL for upstream connection - `rewrite_host`: Rewrite Host header to upstream hostname +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `DEBUG` | Enable debug logging (set to `1`) | +| `LOG_FILE` | Path to log file (default: stdout) | +| `RATE_LIMIT` | Max requests per minute per IP | +| `DASHBOARD_USER` | Dashboard authentication username | +| `DASHBOARD_PASS` | Dashboard authentication password | +| `SSL_VERIFY` | Disable SSL verification (set to `0`) | +| `SSL_CA_FILE` | Path to custom CA certificate file | +| `SSL_CA_PATH` | Path to CA certificate directory | + ## Usage ```bash @@ -63,11 +84,44 @@ Configuration is defined in `proxy_config.json`: If no config file is specified, defaults to `proxy_config.json`. +Examples: + +```bash +# Basic usage +./rproxy + +# With custom config +./rproxy /etc/rproxy/config.json + +# With debug logging +DEBUG=1 ./rproxy + +# With file logging +LOG_FILE=/var/log/rproxy.log ./rproxy + +# With rate limiting (100 requests/minute) +RATE_LIMIT=100 ./rproxy + +# With dashboard authentication +DASHBOARD_USER=admin DASHBOARD_PASS=secret ./rproxy + +# Reload configuration +kill -HUP $(pidof rproxy) +``` + ## Endpoints - Dashboard: `http://localhost:{port}/rproxy/dashboard` - API Stats: `http://localhost:{port}/rproxy/api/stats` +## Signals + +| Signal | Action | +|--------|--------| +| `SIGINT` | Graceful shutdown | +| `SIGTERM` | Graceful shutdown | +| `SIGHUP` | Reload configuration | + ## Architecture - **main.c**: Entry point, event loop, signal handling @@ -79,6 +133,9 @@ If no config file is specified, defaults to `proxy_config.json`. - **config.c**: JSON configuration parsing - **buffer.c**: Circular buffer implementation - **logging.c**: Logging utilities +- **rate_limit.c**: Per-IP rate limiting +- **auth.c**: Dashboard authentication +- **health_check.c**: Upstream health monitoring ## Testing diff --git a/src/auth.c b/src/auth.c new file mode 100644 index 0000000..b4cd5ed --- /dev/null +++ b/src/auth.c @@ -0,0 +1,196 @@ +#include "auth.h" +#include "logging.h" +#include +#include +#include +#include + +static char g_dashboard_username[128] = ""; +static char g_dashboard_password_hash[256] = ""; +static int g_auth_enabled = 0; + +static void compute_sha256(const char *input, char *output, size_t output_size) { + EVP_MD_CTX *ctx = EVP_MD_CTX_new(); + if (!ctx) return; + + unsigned char hash[EVP_MAX_MD_SIZE]; + unsigned int hash_len = 0; + + EVP_DigestInit_ex(ctx, EVP_sha256(), NULL); + EVP_DigestUpdate(ctx, input, strlen(input)); + EVP_DigestFinal_ex(ctx, hash, &hash_len); + EVP_MD_CTX_free(ctx); + + for (unsigned int i = 0; i < hash_len && (i * 2 + 2) < output_size; i++) { + snprintf(output + (i * 2), 3, "%02x", hash[i]); + } +} + +void auth_init(const char *username, const char *password) { + if (!username || !password || strlen(username) == 0 || strlen(password) == 0) { + g_auth_enabled = 0; + return; + } + + strncpy(g_dashboard_username, username, sizeof(g_dashboard_username) - 1); + g_dashboard_username[sizeof(g_dashboard_username) - 1] = '\0'; + + compute_sha256(password, g_dashboard_password_hash, sizeof(g_dashboard_password_hash)); + g_auth_enabled = 1; + log_info("Dashboard authentication enabled for user: %s", username); +} + +int auth_is_enabled(void) { + return g_auth_enabled; +} + +int auth_check_credentials(const char *username, const char *password) { + if (!g_auth_enabled) return 1; + if (!username || !password) return 0; + + if (strcmp(username, g_dashboard_username) != 0) return 0; + + char password_hash[256]; + compute_sha256(password, password_hash, sizeof(password_hash)); + + return strcmp(password_hash, g_dashboard_password_hash) == 0; +} + +static int base64_decode_char(char c) { + if (c >= 'A' && c <= 'Z') return c - 'A'; + if (c >= 'a' && c <= 'z') return c - 'a' + 26; + if (c >= '0' && c <= '9') return c - '0' + 52; + if (c == '+') return 62; + if (c == '/') return 63; + return -1; +} + +static int base64_decode(const char *input, char *output, size_t output_size) { + size_t input_len = strlen(input); + size_t output_idx = 0; + + for (size_t i = 0; i < input_len && output_idx < output_size - 1; i += 4) { + int v[4] = {0, 0, 0, 0}; + int pad = 0; + + for (int j = 0; j < 4; j++) { + if (i + j >= input_len || input[i + j] == '=') { + pad++; + v[j] = 0; + } else { + v[j] = base64_decode_char(input[i + j]); + if (v[j] < 0) return -1; + } + } + + if (output_idx < output_size - 1) output[output_idx++] = (v[0] << 2) | (v[1] >> 4); + if (pad < 2 && output_idx < output_size - 1) output[output_idx++] = (v[1] << 4) | (v[2] >> 2); + if (pad < 1 && output_idx < output_size - 1) output[output_idx++] = (v[2] << 6) | v[3]; + } + + output[output_idx] = '\0'; + return output_idx; +} + +int auth_check_basic_auth(const char *auth_header, char *error_msg, size_t error_size) { + if (!g_auth_enabled) return 1; + if (!auth_header) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Authentication required", error_size - 1); + } + return 0; + } + + if (strncmp(auth_header, "Basic ", 6) != 0) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid authentication method", error_size - 1); + } + return 0; + } + + char decoded[512]; + if (base64_decode(auth_header + 6, decoded, sizeof(decoded)) < 0) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid credentials format", error_size - 1); + } + return 0; + } + + char *colon = strchr(decoded, ':'); + if (!colon) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid credentials format", error_size - 1); + } + return 0; + } + + *colon = '\0'; + const char *username = decoded; + const char *password = colon + 1; + + if (!auth_check_credentials(username, password)) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid username or password", error_size - 1); + } + return 0; + } + + return 1; +} + +int auth_check_route_basic_auth(const route_config_t *route, const char *auth_header, char *error_msg, size_t error_size) { + if (!route || !route->use_auth) return 1; + + if (!auth_header) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Authentication required", error_size - 1); + } + return 0; + } + + if (strncmp(auth_header, "Basic ", 6) != 0) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid authentication method", error_size - 1); + } + return 0; + } + + char decoded[512]; + if (base64_decode(auth_header + 6, decoded, sizeof(decoded)) < 0) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid credentials format", error_size - 1); + } + return 0; + } + + char *colon = strchr(decoded, ':'); + if (!colon) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid credentials format", error_size - 1); + } + return 0; + } + + *colon = '\0'; + const char *username = decoded; + const char *password = colon + 1; + + if (strcmp(username, route->username) != 0) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid username or password", error_size - 1); + } + return 0; + } + + char password_hash[256]; + compute_sha256(password, password_hash, sizeof(password_hash)); + + if (strcmp(password_hash, route->password_hash) != 0) { + if (error_msg && error_size > 0) { + strncpy(error_msg, "Invalid username or password", error_size - 1); + } + return 0; + } + + return 1; +} diff --git a/src/auth.h b/src/auth.h new file mode 100644 index 0000000..170d0f1 --- /dev/null +++ b/src/auth.h @@ -0,0 +1,13 @@ +#ifndef RPROXY_AUTH_H +#define RPROXY_AUTH_H + +#include +#include "types.h" + +void auth_init(const char *username, const char *password); +int auth_is_enabled(void); +int auth_check_credentials(const char *username, const char *password); +int auth_check_basic_auth(const char *auth_header, char *error_msg, size_t error_size); +int auth_check_route_basic_auth(const route_config_t *route, const char *auth_header, char *error_msg, size_t error_size); + +#endif diff --git a/src/buffer.c b/src/buffer.c index d0fc88a..e593787 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -1,10 +1,12 @@ #include "buffer.h" #include "logging.h" +#include "types.h" #include #include #include int buffer_init(buffer_t *buf, size_t capacity) { + if (!buf) return -1; buf->data = malloc(capacity); if (!buf->data) { log_error("Failed to allocate buffer"); @@ -17,6 +19,7 @@ int buffer_init(buffer_t *buf, size_t capacity) { } void buffer_free(buffer_t *buf) { + if (!buf) return; if (buf->data) { free(buf->data); buf->data = NULL; @@ -27,23 +30,40 @@ void buffer_free(buffer_t *buf) { } size_t buffer_available_read(buffer_t *buf) { + if (!buf) return 0; return buf->tail - buf->head; } size_t buffer_available_write(buffer_t *buf) { + if (!buf) return 0; return buf->capacity - buf->tail; } int buffer_ensure_capacity(buffer_t *buf, size_t required) { + if (!buf) return -1; if (buf->capacity >= required) return 0; + if (required > MAX_BUFFER_SIZE) { + log_error("Buffer size limit exceeded: requested %zu, max %d", required, MAX_BUFFER_SIZE); + return -1; + } + size_t new_capacity = buf->capacity; while (new_capacity < required) { - new_capacity *= 2; if (new_capacity > SIZE_MAX / 2) { - log_error("Buffer size limit exceeded"); + log_error("Buffer size overflow"); return -1; } + new_capacity *= 2; + if (new_capacity > MAX_BUFFER_SIZE) { + new_capacity = MAX_BUFFER_SIZE; + break; + } + } + + if (new_capacity < required) { + log_error("Cannot satisfy buffer capacity requirement"); + return -1; } char *new_data = realloc(buf->data, new_capacity); @@ -57,7 +77,7 @@ int buffer_ensure_capacity(buffer_t *buf, size_t required) { } void buffer_compact(buffer_t *buf) { - if (buf->head == 0) return; + if (!buf || buf->head == 0) return; size_t len = buf->tail - buf->head; if (len > 0) { memmove(buf->data, buf->data + buf->head, len); @@ -67,6 +87,11 @@ void buffer_compact(buffer_t *buf) { } void buffer_consume(buffer_t *buf, size_t bytes) { + if (!buf) return; + size_t available = buf->tail - buf->head; + if (bytes > available) { + bytes = available; + } buf->head += bytes; if (buf->head >= buf->tail) { buf->head = 0; diff --git a/src/config.c b/src/config.c index 0041235..18199ce 100644 --- a/src/config.c +++ b/src/config.c @@ -4,9 +4,84 @@ #include #include #include +#include +#include +#include +#include + +static pthread_rwlock_t config_lock = PTHREAD_RWLOCK_INITIALIZER; +static time_t config_file_mtime = 0; + +static void compute_password_hash(const char *password, char *output, size_t output_size) { + EVP_MD_CTX *ctx = EVP_MD_CTX_new(); + if (!ctx) return; + + unsigned char hash[EVP_MAX_MD_SIZE]; + unsigned int hash_len = 0; + + EVP_DigestInit_ex(ctx, EVP_sha256(), NULL); + EVP_DigestUpdate(ctx, password, strlen(password)); + EVP_DigestFinal_ex(ctx, hash, &hash_len); + EVP_MD_CTX_free(ctx); + + for (unsigned int i = 0; i < hash_len && (i * 2 + 2) < output_size; i++) { + snprintf(output + (i * 2), 3, "%02x", hash[i]); + } +} app_config_t config; +static int is_valid_hostname(const char *hostname) { + if (!hostname || strlen(hostname) == 0 || strlen(hostname) > 253) return 0; + + const char *p = hostname; + int label_len = 0; + + while (*p) { + char c = *p; + if (c == '.') { + if (label_len == 0) return 0; + label_len = 0; + } else if (isalnum((unsigned char)c) || c == '-' || c == '_') { + label_len++; + if (label_len > 63) return 0; + } else { + return 0; + } + p++; + } + + return 1; +} + +static int is_valid_ip(const char *ip) { + if (!ip) return 0; + int dots = 0; + int num = 0; + int has_digit = 0; + + while (*ip) { + if (*ip == '.') { + if (!has_digit || num > 255) return 0; + dots++; + num = 0; + has_digit = 0; + } else if (isdigit((unsigned char)*ip)) { + num = num * 10 + (*ip - '0'); + has_digit = 1; + } else { + return 0; + } + ip++; + } + + return dots == 3 && has_digit && num <= 255; +} + +static int is_valid_host(const char *host) { + return is_valid_hostname(host) || is_valid_ip(host); +} + static char* read_file_to_string(const char *filename) { FILE *f = fopen(filename, "rb"); if (!f) return NULL; @@ -81,8 +156,20 @@ int config_load(const char *filename) { continue; } + if (!is_valid_host(hostname->valuestring)) { + fprintf(stderr, "Invalid hostname at index %d: %s\n", i, hostname->valuestring); + continue; + } + + if (!is_valid_host(upstream_host->valuestring)) { + fprintf(stderr, "Invalid upstream_host at index %d: %s\n", i, upstream_host->valuestring); + continue; + } + strncpy(route->hostname, hostname->valuestring, sizeof(route->hostname) - 1); + route->hostname[sizeof(route->hostname) - 1] = '\0'; strncpy(route->upstream_host, upstream_host->valuestring, sizeof(route->upstream_host) - 1); + route->upstream_host[sizeof(route->upstream_host) - 1] = '\0'; route->upstream_port = upstream_port->valueint; if (route->upstream_port < 1 || route->upstream_port > 65535) { @@ -93,9 +180,27 @@ int config_load(const char *filename) { route->use_ssl = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "use_ssl")); route->rewrite_host = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "rewrite_host")); - log_info("Route configured: %s -> %s:%d (SSL: %s, Rewrite Host: %s)", + route->use_auth = 0; + route->username[0] = '\0'; + route->password_hash[0] = '\0'; + + cJSON *use_auth = cJSON_GetObjectItem(route_item, "use_auth"); + cJSON *auth_username = cJSON_GetObjectItem(route_item, "username"); + cJSON *auth_password = cJSON_GetObjectItem(route_item, "password"); + + if (cJSON_IsTrue(use_auth) && cJSON_IsString(auth_username) && cJSON_IsString(auth_password)) { + if (strlen(auth_username->valuestring) > 0 && strlen(auth_password->valuestring) > 0) { + route->use_auth = 1; + strncpy(route->username, auth_username->valuestring, sizeof(route->username) - 1); + route->username[sizeof(route->username) - 1] = '\0'; + compute_password_hash(auth_password->valuestring, route->password_hash, sizeof(route->password_hash)); + } + } + + log_info("Route configured: %s -> %s:%d (SSL: %s, Rewrite Host: %s, Auth: %s)", route->hostname, route->upstream_host, route->upstream_port, - route->use_ssl ? "yes" : "no", route->rewrite_host ? "yes" : "no"); + route->use_ssl ? "yes" : "no", route->rewrite_host ? "yes" : "no", + route->use_auth ? "yes" : "no"); i++; } } @@ -150,10 +255,144 @@ void config_create_default(const char *filename) { route_config_t *config_find_route(const char *hostname) { if (!hostname) return NULL; + pthread_rwlock_rdlock(&config_lock); + route_config_t *result = NULL; for (int i = 0; i < config.route_count; i++) { if (strcasecmp(hostname, config.routes[i].hostname) == 0) { - return &config.routes[i]; + result = &config.routes[i]; + break; } } - return NULL; + pthread_rwlock_unlock(&config_lock); + return result; +} + +int config_check_file_changed(const char *filename) { + struct stat st; + if (stat(filename, &st) != 0) { + return 0; + } + if (config_file_mtime == 0) { + config_file_mtime = st.st_mtime; + return 0; + } + if (st.st_mtime != config_file_mtime) { + config_file_mtime = st.st_mtime; + return 1; + } + return 0; +} + +int config_hot_reload(const char *filename) { + log_info("Hot-reloading configuration from %s", filename); + + app_config_t new_config; + memset(&new_config, 0, sizeof(app_config_t)); + + char *json_string = read_file_to_string(filename); + if (!json_string) { + log_error("Hot-reload: Could not read config file"); + return 0; + } + + cJSON *root = cJSON_Parse(json_string); + free(json_string); + if (!root) { + log_error("Hot-reload: JSON parse error: %s", cJSON_GetErrorPtr()); + return 0; + } + + cJSON *port_item = cJSON_GetObjectItem(root, "port"); + new_config.port = cJSON_IsNumber(port_item) ? port_item->valueint : 8080; + + if (new_config.port < 1 || new_config.port > 65535) { + log_error("Hot-reload: Invalid port number: %d", new_config.port); + cJSON_Delete(root); + return 0; + } + + cJSON *proxy_array = cJSON_GetObjectItem(root, "reverse_proxy"); + if (cJSON_IsArray(proxy_array)) { + new_config.route_count = cJSON_GetArraySize(proxy_array); + if (new_config.route_count <= 0) { + cJSON_Delete(root); + return 0; + } + + new_config.routes = calloc(new_config.route_count, sizeof(route_config_t)); + if (!new_config.routes) { + log_error("Hot-reload: Failed to allocate memory for routes"); + cJSON_Delete(root); + return 0; + } + + int i = 0; + cJSON *route_item; + cJSON_ArrayForEach(route_item, proxy_array) { + route_config_t *route = &new_config.routes[i]; + + cJSON *hostname = cJSON_GetObjectItem(route_item, "hostname"); + cJSON *upstream_host = cJSON_GetObjectItem(route_item, "upstream_host"); + cJSON *upstream_port = cJSON_GetObjectItem(route_item, "upstream_port"); + + if (!cJSON_IsString(hostname) || !cJSON_IsString(upstream_host) || !cJSON_IsNumber(upstream_port)) { + continue; + } + + if (!is_valid_host(hostname->valuestring) || !is_valid_host(upstream_host->valuestring)) { + continue; + } + + strncpy(route->hostname, hostname->valuestring, sizeof(route->hostname) - 1); + route->hostname[sizeof(route->hostname) - 1] = '\0'; + strncpy(route->upstream_host, upstream_host->valuestring, sizeof(route->upstream_host) - 1); + route->upstream_host[sizeof(route->upstream_host) - 1] = '\0'; + route->upstream_port = upstream_port->valueint; + + if (route->upstream_port < 1 || route->upstream_port > 65535) { + continue; + } + + route->use_ssl = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "use_ssl")); + route->rewrite_host = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "rewrite_host")); + + route->use_auth = 0; + route->username[0] = '\0'; + route->password_hash[0] = '\0'; + + cJSON *use_auth = cJSON_GetObjectItem(route_item, "use_auth"); + cJSON *auth_username = cJSON_GetObjectItem(route_item, "username"); + cJSON *auth_password = cJSON_GetObjectItem(route_item, "password"); + + if (cJSON_IsTrue(use_auth) && cJSON_IsString(auth_username) && cJSON_IsString(auth_password)) { + if (strlen(auth_username->valuestring) > 0 && strlen(auth_password->valuestring) > 0) { + route->use_auth = 1; + strncpy(route->username, auth_username->valuestring, sizeof(route->username) - 1); + route->username[sizeof(route->username) - 1] = '\0'; + compute_password_hash(auth_password->valuestring, route->password_hash, sizeof(route->password_hash)); + } + } + + log_info("Hot-reload route: %s -> %s:%d (SSL: %s, Auth: %s)", + route->hostname, route->upstream_host, route->upstream_port, + route->use_ssl ? "yes" : "no", route->use_auth ? "yes" : "no"); + i++; + } + new_config.route_count = i; + } + + cJSON_Delete(root); + + pthread_rwlock_wrlock(&config_lock); + route_config_t *old_routes = config.routes; + config.routes = new_config.routes; + config.route_count = new_config.route_count; + pthread_rwlock_unlock(&config_lock); + + if (old_routes) { + free(old_routes); + } + + log_info("Hot-reload complete: %d routes loaded", new_config.route_count); + return 1; } diff --git a/src/config.h b/src/config.h index 3c17c7c..8b3925a 100644 --- a/src/config.h +++ b/src/config.h @@ -3,11 +3,15 @@ #include "types.h" +#define CONFIG_RELOAD_INTERVAL_SECONDS 3 + extern app_config_t config; int config_load(const char *filename); void config_free(void); void config_create_default(const char *filename); route_config_t *config_find_route(const char *hostname); +int config_check_file_changed(const char *filename); +int config_hot_reload(const char *filename); #endif diff --git a/src/connection.c b/src/connection.c index 07f2c81..4a753d5 100644 --- a/src/connection.c +++ b/src/connection.c @@ -6,6 +6,7 @@ #include "http.h" #include "ssl_handler.h" #include "dashboard.h" +#include "auth.h" #include #include @@ -29,11 +30,17 @@ void connection_init_all(void) { } } -void connection_set_non_blocking(int fd) { +int connection_set_non_blocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); - if (flags >= 0) { - fcntl(fd, F_SETFL, flags | O_NONBLOCK); + if (flags < 0) { + log_error("fcntl F_GETFL failed"); + return -1; } + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) { + log_error("fcntl F_SETFL failed"); + return -1; + } + return 0; } void connection_set_tcp_keepalive(int fd) { @@ -273,16 +280,22 @@ int connection_do_write(connection_t *conn) { void connection_send_error_response(connection_t *conn, int code, const char* status, const char* body) { if (!conn || !status || !body) return; - char response[2048]; + time_t now = time(NULL); + struct tm *gmt = gmtime(&now); + char date_buf[64]; + strftime(date_buf, sizeof(date_buf), "%a, %d %b %Y %H:%M:%S GMT", gmt); + + char response[ERROR_RESPONSE_SIZE]; int len = snprintf(response, sizeof(response), "HTTP/1.1 %d %s\r\n" "Content-Type: text/plain; charset=utf-8\r\n" "Content-Length: %zu\r\n" "Connection: close\r\n" + "Date: %s\r\n" "Server: ReverseProxy/4.0\r\n" "\r\n" "%s", - code, status, strlen(body), body); + code, status, strlen(body), date_buf, body); if (len > 0 && (size_t)len < sizeof(response)) { if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) { @@ -298,6 +311,66 @@ void connection_send_error_response(connection_t *conn, int code, const char* st conn->request.keep_alive = 0; } +void connection_send_auth_required(connection_t *conn, const char *realm) { + if (!conn) return; + + time_t now = time(NULL); + struct tm *gmt = gmtime(&now); + char date_buf[64]; + strftime(date_buf, sizeof(date_buf), "%a, %d %b %Y %H:%M:%S GMT", gmt); + + const char *body = "401 Unauthorized - Authentication required"; + char response[ERROR_RESPONSE_SIZE]; + int len = snprintf(response, sizeof(response), + "HTTP/1.1 401 Unauthorized\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Content-Length: %zu\r\n" + "WWW-Authenticate: Basic realm=\"%s\"\r\n" + "Connection: close\r\n" + "Date: %s\r\n" + "Server: ReverseProxy/4.0\r\n" + "\r\n" + "%s", + strlen(body), realm ? realm : "Protected Area", date_buf, body); + + if (len > 0 && (size_t)len < sizeof(response)) { + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) { + memcpy(conn->write_buf.data + conn->write_buf.tail, response, len); + conn->write_buf.tail += len; + + struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); + } + } + + conn->state = CLIENT_STATE_ERROR; + conn->request.keep_alive = 0; +} + +static int try_upstream_connect(struct sockaddr_in *addr, int *out_fd) { + int up_fd = socket(AF_INET, SOCK_STREAM, 0); + if (up_fd < 0) { + return -1; + } + + if (up_fd >= MAX_FDS) { + close(up_fd); + return -1; + } + + connection_set_non_blocking(up_fd); + connection_set_tcp_keepalive(up_fd); + + int connect_result = connect(up_fd, (struct sockaddr*)addr, sizeof(*addr)); + if (connect_result < 0 && errno != EINPROGRESS) { + close(up_fd); + return -1; + } + + *out_fd = up_fd; + return 0; +} + void connection_connect_to_upstream(connection_t *client, const char *data, size_t data_len) { if (!client || !data) return; @@ -307,39 +380,48 @@ void connection_connect_to_upstream(connection_t *client, const char *data, size return; } - int up_fd = socket(AF_INET, SOCK_STREAM, 0); - if (up_fd < 0) { - connection_send_error_response(client, 502, "Bad Gateway", "Failed to create upstream socket"); - return; - } - - if (up_fd >= MAX_FDS) { - close(up_fd); - connection_send_error_response(client, 502, "Bad Gateway", "Connection limit exceeded"); - return; - } - struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = htons(route->upstream_port); if (inet_pton(AF_INET, route->upstream_host, &addr.sin_addr) <= 0) { - struct hostent *he = gethostbyname(route->upstream_host); - if (!he) { - close(up_fd); + struct addrinfo hints, *result; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + int gai_err = getaddrinfo(route->upstream_host, NULL, &hints, &result); + if (gai_err != 0) { + log_debug("DNS resolution failed for %s: %s", route->upstream_host, gai_strerror(gai_err)); connection_send_error_response(client, 502, "Bad Gateway", "Cannot resolve upstream hostname"); return; } - memcpy(&addr.sin_addr, he->h_addr_list[0], he->h_length); + + struct sockaddr_in *resolved = (struct sockaddr_in *)result->ai_addr; + addr.sin_addr = resolved->sin_addr; + freeaddrinfo(result); } - connection_set_non_blocking(up_fd); - connection_set_tcp_keepalive(up_fd); + int up_fd = -1; + int retry_count = 0; - int connect_result = connect(up_fd, (struct sockaddr*)&addr, sizeof(addr)); - if (connect_result < 0 && errno != EINPROGRESS) { - close(up_fd); + while (retry_count < MAX_UPSTREAM_RETRIES) { + if (try_upstream_connect(&addr, &up_fd) == 0) { + break; + } + + retry_count++; + if (retry_count < MAX_UPSTREAM_RETRIES) { + log_debug("Upstream connection attempt %d failed for %s:%d, retrying...", + retry_count, route->upstream_host, route->upstream_port); + usleep(UPSTREAM_RETRY_DELAY_MS * 1000); + } + } + + if (up_fd < 0) { + log_debug("All %d connection attempts failed for %s:%d", + MAX_UPSTREAM_RETRIES, route->upstream_host, route->upstream_port); connection_send_error_response(client, 502, "Bad Gateway", "Failed to connect to upstream"); return; } @@ -356,9 +438,23 @@ void connection_connect_to_upstream(connection_t *client, const char *data, size up->pair = client; up->vhost_stats = client->vhost_stats; - if (buffer_init(&up->read_buf, CHUNK_SIZE) < 0 || - buffer_init(&up->write_buf, CHUNK_SIZE) < 0) { - connection_close(client->fd); + if (buffer_init(&up->read_buf, CHUNK_SIZE) < 0) { + close(up_fd); + memset(up, 0, sizeof(connection_t)); + up->type = CONN_TYPE_UNUSED; + up->fd = -1; + client->pair = NULL; + connection_send_error_response(client, 502, "Bad Gateway", "Memory allocation failed"); + return; + } + if (buffer_init(&up->write_buf, CHUNK_SIZE) < 0) { + buffer_free(&up->read_buf); + close(up_fd); + memset(up, 0, sizeof(connection_t)); + up->type = CONN_TYPE_UNUSED; + up->fd = -1; + client->pair = NULL; + connection_send_error_response(client, 502, "Bad Gateway", "Memory allocation failed"); return; } @@ -535,9 +631,9 @@ static void handle_client_read(connection_t *conn) { conn->state = CLIENT_STATE_SERVING_INTERNAL; if (strncmp(conn->request.uri, DASHBOARD_PATH, sizeof(DASHBOARD_PATH) - 1) == 0) { - dashboard_serve(conn); + dashboard_serve(conn, data_start, headers_len); } else { - dashboard_serve_stats_api(conn); + dashboard_serve_stats_api(conn, data_start, headers_len); } buffer_consume(buf, total_request_len); @@ -551,6 +647,21 @@ static void handle_client_read(connection_t *conn) { #undef DASHBOARD_PATH #undef STATS_PATH + route_config_t *route = config_find_route(conn->request.host); + if (route && route->use_auth) { + char auth_header[1024] = ""; + const char *headers_start = data_start + (strstr(data_start, "\r\n") - data_start + 2); + http_find_header_value(headers_start, headers_len - (headers_start - data_start), "Authorization", auth_header, sizeof(auth_header)); + + char error_msg[256] = ""; + if (!auth_check_route_basic_auth(route, strlen(auth_header) > 0 ? auth_header : NULL, error_msg, sizeof(error_msg))) { + log_info("[ROUTING-AUTH] Authentication failed for %s: %s", conn->request.host, error_msg); + connection_send_auth_required(conn, conn->request.host); + buffer_consume(buf, total_request_len); + return; + } + } + log_info("[ROUTING-FORWARD] Forwarding request for fd=%d: %s %s", conn->fd, conn->request.method, conn->request.uri); diff --git a/src/connection.h b/src/connection.h index 2b99669..919976c 100644 --- a/src/connection.h +++ b/src/connection.h @@ -14,7 +14,7 @@ void connection_close(int fd); void connection_handle_event(struct epoll_event *event); void connection_cleanup_idle(void); -void connection_set_non_blocking(int fd); +int connection_set_non_blocking(int fd); void connection_set_tcp_keepalive(int fd); void connection_add_to_epoll(int fd, uint32_t events); void connection_modify_epoll(int fd, uint32_t events); @@ -23,6 +23,7 @@ int connection_do_read(connection_t *conn); int connection_do_write(connection_t *conn); void connection_send_error_response(connection_t *conn, int code, const char* status, const char* body); +void connection_send_auth_required(connection_t *conn, const char *realm); void connection_connect_to_upstream(connection_t *client, const char *data, size_t data_len); #endif diff --git a/src/dashboard.c b/src/dashboard.c index cdf8284..7b3b9bc 100644 --- a/src/dashboard.c +++ b/src/dashboard.c @@ -2,11 +2,14 @@ #include "buffer.h" #include "monitor.h" #include "connection.h" +#include "auth.h" +#include "http.h" #include "../cJSON.h" #include #include #include #include +#include static const char *DASHBOARD_HTML = "\n" @@ -289,9 +292,59 @@ static const char *DASHBOARD_HTML = "\n" "\n"; -void dashboard_serve(connection_t *conn) { +static void send_unauthorized_response(connection_t *conn) { + time_t now = time(NULL); + struct tm *gmt = gmtime(&now); + char date_buf[64]; + strftime(date_buf, sizeof(date_buf), "%a, %d %b %Y %H:%M:%S GMT", gmt); + + const char *body = "Unauthorized"; + char header[1024]; + int len = snprintf(header, sizeof(header), + "HTTP/1.1 401 Unauthorized\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Content-Length: %zu\r\n" + "WWW-Authenticate: Basic realm=\"RProxy Dashboard\"\r\n" + "Date: %s\r\n" + "Connection: close\r\n" + "\r\n" + "%s", + strlen(body), date_buf, body); + + if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) { + memcpy(conn->write_buf.data + conn->write_buf.tail, header, len); + conn->write_buf.tail += len; + + struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT }; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event); + } + + conn->state = CLIENT_STATE_ERROR; + conn->request.keep_alive = 0; +} + +static int check_dashboard_auth(connection_t *conn, const char *request_data, size_t request_len) { + if (!auth_is_enabled()) return 1; + + char auth_header[512] = ""; + http_find_header_value(request_data, request_len, "Authorization", auth_header, sizeof(auth_header)); + + char error_msg[256] = ""; + if (!auth_check_basic_auth(auth_header[0] ? auth_header : NULL, error_msg, sizeof(error_msg))) { + send_unauthorized_response(conn); + return 0; + } + + return 1; +} + +void dashboard_serve(connection_t *conn, const char *request_data, size_t request_len) { if (!conn) return; + if (!check_dashboard_auth(conn, request_data, request_len)) { + return; + } + size_t content_len = strlen(DASHBOARD_HTML); char header[512]; int len = snprintf(header, sizeof(header), @@ -320,7 +373,8 @@ void dashboard_serve(connection_t *conn) { static cJSON* format_history(history_deque_t *dq, int window_seconds) { cJSON *arr = cJSON_CreateArray(); - if (!arr || !dq || !dq->points || dq->count == 0) return arr; + if (!arr) return NULL; + if (!dq || !dq->points || dq->count == 0) return arr; double current_time = time(NULL); int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity; @@ -342,7 +396,8 @@ static cJSON* format_history(history_deque_t *dq, int window_seconds) { static cJSON* format_network_history(network_history_deque_t *dq, int window_seconds, const char *key) { cJSON *arr = cJSON_CreateArray(); - if (!arr || !dq || !dq->points || !key || dq->count == 0) return arr; + if (!arr) return NULL; + if (!dq || !dq->points || !key || dq->count == 0) return arr; double current_time = time(NULL); int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity; @@ -364,7 +419,8 @@ static cJSON* format_network_history(network_history_deque_t *dq, int window_sec static cJSON* format_disk_history(disk_history_deque_t *dq, int window_seconds, const char *key) { cJSON *arr = cJSON_CreateArray(); - if (!arr || !dq || !dq->points || !key || dq->count == 0) return arr; + if (!arr) return NULL; + if (!dq || !dq->points || !key || dq->count == 0) return arr; double current_time = time(NULL); int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity; @@ -384,9 +440,13 @@ static cJSON* format_disk_history(disk_history_deque_t *dq, int window_seconds, return arr; } -void dashboard_serve_stats_api(connection_t *conn) { +void dashboard_serve_stats_api(connection_t *conn, const char *request_data, size_t request_len) { if (!conn) return; + if (!check_dashboard_auth(conn, request_data, request_len)) { + return; + } + cJSON *root = cJSON_CreateObject(); if (!root) { connection_send_error_response(conn, 500, "Internal Server Error", "JSON creation failed"); diff --git a/src/dashboard.h b/src/dashboard.h index bf06406..07345a2 100644 --- a/src/dashboard.h +++ b/src/dashboard.h @@ -3,7 +3,7 @@ #include "types.h" -void dashboard_serve(connection_t *conn); -void dashboard_serve_stats_api(connection_t *conn); +void dashboard_serve(connection_t *conn, const char *request_data, size_t request_len); +void dashboard_serve_stats_api(connection_t *conn, const char *request_data, size_t request_len); #endif diff --git a/src/health_check.c b/src/health_check.c new file mode 100644 index 0000000..92983e7 --- /dev/null +++ b/src/health_check.c @@ -0,0 +1,188 @@ +#include "health_check.h" +#include "logging.h" +#include "config.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +typedef struct { + char hostname[256]; + char upstream_host[256]; + int upstream_port; + int healthy; + int consecutive_failures; + time_t last_check; +} upstream_health_t; + +static upstream_health_t *health_states = NULL; +static int health_state_count = 0; +static pthread_mutex_t health_mutex = PTHREAD_MUTEX_INITIALIZER; +static int g_health_check_enabled = 0; + +void health_check_init(void) { + pthread_mutex_lock(&health_mutex); + + if (health_states) { + free(health_states); + } + + health_state_count = config.route_count; + if (health_state_count <= 0) { + health_states = NULL; + pthread_mutex_unlock(&health_mutex); + return; + } + + health_states = calloc(health_state_count, sizeof(upstream_health_t)); + if (!health_states) { + health_state_count = 0; + pthread_mutex_unlock(&health_mutex); + return; + } + + for (int i = 0; i < health_state_count; i++) { + strncpy(health_states[i].hostname, config.routes[i].hostname, sizeof(health_states[i].hostname) - 1); + strncpy(health_states[i].upstream_host, config.routes[i].upstream_host, sizeof(health_states[i].upstream_host) - 1); + health_states[i].upstream_port = config.routes[i].upstream_port; + health_states[i].healthy = 1; + health_states[i].consecutive_failures = 0; + health_states[i].last_check = 0; + } + + g_health_check_enabled = 1; + log_info("Health check initialized for %d upstreams", health_state_count); + + pthread_mutex_unlock(&health_mutex); +} + +void health_check_cleanup(void) { + pthread_mutex_lock(&health_mutex); + if (health_states) { + free(health_states); + health_states = NULL; + } + health_state_count = 0; + g_health_check_enabled = 0; + pthread_mutex_unlock(&health_mutex); +} + +static int check_tcp_connection(const char *host, int port, int timeout_ms) { + struct addrinfo hints, *result; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + if (getaddrinfo(host, NULL, &hints, &result) != 0) { + return 0; + } + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + freeaddrinfo(result); + return 0; + } + + int flags = fcntl(fd, F_GETFL, 0); + fcntl(fd, F_SETFL, flags | O_NONBLOCK); + + struct sockaddr_in *addr = (struct sockaddr_in *)result->ai_addr; + addr->sin_port = htons(port); + + int connect_result = connect(fd, (struct sockaddr *)addr, sizeof(struct sockaddr_in)); + freeaddrinfo(result); + + if (connect_result == 0) { + close(fd); + return 1; + } + + if (errno != EINPROGRESS) { + close(fd); + return 0; + } + + struct pollfd pfd; + pfd.fd = fd; + pfd.events = POLLOUT; + + int poll_result = poll(&pfd, 1, timeout_ms); + if (poll_result <= 0) { + close(fd); + return 0; + } + + int error = 0; + socklen_t len = sizeof(error); + getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &len); + close(fd); + + return error == 0; +} + +void health_check_run(void) { + if (!g_health_check_enabled) return; + + pthread_mutex_lock(&health_mutex); + + time_t now = time(NULL); + + for (int i = 0; i < health_state_count; i++) { + if (now - health_states[i].last_check < HEALTH_CHECK_INTERVAL_SECONDS) { + continue; + } + + health_states[i].last_check = now; + + int is_healthy = check_tcp_connection( + health_states[i].upstream_host, + health_states[i].upstream_port, + HEALTH_CHECK_TIMEOUT_MS + ); + + if (is_healthy) { + if (!health_states[i].healthy) { + log_info("Upstream %s:%d is now healthy", + health_states[i].upstream_host, + health_states[i].upstream_port); + } + health_states[i].healthy = 1; + health_states[i].consecutive_failures = 0; + } else { + health_states[i].consecutive_failures++; + if (health_states[i].consecutive_failures >= 3 && health_states[i].healthy) { + log_info("Upstream %s:%d is now unhealthy (failures: %d)", + health_states[i].upstream_host, + health_states[i].upstream_port, + health_states[i].consecutive_failures); + health_states[i].healthy = 0; + } + } + } + + pthread_mutex_unlock(&health_mutex); +} + +int health_check_is_healthy(const char *hostname) { + if (!g_health_check_enabled || !hostname) return 1; + + pthread_mutex_lock(&health_mutex); + + for (int i = 0; i < health_state_count; i++) { + if (strcasecmp(health_states[i].hostname, hostname) == 0) { + int result = health_states[i].healthy; + pthread_mutex_unlock(&health_mutex); + return result; + } + } + + pthread_mutex_unlock(&health_mutex); + return 1; +} diff --git a/src/health_check.h b/src/health_check.h new file mode 100644 index 0000000..6724d49 --- /dev/null +++ b/src/health_check.h @@ -0,0 +1,11 @@ +#ifndef RPROXY_HEALTH_CHECK_H +#define RPROXY_HEALTH_CHECK_H + +#include "types.h" + +void health_check_init(void); +void health_check_cleanup(void); +void health_check_run(void); +int health_check_is_healthy(const char *hostname); + +#endif diff --git a/src/http.c b/src/http.c index 3e7f25f..4735ccd 100644 --- a/src/http.c +++ b/src/http.c @@ -124,7 +124,11 @@ int http_parse_request(const char *data, size_t len, http_request_t *req) { } if (http_find_header_value(headers_start, len - (headers_start - data), "Content-Length", value, sizeof(value))) { - req->content_length = atol(value); + char *endptr; + long parsed = strtol(value, &endptr, 10); + if (endptr != value && *endptr == '\0' && parsed >= 0) { + req->content_length = parsed; + } } if (http_find_header_value(headers_start, len - (headers_start - data), "Transfer-Encoding", value, sizeof(value))) { diff --git a/src/logging.c b/src/logging.c index 7efc5fc..40b93c5 100644 --- a/src/logging.c +++ b/src/logging.c @@ -2,8 +2,13 @@ #include #include #include +#include +#include +#include static int g_debug_mode = 0; +static FILE *g_log_file = NULL; +static pthread_mutex_t log_mutex = PTHREAD_MUTEX_INITIALIZER; void logging_set_debug(int enabled) { g_debug_mode = enabled; @@ -13,19 +18,77 @@ int logging_get_debug(void) { return g_debug_mode; } -void log_error(const char *msg) { - perror(msg); +int logging_set_file(const char *path) { + pthread_mutex_lock(&log_mutex); + if (g_log_file && g_log_file != stdout && g_log_file != stderr) { + fclose(g_log_file); + } + if (path) { + g_log_file = fopen(path, "a"); + if (!g_log_file) { + g_log_file = stdout; + pthread_mutex_unlock(&log_mutex); + return -1; + } + } else { + g_log_file = stdout; + } + pthread_mutex_unlock(&log_mutex); + return 0; +} + +void logging_cleanup(void) { + pthread_mutex_lock(&log_mutex); + if (g_log_file && g_log_file != stdout && g_log_file != stderr) { + fclose(g_log_file); + } + g_log_file = NULL; + pthread_mutex_unlock(&log_mutex); } static void log_message(const char *level, const char *format, va_list args) { + pthread_mutex_lock(&log_mutex); + + FILE *out = g_log_file ? g_log_file : stdout; time_t now; time(&now); + struct tm *local = localtime(&now); char buf[32]; - strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", localtime(&now)); - printf("%s - %-5s - ", buf, level); - vprintf(format, args); - printf("\n"); - fflush(stdout); + strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", local); + fprintf(out, "%s - %-5s - ", buf, level); + vfprintf(out, format, args); + fprintf(out, "\n"); + fflush(out); + + pthread_mutex_unlock(&log_mutex); +} + +void log_error(const char *format, ...) { + va_list args; + va_start(args, format); + + int saved_errno = errno; + char msg[1024]; + vsnprintf(msg, sizeof(msg), format, args); + va_end(args); + + pthread_mutex_lock(&log_mutex); + + FILE *out = g_log_file ? g_log_file : stderr; + time_t now; + time(&now); + struct tm *local = localtime(&now); + char buf[32]; + strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", local); + + if (saved_errno != 0) { + fprintf(out, "%s - ERROR - %s: %s\n", buf, msg, strerror(saved_errno)); + } else { + fprintf(out, "%s - ERROR - %s\n", buf, msg); + } + fflush(out); + + pthread_mutex_unlock(&log_mutex); } void log_info(const char *format, ...) { diff --git a/src/logging.h b/src/logging.h index 2847dea..5b4e36e 100644 --- a/src/logging.h +++ b/src/logging.h @@ -1,10 +1,12 @@ #ifndef RPROXY_LOGGING_H #define RPROXY_LOGGING_H -void log_error(const char *msg); +void log_error(const char *format, ...); void log_info(const char *format, ...); void log_debug(const char *format, ...); void logging_set_debug(int enabled); int logging_get_debug(void); +int logging_set_file(const char *path); +void logging_cleanup(void); #endif diff --git a/src/main.c b/src/main.c index e3dd134..ce42e9f 100644 --- a/src/main.c +++ b/src/main.c @@ -6,6 +6,7 @@ #include #include #include +#include #include "types.h" #include "logging.h" @@ -13,13 +14,82 @@ #include "monitor.h" #include "ssl_handler.h" #include "connection.h" +#include "rate_limit.h" +#include "auth.h" +#include "health_check.h" -static volatile int g_shutdown = 0; +static volatile sig_atomic_t g_shutdown = 0; +static volatile sig_atomic_t g_reload_config = 0; +static const char *g_config_file = NULL; static void signal_handler(int sig) { if (sig == SIGINT || sig == SIGTERM) { - log_info("Received signal %d, shutting down...", sig); g_shutdown = 1; + } else if (sig == SIGHUP) { + g_reload_config = 1; + } +} + +static void reload_configuration(void) { + if (!g_config_file) return; + + log_info("Reloading configuration from %s", g_config_file); + + app_config_t old_config = config; + memset(&config, 0, sizeof(app_config_t)); + + if (!config_load(g_config_file)) { + log_error("Failed to reload configuration, keeping old config"); + config = old_config; + return; + } + + if (old_config.routes) { + free(old_config.routes); + } + + log_info("Configuration reloaded successfully"); +} + +static void graceful_shutdown(void) { + log_info("Initiating graceful shutdown..."); + + for (int i = 0; i < MAX_FDS; i++) { + connection_t *conn = &connections[i]; + if (conn->type == CONN_TYPE_LISTENER && conn->fd != -1) { + epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->fd, NULL); + close(conn->fd); + conn->fd = -1; + conn->type = CONN_TYPE_UNUSED; + log_info("Stopped accepting new connections"); + } + } + + int active_count = 0; + int drain_timeout = 30; + time_t drain_start = time(NULL); + + do { + active_count = 0; + for (int i = 0; i < MAX_FDS; i++) { + if (connections[i].type == CONN_TYPE_CLIENT || connections[i].type == CONN_TYPE_UPSTREAM) { + if (connections[i].fd != -1) { + active_count++; + } + } + } + + if (active_count > 0 && (time(NULL) - drain_start) < drain_timeout) { + struct epoll_event events[MAX_EVENTS]; + int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 100); + for (int i = 0; i < n; i++) { + connection_handle_event(&events[i]); + } + } + } while (active_count > 0 && (time(NULL) - drain_start) < drain_timeout); + + if (active_count > 0) { + log_info("Drain timeout reached, forcing close of %d connections", active_count); } } @@ -34,6 +104,9 @@ static void cleanup(void) { config_free(); monitor_cleanup(); + rate_limit_cleanup(); + health_check_cleanup(); + logging_cleanup(); if (epoll_fd >= 0) { close(epoll_fd); @@ -49,23 +122,62 @@ int main(int argc, char *argv[]) { signal(SIGPIPE, SIG_IGN); signal(SIGINT, signal_handler); signal(SIGTERM, signal_handler); + signal(SIGHUP, signal_handler); if (getenv("DEBUG")) { logging_set_debug(1); log_info("Debug mode enabled"); } - const char *config_file = (argc > 1) ? argv[1] : "proxy_config.json"; - config_create_default(config_file); + const char *log_file = getenv("LOG_FILE"); + if (log_file) { + if (logging_set_file(log_file) == 0) { + log_info("Logging to file: %s", log_file); + } + } - if (!config_load(config_file)) { + g_config_file = (argc > 1) ? argv[1] : "proxy_config.json"; + config_create_default(g_config_file); + + if (!config_load(g_config_file)) { fprintf(stderr, "Failed to load configuration\n"); return 1; } + const char *ssl_verify = getenv("SSL_VERIFY"); + if (ssl_verify && strcmp(ssl_verify, "0") == 0) { + ssl_set_verify(0); + } + + const char *ca_file = getenv("SSL_CA_FILE"); + if (ca_file) { + ssl_set_ca_file(ca_file); + } + + const char *ca_path = getenv("SSL_CA_PATH"); + if (ca_path) { + ssl_set_ca_path(ca_path); + } + ssl_init(); monitor_init("proxy_stats.db"); + const char *rate_limit_str = getenv("RATE_LIMIT"); + if (rate_limit_str) { + int rate = atoi(rate_limit_str); + if (rate > 0) { + rate_limit_init(rate, RATE_LIMIT_WINDOW_SECONDS); + } + } + + const char *auth_user = getenv("DASHBOARD_USER"); + const char *auth_pass = getenv("DASHBOARD_PASS"); + if (auth_user && auth_pass) { + auth_init(auth_user, auth_pass); + } + + health_check_init(); + epoll_fd = epoll_create1(EPOLL_CLOEXEC); if (epoll_fd == -1) { log_error("epoll_create1 failed"); @@ -78,14 +190,21 @@ int main(int argc, char *argv[]) { log_info("Port %d", config.port); log_info("Dashboard: http://localhost:%d/rproxy/dashboard", config.port); log_info("Stats: http://localhost:%d/rproxy/api/stats", config.port); + log_info("Send SIGHUP to reload configuration"); atexit(cleanup); struct epoll_event events[MAX_EVENTS]; time_t last_monitor_update = 0; time_t last_cleanup = 0; + time_t last_config_check = 0; while (!g_shutdown) { + if (g_reload_config) { + g_reload_config = 0; + reload_configuration(); + } + int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 1000); if (n == -1) { if (errno == EINTR) continue; @@ -104,12 +223,24 @@ int main(int argc, char *argv[]) { last_monitor_update = current_time; } + if (current_time - last_config_check >= CONFIG_RELOAD_INTERVAL_SECONDS) { + if (config_check_file_changed(g_config_file)) { + config_hot_reload(g_config_file); + } + last_config_check = current_time; + } + if (current_time - last_cleanup >= 60) { connection_cleanup_idle(); + rate_limit_purge_expired(); last_cleanup = current_time; } + + health_check_run(); } + log_info("Received shutdown signal"); + graceful_shutdown(); log_info("Shutdown complete"); return 0; } diff --git a/src/monitor.c b/src/monitor.c index 62b7d08..79eefb3 100644 --- a/src/monitor.c +++ b/src/monitor.c @@ -5,8 +5,10 @@ #include #include #include +#include system_monitor_t monitor; +static pthread_mutex_t vhost_stats_mutex = PTHREAD_MUTEX_INITIALIZER; void history_deque_init(history_deque_t *dq, int capacity) { dq->points = calloc(capacity, sizeof(history_point_t)); @@ -300,8 +302,11 @@ static void get_disk_stats(long long *sectors_read, long long *sectors_written) if (nfields >= 11) { strncpy(device, dev, sizeof(device)-1); device[sizeof(device)-1] = '\0'; - sectors_r = atoll(sr); - sectors_w = atoll(sw); + char *endptr; + sectors_r = strtoll(sr, &endptr, 10); + if (endptr == sr) sectors_r = 0; + sectors_w = strtoll(sw, &endptr, 10); + if (endptr == sw) sectors_w = 0; if (strncmp(device, "loop", 4) != 0 && strncmp(device, "ram", 3) != 0) { int len = strlen(device); @@ -398,12 +403,20 @@ void monitor_update(void) { vhost_stats_t* monitor_get_or_create_vhost_stats(const char *vhost_name) { if (!vhost_name || strlen(vhost_name) == 0) return NULL; + pthread_mutex_lock(&vhost_stats_mutex); + for (vhost_stats_t *curr = monitor.vhost_stats_head; curr; curr = curr->next) { - if (strcmp(curr->vhost_name, vhost_name) == 0) return curr; + if (strcmp(curr->vhost_name, vhost_name) == 0) { + pthread_mutex_unlock(&vhost_stats_mutex); + return curr; + } } vhost_stats_t *new_stats = calloc(1, sizeof(vhost_stats_t)); - if (!new_stats) return NULL; + if (!new_stats) { + pthread_mutex_unlock(&vhost_stats_mutex); + return NULL; + } strncpy(new_stats->vhost_name, vhost_name, sizeof(new_stats->vhost_name) - 1); new_stats->last_update = time(NULL); @@ -411,6 +424,8 @@ vhost_stats_t* monitor_get_or_create_vhost_stats(const char *vhost_name) { request_time_deque_init(&new_stats->request_times, 100); new_stats->next = monitor.vhost_stats_head; monitor.vhost_stats_head = new_stats; + + pthread_mutex_unlock(&vhost_stats_mutex); return new_stats; } diff --git a/src/rate_limit.c b/src/rate_limit.c new file mode 100644 index 0000000..ec4b324 --- /dev/null +++ b/src/rate_limit.c @@ -0,0 +1,126 @@ +#include "rate_limit.h" +#include "logging.h" +#include +#include +#include +#include + +#define MAX_RATE_LIMIT_ENTRIES 10000 + +typedef struct rate_limit_entry { + char client_ip[64]; + int request_count; + time_t window_start; + struct rate_limit_entry *next; +} rate_limit_entry_t; + +static rate_limit_entry_t *rate_limit_table[256]; +static pthread_mutex_t rate_limit_mutex = PTHREAD_MUTEX_INITIALIZER; +static int g_rate_limit_enabled = 0; +static int g_requests_per_window = DEFAULT_RATE_LIMIT_REQUESTS; +static int g_window_seconds = RATE_LIMIT_WINDOW_SECONDS; + +static unsigned int hash_ip(const char *ip) { + unsigned int hash = 0; + while (*ip) { + hash = hash * 31 + (unsigned char)*ip++; + } + return hash % 256; +} + +void rate_limit_init(int requests_per_window, int window_seconds) { + g_rate_limit_enabled = 1; + g_requests_per_window = requests_per_window; + g_window_seconds = window_seconds; + memset(rate_limit_table, 0, sizeof(rate_limit_table)); + log_info("Rate limiting enabled: %d requests per %d seconds", requests_per_window, window_seconds); +} + +void rate_limit_cleanup(void) { + pthread_mutex_lock(&rate_limit_mutex); + for (int i = 0; i < 256; i++) { + rate_limit_entry_t *entry = rate_limit_table[i]; + while (entry) { + rate_limit_entry_t *next = entry->next; + free(entry); + entry = next; + } + rate_limit_table[i] = NULL; + } + pthread_mutex_unlock(&rate_limit_mutex); +} + +int rate_limit_check(const char *client_ip) { + if (!g_rate_limit_enabled || !client_ip) return 1; + + pthread_mutex_lock(&rate_limit_mutex); + + time_t now = time(NULL); + unsigned int bucket = hash_ip(client_ip); + rate_limit_entry_t *entry = rate_limit_table[bucket]; + + while (entry) { + if (strcmp(entry->client_ip, client_ip) == 0) { + if (now - entry->window_start >= g_window_seconds) { + entry->window_start = now; + entry->request_count = 1; + pthread_mutex_unlock(&rate_limit_mutex); + return 1; + } + + entry->request_count++; + if (entry->request_count > g_requests_per_window) { + pthread_mutex_unlock(&rate_limit_mutex); + return 0; + } + + pthread_mutex_unlock(&rate_limit_mutex); + return 1; + } + entry = entry->next; + } + + rate_limit_entry_t *new_entry = calloc(1, sizeof(rate_limit_entry_t)); + if (!new_entry) { + pthread_mutex_unlock(&rate_limit_mutex); + return 1; + } + + strncpy(new_entry->client_ip, client_ip, sizeof(new_entry->client_ip) - 1); + new_entry->request_count = 1; + new_entry->window_start = now; + new_entry->next = rate_limit_table[bucket]; + rate_limit_table[bucket] = new_entry; + + pthread_mutex_unlock(&rate_limit_mutex); + return 1; +} + +void rate_limit_purge_expired(void) { + if (!g_rate_limit_enabled) return; + + pthread_mutex_lock(&rate_limit_mutex); + + time_t now = time(NULL); + for (int i = 0; i < 256; i++) { + rate_limit_entry_t *entry = rate_limit_table[i]; + rate_limit_entry_t *prev = NULL; + + while (entry) { + rate_limit_entry_t *next = entry->next; + if (now - entry->window_start >= g_window_seconds * 2) { + if (prev) { + prev->next = next; + } else { + rate_limit_table[i] = next; + } + free(entry); + } else { + prev = entry; + } + entry = next; + } + } + + pthread_mutex_unlock(&rate_limit_mutex); +} diff --git a/src/rate_limit.h b/src/rate_limit.h new file mode 100644 index 0000000..8384a2a --- /dev/null +++ b/src/rate_limit.h @@ -0,0 +1,11 @@ +#ifndef RPROXY_RATE_LIMIT_H +#define RPROXY_RATE_LIMIT_H + +#include "types.h" + +void rate_limit_init(int requests_per_window, int window_seconds); +void rate_limit_cleanup(void); +int rate_limit_check(const char *client_ip); +void rate_limit_purge_expired(void); + +#endif diff --git a/src/ssl_handler.c b/src/ssl_handler.c index 9cacc6f..9b69b42 100644 --- a/src/ssl_handler.c +++ b/src/ssl_handler.c @@ -1,23 +1,68 @@ #include "ssl_handler.h" #include "logging.h" #include +#include #include +#include SSL_CTX *ssl_ctx = NULL; +static int g_ssl_verify_enabled = 1; +static char g_ca_file[512] = ""; +static char g_ca_path[512] = ""; + +void ssl_set_verify(int enabled) { + g_ssl_verify_enabled = enabled; +} + +void ssl_set_ca_file(const char *path) { + if (path) { + strncpy(g_ca_file, path, sizeof(g_ca_file) - 1); + g_ca_file[sizeof(g_ca_file) - 1] = '\0'; + } +} + +void ssl_set_ca_path(const char *path) { + if (path) { + strncpy(g_ca_path, path, sizeof(g_ca_path) - 1); + g_ca_path[sizeof(g_ca_path) - 1] = '\0'; + } +} void ssl_init(void) { - SSL_load_error_strings(); - OpenSSL_add_ssl_algorithms(); ssl_ctx = SSL_CTX_new(TLS_client_method()); if (!ssl_ctx) { log_error("Failed to create SSL context"); exit(EXIT_FAILURE); } - SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, NULL); - SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3); + if (g_ssl_verify_enabled) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, NULL); + SSL_CTX_set_verify_depth(ssl_ctx, 4); + + int ca_loaded = 0; + if (g_ca_file[0] != '\0' || g_ca_path[0] != '\0') { + const char *file = g_ca_file[0] != '\0' ? g_ca_file : NULL; + const char *path = g_ca_path[0] != '\0' ? g_ca_path : NULL; + if (SSL_CTX_load_verify_locations(ssl_ctx, file, path) == 1) { + ca_loaded = 1; + log_info("Loaded CA certificates from custom location"); + } + } + + if (!ca_loaded) { + if (SSL_CTX_set_default_verify_paths(ssl_ctx) != 1) { + log_info("Warning: Could not load default CA certificates"); + } else { + log_info("Loaded system default CA certificates"); + } + } + } else { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, NULL); + log_info("Warning: SSL certificate verification disabled"); + } + + SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1); SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); - SSL_CTX_set_verify_depth(ssl_ctx, 0); } void ssl_cleanup(void) { @@ -25,7 +70,6 @@ void ssl_cleanup(void) { SSL_CTX_free(ssl_ctx); ssl_ctx = NULL; } - EVP_cleanup(); } int ssl_do_handshake(connection_t *conn) { diff --git a/src/ssl_handler.h b/src/ssl_handler.h index 1129b8c..2a4b823 100644 --- a/src/ssl_handler.h +++ b/src/ssl_handler.h @@ -5,6 +5,9 @@ extern SSL_CTX *ssl_ctx; +void ssl_set_verify(int enabled); +void ssl_set_ca_file(const char *path); +void ssl_set_ca_path(const char *path); void ssl_init(void); void ssl_cleanup(void); int ssl_do_handshake(connection_t *conn); diff --git a/src/types.h b/src/types.h index cdb75ea..9ad3a62 100644 --- a/src/types.h +++ b/src/types.h @@ -18,6 +18,16 @@ #define MAX_REQUEST_LINE_SIZE 4096 #define MAX_URI_SIZE 2048 #define CONNECTION_TIMEOUT 300 +#define ERROR_RESPONSE_SIZE 4096 +#define HOST_HEADER_SIZE 512 +#define MAX_BUFFER_SIZE (64 * 1024 * 1024) +#define MIN_DATA_FOR_REQUEST_CHECK 1 +#define RATE_LIMIT_WINDOW_SECONDS 60 +#define DEFAULT_RATE_LIMIT_REQUESTS 1000 +#define HEALTH_CHECK_INTERVAL_SECONDS 30 +#define HEALTH_CHECK_TIMEOUT_MS 5000 +#define MAX_UPSTREAM_RETRIES 3 +#define UPSTREAM_RETRY_DELAY_MS 100 typedef enum { CONN_TYPE_UNUSED, @@ -79,6 +89,9 @@ typedef struct { int upstream_port; int use_ssl; int rewrite_host; + int use_auth; + char username[128]; + char password_hash[256]; } route_config_t; typedef struct {