#ifndef PGS_SOCK_H #define PGS_SOCK_H #include "protocol.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "py.h" #define MAX_EVENTS 8096 #define BUFFER_SIZE 1024 typedef struct { PROTOCOL_STATUS status; PROTOCOL_NAME protocol_name; int client_fd; int upstream_fd; char *buffer; size_t buffer_size; size_t buffer_offset; char *http_intercepted_headers; } connection_t; int listen_fd = 0; int epoll_fd = 0; connection_t connections[MAX_EVENTS][sizeof(connection_t)] = {0}; int sock_init(void); void sock_exit(void); void set_nonblocking(int fd) { int flags = fcntl(fd, F_GETFL, 0); if (flags == -1) { perror("fcntl get"); exit(EXIT_FAILURE); } if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) { perror("fcntl set"); exit(EXIT_FAILURE); } } int prepare_upstream() { int sockfd = socket(AF_INET, SOCK_STREAM, 0); return sockfd; } int connect_upstream(const char *host, int port) { int sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd == -1) { perror("socket"); return -1; } set_nonblocking(sockfd); struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(port); if (inet_pton(AF_INET, host, &server_addr.sin_addr) <= 0) { perror("inet_pton"); close(sockfd); return -1; } if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { if (errno != EINPROGRESS) { perror("connect"); close(sockfd); return -1; } } return sockfd; } int create_listening_socket(int port) { int listen_fd = socket(AF_INET, SOCK_STREAM, 0); if (listen_fd == -1) { perror("socket"); return -1; } int opt = 1; if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) { perror("setsockopt"); close(listen_fd); return -1; } struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) { perror("bind"); close(listen_fd); return -1; } if (listen(listen_fd, SOMAXCONN) == -1) { perror("listen"); close(listen_fd); return -1; } set_nonblocking(listen_fd); return listen_fd; } char *sock_read(int fd, char *buf, size_t size) { connection_t *conn = connections[fd]; size_t left_in_buffer = conn->buffer_size - conn->buffer_offset; size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size; ssize_t bytes_read = 0; char *buffer; buffer[size] = 0; if (bytes_to_read) { bytes_read = recv(fd, buffer, size, 0); buffer[bytes_read] = 0; } memcpy(buf, conn->buffer + conn->buffer_offset, bytes_to_read); if (bytes_read > 0) { return buf; } else if (bytes_read == 0) { printf("Connection closed by remote (fd=%d)\n", fd); } else { perror("read"); } return NULL; } void close_connection(int epoll_fd, connection_t *conn) { if (conn->client_fd != -1) { epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->client_fd, NULL); close(conn->client_fd); } if (conn->upstream_fd != -1) { epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->upstream_fd, NULL); close(conn->upstream_fd); } } int forward_data(int from_fd, int to_fd) { static char buffer[BUFFER_SIZE]; // Feels great to do somehow. Better safe than sorry. memset(buffer, 0, BUFFER_SIZE); ssize_t bytes_read = recv(from_fd, buffer, sizeof(buffer), 0); if (bytes_read > 0) { ssize_t bytes_written = send(to_fd, buffer, bytes_read, 0); if (bytes_written == -1) { perror("write"); } } else if (bytes_read == 0) { printf("Connection closed by remote (fd=%d)\n", from_fd); } else { perror("read"); } return (int)bytes_read; } ssize_t sock_send_all(int fd, char * content, ssize_t length) { ssize_t bytes_total_sent = 0; while(bytes_total_sent < length) { ssize_t bytes_sent = send(fd, content + bytes_total_sent, length, 0); if(bytes_sent <= 0){ return -1; } bytes_total_sent += bytes_sent; } return bytes_total_sent; } bool handle_connect(struct epoll_event event, int epoll_fd) { struct sockaddr_in client_addr; socklen_t client_len = sizeof(client_addr); int client_fd = accept(listen_fd, (struct sockaddr *)&client_addr, &client_len); if (client_fd == -1) { perror("accept"); return false; } set_nonblocking(client_fd); struct epoll_event client_event; client_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; client_event.data.ptr = connections[client_fd]; client_event.data.fd = client_fd; connections[client_fd]->upstream_fd = -1; connections[client_fd]->client_fd = client_fd; connections[client_fd]->status = PS_SNIFF; connections[client_fd]->protocol_name = PS_NONE; printf("New connection: client_fd=%d\n", client_fd); epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event); return true; } void handle_close(int epoll_fd, connection_t *conn) { printf("Connection closed: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); close_connection(epoll_fd, conn); } char * sock_read_until(int fd, char * until, int until_len) { char * result = (char *)malloc(BUFFER_SIZE); result[0] = 0; int bytes; int bytes_total = 0; char buffer[2]; while((bytes = recv(fd, buffer, 1,0)) > 0){ result[bytes_total] = buffer[0]; bytes_total += bytes; if(bytes_total >= until_len){ if(!strncmp(buffer + bytes_total - until_len, until, until_len)){ break; } } } result[bytes_total] = 0; return result; } void http_intercept_headers(int epoll_fd, connection_t *conn) { char * headers = sock_read_until(conn->client_fd, "\r\n\r\n", 4); char * intercepted_headers = py_http_intercept_headers(conn->client_fd,headers); if(!intercepted_headers){ close_connection(epoll_fd, conn); return; } sock_send_all(conn->client_fd, intercepted_headers, strlen(intercepted_headers)); printf("Intercepted headers: %s\n", intercepted_headers); conn->http_intercepted_headers = intercepted_headers; conn->status = PS_NONE; connections[conn->upstream_fd]->status = PS_HTTP_READ_HEADER; } void http_intercept_body(int epoll_fd, connection_t *conn) { printf("HIERR\n"); } void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) { if (conn->upstream_fd == -1) { conn->upstream_fd = py_on_connect(conn->client_fd); if(conn->upstream_fd == -1){ close_connection(epoll_fd, conn); return; } conn->protocol_name = protocol_sniff(conn->client_fd); if(conn->protocol_name == PN_ERROR){ close_connection(epoll_fd, conn); return; }else if(conn->protocol_name == PN_HTTP){ conn->status = PS_HTTP_READ_HEADER; }else if(conn->protocol_name == PN_HTTP_CHUNKED){ conn->status = PS_HTTP_READ_HEADER; }else if(conn->protocol_name == PN_HTTP_KEEP_ALIVE){ conn->status = PS_HTTP_READ_HEADER; }else if(conn->protocol_name == PN_HTTP_REQUEST){ conn->status = PS_HTTP_READ_HEADER; }else if(conn->protocol_name == PN_HTTP_WEBSOCKET){ conn->status = PS_HTTP_READ_HEADER; }else if(conn->protocol_name == PN_SSH){ conn->status = PS_STREAM; }else if(conn->protocol_name == PN_RAW){ conn->status = PS_STREAM; }else { conn->status = PS_ERROR; } if (conn->upstream_fd == -1 || conn->status == PS_ERROR) { close_connection(epoll_fd, conn); return; } set_nonblocking(conn->upstream_fd); struct epoll_event upstream_event; upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; upstream_event.data.ptr = connections[conn->upstream_fd]; upstream_event.data.fd = conn->upstream_fd; connections[conn->client_fd]->upstream_fd = conn->upstream_fd; connections[conn->upstream_fd]->client_fd = conn->client_fd; connections[conn->upstream_fd]->upstream_fd = conn->upstream_fd; epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn->upstream_fd, &upstream_event); printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); return; } if (event.data.fd == conn->client_fd) { if(conn->status = PS_HTTP_READ_HEADER){ http_intercept_headers(epoll_fd, conn); } if(conn->status == PS_HTTP_READ_BODY){ http_intercept_body(epoll_fd, conn); } if (forward_data(conn->client_fd, conn->upstream_fd) < 1) { close_connection(epoll_fd, conn); } } else if (event.data.fd == conn->upstream_fd) { if (forward_data(conn->upstream_fd, conn->client_fd) < 1) { close_connection(epoll_fd, conn); } } } void serve(int port) { listen_fd = create_listening_socket(port); if (listen_fd == -1) { fprintf(stderr, "Failed to create listening socket\n"); return; } epoll_fd = epoll_create1(0); if (epoll_fd == -1) { perror("epoll_create1"); close(listen_fd); return; } struct epoll_event event; event.events = EPOLLIN; event.data.fd = listen_fd; if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_fd, &event) == -1) { perror("epoll_ctl"); close(listen_fd); close(epoll_fd); return; } struct epoll_event events[MAX_EVENTS]; memset(events, 0, sizeof(events)); printf("Pretty Good Server listening on port %d\n", port); while (1) { int num_events = epoll_wait(epoll_fd, events, MAX_EVENTS, -1); if (num_events == -1) { perror("epoll_wait"); break; } for (int i = 0; i < num_events; i++) { if (events[i].data.fd == listen_fd) { handle_connect(events[i], epoll_fd); } else { connection_t *conn = connections[events[i].data.fd]; if (events[i].events & (EPOLLHUP | EPOLLERR)) { handle_close(epoll_fd, conn); } else if (events[i].events & EPOLLIN) { handle_stream(events[i], epoll_fd, conn); } } } } } #endif