From 894add21d5d27f1247565b3897f3624c36978890 Mon Sep 17 00:00:00 2001
From: retoor <retoor@molodetz.nl>
Date: Fri, 21 Mar 2025 08:25:05 +0100
Subject: [PATCH] Progress.

---
 Makefile    |   9 +++-
 pgs_api.h   |  33 ++++++++++++--
 pgscript.py |  53 +++++++++++++---------
 protocol.h  |  70 ++++++++++++++++++++++++++++++
 py.h        |  47 +++++++++++++++++---
 sock.h      | 123 ++++++++++++++++++++++++++++++++++++++++++++--------
 6 files changed, 287 insertions(+), 48 deletions(-)

diff --git a/Makefile b/Makefile
index e4d8229..2cbb9a1 100644
--- a/Makefile
+++ b/Makefile
@@ -1,10 +1,15 @@
 CC=gcc 
-CFLAGS=-Wall -g
+CFLAGS=-Wall -g -I/usr/include/python3.14 -lpython3.14
 LDFLAGS=-lm
 OBJS=pgs.o
 TARGET=pgs	
 
 
+upstreams:
+	ssh -f -N -L 3028:127.0.0.1:3028 molodetz.nl
+	ssh -f -N -L 8082:127.0.0.1:8082 molodetz.nl
+
+
 build:
 	$(CC) pgs.c $(CFLAGS)  $(LDFLAGS) -o $(TARGET)
 
@@ -12,4 +17,4 @@ run: build
 	./$(TARGET)
 
 clean:	
-	rm -f $(OBJS) $(TARGET)
\ No newline at end of file
+	rm -f $(OBJS) $(TARGET)
diff --git a/pgs_api.h b/pgs_api.h
index d3c2aec..78c1989 100644
--- a/pgs_api.h
+++ b/pgs_api.h
@@ -1,6 +1,8 @@
 #define PY_SSIZE_T_CLEAN 1
+#include "pgs.h"
 #include <Python.h>
-
+#include "py.h"
+#include "sock.h"
 #include <arpa/inet.h>
 #include <stdbool.h>
 #include <stdio.h>
@@ -8,6 +10,18 @@
 #include <string.h>
 #include <sys/socket.h>
 
