#include "wren.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <errno.h>
#include <limits.h>
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "ws2_32.lib")
typedef SOCKET socket_t;
typedef HANDLE thread_t;
typedef CRITICAL_SECTION mutex_t;
typedef CONDITION_VARIABLE cond_t;
#define INVALID_SOCKET_HANDLE INVALID_SOCKET
#define poll WSAPoll
#else
#include <pthread.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <sys/select.h>
#include <fcntl.h>
#include <poll.h>
#include <signal.h>
typedef int socket_t;
typedef pthread_t thread_t;
typedef pthread_mutex_t mutex_t;
typedef pthread_cond_t cond_t;
#define closesocket(s) close(s)
#define INVALID_SOCKET_HANDLE -1
#endif
#define MAX_READ_UNTIL (16LL * 1024 * 1024) // 16MB limit for read_until
#define MAX_BUFFER_SIZE (64LL * 1024 * 1024) // 64MB absolute max
#define CHUNK_SIZE 65536 // 64KB chunks for reading
// --- Data Structures ---
typedef enum {
SOCKET_OP_CONNECT,
SOCKET_OP_NEW,
SOCKET_OP_BIND,
SOCKET_OP_LISTEN,
SOCKET_OP_ACCEPT,
SOCKET_OP_READ,
SOCKET_OP_READ_UNTIL,
SOCKET_OP_READ_EXACTLY,
SOCKET_OP_WRITE,
SOCKET_OP_IS_READABLE,
SOCKET_OP_SELECT,
SOCKET_OP_CLOSE
} SocketOp;
typedef struct {
char* data;
size_t length;
size_t capacity;
} Buffer;
typedef struct SocketContext {
WrenVM* vm;
SocketOp operation;
WrenHandle* callback;
// Operation specific data
socket_t sock;
char* host;
int port;
int backlog;
size_t length;
Buffer write_data;
char* until_bytes;
size_t until_len;
socket_t* sockets;
int sockets_count;
// Result data
bool success;
char* error_message;
socket_t new_sock;
Buffer read_data;
bool result_bool;
socket_t* readable_sockets;
int readable_count;
struct SocketContext* next;
} SocketContext;
// --- Thread-Safe Queue for Socket Operations ---
typedef struct {
SocketContext *head, *tail;
mutex_t mutex;
cond_t cond;
volatile bool shutdown;
} SocketThreadSafeQueue;
static void socket_queue_init(SocketThreadSafeQueue* q) {
if (!q) return;
q->head = q->tail = NULL;
q->shutdown = false;
#ifdef _WIN32
InitializeCriticalSection(&q->mutex);
InitializeConditionVariable(&q->cond);
#else
pthread_mutex_init(&q->mutex, NULL);
pthread_cond_init(&q->cond, NULL);
#endif
}
static void socket_queue_destroy(SocketThreadSafeQueue* q) {
if (!q) return;
#ifdef _WIN32
DeleteCriticalSection(&q->mutex);
#else
pthread_mutex_destroy(&q->mutex);
pthread_cond_destroy(&q->cond);
#endif
}
static void socket_queue_push(SocketThreadSafeQueue* q, SocketContext* context) {
if (!q) return;
#ifdef _WIN32
EnterCriticalSection(&q->mutex);
#else
pthread_mutex_lock(&q->mutex);
#endif
if (q->shutdown) {
#ifdef _WIN32
LeaveCriticalSection(&q->mutex);
#else
pthread_mutex_unlock(&q->mutex);
#endif
return;
}
if (context) context->next = NULL;
if (q->tail) q->tail->next = context;
else q->head = context;
q->tail = context;
#ifdef _WIN32
WakeConditionVariable(&q->cond);
LeaveCriticalSection(&q->mutex);
#else
pthread_cond_signal(&q->cond);
pthread_mutex_unlock(&q->mutex);
#endif
}
static SocketContext* socket_queue_pop(SocketThreadSafeQueue* q) {
if (!q) return NULL;
#ifdef _WIN32
EnterCriticalSection(&q->mutex);
while (q->head == NULL && !q->shutdown) {
SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE);
}
#else
pthread_mutex_lock(&q->mutex);
while (q->head == NULL && !q->shutdown) {
pthread_cond_wait(&q->cond, &q->mutex);
}
#endif
if (q->shutdown && q->head == NULL) {
#ifdef _WIN32
LeaveCriticalSection(&q->mutex);
#else
pthread_mutex_unlock(&q->mutex);
#endif
return NULL;
}
SocketContext* context = q->head;
if (context) {
q->head = q->head->next;
if (q->head == NULL) q->tail = NULL;
}
#ifdef _WIN32
LeaveCriticalSection(&q->mutex);
#else
pthread_mutex_unlock(&q->mutex);
#endif
return context;
}
static void socket_queue_shutdown(SocketThreadSafeQueue* q) {
if (!q) return;
#ifdef _WIN32
EnterCriticalSection(&q->mutex);
q->shutdown = true;
WakeAllConditionVariable(&q->cond);
LeaveCriticalSection(&q->mutex);
#else
pthread_mutex_lock(&q->mutex);
q->shutdown = true;
pthread_cond_broadcast(&q->cond);
pthread_mutex_unlock(&q->mutex);
#endif
}
static bool socket_queue_empty(SocketThreadSafeQueue* q) {
if (!q) return true;
bool empty;
#ifdef _WIN32
EnterCriticalSection(&q->mutex);
empty = (q->head == NULL);
LeaveCriticalSection(&q->mutex);
#else
pthread_mutex_lock(&q->mutex);
empty = (q->head == NULL);
pthread_mutex_unlock(&q->mutex);
#endif
return empty;
}
// --- Async Socket Manager ---
typedef struct {
WrenVM* vm;
volatile bool running;
thread_t threads[16];
SocketThreadSafeQueue requestQueue;
SocketThreadSafeQueue completionQueue;
} AsyncSocketManager;
static AsyncSocketManager* socketManager = NULL;
static void free_buffer(Buffer* buf) {
if (buf && buf->data) {
free(buf->data);
buf->data = NULL;
buf->length = 0;
buf->capacity = 0;
}
}
static void free_socket_context(SocketContext* context) {
if (context == NULL) return;
if (context->host) {
free(context->host);
context->host = NULL;
}
if (context->error_message) {
free(context->error_message);
context->error_message = NULL;
}
free_buffer(&context->write_data);
free_buffer(&context->read_data);
if (context->until_bytes) {
free(context->until_bytes);
context->until_bytes = NULL;
}
if (context->sockets) {
free(context->sockets);
context->sockets = NULL;
}
if (context->readable_sockets) {
free(context->readable_sockets);
context->readable_sockets = NULL;
}
if (context->vm && context->callback) {
wrenReleaseHandle(context->vm, context->callback);
context->callback = NULL;
}
memset(context, 0, sizeof(SocketContext));
free(context);
}
static void set_socket_context_error(SocketContext* context, const char* message) {
if (context == NULL) return;
context->success = false;
if (context->error_message) {
free(context->error_message);
context->error_message = NULL;
}
if (message) {
size_t len = strlen(message);
if (len > 1024) len = 1024; // Limit error message size
context->error_message = (char*)malloc(len + 1);
if (context->error_message) {
strncpy(context->error_message, message, len);
context->error_message[len] = '\0';
}
} else {
context->error_message = strdup("Unknown socket error");
}
}
static void set_socket_context_error_errno(SocketContext* context, const char* prefix) {
if (context == NULL) return;
char buf[512];
#ifdef _WIN32
int err = WSAGetLastError();
FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
(LPSTR)&buf, 0, NULL);
set_socket_context_error(context, buf);
#else
snprintf(buf, sizeof(buf), "%s: %s", prefix ? prefix : "Socket error", strerror(errno));
set_socket_context_error(context, buf);
#endif
}
static bool set_socket_nonblocking(socket_t sock) {
#ifdef _WIN32
u_long mode = 1;
return ioctlsocket(sock, FIONBIO, &mode) == 0;
#else
int flags = fcntl(sock, F_GETFL, 0);
if (flags == -1) return false;
return fcntl(sock, F_SETFL, flags | O_NONBLOCK) != -1;
#endif
}
#ifdef _WIN32
static DWORD WINAPI socketWorkerThread(LPVOID arg);
#else
static void* socketWorkerThread(void* arg);
#endif
static void socketManager_create(WrenVM* vm) {
if (socketManager != NULL || !vm) return;
#ifndef _WIN32
// Ignore SIGPIPE to prevent crashes on broken pipe
signal(SIGPIPE, SIG_IGN);
#endif
socketManager = (AsyncSocketManager*)calloc(1, sizeof(AsyncSocketManager));
if (socketManager == NULL) return;
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
free(socketManager);
socketManager = NULL;
return;
}
#endif
socketManager->vm = vm;
socketManager->running = true;
socket_queue_init(&socketManager->requestQueue);
socket_queue_init(&socketManager->completionQueue);
for (int i = 0; i < 16; ++i) {
#ifdef _WIN32
socketManager->threads[i] = CreateThread(NULL, 0, socketWorkerThread, socketManager, 0, NULL);
if (socketManager->threads[i] == NULL) {
socketManager->running = false;
socket_queue_shutdown(&socketManager->requestQueue);
socket_queue_shutdown(&socketManager->completionQueue);
for (int j = 0; j < i; ++j) {
if (socketManager->threads[j]) {
WaitForSingleObject(socketManager->threads[j], 5000);
CloseHandle(socketManager->threads[j]);
}
}
socket_queue_destroy(&socketManager->requestQueue);
socket_queue_destroy(&socketManager->completionQueue);
free(socketManager);
socketManager = NULL;
return;
}
#else
if (pthread_create(&socketManager->threads[i], NULL, socketWorkerThread, socketManager) != 0) {
socketManager->running = false;
socket_queue_shutdown(&socketManager->requestQueue);
socket_queue_shutdown(&socketManager->completionQueue);
for (int j = 0; j < i; ++j) {
pthread_join(socketManager->threads[j], NULL);
}
socket_queue_destroy(&socketManager->requestQueue);
socket_queue_destroy(&socketManager->completionQueue);
free(socketManager);
socketManager = NULL;
return;
}
#endif
}
}
static void socketManager_destroy() {
if (!socketManager) return;
socketManager->running = false;
socket_queue_shutdown(&socketManager->requestQueue);
socket_queue_shutdown(&socketManager->completionQueue);
for (int i = 0; i < 16; ++i) {
#ifdef _WIN32
if (socketManager->threads[i] != NULL) {
WaitForSingleObject(socketManager->threads[i], 5000);
CloseHandle(socketManager->threads[i]);
}
#else
pthread_join(socketManager->threads[i], NULL);
#endif
}
// Clean up remaining contexts
SocketContext* ctx;
while ((ctx = socket_queue_pop(&socketManager->requestQueue)) != NULL) {
free_socket_context(ctx);
}
while ((ctx = socket_queue_pop(&socketManager->completionQueue)) != NULL) {
free_socket_context(ctx);
}
socket_queue_destroy(&socketManager->requestQueue);
socket_queue_destroy(&socketManager->completionQueue);
free(socketManager);
socketManager = NULL;
#ifdef _WIN32
WSACleanup();
#endif
}
void socketManager_processCompletions() {
if (!socketManager || !socketManager->vm) return;
int processed = 0;
const int max_per_cycle = 100; // Process max 100 completions per cycle
while (!socket_queue_empty(&socketManager->completionQueue) && processed < max_per_cycle) {
SocketContext* context = socket_queue_pop(&socketManager->completionQueue);
if (context == NULL) continue;
processed++;
if (context->callback == NULL) {
free_socket_context(context);
continue;
}
WrenHandle* callHandle = wrenMakeCallHandle(socketManager->vm, "call(_,_)");
if (!callHandle) {
free_socket_context(context);
continue;
}
wrenEnsureSlots(socketManager->vm, 3);
wrenSetSlotHandle(socketManager->vm, 0, context->callback);
if (context->success) {
wrenSetSlotNull(socketManager->vm, 1);
switch(context->operation) {
case SOCKET_OP_CONNECT:
case SOCKET_OP_NEW:
case SOCKET_OP_ACCEPT:
wrenSetSlotDouble(socketManager->vm, 2, (double)context->new_sock);
break;
case SOCKET_OP_BIND:
case SOCKET_OP_LISTEN:
case SOCKET_OP_IS_READABLE:
wrenSetSlotBool(socketManager->vm, 2, context->result_bool);
break;
case SOCKET_OP_READ:
case SOCKET_OP_READ_UNTIL:
case SOCKET_OP_READ_EXACTLY:
if (context->read_data.data && context->read_data.length > 0) {
wrenSetSlotBytes(socketManager->vm, 2, context->read_data.data, context->read_data.length);
} else {
wrenSetSlotBytes(socketManager->vm, 2, "", 0);
}
break;
case SOCKET_OP_SELECT: {
wrenSetSlotNewList(socketManager->vm, 2);
for (int i = 0; i < context->readable_count; ++i) {
wrenSetSlotDouble(socketManager->vm, 0, (double)context->readable_sockets[i]);
wrenInsertInList(socketManager->vm, 2, -1, 0);
}
break;
}
case SOCKET_OP_WRITE:
case SOCKET_OP_CLOSE:
default:
wrenSetSlotNull(socketManager->vm, 2);
break;
}
} else {
wrenSetSlotString(socketManager->vm, 1,
context->error_message ? context->error_message : "Unknown error");
wrenSetSlotNull(socketManager->vm, 2);
}
wrenCall(socketManager->vm, callHandle);
wrenReleaseHandle(socketManager->vm, callHandle);
free_socket_context(context);
}
}
// --- Worker Thread Implementation ---
#ifdef _WIN32
static DWORD WINAPI socketWorkerThread(LPVOID arg) {
#else
static void* socketWorkerThread(void* arg) {
#endif
AsyncSocketManager* manager = (AsyncSocketManager*)arg;
if (!manager) return 0;
while (manager->running) {
SocketContext* context = socket_queue_pop(&manager->requestQueue);
if (!context || !manager->running) {
if (context) free_socket_context(context);
continue;
}
// Initialize result fields
context->success = false;
context->new_sock = INVALID_SOCKET_HANDLE;
context->result_bool = false;
context->readable_count = 0;
switch (context->operation) {
case SOCKET_OP_NEW: {
context->new_sock = socket(AF_INET, SOCK_STREAM, 0);
if (context->new_sock == INVALID_SOCKET_HANDLE) {
set_socket_context_error_errno(context, "Failed to create socket");
} else {
int opt_val = 1;
if (setsockopt(context->new_sock, SOL_SOCKET, SO_REUSEADDR,
(const char*)&opt_val, sizeof(opt_val)) < 0) {
set_socket_context_error_errno(context, "Failed to set SO_REUSEADDR");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
} else if (!set_socket_nonblocking(context->new_sock)) {
set_socket_context_error_errno(context, "Failed to set non-blocking");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
} else {
// Set TCP_NODELAY for better responsiveness
int nodelay = 1;
setsockopt(context->new_sock, IPPROTO_TCP, TCP_NODELAY,
(const char*)&nodelay, sizeof(nodelay));
context->success = true;
}
}
break;
}
case SOCKET_OP_CONNECT: {
struct sockaddr_in serv_addr = {0};
context->new_sock = socket(AF_INET, SOCK_STREAM, 0);
if (context->new_sock == INVALID_SOCKET_HANDLE) {
set_socket_context_error_errno(context, "Socket creation error");
break;
}
if (!set_socket_nonblocking(context->new_sock)) {
set_socket_context_error_errno(context, "Failed to set non-blocking");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
break;
}
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(context->port);
if (inet_pton(AF_INET, context->host, &serv_addr.sin_addr) <= 0) {
set_socket_context_error(context, "Invalid address");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
break;
}
int ret = connect(context->new_sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr));
if (ret < 0) {
#ifdef _WIN32
if (WSAGetLastError() != WSAEWOULDBLOCK) {
#else
if (errno != EINPROGRESS) {
#endif
set_socket_context_error_errno(context, "Connection failed");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
} else {
// Connection in progress, wait for completion
struct pollfd pfd = {context->new_sock, POLLOUT, 0};
int poll_ret = poll(&pfd, 1, 5000); // 5 second timeout
if (poll_ret <= 0) {
set_socket_context_error(context, "Connection timeout");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
} else {
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(context->new_sock, SOL_SOCKET, SO_ERROR,
(char*)&error, &len) < 0 || error != 0) {
set_socket_context_error(context, "Connection failed");
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
} else {
context->success = true;
}
}
}
} else {
context->success = true;
}
break;
}
case SOCKET_OP_BIND: {
struct sockaddr_in address = {0};
address.sin_family = AF_INET;
address.sin_port = htons(context->port);
if (context->host && strlen(context->host) > 0) {
if (inet_pton(AF_INET, context->host, &address.sin_addr) <= 0) {
set_socket_context_error(context, "Invalid bind address");
break;
}
} else {
address.sin_addr.s_addr = INADDR_ANY;
}
if (bind(context->sock, (struct sockaddr *)&address, sizeof(address)) < 0) {
set_socket_context_error_errno(context, "Bind failed");
} else {
context->success = true;
context->result_bool = true;
}
break;
}
case SOCKET_OP_LISTEN: {
if (listen(context->sock, context->backlog) < 0) {
set_socket_context_error_errno(context, "Listen failed");
} else {
context->success = true;
context->result_bool = true;
}
break;
}
case SOCKET_OP_ACCEPT: {
struct sockaddr_in address = {0};
socklen_t addrlen = sizeof(address);
context->new_sock = accept(context->sock, (struct sockaddr *)&address, &addrlen);
if (context->new_sock == INVALID_SOCKET_HANDLE) {
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK) {
#else
if (errno == EAGAIN || errno == EWOULDBLOCK) {
#endif
set_socket_context_error(context, "Would block");
} else {
set_socket_context_error_errno(context, "Accept failed");
}
} else {
if (!set_socket_nonblocking(context->new_sock)) {
closesocket(context->new_sock);
context->new_sock = INVALID_SOCKET_HANDLE;
set_socket_context_error(context, "Failed to set non-blocking on accepted socket");
} else {
context->success = true;
}
}
break;
}
case SOCKET_OP_READ: {
if (context->length <= 0 || context->length > MAX_BUFFER_SIZE) {
set_socket_context_error(context, "Invalid read length");
break;
}
size_t actual_size = context->length < CHUNK_SIZE ? context->length : CHUNK_SIZE;
char* buffer = (char*)malloc(actual_size);
if (!buffer) {
set_socket_context_error(context, "Out of memory");
break;
}
int valread = recv(context->sock, buffer, actual_size, 0);
if (valread > 0) {
context->read_data.data = buffer;
context->read_data.length = valread;
context->read_data.capacity = actual_size;
context->success = true;
} else if (valread == 0) {
free(buffer);
context->read_data.data = NULL;
context->read_data.length = 0;
context->success = true; // EOF is success with 0 bytes
} else {
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK) {
#else
if (errno == EAGAIN || errno == EWOULDBLOCK) {
#endif
free(buffer);
context->read_data.data = NULL;
context->read_data.length = 0;
context->success = true; // Would block, return empty
} else {
set_socket_context_error_errno(context, "Read failed");
free(buffer);
}
}
break;
}
case SOCKET_OP_READ_UNTIL: {
if (!context->until_bytes || context->until_len <= 0 || context->until_len > 256) {
set_socket_context_error(context, "Invalid until bytes");
break;
}
Buffer buf = {0};
buf.capacity = CHUNK_SIZE;
buf.data = (char*)malloc(buf.capacity);
if (!buf.data) {
set_socket_context_error(context, "Out of memory");
break;
}
bool found = false;
bool error = false;
while (!found && !error && buf.length < MAX_READ_UNTIL) {
// Make sure we have room for more data
if (buf.length + CHUNK_SIZE > buf.capacity) {
size_t new_cap = buf.capacity * 2;
if (new_cap > MAX_READ_UNTIL) new_cap = MAX_READ_UNTIL;
char* new_data = (char*)realloc(buf.data, new_cap);
if (!new_data) {
set_socket_context_error(context, "Out of memory");
error = true;
break;
}
buf.data = new_data;
buf.capacity = new_cap;
}
size_t read_size = buf.capacity - buf.length;
if (read_size > CHUNK_SIZE) read_size = CHUNK_SIZE;
int valread = recv(context->sock, buf.data + buf.length, read_size, 0);
if (valread > 0) {
buf.length += valread;
// Search for until_bytes
if (buf.length >= context->until_len) {
for (size_t i = 0; i <= buf.length - context->until_len; ++i) {
if (memcmp(buf.data + i, context->until_bytes, context->until_len) == 0) {
context->read_data.data = buf.data;
context->read_data.length = i + context->until_len;
context->read_data.capacity = buf.capacity;
found = true;
buf.data = NULL; // Transfer ownership
break;
}
}
}
} else if (valread == 0) {
set_socket_context_error(context, "Unexpected EOF");
error = true;
} else {
#ifdef _WIN32
if (WSAGetLastError() != WSAEWOULDBLOCK) {
#else
if (errno != EAGAIN && errno != EWOULDBLOCK) {
#endif
set_socket_context_error_errno(context, "Read failed");
error = true;
} else {
// Would block, wait a bit
#ifdef _WIN32
Sleep(1);
#else
usleep(1000);
#endif
}
}
}
if (!found && !error && buf.length >= MAX_READ_UNTIL) {
set_socket_context_error(context, "Data exceeds limit");
error = true;
}
if (buf.data) free(buf.data);
if (found) context->success = true;
break;
}
case SOCKET_OP_READ_EXACTLY: {
if (context->length <= 0 || context->length > MAX_BUFFER_SIZE) {
set_socket_context_error(context, "Invalid read length");
break;
}
char* buffer = (char*)malloc(context->length);
if (!buffer) {
set_socket_context_error(context, "Out of memory");
break;
}
size_t total_read = 0;
int retries = 0;
const int max_retries = 1000;
while (total_read < context->length && retries < max_retries) {
size_t to_read = context->length - total_read;
if (to_read > CHUNK_SIZE) to_read = CHUNK_SIZE;
int valread = recv(context->sock, buffer + total_read, to_read, 0);
if (valread > 0) {
total_read += valread;
retries = 0; // Reset retries on successful read
} else if (valread == 0) {
set_socket_context_error(context, "Unexpected EOF");
free(buffer);
total_read = 0;
break;
} else {
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK) {
#else
if (errno == EAGAIN || errno == EWOULDBLOCK) {
#endif
retries++;
#ifdef _WIN32
Sleep(1);
#else
usleep(1000);
#endif
} else {
set_socket_context_error_errno(context, "Read failed");
free(buffer);
total_read = 0;
break;
}
}
}
if (total_read == context->length) {
context->read_data.data = buffer;
context->read_data.length = total_read;
context->read_data.capacity = context->length;
context->success = true;
} else if (retries >= max_retries) {
set_socket_context_error(context, "Read timeout");
free(buffer);
}
break;
}
case SOCKET_OP_WRITE: {
if (!context->write_data.data || context->write_data.length <= 0) {
set_socket_context_error(context, "Invalid write data");
break;
}
if (context->write_data.length > MAX_BUFFER_SIZE) {
set_socket_context_error(context, "Write data too large");
break;
}
size_t total_sent = 0;
int retries = 0;
const int max_retries = 1000;
while (total_sent < context->write_data.length && retries < max_retries) {
size_t to_send = context->write_data.length - total_sent;
if (to_send > CHUNK_SIZE) to_send = CHUNK_SIZE;
int sent = send(context->sock, context->write_data.data + total_sent, to_send,
#ifdef _WIN32
0
#else
MSG_NOSIGNAL
#endif
);
if (sent > 0) {
total_sent += sent;
retries = 0;
} else if (sent == 0) {
retries++;
#ifdef _WIN32
Sleep(1);
#else
usleep(1000);
#endif
} else {
#ifdef _WIN32
if (WSAGetLastError() == WSAEWOULDBLOCK) {
#else
if (errno == EAGAIN || errno == EWOULDBLOCK) {
#endif
retries++;
#ifdef _WIN32
Sleep(1);
#else
usleep(1000);
#endif
} else {
set_socket_context_error_errno(context, "Write failed");
break;
}
}
}
if (total_sent == context->write_data.length) {
context->success = true;
} else if (retries >= max_retries) {
set_socket_context_error(context, "Write timeout");
}
break;
}
case SOCKET_OP_IS_READABLE: {
struct pollfd pfd = {context->sock, POLLIN, 0};
int result = poll(&pfd, 1, 0);
if (result < 0) {
set_socket_context_error_errno(context, "Poll failed");
} else {
context->success = true;
context->result_bool = (pfd.revents & POLLIN) != 0;
}
break;
}
case SOCKET_OP_SELECT: {
if (!context->sockets || context->sockets_count <= 0 || context->sockets_count > 1024) {
set_socket_context_error(context, "Invalid sockets list");
break;
}
struct pollfd* pfds = (struct pollfd*)calloc(context->sockets_count, sizeof(struct pollfd));
if (!pfds) {
set_socket_context_error(context, "Out of memory");
break;
}
int valid_fds = 0;
for (int i = 0; i < context->sockets_count; ++i) {
socket_t fd = context->sockets[i];
if (fd != INVALID_SOCKET_HANDLE) {
pfds[i].fd = fd;
pfds[i].events = POLLIN;
pfds[i].revents = 0;
valid_fds++;
} else {
pfds[i].fd = -1;
}
}
if (valid_fds == 0) {
set_socket_context_error(context, "No valid sockets");
free(pfds);
break;
}
int result = poll(pfds, context->sockets_count, 100); // 100ms timeout
if (result < 0) {
set_socket_context_error_errno(context, "Poll failed");
free(pfds);
break;
}
int rcount = 0;
for (int i = 0; i < context->sockets_count; ++i) {
if (pfds[i].fd != -1 && (pfds[i].revents & POLLIN)) {
rcount++;
}
}
if (rcount > 0) {
context->readable_sockets = (socket_t*)malloc(rcount * sizeof(socket_t));
if (!context->readable_sockets) {
set_socket_context_error(context, "Out of memory");
free(pfds);
break;
}
int j = 0;
for (int i = 0; i < context->sockets_count && j < rcount; ++i) {
if (pfds[i].fd != -1 && (pfds[i].revents & POLLIN)) {
context->readable_sockets[j++] = pfds[i].fd;
}
}
context->readable_count = j;
} else {
context->readable_count = 0;
}
free(pfds);
context->success = true;
break;
}
case SOCKET_OP_CLOSE: {
if (context->sock != INVALID_SOCKET_HANDLE) {
closesocket(context->sock);
}
context->success = true;
break;
}
default:
set_socket_context_error(context, "Unsupported operation");
break;
}
socket_queue_push(&manager->completionQueue, context);
}
return 0;
}
// --- Wren FFI Functions ---
static void create_socket_context(WrenVM* vm, SocketOp op) {
if (!vm || !socketManager) {
wrenSetSlotString(vm, 0, "Socket manager not initialized");
wrenAbortFiber(vm, 0);
return;
}
SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext));
if (!context) {
wrenSetSlotString(vm, 0, "Out of memory");
wrenAbortFiber(vm, 0);
return;
}
context->vm = vm;
context->operation = op;
context->sock = INVALID_SOCKET_HANDLE;
context->new_sock = INVALID_SOCKET_HANDLE;
bool valid = true;
// Validate and extract parameters based on operation
switch(op) {
case SOCKET_OP_CONNECT:
if (wrenGetSlotType(vm, 1) != WREN_TYPE_STRING) {
valid = false;
break;
}
context->host = strdup(wrenGetSlotString(vm, 1));
if (!context->host) valid = false;
context->port = (int)wrenGetSlotDouble(vm, 2);
if (context->port < 0 || context->port > 65535) valid = false;
context->callback = wrenGetSlotHandle(vm, 3);
break;
case SOCKET_OP_NEW:
// Get callback from slot 1
if (wrenGetSlotType(vm, 1) == WREN_TYPE_NULL) {
free_socket_context(context);
wrenSetSlotString(vm, 0, "Callback cannot be null");
wrenAbortFiber(vm, 0);
return;
}
context->callback = wrenGetSlotHandle(vm, 1);
if (!context->callback) {
free_socket_context(context);
wrenSetSlotString(vm, 0, "Failed to get callback handle");
wrenAbortFiber(vm, 0);
return;
}
break;
case SOCKET_OP_BIND:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
if (wrenGetSlotType(vm, 2) != WREN_TYPE_STRING) {
valid = false;
break;
}
context->host = strdup(wrenGetSlotString(vm, 2));
context->port = (int)wrenGetSlotDouble(vm, 3);
if (context->port < 0 || context->port > 65535) valid = false;
context->callback = wrenGetSlotHandle(vm, 4);
break;
case SOCKET_OP_LISTEN:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
context->backlog = (int)wrenGetSlotDouble(vm, 2);
if (context->backlog < 0 || context->backlog > 1024) context->backlog = 128;
context->callback = wrenGetSlotHandle(vm, 3);
break;
case SOCKET_OP_ACCEPT:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
context->callback = wrenGetSlotHandle(vm, 2);
break;
case SOCKET_OP_READ:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
context->length = (size_t)wrenGetSlotDouble(vm, 2);
if (context->length <= 0 || context->length > MAX_BUFFER_SIZE) valid = false;
context->callback = wrenGetSlotHandle(vm, 3);
break;
case SOCKET_OP_READ_UNTIL: {
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
int ulen;
const char* udata = wrenGetSlotBytes(vm, 2, &ulen);
if (!udata || ulen <= 0 || ulen > 256) {
valid = false;
break;
}
context->until_bytes = (char*)malloc(ulen);
if (context->until_bytes) {
memcpy(context->until_bytes, udata, ulen);
context->until_len = ulen;
} else {
valid = false;
}
context->callback = wrenGetSlotHandle(vm, 3);
break;
}
case SOCKET_OP_READ_EXACTLY:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
context->length = (size_t)wrenGetSlotDouble(vm, 2);
if (context->length <= 0 || context->length > MAX_BUFFER_SIZE) valid = false;
context->callback = wrenGetSlotHandle(vm, 3);
break;
case SOCKET_OP_WRITE: {
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
int len;
const char* data = wrenGetSlotBytes(vm, 2, &len);
if (!data || len <= 0 || len > MAX_BUFFER_SIZE) {
valid = false;
break;
}
context->write_data.data = (char*)malloc(len);
if (context->write_data.data) {
memcpy(context->write_data.data, data, len);
context->write_data.length = len;
context->write_data.capacity = len;
} else {
valid = false;
}
context->callback = wrenGetSlotHandle(vm, 3);
break;
}
case SOCKET_OP_IS_READABLE:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
context->callback = wrenGetSlotHandle(vm, 2);
break;
case SOCKET_OP_SELECT: {
if (wrenGetSlotType(vm, 1) != WREN_TYPE_LIST) {
valid = false;
break;
}
int count = wrenGetListCount(vm, 1);
if (count <= 0 || count > 1024) {
valid = false;
break;
}
context->sockets = (socket_t*)malloc(count * sizeof(socket_t));
if (!context->sockets) {
valid = false;
} else {
context->sockets_count = count;
for (int i = 0; i < count; ++i) {
wrenGetListElement(vm, 1, i, 0);
context->sockets[i] = (socket_t)wrenGetSlotDouble(vm, 0);
}
}
context->callback = wrenGetSlotHandle(vm, 2);
break;
}
case SOCKET_OP_CLOSE:
context->sock = (socket_t)wrenGetSlotDouble(vm, 1);
context->callback = wrenGetSlotHandle(vm, 2);
break;
default:
valid = false;
break;
}
if (!valid || !context->callback) {
free_socket_context(context);
wrenSetSlotString(vm, 0, "Invalid context creation");
wrenAbortFiber(vm, 0);
return;
}
socket_queue_push(&socketManager->requestQueue, context);
}
static void socketSetNonBlocking(WrenVM* vm) {
socket_t fd = (socket_t)wrenGetSlotDouble(vm, 0);
bool flag = wrenGetSlotBool(vm, 1);
#ifdef _WIN32
u_long mode = flag ? 1 : 0;
wrenSetSlotBool(vm, 0, ioctlsocket(fd, FIONBIO, &mode) == 0);
#else
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
wrenSetSlotBool(vm, 0, false);
return;
}
if (flag) {
flags |= O_NONBLOCK;
} else {
flags &= ~O_NONBLOCK;
}
wrenSetSlotBool(vm, 0, fcntl(fd, F_SETFL, flags) != -1);
#endif
}
static void socketConnect(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_CONNECT); }
static void socketNew(WrenVM* vm) {
if (!vm || !socketManager) {
wrenSetSlotString(vm, 0, "Socket manager not initialized");
wrenAbortFiber(vm, 0);
return;
}
// Validate callback parameter
if (wrenGetSlotType(vm, 1) != WREN_TYPE_NULL && wrenGetSlotHandle(vm, 1) == NULL) {
wrenSetSlotString(vm, 0, "Invalid callback parameter");
wrenAbortFiber(vm, 0);
return;
}
create_socket_context(vm, SOCKET_OP_NEW);
}
static void socketBind(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_BIND); }
static void socketListen(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_LISTEN); }
static void socketAccept(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_ACCEPT); }
static void socketRead(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_READ); }
static void socketReadUntil(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_READ_UNTIL); }
static void socketReadExactly(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_READ_EXACTLY); }
static void socketWrite(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_WRITE); }
static void socketIsReadable(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_IS_READABLE); }
static void socketSelect(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_SELECT); }
static void socketClose(WrenVM* vm) { create_socket_context(vm, SOCKET_OP_CLOSE); }
WrenForeignMethodFn bindSocketForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) {
if (strcmp(module, "socket") != 0) return NULL;
if (strcmp(className, "Socket") == 0 && isStatic) {
if (strcmp(signature, "connect_(_,_,_)") == 0) return socketConnect;
if (strcmp(signature, "new_(_)") == 0) return socketNew;
if (strcmp(signature, "bind_(_,_,_,_)") == 0) return socketBind;
if (strcmp(signature, "listen_(_,_,_)") == 0) return socketListen;
if (strcmp(signature, "accept_(_,_)") == 0) return socketAccept;
if (strcmp(signature, "read_(_,_,_)") == 0) return socketRead;
if (strcmp(signature, "readUntil_(_,_,_)") == 0) return socketReadUntil;
if (strcmp(signature, "readExactly_(_,_,_)") == 0) return socketReadExactly;
if (strcmp(signature, "write_(_,_,_)") == 0) return socketWrite;
if (strcmp(signature, "isReadable_(_,_)") == 0) return socketIsReadable;
if (strcmp(signature, "select_(_,_)") == 0) return socketSelect;
if (strcmp(signature, "setNonBlocking(_,_)") == 0) return socketSetNonBlocking;
if (strcmp(signature, "close_(_,_)") == 0) return socketClose;
}
return NULL;
}
WrenForeignClassMethods bindSocketForeignClass(WrenVM* vm, const char* module, const char* className) {
WrenForeignClassMethods methods = {NULL, NULL};
return methods;
}