// Written by retoor@molodetz.nl

// This source code performs a load test by simulating multiple IRC clients that connect to an IRC server, join a channel, and send a series of messages. It measures the total messages sent, bytes sent, and received.

// Uses the following non-standard C libraries: pthread, sys/socket, netinet/in.h, arpa/inet.h

// MIT License

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <pthread.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <time.h>

#define SERVER "127.0.0.1"
#define PORT 6667
#define CHANNEL "#loadtest"
#define NUM_CLIENTS 300
#define MESSAGES_PER_CLIENT 100
#define MESSAGE_INTERVAL 100000

int total_messages = MESSAGES_PER_CLIENT * NUM_CLIENTS;
long long total_bytes_sent = 0;
long long total_bytes_received = 0;
pthread_mutex_t bytes_mutex = PTHREAD_MUTEX_INITIALIZER;

void random_nick(char *nick, size_t len) {
    const char charset[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
    strcpy(nick, "LT");
    for (size_t i = 2; i < len - 1; ++i) {
        int key = rand() % (int)(sizeof(charset) - 1);
        nick[i] = charset[key];
    }
    nick[len-1] = '\0';
}

void *irc_client(void *arg) {
    int client_id = *(int *)arg;
    char nick[11];
    random_nick(nick, sizeof(nick));
    char user[64];
    snprintf(user, sizeof(user), "%s 0 * :Load Tester", nick);

    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd < 0) {
        perror("socket");
        pthread_exit(NULL);
    }

    struct sockaddr_in serv_addr;
    memset(&serv_addr, 0, sizeof(serv_addr));
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_port = htons(PORT);
    inet_pton(AF_INET, SERVER, &serv_addr.sin_addr);

    if (connect(sockfd, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {
        printf("[%s] Error: connect failed\n", nick);
        close(sockfd);
        pthread_exit(NULL);
    }

    char buf[512];
    int sent;
    snprintf(buf, sizeof(buf), "NICK %s\r\n", nick);
    sent = send(sockfd, buf, strlen(buf), 0);
    pthread_mutex_lock(&bytes_mutex);
    if (sent > 0) total_bytes_sent += sent;
    pthread_mutex_unlock(&bytes_mutex);

    snprintf(buf, sizeof(buf), "USER %s\r\n", user);
    sent = send(sockfd, buf, strlen(buf), 0);
    pthread_mutex_lock(&bytes_mutex);
    if (sent > 0) total_bytes_sent += sent;
    pthread_mutex_unlock(&bytes_mutex);
    usleep(1000000);

    snprintf(buf, sizeof(buf), "JOIN %s\r\n", CHANNEL);
    sent = send(sockfd, buf, strlen(buf), 0);
    pthread_mutex_lock(&bytes_mutex);
    if (sent > 0) total_bytes_sent += sent;
    pthread_mutex_unlock(&bytes_mutex);
    usleep(1000000);

    for (int i = 0; i < MESSAGES_PER_CLIENT; ++i) {
        char buf[1024];
        char msg[250];
        memset(msg, 'a', sizeof(msg));
        msg[sizeof(msg)-1] = '\0';
        
        snprintf(buf, sizeof(buf), "PRIVMSG %s :%s from %s #%d\r\n", CHANNEL, msg, nick, i);
        sent = send(sockfd, buf, strlen(buf), 0);
        pthread_mutex_lock(&bytes_mutex);
        if (sent > 0) total_bytes_sent += sent;
        pthread_mutex_unlock(&bytes_mutex);

        int n = recv(sockfd, buf, sizeof(buf)-1, MSG_DONTWAIT);
        pthread_mutex_lock(&bytes_mutex);
        if (n > 0) total_bytes_received += n;
        pthread_mutex_unlock(&bytes_mutex);
        usleep(MESSAGE_INTERVAL);
    }
    
    sent = send(sockfd, "QUIT\r\n", 6, 0);
    pthread_mutex_lock(&bytes_mutex);
    if (sent > 0) total_bytes_sent += sent;
    pthread_mutex_unlock(&bytes_mutex);
    close(sockfd);
    pthread_exit(NULL);
}

int main() {
    srand(time(NULL));
    double start_time = (double)clock() / CLOCKS_PER_SEC;
    pthread_t threads[NUM_CLIENTS];
    int ids[NUM_CLIENTS];

    for (int i = 0; i < NUM_CLIENTS; ++i) {
        ids[i] = i;
        pthread_create(&threads[i], NULL, irc_client, &ids[i]);
        usleep(50000);
    }
    for (int i = 0; i < NUM_CLIENTS; ++i) {
        pthread_join(threads[i], NULL);
    }
    double end_time = (double)clock() / CLOCKS_PER_SEC;
    double total_time = end_time - start_time;
    printf("Total time: %.2f seconds\n", total_time);
    printf("Total messages: %d\n", total_messages);
    printf("Messages per second: %.2f\n", total_messages / total_time);
    printf("Total bytes sent: %lld (%.2f MB)\n", total_bytes_sent, total_bytes_sent / (1024.0 * 1024.0));
    printf("Total bytes received: %lld (%.2f MB)\n", total_bytes_received, total_bytes_received / (1024.0 * 1024.0));
    printf("Send rate: %.2f MB/s\n", (total_bytes_sent / (1024.0 * 1024.0)) / total_time);
    printf("Receive rate: %.2f MB/s\n", (total_bytes_received / (1024.0 * 1024.0)) / total_time);
    printf("Load test complete.\n");
    return 0;
}