Update.
All checks were successful
Build and Test / build (push) Successful in 33s

This commit is contained in:
retoor 2025-11-29 04:58:34 +01:00
parent 01830e2eb2
commit 833c78debc
23 changed files with 1404 additions and 72 deletions

View File

@ -1,6 +1,6 @@
CC = gcc
CFLAGS = -Wall -Wextra -O2 -g -D_GNU_SOURCE
LDFLAGS = -lssl -lcrypto -lsqlite3 -lm
LDFLAGS = -lssl -lcrypto -lsqlite3 -lm -lpthread
SRC_DIR = src
BUILD_DIR = build
@ -15,6 +15,9 @@ SOURCES = $(SRC_DIR)/main.c \
$(SRC_DIR)/ssl_handler.c \
$(SRC_DIR)/connection.c \
$(SRC_DIR)/dashboard.c \
$(SRC_DIR)/rate_limit.c \
$(SRC_DIR)/auth.c \
$(SRC_DIR)/health_check.c \
cJSON.c
OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(SOURCES)))
@ -37,6 +40,9 @@ TEST_LIB_SOURCES = $(SRC_DIR)/buffer.c \
$(SRC_DIR)/ssl_handler.c \
$(SRC_DIR)/connection.c \
$(SRC_DIR)/dashboard.c \
$(SRC_DIR)/rate_limit.c \
$(SRC_DIR)/auth.c \
$(SRC_DIR)/health_check.c \
cJSON.c
TEST_LIB_OBJECTS = $(patsubst %.c,$(BUILD_DIR)/%.o,$(notdir $(TEST_LIB_SOURCES)))
@ -80,6 +86,15 @@ $(BUILD_DIR)/connection.o: $(SRC_DIR)/connection.c
$(BUILD_DIR)/dashboard.o: $(SRC_DIR)/dashboard.c
$(CC) $(CFLAGS) -c $< -o $@
$(BUILD_DIR)/rate_limit.o: $(SRC_DIR)/rate_limit.c
$(CC) $(CFLAGS) -c $< -o $@
$(BUILD_DIR)/auth.o: $(SRC_DIR)/auth.c
$(CC) $(CFLAGS) -c $< -o $@
$(BUILD_DIR)/health_check.o: $(SRC_DIR)/health_check.c
$(CC) $(CFLAGS) -c $< -o $@
$(BUILD_DIR)/cJSON.o: cJSON.c
$(CC) $(CFLAGS) -c $< -o $@

View File

@ -5,19 +5,27 @@ rproxy is a high-performance reverse proxy server written in C. It routes HTTP a
## Features
- Reverse proxy routing by hostname
- SSL/TLS support for upstream connections
- SSL/TLS support for upstream connections with certificate verification
- WebSocket proxying
- Connection pooling and idle timeout management
- Real-time monitoring and statistics
- Web-based dashboard for metrics visualization
- SQLite-based persistent statistics storage
- Epoll-based event handling for high concurrency
- Graceful shutdown with connection draining
- Live configuration reload via SIGHUP
- Dashboard authentication (HTTP Basic Auth)
- Rate limiting per client IP
- Health checks for upstream servers
- Automatic upstream connection retries
- File logging support
## Dependencies
- GCC
- OpenSSL (libssl, libcrypto)
- SQLite3
- pthreads
- cJSON library
## Build
@ -55,6 +63,19 @@ Configuration is defined in `proxy_config.json`:
- `use_ssl`: Enable SSL for upstream connection
- `rewrite_host`: Rewrite Host header to upstream hostname
## Environment Variables
| Variable | Description |
|----------|-------------|
| `DEBUG` | Enable debug logging (set to `1`) |
| `LOG_FILE` | Path to log file (default: stdout) |
| `RATE_LIMIT` | Max requests per minute per IP |
| `DASHBOARD_USER` | Dashboard authentication username |
| `DASHBOARD_PASS` | Dashboard authentication password |
| `SSL_VERIFY` | Disable SSL verification (set to `0`) |
| `SSL_CA_FILE` | Path to custom CA certificate file |
| `SSL_CA_PATH` | Path to CA certificate directory |
## Usage
```bash
@ -63,11 +84,44 @@ Configuration is defined in `proxy_config.json`:
If no config file is specified, defaults to `proxy_config.json`.
Examples:
```bash
# Basic usage
./rproxy
# With custom config
./rproxy /etc/rproxy/config.json
# With debug logging
DEBUG=1 ./rproxy
# With file logging
LOG_FILE=/var/log/rproxy.log ./rproxy
# With rate limiting (100 requests/minute)
RATE_LIMIT=100 ./rproxy
# With dashboard authentication
DASHBOARD_USER=admin DASHBOARD_PASS=secret ./rproxy
# Reload configuration
kill -HUP $(pidof rproxy)
```
## Endpoints
- Dashboard: `http://localhost:{port}/rproxy/dashboard`
- API Stats: `http://localhost:{port}/rproxy/api/stats`
## Signals
| Signal | Action |
|--------|--------|
| `SIGINT` | Graceful shutdown |
| `SIGTERM` | Graceful shutdown |
| `SIGHUP` | Reload configuration |
## Architecture
- **main.c**: Entry point, event loop, signal handling
@ -79,6 +133,9 @@ If no config file is specified, defaults to `proxy_config.json`.
- **config.c**: JSON configuration parsing
- **buffer.c**: Circular buffer implementation
- **logging.c**: Logging utilities
- **rate_limit.c**: Per-IP rate limiting
- **auth.c**: Dashboard authentication
- **health_check.c**: Upstream health monitoring
## Testing

196
src/auth.c Normal file
View File

