#ifndef NSOCK_H #define NSOCK_H #include "rmalloc.h" #include #include #include #include #include #include #include #include #include #include #include #include "rio.h" int *nsock_socks = NULL; int *nsock_readable = NULL; void **nsock_data = NULL; int nsock_server_fd = 0; int nsock_max_socket_fd = 0; typedef enum nsock_type_t { NSOCK_NONE = 0, NSOCK_SERVER, NSOCK_CLIENT, NSOCK_UPSTREAM } nsock_type_t; typedef struct nsock_it { int fd; int *upstreams; bool connected; bool downstream; unsigned int upstream_count; nsock_type_t type; } nsock_t; nsock_t **nsocks = NULL; int nsocks_count = 0; void (*nsock_on_connect)(int fd) = NULL; void (*nsock_on_data)(int fd) = NULL; void (*nsock_on_close)(int fd) = NULL; void nsock_on_before_data(int fd); nsock_t *nsock_get(int fd) { if (nsock_socks[fd] == 0) { return NULL; } if (fd >= nsocks_count || nsocks[fd] == NULL) { if (fd >= nsocks_count) { nsocks_count = fd + 1; nsocks = (nsock_t **)realloc(nsocks, sizeof(nsock_t *) * sizeof(nsock_t) * (nsocks_count)); nsocks[fd] = (nsock_t *)calloc(1, sizeof(nsock_t)); } nsocks[fd]->upstreams = NULL; nsocks[fd]->fd = fd; nsocks[fd]->connected = false; nsocks[fd]->downstream = false; nsocks[fd]->upstream_count = 0; nsocks[fd]->type = NSOCK_CLIENT; return nsocks[fd]; } return nsocks[fd]; } void nsock_close(int fd) { if (nsock_on_close) nsock_on_close(fd); nsock_t *sock = nsock_get(fd); if (sock && sock->connected) { sock->connected = false; for (unsigned int i = 0; i < sock->upstream_count; i++) { nsock_t *upstream = nsock_get(sock->upstreams[i]); if (upstream->connected) nsock_close(sock->upstreams[i]); sock->upstreams[i] = 0; } if (sock->upstream_count) { free(sock->upstreams); } sock->upstream_count = 0; } nsock_socks[fd] = 0; close(fd); } nsock_t *nsock_create(int fd, nsock_type_t type) { if (fd <= 0) return NULL; nsock_socks[fd] = fd; nsock_t *sock = nsock_get(fd); sock->connected = true; sock->downstream = false; sock->type = type; return sock; } int *nsock_init(int socket_count) { if (nsock_socks) { return nsock_socks; } nsock_socks = (int *)calloc(1, sizeof(int) * sizeof(int *) * socket_count + 1); if (nsock_data) { free(nsock_data); nsock_data = NULL; } nsock_data = (void **)malloc(sizeof(void **) * socket_count + 1); nsock_socks[socket_count] = -1; return nsock_socks; } void nsock_free() { if (nsock_socks) free(nsock_socks); if (nsock_readable) free(nsock_readable); nsock_server_fd = 0; nsock_max_socket_fd = 0; if (nsock_data) { exit(1); } } void nsock_add_upstream(int source, int target, bool downstream) { if (!nsock_socks[target]) return; if (!nsock_socks[source]) return; nsock_t *sock = nsock_get(source); nsock_t *sock_target = nsock_get(target); sock_target->type = NSOCK_UPSTREAM; sock->upstreams = (int *)realloc(sock->upstreams, sizeof(int) * (sock->upstream_count + 1)); sock->downstream = downstream; sock->upstreams[sock->upstream_count] = target; sock->upstream_count++; } void *nsock_get_data(int socket) { return nsock_data[socket]; } void nsock_set_data(int socket, void *data) { nsock_data[socket] = data; } int nsock_connect(const char *host, unsigned int port) { char port_str[10] = {0}; sprintf(port_str, "%d", port); int status; int socket_fd = 0; struct addrinfo hints; struct addrinfo *res; struct addrinfo *p; if ((socket_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { return false; } memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; if ((status = getaddrinfo(host, port_str, &hints, &res)) != 0) { return 0; } for (p = res; p != NULL; p = p->ai_next) { if ((socket_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol)) == -1) { continue; } if (connect(socket_fd, p->ai_addr, p->ai_addrlen) == -1) { close(socket_fd); continue; } break; } if (p == NULL) { freeaddrinfo(res); return 0; } freeaddrinfo(res); if (socket_fd) { if (nsock_socks == NULL) { nsock_init(2048); } nsock_socks[socket_fd] = socket_fd; nsock_t *sock = nsock_create(socket_fd, NSOCK_CLIENT); sock->connected = true; } return socket_fd; } void nsock_listen(int port) { int server_fd; struct sockaddr_in address; if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) { perror("Socket failed"); exit(EXIT_FAILURE); } int opt = 1; if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) { perror("setsockopt failed"); close(server_fd); exit(EXIT_FAILURE); } address.sin_family = AF_INET; address.sin_addr.s_addr = INADDR_ANY; address.sin_port = htons(port); if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) { perror("Bind failed"); close(server_fd); exit(EXIT_FAILURE); } if (listen(server_fd, 8096) < 0) { perror("Listen failed"); close(server_fd); exit(EXIT_FAILURE); } nsock_server_fd = server_fd; } int *nsock_select(suseconds_t timeout) { struct timeval tv; tv.tv_sec = 0; tv.tv_usec = timeout; int server_fd = nsock_server_fd; fd_set rfds; FD_ZERO(&rfds); FD_SET(server_fd, &rfds); int *socks = nsock_socks; fd_set efds; FD_ZERO(&efds); nsock_max_socket_fd = server_fd; for (int i = 0; socks[i] != -1; i++) { if (i == server_fd) continue; ; if (!socks[i]) continue; if (socks[i] > nsock_max_socket_fd) { nsock_max_socket_fd = socks[i]; } FD_SET(socks[i], &rfds); FD_SET(socks[i], &efds); } int activity = select(nsock_max_socket_fd + 1, &rfds, NULL, &efds, timeout == 0 ? NULL : &tv); if ((activity < 0) && (errno != EINTR)) { perror("Select error\n"); return NULL; } else if (activity == 0) { return NULL; } if (FD_ISSET(server_fd, &rfds)) { struct sockaddr_in address; int addrlen = sizeof(address); address.sin_family = AF_INET; address.sin_addr.s_addr = INADDR_ANY; int new_socket = 0; if ((new_socket = accept(server_fd, (struct sockaddr *)&address, (socklen_t *)&addrlen)) < 0) { perror("Accept failed"); } else { nsock_socks[new_socket] = new_socket; nsock_create(new_socket, NSOCK_CLIENT); if (nsock_on_connect) nsock_on_connect(new_socket); if (new_socket > nsock_max_socket_fd) nsock_max_socket_fd = new_socket; } } if (nsock_readable) { free(nsock_readable); } nsock_readable = (int *)calloc(1, sizeof(int *) + sizeof(int) * (nsock_max_socket_fd + 2)); nsock_readable[nsock_max_socket_fd + 1] = -1; nsock_readable[0] = 0; int readable_count = 0; for (int i = 0; i < nsock_max_socket_fd + 1; i++) { nsock_t *sock = nsock_get(i); if (!sock) continue; if (FD_ISSET(i, &efds)) { nsock_close(nsock_socks[i]); nsock_socks[i] = 0; nsock_readable[i] = 0; } else if (FD_ISSET(i, &rfds) && i != server_fd) { nsock_readable[i] = i; readable_count++; nsock_on_before_data(i); } else { nsock_readable[i] = 0; sock->connected = false; } } return nsock_readable; } unsigned char *nsock_read(int fd, int length) { if (!nsock_socks[fd]) return NULL; unsigned char *buffer = (unsigned char *)malloc(length + 1); int bytes_read = read(fd, buffer, length); if (bytes_read <= 0) { nsock_close(fd); return NULL; } buffer[bytes_read] = 0; return buffer; } unsigned char *nsock_read_all(int fd, int length) { if (!nsock_socks[fd]) return NULL; unsigned char *buffer = (unsigned char *)malloc(length + 1); int bytes_read = 0; while (bytes_read < length) { int bytes_chunk = read(fd, buffer + bytes_read, length - bytes_read); if (bytes_chunk <= 0) { nsock_close(fd); return NULL; } bytes_read += bytes_chunk; } buffer[bytes_read] = 0; return buffer; } int nsock_write_all(int fd, unsigned char *data, int length) { if (!nsock_socks[fd]) return 0; int bytes_written = 0; while (bytes_written < length) { int bytes_chunk = write(fd, data + bytes_written, length - bytes_written); if (bytes_chunk <= 0) { nsock_close(fd); return 0; } bytes_written += bytes_chunk; } return bytes_written; } int nsock_execute_upstream(int source, size_t buffer_size) { int result = 0; nsock_t *sock = nsock_get(source); unsigned char data[buffer_size]; memset(data, 0, buffer_size); int bytes_read = read(source, data, buffer_size); if (bytes_read <= 0) { nsock_close(source); return 0; } bool downstreamed = false; for (unsigned int i = 0; i < sock->upstream_count; i++) { if (!nsock_socks[sock->upstreams[i]]) continue; int bytes_sent = nsock_write_all(sock->upstreams[i], data, bytes_read); if (bytes_sent <= 0) { nsock_close(sock->upstreams[i]); continue; } if (sock->downstream && downstreamed == false) { downstreamed = true; unsigned char data[4096]; memset(data, 0, 4096); int bytes_read = read(sock->upstreams[i], data, 4096); if (bytes_read <= 0) { nsock_close(source); return 0; } int bytes_sent = nsock_write_all(sock->fd, data, bytes_read); if (bytes_sent <= 0) { nsock_close(sock->upstreams[i]); return 0; } } result++; } return result; } void nsock_on_before_data(int fd) { if (!nsock_socks[fd]) return; nsock_t *sock = nsock_get(fd); if (sock->upstream_count) { int upstreamed_to_count = nsock_execute_upstream(fd, 4096); if (!upstreamed_to_count) { nsock_close(fd); } return; } else if (sock->type == NSOCK_UPSTREAM) { while (rfd_wait(sock->fd, 0)) { unsigned char *data = nsock_read(fd, 4096); (void)data; } } if (nsock_on_data) nsock_on_data(fd); } void nsock(int port, void (*on_connect)(int fd), void (*on_data)(int fd), void (*on_close)(int fd)) { nsock_init(2048); nsock_listen(port); nsock_on_connect = on_connect; nsock_on_data = on_data; nsock_on_close = on_close; int serve_in_terminal = nsock_on_connect == NULL && nsock_on_data == NULL && nsock_on_close == NULL; while (1) { int *readable = nsock_select(0); if (!serve_in_terminal) continue; if (!readable) continue; for (int i = 0; readable[i] != -1; i++) { if (!readable[i]) continue; char buffer[1024] = {0}; int bytes_read = read(readable[i], buffer, 1); buffer[bytes_read] = 0; if (bytes_read <= 0) { nsock_close(readable[i]); continue; } if (write(readable[i], buffer, bytes_read) <= 0) { nsock_close(readable[i]); continue; } } } } #endif