From 894add21d5d27f1247565b3897f3624c36978890 Mon Sep 17 00:00:00 2001 From: retoor <retoor@molodetz.nl> Date: Fri, 21 Mar 2025 08:25:05 +0100 Subject: [PATCH] Progress. --- Makefile | 9 +++- pgs_api.h | 33 ++++++++++++-- pgscript.py | 53 +++++++++++++--------- protocol.h | 70 ++++++++++++++++++++++++++++++ py.h | 47 +++++++++++++++++--- sock.h | 123 ++++++++++++++++++++++++++++++++++++++++++++-------- 6 files changed, 287 insertions(+), 48 deletions(-) diff --git a/Makefile b/Makefile index e4d8229..2cbb9a1 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,15 @@ CC=gcc -CFLAGS=-Wall -g +CFLAGS=-Wall -g -I/usr/include/python3.14 -lpython3.14 LDFLAGS=-lm OBJS=pgs.o TARGET=pgs +upstreams: + ssh -f -N -L 3028:127.0.0.1:3028 molodetz.nl + ssh -f -N -L 8082:127.0.0.1:8082 molodetz.nl + + build: $(CC) pgs.c $(CFLAGS) $(LDFLAGS) -o $(TARGET) @@ -12,4 +17,4 @@ run: build ./$(TARGET) clean: - rm -f $(OBJS) $(TARGET) \ No newline at end of file + rm -f $(OBJS) $(TARGET) diff --git a/pgs_api.h b/pgs_api.h index d3c2aec..78c1989 100644 --- a/pgs_api.h +++ b/pgs_api.h @@ -1,6 +1,8 @@ #define PY_SSIZE_T_CLEAN 1 +#include "pgs.h" #include <Python.h> - +#include "py.h" +#include "sock.h" #include <arpa/inet.h> #include <stdbool.h> #include <stdio.h> @@ -8,6 +10,18 @@ #include <string.h> #include <sys/socket.h> +static PyObject *pgs_api_connect(PyObject *self, PyObject *args) { + char *host; + int port; + int fd; + if (!PyArg_ParseTuple(args, "si", &host, &port)) { + return PyLong_FromLong(-1); + } + if ((fd = connect_upstream(host, port)) == -1) { + return PyLong_FromLong(-1); + } + return PyLong_FromLong(fd); +} static PyObject *pgs_api_is_http(PyObject *self, PyObject *args) { const char *py_bytes; if (!PyArg_ParseTuple(args, "y", &py_bytes)) { @@ -35,9 +49,18 @@ static PyObject *pgs_api_read(PyObject *self, PyObject *args) { char buffer[length + 1]; ssize_t bytes_read = read(fd, buffer, length); buffer[bytes_read] = 0; - Py_buffer *pybuffer = (Py_buffer *)malloc(bytes_read); - PyBuffer_FillInfo(pybuffer, 0, &buffer, bytes_read, false, PyBUF_CONTIG); - return PyMemoryView_FromBuffer(pybuffer); + return PyBytes_FromString(buffer); +} + +static PyObject *pgs_api_peek(PyObject *self, PyObject *args) { + int fd, length; + if (!PyArg_ParseTuple(args, "ii", &fd, &length)) { + return NULL; + } + char buffer[length + 1]; + ssize_t bytes_read = recv(fd, buffer, length, MSG_PEEK); + buffer[bytes_read] = 0; + return PyBytes_FromString(buffer); } static PyObject *pgs_api_write(PyObject *self, PyObject *args) { @@ -76,7 +99,9 @@ static PyObject *mymodule_add(PyObject *self, PyObject *args) { // Method table for the module static PyMethodDef MyModuleMethods[] = { {"add", mymodule_add, METH_VARARGS, "Add two numbers"}, + {"connect", pgs_api_connect, METH_VARARGS, "Connect to upstream"}, {"read", pgs_api_read, METH_VARARGS, "Read fd"}, + {"peek", pgs_api_peek, METH_VARARGS, "Peek fd"}, {"write", pgs_api_write, METH_VARARGS, "Write fd"}, {"is_ssh", pgs_api_is_ssh, METH_VARARGS, "Check if header contains SSH data."}, diff --git a/pgscript.py b/pgscript.py index c9d41da..d5e48b5 100644 --- a/pgscript.py +++ b/pgscript.py @@ -39,7 +39,7 @@ def is_http(header_bytes): def is_https(header_bytes): return not any([is_ssh(header_bytes), is_http(header_bytes)]) -def route(downstream,upstream): +def on_connect(downstream): """ This is a connection router which will be called by the server every time a client connects. This function will be used to determine @@ -80,15 +80,17 @@ def route(downstream,upstream): counter += 1 print("Connection nr.", counter) - u = socket.fromfd(upstream, socket.AF_INET, socket.SOCK_STREAM) + #u = socket.socket(socket.AF_INET, socket.SOCK_STREAM) #print("FD:",u.fileno()) - peek = pgs.read(downstream, 4096).tobytes() - + peek = pgs.read(downstream, 4096) + + redirect_to = [] if pgs.is_ssh(peek): print("Forwarding to ssh molodetz") - u.connect(("molodetz.nl", 22)) + redirect_to = "molodetz.nl",22 + u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) elif pgs.is_http(peek): if b'/random' in peek or b'random.' in peek: @@ -96,11 +98,13 @@ def route(downstream,upstream): print("Forwarding to 127.0.0.1:3028.") peek = peek.replace(b'/random', b'/') peek = peek.replace(b'random.', b'') - u.connect(("127.0.0.1", 3028)) + redirect_to = "127.0.0.1",3028 + u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) elif b'molodetz.local' in peek: print("Forwarding to 127.0.0.1:8082.") peek = peek.replace(b'molodetz.local', b'localhost') - u.connect(("127.0.0.1", 8082)) + redirect_to = "127.0.0.1",8082 + u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) elif b'bench.local' in peek: print("Responding with bench page.") body = f"""<html>\n<head>\n<title>Benchmark page.</title>\n</head>\n<body>\n<h1>Bench</h1>\n<p>{counter}</p>\n</body>\n</html>\n""".encode() @@ -127,7 +131,6 @@ Environment: {env} Total connections: {counter} Local hostname: {hostname} Downstream FD: {downstream} -Upstream FD: {upstream} Current time server: {datetime.now()} Server started on: {server_start} Server uptime: {get_server_uptime()} @@ -145,8 +148,8 @@ Server uptime: {get_server_uptime()} "" ] headers = "\r\n".join(headers) - response = f"{headers}{body}" - + #response = f"{headers}{body}" + response = f"{headers}\r\n\r\n{body}" pgs.write(downstream,response) # Unset socket so the server will close it. @@ -157,9 +160,8 @@ Server uptime: {get_server_uptime()} elif is_https(peek) and env == "prod": print("Forwarding to dev.to") - u.connect(("devrant.com", 443)) - peek = peek.replace(b'localhost', b'devrant.com') - peek = peek.replace(b'molodetz.nl', b'devrant.com') + redirect_to = "devrant.com", 443 + u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) else: # Error. print("Could not find upstream for header content.") @@ -171,15 +173,26 @@ Server uptime: {get_server_uptime()} if not u: return -1 - # Remove reference to the socket so it doesn't get garbage collected. - # This could break the connection. This way, it stays open. - u = None - - os.write(upstream,peek) + + #os.write(upstream,peek) # Keep track of connections. Not sure if this is needed. - streams[downstream] = upstream - streams[upstream] = downstream + upstream = u.fileno() + + # Remove reference to the socket so it doesn't get garbage collected. + # This could break the connection. This way, it stays open. + u = None + streams[downstream] = dict(upstream=upstream,upstream_host=redirect_to[0],upstream_port=redirect_to[1]) + streams[upstream] = dict(dowstream=downstream, upstream_host=redirect_to[0],upstream_port=redirect_to[1]) # Return exact same value as what is given as parameter. return upstream + +def on_headers(downstream, headers): + stream = streams[downstream] + if stream['upstream_host'] == b'devrant.com': + headers = headers.replace(b'localhost', b'devrant.com') + headers = headers.replace(b'molodetz.nl', b'devrant.com') + if stream['upstream_host'] == b'molodetz.nl': + headers = headers.replace(b'localhost', b'molodetz.nl') + return headers diff --git a/protocol.h b/protocol.h index c71e9b6..ecd8acf 100644 --- a/protocol.h +++ b/protocol.h @@ -2,3 +2,73 @@ #include <stdio.h> #include <stdlib.h> #include <string.h> +#include <sys/socket.h> + +typedef enum PROTOCOL_STATUS { + PS_NONE, + PS_SNIFF, + PS_HTTP_READ_HEADER, + PS_HTTP_READ_BODY, + PS_STREAM, + PS_ERROR + +} PROTOCOL_STATUS; + +typedef enum PROTOCOL_NAME { + PN_NONE, + PN_HTTP, + PN_HTTP_CHUNKED, + PN_HTTP_WEBSOCKET, + PN_HTTP_KEEP_ALIVE, + PN_HTTP_REQUEST, + PN_SSH, + PN_RAW, + PN_ERROR +} PROTOCOL_NAME ; + +char * http_get_header_value(char * headers, char * key){ + char * result = NULL; + char * start = strstr(headers, key); + if(!start){ + return NULL; + } + start += strlen(key); + start += 2; + char * end = strstr(start, "\r\n"); + if(!end){ + return NULL; + } + result = (char *)malloc(end - start + 1); + strncpy(result, start, end - start); + result[end-start] = 0; + return result; +} + +PROTOCOL_NAME protocol_sniff(int fd){ + char buffer[4096] = {0}; + + ssize_t bytes_received = recv(fd, buffer,sizeof(buffer),MSG_PEEK); + buffer[bytes_received] = 0; + if(bytes_received <= 0){ + return PN_ERROR; + } + + if(strncmp(buffer,"HTTP",4) == 0){ + if(strstr(buffer,"Transfer-Encoding: chunked")) + return PN_HTTP_CHUNKED; + if(strstr(buffer,"Connection: keep-alive")){ + return PN_HTTP_KEEP_ALIVE; + } + return PN_HTTP; + } + if(!strncmp(buffer,"GET ",4) && strstr(buffer,"Upgrade: websocket")){ + return PN_HTTP_WEBSOCKET; + } + if(!strncmp(buffer,"SSH ",4)){ + return PN_SSH; + } + if(!strncmp(buffer,"GET ", 4)){ + return PN_HTTP_REQUEST; + } + return PN_RAW; +} \ No newline at end of file diff --git a/py.h b/py.h index d092c8d..1afd86d 100644 --- a/py.h +++ b/py.h @@ -1,3 +1,5 @@ +#ifndef PGS_PY_H +#define PGS_PY_H #define PY_SSIZE_T_CLEAN 1 #include "pgs_api.h" #include <Python.h> @@ -64,14 +66,47 @@ void py_destruct() { python_initialized = false; } -int py_route(int downstream, int upstream) { +char py_on_headers(int downstream, char *headers) { PyObject *pModule = py_construct(); - long upstream_fd = 0; + char *new_headers = NULL; if (pModule != NULL) { - PyObject *pFunc = PyObject_GetAttrString(pModule, "route"); + PyObject *pFunc = PyObject_GetAttrString(pModule, "on_headers"); if (PyCallable_Check(pFunc)) { PyObject *pArgs = PyTuple_Pack(2, PyLong_FromLong(downstream), - PyLong_FromLong(upstream)); + PyUnicode_FromString(headers)); + PyGILState_STATE gstate = PyGILState_Ensure(); + new_headers = PyBytes_AsString(PyObject_CallObject(pFunc, pArgs)); + PyGILState_Release(gstate); + Py_DECREF(pArgs); + } + } + return new_headers ? new_headers : headers; +} + +char * py_http_intercept_headers(int sock, char * headers){ + PyObject *pModule = py_construct(); + char *new_headers = NULL; + if (pModule != NULL) { + PyObject *pFunc = PyObject_GetAttrString(pModule, "http_intercept_headers"); + if (PyCallable_Check(pFunc)) { + PyObject *pArgs = PyTuple_Pack(2, PyLong_FromLong(sock), + PyUnicode_FromString(headers)); + PyGILState_STATE gstate = PyGILState_Ensure(); + new_headers = PyBytes_AsString(PyObject_CallObject(pFunc, pArgs)); + PyGILState_Release(gstate); + Py_DECREF(pArgs); + } + } + return new_headers ? new_headers : headers; +} + +int py_on_connect(int downstream) { + PyObject *pModule = py_construct(); + int upstream_fd = -1; + if (pModule != NULL) { + PyObject *pFunc = PyObject_GetAttrString(pModule, "on_connect"); + if (PyCallable_Check(pFunc)) { + PyObject *pArgs = PyTuple_Pack(1, PyLong_FromLong(downstream)); PyGILState_STATE gstate = PyGILState_Ensure(); @@ -97,5 +132,7 @@ int py_route(int downstream, int upstream) { fprintf(stderr, "Failed to load 'script'\n"); } - return (int)upstream_fd; + return upstream_fd; } + +#endif \ No newline at end of file diff --git a/sock.h b/sock.h index 462f272..4046525 100644 --- a/sock.h +++ b/sock.h @@ -1,3 +1,6 @@ +#ifndef PGS_SOCK_H +#define PGS_SOCK_H +#include "protocol.h" #include <arpa/inet.h> #include <errno.h> #include <fcntl.h> @@ -12,16 +15,19 @@ #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> - +#include "py.h" #define MAX_EVENTS 8096 #define BUFFER_SIZE 1024 typedef struct { + PROTOCOL_STATUS status; + PROTOCOL_NAME protocol_name; int client_fd; int upstream_fd; char *buffer; size_t buffer_size; size_t buffer_offset; + char *http_intercepted_headers; } connection_t; int listen_fd = 0; @@ -124,7 +130,7 @@ char *sock_read(int fd, char *buf, size_t size) { size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size; ssize_t bytes_read = 0; char *buffer; - buffer[size]; + buffer[size] = 0; if (bytes_to_read) { bytes_read = recv(fd, buffer, size, 0); buffer[bytes_read] = 0; @@ -170,6 +176,18 @@ int forward_data(int from_fd, int to_fd) { return (int)bytes_read; } +ssize_t sock_send_all(int fd, char * content, ssize_t length) { + ssize_t bytes_total_sent = 0; + while(bytes_total_sent < length) { + ssize_t bytes_sent = send(fd, content + bytes_total_sent, length, 0); + if(bytes_sent <= 0){ + return -1; + } + bytes_total_sent += bytes_sent; + } + return bytes_total_sent; +} + bool handle_connect(struct epoll_event event, int epoll_fd) { struct sockaddr_in client_addr; socklen_t client_len = sizeof(client_addr); @@ -187,7 +205,8 @@ bool handle_connect(struct epoll_event event, int epoll_fd) { client_event.data.fd = client_fd; connections[client_fd]->upstream_fd = -1; connections[client_fd]->client_fd = client_fd; - + connections[client_fd]->status = PS_SNIFF; + connections[client_fd]->protocol_name = PS_NONE; printf("New connection: client_fd=%d\n", client_fd); epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event); @@ -200,34 +219,103 @@ void handle_close(int epoll_fd, connection_t *conn) { close_connection(epoll_fd, conn); } +char * sock_read_until(int fd, char * until, int until_len) { + + char * result = (char *)malloc(BUFFER_SIZE); + result[0] = 0; + int bytes; + int bytes_total = 0; + char buffer[2]; + while((bytes = recv(fd, buffer, 1,0)) > 0){ + result[bytes_total] = buffer[0]; + bytes_total += bytes; + if(bytes_total >= until_len){ + if(!strncmp(buffer + bytes_total - until_len, until, until_len)){ + break; + } + } + } + result[bytes_total] = 0; + return result; +} + +void http_intercept_headers(int epoll_fd, connection_t *conn) { + char * headers = sock_read_until(conn->client_fd, "\r\n\r\n", 4); + + char * intercepted_headers = py_http_intercept_headers(conn->client_fd,headers); + if(!intercepted_headers){ + close_connection(epoll_fd, conn); + return; + } + sock_send_all(conn->client_fd, intercepted_headers, strlen(intercepted_headers)); + printf("Intercepted headers: %s\n", intercepted_headers); + conn->http_intercepted_headers = intercepted_headers; + conn->status = PS_NONE; + connections[conn->upstream_fd]->status = PS_HTTP_READ_HEADER; +} + + +void http_intercept_body(int epoll_fd, connection_t *conn) { + printf("HIERR\n"); +} + void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) { if (conn->upstream_fd == -1) { - conn->upstream_fd = prepare_upstream(); - int upstream_fd = py_route(conn->client_fd, conn->upstream_fd); - - if (upstream_fd == -1) { + conn->upstream_fd = py_on_connect(conn->client_fd); + if(conn->upstream_fd == -1){ close_connection(epoll_fd, conn); return; } - set_nonblocking(upstream_fd); + conn->protocol_name = protocol_sniff(conn->client_fd); + if(conn->protocol_name == PN_ERROR){ + close_connection(epoll_fd, conn); + return; + }else if(conn->protocol_name == PN_HTTP){ + conn->status = PS_HTTP_READ_HEADER; + }else if(conn->protocol_name == PN_HTTP_CHUNKED){ + conn->status = PS_HTTP_READ_HEADER; + }else if(conn->protocol_name == PN_HTTP_KEEP_ALIVE){ + conn->status = PS_HTTP_READ_HEADER; + }else if(conn->protocol_name == PN_HTTP_REQUEST){ + conn->status = PS_HTTP_READ_HEADER; + }else if(conn->protocol_name == PN_HTTP_WEBSOCKET){ + conn->status = PS_HTTP_READ_HEADER; + }else if(conn->protocol_name == PN_SSH){ + conn->status = PS_STREAM; + }else if(conn->protocol_name == PN_RAW){ + conn->status = PS_STREAM; + }else { + conn->status = PS_ERROR; + } + + if (conn->upstream_fd == -1 || conn->status == PS_ERROR) { + close_connection(epoll_fd, conn); + return; + } + set_nonblocking(conn->upstream_fd); struct epoll_event upstream_event; upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP; - upstream_event.data.ptr = connections[upstream_fd]; - upstream_event.data.fd = upstream_fd; + upstream_event.data.ptr = connections[conn->upstream_fd]; + upstream_event.data.fd = conn->upstream_fd; + connections[conn->client_fd]->upstream_fd = conn->upstream_fd; + connections[conn->upstream_fd]->client_fd = conn->client_fd; + connections[conn->upstream_fd]->upstream_fd = conn->upstream_fd; - connections[conn->client_fd]->upstream_fd = upstream_fd; - connections[upstream_fd]->client_fd = conn->client_fd; - connections[upstream_fd]->upstream_fd = upstream_fd; - - epoll_ctl(epoll_fd, EPOLL_CTL_ADD, upstream_fd, &upstream_event); + epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn->upstream_fd, &upstream_event); printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd, conn->upstream_fd); + return; } if (event.data.fd == conn->client_fd) { - + if(conn->status = PS_HTTP_READ_HEADER){ + http_intercept_headers(epoll_fd, conn); + } + if(conn->status == PS_HTTP_READ_BODY){ + http_intercept_body(epoll_fd, conn); + } if (forward_data(conn->client_fd, conn->upstream_fd) < 1) { close_connection(epoll_fd, conn); } @@ -290,4 +378,5 @@ void serve(int port) { } } } -} \ No newline at end of file +} +#endif \ No newline at end of file