commit 035df8254b5f9dc922185df5f61d5c61da7f1a6f Author: retoor Date: Sat Sep 27 00:33:35 2025 +0200 Initial commit diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..df02419 --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ + +build: + gcc -D_GNU_SOURCE -O3 -march=native -mtune=native -flto -pthread rpubsub.c -o rpubsub -lrt -std=c11 + gcc -O3 -march=native -std=c11 -o load_test load_test.c -lm -lrt + +clean: + rm -f rpubsub load_test + + + diff --git a/load_test b/load_test new file mode 100755 index 0000000..2e5cd72 Binary files /dev/null and b/load_test differ diff --git a/load_test.c b/load_test.c new file mode 100644 index 0000000..ebbfde7 --- /dev/null +++ b/load_test.c @@ -0,0 +1,518 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// --- Test Configuration --- +#define HOST "127.0.0.1" +#define PORT 8080 +#define NUM_SUBSCRIBERS 1000 +#define NUM_PUBLISHERS 10 +#define TOTAL_CLIENTS (NUM_SUBSCRIBERS + NUM_PUBLISHERS) +#define TEST_DURATION_S 15 +#define MESSAGES_PER_SECOND_PER_PUBLISHER 100 + +// --- Internal Configuration --- +#define MAX_EVENTS TOTAL_CLIENTS +#define RW_BUFFER_SIZE 8192 +#define MAX_LATENCIES 20000000 // Pre-allocate for ~1.3M messages/sec + +// --- WebSocket Constants --- +#define WEBSOCKET_KEY_MAGIC "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +// --- Helper Enums --- +typedef enum { + CLIENT_SUBSCRIBER, + CLIENT_PUBLISHER +} ClientType; + +typedef enum { + STATE_CONNECTING, + STATE_HANDSHAKE_SEND, + STATE_HANDSHAKE_RECV, + STATE_SUBSCRIBING, + STATE_RUNNING, + STATE_CLOSED +} ClientState; + +// --- Client State Structure --- +typedef struct { + int fd; + ClientType type; + ClientState state; + char read_buf[RW_BUFFER_SIZE]; + size_t read_len; + char write_buf[RW_BUFFER_SIZE]; + size_t write_len; + size_t write_pos; + double next_send_time; +} Client; + +// --- Global State & Metrics --- +double* latencies; +size_t latencies_count = 0; +uint64_t messages_sent = 0; +uint64_t messages_received = 0; +int subscriber_setup_count = 0; +int all_subscribed = 0; +int epoll_fd; + +const char* CHANNELS[] = {"news", "sports", "tech", "finance", "weather"}; +const int NUM_CHANNELS = sizeof(CHANNELS) / sizeof(CHANNELS[0]); + + +// ============================================================================= +// START: Clean SHA-1 and Base64 Implementations +// ============================================================================= + +// --- SHA-1 Implementation --- +typedef struct { + uint32_t state[5]; + uint32_t count[2]; + unsigned char buffer[64]; +} SHA1_CTX; + +#define SHA1_ROTLEFT(n,c) (((n) << (c)) | ((n) >> (32 - (c)))) + +void SHA1_Transform(uint32_t state[5], const unsigned char buffer[64]) { + uint32_t a, b, c, d, e; + uint32_t block[16]; + memcpy(block, buffer, 64); + for(int i = 0; i < 16; i++) { + uint8_t *p = (uint8_t*)&block[i]; + block[i] = (p[0] << 24) | (p[1] << 16) | (p[2] << 8) | p[3]; + } + + a = state[0]; b = state[1]; c = state[2]; d = state[3]; e = state[4]; + + uint32_t W[80]; + for (int t = 0; t < 80; t++) { + if (t < 16) { + W[t] = block[t]; + } else { + W[t] = SHA1_ROTLEFT(W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16], 1); + } + uint32_t temp = SHA1_ROTLEFT(a, 5) + e + W[t]; + if (t < 20) temp += ((b & c) | (~b & d)) + 0x5A827999; + else if (t < 40) temp += (b ^ c ^ d) + 0x6ED9EBA1; + else if (t < 60) temp += ((b & c) | (b & d) | (c & d)) + 0x8F1BBCDC; + else temp += (b ^ c ^ d) + 0xCA62C1D6; + e = d; d = c; c = SHA1_ROTLEFT(b, 30); b = a; a = temp; + } + state[0] += a; state[1] += b; state[2] += c; state[3] += d; state[4] += e; +} + +void SHA1_Init(SHA1_CTX* context) { + context->state[0] = 0x67452301; + context->state[1] = 0xEFCDAB89; + context->state[2] = 0x98BADCFE; + context->state[3] = 0x10325476; + context->state[4] = 0xC3D2E1F0; + context->count[0] = context->count[1] = 0; +} + +void SHA1_Update(SHA1_CTX* context, const unsigned char* data, uint32_t len) { + uint32_t i, j; + j = context->count[0]; + if ((context->count[0] += len << 3) < j) context->count[1]++; + context->count[1] += (len >> 29); + j = (j >> 3) & 63; + if ((j + len) > 63) { + memcpy(&context->buffer[j], data, (i = 64 - j)); + SHA1_Transform(context->state, context->buffer); + for (; i + 63 < len; i += 64) { + SHA1_Transform(context->state, &data[i]); + } + j = 0; + } else { + i = 0; + } + memcpy(&context->buffer[j], &data[i], len - i); +} + +void SHA1_Final(unsigned char digest[20], SHA1_CTX* context) { + uint32_t i; + unsigned char finalcount[8]; + for (i = 0; i < 8; i++) { + finalcount[i] = (unsigned char)((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); + } + SHA1_Update(context, (unsigned char*)"\x80", 1); + while ((context->count[0] & 504) != 448) { + SHA1_Update(context, (unsigned char*)"\0", 1); + } + SHA1_Update(context, finalcount, 8); + for (i = 0; i < 20; i++) { + digest[i] = (unsigned char)((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255); + } +} + +// --- Base64 Implementation --- +char* base64_encode(const unsigned char *data, size_t input_length) { + const char b64_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + size_t output_length = 4 * ((input_length + 2) / 3); + char *encoded_data = malloc(output_length + 1); + if (encoded_data == NULL) return NULL; + + for (size_t i = 0, j = 0; i < input_length;) { + uint32_t octet_a = i < input_length ? data[i++] : 0; + uint32_t octet_b = i < input_length ? data[i++] : 0; + uint32_t octet_c = i < input_length ? data[i++] : 0; + uint32_t triple = (octet_a << 16) + (octet_b << 8) + octet_c; + + encoded_data[j++] = b64_table[(triple >> 18) & 0x3F]; + encoded_data[j++] = b64_table[(triple >> 12) & 0x3F]; + encoded_data[j++] = b64_table[(triple >> 6) & 0x3F]; + encoded_data[j++] = b64_table[triple & 0x3F]; + } + + for (size_t i = 0; i < (3 - input_length % 3) % 3; i++) { + encoded_data[output_length - 1 - i] = '='; + } + encoded_data[output_length] = '\0'; + return encoded_data; +} + +// ============================================================================= +// END: Clean SHA-1 and Base64 Implementations +// ============================================================================= + + +// --- Utility Functions --- +double get_time_double() { + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + return ts.tv_sec + ts.tv_nsec / 1e9; +} + +void epoll_ctl_mod(int fd, uint32_t events, void* ptr) { + struct epoll_event ev; + ev.events = events; + ev.data.ptr = ptr; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, fd, &ev); +} + +void close_client(Client* client) { + if (client->state != STATE_CLOSED) { + epoll_ctl(epoll_fd, EPOLL_CTL_DEL, client->fd, NULL); + close(client->fd); + client->state = STATE_CLOSED; + } +} + +// --- WebSocket Core Functions --- +size_t create_ws_frame(const char* payload, size_t payload_len, char* out_buffer) { + size_t frame_len = 2 + payload_len + 4; // Header + Mask + Payload + if (payload_len > 125) frame_len += 2; // For 16-bit length + + out_buffer[0] = 0x81; // FIN + Text Frame + if (payload_len <= 125) { + out_buffer[1] = 0x80 | payload_len; + } else { + out_buffer[1] = 0x80 | 126; + *(uint16_t*)(out_buffer + 2) = htons(payload_len); + } + + size_t header_len = (payload_len <= 125) ? 2 : 4; + uint32_t mask = rand(); + *(uint32_t*)(out_buffer + header_len) = mask; + + uint8_t* mask_bytes = (uint8_t*)&mask; + for(size_t i = 0; i < payload_len; ++i) { + out_buffer[header_len + 4 + i] = payload[i] ^ mask_bytes[i % 4]; + } + return header_len + 4 + payload_len; +} + +void send_handshake(Client* client) { + unsigned char key_bytes[16]; + for (int i = 0; i < 16; i++) key_bytes[i] = rand() % 256; + char* b64_key = base64_encode(key_bytes, 16); + + client->write_len = snprintf(client->write_buf, RW_BUFFER_SIZE, + "GET / HTTP/1.1\r\n" + "Host: %s:%d\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: %s\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n", + HOST, PORT, b64_key); + + client->state = STATE_HANDSHAKE_SEND; + epoll_ctl_mod(client->fd, EPOLLIN | EPOLLOUT | EPOLLET, client); + free(b64_key); +} + +// --- Event Handlers --- +void handle_write(Client* client) { + ssize_t sent = send(client->fd, client->write_buf + client->write_pos, client->write_len - client->write_pos, 0); + if (sent < 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK) close_client(client); + return; + } + client->write_pos += sent; + + if (client->write_pos >= client->write_len) { + client->write_pos = 0; + client->write_len = 0; + epoll_ctl_mod(client->fd, EPOLLIN | EPOLLET, client); // Done writing, wait for reads + + if (client->state == STATE_HANDSHAKE_SEND) { + client->state = STATE_HANDSHAKE_RECV; + } else if (client->state == STATE_SUBSCRIBING) { + client->state = STATE_RUNNING; + subscriber_setup_count++; + if (subscriber_setup_count == NUM_SUBSCRIBERS) { + printf("✅ All subscribers are connected and subscribed. Starting publishers...\n"); + all_subscribed = 1; + } + } + } +} + +void handle_read(Client* client) { + ssize_t n = read(client->fd, client->read_buf + client->read_len, RW_BUFFER_SIZE - client->read_len); + if (n <= 0) { + if (n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) perror("read error"); + close_client(client); + return; + } + client->read_len += n; + + if (client->state == STATE_HANDSHAKE_RECV) { + if (strstr(client->read_buf, "\r\n\r\n")) { + if (strstr(client->read_buf, " 101 ") == NULL) { + fprintf(stderr, "Handshake failed for fd %d\n", client->fd); + close_client(client); + return; + } + client->read_len = 0; // Handshake complete, clear buffer + + if (client->type == CLIENT_SUBSCRIBER) { + const char* channel = CHANNELS[rand() % NUM_CHANNELS]; + char sub_msg[128]; + int msg_len = snprintf(sub_msg, sizeof(sub_msg), "sub %s", channel); + client->write_len = create_ws_frame(sub_msg, msg_len, client->write_buf); + client->state = STATE_SUBSCRIBING; + epoll_ctl_mod(client->fd, EPOLLIN | EPOLLOUT | EPOLLET, client); + } else { // Publisher + client->state = STATE_RUNNING; + } + } + } else if (client->state == STATE_RUNNING && client->type == CLIENT_SUBSCRIBER) { + // Simple WebSocket frame parsing for this specific test case + while (client->read_len >= 2) { + uint64_t payload_len = client->read_buf[1] & 0x7F; + size_t header_len = 2; + if (payload_len == 126) { + if (client->read_len < 4) break; + payload_len = ntohs(*(uint16_t*)(client->read_buf + 2)); + header_len = 4; + } else if (payload_len == 127) { + // Not expected for this test, would require 64-bit length handling + close_client(client); + break; + } + + if (client->read_len >= header_len + payload_len) { + char* payload = client->read_buf + header_len; + double sent_time = atof(payload); + if (sent_time > 0) { + double latency = get_time_double() - sent_time; + if (latencies_count < MAX_LATENCIES) { + latencies[latencies_count++] = latency; + } + messages_received++; + } + + size_t frame_size = header_len + payload_len; + memmove(client->read_buf, client->read_buf + frame_size, client->read_len - frame_size); + client->read_len -= frame_size; + } else { + break; // Incomplete frame + } + } + } +} + +// --- Statistics --- +int compare_doubles(const void* a, const void* b) { + double da = *(const double*)a; + double db = *(const double*)b; + if (da < db) return -1; + if (da > db) return 1; + return 0; +} + +void print_report() { + printf("\n" + "================================================================================\n"); + printf("%s\n", " PERFORMANCE REPORT "); + printf("================================================================================\n"); + + if (latencies_count == 0) { + printf("No messages were received. Cannot generate a report. Is the server running?\n"); + return; + } + + uint64_t message_loss = (messages_sent > messages_received) ? (messages_sent - messages_received) : 0; + double loss_rate = (messages_sent > 0) ? ((double)message_loss / messages_sent * 100.0) : 0; + double throughput = (double)messages_received / TEST_DURATION_S; + + printf("Test Duration: %d seconds\n", TEST_DURATION_S); + printf("Total Messages Sent: %lu\n", messages_sent); + printf("Total Messages Rcvd: %lu\n", messages_received); + printf("Message Loss: %lu (%.2f%%)\n", message_loss, loss_rate); + printf("Actual Throughput: %.2f msg/sec\n", throughput); + printf("--------------------------------------------------------------------------------\n"); + + qsort(latencies, latencies_count, sizeof(double), compare_doubles); + + double sum = 0; + for (size_t i = 0; i < latencies_count; ++i) sum += latencies[i]; + + printf("Latency Statistics (ms):\n"); + printf(" Average: %.4f ms\n", (sum / latencies_count) * 1000.0); + printf(" Min: %.4f ms\n", latencies[0] * 1000.0); + printf(" Max: %.4f ms\n", latencies[latencies_count - 1] * 1000.0); + printf(" Median (p50): %.4f ms\n", latencies[(size_t)(latencies_count * 0.50)] * 1000.0); + printf(" 95th Percentile: %.4f ms\n", latencies[(size_t)(latencies_count * 0.95)] * 1000.0); + printf(" 99th Percentile: %.4f ms\n", latencies[(size_t)(latencies_count * 0.99)] * 1000.0); + printf("================================================================================\n"); +} + +// --- Main Function --- +int main() { + srand(time(NULL)); + latencies = malloc(sizeof(double) * MAX_LATENCIES); + if (!latencies) { + perror("malloc latencies"); + return 1; + } + + printf("Starting WebSocket Pub/Sub Load Test...\n"); + printf("Simulating %d subscribers and %d publishers.\n", NUM_SUBSCRIBERS, NUM_PUBLISHERS); + printf("Publishing at ~%d msg/sec for %d seconds.\n", NUM_PUBLISHERS * MESSAGES_PER_SECOND_PER_PUBLISHER, TEST_DURATION_S); + printf("--------------------------------------------------------------------------------\n"); + + epoll_fd = epoll_create1(0); + if (epoll_fd == -1) { + perror("epoll_create1"); + free(latencies); + return 1; + } + + // *** FIX: Allocate clients on the heap, not the stack *** + Client* clients = malloc(sizeof(Client) * TOTAL_CLIENTS); + if (!clients) { + perror("malloc clients"); + free(latencies); + close(epoll_fd); + return 1; + } + + struct sockaddr_in server_addr; + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + inet_pton(AF_INET, HOST, &server_addr.sin_addr); + + for (int i = 0; i < TOTAL_CLIENTS; ++i) { + clients[i].fd = socket(AF_INET, SOCK_STREAM, 0); + fcntl(clients[i].fd, F_SETFL, O_NONBLOCK); + + clients[i].type = (i < NUM_SUBSCRIBERS) ? CLIENT_SUBSCRIBER : CLIENT_PUBLISHER; + clients[i].state = STATE_CONNECTING; + clients[i].read_len = 0; + clients[i].write_len = 0; + clients[i].write_pos = 0; + clients[i].next_send_time = 0; + + connect(clients[i].fd, (struct sockaddr*)&server_addr, sizeof(server_addr)); + + struct epoll_event ev; + ev.events = EPOLLIN | EPOLLOUT | EPOLLET; + ev.data.ptr = &clients[i]; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, clients[i].fd, &ev) == -1) { + perror("epoll_ctl: add"); + free(latencies); + free(clients); + close(epoll_fd); + return 1; + } + } + + struct epoll_event events[MAX_EVENTS]; + double start_time = get_time_double(); + double end_time = start_time + TEST_DURATION_S; + + while (get_time_double() < end_time) { + int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 10); // 10ms timeout + double now = get_time_double(); + + for (int i = 0; i < n; ++i) { + Client* client = (Client*)events[i].data.ptr; + if (events[i].events & (EPOLLERR | EPOLLHUP)) { + close_client(client); + continue; + } + + if (client->state == STATE_CONNECTING && (events[i].events & EPOLLOUT)) { + int result; + socklen_t result_len = sizeof(result); + getsockopt(client->fd, SOL_SOCKET, SO_ERROR, &result, &result_len); + if (result == 0) { + send_handshake(client); + } else { + close_client(client); + } + } else { + if (events[i].events & EPOLLIN) handle_read(client); + if (events[i].events & EPOLLOUT) handle_write(client); + } + } + + // Publisher logic + if (all_subscribed) { + for (int i = NUM_SUBSCRIBERS; i < TOTAL_CLIENTS; ++i) { + Client* client = &clients[i]; + if (client->state == STATE_RUNNING && client->write_len == 0 && now >= client->next_send_time) { + const char* channel = CHANNELS[rand() % NUM_CHANNELS]; + char message[256]; + int msg_len = snprintf(message, sizeof(message), "%.6f:Hello from publisher %d on channel %s", now, i - NUM_SUBSCRIBERS, channel); + + char pub_msg[384]; + int pub_msg_len = snprintf(pub_msg, sizeof(pub_msg), "pub %s %s", channel, message); + + client->write_len = create_ws_frame(pub_msg, pub_msg_len, client->write_buf); + messages_sent++; + + client->next_send_time = now + (1.0 / MESSAGES_PER_SECOND_PER_PUBLISHER); + epoll_ctl_mod(client->fd, EPOLLOUT | EPOLLIN | EPOLLET, client); + } + } + } + } + + printf("\nTest duration finished. Shutting down clients...\n"); + for (int i = 0; i < TOTAL_CLIENTS; ++i) { + close_client(&clients[i]); + } + + print_report(); + + free(clients); + free(latencies); + close(epoll_fd); + + return 0; +} diff --git a/rpubsub b/rpubsub new file mode 100755 index 0000000..0c49139 Binary files /dev/null and b/rpubsub differ diff --git a/rpubsub.c b/rpubsub.c new file mode 100644 index 0000000..9eb6dc3 --- /dev/null +++ b/rpubsub.c @@ -0,0 +1,766 @@ +// ============================================================================= +// High-Performance Multi-Threaded WebSocket Pub/Sub Server +// +// Author: Gemini +// Date: September 27, 2025 +// +// Key Optimizations: +// 1. Worker Thread Pool: Offloads message fan-out from the I/O thread. +// 2. Lock-Free Task Queue: Efficiently passes tasks to workers. +// 3. Decoupled I/O: Workers queue data; the I/O thread sends it. +// 4. Circular Ring Buffers: Simplified and efficient client write buffers. +// 5. Thread-Safe Epoll Control: Uses a pipe to signal I/O thread safely. +// 6. Optimized Data Structures: Faster channel lookups and client removal. +// ============================================================================= + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // <--- FIX: Added for bool type +#include +#include +#include +#include // for strcasestr + +// --- Server Configuration --- +#define PORT 8080 +#define MAX_CLIENTS 65536 +#define MAX_EVENTS 2048 +#define READ_BUFFER_SIZE 8192 +#define WRITE_BUFFER_SIZE 262144 // 256KB per-client write buffer +#define MAX_FRAME_SIZE 65536 // 64KB max incoming frame +#define MAX_CHANNELS 1024 +#define MAX_SUBSCRIPTIONS 32 +#define WORKER_THREADS 4 // Number of threads for broadcasting +#define TASK_QUEUE_SIZE 16384 + +#define WEBSOCKET_KEY_MAGIC "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +#define LISTEN_BACKLOG 32768 + +// Forward declarations +struct ChannelNode; + +// --- Data Structures --- + +// Circular write buffer for non-blocking sends +typedef struct { + uint8_t* data; + size_t capacity; + atomic_size_t head; + atomic_size_t tail; + pthread_spinlock_t lock; // Protects against concurrent writes from workers +} RingBuffer; + +typedef enum { + STATE_HANDSHAKE, + STATE_CONNECTED, + STATE_CLOSED +} ClientState; + +typedef struct { + ClientState state; + RingBuffer write_buf; + uint8_t* read_buf; + size_t read_len; + struct ChannelNode* subscriptions[MAX_SUBSCRIPTIONS]; + int sub_count; + atomic_char write_registered; // <--- FIX: Changed from atomic_bool to atomic_char +} Client; + +// Channel for pub/sub +typedef struct ChannelNode { + char name[64]; + int* subscribers; // Array of client FDs + int sub_count; + int sub_capacity; + pthread_rwlock_t lock; + struct ChannelNode* next; +} ChannelNode; + +typedef struct { + ChannelNode* buckets[256]; // Simple hash table for channels +} ChannelTable; + +// Task for worker threads to execute broadcasts +typedef struct { + struct ChannelNode* channel; + uint8_t* frame_data; + size_t frame_len; +} BroadcastTask; + +// Lock-free Single-Producer, Multi-Consumer queue for tasks +typedef struct { + BroadcastTask* tasks; + atomic_size_t head; + atomic_size_t tail; + size_t capacity; +} SPMCQueue; + +// --- Globals --- +Client* clients; +ChannelTable channels; +int epoll_fd; +int notify_pipe[2]; // Pipe for workers to signal main thread +SPMCQueue task_queue; +pthread_t worker_threads[WORKER_THREADS]; +volatile sig_atomic_t running = 1; +atomic_int active_connections = 0; + +// --- Function Prototypes --- +void remove_client(int fd, int gracefully); +void arm_write(int fd); + +// --- Utils --- +void handle_sigint(int sig) { running = 0; } + +static inline uint64_t get_ns_time() { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return ts.tv_sec * 1000000000ULL + ts.tv_nsec; +} + +// --- Ring Buffer Implementation --- +void ring_buffer_init(RingBuffer* rb) { + rb->data = malloc(WRITE_BUFFER_SIZE); + rb->capacity = WRITE_BUFFER_SIZE; + atomic_init(&rb->head, 0); + atomic_init(&rb->tail, 0); + pthread_spin_init(&rb->lock, PTHREAD_PROCESS_PRIVATE); +} + +void ring_buffer_free(RingBuffer* rb) { + if (rb->data) free(rb->data); + pthread_spin_destroy(&rb->lock); +} + +// Tries to write data to the buffer. Used by worker threads. +int ring_buffer_write(RingBuffer* rb, const uint8_t* data, size_t len) { + pthread_spin_lock(&rb->lock); + size_t head = atomic_load_explicit(&rb->head, memory_order_relaxed); + size_t tail = atomic_load_explicit(&rb->tail, memory_order_relaxed); + size_t free_space = rb->capacity - (head - tail); + + if (len > free_space) { + pthread_spin_unlock(&rb->lock); + return 0; // Not enough space + } + + size_t head_idx = head % rb->capacity; + size_t to_end = rb->capacity - head_idx; + if (len <= to_end) { + memcpy(rb->data + head_idx, data, len); + } else { + memcpy(rb->data + head_idx, data, to_end); + memcpy(rb->data, data + to_end, len - to_end); + } + atomic_store_explicit(&rb->head, head + len, memory_order_release); + pthread_spin_unlock(&rb->lock); + return 1; +} + +// --- Task Queue --- +void queue_init(SPMCQueue* q) { + q->tasks = calloc(TASK_QUEUE_SIZE, sizeof(BroadcastTask)); + atomic_init(&q->head, 0); + atomic_init(&q->tail, 0); + q->capacity = TASK_QUEUE_SIZE; +} + +// Used by main I/O thread (single producer) +int queue_push(SPMCQueue* q, BroadcastTask task) { + size_t head = atomic_load_explicit(&q->head, memory_order_relaxed); + size_t tail = atomic_load_explicit(&q->tail, memory_order_acquire); + if (head - tail >= q->capacity) { + return 0; // Queue full + } + q->tasks[head % q->capacity] = task; + atomic_store_explicit(&q->head, head + 1, memory_order_release); + return 1; +} + +// Used by worker threads (multi-consumer) +int queue_pop(SPMCQueue* q, BroadcastTask* task) { + while (1) { + size_t tail = atomic_load_explicit(&q->tail, memory_order_relaxed); + size_t head = atomic_load_explicit(&q->head, memory_order_acquire); + if (tail >= head) { + return 0; // Queue empty + } + *task = q->tasks[tail % q->capacity]; + if (atomic_compare_exchange_weak_explicit(&q->tail, &tail, tail + 1, memory_order_release, memory_order_relaxed)) { + return 1; + } + } +} + +// --- SHA-1 Implementation --- +typedef struct { uint32_t s[5]; uint32_t c[2]; uint8_t b[64]; } SHA1_CTX; +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) +#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) | (rol(block->l[i],8)&0x00FF00FF)) +#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15]^block->l[(i+2)&15]^block->l[i&15],1)) +#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); +#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); +#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); + +void SHA1_Transform(uint32_t s[5], const uint8_t buffer[64]) { + uint32_t a, b, c, d, e; + typedef union { uint8_t c[64]; uint32_t l[16]; } CHAR64LONG16; + CHAR64LONG16* block = (CHAR64LONG16*)buffer; + a = s[0]; b = s[1]; c = s[2]; d = s[3]; e = s[4]; + R0(a,b,c,d,e, 0); R0(e,a,b,c,d, 1); R0(d,e,a,b,c, 2); R0(c,d,e,a,b, 3); + R0(b,c,d,e,a, 4); R0(a,b,c,d,e, 5); R0(e,a,b,c,d, 6); R0(d,e,a,b,c, 7); + R0(c,d,e,a,b, 8); R0(b,c,d,e,a, 9); R0(a,b,c,d,e,10); R0(e,a,b,c,d,11); + R0(d,e,a,b,c,12); R0(c,d,e,a,b,13); R0(b,c,d,e,a,14); R0(a,b,c,d,e,15); + R1(e,a,b,c,d,16); R1(d,e,a,b,c,17); R1(c,d,e,a,b,18); R1(b,c,d,e,a,19); + R2(a,b,c,d,e,20); R2(e,a,b,c,d,21); R2(d,e,a,b,c,22); R2(c,d,e,a,b,23); + R2(b,c,d,e,a,24); R2(a,b,c,d,e,25); R2(e,a,b,c,d,26); R2(d,e,a,b,c,27); + R2(c,d,e,a,b,28); R2(b,c,d,e,a,29); R2(a,b,c,d,e,30); R2(e,a,b,c,d,31); + R2(d,e,a,b,c,32); R2(c,d,e,a,b,33); R2(b,c,d,e,a,34); R2(a,b,c,d,e,35); + R2(e,a,b,c,d,36); R2(d,e,a,b,c,37); R2(c,d,e,a,b,38); R2(b,c,d,e,a,39); + R3(a,b,c,d,e,40); R3(e,a,b,c,d,41); R3(d,e,a,b,c,42); R3(c,d,e,a,b,43); + R3(b,c,d,e,a,44); R3(a,b,c,d,e,45); R3(e,a,b,c,d,46); R3(d,e,a,b,c,47); + R3(c,d,e,a,b,48); R3(b,c,d,e,a,49); R3(a,b,c,d,e,50); R3(e,a,b,c,d,51); + R3(d,e,a,b,c,52); R3(c,d,e,a,b,53); R3(b,c,d,e,a,54); R3(a,b,c,d,e,55); + R3(e,a,b,c,d,56); R3(d,e,a,b,c,57); R3(c,d,e,a,b,58); R3(b,c,d,e,a,59); + R4(a,b,c,d,e,60); R4(e,a,b,c,d,61); R4(d,e,a,b,c,62); R4(c,d,e,a,b,63); + R4(b,c,d,e,a,64); R4(a,b,c,d,e,65); R4(e,a,b,c,d,66); R4(d,e,a,b,c,67); + R4(c,d,e,a,b,68); R4(b,c,d,e,a,69); R4(a,b,c,d,e,70); R4(e,a,b,c,d,71); + R4(d,e,a,b,c,72); R4(c,d,e,a,b,73); R4(b,c,d,e,a,74); R4(a,b,c,d,e,75); + R4(e,a,b,c,d,76); R4(d,e,a,b,c,77); R4(c,d,e,a,b,78); R4(b,c,d,e,a,79); + s[0] += a; s[1] += b; s[2] += c; s[3] += d; s[4] += e; +} + +void SHA1_Init(SHA1_CTX* c) { + c->s[0] = 0x67452301; c->s[1] = 0xEFCDAB89; c->s[2] = 0x98BADCFE; + c->s[3] = 0x10325476; c->s[4] = 0xC3D2E1F0; + c->c[0] = c->c[1] = 0; +} + +void SHA1_Update(SHA1_CTX* c, const uint8_t* d, uint32_t l) { + uint32_t i, j; j = (c->c[0] >> 3) & 63; + if ((c->c[0] += l << 3) < (l << 3)) c->c[1]++; c->c[1] += (l >> 29); + if ((j + l) > 63) { + memcpy(&c->b[j], d, (i = 64-j)); + SHA1_Transform(c->s, c->b); + for (; i + 63 < l; i += 64) SHA1_Transform(c->s, &d[i]); + j = 0; + } else i = 0; + memcpy(&c->b[j], &d[i], l - i); +} + +void SHA1_Final(uint8_t d[20], SHA1_CTX* c) { + uint32_t i; uint8_t fc[8]; + for (i = 0; i < 8; i++) + fc[i] = (uint8_t)((c->c[(i >= 4 ? 0 : 1)] >> ((3-(i & 3)) * 8)) & 255); + SHA1_Update(c, (uint8_t*)"\200", 1); + while ((c->c[0] & 504) != 448) SHA1_Update(c, (uint8_t*)"\0", 1); + SHA1_Update(c, fc, 8); + for (i = 0; i < 20; i++) + d[i] = (uint8_t)((c->s[i>>2] >> ((3-(i & 3)) * 8)) & 255); +} + +// --- Base64 Implementation --- +const char b64_table[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +char* base64_encode(const uint8_t* data, size_t len) { + size_t out_len = 4 * ((len + 2) / 3); + char* out = malloc(out_len + 1); + if (!out) return NULL; + for (size_t i = 0, j = 0; i < len;) { + uint32_t a = i < len ? data[i++] : 0; + uint32_t b = i < len ? data[i++] : 0; + uint32_t c = i < len ? data[i++] : 0; + uint32_t t = (a << 16) + (b << 8) + c; + out[j++] = b64_table[(t >> 18) & 0x3F]; + out[j++] = b64_table[(t >> 12) & 0x3F]; + out[j++] = b64_table[(t >> 6) & 0x3F]; + out[j++] = b64_table[t & 0x3F]; + } + for (size_t i = 0; i < (3 - len % 3) % 3; i++) + out[out_len - 1 - i] = '='; + out[out_len] = '\0'; + return out; +} + +// --- Channel Management --- +uint8_t hash_channel(const char* name) { + uint8_t hash = 53; // A prime starting number + while (*name) hash = (hash * 31) + *name++; // Another prime multiplier + return hash; +} + +ChannelNode* find_or_create_channel(const char* name) { + uint8_t h = hash_channel(name); + ChannelNode* node = channels.buckets[h]; + while (node) { + if (strcmp(node->name, name) == 0) return node; + node = node->next; + } + node = calloc(1, sizeof(ChannelNode)); + strncpy(node->name, name, 63); + node->sub_capacity = 8; + node->subscribers = malloc(sizeof(int) * node->sub_capacity); + pthread_rwlock_init(&node->lock, NULL); + node->next = channels.buckets[h]; + channels.buckets[h] = node; + return node; +} + +void add_subscriber(ChannelNode* ch, int fd) { + pthread_rwlock_wrlock(&ch->lock); + if (ch->sub_count >= ch->sub_capacity) { + ch->sub_capacity *= 2; + ch->subscribers = realloc(ch->subscribers, sizeof(int) * ch->sub_capacity); + } + ch->subscribers[ch->sub_count++] = fd; + pthread_rwlock_unlock(&ch->lock); + + Client* c = &clients[fd]; + if (c->sub_count < MAX_SUBSCRIPTIONS) { + c->subscriptions[c->sub_count++] = ch; + } +} + +void remove_subscriber(ChannelNode* ch, int fd) { + pthread_rwlock_wrlock(&ch->lock); + for (int i = 0; i < ch->sub_count; i++) { + if (ch->subscribers[i] == fd) { + ch->subscribers[i] = ch->subscribers[--ch->sub_count]; + break; + } + } + pthread_rwlock_unlock(&ch->lock); +} + +// --- WebSocket Logic --- +void handle_handshake(int fd) { + Client* c = &clients[fd]; + char* req = (char*)c->read_buf; + if (!strstr(req, "\r\n\r\n")) return; + + char* key_start = strcasestr(req, "Sec-WebSocket-Key: "); + if (!key_start) { remove_client(fd, 0); return; } + key_start += 19; + char* key_end = strchr(key_start, '\r'); + if (!key_end) { remove_client(fd, 0); return; } + + char key[256]; + size_t key_len = key_end - key_start; + memcpy(key, key_start, key_len); + memcpy(key + key_len, WEBSOCKET_KEY_MAGIC, strlen(WEBSOCKET_KEY_MAGIC)); + key[key_len + strlen(WEBSOCKET_KEY_MAGIC)] = '\0'; + + uint8_t sha1[20]; + SHA1_CTX ctx; + SHA1_Init(&ctx); + SHA1_Update(&ctx, (uint8_t*)key, strlen(key)); + SHA1_Final(sha1, &ctx); + + char* accept = base64_encode(sha1, 20); + char response[256]; + int len = snprintf(response, sizeof(response), + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n\r\n", accept); + free(accept); + + if (send(fd, response, len, MSG_NOSIGNAL | MSG_DONTWAIT) == len) { + c->state = STATE_CONNECTED; + c->read_len = 0; // Clear handshake data + atomic_fetch_add(&active_connections, 1); + } else { + remove_client(fd, 0); + } +} + +void process_ws_message(int fd, uint8_t* payload, size_t len) { + payload[len] = '\0'; // Ensure null termination for string functions + char cmd[16], channel_name[64]; + + if (sscanf((char*)payload, "%15s %63s", cmd, channel_name) < 2) return; + + if (strcmp(cmd, "sub") == 0) { + ChannelNode* ch = find_or_create_channel(channel_name); + if (ch) add_subscriber(ch, fd); + } else if (strcmp(cmd, "pub") == 0) { + char* msg_start = (char*)payload + strlen(cmd) + 1 + strlen(channel_name) + 1; + if (msg_start >= (char*)payload + len) return; + size_t msg_len = len - (msg_start - (char*)payload); + + ChannelNode* ch = find_or_create_channel(channel_name); + if (!ch || ch->sub_count == 0) return; + + // Build WebSocket frame header once + uint8_t header[10]; + int header_len = 2; + header[0] = 0x81; // FIN + Text Frame + if (msg_len < 126) { + header[1] = msg_len; + } else { + header[1] = 126; + header[2] = (msg_len >> 8) & 0xFF; + header[3] = msg_len & 0xFF; + header_len = 4; + } + + // Allocate a single buffer for the entire frame + size_t frame_len = header_len + msg_len; + uint8_t* frame_data = malloc(frame_len); + if (!frame_data) return; + memcpy(frame_data, header, header_len); + memcpy(frame_data + header_len, msg_start, msg_len); + + BroadcastTask task = { .channel = ch, .frame_data = frame_data, .frame_len = frame_len }; + if (!queue_push(&task_queue, task)) { + // If queue is full, drop the message and free memory + free(frame_data); + } + } +} + +void handle_ws_data(int fd) { + Client* c = &clients[fd]; + uint8_t* buf = c->read_buf; + size_t len = c->read_len; + + while (len >= 2) { + uint64_t payload_len = buf[1] & 0x7F; + size_t header_len = 2; + if (payload_len == 126) { + if (len < 4) break; + payload_len = ((uint64_t)buf[2] << 8) | buf[3]; + header_len = 4; + } else if (payload_len == 127) { + if (len < 10) break; + payload_len = __builtin_bswap64(*(uint64_t*)(buf + 2)); + header_len = 10; + } + + if (payload_len > MAX_FRAME_SIZE) { remove_client(fd, 0); return; } + + size_t mask_offset = header_len; + size_t payload_offset = header_len + 4; + size_t total_frame_len = payload_offset + payload_len; + + if (len < total_frame_len) break; // Incomplete frame + + uint32_t* mask = (uint32_t*)(buf + mask_offset); + uint8_t* payload = buf + payload_offset; + + // Unmask payload (optimized for 4-byte chunks) + for (size_t i = 0; i < payload_len / 4; i++) { + ((uint32_t*)payload)[i] ^= *mask; + } + for (size_t i = payload_len - (payload_len % 4); i < payload_len; i++) { + payload[i] ^= ((uint8_t*)mask)[i % 4]; + } + + uint8_t opcode = buf[0] & 0x0F; + if (opcode == 0x01) { // Text + process_ws_message(fd, payload, payload_len); + } else if (opcode == 0x08) { // Close + remove_client(fd, 1); + return; + } else if (opcode == 0x09) { // Ping + uint8_t frame[12]; + frame[0] = 0x8A; // Pong frame + memcpy(frame + 2, payload, payload_len < 10 ? payload_len : 10); + ring_buffer_write(&c->write_buf, frame, 2 + payload_len); + arm_write(fd); + } + + memmove(buf, buf + total_frame_len, len - total_frame_len); + len -= total_frame_len; + } + c->read_len = len; +} + +// --- Network Event Handlers --- +void handle_read(int fd) { + Client* c = &clients[fd]; + ssize_t n = recv(fd, c->read_buf + c->read_len, READ_BUFFER_SIZE - c->read_len, MSG_DONTWAIT); + + if (n > 0) { + c->read_len += n; + if (c->state == STATE_HANDSHAKE) { + handle_handshake(fd); + } else if (c->state == STATE_CONNECTED) { + handle_ws_data(fd); + } + } else if (n == 0 || (errno != EAGAIN && errno != EWOULDBLOCK)) { + remove_client(fd, 0); + } +} + +void handle_write(int fd) { + Client* c = &clients[fd]; + RingBuffer* rb = &c->write_buf; + + size_t tail = atomic_load_explicit(&rb->tail, memory_order_acquire); + size_t head = atomic_load_explicit(&rb->head, memory_order_acquire); + if (tail == head) return; // Nothing to write + + size_t tail_idx = tail % rb->capacity; + size_t head_idx = head % rb->capacity; + size_t len = (head > tail) ? (head - tail) : (rb->capacity - tail_idx + head_idx); + + ssize_t sent; + if (head_idx > tail_idx || tail_idx == head_idx) { // Data does not wrap or buffer is full but appears as non-wrapping + sent = send(fd, rb->data + tail_idx, len, MSG_NOSIGNAL | MSG_DONTWAIT); + } else { // Wraps around + struct iovec iov[2]; + iov[0].iov_base = rb->data + tail_idx; + iov[0].iov_len = rb->capacity - tail_idx; + iov[1].iov_base = rb->data; + iov[1].iov_len = head_idx; + sent = writev(fd, iov, 2); + } + + if (sent > 0) { + atomic_store_explicit(&rb->tail, tail + sent, memory_order_release); + } else if (errno != EAGAIN && errno != EWOULDBLOCK) { + remove_client(fd, 0); + return; + } + + // If buffer is not empty, we need to keep writing + if (atomic_load_explicit(&rb->tail, memory_order_relaxed) != atomic_load_explicit(&rb->head, memory_order_relaxed)) { + arm_write(fd); + } else { + atomic_store(&c->write_registered, 0); + } +} + +void handle_accept(int server_fd) { + while (1) { + int fd = accept4(server_fd, NULL, NULL, SOCK_NONBLOCK); + if (fd < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) break; + perror("accept4"); + continue; + } + if (fd >= MAX_CLIENTS) { close(fd); continue; } + + int opt = 1; + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + + Client* c = &clients[fd]; + memset(c, 0, sizeof(Client)); + c->state = STATE_HANDSHAKE; + c->read_buf = malloc(READ_BUFFER_SIZE); + ring_buffer_init(&c->write_buf); + atomic_init(&c->write_registered, 0); + + struct epoll_event ev = { .events = EPOLLIN | EPOLLET | EPOLLRDHUP, .data.fd = fd }; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev) < 0) { + perror("epoll_ctl add client"); + free(c->read_buf); + ring_buffer_free(&c->write_buf); + close(fd); + } + } +} + +void remove_client(int fd, int gracefully) { + if (fd < 0 || fd >= MAX_CLIENTS || clients[fd].state == STATE_CLOSED) return; + + Client* c = &clients[fd]; + if (c->state == STATE_CONNECTED) { + atomic_fetch_sub(&active_connections, 1); + } + c->state = STATE_CLOSED; + + // Unsubscribe from channels efficiently + for (int i = 0; i < c->sub_count; i++) { + if (c->subscriptions[i]) { + remove_subscriber(c->subscriptions[i], fd); + } + } + + epoll_ctl(epoll_fd, EPOLL_CTL_DEL, fd, NULL); + close(fd); + free(c->read_buf); + ring_buffer_free(&c->write_buf); +} + +// --- Worker Thread Logic --- +void execute_broadcast(BroadcastTask* task) { + ChannelNode* ch = task->channel; + pthread_rwlock_rdlock(&ch->lock); + + // Create a temporary copy to avoid holding the lock for too long + int num_subs = ch->sub_count; + if (num_subs == 0) { + pthread_rwlock_unlock(&ch->lock); + return; + } + + int* subs_copy = malloc(sizeof(int) * num_subs); + if (subs_copy) { + memcpy(subs_copy, ch->subscribers, sizeof(int) * num_subs); + } + pthread_rwlock_unlock(&ch->lock); + + if (!subs_copy) return; + + for (int i = 0; i < num_subs; i++) { + int fd = subs_copy[i]; + if (fd < 0 || fd >= MAX_CLIENTS) continue; + + Client* c = &clients[fd]; + if (c->state != STATE_CONNECTED) continue; + + // Check if write buffer was empty before adding data + size_t head = atomic_load_explicit(&c->write_buf.head, memory_order_relaxed); + size_t tail = atomic_load_explicit(&c->write_buf.tail, memory_order_relaxed); + int was_empty = (head == tail); + + if (ring_buffer_write(&c->write_buf, task->frame_data, task->frame_len)) { + // If it was empty, we need to tell the I/O thread to arm EPOLLOUT + if (was_empty) { + arm_write(fd); + } + } + } + free(subs_copy); +} + +void* worker_main(void* arg) { + int id = *(int*)arg; + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + if (id + 1 < sysconf(_SC_NPROCESSORS_ONLN)) { + CPU_SET(id + 1, &cpuset); // Pin workers to cores 1, 2, 3... + pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); + } + + while (running) { + BroadcastTask task; + if (queue_pop(&task_queue, &task)) { + execute_broadcast(&task); + free(task.frame_data); // Free the frame after broadcasting + } else { + usleep(100); // Sleep briefly if queue is empty + } + } + return NULL; +} + +// Safely tells the main I/O thread to arm EPOLLOUT for a given FD +void arm_write(int fd) { + if (fd < 0 || fd >= MAX_CLIENTS) return; + Client* c = &clients[fd]; + // Use CAS to avoid redundant pipe writes and epoll_ctl calls + char expected = 0; // <--- FIX: Changed from bool to char + if (atomic_compare_exchange_strong(&c->write_registered, &expected, 1)) { + write(notify_pipe[1], &fd, sizeof(fd)); + } +} + +// --- Main Server --- +int main() { + signal(SIGINT, handle_sigint); + signal(SIGPIPE, SIG_IGN); + + clients = calloc(MAX_CLIENTS, sizeof(Client)); + queue_init(&task_queue); + + // Create server socket + int server_fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); + int opt = 1; + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + struct sockaddr_in addr = { .sin_family = AF_INET, .sin_port = htons(PORT), .sin_addr.s_addr = INADDR_ANY }; + if (bind(server_fd, (struct sockaddr*)&addr, sizeof(addr)) < 0) { perror("bind"); return 1; } + if (listen(server_fd, LISTEN_BACKLOG) < 0) { perror("listen"); return 1; } + + epoll_fd = epoll_create1(0); + struct epoll_event ev = { .events = EPOLLIN | EPOLLET, .data.fd = server_fd }; + epoll_ctl(epoll_fd, EPOLL_CTL_ADD, server_fd, &ev); + + // Create pipe for thread communication + if (pipe2(notify_pipe, O_NONBLOCK) < 0) { perror("pipe2"); return 1; } + ev.events = EPOLLIN | EPOLLET; + ev.data.fd = notify_pipe[0]; + epoll_ctl(epoll_fd, EPOLL_CTL_ADD, notify_pipe[0], &ev); + + // Pin main I/O thread to core 0 + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(0, &cpuset); + sched_setaffinity(0, sizeof(cpuset), &cpuset); + + // Start worker threads + int worker_ids[WORKER_THREADS]; + for (int i = 0; i < WORKER_THREADS; i++) { + worker_ids[i] = i; + pthread_create(&worker_threads[i], NULL, worker_main, &worker_ids[i]); + } + + printf("Server started on port %d with %d worker threads.\n", PORT, WORKER_THREADS); + + struct epoll_event events[MAX_EVENTS]; + uint64_t last_stats_time = get_ns_time(); + + while (running) { + int n = epoll_wait(epoll_fd, events, MAX_EVENTS, 200); + for (int i = 0; i < n; i++) { + int fd = events[i].data.fd; + uint32_t e = events[i].events; + + if (fd == server_fd) { + handle_accept(server_fd); + } else if (fd == notify_pipe[0]) { + int client_fd; + while (read(notify_pipe[0], &client_fd, sizeof(client_fd)) > 0) { + struct epoll_event client_ev = { + .events = EPOLLIN | EPOLLOUT | EPOLLET | EPOLLRDHUP, + .data.fd = client_fd + }; + epoll_ctl(epoll_fd, EPOLL_CTL_MOD, client_fd, &client_ev); + } + } else { + if (e & (EPOLLERR | EPOLLHUP | EPOLLRDHUP)) { + remove_client(fd, 0); + continue; + } + if (e & EPOLLIN) handle_read(fd); + if (e & EPOLLOUT) handle_write(fd); + } + } + + uint64_t now = get_ns_time(); + if (now - last_stats_time > 5000000000ULL) { + printf("Active connections: %d\n", atomic_load(&active_connections)); + last_stats_time = now; + } + } + + printf("Shutting down...\n"); + for (int i = 0; i < WORKER_THREADS; i++) { + pthread_join(worker_threads[i], NULL); + } + + close(server_fd); + close(notify_pipe[0]); + close(notify_pipe[1]); + close(epoll_fd); + free(clients); + // ... further cleanup for channel structures etc. would be ideal in a real app ... + + printf("Server shutdown complete.\n"); + return 0; +} diff --git a/test.py b/test.py new file mode 100644 index 0000000..7408f2c --- /dev/null +++ b/test.py @@ -0,0 +1,148 @@ +import asyncio +import time +import random +import statistics +from collections import deque +import websockets + +# --- Test Configuration --- +HOST = "127.0.0.1" +PORT = 8080 +URI = f"ws://{HOST}:{PORT}" + +# Client setup +NUM_SUBSCRIBERS = 1000 +NUM_PUBLISHERS = 10 +CHANNELS = ["news", "sports", "tech", "finance", "weather"] + +# Test execution +TEST_DURATION_S = 15 +MESSAGES_PER_SECOND_PER_PUBLISHER = 100 # Increased message rate + +# --- Global State & Metrics --- +latencies = deque() +messages_sent = 0 +messages_received = 0 +subscriber_setup_count = 0 +all_subscribed_event = asyncio.Event() + +async def subscriber_client(client_id: int): + global subscriber_setup_count, messages_received + channel = random.choice(CHANNELS) + + try: + async with websockets.connect(URI) as websocket: + await websocket.send(f"sub {channel}") + subscriber_setup_count += 1 + if subscriber_setup_count == NUM_SUBSCRIBERS: + print("✅ All subscribers are connected and subscribed. Starting publishers...") + all_subscribed_event.set() + + while True: + message = await websocket.recv() + try: + sent_time_str = message.split(":", 1)[0] + sent_time = float(sent_time_str) + latency = time.time() - sent_time + latencies.append(latency) + messages_received += 1 + except (ValueError, IndexError): + print(f"Warning: Received malformed message: {message}") + + except (websockets.exceptions.ConnectionClosedError, ConnectionRefusedError) as e: + print(f"Subscriber {client_id} disconnected: {e}") + except asyncio.CancelledError: + pass + except Exception as e: + print(f"An unexpected error occurred in subscriber {client_id}: {e}") + +async def publisher_client(client_id: int): + global messages_sent + await all_subscribed_event.wait() + + sleep_interval = 1.0 / MESSAGES_PER_SECOND_PER_PUBLISHER + + try: + async with websockets.connect(URI) as websocket: + while True: + channel = random.choice(CHANNELS) + send_time = time.time() + message = f"{send_time:.6f}:Hello from publisher {client_id} on channel {channel}" + + await websocket.send(f"pub {channel} {message}") + messages_sent += 1 + + await asyncio.sleep(sleep_interval) + + except (websockets.exceptions.ConnectionClosedError, ConnectionRefusedError) as e: + print(f"Publisher {client_id} disconnected: {e}") + except asyncio.CancelledError: + pass + except Exception as e: + print(f"An unexpected error occurred in publisher {client_id}: {e}") + +def print_report(): + print("\n" + "="*80) + print("PERFORMANCE REPORT".center(80)) + print("="*80) + + if not latencies: + print("No messages were received. Cannot generate a report. Is the server running?") + return + + total_sent = messages_sent + total_received = messages_received + message_loss = max(0, total_sent - total_received) + loss_rate = (message_loss / total_sent * 100) if total_sent > 0 else 0 + throughput = total_received / TEST_DURATION_S + + print(f"Test Duration: {TEST_DURATION_S} seconds") + print(f"Total Messages Sent: {total_sent}") + print(f"Total Messages Rcvd: {total_received}") + print(f"Message Loss: {message_loss} ({loss_rate:.2f}%)") + print(f"Actual Throughput: {throughput:.2f} msg/sec") + print("-"*80) + + sorted_latencies = sorted(latencies) + avg_latency_ms = statistics.mean(sorted_latencies) * 1000 + min_latency_ms = sorted_latencies[0] * 1000 + max_latency_ms = sorted_latencies[-1] * 1000 + p50_latency_ms = statistics.median(sorted_latencies) * 1000 + p95_latency_ms = sorted_latencies[int(len(sorted_latencies) * 0.95)] * 1000 + p99_latency_ms = sorted_latencies[int(len(sorted_latencies) * 0.99)] * 1000 + + print("Latency Statistics (ms):") + print(f" Average: {avg_latency_ms:.4f} ms") + print(f" Min: {min_latency_ms:.4f} ms") + print(f" Max: {max_latency_ms:.4f} ms") + print(f" Median (p50): {p50_latency_ms:.4f} ms") + print(f" 95th Percentile: {p95_latency_ms:.4f} ms") + print(f" 99th Percentile: {p99_latency_ms:.4f} ms") + print("="*80) + +async def main(): + print("Starting WebSocket Pub/Sub Load Test...") + print(f"Simulating {NUM_SUBSCRIBERS} subscribers and {NUM_PUBLISHERS} publishers.") + print(f"Publishing at ~{NUM_PUBLISHERS * MESSAGES_PER_SECOND_PER_PUBLISHER} msg/sec for {TEST_DURATION_S} seconds.") + print("-"*80) + + subscriber_tasks = [asyncio.create_task(subscriber_client(i)) for i in range(NUM_SUBSCRIBERS)] + publisher_tasks = [asyncio.create_task(publisher_client(i)) for i in range(NUM_PUBLISHERS)] + all_tasks = subscriber_tasks + publisher_tasks + + try: + await asyncio.sleep(TEST_DURATION_S) + finally: + print("\nTest duration finished. Shutting down clients...") + for task in all_tasks: + task.cancel() + + await asyncio.gather(*all_tasks, return_exceptions=True) + print_report() + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user.") +