@ -0,0 +1,196 @@
#include "auth.h"
#include "logging.h"
#include <stdlib.h>
#include <string.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
static char g_dashboard_username[128] = "";
static char g_dashboard_password_hash[256] = "";
static int g_auth_enabled = 0;
static void compute_sha256(const char *input, char *output, size_t output_size) {
EVP_MD_CTX *ctx = EVP_MD_CTX_new();
if (!ctx) return;
unsigned char hash[EVP_MAX_MD_SIZE];
unsigned int hash_len = 0;
EVP_DigestInit_ex(ctx, EVP_sha256(), NULL);
EVP_DigestUpdate(ctx, input, strlen(input));
EVP_DigestFinal_ex(ctx, hash, &hash_len);
EVP_MD_CTX_free(ctx);
for (unsigned int i = 0; i < hash_len && (i * 2 + 2) < output_size; i++) {
snprintf(output + (i * 2), 3, "%02x", hash[i]);
}
}
void auth_init(const char *username, const char *password) {
if (!username || !password || strlen(username) == 0 || strlen(password) == 0) {
g_auth_enabled = 0;
return;
}
strncpy(g_dashboard_username, username, sizeof(g_dashboard_username) - 1);
g_dashboard_username[sizeof(g_dashboard_username) - 1] = '\0';
compute_sha256(password, g_dashboard_password_hash, sizeof(g_dashboard_password_hash));
g_auth_enabled = 1;
log_info("Dashboard authentication enabled for user: %s", username);
}
int auth_is_enabled(void) {
return g_auth_enabled;
}
int auth_check_credentials(const char *username, const char *password) {
if (!g_auth_enabled) return 1;
if (!username || !password) return 0;
if (strcmp(username, g_dashboard_username) != 0) return 0;
char password_hash[256];
compute_sha256(password, password_hash, sizeof(password_hash));
return strcmp(password_hash, g_dashboard_password_hash) == 0;
}
static int base64_decode_char(char c) {
if (c >= 'A' && c <= 'Z') return c - 'A';
if (c >= 'a' && c <= 'z') return c - 'a' + 26;
if (c >= '0' && c <= '9') return c - '0' + 52;
if (c == '+') return 62;
if (c == '/') return 63;
return -1;
}
static int base64_decode(const char *input, char *output, size_t output_size) {
size_t input_len = strlen(input);
size_t output_idx = 0;
for (size_t i = 0; i < input_len && output_idx < output_size - 1; i += 4) {
int v[4] = {0, 0, 0, 0};
int pad = 0;
for (int j = 0; j < 4; j++) {
if (i + j >= input_len || input[i + j] == '=') {
pad++;
v[j] = 0;
} else {
v[j] = base64_decode_char(input[i + j]);
if (v[j] < 0) return -1;
}
}
if (output_idx < output_size - 1) output[output_idx++] = (v[0] << 2) | (v[1] >> 4);
if (pad < 2 && output_idx < output_size - 1) output[output_idx++] = (v[1] << 4) | (v[2] >> 2);
if (pad < 1 && output_idx < output_size - 1) output[output_idx++] = (v[2] << 6) | v[3];
}
output[output_idx] = '\0';
return output_idx;
}
int auth_check_basic_auth(const char *auth_header, char *error_msg, size_t error_size) {
if (!g_auth_enabled) return 1;
if (!auth_header) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Authentication required", error_size - 1);
}
return 0;
}
if (strncmp(auth_header, "Basic ", 6) != 0) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid authentication method", error_size - 1);
}
return 0;
}
char decoded[512];
if (base64_decode(auth_header + 6, decoded, sizeof(decoded)) < 0) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid credentials format", error_size - 1);
}
return 0;
}
char *colon = strchr(decoded, ':');
if (!colon) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid credentials format", error_size - 1);
}
return 0;
}
*colon = '\0';
const char *username = decoded;
const char *password = colon + 1;
if (!auth_check_credentials(username, password)) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid username or password", error_size - 1);
}
return 0;
}
return 1;
}
int auth_check_route_basic_auth(const route_config_t *route, const char *auth_header, char *error_msg, size_t error_size) {
if (!route || !route->use_auth) return 1;
if (!auth_header) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Authentication required", error_size - 1);
}
return 0;
}
if (strncmp(auth_header, "Basic ", 6) != 0) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid authentication method", error_size - 1);
}
return 0;
}
char decoded[512];
if (base64_decode(auth_header + 6, decoded, sizeof(decoded)) < 0) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid credentials format", error_size - 1);
}
return 0;
}
char *colon = strchr(decoded, ':');
if (!colon) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid credentials format", error_size - 1);
}
return 0;
}
*colon = '\0';
const char *username = decoded;
const char *password = colon + 1;
if (strcmp(username, route->username) != 0) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid username or password", error_size - 1);
}
return 0;
}
char password_hash[256];
compute_sha256(password, password_hash, sizeof(password_hash));
if (strcmp(password_hash, route->password_hash) != 0) {
if (error_msg && error_size > 0) {
strncpy(error_msg, "Invalid username or password", error_size - 1);
}
return 0;
}
return 1;
}

13
src/auth.h Normal file
View File

@ -0,0 +1,13 @@
#ifndef RPROXY_AUTH_H
#define RPROXY_AUTH_H
#include <stddef.h>
#include "types.h"
void auth_init(const char *username, const char *password);
int auth_is_enabled(void);
int auth_check_credentials(const char *username, const char *password);
int auth_check_basic_auth(const char *auth_header, char *error_msg, size_t error_size);
int auth_check_route_basic_auth(const route_config_t *route, const char *auth_header, char *error_msg, size_t error_size);
#endif

View File