+static PyObject *pgs_api_connect(PyObject *self, PyObject *args) {
+  char *host;
+  int port;
+  int fd;
+  if (!PyArg_ParseTuple(args, "si", &host, &port)) {
+    return PyLong_FromLong(-1);
+  }
+  if ((fd = connect_upstream(host, port)) == -1) {
+    return PyLong_FromLong(-1);
+  }
+  return PyLong_FromLong(fd);
+}
 static PyObject *pgs_api_is_http(PyObject *self, PyObject *args) {
   const char *py_bytes;
   if (!PyArg_ParseTuple(args, "y", &py_bytes)) {
@@ -35,9 +49,18 @@ static PyObject *pgs_api_read(PyObject *self, PyObject *args) {
   char buffer[length + 1];
   ssize_t bytes_read = read(fd, buffer, length);
   buffer[bytes_read] = 0;
-  Py_buffer *pybuffer = (Py_buffer *)malloc(bytes_read);
-  PyBuffer_FillInfo(pybuffer, 0, &buffer, bytes_read, false, PyBUF_CONTIG);
-  return PyMemoryView_FromBuffer(pybuffer);
+  return PyBytes_FromString(buffer);
+}
+
+static PyObject *pgs_api_peek(PyObject *self, PyObject *args) {
+  int fd, length;
+  if (!PyArg_ParseTuple(args, "ii", &fd, &length)) {
+    return NULL;
+  }
+  char buffer[length + 1];
+  ssize_t bytes_read = recv(fd, buffer, length, MSG_PEEK);
+  buffer[bytes_read] = 0;
+  return PyBytes_FromString(buffer);
 }
 
 static PyObject *pgs_api_write(PyObject *self, PyObject *args) {
@@ -76,7 +99,9 @@ static PyObject *mymodule_add(PyObject *self, PyObject *args) {
 // Method table for the module
 static PyMethodDef MyModuleMethods[] = {
     {"add", mymodule_add, METH_VARARGS, "Add two numbers"},
+    {"connect", pgs_api_connect, METH_VARARGS, "Connect to upstream"},
     {"read", pgs_api_read, METH_VARARGS, "Read fd"},
+    {"peek", pgs_api_peek, METH_VARARGS, "Peek fd"},
     {"write", pgs_api_write, METH_VARARGS, "Write fd"},
     {"is_ssh", pgs_api_is_ssh, METH_VARARGS,
      "Check if header contains SSH data."},
diff --git a/pgscript.py b/pgscript.py
index c9d41da..d5e48b5 100644
--- a/pgscript.py
+++ b/pgscript.py
@@ -39,7 +39,7 @@ def is_http(header_bytes):
 def is_https(header_bytes):
     return not any([is_ssh(header_bytes), is_http(header_bytes)])
 
-def route(downstream,upstream):
+def on_connect(downstream):
     """
     This is a connection router which will be called by the server every 
     time a client connects. This function will be used to determine
@@ -80,15 +80,17 @@ def route(downstream,upstream):
     counter += 1
 
     print("Connection nr.", counter)
-    u = socket.fromfd(upstream, socket.AF_INET, socket.SOCK_STREAM)
+
     #u = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     #print("FD:",u.fileno())
-    peek = pgs.read(downstream, 4096).tobytes()
-    
+    peek = pgs.read(downstream, 4096)
+
+    redirect_to = []
     
     if pgs.is_ssh(peek):
         print("Forwarding to ssh molodetz")
-        u.connect(("molodetz.nl", 22))
+        redirect_to = "molodetz.nl",22
+        u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM)
     elif pgs.is_http(peek):
         
         if b'/random' in peek or b'random.' in peek:
@@ -96,11 +98,13 @@ def route(downstream,upstream):
             print("Forwarding to 127.0.0.1:3028.")
             peek = peek.replace(b'/random', b'/')
             peek = peek.replace(b'random.', b'')
-            u.connect(("127.0.0.1", 3028))   
+            redirect_to = "127.0.0.1",3028
+            u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) 
         elif b'molodetz.local' in peek:
             print("Forwarding to 127.0.0.1:8082.")
             peek = peek.replace(b'molodetz.local', b'localhost')
-            u.connect(("127.0.0.1", 8082))
+            redirect_to = "127.0.0.1",8082
+            u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) 
         elif b'bench.local' in peek:
             print("Responding with bench page.")
             body = f"""<html>\n<head>\n<title>Benchmark page.</title>\n</head>\n<body>\n<h1>Bench</h1>\n<p>{counter}</p>\n</body>\n</html>\n""".encode()
@@ -127,7 +131,6 @@ Environment: {env}
 Total connections: {counter}
 Local hostname: {hostname}
 Downstream FD: {downstream}
-Upstream FD: {upstream}
 Current time server: {datetime.now()}
 Server started on: {server_start}
 Server uptime: {get_server_uptime()}
@@ -145,8 +148,8 @@ Server uptime: {get_server_uptime()}
                     ""
                 ]
                 headers = "\r\n".join(headers)
-                response = f"{headers}{body}" 
-
+                #response = f"{headers}{body}" 
+                response = f"{headers}\r\n\r\n{body}"
                 pgs.write(downstream,response)
             
             # Unset socket so the server will close it.
@@ -157,9 +160,8 @@ Server uptime: {get_server_uptime()}
 
     elif is_https(peek) and env == "prod":
         print("Forwarding to dev.to")
-        u.connect(("devrant.com", 443))
-        peek = peek.replace(b'localhost', b'devrant.com')
-        peek = peek.replace(b'molodetz.nl', b'devrant.com')
+        redirect_to = "devrant.com", 443
+        u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM) 
     else:
         # Error.
         print("Could not find upstream for header content.")
@@ -171,15 +173,26 @@ Server uptime: {get_server_uptime()}
     if not u:
         return -1
     
-    # Remove reference to the socket so it doesn't get garbage collected.
-    # This could break the connection. This way, it stays open.
-    u = None
-
-    os.write(upstream,peek)
+    
+    #os.write(upstream,peek)
     
     # Keep track of connections. Not sure if this is needed.
-    streams[downstream] = upstream
-    streams[upstream] = downstream
+    upstream = u.fileno()
+    
+    # Remove reference to the socket so it doesn't get garbage collected.
+    # This could break the connection. This way, it stays open.
+    u = None 
+    streams[downstream] = dict(upstream=upstream,upstream_host=redirect_to[0],upstream_port=redirect_to[1])
+    streams[upstream] = dict(dowstream=downstream, upstream_host=redirect_to[0],upstream_port=redirect_to[1])
 
     # Return exact same value as what is given as parameter.
     return upstream
+
+def on_headers(downstream, headers):
+    stream = streams[downstream]
+    if stream['upstream_host'] == b'devrant.com':
+        headers = headers.replace(b'localhost', b'devrant.com')
+        headers = headers.replace(b'molodetz.nl', b'devrant.com')
+    if stream['upstream_host'] == b'molodetz.nl':
+        headers = headers.replace(b'localhost', b'molodetz.nl')
+    return headers
diff --git a/protocol.h b/protocol.h
index c71e9b6..ecd8acf 100644
--- a/protocol.h
+++ b/protocol.h
@@ -2,3 +2,73 @@
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
+#include <sys/socket.h>
+
+typedef enum PROTOCOL_STATUS {
+  PS_NONE,
+  PS_SNIFF,
+  PS_HTTP_READ_HEADER,
+  PS_HTTP_READ_BODY,
+  PS_STREAM,
+  PS_ERROR
+
+} PROTOCOL_STATUS;
+
+typedef enum PROTOCOL_NAME {
+  PN_NONE,
+  PN_HTTP,
+  PN_HTTP_CHUNKED,
+  PN_HTTP_WEBSOCKET,
+  PN_HTTP_KEEP_ALIVE,
+  PN_HTTP_REQUEST,
+  PN_SSH,
+  PN_RAW,
+  PN_ERROR
+} PROTOCOL_NAME ;
+
+char * http_get_header_value(char * headers, char * key){
+  char * result = NULL;
+  char * start = strstr(headers, key);
+  if(!start){
+    return NULL;
+  }
+  start += strlen(key);
+  start += 2;
+  char * end = strstr(start, "\r\n");
+  if(!end){
+    return NULL;
+  }
+  result = (char *)malloc(end - start + 1);
+  strncpy(result, start, end - start);
+  result[end-start] = 0;
+  return result;
+}
+
+PROTOCOL_NAME protocol_sniff(int fd){
+  char buffer[4096] = {0};
+
+  ssize_t bytes_received = recv(fd, buffer,sizeof(buffer),MSG_PEEK);
+  buffer[bytes_received] = 0;
+  if(bytes_received <= 0){
+    return PN_ERROR;
+  }
+
+  if(strncmp(buffer,"HTTP",4) == 0){
+    if(strstr(buffer,"Transfer-Encoding: chunked"))
+      return PN_HTTP_CHUNKED;
+    if(strstr(buffer,"Connection: keep-alive")){
+      return PN_HTTP_KEEP_ALIVE;
+    }
+    return PN_HTTP;
+  }
+  if(!strncmp(buffer,"GET ",4) && strstr(buffer,"Upgrade: websocket")){
+    return PN_HTTP_WEBSOCKET;
+  }
+  if(!strncmp(buffer,"SSH ",4)){
+    return PN_SSH;
+  }
+  if(!strncmp(buffer,"GET ", 4)){
+    return PN_HTTP_REQUEST;
+  }
+  return PN_RAW;
+}
\ No newline at end of file
diff --git a/py.h b/py.h
index d092c8d..1afd86d 100644
--- a/py.h
+++ b/py.h
@@ -1,3 +1,5 @@
+#ifndef PGS_PY_H
+#define PGS_PY_H
 #define PY_SSIZE_T_CLEAN 1
 #include "pgs_api.h"
 #include <Python.h>
@@ -64,14 +66,47 @@ void py_destruct() {
   python_initialized = false;
 }
 
-int py_route(int downstream, int upstream) {
+char py_on_headers(int downstream, char *headers) {
   PyObject *pModule = py_construct();
-  long upstream_fd = 0;
+  char *new_headers = NULL;
   if (pModule != NULL) {
-    PyObject *pFunc = PyObject_GetAttrString(pModule, "route");
+    PyObject *pFunc = PyObject_GetAttrString(pModule, "on_headers");
     if (PyCallable_Check(pFunc)) {
       PyObject *pArgs = PyTuple_Pack(2, PyLong_FromLong(downstream),
-                                     PyLong_FromLong(upstream));
+                                     PyUnicode_FromString(headers));
+      PyGILState_STATE gstate = PyGILState_Ensure();
+      new_headers = PyBytes_AsString(PyObject_CallObject(pFunc, pArgs));
+      PyGILState_Release(gstate);
+      Py_DECREF(pArgs);
+    }
+  }
+  return new_headers ? new_headers : headers;
+}
+
+char * py_http_intercept_headers(int sock, char * headers){
+  PyObject *pModule = py_construct();
+  char *new_headers = NULL;
+  if (pModule != NULL) {
+    PyObject *pFunc = PyObject_GetAttrString(pModule, "http_intercept_headers");
+    if (PyCallable_Check(pFunc)) {
+      PyObject *pArgs = PyTuple_Pack(2, PyLong_FromLong(sock),
+                                     PyUnicode_FromString(headers));
+      PyGILState_STATE gstate = PyGILState_Ensure();
+      new_headers = PyBytes_AsString(PyObject_CallObject(pFunc, pArgs));
+      PyGILState_Release(gstate);
+      Py_DECREF(pArgs);
+    }
+  }
+  return new_headers ? new_headers : headers;
+}
+
+int py_on_connect(int downstream) {
+  PyObject *pModule = py_construct();
+  int upstream_fd = -1;
+  if (pModule != NULL) {
+    PyObject *pFunc = PyObject_GetAttrString(pModule, "on_connect");
+    if (PyCallable_Check(pFunc)) {
+      PyObject *pArgs = PyTuple_Pack(1, PyLong_FromLong(downstream));
 
       PyGILState_STATE gstate = PyGILState_Ensure();
 
@@ -97,5 +132,7 @@ int py_route(int downstream, int upstream) {
     fprintf(stderr, "Failed to load 'script'\n");
   }
 
-  return (int)upstream_fd;
+  return upstream_fd;
 }
+
+#endif
\ No newline at end of file
diff --git a/sock.h b/sock.h
index 462f272..4046525 100644
--- a/sock.h
+++ b/sock.h
@@ -1,3 +1,6 @@
+#ifndef PGS_SOCK_H
+#define PGS_SOCK_H
+#include "protocol.h"
 #include <arpa/inet.h>
 #include <errno.h>
 #include <fcntl.h>
@@ -12,16 +15,19 @@
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <unistd.h>
-
+#include "py.h"
 #define MAX_EVENTS 8096
 #define BUFFER_SIZE 1024
 
 typedef struct {
+  PROTOCOL_STATUS status;
+  PROTOCOL_NAME protocol_name;
   int client_fd;
   int upstream_fd;
   char *buffer;
   size_t buffer_size;
   size_t buffer_offset;
+  char *http_intercepted_headers;
 } connection_t;
 
 int listen_fd = 0;
@@ -124,7 +130,7 @@ char *sock_read(int fd, char *buf, size_t size) {
   size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size;
   ssize_t bytes_read = 0;
   char *buffer;
-  buffer[size];
+  buffer[size] = 0;
   if (bytes_to_read) {
     bytes_read = recv(fd, buffer, size, 0);
     buffer[bytes_read] = 0;
@@ -170,6 +176,18 @@ int forward_data(int from_fd, int to_fd) {
   return (int)bytes_read;
 }
 
+ssize_t sock_send_all(int fd, char * content, ssize_t length) {
+    ssize_t bytes_total_sent = 0;
+    while(bytes_total_sent < length) {
+      ssize_t bytes_sent = send(fd, content + bytes_total_sent, length, 0);
+      if(bytes_sent <= 0){
+        return -1;
+      }
+      bytes_total_sent += bytes_sent; 
+    }
+    return bytes_total_sent;
+}
+
 bool handle_connect(struct epoll_event event, int epoll_fd) {
   struct sockaddr_in client_addr;
   socklen_t client_len = sizeof(client_addr);
@@ -187,7 +205,8 @@ bool handle_connect(struct epoll_event event, int epoll_fd) {
   client_event.data.fd = client_fd;
   connections[client_fd]->upstream_fd = -1;
   connections[client_fd]->client_fd = client_fd;
-
+  connections[client_fd]->status = PS_SNIFF;
+  connections[client_fd]->protocol_name = PS_NONE;
   printf("New connection: client_fd=%d\n", client_fd);
 
   epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event);
@@ -200,34 +219,103 @@ void handle_close(int epoll_fd, connection_t *conn) {
   close_connection(epoll_fd, conn);
 }
 
+char * sock_read_until(int fd, char * until, int until_len) {
+
+  char * result = (char *)malloc(BUFFER_SIZE);
+  result[0] = 0;
+  int bytes;
+  int bytes_total = 0;
+  char buffer[2];
+  while((bytes = recv(fd, buffer, 1,0)) > 0){
+    result[bytes_total] = buffer[0];
+    bytes_total += bytes;
+    if(bytes_total >= until_len){
+      if(!strncmp(buffer + bytes_total - until_len, until, until_len)){
+        break;
+      }
+    }
+  }
+  result[bytes_total] = 0;
+  return result;
+}
+
+void http_intercept_headers(int epoll_fd, connection_t *conn) {
+  char * headers = sock_read_until(conn->client_fd, "\r\n\r\n", 4);
+
+  char * intercepted_headers = py_http_intercept_headers(conn->client_fd,headers);
+  if(!intercepted_headers){
+    close_connection(epoll_fd, conn);
+    return;
+  } 
+  sock_send_all(conn->client_fd, intercepted_headers, strlen(intercepted_headers));
+  printf("Intercepted headers: %s\n", intercepted_headers);
+  conn->http_intercepted_headers = intercepted_headers;
+  conn->status = PS_NONE;
+  connections[conn->upstream_fd]->status = PS_HTTP_READ_HEADER;
+}
+
+
+void http_intercept_body(int epoll_fd, connection_t *conn) {
+   printf("HIERR\n"); 
+}
+
 void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) {
   if (conn->upstream_fd == -1) {
-    conn->upstream_fd = prepare_upstream();
-    int upstream_fd = py_route(conn->client_fd, conn->upstream_fd);
-
-    if (upstream_fd == -1) {
+    conn->upstream_fd = py_on_connect(conn->client_fd);
+    if(conn->upstream_fd == -1){
       close_connection(epoll_fd, conn);
       return;
     }
-    set_nonblocking(upstream_fd);
+    conn->protocol_name = protocol_sniff(conn->client_fd);
+    if(conn->protocol_name == PN_ERROR){
+      close_connection(epoll_fd, conn);
+      return;
+    }else if(conn->protocol_name == PN_HTTP){
+      conn->status = PS_HTTP_READ_HEADER;
+    }else if(conn->protocol_name == PN_HTTP_CHUNKED){
+      conn->status = PS_HTTP_READ_HEADER;
+    }else if(conn->protocol_name == PN_HTTP_KEEP_ALIVE){
+      conn->status = PS_HTTP_READ_HEADER;
+    }else if(conn->protocol_name == PN_HTTP_REQUEST){
+      conn->status = PS_HTTP_READ_HEADER;
+    }else if(conn->protocol_name == PN_HTTP_WEBSOCKET){
+      conn->status = PS_HTTP_READ_HEADER;
+    }else if(conn->protocol_name == PN_SSH){
+      conn->status = PS_STREAM;
+    }else if(conn->protocol_name == PN_RAW){
+      conn->status = PS_STREAM;
+    }else {
+      conn->status = PS_ERROR;
+    }
+
+    if (conn->upstream_fd == -1 || conn->status == PS_ERROR) {
+      close_connection(epoll_fd, conn);
+      return;
+    }
+    set_nonblocking(conn->upstream_fd);
     struct epoll_event upstream_event;
     upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
-    upstream_event.data.ptr = connections[upstream_fd];
-    upstream_event.data.fd = upstream_fd;
+    upstream_event.data.ptr = connections[conn->upstream_fd];
+    upstream_event.data.fd = conn->upstream_fd;
+    connections[conn->client_fd]->upstream_fd = conn->upstream_fd;
+    connections[conn->upstream_fd]->client_fd = conn->client_fd;
+    connections[conn->upstream_fd]->upstream_fd = conn->upstream_fd;
 
-    connections[conn->client_fd]->upstream_fd = upstream_fd;
-    connections[upstream_fd]->client_fd = conn->client_fd;
-    connections[upstream_fd]->upstream_fd = upstream_fd;
-
-    epoll_ctl(epoll_fd, EPOLL_CTL_ADD, upstream_fd, &upstream_event);
+    epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn->upstream_fd, &upstream_event);
 
     printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd,
            conn->upstream_fd);
+
     return;
   }
 
   if (event.data.fd == conn->client_fd) {
-
+    if(conn->status = PS_HTTP_READ_HEADER){
+      http_intercept_headers(epoll_fd, conn);
+    }
+    if(conn->status == PS_HTTP_READ_BODY){
+      http_intercept_body(epoll_fd, conn);
+    }
     if (forward_data(conn->client_fd, conn->upstream_fd) < 1) {
       close_connection(epoll_fd, conn);
     }
@@ -290,4 +378,5 @@ void serve(int port) {
       }
     }
   }
-}
\ No newline at end of file
+}
+#endif
\ No newline at end of file