#ifndef PGS_SOCK_H
#define PGS_SOCK_H
#include "protocol.h"
#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <pthread.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "py.h"
#define MAX_EVENTS 8096
#define BUFFER_SIZE 1024
typedef struct {
PROTOCOL_STATUS status;
PROTOCOL_NAME protocol_name;
int client_fd;
int upstream_fd;
char *buffer;
size_t buffer_size;
size_t buffer_offset;
char *http_intercepted_headers;
} connection_t;
int listen_fd = 0;
int epoll_fd = 0;
connection_t connections[MAX_EVENTS][sizeof(connection_t)] = {0};
int sock_init(void);
void sock_exit(void);
void set_nonblocking(int fd) {
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
perror("fcntl get");
exit(EXIT_FAILURE);
}
if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
perror("fcntl set");
exit(EXIT_FAILURE);
}
}
int prepare_upstream() {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
return sockfd;
}
int connect_upstream(const char *host, int port) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd == -1) {
perror("socket");
return -1;
}
set_nonblocking(sockfd);
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
if (inet_pton(AF_INET, host, &server_addr.sin_addr) <= 0) {
perror("inet_pton");
close(sockfd);
return -1;
}
if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) ==
-1) {
if (errno != EINPROGRESS) {
perror("connect");
close(sockfd);
return -1;
}
}
return sockfd;
}
int create_listening_socket(int port) {
int listen_fd = socket(AF_INET, SOCK_STREAM, 0);
if (listen_fd == -1) {
perror("socket");
return -1;
}
int opt = 1;
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) ==
-1) {
perror("setsockopt");
close(listen_fd);
return -1;
}
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(port);
if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) ==
-1) {
perror("bind");
close(listen_fd);
return -1;
}
if (listen(listen_fd, SOMAXCONN) == -1) {
perror("listen");
close(listen_fd);
return -1;
}
set_nonblocking(listen_fd);
return listen_fd;
}
char *sock_read(int fd, char *buf, size_t size) {
connection_t *conn = connections[fd];
size_t left_in_buffer = conn->buffer_size - conn->buffer_offset;
size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size;
ssize_t bytes_read = 0;
char *buffer;
buffer[size] = 0;
if (bytes_to_read) {
bytes_read = recv(fd, buffer, size, 0);
buffer[bytes_read] = 0;
}
memcpy(buf, conn->buffer + conn->buffer_offset, bytes_to_read);
if (bytes_read > 0) {
return buf;
} else if (bytes_read == 0) {
printf("Connection closed by remote (fd=%d)\n", fd);
} else {
perror("read");
}
return NULL;
}
void close_connection(int epoll_fd, connection_t *conn) {
if (conn->client_fd != -1) {
epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->client_fd, NULL);
close(conn->client_fd);
}
if (conn->upstream_fd != -1) {
epoll_ctl(epoll_fd, EPOLL_CTL_DEL, conn->upstream_fd, NULL);
close(conn->upstream_fd);
}
}
int forward_data(int from_fd, int to_fd) {
static char buffer[BUFFER_SIZE];
// Feels great to do somehow. Better safe than sorry.
memset(buffer, 0, BUFFER_SIZE);
ssize_t bytes_read = recv(from_fd, buffer, sizeof(buffer), 0);
if (bytes_read > 0) {
ssize_t bytes_written = send(to_fd, buffer, bytes_read, 0);
if (bytes_written == -1) {
perror("write");
}
} else if (bytes_read == 0) {
printf("Connection closed by remote (fd=%d)\n", from_fd);
} else {
perror("read");
}
return (int)bytes_read;
}
ssize_t sock_send_all(int fd, char * content, ssize_t length) {
ssize_t bytes_total_sent = 0;
while(bytes_total_sent < length) {
ssize_t bytes_sent = send(fd, content + bytes_total_sent, length, 0);
if(bytes_sent <= 0){
return -1;
}
bytes_total_sent += bytes_sent;
}
return bytes_total_sent;
}
bool handle_connect(struct epoll_event event, int epoll_fd) {
struct sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
int client_fd =
accept(listen_fd, (struct sockaddr *)&client_addr, &client_len);
if (client_fd == -1) {
perror("accept");
return false;
}
set_nonblocking(client_fd);
struct epoll_event client_event;
client_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
client_event.data.ptr = connections[client_fd];
client_event.data.fd = client_fd;
connections[client_fd]->upstream_fd = -1;
connections[client_fd]->client_fd = client_fd;
connections[client_fd]->status = PS_SNIFF;
connections[client_fd]->protocol_name = PS_NONE;
printf("New connection: client_fd=%d\n", client_fd);
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event);
return true;
}
void handle_close(int epoll_fd, connection_t *conn) {
printf("Connection closed: client_fd=%d, upstream_fd=%d\n", conn->client_fd,
conn->upstream_fd);
close_connection(epoll_fd, conn);
}
char * sock_read_until(int fd, char * until, int until_len) {
char * result = (char *)malloc(BUFFER_SIZE);
result[0] = 0;
int bytes;
int bytes_total = 0;
char buffer[2];
while((bytes = recv(fd, buffer, 1,0)) > 0){
result[bytes_total] = buffer[0];
bytes_total += bytes;
if(bytes_total >= until_len){
if(!strncmp(buffer + bytes_total - until_len, until, until_len)){
break;
}
}
}
result[bytes_total] = 0;
return result;
}
void http_intercept_headers(int epoll_fd, connection_t *conn) {
char * headers = sock_read_until(conn->client_fd, "\r\n\r\n", 4);
char * intercepted_headers = py_http_intercept_headers(conn->client_fd,headers);
if(!intercepted_headers){
close_connection(epoll_fd, conn);
return;
}
sock_send_all(conn->client_fd, intercepted_headers, strlen(intercepted_headers));
printf("Intercepted headers: %s\n", intercepted_headers);
conn->http_intercepted_headers = intercepted_headers;
conn->status = PS_NONE;
connections[conn->upstream_fd]->status = PS_HTTP_READ_HEADER;
}
void http_intercept_body(int epoll_fd, connection_t *conn) {
printf("HIERR\n");
}
void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) {
if (conn->upstream_fd == -1) {
conn->upstream_fd = py_on_connect(conn->client_fd);
if(conn->upstream_fd == -1){
close_connection(epoll_fd, conn);
return;
}
conn->protocol_name = protocol_sniff(conn->client_fd);
if(conn->protocol_name == PN_ERROR){
close_connection(epoll_fd, conn);
return;
}else if(conn->protocol_name == PN_HTTP){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_CHUNKED){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_KEEP_ALIVE){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_REQUEST){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_WEBSOCKET){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_SSH){
conn->status = PS_STREAM;
}else if(conn->protocol_name == PN_RAW){
conn->status = PS_STREAM;
}else {
conn->status = PS_ERROR;
}
if (conn->upstream_fd == -1 || conn->status == PS_ERROR) {
close_connection(epoll_fd, conn);
return;
}
set_nonblocking(conn->upstream_fd);
struct epoll_event upstream_event;
upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
upstream_event.data.ptr = connections[conn->upstream_fd];
upstream_event.data.fd = conn->upstream_fd;
connections[conn->client_fd]->upstream_fd = conn->upstream_fd;
connections[conn->upstream_fd]->client_fd = conn->client_fd;
connections[conn->upstream_fd]->upstream_fd = conn->upstream_fd;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn->upstream_fd, &upstream_event);
printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd,
conn->upstream_fd);
return;
}
if (event.data.fd == conn->client_fd) {
if(conn->status = PS_HTTP_READ_HEADER){
http_intercept_headers(epoll_fd, conn);
}
if(conn->status == PS_HTTP_READ_BODY){
http_intercept_body(epoll_fd, conn);
}
if (forward_data(conn->client_fd, conn->upstream_fd) < 1) {
close_connection(epoll_fd, conn);
}
} else if (event.data.fd == conn->upstream_fd) {
if (forward_data(conn->upstream_fd, conn->client_fd) < 1) {
close_connection(epoll_fd, conn);
}
}
}
void serve(int port) {
listen_fd = create_listening_socket(port);
if (listen_fd == -1) {
fprintf(stderr, "Failed to create listening socket\n");
return;
}
epoll_fd = epoll_create1(0);
if (epoll_fd == -1) {
perror("epoll_create1");
close(listen_fd);
return;
}
struct epoll_event event;
event.events = EPOLLIN;
event.data.fd = listen_fd;
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_fd, &event) == -1) {
perror("epoll_ctl");
close(listen_fd);
close(epoll_fd);
return;
}
struct epoll_event events[MAX_EVENTS];
memset(events, 0, sizeof(events));
printf("Pretty Good Server listening on port %d\n", port);
while (1) {
int num_events = epoll_wait(epoll_fd, events, MAX_EVENTS, -1);
if (num_events == -1) {
perror("epoll_wait");
break;
}
for (int i = 0; i < num_events; i++) {
if (events[i].data.fd == listen_fd) {
handle_connect(events[i], epoll_fd);
} else {
connection_t *conn = connections[events[i].data.fd];
if (events[i].events & (EPOLLHUP | EPOLLERR)) {
handle_close(epoll_fd, conn);
} else if (events[i].events & EPOLLIN) {
handle_stream(events[i], epoll_fd, conn);
}
}
}
}
}
#endif