@ -1,10 +1,12 @@
#include "buffer.h"
#include "logging.h"
#include "types.h"
#include <stdlib.h>
#include <string.h>
#include <stdint.h>
int buffer_init(buffer_t *buf, size_t capacity) {
if (!buf) return -1;
buf->data = malloc(capacity);
if (!buf->data) {
log_error("Failed to allocate buffer");
@ -17,6 +19,7 @@ int buffer_init(buffer_t *buf, size_t capacity) {
}
void buffer_free(buffer_t *buf) {
if (!buf) return;
if (buf->data) {
free(buf->data);
buf->data = NULL;
@ -27,23 +30,40 @@ void buffer_free(buffer_t *buf) {
}
size_t buffer_available_read(buffer_t *buf) {
if (!buf) return 0;
return buf->tail - buf->head;
}
size_t buffer_available_write(buffer_t *buf) {
if (!buf) return 0;
return buf->capacity - buf->tail;
}
int buffer_ensure_capacity(buffer_t *buf, size_t required) {
if (!buf) return -1;
if (buf->capacity >= required) return 0;
if (required > MAX_BUFFER_SIZE) {
log_error("Buffer size limit exceeded: requested %zu, max %d", required, MAX_BUFFER_SIZE);
return -1;
}
size_t new_capacity = buf->capacity;
while (new_capacity < required) {
new_capacity *= 2;
if (new_capacity > SIZE_MAX / 2) {
log_error("Buffer size limit exceeded");
log_error("Buffer size overflow");
return -1;
}
new_capacity *= 2;
if (new_capacity > MAX_BUFFER_SIZE) {
new_capacity = MAX_BUFFER_SIZE;
break;
}
}
if (new_capacity < required) {
log_error("Cannot satisfy buffer capacity requirement");
return -1;
}
char *new_data = realloc(buf->data, new_capacity);
@ -57,7 +77,7 @@ int buffer_ensure_capacity(buffer_t *buf, size_t required) {
}
void buffer_compact(buffer_t *buf) {
if (buf->head == 0) return;
if (!buf || buf->head == 0) return;
size_t len = buf->tail - buf->head;
if (len > 0) {
memmove(buf->data, buf->data + buf->head, len);
@ -67,6 +87,11 @@ void buffer_compact(buffer_t *buf) {
}
void buffer_consume(buffer_t *buf, size_t bytes) {
if (!buf) return;
size_t available = buf->tail - buf->head;
if (bytes > available) {
bytes = available;
}
buf->head += bytes;
if (buf->head >= buf->tail) {
buf->head = 0;

View File

@ -4,9 +4,84 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <openssl/evp.h>
#include <sys/stat.h>
#include <pthread.h>
static pthread_rwlock_t config_lock = PTHREAD_RWLOCK_INITIALIZER;
static time_t config_file_mtime = 0;
static void compute_password_hash(const char *password, char *output, size_t output_size) {
EVP_MD_CTX *ctx = EVP_MD_CTX_new();
if (!ctx) return;
unsigned char hash[EVP_MAX_MD_SIZE];
unsigned int hash_len = 0;
EVP_DigestInit_ex(ctx, EVP_sha256(), NULL);
EVP_DigestUpdate(ctx, password, strlen(password));
EVP_DigestFinal_ex(ctx, hash, &hash_len);
EVP_MD_CTX_free(ctx);
for (unsigned int i = 0; i < hash_len && (i * 2 + 2) < output_size; i++) {
snprintf(output + (i * 2), 3, "%02x", hash[i]);
}
}
app_config_t config;
static int is_valid_hostname(const char *hostname) {
if (!hostname || strlen(hostname) == 0 || strlen(hostname) > 253) return 0;
const char *p = hostname;
int label_len = 0;
while (*p) {
char c = *p;
if (c == '.') {
if (label_len == 0) return 0;
label_len = 0;
} else if (isalnum((unsigned char)c) || c == '-' || c == '_') {
label_len++;
if (label_len > 63) return 0;
} else {
return 0;
}
p++;
}
return 1;
}
static int is_valid_ip(const char *ip) {
if (!ip) return 0;
int dots = 0;
int num = 0;
int has_digit = 0;
while (*ip) {
if (*ip == '.') {
if (!has_digit || num > 255) return 0;
dots++;
num = 0;
has_digit = 0;
} else if (isdigit((unsigned char)*ip)) {
num = num * 10 + (*ip - '0');
has_digit = 1;
} else {
return 0;
}
ip++;
}
return dots == 3 && has_digit && num <= 255;
}
static int is_valid_host(const char *host) {
return is_valid_hostname(host) || is_valid_ip(host);
}
static char* read_file_to_string(const char *filename) {
FILE *f = fopen(filename, "rb");
if (!f) return NULL;
@ -81,8 +156,20 @@ int config_load(const char *filename) {
continue;
}
if (!is_valid_host(hostname->valuestring)) {
fprintf(stderr, "Invalid hostname at index %d: %s\n", i, hostname->valuestring);
continue;
}
if (!is_valid_host(upstream_host->valuestring)) {
fprintf(stderr, "Invalid upstream_host at index %d: %s\n", i, upstream_host->valuestring);
continue;
}
strncpy(route->hostname, hostname->valuestring, sizeof(route->hostname) - 1);
route->hostname[sizeof(route->hostname) - 1] = '\0';
strncpy(route->upstream_host, upstream_host->valuestring, sizeof(route->upstream_host) - 1);
route->upstream_host[sizeof(route->upstream_host) - 1] = '\0';
route->upstream_port = upstream_port->valueint;
if (route->upstream_port < 1 || route->upstream_port > 65535) {
@ -93,9 +180,27 @@ int config_load(const char *filename) {
route->use_ssl = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "use_ssl"));
route->rewrite_host = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "rewrite_host"));
log_info("Route configured: %s -> %s:%d (SSL: %s, Rewrite Host: %s)",
route->use_auth = 0;
route->username[0] = '\0';
route->password_hash[0] = '\0';
cJSON *use_auth = cJSON_GetObjectItem(route_item, "use_auth");
cJSON *auth_username = cJSON_GetObjectItem(route_item, "username");
cJSON *auth_password = cJSON_GetObjectItem(route_item, "password");
if (cJSON_IsTrue(use_auth) && cJSON_IsString(auth_username) && cJSON_IsString(auth_password)) {
if (strlen(auth_username->valuestring) > 0 && strlen(auth_password->valuestring) > 0) {
route->use_auth = 1;
strncpy(route->username, auth_username->valuestring, sizeof(route->username) - 1);
route->username[sizeof(route->username) - 1] = '\0';
compute_password_hash(auth_password->valuestring, route->password_hash, sizeof(route->password_hash));
}
}
log_info("Route configured: %s -> %s:%d (SSL: %s, Rewrite Host: %s, Auth: %s)",
route->hostname, route->upstream_host, route->upstream_port,
route->use_ssl ? "yes" : "no", route->rewrite_host ? "yes" : "no");
route->use_ssl ? "yes" : "no", route->rewrite_host ? "yes" : "no",
route->use_auth ? "yes" : "no");
i++;
}
}
@ -150,10 +255,144 @@ void config_create_default(const char *filename) {
route_config_t *config_find_route(const char *hostname) {
if (!hostname) return NULL;
pthread_rwlock_rdlock(&config_lock);
route_config_t *result = NULL;
for (int i = 0; i < config.route_count; i++) {
if (strcasecmp(hostname, config.routes[i].hostname) == 0) {
return &config.routes[i];
result = &config.routes[i];
break;
}
}
return NULL;
pthread_rwlock_unlock(&config_lock);
return result;
}
int config_check_file_changed(const char *filename) {
struct stat st;
if (stat(filename, &st) != 0) {
return 0;
}
if (config_file_mtime == 0) {
config_file_mtime = st.st_mtime;
return 0;
}
if (st.st_mtime != config_file_mtime) {
config_file_mtime = st.st_mtime;
return 1;
}
return 0;
}
int config_hot_reload(const char *filename) {
log_info("Hot-reloading configuration from %s", filename);
app_config_t new_config;
memset(&new_config, 0, sizeof(app_config_t));
char *json_string = read_file_to_string(filename);
if (!json_string) {
log_error("Hot-reload: Could not read config file");
return 0;
}
cJSON *root = cJSON_Parse(json_string);
free(json_string);
if (!root) {
log_error("Hot-reload: JSON parse error: %s", cJSON_GetErrorPtr());
return 0;
}
cJSON *port_item = cJSON_GetObjectItem(root, "port");
new_config.port = cJSON_IsNumber(port_item) ? port_item->valueint : 8080;
if (new_config.port < 1 || new_config.port > 65535) {
log_error("Hot-reload: Invalid port number: %d", new_config.port);
cJSON_Delete(root);
return 0;
}
cJSON *proxy_array = cJSON_GetObjectItem(root, "reverse_proxy");
if (cJSON_IsArray(proxy_array)) {
new_config.route_count = cJSON_GetArraySize(proxy_array);
if (new_config.route_count <= 0) {
cJSON_Delete(root);
return 0;
}
new_config.routes = calloc(new_config.route_count, sizeof(route_config_t));
if (!new_config.routes) {
log_error("Hot-reload: Failed to allocate memory for routes");
cJSON_Delete(root);
return 0;
}
int i = 0;
cJSON *route_item;
cJSON_ArrayForEach(route_item, proxy_array) {
route_config_t *route = &new_config.routes[i];
cJSON *hostname = cJSON_GetObjectItem(route_item, "hostname");
cJSON *upstream_host = cJSON_GetObjectItem(route_item, "upstream_host");
cJSON *upstream_port = cJSON_GetObjectItem(route_item, "upstream_port");
if (!cJSON_IsString(hostname) || !cJSON_IsString(upstream_host) || !cJSON_IsNumber(upstream_port)) {
continue;
}
if (!is_valid_host(hostname->valuestring) || !is_valid_host(upstream_host->valuestring)) {
continue;
}
strncpy(route->hostname, hostname->valuestring, sizeof(route->hostname) - 1);
route->hostname[sizeof(route->hostname) - 1] = '\0';
strncpy(route->upstream_host, upstream_host->valuestring, sizeof(route->upstream_host) - 1);
route->upstream_host[sizeof(route->upstream_host) - 1] = '\0';
route->upstream_port = upstream_port->valueint;
if (route->upstream_port < 1 || route->upstream_port > 65535) {
continue;
}
route->use_ssl = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "use_ssl"));
route->rewrite_host = cJSON_IsTrue(cJSON_GetObjectItem(route_item, "rewrite_host"));
route->use_auth = 0;
route->username[0] = '\0';
route->password_hash[0] = '\0';
cJSON *use_auth = cJSON_GetObjectItem(route_item, "use_auth");
cJSON *auth_username = cJSON_GetObjectItem(route_item, "username");
cJSON *auth_password = cJSON_GetObjectItem(route_item, "password");
if (cJSON_IsTrue(use_auth) && cJSON_IsString(auth_username) && cJSON_IsString(auth_password)) {
if (strlen(auth_username->valuestring) > 0 && strlen(auth_password->valuestring) > 0) {
route->use_auth = 1;
strncpy(route->username, auth_username->valuestring, sizeof(route->username) - 1);
route->username[sizeof(route->username) - 1] = '\0';
compute_password_hash(auth_password->valuestring, route->password_hash, sizeof(route->password_hash));
}
}
log_info("Hot-reload route: %s -> %s:%d (SSL: %s, Auth: %s)",
route->hostname, route->upstream_host, route->upstream_port,
route->use_ssl ? "yes" : "no", route->use_auth ? "yes" : "no");
i++;
}
new_config.route_count = i;
}
cJSON_Delete(root);
pthread_rwlock_wrlock(&config_lock);
route_config_t *old_routes = config.routes;
config.routes = new_config.routes;
config.route_count = new_config.route_count;
pthread_rwlock_unlock(&config_lock);
if (old_routes) {
free(old_routes);
}
log_info("Hot-reload complete: %d routes loaded", new_config.route_count);
return 1;
}

View File

@ -3,11 +3,15 @@
#include "types.h"
#define CONFIG_RELOAD_INTERVAL_SECONDS 3
extern app_config_t config;
int config_load(const char *filename);
void config_free(void);
void config_create_default(const char *filename);
route_config_t *config_find_route(const char *hostname);
int config_check_file_changed(const char *filename);
int config_hot_reload(const char *filename);
#endif

View File

@ -6,6 +6,7 @@
#include "http.h"
#include "ssl_handler.h"
#include "dashboard.h"
#include "auth.h"
#include <stdio.h>
#include <stdlib.h>
@ -29,11 +30,17 @@ void connection_init_all(void) {
}
}
void connection_set_non_blocking(int fd) {
int connection_set_non_blocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
if (flags >= 0) {
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
if (flags < 0) {
log_error("fcntl F_GETFL failed");
return -1;
}
if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0) {
log_error("fcntl F_SETFL failed");
return -1;
}
return 0;
}
void connection_set_tcp_keepalive(int fd) {
@ -273,16 +280,22 @@ int connection_do_write(connection_t *conn) {
void connection_send_error_response(connection_t *conn, int code, const char* status, const char* body) {
if (!conn || !status || !body) return;
char response[2048];
time_t now = time(NULL);
struct tm *gmt = gmtime(&now);
char date_buf[64];
strftime(date_buf, sizeof(date_buf), "%a, %d %b %Y %H:%M:%S GMT", gmt);
char response[ERROR_RESPONSE_SIZE];
int len = snprintf(response, sizeof(response),
"HTTP/1.1 %d %s\r\n"
"Content-Type: text/plain; charset=utf-8\r\n"
"Content-Length: %zu\r\n"
"Connection: close\r\n"
"Date: %s\r\n"
"Server: ReverseProxy/4.0\r\n"
"\r\n"
"%s",
code, status, strlen(body), body);
code, status, strlen(body), date_buf, body);
if (len > 0 && (size_t)len < sizeof(response)) {
if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) {
@ -298,6 +311,66 @@ void connection_send_error_response(connection_t *conn, int code, const char* st
conn->request.keep_alive = 0;
}
void connection_send_auth_required(connection_t *conn, const char *realm) {
if (!conn) return;
time_t now = time(NULL);
struct tm *gmt = gmtime(&now);
char date_buf[64];
strftime(date_buf, sizeof(date_buf), "%a, %d %b %Y %H:%M:%S GMT", gmt);
const char *body = "401 Unauthorized - Authentication required";
char response[ERROR_RESPONSE_SIZE];
int len = snprintf(response, sizeof(response),
"HTTP/1.1 401 Unauthorized\r\n"
"Content-Type: text/plain; charset=utf-8\r\n"
"Content-Length: %zu\r\n"
"WWW-Authenticate: Basic realm=\"%s\"\r\n"
"Connection: close\r\n"
"Date: %s\r\n"
"Server: ReverseProxy/4.0\r\n"
"\r\n"
"%s",
strlen(body), realm ? realm : "Protected Area", date_buf, body);
if (len > 0 && (size_t)len < sizeof(response)) {
if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) {
memcpy(conn->write_buf.data + conn->write_buf.tail, response, len);
conn->write_buf.tail += len;
struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT };
epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event);
}
}
conn->state = CLIENT_STATE_ERROR;
conn->request.keep_alive = 0;
}
static int try_upstream_connect(struct sockaddr_in *addr, int *out_fd) {
int up_fd = socket(AF_INET, SOCK_STREAM, 0);
if (up_fd < 0) {
return -1;
}
if (up_fd >= MAX_FDS) {
close(up_fd);
return -1;
}
connection_set_non_blocking(up_fd);
connection_set_tcp_keepalive(up_fd);
int connect_result = connect(up_fd, (struct sockaddr*)addr, sizeof(*addr));
if (connect_result < 0 && errno != EINPROGRESS) {
close(up_fd);
return -1;
}
*out_fd = up_fd;
return 0;
}
void connection_connect_to_upstream(connection_t *client, const char *data, size_t data_len) {
if (!client || !data) return;
@ -307,39 +380,48 @@ void connection_connect_to_upstream(connection_t *client, const char *data, size
return;
}
int up_fd = socket(AF_INET, SOCK_STREAM, 0);
if (up_fd < 0) {
connection_send_error_response(client, 502, "Bad Gateway", "Failed to create upstream socket");
return;
}
if (up_fd >= MAX_FDS) {
close(up_fd);
connection_send_error_response(client, 502, "Bad Gateway", "Connection limit exceeded");
return;
}
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_port = htons(route->upstream_port);
if (inet_pton(AF_INET, route->upstream_host, &addr.sin_addr) <= 0) {
struct hostent *he = gethostbyname(route->upstream_host);
if (!he) {
close(up_fd);
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
int gai_err = getaddrinfo(route->upstream_host, NULL, &hints, &result);
if (gai_err != 0) {
log_debug("DNS resolution failed for %s: %s", route->upstream_host, gai_strerror(gai_err));
connection_send_error_response(client, 502, "Bad Gateway", "Cannot resolve upstream hostname");
return;
}
memcpy(&addr.sin_addr, he->h_addr_list[0], he->h_length);
struct sockaddr_in *resolved = (struct sockaddr_in *)result->ai_addr;
addr.sin_addr = resolved->sin_addr;
freeaddrinfo(result);
}
connection_set_non_blocking(up_fd);
connection_set_tcp_keepalive(up_fd);
int up_fd = -1;
int retry_count = 0;
int connect_result = connect(up_fd, (struct sockaddr*)&addr, sizeof(addr));
if (connect_result < 0 && errno != EINPROGRESS) {
close(up_fd);
while (retry_count < MAX_UPSTREAM_RETRIES) {
if (try_upstream_connect(&addr, &up_fd) == 0) {
break;
}
retry_count++;
if (retry_count < MAX_UPSTREAM_RETRIES) {
log_debug("Upstream connection attempt %d failed for %s:%d, retrying...",
retry_count, route->upstream_host, route->upstream_port);
usleep(UPSTREAM_RETRY_DELAY_MS * 1000);
}
}
if (up_fd < 0) {
log_debug("All %d connection attempts failed for %s:%d",
MAX_UPSTREAM_RETRIES, route->upstream_host, route->upstream_port);
connection_send_error_response(client, 502, "Bad Gateway", "Failed to connect to upstream");
return;
}
@ -356,9 +438,23 @@ void connection_connect_to_upstream(connection_t *client, const char *data, size
up->pair = client;
up->vhost_stats = client->vhost_stats;
if (buffer_init(&up->read_buf, CHUNK_SIZE) < 0 ||
buffer_init(&up->write_buf, CHUNK_SIZE) < 0) {
connection_close(client->fd);
if (buffer_init(&up->read_buf, CHUNK_SIZE) < 0) {
close(up_fd);
memset(up, 0, sizeof(connection_t));
up->type = CONN_TYPE_UNUSED;
up->fd = -1;
client->pair = NULL;
connection_send_error_response(client, 502, "Bad Gateway", "Memory allocation failed");
return;
}
if (buffer_init(&up->write_buf, CHUNK_SIZE) < 0) {
buffer_free(&up->read_buf);
close(up_fd);
memset(up, 0, sizeof(connection_t));
up->type = CONN_TYPE_UNUSED;
up->fd = -1;
client->pair = NULL;
connection_send_error_response(client, 502, "Bad Gateway", "Memory allocation failed");
return;
}
@ -535,9 +631,9 @@ static void handle_client_read(connection_t *conn) {
conn->state = CLIENT_STATE_SERVING_INTERNAL;
if (strncmp(conn->request.uri, DASHBOARD_PATH, sizeof(DASHBOARD_PATH) - 1) == 0) {
dashboard_serve(conn);
dashboard_serve(conn, data_start, headers_len);
} else {
dashboard_serve_stats_api(conn);
dashboard_serve_stats_api(conn, data_start, headers_len);
}
buffer_consume(buf, total_request_len);
@ -551,6 +647,21 @@ static void handle_client_read(connection_t *conn) {
#undef DASHBOARD_PATH
#undef STATS_PATH
route_config_t *route = config_find_route(conn->request.host);
if (route && route->use_auth) {
char auth_header[1024] = "";
const char *headers_start = data_start + (strstr(data_start, "\r\n") - data_start + 2);
http_find_header_value(headers_start, headers_len - (headers_start - data_start), "Authorization", auth_header, sizeof(auth_header));
char error_msg[256] = "";
if (!auth_check_route_basic_auth(route, strlen(auth_header) > 0 ? auth_header : NULL, error_msg, sizeof(error_msg))) {
log_info("[ROUTING-AUTH] Authentication failed for %s: %s", conn->request.host, error_msg);
connection_send_auth_required(conn, conn->request.host);
buffer_consume(buf, total_request_len);
return;
}
}
log_info("[ROUTING-FORWARD] Forwarding request for fd=%d: %s %s",
conn->fd, conn->request.method, conn->request.uri);

View File

@ -14,7 +14,7 @@ void connection_close(int fd);
void connection_handle_event(struct epoll_event *event);
void connection_cleanup_idle(void);
void connection_set_non_blocking(int fd);
int connection_set_non_blocking(int fd);
void connection_set_tcp_keepalive(int fd);
void connection_add_to_epoll(int fd, uint32_t events);
void connection_modify_epoll(int fd, uint32_t events);
@ -23,6 +23,7 @@ int connection_do_read(connection_t *conn);
int connection_do_write(connection_t *conn);
void connection_send_error_response(connection_t *conn, int code, const char* status, const char* body);
void connection_send_auth_required(connection_t *conn, const char *realm);
void connection_connect_to_upstream(connection_t *client, const char *data, size_t data_len);
#endif

View File

@ -2,11 +2,14 @@
#include "buffer.h"
#include "monitor.h"
#include "connection.h"
#include "auth.h"
#include "http.h"
#include "../cJSON.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <time.h>
static const char *DASHBOARD_HTML =
"<!DOCTYPE html>\n"
@ -289,9 +292,59 @@ static const char *DASHBOARD_HTML =
"</body>\n"
"</html>\n";
void dashboard_serve(connection_t *conn) {
static void send_unauthorized_response(connection_t *conn) {
time_t now = time(NULL);
struct tm *gmt = gmtime(&now);
char date_buf[64];
strftime(date_buf, sizeof(date_buf), "%a, %d %b %Y %H:%M:%S GMT", gmt);
const char *body = "Unauthorized";
char header[1024];
int len = snprintf(header, sizeof(header),
"HTTP/1.1 401 Unauthorized\r\n"
"Content-Type: text/plain; charset=utf-8\r\n"
"Content-Length: %zu\r\n"
"WWW-Authenticate: Basic realm=\"RProxy Dashboard\"\r\n"
"Date: %s\r\n"
"Connection: close\r\n"
"\r\n"
"%s",
strlen(body), date_buf, body);
if (buffer_ensure_capacity(&conn->write_buf, conn->write_buf.tail + len) == 0) {
memcpy(conn->write_buf.data + conn->write_buf.tail, header, len);
conn->write_buf.tail += len;
struct epoll_event event = { .data.fd = conn->fd, .events = EPOLLIN | EPOLLOUT };
epoll_ctl(epoll_fd, EPOLL_CTL_MOD, conn->fd, &event);
}
conn->state = CLIENT_STATE_ERROR;
conn->request.keep_alive = 0;
}
static int check_dashboard_auth(connection_t *conn, const char *request_data, size_t request_len) {
if (!auth_is_enabled()) return 1;
char auth_header[512] = "";
http_find_header_value(request_data, request_len, "Authorization", auth_header, sizeof(auth_header));
char error_msg[256] = "";
if (!auth_check_basic_auth(auth_header[0] ? auth_header : NULL, error_msg, sizeof(error_msg))) {
send_unauthorized_response(conn);
return 0;
}
return 1;
}
void dashboard_serve(connection_t *conn, const char *request_data, size_t request_len) {
if (!conn) return;
if (!check_dashboard_auth(conn, request_data, request_len)) {
return;
}
size_t content_len = strlen(DASHBOARD_HTML);
char header[512];
int len = snprintf(header, sizeof(header),
@ -320,7 +373,8 @@ void dashboard_serve(connection_t *conn) {
static cJSON* format_history(history_deque_t *dq, int window_seconds) {
cJSON *arr = cJSON_CreateArray();
if (!arr || !dq || !dq->points || dq->count == 0) return arr;
if (!arr) return NULL;
if (!dq || !dq->points || dq->count == 0) return arr;
double current_time = time(NULL);
int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity;
@ -342,7 +396,8 @@ static cJSON* format_history(history_deque_t *dq, int window_seconds) {
static cJSON* format_network_history(network_history_deque_t *dq, int window_seconds, const char *key) {
cJSON *arr = cJSON_CreateArray();
if (!arr || !dq || !dq->points || !key || dq->count == 0) return arr;
if (!arr) return NULL;
if (!dq || !dq->points || !key || dq->count == 0) return arr;
double current_time = time(NULL);
int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity;
@ -364,7 +419,8 @@ static cJSON* format_network_history(network_history_deque_t *dq, int window_sec
static cJSON* format_disk_history(disk_history_deque_t *dq, int window_seconds, const char *key) {
cJSON *arr = cJSON_CreateArray();
if (!arr || !dq || !dq->points || !key || dq->count == 0) return arr;
if (!arr) return NULL;
if (!dq || !dq->points || !key || dq->count == 0) return arr;
double current_time = time(NULL);
int start_index = (dq->head - dq->count + dq->capacity) % dq->capacity;
@ -384,9 +440,13 @@ static cJSON* format_disk_history(disk_history_deque_t *dq, int window_seconds,
return arr;
}
void dashboard_serve_stats_api(connection_t *conn) {
void dashboard_serve_stats_api(connection_t *conn, const char *request_data, size_t request_len) {
if (!conn) return;
if (!check_dashboard_auth(conn, request_data, request_len)) {
return;
}
cJSON *root = cJSON_CreateObject();
if (!root) {
connection_send_error_response(conn, 500, "Internal Server Error", "JSON creation failed");

View File

@ -3,7 +3,7 @@
#include "types.h"
void dashboard_serve(connection_t *conn);
void dashboard_serve_stats_api(connection_t *conn);
void dashboard_serve(connection_t *conn, const char *request_data, size_t request_len);
void dashboard_serve_stats_api(connection_t *conn, const char *request_data, size_t request_len);
#endif

188
src/health_check.c Normal file
View File

@ -0,0 +1,188 @@
#include "health_check.h"
#include "logging.h"
#include "config.h"
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <fcntl.h>
#include <errno.h>
#include <poll.h>
#include <pthread.h>
typedef struct {
char hostname[256];
char upstream_host[256];
int upstream_port;
int healthy;
int consecutive_failures;
time_t last_check;
} upstream_health_t;
static upstream_health_t *health_states = NULL;
static int health_state_count = 0;
static pthread_mutex_t health_mutex = PTHREAD_MUTEX_INITIALIZER;
static int g_health_check_enabled = 0;
void health_check_init(void) {
pthread_mutex_lock(&health_mutex);
if (health_states) {
free(health_states);
}
health_state_count = config.route_count;
if (health_state_count <= 0) {
health_states = NULL;
pthread_mutex_unlock(&health_mutex);
return;
}
health_states = calloc(health_state_count, sizeof(upstream_health_t));
if (!health_states) {
health_state_count = 0;
pthread_mutex_unlock(&health_mutex);
return;
}
for (int i = 0; i < health_state_count; i++) {
strncpy(health_states[i].hostname, config.routes[i].hostname, sizeof(health_states[i].hostname) - 1);
strncpy(health_states[i].upstream_host, config.routes[i].upstream_host, sizeof(health_states[i].upstream_host) - 1);
health_states[i].upstream_port = config.routes[i].upstream_port;
health_states[i].healthy = 1;
health_states[i].consecutive_failures = 0;
health_states[i].last_check = 0;
}
g_health_check_enabled = 1;
log_info("Health check initialized for %d upstreams", health_state_count);
pthread_mutex_unlock(&health_mutex);
}
void health_check_cleanup(void) {
pthread_mutex_lock(&health_mutex);
if (health_states) {
free(health_states);
health_states = NULL;
}
health_state_count = 0;
g_health_check_enabled = 0;
pthread_mutex_unlock(&health_mutex);
}
static int check_tcp_connection(const char *host, int port, int timeout_ms) {
struct addrinfo hints, *result;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
if (getaddrinfo(host, NULL, &hints, &result) != 0) {
return 0;
}
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
freeaddrinfo(result);
return 0;
}
int flags = fcntl(fd, F_GETFL, 0);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
struct sockaddr_in *addr = (struct sockaddr_in *)result->ai_addr;
addr->sin_port = htons(port);
int connect_result = connect(fd, (struct sockaddr *)addr, sizeof(struct sockaddr_in));
freeaddrinfo(result);
if (connect_result == 0) {
close(fd);
return 1;
}
if (errno != EINPROGRESS) {
close(fd);
return 0;
}
struct pollfd pfd;
pfd.fd = fd;
pfd.events = POLLOUT;
int poll_result = poll(&pfd, 1, timeout_ms);
if (poll_result <= 0) {
close(fd);
return 0;
}
int error = 0;
socklen_t len = sizeof(error);
getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &len);
close(fd);
return error == 0;
}
void health_check_run(void) {
if (!g_health_check_enabled) return;
pthread_mutex_lock(&health_mutex);
time_t now = time(NULL);
for (int i = 0; i < health_state_count; i++) {
if (now - health_states[i].last_check < HEALTH_CHECK_INTERVAL_SECONDS) {
continue;
}
health_states[i].last_check = now;
int is_healthy = check_tcp_connection(
health_states[i].upstream_host,
health_states[i].upstream_port,
HEALTH_CHECK_TIMEOUT_MS
);
if (is_healthy) {
if (!health_states[i].healthy) {
log_info("Upstream %s:%d is now healthy",
health_states[i].upstream_host,
health_states[i].upstream_port);
}
health_states[i].healthy = 1;
health_states[i].consecutive_failures = 0;
} else {
health_states[i].consecutive_failures++;
if (health_states[i].consecutive_failures >= 3 && health_states[i].healthy) {
log_info("Upstream %s:%d is now unhealthy (failures: %d)",
health_states[i].upstream_host,
health_states[i].upstream_port,
health_states[i].consecutive_failures);
health_states[i].healthy = 0;
}
}
}
pthread_mutex_unlock(&health_mutex);
}
int health_check_is_healthy(const char *hostname) {
if (!g_health_check_enabled || !hostname) return 1;
pthread_mutex_lock(&health_mutex);
for (int i = 0; i < health_state_count; i++) {
if (strcasecmp(health_states[i].hostname, hostname) == 0) {
int result = health_states[i].healthy;
pthread_mutex_unlock(&health_mutex);
return result;
}
}
pthread_mutex_unlock(&health_mutex);
return 1;
}

11
src/health_check.h Normal file
View File

@ -0,0 +1,11 @@
#ifndef RPROXY_HEALTH_CHECK_H
#define RPROXY_HEALTH_CHECK_H
#include "types.h"
void health_check_init(void);
void health_check_cleanup(void);
void health_check_run(void);
int health_check_is_healthy(const char *hostname);
#endif

View File

@ -124,7 +124,11 @@ int http_parse_request(const char *data, size_t len, http_request_t *req) {
}
if (http_find_header_value(headers_start, len - (headers_start - data), "Content-Length", value, sizeof(value))) {
req->content_length = atol(value);
char *endptr;
long parsed = strtol(value, &endptr, 10);
if (endptr != value && *endptr == '\0' && parsed >= 0) {
req->content_length = parsed;
}
}
if (http_find_header_value(headers_start, len - (headers_start - data), "Transfer-Encoding", value, sizeof(value))) {

View File

@ -2,8 +2,13 @@
#include <stdio.h>
#include <stdarg.h>
#include <time.h>
#include <string.h>
#include <errno.h>
#include <pthread.h>
static int g_debug_mode = 0;
static FILE *g_log_file = NULL;
static pthread_mutex_t log_mutex = PTHREAD_MUTEX_INITIALIZER;
void logging_set_debug(int enabled) {
g_debug_mode = enabled;
@ -13,19 +18,77 @@ int logging_get_debug(void) {
return g_debug_mode;
}
void log_error(const char *msg) {
perror(msg);
int logging_set_file(const char *path) {
pthread_mutex_lock(&log_mutex);
if (g_log_file && g_log_file != stdout && g_log_file != stderr) {
fclose(g_log_file);
}
if (path) {
g_log_file = fopen(path, "a");
if (!g_log_file) {
g_log_file = stdout;
pthread_mutex_unlock(&log_mutex);
return -1;
}
} else {
g_log_file = stdout;
}
pthread_mutex_unlock(&log_mutex);
return 0;
}
void logging_cleanup(void) {
pthread_mutex_lock(&log_mutex);
if (g_log_file && g_log_file != stdout && g_log_file != stderr) {
fclose(g_log_file);
}
g_log_file = NULL;
pthread_mutex_unlock(&log_mutex);
}
static void log_message(const char *level, const char *format, va_list args) {
pthread_mutex_lock(&log_mutex);
FILE *out = g_log_file ? g_log_file : stdout;
time_t now;
time(&now);
struct tm *local = localtime(&now);
char buf[32];
strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", localtime(&now));
printf("%s - %-5s - ", buf, level);
vprintf(format, args);
printf("\n");
fflush(stdout);
strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", local);
fprintf(out, "%s - %-5s - ", buf, level);
vfprintf(out, format, args);
fprintf(out, "\n");
fflush(out);
pthread_mutex_unlock(&log_mutex);
}
void log_error(const char *format, ...) {
va_list args;
va_start(args, format);
int saved_errno = errno;
char msg[1024];
vsnprintf(msg, sizeof(msg), format, args);
va_end(args);
pthread_mutex_lock(&log_mutex);
FILE *out = g_log_file ? g_log_file : stderr;
time_t now;
time(&now);
struct tm *local = localtime(&now);
char buf[32];
strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", local);
if (saved_errno != 0) {
fprintf(out, "%s - ERROR - %s: %s\n", buf, msg, strerror(saved_errno));
} else {
fprintf(out, "%s - ERROR - %s\n", buf, msg);
}
fflush(out);
pthread_mutex_unlock(&log_mutex);
}
void log_info(const char *format, ...) {

View File

@ -1,10 +1,12 @@
#ifndef RPROXY_LOGGING_H
#define RPROXY_LOGGING_H
void log_error(const char *msg);
void log_error(const char *format, ...);
void log_info(const char *format, ...);
void log_debug(const char *format, ...);
void logging_set_debug(int enabled);
int logging_get_debug(void);
int logging_set_file(const char *path);
void logging_cleanup(void);
#endif

View File

@ -6,6 +6,7 @@
#include <signal.h>
#include <errno.h>
#include <sys/epoll.h>
#include <unistd.h>
#include "types.h"
#include "logging.h"
@ -13,13 +14,82 @@
#include "monitor.h"
#include "ssl_handler.h"
#include "connection.h"
#include "rate_limit.h"
#include "auth.h"
#include "health_check.h"
static volatile int g_shutdown = 0;
static volatile sig_atomic_t g_shutdown = 0;
static volatile sig_atomic_t g_reload_config = 0;
static const char *g_config_file = NULL;
static void signal_handler(int sig) {
if (sig == SIGINT || sig == SIGTERM) {
log_info("Received signal %d, shutting down...", sig);
g_shutdown = 1;
} else if (sig == SIGHUP) {
g_reload_config = 1;
}
}
static void reload_configuration(void) {
if (!g_config_file) return;
log_info("Reloading configuration from %s", g_config_file);
app_config_t old_config = config;
memset(&config, 0, sizeof(app_config_t));
if (!config_load(g_config_file)) {
log_error("Failed to reload configuration, keeping old config");
config = old_config;
return;
}
if (old_config.routes) {
free(old_config.routes);
}
log_info("Configuration reloaded successfully");
}
static void graceful_shutdown(void) {
log_info("Initiating graceful shutdown...");
for (int i = 0; i < MAX_FDS; i++) {
connection_t *conn = &connections[i];
if (conn->type == CONN_TYPE_LISTENER && conn->fd != -1) {
epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->fd, NULL);
close(conn->fd);
conn->fd = -1;
conn->type = CONN_TYPE_UNUSED;
log_info("Stopped accepting new connections");
}
}
int active_count = 0;
int drain_timeout = 30;
time_t drain_start = time(NULL);
do {
active_count = 0;
for (int i = 0; i < MAX_FDS; i++) {
if (connections[i].type == CONN_TYPE_CLIENT || connections[i].type == CONN_TYPE_UPSTREAM) {
if (connections[i].fd != -1) {
active_count++;
}
}
}
if (active_count > 0 && (time(NULL) - drain_start) < drain_timeout) {
struct epoll_event events[MAX_EVENTS];
int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 100);
for (int i = 0; i < n; i++) {
connection_handle_event(&events[i]);
}
}
} while (active_count > 0 && (time(NULL) - drain_start) < drain_timeout);
if (active_count > 0) {
log_info("Drain timeout reached, forcing close of %d connections", active_count);
}
}
@ -34,6 +104,9 @@ static void cleanup(void) {
config_free();
monitor_cleanup();
rate_limit_cleanup();
health_check_cleanup();
logging_cleanup();
if (epoll_fd >= 0) {
close(epoll_fd);
@ -49,23 +122,62 @@ int main(int argc, char *argv[]) {
signal(SIGPIPE, SIG_IGN);
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
signal(SIGHUP, signal_handler);
if (getenv("DEBUG")) {
logging_set_debug(1);
log_info("Debug mode enabled");
}
const char *config_file = (argc > 1) ? argv[1] : "proxy_config.json";
config_create_default(config_file);
const char *log_file = getenv("LOG_FILE");
if (log_file) {
if (logging_set_file(log_file) == 0) {
log_info("Logging to file: %s", log_file);
}
}
if (!config_load(config_file)) {
g_config_file = (argc > 1) ? argv[1] : "proxy_config.json";
config_create_default(g_config_file);
if (!config_load(g_config_file)) {
fprintf(stderr, "Failed to load configuration\n");
return 1;
}
const char *ssl_verify = getenv("SSL_VERIFY");
if (ssl_verify && strcmp(ssl_verify, "0") == 0) {
ssl_set_verify(0);
}
const char *ca_file = getenv("SSL_CA_FILE");
if (ca_file) {
ssl_set_ca_file(ca_file);
}
const char *ca_path = getenv("SSL_CA_PATH");
if (ca_path) {
ssl_set_ca_path(ca_path);
}
ssl_init();
monitor_init("proxy_stats.db");
const char *rate_limit_str = getenv("RATE_LIMIT");
if (rate_limit_str) {
int rate = atoi(rate_limit_str);
if (rate > 0) {
rate_limit_init(rate, RATE_LIMIT_WINDOW_SECONDS);
}
}
const char *auth_user = getenv("DASHBOARD_USER");
const char *auth_pass = getenv("DASHBOARD_PASS");
if (auth_user && auth_pass) {
auth_init(auth_user, auth_pass);
}
health_check_init();
epoll_fd = epoll_create1(EPOLL_CLOEXEC);
if (epoll_fd == -1) {
log_error("epoll_create1 failed");
@ -78,14 +190,21 @@ int main(int argc, char *argv[]) {
log_info("Port %d", config.port);
log_info("Dashboard: http://localhost:%d/rproxy/dashboard", config.port);
log_info("Stats: http://localhost:%d/rproxy/api/stats", config.port);
log_info("Send SIGHUP to reload configuration");
atexit(cleanup);
struct epoll_event events[MAX_EVENTS];
time_t last_monitor_update = 0;
time_t last_cleanup = 0;
time_t last_config_check = 0;
while (!g_shutdown) {
if (g_reload_config) {
g_reload_config = 0;
reload_configuration();
}
int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 1000);
if (n == -1) {
if (errno == EINTR) continue;
@ -104,12 +223,24 @@ int main(int argc, char *argv[]) {
last_monitor_update = current_time;
}
if (current_time - last_config_check >= CONFIG_RELOAD_INTERVAL_SECONDS) {
if (config_check_file_changed(g_config_file)) {
config_hot_reload(g_config_file);
}
last_config_check = current_time;
}
if (current_time - last_cleanup >= 60) {
connection_cleanup_idle();
rate_limit_purge_expired();
last_cleanup = current_time;
}
health_check_run();
}
log_info("Received shutdown signal");
graceful_shutdown();
log_info("Shutdown complete");
return 0;
}

View File

@ -5,8 +5,10 @@
#include <string.h>
#include <sys/sysinfo.h>
#include <math.h>
#include <pthread.h>
system_monitor_t monitor;
static pthread_mutex_t vhost_stats_mutex = PTHREAD_MUTEX_INITIALIZER;
void history_deque_init(history_deque_t *dq, int capacity) {
dq->points = calloc(capacity, sizeof(history_point_t));
@ -300,8 +302,11 @@ static void get_disk_stats(long long *sectors_read, long long *sectors_written)
if (nfields >= 11) {
strncpy(device, dev, sizeof(device)-1);
device[sizeof(device)-1] = '\0';
sectors_r = atoll(sr);
sectors_w = atoll(sw);
char *endptr;
sectors_r = strtoll(sr, &endptr, 10);
if (endptr == sr) sectors_r = 0;
sectors_w = strtoll(sw, &endptr, 10);
if (endptr == sw) sectors_w = 0;
if (strncmp(device, "loop", 4) != 0 && strncmp(device, "ram", 3) != 0) {
int len = strlen(device);
@ -398,12 +403,20 @@ void monitor_update(void) {
vhost_stats_t* monitor_get_or_create_vhost_stats(const char *vhost_name) {
if (!vhost_name || strlen(vhost_name) == 0) return NULL;
pthread_mutex_lock(&vhost_stats_mutex);
for (vhost_stats_t *curr = monitor.vhost_stats_head; curr; curr = curr->next) {
if (strcmp(curr->vhost_name, vhost_name) == 0) return curr;
if (strcmp(curr->vhost_name, vhost_name) == 0) {
pthread_mutex_unlock(&vhost_stats_mutex);
return curr;
}
}
vhost_stats_t *new_stats = calloc(1, sizeof(vhost_stats_t));
if (!new_stats) return NULL;
if (!new_stats) {
pthread_mutex_unlock(&vhost_stats_mutex);
return NULL;
}
strncpy(new_stats->vhost_name, vhost_name, sizeof(new_stats->vhost_name) - 1);
new_stats->last_update = time(NULL);
@ -411,6 +424,8 @@ vhost_stats_t* monitor_get_or_create_vhost_stats(const char *vhost_name) {
request_time_deque_init(&new_stats->request_times, 100);
new_stats->next = monitor.vhost_stats_head;
monitor.vhost_stats_head = new_stats;
pthread_mutex_unlock(&vhost_stats_mutex);
return new_stats;
}

126
src/rate_limit.c Normal file
View File

@ -0,0 +1,126 @@
#include "rate_limit.h"
#include "logging.h"
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <pthread.h>
#define MAX_RATE_LIMIT_ENTRIES 10000
typedef struct rate_limit_entry {
char client_ip[64];
int request_count;
time_t window_start;
struct rate_limit_entry *next;
} rate_limit_entry_t;
static rate_limit_entry_t *rate_limit_table[256];
static pthread_mutex_t rate_limit_mutex = PTHREAD_MUTEX_INITIALIZER;
static int g_rate_limit_enabled = 0;
static int g_requests_per_window = DEFAULT_RATE_LIMIT_REQUESTS;
static int g_window_seconds = RATE_LIMIT_WINDOW_SECONDS;
static unsigned int hash_ip(const char *ip) {
unsigned int hash = 0;
while (*ip) {
hash = hash * 31 + (unsigned char)*ip++;
}
return hash % 256;
}
void rate_limit_init(int requests_per_window, int window_seconds) {
g_rate_limit_enabled = 1;
g_requests_per_window = requests_per_window;
g_window_seconds = window_seconds;
memset(rate_limit_table, 0, sizeof(rate_limit_table));
log_info("Rate limiting enabled: %d requests per %d seconds", requests_per_window, window_seconds);
}
void rate_limit_cleanup(void) {
pthread_mutex_lock(&rate_limit_mutex);
for (int i = 0; i < 256; i++) {
rate_limit_entry_t *entry = rate_limit_table[i];
while (entry) {
rate_limit_entry_t *next = entry->next;
free(entry);
entry = next;
}
rate_limit_table[i] = NULL;
}
pthread_mutex_unlock(&rate_limit_mutex);
}
int rate_limit_check(const char *client_ip) {
if (!g_rate_limit_enabled || !client_ip) return 1;
pthread_mutex_lock(&rate_limit_mutex);
time_t now = time(NULL);
unsigned int bucket = hash_ip(client_ip);
rate_limit_entry_t *entry = rate_limit_table[bucket];
while (entry) {
if (strcmp(entry->client_ip, client_ip) == 0) {
if (now - entry->window_start >= g_window_seconds) {
entry->window_start = now;
entry->request_count = 1;
pthread_mutex_unlock(&rate_limit_mutex);
return 1;
}
entry->request_count++;
if (entry->request_count > g_requests_per_window) {
pthread_mutex_unlock(&rate_limit_mutex);
return 0;
}
pthread_mutex_unlock(&rate_limit_mutex);
return 1;
}
entry = entry->next;
}
rate_limit_entry_t *new_entry = calloc(1, sizeof(rate_limit_entry_t));
if (!new_entry) {
pthread_mutex_unlock(&rate_limit_mutex);
return 1;
}
strncpy(new_entry->client_ip, client_ip, sizeof(new_entry->client_ip) - 1);
new_entry->request_count = 1;
new_entry->window_start = now;
new_entry->next = rate_limit_table[bucket];
rate_limit_table[bucket] = new_entry;
pthread_mutex_unlock(&rate_limit_mutex);
return 1;
}
void rate_limit_purge_expired(void) {
if (!g_rate_limit_enabled) return;
pthread_mutex_lock(&rate_limit_mutex);
time_t now = time(NULL);
for (int i = 0; i < 256; i++) {
rate_limit_entry_t *entry = rate_limit_table[i];
rate_limit_entry_t *prev = NULL;
while (entry) {
rate_limit_entry_t *next = entry->next;
if (now - entry->window_start >= g_window_seconds * 2) {
if (prev) {
prev->next = next;
} else {
rate_limit_table[i] = next;
}
free(entry);
} else {
prev = entry;
}
entry = next;
}
}
pthread_mutex_unlock(&rate_limit_mutex);
}

11
src/rate_limit.h Normal file
View File

@ -0,0 +1,11 @@
#ifndef RPROXY_RATE_LIMIT_H
#define RPROXY_RATE_LIMIT_H
#include "types.h"
void rate_limit_init(int requests_per_window, int window_seconds);
void rate_limit_cleanup(void);
int rate_limit_check(const char *client_ip);
void rate_limit_purge_expired(void);
#endif

View File

@ -1,23 +1,68 @@
#include "ssl_handler.h"
#include "logging.h"
#include <openssl/err.h>
#include <openssl/x509_vfy.h>
#include <stdlib.h>
#include <string.h>
SSL_CTX *ssl_ctx = NULL;
static int g_ssl_verify_enabled = 1;
static char g_ca_file[512] = "";
static char g_ca_path[512] = "";
void ssl_set_verify(int enabled) {
g_ssl_verify_enabled = enabled;
}
void ssl_set_ca_file(const char *path) {
if (path) {
strncpy(g_ca_file, path, sizeof(g_ca_file) - 1);
g_ca_file[sizeof(g_ca_file) - 1] = '\0';
}
}
void ssl_set_ca_path(const char *path) {
if (path) {
strncpy(g_ca_path, path, sizeof(g_ca_path) - 1);
g_ca_path[sizeof(g_ca_path) - 1] = '\0';
}
}
void ssl_init(void) {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
ssl_ctx = SSL_CTX_new(TLS_client_method());
if (!ssl_ctx) {
log_error("Failed to create SSL context");
exit(EXIT_FAILURE);
}
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, NULL);
SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
if (g_ssl_verify_enabled) {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, NULL);
SSL_CTX_set_verify_depth(ssl_ctx, 4);
int ca_loaded = 0;
if (g_ca_file[0] != '\0' || g_ca_path[0] != '\0') {
const char *file = g_ca_file[0] != '\0' ? g_ca_file : NULL;
const char *path = g_ca_path[0] != '\0' ? g_ca_path : NULL;
if (SSL_CTX_load_verify_locations(ssl_ctx, file, path) == 1) {
ca_loaded = 1;
log_info("Loaded CA certificates from custom location");
}
}
if (!ca_loaded) {
if (SSL_CTX_set_default_verify_paths(ssl_ctx) != 1) {
log_info("Warning: Could not load default CA certificates");
} else {
log_info("Loaded system default CA certificates");
}
}
} else {
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_NONE, NULL);
log_info("Warning: SSL certificate verification disabled");
}
SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1);
SSL_CTX_set_mode(ssl_ctx, SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_verify_depth(ssl_ctx, 0);
}
void ssl_cleanup(void) {
@ -25,7 +70,6 @@ void ssl_cleanup(void) {
SSL_CTX_free(ssl_ctx);
ssl_ctx = NULL;
}
EVP_cleanup();
}
int ssl_do_handshake(connection_t *conn) {

View File

@ -5,6 +5,9 @@
extern SSL_CTX *ssl_ctx;
void ssl_set_verify(int enabled);
void ssl_set_ca_file(const char *path);
void ssl_set_ca_path(const char *path);
void ssl_init(void);
void ssl_cleanup(void);
int ssl_do_handshake(connection_t *conn);

View File

@ -18,6 +18,16 @@
#define MAX_REQUEST_LINE_SIZE 4096
#define MAX_URI_SIZE 2048
#define CONNECTION_TIMEOUT 300
#define ERROR_RESPONSE_SIZE 4096
#define HOST_HEADER_SIZE 512
#define MAX_BUFFER_SIZE (64 * 1024 * 1024)
#define MIN_DATA_FOR_REQUEST_CHECK 1
#define RATE_LIMIT_WINDOW_SECONDS 60
#define DEFAULT_RATE_LIMIT_REQUESTS 1000
#define HEALTH_CHECK_INTERVAL_SECONDS 30
#define HEALTH_CHECK_TIMEOUT_MS 5000
#define MAX_UPSTREAM_RETRIES 3
#define UPSTREAM_RETRY_DELAY_MS 100
typedef enum {
CONN_TYPE_UNUSED,
@ -79,6 +89,9 @@ typedef struct {
int upstream_port;
int use_ssl;
int rewrite_host;
int use_auth;
char username[128];
char password_hash[256];
} route_config_t;
typedef struct {