Progress.

This commit is contained in:
retoor 2025-03-21 08:25:05 +01:00
parent 11872b350c
commit 894add21d5
6 changed files with 287 additions and 48 deletions

View File

@ -1,10 +1,15 @@
CC=gcc
CFLAGS=-Wall -g
CFLAGS=-Wall -g -I/usr/include/python3.14 -lpython3.14
LDFLAGS=-lm
OBJS=pgs.o
TARGET=pgs
upstreams:
ssh -f -N -L 3028:127.0.0.1:3028 molodetz.nl
ssh -f -N -L 8082:127.0.0.1:8082 molodetz.nl
build:
$(CC) pgs.c $(CFLAGS) $(LDFLAGS) -o $(TARGET)
@ -12,4 +17,4 @@ run: build
./$(TARGET)
clean:
rm -f $(OBJS) $(TARGET)
rm -f $(OBJS) $(TARGET)

View File

@ -1,6 +1,8 @@
#define PY_SSIZE_T_CLEAN 1
#include "pgs.h"
#include <Python.h>
#include "py.h"
#include "sock.h"
#include <arpa/inet.h>
#include <stdbool.h>
#include <stdio.h>
@ -8,6 +10,18 @@
#include <string.h>
#include <sys/socket.h>
static PyObject *pgs_api_connect(PyObject *self, PyObject *args) {
char *host;
int port;
int fd;
if (!PyArg_ParseTuple(args, "si", &host, &port)) {
return PyLong_FromLong(-1);
}
if ((fd = connect_upstream(host, port)) == -1) {
return PyLong_FromLong(-1);
}
return PyLong_FromLong(fd);
}
static PyObject *pgs_api_is_http(PyObject *self, PyObject *args) {
const char *py_bytes;
if (!PyArg_ParseTuple(args, "y", &py_bytes)) {
@ -35,9 +49,18 @@ static PyObject *pgs_api_read(PyObject *self, PyObject *args) {
char buffer[length + 1];
ssize_t bytes_read = read(fd, buffer, length);
buffer[bytes_read] = 0;
Py_buffer *pybuffer = (Py_buffer *)malloc(bytes_read);
PyBuffer_FillInfo(pybuffer, 0, &buffer, bytes_read, false, PyBUF_CONTIG);
return PyMemoryView_FromBuffer(pybuffer);
return PyBytes_FromString(buffer);
}
static PyObject *pgs_api_peek(PyObject *self, PyObject *args) {
int fd, length;
if (!PyArg_ParseTuple(args, "ii", &fd, &length)) {
return NULL;
}
char buffer[length + 1];
ssize_t bytes_read = recv(fd, buffer, length, MSG_PEEK);
buffer[bytes_read] = 0;
return PyBytes_FromString(buffer);
}
static PyObject *pgs_api_write(PyObject *self, PyObject *args) {
@ -76,7 +99,9 @@ static PyObject *mymodule_add(PyObject *self, PyObject *args) {
// Method table for the module
static PyMethodDef MyModuleMethods[] = {
{"add", mymodule_add, METH_VARARGS, "Add two numbers"},
{"connect", pgs_api_connect, METH_VARARGS, "Connect to upstream"},
{"read", pgs_api_read, METH_VARARGS, "Read fd"},
{"peek", pgs_api_peek, METH_VARARGS, "Peek fd"},
{"write", pgs_api_write, METH_VARARGS, "Write fd"},
{"is_ssh", pgs_api_is_ssh, METH_VARARGS,
"Check if header contains SSH data."},

View File

@ -39,7 +39,7 @@ def is_http(header_bytes):
def is_https(header_bytes):
return not any([is_ssh(header_bytes), is_http(header_bytes)])
def route(downstream,upstream):
def on_connect(downstream):
"""
This is a connection router which will be called by the server every
time a client connects. This function will be used to determine
@ -80,15 +80,17 @@ def route(downstream,upstream):
counter += 1
print("Connection nr.", counter)
u = socket.fromfd(upstream, socket.AF_INET, socket.SOCK_STREAM)
#u = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
#print("FD:",u.fileno())
peek = pgs.read(downstream, 4096).tobytes()
peek = pgs.read(downstream, 4096)
redirect_to = []
if pgs.is_ssh(peek):
print("Forwarding to ssh molodetz")
u.connect(("molodetz.nl", 22))
redirect_to = "molodetz.nl",22
u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM)
elif pgs.is_http(peek):
if b'/random' in peek or b'random.' in peek:
@ -96,11 +98,13 @@ def route(downstream,upstream):
print("Forwarding to 127.0.0.1:3028.")
peek = peek.replace(b'/random', b'/')
peek = peek.replace(b'random.', b'')
u.connect(("127.0.0.1", 3028))
redirect_to = "127.0.0.1",3028
u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM)
elif b'molodetz.local' in peek:
print("Forwarding to 127.0.0.1:8082.")
peek = peek.replace(b'molodetz.local', b'localhost')
u.connect(("127.0.0.1", 8082))
redirect_to = "127.0.0.1",8082
u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM)
elif b'bench.local' in peek:
print("Responding with bench page.")
body = f"""<html>\n<head>\n<title>Benchmark page.</title>\n</head>\n<body>\n<h1>Bench</h1>\n<p>{counter}</p>\n</body>\n</html>\n""".encode()
@ -127,7 +131,6 @@ Environment: {env}
Total connections: {counter}
Local hostname: {hostname}
Downstream FD: {downstream}
Upstream FD: {upstream}
Current time server: {datetime.now()}
Server started on: {server_start}
Server uptime: {get_server_uptime()}
@ -145,8 +148,8 @@ Server uptime: {get_server_uptime()}
""
]
headers = "\r\n".join(headers)
response = f"{headers}{body}"
#response = f"{headers}{body}"
response = f"{headers}\r\n\r\n{body}"
pgs.write(downstream,response)
# Unset socket so the server will close it.
@ -157,9 +160,8 @@ Server uptime: {get_server_uptime()}
elif is_https(peek) and env == "prod":
print("Forwarding to dev.to")
u.connect(("devrant.com", 443))
peek = peek.replace(b'localhost', b'devrant.com')
peek = peek.replace(b'molodetz.nl', b'devrant.com')
redirect_to = "devrant.com", 443
u = socket.fromfd(pgs.connect(*redirect_to), socket.AF_INET, socket.SOCK_STREAM)
else:
# Error.
print("Could not find upstream for header content.")
@ -171,15 +173,26 @@ Server uptime: {get_server_uptime()}
if not u:
return -1
# Remove reference to the socket so it doesn't get garbage collected.
# This could break the connection. This way, it stays open.
u = None
os.write(upstream,peek)
#os.write(upstream,peek)
# Keep track of connections. Not sure if this is needed.
streams[downstream] = upstream
streams[upstream] = downstream
upstream = u.fileno()
# Remove reference to the socket so it doesn't get garbage collected.
# This could break the connection. This way, it stays open.
u = None
streams[downstream] = dict(upstream=upstream,upstream_host=redirect_to[0],upstream_port=redirect_to[1])
streams[upstream] = dict(dowstream=downstream, upstream_host=redirect_to[0],upstream_port=redirect_to[1])
# Return exact same value as what is given as parameter.
return upstream
def on_headers(downstream, headers):
stream = streams[downstream]
if stream['upstream_host'] == b'devrant.com':
headers = headers.replace(b'localhost', b'devrant.com')
headers = headers.replace(b'molodetz.nl', b'devrant.com')
if stream['upstream_host'] == b'molodetz.nl':
headers = headers.replace(b'localhost', b'molodetz.nl')
return headers

View File

@ -2,3 +2,73 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
typedef enum PROTOCOL_STATUS {
PS_NONE,
PS_SNIFF,
PS_HTTP_READ_HEADER,
PS_HTTP_READ_BODY,
PS_STREAM,
PS_ERROR
} PROTOCOL_STATUS;
typedef enum PROTOCOL_NAME {
PN_NONE,
PN_HTTP,
PN_HTTP_CHUNKED,
PN_HTTP_WEBSOCKET,
PN_HTTP_KEEP_ALIVE,
PN_HTTP_REQUEST,
PN_SSH,
PN_RAW,
PN_ERROR
} PROTOCOL_NAME ;
char * http_get_header_value(char * headers, char * key){
char * result = NULL;
char * start = strstr(headers, key);
if(!start){
return NULL;
}
start += strlen(key);
start += 2;
char * end = strstr(start, "\r\n");
if(!end){
return NULL;
}
result = (char *)malloc(end - start + 1);
strncpy(result, start, end - start);
result[end-start] = 0;
return result;
}
PROTOCOL_NAME protocol_sniff(int fd){
char buffer[4096] = {0};
ssize_t bytes_received = recv(fd, buffer,sizeof(buffer),MSG_PEEK);
buffer[bytes_received] = 0;
if(bytes_received <= 0){
return PN_ERROR;
}
if(strncmp(buffer,"HTTP",4) == 0){
if(strstr(buffer,"Transfer-Encoding: chunked"))
return PN_HTTP_CHUNKED;
if(strstr(buffer,"Connection: keep-alive")){
return PN_HTTP_KEEP_ALIVE;
}
return PN_HTTP;
}
if(!strncmp(buffer,"GET ",4) && strstr(buffer,"Upgrade: websocket")){
return PN_HTTP_WEBSOCKET;
}
if(!strncmp(buffer,"SSH ",4)){
return PN_SSH;
}
if(!strncmp(buffer,"GET ", 4)){
return PN_HTTP_REQUEST;
}
return PN_RAW;
}

47
py.h
View File

@ -1,3 +1,5 @@
#ifndef PGS_PY_H
#define PGS_PY_H
#define PY_SSIZE_T_CLEAN 1
#include "pgs_api.h"
#include <Python.h>
@ -64,14 +66,47 @@ void py_destruct() {
python_initialized = false;
}
int py_route(int downstream, int upstream) {
char py_on_headers(int downstream, char *headers) {
PyObject *pModule = py_construct();
long upstream_fd = 0;
char *new_headers = NULL;
if (pModule != NULL) {
PyObject *pFunc = PyObject_GetAttrString(pModule, "route");
PyObject *pFunc = PyObject_GetAttrString(pModule, "on_headers");
if (PyCallable_Check(pFunc)) {
PyObject *pArgs = PyTuple_Pack(2, PyLong_FromLong(downstream),
PyLong_FromLong(upstream));
PyUnicode_FromString(headers));
PyGILState_STATE gstate = PyGILState_Ensure();
new_headers = PyBytes_AsString(PyObject_CallObject(pFunc, pArgs));
PyGILState_Release(gstate);
Py_DECREF(pArgs);
}
}
return new_headers ? new_headers : headers;
}
char * py_http_intercept_headers(int sock, char * headers){
PyObject *pModule = py_construct();
char *new_headers = NULL;
if (pModule != NULL) {
PyObject *pFunc = PyObject_GetAttrString(pModule, "http_intercept_headers");
if (PyCallable_Check(pFunc)) {
PyObject *pArgs = PyTuple_Pack(2, PyLong_FromLong(sock),
PyUnicode_FromString(headers));
PyGILState_STATE gstate = PyGILState_Ensure();
new_headers = PyBytes_AsString(PyObject_CallObject(pFunc, pArgs));
PyGILState_Release(gstate);
Py_DECREF(pArgs);
}
}
return new_headers ? new_headers : headers;
}
int py_on_connect(int downstream) {
PyObject *pModule = py_construct();
int upstream_fd = -1;
if (pModule != NULL) {
PyObject *pFunc = PyObject_GetAttrString(pModule, "on_connect");
if (PyCallable_Check(pFunc)) {
PyObject *pArgs = PyTuple_Pack(1, PyLong_FromLong(downstream));
PyGILState_STATE gstate = PyGILState_Ensure();
@ -97,5 +132,7 @@ int py_route(int downstream, int upstream) {
fprintf(stderr, "Failed to load 'script'\n");
}
return (int)upstream_fd;
return upstream_fd;
}
#endif

123
sock.h
View File

@ -1,3 +1,6 @@
#ifndef PGS_SOCK_H
#define PGS_SOCK_H
#include "protocol.h"
#include <arpa/inet.h>
#include <errno.h>
#include <fcntl.h>
@ -12,16 +15,19 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include "py.h"
#define MAX_EVENTS 8096
#define BUFFER_SIZE 1024
typedef struct {
PROTOCOL_STATUS status;
PROTOCOL_NAME protocol_name;
int client_fd;
int upstream_fd;
char *buffer;
size_t buffer_size;
size_t buffer_offset;
char *http_intercepted_headers;
} connection_t;
int listen_fd = 0;
@ -124,7 +130,7 @@ char *sock_read(int fd, char *buf, size_t size) {
size_t bytes_to_read = size > left_in_buffer ? left_in_buffer : size;
ssize_t bytes_read = 0;
char *buffer;
buffer[size];
buffer[size] = 0;
if (bytes_to_read) {
bytes_read = recv(fd, buffer, size, 0);
buffer[bytes_read] = 0;
@ -170,6 +176,18 @@ int forward_data(int from_fd, int to_fd) {
return (int)bytes_read;
}
ssize_t sock_send_all(int fd, char * content, ssize_t length) {
ssize_t bytes_total_sent = 0;
while(bytes_total_sent < length) {
ssize_t bytes_sent = send(fd, content + bytes_total_sent, length, 0);
if(bytes_sent <= 0){
return -1;
}
bytes_total_sent += bytes_sent;
}
return bytes_total_sent;
}
bool handle_connect(struct epoll_event event, int epoll_fd) {
struct sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
@ -187,7 +205,8 @@ bool handle_connect(struct epoll_event event, int epoll_fd) {
client_event.data.fd = client_fd;
connections[client_fd]->upstream_fd = -1;
connections[client_fd]->client_fd = client_fd;
connections[client_fd]->status = PS_SNIFF;
connections[client_fd]->protocol_name = PS_NONE;
printf("New connection: client_fd=%d\n", client_fd);
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &client_event);
@ -200,34 +219,103 @@ void handle_close(int epoll_fd, connection_t *conn) {
close_connection(epoll_fd, conn);
}
char * sock_read_until(int fd, char * until, int until_len) {
char * result = (char *)malloc(BUFFER_SIZE);
result[0] = 0;
int bytes;
int bytes_total = 0;
char buffer[2];
while((bytes = recv(fd, buffer, 1,0)) > 0){
result[bytes_total] = buffer[0];
bytes_total += bytes;
if(bytes_total >= until_len){
if(!strncmp(buffer + bytes_total - until_len, until, until_len)){
break;
}
}
}
result[bytes_total] = 0;
return result;
}
void http_intercept_headers(int epoll_fd, connection_t *conn) {
char * headers = sock_read_until(conn->client_fd, "\r\n\r\n", 4);
char * intercepted_headers = py_http_intercept_headers(conn->client_fd,headers);
if(!intercepted_headers){
close_connection(epoll_fd, conn);
return;
}
sock_send_all(conn->client_fd, intercepted_headers, strlen(intercepted_headers));
printf("Intercepted headers: %s\n", intercepted_headers);
conn->http_intercepted_headers = intercepted_headers;
conn->status = PS_NONE;
connections[conn->upstream_fd]->status = PS_HTTP_READ_HEADER;
}
void http_intercept_body(int epoll_fd, connection_t *conn) {
printf("HIERR\n");
}
void handle_stream(struct epoll_event event, int epoll_fd, connection_t *conn) {
if (conn->upstream_fd == -1) {
conn->upstream_fd = prepare_upstream();
int upstream_fd = py_route(conn->client_fd, conn->upstream_fd);
if (upstream_fd == -1) {
conn->upstream_fd = py_on_connect(conn->client_fd);
if(conn->upstream_fd == -1){
close_connection(epoll_fd, conn);
return;
}
set_nonblocking(upstream_fd);
conn->protocol_name = protocol_sniff(conn->client_fd);
if(conn->protocol_name == PN_ERROR){
close_connection(epoll_fd, conn);
return;
}else if(conn->protocol_name == PN_HTTP){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_CHUNKED){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_KEEP_ALIVE){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_REQUEST){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_HTTP_WEBSOCKET){
conn->status = PS_HTTP_READ_HEADER;
}else if(conn->protocol_name == PN_SSH){
conn->status = PS_STREAM;
}else if(conn->protocol_name == PN_RAW){
conn->status = PS_STREAM;
}else {
conn->status = PS_ERROR;
}
if (conn->upstream_fd == -1 || conn->status == PS_ERROR) {
close_connection(epoll_fd, conn);
return;
}
set_nonblocking(conn->upstream_fd);
struct epoll_event upstream_event;
upstream_event.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
upstream_event.data.ptr = connections[upstream_fd];
upstream_event.data.fd = upstream_fd;
upstream_event.data.ptr = connections[conn->upstream_fd];
upstream_event.data.fd = conn->upstream_fd;
connections[conn->client_fd]->upstream_fd = conn->upstream_fd;
connections[conn->upstream_fd]->client_fd = conn->client_fd;
connections[conn->upstream_fd]->upstream_fd = conn->upstream_fd;
connections[conn->client_fd]->upstream_fd = upstream_fd;
connections[upstream_fd]->client_fd = conn->client_fd;
connections[upstream_fd]->upstream_fd = upstream_fd;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, upstream_fd, &upstream_event);
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn->upstream_fd, &upstream_event);
printf("Connected: client_fd=%d, upstream_fd=%d\n", conn->client_fd,
conn->upstream_fd);
return;
}
if (event.data.fd == conn->client_fd) {
if(conn->status = PS_HTTP_READ_HEADER){
http_intercept_headers(epoll_fd, conn);
}
if(conn->status == PS_HTTP_READ_BODY){
http_intercept_body(epoll_fd, conn);
}
if (forward_data(conn->client_fd, conn->upstream_fd) < 1) {
close_connection(epoll_fd, conn);
}
@ -290,4 +378,5 @@ void serve(int port) {
}
}
}
}
}
#endif