|
// Start of socket_backend.c
|
|
// socket_backend.c (Corrected with better handle safety and non-blocking I/O)
|
|
#include "wren.h"
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
#include <stdbool.h>
|
|
#include <time.h>
|
|
|
|
// Platform-specific includes and definitions
|
|
#ifdef _WIN32
|
|
#include <winsock2.h>
|
|
#include <ws2tcpip.h>
|
|
#include <windows.h>
|
|
#pragma comment(lib, "ws2_32.lib")
|
|
typedef SOCKET socket_t;
|
|
typedef int socklen_t;
|
|
typedef HANDLE thread_t;
|
|
typedef CRITICAL_SECTION mutex_t;
|
|
typedef CONDITION_VARIABLE cond_t;
|
|
#define IS_SOCKET_VALID(s) ((s) != INVALID_SOCKET)
|
|
#define CLOSE_SOCKET(s) closesocket(s)
|
|
#else
|
|
#include <pthread.h>
|
|
#include <sys/socket.h>
|
|
#include <netinet/in.h>
|
|
#include <arpa/inet.h>
|
|
#include <unistd.h>
|
|
#include <fcntl.h>
|
|
#include <netdb.h>
|
|
#include <errno.h>
|
|
#include <sys/select.h>
|
|
typedef int socket_t;
|
|
typedef pthread_t thread_t;
|
|
typedef pthread_mutex_t mutex_t;
|
|
typedef pthread_cond_t cond_t;
|
|
#define INVALID_SOCKET -1
|
|
#define IS_SOCKET_VALID(s) ((s) >= 0)
|
|
#define CLOSE_SOCKET(s) close(s)
|
|
#endif
|
|
|
|
// --- Forward Declarations ---
|
|
typedef struct SocketContext SocketContext;
|
|
|
|
// --- Socket Data Structures ---
|
|
|
|
typedef enum {
|
|
SOCKET_OP_CONNECT,
|
|
SOCKET_OP_READ,
|
|
SOCKET_OP_WRITE,
|
|
} SocketOp;
|
|
|
|
typedef struct {
|
|
socket_t sock;
|
|
bool isListener;
|
|
} SocketData;
|
|
|
|
struct SocketContext {
|
|
SocketOp operation;
|
|
WrenVM* vm;
|
|
WrenHandle* socketHandle;
|
|
WrenHandle* callback;
|
|
|
|
char* host;
|
|
int port;
|
|
char* data;
|
|
size_t dataLength;
|
|
|
|
bool success;
|
|
char* resultData;
|
|
size_t resultDataLength;
|
|
char* errorMessage;
|
|
socket_t newSocket;
|
|
struct SocketContext* next;
|
|
};
|
|
|
|
// --- Thread-Safe Queue Implementation in C ---
|
|
typedef struct {
|
|
SocketContext *head, *tail;
|
|
mutex_t mutex;
|
|
cond_t cond;
|
|
} ThreadSafeQueueSocket;
|
|
|
|
void queue_init(ThreadSafeQueueSocket* q) {
|
|
q->head = q->tail = NULL;
|
|
#ifdef _WIN32
|
|
InitializeCriticalSection(&q->mutex);
|
|
InitializeConditionVariable(&q->cond);
|
|
#else
|
|
pthread_mutex_init(&q->mutex, NULL);
|
|
pthread_cond_init(&q->cond, NULL);
|
|
#endif
|
|
}
|
|
|
|
void queue_destroy(ThreadSafeQueueSocket* q) {
|
|
#ifdef _WIN32
|
|
DeleteCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_destroy(&q->mutex);
|
|
pthread_cond_destroy(&q->cond);
|
|
#endif
|
|
}
|
|
|
|
void queue_push(ThreadSafeQueueSocket* q, SocketContext* context) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_lock(&q->mutex);
|
|
#endif
|
|
|
|
if (context) {
|
|
context->next = NULL;
|
|
}
|
|
|
|
if (q->tail) {
|
|
q->tail->next = context;
|
|
} else {
|
|
q->head = context;
|
|
}
|
|
q->tail = context;
|
|
|
|
#ifdef _WIN32
|
|
WakeConditionVariable(&q->cond);
|
|
LeaveCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_cond_signal(&q->cond);
|
|
pthread_mutex_unlock(&q->mutex);
|
|
#endif
|
|
}
|
|
|
|
SocketContext* queue_pop(ThreadSafeQueueSocket* q) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&q->mutex);
|
|
while (q->head == NULL) {
|
|
SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE);
|
|
}
|
|
#else
|
|
pthread_mutex_lock(&q->mutex);
|
|
while (q->head == NULL) {
|
|
pthread_cond_wait(&q->cond, &q->mutex);
|
|
}
|
|
#endif
|
|
|
|
SocketContext* context = q->head;
|
|
q->head = q->head->next;
|
|
if (q->head == NULL) {
|
|
q->tail = NULL;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
LeaveCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_unlock(&q->mutex);
|
|
#endif
|
|
|
|
return context;
|
|
}
|
|
|
|
bool queue_empty(ThreadSafeQueueSocket* q) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&q->mutex);
|
|
bool empty = (q->head == NULL);
|
|
LeaveCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_lock(&q->mutex);
|
|
bool empty = (q->head == NULL);
|
|
pthread_mutex_unlock(&q->mutex);
|
|
#endif
|
|
return empty;
|
|
}
|
|
|
|
// --- Asynchronous Socket Manager ---
|
|
|
|
#define MAX_LISTENERS 64
|
|
|
|
typedef struct {
|
|
WrenVM* vm;
|
|
volatile bool running;
|
|
thread_t worker_threads[4];
|
|
thread_t listener_thread;
|
|
|
|
ThreadSafeQueueSocket requestQueue;
|
|
ThreadSafeQueueSocket completionQueue;
|
|
ThreadSafeQueueSocket acceptQueue;
|
|
|
|
mutex_t listener_mutex;
|
|
socket_t listener_sockets[MAX_LISTENERS];
|
|
int listener_count;
|
|
#ifndef _WIN32
|
|
socket_t wake_pipe[2];
|
|
#endif
|
|
} AsyncSocketManager;
|
|
|
|
static AsyncSocketManager* socketManager = NULL;
|
|
|
|
void free_socket_context_data(SocketContext* context) {
|
|
if (!context) return;
|
|
free(context->host);
|
|
free(context->data);
|
|
free(context->resultData);
|
|
free(context->errorMessage);
|
|
free(context);
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
DWORD WINAPI workerThread(LPVOID arg);
|
|
DWORD WINAPI listenerThread(LPVOID arg);
|
|
#else
|
|
void* workerThread(void* arg);
|
|
void* listenerThread(void* arg);
|
|
#endif
|
|
|
|
// --- Worker and Listener Thread Implementations ---
|
|
|
|
#ifdef _WIN32
|
|
DWORD WINAPI listenerThread(LPVOID arg) {
|
|
#else
|
|
void* listenerThread(void* arg) {
|
|
#endif
|
|
AsyncSocketManager* manager = (AsyncSocketManager*)arg;
|
|
|
|
while (manager->running) {
|
|
fd_set read_fds;
|
|
FD_ZERO(&read_fds);
|
|
|
|
socket_t max_fd = 0;
|
|
|
|
#ifndef _WIN32
|
|
FD_SET(manager->wake_pipe[0], &read_fds);
|
|
max_fd = manager->wake_pipe[0];
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&manager->listener_mutex);
|
|
#else
|
|
pthread_mutex_lock(&manager->listener_mutex);
|
|
#endif
|
|
|
|
for (int i = 0; i < manager->listener_count; i++) {
|
|
socket_t sock = manager->listener_sockets[i];
|
|
if (IS_SOCKET_VALID(sock)) {
|
|
FD_SET(sock, &read_fds);
|
|
if (sock > max_fd) {
|
|
max_fd = sock;
|
|
}
|
|
}
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
LeaveCriticalSection(&manager->listener_mutex);
|
|
#else
|
|
pthread_mutex_unlock(&manager->listener_mutex);
|
|
#endif
|
|
|
|
struct timeval timeout;
|
|
timeout.tv_sec = 1;
|
|
timeout.tv_usec = 0;
|
|
|
|
int activity = select(max_fd + 1, &read_fds, NULL, NULL, &timeout);
|
|
|
|
if (!manager->running) break;
|
|
if (activity < 0) {
|
|
#ifndef _WIN32
|
|
if (errno != EINTR) {
|
|
perror("select error");
|
|
}
|
|
#endif
|
|
continue;
|
|
}
|
|
if (activity == 0) continue;
|
|
|
|
#ifndef _WIN32
|
|
if (FD_ISSET(manager->wake_pipe[0], &read_fds)) {
|
|
char buffer[1];
|
|
read(manager->wake_pipe[0], buffer, 1);
|
|
}
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&manager->listener_mutex);
|
|
#else
|
|
pthread_mutex_lock(&manager->listener_mutex);
|
|
#endif
|
|
|
|
for (int i = 0; i < manager->listener_count; i++) {
|
|
socket_t sock = manager->listener_sockets[i];
|
|
if (IS_SOCKET_VALID(sock) && FD_ISSET(sock, &read_fds)) {
|
|
if (!queue_empty(&manager->acceptQueue)) {
|
|
SocketContext* context = queue_pop(&manager->acceptQueue);
|
|
context->newSocket = accept(sock, NULL, NULL);
|
|
context->success = IS_SOCKET_VALID(context->newSocket);
|
|
if (!context->success) {
|
|
context->errorMessage = strdup("Accept failed.");
|
|
}
|
|
queue_push(&manager->completionQueue, context);
|
|
}
|
|
}
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
LeaveCriticalSection(&manager->listener_mutex);
|
|
#else
|
|
pthread_mutex_unlock(&manager->listener_mutex);
|
|
#endif
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
DWORD WINAPI workerThread(LPVOID arg) {
|
|
#else
|
|
void* workerThread(void* arg) {
|
|
#endif
|
|
AsyncSocketManager* manager = (AsyncSocketManager*)arg;
|
|
|
|
while (manager->running) {
|
|
SocketContext* context = queue_pop(&manager->requestQueue);
|
|
if (!context || !manager->running) {
|
|
if (context) free_socket_context_data(context);
|
|
break;
|
|
}
|
|
|
|
wrenEnsureSlots(context->vm, 1);
|
|
wrenSetSlotHandle(context->vm, 0, context->socketHandle);
|
|
SocketData* socketData = (wrenGetSlotType(context->vm, 0) == WREN_TYPE_FOREIGN)
|
|
? (SocketData*)wrenGetSlotForeign(context->vm, 0)
|
|
: NULL;
|
|
|
|
if (!socketData) {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Invalid socket object.");
|
|
queue_push(&manager->completionQueue, context);
|
|
continue;
|
|
}
|
|
|
|
switch (context->operation) {
|
|
case SOCKET_OP_CONNECT: {
|
|
struct addrinfo hints = {0}, *addrs;
|
|
hints.ai_family = AF_UNSPEC;
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
char port_str[6];
|
|
snprintf(port_str, 6, "%d", context->port);
|
|
if (getaddrinfo(context->host, port_str, &hints, &addrs) != 0) {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Host lookup failed.");
|
|
break;
|
|
}
|
|
|
|
socket_t sock = INVALID_SOCKET;
|
|
for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) {
|
|
sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
|
|
if (!IS_SOCKET_VALID(sock)) continue;
|
|
if (connect(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break;
|
|
CLOSE_SOCKET(sock);
|
|
sock = INVALID_SOCKET;
|
|
}
|
|
freeaddrinfo(addrs);
|
|
|
|
if (IS_SOCKET_VALID(sock)) {
|
|
socketData->sock = sock;
|
|
socketData->isListener = false;
|
|
context->success = true;
|
|
} else {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Connection failed.");
|
|
}
|
|
break;
|
|
}
|
|
case SOCKET_OP_READ: {
|
|
if (socketData->isListener) {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Cannot read from a listening socket.");
|
|
break;
|
|
}
|
|
|
|
fd_set read_fds;
|
|
FD_ZERO(&read_fds);
|
|
FD_SET(socketData->sock, &read_fds);
|
|
struct timeval timeout = { .tv_sec = 5, .tv_usec = 0 }; // 5-second timeout
|
|
|
|
int activity = select(socketData->sock + 1, &read_fds, NULL, NULL, &timeout);
|
|
if (activity > 0 && FD_ISSET(socketData->sock, &read_fds)) {
|
|
char buf[4096];
|
|
ssize_t len = recv(socketData->sock, buf, sizeof(buf), 0);
|
|
if (len > 0) {
|
|
context->resultData = (char*)malloc(len);
|
|
memcpy(context->resultData, buf, len);
|
|
context->resultDataLength = len;
|
|
context->success = true;
|
|
} else {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Read failed or connection closed.");
|
|
}
|
|
} else {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Read timeout or error.");
|
|
}
|
|
break;
|
|
}
|
|
case SOCKET_OP_WRITE: {
|
|
if (socketData->isListener) {
|
|
context->success = false;
|
|
context->errorMessage = strdup("Cannot write to a listening socket.");
|
|
break;
|
|
}
|
|
ssize_t written = send(socketData->sock, context->data, context->dataLength, 0);
|
|
context->success = (written == (ssize_t)context->dataLength);
|
|
if(!context->success) context->errorMessage = strdup("Write failed.");
|
|
break;
|
|
}
|
|
}
|
|
queue_push(&manager->completionQueue, context);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// --- Manager Lifecycle ---
|
|
|
|
void socketManager_create(WrenVM* vm) {
|
|
socketManager = (AsyncSocketManager*)malloc(sizeof(AsyncSocketManager));
|
|
socketManager->vm = vm;
|
|
socketManager->running = true;
|
|
socketManager->listener_count = 0;
|
|
|
|
queue_init(&socketManager->requestQueue);
|
|
queue_init(&socketManager->completionQueue);
|
|
queue_init(&socketManager->acceptQueue);
|
|
|
|
#ifdef _WIN32
|
|
InitializeCriticalSection(&socketManager->listener_mutex);
|
|
#else
|
|
pthread_mutex_init(&socketManager->listener_mutex, NULL);
|
|
#endif
|
|
|
|
#ifndef _WIN32
|
|
if (pipe(socketManager->wake_pipe) == -1) {
|
|
perror("pipe");
|
|
exit(1);
|
|
}
|
|
#endif
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
#ifdef _WIN32
|
|
socketManager->worker_threads[i] = CreateThread(NULL, 0, workerThread, socketManager, 0, NULL);
|
|
#else
|
|
pthread_create(&socketManager->worker_threads[i], NULL, workerThread, socketManager);
|
|
#endif
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
socketManager->listener_thread = CreateThread(NULL, 0, listenerThread, socketManager, 0, NULL);
|
|
#else
|
|
pthread_create(&socketManager->listener_thread, NULL, listenerThread, socketManager);
|
|
#endif
|
|
}
|
|
|
|
void socketManager_destroy() {
|
|
socketManager->running = false;
|
|
|
|
#ifndef _WIN32
|
|
write(socketManager->wake_pipe[1], "w", 1);
|
|
#endif
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
queue_push(&socketManager->requestQueue, NULL);
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
WaitForSingleObject(socketManager->listener_thread, INFINITE);
|
|
CloseHandle(socketManager->listener_thread);
|
|
for (int i = 0; i < 4; i++) {
|
|
WaitForSingleObject(socketManager->worker_threads[i], INFINITE);
|
|
CloseHandle(socketManager->worker_threads[i]);
|
|
}
|
|
#else
|
|
pthread_join(socketManager->listener_thread, NULL);
|
|
for (int i = 0; i < 4; i++) {
|
|
pthread_join(socketManager->worker_threads[i], NULL);
|
|
}
|
|
close(socketManager->wake_pipe[0]);
|
|
close(socketManager->wake_pipe[1]);
|
|
#endif
|
|
|
|
queue_destroy(&socketManager->requestQueue);
|
|
queue_destroy(&socketManager->completionQueue);
|
|
queue_destroy(&socketManager->acceptQueue);
|
|
|
|
#ifdef _WIN32
|
|
DeleteCriticalSection(&socketManager->listener_mutex);
|
|
#else
|
|
pthread_mutex_destroy(&socketManager->listener_mutex);
|
|
#endif
|
|
|
|
free(socketManager);
|
|
}
|
|
|
|
void socketManager_processCompletions() {
|
|
WrenHandle* callHandle = wrenMakeCallHandle(socketManager->vm, "call(_,_)");
|
|
while (!queue_empty(&socketManager->completionQueue)) {
|
|
SocketContext* context = queue_pop(&socketManager->completionQueue);
|
|
|
|
wrenEnsureSlots(socketManager->vm, 3);
|
|
wrenSetSlotHandle(socketManager->vm, 0, context->callback);
|
|
if (context->success) {
|
|
wrenSetSlotNull(socketManager->vm, 1);
|
|
if (IS_SOCKET_VALID(context->newSocket)) {
|
|
wrenGetVariable(socketManager->vm, "socket", "Socket", 2);
|
|
void* foreign = wrenSetSlotNewForeign(socketManager->vm, 2, 2, sizeof(SocketData));
|
|
SocketData* clientData = (SocketData*)foreign;
|
|
clientData->sock = context->newSocket;
|
|
clientData->isListener = false;
|
|
} else if (context->resultData) {
|
|
wrenSetSlotBytes(socketManager->vm, 2, context->resultData, context->resultDataLength);
|
|
} else {
|
|
wrenSetSlotNull(socketManager->vm, 2);
|
|
}
|
|
} else {
|
|
wrenSetSlotString(socketManager->vm, 1, context->errorMessage ? context->errorMessage : "Unknown error.");
|
|
wrenSetSlotNull(socketManager->vm, 2);
|
|
}
|
|
|
|
wrenCall(socketManager->vm, callHandle);
|
|
|
|
// Safely release handles here on the main thread
|
|
wrenReleaseHandle(socketManager->vm, context->socketHandle);
|
|
wrenReleaseHandle(socketManager->vm, context->callback);
|
|
free_socket_context_data(context);
|
|
}
|
|
wrenReleaseHandle(socketManager->vm, callHandle);
|
|
}
|
|
|
|
// ... (The rest of the foreign functions from socketAllocate onwards are identical to the previous response) ...
|
|
void socketAllocate(WrenVM* vm) {
|
|
SocketData* data = (SocketData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(SocketData));
|
|
data->sock = INVALID_SOCKET;
|
|
data->isListener = false;
|
|
}
|
|
|
|
void socketConnect(WrenVM* vm) {
|
|
SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext));
|
|
context->operation = SOCKET_OP_CONNECT;
|
|
context->vm = vm;
|
|
context->socketHandle = wrenGetSlotHandle(vm, 0);
|
|
context->host = strdup(wrenGetSlotString(vm, 1));
|
|
context->port = (int)wrenGetSlotDouble(vm, 2);
|
|
context->callback = wrenGetSlotHandle(vm, 3);
|
|
queue_push(&socketManager->requestQueue, context);
|
|
}
|
|
|
|
void socketListen(WrenVM* vm) {
|
|
SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0);
|
|
const char* host = wrenGetSlotString(vm, 1);
|
|
int port = (int)wrenGetSlotDouble(vm, 2);
|
|
int backlog = (int)wrenGetSlotDouble(vm, 3);
|
|
|
|
struct addrinfo hints = {0}, *addrs;
|
|
hints.ai_family = AF_UNSPEC;
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
hints.ai_flags = AI_PASSIVE;
|
|
|
|
char port_str[6];
|
|
snprintf(port_str, 6, "%d", port);
|
|
if (getaddrinfo(host, port_str, &hints, &addrs) != 0) {
|
|
wrenSetSlotBool(vm, 0, false);
|
|
return;
|
|
}
|
|
|
|
socket_t sock = INVALID_SOCKET;
|
|
for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) {
|
|
sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
|
|
if (!IS_SOCKET_VALID(sock)) continue;
|
|
|
|
int yes = 1;
|
|
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&yes, sizeof(yes));
|
|
|
|
if (bind(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break;
|
|
|
|
CLOSE_SOCKET(sock);
|
|
sock = INVALID_SOCKET;
|
|
}
|
|
freeaddrinfo(addrs);
|
|
|
|
if (IS_SOCKET_VALID(sock) && listen(sock, backlog) == 0) {
|
|
data->sock = sock;
|
|
data->isListener = true;
|
|
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&socketManager->listener_mutex);
|
|
#else
|
|
pthread_mutex_lock(&socketManager->listener_mutex);
|
|
#endif
|
|
|
|
if (socketManager->listener_count < MAX_LISTENERS) {
|
|
socketManager->listener_sockets[socketManager->listener_count++] = sock;
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
LeaveCriticalSection(&socketManager->listener_mutex);
|
|
#else
|
|
pthread_mutex_unlock(&socketManager->listener_mutex);
|
|
#endif
|
|
|
|
#ifndef _WIN32
|
|
write(socketManager->wake_pipe[1], "w", 1);
|
|
#endif
|
|
|
|
wrenSetSlotBool(vm, 0, true);
|
|
} else {
|
|
if(IS_SOCKET_VALID(sock)) CLOSE_SOCKET(sock);
|
|
wrenSetSlotBool(vm, 0, false);
|
|
}
|
|
}
|
|
|
|
void socketAccept(WrenVM* vm) {
|
|
SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext));
|
|
context->vm = vm;
|
|
context->socketHandle = wrenGetSlotHandle(vm, 0);
|
|
context->callback = wrenGetSlotHandle(vm, 1);
|
|
queue_push(&socketManager->acceptQueue, context);
|
|
}
|
|
|
|
void socketRead(WrenVM* vm) {
|
|
SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext));
|
|
context->operation = SOCKET_OP_READ;
|
|
context->vm = vm;
|
|
context->socketHandle = wrenGetSlotHandle(vm, 0);
|
|
context->callback = wrenGetSlotHandle(vm, 1);
|
|
queue_push(&socketManager->requestQueue, context);
|
|
}
|
|
|
|
void socketWrite(WrenVM* vm) {
|
|
SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext));
|
|
context->operation = SOCKET_OP_WRITE;
|
|
context->vm = vm;
|
|
context->socketHandle = wrenGetSlotHandle(vm, 0);
|
|
int len;
|
|
const char* bytes = wrenGetSlotBytes(vm, 1, &len);
|
|
context->data = (char*)malloc(len);
|
|
memcpy(context->data, bytes, len);
|
|
context->dataLength = len;
|
|
context->callback = wrenGetSlotHandle(vm, 2);
|
|
queue_push(&socketManager->requestQueue, context);
|
|
}
|
|
|
|
void socketClose(WrenVM* vm) {
|
|
SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0);
|
|
if (IS_SOCKET_VALID(data->sock)) {
|
|
if (data->isListener) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&socketManager->listener_mutex);
|
|
#else
|
|
pthread_mutex_lock(&socketManager->listener_mutex);
|
|
#endif
|
|
|
|
for (int i = 0; i < socketManager->listener_count; i++) {
|
|
if (socketManager->listener_sockets[i] == data->sock) {
|
|
socketManager->listener_sockets[i] = socketManager->listener_sockets[socketManager->listener_count - 1];
|
|
socketManager->listener_count--;
|
|
break;
|
|
}
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
LeaveCriticalSection(&socketManager->listener_mutex);
|
|
#else
|
|
pthread_mutex_unlock(&socketManager->listener_mutex);
|
|
#endif
|
|
}
|
|
CLOSE_SOCKET(data->sock);
|
|
data->sock = INVALID_SOCKET;
|
|
}
|
|
}
|
|
|
|
void socketIsOpen(WrenVM* vm) {
|
|
SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0);
|
|
wrenSetSlotBool(vm, 0, IS_SOCKET_VALID(data->sock));
|
|
}
|
|
|
|
void socketRemoteAddress(WrenVM* vm) {
|
|
SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0);
|
|
if (!IS_SOCKET_VALID(data->sock) || data->isListener) {
|
|
wrenSetSlotNull(vm, 0);
|
|
return;
|
|
}
|
|
|
|
struct sockaddr_storage addr;
|
|
socklen_t len = sizeof(addr);
|
|
char ipstr[INET6_ADDRSTRLEN];
|
|
|
|
if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) {
|
|
if (addr.ss_family == AF_INET) {
|
|
inet_ntop(AF_INET, &((struct sockaddr_in*)&addr)->sin_addr, ipstr, sizeof(ipstr));
|
|
} else {
|
|
inet_ntop(AF_INET6, &((struct sockaddr_in6*)&addr)->sin6_addr, ipstr, sizeof(ipstr));
|
|
}
|
|
wrenSetSlotString(vm, 0, ipstr);
|
|
} else {
|
|
wrenSetSlotNull(vm, 0);
|
|
}
|
|
}
|
|
|
|
void socketRemotePort(WrenVM* vm) {
|
|
SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0);
|
|
if (!IS_SOCKET_VALID(data->sock) || data->isListener) {
|
|
wrenSetSlotNull(vm, 0);
|
|
return;
|
|
}
|
|
|
|
struct sockaddr_storage addr;
|
|
socklen_t len = sizeof(addr);
|
|
|
|
if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) {
|
|
int port = 0;
|
|
if (addr.ss_family == AF_INET) {
|
|
port = ntohs(((struct sockaddr_in*)&addr)->sin_port);
|
|
} else if (addr.ss_family == AF_INET6) {
|
|
port = ntohs(((struct sockaddr_in6*)&addr)->sin6_port);
|
|
}
|
|
wrenSetSlotDouble(vm, 0, (double)port);
|
|
} else {
|
|
wrenSetSlotNull(vm, 0);
|
|
}
|
|
}
|
|
|
|
WrenForeignMethodFn bindSocketForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) {
|
|
if (strcmp(module, "socket") != 0) return NULL;
|
|
if (strcmp(className, "Socket") == 0 && !isStatic) {
|
|
if (strcmp(signature, "connect(_,_,_)") == 0) return socketConnect;
|
|
if (strcmp(signature, "listen(_,_,_)") == 0) return socketListen;
|
|
if (strcmp(signature, "accept(_)") == 0) return socketAccept;
|
|
// NOTE: The signature for read() in Wren takes one argument (the callback) now.
|
|
if (strcmp(signature, "read(_)") == 0) return socketRead;
|
|
if (strcmp(signature, "write_(_,_)") == 0) return socketWrite;
|
|
if (strcmp(signature, "close()") == 0) return socketClose;
|
|
if (strcmp(signature, "isOpen") == 0) return socketIsOpen;
|
|
if (strcmp(signature, "remoteAddress") == 0) return socketRemoteAddress;
|
|
if (strcmp(signature, "remotePort") == 0) return socketRemotePort;
|
|
}
|
|
return NULL;
|
|
}
|
|
|
|
WrenForeignClassMethods bindSocketForeignClass(WrenVM* vm, const char* module, const char* className) {
|
|
WrenForeignClassMethods methods = {0, 0};
|
|
if (strcmp(module, "socket") == 0 && strcmp(className, "Socket") == 0) {
|
|
methods.allocate = socketAllocate;
|
|
}
|
|
return methods;
|
|
}
|
|
|
|
// End of socket_backend.c
|
|
|
|
// Start of httplib.h
|
|
//
|
|
// httplib.h
|
|
//
|
|
// Copyright (c) 2025 Yuji Hirose. All rights reserved.
|
|
// MIT License
|
|
//
|
|
|
|
#ifndef CPPHTTPLIB_HTTPLIB_H
|
|
#define CPPHTTPLIB_HTTPLIB_H
|
|
|
|
#define CPPHTTPLIB_VERSION "0.23.1"
|
|
#define CPPHTTPLIB_VERSION_NUM "0x001701"
|
|
|
|
/*
|
|
* Platform compatibility check
|
|
*/
|
|
|
|
#if defined(_WIN32) && !defined(_WIN64)
|
|
#error \
|
|
"cpp-httplib doesn't support 32-bit Windows. Please use a 64-bit compiler."
|
|
#elif defined(__SIZEOF_POINTER__) && __SIZEOF_POINTER__ < 8
|
|
#warning \
|
|
"cpp-httplib doesn't support 32-bit platforms. Please use a 64-bit compiler."
|
|
#elif defined(__SIZEOF_SIZE_T__) && __SIZEOF_SIZE_T__ < 8
|
|
#warning \
|
|
"cpp-httplib doesn't support platforms where size_t is less than 64 bits."
|
|
#endif
|
|
|
|
#ifdef _WIN32
|
|
#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0602
|
|
#error \
|
|
"cpp-httplib doesn't support Windows 8 or lower. Please use Windows 10 or later."
|
|
#endif
|
|
#endif
|
|
|
|
/*
|
|
* Configuration
|
|
*/
|
|
|
|
#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND
|
|
#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND
|
|
#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT
|
|
#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND
|
|
#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND
|
|
#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND
|
|
#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND
|
|
#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND
|
|
#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND
|
|
#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND
|
|
#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND
|
|
#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND
|
|
#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND
|
|
#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND
|
|
#define CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND
|
|
#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND
|
|
#ifdef _WIN64
|
|
#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 1000
|
|
#else
|
|
#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0
|
|
#endif
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH
|
|
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH
|
|
#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_HEADER_MAX_COUNT
|
|
#define CPPHTTPLIB_HEADER_MAX_COUNT 100
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT
|
|
#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT
|
|
#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH
|
|
#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits<size_t>::max)())
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH
|
|
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_RANGE_MAX_COUNT
|
|
#define CPPHTTPLIB_RANGE_MAX_COUNT 1024
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_TCP_NODELAY
|
|
#define CPPHTTPLIB_TCP_NODELAY false
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_IPV6_V6ONLY
|
|
#define CPPHTTPLIB_IPV6_V6ONLY false
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_RECV_BUFSIZ
|
|
#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u)
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_SEND_BUFSIZ
|
|
#define CPPHTTPLIB_SEND_BUFSIZ size_t(16384u)
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ
|
|
#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u)
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_THREAD_POOL_COUNT
|
|
#define CPPHTTPLIB_THREAD_POOL_COUNT \
|
|
((std::max)(8u, std::thread::hardware_concurrency() > 0 \
|
|
? std::thread::hardware_concurrency() - 1 \
|
|
: 0))
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_RECV_FLAGS
|
|
#define CPPHTTPLIB_RECV_FLAGS 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_SEND_FLAGS
|
|
#define CPPHTTPLIB_SEND_FLAGS 0
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_LISTEN_BACKLOG
|
|
#define CPPHTTPLIB_LISTEN_BACKLOG 5
|
|
#endif
|
|
|
|
#ifndef CPPHTTPLIB_MAX_LINE_LENGTH
|
|
#define CPPHTTPLIB_MAX_LINE_LENGTH 32768
|
|
#endif
|
|
|
|
/*
|
|
* Headers
|
|
*/
|
|
|
|
#ifdef _WIN64
|
|
#ifndef _CRT_SECURE_NO_WARNINGS
|
|
#define _CRT_SECURE_NO_WARNINGS
|
|
#endif //_CRT_SECURE_NO_WARNINGS
|
|
|
|
#ifndef _CRT_NONSTDC_NO_DEPRECATE
|
|
#define _CRT_NONSTDC_NO_DEPRECATE
|
|
#endif //_CRT_NONSTDC_NO_DEPRECATE
|
|
|
|
#if defined(_MSC_VER)
|
|
#if _MSC_VER < 1900
|
|
#error Sorry, Visual Studio versions prior to 2015 are not supported
|
|
#endif
|
|
|
|
#pragma comment(lib, "ws2_32.lib")
|
|
|
|
using ssize_t = __int64;
|
|
#endif // _MSC_VER
|
|
|
|
#ifndef S_ISREG
|
|
#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG)
|
|
#endif // S_ISREG
|
|
|
|
#ifndef S_ISDIR
|
|
#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR)
|
|
#endif // S_ISDIR
|
|
|
|
#ifndef NOMINMAX
|
|
#define NOMINMAX
|
|
#endif // NOMINMAX
|
|
|
|
#include <io.h>
|
|
#include <winsock2.h>
|
|
#include <ws2tcpip.h>
|
|
|
|
#if defined(__has_include)
|
|
#if __has_include(<afunix.h>)
|
|
// afunix.h uses types declared in winsock2.h, so has to be included after it.
|
|
#include <afunix.h>
|
|
#define CPPHTTPLIB_HAVE_AFUNIX_H 1
|
|
#endif
|
|
#endif
|
|
|
|
#ifndef WSA_FLAG_NO_HANDLE_INHERIT
|
|
#define WSA_FLAG_NO_HANDLE_INHERIT 0x80
|
|
#endif
|
|
|
|
using nfds_t = unsigned long;
|
|
using socket_t = SOCKET;
|
|
using socklen_t = int;
|
|
|
|
#else // not _WIN64
|
|
|
|
#include <arpa/inet.h>
|
|
#if !defined(_AIX) && !defined(__MVS__)
|
|
#include <ifaddrs.h>
|
|
#endif
|
|
#ifdef __MVS__
|
|
#include <strings.h>
|
|
#ifndef NI_MAXHOST
|
|
#define NI_MAXHOST 1025
|
|
#endif
|
|
#endif
|
|
#include <net/if.h>
|
|
#include <netdb.h>
|
|
#include <netinet/in.h>
|
|
#ifdef __linux__
|
|
#include <resolv.h>
|
|
#endif
|
|
#include <csignal>
|
|
#include <netinet/tcp.h>
|
|
#include <poll.h>
|
|
#include <pthread.h>
|
|
#include <sys/mman.h>
|
|
#include <sys/socket.h>
|
|
#include <sys/un.h>
|
|
#include <unistd.h>
|
|
|
|
using socket_t = int;
|
|
#ifndef INVALID_SOCKET
|
|
#define INVALID_SOCKET (-1)
|
|
#endif
|
|
#endif //_WIN64
|
|
|
|
#if defined(__APPLE__)
|
|
#include <TargetConditionals.h>
|
|
#endif
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <atomic>
|
|
#include <cassert>
|
|
#include <cctype>
|
|
#include <climits>
|
|
#include <condition_variable>
|
|
#include <cstring>
|
|
#include <errno.h>
|
|
#include <exception>
|
|
#include <fcntl.h>
|
|
#include <functional>
|
|
#include <iomanip>
|
|
#include <iostream>
|
|
#include <list>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <random>
|
|
#include <regex>
|
|
#include <set>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <sys/stat.h>
|
|
#include <thread>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
|
|
#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \
|
|
defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
|
#if TARGET_OS_OSX
|
|
#include <CFNetwork/CFHost.h>
|
|
#include <CoreFoundation/CoreFoundation.h>
|
|
#endif
|
|
#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or
|
|
// CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
#ifdef _WIN64
|
|
#include <wincrypt.h>
|
|
|
|
// these are defined in wincrypt.h and it breaks compilation if BoringSSL is
|
|
// used
|
|
#undef X509_NAME
|
|
#undef X509_CERT_PAIR
|
|
#undef X509_EXTENSIONS
|
|
#undef PKCS7_SIGNER_INFO
|
|
|
|
#ifdef _MSC_VER
|
|
#pragma comment(lib, "crypt32.lib")
|
|
#endif
|
|
#endif // _WIN64
|
|
|
|
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
|
#if TARGET_OS_OSX
|
|
#include <Security/Security.h>
|
|
#endif
|
|
#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO
|
|
|
|
#include <openssl/err.h>
|
|
#include <openssl/evp.h>
|
|
#include <openssl/ssl.h>
|
|
#include <openssl/x509v3.h>
|
|
|
|
#if defined(_WIN64) && defined(OPENSSL_USE_APPLINK)
|
|
#include <openssl/applink.c>
|
|
#endif
|
|
|
|
#include <iostream>
|
|
#include <sstream>
|
|
|
|
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
|
|
#if OPENSSL_VERSION_NUMBER < 0x1010107f
|
|
#error Please use OpenSSL or a current version of BoringSSL
|
|
#endif
|
|
#define SSL_get1_peer_certificate SSL_get_peer_certificate
|
|
#elif OPENSSL_VERSION_NUMBER < 0x30000000L
|
|
#error Sorry, OpenSSL versions prior to 3.0.0 are not supported
|
|
#endif
|
|
|
|
#endif // CPPHTTPLIB_OPENSSL_SUPPORT
|
|
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
#include <zlib.h>
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
#include <brotli/decode.h>
|
|
#include <brotli/encode.h>
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
#include <zstd.h>
|
|
#endif
|
|
|
|
/*
|
|
* Declaration
|
|
*/
|
|
namespace httplib {
|
|
|
|
namespace detail {
|
|
|
|
/*
|
|
* Backport std::make_unique from C++14.
|
|
*
|
|
* NOTE: This code came up with the following stackoverflow post:
|
|
* https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique
|
|
*
|
|
*/
|
|
|
|
template <class T, class... Args>
|
|
typename std::enable_if<!std::is_array<T>::value, std::unique_ptr<T>>::type
|
|
make_unique(Args &&...args) {
|
|
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
|
|
}
|
|
|
|
template <class T>
|
|
typename std::enable_if<std::is_array<T>::value, std::unique_ptr<T>>::type
|
|
make_unique(std::size_t n) {
|
|
typedef typename std::remove_extent<T>::type RT;
|
|
return std::unique_ptr<T>(new RT[n]);
|
|
}
|
|
|
|
namespace case_ignore {
|
|
|
|
inline unsigned char to_lower(int c) {
|
|
const static unsigned char table[256] = {
|
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
|
|
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
|
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
|
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
|
60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
|
|
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
|
|
122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
|
|
105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
|
|
120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134,
|
|
135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149,
|
|
150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
|
|
165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179,
|
|
180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226,
|
|
227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241,
|
|
242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224,
|
|
225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
|
|
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254,
|
|
255,
|
|
};
|
|
return table[(unsigned char)(char)c];
|
|
}
|
|
|
|
inline bool equal(const std::string &a, const std::string &b) {
|
|
return a.size() == b.size() &&
|
|
std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) {
|
|
return to_lower(ca) == to_lower(cb);
|
|
});
|
|
}
|
|
|
|
struct equal_to {
|
|
bool operator()(const std::string &a, const std::string &b) const {
|
|
return equal(a, b);
|
|
}
|
|
};
|
|
|
|
struct hash {
|
|
size_t operator()(const std::string &key) const {
|
|
return hash_core(key.data(), key.size(), 0);
|
|
}
|
|
|
|
size_t hash_core(const char *s, size_t l, size_t h) const {
|
|
return (l == 0) ? h
|
|
: hash_core(s + 1, l - 1,
|
|
// Unsets the 6 high bits of h, therefore no
|
|
// overflow happens
|
|
(((std::numeric_limits<size_t>::max)() >> 6) &
|
|
h * 33) ^
|
|
static_cast<unsigned char>(to_lower(*s)));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
using unordered_set = std::unordered_set<T, detail::case_ignore::hash,
|
|
detail::case_ignore::equal_to>;
|
|
|
|
} // namespace case_ignore
|
|
|
|
// This is based on
|
|
// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189".
|
|
|
|
struct scope_exit {
|
|
explicit scope_exit(std::function<void(void)> &&f)
|
|
: exit_function(std::move(f)), execute_on_destruction{true} {}
|
|
|
|
scope_exit(scope_exit &&rhs) noexcept
|
|
: exit_function(std::move(rhs.exit_function)),
|
|
execute_on_destruction{rhs.execute_on_destruction} {
|
|
rhs.release();
|
|
}
|
|
|
|
~scope_exit() {
|
|
if (execute_on_destruction) { this->exit_function(); }
|
|
}
|
|
|
|
void release() { this->execute_on_destruction = false; }
|
|
|
|
private:
|
|
scope_exit(const scope_exit &) = delete;
|
|
void operator=(const scope_exit &) = delete;
|
|
scope_exit &operator=(scope_exit &&) = delete;
|
|
|
|
std::function<void(void)> exit_function;
|
|
bool execute_on_destruction;
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
enum SSLVerifierResponse {
|
|
// no decision has been made, use the built-in certificate verifier
|
|
NoDecisionMade,
|
|
// connection certificate is verified and accepted
|
|
CertificateAccepted,
|
|
// connection certificate was processed but is rejected
|
|
CertificateRejected
|
|
};
|
|
|
|
enum StatusCode {
|
|
// Information responses
|
|
Continue_100 = 100,
|
|
SwitchingProtocol_101 = 101,
|
|
Processing_102 = 102,
|
|
EarlyHints_103 = 103,
|
|
|
|
// Successful responses
|
|
OK_200 = 200,
|
|
Created_201 = 201,
|
|
Accepted_202 = 202,
|
|
NonAuthoritativeInformation_203 = 203,
|
|
NoContent_204 = 204,
|
|
ResetContent_205 = 205,
|
|
PartialContent_206 = 206,
|
|
MultiStatus_207 = 207,
|
|
AlreadyReported_208 = 208,
|
|
IMUsed_226 = 226,
|
|
|
|
// Redirection messages
|
|
MultipleChoices_300 = 300,
|
|
MovedPermanently_301 = 301,
|
|
Found_302 = 302,
|
|
SeeOther_303 = 303,
|
|
NotModified_304 = 304,
|
|
UseProxy_305 = 305,
|
|
unused_306 = 306,
|
|
TemporaryRedirect_307 = 307,
|
|
PermanentRedirect_308 = 308,
|
|
|
|
// Client error responses
|
|
BadRequest_400 = 400,
|
|
Unauthorized_401 = 401,
|
|
PaymentRequired_402 = 402,
|
|
Forbidden_403 = 403,
|
|
NotFound_404 = 404,
|
|
MethodNotAllowed_405 = 405,
|
|
NotAcceptable_406 = 406,
|
|
ProxyAuthenticationRequired_407 = 407,
|
|
RequestTimeout_408 = 408,
|
|
Conflict_409 = 409,
|
|
Gone_410 = 410,
|
|
LengthRequired_411 = 411,
|
|
PreconditionFailed_412 = 412,
|
|
PayloadTooLarge_413 = 413,
|
|
UriTooLong_414 = 414,
|
|
UnsupportedMediaType_415 = 415,
|
|
RangeNotSatisfiable_416 = 416,
|
|
ExpectationFailed_417 = 417,
|
|
ImATeapot_418 = 418,
|
|
MisdirectedRequest_421 = 421,
|
|
UnprocessableContent_422 = 422,
|
|
Locked_423 = 423,
|
|
FailedDependency_424 = 424,
|
|
TooEarly_425 = 425,
|
|
UpgradeRequired_426 = 426,
|
|
PreconditionRequired_428 = 428,
|
|
TooManyRequests_429 = 429,
|
|
RequestHeaderFieldsTooLarge_431 = 431,
|
|
UnavailableForLegalReasons_451 = 451,
|
|
|
|
// Server error responses
|
|
InternalServerError_500 = 500,
|
|
NotImplemented_501 = 501,
|
|
BadGateway_502 = 502,
|
|
ServiceUnavailable_503 = 503,
|
|
GatewayTimeout_504 = 504,
|
|
HttpVersionNotSupported_505 = 505,
|
|
VariantAlsoNegotiates_506 = 506,
|
|
InsufficientStorage_507 = 507,
|
|
LoopDetected_508 = 508,
|
|
NotExtended_510 = 510,
|
|
NetworkAuthenticationRequired_511 = 511,
|
|
};
|
|
|
|
using Headers =
|
|
std::unordered_multimap<std::string, std::string, detail::case_ignore::hash,
|
|
detail::case_ignore::equal_to>;
|
|
|
|
using Params = std::multimap<std::string, std::string>;
|
|
using Match = std::smatch;
|
|
|
|
using DownloadProgress = std::function<bool(size_t current, size_t total)>;
|
|
using UploadProgress = std::function<bool(size_t current, size_t total)>;
|
|
|
|
struct Response;
|
|
using ResponseHandler = std::function<bool(const Response &response)>;
|
|
|
|
struct FormData {
|
|
std::string name;
|
|
std::string content;
|
|
std::string filename;
|
|
std::string content_type;
|
|
Headers headers;
|
|
};
|
|
|
|
struct FormField {
|
|
std::string name;
|
|
std::string content;
|
|
Headers headers;
|
|
};
|
|
using FormFields = std::multimap<std::string, FormField>;
|
|
|
|
using FormFiles = std::multimap<std::string, FormData>;
|
|
|
|
struct MultipartFormData {
|
|
FormFields fields; // Text fields from multipart
|
|
FormFiles files; // Files from multipart
|
|
|
|
// Text field access
|
|
std::string get_field(const std::string &key, size_t id = 0) const;
|
|
std::vector<std::string> get_fields(const std::string &key) const;
|
|
bool has_field(const std::string &key) const;
|
|
size_t get_field_count(const std::string &key) const;
|
|
|
|
// File access
|
|
FormData get_file(const std::string &key, size_t id = 0) const;
|
|
std::vector<FormData> get_files(const std::string &key) const;
|
|
bool has_file(const std::string &key) const;
|
|
size_t get_file_count(const std::string &key) const;
|
|
};
|
|
|
|
struct UploadFormData {
|
|
std::string name;
|
|
std::string content;
|
|
std::string filename;
|
|
std::string content_type;
|
|
};
|
|
using UploadFormDataItems = std::vector<UploadFormData>;
|
|
|
|
class DataSink {
|
|
public:
|
|
DataSink() : os(&sb_), sb_(*this) {}
|
|
|
|
DataSink(const DataSink &) = delete;
|
|
DataSink &operator=(const DataSink &) = delete;
|
|
DataSink(DataSink &&) = delete;
|
|
DataSink &operator=(DataSink &&) = delete;
|
|
|
|
std::function<bool(const char *data, size_t data_len)> write;
|
|
std::function<bool()> is_writable;
|
|
std::function<void()> done;
|
|
std::function<void(const Headers &trailer)> done_with_trailer;
|
|
std::ostream os;
|
|
|
|
private:
|
|
class data_sink_streambuf final : public std::streambuf {
|
|
public:
|
|
explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {}
|
|
|
|
protected:
|
|
std::streamsize xsputn(const char *s, std::streamsize n) override {
|
|
sink_.write(s, static_cast<size_t>(n));
|
|
return n;
|
|
}
|
|
|
|
private:
|
|
DataSink &sink_;
|
|
};
|
|
|
|
data_sink_streambuf sb_;
|
|
};
|
|
|
|
using ContentProvider =
|
|
std::function<bool(size_t offset, size_t length, DataSink &sink)>;
|
|
|
|
using ContentProviderWithoutLength =
|
|
std::function<bool(size_t offset, DataSink &sink)>;
|
|
|
|
using ContentProviderResourceReleaser = std::function<void(bool success)>;
|
|
|
|
struct FormDataProvider {
|
|
std::string name;
|
|
ContentProviderWithoutLength provider;
|
|
std::string filename;
|
|
std::string content_type;
|
|
};
|
|
using FormDataProviderItems = std::vector<FormDataProvider>;
|
|
|
|
using ContentReceiverWithProgress = std::function<bool(
|
|
const char *data, size_t data_length, size_t offset, size_t total_length)>;
|
|
|
|
using ContentReceiver =
|
|
std::function<bool(const char *data, size_t data_length)>;
|
|
|
|
using FormDataHeader = std::function<bool(const FormData &file)>;
|
|
|
|
class ContentReader {
|
|
public:
|
|
using Reader = std::function<bool(ContentReceiver receiver)>;
|
|
using FormDataReader =
|
|
std::function<bool(FormDataHeader header, ContentReceiver receiver)>;
|
|
|
|
ContentReader(Reader reader, FormDataReader multipart_reader)
|
|
: reader_(std::move(reader)),
|
|
formdata_reader_(std::move(multipart_reader)) {}
|
|
|
|
bool operator()(FormDataHeader header, ContentReceiver receiver) const {
|
|
return formdata_reader_(std::move(header), std::move(receiver));
|
|
}
|
|
|
|
bool operator()(ContentReceiver receiver) const {
|
|
return reader_(std::move(receiver));
|
|
}
|
|
|
|
Reader reader_;
|
|
FormDataReader formdata_reader_;
|
|
};
|
|
|
|
using Range = std::pair<ssize_t, ssize_t>;
|
|
using Ranges = std::vector<Range>;
|
|
|
|
struct Request {
|
|
std::string method;
|
|
std::string path;
|
|
std::string matched_route;
|
|
Params params;
|
|
Headers headers;
|
|
Headers trailers;
|
|
std::string body;
|
|
|
|
std::string remote_addr;
|
|
int remote_port = -1;
|
|
std::string local_addr;
|
|
int local_port = -1;
|
|
|
|
// for server
|
|
std::string version;
|
|
std::string target;
|
|
MultipartFormData form;
|
|
Ranges ranges;
|
|
Match matches;
|
|
std::unordered_map<std::string, std::string> path_params;
|
|
std::function<bool()> is_connection_closed = []() { return true; };
|
|
|
|
// for client
|
|
std::vector<std::string> accept_content_types;
|
|
ResponseHandler response_handler;
|
|
ContentReceiverWithProgress content_receiver;
|
|
DownloadProgress download_progress;
|
|
UploadProgress upload_progress;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
const SSL *ssl = nullptr;
|
|
#endif
|
|
|
|
bool has_header(const std::string &key) const;
|
|
std::string get_header_value(const std::string &key, const char *def = "",
|
|
size_t id = 0) const;
|
|
size_t get_header_value_u64(const std::string &key, size_t def = 0,
|
|
size_t id = 0) const;
|
|
size_t get_header_value_count(const std::string &key) const;
|
|
void set_header(const std::string &key, const std::string &val);
|
|
|
|
bool has_trailer(const std::string &key) const;
|
|
std::string get_trailer_value(const std::string &key, size_t id = 0) const;
|
|
size_t get_trailer_value_count(const std::string &key) const;
|
|
|
|
bool has_param(const std::string &key) const;
|
|
std::string get_param_value(const std::string &key, size_t id = 0) const;
|
|
size_t get_param_value_count(const std::string &key) const;
|
|
|
|
bool is_multipart_form_data() const;
|
|
|
|
// private members...
|
|
size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT;
|
|
size_t content_length_ = 0;
|
|
ContentProvider content_provider_;
|
|
bool is_chunked_content_provider_ = false;
|
|
size_t authorization_count_ = 0;
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time_ =
|
|
(std::chrono::steady_clock::time_point::min)();
|
|
};
|
|
|
|
struct Response {
|
|
std::string version;
|
|
int status = -1;
|
|
std::string reason;
|
|
Headers headers;
|
|
Headers trailers;
|
|
std::string body;
|
|
std::string location; // Redirect location
|
|
|
|
bool has_header(const std::string &key) const;
|
|
std::string get_header_value(const std::string &key, const char *def = "",
|
|
size_t id = 0) const;
|
|
size_t get_header_value_u64(const std::string &key, size_t def = 0,
|
|
size_t id = 0) const;
|
|
size_t get_header_value_count(const std::string &key) const;
|
|
void set_header(const std::string &key, const std::string &val);
|
|
|
|
bool has_trailer(const std::string &key) const;
|
|
std::string get_trailer_value(const std::string &key, size_t id = 0) const;
|
|
size_t get_trailer_value_count(const std::string &key) const;
|
|
|
|
void set_redirect(const std::string &url, int status = StatusCode::Found_302);
|
|
void set_content(const char *s, size_t n, const std::string &content_type);
|
|
void set_content(const std::string &s, const std::string &content_type);
|
|
void set_content(std::string &&s, const std::string &content_type);
|
|
|
|
void set_content_provider(
|
|
size_t length, const std::string &content_type, ContentProvider provider,
|
|
ContentProviderResourceReleaser resource_releaser = nullptr);
|
|
|
|
void set_content_provider(
|
|
const std::string &content_type, ContentProviderWithoutLength provider,
|
|
ContentProviderResourceReleaser resource_releaser = nullptr);
|
|
|
|
void set_chunked_content_provider(
|
|
const std::string &content_type, ContentProviderWithoutLength provider,
|
|
ContentProviderResourceReleaser resource_releaser = nullptr);
|
|
|
|
void set_file_content(const std::string &path,
|
|
const std::string &content_type);
|
|
void set_file_content(const std::string &path);
|
|
|
|
Response() = default;
|
|
Response(const Response &) = default;
|
|
Response &operator=(const Response &) = default;
|
|
Response(Response &&) = default;
|
|
Response &operator=(Response &&) = default;
|
|
~Response() {
|
|
if (content_provider_resource_releaser_) {
|
|
content_provider_resource_releaser_(content_provider_success_);
|
|
}
|
|
}
|
|
|
|
// private members...
|
|
size_t content_length_ = 0;
|
|
ContentProvider content_provider_;
|
|
ContentProviderResourceReleaser content_provider_resource_releaser_;
|
|
bool is_chunked_content_provider_ = false;
|
|
bool content_provider_success_ = false;
|
|
std::string file_content_path_;
|
|
std::string file_content_content_type_;
|
|
};
|
|
|
|
class Stream {
|
|
public:
|
|
virtual ~Stream() = default;
|
|
|
|
virtual bool is_readable() const = 0;
|
|
virtual bool wait_readable() const = 0;
|
|
virtual bool wait_writable() const = 0;
|
|
|
|
virtual ssize_t read(char *ptr, size_t size) = 0;
|
|
virtual ssize_t write(const char *ptr, size_t size) = 0;
|
|
virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0;
|
|
virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0;
|
|
virtual socket_t socket() const = 0;
|
|
|
|
virtual time_t duration() const = 0;
|
|
|
|
ssize_t write(const char *ptr);
|
|
ssize_t write(const std::string &s);
|
|
};
|
|
|
|
class TaskQueue {
|
|
public:
|
|
TaskQueue() = default;
|
|
virtual ~TaskQueue() = default;
|
|
|
|
virtual bool enqueue(std::function<void()> fn) = 0;
|
|
virtual void shutdown() = 0;
|
|
|
|
virtual void on_idle() {}
|
|
};
|
|
|
|
class ThreadPool final : public TaskQueue {
|
|
public:
|
|
explicit ThreadPool(size_t n, size_t mqr = 0)
|
|
: shutdown_(false), max_queued_requests_(mqr) {
|
|
while (n) {
|
|
threads_.emplace_back(worker(*this));
|
|
n--;
|
|
}
|
|
}
|
|
|
|
ThreadPool(const ThreadPool &) = delete;
|
|
~ThreadPool() override = default;
|
|
|
|
bool enqueue(std::function<void()> fn) override {
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) {
|
|
return false;
|
|
}
|
|
jobs_.push_back(std::move(fn));
|
|
}
|
|
|
|
cond_.notify_one();
|
|
return true;
|
|
}
|
|
|
|
void shutdown() override {
|
|
// Stop all worker threads...
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
shutdown_ = true;
|
|
}
|
|
|
|
cond_.notify_all();
|
|
|
|
// Join...
|
|
for (auto &t : threads_) {
|
|
t.join();
|
|
}
|
|
}
|
|
|
|
private:
|
|
struct worker {
|
|
explicit worker(ThreadPool &pool) : pool_(pool) {}
|
|
|
|
void operator()() {
|
|
for (;;) {
|
|
std::function<void()> fn;
|
|
{
|
|
std::unique_lock<std::mutex> lock(pool_.mutex_);
|
|
|
|
pool_.cond_.wait(
|
|
lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; });
|
|
|
|
if (pool_.shutdown_ && pool_.jobs_.empty()) { break; }
|
|
|
|
fn = pool_.jobs_.front();
|
|
pool_.jobs_.pop_front();
|
|
}
|
|
|
|
assert(true == static_cast<bool>(fn));
|
|
fn();
|
|
}
|
|
|
|
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \
|
|
!defined(LIBRESSL_VERSION_NUMBER)
|
|
OPENSSL_thread_stop();
|
|
#endif
|
|
}
|
|
|
|
ThreadPool &pool_;
|
|
};
|
|
friend struct worker;
|
|
|
|
std::vector<std::thread> threads_;
|
|
std::list<std::function<void()>> jobs_;
|
|
|
|
bool shutdown_;
|
|
size_t max_queued_requests_ = 0;
|
|
|
|
std::condition_variable cond_;
|
|
std::mutex mutex_;
|
|
};
|
|
|
|
using Logger = std::function<void(const Request &, const Response &)>;
|
|
|
|
using SocketOptions = std::function<void(socket_t sock)>;
|
|
|
|
namespace detail {
|
|
|
|
bool set_socket_opt_impl(socket_t sock, int level, int optname,
|
|
const void *optval, socklen_t optlen);
|
|
bool set_socket_opt(socket_t sock, int level, int optname, int opt);
|
|
bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec,
|
|
time_t usec);
|
|
|
|
} // namespace detail
|
|
|
|
void default_socket_options(socket_t sock);
|
|
|
|
const char *status_message(int status);
|
|
|
|
std::string get_bearer_token_auth(const Request &req);
|
|
|
|
namespace detail {
|
|
|
|
class MatcherBase {
|
|
public:
|
|
MatcherBase(std::string pattern) : pattern_(pattern) {}
|
|
virtual ~MatcherBase() = default;
|
|
|
|
const std::string &pattern() const { return pattern_; }
|
|
|
|
// Match request path and populate its matches and
|
|
virtual bool match(Request &request) const = 0;
|
|
|
|
private:
|
|
std::string pattern_;
|
|
};
|
|
|
|
/**
|
|
* Captures parameters in request path and stores them in Request::path_params
|
|
*
|
|
* Capture name is a substring of a pattern from : to /.
|
|
* The rest of the pattern is matched against the request path directly
|
|
* Parameters are captured starting from the next character after
|
|
* the end of the last matched static pattern fragment until the next /.
|
|
*
|
|
* Example pattern:
|
|
* "/path/fragments/:capture/more/fragments/:second_capture"
|
|
* Static fragments:
|
|
* "/path/fragments/", "more/fragments/"
|
|
*
|
|
* Given the following request path:
|
|
* "/path/fragments/:1/more/fragments/:2"
|
|
* the resulting capture will be
|
|
* {{"capture", "1"}, {"second_capture", "2"}}
|
|
*/
|
|
class PathParamsMatcher final : public MatcherBase {
|
|
public:
|
|
PathParamsMatcher(const std::string &pattern);
|
|
|
|
bool match(Request &request) const override;
|
|
|
|
private:
|
|
// Treat segment separators as the end of path parameter capture
|
|
// Does not need to handle query parameters as they are parsed before path
|
|
// matching
|
|
static constexpr char separator = '/';
|
|
|
|
// Contains static path fragments to match against, excluding the '/' after
|
|
// path params
|
|
// Fragments are separated by path params
|
|
std::vector<std::string> static_fragments_;
|
|
// Stores the names of the path parameters to be used as keys in the
|
|
// Request::path_params map
|
|
std::vector<std::string> param_names_;
|
|
};
|
|
|
|
/**
|
|
* Performs std::regex_match on request path
|
|
* and stores the result in Request::matches
|
|
*
|
|
* Note that regex match is performed directly on the whole request.
|
|
* This means that wildcard patterns may match multiple path segments with /:
|
|
* "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end".
|
|
*/
|
|
class RegexMatcher final : public MatcherBase {
|
|
public:
|
|
RegexMatcher(const std::string &pattern)
|
|
: MatcherBase(pattern), regex_(pattern) {}
|
|
|
|
bool match(Request &request) const override;
|
|
|
|
private:
|
|
std::regex regex_;
|
|
};
|
|
|
|
ssize_t write_headers(Stream &strm, const Headers &headers);
|
|
|
|
} // namespace detail
|
|
|
|
class Server {
|
|
public:
|
|
using Handler = std::function<void(const Request &, Response &)>;
|
|
|
|
using ExceptionHandler =
|
|
std::function<void(const Request &, Response &, std::exception_ptr ep)>;
|
|
|
|
enum class HandlerResponse {
|
|
Handled,
|
|
Unhandled,
|
|
};
|
|
using HandlerWithResponse =
|
|
std::function<HandlerResponse(const Request &, Response &)>;
|
|
|
|
using HandlerWithContentReader = std::function<void(
|
|
const Request &, Response &, const ContentReader &content_reader)>;
|
|
|
|
using Expect100ContinueHandler =
|
|
std::function<int(const Request &, Response &)>;
|
|
|
|
Server();
|
|
|
|
virtual ~Server();
|
|
|
|
virtual bool is_valid() const;
|
|
|
|
Server &Get(const std::string &pattern, Handler handler);
|
|
Server &Post(const std::string &pattern, Handler handler);
|
|
Server &Post(const std::string &pattern, HandlerWithContentReader handler);
|
|
Server &Put(const std::string &pattern, Handler handler);
|
|
Server &Put(const std::string &pattern, HandlerWithContentReader handler);
|
|
Server &Patch(const std::string &pattern, Handler handler);
|
|
Server &Patch(const std::string &pattern, HandlerWithContentReader handler);
|
|
Server &Delete(const std::string &pattern, Handler handler);
|
|
Server &Delete(const std::string &pattern, HandlerWithContentReader handler);
|
|
Server &Options(const std::string &pattern, Handler handler);
|
|
|
|
bool set_base_dir(const std::string &dir,
|
|
const std::string &mount_point = std::string());
|
|
bool set_mount_point(const std::string &mount_point, const std::string &dir,
|
|
Headers headers = Headers());
|
|
bool remove_mount_point(const std::string &mount_point);
|
|
Server &set_file_extension_and_mimetype_mapping(const std::string &ext,
|
|
const std::string &mime);
|
|
Server &set_default_file_mimetype(const std::string &mime);
|
|
Server &set_file_request_handler(Handler handler);
|
|
|
|
template <class ErrorHandlerFunc>
|
|
Server &set_error_handler(ErrorHandlerFunc &&handler) {
|
|
return set_error_handler_core(
|
|
std::forward<ErrorHandlerFunc>(handler),
|
|
std::is_convertible<ErrorHandlerFunc, HandlerWithResponse>{});
|
|
}
|
|
|
|
Server &set_exception_handler(ExceptionHandler handler);
|
|
|
|
Server &set_pre_routing_handler(HandlerWithResponse handler);
|
|
Server &set_post_routing_handler(Handler handler);
|
|
|
|
Server &set_pre_request_handler(HandlerWithResponse handler);
|
|
|
|
Server &set_expect_100_continue_handler(Expect100ContinueHandler handler);
|
|
Server &set_logger(Logger logger);
|
|
Server &set_pre_compression_logger(Logger logger);
|
|
|
|
Server &set_address_family(int family);
|
|
Server &set_tcp_nodelay(bool on);
|
|
Server &set_ipv6_v6only(bool on);
|
|
Server &set_socket_options(SocketOptions socket_options);
|
|
|
|
Server &set_default_headers(Headers headers);
|
|
Server &
|
|
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
|
|
|
|
Server &set_keep_alive_max_count(size_t count);
|
|
Server &set_keep_alive_timeout(time_t sec);
|
|
|
|
Server &set_read_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
Server &set_read_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
Server &set_write_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
Server &set_write_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
Server &set_idle_interval(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
Server &set_idle_interval(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
Server &set_payload_max_length(size_t length);
|
|
|
|
bool bind_to_port(const std::string &host, int port, int socket_flags = 0);
|
|
int bind_to_any_port(const std::string &host, int socket_flags = 0);
|
|
bool listen_after_bind();
|
|
|
|
bool listen(const std::string &host, int port, int socket_flags = 0);
|
|
|
|
bool is_running() const;
|
|
void wait_until_ready() const;
|
|
void stop();
|
|
void decommission();
|
|
|
|
std::function<TaskQueue *(void)> new_task_queue;
|
|
|
|
protected:
|
|
bool process_request(Stream &strm, const std::string &remote_addr,
|
|
int remote_port, const std::string &local_addr,
|
|
int local_port, bool close_connection,
|
|
bool &connection_closed,
|
|
const std::function<void(Request &)> &setup_request);
|
|
|
|
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
|
|
size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
|
|
time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND;
|
|
time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND;
|
|
time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND;
|
|
time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND;
|
|
time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND;
|
|
time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND;
|
|
time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND;
|
|
size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH;
|
|
|
|
private:
|
|
using Handlers =
|
|
std::vector<std::pair<std::unique_ptr<detail::MatcherBase>, Handler>>;
|
|
using HandlersForContentReader =
|
|
std::vector<std::pair<std::unique_ptr<detail::MatcherBase>,
|
|
HandlerWithContentReader>>;
|
|
|
|
static std::unique_ptr<detail::MatcherBase>
|
|
make_matcher(const std::string &pattern);
|
|
|
|
Server &set_error_handler_core(HandlerWithResponse handler, std::true_type);
|
|
Server &set_error_handler_core(Handler handler, std::false_type);
|
|
|
|
socket_t create_server_socket(const std::string &host, int port,
|
|
int socket_flags,
|
|
SocketOptions socket_options) const;
|
|
int bind_internal(const std::string &host, int port, int socket_flags);
|
|
bool listen_internal();
|
|
|
|
bool routing(Request &req, Response &res, Stream &strm);
|
|
bool handle_file_request(const Request &req, Response &res);
|
|
bool dispatch_request(Request &req, Response &res,
|
|
const Handlers &handlers) const;
|
|
bool dispatch_request_for_content_reader(
|
|
Request &req, Response &res, ContentReader content_reader,
|
|
const HandlersForContentReader &handlers) const;
|
|
|
|
bool parse_request_line(const char *s, Request &req) const;
|
|
void apply_ranges(const Request &req, Response &res,
|
|
std::string &content_type, std::string &boundary) const;
|
|
bool write_response(Stream &strm, bool close_connection, Request &req,
|
|
Response &res);
|
|
bool write_response_with_content(Stream &strm, bool close_connection,
|
|
const Request &req, Response &res);
|
|
bool write_response_core(Stream &strm, bool close_connection,
|
|
const Request &req, Response &res,
|
|
bool need_apply_ranges);
|
|
bool write_content_with_provider(Stream &strm, const Request &req,
|
|
Response &res, const std::string &boundary,
|
|
const std::string &content_type);
|
|
bool read_content(Stream &strm, Request &req, Response &res);
|
|
bool read_content_with_content_receiver(Stream &strm, Request &req,
|
|
Response &res,
|
|
ContentReceiver receiver,
|
|
FormDataHeader multipart_header,
|
|
ContentReceiver multipart_receiver);
|
|
bool read_content_core(Stream &strm, Request &req, Response &res,
|
|
ContentReceiver receiver,
|
|
FormDataHeader multipart_header,
|
|
ContentReceiver multipart_receiver) const;
|
|
|
|
virtual bool process_and_close_socket(socket_t sock);
|
|
|
|
std::atomic<bool> is_running_{false};
|
|
std::atomic<bool> is_decommissioned{false};
|
|
|
|
struct MountPointEntry {
|
|
std::string mount_point;
|
|
std::string base_dir;
|
|
Headers headers;
|
|
};
|
|
std::vector<MountPointEntry> base_dirs_;
|
|
std::map<std::string, std::string> file_extension_and_mimetype_map_;
|
|
std::string default_file_mimetype_ = "application/octet-stream";
|
|
Handler file_request_handler_;
|
|
|
|
Handlers get_handlers_;
|
|
Handlers post_handlers_;
|
|
HandlersForContentReader post_handlers_for_content_reader_;
|
|
Handlers put_handlers_;
|
|
HandlersForContentReader put_handlers_for_content_reader_;
|
|
Handlers patch_handlers_;
|
|
HandlersForContentReader patch_handlers_for_content_reader_;
|
|
Handlers delete_handlers_;
|
|
HandlersForContentReader delete_handlers_for_content_reader_;
|
|
Handlers options_handlers_;
|
|
|
|
HandlerWithResponse error_handler_;
|
|
ExceptionHandler exception_handler_;
|
|
HandlerWithResponse pre_routing_handler_;
|
|
Handler post_routing_handler_;
|
|
HandlerWithResponse pre_request_handler_;
|
|
Expect100ContinueHandler expect_100_continue_handler_;
|
|
|
|
Logger logger_;
|
|
Logger pre_compression_logger_;
|
|
|
|
int address_family_ = AF_UNSPEC;
|
|
bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY;
|
|
bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY;
|
|
SocketOptions socket_options_ = default_socket_options;
|
|
|
|
Headers default_headers_;
|
|
std::function<ssize_t(Stream &, Headers &)> header_writer_ =
|
|
detail::write_headers;
|
|
};
|
|
|
|
enum class Error {
|
|
Success = 0,
|
|
Unknown,
|
|
Connection,
|
|
BindIPAddress,
|
|
Read,
|
|
Write,
|
|
ExceedRedirectCount,
|
|
Canceled,
|
|
SSLConnection,
|
|
SSLLoadingCerts,
|
|
SSLServerVerification,
|
|
SSLServerHostnameVerification,
|
|
UnsupportedMultipartBoundaryChars,
|
|
Compression,
|
|
ConnectionTimeout,
|
|
ProxyConnection,
|
|
|
|
// For internal use only
|
|
SSLPeerCouldBeClosed_,
|
|
};
|
|
|
|
std::string to_string(Error error);
|
|
|
|
std::ostream &operator<<(std::ostream &os, const Error &obj);
|
|
|
|
class Result {
|
|
public:
|
|
Result() = default;
|
|
Result(std::unique_ptr<Response> &&res, Error err,
|
|
Headers &&request_headers = Headers{})
|
|
: res_(std::move(res)), err_(err),
|
|
request_headers_(std::move(request_headers)) {}
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
Result(std::unique_ptr<Response> &&res, Error err, Headers &&request_headers,
|
|
int ssl_error)
|
|
: res_(std::move(res)), err_(err),
|
|
request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {}
|
|
Result(std::unique_ptr<Response> &&res, Error err, Headers &&request_headers,
|
|
int ssl_error, unsigned long ssl_openssl_error)
|
|
: res_(std::move(res)), err_(err),
|
|
request_headers_(std::move(request_headers)), ssl_error_(ssl_error),
|
|
ssl_openssl_error_(ssl_openssl_error) {}
|
|
#endif
|
|
// Response
|
|
operator bool() const { return res_ != nullptr; }
|
|
bool operator==(std::nullptr_t) const { return res_ == nullptr; }
|
|
bool operator!=(std::nullptr_t) const { return res_ != nullptr; }
|
|
const Response &value() const { return *res_; }
|
|
Response &value() { return *res_; }
|
|
const Response &operator*() const { return *res_; }
|
|
Response &operator*() { return *res_; }
|
|
const Response *operator->() const { return res_.get(); }
|
|
Response *operator->() { return res_.get(); }
|
|
|
|
// Error
|
|
Error error() const { return err_; }
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
// SSL Error
|
|
int ssl_error() const { return ssl_error_; }
|
|
// OpenSSL Error
|
|
unsigned long ssl_openssl_error() const { return ssl_openssl_error_; }
|
|
#endif
|
|
|
|
// Request Headers
|
|
bool has_request_header(const std::string &key) const;
|
|
std::string get_request_header_value(const std::string &key,
|
|
const char *def = "",
|
|
size_t id = 0) const;
|
|
size_t get_request_header_value_u64(const std::string &key, size_t def = 0,
|
|
size_t id = 0) const;
|
|
size_t get_request_header_value_count(const std::string &key) const;
|
|
|
|
private:
|
|
std::unique_ptr<Response> res_;
|
|
Error err_ = Error::Unknown;
|
|
Headers request_headers_;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
int ssl_error_ = 0;
|
|
unsigned long ssl_openssl_error_ = 0;
|
|
#endif
|
|
};
|
|
|
|
class ClientImpl {
|
|
public:
|
|
explicit ClientImpl(const std::string &host);
|
|
|
|
explicit ClientImpl(const std::string &host, int port);
|
|
|
|
explicit ClientImpl(const std::string &host, int port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path);
|
|
|
|
virtual ~ClientImpl();
|
|
|
|
virtual bool is_valid() const;
|
|
|
|
// clang-format off
|
|
Result Get(const std::string &path, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Params ¶ms, const Headers &headers, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Head(const std::string &path);
|
|
Result Head(const std::string &path, const Headers &headers);
|
|
|
|
Result Post(const std::string &path);
|
|
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Params ¶ms);
|
|
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers);
|
|
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const Params ¶ms);
|
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Put(const std::string &path);
|
|
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Params ¶ms);
|
|
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers);
|
|
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const Params ¶ms);
|
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Patch(const std::string &path);
|
|
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Params ¶ms);
|
|
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const Params ¶ms);
|
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Delete(const std::string &path, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Params ¶ms, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, const Params ¶ms, DownloadProgress progress = nullptr);
|
|
|
|
Result Options(const std::string &path);
|
|
Result Options(const std::string &path, const Headers &headers);
|
|
// clang-format on
|
|
|
|
bool send(Request &req, Response &res, Error &error);
|
|
Result send(const Request &req);
|
|
|
|
void stop();
|
|
|
|
std::string host() const;
|
|
int port() const;
|
|
|
|
size_t is_socket_open() const;
|
|
socket_t socket() const;
|
|
|
|
void set_hostname_addr_map(std::map<std::string, std::string> addr_map);
|
|
|
|
void set_default_headers(Headers headers);
|
|
|
|
void
|
|
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
|
|
|
|
void set_address_family(int family);
|
|
void set_tcp_nodelay(bool on);
|
|
void set_ipv6_v6only(bool on);
|
|
void set_socket_options(SocketOptions socket_options);
|
|
|
|
void set_connection_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
void
|
|
set_connection_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_read_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
void set_read_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_write_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
void set_write_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_max_timeout(time_t msec);
|
|
template <class Rep, class Period>
|
|
void set_max_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_basic_auth(const std::string &username, const std::string &password);
|
|
void set_bearer_token_auth(const std::string &token);
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void set_digest_auth(const std::string &username,
|
|
const std::string &password);
|
|
#endif
|
|
|
|
void set_keep_alive(bool on);
|
|
void set_follow_location(bool on);
|
|
|
|
void set_path_encode(bool on);
|
|
|
|
void set_compress(bool on);
|
|
|
|
void set_decompress(bool on);
|
|
|
|
void set_interface(const std::string &intf);
|
|
|
|
void set_proxy(const std::string &host, int port);
|
|
void set_proxy_basic_auth(const std::string &username,
|
|
const std::string &password);
|
|
void set_proxy_bearer_token_auth(const std::string &token);
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void set_proxy_digest_auth(const std::string &username,
|
|
const std::string &password);
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void set_ca_cert_path(const std::string &ca_cert_file_path,
|
|
const std::string &ca_cert_dir_path = std::string());
|
|
void set_ca_cert_store(X509_STORE *ca_cert_store);
|
|
X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const;
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void enable_server_certificate_verification(bool enabled);
|
|
void enable_server_hostname_verification(bool enabled);
|
|
void set_server_certificate_verifier(
|
|
std::function<SSLVerifierResponse(SSL *ssl)> verifier);
|
|
#endif
|
|
|
|
void set_logger(Logger logger);
|
|
|
|
protected:
|
|
struct Socket {
|
|
socket_t sock = INVALID_SOCKET;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
SSL *ssl = nullptr;
|
|
#endif
|
|
|
|
bool is_open() const { return sock != INVALID_SOCKET; }
|
|
};
|
|
|
|
virtual bool create_and_connect_socket(Socket &socket, Error &error);
|
|
|
|
// All of:
|
|
// shutdown_ssl
|
|
// shutdown_socket
|
|
// close_socket
|
|
// should ONLY be called when socket_mutex_ is locked.
|
|
// Also, shutdown_ssl and close_socket should also NOT be called concurrently
|
|
// with a DIFFERENT thread sending requests using that socket.
|
|
virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully);
|
|
void shutdown_socket(Socket &socket) const;
|
|
void close_socket(Socket &socket);
|
|
|
|
bool process_request(Stream &strm, Request &req, Response &res,
|
|
bool close_connection, Error &error);
|
|
|
|
bool write_content_with_provider(Stream &strm, const Request &req,
|
|
Error &error) const;
|
|
|
|
void copy_settings(const ClientImpl &rhs);
|
|
|
|
// Socket endpoint information
|
|
const std::string host_;
|
|
const int port_;
|
|
const std::string host_and_port_;
|
|
|
|
// Current open socket
|
|
Socket socket_;
|
|
mutable std::mutex socket_mutex_;
|
|
std::recursive_mutex request_mutex_;
|
|
|
|
// These are all protected under socket_mutex
|
|
size_t socket_requests_in_flight_ = 0;
|
|
std::thread::id socket_requests_are_from_thread_ = std::thread::id();
|
|
bool socket_should_be_closed_when_request_is_done_ = false;
|
|
|
|
// Hostname-IP map
|
|
std::map<std::string, std::string> addr_map_;
|
|
|
|
// Default headers
|
|
Headers default_headers_;
|
|
|
|
// Header writer
|
|
std::function<ssize_t(Stream &, Headers &)> header_writer_ =
|
|
detail::write_headers;
|
|
|
|
// Settings
|
|
std::string client_cert_path_;
|
|
std::string client_key_path_;
|
|
|
|
time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND;
|
|
time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND;
|
|
time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND;
|
|
time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND;
|
|
time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND;
|
|
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
|
|
time_t max_timeout_msec_ = CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND;
|
|
|
|
std::string basic_auth_username_;
|
|
std::string basic_auth_password_;
|
|
std::string bearer_token_auth_token_;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
std::string digest_auth_username_;
|
|
std::string digest_auth_password_;
|
|
#endif
|
|
|
|
bool keep_alive_ = false;
|
|
bool follow_location_ = false;
|
|
|
|
bool path_encode_ = true;
|
|
|
|
int address_family_ = AF_UNSPEC;
|
|
bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY;
|
|
bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY;
|
|
SocketOptions socket_options_ = nullptr;
|
|
|
|
bool compress_ = false;
|
|
bool decompress_ = true;
|
|
|
|
std::string interface_;
|
|
|
|
std::string proxy_host_;
|
|
int proxy_port_ = -1;
|
|
|
|
std::string proxy_basic_auth_username_;
|
|
std::string proxy_basic_auth_password_;
|
|
std::string proxy_bearer_token_auth_token_;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
std::string proxy_digest_auth_username_;
|
|
std::string proxy_digest_auth_password_;
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
std::string ca_cert_file_path_;
|
|
std::string ca_cert_dir_path_;
|
|
|
|
X509_STORE *ca_cert_store_ = nullptr;
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
bool server_certificate_verification_ = true;
|
|
bool server_hostname_verification_ = true;
|
|
std::function<SSLVerifierResponse(SSL *ssl)> server_certificate_verifier_;
|
|
#endif
|
|
|
|
Logger logger_;
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
int last_ssl_error_ = 0;
|
|
unsigned long last_openssl_error_ = 0;
|
|
#endif
|
|
|
|
private:
|
|
bool send_(Request &req, Response &res, Error &error);
|
|
Result send_(Request &&req);
|
|
|
|
socket_t create_client_socket(Error &error) const;
|
|
bool read_response_line(Stream &strm, const Request &req,
|
|
Response &res) const;
|
|
bool write_request(Stream &strm, Request &req, bool close_connection,
|
|
Error &error);
|
|
bool redirect(Request &req, Response &res, Error &error);
|
|
bool create_redirect_client(const std::string &scheme,
|
|
const std::string &host, int port, Request &req,
|
|
Response &res, const std::string &path,
|
|
const std::string &location, Error &error);
|
|
template <typename ClientType> void setup_redirect_client(ClientType &client);
|
|
bool handle_request(Stream &strm, Request &req, Response &res,
|
|
bool close_connection, Error &error);
|
|
std::unique_ptr<Response> send_with_content_provider(
|
|
Request &req, const char *body, size_t content_length,
|
|
ContentProvider content_provider,
|
|
ContentProviderWithoutLength content_provider_without_length,
|
|
const std::string &content_type, Error &error);
|
|
Result send_with_content_provider(
|
|
const std::string &method, const std::string &path,
|
|
const Headers &headers, const char *body, size_t content_length,
|
|
ContentProvider content_provider,
|
|
ContentProviderWithoutLength content_provider_without_length,
|
|
const std::string &content_type, UploadProgress progress);
|
|
ContentProviderWithoutLength get_multipart_content_provider(
|
|
const std::string &boundary, const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items) const;
|
|
|
|
std::string adjust_host_string(const std::string &host) const;
|
|
|
|
virtual bool
|
|
process_socket(const Socket &socket,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
std::function<bool(Stream &strm)> callback);
|
|
virtual bool is_ssl() const;
|
|
};
|
|
|
|
class Client {
|
|
public:
|
|
// Universal interface
|
|
explicit Client(const std::string &scheme_host_port);
|
|
|
|
explicit Client(const std::string &scheme_host_port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path);
|
|
|
|
// HTTP only interface
|
|
explicit Client(const std::string &host, int port);
|
|
|
|
explicit Client(const std::string &host, int port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path);
|
|
|
|
Client(Client &&) = default;
|
|
Client &operator=(Client &&) = default;
|
|
|
|
~Client();
|
|
|
|
bool is_valid() const;
|
|
|
|
// clang-format off
|
|
Result Get(const std::string &path, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Params ¶ms, const Headers &headers, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Head(const std::string &path);
|
|
Result Head(const std::string &path, const Headers &headers);
|
|
|
|
Result Post(const std::string &path);
|
|
Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Params ¶ms);
|
|
Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers);
|
|
Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const Params ¶ms);
|
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr);
|
|
Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Put(const std::string &path);
|
|
Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Params ¶ms);
|
|
Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers);
|
|
Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const Params ¶ms);
|
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr);
|
|
Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Patch(const std::string &path);
|
|
Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Params ¶ms);
|
|
Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers);
|
|
Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const Params ¶ms);
|
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr);
|
|
Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr);
|
|
|
|
Result Delete(const std::string &path, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Params ¶ms, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr);
|
|
Result Delete(const std::string &path, const Headers &headers, const Params ¶ms, DownloadProgress progress = nullptr);
|
|
|
|
Result Options(const std::string &path);
|
|
Result Options(const std::string &path, const Headers &headers);
|
|
// clang-format on
|
|
|
|
bool send(Request &req, Response &res, Error &error);
|
|
Result send(const Request &req);
|
|
|
|
void stop();
|
|
|
|
std::string host() const;
|
|
int port() const;
|
|
|
|
size_t is_socket_open() const;
|
|
socket_t socket() const;
|
|
|
|
void set_hostname_addr_map(std::map<std::string, std::string> addr_map);
|
|
|
|
void set_default_headers(Headers headers);
|
|
|
|
void
|
|
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
|
|
|
|
void set_address_family(int family);
|
|
void set_tcp_nodelay(bool on);
|
|
void set_socket_options(SocketOptions socket_options);
|
|
|
|
void set_connection_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
void
|
|
set_connection_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_read_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
void set_read_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_write_timeout(time_t sec, time_t usec = 0);
|
|
template <class Rep, class Period>
|
|
void set_write_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_max_timeout(time_t msec);
|
|
template <class Rep, class Period>
|
|
void set_max_timeout(const std::chrono::duration<Rep, Period> &duration);
|
|
|
|
void set_basic_auth(const std::string &username, const std::string &password);
|
|
void set_bearer_token_auth(const std::string &token);
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void set_digest_auth(const std::string &username,
|
|
const std::string &password);
|
|
#endif
|
|
|
|
void set_keep_alive(bool on);
|
|
void set_follow_location(bool on);
|
|
|
|
void set_path_encode(bool on);
|
|
void set_url_encode(bool on);
|
|
|
|
void set_compress(bool on);
|
|
|
|
void set_decompress(bool on);
|
|
|
|
void set_interface(const std::string &intf);
|
|
|
|
void set_proxy(const std::string &host, int port);
|
|
void set_proxy_basic_auth(const std::string &username,
|
|
const std::string &password);
|
|
void set_proxy_bearer_token_auth(const std::string &token);
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void set_proxy_digest_auth(const std::string &username,
|
|
const std::string &password);
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void enable_server_certificate_verification(bool enabled);
|
|
void enable_server_hostname_verification(bool enabled);
|
|
void set_server_certificate_verifier(
|
|
std::function<SSLVerifierResponse(SSL *ssl)> verifier);
|
|
#endif
|
|
|
|
void set_logger(Logger logger);
|
|
|
|
// SSL
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
void set_ca_cert_path(const std::string &ca_cert_file_path,
|
|
const std::string &ca_cert_dir_path = std::string());
|
|
|
|
void set_ca_cert_store(X509_STORE *ca_cert_store);
|
|
void load_ca_cert_store(const char *ca_cert, std::size_t size);
|
|
|
|
long get_openssl_verify_result() const;
|
|
|
|
SSL_CTX *ssl_context() const;
|
|
#endif
|
|
|
|
private:
|
|
std::unique_ptr<ClientImpl> cli_;
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
bool is_ssl_ = false;
|
|
#endif
|
|
};
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
class SSLServer : public Server {
|
|
public:
|
|
SSLServer(const char *cert_path, const char *private_key_path,
|
|
const char *client_ca_cert_file_path = nullptr,
|
|
const char *client_ca_cert_dir_path = nullptr,
|
|
const char *private_key_password = nullptr);
|
|
|
|
SSLServer(X509 *cert, EVP_PKEY *private_key,
|
|
X509_STORE *client_ca_cert_store = nullptr);
|
|
|
|
SSLServer(
|
|
const std::function<bool(SSL_CTX &ssl_ctx)> &setup_ssl_ctx_callback);
|
|
|
|
~SSLServer() override;
|
|
|
|
bool is_valid() const override;
|
|
|
|
SSL_CTX *ssl_context() const;
|
|
|
|
void update_certs(X509 *cert, EVP_PKEY *private_key,
|
|
X509_STORE *client_ca_cert_store = nullptr);
|
|
|
|
private:
|
|
bool process_and_close_socket(socket_t sock) override;
|
|
|
|
SSL_CTX *ctx_;
|
|
std::mutex ctx_mutex_;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
int last_ssl_error_ = 0;
|
|
#endif
|
|
};
|
|
|
|
class SSLClient final : public ClientImpl {
|
|
public:
|
|
explicit SSLClient(const std::string &host);
|
|
|
|
explicit SSLClient(const std::string &host, int port);
|
|
|
|
explicit SSLClient(const std::string &host, int port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path,
|
|
const std::string &private_key_password = std::string());
|
|
|
|
explicit SSLClient(const std::string &host, int port, X509 *client_cert,
|
|
EVP_PKEY *client_key,
|
|
const std::string &private_key_password = std::string());
|
|
|
|
~SSLClient() override;
|
|
|
|
bool is_valid() const override;
|
|
|
|
void set_ca_cert_store(X509_STORE *ca_cert_store);
|
|
void load_ca_cert_store(const char *ca_cert, std::size_t size);
|
|
|
|
long get_openssl_verify_result() const;
|
|
|
|
SSL_CTX *ssl_context() const;
|
|
|
|
private:
|
|
bool create_and_connect_socket(Socket &socket, Error &error) override;
|
|
void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override;
|
|
void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully);
|
|
|
|
bool
|
|
process_socket(const Socket &socket,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
std::function<bool(Stream &strm)> callback) override;
|
|
bool is_ssl() const override;
|
|
|
|
bool connect_with_proxy(
|
|
Socket &sock,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
Response &res, bool &success, Error &error);
|
|
bool initialize_ssl(Socket &socket, Error &error);
|
|
|
|
bool load_certs();
|
|
|
|
bool verify_host(X509 *server_cert) const;
|
|
bool verify_host_with_subject_alt_name(X509 *server_cert) const;
|
|
bool verify_host_with_common_name(X509 *server_cert) const;
|
|
bool check_host_name(const char *pattern, size_t pattern_len) const;
|
|
|
|
SSL_CTX *ctx_;
|
|
std::mutex ctx_mutex_;
|
|
std::once_flag initialize_cert_;
|
|
|
|
std::vector<std::string> host_components_;
|
|
|
|
long verify_result_ = 0;
|
|
|
|
friend class ClientImpl;
|
|
};
|
|
#endif
|
|
|
|
/*
|
|
* Implementation of template methods.
|
|
*/
|
|
|
|
namespace detail {
|
|
|
|
template <typename T, typename U>
|
|
inline void duration_to_sec_and_usec(const T &duration, U callback) {
|
|
auto sec = std::chrono::duration_cast<std::chrono::seconds>(duration).count();
|
|
auto usec = std::chrono::duration_cast<std::chrono::microseconds>(
|
|
duration - std::chrono::seconds(sec))
|
|
.count();
|
|
callback(static_cast<time_t>(sec), static_cast<time_t>(usec));
|
|
}
|
|
|
|
template <size_t N> inline constexpr size_t str_len(const char (&)[N]) {
|
|
return N - 1;
|
|
}
|
|
|
|
inline bool is_numeric(const std::string &str) {
|
|
return !str.empty() &&
|
|
std::all_of(str.cbegin(), str.cend(),
|
|
[](unsigned char c) { return std::isdigit(c); });
|
|
}
|
|
|
|
inline size_t get_header_value_u64(const Headers &headers,
|
|
const std::string &key, size_t def,
|
|
size_t id, bool &is_invalid_value) {
|
|
is_invalid_value = false;
|
|
auto rng = headers.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) {
|
|
if (is_numeric(it->second)) {
|
|
return std::strtoull(it->second.data(), nullptr, 10);
|
|
} else {
|
|
is_invalid_value = true;
|
|
}
|
|
}
|
|
return def;
|
|
}
|
|
|
|
inline size_t get_header_value_u64(const Headers &headers,
|
|
const std::string &key, size_t def,
|
|
size_t id) {
|
|
bool dummy = false;
|
|
return get_header_value_u64(headers, key, def, id, dummy);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
inline size_t Request::get_header_value_u64(const std::string &key, size_t def,
|
|
size_t id) const {
|
|
return detail::get_header_value_u64(headers, key, def, id);
|
|
}
|
|
|
|
inline size_t Response::get_header_value_u64(const std::string &key, size_t def,
|
|
size_t id) const {
|
|
return detail::get_header_value_u64(headers, key, def, id);
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
inline bool set_socket_opt_impl(socket_t sock, int level, int optname,
|
|
const void *optval, socklen_t optlen) {
|
|
return setsockopt(sock, level, optname,
|
|
#ifdef _WIN64
|
|
reinterpret_cast<const char *>(optval),
|
|
#else
|
|
optval,
|
|
#endif
|
|
optlen) == 0;
|
|
}
|
|
|
|
inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) {
|
|
return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval));
|
|
}
|
|
|
|
inline bool set_socket_opt_time(socket_t sock, int level, int optname,
|
|
time_t sec, time_t usec) {
|
|
#ifdef _WIN64
|
|
auto timeout = static_cast<uint32_t>(sec * 1000 + usec / 1000);
|
|
#else
|
|
timeval timeout;
|
|
timeout.tv_sec = static_cast<long>(sec);
|
|
timeout.tv_usec = static_cast<decltype(timeout.tv_usec)>(usec);
|
|
#endif
|
|
return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout));
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
inline void default_socket_options(socket_t sock) {
|
|
detail::set_socket_opt(sock, SOL_SOCKET,
|
|
#ifdef SO_REUSEPORT
|
|
SO_REUSEPORT,
|
|
#else
|
|
SO_REUSEADDR,
|
|
#endif
|
|
1);
|
|
}
|
|
|
|
inline const char *status_message(int status) {
|
|
switch (status) {
|
|
case StatusCode::Continue_100: return "Continue";
|
|
case StatusCode::SwitchingProtocol_101: return "Switching Protocol";
|
|
case StatusCode::Processing_102: return "Processing";
|
|
case StatusCode::EarlyHints_103: return "Early Hints";
|
|
case StatusCode::OK_200: return "OK";
|
|
case StatusCode::Created_201: return "Created";
|
|
case StatusCode::Accepted_202: return "Accepted";
|
|
case StatusCode::NonAuthoritativeInformation_203:
|
|
return "Non-Authoritative Information";
|
|
case StatusCode::NoContent_204: return "No Content";
|
|
case StatusCode::ResetContent_205: return "Reset Content";
|
|
case StatusCode::PartialContent_206: return "Partial Content";
|
|
case StatusCode::MultiStatus_207: return "Multi-Status";
|
|
case StatusCode::AlreadyReported_208: return "Already Reported";
|
|
case StatusCode::IMUsed_226: return "IM Used";
|
|
case StatusCode::MultipleChoices_300: return "Multiple Choices";
|
|
case StatusCode::MovedPermanently_301: return "Moved Permanently";
|
|
case StatusCode::Found_302: return "Found";
|
|
case StatusCode::SeeOther_303: return "See Other";
|
|
case StatusCode::NotModified_304: return "Not Modified";
|
|
case StatusCode::UseProxy_305: return "Use Proxy";
|
|
case StatusCode::unused_306: return "unused";
|
|
case StatusCode::TemporaryRedirect_307: return "Temporary Redirect";
|
|
case StatusCode::PermanentRedirect_308: return "Permanent Redirect";
|
|
case StatusCode::BadRequest_400: return "Bad Request";
|
|
case StatusCode::Unauthorized_401: return "Unauthorized";
|
|
case StatusCode::PaymentRequired_402: return "Payment Required";
|
|
case StatusCode::Forbidden_403: return "Forbidden";
|
|
case StatusCode::NotFound_404: return "Not Found";
|
|
case StatusCode::MethodNotAllowed_405: return "Method Not Allowed";
|
|
case StatusCode::NotAcceptable_406: return "Not Acceptable";
|
|
case StatusCode::ProxyAuthenticationRequired_407:
|
|
return "Proxy Authentication Required";
|
|
case StatusCode::RequestTimeout_408: return "Request Timeout";
|
|
case StatusCode::Conflict_409: return "Conflict";
|
|
case StatusCode::Gone_410: return "Gone";
|
|
case StatusCode::LengthRequired_411: return "Length Required";
|
|
case StatusCode::PreconditionFailed_412: return "Precondition Failed";
|
|
case StatusCode::PayloadTooLarge_413: return "Payload Too Large";
|
|
case StatusCode::UriTooLong_414: return "URI Too Long";
|
|
case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type";
|
|
case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable";
|
|
case StatusCode::ExpectationFailed_417: return "Expectation Failed";
|
|
case StatusCode::ImATeapot_418: return "I'm a teapot";
|
|
case StatusCode::MisdirectedRequest_421: return "Misdirected Request";
|
|
case StatusCode::UnprocessableContent_422: return "Unprocessable Content";
|
|
case StatusCode::Locked_423: return "Locked";
|
|
case StatusCode::FailedDependency_424: return "Failed Dependency";
|
|
case StatusCode::TooEarly_425: return "Too Early";
|
|
case StatusCode::UpgradeRequired_426: return "Upgrade Required";
|
|
case StatusCode::PreconditionRequired_428: return "Precondition Required";
|
|
case StatusCode::TooManyRequests_429: return "Too Many Requests";
|
|
case StatusCode::RequestHeaderFieldsTooLarge_431:
|
|
return "Request Header Fields Too Large";
|
|
case StatusCode::UnavailableForLegalReasons_451:
|
|
return "Unavailable For Legal Reasons";
|
|
case StatusCode::NotImplemented_501: return "Not Implemented";
|
|
case StatusCode::BadGateway_502: return "Bad Gateway";
|
|
case StatusCode::ServiceUnavailable_503: return "Service Unavailable";
|
|
case StatusCode::GatewayTimeout_504: return "Gateway Timeout";
|
|
case StatusCode::HttpVersionNotSupported_505:
|
|
return "HTTP Version Not Supported";
|
|
case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates";
|
|
case StatusCode::InsufficientStorage_507: return "Insufficient Storage";
|
|
case StatusCode::LoopDetected_508: return "Loop Detected";
|
|
case StatusCode::NotExtended_510: return "Not Extended";
|
|
case StatusCode::NetworkAuthenticationRequired_511:
|
|
return "Network Authentication Required";
|
|
|
|
default:
|
|
case StatusCode::InternalServerError_500: return "Internal Server Error";
|
|
}
|
|
}
|
|
|
|
inline std::string get_bearer_token_auth(const Request &req) {
|
|
if (req.has_header("Authorization")) {
|
|
constexpr auto bearer_header_prefix_len = detail::str_len("Bearer ");
|
|
return req.get_header_value("Authorization")
|
|
.substr(bearer_header_prefix_len);
|
|
}
|
|
return "";
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline Server &
|
|
Server::set_read_timeout(const std::chrono::duration<Rep, Period> &duration) {
|
|
detail::duration_to_sec_and_usec(
|
|
duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); });
|
|
return *this;
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline Server &
|
|
Server::set_write_timeout(const std::chrono::duration<Rep, Period> &duration) {
|
|
detail::duration_to_sec_and_usec(
|
|
duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); });
|
|
return *this;
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline Server &
|
|
Server::set_idle_interval(const std::chrono::duration<Rep, Period> &duration) {
|
|
detail::duration_to_sec_and_usec(
|
|
duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); });
|
|
return *this;
|
|
}
|
|
|
|
inline std::string to_string(const Error error) {
|
|
switch (error) {
|
|
case Error::Success: return "Success (no error)";
|
|
case Error::Connection: return "Could not establish connection";
|
|
case Error::BindIPAddress: return "Failed to bind IP address";
|
|
case Error::Read: return "Failed to read connection";
|
|
case Error::Write: return "Failed to write connection";
|
|
case Error::ExceedRedirectCount: return "Maximum redirect count exceeded";
|
|
case Error::Canceled: return "Connection handling canceled";
|
|
case Error::SSLConnection: return "SSL connection failed";
|
|
case Error::SSLLoadingCerts: return "SSL certificate loading failed";
|
|
case Error::SSLServerVerification: return "SSL server verification failed";
|
|
case Error::SSLServerHostnameVerification:
|
|
return "SSL server hostname verification failed";
|
|
case Error::UnsupportedMultipartBoundaryChars:
|
|
return "Unsupported HTTP multipart boundary characters";
|
|
case Error::Compression: return "Compression failed";
|
|
case Error::ConnectionTimeout: return "Connection timed out";
|
|
case Error::ProxyConnection: return "Proxy connection failed";
|
|
case Error::Unknown: return "Unknown";
|
|
default: break;
|
|
}
|
|
|
|
return "Invalid";
|
|
}
|
|
|
|
inline std::ostream &operator<<(std::ostream &os, const Error &obj) {
|
|
os << to_string(obj);
|
|
os << " (" << static_cast<std::underlying_type<Error>::type>(obj) << ')';
|
|
return os;
|
|
}
|
|
|
|
inline size_t Result::get_request_header_value_u64(const std::string &key,
|
|
size_t def,
|
|
size_t id) const {
|
|
return detail::get_header_value_u64(request_headers_, key, def, id);
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void ClientImpl::set_connection_timeout(
|
|
const std::chrono::duration<Rep, Period> &duration) {
|
|
detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) {
|
|
set_connection_timeout(sec, usec);
|
|
});
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void ClientImpl::set_read_timeout(
|
|
const std::chrono::duration<Rep, Period> &duration) {
|
|
detail::duration_to_sec_and_usec(
|
|
duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); });
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void ClientImpl::set_write_timeout(
|
|
const std::chrono::duration<Rep, Period> &duration) {
|
|
detail::duration_to_sec_and_usec(
|
|
duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); });
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void ClientImpl::set_max_timeout(
|
|
const std::chrono::duration<Rep, Period> &duration) {
|
|
auto msec =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
|
|
set_max_timeout(msec);
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void Client::set_connection_timeout(
|
|
const std::chrono::duration<Rep, Period> &duration) {
|
|
cli_->set_connection_timeout(duration);
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void
|
|
Client::set_read_timeout(const std::chrono::duration<Rep, Period> &duration) {
|
|
cli_->set_read_timeout(duration);
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void
|
|
Client::set_write_timeout(const std::chrono::duration<Rep, Period> &duration) {
|
|
cli_->set_write_timeout(duration);
|
|
}
|
|
|
|
inline void Client::set_max_timeout(time_t msec) {
|
|
cli_->set_max_timeout(msec);
|
|
}
|
|
|
|
template <class Rep, class Period>
|
|
inline void
|
|
Client::set_max_timeout(const std::chrono::duration<Rep, Period> &duration) {
|
|
cli_->set_max_timeout(duration);
|
|
}
|
|
|
|
/*
|
|
* Forward declarations and types that will be part of the .h file if split into
|
|
* .h + .cc.
|
|
*/
|
|
|
|
std::string hosted_at(const std::string &hostname);
|
|
|
|
void hosted_at(const std::string &hostname, std::vector<std::string> &addrs);
|
|
|
|
std::string encode_uri_component(const std::string &value);
|
|
|
|
std::string encode_uri(const std::string &value);
|
|
|
|
std::string decode_uri_component(const std::string &value);
|
|
|
|
std::string decode_uri(const std::string &value);
|
|
|
|
std::string encode_query_param(const std::string &value);
|
|
|
|
std::string append_query_params(const std::string &path, const Params ¶ms);
|
|
|
|
std::pair<std::string, std::string> make_range_header(const Ranges &ranges);
|
|
|
|
std::pair<std::string, std::string>
|
|
make_basic_authentication_header(const std::string &username,
|
|
const std::string &password,
|
|
bool is_proxy = false);
|
|
|
|
namespace detail {
|
|
|
|
#if defined(_WIN64)
|
|
inline std::wstring u8string_to_wstring(const char *s) {
|
|
std::wstring ws;
|
|
auto len = static_cast<int>(strlen(s));
|
|
auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0);
|
|
if (wlen > 0) {
|
|
ws.resize(wlen);
|
|
wlen = ::MultiByteToWideChar(
|
|
CP_UTF8, 0, s, len,
|
|
const_cast<LPWSTR>(reinterpret_cast<LPCWSTR>(ws.data())), wlen);
|
|
if (wlen != static_cast<int>(ws.size())) { ws.clear(); }
|
|
}
|
|
return ws;
|
|
}
|
|
#endif
|
|
|
|
struct FileStat {
|
|
FileStat(const std::string &path);
|
|
bool is_file() const;
|
|
bool is_dir() const;
|
|
|
|
private:
|
|
#if defined(_WIN64)
|
|
struct _stat st_;
|
|
#else
|
|
struct stat st_;
|
|
#endif
|
|
int ret_ = -1;
|
|
};
|
|
|
|
std::string decode_path(const std::string &s, bool convert_plus_to_space);
|
|
|
|
std::string trim_copy(const std::string &s);
|
|
|
|
void divide(
|
|
const char *data, std::size_t size, char d,
|
|
std::function<void(const char *, std::size_t, const char *, std::size_t)>
|
|
fn);
|
|
|
|
void divide(
|
|
const std::string &str, char d,
|
|
std::function<void(const char *, std::size_t, const char *, std::size_t)>
|
|
fn);
|
|
|
|
void split(const char *b, const char *e, char d,
|
|
std::function<void(const char *, const char *)> fn);
|
|
|
|
void split(const char *b, const char *e, char d, size_t m,
|
|
std::function<void(const char *, const char *)> fn);
|
|
|
|
bool process_client_socket(
|
|
socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec, time_t write_timeout_usec,
|
|
time_t max_timeout_msec,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
std::function<bool(Stream &)> callback);
|
|
|
|
socket_t create_client_socket(const std::string &host, const std::string &ip,
|
|
int port, int address_family, bool tcp_nodelay,
|
|
bool ipv6_v6only, SocketOptions socket_options,
|
|
time_t connection_timeout_sec,
|
|
time_t connection_timeout_usec,
|
|
time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec,
|
|
time_t write_timeout_usec,
|
|
const std::string &intf, Error &error);
|
|
|
|
const char *get_header_value(const Headers &headers, const std::string &key,
|
|
const char *def, size_t id);
|
|
|
|
std::string params_to_query_str(const Params ¶ms);
|
|
|
|
void parse_query_text(const char *data, std::size_t size, Params ¶ms);
|
|
|
|
void parse_query_text(const std::string &s, Params ¶ms);
|
|
|
|
bool parse_multipart_boundary(const std::string &content_type,
|
|
std::string &boundary);
|
|
|
|
bool parse_range_header(const std::string &s, Ranges &ranges);
|
|
|
|
bool parse_accept_header(const std::string &s,
|
|
std::vector<std::string> &content_types);
|
|
|
|
int close_socket(socket_t sock);
|
|
|
|
ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags);
|
|
|
|
ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags);
|
|
|
|
enum class EncodingType { None = 0, Gzip, Brotli, Zstd };
|
|
|
|
EncodingType encoding_type(const Request &req, const Response &res);
|
|
|
|
class BufferStream final : public Stream {
|
|
public:
|
|
BufferStream() = default;
|
|
~BufferStream() override = default;
|
|
|
|
bool is_readable() const override;
|
|
bool wait_readable() const override;
|
|
bool wait_writable() const override;
|
|
ssize_t read(char *ptr, size_t size) override;
|
|
ssize_t write(const char *ptr, size_t size) override;
|
|
void get_remote_ip_and_port(std::string &ip, int &port) const override;
|
|
void get_local_ip_and_port(std::string &ip, int &port) const override;
|
|
socket_t socket() const override;
|
|
time_t duration() const override;
|
|
|
|
const std::string &get_buffer() const;
|
|
|
|
private:
|
|
std::string buffer;
|
|
size_t position = 0;
|
|
};
|
|
|
|
class compressor {
|
|
public:
|
|
virtual ~compressor() = default;
|
|
|
|
typedef std::function<bool(const char *data, size_t data_len)> Callback;
|
|
virtual bool compress(const char *data, size_t data_length, bool last,
|
|
Callback callback) = 0;
|
|
};
|
|
|
|
class decompressor {
|
|
public:
|
|
virtual ~decompressor() = default;
|
|
|
|
virtual bool is_valid() const = 0;
|
|
|
|
typedef std::function<bool(const char *data, size_t data_len)> Callback;
|
|
virtual bool decompress(const char *data, size_t data_length,
|
|
Callback callback) = 0;
|
|
};
|
|
|
|
class nocompressor final : public compressor {
|
|
public:
|
|
~nocompressor() override = default;
|
|
|
|
bool compress(const char *data, size_t data_length, bool /*last*/,
|
|
Callback callback) override;
|
|
};
|
|
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
class gzip_compressor final : public compressor {
|
|
public:
|
|
gzip_compressor();
|
|
~gzip_compressor() override;
|
|
|
|
bool compress(const char *data, size_t data_length, bool last,
|
|
Callback callback) override;
|
|
|
|
private:
|
|
bool is_valid_ = false;
|
|
z_stream strm_;
|
|
};
|
|
|
|
class gzip_decompressor final : public decompressor {
|
|
public:
|
|
gzip_decompressor();
|
|
~gzip_decompressor() override;
|
|
|
|
bool is_valid() const override;
|
|
|
|
bool decompress(const char *data, size_t data_length,
|
|
Callback callback) override;
|
|
|
|
private:
|
|
bool is_valid_ = false;
|
|
z_stream strm_;
|
|
};
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
class brotli_compressor final : public compressor {
|
|
public:
|
|
brotli_compressor();
|
|
~brotli_compressor();
|
|
|
|
bool compress(const char *data, size_t data_length, bool last,
|
|
Callback callback) override;
|
|
|
|
private:
|
|
BrotliEncoderState *state_ = nullptr;
|
|
};
|
|
|
|
class brotli_decompressor final : public decompressor {
|
|
public:
|
|
brotli_decompressor();
|
|
~brotli_decompressor();
|
|
|
|
bool is_valid() const override;
|
|
|
|
bool decompress(const char *data, size_t data_length,
|
|
Callback callback) override;
|
|
|
|
private:
|
|
BrotliDecoderResult decoder_r;
|
|
BrotliDecoderState *decoder_s = nullptr;
|
|
};
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
class zstd_compressor : public compressor {
|
|
public:
|
|
zstd_compressor();
|
|
~zstd_compressor();
|
|
|
|
bool compress(const char *data, size_t data_length, bool last,
|
|
Callback callback) override;
|
|
|
|
private:
|
|
ZSTD_CCtx *ctx_ = nullptr;
|
|
};
|
|
|
|
class zstd_decompressor : public decompressor {
|
|
public:
|
|
zstd_decompressor();
|
|
~zstd_decompressor();
|
|
|
|
bool is_valid() const override;
|
|
|
|
bool decompress(const char *data, size_t data_length,
|
|
Callback callback) override;
|
|
|
|
private:
|
|
ZSTD_DCtx *ctx_ = nullptr;
|
|
};
|
|
#endif
|
|
|
|
// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer`
|
|
// to store data. The call can set memory on stack for performance.
|
|
class stream_line_reader {
|
|
public:
|
|
stream_line_reader(Stream &strm, char *fixed_buffer,
|
|
size_t fixed_buffer_size);
|
|
const char *ptr() const;
|
|
size_t size() const;
|
|
bool end_with_crlf() const;
|
|
bool getline();
|
|
|
|
private:
|
|
void append(char c);
|
|
|
|
Stream &strm_;
|
|
char *fixed_buffer_;
|
|
const size_t fixed_buffer_size_;
|
|
size_t fixed_buffer_used_size_ = 0;
|
|
std::string growable_buffer_;
|
|
};
|
|
|
|
class mmap {
|
|
public:
|
|
mmap(const char *path);
|
|
~mmap();
|
|
|
|
bool open(const char *path);
|
|
void close();
|
|
|
|
bool is_open() const;
|
|
size_t size() const;
|
|
const char *data() const;
|
|
|
|
private:
|
|
#if defined(_WIN64)
|
|
HANDLE hFile_ = NULL;
|
|
HANDLE hMapping_ = NULL;
|
|
#else
|
|
int fd_ = -1;
|
|
#endif
|
|
size_t size_ = 0;
|
|
void *addr_ = nullptr;
|
|
bool is_open_empty_file = false;
|
|
};
|
|
|
|
// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5
|
|
namespace fields {
|
|
|
|
inline bool is_token_char(char c) {
|
|
return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' ||
|
|
c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' ||
|
|
c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~';
|
|
}
|
|
|
|
inline bool is_token(const std::string &s) {
|
|
if (s.empty()) { return false; }
|
|
for (auto c : s) {
|
|
if (!is_token_char(c)) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
|
|
inline bool is_field_name(const std::string &s) { return is_token(s); }
|
|
|
|
inline bool is_vchar(char c) { return c >= 33 && c <= 126; }
|
|
|
|
inline bool is_obs_text(char c) { return 128 <= static_cast<unsigned char>(c); }
|
|
|
|
inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); }
|
|
|
|
inline bool is_field_content(const std::string &s) {
|
|
if (s.empty()) { return true; }
|
|
|
|
if (s.size() == 1) {
|
|
return is_field_vchar(s[0]);
|
|
} else if (s.size() == 2) {
|
|
return is_field_vchar(s[0]) && is_field_vchar(s[1]);
|
|
} else {
|
|
size_t i = 0;
|
|
|
|
if (!is_field_vchar(s[i])) { return false; }
|
|
i++;
|
|
|
|
while (i < s.size() - 1) {
|
|
auto c = s[i++];
|
|
if (c == ' ' || c == '\t' || is_field_vchar(c)) {
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return is_field_vchar(s[i]);
|
|
}
|
|
}
|
|
|
|
inline bool is_field_value(const std::string &s) { return is_field_content(s); }
|
|
|
|
} // namespace fields
|
|
|
|
} // namespace detail
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
/*
|
|
* Implementation that will be part of the .cc file if split into .h + .cc.
|
|
*/
|
|
|
|
namespace detail {
|
|
|
|
inline bool is_hex(char c, int &v) {
|
|
if (0x20 <= c && isdigit(c)) {
|
|
v = c - '0';
|
|
return true;
|
|
} else if ('A' <= c && c <= 'F') {
|
|
v = c - 'A' + 10;
|
|
return true;
|
|
} else if ('a' <= c && c <= 'f') {
|
|
v = c - 'a' + 10;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt,
|
|
int &val) {
|
|
if (i >= s.size()) { return false; }
|
|
|
|
val = 0;
|
|
for (; cnt; i++, cnt--) {
|
|
if (!s[i]) { return false; }
|
|
auto v = 0;
|
|
if (is_hex(s[i], v)) {
|
|
val = val * 16 + v;
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
inline std::string from_i_to_hex(size_t n) {
|
|
static const auto charset = "0123456789abcdef";
|
|
std::string ret;
|
|
do {
|
|
ret = charset[n & 15] + ret;
|
|
n >>= 4;
|
|
} while (n > 0);
|
|
return ret;
|
|
}
|
|
|
|
inline size_t to_utf8(int code, char *buff) {
|
|
if (code < 0x0080) {
|
|
buff[0] = static_cast<char>(code & 0x7F);
|
|
return 1;
|
|
} else if (code < 0x0800) {
|
|
buff[0] = static_cast<char>(0xC0 | ((code >> 6) & 0x1F));
|
|
buff[1] = static_cast<char>(0x80 | (code & 0x3F));
|
|
return 2;
|
|
} else if (code < 0xD800) {
|
|
buff[0] = static_cast<char>(0xE0 | ((code >> 12) & 0xF));
|
|
buff[1] = static_cast<char>(0x80 | ((code >> 6) & 0x3F));
|
|
buff[2] = static_cast<char>(0x80 | (code & 0x3F));
|
|
return 3;
|
|
} else if (code < 0xE000) { // D800 - DFFF is invalid...
|
|
return 0;
|
|
} else if (code < 0x10000) {
|
|
buff[0] = static_cast<char>(0xE0 | ((code >> 12) & 0xF));
|
|
buff[1] = static_cast<char>(0x80 | ((code >> 6) & 0x3F));
|
|
buff[2] = static_cast<char>(0x80 | (code & 0x3F));
|
|
return 3;
|
|
} else if (code < 0x110000) {
|
|
buff[0] = static_cast<char>(0xF0 | ((code >> 18) & 0x7));
|
|
buff[1] = static_cast<char>(0x80 | ((code >> 12) & 0x3F));
|
|
buff[2] = static_cast<char>(0x80 | ((code >> 6) & 0x3F));
|
|
buff[3] = static_cast<char>(0x80 | (code & 0x3F));
|
|
return 4;
|
|
}
|
|
|
|
// NOTREACHED
|
|
return 0;
|
|
}
|
|
|
|
// NOTE: This code came up with the following stackoverflow post:
|
|
// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c
|
|
inline std::string base64_encode(const std::string &in) {
|
|
static const auto lookup =
|
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
|
|
std::string out;
|
|
out.reserve(in.size());
|
|
|
|
auto val = 0;
|
|
auto valb = -6;
|
|
|
|
for (auto c : in) {
|
|
val = (val << 8) + static_cast<uint8_t>(c);
|
|
valb += 8;
|
|
while (valb >= 0) {
|
|
out.push_back(lookup[(val >> valb) & 0x3F]);
|
|
valb -= 6;
|
|
}
|
|
}
|
|
|
|
if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); }
|
|
|
|
while (out.size() % 4) {
|
|
out.push_back('=');
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
inline bool is_valid_path(const std::string &path) {
|
|
size_t level = 0;
|
|
size_t i = 0;
|
|
|
|
// Skip slash
|
|
while (i < path.size() && path[i] == '/') {
|
|
i++;
|
|
}
|
|
|
|
while (i < path.size()) {
|
|
// Read component
|
|
auto beg = i;
|
|
while (i < path.size() && path[i] != '/') {
|
|
if (path[i] == '\0') {
|
|
return false;
|
|
} else if (path[i] == '\\') {
|
|
return false;
|
|
}
|
|
i++;
|
|
}
|
|
|
|
auto len = i - beg;
|
|
assert(len > 0);
|
|
|
|
if (!path.compare(beg, len, ".")) {
|
|
;
|
|
} else if (!path.compare(beg, len, "..")) {
|
|
if (level == 0) { return false; }
|
|
level--;
|
|
} else {
|
|
level++;
|
|
}
|
|
|
|
// Skip slash
|
|
while (i < path.size() && path[i] == '/') {
|
|
i++;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline FileStat::FileStat(const std::string &path) {
|
|
#if defined(_WIN64)
|
|
auto wpath = u8string_to_wstring(path.c_str());
|
|
ret_ = _wstat(wpath.c_str(), &st_);
|
|
#else
|
|
ret_ = stat(path.c_str(), &st_);
|
|
#endif
|
|
}
|
|
inline bool FileStat::is_file() const {
|
|
return ret_ >= 0 && S_ISREG(st_.st_mode);
|
|
}
|
|
inline bool FileStat::is_dir() const {
|
|
return ret_ >= 0 && S_ISDIR(st_.st_mode);
|
|
}
|
|
|
|
inline std::string encode_path(const std::string &s) {
|
|
std::string result;
|
|
result.reserve(s.size());
|
|
|
|
for (size_t i = 0; s[i]; i++) {
|
|
switch (s[i]) {
|
|
case ' ': result += "%20"; break;
|
|
case '+': result += "%2B"; break;
|
|
case '\r': result += "%0D"; break;
|
|
case '\n': result += "%0A"; break;
|
|
case '\'': result += "%27"; break;
|
|
case ',': result += "%2C"; break;
|
|
// case ':': result += "%3A"; break; // ok? probably...
|
|
case ';': result += "%3B"; break;
|
|
default:
|
|
auto c = static_cast<uint8_t>(s[i]);
|
|
if (c >= 0x80) {
|
|
result += '%';
|
|
char hex[4];
|
|
auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c);
|
|
assert(len == 2);
|
|
result.append(hex, static_cast<size_t>(len));
|
|
} else {
|
|
result += s[i];
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
inline std::string decode_path(const std::string &s,
|
|
bool convert_plus_to_space) {
|
|
std::string result;
|
|
|
|
for (size_t i = 0; i < s.size(); i++) {
|
|
if (s[i] == '%' && i + 1 < s.size()) {
|
|
if (s[i + 1] == 'u') {
|
|
auto val = 0;
|
|
if (from_hex_to_i(s, i + 2, 4, val)) {
|
|
// 4 digits Unicode codes
|
|
char buff[4];
|
|
size_t len = to_utf8(val, buff);
|
|
if (len > 0) { result.append(buff, len); }
|
|
i += 5; // 'u0000'
|
|
} else {
|
|
result += s[i];
|
|
}
|
|
} else {
|
|
auto val = 0;
|
|
if (from_hex_to_i(s, i + 1, 2, val)) {
|
|
// 2 digits hex codes
|
|
result += static_cast<char>(val);
|
|
i += 2; // '00'
|
|
} else {
|
|
result += s[i];
|
|
}
|
|
}
|
|
} else if (convert_plus_to_space && s[i] == '+') {
|
|
result += ' ';
|
|
} else {
|
|
result += s[i];
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
inline std::string file_extension(const std::string &path) {
|
|
std::smatch m;
|
|
thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$");
|
|
if (std::regex_search(path, m, re)) { return m[1].str(); }
|
|
return std::string();
|
|
}
|
|
|
|
inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; }
|
|
|
|
inline std::pair<size_t, size_t> trim(const char *b, const char *e, size_t left,
|
|
size_t right) {
|
|
while (b + left < e && is_space_or_tab(b[left])) {
|
|
left++;
|
|
}
|
|
while (right > 0 && is_space_or_tab(b[right - 1])) {
|
|
right--;
|
|
}
|
|
return std::make_pair(left, right);
|
|
}
|
|
|
|
inline std::string trim_copy(const std::string &s) {
|
|
auto r = trim(s.data(), s.data() + s.size(), 0, s.size());
|
|
return s.substr(r.first, r.second - r.first);
|
|
}
|
|
|
|
inline std::string trim_double_quotes_copy(const std::string &s) {
|
|
if (s.length() >= 2 && s.front() == '"' && s.back() == '"') {
|
|
return s.substr(1, s.size() - 2);
|
|
}
|
|
return s;
|
|
}
|
|
|
|
inline void
|
|
divide(const char *data, std::size_t size, char d,
|
|
std::function<void(const char *, std::size_t, const char *, std::size_t)>
|
|
fn) {
|
|
const auto it = std::find(data, data + size, d);
|
|
const auto found = static_cast<std::size_t>(it != data + size);
|
|
const auto lhs_data = data;
|
|
const auto lhs_size = static_cast<std::size_t>(it - data);
|
|
const auto rhs_data = it + found;
|
|
const auto rhs_size = size - lhs_size - found;
|
|
|
|
fn(lhs_data, lhs_size, rhs_data, rhs_size);
|
|
}
|
|
|
|
inline void
|
|
divide(const std::string &str, char d,
|
|
std::function<void(const char *, std::size_t, const char *, std::size_t)>
|
|
fn) {
|
|
divide(str.data(), str.size(), d, std::move(fn));
|
|
}
|
|
|
|
inline void split(const char *b, const char *e, char d,
|
|
std::function<void(const char *, const char *)> fn) {
|
|
return split(b, e, d, (std::numeric_limits<size_t>::max)(), std::move(fn));
|
|
}
|
|
|
|
inline void split(const char *b, const char *e, char d, size_t m,
|
|
std::function<void(const char *, const char *)> fn) {
|
|
size_t i = 0;
|
|
size_t beg = 0;
|
|
size_t count = 1;
|
|
|
|
while (e ? (b + i < e) : (b[i] != '\0')) {
|
|
if (b[i] == d && count < m) {
|
|
auto r = trim(b, e, beg, i);
|
|
if (r.first < r.second) { fn(&b[r.first], &b[r.second]); }
|
|
beg = i + 1;
|
|
count++;
|
|
}
|
|
i++;
|
|
}
|
|
|
|
if (i) {
|
|
auto r = trim(b, e, beg, i);
|
|
if (r.first < r.second) { fn(&b[r.first], &b[r.second]); }
|
|
}
|
|
}
|
|
|
|
inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer,
|
|
size_t fixed_buffer_size)
|
|
: strm_(strm), fixed_buffer_(fixed_buffer),
|
|
fixed_buffer_size_(fixed_buffer_size) {}
|
|
|
|
inline const char *stream_line_reader::ptr() const {
|
|
if (growable_buffer_.empty()) {
|
|
return fixed_buffer_;
|
|
} else {
|
|
return growable_buffer_.data();
|
|
}
|
|
}
|
|
|
|
inline size_t stream_line_reader::size() const {
|
|
if (growable_buffer_.empty()) {
|
|
return fixed_buffer_used_size_;
|
|
} else {
|
|
return growable_buffer_.size();
|
|
}
|
|
}
|
|
|
|
inline bool stream_line_reader::end_with_crlf() const {
|
|
auto end = ptr() + size();
|
|
return size() >= 2 && end[-2] == '\r' && end[-1] == '\n';
|
|
}
|
|
|
|
inline bool stream_line_reader::getline() {
|
|
fixed_buffer_used_size_ = 0;
|
|
growable_buffer_.clear();
|
|
|
|
#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR
|
|
char prev_byte = 0;
|
|
#endif
|
|
|
|
for (size_t i = 0;; i++) {
|
|
if (size() >= CPPHTTPLIB_MAX_LINE_LENGTH) {
|
|
// Treat exceptionally long lines as an error to
|
|
// prevent infinite loops/memory exhaustion
|
|
return false;
|
|
}
|
|
char byte;
|
|
auto n = strm_.read(&byte, 1);
|
|
|
|
if (n < 0) {
|
|
return false;
|
|
} else if (n == 0) {
|
|
if (i == 0) {
|
|
return false;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
append(byte);
|
|
|
|
#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR
|
|
if (byte == '\n') { break; }
|
|
#else
|
|
if (prev_byte == '\r' && byte == '\n') { break; }
|
|
prev_byte = byte;
|
|
#endif
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline void stream_line_reader::append(char c) {
|
|
if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) {
|
|
fixed_buffer_[fixed_buffer_used_size_++] = c;
|
|
fixed_buffer_[fixed_buffer_used_size_] = '\0';
|
|
} else {
|
|
if (growable_buffer_.empty()) {
|
|
assert(fixed_buffer_[fixed_buffer_used_size_] == '\0');
|
|
growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_);
|
|
}
|
|
growable_buffer_ += c;
|
|
}
|
|
}
|
|
|
|
inline mmap::mmap(const char *path) { open(path); }
|
|
|
|
inline mmap::~mmap() { close(); }
|
|
|
|
inline bool mmap::open(const char *path) {
|
|
close();
|
|
|
|
#if defined(_WIN64)
|
|
auto wpath = u8string_to_wstring(path);
|
|
if (wpath.empty()) { return false; }
|
|
|
|
hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ,
|
|
OPEN_EXISTING, NULL);
|
|
|
|
if (hFile_ == INVALID_HANDLE_VALUE) { return false; }
|
|
|
|
LARGE_INTEGER size{};
|
|
if (!::GetFileSizeEx(hFile_, &size)) { return false; }
|
|
// If the following line doesn't compile due to QuadPart, update Windows SDK.
|
|
// See:
|
|
// https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721
|
|
if (static_cast<ULONGLONG>(size.QuadPart) >
|
|
(std::numeric_limits<decltype(size_)>::max)()) {
|
|
// `size_t` might be 32-bits, on 32-bits Windows.
|
|
return false;
|
|
}
|
|
size_ = static_cast<size_t>(size.QuadPart);
|
|
|
|
hMapping_ =
|
|
::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL);
|
|
|
|
// Special treatment for an empty file...
|
|
if (hMapping_ == NULL && size_ == 0) {
|
|
close();
|
|
is_open_empty_file = true;
|
|
return true;
|
|
}
|
|
|
|
if (hMapping_ == NULL) {
|
|
close();
|
|
return false;
|
|
}
|
|
|
|
addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0);
|
|
|
|
if (addr_ == nullptr) {
|
|
close();
|
|
return false;
|
|
}
|
|
#else
|
|
fd_ = ::open(path, O_RDONLY);
|
|
if (fd_ == -1) { return false; }
|
|
|
|
struct stat sb;
|
|
if (fstat(fd_, &sb) == -1) {
|
|
close();
|
|
return false;
|
|
}
|
|
size_ = static_cast<size_t>(sb.st_size);
|
|
|
|
addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0);
|
|
|
|
// Special treatment for an empty file...
|
|
if (addr_ == MAP_FAILED && size_ == 0) {
|
|
close();
|
|
is_open_empty_file = true;
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
return true;
|
|
}
|
|
|
|
inline bool mmap::is_open() const {
|
|
return is_open_empty_file ? true : addr_ != nullptr;
|
|
}
|
|
|
|
inline size_t mmap::size() const { return size_; }
|
|
|
|
inline const char *mmap::data() const {
|
|
return is_open_empty_file ? "" : static_cast<const char *>(addr_);
|
|
}
|
|
|
|
inline void mmap::close() {
|
|
#if defined(_WIN64)
|
|
if (addr_) {
|
|
::UnmapViewOfFile(addr_);
|
|
addr_ = nullptr;
|
|
}
|
|
|
|
if (hMapping_) {
|
|
::CloseHandle(hMapping_);
|
|
hMapping_ = NULL;
|
|
}
|
|
|
|
if (hFile_ != INVALID_HANDLE_VALUE) {
|
|
::CloseHandle(hFile_);
|
|
hFile_ = INVALID_HANDLE_VALUE;
|
|
}
|
|
|
|
is_open_empty_file = false;
|
|
#else
|
|
if (addr_ != nullptr) {
|
|
munmap(addr_, size_);
|
|
addr_ = nullptr;
|
|
}
|
|
|
|
if (fd_ != -1) {
|
|
::close(fd_);
|
|
fd_ = -1;
|
|
}
|
|
#endif
|
|
size_ = 0;
|
|
}
|
|
inline int close_socket(socket_t sock) {
|
|
#ifdef _WIN64
|
|
return closesocket(sock);
|
|
#else
|
|
return close(sock);
|
|
#endif
|
|
}
|
|
|
|
template <typename T> inline ssize_t handle_EINTR(T fn) {
|
|
ssize_t res = 0;
|
|
while (true) {
|
|
res = fn();
|
|
if (res < 0 && errno == EINTR) {
|
|
std::this_thread::sleep_for(std::chrono::microseconds{1});
|
|
continue;
|
|
}
|
|
break;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) {
|
|
return handle_EINTR([&]() {
|
|
return recv(sock,
|
|
#ifdef _WIN64
|
|
static_cast<char *>(ptr), static_cast<int>(size),
|
|
#else
|
|
ptr, size,
|
|
#endif
|
|
flags);
|
|
});
|
|
}
|
|
|
|
inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size,
|
|
int flags) {
|
|
return handle_EINTR([&]() {
|
|
return send(sock,
|
|
#ifdef _WIN64
|
|
static_cast<const char *>(ptr), static_cast<int>(size),
|
|
#else
|
|
ptr, size,
|
|
#endif
|
|
flags);
|
|
});
|
|
}
|
|
|
|
inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) {
|
|
#ifdef _WIN64
|
|
return ::WSAPoll(fds, nfds, timeout);
|
|
#else
|
|
return ::poll(fds, nfds, timeout);
|
|
#endif
|
|
}
|
|
|
|
template <bool Read>
|
|
inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) {
|
|
#ifdef __APPLE__
|
|
if (sock >= FD_SETSIZE) { return -1; }
|
|
|
|
fd_set fds, *rfds, *wfds;
|
|
FD_ZERO(&fds);
|
|
FD_SET(sock, &fds);
|
|
rfds = (Read ? &fds : nullptr);
|
|
wfds = (Read ? nullptr : &fds);
|
|
|
|
timeval tv;
|
|
tv.tv_sec = static_cast<long>(sec);
|
|
tv.tv_usec = static_cast<decltype(tv.tv_usec)>(usec);
|
|
|
|
return handle_EINTR([&]() {
|
|
return select(static_cast<int>(sock + 1), rfds, wfds, nullptr, &tv);
|
|
});
|
|
#else
|
|
struct pollfd pfd;
|
|
pfd.fd = sock;
|
|
pfd.events = (Read ? POLLIN : POLLOUT);
|
|
|
|
auto timeout = static_cast<int>(sec * 1000 + usec / 1000);
|
|
|
|
return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); });
|
|
#endif
|
|
}
|
|
|
|
inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) {
|
|
return select_impl<true>(sock, sec, usec);
|
|
}
|
|
|
|
inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) {
|
|
return select_impl<false>(sock, sec, usec);
|
|
}
|
|
|
|
inline Error wait_until_socket_is_ready(socket_t sock, time_t sec,
|
|
time_t usec) {
|
|
#ifdef __APPLE__
|
|
if (sock >= FD_SETSIZE) { return Error::Connection; }
|
|
|
|
fd_set fdsr, fdsw;
|
|
FD_ZERO(&fdsr);
|
|
FD_ZERO(&fdsw);
|
|
FD_SET(sock, &fdsr);
|
|
FD_SET(sock, &fdsw);
|
|
|
|
timeval tv;
|
|
tv.tv_sec = static_cast<long>(sec);
|
|
tv.tv_usec = static_cast<decltype(tv.tv_usec)>(usec);
|
|
|
|
auto ret = handle_EINTR([&]() {
|
|
return select(static_cast<int>(sock + 1), &fdsr, &fdsw, nullptr, &tv);
|
|
});
|
|
|
|
if (ret == 0) { return Error::ConnectionTimeout; }
|
|
|
|
if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) {
|
|
auto error = 0;
|
|
socklen_t len = sizeof(error);
|
|
auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR,
|
|
reinterpret_cast<char *>(&error), &len);
|
|
auto successful = res >= 0 && !error;
|
|
return successful ? Error::Success : Error::Connection;
|
|
}
|
|
|
|
return Error::Connection;
|
|
#else
|
|
struct pollfd pfd_read;
|
|
pfd_read.fd = sock;
|
|
pfd_read.events = POLLIN | POLLOUT;
|
|
|
|
auto timeout = static_cast<int>(sec * 1000 + usec / 1000);
|
|
|
|
auto poll_res =
|
|
handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); });
|
|
|
|
if (poll_res == 0) { return Error::ConnectionTimeout; }
|
|
|
|
if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) {
|
|
auto error = 0;
|
|
socklen_t len = sizeof(error);
|
|
auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR,
|
|
reinterpret_cast<char *>(&error), &len);
|
|
auto successful = res >= 0 && !error;
|
|
return successful ? Error::Success : Error::Connection;
|
|
}
|
|
|
|
return Error::Connection;
|
|
#endif
|
|
}
|
|
|
|
inline bool is_socket_alive(socket_t sock) {
|
|
const auto val = detail::select_read(sock, 0, 0);
|
|
if (val == 0) {
|
|
return true;
|
|
} else if (val < 0 && errno == EBADF) {
|
|
return false;
|
|
}
|
|
char buf[1];
|
|
return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0;
|
|
}
|
|
|
|
class SocketStream final : public Stream {
|
|
public:
|
|
SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec, time_t write_timeout_usec,
|
|
time_t max_timeout_msec = 0,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time =
|
|
(std::chrono::steady_clock::time_point::min)());
|
|
~SocketStream() override;
|
|
|
|
bool is_readable() const override;
|
|
bool wait_readable() const override;
|
|
bool wait_writable() const override;
|
|
ssize_t read(char *ptr, size_t size) override;
|
|
ssize_t write(const char *ptr, size_t size) override;
|
|
void get_remote_ip_and_port(std::string &ip, int &port) const override;
|
|
void get_local_ip_and_port(std::string &ip, int &port) const override;
|
|
socket_t socket() const override;
|
|
time_t duration() const override;
|
|
|
|
private:
|
|
socket_t sock_;
|
|
time_t read_timeout_sec_;
|
|
time_t read_timeout_usec_;
|
|
time_t write_timeout_sec_;
|
|
time_t write_timeout_usec_;
|
|
time_t max_timeout_msec_;
|
|
const std::chrono::time_point<std::chrono::steady_clock> start_time_;
|
|
|
|
std::vector<char> read_buff_;
|
|
size_t read_buff_off_ = 0;
|
|
size_t read_buff_content_size_ = 0;
|
|
|
|
static const size_t read_buff_size_ = 1024l * 4;
|
|
};
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
class SSLSocketStream final : public Stream {
|
|
public:
|
|
SSLSocketStream(
|
|
socket_t sock, SSL *ssl, time_t read_timeout_sec,
|
|
time_t read_timeout_usec, time_t write_timeout_sec,
|
|
time_t write_timeout_usec, time_t max_timeout_msec = 0,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time =
|
|
(std::chrono::steady_clock::time_point::min)());
|
|
~SSLSocketStream() override;
|
|
|
|
bool is_readable() const override;
|
|
bool wait_readable() const override;
|
|
bool wait_writable() const override;
|
|
ssize_t read(char *ptr, size_t size) override;
|
|
ssize_t write(const char *ptr, size_t size) override;
|
|
void get_remote_ip_and_port(std::string &ip, int &port) const override;
|
|
void get_local_ip_and_port(std::string &ip, int &port) const override;
|
|
socket_t socket() const override;
|
|
time_t duration() const override;
|
|
|
|
private:
|
|
socket_t sock_;
|
|
SSL *ssl_;
|
|
time_t read_timeout_sec_;
|
|
time_t read_timeout_usec_;
|
|
time_t write_timeout_sec_;
|
|
time_t write_timeout_usec_;
|
|
time_t max_timeout_msec_;
|
|
const std::chrono::time_point<std::chrono::steady_clock> start_time_;
|
|
};
|
|
#endif
|
|
|
|
inline bool keep_alive(const std::atomic<socket_t> &svr_sock, socket_t sock,
|
|
time_t keep_alive_timeout_sec) {
|
|
using namespace std::chrono;
|
|
|
|
const auto interval_usec =
|
|
CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND;
|
|
|
|
// Avoid expensive `steady_clock::now()` call for the first time
|
|
if (select_read(sock, 0, interval_usec) > 0) { return true; }
|
|
|
|
const auto start = steady_clock::now() - microseconds{interval_usec};
|
|
const auto timeout = seconds{keep_alive_timeout_sec};
|
|
|
|
while (true) {
|
|
if (svr_sock == INVALID_SOCKET) {
|
|
break; // Server socket is closed
|
|
}
|
|
|
|
auto val = select_read(sock, 0, interval_usec);
|
|
if (val < 0) {
|
|
break; // Ssocket error
|
|
} else if (val == 0) {
|
|
if (steady_clock::now() - start > timeout) {
|
|
break; // Timeout
|
|
}
|
|
} else {
|
|
return true; // Ready for read
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool
|
|
process_server_socket_core(const std::atomic<socket_t> &svr_sock, socket_t sock,
|
|
size_t keep_alive_max_count,
|
|
time_t keep_alive_timeout_sec, T callback) {
|
|
assert(keep_alive_max_count > 0);
|
|
auto ret = false;
|
|
auto count = keep_alive_max_count;
|
|
while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) {
|
|
auto close_connection = count == 1;
|
|
auto connection_closed = false;
|
|
ret = callback(close_connection, connection_closed);
|
|
if (!ret || connection_closed) { break; }
|
|
count--;
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool
|
|
process_server_socket(const std::atomic<socket_t> &svr_sock, socket_t sock,
|
|
size_t keep_alive_max_count,
|
|
time_t keep_alive_timeout_sec, time_t read_timeout_sec,
|
|
time_t read_timeout_usec, time_t write_timeout_sec,
|
|
time_t write_timeout_usec, T callback) {
|
|
return process_server_socket_core(
|
|
svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec,
|
|
[&](bool close_connection, bool &connection_closed) {
|
|
SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
|
|
write_timeout_sec, write_timeout_usec);
|
|
return callback(strm, close_connection, connection_closed);
|
|
});
|
|
}
|
|
|
|
inline bool process_client_socket(
|
|
socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec, time_t write_timeout_usec,
|
|
time_t max_timeout_msec,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
std::function<bool(Stream &)> callback) {
|
|
SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
|
|
write_timeout_sec, write_timeout_usec, max_timeout_msec,
|
|
start_time);
|
|
return callback(strm);
|
|
}
|
|
|
|
inline int shutdown_socket(socket_t sock) {
|
|
#ifdef _WIN64
|
|
return shutdown(sock, SD_BOTH);
|
|
#else
|
|
return shutdown(sock, SHUT_RDWR);
|
|
#endif
|
|
}
|
|
|
|
inline std::string escape_abstract_namespace_unix_domain(const std::string &s) {
|
|
if (s.size() > 1 && s[0] == '\0') {
|
|
auto ret = s;
|
|
ret[0] = '@';
|
|
return ret;
|
|
}
|
|
return s;
|
|
}
|
|
|
|
inline std::string
|
|
unescape_abstract_namespace_unix_domain(const std::string &s) {
|
|
if (s.size() > 1 && s[0] == '@') {
|
|
auto ret = s;
|
|
ret[0] = '\0';
|
|
return ret;
|
|
}
|
|
return s;
|
|
}
|
|
|
|
inline int getaddrinfo_with_timeout(const char *node, const char *service,
|
|
const struct addrinfo *hints,
|
|
struct addrinfo **res, time_t timeout_sec) {
|
|
#ifdef CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO
|
|
if (timeout_sec <= 0) {
|
|
// No timeout specified, use standard getaddrinfo
|
|
return getaddrinfo(node, service, hints, res);
|
|
}
|
|
|
|
#ifdef _WIN64
|
|
// Windows-specific implementation using GetAddrInfoEx with overlapped I/O
|
|
OVERLAPPED overlapped = {0};
|
|
HANDLE event = CreateEventW(nullptr, TRUE, FALSE, nullptr);
|
|
if (!event) { return EAI_FAIL; }
|
|
|
|
overlapped.hEvent = event;
|
|
|
|
PADDRINFOEXW result_addrinfo = nullptr;
|
|
HANDLE cancel_handle = nullptr;
|
|
|
|
ADDRINFOEXW hints_ex = {0};
|
|
if (hints) {
|
|
hints_ex.ai_flags = hints->ai_flags;
|
|
hints_ex.ai_family = hints->ai_family;
|
|
hints_ex.ai_socktype = hints->ai_socktype;
|
|
hints_ex.ai_protocol = hints->ai_protocol;
|
|
}
|
|
|
|
auto wnode = u8string_to_wstring(node);
|
|
auto wservice = u8string_to_wstring(service);
|
|
|
|
auto ret = ::GetAddrInfoExW(wnode.data(), wservice.data(), NS_DNS, nullptr,
|
|
hints ? &hints_ex : nullptr, &result_addrinfo,
|
|
nullptr, &overlapped, nullptr, &cancel_handle);
|
|
|
|
if (ret == WSA_IO_PENDING) {
|
|
auto wait_result =
|
|
::WaitForSingleObject(event, static_cast<DWORD>(timeout_sec * 1000));
|
|
if (wait_result == WAIT_TIMEOUT) {
|
|
if (cancel_handle) { ::GetAddrInfoExCancel(&cancel_handle); }
|
|
::CloseHandle(event);
|
|
return EAI_AGAIN;
|
|
}
|
|
|
|
DWORD bytes_returned;
|
|
if (!::GetOverlappedResult((HANDLE)INVALID_SOCKET, &overlapped,
|
|
&bytes_returned, FALSE)) {
|
|
::CloseHandle(event);
|
|
return ::WSAGetLastError();
|
|
}
|
|
}
|
|
|
|
::CloseHandle(event);
|
|
|
|
if (ret == NO_ERROR || ret == WSA_IO_PENDING) {
|
|
*res = reinterpret_cast<struct addrinfo *>(result_addrinfo);
|
|
return 0;
|
|
}
|
|
|
|
return ret;
|
|
#elif defined(TARGET_OS_OSX)
|
|
// macOS implementation using CFHost API for asynchronous DNS resolution
|
|
CFStringRef hostname_ref = CFStringCreateWithCString(
|
|
kCFAllocatorDefault, node, kCFStringEncodingUTF8);
|
|
if (!hostname_ref) { return EAI_MEMORY; }
|
|
|
|
CFHostRef host_ref = CFHostCreateWithName(kCFAllocatorDefault, hostname_ref);
|
|
CFRelease(hostname_ref);
|
|
if (!host_ref) { return EAI_MEMORY; }
|
|
|
|
// Set up context for callback
|
|
struct CFHostContext {
|
|
bool completed = false;
|
|
bool success = false;
|
|
CFArrayRef addresses = nullptr;
|
|
std::mutex mutex;
|
|
std::condition_variable cv;
|
|
} context;
|
|
|
|
CFHostClientContext client_context;
|
|
memset(&client_context, 0, sizeof(client_context));
|
|
client_context.info = &context;
|
|
|
|
// Set callback
|
|
auto callback = [](CFHostRef theHost, CFHostInfoType /*typeInfo*/,
|
|
const CFStreamError *error, void *info) {
|
|
auto ctx = static_cast<CFHostContext *>(info);
|
|
std::lock_guard<std::mutex> lock(ctx->mutex);
|
|
|
|
if (error && error->error != 0) {
|
|
ctx->success = false;
|
|
} else {
|
|
Boolean hasBeenResolved;
|
|
ctx->addresses = CFHostGetAddressing(theHost, &hasBeenResolved);
|
|
if (ctx->addresses && hasBeenResolved) {
|
|
CFRetain(ctx->addresses);
|
|
ctx->success = true;
|
|
} else {
|
|
ctx->success = false;
|
|
}
|
|
}
|
|
ctx->completed = true;
|
|
ctx->cv.notify_one();
|
|
};
|
|
|
|
if (!CFHostSetClient(host_ref, callback, &client_context)) {
|
|
CFRelease(host_ref);
|
|
return EAI_SYSTEM;
|
|
}
|
|
|
|
// Schedule on run loop
|
|
CFRunLoopRef run_loop = CFRunLoopGetCurrent();
|
|
CFHostScheduleWithRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode);
|
|
|
|
// Start resolution
|
|
CFStreamError stream_error;
|
|
if (!CFHostStartInfoResolution(host_ref, kCFHostAddresses, &stream_error)) {
|
|
CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode);
|
|
CFRelease(host_ref);
|
|
return EAI_FAIL;
|
|
}
|
|
|
|
// Wait for completion with timeout
|
|
auto timeout_time =
|
|
std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec);
|
|
bool timed_out = false;
|
|
|
|
{
|
|
std::unique_lock<std::mutex> lock(context.mutex);
|
|
|
|
while (!context.completed) {
|
|
auto now = std::chrono::steady_clock::now();
|
|
if (now >= timeout_time) {
|
|
timed_out = true;
|
|
break;
|
|
}
|
|
|
|
// Run the runloop for a short time
|
|
lock.unlock();
|
|
CFRunLoopRunInMode(kCFRunLoopDefaultMode, 0.1, true);
|
|
lock.lock();
|
|
}
|
|
}
|
|
|
|
// Clean up
|
|
CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode);
|
|
CFHostSetClient(host_ref, nullptr, nullptr);
|
|
|
|
if (timed_out || !context.completed) {
|
|
CFHostCancelInfoResolution(host_ref, kCFHostAddresses);
|
|
CFRelease(host_ref);
|
|
return EAI_AGAIN;
|
|
}
|
|
|
|
if (!context.success || !context.addresses) {
|
|
CFRelease(host_ref);
|
|
return EAI_NODATA;
|
|
}
|
|
|
|
// Convert CFArray to addrinfo
|
|
CFIndex count = CFArrayGetCount(context.addresses);
|
|
if (count == 0) {
|
|
CFRelease(context.addresses);
|
|
CFRelease(host_ref);
|
|
return EAI_NODATA;
|
|
}
|
|
|
|
struct addrinfo *result_addrinfo = nullptr;
|
|
struct addrinfo **current = &result_addrinfo;
|
|
|
|
for (CFIndex i = 0; i < count; i++) {
|
|
CFDataRef addr_data =
|
|
static_cast<CFDataRef>(CFArrayGetValueAtIndex(context.addresses, i));
|
|
if (!addr_data) continue;
|
|
|
|
const struct sockaddr *sockaddr_ptr =
|
|
reinterpret_cast<const struct sockaddr *>(CFDataGetBytePtr(addr_data));
|
|
socklen_t sockaddr_len = static_cast<socklen_t>(CFDataGetLength(addr_data));
|
|
|
|
// Allocate addrinfo structure
|
|
*current = static_cast<struct addrinfo *>(malloc(sizeof(struct addrinfo)));
|
|
if (!*current) {
|
|
freeaddrinfo(result_addrinfo);
|
|
CFRelease(context.addresses);
|
|
CFRelease(host_ref);
|
|
return EAI_MEMORY;
|
|
}
|
|
|
|
memset(*current, 0, sizeof(struct addrinfo));
|
|
|
|
// Set up addrinfo fields
|
|
(*current)->ai_family = sockaddr_ptr->sa_family;
|
|
(*current)->ai_socktype = hints ? hints->ai_socktype : SOCK_STREAM;
|
|
(*current)->ai_protocol = hints ? hints->ai_protocol : IPPROTO_TCP;
|
|
(*current)->ai_addrlen = sockaddr_len;
|
|
|
|
// Copy sockaddr
|
|
(*current)->ai_addr = static_cast<struct sockaddr *>(malloc(sockaddr_len));
|
|
if (!(*current)->ai_addr) {
|
|
freeaddrinfo(result_addrinfo);
|
|
CFRelease(context.addresses);
|
|
CFRelease(host_ref);
|
|
return EAI_MEMORY;
|
|
}
|
|
memcpy((*current)->ai_addr, sockaddr_ptr, sockaddr_len);
|
|
|
|
// Set port if service is specified
|
|
if (service && strlen(service) > 0) {
|
|
int port = atoi(service);
|
|
if (port > 0) {
|
|
if (sockaddr_ptr->sa_family == AF_INET) {
|
|
reinterpret_cast<struct sockaddr_in *>((*current)->ai_addr)
|
|
->sin_port = htons(static_cast<uint16_t>(port));
|
|
} else if (sockaddr_ptr->sa_family == AF_INET6) {
|
|
reinterpret_cast<struct sockaddr_in6 *>((*current)->ai_addr)
|
|
->sin6_port = htons(static_cast<uint16_t>(port));
|
|
}
|
|
}
|
|
}
|
|
|
|
current = &((*current)->ai_next);
|
|
}
|
|
|
|
CFRelease(context.addresses);
|
|
CFRelease(host_ref);
|
|
|
|
*res = result_addrinfo;
|
|
return 0;
|
|
#elif defined(_GNU_SOURCE) && defined(__GLIBC__) && \
|
|
(__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2))
|
|
// Linux implementation using getaddrinfo_a for asynchronous DNS resolution
|
|
struct gaicb request;
|
|
struct gaicb *requests[1] = {&request};
|
|
struct sigevent sevp;
|
|
struct timespec timeout;
|
|
|
|
// Initialize the request structure
|
|
memset(&request, 0, sizeof(request));
|
|
request.ar_name = node;
|
|
request.ar_service = service;
|
|
request.ar_request = hints;
|
|
|
|
// Set up timeout
|
|
timeout.tv_sec = timeout_sec;
|
|
timeout.tv_nsec = 0;
|
|
|
|
// Initialize sigevent structure (not used, but required)
|
|
memset(&sevp, 0, sizeof(sevp));
|
|
sevp.sigev_notify = SIGEV_NONE;
|
|
|
|
// Start asynchronous resolution
|
|
int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp);
|
|
if (start_result != 0) { return start_result; }
|
|
|
|
// Wait for completion with timeout
|
|
int wait_result =
|
|
gai_suspend((const struct gaicb *const *)requests, 1, &timeout);
|
|
|
|
if (wait_result == 0) {
|
|
// Completed successfully, get the result
|
|
int gai_result = gai_error(&request);
|
|
if (gai_result == 0) {
|
|
*res = request.ar_result;
|
|
return 0;
|
|
} else {
|
|
// Clean up on error
|
|
if (request.ar_result) { freeaddrinfo(request.ar_result); }
|
|
return gai_result;
|
|
}
|
|
} else if (wait_result == EAI_AGAIN) {
|
|
// Timeout occurred, cancel the request
|
|
gai_cancel(&request);
|
|
return EAI_AGAIN;
|
|
} else {
|
|
// Other error occurred
|
|
gai_cancel(&request);
|
|
return wait_result;
|
|
}
|
|
#else
|
|
// Fallback implementation using thread-based timeout for other Unix systems
|
|
std::mutex result_mutex;
|
|
std::condition_variable result_cv;
|
|
auto completed = false;
|
|
auto result = EAI_SYSTEM;
|
|
struct addrinfo *result_addrinfo = nullptr;
|
|
|
|
std::thread resolve_thread([&]() {
|
|
auto thread_result = getaddrinfo(node, service, hints, &result_addrinfo);
|
|
|
|
std::lock_guard<std::mutex> lock(result_mutex);
|
|
result = thread_result;
|
|
completed = true;
|
|
result_cv.notify_one();
|
|
});
|
|
|
|
// Wait for completion or timeout
|
|
std::unique_lock<std::mutex> lock(result_mutex);
|
|
auto finished = result_cv.wait_for(lock, std::chrono::seconds(timeout_sec),
|
|
[&] { return completed; });
|
|
|
|
if (finished) {
|
|
// Operation completed within timeout
|
|
resolve_thread.join();
|
|
*res = result_addrinfo;
|
|
return result;
|
|
} else {
|
|
// Timeout occurred
|
|
resolve_thread.detach(); // Let the thread finish in background
|
|
return EAI_AGAIN; // Return timeout error
|
|
}
|
|
#endif
|
|
#else
|
|
(void)(timeout_sec); // Unused parameter for non-blocking getaddrinfo
|
|
return getaddrinfo(node, service, hints, res);
|
|
#endif
|
|
}
|
|
|
|
template <typename BindOrConnect>
|
|
socket_t create_socket(const std::string &host, const std::string &ip, int port,
|
|
int address_family, int socket_flags, bool tcp_nodelay,
|
|
bool ipv6_v6only, SocketOptions socket_options,
|
|
BindOrConnect bind_or_connect, time_t timeout_sec = 0) {
|
|
// Get address info
|
|
const char *node = nullptr;
|
|
struct addrinfo hints;
|
|
struct addrinfo *result;
|
|
|
|
memset(&hints, 0, sizeof(struct addrinfo));
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
hints.ai_protocol = IPPROTO_IP;
|
|
|
|
if (!ip.empty()) {
|
|
node = ip.c_str();
|
|
// Ask getaddrinfo to convert IP in c-string to address
|
|
hints.ai_family = AF_UNSPEC;
|
|
hints.ai_flags = AI_NUMERICHOST;
|
|
} else {
|
|
if (!host.empty()) { node = host.c_str(); }
|
|
hints.ai_family = address_family;
|
|
hints.ai_flags = socket_flags;
|
|
}
|
|
|
|
#if !defined(_WIN64) || defined(CPPHTTPLIB_HAVE_AFUNIX_H)
|
|
if (hints.ai_family == AF_UNIX) {
|
|
const auto addrlen = host.length();
|
|
if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; }
|
|
|
|
#ifdef SOCK_CLOEXEC
|
|
auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC,
|
|
hints.ai_protocol);
|
|
#else
|
|
auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol);
|
|
#endif
|
|
|
|
if (sock != INVALID_SOCKET) {
|
|
sockaddr_un addr{};
|
|
addr.sun_family = AF_UNIX;
|
|
|
|
auto unescaped_host = unescape_abstract_namespace_unix_domain(host);
|
|
std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path);
|
|
|
|
hints.ai_addr = reinterpret_cast<sockaddr *>(&addr);
|
|
hints.ai_addrlen = static_cast<socklen_t>(
|
|
sizeof(addr) - sizeof(addr.sun_path) + addrlen);
|
|
|
|
#ifndef SOCK_CLOEXEC
|
|
#ifndef _WIN64
|
|
fcntl(sock, F_SETFD, FD_CLOEXEC);
|
|
#endif
|
|
#endif
|
|
|
|
if (socket_options) { socket_options(sock); }
|
|
|
|
#ifdef _WIN64
|
|
// Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so
|
|
// remove the option.
|
|
detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0);
|
|
#endif
|
|
|
|
bool dummy;
|
|
if (!bind_or_connect(sock, hints, dummy)) {
|
|
close_socket(sock);
|
|
sock = INVALID_SOCKET;
|
|
}
|
|
}
|
|
return sock;
|
|
}
|
|
#endif
|
|
|
|
auto service = std::to_string(port);
|
|
|
|
if (getaddrinfo_with_timeout(node, service.c_str(), &hints, &result,
|
|
timeout_sec)) {
|
|
#if defined __linux__ && !defined __ANDROID__
|
|
res_init();
|
|
#endif
|
|
return INVALID_SOCKET;
|
|
}
|
|
auto se = detail::scope_exit([&] { freeaddrinfo(result); });
|
|
|
|
for (auto rp = result; rp; rp = rp->ai_next) {
|
|
// Create a socket
|
|
#ifdef _WIN64
|
|
auto sock =
|
|
WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0,
|
|
WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED);
|
|
/**
|
|
* Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1
|
|
* and above the socket creation fails on older Windows Systems.
|
|
*
|
|
* Let's try to create a socket the old way in this case.
|
|
*
|
|
* Reference:
|
|
* https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa
|
|
*
|
|
* WSA_FLAG_NO_HANDLE_INHERIT:
|
|
* This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with
|
|
* SP1, and later
|
|
*
|
|
*/
|
|
if (sock == INVALID_SOCKET) {
|
|
sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
|
|
}
|
|
#else
|
|
|
|
#ifdef SOCK_CLOEXEC
|
|
auto sock =
|
|
socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol);
|
|
#else
|
|
auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
|
|
#endif
|
|
|
|
#endif
|
|
if (sock == INVALID_SOCKET) { continue; }
|
|
|
|
#if !defined _WIN64 && !defined SOCK_CLOEXEC
|
|
if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) {
|
|
close_socket(sock);
|
|
continue;
|
|
}
|
|
#endif
|
|
|
|
if (tcp_nodelay) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); }
|
|
|
|
if (rp->ai_family == AF_INET6) {
|
|
set_socket_opt(sock, IPPROTO_IPV6, IPV6_V6ONLY, ipv6_v6only ? 1 : 0);
|
|
}
|
|
|
|
if (socket_options) { socket_options(sock); }
|
|
|
|
// bind or connect
|
|
auto quit = false;
|
|
if (bind_or_connect(sock, *rp, quit)) { return sock; }
|
|
|
|
close_socket(sock);
|
|
|
|
if (quit) { break; }
|
|
}
|
|
|
|
return INVALID_SOCKET;
|
|
}
|
|
|
|
inline void set_nonblocking(socket_t sock, bool nonblocking) {
|
|
#ifdef _WIN64
|
|
auto flags = nonblocking ? 1UL : 0UL;
|
|
ioctlsocket(sock, FIONBIO, &flags);
|
|
#else
|
|
auto flags = fcntl(sock, F_GETFL, 0);
|
|
fcntl(sock, F_SETFL,
|
|
nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK)));
|
|
#endif
|
|
}
|
|
|
|
inline bool is_connection_error() {
|
|
#ifdef _WIN64
|
|
return WSAGetLastError() != WSAEWOULDBLOCK;
|
|
#else
|
|
return errno != EINPROGRESS;
|
|
#endif
|
|
}
|
|
|
|
inline bool bind_ip_address(socket_t sock, const std::string &host) {
|
|
struct addrinfo hints;
|
|
struct addrinfo *result;
|
|
|
|
memset(&hints, 0, sizeof(struct addrinfo));
|
|
hints.ai_family = AF_UNSPEC;
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
hints.ai_protocol = 0;
|
|
|
|
if (getaddrinfo_with_timeout(host.c_str(), "0", &hints, &result, 0)) {
|
|
return false;
|
|
}
|
|
|
|
auto se = detail::scope_exit([&] { freeaddrinfo(result); });
|
|
|
|
auto ret = false;
|
|
for (auto rp = result; rp; rp = rp->ai_next) {
|
|
const auto &ai = *rp;
|
|
if (!::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) {
|
|
ret = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
#if !defined _WIN64 && !defined ANDROID && !defined _AIX && !defined __MVS__
|
|
#define USE_IF2IP
|
|
#endif
|
|
|
|
#ifdef USE_IF2IP
|
|
inline std::string if2ip(int address_family, const std::string &ifn) {
|
|
struct ifaddrs *ifap;
|
|
getifaddrs(&ifap);
|
|
auto se = detail::scope_exit([&] { freeifaddrs(ifap); });
|
|
|
|
std::string addr_candidate;
|
|
for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) {
|
|
if (ifa->ifa_addr && ifn == ifa->ifa_name &&
|
|
(AF_UNSPEC == address_family ||
|
|
ifa->ifa_addr->sa_family == address_family)) {
|
|
if (ifa->ifa_addr->sa_family == AF_INET) {
|
|
auto sa = reinterpret_cast<struct sockaddr_in *>(ifa->ifa_addr);
|
|
char buf[INET_ADDRSTRLEN];
|
|
if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) {
|
|
return std::string(buf, INET_ADDRSTRLEN);
|
|
}
|
|
} else if (ifa->ifa_addr->sa_family == AF_INET6) {
|
|
auto sa = reinterpret_cast<struct sockaddr_in6 *>(ifa->ifa_addr);
|
|
if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) {
|
|
char buf[INET6_ADDRSTRLEN] = {};
|
|
if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) {
|
|
// equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL
|
|
auto s6_addr_head = sa->sin6_addr.s6_addr[0];
|
|
if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) {
|
|
addr_candidate = std::string(buf, INET6_ADDRSTRLEN);
|
|
} else {
|
|
return std::string(buf, INET6_ADDRSTRLEN);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return addr_candidate;
|
|
}
|
|
#endif
|
|
|
|
inline socket_t create_client_socket(
|
|
const std::string &host, const std::string &ip, int port,
|
|
int address_family, bool tcp_nodelay, bool ipv6_v6only,
|
|
SocketOptions socket_options, time_t connection_timeout_sec,
|
|
time_t connection_timeout_usec, time_t read_timeout_sec,
|
|
time_t read_timeout_usec, time_t write_timeout_sec,
|
|
time_t write_timeout_usec, const std::string &intf, Error &error) {
|
|
auto sock = create_socket(
|
|
host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only,
|
|
std::move(socket_options),
|
|
[&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool {
|
|
if (!intf.empty()) {
|
|
#ifdef USE_IF2IP
|
|
auto ip_from_if = if2ip(address_family, intf);
|
|
if (ip_from_if.empty()) { ip_from_if = intf; }
|
|
if (!bind_ip_address(sock2, ip_from_if)) {
|
|
error = Error::BindIPAddress;
|
|
return false;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
set_nonblocking(sock2, true);
|
|
|
|
auto ret =
|
|
::connect(sock2, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen));
|
|
|
|
if (ret < 0) {
|
|
if (is_connection_error()) {
|
|
error = Error::Connection;
|
|
return false;
|
|
}
|
|
error = wait_until_socket_is_ready(sock2, connection_timeout_sec,
|
|
connection_timeout_usec);
|
|
if (error != Error::Success) {
|
|
if (error == Error::ConnectionTimeout) { quit = true; }
|
|
return false;
|
|
}
|
|
}
|
|
|
|
set_nonblocking(sock2, false);
|
|
set_socket_opt_time(sock2, SOL_SOCKET, SO_RCVTIMEO, read_timeout_sec,
|
|
read_timeout_usec);
|
|
set_socket_opt_time(sock2, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec,
|
|
write_timeout_usec);
|
|
|
|
error = Error::Success;
|
|
return true;
|
|
},
|
|
connection_timeout_sec); // Pass DNS timeout
|
|
|
|
if (sock != INVALID_SOCKET) {
|
|
error = Error::Success;
|
|
} else {
|
|
if (error == Error::Success) { error = Error::Connection; }
|
|
}
|
|
|
|
return sock;
|
|
}
|
|
|
|
inline bool get_ip_and_port(const struct sockaddr_storage &addr,
|
|
socklen_t addr_len, std::string &ip, int &port) {
|
|
if (addr.ss_family == AF_INET) {
|
|
port = ntohs(reinterpret_cast<const struct sockaddr_in *>(&addr)->sin_port);
|
|
} else if (addr.ss_family == AF_INET6) {
|
|
port =
|
|
ntohs(reinterpret_cast<const struct sockaddr_in6 *>(&addr)->sin6_port);
|
|
} else {
|
|
return false;
|
|
}
|
|
|
|
std::array<char, NI_MAXHOST> ipstr{};
|
|
if (getnameinfo(reinterpret_cast<const struct sockaddr *>(&addr), addr_len,
|
|
ipstr.data(), static_cast<socklen_t>(ipstr.size()), nullptr,
|
|
0, NI_NUMERICHOST)) {
|
|
return false;
|
|
}
|
|
|
|
ip = ipstr.data();
|
|
return true;
|
|
}
|
|
|
|
inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) {
|
|
struct sockaddr_storage addr;
|
|
socklen_t addr_len = sizeof(addr);
|
|
if (!getsockname(sock, reinterpret_cast<struct sockaddr *>(&addr),
|
|
&addr_len)) {
|
|
get_ip_and_port(addr, addr_len, ip, port);
|
|
}
|
|
}
|
|
|
|
inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) {
|
|
struct sockaddr_storage addr;
|
|
socklen_t addr_len = sizeof(addr);
|
|
|
|
if (!getpeername(sock, reinterpret_cast<struct sockaddr *>(&addr),
|
|
&addr_len)) {
|
|
#ifndef _WIN64
|
|
if (addr.ss_family == AF_UNIX) {
|
|
#if defined(__linux__)
|
|
struct ucred ucred;
|
|
socklen_t len = sizeof(ucred);
|
|
if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) {
|
|
port = ucred.pid;
|
|
}
|
|
#elif defined(SOL_LOCAL) && defined(SO_PEERPID)
|
|
pid_t pid;
|
|
socklen_t len = sizeof(pid);
|
|
if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) {
|
|
port = pid;
|
|
}
|
|
#endif
|
|
return;
|
|
}
|
|
#endif
|
|
get_ip_and_port(addr, addr_len, ip, port);
|
|
}
|
|
}
|
|
|
|
inline constexpr unsigned int str2tag_core(const char *s, size_t l,
|
|
unsigned int h) {
|
|
return (l == 0)
|
|
? h
|
|
: str2tag_core(
|
|
s + 1, l - 1,
|
|
// Unsets the 6 high bits of h, therefore no overflow happens
|
|
(((std::numeric_limits<unsigned int>::max)() >> 6) &
|
|
h * 33) ^
|
|
static_cast<unsigned char>(*s));
|
|
}
|
|
|
|
inline unsigned int str2tag(const std::string &s) {
|
|
return str2tag_core(s.data(), s.size(), 0);
|
|
}
|
|
|
|
namespace udl {
|
|
|
|
inline constexpr unsigned int operator""_t(const char *s, size_t l) {
|
|
return str2tag_core(s, l, 0);
|
|
}
|
|
|
|
} // namespace udl
|
|
|
|
inline std::string
|
|
find_content_type(const std::string &path,
|
|
const std::map<std::string, std::string> &user_data,
|
|
const std::string &default_content_type) {
|
|
auto ext = file_extension(path);
|
|
|
|
auto it = user_data.find(ext);
|
|
if (it != user_data.end()) { return it->second; }
|
|
|
|
using udl::operator""_t;
|
|
|
|
switch (str2tag(ext)) {
|
|
default: return default_content_type;
|
|
|
|
case "css"_t: return "text/css";
|
|
case "csv"_t: return "text/csv";
|
|
case "htm"_t:
|
|
case "html"_t: return "text/html";
|
|
case "js"_t:
|
|
case "mjs"_t: return "text/javascript";
|
|
case "txt"_t: return "text/plain";
|
|
case "vtt"_t: return "text/vtt";
|
|
|
|
case "apng"_t: return "image/apng";
|
|
case "avif"_t: return "image/avif";
|
|
case "bmp"_t: return "image/bmp";
|
|
case "gif"_t: return "image/gif";
|
|
case "png"_t: return "image/png";
|
|
case "svg"_t: return "image/svg+xml";
|
|
case "webp"_t: return "image/webp";
|
|
case "ico"_t: return "image/x-icon";
|
|
case "tif"_t: return "image/tiff";
|
|
case "tiff"_t: return "image/tiff";
|
|
case "jpg"_t:
|
|
case "jpeg"_t: return "image/jpeg";
|
|
|
|
case "mp4"_t: return "video/mp4";
|
|
case "mpeg"_t: return "video/mpeg";
|
|
case "webm"_t: return "video/webm";
|
|
|
|
case "mp3"_t: return "audio/mp3";
|
|
case "mpga"_t: return "audio/mpeg";
|
|
case "weba"_t: return "audio/webm";
|
|
case "wav"_t: return "audio/wave";
|
|
|
|
case "otf"_t: return "font/otf";
|
|
case "ttf"_t: return "font/ttf";
|
|
case "woff"_t: return "font/woff";
|
|
case "woff2"_t: return "font/woff2";
|
|
|
|
case "7z"_t: return "application/x-7z-compressed";
|
|
case "atom"_t: return "application/atom+xml";
|
|
case "pdf"_t: return "application/pdf";
|
|
case "json"_t: return "application/json";
|
|
case "rss"_t: return "application/rss+xml";
|
|
case "tar"_t: return "application/x-tar";
|
|
case "xht"_t:
|
|
case "xhtml"_t: return "application/xhtml+xml";
|
|
case "xslt"_t: return "application/xslt+xml";
|
|
case "xml"_t: return "application/xml";
|
|
case "gz"_t: return "application/gzip";
|
|
case "zip"_t: return "application/zip";
|
|
case "wasm"_t: return "application/wasm";
|
|
}
|
|
}
|
|
|
|
inline bool can_compress_content_type(const std::string &content_type) {
|
|
using udl::operator""_t;
|
|
|
|
auto tag = str2tag(content_type);
|
|
|
|
switch (tag) {
|
|
case "image/svg+xml"_t:
|
|
case "application/javascript"_t:
|
|
case "application/json"_t:
|
|
case "application/xml"_t:
|
|
case "application/protobuf"_t:
|
|
case "application/xhtml+xml"_t: return true;
|
|
|
|
case "text/event-stream"_t: return false;
|
|
|
|
default: return !content_type.rfind("text/", 0);
|
|
}
|
|
}
|
|
|
|
inline EncodingType encoding_type(const Request &req, const Response &res) {
|
|
auto ret =
|
|
detail::can_compress_content_type(res.get_header_value("Content-Type"));
|
|
if (!ret) { return EncodingType::None; }
|
|
|
|
const auto &s = req.get_header_value("Accept-Encoding");
|
|
(void)(s);
|
|
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
// TODO: 'Accept-Encoding' has br, not br;q=0
|
|
ret = s.find("br") != std::string::npos;
|
|
if (ret) { return EncodingType::Brotli; }
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
// TODO: 'Accept-Encoding' has gzip, not gzip;q=0
|
|
ret = s.find("gzip") != std::string::npos;
|
|
if (ret) { return EncodingType::Gzip; }
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
// TODO: 'Accept-Encoding' has zstd, not zstd;q=0
|
|
ret = s.find("zstd") != std::string::npos;
|
|
if (ret) { return EncodingType::Zstd; }
|
|
#endif
|
|
|
|
return EncodingType::None;
|
|
}
|
|
|
|
inline bool nocompressor::compress(const char *data, size_t data_length,
|
|
bool /*last*/, Callback callback) {
|
|
if (!data_length) { return true; }
|
|
return callback(data, data_length);
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
inline gzip_compressor::gzip_compressor() {
|
|
std::memset(&strm_, 0, sizeof(strm_));
|
|
strm_.zalloc = Z_NULL;
|
|
strm_.zfree = Z_NULL;
|
|
strm_.opaque = Z_NULL;
|
|
|
|
is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8,
|
|
Z_DEFAULT_STRATEGY) == Z_OK;
|
|
}
|
|
|
|
inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); }
|
|
|
|
inline bool gzip_compressor::compress(const char *data, size_t data_length,
|
|
bool last, Callback callback) {
|
|
assert(is_valid_);
|
|
|
|
do {
|
|
constexpr size_t max_avail_in =
|
|
(std::numeric_limits<decltype(strm_.avail_in)>::max)();
|
|
|
|
strm_.avail_in = static_cast<decltype(strm_.avail_in)>(
|
|
(std::min)(data_length, max_avail_in));
|
|
strm_.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data));
|
|
|
|
data_length -= strm_.avail_in;
|
|
data += strm_.avail_in;
|
|
|
|
auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH;
|
|
auto ret = Z_OK;
|
|
|
|
std::array<char, CPPHTTPLIB_COMPRESSION_BUFSIZ> buff{};
|
|
do {
|
|
strm_.avail_out = static_cast<uInt>(buff.size());
|
|
strm_.next_out = reinterpret_cast<Bytef *>(buff.data());
|
|
|
|
ret = deflate(&strm_, flush);
|
|
if (ret == Z_STREAM_ERROR) { return false; }
|
|
|
|
if (!callback(buff.data(), buff.size() - strm_.avail_out)) {
|
|
return false;
|
|
}
|
|
} while (strm_.avail_out == 0);
|
|
|
|
assert((flush == Z_FINISH && ret == Z_STREAM_END) ||
|
|
(flush == Z_NO_FLUSH && ret == Z_OK));
|
|
assert(strm_.avail_in == 0);
|
|
} while (data_length > 0);
|
|
|
|
return true;
|
|
}
|
|
|
|
inline gzip_decompressor::gzip_decompressor() {
|
|
std::memset(&strm_, 0, sizeof(strm_));
|
|
strm_.zalloc = Z_NULL;
|
|
strm_.zfree = Z_NULL;
|
|
strm_.opaque = Z_NULL;
|
|
|
|
// 15 is the value of wbits, which should be at the maximum possible value
|
|
// to ensure that any gzip stream can be decoded. The offset of 32 specifies
|
|
// that the stream type should be automatically detected either gzip or
|
|
// deflate.
|
|
is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK;
|
|
}
|
|
|
|
inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); }
|
|
|
|
inline bool gzip_decompressor::is_valid() const { return is_valid_; }
|
|
|
|
inline bool gzip_decompressor::decompress(const char *data, size_t data_length,
|
|
Callback callback) {
|
|
assert(is_valid_);
|
|
|
|
auto ret = Z_OK;
|
|
|
|
do {
|
|
constexpr size_t max_avail_in =
|
|
(std::numeric_limits<decltype(strm_.avail_in)>::max)();
|
|
|
|
strm_.avail_in = static_cast<decltype(strm_.avail_in)>(
|
|
(std::min)(data_length, max_avail_in));
|
|
strm_.next_in = const_cast<Bytef *>(reinterpret_cast<const Bytef *>(data));
|
|
|
|
data_length -= strm_.avail_in;
|
|
data += strm_.avail_in;
|
|
|
|
std::array<char, CPPHTTPLIB_COMPRESSION_BUFSIZ> buff{};
|
|
while (strm_.avail_in > 0 && ret == Z_OK) {
|
|
strm_.avail_out = static_cast<uInt>(buff.size());
|
|
strm_.next_out = reinterpret_cast<Bytef *>(buff.data());
|
|
|
|
ret = inflate(&strm_, Z_NO_FLUSH);
|
|
|
|
assert(ret != Z_STREAM_ERROR);
|
|
switch (ret) {
|
|
case Z_NEED_DICT:
|
|
case Z_DATA_ERROR:
|
|
case Z_MEM_ERROR: inflateEnd(&strm_); return false;
|
|
}
|
|
|
|
if (!callback(buff.data(), buff.size() - strm_.avail_out)) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (ret != Z_OK && ret != Z_STREAM_END) { return false; }
|
|
|
|
} while (data_length > 0);
|
|
|
|
return true;
|
|
}
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
inline brotli_compressor::brotli_compressor() {
|
|
state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr);
|
|
}
|
|
|
|
inline brotli_compressor::~brotli_compressor() {
|
|
BrotliEncoderDestroyInstance(state_);
|
|
}
|
|
|
|
inline bool brotli_compressor::compress(const char *data, size_t data_length,
|
|
bool last, Callback callback) {
|
|
std::array<uint8_t, CPPHTTPLIB_COMPRESSION_BUFSIZ> buff{};
|
|
|
|
auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS;
|
|
auto available_in = data_length;
|
|
auto next_in = reinterpret_cast<const uint8_t *>(data);
|
|
|
|
for (;;) {
|
|
if (last) {
|
|
if (BrotliEncoderIsFinished(state_)) { break; }
|
|
} else {
|
|
if (!available_in) { break; }
|
|
}
|
|
|
|
auto available_out = buff.size();
|
|
auto next_out = buff.data();
|
|
|
|
if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in,
|
|
&available_out, &next_out, nullptr)) {
|
|
return false;
|
|
}
|
|
|
|
auto output_bytes = buff.size() - available_out;
|
|
if (output_bytes) {
|
|
callback(reinterpret_cast<const char *>(buff.data()), output_bytes);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline brotli_decompressor::brotli_decompressor() {
|
|
decoder_s = BrotliDecoderCreateInstance(0, 0, 0);
|
|
decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT
|
|
: BROTLI_DECODER_RESULT_ERROR;
|
|
}
|
|
|
|
inline brotli_decompressor::~brotli_decompressor() {
|
|
if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); }
|
|
}
|
|
|
|
inline bool brotli_decompressor::is_valid() const { return decoder_s; }
|
|
|
|
inline bool brotli_decompressor::decompress(const char *data,
|
|
size_t data_length,
|
|
Callback callback) {
|
|
if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS ||
|
|
decoder_r == BROTLI_DECODER_RESULT_ERROR) {
|
|
return 0;
|
|
}
|
|
|
|
auto next_in = reinterpret_cast<const uint8_t *>(data);
|
|
size_t avail_in = data_length;
|
|
size_t total_out;
|
|
|
|
decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT;
|
|
|
|
std::array<char, CPPHTTPLIB_COMPRESSION_BUFSIZ> buff{};
|
|
while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) {
|
|
char *next_out = buff.data();
|
|
size_t avail_out = buff.size();
|
|
|
|
decoder_r = BrotliDecoderDecompressStream(
|
|
decoder_s, &avail_in, &next_in, &avail_out,
|
|
reinterpret_cast<uint8_t **>(&next_out), &total_out);
|
|
|
|
if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; }
|
|
|
|
if (!callback(buff.data(), buff.size() - avail_out)) { return false; }
|
|
}
|
|
|
|
return decoder_r == BROTLI_DECODER_RESULT_SUCCESS ||
|
|
decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT;
|
|
}
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
inline zstd_compressor::zstd_compressor() {
|
|
ctx_ = ZSTD_createCCtx();
|
|
ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast);
|
|
}
|
|
|
|
inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); }
|
|
|
|
inline bool zstd_compressor::compress(const char *data, size_t data_length,
|
|
bool last, Callback callback) {
|
|
std::array<char, CPPHTTPLIB_COMPRESSION_BUFSIZ> buff{};
|
|
|
|
ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue;
|
|
ZSTD_inBuffer input = {data, data_length, 0};
|
|
|
|
bool finished;
|
|
do {
|
|
ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0};
|
|
size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode);
|
|
|
|
if (ZSTD_isError(remaining)) { return false; }
|
|
|
|
if (!callback(buff.data(), output.pos)) { return false; }
|
|
|
|
finished = last ? (remaining == 0) : (input.pos == input.size);
|
|
|
|
} while (!finished);
|
|
|
|
return true;
|
|
}
|
|
|
|
inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); }
|
|
|
|
inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); }
|
|
|
|
inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; }
|
|
|
|
inline bool zstd_decompressor::decompress(const char *data, size_t data_length,
|
|
Callback callback) {
|
|
std::array<char, CPPHTTPLIB_COMPRESSION_BUFSIZ> buff{};
|
|
ZSTD_inBuffer input = {data, data_length, 0};
|
|
|
|
while (input.pos < input.size) {
|
|
ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0};
|
|
size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input);
|
|
|
|
if (ZSTD_isError(remaining)) { return false; }
|
|
|
|
if (!callback(buff.data(), output.pos)) { return false; }
|
|
}
|
|
|
|
return true;
|
|
}
|
|
#endif
|
|
|
|
inline bool has_header(const Headers &headers, const std::string &key) {
|
|
return headers.find(key) != headers.end();
|
|
}
|
|
|
|
inline const char *get_header_value(const Headers &headers,
|
|
const std::string &key, const char *def,
|
|
size_t id) {
|
|
auto rng = headers.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) { return it->second.c_str(); }
|
|
return def;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool parse_header(const char *beg, const char *end, T fn) {
|
|
// Skip trailing spaces and tabs.
|
|
while (beg < end && is_space_or_tab(end[-1])) {
|
|
end--;
|
|
}
|
|
|
|
auto p = beg;
|
|
while (p < end && *p != ':') {
|
|
p++;
|
|
}
|
|
|
|
auto name = std::string(beg, p);
|
|
if (!detail::fields::is_field_name(name)) { return false; }
|
|
|
|
if (p == end) { return false; }
|
|
|
|
auto key_end = p;
|
|
|
|
if (*p++ != ':') { return false; }
|
|
|
|
while (p < end && is_space_or_tab(*p)) {
|
|
p++;
|
|
}
|
|
|
|
if (p <= end) {
|
|
auto key_len = key_end - beg;
|
|
if (!key_len) { return false; }
|
|
|
|
auto key = std::string(beg, key_end);
|
|
auto val = std::string(p, end);
|
|
|
|
if (!detail::fields::is_field_value(val)) { return false; }
|
|
|
|
if (case_ignore::equal(key, "Location") ||
|
|
case_ignore::equal(key, "Referer")) {
|
|
fn(key, val);
|
|
} else {
|
|
fn(key, decode_path(val, false));
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
inline bool read_headers(Stream &strm, Headers &headers) {
|
|
const auto bufsiz = 2048;
|
|
char buf[bufsiz];
|
|
stream_line_reader line_reader(strm, buf, bufsiz);
|
|
|
|
size_t header_count = 0;
|
|
|
|
for (;;) {
|
|
if (!line_reader.getline()) { return false; }
|
|
|
|
// Check if the line ends with CRLF.
|
|
auto line_terminator_len = 2;
|
|
if (line_reader.end_with_crlf()) {
|
|
// Blank line indicates end of headers.
|
|
if (line_reader.size() == 2) { break; }
|
|
} else {
|
|
#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR
|
|
// Blank line indicates end of headers.
|
|
if (line_reader.size() == 1) { break; }
|
|
line_terminator_len = 1;
|
|
#else
|
|
continue; // Skip invalid line.
|
|
#endif
|
|
}
|
|
|
|
if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; }
|
|
|
|
// Check header count limit
|
|
if (header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; }
|
|
|
|
// Exclude line terminator
|
|
auto end = line_reader.ptr() + line_reader.size() - line_terminator_len;
|
|
|
|
if (!parse_header(line_reader.ptr(), end,
|
|
[&](const std::string &key, const std::string &val) {
|
|
headers.emplace(key, val);
|
|
})) {
|
|
return false;
|
|
}
|
|
|
|
header_count++;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline bool read_content_with_length(Stream &strm, size_t len,
|
|
DownloadProgress progress,
|
|
ContentReceiverWithProgress out) {
|
|
char buf[CPPHTTPLIB_RECV_BUFSIZ];
|
|
|
|
size_t r = 0;
|
|
while (r < len) {
|
|
auto read_len = static_cast<size_t>(len - r);
|
|
auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ));
|
|
if (n <= 0) { return false; }
|
|
|
|
if (!out(buf, static_cast<size_t>(n), r, len)) { return false; }
|
|
r += static_cast<size_t>(n);
|
|
|
|
if (progress) {
|
|
if (!progress(r, len)) { return false; }
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline void skip_content_with_length(Stream &strm, size_t len) {
|
|
char buf[CPPHTTPLIB_RECV_BUFSIZ];
|
|
size_t r = 0;
|
|
while (r < len) {
|
|
auto read_len = static_cast<size_t>(len - r);
|
|
auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ));
|
|
if (n <= 0) { return; }
|
|
r += static_cast<size_t>(n);
|
|
}
|
|
}
|
|
|
|
enum class ReadContentResult {
|
|
Success, // Successfully read the content
|
|
PayloadTooLarge, // The content exceeds the specified payload limit
|
|
Error // An error occurred while reading the content
|
|
};
|
|
|
|
inline ReadContentResult
|
|
read_content_without_length(Stream &strm, size_t payload_max_length,
|
|
ContentReceiverWithProgress out) {
|
|
char buf[CPPHTTPLIB_RECV_BUFSIZ];
|
|
size_t r = 0;
|
|
for (;;) {
|
|
auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ);
|
|
if (n == 0) { return ReadContentResult::Success; }
|
|
if (n < 0) { return ReadContentResult::Error; }
|
|
|
|
// Check if adding this data would exceed the payload limit
|
|
if (r > payload_max_length ||
|
|
payload_max_length - r < static_cast<size_t>(n)) {
|
|
return ReadContentResult::PayloadTooLarge;
|
|
}
|
|
|
|
if (!out(buf, static_cast<size_t>(n), r, 0)) {
|
|
return ReadContentResult::Error;
|
|
}
|
|
r += static_cast<size_t>(n);
|
|
}
|
|
|
|
return ReadContentResult::Success;
|
|
}
|
|
|
|
template <typename T>
|
|
inline ReadContentResult read_content_chunked(Stream &strm, T &x,
|
|
size_t payload_max_length,
|
|
ContentReceiverWithProgress out) {
|
|
const auto bufsiz = 16;
|
|
char buf[bufsiz];
|
|
|
|
stream_line_reader line_reader(strm, buf, bufsiz);
|
|
|
|
if (!line_reader.getline()) { return ReadContentResult::Error; }
|
|
|
|
unsigned long chunk_len;
|
|
size_t total_len = 0;
|
|
while (true) {
|
|
char *end_ptr;
|
|
|
|
chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16);
|
|
|
|
if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; }
|
|
if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; }
|
|
|
|
if (chunk_len == 0) { break; }
|
|
|
|
// Check if adding this chunk would exceed the payload limit
|
|
if (total_len > payload_max_length ||
|
|
payload_max_length - total_len < chunk_len) {
|
|
return ReadContentResult::PayloadTooLarge;
|
|
}
|
|
|
|
total_len += chunk_len;
|
|
|
|
if (!read_content_with_length(strm, chunk_len, nullptr, out)) {
|
|
return ReadContentResult::Error;
|
|
}
|
|
|
|
if (!line_reader.getline()) { return ReadContentResult::Error; }
|
|
|
|
if (strcmp(line_reader.ptr(), "\r\n") != 0) {
|
|
return ReadContentResult::Error;
|
|
}
|
|
|
|
if (!line_reader.getline()) { return ReadContentResult::Error; }
|
|
}
|
|
|
|
assert(chunk_len == 0);
|
|
|
|
// NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked
|
|
// transfer coding is complete when a chunk with a chunk-size of zero is
|
|
// received, possibly followed by a trailer section, and finally terminated by
|
|
// an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1
|
|
//
|
|
// In '7.1.3. Decoding Chunked', however, the pseudo-code in the section
|
|
// does't care for the existence of the final CRLF. In other words, it seems
|
|
// to be ok whether the final CRLF exists or not in the chunked data.
|
|
// https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3
|
|
//
|
|
// According to the reference code in RFC 9112, cpp-httplib now allows
|
|
// chunked transfer coding data without the final CRLF.
|
|
if (!line_reader.getline()) { return ReadContentResult::Success; }
|
|
|
|
// RFC 7230 Section 4.1.2 - Headers prohibited in trailers
|
|
thread_local case_ignore::unordered_set<std::string> prohibited_trailers = {
|
|
// Message framing
|
|
"transfer-encoding", "content-length",
|
|
|
|
// Routing
|
|
"host",
|
|
|
|
// Authentication
|
|
"authorization", "www-authenticate", "proxy-authenticate",
|
|
"proxy-authorization", "cookie", "set-cookie",
|
|
|
|
// Request modifiers
|
|
"cache-control", "expect", "max-forwards", "pragma", "range", "te",
|
|
|
|
// Response control
|
|
"age", "expires", "date", "location", "retry-after", "vary", "warning",
|
|
|
|
// Payload processing
|
|
"content-encoding", "content-type", "content-range", "trailer"};
|
|
|
|
// Parse declared trailer headers once for performance
|
|
case_ignore::unordered_set<std::string> declared_trailers;
|
|
if (has_header(x.headers, "Trailer")) {
|
|
auto trailer_header = get_header_value(x.headers, "Trailer", "", 0);
|
|
auto len = std::strlen(trailer_header);
|
|
|
|
split(trailer_header, trailer_header + len, ',',
|
|
[&](const char *b, const char *e) {
|
|
std::string key(b, e);
|
|
if (prohibited_trailers.find(key) == prohibited_trailers.end()) {
|
|
declared_trailers.insert(key);
|
|
}
|
|
});
|
|
}
|
|
|
|
size_t trailer_header_count = 0;
|
|
while (strcmp(line_reader.ptr(), "\r\n") != 0) {
|
|
if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) {
|
|
return ReadContentResult::Error;
|
|
}
|
|
|
|
// Check trailer header count limit
|
|
if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) {
|
|
return ReadContentResult::Error;
|
|
}
|
|
|
|
// Exclude line terminator
|
|
constexpr auto line_terminator_len = 2;
|
|
auto end = line_reader.ptr() + line_reader.size() - line_terminator_len;
|
|
|
|
parse_header(line_reader.ptr(), end,
|
|
[&](const std::string &key, const std::string &val) {
|
|
if (declared_trailers.find(key) != declared_trailers.end()) {
|
|
x.trailers.emplace(key, val);
|
|
trailer_header_count++;
|
|
}
|
|
});
|
|
|
|
if (!line_reader.getline()) { return ReadContentResult::Error; }
|
|
}
|
|
|
|
return ReadContentResult::Success;
|
|
}
|
|
|
|
inline bool is_chunked_transfer_encoding(const Headers &headers) {
|
|
return case_ignore::equal(
|
|
get_header_value(headers, "Transfer-Encoding", "", 0), "chunked");
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
bool prepare_content_receiver(T &x, int &status,
|
|
ContentReceiverWithProgress receiver,
|
|
bool decompress, U callback) {
|
|
if (decompress) {
|
|
std::string encoding = x.get_header_value("Content-Encoding");
|
|
std::unique_ptr<decompressor> decompressor;
|
|
|
|
if (encoding == "gzip" || encoding == "deflate") {
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
decompressor = detail::make_unique<gzip_decompressor>();
|
|
#else
|
|
status = StatusCode::UnsupportedMediaType_415;
|
|
return false;
|
|
#endif
|
|
} else if (encoding.find("br") != std::string::npos) {
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
decompressor = detail::make_unique<brotli_decompressor>();
|
|
#else
|
|
status = StatusCode::UnsupportedMediaType_415;
|
|
return false;
|
|
#endif
|
|
} else if (encoding == "zstd") {
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
decompressor = detail::make_unique<zstd_decompressor>();
|
|
#else
|
|
status = StatusCode::UnsupportedMediaType_415;
|
|
return false;
|
|
#endif
|
|
}
|
|
|
|
if (decompressor) {
|
|
if (decompressor->is_valid()) {
|
|
ContentReceiverWithProgress out = [&](const char *buf, size_t n,
|
|
size_t off, size_t len) {
|
|
return decompressor->decompress(buf, n,
|
|
[&](const char *buf2, size_t n2) {
|
|
return receiver(buf2, n2, off, len);
|
|
});
|
|
};
|
|
return callback(std::move(out));
|
|
} else {
|
|
status = StatusCode::InternalServerError_500;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
ContentReceiverWithProgress out = [&](const char *buf, size_t n, size_t off,
|
|
size_t len) {
|
|
return receiver(buf, n, off, len);
|
|
};
|
|
return callback(std::move(out));
|
|
}
|
|
|
|
template <typename T>
|
|
bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
|
|
DownloadProgress progress,
|
|
ContentReceiverWithProgress receiver, bool decompress) {
|
|
return prepare_content_receiver(
|
|
x, status, std::move(receiver), decompress,
|
|
[&](const ContentReceiverWithProgress &out) {
|
|
auto ret = true;
|
|
auto exceed_payload_max_length = false;
|
|
|
|
if (is_chunked_transfer_encoding(x.headers)) {
|
|
auto result = read_content_chunked(strm, x, payload_max_length, out);
|
|
if (result == ReadContentResult::Success) {
|
|
ret = true;
|
|
} else if (result == ReadContentResult::PayloadTooLarge) {
|
|
exceed_payload_max_length = true;
|
|
ret = false;
|
|
} else {
|
|
ret = false;
|
|
}
|
|
} else if (!has_header(x.headers, "Content-Length")) {
|
|
auto result =
|
|
read_content_without_length(strm, payload_max_length, out);
|
|
if (result == ReadContentResult::Success) {
|
|
ret = true;
|
|
} else if (result == ReadContentResult::PayloadTooLarge) {
|
|
exceed_payload_max_length = true;
|
|
ret = false;
|
|
} else {
|
|
ret = false;
|
|
}
|
|
} else {
|
|
auto is_invalid_value = false;
|
|
auto len = get_header_value_u64(x.headers, "Content-Length",
|
|
(std::numeric_limits<size_t>::max)(),
|
|
0, is_invalid_value);
|
|
|
|
if (is_invalid_value) {
|
|
ret = false;
|
|
} else if (len > payload_max_length) {
|
|
exceed_payload_max_length = true;
|
|
skip_content_with_length(strm, len);
|
|
ret = false;
|
|
} else if (len > 0) {
|
|
ret = read_content_with_length(strm, len, std::move(progress), out);
|
|
}
|
|
}
|
|
|
|
if (!ret) {
|
|
status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413
|
|
: StatusCode::BadRequest_400;
|
|
}
|
|
return ret;
|
|
});
|
|
}
|
|
|
|
inline ssize_t write_request_line(Stream &strm, const std::string &method,
|
|
const std::string &path) {
|
|
std::string s = method;
|
|
s += " ";
|
|
s += path;
|
|
s += " HTTP/1.1\r\n";
|
|
return strm.write(s.data(), s.size());
|
|
}
|
|
|
|
inline ssize_t write_response_line(Stream &strm, int status) {
|
|
std::string s = "HTTP/1.1 ";
|
|
s += std::to_string(status);
|
|
s += " ";
|
|
s += httplib::status_message(status);
|
|
s += "\r\n";
|
|
return strm.write(s.data(), s.size());
|
|
}
|
|
|
|
inline ssize_t write_headers(Stream &strm, const Headers &headers) {
|
|
ssize_t write_len = 0;
|
|
for (const auto &x : headers) {
|
|
std::string s;
|
|
s = x.first;
|
|
s += ": ";
|
|
s += x.second;
|
|
s += "\r\n";
|
|
|
|
auto len = strm.write(s.data(), s.size());
|
|
if (len < 0) { return len; }
|
|
write_len += len;
|
|
}
|
|
auto len = strm.write("\r\n");
|
|
if (len < 0) { return len; }
|
|
write_len += len;
|
|
return write_len;
|
|
}
|
|
|
|
inline bool write_data(Stream &strm, const char *d, size_t l) {
|
|
size_t offset = 0;
|
|
while (offset < l) {
|
|
auto length = strm.write(d + offset, l - offset);
|
|
if (length < 0) { return false; }
|
|
offset += static_cast<size_t>(length);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool write_content_with_progress(Stream &strm,
|
|
const ContentProvider &content_provider,
|
|
size_t offset, size_t length,
|
|
T is_shutting_down,
|
|
const UploadProgress &upload_progress,
|
|
Error &error) {
|
|
size_t end_offset = offset + length;
|
|
size_t start_offset = offset;
|
|
auto ok = true;
|
|
DataSink data_sink;
|
|
|
|
data_sink.write = [&](const char *d, size_t l) -> bool {
|
|
if (ok) {
|
|
if (write_data(strm, d, l)) {
|
|
offset += l;
|
|
|
|
if (upload_progress && length > 0) {
|
|
size_t current_written = offset - start_offset;
|
|
if (!upload_progress(current_written, length)) {
|
|
ok = false;
|
|
return false;
|
|
}
|
|
}
|
|
} else {
|
|
ok = false;
|
|
}
|
|
}
|
|
return ok;
|
|
};
|
|
|
|
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
|
|
|
|
while (offset < end_offset && !is_shutting_down()) {
|
|
if (!strm.wait_writable()) {
|
|
error = Error::Write;
|
|
return false;
|
|
} else if (!content_provider(offset, end_offset - offset, data_sink)) {
|
|
error = Error::Canceled;
|
|
return false;
|
|
} else if (!ok) {
|
|
error = Error::Write;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
error = Error::Success;
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool write_content(Stream &strm, const ContentProvider &content_provider,
|
|
size_t offset, size_t length, T is_shutting_down,
|
|
Error &error) {
|
|
return write_content_with_progress<T>(strm, content_provider, offset, length,
|
|
is_shutting_down, nullptr, error);
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool write_content(Stream &strm, const ContentProvider &content_provider,
|
|
size_t offset, size_t length,
|
|
const T &is_shutting_down) {
|
|
auto error = Error::Success;
|
|
return write_content(strm, content_provider, offset, length, is_shutting_down,
|
|
error);
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool
|
|
write_content_without_length(Stream &strm,
|
|
const ContentProvider &content_provider,
|
|
const T &is_shutting_down) {
|
|
size_t offset = 0;
|
|
auto data_available = true;
|
|
auto ok = true;
|
|
DataSink data_sink;
|
|
|
|
data_sink.write = [&](const char *d, size_t l) -> bool {
|
|
if (ok) {
|
|
offset += l;
|
|
if (!write_data(strm, d, l)) { ok = false; }
|
|
}
|
|
return ok;
|
|
};
|
|
|
|
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
|
|
|
|
data_sink.done = [&](void) { data_available = false; };
|
|
|
|
while (data_available && !is_shutting_down()) {
|
|
if (!strm.wait_writable()) {
|
|
return false;
|
|
} else if (!content_provider(offset, 0, data_sink)) {
|
|
return false;
|
|
} else if (!ok) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
inline bool
|
|
write_content_chunked(Stream &strm, const ContentProvider &content_provider,
|
|
const T &is_shutting_down, U &compressor, Error &error) {
|
|
size_t offset = 0;
|
|
auto data_available = true;
|
|
auto ok = true;
|
|
DataSink data_sink;
|
|
|
|
data_sink.write = [&](const char *d, size_t l) -> bool {
|
|
if (ok) {
|
|
data_available = l > 0;
|
|
offset += l;
|
|
|
|
std::string payload;
|
|
if (compressor.compress(d, l, false,
|
|
[&](const char *data, size_t data_len) {
|
|
payload.append(data, data_len);
|
|
return true;
|
|
})) {
|
|
if (!payload.empty()) {
|
|
// Emit chunked response header and footer for each chunk
|
|
auto chunk =
|
|
from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
|
|
if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; }
|
|
}
|
|
} else {
|
|
ok = false;
|
|
}
|
|
}
|
|
return ok;
|
|
};
|
|
|
|
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
|
|
|
|
auto done_with_trailer = [&](const Headers *trailer) {
|
|
if (!ok) { return; }
|
|
|
|
data_available = false;
|
|
|
|
std::string payload;
|
|
if (!compressor.compress(nullptr, 0, true,
|
|
[&](const char *data, size_t data_len) {
|
|
payload.append(data, data_len);
|
|
return true;
|
|
})) {
|
|
ok = false;
|
|
return;
|
|
}
|
|
|
|
if (!payload.empty()) {
|
|
// Emit chunked response header and footer for each chunk
|
|
auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n";
|
|
if (!write_data(strm, chunk.data(), chunk.size())) {
|
|
ok = false;
|
|
return;
|
|
}
|
|
}
|
|
|
|
constexpr const char done_marker[] = "0\r\n";
|
|
if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; }
|
|
|
|
// Trailer
|
|
if (trailer) {
|
|
for (const auto &kv : *trailer) {
|
|
std::string field_line = kv.first + ": " + kv.second + "\r\n";
|
|
if (!write_data(strm, field_line.data(), field_line.size())) {
|
|
ok = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
constexpr const char crlf[] = "\r\n";
|
|
if (!write_data(strm, crlf, str_len(crlf))) { ok = false; }
|
|
};
|
|
|
|
data_sink.done = [&](void) { done_with_trailer(nullptr); };
|
|
|
|
data_sink.done_with_trailer = [&](const Headers &trailer) {
|
|
done_with_trailer(&trailer);
|
|
};
|
|
|
|
while (data_available && !is_shutting_down()) {
|
|
if (!strm.wait_writable()) {
|
|
error = Error::Write;
|
|
return false;
|
|
} else if (!content_provider(offset, 0, data_sink)) {
|
|
error = Error::Canceled;
|
|
return false;
|
|
} else if (!ok) {
|
|
error = Error::Write;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
error = Error::Success;
|
|
return true;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
inline bool write_content_chunked(Stream &strm,
|
|
const ContentProvider &content_provider,
|
|
const T &is_shutting_down, U &compressor) {
|
|
auto error = Error::Success;
|
|
return write_content_chunked(strm, content_provider, is_shutting_down,
|
|
compressor, error);
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool redirect(T &cli, Request &req, Response &res,
|
|
const std::string &path, const std::string &location,
|
|
Error &error) {
|
|
Request new_req = req;
|
|
new_req.path = path;
|
|
new_req.redirect_count_ -= 1;
|
|
|
|
if (res.status == StatusCode::SeeOther_303 &&
|
|
(req.method != "GET" && req.method != "HEAD")) {
|
|
new_req.method = "GET";
|
|
new_req.body.clear();
|
|
new_req.headers.clear();
|
|
}
|
|
|
|
Response new_res;
|
|
|
|
auto ret = cli.send(new_req, new_res, error);
|
|
if (ret) {
|
|
req = new_req;
|
|
res = new_res;
|
|
|
|
if (res.location.empty()) { res.location = location; }
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
inline std::string params_to_query_str(const Params ¶ms) {
|
|
std::string query;
|
|
|
|
for (auto it = params.begin(); it != params.end(); ++it) {
|
|
if (it != params.begin()) { query += "&"; }
|
|
query += it->first;
|
|
query += "=";
|
|
query += httplib::encode_uri_component(it->second);
|
|
}
|
|
return query;
|
|
}
|
|
|
|
inline void parse_query_text(const char *data, std::size_t size,
|
|
Params ¶ms) {
|
|
std::set<std::string> cache;
|
|
split(data, data + size, '&', [&](const char *b, const char *e) {
|
|
std::string kv(b, e);
|
|
if (cache.find(kv) != cache.end()) { return; }
|
|
cache.insert(std::move(kv));
|
|
|
|
std::string key;
|
|
std::string val;
|
|
divide(b, static_cast<std::size_t>(e - b), '=',
|
|
[&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data,
|
|
std::size_t rhs_size) {
|
|
key.assign(lhs_data, lhs_size);
|
|
val.assign(rhs_data, rhs_size);
|
|
});
|
|
|
|
if (!key.empty()) {
|
|
params.emplace(decode_path(key, true), decode_path(val, true));
|
|
}
|
|
});
|
|
}
|
|
|
|
inline void parse_query_text(const std::string &s, Params ¶ms) {
|
|
parse_query_text(s.data(), s.size(), params);
|
|
}
|
|
|
|
inline bool parse_multipart_boundary(const std::string &content_type,
|
|
std::string &boundary) {
|
|
auto boundary_keyword = "boundary=";
|
|
auto pos = content_type.find(boundary_keyword);
|
|
if (pos == std::string::npos) { return false; }
|
|
auto end = content_type.find(';', pos);
|
|
auto beg = pos + strlen(boundary_keyword);
|
|
boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg));
|
|
return !boundary.empty();
|
|
}
|
|
|
|
inline void parse_disposition_params(const std::string &s, Params ¶ms) {
|
|
std::set<std::string> cache;
|
|
split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) {
|
|
std::string kv(b, e);
|
|
if (cache.find(kv) != cache.end()) { return; }
|
|
cache.insert(kv);
|
|
|
|
std::string key;
|
|
std::string val;
|
|
split(b, e, '=', [&](const char *b2, const char *e2) {
|
|
if (key.empty()) {
|
|
key.assign(b2, e2);
|
|
} else {
|
|
val.assign(b2, e2);
|
|
}
|
|
});
|
|
|
|
if (!key.empty()) {
|
|
params.emplace(trim_double_quotes_copy((key)),
|
|
trim_double_quotes_copy((val)));
|
|
}
|
|
});
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_NO_EXCEPTIONS
|
|
inline bool parse_range_header(const std::string &s, Ranges &ranges) {
|
|
#else
|
|
inline bool parse_range_header(const std::string &s, Ranges &ranges) try {
|
|
#endif
|
|
auto is_valid = [](const std::string &str) {
|
|
return std::all_of(str.cbegin(), str.cend(),
|
|
[](unsigned char c) { return std::isdigit(c); });
|
|
};
|
|
|
|
if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) {
|
|
const auto pos = static_cast<size_t>(6);
|
|
const auto len = static_cast<size_t>(s.size() - 6);
|
|
auto all_valid_ranges = true;
|
|
split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) {
|
|
if (!all_valid_ranges) { return; }
|
|
|
|
const auto it = std::find(b, e, '-');
|
|
if (it == e) {
|
|
all_valid_ranges = false;
|
|
return;
|
|
}
|
|
|
|
const auto lhs = std::string(b, it);
|
|
const auto rhs = std::string(it + 1, e);
|
|
if (!is_valid(lhs) || !is_valid(rhs)) {
|
|
all_valid_ranges = false;
|
|
return;
|
|
}
|
|
|
|
const auto first =
|
|
static_cast<ssize_t>(lhs.empty() ? -1 : std::stoll(lhs));
|
|
const auto last =
|
|
static_cast<ssize_t>(rhs.empty() ? -1 : std::stoll(rhs));
|
|
if ((first == -1 && last == -1) ||
|
|
(first != -1 && last != -1 && first > last)) {
|
|
all_valid_ranges = false;
|
|
return;
|
|
}
|
|
|
|
ranges.emplace_back(first, last);
|
|
});
|
|
return all_valid_ranges && !ranges.empty();
|
|
}
|
|
return false;
|
|
#ifdef CPPHTTPLIB_NO_EXCEPTIONS
|
|
}
|
|
#else
|
|
} catch (...) { return false; }
|
|
#endif
|
|
|
|
inline bool parse_accept_header(const std::string &s,
|
|
std::vector<std::string> &content_types) {
|
|
content_types.clear();
|
|
|
|
// Empty string is considered valid (no preference)
|
|
if (s.empty()) { return true; }
|
|
|
|
// Check for invalid patterns: leading/trailing commas or consecutive commas
|
|
if (s.front() == ',' || s.back() == ',' ||
|
|
s.find(",,") != std::string::npos) {
|
|
return false;
|
|
}
|
|
|
|
struct AcceptEntry {
|
|
std::string media_type;
|
|
double quality;
|
|
int order; // Original order in header
|
|
};
|
|
|
|
std::vector<AcceptEntry> entries;
|
|
int order = 0;
|
|
bool has_invalid_entry = false;
|
|
|
|
// Split by comma and parse each entry
|
|
split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) {
|
|
std::string entry(b, e);
|
|
entry = trim_copy(entry);
|
|
|
|
if (entry.empty()) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
|
|
AcceptEntry accept_entry;
|
|
accept_entry.quality = 1.0; // Default quality
|
|
accept_entry.order = order++;
|
|
|
|
// Find q= parameter
|
|
auto q_pos = entry.find(";q=");
|
|
if (q_pos == std::string::npos) { q_pos = entry.find("; q="); }
|
|
|
|
if (q_pos != std::string::npos) {
|
|
// Extract media type (before q parameter)
|
|
accept_entry.media_type = trim_copy(entry.substr(0, q_pos));
|
|
|
|
// Extract quality value
|
|
auto q_start = entry.find('=', q_pos) + 1;
|
|
auto q_end = entry.find(';', q_start);
|
|
if (q_end == std::string::npos) { q_end = entry.length(); }
|
|
|
|
std::string quality_str =
|
|
trim_copy(entry.substr(q_start, q_end - q_start));
|
|
if (quality_str.empty()) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_NO_EXCEPTIONS
|
|
{
|
|
std::istringstream iss(quality_str);
|
|
iss >> accept_entry.quality;
|
|
|
|
// Check if conversion was successful and entire string was consumed
|
|
if (iss.fail() || !iss.eof()) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
}
|
|
#else
|
|
try {
|
|
accept_entry.quality = std::stod(quality_str);
|
|
} catch (...) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
#endif
|
|
// Check if quality is in valid range [0.0, 1.0]
|
|
if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
} else {
|
|
// No quality parameter, use entire entry as media type
|
|
accept_entry.media_type = entry;
|
|
}
|
|
|
|
// Remove additional parameters from media type
|
|
auto param_pos = accept_entry.media_type.find(';');
|
|
if (param_pos != std::string::npos) {
|
|
accept_entry.media_type =
|
|
trim_copy(accept_entry.media_type.substr(0, param_pos));
|
|
}
|
|
|
|
// Basic validation of media type format
|
|
if (accept_entry.media_type.empty()) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
|
|
// Check for basic media type format (should contain '/' or be '*')
|
|
if (accept_entry.media_type != "*" &&
|
|
accept_entry.media_type.find('/') == std::string::npos) {
|
|
has_invalid_entry = true;
|
|
return;
|
|
}
|
|
|
|
entries.push_back(accept_entry);
|
|
});
|
|
|
|
// Return false if any invalid entry was found
|
|
if (has_invalid_entry) { return false; }
|
|
|
|
// Sort by quality (descending), then by original order (ascending)
|
|
std::sort(entries.begin(), entries.end(),
|
|
[](const AcceptEntry &a, const AcceptEntry &b) {
|
|
if (a.quality != b.quality) {
|
|
return a.quality > b.quality; // Higher quality first
|
|
}
|
|
return a.order < b.order; // Earlier order first for same quality
|
|
});
|
|
|
|
// Extract sorted media types
|
|
content_types.reserve(entries.size());
|
|
for (const auto &entry : entries) {
|
|
content_types.push_back(entry.media_type);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
class FormDataParser {
|
|
public:
|
|
FormDataParser() = default;
|
|
|
|
void set_boundary(std::string &&boundary) {
|
|
boundary_ = boundary;
|
|
dash_boundary_crlf_ = dash_ + boundary_ + crlf_;
|
|
crlf_dash_boundary_ = crlf_ + dash_ + boundary_;
|
|
}
|
|
|
|
bool is_valid() const { return is_valid_; }
|
|
|
|
bool parse(const char *buf, size_t n, const FormDataHeader &header_callback,
|
|
const ContentReceiver &content_callback) {
|
|
|
|
buf_append(buf, n);
|
|
|
|
while (buf_size() > 0) {
|
|
switch (state_) {
|
|
case 0: { // Initial boundary
|
|
auto pos = buf_find(dash_boundary_crlf_);
|
|
if (pos == buf_size()) { return true; }
|
|
buf_erase(pos + dash_boundary_crlf_.size());
|
|
state_ = 1;
|
|
break;
|
|
}
|
|
case 1: { // New entry
|
|
clear_file_info();
|
|
state_ = 2;
|
|
break;
|
|
}
|
|
case 2: { // Headers
|
|
auto pos = buf_find(crlf_);
|
|
if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; }
|
|
while (pos < buf_size()) {
|
|
// Empty line
|
|
if (pos == 0) {
|
|
if (!header_callback(file_)) {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
buf_erase(crlf_.size());
|
|
state_ = 3;
|
|
break;
|
|
}
|
|
|
|
const auto header = buf_head(pos);
|
|
|
|
if (!parse_header(header.data(), header.data() + header.size(),
|
|
[&](const std::string &, const std::string &) {})) {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
|
|
// Parse and emplace space trimmed headers into a map
|
|
if (!parse_header(
|
|
header.data(), header.data() + header.size(),
|
|
[&](const std::string &key, const std::string &val) {
|
|
file_.headers.emplace(key, val);
|
|
})) {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
|
|
constexpr const char header_content_type[] = "Content-Type:";
|
|
|
|
if (start_with_case_ignore(header, header_content_type)) {
|
|
file_.content_type =
|
|
trim_copy(header.substr(str_len(header_content_type)));
|
|
} else {
|
|
thread_local const std::regex re_content_disposition(
|
|
R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~",
|
|
std::regex_constants::icase);
|
|
|
|
std::smatch m;
|
|
if (std::regex_match(header, m, re_content_disposition)) {
|
|
Params params;
|
|
parse_disposition_params(m[1], params);
|
|
|
|
auto it = params.find("name");
|
|
if (it != params.end()) {
|
|
file_.name = it->second;
|
|
} else {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
|
|
it = params.find("filename");
|
|
if (it != params.end()) { file_.filename = it->second; }
|
|
|
|
it = params.find("filename*");
|
|
if (it != params.end()) {
|
|
// Only allow UTF-8 encoding...
|
|
thread_local const std::regex re_rfc5987_encoding(
|
|
R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase);
|
|
|
|
std::smatch m2;
|
|
if (std::regex_match(it->second, m2, re_rfc5987_encoding)) {
|
|
file_.filename = decode_path(m2[1], false); // override...
|
|
} else {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
buf_erase(pos + crlf_.size());
|
|
pos = buf_find(crlf_);
|
|
}
|
|
if (state_ != 3) { return true; }
|
|
break;
|
|
}
|
|
case 3: { // Body
|
|
if (crlf_dash_boundary_.size() > buf_size()) { return true; }
|
|
auto pos = buf_find(crlf_dash_boundary_);
|
|
if (pos < buf_size()) {
|
|
if (!content_callback(buf_data(), pos)) {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
buf_erase(pos + crlf_dash_boundary_.size());
|
|
state_ = 4;
|
|
} else {
|
|
auto len = buf_size() - crlf_dash_boundary_.size();
|
|
if (len > 0) {
|
|
if (!content_callback(buf_data(), len)) {
|
|
is_valid_ = false;
|
|
return false;
|
|
}
|
|
buf_erase(len);
|
|
}
|
|
return true;
|
|
}
|
|
break;
|
|
}
|
|
case 4: { // Boundary
|
|
if (crlf_.size() > buf_size()) { return true; }
|
|
if (buf_start_with(crlf_)) {
|
|
buf_erase(crlf_.size());
|
|
state_ = 1;
|
|
} else {
|
|
if (dash_.size() > buf_size()) { return true; }
|
|
if (buf_start_with(dash_)) {
|
|
buf_erase(dash_.size());
|
|
is_valid_ = true;
|
|
buf_erase(buf_size()); // Remove epilogue
|
|
} else {
|
|
return true;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
void clear_file_info() {
|
|
file_.name.clear();
|
|
file_.filename.clear();
|
|
file_.content_type.clear();
|
|
file_.headers.clear();
|
|
}
|
|
|
|
bool start_with_case_ignore(const std::string &a, const char *b) const {
|
|
const auto b_len = strlen(b);
|
|
if (a.size() < b_len) { return false; }
|
|
for (size_t i = 0; i < b_len; i++) {
|
|
if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
const std::string dash_ = "--";
|
|
const std::string crlf_ = "\r\n";
|
|
std::string boundary_;
|
|
std::string dash_boundary_crlf_;
|
|
std::string crlf_dash_boundary_;
|
|
|
|
size_t state_ = 0;
|
|
bool is_valid_ = false;
|
|
FormData file_;
|
|
|
|
// Buffer
|
|
bool start_with(const std::string &a, size_t spos, size_t epos,
|
|
const std::string &b) const {
|
|
if (epos - spos < b.size()) { return false; }
|
|
for (size_t i = 0; i < b.size(); i++) {
|
|
if (a[i + spos] != b[i]) { return false; }
|
|
}
|
|
return true;
|
|
}
|
|
|
|
size_t buf_size() const { return buf_epos_ - buf_spos_; }
|
|
|
|
const char *buf_data() const { return &buf_[buf_spos_]; }
|
|
|
|
std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); }
|
|
|
|
bool buf_start_with(const std::string &s) const {
|
|
return start_with(buf_, buf_spos_, buf_epos_, s);
|
|
}
|
|
|
|
size_t buf_find(const std::string &s) const {
|
|
auto c = s.front();
|
|
|
|
size_t off = buf_spos_;
|
|
while (off < buf_epos_) {
|
|
auto pos = off;
|
|
while (true) {
|
|
if (pos == buf_epos_) { return buf_size(); }
|
|
if (buf_[pos] == c) { break; }
|
|
pos++;
|
|
}
|
|
|
|
auto remaining_size = buf_epos_ - pos;
|
|
if (s.size() > remaining_size) { return buf_size(); }
|
|
|
|
if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; }
|
|
|
|
off = pos + 1;
|
|
}
|
|
|
|
return buf_size();
|
|
}
|
|
|
|
void buf_append(const char *data, size_t n) {
|
|
auto remaining_size = buf_size();
|
|
if (remaining_size > 0 && buf_spos_ > 0) {
|
|
for (size_t i = 0; i < remaining_size; i++) {
|
|
buf_[i] = buf_[buf_spos_ + i];
|
|
}
|
|
}
|
|
buf_spos_ = 0;
|
|
buf_epos_ = remaining_size;
|
|
|
|
if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); }
|
|
|
|
for (size_t i = 0; i < n; i++) {
|
|
buf_[buf_epos_ + i] = data[i];
|
|
}
|
|
buf_epos_ += n;
|
|
}
|
|
|
|
void buf_erase(size_t size) { buf_spos_ += size; }
|
|
|
|
std::string buf_;
|
|
size_t buf_spos_ = 0;
|
|
size_t buf_epos_ = 0;
|
|
};
|
|
|
|
inline std::string random_string(size_t length) {
|
|
constexpr const char data[] =
|
|
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
|
|
|
thread_local auto engine([]() {
|
|
// std::random_device might actually be deterministic on some
|
|
// platforms, but due to lack of support in the c++ standard library,
|
|
// doing better requires either some ugly hacks or breaking portability.
|
|
std::random_device seed_gen;
|
|
// Request 128 bits of entropy for initialization
|
|
std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()};
|
|
return std::mt19937(seed_sequence);
|
|
}());
|
|
|
|
std::string result;
|
|
for (size_t i = 0; i < length; i++) {
|
|
result += data[engine() % (sizeof(data) - 1)];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
inline std::string make_multipart_data_boundary() {
|
|
return "--cpp-httplib-multipart-data-" + detail::random_string(16);
|
|
}
|
|
|
|
inline bool is_multipart_boundary_chars_valid(const std::string &boundary) {
|
|
auto valid = true;
|
|
for (size_t i = 0; i < boundary.size(); i++) {
|
|
auto c = boundary[i];
|
|
if (!std::isalnum(c) && c != '-' && c != '_') {
|
|
valid = false;
|
|
break;
|
|
}
|
|
}
|
|
return valid;
|
|
}
|
|
|
|
template <typename T>
|
|
inline std::string
|
|
serialize_multipart_formdata_item_begin(const T &item,
|
|
const std::string &boundary) {
|
|
std::string body = "--" + boundary + "\r\n";
|
|
body += "Content-Disposition: form-data; name=\"" + item.name + "\"";
|
|
if (!item.filename.empty()) {
|
|
body += "; filename=\"" + item.filename + "\"";
|
|
}
|
|
body += "\r\n";
|
|
if (!item.content_type.empty()) {
|
|
body += "Content-Type: " + item.content_type + "\r\n";
|
|
}
|
|
body += "\r\n";
|
|
|
|
return body;
|
|
}
|
|
|
|
inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; }
|
|
|
|
inline std::string
|
|
serialize_multipart_formdata_finish(const std::string &boundary) {
|
|
return "--" + boundary + "--\r\n";
|
|
}
|
|
|
|
inline std::string
|
|
serialize_multipart_formdata_get_content_type(const std::string &boundary) {
|
|
return "multipart/form-data; boundary=" + boundary;
|
|
}
|
|
|
|
inline std::string
|
|
serialize_multipart_formdata(const UploadFormDataItems &items,
|
|
const std::string &boundary, bool finish = true) {
|
|
std::string body;
|
|
|
|
for (const auto &item : items) {
|
|
body += serialize_multipart_formdata_item_begin(item, boundary);
|
|
body += item.content + serialize_multipart_formdata_item_end();
|
|
}
|
|
|
|
if (finish) { body += serialize_multipart_formdata_finish(boundary); }
|
|
|
|
return body;
|
|
}
|
|
|
|
inline void coalesce_ranges(Ranges &ranges, size_t content_length) {
|
|
if (ranges.size() <= 1) return;
|
|
|
|
// Sort ranges by start position
|
|
std::sort(ranges.begin(), ranges.end(),
|
|
[](const Range &a, const Range &b) { return a.first < b.first; });
|
|
|
|
Ranges coalesced;
|
|
coalesced.reserve(ranges.size());
|
|
|
|
for (auto &r : ranges) {
|
|
auto first_pos = r.first;
|
|
auto last_pos = r.second;
|
|
|
|
// Handle special cases like in range_error
|
|
if (first_pos == -1 && last_pos == -1) {
|
|
first_pos = 0;
|
|
last_pos = static_cast<ssize_t>(content_length);
|
|
}
|
|
|
|
if (first_pos == -1) {
|
|
first_pos = static_cast<ssize_t>(content_length) - last_pos;
|
|
last_pos = static_cast<ssize_t>(content_length) - 1;
|
|
}
|
|
|
|
if (last_pos == -1 || last_pos >= static_cast<ssize_t>(content_length)) {
|
|
last_pos = static_cast<ssize_t>(content_length) - 1;
|
|
}
|
|
|
|
// Skip invalid ranges
|
|
if (!(0 <= first_pos && first_pos <= last_pos &&
|
|
last_pos < static_cast<ssize_t>(content_length))) {
|
|
continue;
|
|
}
|
|
|
|
// Coalesce with previous range if overlapping or adjacent (but not
|
|
// identical)
|
|
if (!coalesced.empty()) {
|
|
auto &prev = coalesced.back();
|
|
// Check if current range overlaps or is adjacent to previous range
|
|
// but don't coalesce identical ranges (allow duplicates)
|
|
if (first_pos <= prev.second + 1 &&
|
|
!(first_pos == prev.first && last_pos == prev.second)) {
|
|
// Extend the previous range
|
|
prev.second = (std::max)(prev.second, last_pos);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Add new range
|
|
coalesced.emplace_back(first_pos, last_pos);
|
|
}
|
|
|
|
ranges = std::move(coalesced);
|
|
}
|
|
|
|
inline bool range_error(Request &req, Response &res) {
|
|
if (!req.ranges.empty() && 200 <= res.status && res.status < 300) {
|
|
ssize_t content_len = static_cast<ssize_t>(
|
|
res.content_length_ ? res.content_length_ : res.body.size());
|
|
|
|
std::vector<std::pair<ssize_t, ssize_t>> processed_ranges;
|
|
size_t overwrapping_count = 0;
|
|
|
|
// NOTE: The following Range check is based on '14.2. Range' in RFC 9110
|
|
// 'HTTP Semantics' to avoid potential denial-of-service attacks.
|
|
// https://www.rfc-editor.org/rfc/rfc9110#section-14.2
|
|
|
|
// Too many ranges
|
|
if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { return true; }
|
|
|
|
for (auto &r : req.ranges) {
|
|
auto &first_pos = r.first;
|
|
auto &last_pos = r.second;
|
|
|
|
if (first_pos == -1 && last_pos == -1) {
|
|
first_pos = 0;
|
|
last_pos = content_len;
|
|
}
|
|
|
|
if (first_pos == -1) {
|
|
first_pos = content_len - last_pos;
|
|
last_pos = content_len - 1;
|
|
}
|
|
|
|
// NOTE: RFC-9110 '14.1.2. Byte Ranges':
|
|
// A client can limit the number of bytes requested without knowing the
|
|
// size of the selected representation. If the last-pos value is absent,
|
|
// or if the value is greater than or equal to the current length of the
|
|
// representation data, the byte range is interpreted as the remainder of
|
|
// the representation (i.e., the server replaces the value of last-pos
|
|
// with a value that is one less than the current length of the selected
|
|
// representation).
|
|
// https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6
|
|
if (last_pos == -1 || last_pos >= content_len) {
|
|
last_pos = content_len - 1;
|
|
}
|
|
|
|
// Range must be within content length
|
|
if (!(0 <= first_pos && first_pos <= last_pos &&
|
|
last_pos <= content_len - 1)) {
|
|
return true;
|
|
}
|
|
|
|
// Request must not have more than two overlapping ranges
|
|
for (const auto &processed_range : processed_ranges) {
|
|
if (!(last_pos < processed_range.first ||
|
|
first_pos > processed_range.second)) {
|
|
overwrapping_count++;
|
|
if (overwrapping_count > 2) { return true; }
|
|
break; // Only count once per range
|
|
}
|
|
}
|
|
|
|
processed_ranges.emplace_back(first_pos, last_pos);
|
|
}
|
|
|
|
// After validation, coalesce overlapping ranges as per RFC 9110
|
|
coalesce_ranges(req.ranges, static_cast<size_t>(content_len));
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
inline std::pair<size_t, size_t>
|
|
get_range_offset_and_length(Range r, size_t content_length) {
|
|
assert(r.first != -1 && r.second != -1);
|
|
assert(0 <= r.first && r.first < static_cast<ssize_t>(content_length));
|
|
assert(r.first <= r.second &&
|
|
r.second < static_cast<ssize_t>(content_length));
|
|
(void)(content_length);
|
|
return std::make_pair(r.first, static_cast<size_t>(r.second - r.first) + 1);
|
|
}
|
|
|
|
inline std::string make_content_range_header_field(
|
|
const std::pair<size_t, size_t> &offset_and_length, size_t content_length) {
|
|
auto st = offset_and_length.first;
|
|
auto ed = st + offset_and_length.second - 1;
|
|
|
|
std::string field = "bytes ";
|
|
field += std::to_string(st);
|
|
field += "-";
|
|
field += std::to_string(ed);
|
|
field += "/";
|
|
field += std::to_string(content_length);
|
|
return field;
|
|
}
|
|
|
|
template <typename SToken, typename CToken, typename Content>
|
|
bool process_multipart_ranges_data(const Request &req,
|
|
const std::string &boundary,
|
|
const std::string &content_type,
|
|
size_t content_length, SToken stoken,
|
|
CToken ctoken, Content content) {
|
|
for (size_t i = 0; i < req.ranges.size(); i++) {
|
|
ctoken("--");
|
|
stoken(boundary);
|
|
ctoken("\r\n");
|
|
if (!content_type.empty()) {
|
|
ctoken("Content-Type: ");
|
|
stoken(content_type);
|
|
ctoken("\r\n");
|
|
}
|
|
|
|
auto offset_and_length =
|
|
get_range_offset_and_length(req.ranges[i], content_length);
|
|
|
|
ctoken("Content-Range: ");
|
|
stoken(make_content_range_header_field(offset_and_length, content_length));
|
|
ctoken("\r\n");
|
|
ctoken("\r\n");
|
|
|
|
if (!content(offset_and_length.first, offset_and_length.second)) {
|
|
return false;
|
|
}
|
|
ctoken("\r\n");
|
|
}
|
|
|
|
ctoken("--");
|
|
stoken(boundary);
|
|
ctoken("--");
|
|
|
|
return true;
|
|
}
|
|
|
|
inline void make_multipart_ranges_data(const Request &req, Response &res,
|
|
const std::string &boundary,
|
|
const std::string &content_type,
|
|
size_t content_length,
|
|
std::string &data) {
|
|
process_multipart_ranges_data(
|
|
req, boundary, content_type, content_length,
|
|
[&](const std::string &token) { data += token; },
|
|
[&](const std::string &token) { data += token; },
|
|
[&](size_t offset, size_t length) {
|
|
assert(offset + length <= content_length);
|
|
data += res.body.substr(offset, length);
|
|
return true;
|
|
});
|
|
}
|
|
|
|
inline size_t get_multipart_ranges_data_length(const Request &req,
|
|
const std::string &boundary,
|
|
const std::string &content_type,
|
|
size_t content_length) {
|
|
size_t data_length = 0;
|
|
|
|
process_multipart_ranges_data(
|
|
req, boundary, content_type, content_length,
|
|
[&](const std::string &token) { data_length += token.size(); },
|
|
[&](const std::string &token) { data_length += token.size(); },
|
|
[&](size_t /*offset*/, size_t length) {
|
|
data_length += length;
|
|
return true;
|
|
});
|
|
|
|
return data_length;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool
|
|
write_multipart_ranges_data(Stream &strm, const Request &req, Response &res,
|
|
const std::string &boundary,
|
|
const std::string &content_type,
|
|
size_t content_length, const T &is_shutting_down) {
|
|
return process_multipart_ranges_data(
|
|
req, boundary, content_type, content_length,
|
|
[&](const std::string &token) { strm.write(token); },
|
|
[&](const std::string &token) { strm.write(token); },
|
|
[&](size_t offset, size_t length) {
|
|
return write_content(strm, res.content_provider_, offset, length,
|
|
is_shutting_down);
|
|
});
|
|
}
|
|
|
|
inline bool expect_content(const Request &req) {
|
|
if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" ||
|
|
req.method == "DELETE") {
|
|
return true;
|
|
}
|
|
if (req.has_header("Content-Length") &&
|
|
req.get_header_value_u64("Content-Length") > 0) {
|
|
return true;
|
|
}
|
|
if (is_chunked_transfer_encoding(req.headers)) { return true; }
|
|
return false;
|
|
}
|
|
|
|
inline bool has_crlf(const std::string &s) {
|
|
auto p = s.c_str();
|
|
while (*p) {
|
|
if (*p == '\r' || *p == '\n') { return true; }
|
|
p++;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline std::string message_digest(const std::string &s, const EVP_MD *algo) {
|
|
auto context = std::unique_ptr<EVP_MD_CTX, decltype(&EVP_MD_CTX_free)>(
|
|
EVP_MD_CTX_new(), EVP_MD_CTX_free);
|
|
|
|
unsigned int hash_length = 0;
|
|
unsigned char hash[EVP_MAX_MD_SIZE];
|
|
|
|
EVP_DigestInit_ex(context.get(), algo, nullptr);
|
|
EVP_DigestUpdate(context.get(), s.c_str(), s.size());
|
|
EVP_DigestFinal_ex(context.get(), hash, &hash_length);
|
|
|
|
std::stringstream ss;
|
|
for (auto i = 0u; i < hash_length; ++i) {
|
|
ss << std::hex << std::setw(2) << std::setfill('0')
|
|
<< static_cast<unsigned int>(hash[i]);
|
|
}
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
inline std::string MD5(const std::string &s) {
|
|
return message_digest(s, EVP_md5());
|
|
}
|
|
|
|
inline std::string SHA_256(const std::string &s) {
|
|
return message_digest(s, EVP_sha256());
|
|
}
|
|
|
|
inline std::string SHA_512(const std::string &s) {
|
|
return message_digest(s, EVP_sha512());
|
|
}
|
|
|
|
inline std::pair<std::string, std::string> make_digest_authentication_header(
|
|
const Request &req, const std::map<std::string, std::string> &auth,
|
|
size_t cnonce_count, const std::string &cnonce, const std::string &username,
|
|
const std::string &password, bool is_proxy = false) {
|
|
std::string nc;
|
|
{
|
|
std::stringstream ss;
|
|
ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count;
|
|
nc = ss.str();
|
|
}
|
|
|
|
std::string qop;
|
|
if (auth.find("qop") != auth.end()) {
|
|
qop = auth.at("qop");
|
|
if (qop.find("auth-int") != std::string::npos) {
|
|
qop = "auth-int";
|
|
} else if (qop.find("auth") != std::string::npos) {
|
|
qop = "auth";
|
|
} else {
|
|
qop.clear();
|
|
}
|
|
}
|
|
|
|
std::string algo = "MD5";
|
|
if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); }
|
|
|
|
std::string response;
|
|
{
|
|
auto H = algo == "SHA-256" ? detail::SHA_256
|
|
: algo == "SHA-512" ? detail::SHA_512
|
|
: detail::MD5;
|
|
|
|
auto A1 = username + ":" + auth.at("realm") + ":" + password;
|
|
|
|
auto A2 = req.method + ":" + req.path;
|
|
if (qop == "auth-int") { A2 += ":" + H(req.body); }
|
|
|
|
if (qop.empty()) {
|
|
response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2));
|
|
} else {
|
|
response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce +
|
|
":" + qop + ":" + H(A2));
|
|
}
|
|
}
|
|
|
|
auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : "";
|
|
|
|
auto field = "Digest username=\"" + username + "\", realm=\"" +
|
|
auth.at("realm") + "\", nonce=\"" + auth.at("nonce") +
|
|
"\", uri=\"" + req.path + "\", algorithm=" + algo +
|
|
(qop.empty() ? ", response=\""
|
|
: ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" +
|
|
cnonce + "\", response=\"") +
|
|
response + "\"" +
|
|
(opaque.empty() ? "" : ", opaque=\"" + opaque + "\"");
|
|
|
|
auto key = is_proxy ? "Proxy-Authorization" : "Authorization";
|
|
return std::make_pair(key, field);
|
|
}
|
|
|
|
inline bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) {
|
|
detail::set_nonblocking(sock, true);
|
|
auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); });
|
|
|
|
char buf[1];
|
|
return !SSL_peek(ssl, buf, 1) &&
|
|
SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN;
|
|
}
|
|
|
|
#ifdef _WIN64
|
|
// NOTE: This code came up with the following stackoverflow post:
|
|
// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store
|
|
inline bool load_system_certs_on_windows(X509_STORE *store) {
|
|
auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT");
|
|
if (!hStore) { return false; }
|
|
|
|
auto result = false;
|
|
PCCERT_CONTEXT pContext = NULL;
|
|
while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) !=
|
|
nullptr) {
|
|
auto encoded_cert =
|
|
static_cast<const unsigned char *>(pContext->pbCertEncoded);
|
|
|
|
auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded);
|
|
if (x509) {
|
|
X509_STORE_add_cert(store, x509);
|
|
X509_free(x509);
|
|
result = true;
|
|
}
|
|
}
|
|
|
|
CertFreeCertificateContext(pContext);
|
|
CertCloseStore(hStore, 0);
|
|
|
|
return result;
|
|
}
|
|
#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \
|
|
defined(TARGET_OS_OSX)
|
|
template <typename T>
|
|
using CFObjectPtr =
|
|
std::unique_ptr<typename std::remove_pointer<T>::type, void (*)(CFTypeRef)>;
|
|
|
|
inline void cf_object_ptr_deleter(CFTypeRef obj) {
|
|
if (obj) { CFRelease(obj); }
|
|
}
|
|
|
|
inline bool retrieve_certs_from_keychain(CFObjectPtr<CFArrayRef> &certs) {
|
|
CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef};
|
|
CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll,
|
|
kCFBooleanTrue};
|
|
|
|
CFObjectPtr<CFDictionaryRef> query(
|
|
CFDictionaryCreate(nullptr, reinterpret_cast<const void **>(keys), values,
|
|
sizeof(keys) / sizeof(keys[0]),
|
|
&kCFTypeDictionaryKeyCallBacks,
|
|
&kCFTypeDictionaryValueCallBacks),
|
|
cf_object_ptr_deleter);
|
|
|
|
if (!query) { return false; }
|
|
|
|
CFTypeRef security_items = nullptr;
|
|
if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess ||
|
|
CFArrayGetTypeID() != CFGetTypeID(security_items)) {
|
|
return false;
|
|
}
|
|
|
|
certs.reset(reinterpret_cast<CFArrayRef>(security_items));
|
|
return true;
|
|
}
|
|
|
|
inline bool retrieve_root_certs_from_keychain(CFObjectPtr<CFArrayRef> &certs) {
|
|
CFArrayRef root_security_items = nullptr;
|
|
if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) {
|
|
return false;
|
|
}
|
|
|
|
certs.reset(root_security_items);
|
|
return true;
|
|
}
|
|
|
|
inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) {
|
|
auto result = false;
|
|
for (auto i = 0; i < CFArrayGetCount(certs); ++i) {
|
|
const auto cert = reinterpret_cast<const __SecCertificate *>(
|
|
CFArrayGetValueAtIndex(certs, i));
|
|
|
|
if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; }
|
|
|
|
CFDataRef cert_data = nullptr;
|
|
if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) !=
|
|
errSecSuccess) {
|
|
continue;
|
|
}
|
|
|
|
CFObjectPtr<CFDataRef> cert_data_ptr(cert_data, cf_object_ptr_deleter);
|
|
|
|
auto encoded_cert = static_cast<const unsigned char *>(
|
|
CFDataGetBytePtr(cert_data_ptr.get()));
|
|
|
|
auto x509 =
|
|
d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get()));
|
|
|
|
if (x509) {
|
|
X509_STORE_add_cert(store, x509);
|
|
X509_free(x509);
|
|
result = true;
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
inline bool load_system_certs_on_macos(X509_STORE *store) {
|
|
auto result = false;
|
|
CFObjectPtr<CFArrayRef> certs(nullptr, cf_object_ptr_deleter);
|
|
if (retrieve_certs_from_keychain(certs) && certs) {
|
|
result = add_certs_to_x509_store(certs.get(), store);
|
|
}
|
|
|
|
if (retrieve_root_certs_from_keychain(certs) && certs) {
|
|
result = add_certs_to_x509_store(certs.get(), store) || result;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
#endif // _WIN64
|
|
#endif // CPPHTTPLIB_OPENSSL_SUPPORT
|
|
|
|
#ifdef _WIN64
|
|
class WSInit {
|
|
public:
|
|
WSInit() {
|
|
WSADATA wsaData;
|
|
if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true;
|
|
}
|
|
|
|
~WSInit() {
|
|
if (is_valid_) WSACleanup();
|
|
}
|
|
|
|
bool is_valid_ = false;
|
|
};
|
|
|
|
static WSInit wsinit_;
|
|
#endif
|
|
|
|
inline bool parse_www_authenticate(const Response &res,
|
|
std::map<std::string, std::string> &auth,
|
|
bool is_proxy) {
|
|
auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate";
|
|
if (res.has_header(auth_key)) {
|
|
thread_local auto re =
|
|
std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~");
|
|
auto s = res.get_header_value(auth_key);
|
|
auto pos = s.find(' ');
|
|
if (pos != std::string::npos) {
|
|
auto type = s.substr(0, pos);
|
|
if (type == "Basic") {
|
|
return false;
|
|
} else if (type == "Digest") {
|
|
s = s.substr(pos + 1);
|
|
auto beg = std::sregex_iterator(s.begin(), s.end(), re);
|
|
for (auto i = beg; i != std::sregex_iterator(); ++i) {
|
|
const auto &m = *i;
|
|
auto key = s.substr(static_cast<size_t>(m.position(1)),
|
|
static_cast<size_t>(m.length(1)));
|
|
auto val = m.length(2) > 0
|
|
? s.substr(static_cast<size_t>(m.position(2)),
|
|
static_cast<size_t>(m.length(2)))
|
|
: s.substr(static_cast<size_t>(m.position(3)),
|
|
static_cast<size_t>(m.length(3)));
|
|
auth[key] = val;
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
class ContentProviderAdapter {
|
|
public:
|
|
explicit ContentProviderAdapter(
|
|
ContentProviderWithoutLength &&content_provider)
|
|
: content_provider_(content_provider) {}
|
|
|
|
bool operator()(size_t offset, size_t, DataSink &sink) {
|
|
return content_provider_(offset, sink);
|
|
}
|
|
|
|
private:
|
|
ContentProviderWithoutLength content_provider_;
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
inline std::string hosted_at(const std::string &hostname) {
|
|
std::vector<std::string> addrs;
|
|
hosted_at(hostname, addrs);
|
|
if (addrs.empty()) { return std::string(); }
|
|
return addrs[0];
|
|
}
|
|
|
|
inline void hosted_at(const std::string &hostname,
|
|
std::vector<std::string> &addrs) {
|
|
struct addrinfo hints;
|
|
struct addrinfo *result;
|
|
|
|
memset(&hints, 0, sizeof(struct addrinfo));
|
|
hints.ai_family = AF_UNSPEC;
|
|
hints.ai_socktype = SOCK_STREAM;
|
|
hints.ai_protocol = 0;
|
|
|
|
if (detail::getaddrinfo_with_timeout(hostname.c_str(), nullptr, &hints,
|
|
&result, 0)) {
|
|
#if defined __linux__ && !defined __ANDROID__
|
|
res_init();
|
|
#endif
|
|
return;
|
|
}
|
|
auto se = detail::scope_exit([&] { freeaddrinfo(result); });
|
|
|
|
for (auto rp = result; rp; rp = rp->ai_next) {
|
|
const auto &addr =
|
|
*reinterpret_cast<struct sockaddr_storage *>(rp->ai_addr);
|
|
std::string ip;
|
|
auto dummy = -1;
|
|
if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip,
|
|
dummy)) {
|
|
addrs.push_back(ip);
|
|
}
|
|
}
|
|
}
|
|
|
|
inline std::string encode_uri_component(const std::string &value) {
|
|
std::ostringstream escaped;
|
|
escaped.fill('0');
|
|
escaped << std::hex;
|
|
|
|
for (auto c : value) {
|
|
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
|
|
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
|
|
c == ')') {
|
|
escaped << c;
|
|
} else {
|
|
escaped << std::uppercase;
|
|
escaped << '%' << std::setw(2)
|
|
<< static_cast<int>(static_cast<unsigned char>(c));
|
|
escaped << std::nouppercase;
|
|
}
|
|
}
|
|
|
|
return escaped.str();
|
|
}
|
|
|
|
inline std::string encode_uri(const std::string &value) {
|
|
std::ostringstream escaped;
|
|
escaped.fill('0');
|
|
escaped << std::hex;
|
|
|
|
for (auto c : value) {
|
|
if (std::isalnum(static_cast<uint8_t>(c)) || c == '-' || c == '_' ||
|
|
c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' ||
|
|
c == ')' || c == ';' || c == '/' || c == '?' || c == ':' || c == '@' ||
|
|
c == '&' || c == '=' || c == '+' || c == '$' || c == ',' || c == '#') {
|
|
escaped << c;
|
|
} else {
|
|
escaped << std::uppercase;
|
|
escaped << '%' << std::setw(2)
|
|
<< static_cast<int>(static_cast<unsigned char>(c));
|
|
escaped << std::nouppercase;
|
|
}
|
|
}
|
|
|
|
return escaped.str();
|
|
}
|
|
|
|
inline std::string decode_uri_component(const std::string &value) {
|
|
std::string result;
|
|
|
|
for (size_t i = 0; i < value.size(); i++) {
|
|
if (value[i] == '%' && i + 2 < value.size()) {
|
|
auto val = 0;
|
|
if (detail::from_hex_to_i(value, i + 1, 2, val)) {
|
|
result += static_cast<char>(val);
|
|
i += 2;
|
|
} else {
|
|
result += value[i];
|
|
}
|
|
} else {
|
|
result += value[i];
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
inline std::string decode_uri(const std::string &value) {
|
|
std::string result;
|
|
|
|
for (size_t i = 0; i < value.size(); i++) {
|
|
if (value[i] == '%' && i + 2 < value.size()) {
|
|
auto val = 0;
|
|
if (detail::from_hex_to_i(value, i + 1, 2, val)) {
|
|
result += static_cast<char>(val);
|
|
i += 2;
|
|
} else {
|
|
result += value[i];
|
|
}
|
|
} else {
|
|
result += value[i];
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
[[deprecated("Use encode_uri_component instead")]]
|
|
inline std::string encode_query_param(const std::string &value) {
|
|
return encode_uri_component(value);
|
|
}
|
|
|
|
inline std::string append_query_params(const std::string &path,
|
|
const Params ¶ms) {
|
|
std::string path_with_query = path;
|
|
thread_local const std::regex re("[^?]+\\?.*");
|
|
auto delm = std::regex_match(path, re) ? '&' : '?';
|
|
path_with_query += delm + detail::params_to_query_str(params);
|
|
return path_with_query;
|
|
}
|
|
|
|
// Header utilities
|
|
inline std::pair<std::string, std::string>
|
|
make_range_header(const Ranges &ranges) {
|
|
std::string field = "bytes=";
|
|
auto i = 0;
|
|
for (const auto &r : ranges) {
|
|
if (i != 0) { field += ", "; }
|
|
if (r.first != -1) { field += std::to_string(r.first); }
|
|
field += '-';
|
|
if (r.second != -1) { field += std::to_string(r.second); }
|
|
i++;
|
|
}
|
|
return std::make_pair("Range", std::move(field));
|
|
}
|
|
|
|
inline std::pair<std::string, std::string>
|
|
make_basic_authentication_header(const std::string &username,
|
|
const std::string &password, bool is_proxy) {
|
|
auto field = "Basic " + detail::base64_encode(username + ":" + password);
|
|
auto key = is_proxy ? "Proxy-Authorization" : "Authorization";
|
|
return std::make_pair(key, std::move(field));
|
|
}
|
|
|
|
inline std::pair<std::string, std::string>
|
|
make_bearer_token_authentication_header(const std::string &token,
|
|
bool is_proxy = false) {
|
|
auto field = "Bearer " + token;
|
|
auto key = is_proxy ? "Proxy-Authorization" : "Authorization";
|
|
return std::make_pair(key, std::move(field));
|
|
}
|
|
|
|
// Request implementation
|
|
inline bool Request::has_header(const std::string &key) const {
|
|
return detail::has_header(headers, key);
|
|
}
|
|
|
|
inline std::string Request::get_header_value(const std::string &key,
|
|
const char *def, size_t id) const {
|
|
return detail::get_header_value(headers, key, def, id);
|
|
}
|
|
|
|
inline size_t Request::get_header_value_count(const std::string &key) const {
|
|
auto r = headers.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
inline void Request::set_header(const std::string &key,
|
|
const std::string &val) {
|
|
if (detail::fields::is_field_name(key) &&
|
|
detail::fields::is_field_value(val)) {
|
|
headers.emplace(key, val);
|
|
}
|
|
}
|
|
|
|
inline bool Request::has_trailer(const std::string &key) const {
|
|
return trailers.find(key) != trailers.end();
|
|
}
|
|
|
|
inline std::string Request::get_trailer_value(const std::string &key,
|
|
size_t id) const {
|
|
auto rng = trailers.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) { return it->second; }
|
|
return std::string();
|
|
}
|
|
|
|
inline size_t Request::get_trailer_value_count(const std::string &key) const {
|
|
auto r = trailers.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
inline bool Request::has_param(const std::string &key) const {
|
|
return params.find(key) != params.end();
|
|
}
|
|
|
|
inline std::string Request::get_param_value(const std::string &key,
|
|
size_t id) const {
|
|
auto rng = params.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) { return it->second; }
|
|
return std::string();
|
|
}
|
|
|
|
inline size_t Request::get_param_value_count(const std::string &key) const {
|
|
auto r = params.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
inline bool Request::is_multipart_form_data() const {
|
|
const auto &content_type = get_header_value("Content-Type");
|
|
return !content_type.rfind("multipart/form-data", 0);
|
|
}
|
|
|
|
// Multipart FormData implementation
|
|
inline std::string MultipartFormData::get_field(const std::string &key,
|
|
size_t id) const {
|
|
auto rng = fields.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) { return it->second.content; }
|
|
return std::string();
|
|
}
|
|
|
|
inline std::vector<std::string>
|
|
MultipartFormData::get_fields(const std::string &key) const {
|
|
std::vector<std::string> values;
|
|
auto rng = fields.equal_range(key);
|
|
for (auto it = rng.first; it != rng.second; it++) {
|
|
values.push_back(it->second.content);
|
|
}
|
|
return values;
|
|
}
|
|
|
|
inline bool MultipartFormData::has_field(const std::string &key) const {
|
|
return fields.find(key) != fields.end();
|
|
}
|
|
|
|
inline size_t MultipartFormData::get_field_count(const std::string &key) const {
|
|
auto r = fields.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
inline FormData MultipartFormData::get_file(const std::string &key,
|
|
size_t id) const {
|
|
auto rng = files.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) { return it->second; }
|
|
return FormData();
|
|
}
|
|
|
|
inline std::vector<FormData>
|
|
MultipartFormData::get_files(const std::string &key) const {
|
|
std::vector<FormData> values;
|
|
auto rng = files.equal_range(key);
|
|
for (auto it = rng.first; it != rng.second; it++) {
|
|
values.push_back(it->second);
|
|
}
|
|
return values;
|
|
}
|
|
|
|
inline bool MultipartFormData::has_file(const std::string &key) const {
|
|
return files.find(key) != files.end();
|
|
}
|
|
|
|
inline size_t MultipartFormData::get_file_count(const std::string &key) const {
|
|
auto r = files.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
// Response implementation
|
|
inline bool Response::has_header(const std::string &key) const {
|
|
return headers.find(key) != headers.end();
|
|
}
|
|
|
|
inline std::string Response::get_header_value(const std::string &key,
|
|
const char *def,
|
|
size_t id) const {
|
|
return detail::get_header_value(headers, key, def, id);
|
|
}
|
|
|
|
inline size_t Response::get_header_value_count(const std::string &key) const {
|
|
auto r = headers.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
inline void Response::set_header(const std::string &key,
|
|
const std::string &val) {
|
|
if (detail::fields::is_field_name(key) &&
|
|
detail::fields::is_field_value(val)) {
|
|
headers.emplace(key, val);
|
|
}
|
|
}
|
|
inline bool Response::has_trailer(const std::string &key) const {
|
|
return trailers.find(key) != trailers.end();
|
|
}
|
|
|
|
inline std::string Response::get_trailer_value(const std::string &key,
|
|
size_t id) const {
|
|
auto rng = trailers.equal_range(key);
|
|
auto it = rng.first;
|
|
std::advance(it, static_cast<ssize_t>(id));
|
|
if (it != rng.second) { return it->second; }
|
|
return std::string();
|
|
}
|
|
|
|
inline size_t Response::get_trailer_value_count(const std::string &key) const {
|
|
auto r = trailers.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
inline void Response::set_redirect(const std::string &url, int stat) {
|
|
if (detail::fields::is_field_value(url)) {
|
|
set_header("Location", url);
|
|
if (300 <= stat && stat < 400) {
|
|
this->status = stat;
|
|
} else {
|
|
this->status = StatusCode::Found_302;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline void Response::set_content(const char *s, size_t n,
|
|
const std::string &content_type) {
|
|
body.assign(s, n);
|
|
|
|
auto rng = headers.equal_range("Content-Type");
|
|
headers.erase(rng.first, rng.second);
|
|
set_header("Content-Type", content_type);
|
|
}
|
|
|
|
inline void Response::set_content(const std::string &s,
|
|
const std::string &content_type) {
|
|
set_content(s.data(), s.size(), content_type);
|
|
}
|
|
|
|
inline void Response::set_content(std::string &&s,
|
|
const std::string &content_type) {
|
|
body = std::move(s);
|
|
|
|
auto rng = headers.equal_range("Content-Type");
|
|
headers.erase(rng.first, rng.second);
|
|
set_header("Content-Type", content_type);
|
|
}
|
|
|
|
inline void Response::set_content_provider(
|
|
size_t in_length, const std::string &content_type, ContentProvider provider,
|
|
ContentProviderResourceReleaser resource_releaser) {
|
|
set_header("Content-Type", content_type);
|
|
content_length_ = in_length;
|
|
if (in_length > 0) { content_provider_ = std::move(provider); }
|
|
content_provider_resource_releaser_ = std::move(resource_releaser);
|
|
is_chunked_content_provider_ = false;
|
|
}
|
|
|
|
inline void Response::set_content_provider(
|
|
const std::string &content_type, ContentProviderWithoutLength provider,
|
|
ContentProviderResourceReleaser resource_releaser) {
|
|
set_header("Content-Type", content_type);
|
|
content_length_ = 0;
|
|
content_provider_ = detail::ContentProviderAdapter(std::move(provider));
|
|
content_provider_resource_releaser_ = std::move(resource_releaser);
|
|
is_chunked_content_provider_ = false;
|
|
}
|
|
|
|
inline void Response::set_chunked_content_provider(
|
|
const std::string &content_type, ContentProviderWithoutLength provider,
|
|
ContentProviderResourceReleaser resource_releaser) {
|
|
set_header("Content-Type", content_type);
|
|
content_length_ = 0;
|
|
content_provider_ = detail::ContentProviderAdapter(std::move(provider));
|
|
content_provider_resource_releaser_ = std::move(resource_releaser);
|
|
is_chunked_content_provider_ = true;
|
|
}
|
|
|
|
inline void Response::set_file_content(const std::string &path,
|
|
const std::string &content_type) {
|
|
file_content_path_ = path;
|
|
file_content_content_type_ = content_type;
|
|
}
|
|
|
|
inline void Response::set_file_content(const std::string &path) {
|
|
file_content_path_ = path;
|
|
}
|
|
|
|
// Result implementation
|
|
inline bool Result::has_request_header(const std::string &key) const {
|
|
return request_headers_.find(key) != request_headers_.end();
|
|
}
|
|
|
|
inline std::string Result::get_request_header_value(const std::string &key,
|
|
const char *def,
|
|
size_t id) const {
|
|
return detail::get_header_value(request_headers_, key, def, id);
|
|
}
|
|
|
|
inline size_t
|
|
Result::get_request_header_value_count(const std::string &key) const {
|
|
auto r = request_headers_.equal_range(key);
|
|
return static_cast<size_t>(std::distance(r.first, r.second));
|
|
}
|
|
|
|
// Stream implementation
|
|
inline ssize_t Stream::write(const char *ptr) {
|
|
return write(ptr, strlen(ptr));
|
|
}
|
|
|
|
inline ssize_t Stream::write(const std::string &s) {
|
|
return write(s.data(), s.size());
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec,
|
|
time_t timeout_sec, time_t timeout_usec,
|
|
time_t &actual_timeout_sec,
|
|
time_t &actual_timeout_usec) {
|
|
auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000);
|
|
|
|
auto actual_timeout_msec =
|
|
(std::min)(max_timeout_msec - duration_msec, timeout_msec);
|
|
|
|
if (actual_timeout_msec < 0) { actual_timeout_msec = 0; }
|
|
|
|
actual_timeout_sec = actual_timeout_msec / 1000;
|
|
actual_timeout_usec = (actual_timeout_msec % 1000) * 1000;
|
|
}
|
|
|
|
// Socket stream implementation
|
|
inline SocketStream::SocketStream(
|
|
socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec, time_t write_timeout_usec,
|
|
time_t max_timeout_msec,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time)
|
|
: sock_(sock), read_timeout_sec_(read_timeout_sec),
|
|
read_timeout_usec_(read_timeout_usec),
|
|
write_timeout_sec_(write_timeout_sec),
|
|
write_timeout_usec_(write_timeout_usec),
|
|
max_timeout_msec_(max_timeout_msec), start_time_(start_time),
|
|
read_buff_(read_buff_size_, 0) {}
|
|
|
|
inline SocketStream::~SocketStream() = default;
|
|
|
|
inline bool SocketStream::is_readable() const {
|
|
return read_buff_off_ < read_buff_content_size_;
|
|
}
|
|
|
|
inline bool SocketStream::wait_readable() const {
|
|
if (max_timeout_msec_ <= 0) {
|
|
return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
|
|
}
|
|
|
|
time_t read_timeout_sec;
|
|
time_t read_timeout_usec;
|
|
calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_,
|
|
read_timeout_usec_, read_timeout_sec, read_timeout_usec);
|
|
|
|
return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
|
|
}
|
|
|
|
inline bool SocketStream::wait_writable() const {
|
|
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
|
|
is_socket_alive(sock_);
|
|
}
|
|
|
|
inline ssize_t SocketStream::read(char *ptr, size_t size) {
|
|
#ifdef _WIN64
|
|
size =
|
|
(std::min)(size, static_cast<size_t>((std::numeric_limits<int>::max)()));
|
|
#else
|
|
size = (std::min)(size,
|
|
static_cast<size_t>((std::numeric_limits<ssize_t>::max)()));
|
|
#endif
|
|
|
|
if (read_buff_off_ < read_buff_content_size_) {
|
|
auto remaining_size = read_buff_content_size_ - read_buff_off_;
|
|
if (size <= remaining_size) {
|
|
memcpy(ptr, read_buff_.data() + read_buff_off_, size);
|
|
read_buff_off_ += size;
|
|
return static_cast<ssize_t>(size);
|
|
} else {
|
|
memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size);
|
|
read_buff_off_ += remaining_size;
|
|
return static_cast<ssize_t>(remaining_size);
|
|
}
|
|
}
|
|
|
|
if (!wait_readable()) { return -1; }
|
|
|
|
read_buff_off_ = 0;
|
|
read_buff_content_size_ = 0;
|
|
|
|
if (size < read_buff_size_) {
|
|
auto n = read_socket(sock_, read_buff_.data(), read_buff_size_,
|
|
CPPHTTPLIB_RECV_FLAGS);
|
|
if (n <= 0) {
|
|
return n;
|
|
} else if (n <= static_cast<ssize_t>(size)) {
|
|
memcpy(ptr, read_buff_.data(), static_cast<size_t>(n));
|
|
return n;
|
|
} else {
|
|
memcpy(ptr, read_buff_.data(), size);
|
|
read_buff_off_ = size;
|
|
read_buff_content_size_ = static_cast<size_t>(n);
|
|
return static_cast<ssize_t>(size);
|
|
}
|
|
} else {
|
|
return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS);
|
|
}
|
|
}
|
|
|
|
inline ssize_t SocketStream::write(const char *ptr, size_t size) {
|
|
if (!wait_writable()) { return -1; }
|
|
|
|
#if defined(_WIN64) && !defined(_WIN64)
|
|
size =
|
|
(std::min)(size, static_cast<size_t>((std::numeric_limits<int>::max)()));
|
|
#endif
|
|
|
|
return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS);
|
|
}
|
|
|
|
inline void SocketStream::get_remote_ip_and_port(std::string &ip,
|
|
int &port) const {
|
|
return detail::get_remote_ip_and_port(sock_, ip, port);
|
|
}
|
|
|
|
inline void SocketStream::get_local_ip_and_port(std::string &ip,
|
|
int &port) const {
|
|
return detail::get_local_ip_and_port(sock_, ip, port);
|
|
}
|
|
|
|
inline socket_t SocketStream::socket() const { return sock_; }
|
|
|
|
inline time_t SocketStream::duration() const {
|
|
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
std::chrono::steady_clock::now() - start_time_)
|
|
.count();
|
|
}
|
|
|
|
// Buffer stream implementation
|
|
inline bool BufferStream::is_readable() const { return true; }
|
|
|
|
inline bool BufferStream::wait_readable() const { return true; }
|
|
|
|
inline bool BufferStream::wait_writable() const { return true; }
|
|
|
|
inline ssize_t BufferStream::read(char *ptr, size_t size) {
|
|
#if defined(_MSC_VER) && _MSC_VER < 1910
|
|
auto len_read = buffer._Copy_s(ptr, size, size, position);
|
|
#else
|
|
auto len_read = buffer.copy(ptr, size, position);
|
|
#endif
|
|
position += static_cast<size_t>(len_read);
|
|
return static_cast<ssize_t>(len_read);
|
|
}
|
|
|
|
inline ssize_t BufferStream::write(const char *ptr, size_t size) {
|
|
buffer.append(ptr, size);
|
|
return static_cast<ssize_t>(size);
|
|
}
|
|
|
|
inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/,
|
|
int & /*port*/) const {}
|
|
|
|
inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/,
|
|
int & /*port*/) const {}
|
|
|
|
inline socket_t BufferStream::socket() const { return 0; }
|
|
|
|
inline time_t BufferStream::duration() const { return 0; }
|
|
|
|
inline const std::string &BufferStream::get_buffer() const { return buffer; }
|
|
|
|
inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern)
|
|
: MatcherBase(pattern) {
|
|
constexpr const char marker[] = "/:";
|
|
|
|
// One past the last ending position of a path param substring
|
|
std::size_t last_param_end = 0;
|
|
|
|
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
|
// Needed to ensure that parameter names are unique during matcher
|
|
// construction
|
|
// If exceptions are disabled, only last duplicate path
|
|
// parameter will be set
|
|
std::unordered_set<std::string> param_name_set;
|
|
#endif
|
|
|
|
while (true) {
|
|
const auto marker_pos = pattern.find(
|
|
marker, last_param_end == 0 ? last_param_end : last_param_end - 1);
|
|
if (marker_pos == std::string::npos) { break; }
|
|
|
|
static_fragments_.push_back(
|
|
pattern.substr(last_param_end, marker_pos - last_param_end + 1));
|
|
|
|
const auto param_name_start = marker_pos + str_len(marker);
|
|
|
|
auto sep_pos = pattern.find(separator, param_name_start);
|
|
if (sep_pos == std::string::npos) { sep_pos = pattern.length(); }
|
|
|
|
auto param_name =
|
|
pattern.substr(param_name_start, sep_pos - param_name_start);
|
|
|
|
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
|
if (param_name_set.find(param_name) != param_name_set.cend()) {
|
|
std::string msg = "Encountered path parameter '" + param_name +
|
|
"' multiple times in route pattern '" + pattern + "'.";
|
|
throw std::invalid_argument(msg);
|
|
}
|
|
#endif
|
|
|
|
param_names_.push_back(std::move(param_name));
|
|
|
|
last_param_end = sep_pos + 1;
|
|
}
|
|
|
|
if (last_param_end < pattern.length()) {
|
|
static_fragments_.push_back(pattern.substr(last_param_end));
|
|
}
|
|
}
|
|
|
|
inline bool PathParamsMatcher::match(Request &request) const {
|
|
request.matches = std::smatch();
|
|
request.path_params.clear();
|
|
request.path_params.reserve(param_names_.size());
|
|
|
|
// One past the position at which the path matched the pattern last time
|
|
std::size_t starting_pos = 0;
|
|
for (size_t i = 0; i < static_fragments_.size(); ++i) {
|
|
const auto &fragment = static_fragments_[i];
|
|
|
|
if (starting_pos + fragment.length() > request.path.length()) {
|
|
return false;
|
|
}
|
|
|
|
// Avoid unnecessary allocation by using strncmp instead of substr +
|
|
// comparison
|
|
if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(),
|
|
fragment.length()) != 0) {
|
|
return false;
|
|
}
|
|
|
|
starting_pos += fragment.length();
|
|
|
|
// Should only happen when we have a static fragment after a param
|
|
// Example: '/users/:id/subscriptions'
|
|
// The 'subscriptions' fragment here does not have a corresponding param
|
|
if (i >= param_names_.size()) { continue; }
|
|
|
|
auto sep_pos = request.path.find(separator, starting_pos);
|
|
if (sep_pos == std::string::npos) { sep_pos = request.path.length(); }
|
|
|
|
const auto ¶m_name = param_names_[i];
|
|
|
|
request.path_params.emplace(
|
|
param_name, request.path.substr(starting_pos, sep_pos - starting_pos));
|
|
|
|
// Mark everything up to '/' as matched
|
|
starting_pos = sep_pos + 1;
|
|
}
|
|
// Returns false if the path is longer than the pattern
|
|
return starting_pos >= request.path.length();
|
|
}
|
|
|
|
inline bool RegexMatcher::match(Request &request) const {
|
|
request.path_params.clear();
|
|
return std::regex_match(request.path, request.matches, regex_);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
// HTTP server implementation
|
|
inline Server::Server()
|
|
: new_task_queue(
|
|
[] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) {
|
|
#ifndef _WIN64
|
|
signal(SIGPIPE, SIG_IGN);
|
|
#endif
|
|
}
|
|
|
|
inline Server::~Server() = default;
|
|
|
|
inline std::unique_ptr<detail::MatcherBase>
|
|
Server::make_matcher(const std::string &pattern) {
|
|
if (pattern.find("/:") != std::string::npos) {
|
|
return detail::make_unique<detail::PathParamsMatcher>(pattern);
|
|
} else {
|
|
return detail::make_unique<detail::RegexMatcher>(pattern);
|
|
}
|
|
}
|
|
|
|
inline Server &Server::Get(const std::string &pattern, Handler handler) {
|
|
get_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Post(const std::string &pattern, Handler handler) {
|
|
post_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Post(const std::string &pattern,
|
|
HandlerWithContentReader handler) {
|
|
post_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
|
std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Put(const std::string &pattern, Handler handler) {
|
|
put_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Put(const std::string &pattern,
|
|
HandlerWithContentReader handler) {
|
|
put_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
|
std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Patch(const std::string &pattern, Handler handler) {
|
|
patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Patch(const std::string &pattern,
|
|
HandlerWithContentReader handler) {
|
|
patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
|
std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Delete(const std::string &pattern, Handler handler) {
|
|
delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Delete(const std::string &pattern,
|
|
HandlerWithContentReader handler) {
|
|
delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern),
|
|
std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::Options(const std::string &pattern, Handler handler) {
|
|
options_handlers_.emplace_back(make_matcher(pattern), std::move(handler));
|
|
return *this;
|
|
}
|
|
|
|
inline bool Server::set_base_dir(const std::string &dir,
|
|
const std::string &mount_point) {
|
|
return set_mount_point(mount_point, dir);
|
|
}
|
|
|
|
inline bool Server::set_mount_point(const std::string &mount_point,
|
|
const std::string &dir, Headers headers) {
|
|
detail::FileStat stat(dir);
|
|
if (stat.is_dir()) {
|
|
std::string mnt = !mount_point.empty() ? mount_point : "/";
|
|
if (!mnt.empty() && mnt[0] == '/') {
|
|
base_dirs_.push_back({mnt, dir, std::move(headers)});
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline bool Server::remove_mount_point(const std::string &mount_point) {
|
|
for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) {
|
|
if (it->mount_point == mount_point) {
|
|
base_dirs_.erase(it);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline Server &
|
|
Server::set_file_extension_and_mimetype_mapping(const std::string &ext,
|
|
const std::string &mime) {
|
|
file_extension_and_mimetype_map_[ext] = mime;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_default_file_mimetype(const std::string &mime) {
|
|
default_file_mimetype_ = mime;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_file_request_handler(Handler handler) {
|
|
file_request_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_error_handler_core(HandlerWithResponse handler,
|
|
std::true_type) {
|
|
error_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_error_handler_core(Handler handler,
|
|
std::false_type) {
|
|
error_handler_ = [handler](const Request &req, Response &res) {
|
|
handler(req, res);
|
|
return HandlerResponse::Handled;
|
|
};
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_exception_handler(ExceptionHandler handler) {
|
|
exception_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) {
|
|
pre_routing_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_post_routing_handler(Handler handler) {
|
|
post_routing_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_pre_request_handler(HandlerWithResponse handler) {
|
|
pre_request_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_logger(Logger logger) {
|
|
logger_ = std::move(logger);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_pre_compression_logger(Logger logger) {
|
|
pre_compression_logger_ = std::move(logger);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &
|
|
Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) {
|
|
expect_100_continue_handler_ = std::move(handler);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_address_family(int family) {
|
|
address_family_ = family;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_tcp_nodelay(bool on) {
|
|
tcp_nodelay_ = on;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_ipv6_v6only(bool on) {
|
|
ipv6_v6only_ = on;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_socket_options(SocketOptions socket_options) {
|
|
socket_options_ = std::move(socket_options);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_default_headers(Headers headers) {
|
|
default_headers_ = std::move(headers);
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_header_writer(
|
|
std::function<ssize_t(Stream &, Headers &)> const &writer) {
|
|
header_writer_ = writer;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_keep_alive_max_count(size_t count) {
|
|
keep_alive_max_count_ = count;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_keep_alive_timeout(time_t sec) {
|
|
keep_alive_timeout_sec_ = sec;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_read_timeout(time_t sec, time_t usec) {
|
|
read_timeout_sec_ = sec;
|
|
read_timeout_usec_ = usec;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_write_timeout(time_t sec, time_t usec) {
|
|
write_timeout_sec_ = sec;
|
|
write_timeout_usec_ = usec;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_idle_interval(time_t sec, time_t usec) {
|
|
idle_interval_sec_ = sec;
|
|
idle_interval_usec_ = usec;
|
|
return *this;
|
|
}
|
|
|
|
inline Server &Server::set_payload_max_length(size_t length) {
|
|
payload_max_length_ = length;
|
|
return *this;
|
|
}
|
|
|
|
inline bool Server::bind_to_port(const std::string &host, int port,
|
|
int socket_flags) {
|
|
auto ret = bind_internal(host, port, socket_flags);
|
|
if (ret == -1) { is_decommissioned = true; }
|
|
return ret >= 0;
|
|
}
|
|
inline int Server::bind_to_any_port(const std::string &host, int socket_flags) {
|
|
auto ret = bind_internal(host, 0, socket_flags);
|
|
if (ret == -1) { is_decommissioned = true; }
|
|
return ret;
|
|
}
|
|
|
|
inline bool Server::listen_after_bind() { return listen_internal(); }
|
|
|
|
inline bool Server::listen(const std::string &host, int port,
|
|
int socket_flags) {
|
|
return bind_to_port(host, port, socket_flags) && listen_internal();
|
|
}
|
|
|
|
inline bool Server::is_running() const { return is_running_; }
|
|
|
|
inline void Server::wait_until_ready() const {
|
|
while (!is_running_ && !is_decommissioned) {
|
|
std::this_thread::sleep_for(std::chrono::milliseconds{1});
|
|
}
|
|
}
|
|
|
|
inline void Server::stop() {
|
|
if (is_running_) {
|
|
assert(svr_sock_ != INVALID_SOCKET);
|
|
std::atomic<socket_t> sock(svr_sock_.exchange(INVALID_SOCKET));
|
|
detail::shutdown_socket(sock);
|
|
detail::close_socket(sock);
|
|
}
|
|
is_decommissioned = false;
|
|
}
|
|
|
|
inline void Server::decommission() { is_decommissioned = true; }
|
|
|
|
inline bool Server::parse_request_line(const char *s, Request &req) const {
|
|
auto len = strlen(s);
|
|
if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; }
|
|
len -= 2;
|
|
|
|
{
|
|
size_t count = 0;
|
|
|
|
detail::split(s, s + len, ' ', [&](const char *b, const char *e) {
|
|
switch (count) {
|
|
case 0: req.method = std::string(b, e); break;
|
|
case 1: req.target = std::string(b, e); break;
|
|
case 2: req.version = std::string(b, e); break;
|
|
default: break;
|
|
}
|
|
count++;
|
|
});
|
|
|
|
if (count != 3) { return false; }
|
|
}
|
|
|
|
thread_local const std::set<std::string> methods{
|
|
"GET", "HEAD", "POST", "PUT", "DELETE",
|
|
"CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"};
|
|
|
|
if (methods.find(req.method) == methods.end()) { return false; }
|
|
|
|
if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; }
|
|
|
|
{
|
|
// Skip URL fragment
|
|
for (size_t i = 0; i < req.target.size(); i++) {
|
|
if (req.target[i] == '#') {
|
|
req.target.erase(i);
|
|
break;
|
|
}
|
|
}
|
|
|
|
detail::divide(req.target, '?',
|
|
[&](const char *lhs_data, std::size_t lhs_size,
|
|
const char *rhs_data, std::size_t rhs_size) {
|
|
req.path = detail::decode_path(
|
|
std::string(lhs_data, lhs_size), false);
|
|
detail::parse_query_text(rhs_data, rhs_size, req.params);
|
|
});
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline bool Server::write_response(Stream &strm, bool close_connection,
|
|
Request &req, Response &res) {
|
|
// NOTE: `req.ranges` should be empty, otherwise it will be applied
|
|
// incorrectly to the error content.
|
|
req.ranges.clear();
|
|
return write_response_core(strm, close_connection, req, res, false);
|
|
}
|
|
|
|
inline bool Server::write_response_with_content(Stream &strm,
|
|
bool close_connection,
|
|
const Request &req,
|
|
Response &res) {
|
|
return write_response_core(strm, close_connection, req, res, true);
|
|
}
|
|
|
|
inline bool Server::write_response_core(Stream &strm, bool close_connection,
|
|
const Request &req, Response &res,
|
|
bool need_apply_ranges) {
|
|
assert(res.status != -1);
|
|
|
|
if (400 <= res.status && error_handler_ &&
|
|
error_handler_(req, res) == HandlerResponse::Handled) {
|
|
need_apply_ranges = true;
|
|
}
|
|
|
|
std::string content_type;
|
|
std::string boundary;
|
|
if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
|
|
|
|
// Prepare additional headers
|
|
if (close_connection || req.get_header_value("Connection") == "close") {
|
|
res.set_header("Connection", "close");
|
|
} else {
|
|
std::string s = "timeout=";
|
|
s += std::to_string(keep_alive_timeout_sec_);
|
|
s += ", max=";
|
|
s += std::to_string(keep_alive_max_count_);
|
|
res.set_header("Keep-Alive", s);
|
|
}
|
|
|
|
if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) &&
|
|
!res.has_header("Content-Type")) {
|
|
res.set_header("Content-Type", "text/plain");
|
|
}
|
|
|
|
if (res.body.empty() && !res.content_length_ && !res.content_provider_ &&
|
|
!res.has_header("Content-Length")) {
|
|
res.set_header("Content-Length", "0");
|
|
}
|
|
|
|
if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) {
|
|
res.set_header("Accept-Ranges", "bytes");
|
|
}
|
|
|
|
if (post_routing_handler_) { post_routing_handler_(req, res); }
|
|
|
|
// Response line and headers
|
|
{
|
|
detail::BufferStream bstrm;
|
|
if (!detail::write_response_line(bstrm, res.status)) { return false; }
|
|
if (!header_writer_(bstrm, res.headers)) { return false; }
|
|
|
|
// Flush buffer
|
|
auto &data = bstrm.get_buffer();
|
|
detail::write_data(strm, data.data(), data.size());
|
|
}
|
|
|
|
// Body
|
|
auto ret = true;
|
|
if (req.method != "HEAD") {
|
|
if (!res.body.empty()) {
|
|
if (!detail::write_data(strm, res.body.data(), res.body.size())) {
|
|
ret = false;
|
|
}
|
|
} else if (res.content_provider_) {
|
|
if (write_content_with_provider(strm, req, res, boundary, content_type)) {
|
|
res.content_provider_success_ = true;
|
|
} else {
|
|
ret = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Log
|
|
if (logger_) { logger_(req, res); }
|
|
|
|
return ret;
|
|
}
|
|
|
|
inline bool
|
|
Server::write_content_with_provider(Stream &strm, const Request &req,
|
|
Response &res, const std::string &boundary,
|
|
const std::string &content_type) {
|
|
auto is_shutting_down = [this]() {
|
|
return this->svr_sock_ == INVALID_SOCKET;
|
|
};
|
|
|
|
if (res.content_length_ > 0) {
|
|
if (req.ranges.empty()) {
|
|
return detail::write_content(strm, res.content_provider_, 0,
|
|
res.content_length_, is_shutting_down);
|
|
} else if (req.ranges.size() == 1) {
|
|
auto offset_and_length = detail::get_range_offset_and_length(
|
|
req.ranges[0], res.content_length_);
|
|
|
|
return detail::write_content(strm, res.content_provider_,
|
|
offset_and_length.first,
|
|
offset_and_length.second, is_shutting_down);
|
|
} else {
|
|
return detail::write_multipart_ranges_data(
|
|
strm, req, res, boundary, content_type, res.content_length_,
|
|
is_shutting_down);
|
|
}
|
|
} else {
|
|
if (res.is_chunked_content_provider_) {
|
|
auto type = detail::encoding_type(req, res);
|
|
|
|
std::unique_ptr<detail::compressor> compressor;
|
|
if (type == detail::EncodingType::Gzip) {
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
compressor = detail::make_unique<detail::gzip_compressor>();
|
|
#endif
|
|
} else if (type == detail::EncodingType::Brotli) {
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
compressor = detail::make_unique<detail::brotli_compressor>();
|
|
#endif
|
|
} else if (type == detail::EncodingType::Zstd) {
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
compressor = detail::make_unique<detail::zstd_compressor>();
|
|
#endif
|
|
} else {
|
|
compressor = detail::make_unique<detail::nocompressor>();
|
|
}
|
|
assert(compressor != nullptr);
|
|
|
|
return detail::write_content_chunked(strm, res.content_provider_,
|
|
is_shutting_down, *compressor);
|
|
} else {
|
|
return detail::write_content_without_length(strm, res.content_provider_,
|
|
is_shutting_down);
|
|
}
|
|
}
|
|
}
|
|
|
|
inline bool Server::read_content(Stream &strm, Request &req, Response &res) {
|
|
FormFields::iterator cur_field;
|
|
FormFiles::iterator cur_file;
|
|
auto is_text_field = false;
|
|
size_t count = 0;
|
|
if (read_content_core(
|
|
strm, req, res,
|
|
// Regular
|
|
[&](const char *buf, size_t n) {
|
|
if (req.body.size() + n > req.body.max_size()) { return false; }
|
|
req.body.append(buf, n);
|
|
return true;
|
|
},
|
|
// Multipart FormData
|
|
[&](const FormData &file) {
|
|
if (count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) {
|
|
return false;
|
|
}
|
|
|
|
if (file.filename.empty()) {
|
|
cur_field = req.form.fields.emplace(
|
|
file.name, FormField{file.name, file.content, file.headers});
|
|
is_text_field = true;
|
|
} else {
|
|
cur_file = req.form.files.emplace(file.name, file);
|
|
is_text_field = false;
|
|
}
|
|
return true;
|
|
},
|
|
[&](const char *buf, size_t n) {
|
|
if (is_text_field) {
|
|
auto &content = cur_field->second.content;
|
|
if (content.size() + n > content.max_size()) { return false; }
|
|
content.append(buf, n);
|
|
} else {
|
|
auto &content = cur_file->second.content;
|
|
if (content.size() + n > content.max_size()) { return false; }
|
|
content.append(buf, n);
|
|
}
|
|
return true;
|
|
})) {
|
|
const auto &content_type = req.get_header_value("Content-Type");
|
|
if (!content_type.find("application/x-www-form-urlencoded")) {
|
|
if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) {
|
|
res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414?
|
|
return false;
|
|
}
|
|
detail::parse_query_text(req.body, req.params);
|
|
}
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline bool Server::read_content_with_content_receiver(
|
|
Stream &strm, Request &req, Response &res, ContentReceiver receiver,
|
|
FormDataHeader multipart_header, ContentReceiver multipart_receiver) {
|
|
return read_content_core(strm, req, res, std::move(receiver),
|
|
std::move(multipart_header),
|
|
std::move(multipart_receiver));
|
|
}
|
|
|
|
inline bool Server::read_content_core(
|
|
Stream &strm, Request &req, Response &res, ContentReceiver receiver,
|
|
FormDataHeader multipart_header, ContentReceiver multipart_receiver) const {
|
|
detail::FormDataParser multipart_form_data_parser;
|
|
ContentReceiverWithProgress out;
|
|
|
|
if (req.is_multipart_form_data()) {
|
|
const auto &content_type = req.get_header_value("Content-Type");
|
|
std::string boundary;
|
|
if (!detail::parse_multipart_boundary(content_type, boundary)) {
|
|
res.status = StatusCode::BadRequest_400;
|
|
return false;
|
|
}
|
|
|
|
multipart_form_data_parser.set_boundary(std::move(boundary));
|
|
out = [&](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) {
|
|
return multipart_form_data_parser.parse(buf, n, multipart_header,
|
|
multipart_receiver);
|
|
};
|
|
} else {
|
|
out = [receiver](const char *buf, size_t n, size_t /*off*/,
|
|
size_t /*len*/) { return receiver(buf, n); };
|
|
}
|
|
|
|
if (req.method == "DELETE" && !req.has_header("Content-Length")) {
|
|
return true;
|
|
}
|
|
|
|
if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr,
|
|
out, true)) {
|
|
return false;
|
|
}
|
|
|
|
if (req.is_multipart_form_data()) {
|
|
if (!multipart_form_data_parser.is_valid()) {
|
|
res.status = StatusCode::BadRequest_400;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline bool Server::handle_file_request(const Request &req, Response &res) {
|
|
for (const auto &entry : base_dirs_) {
|
|
// Prefix match
|
|
if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) {
|
|
std::string sub_path = "/" + req.path.substr(entry.mount_point.size());
|
|
if (detail::is_valid_path(sub_path)) {
|
|
auto path = entry.base_dir + sub_path;
|
|
if (path.back() == '/') { path += "index.html"; }
|
|
|
|
detail::FileStat stat(path);
|
|
|
|
if (stat.is_dir()) {
|
|
res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301);
|
|
return true;
|
|
}
|
|
|
|
if (stat.is_file()) {
|
|
for (const auto &kv : entry.headers) {
|
|
res.set_header(kv.first, kv.second);
|
|
}
|
|
|
|
auto mm = std::make_shared<detail::mmap>(path.c_str());
|
|
if (!mm->is_open()) { return false; }
|
|
|
|
res.set_content_provider(
|
|
mm->size(),
|
|
detail::find_content_type(path, file_extension_and_mimetype_map_,
|
|
default_file_mimetype_),
|
|
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
|
|
sink.write(mm->data() + offset, length);
|
|
return true;
|
|
});
|
|
|
|
if (req.method != "HEAD" && file_request_handler_) {
|
|
file_request_handler_(req, res);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline socket_t
|
|
Server::create_server_socket(const std::string &host, int port,
|
|
int socket_flags,
|
|
SocketOptions socket_options) const {
|
|
return detail::create_socket(
|
|
host, std::string(), port, address_family_, socket_flags, tcp_nodelay_,
|
|
ipv6_v6only_, std::move(socket_options),
|
|
[](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool {
|
|
if (::bind(sock, ai.ai_addr, static_cast<socklen_t>(ai.ai_addrlen))) {
|
|
return false;
|
|
}
|
|
if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; }
|
|
return true;
|
|
});
|
|
}
|
|
|
|
inline int Server::bind_internal(const std::string &host, int port,
|
|
int socket_flags) {
|
|
if (is_decommissioned) { return -1; }
|
|
|
|
if (!is_valid()) { return -1; }
|
|
|
|
svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_);
|
|
if (svr_sock_ == INVALID_SOCKET) { return -1; }
|
|
|
|
if (port == 0) {
|
|
struct sockaddr_storage addr;
|
|
socklen_t addr_len = sizeof(addr);
|
|
if (getsockname(svr_sock_, reinterpret_cast<struct sockaddr *>(&addr),
|
|
&addr_len) == -1) {
|
|
return -1;
|
|
}
|
|
if (addr.ss_family == AF_INET) {
|
|
return ntohs(reinterpret_cast<struct sockaddr_in *>(&addr)->sin_port);
|
|
} else if (addr.ss_family == AF_INET6) {
|
|
return ntohs(reinterpret_cast<struct sockaddr_in6 *>(&addr)->sin6_port);
|
|
} else {
|
|
return -1;
|
|
}
|
|
} else {
|
|
return port;
|
|
}
|
|
}
|
|
|
|
inline bool Server::listen_internal() {
|
|
if (is_decommissioned) { return false; }
|
|
|
|
auto ret = true;
|
|
is_running_ = true;
|
|
auto se = detail::scope_exit([&]() { is_running_ = false; });
|
|
|
|
{
|
|
std::unique_ptr<TaskQueue> task_queue(new_task_queue());
|
|
|
|
while (svr_sock_ != INVALID_SOCKET) {
|
|
#ifndef _WIN64
|
|
if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) {
|
|
#endif
|
|
auto val = detail::select_read(svr_sock_, idle_interval_sec_,
|
|
idle_interval_usec_);
|
|
if (val == 0) { // Timeout
|
|
task_queue->on_idle();
|
|
continue;
|
|
}
|
|
#ifndef _WIN64
|
|
}
|
|
#endif
|
|
|
|
#if defined _WIN64
|
|
// sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT,
|
|
// OVERLAPPED
|
|
socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0);
|
|
#elif defined SOCK_CLOEXEC
|
|
socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC);
|
|
#else
|
|
socket_t sock = accept(svr_sock_, nullptr, nullptr);
|
|
#endif
|
|
|
|
if (sock == INVALID_SOCKET) {
|
|
if (errno == EMFILE) {
|
|
// The per-process limit of open file descriptors has been reached.
|
|
// Try to accept new connections after a short sleep.
|
|
std::this_thread::sleep_for(std::chrono::microseconds{1});
|
|
continue;
|
|
} else if (errno == EINTR || errno == EAGAIN) {
|
|
continue;
|
|
}
|
|
if (svr_sock_ != INVALID_SOCKET) {
|
|
detail::close_socket(svr_sock_);
|
|
ret = false;
|
|
} else {
|
|
; // The server socket was closed by user.
|
|
}
|
|
break;
|
|
}
|
|
|
|
detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO,
|
|
read_timeout_sec_, read_timeout_usec_);
|
|
detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO,
|
|
write_timeout_sec_, write_timeout_usec_);
|
|
|
|
if (!task_queue->enqueue(
|
|
[this, sock]() { process_and_close_socket(sock); })) {
|
|
detail::shutdown_socket(sock);
|
|
detail::close_socket(sock);
|
|
}
|
|
}
|
|
|
|
task_queue->shutdown();
|
|
}
|
|
|
|
is_decommissioned = !ret;
|
|
return ret;
|
|
}
|
|
|
|
inline bool Server::routing(Request &req, Response &res, Stream &strm) {
|
|
if (pre_routing_handler_ &&
|
|
pre_routing_handler_(req, res) == HandlerResponse::Handled) {
|
|
return true;
|
|
}
|
|
|
|
// File handler
|
|
if ((req.method == "GET" || req.method == "HEAD") &&
|
|
handle_file_request(req, res)) {
|
|
return true;
|
|
}
|
|
|
|
if (detail::expect_content(req)) {
|
|
// Content reader handler
|
|
{
|
|
ContentReader reader(
|
|
[&](ContentReceiver receiver) {
|
|
return read_content_with_content_receiver(
|
|
strm, req, res, std::move(receiver), nullptr, nullptr);
|
|
},
|
|
[&](FormDataHeader header, ContentReceiver receiver) {
|
|
return read_content_with_content_receiver(strm, req, res, nullptr,
|
|
std::move(header),
|
|
std::move(receiver));
|
|
});
|
|
|
|
if (req.method == "POST") {
|
|
if (dispatch_request_for_content_reader(
|
|
req, res, std::move(reader),
|
|
post_handlers_for_content_reader_)) {
|
|
return true;
|
|
}
|
|
} else if (req.method == "PUT") {
|
|
if (dispatch_request_for_content_reader(
|
|
req, res, std::move(reader),
|
|
put_handlers_for_content_reader_)) {
|
|
return true;
|
|
}
|
|
} else if (req.method == "PATCH") {
|
|
if (dispatch_request_for_content_reader(
|
|
req, res, std::move(reader),
|
|
patch_handlers_for_content_reader_)) {
|
|
return true;
|
|
}
|
|
} else if (req.method == "DELETE") {
|
|
if (dispatch_request_for_content_reader(
|
|
req, res, std::move(reader),
|
|
delete_handlers_for_content_reader_)) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Read content into `req.body`
|
|
if (!read_content(strm, req, res)) { return false; }
|
|
}
|
|
|
|
// Regular handler
|
|
if (req.method == "GET" || req.method == "HEAD") {
|
|
return dispatch_request(req, res, get_handlers_);
|
|
} else if (req.method == "POST") {
|
|
return dispatch_request(req, res, post_handlers_);
|
|
} else if (req.method == "PUT") {
|
|
return dispatch_request(req, res, put_handlers_);
|
|
} else if (req.method == "DELETE") {
|
|
return dispatch_request(req, res, delete_handlers_);
|
|
} else if (req.method == "OPTIONS") {
|
|
return dispatch_request(req, res, options_handlers_);
|
|
} else if (req.method == "PATCH") {
|
|
return dispatch_request(req, res, patch_handlers_);
|
|
}
|
|
|
|
res.status = StatusCode::BadRequest_400;
|
|
return false;
|
|
}
|
|
|
|
inline bool Server::dispatch_request(Request &req, Response &res,
|
|
const Handlers &handlers) const {
|
|
for (const auto &x : handlers) {
|
|
const auto &matcher = x.first;
|
|
const auto &handler = x.second;
|
|
|
|
if (matcher->match(req)) {
|
|
req.matched_route = matcher->pattern();
|
|
if (!pre_request_handler_ ||
|
|
pre_request_handler_(req, res) != HandlerResponse::Handled) {
|
|
handler(req, res);
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline void Server::apply_ranges(const Request &req, Response &res,
|
|
std::string &content_type,
|
|
std::string &boundary) const {
|
|
if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) {
|
|
auto it = res.headers.find("Content-Type");
|
|
if (it != res.headers.end()) {
|
|
content_type = it->second;
|
|
res.headers.erase(it);
|
|
}
|
|
|
|
boundary = detail::make_multipart_data_boundary();
|
|
|
|
res.set_header("Content-Type",
|
|
"multipart/byteranges; boundary=" + boundary);
|
|
}
|
|
|
|
auto type = detail::encoding_type(req, res);
|
|
|
|
if (res.body.empty()) {
|
|
if (res.content_length_ > 0) {
|
|
size_t length = 0;
|
|
if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) {
|
|
length = res.content_length_;
|
|
} else if (req.ranges.size() == 1) {
|
|
auto offset_and_length = detail::get_range_offset_and_length(
|
|
req.ranges[0], res.content_length_);
|
|
|
|
length = offset_and_length.second;
|
|
|
|
auto content_range = detail::make_content_range_header_field(
|
|
offset_and_length, res.content_length_);
|
|
res.set_header("Content-Range", content_range);
|
|
} else {
|
|
length = detail::get_multipart_ranges_data_length(
|
|
req, boundary, content_type, res.content_length_);
|
|
}
|
|
res.set_header("Content-Length", std::to_string(length));
|
|
} else {
|
|
if (res.content_provider_) {
|
|
if (res.is_chunked_content_provider_) {
|
|
res.set_header("Transfer-Encoding", "chunked");
|
|
if (type == detail::EncodingType::Gzip) {
|
|
res.set_header("Content-Encoding", "gzip");
|
|
} else if (type == detail::EncodingType::Brotli) {
|
|
res.set_header("Content-Encoding", "br");
|
|
} else if (type == detail::EncodingType::Zstd) {
|
|
res.set_header("Content-Encoding", "zstd");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) {
|
|
;
|
|
} else if (req.ranges.size() == 1) {
|
|
auto offset_and_length =
|
|
detail::get_range_offset_and_length(req.ranges[0], res.body.size());
|
|
auto offset = offset_and_length.first;
|
|
auto length = offset_and_length.second;
|
|
|
|
auto content_range = detail::make_content_range_header_field(
|
|
offset_and_length, res.body.size());
|
|
res.set_header("Content-Range", content_range);
|
|
|
|
assert(offset + length <= res.body.size());
|
|
res.body = res.body.substr(offset, length);
|
|
} else {
|
|
std::string data;
|
|
detail::make_multipart_ranges_data(req, res, boundary, content_type,
|
|
res.body.size(), data);
|
|
res.body.swap(data);
|
|
}
|
|
|
|
if (type != detail::EncodingType::None) {
|
|
if (pre_compression_logger_) { pre_compression_logger_(req, res); }
|
|
|
|
std::unique_ptr<detail::compressor> compressor;
|
|
std::string content_encoding;
|
|
|
|
if (type == detail::EncodingType::Gzip) {
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
compressor = detail::make_unique<detail::gzip_compressor>();
|
|
content_encoding = "gzip";
|
|
#endif
|
|
} else if (type == detail::EncodingType::Brotli) {
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
compressor = detail::make_unique<detail::brotli_compressor>();
|
|
content_encoding = "br";
|
|
#endif
|
|
} else if (type == detail::EncodingType::Zstd) {
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
compressor = detail::make_unique<detail::zstd_compressor>();
|
|
content_encoding = "zstd";
|
|
#endif
|
|
}
|
|
|
|
if (compressor) {
|
|
std::string compressed;
|
|
if (compressor->compress(res.body.data(), res.body.size(), true,
|
|
[&](const char *data, size_t data_len) {
|
|
compressed.append(data, data_len);
|
|
return true;
|
|
})) {
|
|
res.body.swap(compressed);
|
|
res.set_header("Content-Encoding", content_encoding);
|
|
}
|
|
}
|
|
}
|
|
|
|
auto length = std::to_string(res.body.size());
|
|
res.set_header("Content-Length", length);
|
|
}
|
|
}
|
|
|
|
inline bool Server::dispatch_request_for_content_reader(
|
|
Request &req, Response &res, ContentReader content_reader,
|
|
const HandlersForContentReader &handlers) const {
|
|
for (const auto &x : handlers) {
|
|
const auto &matcher = x.first;
|
|
const auto &handler = x.second;
|
|
|
|
if (matcher->match(req)) {
|
|
req.matched_route = matcher->pattern();
|
|
if (!pre_request_handler_ ||
|
|
pre_request_handler_(req, res) != HandlerResponse::Handled) {
|
|
handler(req, res, content_reader);
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline bool
|
|
Server::process_request(Stream &strm, const std::string &remote_addr,
|
|
int remote_port, const std::string &local_addr,
|
|
int local_port, bool close_connection,
|
|
bool &connection_closed,
|
|
const std::function<void(Request &)> &setup_request) {
|
|
std::array<char, 2048> buf{};
|
|
|
|
detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
|
|
|
|
// Connection has been closed on client
|
|
if (!line_reader.getline()) { return false; }
|
|
|
|
Request req;
|
|
|
|
Response res;
|
|
res.version = "HTTP/1.1";
|
|
res.headers = default_headers_;
|
|
|
|
#ifdef __APPLE__
|
|
// Socket file descriptor exceeded FD_SETSIZE...
|
|
if (strm.socket() >= FD_SETSIZE) {
|
|
Headers dummy;
|
|
detail::read_headers(strm, dummy);
|
|
res.status = StatusCode::InternalServerError_500;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
#endif
|
|
|
|
// Request line and headers
|
|
if (!parse_request_line(line_reader.ptr(), req) ||
|
|
!detail::read_headers(strm, req.headers)) {
|
|
res.status = StatusCode::BadRequest_400;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
|
|
// Check if the request URI doesn't exceed the limit
|
|
if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
|
|
Headers dummy;
|
|
detail::read_headers(strm, dummy);
|
|
res.status = StatusCode::UriTooLong_414;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
|
|
if (req.get_header_value("Connection") == "close") {
|
|
connection_closed = true;
|
|
}
|
|
|
|
if (req.version == "HTTP/1.0" &&
|
|
req.get_header_value("Connection") != "Keep-Alive") {
|
|
connection_closed = true;
|
|
}
|
|
|
|
req.remote_addr = remote_addr;
|
|
req.remote_port = remote_port;
|
|
req.set_header("REMOTE_ADDR", req.remote_addr);
|
|
req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
|
|
|
|
req.local_addr = local_addr;
|
|
req.local_port = local_port;
|
|
req.set_header("LOCAL_ADDR", req.local_addr);
|
|
req.set_header("LOCAL_PORT", std::to_string(req.local_port));
|
|
|
|
if (req.has_header("Accept")) {
|
|
const auto &accept_header = req.get_header_value("Accept");
|
|
if (!detail::parse_accept_header(accept_header, req.accept_content_types)) {
|
|
res.status = StatusCode::BadRequest_400;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
}
|
|
|
|
if (req.has_header("Range")) {
|
|
const auto &range_header_value = req.get_header_value("Range");
|
|
if (!detail::parse_range_header(range_header_value, req.ranges)) {
|
|
res.status = StatusCode::RangeNotSatisfiable_416;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
}
|
|
|
|
if (setup_request) { setup_request(req); }
|
|
|
|
if (req.get_header_value("Expect") == "100-continue") {
|
|
int status = StatusCode::Continue_100;
|
|
if (expect_100_continue_handler_) {
|
|
status = expect_100_continue_handler_(req, res);
|
|
}
|
|
switch (status) {
|
|
case StatusCode::Continue_100:
|
|
case StatusCode::ExpectationFailed_417:
|
|
detail::write_response_line(strm, status);
|
|
strm.write("\r\n");
|
|
break;
|
|
default:
|
|
connection_closed = true;
|
|
return write_response(strm, true, req, res);
|
|
}
|
|
}
|
|
|
|
// Setup `is_connection_closed` method
|
|
auto sock = strm.socket();
|
|
req.is_connection_closed = [sock]() {
|
|
return !detail::is_socket_alive(sock);
|
|
};
|
|
|
|
// Routing
|
|
auto routed = false;
|
|
#ifdef CPPHTTPLIB_NO_EXCEPTIONS
|
|
routed = routing(req, res, strm);
|
|
#else
|
|
try {
|
|
routed = routing(req, res, strm);
|
|
} catch (std::exception &e) {
|
|
if (exception_handler_) {
|
|
auto ep = std::current_exception();
|
|
exception_handler_(req, res, ep);
|
|
routed = true;
|
|
} else {
|
|
res.status = StatusCode::InternalServerError_500;
|
|
std::string val;
|
|
auto s = e.what();
|
|
for (size_t i = 0; s[i]; i++) {
|
|
switch (s[i]) {
|
|
case '\r': val += "\\r"; break;
|
|
case '\n': val += "\\n"; break;
|
|
default: val += s[i]; break;
|
|
}
|
|
}
|
|
res.set_header("EXCEPTION_WHAT", val);
|
|
}
|
|
} catch (...) {
|
|
if (exception_handler_) {
|
|
auto ep = std::current_exception();
|
|
exception_handler_(req, res, ep);
|
|
routed = true;
|
|
} else {
|
|
res.status = StatusCode::InternalServerError_500;
|
|
res.set_header("EXCEPTION_WHAT", "UNKNOWN");
|
|
}
|
|
}
|
|
#endif
|
|
if (routed) {
|
|
if (res.status == -1) {
|
|
res.status = req.ranges.empty() ? StatusCode::OK_200
|
|
: StatusCode::PartialContent_206;
|
|
}
|
|
|
|
// Serve file content by using a content provider
|
|
if (!res.file_content_path_.empty()) {
|
|
const auto &path = res.file_content_path_;
|
|
auto mm = std::make_shared<detail::mmap>(path.c_str());
|
|
if (!mm->is_open()) {
|
|
res.body.clear();
|
|
res.content_length_ = 0;
|
|
res.content_provider_ = nullptr;
|
|
res.status = StatusCode::NotFound_404;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
|
|
auto content_type = res.file_content_content_type_;
|
|
if (content_type.empty()) {
|
|
content_type = detail::find_content_type(
|
|
path, file_extension_and_mimetype_map_, default_file_mimetype_);
|
|
}
|
|
|
|
res.set_content_provider(
|
|
mm->size(), content_type,
|
|
[mm](size_t offset, size_t length, DataSink &sink) -> bool {
|
|
sink.write(mm->data() + offset, length);
|
|
return true;
|
|
});
|
|
}
|
|
|
|
if (detail::range_error(req, res)) {
|
|
res.body.clear();
|
|
res.content_length_ = 0;
|
|
res.content_provider_ = nullptr;
|
|
res.status = StatusCode::RangeNotSatisfiable_416;
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
|
|
return write_response_with_content(strm, close_connection, req, res);
|
|
} else {
|
|
if (res.status == -1) { res.status = StatusCode::NotFound_404; }
|
|
|
|
return write_response(strm, close_connection, req, res);
|
|
}
|
|
}
|
|
|
|
inline bool Server::is_valid() const { return true; }
|
|
|
|
inline bool Server::process_and_close_socket(socket_t sock) {
|
|
std::string remote_addr;
|
|
int remote_port = 0;
|
|
detail::get_remote_ip_and_port(sock, remote_addr, remote_port);
|
|
|
|
std::string local_addr;
|
|
int local_port = 0;
|
|
detail::get_local_ip_and_port(sock, local_addr, local_port);
|
|
|
|
auto ret = detail::process_server_socket(
|
|
svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_,
|
|
read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
|
|
write_timeout_usec_,
|
|
[&](Stream &strm, bool close_connection, bool &connection_closed) {
|
|
return process_request(strm, remote_addr, remote_port, local_addr,
|
|
local_port, close_connection, connection_closed,
|
|
nullptr);
|
|
});
|
|
|
|
detail::shutdown_socket(sock);
|
|
detail::close_socket(sock);
|
|
return ret;
|
|
}
|
|
|
|
// HTTP client implementation
|
|
inline ClientImpl::ClientImpl(const std::string &host)
|
|
: ClientImpl(host, 80, std::string(), std::string()) {}
|
|
|
|
inline ClientImpl::ClientImpl(const std::string &host, int port)
|
|
: ClientImpl(host, port, std::string(), std::string()) {}
|
|
|
|
inline ClientImpl::ClientImpl(const std::string &host, int port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path)
|
|
: host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port),
|
|
host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)),
|
|
client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
|
|
|
|
inline ClientImpl::~ClientImpl() {
|
|
// Wait until all the requests in flight are handled.
|
|
size_t retry_count = 10;
|
|
while (retry_count-- > 0) {
|
|
{
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
if (socket_requests_in_flight_ == 0) { break; }
|
|
}
|
|
std::this_thread::sleep_for(std::chrono::milliseconds{1});
|
|
}
|
|
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
shutdown_socket(socket_);
|
|
close_socket(socket_);
|
|
}
|
|
|
|
inline bool ClientImpl::is_valid() const { return true; }
|
|
|
|
inline void ClientImpl::copy_settings(const ClientImpl &rhs) {
|
|
client_cert_path_ = rhs.client_cert_path_;
|
|
client_key_path_ = rhs.client_key_path_;
|
|
connection_timeout_sec_ = rhs.connection_timeout_sec_;
|
|
read_timeout_sec_ = rhs.read_timeout_sec_;
|
|
read_timeout_usec_ = rhs.read_timeout_usec_;
|
|
write_timeout_sec_ = rhs.write_timeout_sec_;
|
|
write_timeout_usec_ = rhs.write_timeout_usec_;
|
|
max_timeout_msec_ = rhs.max_timeout_msec_;
|
|
basic_auth_username_ = rhs.basic_auth_username_;
|
|
basic_auth_password_ = rhs.basic_auth_password_;
|
|
bearer_token_auth_token_ = rhs.bearer_token_auth_token_;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
digest_auth_username_ = rhs.digest_auth_username_;
|
|
digest_auth_password_ = rhs.digest_auth_password_;
|
|
#endif
|
|
keep_alive_ = rhs.keep_alive_;
|
|
follow_location_ = rhs.follow_location_;
|
|
path_encode_ = rhs.path_encode_;
|
|
address_family_ = rhs.address_family_;
|
|
tcp_nodelay_ = rhs.tcp_nodelay_;
|
|
ipv6_v6only_ = rhs.ipv6_v6only_;
|
|
socket_options_ = rhs.socket_options_;
|
|
compress_ = rhs.compress_;
|
|
decompress_ = rhs.decompress_;
|
|
interface_ = rhs.interface_;
|
|
proxy_host_ = rhs.proxy_host_;
|
|
proxy_port_ = rhs.proxy_port_;
|
|
proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_;
|
|
proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_;
|
|
proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_;
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_;
|
|
proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_;
|
|
#endif
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
ca_cert_file_path_ = rhs.ca_cert_file_path_;
|
|
ca_cert_dir_path_ = rhs.ca_cert_dir_path_;
|
|
ca_cert_store_ = rhs.ca_cert_store_;
|
|
#endif
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
server_certificate_verification_ = rhs.server_certificate_verification_;
|
|
server_hostname_verification_ = rhs.server_hostname_verification_;
|
|
server_certificate_verifier_ = rhs.server_certificate_verifier_;
|
|
#endif
|
|
logger_ = rhs.logger_;
|
|
}
|
|
|
|
inline socket_t ClientImpl::create_client_socket(Error &error) const {
|
|
if (!proxy_host_.empty() && proxy_port_ != -1) {
|
|
return detail::create_client_socket(
|
|
proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_,
|
|
ipv6_v6only_, socket_options_, connection_timeout_sec_,
|
|
connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_,
|
|
write_timeout_sec_, write_timeout_usec_, interface_, error);
|
|
}
|
|
|
|
// Check is custom IP specified for host_
|
|
std::string ip;
|
|
auto it = addr_map_.find(host_);
|
|
if (it != addr_map_.end()) { ip = it->second; }
|
|
|
|
return detail::create_client_socket(
|
|
host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_,
|
|
socket_options_, connection_timeout_sec_, connection_timeout_usec_,
|
|
read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
|
|
write_timeout_usec_, interface_, error);
|
|
}
|
|
|
|
inline bool ClientImpl::create_and_connect_socket(Socket &socket,
|
|
Error &error) {
|
|
auto sock = create_client_socket(error);
|
|
if (sock == INVALID_SOCKET) { return false; }
|
|
socket.sock = sock;
|
|
return true;
|
|
}
|
|
|
|
inline void ClientImpl::shutdown_ssl(Socket & /*socket*/,
|
|
bool /*shutdown_gracefully*/) {
|
|
// If there are any requests in flight from threads other than us, then it's
|
|
// a thread-unsafe race because individual ssl* objects are not thread-safe.
|
|
assert(socket_requests_in_flight_ == 0 ||
|
|
socket_requests_are_from_thread_ == std::this_thread::get_id());
|
|
}
|
|
|
|
inline void ClientImpl::shutdown_socket(Socket &socket) const {
|
|
if (socket.sock == INVALID_SOCKET) { return; }
|
|
detail::shutdown_socket(socket.sock);
|
|
}
|
|
|
|
inline void ClientImpl::close_socket(Socket &socket) {
|
|
// If there are requests in flight in another thread, usually closing
|
|
// the socket will be fine and they will simply receive an error when
|
|
// using the closed socket, but it is still a bug since rarely the OS
|
|
// may reassign the socket id to be used for a new socket, and then
|
|
// suddenly they will be operating on a live socket that is different
|
|
// than the one they intended!
|
|
assert(socket_requests_in_flight_ == 0 ||
|
|
socket_requests_are_from_thread_ == std::this_thread::get_id());
|
|
|
|
// It is also a bug if this happens while SSL is still active
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
assert(socket.ssl == nullptr);
|
|
#endif
|
|
if (socket.sock == INVALID_SOCKET) { return; }
|
|
detail::close_socket(socket.sock);
|
|
socket.sock = INVALID_SOCKET;
|
|
}
|
|
|
|
inline bool ClientImpl::read_response_line(Stream &strm, const Request &req,
|
|
Response &res) const {
|
|
std::array<char, 2048> buf{};
|
|
|
|
detail::stream_line_reader line_reader(strm, buf.data(), buf.size());
|
|
|
|
if (!line_reader.getline()) { return false; }
|
|
|
|
#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR
|
|
thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n");
|
|
#else
|
|
thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n");
|
|
#endif
|
|
|
|
std::cmatch m;
|
|
if (!std::regex_match(line_reader.ptr(), m, re)) {
|
|
return req.method == "CONNECT";
|
|
}
|
|
res.version = std::string(m[1]);
|
|
res.status = std::stoi(std::string(m[2]));
|
|
res.reason = std::string(m[3]);
|
|
|
|
// Ignore '100 Continue'
|
|
while (res.status == StatusCode::Continue_100) {
|
|
if (!line_reader.getline()) { return false; } // CRLF
|
|
if (!line_reader.getline()) { return false; } // next response line
|
|
|
|
if (!std::regex_match(line_reader.ptr(), m, re)) { return false; }
|
|
res.version = std::string(m[1]);
|
|
res.status = std::stoi(std::string(m[2]));
|
|
res.reason = std::string(m[3]);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline bool ClientImpl::send(Request &req, Response &res, Error &error) {
|
|
std::lock_guard<std::recursive_mutex> request_mutex_guard(request_mutex_);
|
|
auto ret = send_(req, res, error);
|
|
if (error == Error::SSLPeerCouldBeClosed_) {
|
|
assert(!ret);
|
|
ret = send_(req, res, error);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
inline bool ClientImpl::send_(Request &req, Response &res, Error &error) {
|
|
{
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
|
|
// Set this to false immediately - if it ever gets set to true by the end of
|
|
// the request, we know another thread instructed us to close the socket.
|
|
socket_should_be_closed_when_request_is_done_ = false;
|
|
|
|
auto is_alive = false;
|
|
if (socket_.is_open()) {
|
|
is_alive = detail::is_socket_alive(socket_.sock);
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (is_alive && is_ssl()) {
|
|
if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) {
|
|
is_alive = false;
|
|
}
|
|
}
|
|
#endif
|
|
|
|
if (!is_alive) {
|
|
// Attempt to avoid sigpipe by shutting down non-gracefully if it seems
|
|
// like the other side has already closed the connection Also, there
|
|
// cannot be any requests in flight from other threads since we locked
|
|
// request_mutex_, so safe to close everything immediately
|
|
const bool shutdown_gracefully = false;
|
|
shutdown_ssl(socket_, shutdown_gracefully);
|
|
shutdown_socket(socket_);
|
|
close_socket(socket_);
|
|
}
|
|
}
|
|
|
|
if (!is_alive) {
|
|
if (!create_and_connect_socket(socket_, error)) { return false; }
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
// TODO: refactoring
|
|
if (is_ssl()) {
|
|
auto &scli = static_cast<SSLClient &>(*this);
|
|
if (!proxy_host_.empty() && proxy_port_ != -1) {
|
|
auto success = false;
|
|
if (!scli.connect_with_proxy(socket_, req.start_time_, res, success,
|
|
error)) {
|
|
return success;
|
|
}
|
|
}
|
|
|
|
if (!scli.initialize_ssl(socket_, error)) { return false; }
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// Mark the current socket as being in use so that it cannot be closed by
|
|
// anyone else while this request is ongoing, even though we will be
|
|
// releasing the mutex.
|
|
if (socket_requests_in_flight_ > 1) {
|
|
assert(socket_requests_are_from_thread_ == std::this_thread::get_id());
|
|
}
|
|
socket_requests_in_flight_ += 1;
|
|
socket_requests_are_from_thread_ = std::this_thread::get_id();
|
|
}
|
|
|
|
for (const auto &header : default_headers_) {
|
|
if (req.headers.find(header.first) == req.headers.end()) {
|
|
req.headers.insert(header);
|
|
}
|
|
}
|
|
|
|
auto ret = false;
|
|
auto close_connection = !keep_alive_;
|
|
|
|
auto se = detail::scope_exit([&]() {
|
|
// Briefly lock mutex in order to mark that a request is no longer ongoing
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
socket_requests_in_flight_ -= 1;
|
|
if (socket_requests_in_flight_ <= 0) {
|
|
assert(socket_requests_in_flight_ == 0);
|
|
socket_requests_are_from_thread_ = std::thread::id();
|
|
}
|
|
|
|
if (socket_should_be_closed_when_request_is_done_ || close_connection ||
|
|
!ret) {
|
|
shutdown_ssl(socket_, true);
|
|
shutdown_socket(socket_);
|
|
close_socket(socket_);
|
|
}
|
|
});
|
|
|
|
ret = process_socket(socket_, req.start_time_, [&](Stream &strm) {
|
|
return handle_request(strm, req, res, close_connection, error);
|
|
});
|
|
|
|
if (!ret) {
|
|
if (error == Error::Success) { error = Error::Unknown; }
|
|
}
|
|
|
|
return ret;
|
|
}
|
|
|
|
inline Result ClientImpl::send(const Request &req) {
|
|
auto req2 = req;
|
|
return send_(std::move(req2));
|
|
}
|
|
|
|
inline Result ClientImpl::send_(Request &&req) {
|
|
auto res = detail::make_unique<Response>();
|
|
auto error = Error::Success;
|
|
auto ret = send(req, *res, error);
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers),
|
|
last_ssl_error_, last_openssl_error_};
|
|
#else
|
|
return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)};
|
|
#endif
|
|
}
|
|
|
|
inline bool ClientImpl::handle_request(Stream &strm, Request &req,
|
|
Response &res, bool close_connection,
|
|
Error &error) {
|
|
if (req.path.empty()) {
|
|
error = Error::Connection;
|
|
return false;
|
|
}
|
|
|
|
auto req_save = req;
|
|
|
|
bool ret;
|
|
|
|
if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) {
|
|
auto req2 = req;
|
|
req2.path = "http://" + host_and_port_ + req.path;
|
|
ret = process_request(strm, req2, res, close_connection, error);
|
|
req = req2;
|
|
req.path = req_save.path;
|
|
} else {
|
|
ret = process_request(strm, req, res, close_connection, error);
|
|
}
|
|
|
|
if (!ret) { return false; }
|
|
|
|
if (res.get_header_value("Connection") == "close" ||
|
|
(res.version == "HTTP/1.0" && res.reason != "Connection established")) {
|
|
// TODO this requires a not-entirely-obvious chain of calls to be correct
|
|
// for this to be safe.
|
|
|
|
// This is safe to call because handle_request is only called by send_
|
|
// which locks the request mutex during the process. It would be a bug
|
|
// to call it from a different thread since it's a thread-safety issue
|
|
// to do these things to the socket if another thread is using the socket.
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
shutdown_ssl(socket_, true);
|
|
shutdown_socket(socket_);
|
|
close_socket(socket_);
|
|
}
|
|
|
|
if (300 < res.status && res.status < 400 && follow_location_) {
|
|
req = req_save;
|
|
ret = redirect(req, res, error);
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if ((res.status == StatusCode::Unauthorized_401 ||
|
|
res.status == StatusCode::ProxyAuthenticationRequired_407) &&
|
|
req.authorization_count_ < 5) {
|
|
auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407;
|
|
const auto &username =
|
|
is_proxy ? proxy_digest_auth_username_ : digest_auth_username_;
|
|
const auto &password =
|
|
is_proxy ? proxy_digest_auth_password_ : digest_auth_password_;
|
|
|
|
if (!username.empty() && !password.empty()) {
|
|
std::map<std::string, std::string> auth;
|
|
if (detail::parse_www_authenticate(res, auth, is_proxy)) {
|
|
Request new_req = req;
|
|
new_req.authorization_count_ += 1;
|
|
new_req.headers.erase(is_proxy ? "Proxy-Authorization"
|
|
: "Authorization");
|
|
new_req.headers.insert(detail::make_digest_authentication_header(
|
|
req, auth, new_req.authorization_count_, detail::random_string(10),
|
|
username, password, is_proxy));
|
|
|
|
Response new_res;
|
|
|
|
ret = send(new_req, new_res, error);
|
|
if (ret) { res = new_res; }
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
return ret;
|
|
}
|
|
|
|
inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) {
|
|
if (req.redirect_count_ == 0) {
|
|
error = Error::ExceedRedirectCount;
|
|
return false;
|
|
}
|
|
|
|
auto location = res.get_header_value("location");
|
|
if (location.empty()) { return false; }
|
|
|
|
thread_local const std::regex re(
|
|
R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)");
|
|
|
|
std::smatch m;
|
|
if (!std::regex_match(location, m, re)) { return false; }
|
|
|
|
auto scheme = is_ssl() ? "https" : "http";
|
|
|
|
auto next_scheme = m[1].str();
|
|
auto next_host = m[2].str();
|
|
if (next_host.empty()) { next_host = m[3].str(); }
|
|
auto port_str = m[4].str();
|
|
auto next_path = m[5].str();
|
|
auto next_query = m[6].str();
|
|
|
|
auto next_port = port_;
|
|
if (!port_str.empty()) {
|
|
next_port = std::stoi(port_str);
|
|
} else if (!next_scheme.empty()) {
|
|
next_port = next_scheme == "https" ? 443 : 80;
|
|
}
|
|
|
|
if (next_scheme.empty()) { next_scheme = scheme; }
|
|
if (next_host.empty()) { next_host = host_; }
|
|
if (next_path.empty()) { next_path = "/"; }
|
|
|
|
auto path = detail::decode_path(next_path, true) + next_query;
|
|
|
|
// Same host redirect - use current client
|
|
if (next_scheme == scheme && next_host == host_ && next_port == port_) {
|
|
return detail::redirect(*this, req, res, path, location, error);
|
|
}
|
|
|
|
// Cross-host/scheme redirect - create new client with robust setup
|
|
return create_redirect_client(next_scheme, next_host, next_port, req, res,
|
|
path, location, error);
|
|
}
|
|
|
|
// New method for robust redirect client creation
|
|
inline bool ClientImpl::create_redirect_client(
|
|
const std::string &scheme, const std::string &host, int port, Request &req,
|
|
Response &res, const std::string &path, const std::string &location,
|
|
Error &error) {
|
|
// Determine if we need SSL
|
|
auto need_ssl = (scheme == "https");
|
|
|
|
// Clean up request headers that are host/client specific
|
|
// Remove headers that should not be carried over to new host
|
|
auto headers_to_remove =
|
|
std::vector<std::string>{"Host", "Proxy-Authorization", "Authorization"};
|
|
|
|
for (const auto &header_name : headers_to_remove) {
|
|
auto it = req.headers.find(header_name);
|
|
while (it != req.headers.end()) {
|
|
it = req.headers.erase(it);
|
|
it = req.headers.find(header_name);
|
|
}
|
|
}
|
|
|
|
// Create appropriate client type and handle redirect
|
|
if (need_ssl) {
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
// Create SSL client for HTTPS redirect
|
|
SSLClient redirect_client(host, port);
|
|
|
|
// Setup basic client configuration first
|
|
setup_redirect_client(redirect_client);
|
|
|
|
// SSL-specific configuration for proxy environments
|
|
if (!proxy_host_.empty() && proxy_port_ != -1) {
|
|
// Critical: Disable SSL verification for proxy environments
|
|
redirect_client.enable_server_certificate_verification(false);
|
|
redirect_client.enable_server_hostname_verification(false);
|
|
} else {
|
|
// For direct SSL connections, copy SSL verification settings
|
|
redirect_client.enable_server_certificate_verification(
|
|
server_certificate_verification_);
|
|
redirect_client.enable_server_hostname_verification(
|
|
server_hostname_verification_);
|
|
}
|
|
|
|
// Handle CA certificate store and paths if available
|
|
if (ca_cert_store_) { redirect_client.set_ca_cert_store(ca_cert_store_); }
|
|
if (!ca_cert_file_path_.empty()) {
|
|
redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_);
|
|
}
|
|
|
|
// Client certificates are set through constructor for SSLClient
|
|
// NOTE: SSLClient constructor already takes client_cert_path and
|
|
// client_key_path so we need to create it properly if client certs are
|
|
// needed
|
|
|
|
// Execute the redirect
|
|
return detail::redirect(redirect_client, req, res, path, location, error);
|
|
#else
|
|
// SSL not supported - set appropriate error
|
|
error = Error::SSLConnection;
|
|
return false;
|
|
#endif
|
|
} else {
|
|
// HTTP redirect
|
|
ClientImpl redirect_client(host, port);
|
|
|
|
// Setup client with robust configuration
|
|
setup_redirect_client(redirect_client);
|
|
|
|
// Execute the redirect
|
|
return detail::redirect(redirect_client, req, res, path, location, error);
|
|
}
|
|
}
|
|
|
|
// New method for robust client setup (based on basic_manual_redirect.cpp logic)
|
|
template <typename ClientType>
|
|
inline void ClientImpl::setup_redirect_client(ClientType &client) {
|
|
// Copy basic settings first
|
|
client.set_connection_timeout(connection_timeout_sec_);
|
|
client.set_read_timeout(read_timeout_sec_, read_timeout_usec_);
|
|
client.set_write_timeout(write_timeout_sec_, write_timeout_usec_);
|
|
client.set_keep_alive(keep_alive_);
|
|
client.set_follow_location(
|
|
true); // Enable redirects to handle multi-step redirects
|
|
client.set_path_encode(path_encode_);
|
|
client.set_compress(compress_);
|
|
client.set_decompress(decompress_);
|
|
|
|
// Copy authentication settings BEFORE proxy setup
|
|
if (!basic_auth_username_.empty()) {
|
|
client.set_basic_auth(basic_auth_username_, basic_auth_password_);
|
|
}
|
|
if (!bearer_token_auth_token_.empty()) {
|
|
client.set_bearer_token_auth(bearer_token_auth_token_);
|
|
}
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (!digest_auth_username_.empty()) {
|
|
client.set_digest_auth(digest_auth_username_, digest_auth_password_);
|
|
}
|
|
#endif
|
|
|
|
// Setup proxy configuration (CRITICAL ORDER - proxy must be set
|
|
// before proxy auth)
|
|
if (!proxy_host_.empty() && proxy_port_ != -1) {
|
|
// First set proxy host and port
|
|
client.set_proxy(proxy_host_, proxy_port_);
|
|
|
|
// Then set proxy authentication (order matters!)
|
|
if (!proxy_basic_auth_username_.empty()) {
|
|
client.set_proxy_basic_auth(proxy_basic_auth_username_,
|
|
proxy_basic_auth_password_);
|
|
}
|
|
if (!proxy_bearer_token_auth_token_.empty()) {
|
|
client.set_proxy_bearer_token_auth(proxy_bearer_token_auth_token_);
|
|
}
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (!proxy_digest_auth_username_.empty()) {
|
|
client.set_proxy_digest_auth(proxy_digest_auth_username_,
|
|
proxy_digest_auth_password_);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
// Copy network and socket settings
|
|
client.set_address_family(address_family_);
|
|
client.set_tcp_nodelay(tcp_nodelay_);
|
|
client.set_ipv6_v6only(ipv6_v6only_);
|
|
if (socket_options_) { client.set_socket_options(socket_options_); }
|
|
if (!interface_.empty()) { client.set_interface(interface_); }
|
|
|
|
// Copy logging and headers
|
|
if (logger_) { client.set_logger(logger_); }
|
|
|
|
// NOTE: DO NOT copy default_headers_ as they may contain stale Host headers
|
|
// Each new client should generate its own headers based on its target host
|
|
}
|
|
|
|
inline bool ClientImpl::write_content_with_provider(Stream &strm,
|
|
const Request &req,
|
|
Error &error) const {
|
|
auto is_shutting_down = []() { return false; };
|
|
|
|
if (req.is_chunked_content_provider_) {
|
|
// TODO: Brotli support
|
|
std::unique_ptr<detail::compressor> compressor;
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
if (compress_) {
|
|
compressor = detail::make_unique<detail::gzip_compressor>();
|
|
} else
|
|
#endif
|
|
{
|
|
compressor = detail::make_unique<detail::nocompressor>();
|
|
}
|
|
|
|
return detail::write_content_chunked(strm, req.content_provider_,
|
|
is_shutting_down, *compressor, error);
|
|
} else {
|
|
return detail::write_content_with_progress(
|
|
strm, req.content_provider_, 0, req.content_length_, is_shutting_down,
|
|
req.upload_progress, error);
|
|
}
|
|
}
|
|
|
|
inline bool ClientImpl::write_request(Stream &strm, Request &req,
|
|
bool close_connection, Error &error) {
|
|
// Prepare additional headers
|
|
if (close_connection) {
|
|
if (!req.has_header("Connection")) {
|
|
req.set_header("Connection", "close");
|
|
}
|
|
}
|
|
|
|
if (!req.has_header("Host")) {
|
|
// For Unix socket connections, use "localhost" as Host header (similar to
|
|
// curl behavior)
|
|
if (address_family_ == AF_UNIX) {
|
|
req.set_header("Host", "localhost");
|
|
} else if (is_ssl()) {
|
|
if (port_ == 443) {
|
|
req.set_header("Host", host_);
|
|
} else {
|
|
req.set_header("Host", host_and_port_);
|
|
}
|
|
} else {
|
|
if (port_ == 80) {
|
|
req.set_header("Host", host_);
|
|
} else {
|
|
req.set_header("Host", host_and_port_);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); }
|
|
|
|
if (!req.content_receiver) {
|
|
if (!req.has_header("Accept-Encoding")) {
|
|
std::string accept_encoding;
|
|
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
|
|
accept_encoding = "br";
|
|
#endif
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
if (!accept_encoding.empty()) { accept_encoding += ", "; }
|
|
accept_encoding += "gzip, deflate";
|
|
#endif
|
|
#ifdef CPPHTTPLIB_ZSTD_SUPPORT
|
|
if (!accept_encoding.empty()) { accept_encoding += ", "; }
|
|
accept_encoding += "zstd";
|
|
#endif
|
|
req.set_header("Accept-Encoding", accept_encoding);
|
|
}
|
|
|
|
#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT
|
|
if (!req.has_header("User-Agent")) {
|
|
auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION;
|
|
req.set_header("User-Agent", agent);
|
|
}
|
|
#endif
|
|
};
|
|
|
|
if (req.body.empty()) {
|
|
if (req.content_provider_) {
|
|
if (!req.is_chunked_content_provider_) {
|
|
if (!req.has_header("Content-Length")) {
|
|
auto length = std::to_string(req.content_length_);
|
|
req.set_header("Content-Length", length);
|
|
}
|
|
}
|
|
} else {
|
|
if (req.method == "POST" || req.method == "PUT" ||
|
|
req.method == "PATCH") {
|
|
req.set_header("Content-Length", "0");
|
|
}
|
|
}
|
|
} else {
|
|
if (!req.has_header("Content-Type")) {
|
|
req.set_header("Content-Type", "text/plain");
|
|
}
|
|
|
|
if (!req.has_header("Content-Length")) {
|
|
auto length = std::to_string(req.body.size());
|
|
req.set_header("Content-Length", length);
|
|
}
|
|
}
|
|
|
|
if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) {
|
|
if (!req.has_header("Authorization")) {
|
|
req.headers.insert(make_basic_authentication_header(
|
|
basic_auth_username_, basic_auth_password_, false));
|
|
}
|
|
}
|
|
|
|
if (!proxy_basic_auth_username_.empty() &&
|
|
!proxy_basic_auth_password_.empty()) {
|
|
if (!req.has_header("Proxy-Authorization")) {
|
|
req.headers.insert(make_basic_authentication_header(
|
|
proxy_basic_auth_username_, proxy_basic_auth_password_, true));
|
|
}
|
|
}
|
|
|
|
if (!bearer_token_auth_token_.empty()) {
|
|
if (!req.has_header("Authorization")) {
|
|
req.headers.insert(make_bearer_token_authentication_header(
|
|
bearer_token_auth_token_, false));
|
|
}
|
|
}
|
|
|
|
if (!proxy_bearer_token_auth_token_.empty()) {
|
|
if (!req.has_header("Proxy-Authorization")) {
|
|
req.headers.insert(make_bearer_token_authentication_header(
|
|
proxy_bearer_token_auth_token_, true));
|
|
}
|
|
}
|
|
|
|
// Request line and headers
|
|
{
|
|
detail::BufferStream bstrm;
|
|
|
|
const auto &path_with_query =
|
|
req.params.empty() ? req.path
|
|
: append_query_params(req.path, req.params);
|
|
|
|
const auto &path =
|
|
path_encode_ ? detail::encode_path(path_with_query) : path_with_query;
|
|
|
|
detail::write_request_line(bstrm, req.method, path);
|
|
|
|
header_writer_(bstrm, req.headers);
|
|
|
|
// Flush buffer
|
|
auto &data = bstrm.get_buffer();
|
|
if (!detail::write_data(strm, data.data(), data.size())) {
|
|
error = Error::Write;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Body
|
|
if (req.body.empty()) {
|
|
return write_content_with_provider(strm, req, error);
|
|
}
|
|
|
|
if (req.upload_progress) {
|
|
auto body_size = req.body.size();
|
|
size_t written = 0;
|
|
auto data = req.body.data();
|
|
|
|
while (written < body_size) {
|
|
size_t to_write = (std::min)(CPPHTTPLIB_SEND_BUFSIZ, body_size - written);
|
|
if (!detail::write_data(strm, data + written, to_write)) {
|
|
error = Error::Write;
|
|
return false;
|
|
}
|
|
written += to_write;
|
|
|
|
if (!req.upload_progress(written, body_size)) {
|
|
error = Error::Canceled;
|
|
return false;
|
|
}
|
|
}
|
|
} else {
|
|
if (!detail::write_data(strm, req.body.data(), req.body.size())) {
|
|
error = Error::Write;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline std::unique_ptr<Response> ClientImpl::send_with_content_provider(
|
|
Request &req, const char *body, size_t content_length,
|
|
ContentProvider content_provider,
|
|
ContentProviderWithoutLength content_provider_without_length,
|
|
const std::string &content_type, Error &error) {
|
|
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
|
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
if (compress_) { req.set_header("Content-Encoding", "gzip"); }
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
|
|
if (compress_ && !content_provider_without_length) {
|
|
// TODO: Brotli support
|
|
detail::gzip_compressor compressor;
|
|
|
|
if (content_provider) {
|
|
auto ok = true;
|
|
size_t offset = 0;
|
|
DataSink data_sink;
|
|
|
|
data_sink.write = [&](const char *data, size_t data_len) -> bool {
|
|
if (ok) {
|
|
auto last = offset + data_len == content_length;
|
|
|
|
auto ret = compressor.compress(
|
|
data, data_len, last,
|
|
[&](const char *compressed_data, size_t compressed_data_len) {
|
|
req.body.append(compressed_data, compressed_data_len);
|
|
return true;
|
|
});
|
|
|
|
if (ret) {
|
|
offset += data_len;
|
|
} else {
|
|
ok = false;
|
|
}
|
|
}
|
|
return ok;
|
|
};
|
|
|
|
while (ok && offset < content_length) {
|
|
if (!content_provider(offset, content_length - offset, data_sink)) {
|
|
error = Error::Canceled;
|
|
return nullptr;
|
|
}
|
|
}
|
|
} else {
|
|
if (!compressor.compress(body, content_length, true,
|
|
[&](const char *data, size_t data_len) {
|
|
req.body.append(data, data_len);
|
|
return true;
|
|
})) {
|
|
error = Error::Compression;
|
|
return nullptr;
|
|
}
|
|
}
|
|
} else
|
|
#endif
|
|
{
|
|
if (content_provider) {
|
|
req.content_length_ = content_length;
|
|
req.content_provider_ = std::move(content_provider);
|
|
req.is_chunked_content_provider_ = false;
|
|
} else if (content_provider_without_length) {
|
|
req.content_length_ = 0;
|
|
req.content_provider_ = detail::ContentProviderAdapter(
|
|
std::move(content_provider_without_length));
|
|
req.is_chunked_content_provider_ = true;
|
|
req.set_header("Transfer-Encoding", "chunked");
|
|
} else {
|
|
req.body.assign(body, content_length);
|
|
}
|
|
}
|
|
|
|
auto res = detail::make_unique<Response>();
|
|
return send(req, *res, error) ? std::move(res) : nullptr;
|
|
}
|
|
|
|
inline Result ClientImpl::send_with_content_provider(
|
|
const std::string &method, const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length, ContentProvider content_provider,
|
|
ContentProviderWithoutLength content_provider_without_length,
|
|
const std::string &content_type, UploadProgress progress) {
|
|
Request req;
|
|
req.method = method;
|
|
req.headers = headers;
|
|
req.path = path;
|
|
req.upload_progress = std::move(progress);
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
auto error = Error::Success;
|
|
|
|
auto res = send_with_content_provider(
|
|
req, body, content_length, std::move(content_provider),
|
|
std::move(content_provider_without_length), content_type, error);
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
return Result{std::move(res), error, std::move(req.headers), last_ssl_error_,
|
|
last_openssl_error_};
|
|
#else
|
|
return Result{std::move(res), error, std::move(req.headers)};
|
|
#endif
|
|
}
|
|
|
|
inline std::string
|
|
ClientImpl::adjust_host_string(const std::string &host) const {
|
|
if (host.find(':') != std::string::npos) { return "[" + host + "]"; }
|
|
return host;
|
|
}
|
|
|
|
inline bool ClientImpl::process_request(Stream &strm, Request &req,
|
|
Response &res, bool close_connection,
|
|
Error &error) {
|
|
// Send request
|
|
if (!write_request(strm, req, close_connection, error)) { return false; }
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (is_ssl()) {
|
|
auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1;
|
|
if (!is_proxy_enabled) {
|
|
if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) {
|
|
error = Error::SSLPeerCouldBeClosed_;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
// Receive response and headers
|
|
if (!read_response_line(strm, req, res) ||
|
|
!detail::read_headers(strm, res.headers)) {
|
|
error = Error::Read;
|
|
return false;
|
|
}
|
|
|
|
// Body
|
|
if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" &&
|
|
req.method != "CONNECT") {
|
|
auto redirect = 300 < res.status && res.status < 400 &&
|
|
res.status != StatusCode::NotModified_304 &&
|
|
follow_location_;
|
|
|
|
if (req.response_handler && !redirect) {
|
|
if (!req.response_handler(res)) {
|
|
error = Error::Canceled;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
auto out =
|
|
req.content_receiver
|
|
? static_cast<ContentReceiverWithProgress>(
|
|
[&](const char *buf, size_t n, size_t off, size_t len) {
|
|
if (redirect) { return true; }
|
|
auto ret = req.content_receiver(buf, n, off, len);
|
|
if (!ret) { error = Error::Canceled; }
|
|
return ret;
|
|
})
|
|
: static_cast<ContentReceiverWithProgress>(
|
|
[&](const char *buf, size_t n, size_t /*off*/,
|
|
size_t /*len*/) {
|
|
assert(res.body.size() + n <= res.body.max_size());
|
|
res.body.append(buf, n);
|
|
return true;
|
|
});
|
|
|
|
auto progress = [&](size_t current, size_t total) {
|
|
if (!req.download_progress || redirect) { return true; }
|
|
auto ret = req.download_progress(current, total);
|
|
if (!ret) { error = Error::Canceled; }
|
|
return ret;
|
|
};
|
|
|
|
if (res.has_header("Content-Length")) {
|
|
if (!req.content_receiver) {
|
|
auto len = res.get_header_value_u64("Content-Length");
|
|
if (len > res.body.max_size()) {
|
|
error = Error::Read;
|
|
return false;
|
|
}
|
|
res.body.reserve(static_cast<size_t>(len));
|
|
}
|
|
}
|
|
|
|
if (res.status != StatusCode::NotModified_304) {
|
|
int dummy_status;
|
|
if (!detail::read_content(strm, res, (std::numeric_limits<size_t>::max)(),
|
|
dummy_status, std::move(progress),
|
|
std::move(out), decompress_)) {
|
|
if (error != Error::Canceled) { error = Error::Read; }
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Log
|
|
if (logger_) { logger_(req, res); }
|
|
|
|
return true;
|
|
}
|
|
|
|
inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider(
|
|
const std::string &boundary, const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items) const {
|
|
size_t cur_item = 0;
|
|
size_t cur_start = 0;
|
|
// cur_item and cur_start are copied to within the std::function and maintain
|
|
// state between successive calls
|
|
return [&, cur_item, cur_start](size_t offset,
|
|
DataSink &sink) mutable -> bool {
|
|
if (!offset && !items.empty()) {
|
|
sink.os << detail::serialize_multipart_formdata(items, boundary, false);
|
|
return true;
|
|
} else if (cur_item < provider_items.size()) {
|
|
if (!cur_start) {
|
|
const auto &begin = detail::serialize_multipart_formdata_item_begin(
|
|
provider_items[cur_item], boundary);
|
|
offset += begin.size();
|
|
cur_start = offset;
|
|
sink.os << begin;
|
|
}
|
|
|
|
DataSink cur_sink;
|
|
auto has_data = true;
|
|
cur_sink.write = sink.write;
|
|
cur_sink.done = [&]() { has_data = false; };
|
|
|
|
if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) {
|
|
return false;
|
|
}
|
|
|
|
if (!has_data) {
|
|
sink.os << detail::serialize_multipart_formdata_item_end();
|
|
cur_item++;
|
|
cur_start = 0;
|
|
}
|
|
return true;
|
|
} else {
|
|
sink.os << detail::serialize_multipart_formdata_finish(boundary);
|
|
sink.done();
|
|
return true;
|
|
}
|
|
};
|
|
}
|
|
|
|
inline bool ClientImpl::process_socket(
|
|
const Socket &socket,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
std::function<bool(Stream &strm)> callback) {
|
|
return detail::process_client_socket(
|
|
socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
|
|
write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback));
|
|
}
|
|
|
|
inline bool ClientImpl::is_ssl() const { return false; }
|
|
|
|
inline Result ClientImpl::Get(const std::string &path,
|
|
DownloadProgress progress) {
|
|
return Get(path, Headers(), std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path, const Params ¶ms,
|
|
const Headers &headers,
|
|
DownloadProgress progress) {
|
|
if (params.empty()) { return Get(path, headers); }
|
|
|
|
std::string path_with_query = append_query_params(path, params);
|
|
return Get(path_with_query, headers, std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path, const Headers &headers,
|
|
DownloadProgress progress) {
|
|
Request req;
|
|
req.method = "GET";
|
|
req.path = path;
|
|
req.headers = headers;
|
|
req.download_progress = std::move(progress);
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return Get(path, Headers(), nullptr, std::move(content_receiver),
|
|
std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path, const Headers &headers,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return Get(path, headers, nullptr, std::move(content_receiver),
|
|
std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path,
|
|
ResponseHandler response_handler,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return Get(path, Headers(), std::move(response_handler),
|
|
std::move(content_receiver), std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path, const Headers &headers,
|
|
ResponseHandler response_handler,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
Request req;
|
|
req.method = "GET";
|
|
req.path = path;
|
|
req.headers = headers;
|
|
req.response_handler = std::move(response_handler);
|
|
req.content_receiver =
|
|
[content_receiver](const char *data, size_t data_length,
|
|
size_t /*offset*/, size_t /*total_length*/) {
|
|
return content_receiver(data, data_length);
|
|
};
|
|
req.download_progress = std::move(progress);
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path, const Params ¶ms,
|
|
const Headers &headers,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return Get(path, params, headers, nullptr, std::move(content_receiver),
|
|
std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Get(const std::string &path, const Params ¶ms,
|
|
const Headers &headers,
|
|
ResponseHandler response_handler,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
if (params.empty()) {
|
|
return Get(path, headers, std::move(response_handler),
|
|
std::move(content_receiver), std::move(progress));
|
|
}
|
|
|
|
std::string path_with_query = append_query_params(path, params);
|
|
return Get(path_with_query, headers, std::move(response_handler),
|
|
std::move(content_receiver), std::move(progress));
|
|
}
|
|
|
|
inline Result ClientImpl::Head(const std::string &path) {
|
|
return Head(path, Headers());
|
|
}
|
|
|
|
inline Result ClientImpl::Head(const std::string &path,
|
|
const Headers &headers) {
|
|
Request req;
|
|
req.method = "HEAD";
|
|
req.headers = headers;
|
|
req.path = path;
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path) {
|
|
return Post(path, std::string(), std::string());
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path,
|
|
const Headers &headers) {
|
|
return Post(path, headers, nullptr, 0, std::string());
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Post(path, Headers(), body, content_length, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Post(path, Headers(), body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) {
|
|
return Post(path, Headers(), params);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Post(path, Headers(), content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Post(path, Headers(), std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const Params ¶ms) {
|
|
auto query = detail::params_to_query_str(params);
|
|
return Post(path, headers, query, "application/x-www-form-urlencoded");
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return Post(path, Headers(), items, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
const auto &boundary = detail::make_multipart_data_boundary();
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
const auto &body = detail::serialize_multipart_formdata(items, boundary);
|
|
return Post(path, headers, body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const std::string &boundary,
|
|
UploadProgress progress) {
|
|
if (!detail::is_multipart_boundary_chars_valid(boundary)) {
|
|
return Result{nullptr, Error::UnsupportedMultipartBoundaryChars};
|
|
}
|
|
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
const auto &body = detail::serialize_multipart_formdata(items, boundary);
|
|
return Post(path, headers, body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("POST", path, headers, body, content_length,
|
|
nullptr, nullptr, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("POST", path, headers, body.data(),
|
|
body.size(), nullptr, nullptr, content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("POST", path, headers, nullptr,
|
|
content_length, std::move(content_provider),
|
|
nullptr, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr,
|
|
std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items,
|
|
UploadProgress progress) {
|
|
const auto &boundary = detail::make_multipart_data_boundary();
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
return send_with_content_provider(
|
|
"POST", path, headers, nullptr, 0, nullptr,
|
|
get_multipart_content_provider(boundary, items, provider_items),
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Post(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
Request req;
|
|
req.method = "POST";
|
|
req.path = path;
|
|
req.headers = headers;
|
|
req.body = body;
|
|
req.content_receiver =
|
|
[content_receiver](const char *data, size_t data_length,
|
|
size_t /*offset*/, size_t /*total_length*/) {
|
|
return content_receiver(data, data_length);
|
|
};
|
|
req.download_progress = std::move(progress);
|
|
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path) {
|
|
return Put(path, std::string(), std::string());
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers) {
|
|
return Put(path, headers, nullptr, 0, std::string());
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Put(path, Headers(), body, content_length, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Put(path, Headers(), body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) {
|
|
return Put(path, Headers(), params);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Put(path, Headers(), content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Put(path, Headers(), std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const Params ¶ms) {
|
|
auto query = detail::params_to_query_str(params);
|
|
return Put(path, headers, query, "application/x-www-form-urlencoded");
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return Put(path, Headers(), items, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
const auto &boundary = detail::make_multipart_data_boundary();
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
const auto &body = detail::serialize_multipart_formdata(items, boundary);
|
|
return Put(path, headers, body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const std::string &boundary,
|
|
UploadProgress progress) {
|
|
if (!detail::is_multipart_boundary_chars_valid(boundary)) {
|
|
return Result{nullptr, Error::UnsupportedMultipartBoundaryChars};
|
|
}
|
|
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
const auto &body = detail::serialize_multipart_formdata(items, boundary);
|
|
return Put(path, headers, body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PUT", path, headers, body, content_length,
|
|
nullptr, nullptr, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PUT", path, headers, body.data(),
|
|
body.size(), nullptr, nullptr, content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PUT", path, headers, nullptr,
|
|
content_length, std::move(content_provider),
|
|
nullptr, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr,
|
|
std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items,
|
|
UploadProgress progress) {
|
|
const auto &boundary = detail::make_multipart_data_boundary();
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
return send_with_content_provider(
|
|
"PUT", path, headers, nullptr, 0, nullptr,
|
|
get_multipart_content_provider(boundary, items, provider_items),
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Put(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
Request req;
|
|
req.method = "PUT";
|
|
req.path = path;
|
|
req.headers = headers;
|
|
req.body = body;
|
|
req.content_receiver =
|
|
[content_receiver](const char *data, size_t data_length,
|
|
size_t /*offset*/, size_t /*total_length*/) {
|
|
return content_receiver(data, data_length);
|
|
};
|
|
req.download_progress = std::move(progress);
|
|
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path) {
|
|
return Patch(path, std::string(), std::string());
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
UploadProgress progress) {
|
|
return Patch(path, headers, nullptr, 0, std::string(), progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Patch(path, Headers(), body, content_length, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Patch(path, Headers(), body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Params ¶ms) {
|
|
return Patch(path, Headers(), params);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Patch(path, Headers(), content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return Patch(path, Headers(), std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const Params ¶ms) {
|
|
auto query = detail::params_to_query_str(params);
|
|
return Patch(path, headers, query, "application/x-www-form-urlencoded");
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return Patch(path, Headers(), items, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
const auto &boundary = detail::make_multipart_data_boundary();
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
const auto &body = detail::serialize_multipart_formdata(items, boundary);
|
|
return Patch(path, headers, body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const std::string &boundary,
|
|
UploadProgress progress) {
|
|
if (!detail::is_multipart_boundary_chars_valid(boundary)) {
|
|
return Result{nullptr, Error::UnsupportedMultipartBoundaryChars};
|
|
}
|
|
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
const auto &body = detail::serialize_multipart_formdata(items, boundary);
|
|
return Patch(path, headers, body, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PATCH", path, headers, body,
|
|
content_length, nullptr, nullptr,
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PATCH", path, headers, body.data(),
|
|
body.size(), nullptr, nullptr, content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PATCH", path, headers, nullptr,
|
|
content_length, std::move(content_provider),
|
|
nullptr, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr,
|
|
std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items,
|
|
UploadProgress progress) {
|
|
const auto &boundary = detail::make_multipart_data_boundary();
|
|
const auto &content_type =
|
|
detail::serialize_multipart_formdata_get_content_type(boundary);
|
|
return send_with_content_provider(
|
|
"PATCH", path, headers, nullptr, 0, nullptr,
|
|
get_multipart_content_provider(boundary, items, provider_items),
|
|
content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Patch(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
Request req;
|
|
req.method = "PATCH";
|
|
req.path = path;
|
|
req.headers = headers;
|
|
req.body = body;
|
|
req.content_receiver =
|
|
[content_receiver](const char *data, size_t data_length,
|
|
size_t /*offset*/, size_t /*total_length*/) {
|
|
return content_receiver(data, data_length);
|
|
};
|
|
req.download_progress = std::move(progress);
|
|
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path,
|
|
DownloadProgress progress) {
|
|
return Delete(path, Headers(), std::string(), std::string(), progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path,
|
|
const Headers &headers,
|
|
DownloadProgress progress) {
|
|
return Delete(path, headers, std::string(), std::string(), progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return Delete(path, Headers(), body, content_length, content_type, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return Delete(path, Headers(), body.data(), body.size(), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path,
|
|
const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return Delete(path, headers, body.data(), body.size(), content_type,
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path, const Params ¶ms,
|
|
DownloadProgress progress) {
|
|
return Delete(path, Headers(), params, progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path,
|
|
const Headers &headers, const Params ¶ms,
|
|
DownloadProgress progress) {
|
|
auto query = detail::params_to_query_str(params);
|
|
return Delete(path, headers, query, "application/x-www-form-urlencoded",
|
|
progress);
|
|
}
|
|
|
|
inline Result ClientImpl::Delete(const std::string &path,
|
|
const Headers &headers, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
Request req;
|
|
req.method = "DELETE";
|
|
req.headers = headers;
|
|
req.path = path;
|
|
req.download_progress = std::move(progress);
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
if (!content_type.empty()) { req.set_header("Content-Type", content_type); }
|
|
req.body.assign(body, content_length);
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline Result ClientImpl::Options(const std::string &path) {
|
|
return Options(path, Headers());
|
|
}
|
|
|
|
inline Result ClientImpl::Options(const std::string &path,
|
|
const Headers &headers) {
|
|
Request req;
|
|
req.method = "OPTIONS";
|
|
req.headers = headers;
|
|
req.path = path;
|
|
if (max_timeout_msec_ > 0) {
|
|
req.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
|
|
return send_(std::move(req));
|
|
}
|
|
|
|
inline void ClientImpl::stop() {
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
|
|
// If there is anything ongoing right now, the ONLY thread-safe thing we can
|
|
// do is to shutdown_socket, so that threads using this socket suddenly
|
|
// discover they can't read/write any more and error out. Everything else
|
|
// (closing the socket, shutting ssl down) is unsafe because these actions are
|
|
// not thread-safe.
|
|
if (socket_requests_in_flight_ > 0) {
|
|
shutdown_socket(socket_);
|
|
|
|
// Aside from that, we set a flag for the socket to be closed when we're
|
|
// done.
|
|
socket_should_be_closed_when_request_is_done_ = true;
|
|
return;
|
|
}
|
|
|
|
// Otherwise, still holding the mutex, we can shut everything down ourselves
|
|
shutdown_ssl(socket_, true);
|
|
shutdown_socket(socket_);
|
|
close_socket(socket_);
|
|
}
|
|
|
|
inline std::string ClientImpl::host() const { return host_; }
|
|
|
|
inline int ClientImpl::port() const { return port_; }
|
|
|
|
inline size_t ClientImpl::is_socket_open() const {
|
|
std::lock_guard<std::mutex> guard(socket_mutex_);
|
|
return socket_.is_open();
|
|
}
|
|
|
|
inline socket_t ClientImpl::socket() const { return socket_.sock; }
|
|
|
|
inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) {
|
|
connection_timeout_sec_ = sec;
|
|
connection_timeout_usec_ = usec;
|
|
}
|
|
|
|
inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) {
|
|
read_timeout_sec_ = sec;
|
|
read_timeout_usec_ = usec;
|
|
}
|
|
|
|
inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) {
|
|
write_timeout_sec_ = sec;
|
|
write_timeout_usec_ = usec;
|
|
}
|
|
|
|
inline void ClientImpl::set_max_timeout(time_t msec) {
|
|
max_timeout_msec_ = msec;
|
|
}
|
|
|
|
inline void ClientImpl::set_basic_auth(const std::string &username,
|
|
const std::string &password) {
|
|
basic_auth_username_ = username;
|
|
basic_auth_password_ = password;
|
|
}
|
|
|
|
inline void ClientImpl::set_bearer_token_auth(const std::string &token) {
|
|
bearer_token_auth_token_ = token;
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline void ClientImpl::set_digest_auth(const std::string &username,
|
|
const std::string &password) {
|
|
digest_auth_username_ = username;
|
|
digest_auth_password_ = password;
|
|
}
|
|
#endif
|
|
|
|
inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; }
|
|
|
|
inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; }
|
|
|
|
inline void ClientImpl::set_path_encode(bool on) { path_encode_ = on; }
|
|
|
|
inline void
|
|
ClientImpl::set_hostname_addr_map(std::map<std::string, std::string> addr_map) {
|
|
addr_map_ = std::move(addr_map);
|
|
}
|
|
|
|
inline void ClientImpl::set_default_headers(Headers headers) {
|
|
default_headers_ = std::move(headers);
|
|
}
|
|
|
|
inline void ClientImpl::set_header_writer(
|
|
std::function<ssize_t(Stream &, Headers &)> const &writer) {
|
|
header_writer_ = writer;
|
|
}
|
|
|
|
inline void ClientImpl::set_address_family(int family) {
|
|
address_family_ = family;
|
|
}
|
|
|
|
inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; }
|
|
|
|
inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; }
|
|
|
|
inline void ClientImpl::set_socket_options(SocketOptions socket_options) {
|
|
socket_options_ = std::move(socket_options);
|
|
}
|
|
|
|
inline void ClientImpl::set_compress(bool on) { compress_ = on; }
|
|
|
|
inline void ClientImpl::set_decompress(bool on) { decompress_ = on; }
|
|
|
|
inline void ClientImpl::set_interface(const std::string &intf) {
|
|
interface_ = intf;
|
|
}
|
|
|
|
inline void ClientImpl::set_proxy(const std::string &host, int port) {
|
|
proxy_host_ = host;
|
|
proxy_port_ = port;
|
|
}
|
|
|
|
inline void ClientImpl::set_proxy_basic_auth(const std::string &username,
|
|
const std::string &password) {
|
|
proxy_basic_auth_username_ = username;
|
|
proxy_basic_auth_password_ = password;
|
|
}
|
|
|
|
inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) {
|
|
proxy_bearer_token_auth_token_ = token;
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline void ClientImpl::set_proxy_digest_auth(const std::string &username,
|
|
const std::string &password) {
|
|
proxy_digest_auth_username_ = username;
|
|
proxy_digest_auth_password_ = password;
|
|
}
|
|
|
|
inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path,
|
|
const std::string &ca_cert_dir_path) {
|
|
ca_cert_file_path_ = ca_cert_file_path;
|
|
ca_cert_dir_path_ = ca_cert_dir_path;
|
|
}
|
|
|
|
inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) {
|
|
if (ca_cert_store && ca_cert_store != ca_cert_store_) {
|
|
ca_cert_store_ = ca_cert_store;
|
|
}
|
|
}
|
|
|
|
inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert,
|
|
std::size_t size) const {
|
|
auto mem = BIO_new_mem_buf(ca_cert, static_cast<int>(size));
|
|
auto se = detail::scope_exit([&] { BIO_free_all(mem); });
|
|
if (!mem) { return nullptr; }
|
|
|
|
auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr);
|
|
if (!inf) { return nullptr; }
|
|
|
|
auto cts = X509_STORE_new();
|
|
if (cts) {
|
|
for (auto i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); i++) {
|
|
auto itmp = sk_X509_INFO_value(inf, i);
|
|
if (!itmp) { continue; }
|
|
|
|
if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); }
|
|
if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); }
|
|
}
|
|
}
|
|
|
|
sk_X509_INFO_pop_free(inf, X509_INFO_free);
|
|
return cts;
|
|
}
|
|
|
|
inline void ClientImpl::enable_server_certificate_verification(bool enabled) {
|
|
server_certificate_verification_ = enabled;
|
|
}
|
|
|
|
inline void ClientImpl::enable_server_hostname_verification(bool enabled) {
|
|
server_hostname_verification_ = enabled;
|
|
}
|
|
|
|
inline void ClientImpl::set_server_certificate_verifier(
|
|
std::function<SSLVerifierResponse(SSL *ssl)> verifier) {
|
|
server_certificate_verifier_ = verifier;
|
|
}
|
|
#endif
|
|
|
|
inline void ClientImpl::set_logger(Logger logger) {
|
|
logger_ = std::move(logger);
|
|
}
|
|
|
|
/*
|
|
* SSL Implementation
|
|
*/
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
namespace detail {
|
|
|
|
template <typename U, typename V>
|
|
inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex,
|
|
U SSL_connect_or_accept, V setup) {
|
|
SSL *ssl = nullptr;
|
|
{
|
|
std::lock_guard<std::mutex> guard(ctx_mutex);
|
|
ssl = SSL_new(ctx);
|
|
}
|
|
|
|
if (ssl) {
|
|
set_nonblocking(sock, true);
|
|
auto bio = BIO_new_socket(static_cast<int>(sock), BIO_NOCLOSE);
|
|
BIO_set_nbio(bio, 1);
|
|
SSL_set_bio(ssl, bio, bio);
|
|
|
|
if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) {
|
|
SSL_shutdown(ssl);
|
|
{
|
|
std::lock_guard<std::mutex> guard(ctx_mutex);
|
|
SSL_free(ssl);
|
|
}
|
|
set_nonblocking(sock, false);
|
|
return nullptr;
|
|
}
|
|
BIO_set_nbio(bio, 0);
|
|
set_nonblocking(sock, false);
|
|
}
|
|
|
|
return ssl;
|
|
}
|
|
|
|
inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock,
|
|
bool shutdown_gracefully) {
|
|
// sometimes we may want to skip this to try to avoid SIGPIPE if we know
|
|
// the remote has closed the network connection
|
|
// Note that it is not always possible to avoid SIGPIPE, this is merely a
|
|
// best-efforts.
|
|
if (shutdown_gracefully) {
|
|
(void)(sock);
|
|
// SSL_shutdown() returns 0 on first call (indicating close_notify alert
|
|
// sent) and 1 on subsequent call (indicating close_notify alert received)
|
|
if (SSL_shutdown(ssl) == 0) {
|
|
// Expected to return 1, but even if it doesn't, we free ssl
|
|
SSL_shutdown(ssl);
|
|
}
|
|
}
|
|
|
|
std::lock_guard<std::mutex> guard(ctx_mutex);
|
|
SSL_free(ssl);
|
|
}
|
|
|
|
template <typename U>
|
|
bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl,
|
|
U ssl_connect_or_accept,
|
|
time_t timeout_sec, time_t timeout_usec,
|
|
int *ssl_error) {
|
|
auto res = 0;
|
|
while ((res = ssl_connect_or_accept(ssl)) != 1) {
|
|
auto err = SSL_get_error(ssl, res);
|
|
switch (err) {
|
|
case SSL_ERROR_WANT_READ:
|
|
if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; }
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; }
|
|
break;
|
|
default: break;
|
|
}
|
|
if (ssl_error) { *ssl_error = err; }
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool process_server_socket_ssl(
|
|
const std::atomic<socket_t> &svr_sock, SSL *ssl, socket_t sock,
|
|
size_t keep_alive_max_count, time_t keep_alive_timeout_sec,
|
|
time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec,
|
|
time_t write_timeout_usec, T callback) {
|
|
return process_server_socket_core(
|
|
svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec,
|
|
[&](bool close_connection, bool &connection_closed) {
|
|
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
|
|
write_timeout_sec, write_timeout_usec);
|
|
return callback(strm, close_connection, connection_closed);
|
|
});
|
|
}
|
|
|
|
template <typename T>
|
|
inline bool process_client_socket_ssl(
|
|
SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec, time_t write_timeout_usec,
|
|
time_t max_timeout_msec,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time, T callback) {
|
|
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
|
|
write_timeout_sec, write_timeout_usec, max_timeout_msec,
|
|
start_time);
|
|
return callback(strm);
|
|
}
|
|
|
|
// SSL socket stream implementation
|
|
inline SSLSocketStream::SSLSocketStream(
|
|
socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec,
|
|
time_t write_timeout_sec, time_t write_timeout_usec,
|
|
time_t max_timeout_msec,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time)
|
|
: sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec),
|
|
read_timeout_usec_(read_timeout_usec),
|
|
write_timeout_sec_(write_timeout_sec),
|
|
write_timeout_usec_(write_timeout_usec),
|
|
max_timeout_msec_(max_timeout_msec), start_time_(start_time) {
|
|
SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY);
|
|
}
|
|
|
|
inline SSLSocketStream::~SSLSocketStream() = default;
|
|
|
|
inline bool SSLSocketStream::is_readable() const {
|
|
return SSL_pending(ssl_) > 0;
|
|
}
|
|
|
|
inline bool SSLSocketStream::wait_readable() const {
|
|
if (max_timeout_msec_ <= 0) {
|
|
return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0;
|
|
}
|
|
|
|
time_t read_timeout_sec;
|
|
time_t read_timeout_usec;
|
|
calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_,
|
|
read_timeout_usec_, read_timeout_sec, read_timeout_usec);
|
|
|
|
return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0;
|
|
}
|
|
|
|
inline bool SSLSocketStream::wait_writable() const {
|
|
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
|
|
is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_);
|
|
}
|
|
|
|
inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
|
|
if (SSL_pending(ssl_) > 0) {
|
|
return SSL_read(ssl_, ptr, static_cast<int>(size));
|
|
} else if (wait_readable()) {
|
|
auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
|
|
if (ret < 0) {
|
|
auto err = SSL_get_error(ssl_, ret);
|
|
auto n = 1000;
|
|
#ifdef _WIN64
|
|
while (--n >= 0 && (err == SSL_ERROR_WANT_READ ||
|
|
(err == SSL_ERROR_SYSCALL &&
|
|
WSAGetLastError() == WSAETIMEDOUT))) {
|
|
#else
|
|
while (--n >= 0 && err == SSL_ERROR_WANT_READ) {
|
|
#endif
|
|
if (SSL_pending(ssl_) > 0) {
|
|
return SSL_read(ssl_, ptr, static_cast<int>(size));
|
|
} else if (wait_readable()) {
|
|
std::this_thread::sleep_for(std::chrono::microseconds{10});
|
|
ret = SSL_read(ssl_, ptr, static_cast<int>(size));
|
|
if (ret >= 0) { return ret; }
|
|
err = SSL_get_error(ssl_, ret);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
assert(ret < 0);
|
|
}
|
|
return ret;
|
|
} else {
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) {
|
|
if (wait_writable()) {
|
|
auto handle_size = static_cast<int>(
|
|
std::min<size_t>(size, (std::numeric_limits<int>::max)()));
|
|
|
|
auto ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
|
|
if (ret < 0) {
|
|
auto err = SSL_get_error(ssl_, ret);
|
|
auto n = 1000;
|
|
#ifdef _WIN64
|
|
while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE ||
|
|
(err == SSL_ERROR_SYSCALL &&
|
|
WSAGetLastError() == WSAETIMEDOUT))) {
|
|
#else
|
|
while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) {
|
|
#endif
|
|
if (wait_writable()) {
|
|
std::this_thread::sleep_for(std::chrono::microseconds{10});
|
|
ret = SSL_write(ssl_, ptr, static_cast<int>(handle_size));
|
|
if (ret >= 0) { return ret; }
|
|
err = SSL_get_error(ssl_, ret);
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
assert(ret < 0);
|
|
}
|
|
return ret;
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip,
|
|
int &port) const {
|
|
detail::get_remote_ip_and_port(sock_, ip, port);
|
|
}
|
|
|
|
inline void SSLSocketStream::get_local_ip_and_port(std::string &ip,
|
|
int &port) const {
|
|
detail::get_local_ip_and_port(sock_, ip, port);
|
|
}
|
|
|
|
inline socket_t SSLSocketStream::socket() const { return sock_; }
|
|
|
|
inline time_t SSLSocketStream::duration() const {
|
|
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
std::chrono::steady_clock::now() - start_time_)
|
|
.count();
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
// SSL HTTP server implementation
|
|
inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path,
|
|
const char *client_ca_cert_file_path,
|
|
const char *client_ca_cert_dir_path,
|
|
const char *private_key_password) {
|
|
ctx_ = SSL_CTX_new(TLS_server_method());
|
|
|
|
if (ctx_) {
|
|
SSL_CTX_set_options(ctx_,
|
|
SSL_OP_NO_COMPRESSION |
|
|
SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION);
|
|
|
|
SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION);
|
|
|
|
if (private_key_password != nullptr && (private_key_password[0] != '\0')) {
|
|
SSL_CTX_set_default_passwd_cb_userdata(
|
|
ctx_,
|
|
reinterpret_cast<void *>(const_cast<char *>(private_key_password)));
|
|
}
|
|
|
|
if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 ||
|
|
SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) !=
|
|
1 ||
|
|
SSL_CTX_check_private_key(ctx_) != 1) {
|
|
last_ssl_error_ = static_cast<int>(ERR_get_error());
|
|
SSL_CTX_free(ctx_);
|
|
ctx_ = nullptr;
|
|
} else if (client_ca_cert_file_path || client_ca_cert_dir_path) {
|
|
SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path,
|
|
client_ca_cert_dir_path);
|
|
|
|
SSL_CTX_set_verify(
|
|
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
|
}
|
|
}
|
|
}
|
|
|
|
inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key,
|
|
X509_STORE *client_ca_cert_store) {
|
|
ctx_ = SSL_CTX_new(TLS_server_method());
|
|
|
|
if (ctx_) {
|
|
SSL_CTX_set_options(ctx_,
|
|
SSL_OP_NO_COMPRESSION |
|
|
SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION);
|
|
|
|
SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION);
|
|
|
|
if (SSL_CTX_use_certificate(ctx_, cert) != 1 ||
|
|
SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) {
|
|
SSL_CTX_free(ctx_);
|
|
ctx_ = nullptr;
|
|
} else if (client_ca_cert_store) {
|
|
SSL_CTX_set_cert_store(ctx_, client_ca_cert_store);
|
|
|
|
SSL_CTX_set_verify(
|
|
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
|
}
|
|
}
|
|
}
|
|
|
|
inline SSLServer::SSLServer(
|
|
const std::function<bool(SSL_CTX &ssl_ctx)> &setup_ssl_ctx_callback) {
|
|
ctx_ = SSL_CTX_new(TLS_method());
|
|
if (ctx_) {
|
|
if (!setup_ssl_ctx_callback(*ctx_)) {
|
|
SSL_CTX_free(ctx_);
|
|
ctx_ = nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline SSLServer::~SSLServer() {
|
|
if (ctx_) { SSL_CTX_free(ctx_); }
|
|
}
|
|
|
|
inline bool SSLServer::is_valid() const { return ctx_; }
|
|
|
|
inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; }
|
|
|
|
inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key,
|
|
X509_STORE *client_ca_cert_store) {
|
|
|
|
std::lock_guard<std::mutex> guard(ctx_mutex_);
|
|
|
|
SSL_CTX_use_certificate(ctx_, cert);
|
|
SSL_CTX_use_PrivateKey(ctx_, private_key);
|
|
|
|
if (client_ca_cert_store != nullptr) {
|
|
SSL_CTX_set_cert_store(ctx_, client_ca_cert_store);
|
|
}
|
|
}
|
|
|
|
inline bool SSLServer::process_and_close_socket(socket_t sock) {
|
|
auto ssl = detail::ssl_new(
|
|
sock, ctx_, ctx_mutex_,
|
|
[&](SSL *ssl2) {
|
|
return detail::ssl_connect_or_accept_nonblocking(
|
|
sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_,
|
|
&last_ssl_error_);
|
|
},
|
|
[](SSL * /*ssl2*/) { return true; });
|
|
|
|
auto ret = false;
|
|
if (ssl) {
|
|
std::string remote_addr;
|
|
int remote_port = 0;
|
|
detail::get_remote_ip_and_port(sock, remote_addr, remote_port);
|
|
|
|
std::string local_addr;
|
|
int local_port = 0;
|
|
detail::get_local_ip_and_port(sock, local_addr, local_port);
|
|
|
|
ret = detail::process_server_socket_ssl(
|
|
svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_,
|
|
read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
|
|
write_timeout_usec_,
|
|
[&](Stream &strm, bool close_connection, bool &connection_closed) {
|
|
return process_request(strm, remote_addr, remote_port, local_addr,
|
|
local_port, close_connection,
|
|
connection_closed,
|
|
[&](Request &req) { req.ssl = ssl; });
|
|
});
|
|
|
|
// Shutdown gracefully if the result seemed successful, non-gracefully if
|
|
// the connection appeared to be closed.
|
|
const bool shutdown_gracefully = ret;
|
|
detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully);
|
|
}
|
|
|
|
detail::shutdown_socket(sock);
|
|
detail::close_socket(sock);
|
|
return ret;
|
|
}
|
|
|
|
// SSL HTTP client implementation
|
|
inline SSLClient::SSLClient(const std::string &host)
|
|
: SSLClient(host, 443, std::string(), std::string()) {}
|
|
|
|
inline SSLClient::SSLClient(const std::string &host, int port)
|
|
: SSLClient(host, port, std::string(), std::string()) {}
|
|
|
|
inline SSLClient::SSLClient(const std::string &host, int port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path,
|
|
const std::string &private_key_password)
|
|
: ClientImpl(host, port, client_cert_path, client_key_path) {
|
|
ctx_ = SSL_CTX_new(TLS_client_method());
|
|
|
|
SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION);
|
|
|
|
detail::split(&host_[0], &host_[host_.size()], '.',
|
|
[&](const char *b, const char *e) {
|
|
host_components_.emplace_back(b, e);
|
|
});
|
|
|
|
if (!client_cert_path.empty() && !client_key_path.empty()) {
|
|
if (!private_key_password.empty()) {
|
|
SSL_CTX_set_default_passwd_cb_userdata(
|
|
ctx_, reinterpret_cast<void *>(
|
|
const_cast<char *>(private_key_password.c_str())));
|
|
}
|
|
|
|
if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(),
|
|
SSL_FILETYPE_PEM) != 1 ||
|
|
SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(),
|
|
SSL_FILETYPE_PEM) != 1) {
|
|
last_openssl_error_ = ERR_get_error();
|
|
SSL_CTX_free(ctx_);
|
|
ctx_ = nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline SSLClient::SSLClient(const std::string &host, int port,
|
|
X509 *client_cert, EVP_PKEY *client_key,
|
|
const std::string &private_key_password)
|
|
: ClientImpl(host, port) {
|
|
ctx_ = SSL_CTX_new(TLS_client_method());
|
|
|
|
detail::split(&host_[0], &host_[host_.size()], '.',
|
|
[&](const char *b, const char *e) {
|
|
host_components_.emplace_back(b, e);
|
|
});
|
|
|
|
if (client_cert != nullptr && client_key != nullptr) {
|
|
if (!private_key_password.empty()) {
|
|
SSL_CTX_set_default_passwd_cb_userdata(
|
|
ctx_, reinterpret_cast<void *>(
|
|
const_cast<char *>(private_key_password.c_str())));
|
|
}
|
|
|
|
if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 ||
|
|
SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) {
|
|
last_openssl_error_ = ERR_get_error();
|
|
SSL_CTX_free(ctx_);
|
|
ctx_ = nullptr;
|
|
}
|
|
}
|
|
}
|
|
|
|
inline SSLClient::~SSLClient() {
|
|
if (ctx_) { SSL_CTX_free(ctx_); }
|
|
// Make sure to shut down SSL since shutdown_ssl will resolve to the
|
|
// base function rather than the derived function once we get to the
|
|
// base class destructor, and won't free the SSL (causing a leak).
|
|
shutdown_ssl_impl(socket_, true);
|
|
}
|
|
|
|
inline bool SSLClient::is_valid() const { return ctx_; }
|
|
|
|
inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) {
|
|
if (ca_cert_store) {
|
|
if (ctx_) {
|
|
if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) {
|
|
// Free memory allocated for old cert and use new store `ca_cert_store`
|
|
SSL_CTX_set_cert_store(ctx_, ca_cert_store);
|
|
}
|
|
} else {
|
|
X509_STORE_free(ca_cert_store);
|
|
}
|
|
}
|
|
}
|
|
|
|
inline void SSLClient::load_ca_cert_store(const char *ca_cert,
|
|
std::size_t size) {
|
|
set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size));
|
|
}
|
|
|
|
inline long SSLClient::get_openssl_verify_result() const {
|
|
return verify_result_;
|
|
}
|
|
|
|
inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
|
|
|
|
inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) {
|
|
return is_valid() && ClientImpl::create_and_connect_socket(socket, error);
|
|
}
|
|
|
|
// Assumes that socket_mutex_ is locked and that there are no requests in flight
|
|
inline bool SSLClient::connect_with_proxy(
|
|
Socket &socket,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
Response &res, bool &success, Error &error) {
|
|
success = true;
|
|
Response proxy_res;
|
|
if (!detail::process_client_socket(
|
|
socket.sock, read_timeout_sec_, read_timeout_usec_,
|
|
write_timeout_sec_, write_timeout_usec_, max_timeout_msec_,
|
|
start_time, [&](Stream &strm) {
|
|
Request req2;
|
|
req2.method = "CONNECT";
|
|
req2.path = host_and_port_;
|
|
if (max_timeout_msec_ > 0) {
|
|
req2.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
return process_request(strm, req2, proxy_res, false, error);
|
|
})) {
|
|
// Thread-safe to close everything because we are assuming there are no
|
|
// requests in flight
|
|
shutdown_ssl(socket, true);
|
|
shutdown_socket(socket);
|
|
close_socket(socket);
|
|
success = false;
|
|
return false;
|
|
}
|
|
|
|
if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) {
|
|
if (!proxy_digest_auth_username_.empty() &&
|
|
!proxy_digest_auth_password_.empty()) {
|
|
std::map<std::string, std::string> auth;
|
|
if (detail::parse_www_authenticate(proxy_res, auth, true)) {
|
|
// Close the current socket and create a new one for the authenticated
|
|
// request
|
|
shutdown_ssl(socket, true);
|
|
shutdown_socket(socket);
|
|
close_socket(socket);
|
|
|
|
// Create a new socket for the authenticated CONNECT request
|
|
if (!create_and_connect_socket(socket, error)) {
|
|
success = false;
|
|
return false;
|
|
}
|
|
|
|
proxy_res = Response();
|
|
if (!detail::process_client_socket(
|
|
socket.sock, read_timeout_sec_, read_timeout_usec_,
|
|
write_timeout_sec_, write_timeout_usec_, max_timeout_msec_,
|
|
start_time, [&](Stream &strm) {
|
|
Request req3;
|
|
req3.method = "CONNECT";
|
|
req3.path = host_and_port_;
|
|
req3.headers.insert(detail::make_digest_authentication_header(
|
|
req3, auth, 1, detail::random_string(10),
|
|
proxy_digest_auth_username_, proxy_digest_auth_password_,
|
|
true));
|
|
if (max_timeout_msec_ > 0) {
|
|
req3.start_time_ = std::chrono::steady_clock::now();
|
|
}
|
|
return process_request(strm, req3, proxy_res, false, error);
|
|
})) {
|
|
// Thread-safe to close everything because we are assuming there are
|
|
// no requests in flight
|
|
shutdown_ssl(socket, true);
|
|
shutdown_socket(socket);
|
|
close_socket(socket);
|
|
success = false;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If status code is not 200, proxy request is failed.
|
|
// Set error to ProxyConnection and return proxy response
|
|
// as the response of the request
|
|
if (proxy_res.status != StatusCode::OK_200) {
|
|
error = Error::ProxyConnection;
|
|
res = std::move(proxy_res);
|
|
// Thread-safe to close everything because we are assuming there are
|
|
// no requests in flight
|
|
shutdown_ssl(socket, true);
|
|
shutdown_socket(socket);
|
|
close_socket(socket);
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
inline bool SSLClient::load_certs() {
|
|
auto ret = true;
|
|
|
|
std::call_once(initialize_cert_, [&]() {
|
|
std::lock_guard<std::mutex> guard(ctx_mutex_);
|
|
if (!ca_cert_file_path_.empty()) {
|
|
if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(),
|
|
nullptr)) {
|
|
last_openssl_error_ = ERR_get_error();
|
|
ret = false;
|
|
}
|
|
} else if (!ca_cert_dir_path_.empty()) {
|
|
if (!SSL_CTX_load_verify_locations(ctx_, nullptr,
|
|
ca_cert_dir_path_.c_str())) {
|
|
last_openssl_error_ = ERR_get_error();
|
|
ret = false;
|
|
}
|
|
} else {
|
|
auto loaded = false;
|
|
#ifdef _WIN64
|
|
loaded =
|
|
detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_));
|
|
#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \
|
|
defined(TARGET_OS_OSX)
|
|
loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_));
|
|
#endif // _WIN64
|
|
if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); }
|
|
}
|
|
});
|
|
|
|
return ret;
|
|
}
|
|
|
|
inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
|
|
auto ssl = detail::ssl_new(
|
|
socket.sock, ctx_, ctx_mutex_,
|
|
[&](SSL *ssl2) {
|
|
if (server_certificate_verification_) {
|
|
if (!load_certs()) {
|
|
error = Error::SSLLoadingCerts;
|
|
return false;
|
|
}
|
|
SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr);
|
|
}
|
|
|
|
if (!detail::ssl_connect_or_accept_nonblocking(
|
|
socket.sock, ssl2, SSL_connect, connection_timeout_sec_,
|
|
connection_timeout_usec_, &last_ssl_error_)) {
|
|
error = Error::SSLConnection;
|
|
return false;
|
|
}
|
|
|
|
if (server_certificate_verification_) {
|
|
auto verification_status = SSLVerifierResponse::NoDecisionMade;
|
|
|
|
if (server_certificate_verifier_) {
|
|
verification_status = server_certificate_verifier_(ssl2);
|
|
}
|
|
|
|
if (verification_status == SSLVerifierResponse::CertificateRejected) {
|
|
last_openssl_error_ = ERR_get_error();
|
|
error = Error::SSLServerVerification;
|
|
return false;
|
|
}
|
|
|
|
if (verification_status == SSLVerifierResponse::NoDecisionMade) {
|
|
verify_result_ = SSL_get_verify_result(ssl2);
|
|
|
|
if (verify_result_ != X509_V_OK) {
|
|
last_openssl_error_ = static_cast<unsigned long>(verify_result_);
|
|
error = Error::SSLServerVerification;
|
|
return false;
|
|
}
|
|
|
|
auto server_cert = SSL_get1_peer_certificate(ssl2);
|
|
auto se = detail::scope_exit([&] { X509_free(server_cert); });
|
|
|
|
if (server_cert == nullptr) {
|
|
last_openssl_error_ = ERR_get_error();
|
|
error = Error::SSLServerVerification;
|
|
return false;
|
|
}
|
|
|
|
if (server_hostname_verification_) {
|
|
if (!verify_host(server_cert)) {
|
|
last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH;
|
|
error = Error::SSLServerHostnameVerification;
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
},
|
|
[&](SSL *ssl2) {
|
|
#if defined(OPENSSL_IS_BORINGSSL)
|
|
SSL_set_tlsext_host_name(ssl2, host_.c_str());
|
|
#else
|
|
// NOTE: Direct call instead of using the OpenSSL macro to suppress
|
|
// -Wold-style-cast warning
|
|
SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name,
|
|
static_cast<void *>(const_cast<char *>(host_.c_str())));
|
|
#endif
|
|
return true;
|
|
});
|
|
|
|
if (ssl) {
|
|
socket.ssl = ssl;
|
|
return true;
|
|
}
|
|
|
|
shutdown_socket(socket);
|
|
close_socket(socket);
|
|
return false;
|
|
}
|
|
|
|
inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) {
|
|
shutdown_ssl_impl(socket, shutdown_gracefully);
|
|
}
|
|
|
|
inline void SSLClient::shutdown_ssl_impl(Socket &socket,
|
|
bool shutdown_gracefully) {
|
|
if (socket.sock == INVALID_SOCKET) {
|
|
assert(socket.ssl == nullptr);
|
|
return;
|
|
}
|
|
if (socket.ssl) {
|
|
detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock,
|
|
shutdown_gracefully);
|
|
socket.ssl = nullptr;
|
|
}
|
|
assert(socket.ssl == nullptr);
|
|
}
|
|
|
|
inline bool SSLClient::process_socket(
|
|
const Socket &socket,
|
|
std::chrono::time_point<std::chrono::steady_clock> start_time,
|
|
std::function<bool(Stream &strm)> callback) {
|
|
assert(socket.ssl);
|
|
return detail::process_client_socket_ssl(
|
|
socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_,
|
|
write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time,
|
|
std::move(callback));
|
|
}
|
|
|
|
inline bool SSLClient::is_ssl() const { return true; }
|
|
|
|
inline bool SSLClient::verify_host(X509 *server_cert) const {
|
|
/* Quote from RFC2818 section 3.1 "Server Identity"
|
|
|
|
If a subjectAltName extension of type dNSName is present, that MUST
|
|
be used as the identity. Otherwise, the (most specific) Common Name
|
|
field in the Subject field of the certificate MUST be used. Although
|
|
the use of the Common Name is existing practice, it is deprecated and
|
|
Certification Authorities are encouraged to use the dNSName instead.
|
|
|
|
Matching is performed using the matching rules specified by
|
|
[RFC2459]. If more than one identity of a given type is present in
|
|
the certificate (e.g., more than one dNSName name, a match in any one
|
|
of the set is considered acceptable.) Names may contain the wildcard
|
|
character * which is considered to match any single domain name
|
|
component or component fragment. E.g., *.a.com matches foo.a.com but
|
|
not bar.foo.a.com. f*.com matches foo.com but not bar.com.
|
|
|
|
In some cases, the URI is specified as an IP address rather than a
|
|
hostname. In this case, the iPAddress subjectAltName must be present
|
|
in the certificate and must exactly match the IP in the URI.
|
|
|
|
*/
|
|
return verify_host_with_subject_alt_name(server_cert) ||
|
|
verify_host_with_common_name(server_cert);
|
|
}
|
|
|
|
inline bool
|
|
SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const {
|
|
auto ret = false;
|
|
|
|
auto type = GEN_DNS;
|
|
|
|
struct in6_addr addr6 = {};
|
|
struct in_addr addr = {};
|
|
size_t addr_len = 0;
|
|
|
|
#ifndef __MINGW32__
|
|
if (inet_pton(AF_INET6, host_.c_str(), &addr6)) {
|
|
type = GEN_IPADD;
|
|
addr_len = sizeof(struct in6_addr);
|
|
} else if (inet_pton(AF_INET, host_.c_str(), &addr)) {
|
|
type = GEN_IPADD;
|
|
addr_len = sizeof(struct in_addr);
|
|
}
|
|
#endif
|
|
|
|
auto alt_names = static_cast<const struct stack_st_GENERAL_NAME *>(
|
|
X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr));
|
|
|
|
if (alt_names) {
|
|
auto dsn_matched = false;
|
|
auto ip_matched = false;
|
|
|
|
auto count = sk_GENERAL_NAME_num(alt_names);
|
|
|
|
for (decltype(count) i = 0; i < count && !dsn_matched; i++) {
|
|
auto val = sk_GENERAL_NAME_value(alt_names, i);
|
|
if (val->type == type) {
|
|
auto name =
|
|
reinterpret_cast<const char *>(ASN1_STRING_get0_data(val->d.ia5));
|
|
auto name_len = static_cast<size_t>(ASN1_STRING_length(val->d.ia5));
|
|
|
|
switch (type) {
|
|
case GEN_DNS: dsn_matched = check_host_name(name, name_len); break;
|
|
|
|
case GEN_IPADD:
|
|
if (!memcmp(&addr6, name, addr_len) ||
|
|
!memcmp(&addr, name, addr_len)) {
|
|
ip_matched = true;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (dsn_matched || ip_matched) { ret = true; }
|
|
}
|
|
|
|
GENERAL_NAMES_free(const_cast<STACK_OF(GENERAL_NAME) *>(
|
|
reinterpret_cast<const STACK_OF(GENERAL_NAME) *>(alt_names)));
|
|
return ret;
|
|
}
|
|
|
|
inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const {
|
|
const auto subject_name = X509_get_subject_name(server_cert);
|
|
|
|
if (subject_name != nullptr) {
|
|
char name[BUFSIZ];
|
|
auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName,
|
|
name, sizeof(name));
|
|
|
|
if (name_len != -1) {
|
|
return check_host_name(name, static_cast<size_t>(name_len));
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
inline bool SSLClient::check_host_name(const char *pattern,
|
|
size_t pattern_len) const {
|
|
if (host_.size() == pattern_len && host_ == pattern) { return true; }
|
|
|
|
// Wildcard match
|
|
// https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484
|
|
std::vector<std::string> pattern_components;
|
|
detail::split(&pattern[0], &pattern[pattern_len], '.',
|
|
[&](const char *b, const char *e) {
|
|
pattern_components.emplace_back(b, e);
|
|
});
|
|
|
|
if (host_components_.size() != pattern_components.size()) { return false; }
|
|
|
|
auto itr = pattern_components.begin();
|
|
for (const auto &h : host_components_) {
|
|
auto &p = *itr;
|
|
if (p != h && p != "*") {
|
|
auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' &&
|
|
!p.compare(0, p.size() - 1, h));
|
|
if (!partial_match) { return false; }
|
|
}
|
|
++itr;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
#endif
|
|
|
|
// Universal client implementation
|
|
inline Client::Client(const std::string &scheme_host_port)
|
|
: Client(scheme_host_port, std::string(), std::string()) {}
|
|
|
|
inline Client::Client(const std::string &scheme_host_port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path) {
|
|
const static std::regex re(
|
|
R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)");
|
|
|
|
std::smatch m;
|
|
if (std::regex_match(scheme_host_port, m, re)) {
|
|
auto scheme = m[1].str();
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
if (!scheme.empty() && (scheme != "http" && scheme != "https")) {
|
|
#else
|
|
if (!scheme.empty() && scheme != "http") {
|
|
#endif
|
|
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
|
std::string msg = "'" + scheme + "' scheme is not supported.";
|
|
throw std::invalid_argument(msg);
|
|
#endif
|
|
return;
|
|
}
|
|
|
|
auto is_ssl = scheme == "https";
|
|
|
|
auto host = m[2].str();
|
|
if (host.empty()) { host = m[3].str(); }
|
|
|
|
auto port_str = m[4].str();
|
|
auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80);
|
|
|
|
if (is_ssl) {
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
cli_ = detail::make_unique<SSLClient>(host, port, client_cert_path,
|
|
client_key_path);
|
|
is_ssl_ = is_ssl;
|
|
#endif
|
|
} else {
|
|
cli_ = detail::make_unique<ClientImpl>(host, port, client_cert_path,
|
|
client_key_path);
|
|
}
|
|
} else {
|
|
// NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress)
|
|
// if port param below changes.
|
|
cli_ = detail::make_unique<ClientImpl>(scheme_host_port, 80,
|
|
client_cert_path, client_key_path);
|
|
}
|
|
} // namespace detail
|
|
|
|
inline Client::Client(const std::string &host, int port)
|
|
: cli_(detail::make_unique<ClientImpl>(host, port)) {}
|
|
|
|
inline Client::Client(const std::string &host, int port,
|
|
const std::string &client_cert_path,
|
|
const std::string &client_key_path)
|
|
: cli_(detail::make_unique<ClientImpl>(host, port, client_cert_path,
|
|
client_key_path)) {}
|
|
|
|
inline Client::~Client() = default;
|
|
|
|
inline bool Client::is_valid() const {
|
|
return cli_ != nullptr && cli_->is_valid();
|
|
}
|
|
|
|
inline Result Client::Get(const std::string &path, DownloadProgress progress) {
|
|
return cli_->Get(path, std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path, const Headers &headers,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, headers, std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, std::move(content_receiver), std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path, const Headers &headers,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, headers, std::move(content_receiver),
|
|
std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path,
|
|
ResponseHandler response_handler,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, std::move(response_handler),
|
|
std::move(content_receiver), std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path, const Headers &headers,
|
|
ResponseHandler response_handler,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, headers, std::move(response_handler),
|
|
std::move(content_receiver), std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path, const Params ¶ms,
|
|
const Headers &headers, DownloadProgress progress) {
|
|
return cli_->Get(path, params, headers, std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path, const Params ¶ms,
|
|
const Headers &headers,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, params, headers, std::move(content_receiver),
|
|
std::move(progress));
|
|
}
|
|
inline Result Client::Get(const std::string &path, const Params ¶ms,
|
|
const Headers &headers,
|
|
ResponseHandler response_handler,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Get(path, params, headers, std::move(response_handler),
|
|
std::move(content_receiver), std::move(progress));
|
|
}
|
|
|
|
inline Result Client::Head(const std::string &path) { return cli_->Head(path); }
|
|
inline Result Client::Head(const std::string &path, const Headers &headers) {
|
|
return cli_->Head(path, headers);
|
|
}
|
|
|
|
inline Result Client::Post(const std::string &path) { return cli_->Post(path); }
|
|
inline Result Client::Post(const std::string &path, const Headers &headers) {
|
|
return cli_->Post(path, headers);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, body, content_length, content_type, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, body, content_length, content_type,
|
|
progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, body, content_type, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, body, content_type, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, std::move(content_provider), content_type, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Params ¶ms) {
|
|
return cli_->Post(path, params);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const Params ¶ms) {
|
|
return cli_->Post(path, headers, params);
|
|
}
|
|
inline Result Client::Post(const std::string &path,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, items, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, items, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const std::string &boundary,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, items, boundary, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items,
|
|
UploadProgress progress) {
|
|
return cli_->Post(path, headers, items, provider_items, progress);
|
|
}
|
|
inline Result Client::Post(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Post(path, headers, body, content_type, content_receiver,
|
|
progress);
|
|
}
|
|
|
|
inline Result Client::Put(const std::string &path) { return cli_->Put(path); }
|
|
inline Result Client::Put(const std::string &path, const Headers &headers) {
|
|
return cli_->Put(path, headers);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, body, content_length, content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, body, content_length, content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, body, content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, body, content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, std::move(content_provider), content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Params ¶ms) {
|
|
return cli_->Put(path, params);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const Params ¶ms) {
|
|
return cli_->Put(path, headers, params);
|
|
}
|
|
inline Result Client::Put(const std::string &path,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, items, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, items, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const std::string &boundary,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, items, boundary, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items,
|
|
UploadProgress progress) {
|
|
return cli_->Put(path, headers, items, provider_items, progress);
|
|
}
|
|
inline Result Client::Put(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Put(path, headers, body, content_type, content_receiver,
|
|
progress);
|
|
}
|
|
|
|
inline Result Client::Patch(const std::string &path) {
|
|
return cli_->Patch(path);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers) {
|
|
return cli_->Patch(path, headers);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, body, content_length, content_type, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, body, content_length, content_type,
|
|
progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, body, content_type, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, body, content_type, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, std::move(content_provider), content_type, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
size_t content_length,
|
|
ContentProvider content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, content_length, std::move(content_provider),
|
|
content_type, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
ContentProviderWithoutLength content_provider,
|
|
const std::string &content_type,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, std::move(content_provider), content_type,
|
|
progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Params ¶ms) {
|
|
return cli_->Patch(path, params);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const Params ¶ms) {
|
|
return cli_->Patch(path, headers, params);
|
|
}
|
|
inline Result Client::Patch(const std::string &path,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, items, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, items, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const std::string &boundary,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, items, boundary, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const UploadFormDataItems &items,
|
|
const FormDataProviderItems &provider_items,
|
|
UploadProgress progress) {
|
|
return cli_->Patch(path, headers, items, provider_items, progress);
|
|
}
|
|
inline Result Client::Patch(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
ContentReceiver content_receiver,
|
|
DownloadProgress progress) {
|
|
return cli_->Patch(path, headers, body, content_type, content_receiver,
|
|
progress);
|
|
}
|
|
|
|
inline Result Client::Delete(const std::string &path,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const Headers &headers,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, headers, progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const char *body,
|
|
size_t content_length,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, body, content_length, content_type, progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const Headers &headers,
|
|
const char *body, size_t content_length,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, headers, body, content_length, content_type,
|
|
progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const std::string &body,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, body, content_type, progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const Headers &headers,
|
|
const std::string &body,
|
|
const std::string &content_type,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, headers, body, content_type, progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const Params ¶ms,
|
|
DownloadProgress progress) {
|
|
return cli_->Delete(path, params, progress);
|
|
}
|
|
inline Result Client::Delete(const std::string &path, const Headers &headers,
|
|
const Params ¶ms, DownloadProgress progress) {
|
|
return cli_->Delete(path, headers, params, progress);
|
|
}
|
|
|
|
inline Result Client::Options(const std::string &path) {
|
|
return cli_->Options(path);
|
|
}
|
|
inline Result Client::Options(const std::string &path, const Headers &headers) {
|
|
return cli_->Options(path, headers);
|
|
}
|
|
|
|
inline bool Client::send(Request &req, Response &res, Error &error) {
|
|
return cli_->send(req, res, error);
|
|
}
|
|
|
|
inline Result Client::send(const Request &req) { return cli_->send(req); }
|
|
|
|
inline void Client::stop() { cli_->stop(); }
|
|
|
|
inline std::string Client::host() const { return cli_->host(); }
|
|
|
|
inline int Client::port() const { return cli_->port(); }
|
|
|
|
inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); }
|
|
|
|
inline socket_t Client::socket() const { return cli_->socket(); }
|
|
|
|
inline void
|
|
Client::set_hostname_addr_map(std::map<std::string, std::string> addr_map) {
|
|
cli_->set_hostname_addr_map(std::move(addr_map));
|
|
}
|
|
|
|
inline void Client::set_default_headers(Headers headers) {
|
|
cli_->set_default_headers(std::move(headers));
|
|
}
|
|
|
|
inline void Client::set_header_writer(
|
|
std::function<ssize_t(Stream &, Headers &)> const &writer) {
|
|
cli_->set_header_writer(writer);
|
|
}
|
|
|
|
inline void Client::set_address_family(int family) {
|
|
cli_->set_address_family(family);
|
|
}
|
|
|
|
inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); }
|
|
|
|
inline void Client::set_socket_options(SocketOptions socket_options) {
|
|
cli_->set_socket_options(std::move(socket_options));
|
|
}
|
|
|
|
inline void Client::set_connection_timeout(time_t sec, time_t usec) {
|
|
cli_->set_connection_timeout(sec, usec);
|
|
}
|
|
|
|
inline void Client::set_read_timeout(time_t sec, time_t usec) {
|
|
cli_->set_read_timeout(sec, usec);
|
|
}
|
|
|
|
inline void Client::set_write_timeout(time_t sec, time_t usec) {
|
|
cli_->set_write_timeout(sec, usec);
|
|
}
|
|
|
|
inline void Client::set_basic_auth(const std::string &username,
|
|
const std::string &password) {
|
|
cli_->set_basic_auth(username, password);
|
|
}
|
|
inline void Client::set_bearer_token_auth(const std::string &token) {
|
|
cli_->set_bearer_token_auth(token);
|
|
}
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline void Client::set_digest_auth(const std::string &username,
|
|
const std::string &password) {
|
|
cli_->set_digest_auth(username, password);
|
|
}
|
|
#endif
|
|
|
|
inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); }
|
|
inline void Client::set_follow_location(bool on) {
|
|
cli_->set_follow_location(on);
|
|
}
|
|
|
|
inline void Client::set_path_encode(bool on) { cli_->set_path_encode(on); }
|
|
|
|
[[deprecated("Use set_path_encode instead")]]
|
|
inline void Client::set_url_encode(bool on) {
|
|
cli_->set_path_encode(on);
|
|
}
|
|
|
|
inline void Client::set_compress(bool on) { cli_->set_compress(on); }
|
|
|
|
inline void Client::set_decompress(bool on) { cli_->set_decompress(on); }
|
|
|
|
inline void Client::set_interface(const std::string &intf) {
|
|
cli_->set_interface(intf);
|
|
}
|
|
|
|
inline void Client::set_proxy(const std::string &host, int port) {
|
|
cli_->set_proxy(host, port);
|
|
}
|
|
inline void Client::set_proxy_basic_auth(const std::string &username,
|
|
const std::string &password) {
|
|
cli_->set_proxy_basic_auth(username, password);
|
|
}
|
|
inline void Client::set_proxy_bearer_token_auth(const std::string &token) {
|
|
cli_->set_proxy_bearer_token_auth(token);
|
|
}
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline void Client::set_proxy_digest_auth(const std::string &username,
|
|
const std::string &password) {
|
|
cli_->set_proxy_digest_auth(username, password);
|
|
}
|
|
#endif
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline void Client::enable_server_certificate_verification(bool enabled) {
|
|
cli_->enable_server_certificate_verification(enabled);
|
|
}
|
|
|
|
inline void Client::enable_server_hostname_verification(bool enabled) {
|
|
cli_->enable_server_hostname_verification(enabled);
|
|
}
|
|
|
|
inline void Client::set_server_certificate_verifier(
|
|
std::function<SSLVerifierResponse(SSL *ssl)> verifier) {
|
|
cli_->set_server_certificate_verifier(verifier);
|
|
}
|
|
#endif
|
|
|
|
inline void Client::set_logger(Logger logger) {
|
|
cli_->set_logger(std::move(logger));
|
|
}
|
|
|
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path,
|
|
const std::string &ca_cert_dir_path) {
|
|
cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path);
|
|
}
|
|
|
|
inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) {
|
|
if (is_ssl_) {
|
|
static_cast<SSLClient &>(*cli_).set_ca_cert_store(ca_cert_store);
|
|
} else {
|
|
cli_->set_ca_cert_store(ca_cert_store);
|
|
}
|
|
}
|
|
|
|
inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) {
|
|
set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size));
|
|
}
|
|
|
|
inline long Client::get_openssl_verify_result() const {
|
|
if (is_ssl_) {
|
|
return static_cast<SSLClient &>(*cli_).get_openssl_verify_result();
|
|
}
|
|
return -1; // NOTE: -1 doesn't match any of X509_V_ERR_???
|
|
}
|
|
|
|
inline SSL_CTX *Client::ssl_context() const {
|
|
if (is_ssl_) { return static_cast<SSLClient &>(*cli_).ssl_context(); }
|
|
return nullptr;
|
|
}
|
|
#endif
|
|
|
|
// ----------------------------------------------------------------------------
|
|
|
|
} // namespace httplib
|
|
|
|
#endif // CPPHTTPLIB_HTTPLIB_H
|
|
|
|
// End of httplib.h
|
|
|
|
// Start of main.c
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
#include <stdbool.h>
|
|
#include <curl/curl.h>
|
|
|
|
#ifdef _WIN32
|
|
#include <windows.h>
|
|
#else
|
|
#include <unistd.h>
|
|
#endif
|
|
|
|
#include "wren.h"
|
|
#include "requests_backend.c"
|
|
#include "socket_backend.c"
|
|
|
|
// --- Global flag to control the main loop ---
|
|
static volatile bool g_mainFiberIsDone = false;
|
|
|
|
// --- Foreign function for Wren to signal the host to exit ---
|
|
void hostSignalDone(WrenVM* vm) {
|
|
(void)vm;
|
|
g_mainFiberIsDone = true;
|
|
}
|
|
|
|
// --- File/VM Setup ---
|
|
static char* readFile(const char* path) {
|
|
FILE* file = fopen(path, "rb");
|
|
if (file == NULL) return NULL;
|
|
fseek(file, 0L, SEEK_END);
|
|
size_t fileSize = ftell(file);
|
|
rewind(file);
|
|
char* buffer = (char*)malloc(fileSize + 1);
|
|
if (!buffer) { fclose(file); return NULL; }
|
|
size_t bytesRead = fread(buffer, sizeof(char), fileSize, file);
|
|
if (bytesRead < fileSize) {
|
|
free(buffer);
|
|
fclose(file);
|
|
return NULL;
|
|
}
|
|
buffer[bytesRead] = '\0';
|
|
fclose(file);
|
|
return buffer;
|
|
}
|
|
|
|
static void writeFn(WrenVM* vm, const char* text) { (void)vm; printf("%s", text); }
|
|
|
|
static void errorFn(WrenVM* vm, WrenErrorType type, const char* module, int line, const char* message) {
|
|
(void)vm;
|
|
switch (type) {
|
|
case WREN_ERROR_COMPILE:
|
|
fprintf(stderr, "[%s line %d] [Error] %s\n", module, line, message);
|
|
break;
|
|
case WREN_ERROR_RUNTIME:
|
|
fprintf(stderr, "[Runtime Error] %s\n", message);
|
|
g_mainFiberIsDone = true; // Stop on runtime errors
|
|
break;
|
|
case WREN_ERROR_STACK_TRACE:
|
|
fprintf(stderr, "[%s line %d] in %s\n", module, line, message);
|
|
break;
|
|
}
|
|
}
|
|
|
|
static void onModuleComplete(WrenVM* vm, const char* name, WrenLoadModuleResult result) {
|
|
(void)vm; (void)name;
|
|
if (result.source) free((void*)result.source);
|
|
}
|
|
|
|
static WrenLoadModuleResult loadModule(WrenVM* vm, const char* name) {
|
|
(void)vm;
|
|
WrenLoadModuleResult result = {0};
|
|
char path[256];
|
|
snprintf(path, sizeof(path), "%s.wren", name);
|
|
char* source = readFile(path);
|
|
if (source != NULL) {
|
|
result.source = source;
|
|
result.onComplete = onModuleComplete;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// --- Combined Foreign Function Binders ---
|
|
WrenForeignMethodFn combinedBindForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) {
|
|
// Delegate to the socket backend's binder
|
|
if (strcmp(module, "socket") == 0) {
|
|
return bindSocketForeignMethod(vm, module, className, isStatic, signature);
|
|
}
|
|
|
|
// Delegate to the requests backend's binder
|
|
if (strcmp(module, "requests") == 0) {
|
|
return bindForeignMethod(vm, module, className, isStatic, signature);
|
|
}
|
|
|
|
// Handle host-specific methods
|
|
if (strcmp(module, "main") == 0 && strcmp(className, "Host") == 0 && isStatic) {
|
|
if (strcmp(signature, "signalDone()") == 0) return hostSignalDone;
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
WrenForeignClassMethods combinedBindForeignClass(WrenVM* vm, const char* module, const char* className) {
|
|
// Delegate to the socket backend's class binder
|
|
if (strcmp(module, "socket") == 0) {
|
|
return bindSocketForeignClass(vm, module, className);
|
|
}
|
|
|
|
// Delegate to the requests backend's class binder
|
|
if (strcmp(module, "requests") == 0) {
|
|
return bindForeignClass(vm, module, className);
|
|
}
|
|
|
|
WrenForeignClassMethods methods = {0, 0};
|
|
return methods;
|
|
}
|
|
|
|
|
|
// --- Main Application Entry Point ---
|
|
int main(int argc, char* argv[]) {
|
|
if (argc < 2) {
|
|
fprintf(stderr, "Usage: %s <script.wren>\n", argv[0]);
|
|
return 1;
|
|
}
|
|
|
|
// Initialize libcurl for the requests module
|
|
curl_global_init(CURL_GLOBAL_ALL);
|
|
|
|
WrenConfiguration config;
|
|
wrenInitConfiguration(&config);
|
|
config.writeFn = writeFn;
|
|
config.errorFn = errorFn;
|
|
config.bindForeignMethodFn = combinedBindForeignMethod;
|
|
config.bindForeignClassFn = combinedBindForeignClass;
|
|
config.loadModuleFn = loadModule;
|
|
|
|
WrenVM* vm = wrenNewVM(&config);
|
|
|
|
// ** Initialize BOTH managers **
|
|
socketManager_create(vm);
|
|
httpManager_create(vm);
|
|
|
|
char* mainSource = readFile(argv[1]);
|
|
if (!mainSource) {
|
|
fprintf(stderr, "Could not open script: %s\n", argv[1]);
|
|
socketManager_destroy();
|
|
httpManager_destroy();
|
|
wrenFreeVM(vm);
|
|
curl_global_cleanup();
|
|
return 1;
|
|
}
|
|
|
|
wrenInterpret(vm, "main", mainSource);
|
|
free(mainSource);
|
|
|
|
if (g_mainFiberIsDone) {
|
|
socketManager_destroy();
|
|
httpManager_destroy();
|
|
wrenFreeVM(vm);
|
|
curl_global_cleanup();
|
|
return 1;
|
|
}
|
|
|
|
wrenEnsureSlots(vm, 1);
|
|
wrenGetVariable(vm, "main", "mainFiber", 0);
|
|
WrenHandle* mainFiberHandle = wrenGetSlotHandle(vm, 0);
|
|
WrenHandle* callHandle = wrenMakeCallHandle(vm, "call()");
|
|
|
|
// === Main Event Loop ===
|
|
while (!g_mainFiberIsDone) {
|
|
// ** Process completions for BOTH managers **
|
|
socketManager_processCompletions();
|
|
httpManager_processCompletions();
|
|
|
|
// Resume the main Wren fiber
|
|
wrenEnsureSlots(vm, 1);
|
|
wrenSetSlotHandle(vm, 0, mainFiberHandle);
|
|
WrenInterpretResult result = wrenCall(vm, callHandle);
|
|
if (result == WREN_RESULT_RUNTIME_ERROR) {
|
|
g_mainFiberIsDone = true;
|
|
}
|
|
|
|
// Prevent 100% CPU usage
|
|
#ifdef _WIN32
|
|
Sleep(1);
|
|
#else
|
|
usleep(1000); // 1ms
|
|
#endif
|
|
}
|
|
|
|
// Process any final completions before shutting down
|
|
socketManager_processCompletions();
|
|
httpManager_processCompletions();
|
|
|
|
wrenReleaseHandle(vm, mainFiberHandle);
|
|
wrenReleaseHandle(vm, callHandle);
|
|
|
|
// ** Destroy BOTH managers **
|
|
socketManager_destroy();
|
|
httpManager_destroy();
|
|
|
|
wrenFreeVM(vm);
|
|
curl_global_cleanup();
|
|
|
|
printf("\nHost application finished.\n");
|
|
return 0;
|
|
}
|
|
|
|
// End of main.c
|
|
|
|
// Start of requests_backend.c
|
|
// http_backend.c (Corrected)
|
|
#include "wren.h"
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
#include <curl/curl.h>
|
|
|
|
#ifdef _WIN32
|
|
#include <windows.h>
|
|
typedef HANDLE thread_t;
|
|
typedef CRITICAL_SECTION mutex_t;
|
|
typedef CONDITION_VARIABLE cond_t;
|
|
#else
|
|
#include <pthread.h>
|
|
typedef pthread_t thread_t;
|
|
typedef pthread_mutex_t mutex_t;
|
|
typedef pthread_cond_t cond_t;
|
|
#endif
|
|
|
|
// --- Data Structures ---
|
|
|
|
typedef struct {
|
|
int isError;
|
|
long statusCode;
|
|
char* body;
|
|
size_t body_len;
|
|
} ResponseData;
|
|
|
|
typedef struct {
|
|
char* memory;
|
|
size_t size;
|
|
} MemoryStruct;
|
|
|
|
typedef struct HttpContext {
|
|
WrenVM* vm;
|
|
WrenHandle* callback;
|
|
|
|
char* url;
|
|
char* method;
|
|
char* body;
|
|
struct curl_slist* headers;
|
|
|
|
bool success;
|
|
char* response_body;
|
|
size_t response_body_len;
|
|
long status_code;
|
|
char* error_message;
|
|
struct HttpContext* next;
|
|
} HttpContext;
|
|
|
|
|
|
// --- Thread-Safe Queue ---
|
|
|
|
typedef struct {
|
|
HttpContext *head, *tail;
|
|
mutex_t mutex;
|
|
cond_t cond;
|
|
} ThreadSafeQueue;
|
|
|
|
void http_queue_init(ThreadSafeQueue* q) {
|
|
q->head = q->tail = NULL;
|
|
#ifdef _WIN32
|
|
InitializeCriticalSection(&q->mutex);
|
|
InitializeConditionVariable(&q->cond);
|
|
#else
|
|
pthread_mutex_init(&q->mutex, NULL);
|
|
pthread_cond_init(&q->cond, NULL);
|
|
#endif
|
|
}
|
|
|
|
void http_queue_destroy(ThreadSafeQueue* q) {
|
|
#ifdef _WIN32
|
|
DeleteCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_destroy(&q->mutex);
|
|
pthread_cond_destroy(&q->cond);
|
|
#endif
|
|
}
|
|
|
|
void http_queue_push(ThreadSafeQueue* q, HttpContext* context) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_lock(&q->mutex);
|
|
#endif
|
|
|
|
if(context) context->next = NULL;
|
|
if (q->tail) q->tail->next = context;
|
|
else q->head = context;
|
|
q->tail = context;
|
|
|
|
#ifdef _WIN32
|
|
WakeConditionVariable(&q->cond);
|
|
LeaveCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_cond_signal(&q->cond);
|
|
pthread_mutex_unlock(&q->mutex);
|
|
#endif
|
|
}
|
|
|
|
HttpContext* http_queue_pop(ThreadSafeQueue* q) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&q->mutex);
|
|
while (q->head == NULL) {
|
|
SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE);
|
|
}
|
|
#else
|
|
pthread_mutex_lock(&q->mutex);
|
|
while (q->head == NULL) {
|
|
pthread_cond_wait(&q->cond, &q->mutex);
|
|
}
|
|
#endif
|
|
|
|
HttpContext* context = q->head;
|
|
q->head = q->head->next;
|
|
if (q->head == NULL) q->tail = NULL;
|
|
|
|
#ifdef _WIN32
|
|
LeaveCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_unlock(&q->mutex);
|
|
#endif
|
|
|
|
return context;
|
|
}
|
|
|
|
bool http_queue_empty(ThreadSafeQueue* q) {
|
|
#ifdef _WIN32
|
|
EnterCriticalSection(&q->mutex);
|
|
bool empty = (q->head == NULL);
|
|
LeaveCriticalSection(&q->mutex);
|
|
#else
|
|
pthread_mutex_lock(&q->mutex);
|
|
bool empty = (q->head == NULL);
|
|
pthread_mutex_unlock(&q->mutex);
|
|
#endif
|
|
return empty;
|
|
}
|
|
|
|
|
|
// --- libcurl Helpers ---
|
|
static size_t write_memory_callback(void *contents, size_t size, size_t nmemb, void *userp) {
|
|
size_t realsize = size * nmemb;
|
|
MemoryStruct *mem = (MemoryStruct *)userp;
|
|
char *ptr = (char*)realloc(mem->memory, mem->size + realsize + 1);
|
|
if (ptr == NULL) return 0;
|
|
mem->memory = ptr;
|
|
memcpy(&(mem->memory[mem->size]), contents, realsize);
|
|
mem->size += realsize;
|
|
mem->memory[mem->size] = 0;
|
|
return realsize;
|
|
}
|
|
|
|
// --- Async HTTP Manager ---
|
|
|
|
typedef struct {
|
|
WrenVM* vm;
|
|
volatile bool running;
|
|
thread_t threads[4];
|
|
ThreadSafeQueue requestQueue;
|
|
ThreadSafeQueue completionQueue;
|
|
} AsyncHttpManager;
|
|
|
|
static AsyncHttpManager* httpManager = NULL;
|
|
|
|
void free_http_context(HttpContext* context) {
|
|
if (!context) return;
|
|
free(context->url);
|
|
free(context->method);
|
|
free(context->body);
|
|
curl_slist_free_all(context->headers);
|
|
free(context->response_body);
|
|
free(context->error_message);
|
|
free(context);
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
DWORD WINAPI httpWorkerThread(LPVOID arg) {
|
|
#else
|
|
void* httpWorkerThread(void* arg) {
|
|
#endif
|
|
AsyncHttpManager* manager = (AsyncHttpManager*)arg;
|
|
while (manager->running) {
|
|
HttpContext* context = http_queue_pop(&manager->requestQueue);
|
|
if (!context || !manager->running) {
|
|
if (context) free_http_context(context);
|
|
break;
|
|
}
|
|
|
|
CURL *curl = curl_easy_init();
|
|
if (curl) {
|
|
MemoryStruct chunk;
|
|
chunk.memory = (char*)malloc(1);
|
|
chunk.size = 0;
|
|
|
|
curl_easy_setopt(curl, CURLOPT_URL, context->url);
|
|
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_memory_callback);
|
|
curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *)&chunk);
|
|
curl_easy_setopt(curl, CURLOPT_USERAGENT, "wren-curl-agent/1.0");
|
|
|
|
if (strcmp(context->method, "POST") == 0) {
|
|
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, context->body);
|
|
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, context->headers);
|
|
}
|
|
|
|
CURLcode res = curl_easy_perform(curl);
|
|
|
|
if (res == CURLE_OK) {
|
|
context->success = true;
|
|
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &context->status_code);
|
|
context->response_body = chunk.memory;
|
|
context->response_body_len = chunk.size;
|
|
} else {
|
|
context->success = false;
|
|
context->status_code = -1;
|
|
context->error_message = strdup(curl_easy_strerror(res));
|
|
free(chunk.memory);
|
|
}
|
|
curl_easy_cleanup(curl);
|
|
} else {
|
|
context->success = false;
|
|
context->error_message = strdup("Failed to initialize cURL handle.");
|
|
}
|
|
http_queue_push(&manager->completionQueue, context);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
void httpManager_create(WrenVM* vm) {
|
|
httpManager = (AsyncHttpManager*)malloc(sizeof(AsyncHttpManager));
|
|
httpManager->vm = vm;
|
|
httpManager->running = true;
|
|
http_queue_init(&httpManager->requestQueue);
|
|
http_queue_init(&httpManager->completionQueue);
|
|
for (int i = 0; i < 4; ++i) {
|
|
#ifdef _WIN32
|
|
httpManager->threads[i] = CreateThread(NULL, 0, httpWorkerThread, httpManager, 0, NULL);
|
|
#else
|
|
pthread_create(&httpManager->threads[i], NULL, httpWorkerThread, httpManager);
|
|
#endif
|
|
}
|
|
}
|
|
|
|
void httpManager_destroy() {
|
|
httpManager->running = false;
|
|
for (int i = 0; i < 4; ++i) {
|
|
http_queue_push(&httpManager->requestQueue, NULL);
|
|
}
|
|
for (int i = 0; i < 4; ++i) {
|
|
#ifdef _WIN32
|
|
WaitForSingleObject(httpManager->threads[i], INFINITE);
|
|
CloseHandle(httpManager->threads[i]);
|
|
#else
|
|
pthread_join(httpManager->threads[i], NULL);
|
|
#endif
|
|
}
|
|
http_queue_destroy(&httpManager->requestQueue);
|
|
http_queue_destroy(&httpManager->completionQueue);
|
|
free(httpManager);
|
|
}
|
|
|
|
void httpManager_processCompletions() {
|
|
while (!http_queue_empty(&httpManager->completionQueue)) {
|
|
HttpContext* context = http_queue_pop(&httpManager->completionQueue);
|
|
|
|
WrenHandle* callHandle = wrenMakeCallHandle(httpManager->vm, "call(_,_)");
|
|
wrenEnsureSlots(httpManager->vm, 3);
|
|
wrenSetSlotHandle(httpManager->vm, 0, context->callback);
|
|
|
|
if (context->success) {
|
|
wrenSetSlotNull(httpManager->vm, 1);
|
|
|
|
wrenGetVariable(httpManager->vm, "requests", "Response", 2);
|
|
void* foreign = wrenSetSlotNewForeign(httpManager->vm, 2, 2, sizeof(ResponseData));
|
|
ResponseData* data = (ResponseData*)foreign;
|
|
data->isError = false;
|
|
data->statusCode = context->status_code;
|
|
data->body = context->response_body;
|
|
data->body_len = context->response_body_len;
|
|
context->response_body = NULL;
|
|
} else {
|
|
wrenSetSlotString(httpManager->vm, 1, context->error_message);
|
|
wrenSetSlotNull(httpManager->vm, 2);
|
|
}
|
|
|
|
wrenCall(httpManager->vm, callHandle);
|
|
wrenReleaseHandle(httpManager->vm, context->callback);
|
|
wrenReleaseHandle(httpManager->vm, callHandle);
|
|
free_http_context(context);
|
|
}
|
|
}
|
|
|
|
void httpManager_submit(HttpContext* context) {
|
|
http_queue_push(&httpManager->requestQueue, context);
|
|
}
|
|
|
|
// --- Wren Foreign Methods ---
|
|
|
|
void responseFinalize(void* data) {
|
|
ResponseData* response = (ResponseData*)data;
|
|
free(response->body);
|
|
}
|
|
|
|
void responseAllocate(WrenVM* vm) {
|
|
ResponseData* data = (ResponseData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(ResponseData));
|
|
data->isError = 0;
|
|
data->statusCode = 0;
|
|
data->body = NULL;
|
|
data->body_len = 0;
|
|
}
|
|
|
|
void responseIsError(WrenVM* vm) {
|
|
ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0);
|
|
wrenSetSlotBool(vm, 0, data->isError ? true : false);
|
|
}
|
|
|
|
void responseStatusCode(WrenVM* vm) {
|
|
ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0);
|
|
wrenSetSlotDouble(vm, 0, (double)data->statusCode);
|
|
}
|
|
|
|
void responseBody(WrenVM* vm) {
|
|
ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0);
|
|
wrenSetSlotBytes(vm, 0, data->body ? data->body : "", data->body_len);
|
|
}
|
|
|
|
void responseJson(WrenVM* vm) {
|
|
// CORRECTED: Replaced incorrect call with the actual logic.
|
|
ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0);
|
|
wrenSetSlotBytes(vm, 0, data->body ? data->body : "", data->body_len);
|
|
}
|
|
|
|
void requestsGet(WrenVM* vm) {
|
|
HttpContext* context = (HttpContext*)calloc(1, sizeof(HttpContext));
|
|
context->vm = vm;
|
|
context->method = strdup("GET");
|
|
context->url = strdup(wrenGetSlotString(vm, 1));
|
|
context->callback = wrenGetSlotHandle(vm, 3);
|
|
httpManager_submit(context);
|
|
}
|
|
|
|
void requestsPost(WrenVM* vm) {
|
|
HttpContext* context = (HttpContext*)calloc(1, sizeof(HttpContext));
|
|
context->vm = vm;
|
|
context->method = strdup("POST");
|
|
context->url = strdup(wrenGetSlotString(vm, 1));
|
|
context->body = strdup(wrenGetSlotString(vm, 2));
|
|
const char* contentType = wrenGetSlotString(vm, 3);
|
|
char contentTypeHeader[256];
|
|
snprintf(contentTypeHeader, sizeof(contentTypeHeader), "Content-Type: %s", contentType);
|
|
context->headers = curl_slist_append(NULL, contentTypeHeader);
|
|
context->callback = wrenGetSlotHandle(vm, 5);
|
|
httpManager_submit(context);
|
|
}
|
|
|
|
// --- FFI Binding Functions ---
|
|
|
|
WrenForeignMethodFn bindForeignMethod(WrenVM* vm, const char* module,
|
|
const char* className, bool isStatic, const char* signature) {
|
|
if (strcmp(module, "requests") != 0) return NULL;
|
|
|
|
if (strcmp(className, "Requests") == 0 && isStatic) {
|
|
if (strcmp(signature, "get_(_,_,_)") == 0) return requestsGet;
|
|
if (strcmp(signature, "post_(_,_,_,_,_)") == 0) return requestsPost;
|
|
}
|
|
|
|
if (strcmp(className, "Response") == 0 && !isStatic) {
|
|
if (strcmp(signature, "isError") == 0) return responseIsError;
|
|
if (strcmp(signature, "statusCode") == 0) return responseStatusCode;
|
|
if (strcmp(signature, "body") == 0) return responseBody;
|
|
if (strcmp(signature, "json()") == 0) return responseJson;
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
WrenForeignClassMethods bindForeignClass(WrenVM* vm, const char* module, const char* className) {
|
|
WrenForeignClassMethods methods = {0, 0};
|
|
if (strcmp(module, "requests") == 0) {
|
|
if (strcmp(className, "Response") == 0) {
|
|
methods.allocate = responseAllocate;
|
|
methods.finalize = responseFinalize;
|
|
}
|
|
}
|
|
return methods;
|
|
}
|
|
|
|
// End of requests_backend.c
|
|
|
|
// Start of wren.h
|
|
#ifndef wren_h
|
|
#define wren_h
|
|
|
|
#include <stdarg.h>
|
|
#include <stdlib.h>
|
|
#include <stdbool.h>
|
|
|
|
// The Wren semantic version number components.
|
|
#define WREN_VERSION_MAJOR 0
|
|
#define WREN_VERSION_MINOR 4
|
|
#define WREN_VERSION_PATCH 0
|
|
|
|
// A human-friendly string representation of the version.
|
|
#define WREN_VERSION_STRING "0.4.0"
|
|
|
|
// A monotonically increasing numeric representation of the version number. Use
|
|
// this if you want to do range checks over versions.
|
|
#define WREN_VERSION_NUMBER (WREN_VERSION_MAJOR * 1000000 + \
|
|
WREN_VERSION_MINOR * 1000 + \
|
|
WREN_VERSION_PATCH)
|
|
|
|
#ifndef WREN_API
|
|
#if defined(_MSC_VER) && defined(WREN_API_DLLEXPORT)
|
|
#define WREN_API __declspec( dllexport )
|
|
#else
|
|
#define WREN_API
|
|
#endif
|
|
#endif //WREN_API
|
|
|
|
// A single virtual machine for executing Wren code.
|
|
//
|
|
// Wren has no global state, so all state stored by a running interpreter lives
|
|
// here.
|
|
typedef struct WrenVM WrenVM;
|
|
|
|
// A handle to a Wren object.
|
|
//
|
|
// This lets code outside of the VM hold a persistent reference to an object.
|
|
// After a handle is acquired, and until it is released, this ensures the
|
|
// garbage collector will not reclaim the object it references.
|
|
typedef struct WrenHandle WrenHandle;
|
|
|
|
// A generic allocation function that handles all explicit memory management
|
|
// used by Wren. It's used like so:
|
|
//
|
|
// - To allocate new memory, [memory] is NULL and [newSize] is the desired
|
|
// size. It should return the allocated memory or NULL on failure.
|
|
//
|
|
// - To attempt to grow an existing allocation, [memory] is the memory, and
|
|
// [newSize] is the desired size. It should return [memory] if it was able to
|
|
// grow it in place, or a new pointer if it had to move it.
|
|
//
|
|
// - To shrink memory, [memory] and [newSize] are the same as above but it will
|
|
// always return [memory].
|
|
//
|
|
// - To free memory, [memory] will be the memory to free and [newSize] will be
|
|
// zero. It should return NULL.
|
|
typedef void* (*WrenReallocateFn)(void* memory, size_t newSize, void* userData);
|
|
|
|
// A function callable from Wren code, but implemented in C.
|
|
typedef void (*WrenForeignMethodFn)(WrenVM* vm);
|
|
|
|
// A finalizer function for freeing resources owned by an instance of a foreign
|
|
// class. Unlike most foreign methods, finalizers do not have access to the VM
|
|
// and should not interact with it since it's in the middle of a garbage
|
|
// collection.
|
|
typedef void (*WrenFinalizerFn)(void* data);
|
|
|
|
// Gives the host a chance to canonicalize the imported module name,
|
|
// potentially taking into account the (previously resolved) name of the module
|
|
// that contains the import. Typically, this is used to implement relative
|
|
// imports.
|
|
typedef const char* (*WrenResolveModuleFn)(WrenVM* vm,
|
|
const char* importer, const char* name);
|
|
|
|
// Forward declare
|
|
struct WrenLoadModuleResult;
|
|
|
|
// Called after loadModuleFn is called for module [name]. The original returned result
|
|
// is handed back to you in this callback, so that you can free memory if appropriate.
|
|
typedef void (*WrenLoadModuleCompleteFn)(WrenVM* vm, const char* name, struct WrenLoadModuleResult result);
|
|
|
|
// The result of a loadModuleFn call.
|
|
// [source] is the source code for the module, or NULL if the module is not found.
|
|
// [onComplete] an optional callback that will be called once Wren is done with the result.
|
|
typedef struct WrenLoadModuleResult
|
|
{
|
|
const char* source;
|
|
WrenLoadModuleCompleteFn onComplete;
|
|
void* userData;
|
|
} WrenLoadModuleResult;
|
|
|
|
// Loads and returns the source code for the module [name].
|
|
typedef WrenLoadModuleResult (*WrenLoadModuleFn)(WrenVM* vm, const char* name);
|
|
|
|
// Returns a pointer to a foreign method on [className] in [module] with
|
|
// [signature].
|
|
typedef WrenForeignMethodFn (*WrenBindForeignMethodFn)(WrenVM* vm,
|
|
const char* module, const char* className, bool isStatic,
|
|
const char* signature);
|
|
|
|
// Displays a string of text to the user.
|
|
typedef void (*WrenWriteFn)(WrenVM* vm, const char* text);
|
|
|
|
typedef enum
|
|
{
|
|
// A syntax or resolution error detected at compile time.
|
|
WREN_ERROR_COMPILE,
|
|
|
|
// The error message for a runtime error.
|
|
WREN_ERROR_RUNTIME,
|
|
|
|
// One entry of a runtime error's stack trace.
|
|
WREN_ERROR_STACK_TRACE
|
|
} WrenErrorType;
|
|
|
|
// Reports an error to the user.
|
|
//
|
|
// An error detected during compile time is reported by calling this once with
|
|
// [type] `WREN_ERROR_COMPILE`, the resolved name of the [module] and [line]
|
|
// where the error occurs, and the compiler's error [message].
|
|
//
|
|
// A runtime error is reported by calling this once with [type]
|
|
// `WREN_ERROR_RUNTIME`, no [module] or [line], and the runtime error's
|
|
// [message]. After that, a series of [type] `WREN_ERROR_STACK_TRACE` calls are
|
|
// made for each line in the stack trace. Each of those has the resolved
|
|
// [module] and [line] where the method or function is defined and [message] is
|
|
// the name of the method or function.
|
|
typedef void (*WrenErrorFn)(
|
|
WrenVM* vm, WrenErrorType type, const char* module, int line,
|
|
const char* message);
|
|
|
|
typedef struct
|
|
{
|
|
// The callback invoked when the foreign object is created.
|
|
//
|
|
// This must be provided. Inside the body of this, it must call
|
|
// [wrenSetSlotNewForeign()] exactly once.
|
|
WrenForeignMethodFn allocate;
|
|
|
|
// The callback invoked when the garbage collector is about to collect a
|
|
// foreign object's memory.
|
|
//
|
|
// This may be `NULL` if the foreign class does not need to finalize.
|
|
WrenFinalizerFn finalize;
|
|
} WrenForeignClassMethods;
|
|
|
|
// Returns a pair of pointers to the foreign methods used to allocate and
|
|
// finalize the data for instances of [className] in resolved [module].
|
|
typedef WrenForeignClassMethods (*WrenBindForeignClassFn)(
|
|
WrenVM* vm, const char* module, const char* className);
|
|
|
|
typedef struct
|
|
{
|
|
// The callback Wren will use to allocate, reallocate, and deallocate memory.
|
|
//
|
|
// If `NULL`, defaults to a built-in function that uses `realloc` and `free`.
|
|
WrenReallocateFn reallocateFn;
|
|
|
|
// The callback Wren uses to resolve a module name.
|
|
//
|
|
// Some host applications may wish to support "relative" imports, where the
|
|
// meaning of an import string depends on the module that contains it. To
|
|
// support that without baking any policy into Wren itself, the VM gives the
|
|
// host a chance to resolve an import string.
|
|
//
|
|
// Before an import is loaded, it calls this, passing in the name of the
|
|
// module that contains the import and the import string. The host app can
|
|
// look at both of those and produce a new "canonical" string that uniquely
|
|
// identifies the module. This string is then used as the name of the module
|
|
// going forward. It is what is passed to [loadModuleFn], how duplicate
|
|
// imports of the same module are detected, and how the module is reported in
|
|
// stack traces.
|
|
//
|
|
// If you leave this function NULL, then the original import string is
|
|
// treated as the resolved string.
|
|
//
|
|
// If an import cannot be resolved by the embedder, it should return NULL and
|
|
// Wren will report that as a runtime error.
|
|
//
|
|
// Wren will take ownership of the string you return and free it for you, so
|
|
// it should be allocated using the same allocation function you provide
|
|
// above.
|
|
WrenResolveModuleFn resolveModuleFn;
|
|
|
|
// The callback Wren uses to load a module.
|
|
//
|
|
// Since Wren does not talk directly to the file system, it relies on the
|
|
// embedder to physically locate and read the source code for a module. The
|
|
// first time an import appears, Wren will call this and pass in the name of
|
|
// the module being imported. The method will return a result, which contains
|
|
// the source code for that module. Memory for the source is owned by the
|
|
// host application, and can be freed using the onComplete callback.
|
|
//
|
|
// This will only be called once for any given module name. Wren caches the
|
|
// result internally so subsequent imports of the same module will use the
|
|
// previous source and not call this.
|
|
//
|
|
// If a module with the given name could not be found by the embedder, it
|
|
// should return NULL and Wren will report that as a runtime error.
|
|
WrenLoadModuleFn loadModuleFn;
|
|
|
|
// The callback Wren uses to find a foreign method and bind it to a class.
|
|
//
|
|
// When a foreign method is declared in a class, this will be called with the
|
|
// foreign method's module, class, and signature when the class body is
|
|
// executed. It should return a pointer to the foreign function that will be
|
|
// bound to that method.
|
|
//
|
|
// If the foreign function could not be found, this should return NULL and
|
|
// Wren will report it as runtime error.
|
|
WrenBindForeignMethodFn bindForeignMethodFn;
|
|
|
|
// The callback Wren uses to find a foreign class and get its foreign methods.
|
|
//
|
|
// When a foreign class is declared, this will be called with the class's
|
|
// module and name when the class body is executed. It should return the
|
|
// foreign functions uses to allocate and (optionally) finalize the bytes
|
|
// stored in the foreign object when an instance is created.
|
|
WrenBindForeignClassFn bindForeignClassFn;
|
|
|
|
// The callback Wren uses to display text when `System.print()` or the other
|
|
// related functions are called.
|
|
//
|
|
// If this is `NULL`, Wren discards any printed text.
|
|
WrenWriteFn writeFn;
|
|
|
|
// The callback Wren uses to report errors.
|
|
//
|
|
// When an error occurs, this will be called with the module name, line
|
|
// number, and an error message. If this is `NULL`, Wren doesn't report any
|
|
// errors.
|
|
WrenErrorFn errorFn;
|
|
|
|
// The number of bytes Wren will allocate before triggering the first garbage
|
|
// collection.
|
|
//
|
|
// If zero, defaults to 10MB.
|
|
size_t initialHeapSize;
|
|
|
|
// After a collection occurs, the threshold for the next collection is
|
|
// determined based on the number of bytes remaining in use. This allows Wren
|
|
// to shrink its memory usage automatically after reclaiming a large amount
|
|
// of memory.
|
|
//
|
|
// This can be used to ensure that the heap does not get too small, which can
|
|
// in turn lead to a large number of collections afterwards as the heap grows
|
|
// back to a usable size.
|
|
//
|
|
// If zero, defaults to 1MB.
|
|
size_t minHeapSize;
|
|
|
|
// Wren will resize the heap automatically as the number of bytes
|
|
// remaining in use after a collection changes. This number determines the
|
|
// amount of additional memory Wren will use after a collection, as a
|
|
// percentage of the current heap size.
|
|
//
|
|
// For example, say that this is 50. After a garbage collection, when there
|
|
// are 400 bytes of memory still in use, the next collection will be triggered
|
|
// after a total of 600 bytes are allocated (including the 400 already in
|
|
// use.)
|
|
//
|
|
// Setting this to a smaller number wastes less memory, but triggers more
|
|
// frequent garbage collections.
|
|
//
|
|
// If zero, defaults to 50.
|
|
int heapGrowthPercent;
|
|
|
|
// User-defined data associated with the VM.
|
|
void* userData;
|
|
|
|
} WrenConfiguration;
|
|
|
|
typedef enum
|
|
{
|
|
WREN_RESULT_SUCCESS,
|
|
WREN_RESULT_COMPILE_ERROR,
|
|
WREN_RESULT_RUNTIME_ERROR
|
|
} WrenInterpretResult;
|
|
|
|
// The type of an object stored in a slot.
|
|
//
|
|
// This is not necessarily the object's *class*, but instead its low level
|
|
// representation type.
|
|
typedef enum
|
|
{
|
|
WREN_TYPE_BOOL,
|
|
WREN_TYPE_NUM,
|
|
WREN_TYPE_FOREIGN,
|
|
WREN_TYPE_LIST,
|
|
WREN_TYPE_MAP,
|
|
WREN_TYPE_NULL,
|
|
WREN_TYPE_STRING,
|
|
|
|
// The object is of a type that isn't accessible by the C API.
|
|
WREN_TYPE_UNKNOWN
|
|
} WrenType;
|
|
|
|
// Get the current wren version number.
|
|
//
|
|
// Can be used to range checks over versions.
|
|
WREN_API int wrenGetVersionNumber();
|
|
|
|
// Initializes [configuration] with all of its default values.
|
|
//
|
|
// Call this before setting the particular fields you care about.
|
|
WREN_API void wrenInitConfiguration(WrenConfiguration* configuration);
|
|
|
|
// Creates a new Wren virtual machine using the given [configuration]. Wren
|
|
// will copy the configuration data, so the argument passed to this can be
|
|
// freed after calling this. If [configuration] is `NULL`, uses a default
|
|
// configuration.
|
|
WREN_API WrenVM* wrenNewVM(WrenConfiguration* configuration);
|
|
|
|
// Disposes of all resources is use by [vm], which was previously created by a
|
|
// call to [wrenNewVM].
|
|
WREN_API void wrenFreeVM(WrenVM* vm);
|
|
|
|
// Immediately run the garbage collector to free unused memory.
|
|
WREN_API void wrenCollectGarbage(WrenVM* vm);
|
|
|
|
// Runs [source], a string of Wren source code in a new fiber in [vm] in the
|
|
// context of resolved [module].
|
|
WREN_API WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module,
|
|
const char* source);
|
|
|
|
// Creates a handle that can be used to invoke a method with [signature] on
|
|
// using a receiver and arguments that are set up on the stack.
|
|
//
|
|
// This handle can be used repeatedly to directly invoke that method from C
|
|
// code using [wrenCall].
|
|
//
|
|
// When you are done with this handle, it must be released using
|
|
// [wrenReleaseHandle].
|
|
WREN_API WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature);
|
|
|
|
// Calls [method], using the receiver and arguments previously set up on the
|
|
// stack.
|
|
//
|
|
// [method] must have been created by a call to [wrenMakeCallHandle]. The
|
|
// arguments to the method must be already on the stack. The receiver should be
|
|
// in slot 0 with the remaining arguments following it, in order. It is an
|
|
// error if the number of arguments provided does not match the method's
|
|
// signature.
|
|
//
|
|
// After this returns, you can access the return value from slot 0 on the stack.
|
|
WREN_API WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method);
|
|
|
|
// Releases the reference stored in [handle]. After calling this, [handle] can
|
|
// no longer be used.
|
|
WREN_API void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle);
|
|
|
|
// The following functions are intended to be called from foreign methods or
|
|
// finalizers. The interface Wren provides to a foreign method is like a
|
|
// register machine: you are given a numbered array of slots that values can be
|
|
// read from and written to. Values always live in a slot (unless explicitly
|
|
// captured using wrenGetSlotHandle(), which ensures the garbage collector can
|
|
// find them.
|
|
//
|
|
// When your foreign function is called, you are given one slot for the receiver
|
|
// and each argument to the method. The receiver is in slot 0 and the arguments
|
|
// are in increasingly numbered slots after that. You are free to read and
|
|
// write to those slots as you want. If you want more slots to use as scratch
|
|
// space, you can call wrenEnsureSlots() to add more.
|
|
//
|
|
// When your function returns, every slot except slot zero is discarded and the
|
|
// value in slot zero is used as the return value of the method. If you don't
|
|
// store a return value in that slot yourself, it will retain its previous
|
|
// value, the receiver.
|
|
//
|
|
// While Wren is dynamically typed, C is not. This means the C interface has to
|
|
// support the various types of primitive values a Wren variable can hold: bool,
|
|
// double, string, etc. If we supported this for every operation in the C API,
|
|
// there would be a combinatorial explosion of functions, like "get a
|
|
// double-valued element from a list", "insert a string key and double value
|
|
// into a map", etc.
|
|
//
|
|
// To avoid that, the only way to convert to and from a raw C value is by going
|
|
// into and out of a slot. All other functions work with values already in a
|
|
// slot. So, to add an element to a list, you put the list in one slot, and the
|
|
// element in another. Then there is a single API function wrenInsertInList()
|
|
// that takes the element out of that slot and puts it into the list.
|
|
//
|
|
// The goal of this API is to be easy to use while not compromising performance.
|
|
// The latter means it does not do type or bounds checking at runtime except
|
|
// using assertions which are generally removed from release builds. C is an
|
|
// unsafe language, so it's up to you to be careful to use it correctly. In
|
|
// return, you get a very fast FFI.
|
|
|
|
// Returns the number of slots available to the current foreign method.
|
|
WREN_API int wrenGetSlotCount(WrenVM* vm);
|
|
|
|
// Ensures that the foreign method stack has at least [numSlots] available for
|
|
// use, growing the stack if needed.
|
|
//
|
|
// Does not shrink the stack if it has more than enough slots.
|
|
//
|
|
// It is an error to call this from a finalizer.
|
|
WREN_API void wrenEnsureSlots(WrenVM* vm, int numSlots);
|
|
|
|
// Gets the type of the object in [slot].
|
|
WREN_API WrenType wrenGetSlotType(WrenVM* vm, int slot);
|
|
|
|
// Reads a boolean value from [slot].
|
|
//
|
|
// It is an error to call this if the slot does not contain a boolean value.
|
|
WREN_API bool wrenGetSlotBool(WrenVM* vm, int slot);
|
|
|
|
// Reads a byte array from [slot].
|
|
//
|
|
// The memory for the returned string is owned by Wren. You can inspect it
|
|
// while in your foreign method, but cannot keep a pointer to it after the
|
|
// function returns, since the garbage collector may reclaim it.
|
|
//
|
|
// Returns a pointer to the first byte of the array and fill [length] with the
|
|
// number of bytes in the array.
|
|
//
|
|
// It is an error to call this if the slot does not contain a string.
|
|
WREN_API const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length);
|
|
|
|
// Reads a number from [slot].
|
|
//
|
|
// It is an error to call this if the slot does not contain a number.
|
|
WREN_API double wrenGetSlotDouble(WrenVM* vm, int slot);
|
|
|
|
// Reads a foreign object from [slot] and returns a pointer to the foreign data
|
|
// stored with it.
|
|
//
|
|
// It is an error to call this if the slot does not contain an instance of a
|
|
// foreign class.
|
|
WREN_API void* wrenGetSlotForeign(WrenVM* vm, int slot);
|
|
|
|
// Reads a string from [slot].
|
|
//
|
|
// The memory for the returned string is owned by Wren. You can inspect it
|
|
// while in your foreign method, but cannot keep a pointer to it after the
|
|
// function returns, since the garbage collector may reclaim it.
|
|
//
|
|
// It is an error to call this if the slot does not contain a string.
|
|
WREN_API const char* wrenGetSlotString(WrenVM* vm, int slot);
|
|
|
|
// Creates a handle for the value stored in [slot].
|
|
//
|
|
// This will prevent the object that is referred to from being garbage collected
|
|
// until the handle is released by calling [wrenReleaseHandle()].
|
|
WREN_API WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot);
|
|
|
|
// Stores the boolean [value] in [slot].
|
|
WREN_API void wrenSetSlotBool(WrenVM* vm, int slot, bool value);
|
|
|
|
// Stores the array [length] of [bytes] in [slot].
|
|
//
|
|
// The bytes are copied to a new string within Wren's heap, so you can free
|
|
// memory used by them after this is called.
|
|
WREN_API void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length);
|
|
|
|
// Stores the numeric [value] in [slot].
|
|
WREN_API void wrenSetSlotDouble(WrenVM* vm, int slot, double value);
|
|
|
|
// Creates a new instance of the foreign class stored in [classSlot] with [size]
|
|
// bytes of raw storage and places the resulting object in [slot].
|
|
//
|
|
// This does not invoke the foreign class's constructor on the new instance. If
|
|
// you need that to happen, call the constructor from Wren, which will then
|
|
// call the allocator foreign method. In there, call this to create the object
|
|
// and then the constructor will be invoked when the allocator returns.
|
|
//
|
|
// Returns a pointer to the foreign object's data.
|
|
WREN_API void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size);
|
|
|
|
// Stores a new empty list in [slot].
|
|
WREN_API void wrenSetSlotNewList(WrenVM* vm, int slot);
|
|
|
|
// Stores a new empty map in [slot].
|
|
WREN_API void wrenSetSlotNewMap(WrenVM* vm, int slot);
|
|
|
|
// Stores null in [slot].
|
|
WREN_API void wrenSetSlotNull(WrenVM* vm, int slot);
|
|
|
|
// Stores the string [text] in [slot].
|
|
//
|
|
// The [text] is copied to a new string within Wren's heap, so you can free
|
|
// memory used by it after this is called. The length is calculated using
|
|
// [strlen()]. If the string may contain any null bytes in the middle, then you
|
|
// should use [wrenSetSlotBytes()] instead.
|
|
WREN_API void wrenSetSlotString(WrenVM* vm, int slot, const char* text);
|
|
|
|
// Stores the value captured in [handle] in [slot].
|
|
//
|
|
// This does not release the handle for the value.
|
|
WREN_API void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle);
|
|
|
|
// Returns the number of elements in the list stored in [slot].
|
|
WREN_API int wrenGetListCount(WrenVM* vm, int slot);
|
|
|
|
// Reads element [index] from the list in [listSlot] and stores it in
|
|
// [elementSlot].
|
|
WREN_API void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot);
|
|
|
|
// Sets the value stored at [index] in the list at [listSlot],
|
|
// to the value from [elementSlot].
|
|
WREN_API void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot);
|
|
|
|
// Takes the value stored at [elementSlot] and inserts it into the list stored
|
|
// at [listSlot] at [index].
|
|
//
|
|
// As in Wren, negative indexes can be used to insert from the end. To append
|
|
// an element, use `-1` for the index.
|
|
WREN_API void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot);
|
|
|
|
// Returns the number of entries in the map stored in [slot].
|
|
WREN_API int wrenGetMapCount(WrenVM* vm, int slot);
|
|
|
|
// Returns true if the key in [keySlot] is found in the map placed in [mapSlot].
|
|
WREN_API bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot);
|
|
|
|
// Retrieves a value with the key in [keySlot] from the map in [mapSlot] and
|
|
// stores it in [valueSlot].
|
|
WREN_API void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot);
|
|
|
|
// Takes the value stored at [valueSlot] and inserts it into the map stored
|
|
// at [mapSlot] with key [keySlot].
|
|
WREN_API void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot);
|
|
|
|
// Removes a value from the map in [mapSlot], with the key from [keySlot],
|
|
// and place it in [removedValueSlot]. If not found, [removedValueSlot] is
|
|
// set to null, the same behaviour as the Wren Map API.
|
|
WREN_API void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot,
|
|
int removedValueSlot);
|
|
|
|
// Looks up the top level variable with [name] in resolved [module] and stores
|
|
// it in [slot].
|
|
WREN_API void wrenGetVariable(WrenVM* vm, const char* module, const char* name,
|
|
int slot);
|
|
|
|
// Looks up the top level variable with [name] in resolved [module],
|
|
// returns false if not found. The module must be imported at the time,
|
|
// use wrenHasModule to ensure that before calling.
|
|
WREN_API bool wrenHasVariable(WrenVM* vm, const char* module, const char* name);
|
|
|
|
// Returns true if [module] has been imported/resolved before, false if not.
|
|
WREN_API bool wrenHasModule(WrenVM* vm, const char* module);
|
|
|
|
// Sets the current fiber to be aborted, and uses the value in [slot] as the
|
|
// runtime error object.
|
|
WREN_API void wrenAbortFiber(WrenVM* vm, int slot);
|
|
|
|
// Returns the user data associated with the WrenVM.
|
|
WREN_API void* wrenGetUserData(WrenVM* vm);
|
|
|
|
// Sets user data associated with the WrenVM.
|
|
WREN_API void wrenSetUserData(WrenVM* vm, void* userData);
|
|
|
|
#endif
|
|
|
|
// End of wren.h
|
|
|
|
// Start of async_http.c
|
|
#include "httplib.h"
|
|
#include "wren.h"
|
|
|
|
// A struct to hold the context for an asynchronous HTTP request
|
|
struct RequestContext {
|
|
std::string url;
|
|
WrenHandle* callback;
|
|
WrenVM* vm;
|
|
std::string response;
|
|
bool error;
|
|
};
|
|
|
|
// A class to manage asynchronous HTTP requests
|
|
class AsyncHttp {
|
|
public:
|
|
AsyncHttp(WrenVM* vm) : vm_(vm), running_(true) {
|
|
// Create a pool of worker threads
|
|
for (int i = 0; i < 4; ++i) {
|
|
threads_.emplace_back([this] {
|
|
while (running_) {
|
|
RequestContext* context = requestQueue_.pop();
|
|
if (!running_) break;
|
|
|
|
httplib::Client cli("http://example.com");
|
|
if (auto res = cli.Get(context->url.c_str())) {
|
|
context->response = res->body;
|
|
context->error = false;
|
|
} else {
|
|
context->response = "Error: " + to_string(res.error());
|
|
context->error = true;
|
|
}
|
|
|
|
completionQueue_.push(context);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
~AsyncHttp() {
|
|
running_ = false;
|
|
// Add dummy requests to unblock worker threads
|
|
for (size_t i = 0; i < threads_.size(); ++i) {
|
|
requestQueue_.push(nullptr);
|
|
}
|
|
for (auto& thread : threads_) {
|
|
thread.join();
|
|
}
|
|
}
|
|
|
|
void request(const std::string& url, WrenHandle* callback) {
|
|
RequestContext* context = new RequestContext{url, callback, vm_};
|
|
requestQueue_.push(context);
|
|
}
|
|
|
|
void processCompletions() {
|
|
while (!completionQueue_.empty()) {
|
|
RequestContext* context = completionQueue_.pop();
|
|
|
|
// Create a handle for the callback function
|
|
WrenHandle* callHandle = wrenMakeCallHandle(vm_, "call(_)");
|
|
|
|
wrenEnsureSlots(vm_, 2);
|
|
wrenSetSlotHandle(vm_, 0, context->callback);
|
|
wrenSetSlotString(vm_, 1, context->response.c_str());
|
|
wrenCall(vm_, callHandle);
|
|
|
|
wrenReleaseHandle(vm_, callHandle);
|
|
wrenReleaseHandle(vm_, context->callback);
|
|
delete context;
|
|
}
|
|
}
|
|
|
|
private:
|
|
WrenVM* vm_;
|
|
bool running_;
|
|
std::vector<std::thread> threads_;
|
|
ThreadSafeQueue<RequestContext*> requestQueue_;
|
|
ThreadSafeQueue<RequestContext*> completionQueue_;
|
|
};
|
|
|
|
// End of async_http.c
|
|
|
|
// Start of wren.c
|
|
// MIT License
|
|
//
|
|
// Copyright (c) 2013-2021 Robert Nystrom and Wren Contributors
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
// of this software and associated documentation files (the "Software"), to deal
|
|
// in the Software without restriction, including without limitation the rights
|
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
// copies of the Software, and to permit persons to whom the Software is
|
|
// furnished to do so, subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in all
|
|
// copies or substantial portions of the Software.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
// SOFTWARE.
|
|
|
|
// Begin file "wren.h"
|
|
#ifndef wren_h
|
|
#define wren_h
|
|
|
|
#include <stdarg.h>
|
|
#include <stdlib.h>
|
|
#include <stdbool.h>
|
|
|
|
// The Wren semantic version number components.
|
|
#define WREN_VERSION_MAJOR 0
|
|
#define WREN_VERSION_MINOR 4
|
|
#define WREN_VERSION_PATCH 0
|
|
|
|
// A human-friendly string representation of the version.
|
|
#define WREN_VERSION_STRING "0.4.0"
|
|
|
|
// A monotonically increasing numeric representation of the version number. Use
|
|
// this if you want to do range checks over versions.
|
|
#define WREN_VERSION_NUMBER (WREN_VERSION_MAJOR * 1000000 + \
|
|
WREN_VERSION_MINOR * 1000 + \
|
|
WREN_VERSION_PATCH)
|
|
|
|
#ifndef WREN_API
|
|
#if defined(_MSC_VER) && defined(WREN_API_DLLEXPORT)
|
|
#define WREN_API __declspec( dllexport )
|
|
#else
|
|
#define WREN_API
|
|
#endif
|
|
#endif //WREN_API
|
|
|
|
// A single virtual machine for executing Wren code.
|
|
//
|
|
// Wren has no global state, so all state stored by a running interpreter lives
|
|
// here.
|
|
typedef struct WrenVM WrenVM;
|
|
|
|
// A handle to a Wren object.
|
|
//
|
|
// This lets code outside of the VM hold a persistent reference to an object.
|
|
// After a handle is acquired, and until it is released, this ensures the
|
|
// garbage collector will not reclaim the object it references.
|
|
typedef struct WrenHandle WrenHandle;
|
|
|
|
// A generic allocation function that handles all explicit memory management
|
|
// used by Wren. It's used like so:
|
|
//
|
|
// - To allocate new memory, [memory] is NULL and [newSize] is the desired
|
|
// size. It should return the allocated memory or NULL on failure.
|
|
//
|
|
// - To attempt to grow an existing allocation, [memory] is the memory, and
|
|
// [newSize] is the desired size. It should return [memory] if it was able to
|
|
// grow it in place, or a new pointer if it had to move it.
|
|
//
|
|
// - To shrink memory, [memory] and [newSize] are the same as above but it will
|
|
// always return [memory].
|
|
//
|
|
// - To free memory, [memory] will be the memory to free and [newSize] will be
|
|
// zero. It should return NULL.
|
|
typedef void* (*WrenReallocateFn)(void* memory, size_t newSize, void* userData);
|
|
|
|
// A function callable from Wren code, but implemented in C.
|
|
typedef void (*WrenForeignMethodFn)(WrenVM* vm);
|
|
|
|
// A finalizer function for freeing resources owned by an instance of a foreign
|
|
// class. Unlike most foreign methods, finalizers do not have access to the VM
|
|
// and should not interact with it since it's in the middle of a garbage
|
|
// collection.
|
|
typedef void (*WrenFinalizerFn)(void* data);
|
|
|
|
// Gives the host a chance to canonicalize the imported module name,
|
|
// potentially taking into account the (previously resolved) name of the module
|
|
// that contains the import. Typically, this is used to implement relative
|
|
// imports.
|
|
typedef const char* (*WrenResolveModuleFn)(WrenVM* vm,
|
|
const char* importer, const char* name);
|
|
|
|
// Forward declare
|
|
struct WrenLoadModuleResult;
|
|
|
|
// Called after loadModuleFn is called for module [name]. The original returned result
|
|
// is handed back to you in this callback, so that you can free memory if appropriate.
|
|
typedef void (*WrenLoadModuleCompleteFn)(WrenVM* vm, const char* name, struct WrenLoadModuleResult result);
|
|
|
|
// The result of a loadModuleFn call.
|
|
// [source] is the source code for the module, or NULL if the module is not found.
|
|
// [onComplete] an optional callback that will be called once Wren is done with the result.
|
|
typedef struct WrenLoadModuleResult
|
|
{
|
|
const char* source;
|
|
WrenLoadModuleCompleteFn onComplete;
|
|
void* userData;
|
|
} WrenLoadModuleResult;
|
|
|
|
// Loads and returns the source code for the module [name].
|
|
typedef WrenLoadModuleResult (*WrenLoadModuleFn)(WrenVM* vm, const char* name);
|
|
|
|
// Returns a pointer to a foreign method on [className] in [module] with
|
|
// [signature].
|
|
typedef WrenForeignMethodFn (*WrenBindForeignMethodFn)(WrenVM* vm,
|
|
const char* module, const char* className, bool isStatic,
|
|
const char* signature);
|
|
|
|
// Displays a string of text to the user.
|
|
typedef void (*WrenWriteFn)(WrenVM* vm, const char* text);
|
|
|
|
typedef enum
|
|
{
|
|
// A syntax or resolution error detected at compile time.
|
|
WREN_ERROR_COMPILE,
|
|
|
|
// The error message for a runtime error.
|
|
WREN_ERROR_RUNTIME,
|
|
|
|
// One entry of a runtime error's stack trace.
|
|
WREN_ERROR_STACK_TRACE
|
|
} WrenErrorType;
|
|
|
|
// Reports an error to the user.
|
|
//
|
|
// An error detected during compile time is reported by calling this once with
|
|
// [type] `WREN_ERROR_COMPILE`, the resolved name of the [module] and [line]
|
|
// where the error occurs, and the compiler's error [message].
|
|
//
|
|
// A runtime error is reported by calling this once with [type]
|
|
// `WREN_ERROR_RUNTIME`, no [module] or [line], and the runtime error's
|
|
// [message]. After that, a series of [type] `WREN_ERROR_STACK_TRACE` calls are
|
|
// made for each line in the stack trace. Each of those has the resolved
|
|
// [module] and [line] where the method or function is defined and [message] is
|
|
// the name of the method or function.
|
|
typedef void (*WrenErrorFn)(
|
|
WrenVM* vm, WrenErrorType type, const char* module, int line,
|
|
const char* message);
|
|
|
|
typedef struct
|
|
{
|
|
// The callback invoked when the foreign object is created.
|
|
//
|
|
// This must be provided. Inside the body of this, it must call
|
|
// [wrenSetSlotNewForeign()] exactly once.
|
|
WrenForeignMethodFn allocate;
|
|
|
|
// The callback invoked when the garbage collector is about to collect a
|
|
// foreign object's memory.
|
|
//
|
|
// This may be `NULL` if the foreign class does not need to finalize.
|
|
WrenFinalizerFn finalize;
|
|
} WrenForeignClassMethods;
|
|
|
|
// Returns a pair of pointers to the foreign methods used to allocate and
|
|
// finalize the data for instances of [className] in resolved [module].
|
|
typedef WrenForeignClassMethods (*WrenBindForeignClassFn)(
|
|
WrenVM* vm, const char* module, const char* className);
|
|
|
|
typedef struct
|
|
{
|
|
// The callback Wren will use to allocate, reallocate, and deallocate memory.
|
|
//
|
|
// If `NULL`, defaults to a built-in function that uses `realloc` and `free`.
|
|
WrenReallocateFn reallocateFn;
|
|
|
|
// The callback Wren uses to resolve a module name.
|
|
//
|
|
// Some host applications may wish to support "relative" imports, where the
|
|
// meaning of an import string depends on the module that contains it. To
|
|
// support that without baking any policy into Wren itself, the VM gives the
|
|
// host a chance to resolve an import string.
|
|
//
|
|
// Before an import is loaded, it calls this, passing in the name of the
|
|
// module that contains the import and the import string. The host app can
|
|
// look at both of those and produce a new "canonical" string that uniquely
|
|
// identifies the module. This string is then used as the name of the module
|
|
// going forward. It is what is passed to [loadModuleFn], how duplicate
|
|
// imports of the same module are detected, and how the module is reported in
|
|
// stack traces.
|
|
//
|
|
// If you leave this function NULL, then the original import string is
|
|
// treated as the resolved string.
|
|
//
|
|
// If an import cannot be resolved by the embedder, it should return NULL and
|
|
// Wren will report that as a runtime error.
|
|
//
|
|
// Wren will take ownership of the string you return and free it for you, so
|
|
// it should be allocated using the same allocation function you provide
|
|
// above.
|
|
WrenResolveModuleFn resolveModuleFn;
|
|
|
|
// The callback Wren uses to load a module.
|
|
//
|
|
// Since Wren does not talk directly to the file system, it relies on the
|
|
// embedder to physically locate and read the source code for a module. The
|
|
// first time an import appears, Wren will call this and pass in the name of
|
|
// the module being imported. The method will return a result, which contains
|
|
// the source code for that module. Memory for the source is owned by the
|
|
// host application, and can be freed using the onComplete callback.
|
|
//
|
|
// This will only be called once for any given module name. Wren caches the
|
|
// result internally so subsequent imports of the same module will use the
|
|
// previous source and not call this.
|
|
//
|
|
// If a module with the given name could not be found by the embedder, it
|
|
// should return NULL and Wren will report that as a runtime error.
|
|
WrenLoadModuleFn loadModuleFn;
|
|
|
|
// The callback Wren uses to find a foreign method and bind it to a class.
|
|
//
|
|
// When a foreign method is declared in a class, this will be called with the
|
|
// foreign method's module, class, and signature when the class body is
|
|
// executed. It should return a pointer to the foreign function that will be
|
|
// bound to that method.
|
|
//
|
|
// If the foreign function could not be found, this should return NULL and
|
|
// Wren will report it as runtime error.
|
|
WrenBindForeignMethodFn bindForeignMethodFn;
|
|
|
|
// The callback Wren uses to find a foreign class and get its foreign methods.
|
|
//
|
|
// When a foreign class is declared, this will be called with the class's
|
|
// module and name when the class body is executed. It should return the
|
|
// foreign functions uses to allocate and (optionally) finalize the bytes
|
|
// stored in the foreign object when an instance is created.
|
|
WrenBindForeignClassFn bindForeignClassFn;
|
|
|
|
// The callback Wren uses to display text when `System.print()` or the other
|
|
// related functions are called.
|
|
//
|
|
// If this is `NULL`, Wren discards any printed text.
|
|
WrenWriteFn writeFn;
|
|
|
|
// The callback Wren uses to report errors.
|
|
//
|
|
// When an error occurs, this will be called with the module name, line
|
|
// number, and an error message. If this is `NULL`, Wren doesn't report any
|
|
// errors.
|
|
WrenErrorFn errorFn;
|
|
|
|
// The number of bytes Wren will allocate before triggering the first garbage
|
|
// collection.
|
|
//
|
|
// If zero, defaults to 10MB.
|
|
size_t initialHeapSize;
|
|
|
|
// After a collection occurs, the threshold for the next collection is
|
|
// determined based on the number of bytes remaining in use. This allows Wren
|
|
// to shrink its memory usage automatically after reclaiming a large amount
|
|
// of memory.
|
|
//
|
|
// This can be used to ensure that the heap does not get too small, which can
|
|
// in turn lead to a large number of collections afterwards as the heap grows
|
|
// back to a usable size.
|
|
//
|
|
// If zero, defaults to 1MB.
|
|
size_t minHeapSize;
|
|
|
|
// Wren will resize the heap automatically as the number of bytes
|
|
// remaining in use after a collection changes. This number determines the
|
|
// amount of additional memory Wren will use after a collection, as a
|
|
// percentage of the current heap size.
|
|
//
|
|
// For example, say that this is 50. After a garbage collection, when there
|
|
// are 400 bytes of memory still in use, the next collection will be triggered
|
|
// after a total of 600 bytes are allocated (including the 400 already in
|
|
// use.)
|
|
//
|
|
// Setting this to a smaller number wastes less memory, but triggers more
|
|
// frequent garbage collections.
|
|
//
|
|
// If zero, defaults to 50.
|
|
int heapGrowthPercent;
|
|
|
|
// User-defined data associated with the VM.
|
|
void* userData;
|
|
|
|
} WrenConfiguration;
|
|
|
|
typedef enum
|
|
{
|
|
WREN_RESULT_SUCCESS,
|
|
WREN_RESULT_COMPILE_ERROR,
|
|
WREN_RESULT_RUNTIME_ERROR
|
|
} WrenInterpretResult;
|
|
|
|
// The type of an object stored in a slot.
|
|
//
|
|
// This is not necessarily the object's *class*, but instead its low level
|
|
// representation type.
|
|
typedef enum
|
|
{
|
|
WREN_TYPE_BOOL,
|
|
WREN_TYPE_NUM,
|
|
WREN_TYPE_FOREIGN,
|
|
WREN_TYPE_LIST,
|
|
WREN_TYPE_MAP,
|
|
WREN_TYPE_NULL,
|
|
WREN_TYPE_STRING,
|
|
|
|
// The object is of a type that isn't accessible by the C API.
|
|
WREN_TYPE_UNKNOWN
|
|
} WrenType;
|
|
|
|
// Get the current wren version number.
|
|
//
|
|
// Can be used to range checks over versions.
|
|
WREN_API int wrenGetVersionNumber();
|
|
|
|
// Initializes [configuration] with all of its default values.
|
|
//
|
|
// Call this before setting the particular fields you care about.
|
|
WREN_API void wrenInitConfiguration(WrenConfiguration* configuration);
|
|
|
|
// Creates a new Wren virtual machine using the given [configuration]. Wren
|
|
// will copy the configuration data, so the argument passed to this can be
|
|
// freed after calling this. If [configuration] is `NULL`, uses a default
|
|
// configuration.
|
|
WREN_API WrenVM* wrenNewVM(WrenConfiguration* configuration);
|
|
|
|
// Disposes of all resources is use by [vm], which was previously created by a
|
|
// call to [wrenNewVM].
|
|
WREN_API void wrenFreeVM(WrenVM* vm);
|
|
|
|
// Immediately run the garbage collector to free unused memory.
|
|
WREN_API void wrenCollectGarbage(WrenVM* vm);
|
|
|
|
// Runs [source], a string of Wren source code in a new fiber in [vm] in the
|
|
// context of resolved [module].
|
|
WREN_API WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module,
|
|
const char* source);
|
|
|
|
// Creates a handle that can be used to invoke a method with [signature] on
|
|
// using a receiver and arguments that are set up on the stack.
|
|
//
|
|
// This handle can be used repeatedly to directly invoke that method from C
|
|
// code using [wrenCall].
|
|
//
|
|
// When you are done with this handle, it must be released using
|
|
// [wrenReleaseHandle].
|
|
WREN_API WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature);
|
|
|
|
// Calls [method], using the receiver and arguments previously set up on the
|
|
// stack.
|
|
//
|
|
// [method] must have been created by a call to [wrenMakeCallHandle]. The
|
|
// arguments to the method must be already on the stack. The receiver should be
|
|
// in slot 0 with the remaining arguments following it, in order. It is an
|
|
// error if the number of arguments provided does not match the method's
|
|
// signature.
|
|
//
|
|
// After this returns, you can access the return value from slot 0 on the stack.
|
|
WREN_API WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method);
|
|
|
|
// Releases the reference stored in [handle]. After calling this, [handle] can
|
|
// no longer be used.
|
|
WREN_API void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle);
|
|
|
|
// The following functions are intended to be called from foreign methods or
|
|
// finalizers. The interface Wren provides to a foreign method is like a
|
|
// register machine: you are given a numbered array of slots that values can be
|
|
// read from and written to. Values always live in a slot (unless explicitly
|
|
// captured using wrenGetSlotHandle(), which ensures the garbage collector can
|
|
// find them.
|
|
//
|
|
// When your foreign function is called, you are given one slot for the receiver
|
|
// and each argument to the method. The receiver is in slot 0 and the arguments
|
|
// are in increasingly numbered slots after that. You are free to read and
|
|
// write to those slots as you want. If you want more slots to use as scratch
|
|
// space, you can call wrenEnsureSlots() to add more.
|
|
//
|
|
// When your function returns, every slot except slot zero is discarded and the
|
|
// value in slot zero is used as the return value of the method. If you don't
|
|
// store a return value in that slot yourself, it will retain its previous
|
|
// value, the receiver.
|
|
//
|
|
// While Wren is dynamically typed, C is not. This means the C interface has to
|
|
// support the various types of primitive values a Wren variable can hold: bool,
|
|
// double, string, etc. If we supported this for every operation in the C API,
|
|
// there would be a combinatorial explosion of functions, like "get a
|
|
// double-valued element from a list", "insert a string key and double value
|
|
// into a map", etc.
|
|
//
|
|
// To avoid that, the only way to convert to and from a raw C value is by going
|
|
// into and out of a slot. All other functions work with values already in a
|
|
// slot. So, to add an element to a list, you put the list in one slot, and the
|
|
// element in another. Then there is a single API function wrenInsertInList()
|
|
// that takes the element out of that slot and puts it into the list.
|
|
//
|
|
// The goal of this API is to be easy to use while not compromising performance.
|
|
// The latter means it does not do type or bounds checking at runtime except
|
|
// using assertions which are generally removed from release builds. C is an
|
|
// unsafe language, so it's up to you to be careful to use it correctly. In
|
|
// return, you get a very fast FFI.
|
|
|
|
// Returns the number of slots available to the current foreign method.
|
|
WREN_API int wrenGetSlotCount(WrenVM* vm);
|
|
|
|
// Ensures that the foreign method stack has at least [numSlots] available for
|
|
// use, growing the stack if needed.
|
|
//
|
|
// Does not shrink the stack if it has more than enough slots.
|
|
//
|
|
// It is an error to call this from a finalizer.
|
|
WREN_API void wrenEnsureSlots(WrenVM* vm, int numSlots);
|
|
|
|
// Gets the type of the object in [slot].
|
|
WREN_API WrenType wrenGetSlotType(WrenVM* vm, int slot);
|
|
|
|
// Reads a boolean value from [slot].
|
|
//
|
|
// It is an error to call this if the slot does not contain a boolean value.
|
|
WREN_API bool wrenGetSlotBool(WrenVM* vm, int slot);
|
|
|
|
// Reads a byte array from [slot].
|
|
//
|
|
// The memory for the returned string is owned by Wren. You can inspect it
|
|
// while in your foreign method, but cannot keep a pointer to it after the
|
|
// function returns, since the garbage collector may reclaim it.
|
|
//
|
|
// Returns a pointer to the first byte of the array and fill [length] with the
|
|
// number of bytes in the array.
|
|
//
|
|
// It is an error to call this if the slot does not contain a string.
|
|
WREN_API const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length);
|
|
|
|
// Reads a number from [slot].
|
|
//
|
|
// It is an error to call this if the slot does not contain a number.
|
|
WREN_API double wrenGetSlotDouble(WrenVM* vm, int slot);
|
|
|
|
// Reads a foreign object from [slot] and returns a pointer to the foreign data
|
|
// stored with it.
|
|
//
|
|
// It is an error to call this if the slot does not contain an instance of a
|
|
// foreign class.
|
|
WREN_API void* wrenGetSlotForeign(WrenVM* vm, int slot);
|
|
|
|
// Reads a string from [slot].
|
|
//
|
|
// The memory for the returned string is owned by Wren. You can inspect it
|
|
// while in your foreign method, but cannot keep a pointer to it after the
|
|
// function returns, since the garbage collector may reclaim it.
|
|
//
|
|
// It is an error to call this if the slot does not contain a string.
|
|
WREN_API const char* wrenGetSlotString(WrenVM* vm, int slot);
|
|
|
|
// Creates a handle for the value stored in [slot].
|
|
//
|
|
// This will prevent the object that is referred to from being garbage collected
|
|
// until the handle is released by calling [wrenReleaseHandle()].
|
|
WREN_API WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot);
|
|
|
|
// Stores the boolean [value] in [slot].
|
|
WREN_API void wrenSetSlotBool(WrenVM* vm, int slot, bool value);
|
|
|
|
// Stores the array [length] of [bytes] in [slot].
|
|
//
|
|
// The bytes are copied to a new string within Wren's heap, so you can free
|
|
// memory used by them after this is called.
|
|
WREN_API void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length);
|
|
|
|
// Stores the numeric [value] in [slot].
|
|
WREN_API void wrenSetSlotDouble(WrenVM* vm, int slot, double value);
|
|
|
|
// Creates a new instance of the foreign class stored in [classSlot] with [size]
|
|
// bytes of raw storage and places the resulting object in [slot].
|
|
//
|
|
// This does not invoke the foreign class's constructor on the new instance. If
|
|
// you need that to happen, call the constructor from Wren, which will then
|
|
// call the allocator foreign method. In there, call this to create the object
|
|
// and then the constructor will be invoked when the allocator returns.
|
|
//
|
|
// Returns a pointer to the foreign object's data.
|
|
WREN_API void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size);
|
|
|
|
// Stores a new empty list in [slot].
|
|
WREN_API void wrenSetSlotNewList(WrenVM* vm, int slot);
|
|
|
|
// Stores a new empty map in [slot].
|
|
WREN_API void wrenSetSlotNewMap(WrenVM* vm, int slot);
|
|
|
|
// Stores null in [slot].
|
|
WREN_API void wrenSetSlotNull(WrenVM* vm, int slot);
|
|
|
|
// Stores the string [text] in [slot].
|
|
//
|
|
// The [text] is copied to a new string within Wren's heap, so you can free
|
|
// memory used by it after this is called. The length is calculated using
|
|
// [strlen()]. If the string may contain any null bytes in the middle, then you
|
|
// should use [wrenSetSlotBytes()] instead.
|
|
WREN_API void wrenSetSlotString(WrenVM* vm, int slot, const char* text);
|
|
|
|
// Stores the value captured in [handle] in [slot].
|
|
//
|
|
// This does not release the handle for the value.
|
|
WREN_API void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle);
|
|
|
|
// Returns the number of elements in the list stored in [slot].
|
|
WREN_API int wrenGetListCount(WrenVM* vm, int slot);
|
|
|
|
// Reads element [index] from the list in [listSlot] and stores it in
|
|
// [elementSlot].
|
|
WREN_API void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot);
|
|
|
|
// Sets the value stored at [index] in the list at [listSlot],
|
|
// to the value from [elementSlot].
|
|
WREN_API void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot);
|
|
|
|
// Takes the value stored at [elementSlot] and inserts it into the list stored
|
|
// at [listSlot] at [index].
|
|
//
|
|
// As in Wren, negative indexes can be used to insert from the end. To append
|
|
// an element, use `-1` for the index.
|
|
WREN_API void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot);
|
|
|
|
// Returns the number of entries in the map stored in [slot].
|
|
WREN_API int wrenGetMapCount(WrenVM* vm, int slot);
|
|
|
|
// Returns true if the key in [keySlot] is found in the map placed in [mapSlot].
|
|
WREN_API bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot);
|
|
|
|
// Retrieves a value with the key in [keySlot] from the map in [mapSlot] and
|
|
// stores it in [valueSlot].
|
|
WREN_API void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot);
|
|
|
|
// Takes the value stored at [valueSlot] and inserts it into the map stored
|
|
// at [mapSlot] with key [keySlot].
|
|
WREN_API void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot);
|
|
|
|
// Removes a value from the map in [mapSlot], with the key from [keySlot],
|
|
// and place it in [removedValueSlot]. If not found, [removedValueSlot] is
|
|
// set to null, the same behaviour as the Wren Map API.
|
|
WREN_API void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot,
|
|
int removedValueSlot);
|
|
|
|
// Looks up the top level variable with [name] in resolved [module] and stores
|
|
// it in [slot].
|
|
WREN_API void wrenGetVariable(WrenVM* vm, const char* module, const char* name,
|
|
int slot);
|
|
|
|
// Looks up the top level variable with [name] in resolved [module],
|
|
// returns false if not found. The module must be imported at the time,
|
|
// use wrenHasModule to ensure that before calling.
|
|
WREN_API bool wrenHasVariable(WrenVM* vm, const char* module, const char* name);
|
|
|
|
// Returns true if [module] has been imported/resolved before, false if not.
|
|
WREN_API bool wrenHasModule(WrenVM* vm, const char* module);
|
|
|
|
// Sets the current fiber to be aborted, and uses the value in [slot] as the
|
|
// runtime error object.
|
|
WREN_API void wrenAbortFiber(WrenVM* vm, int slot);
|
|
|
|
// Returns the user data associated with the WrenVM.
|
|
WREN_API void* wrenGetUserData(WrenVM* vm);
|
|
|
|
// Sets user data associated with the WrenVM.
|
|
WREN_API void wrenSetUserData(WrenVM* vm, void* userData);
|
|
|
|
#endif
|
|
// End file "wren.h"
|
|
// Begin file "wren_debug.h"
|
|
#ifndef wren_debug_h
|
|
#define wren_debug_h
|
|
|
|
// Begin file "wren_value.h"
|
|
#ifndef wren_value_h
|
|
#define wren_value_h
|
|
|
|
#include <stdbool.h>
|
|
#include <string.h>
|
|
|
|
// Begin file "wren_common.h"
|
|
#ifndef wren_common_h
|
|
#define wren_common_h
|
|
|
|
// This header contains macros and defines used across the entire Wren
|
|
// implementation. In particular, it contains "configuration" defines that
|
|
// control how Wren works. Some of these are only used while hacking on Wren
|
|
// itself.
|
|
//
|
|
// This header is *not* intended to be included by code outside of Wren itself.
|
|
|
|
// Wren pervasively uses the C99 integer types (uint16_t, etc.) along with some
|
|
// of the associated limit constants (UINT32_MAX, etc.). The constants are not
|
|
// part of standard C++, so aren't included by default by C++ compilers when you
|
|
// include <stdint> unless __STDC_LIMIT_MACROS is defined.
|
|
#define __STDC_LIMIT_MACROS
|
|
#include <stdint.h>
|
|
|
|
// These flags let you control some details of the interpreter's implementation.
|
|
// Usually they trade-off a bit of portability for speed. They default to the
|
|
// most efficient behavior.
|
|
|
|
// If true, then Wren uses a NaN-tagged double for its core value
|
|
// representation. Otherwise, it uses a larger more conventional struct. The
|
|
// former is significantly faster and more compact. The latter is useful for
|
|
// debugging and may be more portable.
|
|
//
|
|
// Defaults to on.
|
|
#ifndef WREN_NAN_TAGGING
|
|
#define WREN_NAN_TAGGING 1
|
|
#endif
|
|
|
|
// If true, the VM's interpreter loop uses computed gotos. See this for more:
|
|
// http://gcc.gnu.org/onlinedocs/gcc-3.1.1/gcc/Labels-as-Values.html
|
|
// Enabling this speeds up the main dispatch loop a bit, but requires compiler
|
|
// support.
|
|
// see https://bullno1.com/blog/switched-goto for alternative
|
|
// Defaults to true on supported compilers.
|
|
#ifndef WREN_COMPUTED_GOTO
|
|
#if defined(_MSC_VER) && !defined(__clang__)
|
|
// No computed gotos in Visual Studio.
|
|
#define WREN_COMPUTED_GOTO 0
|
|
#else
|
|
#define WREN_COMPUTED_GOTO 1
|
|
#endif
|
|
#endif
|
|
|
|
// The VM includes a number of optional modules. You can choose to include
|
|
// these or not. By default, they are all available. To disable one, set the
|
|
// corresponding `WREN_OPT_<name>` define to `0`.
|
|
#ifndef WREN_OPT_META
|
|
#define WREN_OPT_META 1
|
|
#endif
|
|
|
|
#ifndef WREN_OPT_RANDOM
|
|
#define WREN_OPT_RANDOM 1
|
|
#endif
|
|
|
|
// These flags are useful for debugging and hacking on Wren itself. They are not
|
|
// intended to be used for production code. They default to off.
|
|
|
|
// Set this to true to stress test the GC. It will perform a collection before
|
|
// every allocation. This is useful to ensure that memory is always correctly
|
|
// reachable.
|
|
#define WREN_DEBUG_GC_STRESS 0
|
|
|
|
// Set this to true to log memory operations as they occur.
|
|
#define WREN_DEBUG_TRACE_MEMORY 0
|
|
|
|
// Set this to true to log garbage collections as they occur.
|
|
#define WREN_DEBUG_TRACE_GC 0
|
|
|
|
// Set this to true to print out the compiled bytecode of each function.
|
|
#define WREN_DEBUG_DUMP_COMPILED_CODE 0
|
|
|
|
// Set this to trace each instruction as it's executed.
|
|
#define WREN_DEBUG_TRACE_INSTRUCTIONS 0
|
|
|
|
// The maximum number of module-level variables that may be defined at one time.
|
|
// This limitation comes from the 16 bits used for the arguments to
|
|
// `CODE_LOAD_MODULE_VAR` and `CODE_STORE_MODULE_VAR`.
|
|
#define MAX_MODULE_VARS 65536
|
|
|
|
// The maximum number of arguments that can be passed to a method. Note that
|
|
// this limitation is hardcoded in other places in the VM, in particular, the
|
|
// `CODE_CALL_XX` instructions assume a certain maximum number.
|
|
#define MAX_PARAMETERS 16
|
|
|
|
// The maximum name of a method, not including the signature. This is an
|
|
// arbitrary but enforced maximum just so we know how long the method name
|
|
// strings need to be in the parser.
|
|
#define MAX_METHOD_NAME 64
|
|
|
|
// The maximum length of a method signature. Signatures look like:
|
|
//
|
|
// foo // Getter.
|
|
// foo() // No-argument method.
|
|
// foo(_) // One-argument method.
|
|
// foo(_,_) // Two-argument method.
|
|
// init foo() // Constructor initializer.
|
|
//
|
|
// The maximum signature length takes into account the longest method name, the
|
|
// maximum number of parameters with separators between them, "init ", and "()".
|
|
#define MAX_METHOD_SIGNATURE (MAX_METHOD_NAME + (MAX_PARAMETERS * 2) + 6)
|
|
|
|
// The maximum length of an identifier. The only real reason for this limitation
|
|
// is so that error messages mentioning variables can be stack allocated.
|
|
#define MAX_VARIABLE_NAME 64
|
|
|
|
// The maximum number of fields a class can have, including inherited fields.
|
|
// This is explicit in the bytecode since `CODE_CLASS` and `CODE_SUBCLASS` take
|
|
// a single byte for the number of fields. Note that it's 255 and not 256
|
|
// because creating a class takes the *number* of fields, not the *highest
|
|
// field index*.
|
|
#define MAX_FIELDS 255
|
|
|
|
// Use the VM's allocator to allocate an object of [type].
|
|
#define ALLOCATE(vm, type) \
|
|
((type*)wrenReallocate(vm, NULL, 0, sizeof(type)))
|
|
|
|
// Use the VM's allocator to allocate an object of [mainType] containing a
|
|
// flexible array of [count] objects of [arrayType].
|
|
#define ALLOCATE_FLEX(vm, mainType, arrayType, count) \
|
|
((mainType*)wrenReallocate(vm, NULL, 0, \
|
|
sizeof(mainType) + sizeof(arrayType) * (count)))
|
|
|
|
// Use the VM's allocator to allocate an array of [count] elements of [type].
|
|
#define ALLOCATE_ARRAY(vm, type, count) \
|
|
((type*)wrenReallocate(vm, NULL, 0, sizeof(type) * (count)))
|
|
|
|
// Use the VM's allocator to free the previously allocated memory at [pointer].
|
|
#define DEALLOCATE(vm, pointer) wrenReallocate(vm, pointer, 0, 0)
|
|
|
|
// The Microsoft compiler does not support the "inline" modifier when compiling
|
|
// as plain C.
|
|
#if defined( _MSC_VER ) && !defined(__cplusplus)
|
|
#define inline _inline
|
|
#endif
|
|
|
|
// This is used to clearly mark flexible-sized arrays that appear at the end of
|
|
// some dynamically-allocated structs, known as the "struct hack".
|
|
#if __STDC_VERSION__ >= 199901L
|
|
// In C99, a flexible array member is just "[]".
|
|
#define FLEXIBLE_ARRAY
|
|
#else
|
|
// Elsewhere, use a zero-sized array. It's technically undefined behavior,
|
|
// but works reliably in most known compilers.
|
|
#define FLEXIBLE_ARRAY 0
|
|
#endif
|
|
|
|
// Assertions are used to validate program invariants. They indicate things the
|
|
// program expects to be true about its internal state during execution. If an
|
|
// assertion fails, there is a bug in Wren.
|
|
//
|
|
// Assertions add significant overhead, so are only enabled in debug builds.
|
|
#ifdef DEBUG
|
|
|
|
#include <stdio.h>
|
|
|
|
#define ASSERT(condition, message) \
|
|
do \
|
|
{ \
|
|
if (!(condition)) \
|
|
{ \
|
|
fprintf(stderr, "[%s:%d] Assert failed in %s(): %s\n", \
|
|
__FILE__, __LINE__, __func__, message); \
|
|
abort(); \
|
|
} \
|
|
} while (false)
|
|
|
|
// Indicates that we know execution should never reach this point in the
|
|
// program. In debug mode, we assert this fact because it's a bug to get here.
|
|
//
|
|
// In release mode, we use compiler-specific built in functions to tell the
|
|
// compiler the code can't be reached. This avoids "missing return" warnings
|
|
// in some cases and also lets it perform some optimizations by assuming the
|
|
// code is never reached.
|
|
#define UNREACHABLE() \
|
|
do \
|
|
{ \
|
|
fprintf(stderr, "[%s:%d] This code should not be reached in %s()\n", \
|
|
__FILE__, __LINE__, __func__); \
|
|
abort(); \
|
|
} while (false)
|
|
|
|
#else
|
|
|
|
#define ASSERT(condition, message) do { } while (false)
|
|
|
|
// Tell the compiler that this part of the code will never be reached.
|
|
#if defined( _MSC_VER )
|
|
#define UNREACHABLE() __assume(0)
|
|
#elif (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5))
|
|
#define UNREACHABLE() __builtin_unreachable()
|
|
#else
|
|
#define UNREACHABLE()
|
|
#endif
|
|
|
|
#endif
|
|
|
|
#endif
|
|
// End file "wren_common.h"
|
|
// Begin file "wren_math.h"
|
|
#ifndef wren_math_h
|
|
#define wren_math_h
|
|
|
|
#include <math.h>
|
|
#include <stdint.h>
|
|
|
|
// A union to let us reinterpret a double as raw bits and back.
|
|
typedef union
|
|
{
|
|
uint64_t bits64;
|
|
uint32_t bits32[2];
|
|
double num;
|
|
} WrenDoubleBits;
|
|
|
|
#define WREN_DOUBLE_QNAN_POS_MIN_BITS (UINT64_C(0x7FF8000000000000))
|
|
#define WREN_DOUBLE_QNAN_POS_MAX_BITS (UINT64_C(0x7FFFFFFFFFFFFFFF))
|
|
|
|
#define WREN_DOUBLE_NAN (wrenDoubleFromBits(WREN_DOUBLE_QNAN_POS_MIN_BITS))
|
|
|
|
static inline double wrenDoubleFromBits(uint64_t bits)
|
|
{
|
|
WrenDoubleBits data;
|
|
data.bits64 = bits;
|
|
return data.num;
|
|
}
|
|
|
|
static inline uint64_t wrenDoubleToBits(double num)
|
|
{
|
|
WrenDoubleBits data;
|
|
data.num = num;
|
|
return data.bits64;
|
|
}
|
|
|
|
#endif
|
|
// End file "wren_math.h"
|
|
// Begin file "wren_utils.h"
|
|
#ifndef wren_utils_h
|
|
#define wren_utils_h
|
|
|
|
|
|
// Reusable data structures and other utility functions.
|
|
|
|
// Forward declare this here to break a cycle between wren_utils.h and
|
|
// wren_value.h.
|
|
typedef struct sObjString ObjString;
|
|
|
|
// We need buffers of a few different types. To avoid lots of casting between
|
|
// void* and back, we'll use the preprocessor as a poor man's generics and let
|
|
// it generate a few type-specific ones.
|
|
#define DECLARE_BUFFER(name, type) \
|
|
typedef struct \
|
|
{ \
|
|
type* data; \
|
|
int count; \
|
|
int capacity; \
|
|
} name##Buffer; \
|
|
void wren##name##BufferInit(name##Buffer* buffer); \
|
|
void wren##name##BufferClear(WrenVM* vm, name##Buffer* buffer); \
|
|
void wren##name##BufferFill(WrenVM* vm, name##Buffer* buffer, type data, \
|
|
int count); \
|
|
void wren##name##BufferWrite(WrenVM* vm, name##Buffer* buffer, type data)
|
|
|
|
// This should be used once for each type instantiation, somewhere in a .c file.
|
|
#define DEFINE_BUFFER(name, type) \
|
|
void wren##name##BufferInit(name##Buffer* buffer) \
|
|
{ \
|
|
buffer->data = NULL; \
|
|
buffer->capacity = 0; \
|
|
buffer->count = 0; \
|
|
} \
|
|
\
|
|
void wren##name##BufferClear(WrenVM* vm, name##Buffer* buffer) \
|
|
{ \
|
|
wrenReallocate(vm, buffer->data, 0, 0); \
|
|
wren##name##BufferInit(buffer); \
|
|
} \
|
|
\
|
|
void wren##name##BufferFill(WrenVM* vm, name##Buffer* buffer, type data, \
|
|
int count) \
|
|
{ \
|
|
if (buffer->capacity < buffer->count + count) \
|
|
{ \
|
|
int capacity = wrenPowerOf2Ceil(buffer->count + count); \
|
|
buffer->data = (type*)wrenReallocate(vm, buffer->data, \
|
|
buffer->capacity * sizeof(type), capacity * sizeof(type)); \
|
|
buffer->capacity = capacity; \
|
|
} \
|
|
\
|
|
for (int i = 0; i < count; i++) \
|
|
{ \
|
|
buffer->data[buffer->count++] = data; \
|
|
} \
|
|
} \
|
|
\
|
|
void wren##name##BufferWrite(WrenVM* vm, name##Buffer* buffer, type data) \
|
|
{ \
|
|
wren##name##BufferFill(vm, buffer, data, 1); \
|
|
}
|
|
|
|
DECLARE_BUFFER(Byte, uint8_t);
|
|
DECLARE_BUFFER(Int, int);
|
|
DECLARE_BUFFER(String, ObjString*);
|
|
|
|
// TODO: Change this to use a map.
|
|
typedef StringBuffer SymbolTable;
|
|
|
|
// Initializes the symbol table.
|
|
void wrenSymbolTableInit(SymbolTable* symbols);
|
|
|
|
// Frees all dynamically allocated memory used by the symbol table, but not the
|
|
// SymbolTable itself.
|
|
void wrenSymbolTableClear(WrenVM* vm, SymbolTable* symbols);
|
|
|
|
// Adds name to the symbol table. Returns the index of it in the table.
|
|
int wrenSymbolTableAdd(WrenVM* vm, SymbolTable* symbols,
|
|
const char* name, size_t length);
|
|
|
|
// Adds name to the symbol table. Returns the index of it in the table. Will
|
|
// use an existing symbol if already present.
|
|
int wrenSymbolTableEnsure(WrenVM* vm, SymbolTable* symbols,
|
|
const char* name, size_t length);
|
|
|
|
// Looks up name in the symbol table. Returns its index if found or -1 if not.
|
|
int wrenSymbolTableFind(const SymbolTable* symbols,
|
|
const char* name, size_t length);
|
|
|
|
void wrenBlackenSymbolTable(WrenVM* vm, SymbolTable* symbolTable);
|
|
|
|
// Returns the number of bytes needed to encode [value] in UTF-8.
|
|
//
|
|
// Returns 0 if [value] is too large to encode.
|
|
int wrenUtf8EncodeNumBytes(int value);
|
|
|
|
// Encodes value as a series of bytes in [bytes], which is assumed to be large
|
|
// enough to hold the encoded result.
|
|
//
|
|
// Returns the number of written bytes.
|
|
int wrenUtf8Encode(int value, uint8_t* bytes);
|
|
|
|
// Decodes the UTF-8 sequence starting at [bytes] (which has max [length]),
|
|
// returning the code point.
|
|
//
|
|
// Returns -1 if the bytes are not a valid UTF-8 sequence.
|
|
int wrenUtf8Decode(const uint8_t* bytes, uint32_t length);
|
|
|
|
// Returns the number of bytes in the UTF-8 sequence starting with [byte].
|
|
//
|
|
// If the character at that index is not the beginning of a UTF-8 sequence,
|
|
// returns 0.
|
|
int wrenUtf8DecodeNumBytes(uint8_t byte);
|
|
|
|
// Returns the smallest power of two that is equal to or greater than [n].
|
|
int wrenPowerOf2Ceil(int n);
|
|
|
|
// Validates that [value] is within `[0, count)`. Also allows
|
|
// negative indices which map backwards from the end. Returns the valid positive
|
|
// index value. If invalid, returns `UINT32_MAX`.
|
|
uint32_t wrenValidateIndex(uint32_t count, int64_t value);
|
|
|
|
#endif
|
|
// End file "wren_utils.h"
|
|
|
|
// This defines the built-in types and their core representations in memory.
|
|
// Since Wren is dynamically typed, any variable can hold a value of any type,
|
|
// and the type can change at runtime. Implementing this efficiently is
|
|
// critical for performance.
|
|
//
|
|
// The main type exposed by this is [Value]. A C variable of that type is a
|
|
// storage location that can hold any Wren value. The stack, module variables,
|
|
// and instance fields are all implemented in C as variables of type Value.
|
|
//
|
|
// The built-in types for booleans, numbers, and null are unboxed: their value
|
|
// is stored directly in the Value, and copying a Value copies the value. Other
|
|
// types--classes, instances of classes, functions, lists, and strings--are all
|
|
// reference types. They are stored on the heap and the Value just stores a
|
|
// pointer to it. Copying the Value copies a reference to the same object. The
|
|
// Wren implementation calls these "Obj", or objects, though to a user, all
|
|
// values are objects.
|
|
//
|
|
// There is also a special singleton value "undefined". It is used internally
|
|
// but never appears as a real value to a user. It has two uses:
|
|
//
|
|
// - It is used to identify module variables that have been implicitly declared
|
|
// by use in a forward reference but not yet explicitly declared. These only
|
|
// exist during compilation and do not appear at runtime.
|
|
//
|
|
// - It is used to represent unused map entries in an ObjMap.
|
|
//
|
|
// There are two supported Value representations. The main one uses a technique
|
|
// called "NaN tagging" (explained in detail below) to store a number, any of
|
|
// the value types, or a pointer, all inside one double-precision floating
|
|
// point number. A larger, slower, Value type that uses a struct to store these
|
|
// is also supported, and is useful for debugging the VM.
|
|
//
|
|
// The representation is controlled by the `WREN_NAN_TAGGING` define. If that's
|
|
// defined, Nan tagging is used.
|
|
|
|
// These macros cast a Value to one of the specific object types. These do *not*
|
|
// perform any validation, so must only be used after the Value has been
|
|
// ensured to be the right type.
|
|
#define AS_CLASS(value) ((ObjClass*)AS_OBJ(value)) // ObjClass*
|
|
#define AS_CLOSURE(value) ((ObjClosure*)AS_OBJ(value)) // ObjClosure*
|
|
#define AS_FIBER(v) ((ObjFiber*)AS_OBJ(v)) // ObjFiber*
|
|
#define AS_FN(value) ((ObjFn*)AS_OBJ(value)) // ObjFn*
|
|
#define AS_FOREIGN(v) ((ObjForeign*)AS_OBJ(v)) // ObjForeign*
|
|
#define AS_INSTANCE(value) ((ObjInstance*)AS_OBJ(value)) // ObjInstance*
|
|
#define AS_LIST(value) ((ObjList*)AS_OBJ(value)) // ObjList*
|
|
#define AS_MAP(value) ((ObjMap*)AS_OBJ(value)) // ObjMap*
|
|
#define AS_MODULE(value) ((ObjModule*)AS_OBJ(value)) // ObjModule*
|
|
#define AS_NUM(value) (wrenValueToNum(value)) // double
|
|
#define AS_RANGE(v) ((ObjRange*)AS_OBJ(v)) // ObjRange*
|
|
#define AS_STRING(v) ((ObjString*)AS_OBJ(v)) // ObjString*
|
|
#define AS_CSTRING(v) (AS_STRING(v)->value) // const char*
|
|
|
|
// These macros promote a primitive C value to a full Wren Value. There are
|
|
// more defined below that are specific to the Nan tagged or other
|
|
// representation.
|
|
#define BOOL_VAL(boolean) ((boolean) ? TRUE_VAL : FALSE_VAL) // boolean
|
|
#define NUM_VAL(num) (wrenNumToValue(num)) // double
|
|
#define OBJ_VAL(obj) (wrenObjectToValue((Obj*)(obj))) // Any Obj___*
|
|
|
|
// These perform type tests on a Value, returning `true` if the Value is of the
|
|
// given type.
|
|
#define IS_BOOL(value) (wrenIsBool(value)) // Bool
|
|
#define IS_CLASS(value) (wrenIsObjType(value, OBJ_CLASS)) // ObjClass
|
|
#define IS_CLOSURE(value) (wrenIsObjType(value, OBJ_CLOSURE)) // ObjClosure
|
|
#define IS_FIBER(value) (wrenIsObjType(value, OBJ_FIBER)) // ObjFiber
|
|
#define IS_FN(value) (wrenIsObjType(value, OBJ_FN)) // ObjFn
|
|
#define IS_FOREIGN(value) (wrenIsObjType(value, OBJ_FOREIGN)) // ObjForeign
|
|
#define IS_INSTANCE(value) (wrenIsObjType(value, OBJ_INSTANCE)) // ObjInstance
|
|
#define IS_LIST(value) (wrenIsObjType(value, OBJ_LIST)) // ObjList
|
|
#define IS_MAP(value) (wrenIsObjType(value, OBJ_MAP)) // ObjMap
|
|
#define IS_RANGE(value) (wrenIsObjType(value, OBJ_RANGE)) // ObjRange
|
|
#define IS_STRING(value) (wrenIsObjType(value, OBJ_STRING)) // ObjString
|
|
|
|
// Creates a new string object from [text], which should be a bare C string
|
|
// literal. This determines the length of the string automatically at compile
|
|
// time based on the size of the character array (-1 for the terminating '\0').
|
|
#define CONST_STRING(vm, text) wrenNewStringLength((vm), (text), sizeof(text) - 1)
|
|
|
|
// Identifies which specific type a heap-allocated object is.
|
|
typedef enum {
|
|
OBJ_CLASS,
|
|
OBJ_CLOSURE,
|
|
OBJ_FIBER,
|
|
OBJ_FN,
|
|
OBJ_FOREIGN,
|
|
OBJ_INSTANCE,
|
|
OBJ_LIST,
|
|
OBJ_MAP,
|
|
OBJ_MODULE,
|
|
OBJ_RANGE,
|
|
OBJ_STRING,
|
|
OBJ_UPVALUE
|
|
} ObjType;
|
|
|
|
typedef struct sObjClass ObjClass;
|
|
|
|
// Base struct for all heap-allocated objects.
|
|
typedef struct sObj Obj;
|
|
struct sObj
|
|
{
|
|
ObjType type;
|
|
bool isDark;
|
|
|
|
// The object's class.
|
|
ObjClass* classObj;
|
|
|
|
// The next object in the linked list of all currently allocated objects.
|
|
struct sObj* next;
|
|
};
|
|
|
|
#if WREN_NAN_TAGGING
|
|
|
|
typedef uint64_t Value;
|
|
|
|
#else
|
|
|
|
typedef enum
|
|
{
|
|
VAL_FALSE,
|
|
VAL_NULL,
|
|
VAL_NUM,
|
|
VAL_TRUE,
|
|
VAL_UNDEFINED,
|
|
VAL_OBJ
|
|
} ValueType;
|
|
|
|
typedef struct
|
|
{
|
|
ValueType type;
|
|
union
|
|
{
|
|
double num;
|
|
Obj* obj;
|
|
} as;
|
|
} Value;
|
|
|
|
#endif
|
|
|
|
DECLARE_BUFFER(Value, Value);
|
|
|
|
// A heap-allocated string object.
|
|
struct sObjString
|
|
{
|
|
Obj obj;
|
|
|
|
// Number of bytes in the string, not including the null terminator.
|
|
uint32_t length;
|
|
|
|
// The hash value of the string's contents.
|
|
uint32_t hash;
|
|
|
|
// Inline array of the string's bytes followed by a null terminator.
|
|
char value[FLEXIBLE_ARRAY];
|
|
};
|
|
|
|
// The dynamically allocated data structure for a variable that has been used
|
|
// by a closure. Whenever a function accesses a variable declared in an
|
|
// enclosing function, it will get to it through this.
|
|
//
|
|
// An upvalue can be either "closed" or "open". An open upvalue points directly
|
|
// to a [Value] that is still stored on the fiber's stack because the local
|
|
// variable is still in scope in the function where it's declared.
|
|
//
|
|
// When that local variable goes out of scope, the upvalue pointing to it will
|
|
// be closed. When that happens, the value gets copied off the stack into the
|
|
// upvalue itself. That way, it can have a longer lifetime than the stack
|
|
// variable.
|
|
typedef struct sObjUpvalue
|
|
{
|
|
// The object header. Note that upvalues have this because they are garbage
|
|
// collected, but they are not first class Wren objects.
|
|
Obj obj;
|
|
|
|
// Pointer to the variable this upvalue is referencing.
|
|
Value* value;
|
|
|
|
// If the upvalue is closed (i.e. the local variable it was pointing to has
|
|
// been popped off the stack) then the closed-over value will be hoisted out
|
|
// of the stack into here. [value] will then be changed to point to this.
|
|
Value closed;
|
|
|
|
// Open upvalues are stored in a linked list by the fiber. This points to the
|
|
// next upvalue in that list.
|
|
struct sObjUpvalue* next;
|
|
} ObjUpvalue;
|
|
|
|
// The type of a primitive function.
|
|
//
|
|
// Primitives are similar to foreign functions, but have more direct access to
|
|
// VM internals. It is passed the arguments in [args]. If it returns a value,
|
|
// it places it in `args[0]` and returns `true`. If it causes a runtime error
|
|
// or modifies the running fiber, it returns `false`.
|
|
typedef bool (*Primitive)(WrenVM* vm, Value* args);
|
|
|
|
// TODO: See if it's actually a perf improvement to have this in a separate
|
|
// struct instead of in ObjFn.
|
|
// Stores debugging information for a function used for things like stack
|
|
// traces.
|
|
typedef struct
|
|
{
|
|
// The name of the function. Heap allocated and owned by the FnDebug.
|
|
char* name;
|
|
|
|
// An array of line numbers. There is one element in this array for each
|
|
// bytecode in the function's bytecode array. The value of that element is
|
|
// the line in the source code that generated that instruction.
|
|
IntBuffer sourceLines;
|
|
} FnDebug;
|
|
|
|
// A loaded module and the top-level variables it defines.
|
|
//
|
|
// While this is an Obj and is managed by the GC, it never appears as a
|
|
// first-class object in Wren.
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
|
|
// The currently defined top-level variables.
|
|
ValueBuffer variables;
|
|
|
|
// Symbol table for the names of all module variables. Indexes here directly
|
|
// correspond to entries in [variables].
|
|
SymbolTable variableNames;
|
|
|
|
// The name of the module.
|
|
ObjString* name;
|
|
} ObjModule;
|
|
|
|
// A function object. It wraps and owns the bytecode and other debug information
|
|
// for a callable chunk of code.
|
|
//
|
|
// Function objects are not passed around and invoked directly. Instead, they
|
|
// are always referenced by an [ObjClosure] which is the real first-class
|
|
// representation of a function. This isn't strictly necessary if they function
|
|
// has no upvalues, but lets the rest of the VM assume all called objects will
|
|
// be closures.
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
|
|
ByteBuffer code;
|
|
ValueBuffer constants;
|
|
|
|
// The module where this function was defined.
|
|
ObjModule* module;
|
|
|
|
// The maximum number of stack slots this function may use.
|
|
int maxSlots;
|
|
|
|
// The number of upvalues this function closes over.
|
|
int numUpvalues;
|
|
|
|
// The number of parameters this function expects. Used to ensure that .call
|
|
// handles a mismatch between number of parameters and arguments. This will
|
|
// only be set for fns, and not ObjFns that represent methods or scripts.
|
|
int arity;
|
|
FnDebug* debug;
|
|
} ObjFn;
|
|
|
|
// An instance of a first-class function and the environment it has closed over.
|
|
// Unlike [ObjFn], this has captured the upvalues that the function accesses.
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
|
|
// The function that this closure is an instance of.
|
|
ObjFn* fn;
|
|
|
|
// The upvalues this function has closed over.
|
|
ObjUpvalue* upvalues[FLEXIBLE_ARRAY];
|
|
} ObjClosure;
|
|
|
|
typedef struct
|
|
{
|
|
// Pointer to the current (really next-to-be-executed) instruction in the
|
|
// function's bytecode.
|
|
uint8_t* ip;
|
|
|
|
// The closure being executed.
|
|
ObjClosure* closure;
|
|
|
|
// Pointer to the first stack slot used by this call frame. This will contain
|
|
// the receiver, followed by the function's parameters, then local variables
|
|
// and temporaries.
|
|
Value* stackStart;
|
|
} CallFrame;
|
|
|
|
// Tracks how this fiber has been invoked, aside from the ways that can be
|
|
// detected from the state of other fields in the fiber.
|
|
typedef enum
|
|
{
|
|
// The fiber is being run from another fiber using a call to `try()`.
|
|
FIBER_TRY,
|
|
|
|
// The fiber was directly invoked by `runInterpreter()`. This means it's the
|
|
// initial fiber used by a call to `wrenCall()` or `wrenInterpret()`.
|
|
FIBER_ROOT,
|
|
|
|
// The fiber is invoked some other way. If [caller] is `NULL` then the fiber
|
|
// was invoked using `call()`. If [numFrames] is zero, then the fiber has
|
|
// finished running and is done. If [numFrames] is one and that frame's `ip`
|
|
// points to the first byte of code, the fiber has not been started yet.
|
|
FIBER_OTHER,
|
|
} FiberState;
|
|
|
|
typedef struct sObjFiber
|
|
{
|
|
Obj obj;
|
|
|
|
// The stack of value slots. This is used for holding local variables and
|
|
// temporaries while the fiber is executing. It is heap-allocated and grown
|
|
// as needed.
|
|
Value* stack;
|
|
|
|
// A pointer to one past the top-most value on the stack.
|
|
Value* stackTop;
|
|
|
|
// The number of allocated slots in the stack array.
|
|
int stackCapacity;
|
|
|
|
// The stack of call frames. This is a dynamic array that grows as needed but
|
|
// never shrinks.
|
|
CallFrame* frames;
|
|
|
|
// The number of frames currently in use in [frames].
|
|
int numFrames;
|
|
|
|
// The number of [frames] allocated.
|
|
int frameCapacity;
|
|
|
|
// Pointer to the first node in the linked list of open upvalues that are
|
|
// pointing to values still on the stack. The head of the list will be the
|
|
// upvalue closest to the top of the stack, and then the list works downwards.
|
|
ObjUpvalue* openUpvalues;
|
|
|
|
// The fiber that ran this one. If this fiber is yielded, control will resume
|
|
// to this one. May be `NULL`.
|
|
struct sObjFiber* caller;
|
|
|
|
// If the fiber failed because of a runtime error, this will contain the
|
|
// error object. Otherwise, it will be null.
|
|
Value error;
|
|
|
|
FiberState state;
|
|
} ObjFiber;
|
|
|
|
typedef enum
|
|
{
|
|
// A primitive method implemented in C in the VM. Unlike foreign methods,
|
|
// this can directly manipulate the fiber's stack.
|
|
METHOD_PRIMITIVE,
|
|
|
|
// A primitive that handles .call on Fn.
|
|
METHOD_FUNCTION_CALL,
|
|
|
|
// A externally-defined C method.
|
|
METHOD_FOREIGN,
|
|
|
|
// A normal user-defined method.
|
|
METHOD_BLOCK,
|
|
|
|
// No method for the given symbol.
|
|
METHOD_NONE
|
|
} MethodType;
|
|
|
|
typedef struct
|
|
{
|
|
MethodType type;
|
|
|
|
// The method function itself. The [type] determines which field of the union
|
|
// is used.
|
|
union
|
|
{
|
|
Primitive primitive;
|
|
WrenForeignMethodFn foreign;
|
|
ObjClosure* closure;
|
|
} as;
|
|
} Method;
|
|
|
|
DECLARE_BUFFER(Method, Method);
|
|
|
|
struct sObjClass
|
|
{
|
|
Obj obj;
|
|
ObjClass* superclass;
|
|
|
|
// The number of fields needed for an instance of this class, including all
|
|
// of its superclass fields.
|
|
int numFields;
|
|
|
|
// The table of methods that are defined in or inherited by this class.
|
|
// Methods are called by symbol, and the symbol directly maps to an index in
|
|
// this table. This makes method calls fast at the expense of empty cells in
|
|
// the list for methods the class doesn't support.
|
|
//
|
|
// You can think of it as a hash table that never has collisions but has a
|
|
// really low load factor. Since methods are pretty small (just a type and a
|
|
// pointer), this should be a worthwhile trade-off.
|
|
MethodBuffer methods;
|
|
|
|
// The name of the class.
|
|
ObjString* name;
|
|
|
|
// The ClassAttribute for the class, if any
|
|
Value attributes;
|
|
};
|
|
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
uint8_t data[FLEXIBLE_ARRAY];
|
|
} ObjForeign;
|
|
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
Value fields[FLEXIBLE_ARRAY];
|
|
} ObjInstance;
|
|
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
|
|
// The elements in the list.
|
|
ValueBuffer elements;
|
|
} ObjList;
|
|
|
|
typedef struct
|
|
{
|
|
// The entry's key, or UNDEFINED_VAL if the entry is not in use.
|
|
Value key;
|
|
|
|
// The value associated with the key. If the key is UNDEFINED_VAL, this will
|
|
// be false to indicate an open available entry or true to indicate a
|
|
// tombstone -- an entry that was previously in use but was then deleted.
|
|
Value value;
|
|
} MapEntry;
|
|
|
|
// A hash table mapping keys to values.
|
|
//
|
|
// We use something very simple: open addressing with linear probing. The hash
|
|
// table is an array of entries. Each entry is a key-value pair. If the key is
|
|
// the special UNDEFINED_VAL, it indicates no value is currently in that slot.
|
|
// Otherwise, it's a valid key, and the value is the value associated with it.
|
|
//
|
|
// When entries are added, the array is dynamically scaled by GROW_FACTOR to
|
|
// keep the number of filled slots under MAP_LOAD_PERCENT. Likewise, if the map
|
|
// gets empty enough, it will be resized to a smaller array. When this happens,
|
|
// all existing entries are rehashed and re-added to the new array.
|
|
//
|
|
// When an entry is removed, its slot is replaced with a "tombstone". This is an
|
|
// entry whose key is UNDEFINED_VAL and whose value is TRUE_VAL. When probing
|
|
// for a key, we will continue past tombstones, because the desired key may be
|
|
// found after them if the key that was removed was part of a prior collision.
|
|
// When the array gets resized, all tombstones are discarded.
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
|
|
// The number of entries allocated.
|
|
uint32_t capacity;
|
|
|
|
// The number of entries in the map.
|
|
uint32_t count;
|
|
|
|
// Pointer to a contiguous array of [capacity] entries.
|
|
MapEntry* entries;
|
|
} ObjMap;
|
|
|
|
typedef struct
|
|
{
|
|
Obj obj;
|
|
|
|
// The beginning of the range.
|
|
double from;
|
|
|
|
// The end of the range. May be greater or less than [from].
|
|
double to;
|
|
|
|
// True if [to] is included in the range.
|
|
bool isInclusive;
|
|
} ObjRange;
|
|
|
|
// An IEEE 754 double-precision float is a 64-bit value with bits laid out like:
|
|
//
|
|
// 1 Sign bit
|
|
// | 11 Exponent bits
|
|
// | | 52 Mantissa (i.e. fraction) bits
|
|
// | | |
|
|
// S[Exponent-][Mantissa------------------------------------------]
|
|
//
|
|
// The details of how these are used to represent numbers aren't really
|
|
// relevant here as long we don't interfere with them. The important bit is NaN.
|
|
//
|
|
// An IEEE double can represent a few magical values like NaN ("not a number"),
|
|
// Infinity, and -Infinity. A NaN is any value where all exponent bits are set:
|
|
//
|
|
// v--NaN bits
|
|
// -11111111111----------------------------------------------------
|
|
//
|
|
// Here, "-" means "doesn't matter". Any bit sequence that matches the above is
|
|
// a NaN. With all of those "-", it obvious there are a *lot* of different
|
|
// bit patterns that all mean the same thing. NaN tagging takes advantage of
|
|
// this. We'll use those available bit patterns to represent things other than
|
|
// numbers without giving up any valid numeric values.
|
|
//
|
|
// NaN values come in two flavors: "signalling" and "quiet". The former are
|
|
// intended to halt execution, while the latter just flow through arithmetic
|
|
// operations silently. We want the latter. Quiet NaNs are indicated by setting
|
|
// the highest mantissa bit:
|
|
//
|
|
// v--Highest mantissa bit
|
|
// -[NaN ]1---------------------------------------------------
|
|
//
|
|
// If all of the NaN bits are set, it's not a number. Otherwise, it is.
|
|
// That leaves all of the remaining bits as available for us to play with. We
|
|
// stuff a few different kinds of things here: special singleton values like
|
|
// "true", "false", and "null", and pointers to objects allocated on the heap.
|
|
// We'll use the sign bit to distinguish singleton values from pointers. If
|
|
// it's set, it's a pointer.
|
|
//
|
|
// v--Pointer or singleton?
|
|
// S[NaN ]1---------------------------------------------------
|
|
//
|
|
// For singleton values, we just enumerate the different values. We'll use the
|
|
// low bits of the mantissa for that, and only need a few:
|
|
//
|
|
// 3 Type bits--v
|
|
// 0[NaN ]1------------------------------------------------[T]
|
|
//
|
|
// For pointers, we are left with 51 bits of mantissa to store an address.
|
|
// That's more than enough room for a 32-bit address. Even 64-bit machines
|
|
// only actually use 48 bits for addresses, so we've got plenty. We just stuff
|
|
// the address right into the mantissa.
|
|
//
|
|
// Ta-da, double precision numbers, pointers, and a bunch of singleton values,
|
|
// all stuffed into a single 64-bit sequence. Even better, we don't have to
|
|
// do any masking or work to extract number values: they are unmodified. This
|
|
// means math on numbers is fast.
|
|
#if WREN_NAN_TAGGING
|
|
|
|
// A mask that selects the sign bit.
|
|
#define SIGN_BIT ((uint64_t)1 << 63)
|
|
|
|
// The bits that must be set to indicate a quiet NaN.
|
|
#define QNAN ((uint64_t)0x7ffc000000000000)
|
|
|
|
// If the NaN bits are set, it's not a number.
|
|
#define IS_NUM(value) (((value) & QNAN) != QNAN)
|
|
|
|
// An object pointer is a NaN with a set sign bit.
|
|
#define IS_OBJ(value) (((value) & (QNAN | SIGN_BIT)) == (QNAN | SIGN_BIT))
|
|
|
|
#define IS_FALSE(value) ((value) == FALSE_VAL)
|
|
#define IS_NULL(value) ((value) == NULL_VAL)
|
|
#define IS_UNDEFINED(value) ((value) == UNDEFINED_VAL)
|
|
|
|
// Masks out the tag bits used to identify the singleton value.
|
|
#define MASK_TAG (7)
|
|
|
|
// Tag values for the different singleton values.
|
|
#define TAG_NAN (0)
|
|
#define TAG_NULL (1)
|
|
#define TAG_FALSE (2)
|
|
#define TAG_TRUE (3)
|
|
#define TAG_UNDEFINED (4)
|
|
#define TAG_UNUSED2 (5)
|
|
#define TAG_UNUSED3 (6)
|
|
#define TAG_UNUSED4 (7)
|
|
|
|
// Value -> 0 or 1.
|
|
#define AS_BOOL(value) ((value) == TRUE_VAL)
|
|
|
|
// Value -> Obj*.
|
|
#define AS_OBJ(value) ((Obj*)(uintptr_t)((value) & ~(SIGN_BIT | QNAN)))
|
|
|
|
// Singleton values.
|
|
#define NULL_VAL ((Value)(uint64_t)(QNAN | TAG_NULL))
|
|
#define FALSE_VAL ((Value)(uint64_t)(QNAN | TAG_FALSE))
|
|
#define TRUE_VAL ((Value)(uint64_t)(QNAN | TAG_TRUE))
|
|
#define UNDEFINED_VAL ((Value)(uint64_t)(QNAN | TAG_UNDEFINED))
|
|
|
|
// Gets the singleton type tag for a Value (which must be a singleton).
|
|
#define GET_TAG(value) ((int)((value) & MASK_TAG))
|
|
|
|
#else
|
|
|
|
// Value -> 0 or 1.
|
|
#define AS_BOOL(value) ((value).type == VAL_TRUE)
|
|
|
|
// Value -> Obj*.
|
|
#define AS_OBJ(v) ((v).as.obj)
|
|
|
|
// Determines if [value] is a garbage-collected object or not.
|
|
#define IS_OBJ(value) ((value).type == VAL_OBJ)
|
|
|
|
#define IS_FALSE(value) ((value).type == VAL_FALSE)
|
|
#define IS_NULL(value) ((value).type == VAL_NULL)
|
|
#define IS_NUM(value) ((value).type == VAL_NUM)
|
|
#define IS_UNDEFINED(value) ((value).type == VAL_UNDEFINED)
|
|
|
|
// Singleton values.
|
|
#define FALSE_VAL ((Value){ VAL_FALSE, { 0 } })
|
|
#define NULL_VAL ((Value){ VAL_NULL, { 0 } })
|
|
#define TRUE_VAL ((Value){ VAL_TRUE, { 0 } })
|
|
#define UNDEFINED_VAL ((Value){ VAL_UNDEFINED, { 0 } })
|
|
|
|
#endif
|
|
|
|
// Creates a new "raw" class. It has no metaclass or superclass whatsoever.
|
|
// This is only used for bootstrapping the initial Object and Class classes,
|
|
// which are a little special.
|
|
ObjClass* wrenNewSingleClass(WrenVM* vm, int numFields, ObjString* name);
|
|
|
|
// Makes [superclass] the superclass of [subclass], and causes subclass to
|
|
// inherit its methods. This should be called before any methods are defined
|
|
// on subclass.
|
|
void wrenBindSuperclass(WrenVM* vm, ObjClass* subclass, ObjClass* superclass);
|
|
|
|
// Creates a new class object as well as its associated metaclass.
|
|
ObjClass* wrenNewClass(WrenVM* vm, ObjClass* superclass, int numFields,
|
|
ObjString* name);
|
|
|
|
void wrenBindMethod(WrenVM* vm, ObjClass* classObj, int symbol, Method method);
|
|
|
|
// Creates a new closure object that invokes [fn]. Allocates room for its
|
|
// upvalues, but assumes outside code will populate it.
|
|
ObjClosure* wrenNewClosure(WrenVM* vm, ObjFn* fn);
|
|
|
|
// Creates a new fiber object that will invoke [closure].
|
|
ObjFiber* wrenNewFiber(WrenVM* vm, ObjClosure* closure);
|
|
|
|
// Adds a new [CallFrame] to [fiber] invoking [closure] whose stack starts at
|
|
// [stackStart].
|
|
static inline void wrenAppendCallFrame(WrenVM* vm, ObjFiber* fiber,
|
|
ObjClosure* closure, Value* stackStart)
|
|
{
|
|
// The caller should have ensured we already have enough capacity.
|
|
ASSERT(fiber->frameCapacity > fiber->numFrames, "No memory for call frame.");
|
|
|
|
CallFrame* frame = &fiber->frames[fiber->numFrames++];
|
|
frame->stackStart = stackStart;
|
|
frame->closure = closure;
|
|
frame->ip = closure->fn->code.data;
|
|
}
|
|
|
|
// Ensures [fiber]'s stack has at least [needed] slots.
|
|
void wrenEnsureStack(WrenVM* vm, ObjFiber* fiber, int needed);
|
|
|
|
static inline bool wrenHasError(const ObjFiber* fiber)
|
|
{
|
|
return !IS_NULL(fiber->error);
|
|
}
|
|
|
|
ObjForeign* wrenNewForeign(WrenVM* vm, ObjClass* classObj, size_t size);
|
|
|
|
// Creates a new empty function. Before being used, it must have code,
|
|
// constants, etc. added to it.
|
|
ObjFn* wrenNewFunction(WrenVM* vm, ObjModule* module, int maxSlots);
|
|
|
|
void wrenFunctionBindName(WrenVM* vm, ObjFn* fn, const char* name, int length);
|
|
|
|
// Creates a new instance of the given [classObj].
|
|
Value wrenNewInstance(WrenVM* vm, ObjClass* classObj);
|
|
|
|
// Creates a new list with [numElements] elements (which are left
|
|
// uninitialized.)
|
|
ObjList* wrenNewList(WrenVM* vm, uint32_t numElements);
|
|
|
|
// Inserts [value] in [list] at [index], shifting down the other elements.
|
|
void wrenListInsert(WrenVM* vm, ObjList* list, Value value, uint32_t index);
|
|
|
|
// Removes and returns the item at [index] from [list].
|
|
Value wrenListRemoveAt(WrenVM* vm, ObjList* list, uint32_t index);
|
|
|
|
// Searches for [value] in [list], returns the index or -1 if not found.
|
|
int wrenListIndexOf(WrenVM* vm, ObjList* list, Value value);
|
|
|
|
// Creates a new empty map.
|
|
ObjMap* wrenNewMap(WrenVM* vm);
|
|
|
|
// Validates that [arg] is a valid object for use as a map key. Returns true if
|
|
// it is and returns false otherwise. Use validateKey usually, for a runtime error.
|
|
// This separation exists to aid the API in surfacing errors to the developer as well.
|
|
static inline bool wrenMapIsValidKey(Value arg);
|
|
|
|
// Looks up [key] in [map]. If found, returns the value. Otherwise, returns
|
|
// `UNDEFINED_VAL`.
|
|
Value wrenMapGet(ObjMap* map, Value key);
|
|
|
|
// Associates [key] with [value] in [map].
|
|
void wrenMapSet(WrenVM* vm, ObjMap* map, Value key, Value value);
|
|
|
|
void wrenMapClear(WrenVM* vm, ObjMap* map);
|
|
|
|
// Removes [key] from [map], if present. Returns the value for the key if found
|
|
// or `NULL_VAL` otherwise.
|
|
Value wrenMapRemoveKey(WrenVM* vm, ObjMap* map, Value key);
|
|
|
|
// Creates a new module.
|
|
ObjModule* wrenNewModule(WrenVM* vm, ObjString* name);
|
|
|
|
// Creates a new range from [from] to [to].
|
|
Value wrenNewRange(WrenVM* vm, double from, double to, bool isInclusive);
|
|
|
|
// Creates a new string object and copies [text] into it.
|
|
//
|
|
// [text] must be non-NULL.
|
|
Value wrenNewString(WrenVM* vm, const char* text);
|
|
|
|
// Creates a new string object of [length] and copies [text] into it.
|
|
//
|
|
// [text] may be NULL if [length] is zero.
|
|
Value wrenNewStringLength(WrenVM* vm, const char* text, size_t length);
|
|
|
|
// Creates a new string object by taking a range of characters from [source].
|
|
// The range starts at [start], contains [count] bytes, and increments by
|
|
// [step].
|
|
Value wrenNewStringFromRange(WrenVM* vm, ObjString* source, int start,
|
|
uint32_t count, int step);
|
|
|
|
// Produces a string representation of [value].
|
|
Value wrenNumToString(WrenVM* vm, double value);
|
|
|
|
// Creates a new formatted string from [format] and any additional arguments
|
|
// used in the format string.
|
|
//
|
|
// This is a very restricted flavor of formatting, intended only for internal
|
|
// use by the VM. Two formatting characters are supported, each of which reads
|
|
// the next argument as a certain type:
|
|
//
|
|
// $ - A C string.
|
|
// @ - A Wren string object.
|
|
Value wrenStringFormat(WrenVM* vm, const char* format, ...);
|
|
|
|
// Creates a new string containing the UTF-8 encoding of [value].
|
|
Value wrenStringFromCodePoint(WrenVM* vm, int value);
|
|
|
|
// Creates a new string from the integer representation of a byte
|
|
Value wrenStringFromByte(WrenVM* vm, uint8_t value);
|
|
|
|
// Creates a new string containing the code point in [string] starting at byte
|
|
// [index]. If [index] points into the middle of a UTF-8 sequence, returns an
|
|
// empty string.
|
|
Value wrenStringCodePointAt(WrenVM* vm, ObjString* string, uint32_t index);
|
|
|
|
// Search for the first occurence of [needle] within [haystack] and returns its
|
|
// zero-based offset. Returns `UINT32_MAX` if [haystack] does not contain
|
|
// [needle].
|
|
uint32_t wrenStringFind(ObjString* haystack, ObjString* needle,
|
|
uint32_t startIndex);
|
|
|
|
// Returns true if [a] and [b] represent the same string.
|
|
static inline bool wrenStringEqualsCString(const ObjString* a,
|
|
const char* b, size_t length)
|
|
{
|
|
return a->length == length && memcmp(a->value, b, length) == 0;
|
|
}
|
|
|
|
// Creates a new open upvalue pointing to [value] on the stack.
|
|
ObjUpvalue* wrenNewUpvalue(WrenVM* vm, Value* value);
|
|
|
|
// Mark [obj] as reachable and still in use. This should only be called
|
|
// during the sweep phase of a garbage collection.
|
|
void wrenGrayObj(WrenVM* vm, Obj* obj);
|
|
|
|
// Mark [value] as reachable and still in use. This should only be called
|
|
// during the sweep phase of a garbage collection.
|
|
void wrenGrayValue(WrenVM* vm, Value value);
|
|
|
|
// Mark the values in [buffer] as reachable and still in use. This should only
|
|
// be called during the sweep phase of a garbage collection.
|
|
void wrenGrayBuffer(WrenVM* vm, ValueBuffer* buffer);
|
|
|
|
// Processes every object in the gray stack until all reachable objects have
|
|
// been marked. After that, all objects are either white (freeable) or black
|
|
// (in use and fully traversed).
|
|
void wrenBlackenObjects(WrenVM* vm);
|
|
|
|
// Releases all memory owned by [obj], including [obj] itself.
|
|
void wrenFreeObj(WrenVM* vm, Obj* obj);
|
|
|
|
// Returns the class of [value].
|
|
//
|
|
// Unlike wrenGetClassInline in wren_vm.h, this is not inlined. Inlining helps
|
|
// performance (significantly) in some cases, but degrades it in others. The
|
|
// ones used by the implementation were chosen to give the best results in the
|
|
// benchmarks.
|
|
ObjClass* wrenGetClass(WrenVM* vm, Value value);
|
|
|
|
// Returns true if [a] and [b] are strictly the same value. This is identity
|
|
// for object values, and value equality for unboxed values.
|
|
static inline bool wrenValuesSame(Value a, Value b)
|
|
{
|
|
#if WREN_NAN_TAGGING
|
|
// Value types have unique bit representations and we compare object types
|
|
// by identity (i.e. pointer), so all we need to do is compare the bits.
|
|
return a == b;
|
|
#else
|
|
if (a.type != b.type) return false;
|
|
if (a.type == VAL_NUM) return a.as.num == b.as.num;
|
|
return a.as.obj == b.as.obj;
|
|
#endif
|
|
}
|
|
|
|
// Returns true if [a] and [b] are equivalent. Immutable values (null, bools,
|
|
// numbers, ranges, and strings) are equal if they have the same data. All
|
|
// other values are equal if they are identical objects.
|
|
bool wrenValuesEqual(Value a, Value b);
|
|
|
|
// Returns true if [value] is a bool. Do not call this directly, instead use
|
|
// [IS_BOOL].
|
|
static inline bool wrenIsBool(Value value)
|
|
{
|
|
#if WREN_NAN_TAGGING
|
|
return value == TRUE_VAL || value == FALSE_VAL;
|
|
#else
|
|
return value.type == VAL_FALSE || value.type == VAL_TRUE;
|
|
#endif
|
|
}
|
|
|
|
// Returns true if [value] is an object of type [type]. Do not call this
|
|
// directly, instead use the [IS___] macro for the type in question.
|
|
static inline bool wrenIsObjType(Value value, ObjType type)
|
|
{
|
|
return IS_OBJ(value) && AS_OBJ(value)->type == type;
|
|
}
|
|
|
|
// Converts the raw object pointer [obj] to a [Value].
|
|
static inline Value wrenObjectToValue(Obj* obj)
|
|
{
|
|
#if WREN_NAN_TAGGING
|
|
// The triple casting is necessary here to satisfy some compilers:
|
|
// 1. (uintptr_t) Convert the pointer to a number of the right size.
|
|
// 2. (uint64_t) Pad it up to 64 bits in 32-bit builds.
|
|
// 3. Or in the bits to make a tagged Nan.
|
|
// 4. Cast to a typedef'd value.
|
|
return (Value)(SIGN_BIT | QNAN | (uint64_t)(uintptr_t)(obj));
|
|
#else
|
|
Value value;
|
|
value.type = VAL_OBJ;
|
|
value.as.obj = obj;
|
|
return value;
|
|
#endif
|
|
}
|
|
|
|
// Interprets [value] as a [double].
|
|
static inline double wrenValueToNum(Value value)
|
|
{
|
|
#if WREN_NAN_TAGGING
|
|
return wrenDoubleFromBits(value);
|
|
#else
|
|
return value.as.num;
|
|
#endif
|
|
}
|
|
|
|
// Converts [num] to a [Value].
|
|
static inline Value wrenNumToValue(double num)
|
|
{
|
|
#if WREN_NAN_TAGGING
|
|
return wrenDoubleToBits(num);
|
|
#else
|
|
Value value;
|
|
value.type = VAL_NUM;
|
|
value.as.num = num;
|
|
return value;
|
|
#endif
|
|
}
|
|
|
|
static inline bool wrenMapIsValidKey(Value arg)
|
|
{
|
|
return IS_BOOL(arg)
|
|
|| IS_CLASS(arg)
|
|
|| IS_NULL(arg)
|
|
|| IS_NUM(arg)
|
|
|| IS_RANGE(arg)
|
|
|| IS_STRING(arg);
|
|
}
|
|
|
|
#endif
|
|
// End file "wren_value.h"
|
|
// Begin file "wren_vm.h"
|
|
#ifndef wren_vm_h
|
|
#define wren_vm_h
|
|
|
|
// Begin file "wren_compiler.h"
|
|
#ifndef wren_compiler_h
|
|
#define wren_compiler_h
|
|
|
|
|
|
typedef struct sCompiler Compiler;
|
|
|
|
// This module defines the compiler for Wren. It takes a string of source code
|
|
// and lexes, parses, and compiles it. Wren uses a single-pass compiler. It
|
|
// does not build an actual AST during parsing and then consume that to
|
|
// generate code. Instead, the parser directly emits bytecode.
|
|
//
|
|
// This forces a few restrictions on the grammar and semantics of the language.
|
|
// Things like forward references and arbitrary lookahead are much harder. We
|
|
// get a lot in return for that, though.
|
|
//
|
|
// The implementation is much simpler since we don't need to define a bunch of
|
|
// AST data structures. More so, we don't have to deal with managing memory for
|
|
// AST objects. The compiler does almost no dynamic allocation while running.
|
|
//
|
|
// Compilation is also faster since we don't create a bunch of temporary data
|
|
// structures and destroy them after generating code.
|
|
|
|
// Compiles [source], a string of Wren source code located in [module], to an
|
|
// [ObjFn] that will execute that code when invoked. Returns `NULL` if the
|
|
// source contains any syntax errors.
|
|
//
|
|
// If [isExpression] is `true`, [source] should be a single expression, and
|
|
// this compiles it to a function that evaluates and returns that expression.
|
|
// Otherwise, [source] should be a series of top level statements.
|
|
//
|
|
// If [printErrors] is `true`, any compile errors are output to stderr.
|
|
// Otherwise, they are silently discarded.
|
|
ObjFn* wrenCompile(WrenVM* vm, ObjModule* module, const char* source,
|
|
bool isExpression, bool printErrors);
|
|
|
|
// When a class is defined, its superclass is not known until runtime since
|
|
// class definitions are just imperative statements. Most of the bytecode for a
|
|
// a method doesn't care, but there are two places where it matters:
|
|
//
|
|
// - To load or store a field, we need to know the index of the field in the
|
|
// instance's field array. We need to adjust this so that subclass fields
|
|
// are positioned after superclass fields, and we don't know this until the
|
|
// superclass is known.
|
|
//
|
|
// - Superclass calls need to know which superclass to dispatch to.
|
|
//
|
|
// We could handle this dynamically, but that adds overhead. Instead, when a
|
|
// method is bound, we walk the bytecode for the function and patch it up.
|
|
void wrenBindMethodCode(ObjClass* classObj, ObjFn* fn);
|
|
|
|
// Reaches all of the heap-allocated objects in use by [compiler] (and all of
|
|
// its parents) so that they are not collected by the GC.
|
|
void wrenMarkCompiler(WrenVM* vm, Compiler* compiler);
|
|
|
|
#endif
|
|
// End file "wren_compiler.h"
|
|
|
|
// The maximum number of temporary objects that can be made visible to the GC
|
|
// at one time.
|
|
#define WREN_MAX_TEMP_ROOTS 8
|
|
|
|
typedef enum
|
|
{
|
|
#define OPCODE(name, _) CODE_##name,
|
|
// Begin file "wren_opcodes.h"
|
|
// This defines the bytecode instructions used by the VM. It does so by invoking
|
|
// an OPCODE() macro which is expected to be defined at the point that this is
|
|
// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.)
|
|
//
|
|
// The first argument is the name of the opcode. The second is its "stack
|
|
// effect" -- the amount that the op code changes the size of the stack. A
|
|
// stack effect of 1 means it pushes a value and the stack grows one larger.
|
|
// -2 means it pops two values, etc.
|
|
//
|
|
// Note that the order of instructions here affects the order of the dispatch
|
|
// table in the VM's interpreter loop. That in turn affects caching which
|
|
// affects overall performance. Take care to run benchmarks if you change the
|
|
// order here.
|
|
|
|
// Load the constant at index [arg].
|
|
OPCODE(CONSTANT, 1)
|
|
|
|
// Push null onto the stack.
|
|
OPCODE(NULL, 1)
|
|
|
|
// Push false onto the stack.
|
|
OPCODE(FALSE, 1)
|
|
|
|
// Push true onto the stack.
|
|
OPCODE(TRUE, 1)
|
|
|
|
// Pushes the value in the given local slot.
|
|
OPCODE(LOAD_LOCAL_0, 1)
|
|
OPCODE(LOAD_LOCAL_1, 1)
|
|
OPCODE(LOAD_LOCAL_2, 1)
|
|
OPCODE(LOAD_LOCAL_3, 1)
|
|
OPCODE(LOAD_LOCAL_4, 1)
|
|
OPCODE(LOAD_LOCAL_5, 1)
|
|
OPCODE(LOAD_LOCAL_6, 1)
|
|
OPCODE(LOAD_LOCAL_7, 1)
|
|
OPCODE(LOAD_LOCAL_8, 1)
|
|
|
|
// Note: The compiler assumes the following _STORE instructions always
|
|
// immediately follow their corresponding _LOAD ones.
|
|
|
|
// Pushes the value in local slot [arg].
|
|
OPCODE(LOAD_LOCAL, 1)
|
|
|
|
// Stores the top of stack in local slot [arg]. Does not pop it.
|
|
OPCODE(STORE_LOCAL, 0)
|
|
|
|
// Pushes the value in upvalue [arg].
|
|
OPCODE(LOAD_UPVALUE, 1)
|
|
|
|
// Stores the top of stack in upvalue [arg]. Does not pop it.
|
|
OPCODE(STORE_UPVALUE, 0)
|
|
|
|
// Pushes the value of the top-level variable in slot [arg].
|
|
OPCODE(LOAD_MODULE_VAR, 1)
|
|
|
|
// Stores the top of stack in top-level variable slot [arg]. Does not pop it.
|
|
OPCODE(STORE_MODULE_VAR, 0)
|
|
|
|
// Pushes the value of the field in slot [arg] of the receiver of the current
|
|
// function. This is used for regular field accesses on "this" directly in
|
|
// methods. This instruction is faster than the more general CODE_LOAD_FIELD
|
|
// instruction.
|
|
OPCODE(LOAD_FIELD_THIS, 1)
|
|
|
|
// Stores the top of the stack in field slot [arg] in the receiver of the
|
|
// current value. Does not pop the value. This instruction is faster than the
|
|
// more general CODE_LOAD_FIELD instruction.
|
|
OPCODE(STORE_FIELD_THIS, 0)
|
|
|
|
// Pops an instance and pushes the value of the field in slot [arg] of it.
|
|
OPCODE(LOAD_FIELD, 0)
|
|
|
|
// Pops an instance and stores the subsequent top of stack in field slot
|
|
// [arg] in it. Does not pop the value.
|
|
OPCODE(STORE_FIELD, -1)
|
|
|
|
// Pop and discard the top of stack.
|
|
OPCODE(POP, -1)
|
|
|
|
// Invoke the method with symbol [arg]. The number indicates the number of
|
|
// arguments (not including the receiver).
|
|
OPCODE(CALL_0, 0)
|
|
OPCODE(CALL_1, -1)
|
|
OPCODE(CALL_2, -2)
|
|
OPCODE(CALL_3, -3)
|
|
OPCODE(CALL_4, -4)
|
|
OPCODE(CALL_5, -5)
|
|
OPCODE(CALL_6, -6)
|
|
OPCODE(CALL_7, -7)
|
|
OPCODE(CALL_8, -8)
|
|
OPCODE(CALL_9, -9)
|
|
OPCODE(CALL_10, -10)
|
|
OPCODE(CALL_11, -11)
|
|
OPCODE(CALL_12, -12)
|
|
OPCODE(CALL_13, -13)
|
|
OPCODE(CALL_14, -14)
|
|
OPCODE(CALL_15, -15)
|
|
OPCODE(CALL_16, -16)
|
|
|
|
// Invoke a superclass method with symbol [arg]. The number indicates the
|
|
// number of arguments (not including the receiver).
|
|
OPCODE(SUPER_0, 0)
|
|
OPCODE(SUPER_1, -1)
|
|
OPCODE(SUPER_2, -2)
|
|
OPCODE(SUPER_3, -3)
|
|
OPCODE(SUPER_4, -4)
|
|
OPCODE(SUPER_5, -5)
|
|
OPCODE(SUPER_6, -6)
|
|
OPCODE(SUPER_7, -7)
|
|
OPCODE(SUPER_8, -8)
|
|
OPCODE(SUPER_9, -9)
|
|
OPCODE(SUPER_10, -10)
|
|
OPCODE(SUPER_11, -11)
|
|
OPCODE(SUPER_12, -12)
|
|
OPCODE(SUPER_13, -13)
|
|
OPCODE(SUPER_14, -14)
|
|
OPCODE(SUPER_15, -15)
|
|
OPCODE(SUPER_16, -16)
|
|
|
|
// Jump the instruction pointer [arg] forward.
|
|
OPCODE(JUMP, 0)
|
|
|
|
// Jump the instruction pointer [arg] backward.
|
|
OPCODE(LOOP, 0)
|
|
|
|
// Pop and if not truthy then jump the instruction pointer [arg] forward.
|
|
OPCODE(JUMP_IF, -1)
|
|
|
|
// If the top of the stack is false, jump [arg] forward. Otherwise, pop and
|
|
// continue.
|
|
OPCODE(AND, -1)
|
|
|
|
// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop
|
|
// and continue.
|
|
OPCODE(OR, -1)
|
|
|
|
// Close the upvalue for the local on the top of the stack, then pop it.
|
|
OPCODE(CLOSE_UPVALUE, -1)
|
|
|
|
// Exit from the current function and return the value on the top of the
|
|
// stack.
|
|
OPCODE(RETURN, 0)
|
|
|
|
// Creates a closure for the function stored at [arg] in the constant table.
|
|
//
|
|
// Following the function argument is a number of arguments, two for each
|
|
// upvalue. The first is true if the variable being captured is a local (as
|
|
// opposed to an upvalue), and the second is the index of the local or
|
|
// upvalue being captured.
|
|
//
|
|
// Pushes the created closure.
|
|
OPCODE(CLOSURE, 1)
|
|
|
|
// Creates a new instance of a class.
|
|
//
|
|
// Assumes the class object is in slot zero, and replaces it with the new
|
|
// uninitialized instance of that class. This opcode is only emitted by the
|
|
// compiler-generated constructor metaclass methods.
|
|
OPCODE(CONSTRUCT, 0)
|
|
|
|
// Creates a new instance of a foreign class.
|
|
//
|
|
// Assumes the class object is in slot zero, and replaces it with the new
|
|
// uninitialized instance of that class. This opcode is only emitted by the
|
|
// compiler-generated constructor metaclass methods.
|
|
OPCODE(FOREIGN_CONSTRUCT, 0)
|
|
|
|
// Creates a class. Top of stack is the superclass. Below that is a string for
|
|
// the name of the class. Byte [arg] is the number of fields in the class.
|
|
OPCODE(CLASS, -1)
|
|
|
|
// Ends a class.
|
|
// Atm the stack contains the class and the ClassAttributes (or null).
|
|
OPCODE(END_CLASS, -2)
|
|
|
|
// Creates a foreign class. Top of stack is the superclass. Below that is a
|
|
// string for the name of the class.
|
|
OPCODE(FOREIGN_CLASS, -1)
|
|
|
|
// Define a method for symbol [arg]. The class receiving the method is popped
|
|
// off the stack, then the function defining the body is popped.
|
|
//
|
|
// If a foreign method is being defined, the "function" will be a string
|
|
// identifying the foreign method. Otherwise, it will be a function or
|
|
// closure.
|
|
OPCODE(METHOD_INSTANCE, -2)
|
|
|
|
// Define a method for symbol [arg]. The class whose metaclass will receive
|
|
// the method is popped off the stack, then the function defining the body is
|
|
// popped.
|
|
//
|
|
// If a foreign method is being defined, the "function" will be a string
|
|
// identifying the foreign method. Otherwise, it will be a function or
|
|
// closure.
|
|
OPCODE(METHOD_STATIC, -2)
|
|
|
|
// This is executed at the end of the module's body. Pushes NULL onto the stack
|
|
// as the "return value" of the import statement and stores the module as the
|
|
// most recently imported one.
|
|
OPCODE(END_MODULE, 1)
|
|
|
|
// Import a module whose name is the string stored at [arg] in the constant
|
|
// table.
|
|
//
|
|
// Pushes null onto the stack so that the fiber for the imported module can
|
|
// replace that with a dummy value when it returns. (Fibers always return a
|
|
// value when resuming a caller.)
|
|
OPCODE(IMPORT_MODULE, 1)
|
|
|
|
// Import a variable from the most recently imported module. The name of the
|
|
// variable to import is at [arg] in the constant table. Pushes the loaded
|
|
// variable's value.
|
|
OPCODE(IMPORT_VARIABLE, 1)
|
|
|
|
// This pseudo-instruction indicates the end of the bytecode. It should
|
|
// always be preceded by a `CODE_RETURN`, so is never actually executed.
|
|
OPCODE(END, 0)
|
|
// End file "wren_opcodes.h"
|
|
#undef OPCODE
|
|
} Code;
|
|
|
|
// A handle to a value, basically just a linked list of extra GC roots.
|
|
//
|
|
// Note that even non-heap-allocated values can be stored here.
|
|
struct WrenHandle
|
|
{
|
|
Value value;
|
|
|
|
WrenHandle* prev;
|
|
WrenHandle* next;
|
|
};
|
|
|
|
struct WrenVM
|
|
{
|
|
ObjClass* boolClass;
|
|
ObjClass* classClass;
|
|
ObjClass* fiberClass;
|
|
ObjClass* fnClass;
|
|
ObjClass* listClass;
|
|
ObjClass* mapClass;
|
|
ObjClass* nullClass;
|
|
ObjClass* numClass;
|
|
ObjClass* objectClass;
|
|
ObjClass* rangeClass;
|
|
ObjClass* stringClass;
|
|
|
|
// The fiber that is currently running.
|
|
ObjFiber* fiber;
|
|
|
|
// The loaded modules. Each key is an ObjString (except for the main module,
|
|
// whose key is null) for the module's name and the value is the ObjModule
|
|
// for the module.
|
|
ObjMap* modules;
|
|
|
|
// The most recently imported module. More specifically, the module whose
|
|
// code has most recently finished executing.
|
|
//
|
|
// Not treated like a GC root since the module is already in [modules].
|
|
ObjModule* lastModule;
|
|
|
|
// Memory management data:
|
|
|
|
// The number of bytes that are known to be currently allocated. Includes all
|
|
// memory that was proven live after the last GC, as well as any new bytes
|
|
// that were allocated since then. Does *not* include bytes for objects that
|
|
// were freed since the last GC.
|
|
size_t bytesAllocated;
|
|
|
|
// The number of total allocated bytes that will trigger the next GC.
|
|
size_t nextGC;
|
|
|
|
// The first object in the linked list of all currently allocated objects.
|
|
Obj* first;
|
|
|
|
// The "gray" set for the garbage collector. This is the stack of unprocessed
|
|
// objects while a garbage collection pass is in process.
|
|
Obj** gray;
|
|
int grayCount;
|
|
int grayCapacity;
|
|
|
|
// The list of temporary roots. This is for temporary or new objects that are
|
|
// not otherwise reachable but should not be collected.
|
|
//
|
|
// They are organized as a stack of pointers stored in this array. This
|
|
// implies that temporary roots need to have stack semantics: only the most
|
|
// recently pushed object can be released.
|
|
Obj* tempRoots[WREN_MAX_TEMP_ROOTS];
|
|
|
|
int numTempRoots;
|
|
|
|
// Pointer to the first node in the linked list of active handles or NULL if
|
|
// there are none.
|
|
WrenHandle* handles;
|
|
|
|
// Pointer to the bottom of the range of stack slots available for use from
|
|
// the C API. During a foreign method, this will be in the stack of the fiber
|
|
// that is executing a method.
|
|
//
|
|
// If not in a foreign method, this is initially NULL. If the user requests
|
|
// slots by calling wrenEnsureSlots(), a stack is created and this is
|
|
// initialized.
|
|
Value* apiStack;
|
|
|
|
WrenConfiguration config;
|
|
|
|
// Compiler and debugger data:
|
|
|
|
// The compiler that is currently compiling code. This is used so that heap
|
|
// allocated objects used by the compiler can be found if a GC is kicked off
|
|
// in the middle of a compile.
|
|
Compiler* compiler;
|
|
|
|
// There is a single global symbol table for all method names on all classes.
|
|
// Method calls are dispatched directly by index in this table.
|
|
SymbolTable methodNames;
|
|
};
|
|
|
|
// A generic allocation function that handles all explicit memory management.
|
|
// It's used like so:
|
|
//
|
|
// - To allocate new memory, [memory] is NULL and [oldSize] is zero. It should
|
|
// return the allocated memory or NULL on failure.
|
|
//
|
|
// - To attempt to grow an existing allocation, [memory] is the memory,
|
|
// [oldSize] is its previous size, and [newSize] is the desired size.
|
|
// It should return [memory] if it was able to grow it in place, or a new
|
|
// pointer if it had to move it.
|
|
//
|
|
// - To shrink memory, [memory], [oldSize], and [newSize] are the same as above
|
|
// but it will always return [memory].
|
|
//
|
|
// - To free memory, [memory] will be the memory to free and [newSize] and
|
|
// [oldSize] will be zero. It should return NULL.
|
|
void* wrenReallocate(WrenVM* vm, void* memory, size_t oldSize, size_t newSize);
|
|
|
|
// Invoke the finalizer for the foreign object referenced by [foreign].
|
|
void wrenFinalizeForeign(WrenVM* vm, ObjForeign* foreign);
|
|
|
|
// Creates a new [WrenHandle] for [value].
|
|
WrenHandle* wrenMakeHandle(WrenVM* vm, Value value);
|
|
|
|
// Compile [source] in the context of [module] and wrap in a fiber that can
|
|
// execute it.
|
|
//
|
|
// Returns NULL if a compile error occurred.
|
|
ObjClosure* wrenCompileSource(WrenVM* vm, const char* module,
|
|
const char* source, bool isExpression,
|
|
bool printErrors);
|
|
|
|
// Looks up a variable from a previously-loaded module.
|
|
//
|
|
// Aborts the current fiber if the module or variable could not be found.
|
|
Value wrenGetModuleVariable(WrenVM* vm, Value moduleName, Value variableName);
|
|
|
|
// Returns the value of the module-level variable named [name] in the main
|
|
// module.
|
|
Value wrenFindVariable(WrenVM* vm, ObjModule* module, const char* name);
|
|
|
|
// Adds a new implicitly declared top-level variable named [name] to [module]
|
|
// based on a use site occurring on [line].
|
|
//
|
|
// Does not check to see if a variable with that name is already declared or
|
|
// defined. Returns the symbol for the new variable or -2 if there are too many
|
|
// variables defined.
|
|
int wrenDeclareVariable(WrenVM* vm, ObjModule* module, const char* name,
|
|
size_t length, int line);
|
|
|
|
// Adds a new top-level variable named [name] to [module], and optionally
|
|
// populates line with the line of the implicit first use (line can be NULL).
|
|
//
|
|
// Returns the symbol for the new variable, -1 if a variable with the given name
|
|
// is already defined, or -2 if there are too many variables defined.
|
|
// Returns -3 if this is a top-level lowercase variable (localname) that was
|
|
// used before being defined.
|
|
int wrenDefineVariable(WrenVM* vm, ObjModule* module, const char* name,
|
|
size_t length, Value value, int* line);
|
|
|
|
// Pushes [closure] onto [fiber]'s callstack to invoke it. Expects [numArgs]
|
|
// arguments (including the receiver) to be on the top of the stack already.
|
|
static inline void wrenCallFunction(WrenVM* vm, ObjFiber* fiber,
|
|
ObjClosure* closure, int numArgs)
|
|
{
|
|
// Grow the call frame array if needed.
|
|
if (fiber->numFrames + 1 > fiber->frameCapacity)
|
|
{
|
|
int max = fiber->frameCapacity * 2;
|
|
fiber->frames = (CallFrame*)wrenReallocate(vm, fiber->frames,
|
|
sizeof(CallFrame) * fiber->frameCapacity, sizeof(CallFrame) * max);
|
|
fiber->frameCapacity = max;
|
|
}
|
|
|
|
// Grow the stack if needed.
|
|
int stackSize = (int)(fiber->stackTop - fiber->stack);
|
|
int needed = stackSize + closure->fn->maxSlots;
|
|
wrenEnsureStack(vm, fiber, needed);
|
|
|
|
wrenAppendCallFrame(vm, fiber, closure, fiber->stackTop - numArgs);
|
|
}
|
|
|
|
// Marks [obj] as a GC root so that it doesn't get collected.
|
|
void wrenPushRoot(WrenVM* vm, Obj* obj);
|
|
|
|
// Removes the most recently pushed temporary root.
|
|
void wrenPopRoot(WrenVM* vm);
|
|
|
|
// Returns the class of [value].
|
|
//
|
|
// Defined here instead of in wren_value.h because it's critical that this be
|
|
// inlined. That means it must be defined in the header, but the wren_value.h
|
|
// header doesn't have a full definitely of WrenVM yet.
|
|
static inline ObjClass* wrenGetClassInline(WrenVM* vm, Value value)
|
|
{
|
|
if (IS_NUM(value)) return vm->numClass;
|
|
if (IS_OBJ(value)) return AS_OBJ(value)->classObj;
|
|
|
|
#if WREN_NAN_TAGGING
|
|
switch (GET_TAG(value))
|
|
{
|
|
case TAG_FALSE: return vm->boolClass; break;
|
|
case TAG_NAN: return vm->numClass; break;
|
|
case TAG_NULL: return vm->nullClass; break;
|
|
case TAG_TRUE: return vm->boolClass; break;
|
|
case TAG_UNDEFINED: UNREACHABLE();
|
|
}
|
|
#else
|
|
switch (value.type)
|
|
{
|
|
case VAL_FALSE: return vm->boolClass;
|
|
case VAL_NULL: return vm->nullClass;
|
|
case VAL_NUM: return vm->numClass;
|
|
case VAL_TRUE: return vm->boolClass;
|
|
case VAL_OBJ: return AS_OBJ(value)->classObj;
|
|
case VAL_UNDEFINED: UNREACHABLE();
|
|
}
|
|
#endif
|
|
|
|
UNREACHABLE();
|
|
return NULL;
|
|
}
|
|
|
|
// Returns `true` if [name] is a local variable name (starts with a lowercase
|
|
// letter).
|
|
static inline bool wrenIsLocalName(const char* name)
|
|
{
|
|
return name[0] >= 'a' && name[0] <= 'z';
|
|
}
|
|
|
|
static inline bool wrenIsFalsyValue(Value value)
|
|
{
|
|
return IS_FALSE(value) || IS_NULL(value);
|
|
}
|
|
|
|
#endif
|
|
// End file "wren_vm.h"
|
|
|
|
// Prints the stack trace for the current fiber.
|
|
//
|
|
// Used when a fiber throws a runtime error which is not caught.
|
|
void wrenDebugPrintStackTrace(WrenVM* vm);
|
|
|
|
// The "dump" functions are used for debugging Wren itself. Normal code paths
|
|
// will not call them unless one of the various DEBUG_ flags is enabled.
|
|
|
|
// Prints a representation of [value] to stdout.
|
|
void wrenDumpValue(Value value);
|
|
|
|
// Prints a representation of the bytecode for [fn] at instruction [i].
|
|
int wrenDumpInstruction(WrenVM* vm, ObjFn* fn, int i);
|
|
|
|
// Prints the disassembled code for [fn] to stdout.
|
|
void wrenDumpCode(WrenVM* vm, ObjFn* fn);
|
|
|
|
// Prints the contents of the current stack for [fiber] to stdout.
|
|
void wrenDumpStack(ObjFiber* fiber);
|
|
|
|
#endif
|
|
// End file "wren_debug.h"
|
|
// Begin file "wren_debug.c"
|
|
#include <stdio.h>
|
|
|
|
|
|
void wrenDebugPrintStackTrace(WrenVM* vm)
|
|
{
|
|
// Bail if the host doesn't enable printing errors.
|
|
if (vm->config.errorFn == NULL) return;
|
|
|
|
ObjFiber* fiber = vm->fiber;
|
|
if (IS_STRING(fiber->error))
|
|
{
|
|
vm->config.errorFn(vm, WREN_ERROR_RUNTIME,
|
|
NULL, -1, AS_CSTRING(fiber->error));
|
|
}
|
|
else
|
|
{
|
|
// TODO: Print something a little useful here. Maybe the name of the error's
|
|
// class?
|
|
vm->config.errorFn(vm, WREN_ERROR_RUNTIME,
|
|
NULL, -1, "[error object]");
|
|
}
|
|
|
|
for (int i = fiber->numFrames - 1; i >= 0; i--)
|
|
{
|
|
CallFrame* frame = &fiber->frames[i];
|
|
ObjFn* fn = frame->closure->fn;
|
|
|
|
// Skip over stub functions for calling methods from the C API.
|
|
if (fn->module == NULL) continue;
|
|
|
|
// The built-in core module has no name. We explicitly omit it from stack
|
|
// traces since we don't want to highlight to a user the implementation
|
|
// detail of what part of the core module is written in C and what is Wren.
|
|
if (fn->module->name == NULL) continue;
|
|
|
|
// -1 because IP has advanced past the instruction that it just executed.
|
|
int line = fn->debug->sourceLines.data[frame->ip - fn->code.data - 1];
|
|
vm->config.errorFn(vm, WREN_ERROR_STACK_TRACE,
|
|
fn->module->name->value, line,
|
|
fn->debug->name);
|
|
}
|
|
}
|
|
|
|
static void dumpObject(Obj* obj)
|
|
{
|
|
switch (obj->type)
|
|
{
|
|
case OBJ_CLASS:
|
|
printf("[class %s %p]", ((ObjClass*)obj)->name->value, obj);
|
|
break;
|
|
case OBJ_CLOSURE: printf("[closure %p]", obj); break;
|
|
case OBJ_FIBER: printf("[fiber %p]", obj); break;
|
|
case OBJ_FN: printf("[fn %p]", obj); break;
|
|
case OBJ_FOREIGN: printf("[foreign %p]", obj); break;
|
|
case OBJ_INSTANCE: printf("[instance %p]", obj); break;
|
|
case OBJ_LIST: printf("[list %p]", obj); break;
|
|
case OBJ_MAP: printf("[map %p]", obj); break;
|
|
case OBJ_MODULE: printf("[module %p]", obj); break;
|
|
case OBJ_RANGE: printf("[range %p]", obj); break;
|
|
case OBJ_STRING: printf("%s", ((ObjString*)obj)->value); break;
|
|
case OBJ_UPVALUE: printf("[upvalue %p]", obj); break;
|
|
default: printf("[unknown object %d]", obj->type); break;
|
|
}
|
|
}
|
|
|
|
void wrenDumpValue(Value value)
|
|
{
|
|
#if WREN_NAN_TAGGING
|
|
if (IS_NUM(value))
|
|
{
|
|
printf("%.14g", AS_NUM(value));
|
|
}
|
|
else if (IS_OBJ(value))
|
|
{
|
|
dumpObject(AS_OBJ(value));
|
|
}
|
|
else
|
|
{
|
|
switch (GET_TAG(value))
|
|
{
|
|
case TAG_FALSE: printf("false"); break;
|
|
case TAG_NAN: printf("NaN"); break;
|
|
case TAG_NULL: printf("null"); break;
|
|
case TAG_TRUE: printf("true"); break;
|
|
case TAG_UNDEFINED: UNREACHABLE();
|
|
}
|
|
}
|
|
#else
|
|
switch (value.type)
|
|
{
|
|
case VAL_FALSE: printf("false"); break;
|
|
case VAL_NULL: printf("null"); break;
|
|
case VAL_NUM: printf("%.14g", AS_NUM(value)); break;
|
|
case VAL_TRUE: printf("true"); break;
|
|
case VAL_OBJ: dumpObject(AS_OBJ(value)); break;
|
|
case VAL_UNDEFINED: UNREACHABLE();
|
|
}
|
|
#endif
|
|
}
|
|
|
|
static int dumpInstruction(WrenVM* vm, ObjFn* fn, int i, int* lastLine)
|
|
{
|
|
int start = i;
|
|
uint8_t* bytecode = fn->code.data;
|
|
Code code = (Code)bytecode[i];
|
|
|
|
int line = fn->debug->sourceLines.data[i];
|
|
if (lastLine == NULL || *lastLine != line)
|
|
{
|
|
printf("%4d:", line);
|
|
if (lastLine != NULL) *lastLine = line;
|
|
}
|
|
else
|
|
{
|
|
printf(" ");
|
|
}
|
|
|
|
printf(" %04d ", i++);
|
|
|
|
#define READ_BYTE() (bytecode[i++])
|
|
#define READ_SHORT() (i += 2, (bytecode[i - 2] << 8) | bytecode[i - 1])
|
|
|
|
#define BYTE_INSTRUCTION(name) \
|
|
printf("%-16s %5d\n", name, READ_BYTE()); \
|
|
break
|
|
|
|
switch (code)
|
|
{
|
|
case CODE_CONSTANT:
|
|
{
|
|
int constant = READ_SHORT();
|
|
printf("%-16s %5d '", "CONSTANT", constant);
|
|
wrenDumpValue(fn->constants.data[constant]);
|
|
printf("'\n");
|
|
break;
|
|
}
|
|
|
|
case CODE_NULL: printf("NULL\n"); break;
|
|
case CODE_FALSE: printf("FALSE\n"); break;
|
|
case CODE_TRUE: printf("TRUE\n"); break;
|
|
|
|
case CODE_LOAD_LOCAL_0: printf("LOAD_LOCAL_0\n"); break;
|
|
case CODE_LOAD_LOCAL_1: printf("LOAD_LOCAL_1\n"); break;
|
|
case CODE_LOAD_LOCAL_2: printf("LOAD_LOCAL_2\n"); break;
|
|
case CODE_LOAD_LOCAL_3: printf("LOAD_LOCAL_3\n"); break;
|
|
case CODE_LOAD_LOCAL_4: printf("LOAD_LOCAL_4\n"); break;
|
|
case CODE_LOAD_LOCAL_5: printf("LOAD_LOCAL_5\n"); break;
|
|
case CODE_LOAD_LOCAL_6: printf("LOAD_LOCAL_6\n"); break;
|
|
case CODE_LOAD_LOCAL_7: printf("LOAD_LOCAL_7\n"); break;
|
|
case CODE_LOAD_LOCAL_8: printf("LOAD_LOCAL_8\n"); break;
|
|
|
|
case CODE_LOAD_LOCAL: BYTE_INSTRUCTION("LOAD_LOCAL");
|
|
case CODE_STORE_LOCAL: BYTE_INSTRUCTION("STORE_LOCAL");
|
|
case CODE_LOAD_UPVALUE: BYTE_INSTRUCTION("LOAD_UPVALUE");
|
|
case CODE_STORE_UPVALUE: BYTE_INSTRUCTION("STORE_UPVALUE");
|
|
|
|
case CODE_LOAD_MODULE_VAR:
|
|
{
|
|
int slot = READ_SHORT();
|
|
printf("%-16s %5d '%s'\n", "LOAD_MODULE_VAR", slot,
|
|
fn->module->variableNames.data[slot]->value);
|
|
break;
|
|
}
|
|
|
|
case CODE_STORE_MODULE_VAR:
|
|
{
|
|
int slot = READ_SHORT();
|
|
printf("%-16s %5d '%s'\n", "STORE_MODULE_VAR", slot,
|
|
fn->module->variableNames.data[slot]->value);
|
|
break;
|
|
}
|
|
|
|
case CODE_LOAD_FIELD_THIS: BYTE_INSTRUCTION("LOAD_FIELD_THIS");
|
|
case CODE_STORE_FIELD_THIS: BYTE_INSTRUCTION("STORE_FIELD_THIS");
|
|
case CODE_LOAD_FIELD: BYTE_INSTRUCTION("LOAD_FIELD");
|
|
case CODE_STORE_FIELD: BYTE_INSTRUCTION("STORE_FIELD");
|
|
|
|
case CODE_POP: printf("POP\n"); break;
|
|
|
|
case CODE_CALL_0:
|
|
case CODE_CALL_1:
|
|
case CODE_CALL_2:
|
|
case CODE_CALL_3:
|
|
case CODE_CALL_4:
|
|
case CODE_CALL_5:
|
|
case CODE_CALL_6:
|
|
case CODE_CALL_7:
|
|
case CODE_CALL_8:
|
|
case CODE_CALL_9:
|
|
case CODE_CALL_10:
|
|
case CODE_CALL_11:
|
|
case CODE_CALL_12:
|
|
case CODE_CALL_13:
|
|
case CODE_CALL_14:
|
|
case CODE_CALL_15:
|
|
case CODE_CALL_16:
|
|
{
|
|
int numArgs = bytecode[i - 1] - CODE_CALL_0;
|
|
int symbol = READ_SHORT();
|
|
printf("CALL_%-11d %5d '%s'\n", numArgs, symbol,
|
|
vm->methodNames.data[symbol]->value);
|
|
break;
|
|
}
|
|
|
|
case CODE_SUPER_0:
|
|
case CODE_SUPER_1:
|
|
case CODE_SUPER_2:
|
|
case CODE_SUPER_3:
|
|
case CODE_SUPER_4:
|
|
case CODE_SUPER_5:
|
|
case CODE_SUPER_6:
|
|
case CODE_SUPER_7:
|
|
case CODE_SUPER_8:
|
|
case CODE_SUPER_9:
|
|
case CODE_SUPER_10:
|
|
case CODE_SUPER_11:
|
|
case CODE_SUPER_12:
|
|
case CODE_SUPER_13:
|
|
case CODE_SUPER_14:
|
|
case CODE_SUPER_15:
|
|
case CODE_SUPER_16:
|
|
{
|
|
int numArgs = bytecode[i - 1] - CODE_SUPER_0;
|
|
int symbol = READ_SHORT();
|
|
int superclass = READ_SHORT();
|
|
printf("SUPER_%-10d %5d '%s' %5d\n", numArgs, symbol,
|
|
vm->methodNames.data[symbol]->value, superclass);
|
|
break;
|
|
}
|
|
|
|
case CODE_JUMP:
|
|
{
|
|
int offset = READ_SHORT();
|
|
printf("%-16s %5d to %d\n", "JUMP", offset, i + offset);
|
|
break;
|
|
}
|
|
|
|
case CODE_LOOP:
|
|
{
|
|
int offset = READ_SHORT();
|
|
printf("%-16s %5d to %d\n", "LOOP", offset, i - offset);
|
|
break;
|
|
}
|
|
|
|
case CODE_JUMP_IF:
|
|
{
|
|
int offset = READ_SHORT();
|
|
printf("%-16s %5d to %d\n", "JUMP_IF", offset, i + offset);
|
|
break;
|
|
}
|
|
|
|
case CODE_AND:
|
|
{
|
|
int offset = READ_SHORT();
|
|
printf("%-16s %5d to %d\n", "AND", offset, i + offset);
|
|
break;
|
|
}
|
|
|
|
case CODE_OR:
|
|
{
|
|
int offset = READ_SHORT();
|
|
printf("%-16s %5d to %d\n", "OR", offset, i + offset);
|
|
break;
|
|
}
|
|
|
|
case CODE_CLOSE_UPVALUE: printf("CLOSE_UPVALUE\n"); break;
|
|
case CODE_RETURN: printf("RETURN\n"); break;
|
|
|
|
case CODE_CLOSURE:
|
|
{
|
|
int constant = READ_SHORT();
|
|
printf("%-16s %5d ", "CLOSURE", constant);
|
|
wrenDumpValue(fn->constants.data[constant]);
|
|
printf(" ");
|
|
ObjFn* loadedFn = AS_FN(fn->constants.data[constant]);
|
|
for (int j = 0; j < loadedFn->numUpvalues; j++)
|
|
{
|
|
int isLocal = READ_BYTE();
|
|
int index = READ_BYTE();
|
|
if (j > 0) printf(", ");
|
|
printf("%s %d", isLocal ? "local" : "upvalue", index);
|
|
}
|
|
printf("\n");
|
|
break;
|
|
}
|
|
|
|
case CODE_CONSTRUCT: printf("CONSTRUCT\n"); break;
|
|
case CODE_FOREIGN_CONSTRUCT: printf("FOREIGN_CONSTRUCT\n"); break;
|
|
|
|
case CODE_CLASS:
|
|
{
|
|
int numFields = READ_BYTE();
|
|
printf("%-16s %5d fields\n", "CLASS", numFields);
|
|
break;
|
|
}
|
|
|
|
case CODE_FOREIGN_CLASS: printf("FOREIGN_CLASS\n"); break;
|
|
case CODE_END_CLASS: printf("END_CLASS\n"); break;
|
|
|
|
case CODE_METHOD_INSTANCE:
|
|
{
|
|
int symbol = READ_SHORT();
|
|
printf("%-16s %5d '%s'\n", "METHOD_INSTANCE", symbol,
|
|
vm->methodNames.data[symbol]->value);
|
|
break;
|
|
}
|
|
|
|
case CODE_METHOD_STATIC:
|
|
{
|
|
int symbol = READ_SHORT();
|
|
printf("%-16s %5d '%s'\n", "METHOD_STATIC", symbol,
|
|
vm->methodNames.data[symbol]->value);
|
|
break;
|
|
}
|
|
|
|
case CODE_END_MODULE:
|
|
printf("END_MODULE\n");
|
|
break;
|
|
|
|
case CODE_IMPORT_MODULE:
|
|
{
|
|
int name = READ_SHORT();
|
|
printf("%-16s %5d '", "IMPORT_MODULE", name);
|
|
wrenDumpValue(fn->constants.data[name]);
|
|
printf("'\n");
|
|
break;
|
|
}
|
|
|
|
case CODE_IMPORT_VARIABLE:
|
|
{
|
|
int variable = READ_SHORT();
|
|
printf("%-16s %5d '", "IMPORT_VARIABLE", variable);
|
|
wrenDumpValue(fn->constants.data[variable]);
|
|
printf("'\n");
|
|
break;
|
|
}
|
|
|
|
case CODE_END:
|
|
printf("END\n");
|
|
break;
|
|
|
|
default:
|
|
printf("UKNOWN! [%d]\n", bytecode[i - 1]);
|
|
break;
|
|
}
|
|
|
|
// Return how many bytes this instruction takes, or -1 if it's an END.
|
|
if (code == CODE_END) return -1;
|
|
return i - start;
|
|
|
|
#undef READ_BYTE
|
|
#undef READ_SHORT
|
|
}
|
|
|
|
int wrenDumpInstruction(WrenVM* vm, ObjFn* fn, int i)
|
|
{
|
|
return dumpInstruction(vm, fn, i, NULL);
|
|
}
|
|
|
|
void wrenDumpCode(WrenVM* vm, ObjFn* fn)
|
|
{
|
|
printf("%s: %s\n",
|
|
fn->module->name == NULL ? "<core>" : fn->module->name->value,
|
|
fn->debug->name);
|
|
|
|
int i = 0;
|
|
int lastLine = -1;
|
|
for (;;)
|
|
{
|
|
int offset = dumpInstruction(vm, fn, i, &lastLine);
|
|
if (offset == -1) break;
|
|
i += offset;
|
|
}
|
|
|
|
printf("\n");
|
|
}
|
|
|
|
void wrenDumpStack(ObjFiber* fiber)
|
|
{
|
|
printf("(fiber %p) ", fiber);
|
|
for (Value* slot = fiber->stack; slot < fiber->stackTop; slot++)
|
|
{
|
|
wrenDumpValue(*slot);
|
|
printf(" | ");
|
|
}
|
|
printf("\n");
|
|
}
|
|
// End file "wren_debug.c"
|
|
// Begin file "wren_compiler.c"
|
|
#include <errno.h>
|
|
#include <stdbool.h>
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
|
|
|
|
#if WREN_DEBUG_DUMP_COMPILED_CODE
|
|
#endif
|
|
|
|
// This is written in bottom-up order, so the tokenization comes first, then
|
|
// parsing/code generation. This minimizes the number of explicit forward
|
|
// declarations needed.
|
|
|
|
// The maximum number of local (i.e. not module level) variables that can be
|
|
// declared in a single function, method, or chunk of top level code. This is
|
|
// the maximum number of variables in scope at one time, and spans block scopes.
|
|
//
|
|
// Note that this limitation is also explicit in the bytecode. Since
|
|
// `CODE_LOAD_LOCAL` and `CODE_STORE_LOCAL` use a single argument byte to
|
|
// identify the local, only 256 can be in scope at one time.
|
|
#define MAX_LOCALS 256
|
|
|
|
// The maximum number of upvalues (i.e. variables from enclosing functions)
|
|
// that a function can close over.
|
|
#define MAX_UPVALUES 256
|
|
|
|
// The maximum number of distinct constants that a function can contain. This
|
|
// value is explicit in the bytecode since `CODE_CONSTANT` only takes a single
|
|
// two-byte argument.
|
|
#define MAX_CONSTANTS (1 << 16)
|
|
|
|
// The maximum distance a CODE_JUMP or CODE_JUMP_IF instruction can move the
|
|
// instruction pointer.
|
|
#define MAX_JUMP (1 << 16)
|
|
|
|
// The maximum depth that interpolation can nest. For example, this string has
|
|
// three levels:
|
|
//
|
|
// "outside %(one + "%(two + "%(three)")")"
|
|
#define MAX_INTERPOLATION_NESTING 8
|
|
|
|
// The buffer size used to format a compile error message, excluding the header
|
|
// with the module name and error location. Using a hardcoded buffer for this
|
|
// is kind of hairy, but fortunately we can control what the longest possible
|
|
// message is and handle that. Ideally, we'd use `snprintf()`, but that's not
|
|
// available in standard C++98.
|
|
#define ERROR_MESSAGE_SIZE (80 + MAX_VARIABLE_NAME + 15)
|
|
|
|
typedef enum
|
|
{
|
|
TOKEN_LEFT_PAREN,
|
|
TOKEN_RIGHT_PAREN,
|
|
TOKEN_LEFT_BRACKET,
|
|
TOKEN_RIGHT_BRACKET,
|
|
TOKEN_LEFT_BRACE,
|
|
TOKEN_RIGHT_BRACE,
|
|
TOKEN_COLON,
|
|
TOKEN_DOT,
|
|
TOKEN_DOTDOT,
|
|
TOKEN_DOTDOTDOT,
|
|
TOKEN_COMMA,
|
|
TOKEN_STAR,
|
|
TOKEN_SLASH,
|
|
TOKEN_PERCENT,
|
|
TOKEN_HASH,
|
|
TOKEN_PLUS,
|
|
TOKEN_MINUS,
|
|
TOKEN_LTLT,
|
|
TOKEN_GTGT,
|
|
TOKEN_PIPE,
|
|
TOKEN_PIPEPIPE,
|
|
TOKEN_CARET,
|
|
TOKEN_AMP,
|
|
TOKEN_AMPAMP,
|
|
TOKEN_BANG,
|
|
TOKEN_TILDE,
|
|
TOKEN_QUESTION,
|
|
TOKEN_EQ,
|
|
TOKEN_LT,
|
|
TOKEN_GT,
|
|
TOKEN_LTEQ,
|
|
TOKEN_GTEQ,
|
|
TOKEN_EQEQ,
|
|
TOKEN_BANGEQ,
|
|
|
|
TOKEN_BREAK,
|
|
TOKEN_CONTINUE,
|
|
TOKEN_CLASS,
|
|
TOKEN_CONSTRUCT,
|
|
TOKEN_ELSE,
|
|
TOKEN_FALSE,
|
|
TOKEN_FOR,
|
|
TOKEN_FOREIGN,
|
|
TOKEN_IF,
|
|
TOKEN_IMPORT,
|
|
TOKEN_AS,
|
|
TOKEN_IN,
|
|
TOKEN_IS,
|
|
TOKEN_NULL,
|
|
TOKEN_RETURN,
|
|
TOKEN_STATIC,
|
|
TOKEN_SUPER,
|
|
TOKEN_THIS,
|
|
TOKEN_TRUE,
|
|
TOKEN_VAR,
|
|
TOKEN_WHILE,
|
|
|
|
TOKEN_FIELD,
|
|
TOKEN_STATIC_FIELD,
|
|
TOKEN_NAME,
|
|
TOKEN_NUMBER,
|
|
|
|
// A string literal without any interpolation, or the last section of a
|
|
// string following the last interpolated expression.
|
|
TOKEN_STRING,
|
|
|
|
// A portion of a string literal preceding an interpolated expression. This
|
|
// string:
|
|
//
|
|
// "a %(b) c %(d) e"
|
|
//
|
|
// is tokenized to:
|
|
//
|
|
// TOKEN_INTERPOLATION "a "
|
|
// TOKEN_NAME b
|
|
// TOKEN_INTERPOLATION " c "
|
|
// TOKEN_NAME d
|
|
// TOKEN_STRING " e"
|
|
TOKEN_INTERPOLATION,
|
|
|
|
TOKEN_LINE,
|
|
|
|
TOKEN_ERROR,
|
|
TOKEN_EOF
|
|
} TokenType;
|
|
|
|
typedef struct
|
|
{
|
|
TokenType type;
|
|
|
|
// The beginning of the token, pointing directly into the source.
|
|
const char* start;
|
|
|
|
// The length of the token in characters.
|
|
int length;
|
|
|
|
// The 1-based line where the token appears.
|
|
int line;
|
|
|
|
// The parsed value if the token is a literal.
|
|
Value value;
|
|
} Token;
|
|
|
|
typedef struct
|
|
{
|
|
WrenVM* vm;
|
|
|
|
// The module being parsed.
|
|
ObjModule* module;
|
|
|
|
// The source code being parsed.
|
|
const char* source;
|
|
|
|
// The beginning of the currently-being-lexed token in [source].
|
|
const char* tokenStart;
|
|
|
|
// The current character being lexed in [source].
|
|
const char* currentChar;
|
|
|
|
// The 1-based line number of [currentChar].
|
|
int currentLine;
|
|
|
|
// The upcoming token.
|
|
Token next;
|
|
|
|
// The most recently lexed token.
|
|
Token current;
|
|
|
|
// The most recently consumed/advanced token.
|
|
Token previous;
|
|
|
|
// Tracks the lexing state when tokenizing interpolated strings.
|
|
//
|
|
// Interpolated strings make the lexer not strictly regular: we don't know
|
|
// whether a ")" should be treated as a RIGHT_PAREN token or as ending an
|
|
// interpolated expression unless we know whether we are inside a string
|
|
// interpolation and how many unmatched "(" there are. This is particularly
|
|
// complex because interpolation can nest:
|
|
//
|
|
// " %( " %( inner ) " ) "
|
|
//
|
|
// This tracks that state. The parser maintains a stack of ints, one for each
|
|
// level of current interpolation nesting. Each value is the number of
|
|
// unmatched "(" that are waiting to be closed.
|
|
int parens[MAX_INTERPOLATION_NESTING];
|
|
int numParens;
|
|
|
|
// Whether compile errors should be printed to stderr or discarded.
|
|
bool printErrors;
|
|
|
|
// If a syntax or compile error has occurred.
|
|
bool hasError;
|
|
} Parser;
|
|
|
|
typedef struct
|
|
{
|
|
// The name of the local variable. This points directly into the original
|
|
// source code string.
|
|
const char* name;
|
|
|
|
// The length of the local variable's name.
|
|
int length;
|
|
|
|
// The depth in the scope chain that this variable was declared at. Zero is
|
|
// the outermost scope--parameters for a method, or the first local block in
|
|
// top level code. One is the scope within that, etc.
|
|
int depth;
|
|
|
|
// If this local variable is being used as an upvalue.
|
|
bool isUpvalue;
|
|
} Local;
|
|
|
|
typedef struct
|
|
{
|
|
// True if this upvalue is capturing a local variable from the enclosing
|
|
// function. False if it's capturing an upvalue.
|
|
bool isLocal;
|
|
|
|
// The index of the local or upvalue being captured in the enclosing function.
|
|
int index;
|
|
} CompilerUpvalue;
|
|
|
|
// Bookkeeping information for the current loop being compiled.
|
|
typedef struct sLoop
|
|
{
|
|
// Index of the instruction that the loop should jump back to.
|
|
int start;
|
|
|
|
// Index of the argument for the CODE_JUMP_IF instruction used to exit the
|
|
// loop. Stored so we can patch it once we know where the loop ends.
|
|
int exitJump;
|
|
|
|
// Index of the first instruction of the body of the loop.
|
|
int body;
|
|
|
|
// Depth of the scope(s) that need to be exited if a break is hit inside the
|
|
// loop.
|
|
int scopeDepth;
|
|
|
|
// The loop enclosing this one, or NULL if this is the outermost loop.
|
|
struct sLoop* enclosing;
|
|
} Loop;
|
|
|
|
// The different signature syntaxes for different kinds of methods.
|
|
typedef enum
|
|
{
|
|
// A name followed by a (possibly empty) parenthesized parameter list. Also
|
|
// used for binary operators.
|
|
SIG_METHOD,
|
|
|
|
// Just a name. Also used for unary operators.
|
|
SIG_GETTER,
|
|
|
|
// A name followed by "=".
|
|
SIG_SETTER,
|
|
|
|
// A square bracketed parameter list.
|
|
SIG_SUBSCRIPT,
|
|
|
|
// A square bracketed parameter list followed by "=".
|
|
SIG_SUBSCRIPT_SETTER,
|
|
|
|
// A constructor initializer function. This has a distinct signature to
|
|
// prevent it from being invoked directly outside of the constructor on the
|
|
// metaclass.
|
|
SIG_INITIALIZER
|
|
} SignatureType;
|
|
|
|
typedef struct
|
|
{
|
|
const char* name;
|
|
int length;
|
|
SignatureType type;
|
|
int arity;
|
|
} Signature;
|
|
|
|
// Bookkeeping information for compiling a class definition.
|
|
typedef struct
|
|
{
|
|
// The name of the class.
|
|
ObjString* name;
|
|
|
|
// Attributes for the class itself
|
|
ObjMap* classAttributes;
|
|
// Attributes for methods in this class
|
|
ObjMap* methodAttributes;
|
|
|
|
// Symbol table for the fields of the class.
|
|
SymbolTable fields;
|
|
|
|
// Symbols for the methods defined by the class. Used to detect duplicate
|
|
// method definitions.
|
|
IntBuffer methods;
|
|
IntBuffer staticMethods;
|
|
|
|
// True if the class being compiled is a foreign class.
|
|
bool isForeign;
|
|
|
|
// True if the current method being compiled is static.
|
|
bool inStatic;
|
|
|
|
// The signature of the method being compiled.
|
|
Signature* signature;
|
|
} ClassInfo;
|
|
|
|
struct sCompiler
|
|
{
|
|
Parser* parser;
|
|
|
|
// The compiler for the function enclosing this one, or NULL if it's the
|
|
// top level.
|
|
struct sCompiler* parent;
|
|
|
|
// The currently in scope local variables.
|
|
Local locals[MAX_LOCALS];
|
|
|
|
// The number of local variables currently in scope.
|
|
int numLocals;
|
|
|
|
// The upvalues that this function has captured from outer scopes. The count
|
|
// of them is stored in [numUpvalues].
|
|
CompilerUpvalue upvalues[MAX_UPVALUES];
|
|
|
|
// The current level of block scope nesting, where zero is no nesting. A -1
|
|
// here means top-level code is being compiled and there is no block scope
|
|
// in effect at all. Any variables declared will be module-level.
|
|
int scopeDepth;
|
|
|
|
// The current number of slots (locals and temporaries) in use.
|
|
//
|
|
// We use this and maxSlots to track the maximum number of additional slots
|
|
// a function may need while executing. When the function is called, the
|
|
// fiber will check to ensure its stack has enough room to cover that worst
|
|
// case and grow the stack if needed.
|
|
//
|
|
// This value here doesn't include parameters to the function. Since those
|
|
// are already pushed onto the stack by the caller and tracked there, we
|
|
// don't need to double count them here.
|
|
int numSlots;
|
|
|
|
// The current innermost loop being compiled, or NULL if not in a loop.
|
|
Loop* loop;
|
|
|
|
// If this is a compiler for a method, keeps track of the class enclosing it.
|
|
ClassInfo* enclosingClass;
|
|
|
|
// The function being compiled.
|
|
ObjFn* fn;
|
|
|
|
// The constants for the function being compiled.
|
|
ObjMap* constants;
|
|
|
|
// Whether or not the compiler is for a constructor initializer
|
|
bool isInitializer;
|
|
|
|
// The number of attributes seen while parsing.
|
|
// We track this separately as compile time attributes
|
|
// are not stored, so we can't rely on attributes->count
|
|
// to enforce an error message when attributes are used
|
|
// anywhere other than methods or classes.
|
|
int numAttributes;
|
|
// Attributes for the next class or method.
|
|
ObjMap* attributes;
|
|
};
|
|
|
|
// Describes where a variable is declared.
|
|
typedef enum
|
|
{
|
|
// A local variable in the current function.
|
|
SCOPE_LOCAL,
|
|
|
|
// A local variable declared in an enclosing function.
|
|
SCOPE_UPVALUE,
|
|
|
|
// A top-level module variable.
|
|
SCOPE_MODULE
|
|
} Scope;
|
|
|
|
// A reference to a variable and the scope where it is defined. This contains
|
|
// enough information to emit correct code to load or store the variable.
|
|
typedef struct
|
|
{
|
|
// The stack slot, upvalue slot, or module symbol defining the variable.
|
|
int index;
|
|
|
|
// Where the variable is declared.
|
|
Scope scope;
|
|
} Variable;
|
|
|
|
// Forward declarations
|
|
static void disallowAttributes(Compiler* compiler);
|
|
static void addToAttributeGroup(Compiler* compiler, Value group, Value key, Value value);
|
|
static void emitClassAttributes(Compiler* compiler, ClassInfo* classInfo);
|
|
static void copyAttributes(Compiler* compiler, ObjMap* into);
|
|
static void copyMethodAttributes(Compiler* compiler, bool isForeign,
|
|
bool isStatic, const char* fullSignature, int32_t length);
|
|
|
|
// The stack effect of each opcode. The index in the array is the opcode, and
|
|
// the value is the stack effect of that instruction.
|
|
static const int stackEffects[] = {
|
|
#define OPCODE(_, effect) effect,
|
|
// Begin file "wren_opcodes.h"
|
|
// This defines the bytecode instructions used by the VM. It does so by invoking
|
|
// an OPCODE() macro which is expected to be defined at the point that this is
|
|
// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.)
|
|
//
|
|
// The first argument is the name of the opcode. The second is its "stack
|
|
// effect" -- the amount that the op code changes the size of the stack. A
|
|
// stack effect of 1 means it pushes a value and the stack grows one larger.
|
|
// -2 means it pops two values, etc.
|
|
//
|
|
// Note that the order of instructions here affects the order of the dispatch
|
|
// table in the VM's interpreter loop. That in turn affects caching which
|
|
// affects overall performance. Take care to run benchmarks if you change the
|
|
// order here.
|
|
|
|
// Load the constant at index [arg].
|
|
OPCODE(CONSTANT, 1)
|
|
|
|
// Push null onto the stack.
|
|
OPCODE(NULL, 1)
|
|
|
|
// Push false onto the stack.
|
|
OPCODE(FALSE, 1)
|
|
|
|
// Push true onto the stack.
|
|
OPCODE(TRUE, 1)
|
|
|
|
// Pushes the value in the given local slot.
|
|
OPCODE(LOAD_LOCAL_0, 1)
|
|
OPCODE(LOAD_LOCAL_1, 1)
|
|
OPCODE(LOAD_LOCAL_2, 1)
|
|
OPCODE(LOAD_LOCAL_3, 1)
|
|
OPCODE(LOAD_LOCAL_4, 1)
|
|
OPCODE(LOAD_LOCAL_5, 1)
|
|
OPCODE(LOAD_LOCAL_6, 1)
|
|
OPCODE(LOAD_LOCAL_7, 1)
|
|
OPCODE(LOAD_LOCAL_8, 1)
|
|
|
|
// Note: The compiler assumes the following _STORE instructions always
|
|
// immediately follow their corresponding _LOAD ones.
|
|
|
|
// Pushes the value in local slot [arg].
|
|
OPCODE(LOAD_LOCAL, 1)
|
|
|
|
// Stores the top of stack in local slot [arg]. Does not pop it.
|
|
OPCODE(STORE_LOCAL, 0)
|
|
|
|
// Pushes the value in upvalue [arg].
|
|
OPCODE(LOAD_UPVALUE, 1)
|
|
|
|
// Stores the top of stack in upvalue [arg]. Does not pop it.
|
|
OPCODE(STORE_UPVALUE, 0)
|
|
|
|
// Pushes the value of the top-level variable in slot [arg].
|
|
OPCODE(LOAD_MODULE_VAR, 1)
|
|
|
|
// Stores the top of stack in top-level variable slot [arg]. Does not pop it.
|
|
OPCODE(STORE_MODULE_VAR, 0)
|
|
|
|
// Pushes the value of the field in slot [arg] of the receiver of the current
|
|
// function. This is used for regular field accesses on "this" directly in
|
|
// methods. This instruction is faster than the more general CODE_LOAD_FIELD
|
|
// instruction.
|
|
OPCODE(LOAD_FIELD_THIS, 1)
|
|
|
|
// Stores the top of the stack in field slot [arg] in the receiver of the
|
|
// current value. Does not pop the value. This instruction is faster than the
|
|
// more general CODE_LOAD_FIELD instruction.
|
|
OPCODE(STORE_FIELD_THIS, 0)
|
|
|
|
// Pops an instance and pushes the value of the field in slot [arg] of it.
|
|
OPCODE(LOAD_FIELD, 0)
|
|
|
|
// Pops an instance and stores the subsequent top of stack in field slot
|
|
// [arg] in it. Does not pop the value.
|
|
OPCODE(STORE_FIELD, -1)
|
|
|
|
// Pop and discard the top of stack.
|
|
OPCODE(POP, -1)
|
|
|
|
// Invoke the method with symbol [arg]. The number indicates the number of
|
|
// arguments (not including the receiver).
|
|
OPCODE(CALL_0, 0)
|
|
OPCODE(CALL_1, -1)
|
|
OPCODE(CALL_2, -2)
|
|
OPCODE(CALL_3, -3)
|
|
OPCODE(CALL_4, -4)
|
|
OPCODE(CALL_5, -5)
|
|
OPCODE(CALL_6, -6)
|
|
OPCODE(CALL_7, -7)
|
|
OPCODE(CALL_8, -8)
|
|
OPCODE(CALL_9, -9)
|
|
OPCODE(CALL_10, -10)
|
|
OPCODE(CALL_11, -11)
|
|
OPCODE(CALL_12, -12)
|
|
OPCODE(CALL_13, -13)
|
|
OPCODE(CALL_14, -14)
|
|
OPCODE(CALL_15, -15)
|
|
OPCODE(CALL_16, -16)
|
|
|
|
// Invoke a superclass method with symbol [arg]. The number indicates the
|
|
// number of arguments (not including the receiver).
|
|
OPCODE(SUPER_0, 0)
|
|
OPCODE(SUPER_1, -1)
|
|
OPCODE(SUPER_2, -2)
|
|
OPCODE(SUPER_3, -3)
|
|
OPCODE(SUPER_4, -4)
|
|
OPCODE(SUPER_5, -5)
|
|
OPCODE(SUPER_6, -6)
|
|
OPCODE(SUPER_7, -7)
|
|
OPCODE(SUPER_8, -8)
|
|
OPCODE(SUPER_9, -9)
|
|
OPCODE(SUPER_10, -10)
|
|
OPCODE(SUPER_11, -11)
|
|
OPCODE(SUPER_12, -12)
|
|
OPCODE(SUPER_13, -13)
|
|
OPCODE(SUPER_14, -14)
|
|
OPCODE(SUPER_15, -15)
|
|
OPCODE(SUPER_16, -16)
|
|
|
|
// Jump the instruction pointer [arg] forward.
|
|
OPCODE(JUMP, 0)
|
|
|
|
// Jump the instruction pointer [arg] backward.
|
|
OPCODE(LOOP, 0)
|
|
|
|
// Pop and if not truthy then jump the instruction pointer [arg] forward.
|
|
OPCODE(JUMP_IF, -1)
|
|
|
|
// If the top of the stack is false, jump [arg] forward. Otherwise, pop and
|
|
// continue.
|
|
OPCODE(AND, -1)
|
|
|
|
// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop
|
|
// and continue.
|
|
OPCODE(OR, -1)
|
|
|
|
// Close the upvalue for the local on the top of the stack, then pop it.
|
|
OPCODE(CLOSE_UPVALUE, -1)
|
|
|
|
// Exit from the current function and return the value on the top of the
|
|
// stack.
|
|
OPCODE(RETURN, 0)
|
|
|
|
// Creates a closure for the function stored at [arg] in the constant table.
|
|
//
|
|
// Following the function argument is a number of arguments, two for each
|
|
// upvalue. The first is true if the variable being captured is a local (as
|
|
// opposed to an upvalue), and the second is the index of the local or
|
|
// upvalue being captured.
|
|
//
|
|
// Pushes the created closure.
|
|
OPCODE(CLOSURE, 1)
|
|
|
|
// Creates a new instance of a class.
|
|
//
|
|
// Assumes the class object is in slot zero, and replaces it with the new
|
|
// uninitialized instance of that class. This opcode is only emitted by the
|
|
// compiler-generated constructor metaclass methods.
|
|
OPCODE(CONSTRUCT, 0)
|
|
|
|
// Creates a new instance of a foreign class.
|
|
//
|
|
// Assumes the class object is in slot zero, and replaces it with the new
|
|
// uninitialized instance of that class. This opcode is only emitted by the
|
|
// compiler-generated constructor metaclass methods.
|
|
OPCODE(FOREIGN_CONSTRUCT, 0)
|
|
|
|
// Creates a class. Top of stack is the superclass. Below that is a string for
|
|
// the name of the class. Byte [arg] is the number of fields in the class.
|
|
OPCODE(CLASS, -1)
|
|
|
|
// Ends a class.
|
|
// Atm the stack contains the class and the ClassAttributes (or null).
|
|
OPCODE(END_CLASS, -2)
|
|
|
|
// Creates a foreign class. Top of stack is the superclass. Below that is a
|
|
// string for the name of the class.
|
|
OPCODE(FOREIGN_CLASS, -1)
|
|
|
|
// Define a method for symbol [arg]. The class receiving the method is popped
|
|
// off the stack, then the function defining the body is popped.
|
|
//
|
|
// If a foreign method is being defined, the "function" will be a string
|
|
// identifying the foreign method. Otherwise, it will be a function or
|
|
// closure.
|
|
OPCODE(METHOD_INSTANCE, -2)
|
|
|
|
// Define a method for symbol [arg]. The class whose metaclass will receive
|
|
// the method is popped off the stack, then the function defining the body is
|
|
// popped.
|
|
//
|
|
// If a foreign method is being defined, the "function" will be a string
|
|
// identifying the foreign method. Otherwise, it will be a function or
|
|
// closure.
|
|
OPCODE(METHOD_STATIC, -2)
|
|
|
|
// This is executed at the end of the module's body. Pushes NULL onto the stack
|
|
// as the "return value" of the import statement and stores the module as the
|
|
// most recently imported one.
|
|
OPCODE(END_MODULE, 1)
|
|
|
|
// Import a module whose name is the string stored at [arg] in the constant
|
|
// table.
|
|
//
|
|
// Pushes null onto the stack so that the fiber for the imported module can
|
|
// replace that with a dummy value when it returns. (Fibers always return a
|
|
// value when resuming a caller.)
|
|
OPCODE(IMPORT_MODULE, 1)
|
|
|
|
// Import a variable from the most recently imported module. The name of the
|
|
// variable to import is at [arg] in the constant table. Pushes the loaded
|
|
// variable's value.
|
|
OPCODE(IMPORT_VARIABLE, 1)
|
|
|
|
// This pseudo-instruction indicates the end of the bytecode. It should
|
|
// always be preceded by a `CODE_RETURN`, so is never actually executed.
|
|
OPCODE(END, 0)
|
|
// End file "wren_opcodes.h"
|
|
#undef OPCODE
|
|
};
|
|
|
|
static void printError(Parser* parser, int line, const char* label,
|
|
const char* format, va_list args)
|
|
{
|
|
parser->hasError = true;
|
|
if (!parser->printErrors) return;
|
|
|
|
// Only report errors if there is a WrenErrorFn to handle them.
|
|
if (parser->vm->config.errorFn == NULL) return;
|
|
|
|
// Format the label and message.
|
|
char message[ERROR_MESSAGE_SIZE];
|
|
int length = sprintf(message, "%s: ", label);
|
|
length += vsprintf(message + length, format, args);
|
|
ASSERT(length < ERROR_MESSAGE_SIZE, "Error should not exceed buffer.");
|
|
|
|
ObjString* module = parser->module->name;
|
|
const char* module_name = module ? module->value : "<unknown>";
|
|
|
|
parser->vm->config.errorFn(parser->vm, WREN_ERROR_COMPILE,
|
|
module_name, line, message);
|
|
}
|
|
|
|
// Outputs a lexical error.
|
|
static void lexError(Parser* parser, const char* format, ...)
|
|
{
|
|
va_list args;
|
|
va_start(args, format);
|
|
printError(parser, parser->currentLine, "Error", format, args);
|
|
va_end(args);
|
|
}
|
|
|
|
// Outputs a compile or syntax error. This also marks the compilation as having
|
|
// an error, which ensures that the resulting code will be discarded and never
|
|
// run. This means that after calling error(), it's fine to generate whatever
|
|
// invalid bytecode you want since it won't be used.
|
|
//
|
|
// You'll note that most places that call error() continue to parse and compile
|
|
// after that. That's so that we can try to find as many compilation errors in
|
|
// one pass as possible instead of just bailing at the first one.
|
|
static void error(Compiler* compiler, const char* format, ...)
|
|
{
|
|
Token* token = &compiler->parser->previous;
|
|
|
|
// If the parse error was caused by an error token, the lexer has already
|
|
// reported it.
|
|
if (token->type == TOKEN_ERROR) return;
|
|
|
|
va_list args;
|
|
va_start(args, format);
|
|
if (token->type == TOKEN_LINE)
|
|
{
|
|
printError(compiler->parser, token->line, "Error at newline", format, args);
|
|
}
|
|
else if (token->type == TOKEN_EOF)
|
|
{
|
|
printError(compiler->parser, token->line,
|
|
"Error at end of file", format, args);
|
|
}
|
|
else
|
|
{
|
|
// Make sure we don't exceed the buffer with a very long token.
|
|
char label[10 + MAX_VARIABLE_NAME + 4 + 1];
|
|
if (token->length <= MAX_VARIABLE_NAME)
|
|
{
|
|
sprintf(label, "Error at '%.*s'", token->length, token->start);
|
|
}
|
|
else
|
|
{
|
|
sprintf(label, "Error at '%.*s...'", MAX_VARIABLE_NAME, token->start);
|
|
}
|
|
printError(compiler->parser, token->line, label, format, args);
|
|
}
|
|
va_end(args);
|
|
}
|
|
|
|
// Adds [constant] to the constant pool and returns its index.
|
|
static int addConstant(Compiler* compiler, Value constant)
|
|
{
|
|
if (compiler->parser->hasError) return -1;
|
|
|
|
// See if we already have a constant for the value. If so, reuse it.
|
|
if (compiler->constants != NULL)
|
|
{
|
|
Value existing = wrenMapGet(compiler->constants, constant);
|
|
if (IS_NUM(existing)) return (int)AS_NUM(existing);
|
|
}
|
|
|
|
// It's a new constant.
|
|
if (compiler->fn->constants.count < MAX_CONSTANTS)
|
|
{
|
|
if (IS_OBJ(constant)) wrenPushRoot(compiler->parser->vm, AS_OBJ(constant));
|
|
wrenValueBufferWrite(compiler->parser->vm, &compiler->fn->constants,
|
|
constant);
|
|
if (IS_OBJ(constant)) wrenPopRoot(compiler->parser->vm);
|
|
|
|
if (compiler->constants == NULL)
|
|
{
|
|
compiler->constants = wrenNewMap(compiler->parser->vm);
|
|
}
|
|
wrenMapSet(compiler->parser->vm, compiler->constants, constant,
|
|
NUM_VAL(compiler->fn->constants.count - 1));
|
|
}
|
|
else
|
|
{
|
|
error(compiler, "A function may only contain %d unique constants.",
|
|
MAX_CONSTANTS);
|
|
}
|
|
|
|
return compiler->fn->constants.count - 1;
|
|
}
|
|
|
|
// Initializes [compiler].
|
|
static void initCompiler(Compiler* compiler, Parser* parser, Compiler* parent,
|
|
bool isMethod)
|
|
{
|
|
compiler->parser = parser;
|
|
compiler->parent = parent;
|
|
compiler->loop = NULL;
|
|
compiler->enclosingClass = NULL;
|
|
compiler->isInitializer = false;
|
|
|
|
// Initialize these to NULL before allocating in case a GC gets triggered in
|
|
// the middle of initializing the compiler.
|
|
compiler->fn = NULL;
|
|
compiler->constants = NULL;
|
|
compiler->attributes = NULL;
|
|
|
|
parser->vm->compiler = compiler;
|
|
|
|
// Declare a local slot for either the closure or method receiver so that we
|
|
// don't try to reuse that slot for a user-defined local variable. For
|
|
// methods, we name it "this", so that we can resolve references to that like
|
|
// a normal variable. For functions, they have no explicit "this", so we use
|
|
// an empty name. That way references to "this" inside a function walks up
|
|
// the parent chain to find a method enclosing the function whose "this" we
|
|
// can close over.
|
|
compiler->numLocals = 1;
|
|
compiler->numSlots = compiler->numLocals;
|
|
|
|
if (isMethod)
|
|
{
|
|
compiler->locals[0].name = "this";
|
|
compiler->locals[0].length = 4;
|
|
}
|
|
else
|
|
{
|
|
compiler->locals[0].name = NULL;
|
|
compiler->locals[0].length = 0;
|
|
}
|
|
|
|
compiler->locals[0].depth = -1;
|
|
compiler->locals[0].isUpvalue = false;
|
|
|
|
if (parent == NULL)
|
|
{
|
|
// Compiling top-level code, so the initial scope is module-level.
|
|
compiler->scopeDepth = -1;
|
|
}
|
|
else
|
|
{
|
|
// The initial scope for functions and methods is local scope.
|
|
compiler->scopeDepth = 0;
|
|
}
|
|
|
|
compiler->numAttributes = 0;
|
|
compiler->attributes = wrenNewMap(parser->vm);
|
|
compiler->fn = wrenNewFunction(parser->vm, parser->module,
|
|
compiler->numLocals);
|
|
}
|
|
|
|
// Lexing ----------------------------------------------------------------------
|
|
|
|
typedef struct
|
|
{
|
|
const char* identifier;
|
|
size_t length;
|
|
TokenType tokenType;
|
|
} Keyword;
|
|
|
|
// The table of reserved words and their associated token types.
|
|
static Keyword keywords[] =
|
|
{
|
|
{"break", 5, TOKEN_BREAK},
|
|
{"continue", 8, TOKEN_CONTINUE},
|
|
{"class", 5, TOKEN_CLASS},
|
|
{"construct", 9, TOKEN_CONSTRUCT},
|
|
{"else", 4, TOKEN_ELSE},
|
|
{"false", 5, TOKEN_FALSE},
|
|
{"for", 3, TOKEN_FOR},
|
|
{"foreign", 7, TOKEN_FOREIGN},
|
|
{"if", 2, TOKEN_IF},
|
|
{"import", 6, TOKEN_IMPORT},
|
|
{"as", 2, TOKEN_AS},
|
|
{"in", 2, TOKEN_IN},
|
|
{"is", 2, TOKEN_IS},
|
|
{"null", 4, TOKEN_NULL},
|
|
{"return", 6, TOKEN_RETURN},
|
|
{"static", 6, TOKEN_STATIC},
|
|
{"super", 5, TOKEN_SUPER},
|
|
{"this", 4, TOKEN_THIS},
|
|
{"true", 4, TOKEN_TRUE},
|
|
{"var", 3, TOKEN_VAR},
|
|
{"while", 5, TOKEN_WHILE},
|
|
{NULL, 0, TOKEN_EOF} // Sentinel to mark the end of the array.
|
|
};
|
|
|
|
// Returns true if [c] is a valid (non-initial) identifier character.
|
|
static bool isName(char c)
|
|
{
|
|
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_';
|
|
}
|
|
|
|
// Returns true if [c] is a digit.
|
|
static bool isDigit(char c)
|
|
{
|
|
return c >= '0' && c <= '9';
|
|
}
|
|
|
|
// Returns the current character the parser is sitting on.
|
|
static char peekChar(Parser* parser)
|
|
{
|
|
return *parser->currentChar;
|
|
}
|
|
|
|
// Returns the character after the current character.
|
|
static char peekNextChar(Parser* parser)
|
|
{
|
|
// If we're at the end of the source, don't read past it.
|
|
if (peekChar(parser) == '\0') return '\0';
|
|
return *(parser->currentChar + 1);
|
|
}
|
|
|
|
// Advances the parser forward one character.
|
|
static char nextChar(Parser* parser)
|
|
{
|
|
char c = peekChar(parser);
|
|
parser->currentChar++;
|
|
if (c == '\n') parser->currentLine++;
|
|
return c;
|
|
}
|
|
|
|
// If the current character is [c], consumes it and returns `true`.
|
|
static bool matchChar(Parser* parser, char c)
|
|
{
|
|
if (peekChar(parser) != c) return false;
|
|
nextChar(parser);
|
|
return true;
|
|
}
|
|
|
|
// Sets the parser's current token to the given [type] and current character
|
|
// range.
|
|
static void makeToken(Parser* parser, TokenType type)
|
|
{
|
|
parser->next.type = type;
|
|
parser->next.start = parser->tokenStart;
|
|
parser->next.length = (int)(parser->currentChar - parser->tokenStart);
|
|
parser->next.line = parser->currentLine;
|
|
|
|
// Make line tokens appear on the line containing the "\n".
|
|
if (type == TOKEN_LINE) parser->next.line--;
|
|
}
|
|
|
|
// If the current character is [c], then consumes it and makes a token of type
|
|
// [two]. Otherwise makes a token of type [one].
|
|
static void twoCharToken(Parser* parser, char c, TokenType two, TokenType one)
|
|
{
|
|
makeToken(parser, matchChar(parser, c) ? two : one);
|
|
}
|
|
|
|
// Skips the rest of the current line.
|
|
static void skipLineComment(Parser* parser)
|
|
{
|
|
while (peekChar(parser) != '\n' && peekChar(parser) != '\0')
|
|
{
|
|
nextChar(parser);
|
|
}
|
|
}
|
|
|
|
// Skips the rest of a block comment.
|
|
static void skipBlockComment(Parser* parser)
|
|
{
|
|
int nesting = 1;
|
|
while (nesting > 0)
|
|
{
|
|
if (peekChar(parser) == '\0')
|
|
{
|
|
lexError(parser, "Unterminated block comment.");
|
|
return;
|
|
}
|
|
|
|
if (peekChar(parser) == '/' && peekNextChar(parser) == '*')
|
|
{
|
|
nextChar(parser);
|
|
nextChar(parser);
|
|
nesting++;
|
|
continue;
|
|
}
|
|
|
|
if (peekChar(parser) == '*' && peekNextChar(parser) == '/')
|
|
{
|
|
nextChar(parser);
|
|
nextChar(parser);
|
|
nesting--;
|
|
continue;
|
|
}
|
|
|
|
// Regular comment character.
|
|
nextChar(parser);
|
|
}
|
|
}
|
|
|
|
// Reads the next character, which should be a hex digit (0-9, a-f, or A-F) and
|
|
// returns its numeric value. If the character isn't a hex digit, returns -1.
|
|
static int readHexDigit(Parser* parser)
|
|
{
|
|
char c = nextChar(parser);
|
|
if (c >= '0' && c <= '9') return c - '0';
|
|
if (c >= 'a' && c <= 'f') return c - 'a' + 10;
|
|
if (c >= 'A' && c <= 'F') return c - 'A' + 10;
|
|
|
|
// Don't consume it if it isn't expected. Keeps us from reading past the end
|
|
// of an unterminated string.
|
|
parser->currentChar--;
|
|
return -1;
|
|
}
|
|
|
|
// Parses the numeric value of the current token.
|
|
static void makeNumber(Parser* parser, bool isHex)
|
|
{
|
|
errno = 0;
|
|
|
|
if (isHex)
|
|
{
|
|
parser->next.value = NUM_VAL((double)strtoll(parser->tokenStart, NULL, 16));
|
|
}
|
|
else
|
|
{
|
|
parser->next.value = NUM_VAL(strtod(parser->tokenStart, NULL));
|
|
}
|
|
|
|
if (errno == ERANGE)
|
|
{
|
|
lexError(parser, "Number literal was too large (%d).", sizeof(long int));
|
|
parser->next.value = NUM_VAL(0);
|
|
}
|
|
|
|
// We don't check that the entire token is consumed after calling strtoll()
|
|
// or strtod() because we've already scanned it ourselves and know it's valid.
|
|
|
|
makeToken(parser, TOKEN_NUMBER);
|
|
}
|
|
|
|
// Finishes lexing a hexadecimal number literal.
|
|
static void readHexNumber(Parser* parser)
|
|
{
|
|
// Skip past the `x` used to denote a hexadecimal literal.
|
|
nextChar(parser);
|
|
|
|
// Iterate over all the valid hexadecimal digits found.
|
|
while (readHexDigit(parser) != -1) continue;
|
|
|
|
makeNumber(parser, true);
|
|
}
|
|
|
|
// Finishes lexing a number literal.
|
|
static void readNumber(Parser* parser)
|
|
{
|
|
while (isDigit(peekChar(parser))) nextChar(parser);
|
|
|
|
// See if it has a floating point. Make sure there is a digit after the "."
|
|
// so we don't get confused by method calls on number literals.
|
|
if (peekChar(parser) == '.' && isDigit(peekNextChar(parser)))
|
|
{
|
|
nextChar(parser);
|
|
while (isDigit(peekChar(parser))) nextChar(parser);
|
|
}
|
|
|
|
// See if the number is in scientific notation.
|
|
if (matchChar(parser, 'e') || matchChar(parser, 'E'))
|
|
{
|
|
// Allow a single positive/negative exponent symbol.
|
|
if(!matchChar(parser, '+'))
|
|
{
|
|
matchChar(parser, '-');
|
|
}
|
|
|
|
if (!isDigit(peekChar(parser)))
|
|
{
|
|
lexError(parser, "Unterminated scientific notation.");
|
|
}
|
|
|
|
while (isDigit(peekChar(parser))) nextChar(parser);
|
|
}
|
|
|
|
makeNumber(parser, false);
|
|
}
|
|
|
|
// Finishes lexing an identifier. Handles reserved words.
|
|
static void readName(Parser* parser, TokenType type, char firstChar)
|
|
{
|
|
ByteBuffer string;
|
|
wrenByteBufferInit(&string);
|
|
wrenByteBufferWrite(parser->vm, &string, firstChar);
|
|
|
|
while (isName(peekChar(parser)) || isDigit(peekChar(parser)))
|
|
{
|
|
char c = nextChar(parser);
|
|
wrenByteBufferWrite(parser->vm, &string, c);
|
|
}
|
|
|
|
// Update the type if it's a keyword.
|
|
size_t length = parser->currentChar - parser->tokenStart;
|
|
for (int i = 0; keywords[i].identifier != NULL; i++)
|
|
{
|
|
if (length == keywords[i].length &&
|
|
memcmp(parser->tokenStart, keywords[i].identifier, length) == 0)
|
|
{
|
|
type = keywords[i].tokenType;
|
|
break;
|
|
}
|
|
}
|
|
|
|
parser->next.value = wrenNewStringLength(parser->vm,
|
|
(char*)string.data, string.count);
|
|
|
|
wrenByteBufferClear(parser->vm, &string);
|
|
makeToken(parser, type);
|
|
}
|
|
|
|
// Reads [digits] hex digits in a string literal and returns their number value.
|
|
static int readHexEscape(Parser* parser, int digits, const char* description)
|
|
{
|
|
int value = 0;
|
|
for (int i = 0; i < digits; i++)
|
|
{
|
|
if (peekChar(parser) == '"' || peekChar(parser) == '\0')
|
|
{
|
|
lexError(parser, "Incomplete %s escape sequence.", description);
|
|
|
|
// Don't consume it if it isn't expected. Keeps us from reading past the
|
|
// end of an unterminated string.
|
|
parser->currentChar--;
|
|
break;
|
|
}
|
|
|
|
int digit = readHexDigit(parser);
|
|
if (digit == -1)
|
|
{
|
|
lexError(parser, "Invalid %s escape sequence.", description);
|
|
break;
|
|
}
|
|
|
|
value = (value * 16) | digit;
|
|
}
|
|
|
|
return value;
|
|
}
|
|
|
|
// Reads a hex digit Unicode escape sequence in a string literal.
|
|
static void readUnicodeEscape(Parser* parser, ByteBuffer* string, int length)
|
|
{
|
|
int value = readHexEscape(parser, length, "Unicode");
|
|
|
|
// Grow the buffer enough for the encoded result.
|
|
int numBytes = wrenUtf8EncodeNumBytes(value);
|
|
if (numBytes != 0)
|
|
{
|
|
wrenByteBufferFill(parser->vm, string, 0, numBytes);
|
|
wrenUtf8Encode(value, string->data + string->count - numBytes);
|
|
}
|
|
}
|
|
|
|
static void readRawString(Parser* parser)
|
|
{
|
|
ByteBuffer string;
|
|
wrenByteBufferInit(&string);
|
|
TokenType type = TOKEN_STRING;
|
|
|
|
//consume the second and third "
|
|
nextChar(parser);
|
|
nextChar(parser);
|
|
|
|
int skipStart = 0;
|
|
int firstNewline = -1;
|
|
|
|
int skipEnd = -1;
|
|
int lastNewline = -1;
|
|
|
|
for (;;)
|
|
{
|
|
char c = nextChar(parser);
|
|
char c1 = peekChar(parser);
|
|
char c2 = peekNextChar(parser);
|
|
|
|
if (c == '\r') continue;
|
|
|
|
if (c == '\n') {
|
|
lastNewline = string.count;
|
|
skipEnd = lastNewline;
|
|
firstNewline = firstNewline == -1 ? string.count : firstNewline;
|
|
}
|
|
|
|
if (c == '"' && c1 == '"' && c2 == '"') break;
|
|
|
|
bool isWhitespace = c == ' ' || c == '\t';
|
|
skipEnd = c == '\n' || isWhitespace ? skipEnd : -1;
|
|
|
|
// If we haven't seen a newline or other character yet,
|
|
// and still seeing whitespace, count the characters
|
|
// as skippable till we know otherwise
|
|
bool skippable = skipStart != -1 && isWhitespace && firstNewline == -1;
|
|
skipStart = skippable ? string.count + 1 : skipStart;
|
|
|
|
// We've counted leading whitespace till we hit something else,
|
|
// but it's not a newline, so we reset skipStart since we need these characters
|
|
if (firstNewline == -1 && !isWhitespace && c != '\n') skipStart = -1;
|
|
|
|
if (c == '\0' || c1 == '\0' || c2 == '\0')
|
|
{
|
|
lexError(parser, "Unterminated raw string.");
|
|
|
|
// Don't consume it if it isn't expected. Keeps us from reading past the
|
|
// end of an unterminated string.
|
|
parser->currentChar--;
|
|
break;
|
|
}
|
|
|
|
wrenByteBufferWrite(parser->vm, &string, c);
|
|
}
|
|
|
|
//consume the second and third "
|
|
nextChar(parser);
|
|
nextChar(parser);
|
|
|
|
int offset = 0;
|
|
int count = string.count;
|
|
|
|
if(firstNewline != -1 && skipStart == firstNewline) offset = firstNewline + 1;
|
|
if(lastNewline != -1 && skipEnd == lastNewline) count = lastNewline;
|
|
|
|
count -= (offset > count) ? count : offset;
|
|
|
|
parser->next.value = wrenNewStringLength(parser->vm,
|
|
((char*)string.data) + offset, count);
|
|
|
|
wrenByteBufferClear(parser->vm, &string);
|
|
makeToken(parser, type);
|
|
}
|
|
|
|
// Finishes lexing a string literal.
|
|
static void readString(Parser* parser)
|
|
{
|
|
ByteBuffer string;
|
|
TokenType type = TOKEN_STRING;
|
|
wrenByteBufferInit(&string);
|
|
|
|
for (;;)
|
|
{
|
|
char c = nextChar(parser);
|
|
if (c == '"') break;
|
|
if (c == '\r') continue;
|
|
|
|
if (c == '\0')
|
|
{
|
|
lexError(parser, "Unterminated string.");
|
|
|
|
// Don't consume it if it isn't expected. Keeps us from reading past the
|
|
// end of an unterminated string.
|
|
parser->currentChar--;
|
|
break;
|
|
}
|
|
|
|
if (c == '%')
|
|
{
|
|
if (parser->numParens < MAX_INTERPOLATION_NESTING)
|
|
{
|
|
// TODO: Allow format string.
|
|
if (nextChar(parser) != '(') lexError(parser, "Expect '(' after '%%'.");
|
|
|
|
parser->parens[parser->numParens++] = 1;
|
|
type = TOKEN_INTERPOLATION;
|
|
break;
|
|
}
|
|
|
|
lexError(parser, "Interpolation may only nest %d levels deep.",
|
|
MAX_INTERPOLATION_NESTING);
|
|
}
|
|
|
|
if (c == '\\')
|
|
{
|
|
switch (nextChar(parser))
|
|
{
|
|
case '"': wrenByteBufferWrite(parser->vm, &string, '"'); break;
|
|
case '\\': wrenByteBufferWrite(parser->vm, &string, '\\'); break;
|
|
case '%': wrenByteBufferWrite(parser->vm, &string, '%'); break;
|
|
case '0': wrenByteBufferWrite(parser->vm, &string, '\0'); break;
|
|
case 'a': wrenByteBufferWrite(parser->vm, &string, '\a'); break;
|
|
case 'b': wrenByteBufferWrite(parser->vm, &string, '\b'); break;
|
|
case 'e': wrenByteBufferWrite(parser->vm, &string, '\33'); break;
|
|
case 'f': wrenByteBufferWrite(parser->vm, &string, '\f'); break;
|
|
case 'n': wrenByteBufferWrite(parser->vm, &string, '\n'); break;
|
|
case 'r': wrenByteBufferWrite(parser->vm, &string, '\r'); break;
|
|
case 't': wrenByteBufferWrite(parser->vm, &string, '\t'); break;
|
|
case 'u': readUnicodeEscape(parser, &string, 4); break;
|
|
case 'U': readUnicodeEscape(parser, &string, 8); break;
|
|
case 'v': wrenByteBufferWrite(parser->vm, &string, '\v'); break;
|
|
case 'x':
|
|
wrenByteBufferWrite(parser->vm, &string,
|
|
(uint8_t)readHexEscape(parser, 2, "byte"));
|
|
break;
|
|
|
|
default:
|
|
lexError(parser, "Invalid escape character '%c'.",
|
|
*(parser->currentChar - 1));
|
|
break;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
wrenByteBufferWrite(parser->vm, &string, c);
|
|
}
|
|
}
|
|
|
|
parser->next.value = wrenNewStringLength(parser->vm,
|
|
(char*)string.data, string.count);
|
|
|
|
wrenByteBufferClear(parser->vm, &string);
|
|
makeToken(parser, type);
|
|
}
|
|
|
|
// Lex the next token and store it in [parser.next].
|
|
static void nextToken(Parser* parser)
|
|
{
|
|
parser->previous = parser->current;
|
|
parser->current = parser->next;
|
|
|
|
// If we are out of tokens, don't try to tokenize any more. We *do* still
|
|
// copy the TOKEN_EOF to previous so that code that expects it to be consumed
|
|
// will still work.
|
|
if (parser->next.type == TOKEN_EOF) return;
|
|
if (parser->current.type == TOKEN_EOF) return;
|
|
|
|
while (peekChar(parser) != '\0')
|
|
{
|
|
parser->tokenStart = parser->currentChar;
|
|
|
|
char c = nextChar(parser);
|
|
switch (c)
|
|
{
|
|
case '(':
|
|
// If we are inside an interpolated expression, count the unmatched "(".
|
|
if (parser->numParens > 0) parser->parens[parser->numParens - 1]++;
|
|
makeToken(parser, TOKEN_LEFT_PAREN);
|
|
return;
|
|
|
|
case ')':
|
|
// If we are inside an interpolated expression, count the ")".
|
|
if (parser->numParens > 0 &&
|
|
--parser->parens[parser->numParens - 1] == 0)
|
|
{
|
|
// This is the final ")", so the interpolation expression has ended.
|
|
// This ")" now begins the next section of the template string.
|
|
parser->numParens--;
|
|
readString(parser);
|
|
return;
|
|
}
|
|
|
|
makeToken(parser, TOKEN_RIGHT_PAREN);
|
|
return;
|
|
|
|
case '[': makeToken(parser, TOKEN_LEFT_BRACKET); return;
|
|
case ']': makeToken(parser, TOKEN_RIGHT_BRACKET); return;
|
|
case '{': makeToken(parser, TOKEN_LEFT_BRACE); return;
|
|
case '}': makeToken(parser, TOKEN_RIGHT_BRACE); return;
|
|
case ':': makeToken(parser, TOKEN_COLON); return;
|
|
case ',': makeToken(parser, TOKEN_COMMA); return;
|
|
case '*': makeToken(parser, TOKEN_STAR); return;
|
|
case '%': makeToken(parser, TOKEN_PERCENT); return;
|
|
case '#': {
|
|
// Ignore shebang on the first line.
|
|
if (parser->currentLine == 1 && peekChar(parser) == '!' && peekNextChar(parser) == '/')
|
|
{
|
|
skipLineComment(parser);
|
|
break;
|
|
}
|
|
// Otherwise we treat it as a token
|
|
makeToken(parser, TOKEN_HASH);
|
|
return;
|
|
}
|
|
case '^': makeToken(parser, TOKEN_CARET); return;
|
|
case '+': makeToken(parser, TOKEN_PLUS); return;
|
|
case '-': makeToken(parser, TOKEN_MINUS); return;
|
|
case '~': makeToken(parser, TOKEN_TILDE); return;
|
|
case '?': makeToken(parser, TOKEN_QUESTION); return;
|
|
|
|
case '|': twoCharToken(parser, '|', TOKEN_PIPEPIPE, TOKEN_PIPE); return;
|
|
case '&': twoCharToken(parser, '&', TOKEN_AMPAMP, TOKEN_AMP); return;
|
|
case '=': twoCharToken(parser, '=', TOKEN_EQEQ, TOKEN_EQ); return;
|
|
case '!': twoCharToken(parser, '=', TOKEN_BANGEQ, TOKEN_BANG); return;
|
|
|
|
case '.':
|
|
if (matchChar(parser, '.'))
|
|
{
|
|
twoCharToken(parser, '.', TOKEN_DOTDOTDOT, TOKEN_DOTDOT);
|
|
return;
|
|
}
|
|
|
|
makeToken(parser, TOKEN_DOT);
|
|
return;
|
|
|
|
case '/':
|
|
if (matchChar(parser, '/'))
|
|
{
|
|
skipLineComment(parser);
|
|
break;
|
|
}
|
|
|
|
if (matchChar(parser, '*'))
|
|
{
|
|
skipBlockComment(parser);
|
|
break;
|
|
}
|
|
|
|
makeToken(parser, TOKEN_SLASH);
|
|
return;
|
|
|
|
case '<':
|
|
if (matchChar(parser, '<'))
|
|
{
|
|
makeToken(parser, TOKEN_LTLT);
|
|
}
|
|
else
|
|
{
|
|
twoCharToken(parser, '=', TOKEN_LTEQ, TOKEN_LT);
|
|
}
|
|
return;
|
|
|
|
case '>':
|
|
if (matchChar(parser, '>'))
|
|
{
|
|
makeToken(parser, TOKEN_GTGT);
|
|
}
|
|
else
|
|
{
|
|
twoCharToken(parser, '=', TOKEN_GTEQ, TOKEN_GT);
|
|
}
|
|
return;
|
|
|
|
case '\n':
|
|
makeToken(parser, TOKEN_LINE);
|
|
return;
|
|
|
|
case ' ':
|
|
case '\r':
|
|
case '\t':
|
|
// Skip forward until we run out of whitespace.
|
|
while (peekChar(parser) == ' ' ||
|
|
peekChar(parser) == '\r' ||
|
|
peekChar(parser) == '\t')
|
|
{
|
|
nextChar(parser);
|
|
}
|
|
break;
|
|
|
|
case '"': {
|
|
if(peekChar(parser) == '"' && peekNextChar(parser) == '"') {
|
|
readRawString(parser);
|
|
return;
|
|
}
|
|
readString(parser); return;
|
|
}
|
|
case '_':
|
|
readName(parser,
|
|
peekChar(parser) == '_' ? TOKEN_STATIC_FIELD : TOKEN_FIELD, c);
|
|
return;
|
|
|
|
case '0':
|
|
if (peekChar(parser) == 'x')
|
|
{
|
|
readHexNumber(parser);
|
|
return;
|
|
}
|
|
|
|
readNumber(parser);
|
|
return;
|
|
|
|
default:
|
|
if (isName(c))
|
|
{
|
|
readName(parser, TOKEN_NAME, c);
|
|
}
|
|
else if (isDigit(c))
|
|
{
|
|
readNumber(parser);
|
|
}
|
|
else
|
|
{
|
|
if (c >= 32 && c <= 126)
|
|
{
|
|
lexError(parser, "Invalid character '%c'.", c);
|
|
}
|
|
else
|
|
{
|
|
// Don't show non-ASCII values since we didn't UTF-8 decode the
|
|
// bytes. Since there are no non-ASCII byte values that are
|
|
// meaningful code units in Wren, the lexer works on raw bytes,
|
|
// even though the source code and console output are UTF-8.
|
|
lexError(parser, "Invalid byte 0x%x.", (uint8_t)c);
|
|
}
|
|
parser->next.type = TOKEN_ERROR;
|
|
parser->next.length = 0;
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
|
|
// If we get here, we're out of source, so just make EOF tokens.
|
|
parser->tokenStart = parser->currentChar;
|
|
makeToken(parser, TOKEN_EOF);
|
|
}
|
|
|
|
// Parsing ---------------------------------------------------------------------
|
|
|
|
// Returns the type of the current token.
|
|
static TokenType peek(Compiler* compiler)
|
|
{
|
|
return compiler->parser->current.type;
|
|
}
|
|
|
|
// Returns the type of the current token.
|
|
static TokenType peekNext(Compiler* compiler)
|
|
{
|
|
return compiler->parser->next.type;
|
|
}
|
|
|
|
// Consumes the current token if its type is [expected]. Returns true if a
|
|
// token was consumed.
|
|
static bool match(Compiler* compiler, TokenType expected)
|
|
{
|
|
if (peek(compiler) != expected) return false;
|
|
|
|
nextToken(compiler->parser);
|
|
return true;
|
|
}
|
|
|
|
// Consumes the current token. Emits an error if its type is not [expected].
|
|
static void consume(Compiler* compiler, TokenType expected,
|
|
const char* errorMessage)
|
|
{
|
|
nextToken(compiler->parser);
|
|
if (compiler->parser->previous.type != expected)
|
|
{
|
|
error(compiler, errorMessage);
|
|
|
|
// If the next token is the one we want, assume the current one is just a
|
|
// spurious error and discard it to minimize the number of cascaded errors.
|
|
if (compiler->parser->current.type == expected) nextToken(compiler->parser);
|
|
}
|
|
}
|
|
|
|
// Matches one or more newlines. Returns true if at least one was found.
|
|
static bool matchLine(Compiler* compiler)
|
|
{
|
|
if (!match(compiler, TOKEN_LINE)) return false;
|
|
|
|
while (match(compiler, TOKEN_LINE));
|
|
return true;
|
|
}
|
|
|
|
// Discards any newlines starting at the current token.
|
|
static void ignoreNewlines(Compiler* compiler)
|
|
{
|
|
matchLine(compiler);
|
|
}
|
|
|
|
// Consumes the current token. Emits an error if it is not a newline. Then
|
|
// discards any duplicate newlines following it.
|
|
static void consumeLine(Compiler* compiler, const char* errorMessage)
|
|
{
|
|
consume(compiler, TOKEN_LINE, errorMessage);
|
|
ignoreNewlines(compiler);
|
|
}
|
|
|
|
static void allowLineBeforeDot(Compiler* compiler) {
|
|
if (peek(compiler) == TOKEN_LINE && peekNext(compiler) == TOKEN_DOT) {
|
|
nextToken(compiler->parser);
|
|
}
|
|
}
|
|
|
|
// Variables and scopes --------------------------------------------------------
|
|
|
|
// Emits one single-byte argument. Returns its index.
|
|
static int emitByte(Compiler* compiler, int byte)
|
|
{
|
|
wrenByteBufferWrite(compiler->parser->vm, &compiler->fn->code, (uint8_t)byte);
|
|
|
|
// Assume the instruction is associated with the most recently consumed token.
|
|
wrenIntBufferWrite(compiler->parser->vm, &compiler->fn->debug->sourceLines,
|
|
compiler->parser->previous.line);
|
|
|
|
return compiler->fn->code.count - 1;
|
|
}
|
|
|
|
// Emits one bytecode instruction.
|
|
static void emitOp(Compiler* compiler, Code instruction)
|
|
{
|
|
emitByte(compiler, instruction);
|
|
|
|
// Keep track of the stack's high water mark.
|
|
compiler->numSlots += stackEffects[instruction];
|
|
if (compiler->numSlots > compiler->fn->maxSlots)
|
|
{
|
|
compiler->fn->maxSlots = compiler->numSlots;
|
|
}
|
|
}
|
|
|
|
// Emits one 16-bit argument, which will be written big endian.
|
|
static void emitShort(Compiler* compiler, int arg)
|
|
{
|
|
emitByte(compiler, (arg >> 8) & 0xff);
|
|
emitByte(compiler, arg & 0xff);
|
|
}
|
|
|
|
// Emits one bytecode instruction followed by a 8-bit argument. Returns the
|
|
// index of the argument in the bytecode.
|
|
static int emitByteArg(Compiler* compiler, Code instruction, int arg)
|
|
{
|
|
emitOp(compiler, instruction);
|
|
return emitByte(compiler, arg);
|
|
}
|
|
|
|
// Emits one bytecode instruction followed by a 16-bit argument, which will be
|
|
// written big endian.
|
|
static void emitShortArg(Compiler* compiler, Code instruction, int arg)
|
|
{
|
|
emitOp(compiler, instruction);
|
|
emitShort(compiler, arg);
|
|
}
|
|
|
|
// Emits [instruction] followed by a placeholder for a jump offset. The
|
|
// placeholder can be patched by calling [jumpPatch]. Returns the index of the
|
|
// placeholder.
|
|
static int emitJump(Compiler* compiler, Code instruction)
|
|
{
|
|
emitOp(compiler, instruction);
|
|
emitByte(compiler, 0xff);
|
|
return emitByte(compiler, 0xff) - 1;
|
|
}
|
|
|
|
// Creates a new constant for the current value and emits the bytecode to load
|
|
// it from the constant table.
|
|
static void emitConstant(Compiler* compiler, Value value)
|
|
{
|
|
int constant = addConstant(compiler, value);
|
|
|
|
// Compile the code to load the constant.
|
|
emitShortArg(compiler, CODE_CONSTANT, constant);
|
|
}
|
|
|
|
// Create a new local variable with [name]. Assumes the current scope is local
|
|
// and the name is unique.
|
|
static int addLocal(Compiler* compiler, const char* name, int length)
|
|
{
|
|
Local* local = &compiler->locals[compiler->numLocals];
|
|
local->name = name;
|
|
local->length = length;
|
|
local->depth = compiler->scopeDepth;
|
|
local->isUpvalue = false;
|
|
return compiler->numLocals++;
|
|
}
|
|
|
|
// Declares a variable in the current scope whose name is the given token.
|
|
//
|
|
// If [token] is `NULL`, uses the previously consumed token. Returns its symbol.
|
|
static int declareVariable(Compiler* compiler, Token* token)
|
|
{
|
|
if (token == NULL) token = &compiler->parser->previous;
|
|
|
|
if (token->length > MAX_VARIABLE_NAME)
|
|
{
|
|
error(compiler, "Variable name cannot be longer than %d characters.",
|
|
MAX_VARIABLE_NAME);
|
|
}
|
|
|
|
// Top-level module scope.
|
|
if (compiler->scopeDepth == -1)
|
|
{
|
|
int line = -1;
|
|
int symbol = wrenDefineVariable(compiler->parser->vm,
|
|
compiler->parser->module,
|
|
token->start, token->length,
|
|
NULL_VAL, &line);
|
|
|
|
if (symbol == -1)
|
|
{
|
|
error(compiler, "Module variable is already defined.");
|
|
}
|
|
else if (symbol == -2)
|
|
{
|
|
error(compiler, "Too many module variables defined.");
|
|
}
|
|
else if (symbol == -3)
|
|
{
|
|
error(compiler,
|
|
"Variable '%.*s' referenced before this definition (first use at line %d).",
|
|
token->length, token->start, line);
|
|
}
|
|
|
|
return symbol;
|
|
}
|
|
|
|
// See if there is already a variable with this name declared in the current
|
|
// scope. (Outer scopes are OK: those get shadowed.)
|
|
for (int i = compiler->numLocals - 1; i >= 0; i--)
|
|
{
|
|
Local* local = &compiler->locals[i];
|
|
|
|
// Once we escape this scope and hit an outer one, we can stop.
|
|
if (local->depth < compiler->scopeDepth) break;
|
|
|
|
if (local->length == token->length &&
|
|
memcmp(local->name, token->start, token->length) == 0)
|
|
{
|
|
error(compiler, "Variable is already declared in this scope.");
|
|
return i;
|
|
}
|
|
}
|
|
|
|
if (compiler->numLocals == MAX_LOCALS)
|
|
{
|
|
error(compiler, "Cannot declare more than %d variables in one scope.",
|
|
MAX_LOCALS);
|
|
return -1;
|
|
}
|
|
|
|
return addLocal(compiler, token->start, token->length);
|
|
}
|
|
|
|
// Parses a name token and declares a variable in the current scope with that
|
|
// name. Returns its slot.
|
|
static int declareNamedVariable(Compiler* compiler)
|
|
{
|
|
consume(compiler, TOKEN_NAME, "Expect variable name.");
|
|
return declareVariable(compiler, NULL);
|
|
}
|
|
|
|
// Stores a variable with the previously defined symbol in the current scope.
|
|
static void defineVariable(Compiler* compiler, int symbol)
|
|
{
|
|
// Store the variable. If it's a local, the result of the initializer is
|
|
// in the correct slot on the stack already so we're done.
|
|
if (compiler->scopeDepth >= 0) return;
|
|
|
|
// It's a module-level variable, so store the value in the module slot and
|
|
// then discard the temporary for the initializer.
|
|
emitShortArg(compiler, CODE_STORE_MODULE_VAR, symbol);
|
|
emitOp(compiler, CODE_POP);
|
|
}
|
|
|
|
// Starts a new local block scope.
|
|
static void pushScope(Compiler* compiler)
|
|
{
|
|
compiler->scopeDepth++;
|
|
}
|
|
|
|
// Generates code to discard local variables at [depth] or greater. Does *not*
|
|
// actually undeclare variables or pop any scopes, though. This is called
|
|
// directly when compiling "break" statements to ditch the local variables
|
|
// before jumping out of the loop even though they are still in scope *past*
|
|
// the break instruction.
|
|
//
|
|
// Returns the number of local variables that were eliminated.
|
|
static int discardLocals(Compiler* compiler, int depth)
|
|
{
|
|
ASSERT(compiler->scopeDepth > -1, "Cannot exit top-level scope.");
|
|
|
|
int local = compiler->numLocals - 1;
|
|
while (local >= 0 && compiler->locals[local].depth >= depth)
|
|
{
|
|
// If the local was closed over, make sure the upvalue gets closed when it
|
|
// goes out of scope on the stack. We use emitByte() and not emitOp() here
|
|
// because we don't want to track that stack effect of these pops since the
|
|
// variables are still in scope after the break.
|
|
if (compiler->locals[local].isUpvalue)
|
|
{
|
|
emitByte(compiler, CODE_CLOSE_UPVALUE);
|
|
}
|
|
else
|
|
{
|
|
emitByte(compiler, CODE_POP);
|
|
}
|
|
|
|
|
|
local--;
|
|
}
|
|
|
|
return compiler->numLocals - local - 1;
|
|
}
|
|
|
|
// Closes the last pushed block scope and discards any local variables declared
|
|
// in that scope. This should only be called in a statement context where no
|
|
// temporaries are still on the stack.
|
|
static void popScope(Compiler* compiler)
|
|
{
|
|
int popped = discardLocals(compiler, compiler->scopeDepth);
|
|
compiler->numLocals -= popped;
|
|
compiler->numSlots -= popped;
|
|
compiler->scopeDepth--;
|
|
}
|
|
|
|
// Attempts to look up the name in the local variables of [compiler]. If found,
|
|
// returns its index, otherwise returns -1.
|
|
static int resolveLocal(Compiler* compiler, const char* name, int length)
|
|
{
|
|
// Look it up in the local scopes. Look in reverse order so that the most
|
|
// nested variable is found first and shadows outer ones.
|
|
for (int i = compiler->numLocals - 1; i >= 0; i--)
|
|
{
|
|
if (compiler->locals[i].length == length &&
|
|
memcmp(name, compiler->locals[i].name, length) == 0)
|
|
{
|
|
return i;
|
|
}
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
// Adds an upvalue to [compiler]'s function with the given properties. Does not
|
|
// add one if an upvalue for that variable is already in the list. Returns the
|
|
// index of the upvalue.
|
|
static int addUpvalue(Compiler* compiler, bool isLocal, int index)
|
|
{
|
|
// Look for an existing one.
|
|
for (int i = 0; i < compiler->fn->numUpvalues; i++)
|
|
{
|
|
CompilerUpvalue* upvalue = &compiler->upvalues[i];
|
|
if (upvalue->index == index && upvalue->isLocal == isLocal) return i;
|
|
}
|
|
|
|
// If we got here, it's a new upvalue.
|
|
compiler->upvalues[compiler->fn->numUpvalues].isLocal = isLocal;
|
|
compiler->upvalues[compiler->fn->numUpvalues].index = index;
|
|
return compiler->fn->numUpvalues++;
|
|
}
|
|
|
|
// Attempts to look up [name] in the functions enclosing the one being compiled
|
|
// by [compiler]. If found, it adds an upvalue for it to this compiler's list
|
|
// of upvalues (unless it's already in there) and returns its index. If not
|
|
// found, returns -1.
|
|
//
|
|
// If the name is found outside of the immediately enclosing function, this
|
|
// will flatten the closure and add upvalues to all of the intermediate
|
|
// functions so that it gets walked down to this one.
|
|
//
|
|
// If it reaches a method boundary, this stops and returns -1 since methods do
|
|
// not close over local variables.
|
|
static int findUpvalue(Compiler* compiler, const char* name, int length)
|
|
{
|
|
// If we are at the top level, we didn't find it.
|
|
if (compiler->parent == NULL) return -1;
|
|
|
|
// If we hit the method boundary (and the name isn't a static field), then
|
|
// stop looking for it. We'll instead treat it as a self send.
|
|
if (name[0] != '_' && compiler->parent->enclosingClass != NULL) return -1;
|
|
|
|
// See if it's a local variable in the immediately enclosing function.
|
|
int local = resolveLocal(compiler->parent, name, length);
|
|
if (local != -1)
|
|
{
|
|
// Mark the local as an upvalue so we know to close it when it goes out of
|
|
// scope.
|
|
compiler->parent->locals[local].isUpvalue = true;
|
|
|
|
return addUpvalue(compiler, true, local);
|
|
}
|
|
|
|
// See if it's an upvalue in the immediately enclosing function. In other
|
|
// words, if it's a local variable in a non-immediately enclosing function.
|
|
// This "flattens" closures automatically: it adds upvalues to all of the
|
|
// intermediate functions to get from the function where a local is declared
|
|
// all the way into the possibly deeply nested function that is closing over
|
|
// it.
|
|
int upvalue = findUpvalue(compiler->parent, name, length);
|
|
if (upvalue != -1)
|
|
{
|
|
return addUpvalue(compiler, false, upvalue);
|
|
}
|
|
|
|
// If we got here, we walked all the way up the parent chain and couldn't
|
|
// find it.
|
|
return -1;
|
|
}
|
|
|
|
// Look up [name] in the current scope to see what variable it refers to.
|
|
// Returns the variable either in local scope, or the enclosing function's
|
|
// upvalue list. Does not search the module scope. Returns a variable with
|
|
// index -1 if not found.
|
|
static Variable resolveNonmodule(Compiler* compiler,
|
|
const char* name, int length)
|
|
{
|
|
// Look it up in the local scopes.
|
|
Variable variable;
|
|
variable.scope = SCOPE_LOCAL;
|
|
variable.index = resolveLocal(compiler, name, length);
|
|
if (variable.index != -1) return variable;
|
|
|
|
// Tt's not a local, so guess that it's an upvalue.
|
|
variable.scope = SCOPE_UPVALUE;
|
|
variable.index = findUpvalue(compiler, name, length);
|
|
return variable;
|
|
}
|
|
|
|
// Look up [name] in the current scope to see what variable it refers to.
|
|
// Returns the variable either in module scope, local scope, or the enclosing
|
|
// function's upvalue list. Returns a variable with index -1 if not found.
|
|
static Variable resolveName(Compiler* compiler, const char* name, int length)
|
|
{
|
|
Variable variable = resolveNonmodule(compiler, name, length);
|
|
if (variable.index != -1) return variable;
|
|
|
|
variable.scope = SCOPE_MODULE;
|
|
variable.index = wrenSymbolTableFind(&compiler->parser->module->variableNames,
|
|
name, length);
|
|
return variable;
|
|
}
|
|
|
|
static void loadLocal(Compiler* compiler, int slot)
|
|
{
|
|
if (slot <= 8)
|
|
{
|
|
emitOp(compiler, (Code)(CODE_LOAD_LOCAL_0 + slot));
|
|
return;
|
|
}
|
|
|
|
emitByteArg(compiler, CODE_LOAD_LOCAL, slot);
|
|
}
|
|
|
|
// Finishes [compiler], which is compiling a function, method, or chunk of top
|
|
// level code. If there is a parent compiler, then this emits code in the
|
|
// parent compiler to load the resulting function.
|
|
static ObjFn* endCompiler(Compiler* compiler,
|
|
const char* debugName, int debugNameLength)
|
|
{
|
|
// If we hit an error, don't finish the function since it's borked anyway.
|
|
if (compiler->parser->hasError)
|
|
{
|
|
compiler->parser->vm->compiler = compiler->parent;
|
|
return NULL;
|
|
}
|
|
|
|
// Mark the end of the bytecode. Since it may contain multiple early returns,
|
|
// we can't rely on CODE_RETURN to tell us we're at the end.
|
|
emitOp(compiler, CODE_END);
|
|
|
|
wrenFunctionBindName(compiler->parser->vm, compiler->fn,
|
|
debugName, debugNameLength);
|
|
|
|
// In the function that contains this one, load the resulting function object.
|
|
if (compiler->parent != NULL)
|
|
{
|
|
int constant = addConstant(compiler->parent, OBJ_VAL(compiler->fn));
|
|
|
|
// Wrap the function in a closure. We do this even if it has no upvalues so
|
|
// that the VM can uniformly assume all called objects are closures. This
|
|
// makes creating a function a little slower, but makes invoking them
|
|
// faster. Given that functions are invoked more often than they are
|
|
// created, this is a win.
|
|
emitShortArg(compiler->parent, CODE_CLOSURE, constant);
|
|
|
|
// Emit arguments for each upvalue to know whether to capture a local or
|
|
// an upvalue.
|
|
for (int i = 0; i < compiler->fn->numUpvalues; i++)
|
|
{
|
|
emitByte(compiler->parent, compiler->upvalues[i].isLocal ? 1 : 0);
|
|
emitByte(compiler->parent, compiler->upvalues[i].index);
|
|
}
|
|
}
|
|
|
|
// Pop this compiler off the stack.
|
|
compiler->parser->vm->compiler = compiler->parent;
|
|
|
|
#if WREN_DEBUG_DUMP_COMPILED_CODE
|
|
wrenDumpCode(compiler->parser->vm, compiler->fn);
|
|
#endif
|
|
|
|
return compiler->fn;
|
|
}
|
|
|
|
// Grammar ---------------------------------------------------------------------
|
|
|
|
typedef enum
|
|
{
|
|
PREC_NONE,
|
|
PREC_LOWEST,
|
|
PREC_ASSIGNMENT, // =
|
|
PREC_CONDITIONAL, // ?:
|
|
PREC_LOGICAL_OR, // ||
|
|
PREC_LOGICAL_AND, // &&
|
|
PREC_EQUALITY, // == !=
|
|
PREC_IS, // is
|
|
PREC_COMPARISON, // < > <= >=
|
|
PREC_BITWISE_OR, // |
|
|
PREC_BITWISE_XOR, // ^
|
|
PREC_BITWISE_AND, // &
|
|
PREC_BITWISE_SHIFT, // << >>
|
|
PREC_RANGE, // .. ...
|
|
PREC_TERM, // + -
|
|
PREC_FACTOR, // * / %
|
|
PREC_UNARY, // unary - ! ~
|
|
PREC_CALL, // . () []
|
|
PREC_PRIMARY
|
|
} Precedence;
|
|
|
|
typedef void (*GrammarFn)(Compiler*, bool canAssign);
|
|
|
|
typedef void (*SignatureFn)(Compiler* compiler, Signature* signature);
|
|
|
|
typedef struct
|
|
{
|
|
GrammarFn prefix;
|
|
GrammarFn infix;
|
|
SignatureFn method;
|
|
Precedence precedence;
|
|
const char* name;
|
|
} GrammarRule;
|
|
|
|
// Forward declarations since the grammar is recursive.
|
|
static GrammarRule* getRule(TokenType type);
|
|
static void expression(Compiler* compiler);
|
|
static void statement(Compiler* compiler);
|
|
static void definition(Compiler* compiler);
|
|
static void parsePrecedence(Compiler* compiler, Precedence precedence);
|
|
|
|
// Replaces the placeholder argument for a previous CODE_JUMP or CODE_JUMP_IF
|
|
// instruction with an offset that jumps to the current end of bytecode.
|
|
static void patchJump(Compiler* compiler, int offset)
|
|
{
|
|
// -2 to adjust for the bytecode for the jump offset itself.
|
|
int jump = compiler->fn->code.count - offset - 2;
|
|
if (jump > MAX_JUMP) error(compiler, "Too much code to jump over.");
|
|
|
|
compiler->fn->code.data[offset] = (jump >> 8) & 0xff;
|
|
compiler->fn->code.data[offset + 1] = jump & 0xff;
|
|
}
|
|
|
|
// Parses a block body, after the initial "{" has been consumed.
|
|
//
|
|
// Returns true if it was a expression body, false if it was a statement body.
|
|
// (More precisely, returns true if a value was left on the stack. An empty
|
|
// block returns false.)
|
|
static bool finishBlock(Compiler* compiler)
|
|
{
|
|
// Empty blocks do nothing.
|
|
if (match(compiler, TOKEN_RIGHT_BRACE)) return false;
|
|
|
|
// If there's no line after the "{", it's a single-expression body.
|
|
if (!matchLine(compiler))
|
|
{
|
|
expression(compiler);
|
|
consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' at end of block.");
|
|
return true;
|
|
}
|
|
|
|
// Empty blocks (with just a newline inside) do nothing.
|
|
if (match(compiler, TOKEN_RIGHT_BRACE)) return false;
|
|
|
|
// Compile the definition list.
|
|
do
|
|
{
|
|
definition(compiler);
|
|
consumeLine(compiler, "Expect newline after statement.");
|
|
}
|
|
while (peek(compiler) != TOKEN_RIGHT_BRACE && peek(compiler) != TOKEN_EOF);
|
|
|
|
consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' at end of block.");
|
|
return false;
|
|
}
|
|
|
|
// Parses a method or function body, after the initial "{" has been consumed.
|
|
//
|
|
// If [Compiler->isInitializer] is `true`, this is the body of a constructor
|
|
// initializer. In that case, this adds the code to ensure it returns `this`.
|
|
static void finishBody(Compiler* compiler)
|
|
{
|
|
bool isExpressionBody = finishBlock(compiler);
|
|
|
|
if (compiler->isInitializer)
|
|
{
|
|
// If the initializer body evaluates to a value, discard it.
|
|
if (isExpressionBody) emitOp(compiler, CODE_POP);
|
|
|
|
// The receiver is always stored in the first local slot.
|
|
emitOp(compiler, CODE_LOAD_LOCAL_0);
|
|
}
|
|
else if (!isExpressionBody)
|
|
{
|
|
// Implicitly return null in statement bodies.
|
|
emitOp(compiler, CODE_NULL);
|
|
}
|
|
|
|
emitOp(compiler, CODE_RETURN);
|
|
}
|
|
|
|
// The VM can only handle a certain number of parameters, so check that we
|
|
// haven't exceeded that and give a usable error.
|
|
static void validateNumParameters(Compiler* compiler, int numArgs)
|
|
{
|
|
if (numArgs == MAX_PARAMETERS + 1)
|
|
{
|
|
// Only show an error at exactly max + 1 so that we can keep parsing the
|
|
// parameters and minimize cascaded errors.
|
|
error(compiler, "Methods cannot have more than %d parameters.",
|
|
MAX_PARAMETERS);
|
|
}
|
|
}
|
|
|
|
// Parses the rest of a comma-separated parameter list after the opening
|
|
// delimeter. Updates `arity` in [signature] with the number of parameters.
|
|
static void finishParameterList(Compiler* compiler, Signature* signature)
|
|
{
|
|
do
|
|
{
|
|
ignoreNewlines(compiler);
|
|
validateNumParameters(compiler, ++signature->arity);
|
|
|
|
// Define a local variable in the method for the parameter.
|
|
declareNamedVariable(compiler);
|
|
}
|
|
while (match(compiler, TOKEN_COMMA));
|
|
}
|
|
|
|
// Gets the symbol for a method [name] with [length].
|
|
static int methodSymbol(Compiler* compiler, const char* name, int length)
|
|
{
|
|
return wrenSymbolTableEnsure(compiler->parser->vm,
|
|
&compiler->parser->vm->methodNames, name, length);
|
|
}
|
|
|
|
// Appends characters to [name] (and updates [length]) for [numParams] "_"
|
|
// surrounded by [leftBracket] and [rightBracket].
|
|
static void signatureParameterList(char name[MAX_METHOD_SIGNATURE], int* length,
|
|
int numParams, char leftBracket, char rightBracket)
|
|
{
|
|
name[(*length)++] = leftBracket;
|
|
|
|
// This function may be called with too many parameters. When that happens,
|
|
// a compile error has already been reported, but we need to make sure we
|
|
// don't overflow the string too, hence the MAX_PARAMETERS check.
|
|
for (int i = 0; i < numParams && i < MAX_PARAMETERS; i++)
|
|
{
|
|
if (i > 0) name[(*length)++] = ',';
|
|
name[(*length)++] = '_';
|
|
}
|
|
name[(*length)++] = rightBracket;
|
|
}
|
|
|
|
// Fills [name] with the stringified version of [signature] and updates
|
|
// [length] to the resulting length.
|
|
static void signatureToString(Signature* signature,
|
|
char name[MAX_METHOD_SIGNATURE], int* length)
|
|
{
|
|
*length = 0;
|
|
|
|
// Build the full name from the signature.
|
|
memcpy(name + *length, signature->name, signature->length);
|
|
*length += signature->length;
|
|
|
|
switch (signature->type)
|
|
{
|
|
case SIG_METHOD:
|
|
signatureParameterList(name, length, signature->arity, '(', ')');
|
|
break;
|
|
|
|
case SIG_GETTER:
|
|
// The signature is just the name.
|
|
break;
|
|
|
|
case SIG_SETTER:
|
|
name[(*length)++] = '=';
|
|
signatureParameterList(name, length, 1, '(', ')');
|
|
break;
|
|
|
|
case SIG_SUBSCRIPT:
|
|
signatureParameterList(name, length, signature->arity, '[', ']');
|
|
break;
|
|
|
|
case SIG_SUBSCRIPT_SETTER:
|
|
signatureParameterList(name, length, signature->arity - 1, '[', ']');
|
|
name[(*length)++] = '=';
|
|
signatureParameterList(name, length, 1, '(', ')');
|
|
break;
|
|
|
|
case SIG_INITIALIZER:
|
|
memcpy(name, "init ", 5);
|
|
memcpy(name + 5, signature->name, signature->length);
|
|
*length = 5 + signature->length;
|
|
signatureParameterList(name, length, signature->arity, '(', ')');
|
|
break;
|
|
}
|
|
|
|
name[*length] = '\0';
|
|
}
|
|
|
|
// Gets the symbol for a method with [signature].
|
|
static int signatureSymbol(Compiler* compiler, Signature* signature)
|
|
{
|
|
// Build the full name from the signature.
|
|
char name[MAX_METHOD_SIGNATURE];
|
|
int length;
|
|
signatureToString(signature, name, &length);
|
|
|
|
return methodSymbol(compiler, name, length);
|
|
}
|
|
|
|
// Returns a signature with [type] whose name is from the last consumed token.
|
|
static Signature signatureFromToken(Compiler* compiler, SignatureType type)
|
|
{
|
|
Signature signature;
|
|
|
|
// Get the token for the method name.
|
|
Token* token = &compiler->parser->previous;
|
|
signature.name = token->start;
|
|
signature.length = token->length;
|
|
signature.type = type;
|
|
signature.arity = 0;
|
|
|
|
if (signature.length > MAX_METHOD_NAME)
|
|
{
|
|
error(compiler, "Method names cannot be longer than %d characters.",
|
|
MAX_METHOD_NAME);
|
|
signature.length = MAX_METHOD_NAME;
|
|
}
|
|
|
|
return signature;
|
|
}
|
|
|
|
// Parses a comma-separated list of arguments. Modifies [signature] to include
|
|
// the arity of the argument list.
|
|
static void finishArgumentList(Compiler* compiler, Signature* signature)
|
|
{
|
|
do
|
|
{
|
|
ignoreNewlines(compiler);
|
|
validateNumParameters(compiler, ++signature->arity);
|
|
expression(compiler);
|
|
}
|
|
while (match(compiler, TOKEN_COMMA));
|
|
|
|
// Allow a newline before the closing delimiter.
|
|
ignoreNewlines(compiler);
|
|
}
|
|
|
|
// Compiles a method call with [signature] using [instruction].
|
|
static void callSignature(Compiler* compiler, Code instruction,
|
|
Signature* signature)
|
|
{
|
|
int symbol = signatureSymbol(compiler, signature);
|
|
emitShortArg(compiler, (Code)(instruction + signature->arity), symbol);
|
|
|
|
if (instruction == CODE_SUPER_0)
|
|
{
|
|
// Super calls need to be statically bound to the class's superclass. This
|
|
// ensures we call the right method even when a method containing a super
|
|
// call is inherited by another subclass.
|
|
//
|
|
// We bind it at class definition time by storing a reference to the
|
|
// superclass in a constant. So, here, we create a slot in the constant
|
|
// table and store NULL in it. When the method is bound, we'll look up the
|
|
// superclass then and store it in the constant slot.
|
|
emitShort(compiler, addConstant(compiler, NULL_VAL));
|
|
}
|
|
}
|
|
|
|
// Compiles a method call with [numArgs] for a method with [name] with [length].
|
|
static void callMethod(Compiler* compiler, int numArgs, const char* name,
|
|
int length)
|
|
{
|
|
int symbol = methodSymbol(compiler, name, length);
|
|
emitShortArg(compiler, (Code)(CODE_CALL_0 + numArgs), symbol);
|
|
}
|
|
|
|
// Compiles an (optional) argument list for a method call with [methodSignature]
|
|
// and then calls it.
|
|
static void methodCall(Compiler* compiler, Code instruction,
|
|
Signature* signature)
|
|
{
|
|
// Make a new signature that contains the updated arity and type based on
|
|
// the arguments we find.
|
|
Signature called = { signature->name, signature->length, SIG_GETTER, 0 };
|
|
|
|
// Parse the argument list, if any.
|
|
if (match(compiler, TOKEN_LEFT_PAREN))
|
|
{
|
|
called.type = SIG_METHOD;
|
|
|
|
// Allow new line before an empty argument list
|
|
ignoreNewlines(compiler);
|
|
|
|
// Allow empty an argument list.
|
|
if (peek(compiler) != TOKEN_RIGHT_PAREN)
|
|
{
|
|
finishArgumentList(compiler, &called);
|
|
}
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after arguments.");
|
|
}
|
|
|
|
// Parse the block argument, if any.
|
|
if (match(compiler, TOKEN_LEFT_BRACE))
|
|
{
|
|
// Include the block argument in the arity.
|
|
called.type = SIG_METHOD;
|
|
called.arity++;
|
|
|
|
Compiler fnCompiler;
|
|
initCompiler(&fnCompiler, compiler->parser, compiler, false);
|
|
|
|
// Make a dummy signature to track the arity.
|
|
Signature fnSignature = { "", 0, SIG_METHOD, 0 };
|
|
|
|
// Parse the parameter list, if any.
|
|
if (match(compiler, TOKEN_PIPE))
|
|
{
|
|
finishParameterList(&fnCompiler, &fnSignature);
|
|
consume(compiler, TOKEN_PIPE, "Expect '|' after function parameters.");
|
|
}
|
|
|
|
fnCompiler.fn->arity = fnSignature.arity;
|
|
|
|
finishBody(&fnCompiler);
|
|
|
|
// Name the function based on the method its passed to.
|
|
char blockName[MAX_METHOD_SIGNATURE + 15];
|
|
int blockLength;
|
|
signatureToString(&called, blockName, &blockLength);
|
|
memmove(blockName + blockLength, " block argument", 16);
|
|
|
|
endCompiler(&fnCompiler, blockName, blockLength + 15);
|
|
}
|
|
|
|
// TODO: Allow Grace-style mixfix methods?
|
|
|
|
// If this is a super() call for an initializer, make sure we got an actual
|
|
// argument list.
|
|
if (signature->type == SIG_INITIALIZER)
|
|
{
|
|
if (called.type != SIG_METHOD)
|
|
{
|
|
error(compiler, "A superclass constructor must have an argument list.");
|
|
}
|
|
|
|
called.type = SIG_INITIALIZER;
|
|
}
|
|
|
|
callSignature(compiler, instruction, &called);
|
|
}
|
|
|
|
// Compiles a call whose name is the previously consumed token. This includes
|
|
// getters, method calls with arguments, and setter calls.
|
|
static void namedCall(Compiler* compiler, bool canAssign, Code instruction)
|
|
{
|
|
// Get the token for the method name.
|
|
Signature signature = signatureFromToken(compiler, SIG_GETTER);
|
|
|
|
if (canAssign && match(compiler, TOKEN_EQ))
|
|
{
|
|
ignoreNewlines(compiler);
|
|
|
|
// Build the setter signature.
|
|
signature.type = SIG_SETTER;
|
|
signature.arity = 1;
|
|
|
|
// Compile the assigned value.
|
|
expression(compiler);
|
|
callSignature(compiler, instruction, &signature);
|
|
}
|
|
else
|
|
{
|
|
methodCall(compiler, instruction, &signature);
|
|
allowLineBeforeDot(compiler);
|
|
}
|
|
}
|
|
|
|
// Emits the code to load [variable] onto the stack.
|
|
static void loadVariable(Compiler* compiler, Variable variable)
|
|
{
|
|
switch (variable.scope)
|
|
{
|
|
case SCOPE_LOCAL:
|
|
loadLocal(compiler, variable.index);
|
|
break;
|
|
case SCOPE_UPVALUE:
|
|
emitByteArg(compiler, CODE_LOAD_UPVALUE, variable.index);
|
|
break;
|
|
case SCOPE_MODULE:
|
|
emitShortArg(compiler, CODE_LOAD_MODULE_VAR, variable.index);
|
|
break;
|
|
default:
|
|
UNREACHABLE();
|
|
}
|
|
}
|
|
|
|
// Loads the receiver of the currently enclosing method. Correctly handles
|
|
// functions defined inside methods.
|
|
static void loadThis(Compiler* compiler)
|
|
{
|
|
loadVariable(compiler, resolveNonmodule(compiler, "this", 4));
|
|
}
|
|
|
|
// Pushes the value for a module-level variable implicitly imported from core.
|
|
static void loadCoreVariable(Compiler* compiler, const char* name)
|
|
{
|
|
int symbol = wrenSymbolTableFind(&compiler->parser->module->variableNames,
|
|
name, strlen(name));
|
|
ASSERT(symbol != -1, "Should have already defined core name.");
|
|
emitShortArg(compiler, CODE_LOAD_MODULE_VAR, symbol);
|
|
}
|
|
|
|
// A parenthesized expression.
|
|
static void grouping(Compiler* compiler, bool canAssign)
|
|
{
|
|
expression(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after expression.");
|
|
}
|
|
|
|
// A list literal.
|
|
static void list(Compiler* compiler, bool canAssign)
|
|
{
|
|
// Instantiate a new list.
|
|
loadCoreVariable(compiler, "List");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
|
|
// Compile the list elements. Each one compiles to a ".add()" call.
|
|
do
|
|
{
|
|
ignoreNewlines(compiler);
|
|
|
|
// Stop if we hit the end of the list.
|
|
if (peek(compiler) == TOKEN_RIGHT_BRACKET) break;
|
|
|
|
// The element.
|
|
expression(compiler);
|
|
callMethod(compiler, 1, "addCore_(_)", 11);
|
|
} while (match(compiler, TOKEN_COMMA));
|
|
|
|
// Allow newlines before the closing ']'.
|
|
ignoreNewlines(compiler);
|
|
consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after list elements.");
|
|
}
|
|
|
|
// A map literal.
|
|
static void map(Compiler* compiler, bool canAssign)
|
|
{
|
|
// Instantiate a new map.
|
|
loadCoreVariable(compiler, "Map");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
|
|
// Compile the map elements. Each one is compiled to just invoke the
|
|
// subscript setter on the map.
|
|
do
|
|
{
|
|
ignoreNewlines(compiler);
|
|
|
|
// Stop if we hit the end of the map.
|
|
if (peek(compiler) == TOKEN_RIGHT_BRACE) break;
|
|
|
|
// The key.
|
|
parsePrecedence(compiler, PREC_UNARY);
|
|
consume(compiler, TOKEN_COLON, "Expect ':' after map key.");
|
|
ignoreNewlines(compiler);
|
|
|
|
// The value.
|
|
expression(compiler);
|
|
callMethod(compiler, 2, "addCore_(_,_)", 13);
|
|
} while (match(compiler, TOKEN_COMMA));
|
|
|
|
// Allow newlines before the closing '}'.
|
|
ignoreNewlines(compiler);
|
|
consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' after map entries.");
|
|
}
|
|
|
|
// Unary operators like `-foo`.
|
|
static void unaryOp(Compiler* compiler, bool canAssign)
|
|
{
|
|
GrammarRule* rule = getRule(compiler->parser->previous.type);
|
|
|
|
ignoreNewlines(compiler);
|
|
|
|
// Compile the argument.
|
|
parsePrecedence(compiler, (Precedence)(PREC_UNARY + 1));
|
|
|
|
// Call the operator method on the left-hand side.
|
|
callMethod(compiler, 0, rule->name, 1);
|
|
}
|
|
|
|
static void boolean(Compiler* compiler, bool canAssign)
|
|
{
|
|
emitOp(compiler,
|
|
compiler->parser->previous.type == TOKEN_FALSE ? CODE_FALSE : CODE_TRUE);
|
|
}
|
|
|
|
// Walks the compiler chain to find the compiler for the nearest class
|
|
// enclosing this one. Returns NULL if not currently inside a class definition.
|
|
static Compiler* getEnclosingClassCompiler(Compiler* compiler)
|
|
{
|
|
while (compiler != NULL)
|
|
{
|
|
if (compiler->enclosingClass != NULL) return compiler;
|
|
compiler = compiler->parent;
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
// Walks the compiler chain to find the nearest class enclosing this one.
|
|
// Returns NULL if not currently inside a class definition.
|
|
static ClassInfo* getEnclosingClass(Compiler* compiler)
|
|
{
|
|
compiler = getEnclosingClassCompiler(compiler);
|
|
return compiler == NULL ? NULL : compiler->enclosingClass;
|
|
}
|
|
|
|
static void field(Compiler* compiler, bool canAssign)
|
|
{
|
|
// Initialize it with a fake value so we can keep parsing and minimize the
|
|
// number of cascaded errors.
|
|
int field = MAX_FIELDS;
|
|
|
|
ClassInfo* enclosingClass = getEnclosingClass(compiler);
|
|
|
|
if (enclosingClass == NULL)
|
|
{
|
|
error(compiler, "Cannot reference a field outside of a class definition.");
|
|
}
|
|
else if (enclosingClass->isForeign)
|
|
{
|
|
error(compiler, "Cannot define fields in a foreign class.");
|
|
}
|
|
else if (enclosingClass->inStatic)
|
|
{
|
|
error(compiler, "Cannot use an instance field in a static method.");
|
|
}
|
|
else
|
|
{
|
|
// Look up the field, or implicitly define it.
|
|
field = wrenSymbolTableEnsure(compiler->parser->vm, &enclosingClass->fields,
|
|
compiler->parser->previous.start,
|
|
compiler->parser->previous.length);
|
|
|
|
if (field >= MAX_FIELDS)
|
|
{
|
|
error(compiler, "A class can only have %d fields.", MAX_FIELDS);
|
|
}
|
|
}
|
|
|
|
// If there's an "=" after a field name, it's an assignment.
|
|
bool isLoad = true;
|
|
if (canAssign && match(compiler, TOKEN_EQ))
|
|
{
|
|
// Compile the right-hand side.
|
|
expression(compiler);
|
|
isLoad = false;
|
|
}
|
|
|
|
// If we're directly inside a method, use a more optimal instruction.
|
|
if (compiler->parent != NULL &&
|
|
compiler->parent->enclosingClass == enclosingClass)
|
|
{
|
|
emitByteArg(compiler, isLoad ? CODE_LOAD_FIELD_THIS : CODE_STORE_FIELD_THIS,
|
|
field);
|
|
}
|
|
else
|
|
{
|
|
loadThis(compiler);
|
|
emitByteArg(compiler, isLoad ? CODE_LOAD_FIELD : CODE_STORE_FIELD, field);
|
|
}
|
|
|
|
allowLineBeforeDot(compiler);
|
|
}
|
|
|
|
// Compiles a read or assignment to [variable].
|
|
static void bareName(Compiler* compiler, bool canAssign, Variable variable)
|
|
{
|
|
// If there's an "=" after a bare name, it's a variable assignment.
|
|
if (canAssign && match(compiler, TOKEN_EQ))
|
|
{
|
|
// Compile the right-hand side.
|
|
expression(compiler);
|
|
|
|
// Emit the store instruction.
|
|
switch (variable.scope)
|
|
{
|
|
case SCOPE_LOCAL:
|
|
emitByteArg(compiler, CODE_STORE_LOCAL, variable.index);
|
|
break;
|
|
case SCOPE_UPVALUE:
|
|
emitByteArg(compiler, CODE_STORE_UPVALUE, variable.index);
|
|
break;
|
|
case SCOPE_MODULE:
|
|
emitShortArg(compiler, CODE_STORE_MODULE_VAR, variable.index);
|
|
break;
|
|
default:
|
|
UNREACHABLE();
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Emit the load instruction.
|
|
loadVariable(compiler, variable);
|
|
|
|
allowLineBeforeDot(compiler);
|
|
}
|
|
|
|
static void staticField(Compiler* compiler, bool canAssign)
|
|
{
|
|
Compiler* classCompiler = getEnclosingClassCompiler(compiler);
|
|
if (classCompiler == NULL)
|
|
{
|
|
error(compiler, "Cannot use a static field outside of a class definition.");
|
|
return;
|
|
}
|
|
|
|
// Look up the name in the scope chain.
|
|
Token* token = &compiler->parser->previous;
|
|
|
|
// If this is the first time we've seen this static field, implicitly
|
|
// define it as a variable in the scope surrounding the class definition.
|
|
if (resolveLocal(classCompiler, token->start, token->length) == -1)
|
|
{
|
|
int symbol = declareVariable(classCompiler, NULL);
|
|
|
|
// Implicitly initialize it to null.
|
|
emitOp(classCompiler, CODE_NULL);
|
|
defineVariable(classCompiler, symbol);
|
|
}
|
|
|
|
// It definitely exists now, so resolve it properly. This is different from
|
|
// the above resolveLocal() call because we may have already closed over it
|
|
// as an upvalue.
|
|
Variable variable = resolveName(compiler, token->start, token->length);
|
|
bareName(compiler, canAssign, variable);
|
|
}
|
|
|
|
// Compiles a variable name or method call with an implicit receiver.
|
|
static void name(Compiler* compiler, bool canAssign)
|
|
{
|
|
// Look for the name in the scope chain up to the nearest enclosing method.
|
|
Token* token = &compiler->parser->previous;
|
|
|
|
Variable variable = resolveNonmodule(compiler, token->start, token->length);
|
|
if (variable.index != -1)
|
|
{
|
|
bareName(compiler, canAssign, variable);
|
|
return;
|
|
}
|
|
|
|
// TODO: The fact that we return above here if the variable is known and parse
|
|
// an optional argument list below if not means that the grammar is not
|
|
// context-free. A line of code in a method like "someName(foo)" is a parse
|
|
// error if "someName" is a defined variable in the surrounding scope and not
|
|
// if it isn't. Fix this. One option is to have "someName(foo)" always
|
|
// resolve to a self-call if there is an argument list, but that makes
|
|
// getters a little confusing.
|
|
|
|
// If we're inside a method and the name is lowercase, treat it as a method
|
|
// on this.
|
|
if (wrenIsLocalName(token->start) && getEnclosingClass(compiler) != NULL)
|
|
{
|
|
loadThis(compiler);
|
|
namedCall(compiler, canAssign, CODE_CALL_0);
|
|
return;
|
|
}
|
|
|
|
// Otherwise, look for a module-level variable with the name.
|
|
variable.scope = SCOPE_MODULE;
|
|
variable.index = wrenSymbolTableFind(&compiler->parser->module->variableNames,
|
|
token->start, token->length);
|
|
if (variable.index == -1)
|
|
{
|
|
// Implicitly define a module-level variable in
|
|
// the hopes that we get a real definition later.
|
|
variable.index = wrenDeclareVariable(compiler->parser->vm,
|
|
compiler->parser->module,
|
|
token->start, token->length,
|
|
token->line);
|
|
|
|
if (variable.index == -2)
|
|
{
|
|
error(compiler, "Too many module variables defined.");
|
|
}
|
|
}
|
|
|
|
bareName(compiler, canAssign, variable);
|
|
}
|
|
|
|
static void null(Compiler* compiler, bool canAssign)
|
|
{
|
|
emitOp(compiler, CODE_NULL);
|
|
}
|
|
|
|
// A number or string literal.
|
|
static void literal(Compiler* compiler, bool canAssign)
|
|
{
|
|
emitConstant(compiler, compiler->parser->previous.value);
|
|
}
|
|
|
|
// A string literal that contains interpolated expressions.
|
|
//
|
|
// Interpolation is syntactic sugar for calling ".join()" on a list. So the
|
|
// string:
|
|
//
|
|
// "a %(b + c) d"
|
|
//
|
|
// is compiled roughly like:
|
|
//
|
|
// ["a ", b + c, " d"].join()
|
|
static void stringInterpolation(Compiler* compiler, bool canAssign)
|
|
{
|
|
// Instantiate a new list.
|
|
loadCoreVariable(compiler, "List");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
|
|
do
|
|
{
|
|
// The opening string part.
|
|
literal(compiler, false);
|
|
callMethod(compiler, 1, "addCore_(_)", 11);
|
|
|
|
// The interpolated expression.
|
|
ignoreNewlines(compiler);
|
|
expression(compiler);
|
|
callMethod(compiler, 1, "addCore_(_)", 11);
|
|
|
|
ignoreNewlines(compiler);
|
|
} while (match(compiler, TOKEN_INTERPOLATION));
|
|
|
|
// The trailing string part.
|
|
consume(compiler, TOKEN_STRING, "Expect end of string interpolation.");
|
|
literal(compiler, false);
|
|
callMethod(compiler, 1, "addCore_(_)", 11);
|
|
|
|
// The list of interpolated parts.
|
|
callMethod(compiler, 0, "join()", 6);
|
|
}
|
|
|
|
static void super_(Compiler* compiler, bool canAssign)
|
|
{
|
|
ClassInfo* enclosingClass = getEnclosingClass(compiler);
|
|
if (enclosingClass == NULL)
|
|
{
|
|
error(compiler, "Cannot use 'super' outside of a method.");
|
|
}
|
|
|
|
loadThis(compiler);
|
|
|
|
// TODO: Super operator calls.
|
|
// TODO: There's no syntax for invoking a superclass constructor with a
|
|
// different name from the enclosing one. Figure that out.
|
|
|
|
// See if it's a named super call, or an unnamed one.
|
|
if (match(compiler, TOKEN_DOT))
|
|
{
|
|
// Compile the superclass call.
|
|
consume(compiler, TOKEN_NAME, "Expect method name after 'super.'.");
|
|
namedCall(compiler, canAssign, CODE_SUPER_0);
|
|
}
|
|
else if (enclosingClass != NULL)
|
|
{
|
|
// No explicit name, so use the name of the enclosing method. Make sure we
|
|
// check that enclosingClass isn't NULL first. We've already reported the
|
|
// error, but we don't want to crash here.
|
|
methodCall(compiler, CODE_SUPER_0, enclosingClass->signature);
|
|
}
|
|
}
|
|
|
|
static void this_(Compiler* compiler, bool canAssign)
|
|
{
|
|
if (getEnclosingClass(compiler) == NULL)
|
|
{
|
|
error(compiler, "Cannot use 'this' outside of a method.");
|
|
return;
|
|
}
|
|
|
|
loadThis(compiler);
|
|
}
|
|
|
|
// Subscript or "array indexing" operator like `foo[bar]`.
|
|
static void subscript(Compiler* compiler, bool canAssign)
|
|
{
|
|
Signature signature = { "", 0, SIG_SUBSCRIPT, 0 };
|
|
|
|
// Parse the argument list.
|
|
finishArgumentList(compiler, &signature);
|
|
consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after arguments.");
|
|
|
|
allowLineBeforeDot(compiler);
|
|
|
|
if (canAssign && match(compiler, TOKEN_EQ))
|
|
{
|
|
signature.type = SIG_SUBSCRIPT_SETTER;
|
|
|
|
// Compile the assigned value.
|
|
validateNumParameters(compiler, ++signature.arity);
|
|
expression(compiler);
|
|
}
|
|
|
|
callSignature(compiler, CODE_CALL_0, &signature);
|
|
}
|
|
|
|
static void call(Compiler* compiler, bool canAssign)
|
|
{
|
|
ignoreNewlines(compiler);
|
|
consume(compiler, TOKEN_NAME, "Expect method name after '.'.");
|
|
namedCall(compiler, canAssign, CODE_CALL_0);
|
|
}
|
|
|
|
static void and_(Compiler* compiler, bool canAssign)
|
|
{
|
|
ignoreNewlines(compiler);
|
|
|
|
// Skip the right argument if the left is false.
|
|
int jump = emitJump(compiler, CODE_AND);
|
|
parsePrecedence(compiler, PREC_LOGICAL_AND);
|
|
patchJump(compiler, jump);
|
|
}
|
|
|
|
static void or_(Compiler* compiler, bool canAssign)
|
|
{
|
|
ignoreNewlines(compiler);
|
|
|
|
// Skip the right argument if the left is true.
|
|
int jump = emitJump(compiler, CODE_OR);
|
|
parsePrecedence(compiler, PREC_LOGICAL_OR);
|
|
patchJump(compiler, jump);
|
|
}
|
|
|
|
static void conditional(Compiler* compiler, bool canAssign)
|
|
{
|
|
// Ignore newline after '?'.
|
|
ignoreNewlines(compiler);
|
|
|
|
// Jump to the else branch if the condition is false.
|
|
int ifJump = emitJump(compiler, CODE_JUMP_IF);
|
|
|
|
// Compile the then branch.
|
|
parsePrecedence(compiler, PREC_CONDITIONAL);
|
|
|
|
consume(compiler, TOKEN_COLON,
|
|
"Expect ':' after then branch of conditional operator.");
|
|
ignoreNewlines(compiler);
|
|
|
|
// Jump over the else branch when the if branch is taken.
|
|
int elseJump = emitJump(compiler, CODE_JUMP);
|
|
|
|
// Compile the else branch.
|
|
patchJump(compiler, ifJump);
|
|
|
|
parsePrecedence(compiler, PREC_ASSIGNMENT);
|
|
|
|
// Patch the jump over the else.
|
|
patchJump(compiler, elseJump);
|
|
}
|
|
|
|
void infixOp(Compiler* compiler, bool canAssign)
|
|
{
|
|
GrammarRule* rule = getRule(compiler->parser->previous.type);
|
|
|
|
// An infix operator cannot end an expression.
|
|
ignoreNewlines(compiler);
|
|
|
|
// Compile the right-hand side.
|
|
parsePrecedence(compiler, (Precedence)(rule->precedence + 1));
|
|
|
|
// Call the operator method on the left-hand side.
|
|
Signature signature = { rule->name, (int)strlen(rule->name), SIG_METHOD, 1 };
|
|
callSignature(compiler, CODE_CALL_0, &signature);
|
|
}
|
|
|
|
// Compiles a method signature for an infix operator.
|
|
void infixSignature(Compiler* compiler, Signature* signature)
|
|
{
|
|
// Add the RHS parameter.
|
|
signature->type = SIG_METHOD;
|
|
signature->arity = 1;
|
|
|
|
// Parse the parameter name.
|
|
consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after operator name.");
|
|
declareNamedVariable(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name.");
|
|
}
|
|
|
|
// Compiles a method signature for an unary operator (i.e. "!").
|
|
void unarySignature(Compiler* compiler, Signature* signature)
|
|
{
|
|
// Do nothing. The name is already complete.
|
|
signature->type = SIG_GETTER;
|
|
}
|
|
|
|
// Compiles a method signature for an operator that can either be unary or
|
|
// infix (i.e. "-").
|
|
void mixedSignature(Compiler* compiler, Signature* signature)
|
|
{
|
|
signature->type = SIG_GETTER;
|
|
|
|
// If there is a parameter, it's an infix operator, otherwise it's unary.
|
|
if (match(compiler, TOKEN_LEFT_PAREN))
|
|
{
|
|
// Add the RHS parameter.
|
|
signature->type = SIG_METHOD;
|
|
signature->arity = 1;
|
|
|
|
// Parse the parameter name.
|
|
declareNamedVariable(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name.");
|
|
}
|
|
}
|
|
|
|
// Compiles an optional setter parameter in a method [signature].
|
|
//
|
|
// Returns `true` if it was a setter.
|
|
static bool maybeSetter(Compiler* compiler, Signature* signature)
|
|
{
|
|
// See if it's a setter.
|
|
if (!match(compiler, TOKEN_EQ)) return false;
|
|
|
|
// It's a setter.
|
|
if (signature->type == SIG_SUBSCRIPT)
|
|
{
|
|
signature->type = SIG_SUBSCRIPT_SETTER;
|
|
}
|
|
else
|
|
{
|
|
signature->type = SIG_SETTER;
|
|
}
|
|
|
|
// Parse the value parameter.
|
|
consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after '='.");
|
|
declareNamedVariable(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name.");
|
|
|
|
signature->arity++;
|
|
|
|
return true;
|
|
}
|
|
|
|
// Compiles a method signature for a subscript operator.
|
|
void subscriptSignature(Compiler* compiler, Signature* signature)
|
|
{
|
|
signature->type = SIG_SUBSCRIPT;
|
|
|
|
// The signature currently has "[" as its name since that was the token that
|
|
// matched it. Clear that out.
|
|
signature->length = 0;
|
|
|
|
// Parse the parameters inside the subscript.
|
|
finishParameterList(compiler, signature);
|
|
consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after parameters.");
|
|
|
|
maybeSetter(compiler, signature);
|
|
}
|
|
|
|
// Parses an optional parenthesized parameter list. Updates `type` and `arity`
|
|
// in [signature] to match what was parsed.
|
|
static void parameterList(Compiler* compiler, Signature* signature)
|
|
{
|
|
// The parameter list is optional.
|
|
if (!match(compiler, TOKEN_LEFT_PAREN)) return;
|
|
|
|
signature->type = SIG_METHOD;
|
|
|
|
// Allow new line before an empty argument list
|
|
ignoreNewlines(compiler);
|
|
|
|
// Allow an empty parameter list.
|
|
if (match(compiler, TOKEN_RIGHT_PAREN)) return;
|
|
|
|
finishParameterList(compiler, signature);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameters.");
|
|
}
|
|
|
|
// Compiles a method signature for a named method or setter.
|
|
void namedSignature(Compiler* compiler, Signature* signature)
|
|
{
|
|
signature->type = SIG_GETTER;
|
|
|
|
// If it's a setter, it can't also have a parameter list.
|
|
if (maybeSetter(compiler, signature)) return;
|
|
|
|
// Regular named method with an optional parameter list.
|
|
parameterList(compiler, signature);
|
|
}
|
|
|
|
// Compiles a method signature for a constructor.
|
|
void constructorSignature(Compiler* compiler, Signature* signature)
|
|
{
|
|
consume(compiler, TOKEN_NAME, "Expect constructor name after 'construct'.");
|
|
|
|
// Capture the name.
|
|
*signature = signatureFromToken(compiler, SIG_INITIALIZER);
|
|
|
|
if (match(compiler, TOKEN_EQ))
|
|
{
|
|
error(compiler, "A constructor cannot be a setter.");
|
|
}
|
|
|
|
if (!match(compiler, TOKEN_LEFT_PAREN))
|
|
{
|
|
error(compiler, "A constructor cannot be a getter.");
|
|
return;
|
|
}
|
|
|
|
// Allow an empty parameter list.
|
|
if (match(compiler, TOKEN_RIGHT_PAREN)) return;
|
|
|
|
finishParameterList(compiler, signature);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameters.");
|
|
}
|
|
|
|
// This table defines all of the parsing rules for the prefix and infix
|
|
// expressions in the grammar. Expressions are parsed using a Pratt parser.
|
|
//
|
|
// See: http://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/
|
|
#define UNUSED { NULL, NULL, NULL, PREC_NONE, NULL }
|
|
#define PREFIX(fn) { fn, NULL, NULL, PREC_NONE, NULL }
|
|
#define INFIX(prec, fn) { NULL, fn, NULL, prec, NULL }
|
|
#define INFIX_OPERATOR(prec, name) { NULL, infixOp, infixSignature, prec, name }
|
|
#define PREFIX_OPERATOR(name) { unaryOp, NULL, unarySignature, PREC_NONE, name }
|
|
#define OPERATOR(name) { unaryOp, infixOp, mixedSignature, PREC_TERM, name }
|
|
|
|
GrammarRule rules[] =
|
|
{
|
|
/* TOKEN_LEFT_PAREN */ PREFIX(grouping),
|
|
/* TOKEN_RIGHT_PAREN */ UNUSED,
|
|
/* TOKEN_LEFT_BRACKET */ { list, subscript, subscriptSignature, PREC_CALL, NULL },
|
|
/* TOKEN_RIGHT_BRACKET */ UNUSED,
|
|
/* TOKEN_LEFT_BRACE */ PREFIX(map),
|
|
/* TOKEN_RIGHT_BRACE */ UNUSED,
|
|
/* TOKEN_COLON */ UNUSED,
|
|
/* TOKEN_DOT */ INFIX(PREC_CALL, call),
|
|
/* TOKEN_DOTDOT */ INFIX_OPERATOR(PREC_RANGE, ".."),
|
|
/* TOKEN_DOTDOTDOT */ INFIX_OPERATOR(PREC_RANGE, "..."),
|
|
/* TOKEN_COMMA */ UNUSED,
|
|
/* TOKEN_STAR */ INFIX_OPERATOR(PREC_FACTOR, "*"),
|
|
/* TOKEN_SLASH */ INFIX_OPERATOR(PREC_FACTOR, "/"),
|
|
/* TOKEN_PERCENT */ INFIX_OPERATOR(PREC_FACTOR, "%"),
|
|
/* TOKEN_HASH */ UNUSED,
|
|
/* TOKEN_PLUS */ INFIX_OPERATOR(PREC_TERM, "+"),
|
|
/* TOKEN_MINUS */ OPERATOR("-"),
|
|
/* TOKEN_LTLT */ INFIX_OPERATOR(PREC_BITWISE_SHIFT, "<<"),
|
|
/* TOKEN_GTGT */ INFIX_OPERATOR(PREC_BITWISE_SHIFT, ">>"),
|
|
/* TOKEN_PIPE */ INFIX_OPERATOR(PREC_BITWISE_OR, "|"),
|
|
/* TOKEN_PIPEPIPE */ INFIX(PREC_LOGICAL_OR, or_),
|
|
/* TOKEN_CARET */ INFIX_OPERATOR(PREC_BITWISE_XOR, "^"),
|
|
/* TOKEN_AMP */ INFIX_OPERATOR(PREC_BITWISE_AND, "&"),
|
|
/* TOKEN_AMPAMP */ INFIX(PREC_LOGICAL_AND, and_),
|
|
/* TOKEN_BANG */ PREFIX_OPERATOR("!"),
|
|
/* TOKEN_TILDE */ PREFIX_OPERATOR("~"),
|
|
/* TOKEN_QUESTION */ INFIX(PREC_ASSIGNMENT, conditional),
|
|
/* TOKEN_EQ */ UNUSED,
|
|
/* TOKEN_LT */ INFIX_OPERATOR(PREC_COMPARISON, "<"),
|
|
/* TOKEN_GT */ INFIX_OPERATOR(PREC_COMPARISON, ">"),
|
|
/* TOKEN_LTEQ */ INFIX_OPERATOR(PREC_COMPARISON, "<="),
|
|
/* TOKEN_GTEQ */ INFIX_OPERATOR(PREC_COMPARISON, ">="),
|
|
/* TOKEN_EQEQ */ INFIX_OPERATOR(PREC_EQUALITY, "=="),
|
|
/* TOKEN_BANGEQ */ INFIX_OPERATOR(PREC_EQUALITY, "!="),
|
|
/* TOKEN_BREAK */ UNUSED,
|
|
/* TOKEN_CONTINUE */ UNUSED,
|
|
/* TOKEN_CLASS */ UNUSED,
|
|
/* TOKEN_CONSTRUCT */ { NULL, NULL, constructorSignature, PREC_NONE, NULL },
|
|
/* TOKEN_ELSE */ UNUSED,
|
|
/* TOKEN_FALSE */ PREFIX(boolean),
|
|
/* TOKEN_FOR */ UNUSED,
|
|
/* TOKEN_FOREIGN */ UNUSED,
|
|
/* TOKEN_IF */ UNUSED,
|
|
/* TOKEN_IMPORT */ UNUSED,
|
|
/* TOKEN_AS */ UNUSED,
|
|
/* TOKEN_IN */ UNUSED,
|
|
/* TOKEN_IS */ INFIX_OPERATOR(PREC_IS, "is"),
|
|
/* TOKEN_NULL */ PREFIX(null),
|
|
/* TOKEN_RETURN */ UNUSED,
|
|
/* TOKEN_STATIC */ UNUSED,
|
|
/* TOKEN_SUPER */ PREFIX(super_),
|
|
/* TOKEN_THIS */ PREFIX(this_),
|
|
/* TOKEN_TRUE */ PREFIX(boolean),
|
|
/* TOKEN_VAR */ UNUSED,
|
|
/* TOKEN_WHILE */ UNUSED,
|
|
/* TOKEN_FIELD */ PREFIX(field),
|
|
/* TOKEN_STATIC_FIELD */ PREFIX(staticField),
|
|
/* TOKEN_NAME */ { name, NULL, namedSignature, PREC_NONE, NULL },
|
|
/* TOKEN_NUMBER */ PREFIX(literal),
|
|
/* TOKEN_STRING */ PREFIX(literal),
|
|
/* TOKEN_INTERPOLATION */ PREFIX(stringInterpolation),
|
|
/* TOKEN_LINE */ UNUSED,
|
|
/* TOKEN_ERROR */ UNUSED,
|
|
/* TOKEN_EOF */ UNUSED
|
|
};
|
|
|
|
// Gets the [GrammarRule] associated with tokens of [type].
|
|
static GrammarRule* getRule(TokenType type)
|
|
{
|
|
return &rules[type];
|
|
}
|
|
|
|
// The main entrypoint for the top-down operator precedence parser.
|
|
void parsePrecedence(Compiler* compiler, Precedence precedence)
|
|
{
|
|
nextToken(compiler->parser);
|
|
GrammarFn prefix = rules[compiler->parser->previous.type].prefix;
|
|
|
|
if (prefix == NULL)
|
|
{
|
|
error(compiler, "Expected expression.");
|
|
return;
|
|
}
|
|
|
|
// Track if the precendence of the surrounding expression is low enough to
|
|
// allow an assignment inside this one. We can't compile an assignment like
|
|
// a normal expression because it requires us to handle the LHS specially --
|
|
// it needs to be an lvalue, not an rvalue. So, for each of the kinds of
|
|
// expressions that are valid lvalues -- names, subscripts, fields, etc. --
|
|
// we pass in whether or not it appears in a context loose enough to allow
|
|
// "=". If so, it will parse the "=" itself and handle it appropriately.
|
|
bool canAssign = precedence <= PREC_CONDITIONAL;
|
|
prefix(compiler, canAssign);
|
|
|
|
while (precedence <= rules[compiler->parser->current.type].precedence)
|
|
{
|
|
nextToken(compiler->parser);
|
|
GrammarFn infix = rules[compiler->parser->previous.type].infix;
|
|
infix(compiler, canAssign);
|
|
}
|
|
}
|
|
|
|
// Parses an expression. Unlike statements, expressions leave a resulting value
|
|
// on the stack.
|
|
void expression(Compiler* compiler)
|
|
{
|
|
parsePrecedence(compiler, PREC_LOWEST);
|
|
}
|
|
|
|
// Returns the number of bytes for the arguments to the instruction
|
|
// at [ip] in [fn]'s bytecode.
|
|
static int getByteCountForArguments(const uint8_t* bytecode,
|
|
const Value* constants, int ip)
|
|
{
|
|
Code instruction = (Code)bytecode[ip];
|
|
switch (instruction)
|
|
{
|
|
case CODE_NULL:
|
|
case CODE_FALSE:
|
|
case CODE_TRUE:
|
|
case CODE_POP:
|
|
case CODE_CLOSE_UPVALUE:
|
|
case CODE_RETURN:
|
|
case CODE_END:
|
|
case CODE_LOAD_LOCAL_0:
|
|
case CODE_LOAD_LOCAL_1:
|
|
case CODE_LOAD_LOCAL_2:
|
|
case CODE_LOAD_LOCAL_3:
|
|
case CODE_LOAD_LOCAL_4:
|
|
case CODE_LOAD_LOCAL_5:
|
|
case CODE_LOAD_LOCAL_6:
|
|
case CODE_LOAD_LOCAL_7:
|
|
case CODE_LOAD_LOCAL_8:
|
|
case CODE_CONSTRUCT:
|
|
case CODE_FOREIGN_CONSTRUCT:
|
|
case CODE_FOREIGN_CLASS:
|
|
case CODE_END_MODULE:
|
|
case CODE_END_CLASS:
|
|
return 0;
|
|
|
|
case CODE_LOAD_LOCAL:
|
|
case CODE_STORE_LOCAL:
|
|
case CODE_LOAD_UPVALUE:
|
|
case CODE_STORE_UPVALUE:
|
|
case CODE_LOAD_FIELD_THIS:
|
|
case CODE_STORE_FIELD_THIS:
|
|
case CODE_LOAD_FIELD:
|
|
case CODE_STORE_FIELD:
|
|
case CODE_CLASS:
|
|
return 1;
|
|
|
|
case CODE_CONSTANT:
|
|
case CODE_LOAD_MODULE_VAR:
|
|
case CODE_STORE_MODULE_VAR:
|
|
case CODE_CALL_0:
|
|
case CODE_CALL_1:
|
|
case CODE_CALL_2:
|
|
case CODE_CALL_3:
|
|
case CODE_CALL_4:
|
|
case CODE_CALL_5:
|
|
case CODE_CALL_6:
|
|
case CODE_CALL_7:
|
|
case CODE_CALL_8:
|
|
case CODE_CALL_9:
|
|
case CODE_CALL_10:
|
|
case CODE_CALL_11:
|
|
case CODE_CALL_12:
|
|
case CODE_CALL_13:
|
|
case CODE_CALL_14:
|
|
case CODE_CALL_15:
|
|
case CODE_CALL_16:
|
|
case CODE_JUMP:
|
|
case CODE_LOOP:
|
|
case CODE_JUMP_IF:
|
|
case CODE_AND:
|
|
case CODE_OR:
|
|
case CODE_METHOD_INSTANCE:
|
|
case CODE_METHOD_STATIC:
|
|
case CODE_IMPORT_MODULE:
|
|
case CODE_IMPORT_VARIABLE:
|
|
return 2;
|
|
|
|
case CODE_SUPER_0:
|
|
case CODE_SUPER_1:
|
|
case CODE_SUPER_2:
|
|
case CODE_SUPER_3:
|
|
case CODE_SUPER_4:
|
|
case CODE_SUPER_5:
|
|
case CODE_SUPER_6:
|
|
case CODE_SUPER_7:
|
|
case CODE_SUPER_8:
|
|
case CODE_SUPER_9:
|
|
case CODE_SUPER_10:
|
|
case CODE_SUPER_11:
|
|
case CODE_SUPER_12:
|
|
case CODE_SUPER_13:
|
|
case CODE_SUPER_14:
|
|
case CODE_SUPER_15:
|
|
case CODE_SUPER_16:
|
|
return 4;
|
|
|
|
case CODE_CLOSURE:
|
|
{
|
|
int constant = (bytecode[ip + 1] << 8) | bytecode[ip + 2];
|
|
ObjFn* loadedFn = AS_FN(constants[constant]);
|
|
|
|
// There are two bytes for the constant, then two for each upvalue.
|
|
return 2 + (loadedFn->numUpvalues * 2);
|
|
}
|
|
}
|
|
|
|
UNREACHABLE();
|
|
return 0;
|
|
}
|
|
|
|
// Marks the beginning of a loop. Keeps track of the current instruction so we
|
|
// know what to loop back to at the end of the body.
|
|
static void startLoop(Compiler* compiler, Loop* loop)
|
|
{
|
|
loop->enclosing = compiler->loop;
|
|
loop->start = compiler->fn->code.count - 1;
|
|
loop->scopeDepth = compiler->scopeDepth;
|
|
compiler->loop = loop;
|
|
}
|
|
|
|
// Emits the [CODE_JUMP_IF] instruction used to test the loop condition and
|
|
// potentially exit the loop. Keeps track of the instruction so we can patch it
|
|
// later once we know where the end of the body is.
|
|
static void testExitLoop(Compiler* compiler)
|
|
{
|
|
compiler->loop->exitJump = emitJump(compiler, CODE_JUMP_IF);
|
|
}
|
|
|
|
// Compiles the body of the loop and tracks its extent so that contained "break"
|
|
// statements can be handled correctly.
|
|
static void loopBody(Compiler* compiler)
|
|
{
|
|
compiler->loop->body = compiler->fn->code.count;
|
|
statement(compiler);
|
|
}
|
|
|
|
// Ends the current innermost loop. Patches up all jumps and breaks now that
|
|
// we know where the end of the loop is.
|
|
static void endLoop(Compiler* compiler)
|
|
{
|
|
// We don't check for overflow here since the forward jump over the loop body
|
|
// will report an error for the same problem.
|
|
int loopOffset = compiler->fn->code.count - compiler->loop->start + 2;
|
|
emitShortArg(compiler, CODE_LOOP, loopOffset);
|
|
|
|
patchJump(compiler, compiler->loop->exitJump);
|
|
|
|
// Find any break placeholder instructions (which will be CODE_END in the
|
|
// bytecode) and replace them with real jumps.
|
|
int i = compiler->loop->body;
|
|
while (i < compiler->fn->code.count)
|
|
{
|
|
if (compiler->fn->code.data[i] == CODE_END)
|
|
{
|
|
compiler->fn->code.data[i] = CODE_JUMP;
|
|
patchJump(compiler, i + 1);
|
|
i += 3;
|
|
}
|
|
else
|
|
{
|
|
// Skip this instruction and its arguments.
|
|
i += 1 + getByteCountForArguments(compiler->fn->code.data,
|
|
compiler->fn->constants.data, i);
|
|
}
|
|
}
|
|
|
|
compiler->loop = compiler->loop->enclosing;
|
|
}
|
|
|
|
static void forStatement(Compiler* compiler)
|
|
{
|
|
// A for statement like:
|
|
//
|
|
// for (i in sequence.expression) {
|
|
// System.print(i)
|
|
// }
|
|
//
|
|
// Is compiled to bytecode almost as if the source looked like this:
|
|
//
|
|
// {
|
|
// var seq_ = sequence.expression
|
|
// var iter_
|
|
// while (iter_ = seq_.iterate(iter_)) {
|
|
// var i = seq_.iteratorValue(iter_)
|
|
// System.print(i)
|
|
// }
|
|
// }
|
|
//
|
|
// It's not exactly this, because the synthetic variables `seq_` and `iter_`
|
|
// actually get names that aren't valid Wren identfiers, but that's the basic
|
|
// idea.
|
|
//
|
|
// The important parts are:
|
|
// - The sequence expression is only evaluated once.
|
|
// - The .iterate() method is used to advance the iterator and determine if
|
|
// it should exit the loop.
|
|
// - The .iteratorValue() method is used to get the value at the current
|
|
// iterator position.
|
|
|
|
// Create a scope for the hidden local variables used for the iterator.
|
|
pushScope(compiler);
|
|
|
|
consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'for'.");
|
|
consume(compiler, TOKEN_NAME, "Expect for loop variable name.");
|
|
|
|
// Remember the name of the loop variable.
|
|
const char* name = compiler->parser->previous.start;
|
|
int length = compiler->parser->previous.length;
|
|
|
|
consume(compiler, TOKEN_IN, "Expect 'in' after loop variable.");
|
|
ignoreNewlines(compiler);
|
|
|
|
// Evaluate the sequence expression and store it in a hidden local variable.
|
|
// The space in the variable name ensures it won't collide with a user-defined
|
|
// variable.
|
|
expression(compiler);
|
|
|
|
// Verify that there is space to hidden local variables.
|
|
// Note that we expect only two addLocal calls next to each other in the
|
|
// following code.
|
|
if (compiler->numLocals + 2 > MAX_LOCALS)
|
|
{
|
|
error(compiler, "Cannot declare more than %d variables in one scope. (Not enough space for for-loops internal variables)",
|
|
MAX_LOCALS);
|
|
return;
|
|
}
|
|
int seqSlot = addLocal(compiler, "seq ", 4);
|
|
|
|
// Create another hidden local for the iterator object.
|
|
null(compiler, false);
|
|
int iterSlot = addLocal(compiler, "iter ", 5);
|
|
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after loop expression.");
|
|
|
|
Loop loop;
|
|
startLoop(compiler, &loop);
|
|
|
|
// Advance the iterator by calling the ".iterate" method on the sequence.
|
|
loadLocal(compiler, seqSlot);
|
|
loadLocal(compiler, iterSlot);
|
|
|
|
// Update and test the iterator.
|
|
callMethod(compiler, 1, "iterate(_)", 10);
|
|
emitByteArg(compiler, CODE_STORE_LOCAL, iterSlot);
|
|
testExitLoop(compiler);
|
|
|
|
// Get the current value in the sequence by calling ".iteratorValue".
|
|
loadLocal(compiler, seqSlot);
|
|
loadLocal(compiler, iterSlot);
|
|
callMethod(compiler, 1, "iteratorValue(_)", 16);
|
|
|
|
// Bind the loop variable in its own scope. This ensures we get a fresh
|
|
// variable each iteration so that closures for it don't all see the same one.
|
|
pushScope(compiler);
|
|
addLocal(compiler, name, length);
|
|
|
|
loopBody(compiler);
|
|
|
|
// Loop variable.
|
|
popScope(compiler);
|
|
|
|
endLoop(compiler);
|
|
|
|
// Hidden variables.
|
|
popScope(compiler);
|
|
}
|
|
|
|
static void ifStatement(Compiler* compiler)
|
|
{
|
|
// Compile the condition.
|
|
consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'if'.");
|
|
expression(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after if condition.");
|
|
|
|
// Jump to the else branch if the condition is false.
|
|
int ifJump = emitJump(compiler, CODE_JUMP_IF);
|
|
|
|
// Compile the then branch.
|
|
statement(compiler);
|
|
|
|
// Compile the else branch if there is one.
|
|
if (match(compiler, TOKEN_ELSE))
|
|
{
|
|
// Jump over the else branch when the if branch is taken.
|
|
int elseJump = emitJump(compiler, CODE_JUMP);
|
|
patchJump(compiler, ifJump);
|
|
|
|
statement(compiler);
|
|
|
|
// Patch the jump over the else.
|
|
patchJump(compiler, elseJump);
|
|
}
|
|
else
|
|
{
|
|
patchJump(compiler, ifJump);
|
|
}
|
|
}
|
|
|
|
static void whileStatement(Compiler* compiler)
|
|
{
|
|
Loop loop;
|
|
startLoop(compiler, &loop);
|
|
|
|
// Compile the condition.
|
|
consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'while'.");
|
|
expression(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after while condition.");
|
|
|
|
testExitLoop(compiler);
|
|
loopBody(compiler);
|
|
endLoop(compiler);
|
|
}
|
|
|
|
// Compiles a simple statement. These can only appear at the top-level or
|
|
// within curly blocks. Simple statements exclude variable binding statements
|
|
// like "var" and "class" which are not allowed directly in places like the
|
|
// branches of an "if" statement.
|
|
//
|
|
// Unlike expressions, statements do not leave a value on the stack.
|
|
void statement(Compiler* compiler)
|
|
{
|
|
if (match(compiler, TOKEN_BREAK))
|
|
{
|
|
if (compiler->loop == NULL)
|
|
{
|
|
error(compiler, "Cannot use 'break' outside of a loop.");
|
|
return;
|
|
}
|
|
|
|
// Since we will be jumping out of the scope, make sure any locals in it
|
|
// are discarded first.
|
|
discardLocals(compiler, compiler->loop->scopeDepth + 1);
|
|
|
|
// Emit a placeholder instruction for the jump to the end of the body. When
|
|
// we're done compiling the loop body and know where the end is, we'll
|
|
// replace these with `CODE_JUMP` instructions with appropriate offsets.
|
|
// We use `CODE_END` here because that can't occur in the middle of
|
|
// bytecode.
|
|
emitJump(compiler, CODE_END);
|
|
}
|
|
else if (match(compiler, TOKEN_CONTINUE))
|
|
{
|
|
if (compiler->loop == NULL)
|
|
{
|
|
error(compiler, "Cannot use 'continue' outside of a loop.");
|
|
return;
|
|
}
|
|
|
|
// Since we will be jumping out of the scope, make sure any locals in it
|
|
// are discarded first.
|
|
discardLocals(compiler, compiler->loop->scopeDepth + 1);
|
|
|
|
// emit a jump back to the top of the loop
|
|
int loopOffset = compiler->fn->code.count - compiler->loop->start + 2;
|
|
emitShortArg(compiler, CODE_LOOP, loopOffset);
|
|
}
|
|
else if (match(compiler, TOKEN_FOR))
|
|
{
|
|
forStatement(compiler);
|
|
}
|
|
else if (match(compiler, TOKEN_IF))
|
|
{
|
|
ifStatement(compiler);
|
|
}
|
|
else if (match(compiler, TOKEN_RETURN))
|
|
{
|
|
// Compile the return value.
|
|
if (peek(compiler) == TOKEN_LINE)
|
|
{
|
|
// If there's no expression after return, initializers should
|
|
// return 'this' and regular methods should return null
|
|
Code result = compiler->isInitializer ? CODE_LOAD_LOCAL_0 : CODE_NULL;
|
|
emitOp(compiler, result);
|
|
}
|
|
else
|
|
{
|
|
if (compiler->isInitializer)
|
|
{
|
|
error(compiler, "A constructor cannot return a value.");
|
|
}
|
|
|
|
expression(compiler);
|
|
}
|
|
|
|
emitOp(compiler, CODE_RETURN);
|
|
}
|
|
else if (match(compiler, TOKEN_WHILE))
|
|
{
|
|
whileStatement(compiler);
|
|
}
|
|
else if (match(compiler, TOKEN_LEFT_BRACE))
|
|
{
|
|
// Block statement.
|
|
pushScope(compiler);
|
|
if (finishBlock(compiler))
|
|
{
|
|
// Block was an expression, so discard it.
|
|
emitOp(compiler, CODE_POP);
|
|
}
|
|
popScope(compiler);
|
|
}
|
|
else
|
|
{
|
|
// Expression statement.
|
|
expression(compiler);
|
|
emitOp(compiler, CODE_POP);
|
|
}
|
|
}
|
|
|
|
// Creates a matching constructor method for an initializer with [signature]
|
|
// and [initializerSymbol].
|
|
//
|
|
// Construction is a two-stage process in Wren that involves two separate
|
|
// methods. There is a static method that allocates a new instance of the class.
|
|
// It then invokes an initializer method on the new instance, forwarding all of
|
|
// the constructor arguments to it.
|
|
//
|
|
// The allocator method always has a fixed implementation:
|
|
//
|
|
// CODE_CONSTRUCT - Replace the class in slot 0 with a new instance of it.
|
|
// CODE_CALL - Invoke the initializer on the new instance.
|
|
//
|
|
// This creates that method and calls the initializer with [initializerSymbol].
|
|
static void createConstructor(Compiler* compiler, Signature* signature,
|
|
int initializerSymbol)
|
|
{
|
|
Compiler methodCompiler;
|
|
initCompiler(&methodCompiler, compiler->parser, compiler, true);
|
|
|
|
// Allocate the instance.
|
|
emitOp(&methodCompiler, compiler->enclosingClass->isForeign
|
|
? CODE_FOREIGN_CONSTRUCT : CODE_CONSTRUCT);
|
|
|
|
// Run its initializer.
|
|
emitShortArg(&methodCompiler, (Code)(CODE_CALL_0 + signature->arity),
|
|
initializerSymbol);
|
|
|
|
// Return the instance.
|
|
emitOp(&methodCompiler, CODE_RETURN);
|
|
|
|
endCompiler(&methodCompiler, "", 0);
|
|
}
|
|
|
|
// Loads the enclosing class onto the stack and then binds the function already
|
|
// on the stack as a method on that class.
|
|
static void defineMethod(Compiler* compiler, Variable classVariable,
|
|
bool isStatic, int methodSymbol)
|
|
{
|
|
// Load the class. We have to do this for each method because we can't
|
|
// keep the class on top of the stack. If there are static fields, they
|
|
// will be locals above the initial variable slot for the class on the
|
|
// stack. To skip past those, we just load the class each time right before
|
|
// defining a method.
|
|
loadVariable(compiler, classVariable);
|
|
|
|
// Define the method.
|
|
Code instruction = isStatic ? CODE_METHOD_STATIC : CODE_METHOD_INSTANCE;
|
|
emitShortArg(compiler, instruction, methodSymbol);
|
|
}
|
|
|
|
// Declares a method in the enclosing class with [signature].
|
|
//
|
|
// Reports an error if a method with that signature is already declared.
|
|
// Returns the symbol for the method.
|
|
static int declareMethod(Compiler* compiler, Signature* signature,
|
|
const char* name, int length)
|
|
{
|
|
int symbol = signatureSymbol(compiler, signature);
|
|
|
|
// See if the class has already declared method with this signature.
|
|
ClassInfo* classInfo = compiler->enclosingClass;
|
|
IntBuffer* methods = classInfo->inStatic
|
|
? &classInfo->staticMethods : &classInfo->methods;
|
|
for (int i = 0; i < methods->count; i++)
|
|
{
|
|
if (methods->data[i] == symbol)
|
|
{
|
|
const char* staticPrefix = classInfo->inStatic ? "static " : "";
|
|
error(compiler, "Class %s already defines a %smethod '%s'.",
|
|
&compiler->enclosingClass->name->value, staticPrefix, name);
|
|
break;
|
|
}
|
|
}
|
|
|
|
wrenIntBufferWrite(compiler->parser->vm, methods, symbol);
|
|
return symbol;
|
|
}
|
|
|
|
static Value consumeLiteral(Compiler* compiler, const char* message)
|
|
{
|
|
if(match(compiler, TOKEN_FALSE)) return FALSE_VAL;
|
|
if(match(compiler, TOKEN_TRUE)) return TRUE_VAL;
|
|
if(match(compiler, TOKEN_NUMBER)) return compiler->parser->previous.value;
|
|
if(match(compiler, TOKEN_STRING)) return compiler->parser->previous.value;
|
|
if(match(compiler, TOKEN_NAME)) return compiler->parser->previous.value;
|
|
|
|
error(compiler, message);
|
|
nextToken(compiler->parser);
|
|
return NULL_VAL;
|
|
}
|
|
|
|
static bool matchAttribute(Compiler* compiler) {
|
|
|
|
if(match(compiler, TOKEN_HASH))
|
|
{
|
|
compiler->numAttributes++;
|
|
bool runtimeAccess = match(compiler, TOKEN_BANG);
|
|
if(match(compiler, TOKEN_NAME))
|
|
{
|
|
Value group = compiler->parser->previous.value;
|
|
TokenType ahead = peek(compiler);
|
|
if(ahead == TOKEN_EQ || ahead == TOKEN_LINE)
|
|
{
|
|
Value key = group;
|
|
Value value = NULL_VAL;
|
|
if(match(compiler, TOKEN_EQ))
|
|
{
|
|
value = consumeLiteral(compiler, "Expect a Bool, Num, String or Identifier literal for an attribute value.");
|
|
}
|
|
if(runtimeAccess) addToAttributeGroup(compiler, NULL_VAL, key, value);
|
|
}
|
|
else if(match(compiler, TOKEN_LEFT_PAREN))
|
|
{
|
|
ignoreNewlines(compiler);
|
|
if(match(compiler, TOKEN_RIGHT_PAREN))
|
|
{
|
|
error(compiler, "Expected attributes in group, group cannot be empty.");
|
|
}
|
|
else
|
|
{
|
|
while(peek(compiler) != TOKEN_RIGHT_PAREN)
|
|
{
|
|
consume(compiler, TOKEN_NAME, "Expect name for attribute key.");
|
|
Value key = compiler->parser->previous.value;
|
|
Value value = NULL_VAL;
|
|
if(match(compiler, TOKEN_EQ))
|
|
{
|
|
value = consumeLiteral(compiler, "Expect a Bool, Num, String or Identifier literal for an attribute value.");
|
|
}
|
|
if(runtimeAccess) addToAttributeGroup(compiler, group, key, value);
|
|
ignoreNewlines(compiler);
|
|
if(!match(compiler, TOKEN_COMMA)) break;
|
|
ignoreNewlines(compiler);
|
|
}
|
|
|
|
ignoreNewlines(compiler);
|
|
consume(compiler, TOKEN_RIGHT_PAREN,
|
|
"Expected ')' after grouped attributes.");
|
|
}
|
|
}
|
|
else
|
|
{
|
|
error(compiler, "Expect an equal, newline or grouping after an attribute key.");
|
|
}
|
|
}
|
|
else
|
|
{
|
|
error(compiler, "Expect an attribute definition after #.");
|
|
}
|
|
|
|
consumeLine(compiler, "Expect newline after attribute.");
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
// Compiles a method definition inside a class body.
|
|
//
|
|
// Returns `true` if it compiled successfully, or `false` if the method couldn't
|
|
// be parsed.
|
|
static bool method(Compiler* compiler, Variable classVariable)
|
|
{
|
|
// Parse any attributes before the method and store them
|
|
if(matchAttribute(compiler)) {
|
|
return method(compiler, classVariable);
|
|
}
|
|
|
|
// TODO: What about foreign constructors?
|
|
bool isForeign = match(compiler, TOKEN_FOREIGN);
|
|
bool isStatic = match(compiler, TOKEN_STATIC);
|
|
compiler->enclosingClass->inStatic = isStatic;
|
|
|
|
SignatureFn signatureFn = rules[compiler->parser->current.type].method;
|
|
nextToken(compiler->parser);
|
|
|
|
if (signatureFn == NULL)
|
|
{
|
|
error(compiler, "Expect method definition.");
|
|
return false;
|
|
}
|
|
|
|
// Build the method signature.
|
|
Signature signature = signatureFromToken(compiler, SIG_GETTER);
|
|
compiler->enclosingClass->signature = &signature;
|
|
|
|
Compiler methodCompiler;
|
|
initCompiler(&methodCompiler, compiler->parser, compiler, true);
|
|
|
|
// Compile the method signature.
|
|
signatureFn(&methodCompiler, &signature);
|
|
|
|
methodCompiler.isInitializer = signature.type == SIG_INITIALIZER;
|
|
|
|
if (isStatic && signature.type == SIG_INITIALIZER)
|
|
{
|
|
error(compiler, "A constructor cannot be static.");
|
|
}
|
|
|
|
// Include the full signature in debug messages in stack traces.
|
|
char fullSignature[MAX_METHOD_SIGNATURE];
|
|
int length;
|
|
signatureToString(&signature, fullSignature, &length);
|
|
|
|
// Copy any attributes the compiler collected into the enclosing class
|
|
copyMethodAttributes(compiler, isForeign, isStatic, fullSignature, length);
|
|
|
|
// Check for duplicate methods. Doesn't matter that it's already been
|
|
// defined, error will discard bytecode anyway.
|
|
// Check if the method table already contains this symbol
|
|
int methodSymbol = declareMethod(compiler, &signature, fullSignature, length);
|
|
|
|
if (isForeign)
|
|
{
|
|
// Define a constant for the signature.
|
|
emitConstant(compiler, wrenNewStringLength(compiler->parser->vm,
|
|
fullSignature, length));
|
|
|
|
// We don't need the function we started compiling in the parameter list
|
|
// any more.
|
|
methodCompiler.parser->vm->compiler = methodCompiler.parent;
|
|
}
|
|
else
|
|
{
|
|
consume(compiler, TOKEN_LEFT_BRACE, "Expect '{' to begin method body.");
|
|
finishBody(&methodCompiler);
|
|
endCompiler(&methodCompiler, fullSignature, length);
|
|
}
|
|
|
|
// Define the method. For a constructor, this defines the instance
|
|
// initializer method.
|
|
defineMethod(compiler, classVariable, isStatic, methodSymbol);
|
|
|
|
if (signature.type == SIG_INITIALIZER)
|
|
{
|
|
// Also define a matching constructor method on the metaclass.
|
|
signature.type = SIG_METHOD;
|
|
int constructorSymbol = signatureSymbol(compiler, &signature);
|
|
|
|
createConstructor(compiler, &signature, methodSymbol);
|
|
defineMethod(compiler, classVariable, true, constructorSymbol);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Compiles a class definition. Assumes the "class" token has already been
|
|
// consumed (along with a possibly preceding "foreign" token).
|
|
static void classDefinition(Compiler* compiler, bool isForeign)
|
|
{
|
|
// Create a variable to store the class in.
|
|
Variable classVariable;
|
|
classVariable.scope = compiler->scopeDepth == -1 ? SCOPE_MODULE : SCOPE_LOCAL;
|
|
classVariable.index = declareNamedVariable(compiler);
|
|
|
|
// Create shared class name value
|
|
Value classNameString = wrenNewStringLength(compiler->parser->vm,
|
|
compiler->parser->previous.start, compiler->parser->previous.length);
|
|
|
|
// Create class name string to track method duplicates
|
|
ObjString* className = AS_STRING(classNameString);
|
|
|
|
// Make a string constant for the name.
|
|
emitConstant(compiler, classNameString);
|
|
|
|
// Load the superclass (if there is one).
|
|
if (match(compiler, TOKEN_IS))
|
|
{
|
|
parsePrecedence(compiler, PREC_CALL);
|
|
}
|
|
else
|
|
{
|
|
// Implicitly inherit from Object.
|
|
loadCoreVariable(compiler, "Object");
|
|
}
|
|
|
|
// Store a placeholder for the number of fields argument. We don't know the
|
|
// count until we've compiled all the methods to see which fields are used.
|
|
int numFieldsInstruction = -1;
|
|
if (isForeign)
|
|
{
|
|
emitOp(compiler, CODE_FOREIGN_CLASS);
|
|
}
|
|
else
|
|
{
|
|
numFieldsInstruction = emitByteArg(compiler, CODE_CLASS, 255);
|
|
}
|
|
|
|
// Store it in its name.
|
|
defineVariable(compiler, classVariable.index);
|
|
|
|
// Push a local variable scope. Static fields in a class body are hoisted out
|
|
// into local variables declared in this scope. Methods that use them will
|
|
// have upvalues referencing them.
|
|
pushScope(compiler);
|
|
|
|
ClassInfo classInfo;
|
|
classInfo.isForeign = isForeign;
|
|
classInfo.name = className;
|
|
|
|
// Allocate attribute maps if necessary.
|
|
// A method will allocate the methods one if needed
|
|
classInfo.classAttributes = compiler->attributes->count > 0
|
|
? wrenNewMap(compiler->parser->vm)
|
|
: NULL;
|
|
classInfo.methodAttributes = NULL;
|
|
// Copy any existing attributes into the class
|
|
copyAttributes(compiler, classInfo.classAttributes);
|
|
|
|
// Set up a symbol table for the class's fields. We'll initially compile
|
|
// them to slots starting at zero. When the method is bound to the class, the
|
|
// bytecode will be adjusted by [wrenBindMethod] to take inherited fields
|
|
// into account.
|
|
wrenSymbolTableInit(&classInfo.fields);
|
|
|
|
// Set up symbol buffers to track duplicate static and instance methods.
|
|
wrenIntBufferInit(&classInfo.methods);
|
|
wrenIntBufferInit(&classInfo.staticMethods);
|
|
compiler->enclosingClass = &classInfo;
|
|
|
|
// Compile the method definitions.
|
|
consume(compiler, TOKEN_LEFT_BRACE, "Expect '{' after class declaration.");
|
|
matchLine(compiler);
|
|
|
|
while (!match(compiler, TOKEN_RIGHT_BRACE))
|
|
{
|
|
if (!method(compiler, classVariable)) break;
|
|
|
|
// Don't require a newline after the last definition.
|
|
if (match(compiler, TOKEN_RIGHT_BRACE)) break;
|
|
|
|
consumeLine(compiler, "Expect newline after definition in class.");
|
|
}
|
|
|
|
// If any attributes are present,
|
|
// instantiate a ClassAttributes instance for the class
|
|
// and send it over to CODE_END_CLASS
|
|
bool hasAttr = classInfo.classAttributes != NULL ||
|
|
classInfo.methodAttributes != NULL;
|
|
if(hasAttr) {
|
|
emitClassAttributes(compiler, &classInfo);
|
|
loadVariable(compiler, classVariable);
|
|
// At the moment, we don't have other uses for CODE_END_CLASS,
|
|
// so we put it inside this condition. Later, we can always
|
|
// emit it and use it as needed.
|
|
emitOp(compiler, CODE_END_CLASS);
|
|
}
|
|
|
|
// Update the class with the number of fields.
|
|
if (!isForeign)
|
|
{
|
|
compiler->fn->code.data[numFieldsInstruction] =
|
|
(uint8_t)classInfo.fields.count;
|
|
}
|
|
|
|
// Clear symbol tables for tracking field and method names.
|
|
wrenSymbolTableClear(compiler->parser->vm, &classInfo.fields);
|
|
wrenIntBufferClear(compiler->parser->vm, &classInfo.methods);
|
|
wrenIntBufferClear(compiler->parser->vm, &classInfo.staticMethods);
|
|
compiler->enclosingClass = NULL;
|
|
popScope(compiler);
|
|
}
|
|
|
|
// Compiles an "import" statement.
|
|
//
|
|
// An import compiles to a series of instructions. Given:
|
|
//
|
|
// import "foo" for Bar, Baz
|
|
//
|
|
// We compile a single IMPORT_MODULE "foo" instruction to load the module
|
|
// itself. When that finishes executing the imported module, it leaves the
|
|
// ObjModule in vm->lastModule. Then, for Bar and Baz, we:
|
|
//
|
|
// * Declare a variable in the current scope with that name.
|
|
// * Emit an IMPORT_VARIABLE instruction to load the variable's value from the
|
|
// other module.
|
|
// * Compile the code to store that value in the variable in this scope.
|
|
static void import(Compiler* compiler)
|
|
{
|
|
ignoreNewlines(compiler);
|
|
consume(compiler, TOKEN_STRING, "Expect a string after 'import'.");
|
|
int moduleConstant = addConstant(compiler, compiler->parser->previous.value);
|
|
|
|
// Load the module.
|
|
emitShortArg(compiler, CODE_IMPORT_MODULE, moduleConstant);
|
|
|
|
// Discard the unused result value from calling the module body's closure.
|
|
emitOp(compiler, CODE_POP);
|
|
|
|
// The for clause is optional.
|
|
if (!match(compiler, TOKEN_FOR)) return;
|
|
|
|
// Compile the comma-separated list of variables to import.
|
|
do
|
|
{
|
|
ignoreNewlines(compiler);
|
|
|
|
consume(compiler, TOKEN_NAME, "Expect variable name.");
|
|
|
|
// We need to hold onto the source variable,
|
|
// in order to reference it in the import later
|
|
Token sourceVariableToken = compiler->parser->previous;
|
|
|
|
// Define a string constant for the original variable name.
|
|
int sourceVariableConstant = addConstant(compiler,
|
|
wrenNewStringLength(compiler->parser->vm,
|
|
sourceVariableToken.start,
|
|
sourceVariableToken.length));
|
|
|
|
// Store the symbol we care about for the variable
|
|
int slot = -1;
|
|
if(match(compiler, TOKEN_AS))
|
|
{
|
|
//import "module" for Source as Dest
|
|
//Use 'Dest' as the name by declaring a new variable for it.
|
|
//This parses a name after the 'as' and defines it.
|
|
slot = declareNamedVariable(compiler);
|
|
}
|
|
else
|
|
{
|
|
//import "module" for Source
|
|
//Uses 'Source' as the name directly
|
|
slot = declareVariable(compiler, &sourceVariableToken);
|
|
}
|
|
|
|
// Load the variable from the other module.
|
|
emitShortArg(compiler, CODE_IMPORT_VARIABLE, sourceVariableConstant);
|
|
|
|
// Store the result in the variable here.
|
|
defineVariable(compiler, slot);
|
|
} while (match(compiler, TOKEN_COMMA));
|
|
}
|
|
|
|
// Compiles a "var" variable definition statement.
|
|
static void variableDefinition(Compiler* compiler)
|
|
{
|
|
// Grab its name, but don't declare it yet. A (local) variable shouldn't be
|
|
// in scope in its own initializer.
|
|
consume(compiler, TOKEN_NAME, "Expect variable name.");
|
|
Token nameToken = compiler->parser->previous;
|
|
|
|
// Compile the initializer.
|
|
if (match(compiler, TOKEN_EQ))
|
|
{
|
|
ignoreNewlines(compiler);
|
|
expression(compiler);
|
|
}
|
|
else
|
|
{
|
|
// Default initialize it to null.
|
|
null(compiler, false);
|
|
}
|
|
|
|
// Now put it in scope.
|
|
int symbol = declareVariable(compiler, &nameToken);
|
|
defineVariable(compiler, symbol);
|
|
}
|
|
|
|
// Compiles a "definition". These are the statements that bind new variables.
|
|
// They can only appear at the top level of a block and are prohibited in places
|
|
// like the non-curly body of an if or while.
|
|
void definition(Compiler* compiler)
|
|
{
|
|
if(matchAttribute(compiler)) {
|
|
definition(compiler);
|
|
return;
|
|
}
|
|
|
|
if (match(compiler, TOKEN_CLASS))
|
|
{
|
|
classDefinition(compiler, false);
|
|
return;
|
|
}
|
|
else if (match(compiler, TOKEN_FOREIGN))
|
|
{
|
|
consume(compiler, TOKEN_CLASS, "Expect 'class' after 'foreign'.");
|
|
classDefinition(compiler, true);
|
|
return;
|
|
}
|
|
|
|
disallowAttributes(compiler);
|
|
|
|
if (match(compiler, TOKEN_IMPORT))
|
|
{
|
|
import(compiler);
|
|
}
|
|
else if (match(compiler, TOKEN_VAR))
|
|
{
|
|
variableDefinition(compiler);
|
|
}
|
|
else
|
|
{
|
|
statement(compiler);
|
|
}
|
|
}
|
|
|
|
ObjFn* wrenCompile(WrenVM* vm, ObjModule* module, const char* source,
|
|
bool isExpression, bool printErrors)
|
|
{
|
|
// Skip the UTF-8 BOM if there is one.
|
|
if (strncmp(source, "\xEF\xBB\xBF", 3) == 0) source += 3;
|
|
|
|
Parser parser;
|
|
parser.vm = vm;
|
|
parser.module = module;
|
|
parser.source = source;
|
|
|
|
parser.tokenStart = source;
|
|
parser.currentChar = source;
|
|
parser.currentLine = 1;
|
|
parser.numParens = 0;
|
|
|
|
// Zero-init the current token. This will get copied to previous when
|
|
// nextToken() is called below.
|
|
parser.next.type = TOKEN_ERROR;
|
|
parser.next.start = source;
|
|
parser.next.length = 0;
|
|
parser.next.line = 0;
|
|
parser.next.value = UNDEFINED_VAL;
|
|
|
|
parser.printErrors = printErrors;
|
|
parser.hasError = false;
|
|
|
|
// Read the first token into next
|
|
nextToken(&parser);
|
|
// Copy next -> current
|
|
nextToken(&parser);
|
|
|
|
int numExistingVariables = module->variables.count;
|
|
|
|
Compiler compiler;
|
|
initCompiler(&compiler, &parser, NULL, false);
|
|
ignoreNewlines(&compiler);
|
|
|
|
if (isExpression)
|
|
{
|
|
expression(&compiler);
|
|
consume(&compiler, TOKEN_EOF, "Expect end of expression.");
|
|
}
|
|
else
|
|
{
|
|
while (!match(&compiler, TOKEN_EOF))
|
|
{
|
|
definition(&compiler);
|
|
|
|
// If there is no newline, it must be the end of file on the same line.
|
|
if (!matchLine(&compiler))
|
|
{
|
|
consume(&compiler, TOKEN_EOF, "Expect end of file.");
|
|
break;
|
|
}
|
|
}
|
|
|
|
emitOp(&compiler, CODE_END_MODULE);
|
|
}
|
|
|
|
emitOp(&compiler, CODE_RETURN);
|
|
|
|
// See if there are any implicitly declared module-level variables that never
|
|
// got an explicit definition. They will have values that are numbers
|
|
// indicating the line where the variable was first used.
|
|
for (int i = numExistingVariables; i < parser.module->variables.count; i++)
|
|
{
|
|
if (IS_NUM(parser.module->variables.data[i]))
|
|
{
|
|
// Synthesize a token for the original use site.
|
|
parser.previous.type = TOKEN_NAME;
|
|
parser.previous.start = parser.module->variableNames.data[i]->value;
|
|
parser.previous.length = parser.module->variableNames.data[i]->length;
|
|
parser.previous.line = (int)AS_NUM(parser.module->variables.data[i]);
|
|
error(&compiler, "Variable is used but not defined.");
|
|
}
|
|
}
|
|
|
|
return endCompiler(&compiler, "(script)", 8);
|
|
}
|
|
|
|
void wrenBindMethodCode(ObjClass* classObj, ObjFn* fn)
|
|
{
|
|
int ip = 0;
|
|
for (;;)
|
|
{
|
|
Code instruction = (Code)fn->code.data[ip];
|
|
switch (instruction)
|
|
{
|
|
case CODE_LOAD_FIELD:
|
|
case CODE_STORE_FIELD:
|
|
case CODE_LOAD_FIELD_THIS:
|
|
case CODE_STORE_FIELD_THIS:
|
|
// Shift this class's fields down past the inherited ones. We don't
|
|
// check for overflow here because we'll see if the number of fields
|
|
// overflows when the subclass is created.
|
|
fn->code.data[ip + 1] += classObj->superclass->numFields;
|
|
break;
|
|
|
|
case CODE_SUPER_0:
|
|
case CODE_SUPER_1:
|
|
case CODE_SUPER_2:
|
|
case CODE_SUPER_3:
|
|
case CODE_SUPER_4:
|
|
case CODE_SUPER_5:
|
|
case CODE_SUPER_6:
|
|
case CODE_SUPER_7:
|
|
case CODE_SUPER_8:
|
|
case CODE_SUPER_9:
|
|
case CODE_SUPER_10:
|
|
case CODE_SUPER_11:
|
|
case CODE_SUPER_12:
|
|
case CODE_SUPER_13:
|
|
case CODE_SUPER_14:
|
|
case CODE_SUPER_15:
|
|
case CODE_SUPER_16:
|
|
{
|
|
// Fill in the constant slot with a reference to the superclass.
|
|
int constant = (fn->code.data[ip + 3] << 8) | fn->code.data[ip + 4];
|
|
fn->constants.data[constant] = OBJ_VAL(classObj->superclass);
|
|
break;
|
|
}
|
|
|
|
case CODE_CLOSURE:
|
|
{
|
|
// Bind the nested closure too.
|
|
int constant = (fn->code.data[ip + 1] << 8) | fn->code.data[ip + 2];
|
|
wrenBindMethodCode(classObj, AS_FN(fn->constants.data[constant]));
|
|
break;
|
|
}
|
|
|
|
case CODE_END:
|
|
return;
|
|
|
|
default:
|
|
// Other instructions are unaffected, so just skip over them.
|
|
break;
|
|
}
|
|
ip += 1 + getByteCountForArguments(fn->code.data, fn->constants.data, ip);
|
|
}
|
|
}
|
|
|
|
void wrenMarkCompiler(WrenVM* vm, Compiler* compiler)
|
|
{
|
|
wrenGrayValue(vm, compiler->parser->current.value);
|
|
wrenGrayValue(vm, compiler->parser->previous.value);
|
|
wrenGrayValue(vm, compiler->parser->next.value);
|
|
|
|
// Walk up the parent chain to mark the outer compilers too. The VM only
|
|
// tracks the innermost one.
|
|
do
|
|
{
|
|
wrenGrayObj(vm, (Obj*)compiler->fn);
|
|
wrenGrayObj(vm, (Obj*)compiler->constants);
|
|
wrenGrayObj(vm, (Obj*)compiler->attributes);
|
|
|
|
if (compiler->enclosingClass != NULL)
|
|
{
|
|
wrenBlackenSymbolTable(vm, &compiler->enclosingClass->fields);
|
|
|
|
if(compiler->enclosingClass->methodAttributes != NULL)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)compiler->enclosingClass->methodAttributes);
|
|
}
|
|
if(compiler->enclosingClass->classAttributes != NULL)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)compiler->enclosingClass->classAttributes);
|
|
}
|
|
}
|
|
|
|
compiler = compiler->parent;
|
|
}
|
|
while (compiler != NULL);
|
|
}
|
|
|
|
// Helpers for Attributes
|
|
|
|
// Throw an error if any attributes were found preceding,
|
|
// and clear the attributes so the error doesn't keep happening.
|
|
static void disallowAttributes(Compiler* compiler)
|
|
{
|
|
if (compiler->numAttributes > 0)
|
|
{
|
|
error(compiler, "Attributes can only specified before a class or a method");
|
|
wrenMapClear(compiler->parser->vm, compiler->attributes);
|
|
compiler->numAttributes = 0;
|
|
}
|
|
}
|
|
|
|
// Add an attribute to a given group in the compiler attribues map
|
|
static void addToAttributeGroup(Compiler* compiler,
|
|
Value group, Value key, Value value)
|
|
{
|
|
WrenVM* vm = compiler->parser->vm;
|
|
|
|
if(IS_OBJ(group)) wrenPushRoot(vm, AS_OBJ(group));
|
|
if(IS_OBJ(key)) wrenPushRoot(vm, AS_OBJ(key));
|
|
if(IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value));
|
|
|
|
Value groupMapValue = wrenMapGet(compiler->attributes, group);
|
|
if(IS_UNDEFINED(groupMapValue))
|
|
{
|
|
groupMapValue = OBJ_VAL(wrenNewMap(vm));
|
|
wrenMapSet(vm, compiler->attributes, group, groupMapValue);
|
|
}
|
|
|
|
//we store them as a map per so we can maintain duplicate keys
|
|
//group = { key:[value, ...], }
|
|
ObjMap* groupMap = AS_MAP(groupMapValue);
|
|
|
|
//var keyItems = group[key]
|
|
//if(!keyItems) keyItems = group[key] = []
|
|
Value keyItemsValue = wrenMapGet(groupMap, key);
|
|
if(IS_UNDEFINED(keyItemsValue))
|
|
{
|
|
keyItemsValue = OBJ_VAL(wrenNewList(vm, 0));
|
|
wrenMapSet(vm, groupMap, key, keyItemsValue);
|
|
}
|
|
|
|
//keyItems.add(value)
|
|
ObjList* keyItems = AS_LIST(keyItemsValue);
|
|
wrenValueBufferWrite(vm, &keyItems->elements, value);
|
|
|
|
if(IS_OBJ(group)) wrenPopRoot(vm);
|
|
if(IS_OBJ(key)) wrenPopRoot(vm);
|
|
if(IS_OBJ(value)) wrenPopRoot(vm);
|
|
}
|
|
|
|
|
|
// Emit the attributes in the give map onto the stack
|
|
static void emitAttributes(Compiler* compiler, ObjMap* attributes)
|
|
{
|
|
// Instantiate a new map for the attributes
|
|
loadCoreVariable(compiler, "Map");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
|
|
// The attributes are stored as group = { key:[value, value, ...] }
|
|
// so our first level is the group map
|
|
for(uint32_t groupIdx = 0; groupIdx < attributes->capacity; groupIdx++)
|
|
{
|
|
const MapEntry* groupEntry = &attributes->entries[groupIdx];
|
|
if(IS_UNDEFINED(groupEntry->key)) continue;
|
|
//group key
|
|
emitConstant(compiler, groupEntry->key);
|
|
|
|
//group value is gonna be a map
|
|
loadCoreVariable(compiler, "Map");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
|
|
ObjMap* groupItems = AS_MAP(groupEntry->value);
|
|
for(uint32_t itemIdx = 0; itemIdx < groupItems->capacity; itemIdx++)
|
|
{
|
|
const MapEntry* itemEntry = &groupItems->entries[itemIdx];
|
|
if(IS_UNDEFINED(itemEntry->key)) continue;
|
|
|
|
emitConstant(compiler, itemEntry->key);
|
|
// Attribute key value, key = []
|
|
loadCoreVariable(compiler, "List");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
// Add the items to the key list
|
|
ObjList* items = AS_LIST(itemEntry->value);
|
|
for(int itemIdx = 0; itemIdx < items->elements.count; ++itemIdx)
|
|
{
|
|
emitConstant(compiler, items->elements.data[itemIdx]);
|
|
callMethod(compiler, 1, "addCore_(_)", 11);
|
|
}
|
|
// Add the list to the map
|
|
callMethod(compiler, 2, "addCore_(_,_)", 13);
|
|
}
|
|
|
|
// Add the key/value to the map
|
|
callMethod(compiler, 2, "addCore_(_,_)", 13);
|
|
}
|
|
|
|
}
|
|
|
|
// Methods are stored as method <-> attributes, so we have to have
|
|
// an indirection to resolve for methods
|
|
static void emitAttributeMethods(Compiler* compiler, ObjMap* attributes)
|
|
{
|
|
// Instantiate a new map for the attributes
|
|
loadCoreVariable(compiler, "Map");
|
|
callMethod(compiler, 0, "new()", 5);
|
|
|
|
for(uint32_t methodIdx = 0; methodIdx < attributes->capacity; methodIdx++)
|
|
{
|
|
const MapEntry* methodEntry = &attributes->entries[methodIdx];
|
|
if(IS_UNDEFINED(methodEntry->key)) continue;
|
|
emitConstant(compiler, methodEntry->key);
|
|
ObjMap* attributeMap = AS_MAP(methodEntry->value);
|
|
emitAttributes(compiler, attributeMap);
|
|
callMethod(compiler, 2, "addCore_(_,_)", 13);
|
|
}
|
|
}
|
|
|
|
|
|
// Emit the final ClassAttributes that exists at runtime
|
|
static void emitClassAttributes(Compiler* compiler, ClassInfo* classInfo)
|
|
{
|
|
loadCoreVariable(compiler, "ClassAttributes");
|
|
|
|
classInfo->classAttributes
|
|
? emitAttributes(compiler, classInfo->classAttributes)
|
|
: null(compiler, false);
|
|
|
|
classInfo->methodAttributes
|
|
? emitAttributeMethods(compiler, classInfo->methodAttributes)
|
|
: null(compiler, false);
|
|
|
|
callMethod(compiler, 2, "new(_,_)", 8);
|
|
}
|
|
|
|
// Copy the current attributes stored in the compiler into a destination map
|
|
// This also resets the counter, since the intent is to consume the attributes
|
|
static void copyAttributes(Compiler* compiler, ObjMap* into)
|
|
{
|
|
compiler->numAttributes = 0;
|
|
|
|
if(compiler->attributes->count == 0) return;
|
|
if(into == NULL) return;
|
|
|
|
WrenVM* vm = compiler->parser->vm;
|
|
|
|
// Note we copy the actual values as is since we'll take ownership
|
|
// and clear the original map
|
|
for(uint32_t attrIdx = 0; attrIdx < compiler->attributes->capacity; attrIdx++)
|
|
{
|
|
const MapEntry* attrEntry = &compiler->attributes->entries[attrIdx];
|
|
if(IS_UNDEFINED(attrEntry->key)) continue;
|
|
wrenMapSet(vm, into, attrEntry->key, attrEntry->value);
|
|
}
|
|
|
|
wrenMapClear(vm, compiler->attributes);
|
|
}
|
|
|
|
// Copy the current attributes stored in the compiler into the method specific
|
|
// attributes for the current enclosingClass.
|
|
// This also resets the counter, since the intent is to consume the attributes
|
|
static void copyMethodAttributes(Compiler* compiler, bool isForeign,
|
|
bool isStatic, const char* fullSignature, int32_t length)
|
|
{
|
|
compiler->numAttributes = 0;
|
|
|
|
if(compiler->attributes->count == 0) return;
|
|
|
|
WrenVM* vm = compiler->parser->vm;
|
|
|
|
// Make a map for this method to copy into
|
|
ObjMap* methodAttr = wrenNewMap(vm);
|
|
wrenPushRoot(vm, (Obj*)methodAttr);
|
|
copyAttributes(compiler, methodAttr);
|
|
|
|
// Include 'foreign static ' in front as needed
|
|
int32_t fullLength = length;
|
|
if(isForeign) fullLength += 8;
|
|
if(isStatic) fullLength += 7;
|
|
char fullSignatureWithPrefix[MAX_METHOD_SIGNATURE + 8 + 7];
|
|
const char* foreignPrefix = isForeign ? "foreign " : "";
|
|
const char* staticPrefix = isStatic ? "static " : "";
|
|
sprintf(fullSignatureWithPrefix, "%s%s%.*s", foreignPrefix, staticPrefix,
|
|
length, fullSignature);
|
|
fullSignatureWithPrefix[fullLength] = '\0';
|
|
|
|
if(compiler->enclosingClass->methodAttributes == NULL) {
|
|
compiler->enclosingClass->methodAttributes = wrenNewMap(vm);
|
|
}
|
|
|
|
// Store the method attributes in the class map
|
|
Value key = wrenNewStringLength(vm, fullSignatureWithPrefix, fullLength);
|
|
wrenMapSet(vm, compiler->enclosingClass->methodAttributes, key, OBJ_VAL(methodAttr));
|
|
|
|
wrenPopRoot(vm);
|
|
}
|
|
// End file "wren_compiler.c"
|
|
// Begin file "wren_primitive.c"
|
|
// Begin file "wren_primitive.h"
|
|
#ifndef wren_primitive_h
|
|
#define wren_primitive_h
|
|
|
|
|
|
// Binds a primitive method named [name] (in Wren) implemented using C function
|
|
// [fn] to `ObjClass` [cls].
|
|
#define PRIMITIVE(cls, name, function) \
|
|
do \
|
|
{ \
|
|
int symbol = wrenSymbolTableEnsure(vm, \
|
|
&vm->methodNames, name, strlen(name)); \
|
|
Method method; \
|
|
method.type = METHOD_PRIMITIVE; \
|
|
method.as.primitive = prim_##function; \
|
|
wrenBindMethod(vm, cls, symbol, method); \
|
|
} while (false)
|
|
|
|
// Binds a primitive method named [name] (in Wren) implemented using C function
|
|
// [fn] to `ObjClass` [cls], but as a FN call.
|
|
#define FUNCTION_CALL(cls, name, function) \
|
|
do \
|
|
{ \
|
|
int symbol = wrenSymbolTableEnsure(vm, \
|
|
&vm->methodNames, name, strlen(name)); \
|
|
Method method; \
|
|
method.type = METHOD_FUNCTION_CALL; \
|
|
method.as.primitive = prim_##function; \
|
|
wrenBindMethod(vm, cls, symbol, method); \
|
|
} while (false)
|
|
|
|
// Defines a primitive method whose C function name is [name]. This abstracts
|
|
// the actual type signature of a primitive function and makes it clear which C
|
|
// functions are invoked as primitives.
|
|
#define DEF_PRIMITIVE(name) \
|
|
static bool prim_##name(WrenVM* vm, Value* args)
|
|
|
|
#define RETURN_VAL(value) \
|
|
do \
|
|
{ \
|
|
args[0] = value; \
|
|
return true; \
|
|
} while (false)
|
|
|
|
#define RETURN_OBJ(obj) RETURN_VAL(OBJ_VAL(obj))
|
|
#define RETURN_BOOL(value) RETURN_VAL(BOOL_VAL(value))
|
|
#define RETURN_FALSE RETURN_VAL(FALSE_VAL)
|
|
#define RETURN_NULL RETURN_VAL(NULL_VAL)
|
|
#define RETURN_NUM(value) RETURN_VAL(NUM_VAL(value))
|
|
#define RETURN_TRUE RETURN_VAL(TRUE_VAL)
|
|
|
|
#define RETURN_ERROR(msg) \
|
|
do \
|
|
{ \
|
|
vm->fiber->error = wrenNewStringLength(vm, msg, sizeof(msg) - 1); \
|
|
return false; \
|
|
} while (false)
|
|
|
|
#define RETURN_ERROR_FMT(...) \
|
|
do \
|
|
{ \
|
|
vm->fiber->error = wrenStringFormat(vm, __VA_ARGS__); \
|
|
return false; \
|
|
} while (false)
|
|
|
|
// Validates that the given [arg] is a function. Returns true if it is. If not,
|
|
// reports an error and returns false.
|
|
bool validateFn(WrenVM* vm, Value arg, const char* argName);
|
|
|
|
// Validates that the given [arg] is a Num. Returns true if it is. If not,
|
|
// reports an error and returns false.
|
|
bool validateNum(WrenVM* vm, Value arg, const char* argName);
|
|
|
|
// Validates that [value] is an integer. Returns true if it is. If not, reports
|
|
// an error and returns false.
|
|
bool validateIntValue(WrenVM* vm, double value, const char* argName);
|
|
|
|
// Validates that the given [arg] is an integer. Returns true if it is. If not,
|
|
// reports an error and returns false.
|
|
bool validateInt(WrenVM* vm, Value arg, const char* argName);
|
|
|
|
// Validates that [arg] is a valid object for use as a map key. Returns true if
|
|
// it is. If not, reports an error and returns false.
|
|
bool validateKey(WrenVM* vm, Value arg);
|
|
|
|
// Validates that the argument at [argIndex] is an integer within `[0, count)`.
|
|
// Also allows negative indices which map backwards from the end. Returns the
|
|
// valid positive index value. If invalid, reports an error and returns
|
|
// `UINT32_MAX`.
|
|
uint32_t validateIndex(WrenVM* vm, Value arg, uint32_t count,
|
|
const char* argName);
|
|
|
|
// Validates that the given [arg] is a String. Returns true if it is. If not,
|
|
// reports an error and returns false.
|
|
bool validateString(WrenVM* vm, Value arg, const char* argName);
|
|
|
|
// Given a [range] and the [length] of the object being operated on, determines
|
|
// the series of elements that should be chosen from the underlying object.
|
|
// Handles ranges that count backwards from the end as well as negative ranges.
|
|
//
|
|
// Returns the index from which the range should start or `UINT32_MAX` if the
|
|
// range is invalid. After calling, [length] will be updated with the number of
|
|
// elements in the resulting sequence. [step] will be direction that the range
|
|
// is going: `1` if the range is increasing from the start index or `-1` if the
|
|
// range is decreasing.
|
|
uint32_t calculateRange(WrenVM* vm, ObjRange* range, uint32_t* length,
|
|
int* step);
|
|
|
|
#endif
|
|
// End file "wren_primitive.h"
|
|
|
|
#include <math.h>
|
|
|
|
// Validates that [value] is an integer within `[0, count)`. Also allows
|
|
// negative indices which map backwards from the end. Returns the valid positive
|
|
// index value. If invalid, reports an error and returns `UINT32_MAX`.
|
|
static uint32_t validateIndexValue(WrenVM* vm, uint32_t count, double value,
|
|
const char* argName)
|
|
{
|
|
if (!validateIntValue(vm, value, argName)) return UINT32_MAX;
|
|
|
|
// Negative indices count from the end.
|
|
if (value < 0) value = count + value;
|
|
|
|
// Check bounds.
|
|
if (value >= 0 && value < count) return (uint32_t)value;
|
|
|
|
vm->fiber->error = wrenStringFormat(vm, "$ out of bounds.", argName);
|
|
return UINT32_MAX;
|
|
}
|
|
|
|
bool validateFn(WrenVM* vm, Value arg, const char* argName)
|
|
{
|
|
if (IS_CLOSURE(arg)) return true;
|
|
RETURN_ERROR_FMT("$ must be a function.", argName);
|
|
}
|
|
|
|
bool validateNum(WrenVM* vm, Value arg, const char* argName)
|
|
{
|
|
if (IS_NUM(arg)) return true;
|
|
RETURN_ERROR_FMT("$ must be a number.", argName);
|
|
}
|
|
|
|
bool validateIntValue(WrenVM* vm, double value, const char* argName)
|
|
{
|
|
if (trunc(value) == value) return true;
|
|
RETURN_ERROR_FMT("$ must be an integer.", argName);
|
|
}
|
|
|
|
bool validateInt(WrenVM* vm, Value arg, const char* argName)
|
|
{
|
|
// Make sure it's a number first.
|
|
if (!validateNum(vm, arg, argName)) return false;
|
|
return validateIntValue(vm, AS_NUM(arg), argName);
|
|
}
|
|
|
|
bool validateKey(WrenVM* vm, Value arg)
|
|
{
|
|
if (wrenMapIsValidKey(arg)) return true;
|
|
|
|
RETURN_ERROR("Key must be a value type.");
|
|
}
|
|
|
|
uint32_t validateIndex(WrenVM* vm, Value arg, uint32_t count,
|
|
const char* argName)
|
|
{
|
|
if (!validateNum(vm, arg, argName)) return UINT32_MAX;
|
|
return validateIndexValue(vm, count, AS_NUM(arg), argName);
|
|
}
|
|
|
|
bool validateString(WrenVM* vm, Value arg, const char* argName)
|
|
{
|
|
if (IS_STRING(arg)) return true;
|
|
RETURN_ERROR_FMT("$ must be a string.", argName);
|
|
}
|
|
|
|
uint32_t calculateRange(WrenVM* vm, ObjRange* range, uint32_t* length,
|
|
int* step)
|
|
{
|
|
*step = 0;
|
|
|
|
// Edge case: an empty range is allowed at the end of a sequence. This way,
|
|
// list[0..-1] and list[0...list.count] can be used to copy a list even when
|
|
// empty.
|
|
if (range->from == *length &&
|
|
range->to == (range->isInclusive ? -1.0 : (double)*length))
|
|
{
|
|
*length = 0;
|
|
return 0;
|
|
}
|
|
|
|
uint32_t from = validateIndexValue(vm, *length, range->from, "Range start");
|
|
if (from == UINT32_MAX) return UINT32_MAX;
|
|
|
|
// Bounds check the end manually to handle exclusive ranges.
|
|
double value = range->to;
|
|
if (!validateIntValue(vm, value, "Range end")) return UINT32_MAX;
|
|
|
|
// Negative indices count from the end.
|
|
if (value < 0) value = *length + value;
|
|
|
|
// Convert the exclusive range to an inclusive one.
|
|
if (!range->isInclusive)
|
|
{
|
|
// An exclusive range with the same start and end points is empty.
|
|
if (value == from)
|
|
{
|
|
*length = 0;
|
|
return from;
|
|
}
|
|
|
|
// Shift the endpoint to make it inclusive, handling both increasing and
|
|
// decreasing ranges.
|
|
value += value >= from ? -1 : 1;
|
|
}
|
|
|
|
// Check bounds.
|
|
if (value < 0 || value >= *length)
|
|
{
|
|
vm->fiber->error = CONST_STRING(vm, "Range end out of bounds.");
|
|
return UINT32_MAX;
|
|
}
|
|
|
|
uint32_t to = (uint32_t)value;
|
|
*length = abs((int)(from - to)) + 1;
|
|
*step = from < to ? 1 : -1;
|
|
return from;
|
|
}
|
|
// End file "wren_primitive.c"
|
|
// Begin file "wren_core.c"
|
|
#include <ctype.h>
|
|
#include <errno.h>
|
|
#include <float.h>
|
|
#include <math.h>
|
|
#include <string.h>
|
|
#include <time.h>
|
|
|
|
// Begin file "wren_core.h"
|
|
#ifndef wren_core_h
|
|
#define wren_core_h
|
|
|
|
|
|
// This module defines the built-in classes and their primitives methods that
|
|
// are implemented directly in C code. Some languages try to implement as much
|
|
// of the core module itself in the primary language instead of in the host
|
|
// language.
|
|
//
|
|
// With Wren, we try to do as much of it in C as possible. Primitive methods
|
|
// are always faster than code written in Wren, and it minimizes startup time
|
|
// since we don't have to parse, compile, and execute Wren code.
|
|
//
|
|
// There is one limitation, though. Methods written in C cannot call Wren ones.
|
|
// They can only be the top of the callstack, and immediately return. This
|
|
// makes it difficult to have primitive methods that rely on polymorphic
|
|
// behavior. For example, `System.print` should call `toString` on its argument,
|
|
// including user-defined `toString` methods on user-defined classes.
|
|
|
|
void wrenInitializeCore(WrenVM* vm);
|
|
|
|
#endif
|
|
// End file "wren_core.h"
|
|
|
|
// Begin file "wren_core.wren.inc"
|
|
// Generated automatically from src/vm/wren_core.wren. Do not edit.
|
|
static const char* coreModuleSource =
|
|
"class Bool {}\n"
|
|
"class Fiber {}\n"
|
|
"class Fn {}\n"
|
|
"class Null {}\n"
|
|
"class Num {}\n"
|
|
"\n"
|
|
"class Sequence {\n"
|
|
" all(f) {\n"
|
|
" var result = true\n"
|
|
" for (element in this) {\n"
|
|
" result = f.call(element)\n"
|
|
" if (!result) return result\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" any(f) {\n"
|
|
" var result = false\n"
|
|
" for (element in this) {\n"
|
|
" result = f.call(element)\n"
|
|
" if (result) return result\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" contains(element) {\n"
|
|
" for (item in this) {\n"
|
|
" if (element == item) return true\n"
|
|
" }\n"
|
|
" return false\n"
|
|
" }\n"
|
|
"\n"
|
|
" count {\n"
|
|
" var result = 0\n"
|
|
" for (element in this) {\n"
|
|
" result = result + 1\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" count(f) {\n"
|
|
" var result = 0\n"
|
|
" for (element in this) {\n"
|
|
" if (f.call(element)) result = result + 1\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" each(f) {\n"
|
|
" for (element in this) {\n"
|
|
" f.call(element)\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" isEmpty { iterate(null) ? false : true }\n"
|
|
"\n"
|
|
" map(transformation) { MapSequence.new(this, transformation) }\n"
|
|
"\n"
|
|
" skip(count) {\n"
|
|
" if (!(count is Num) || !count.isInteger || count < 0) {\n"
|
|
" Fiber.abort(\"Count must be a non-negative integer.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" return SkipSequence.new(this, count)\n"
|
|
" }\n"
|
|
"\n"
|
|
" take(count) {\n"
|
|
" if (!(count is Num) || !count.isInteger || count < 0) {\n"
|
|
" Fiber.abort(\"Count must be a non-negative integer.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" return TakeSequence.new(this, count)\n"
|
|
" }\n"
|
|
"\n"
|
|
" where(predicate) { WhereSequence.new(this, predicate) }\n"
|
|
"\n"
|
|
" reduce(acc, f) {\n"
|
|
" for (element in this) {\n"
|
|
" acc = f.call(acc, element)\n"
|
|
" }\n"
|
|
" return acc\n"
|
|
" }\n"
|
|
"\n"
|
|
" reduce(f) {\n"
|
|
" var iter = iterate(null)\n"
|
|
" if (!iter) Fiber.abort(\"Can't reduce an empty sequence.\")\n"
|
|
"\n"
|
|
" // Seed with the first element.\n"
|
|
" var result = iteratorValue(iter)\n"
|
|
" while (iter = iterate(iter)) {\n"
|
|
" result = f.call(result, iteratorValue(iter))\n"
|
|
" }\n"
|
|
"\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" join() { join(\"\") }\n"
|
|
"\n"
|
|
" join(sep) {\n"
|
|
" var first = true\n"
|
|
" var result = \"\"\n"
|
|
"\n"
|
|
" for (element in this) {\n"
|
|
" if (!first) result = result + sep\n"
|
|
" first = false\n"
|
|
" result = result + element.toString\n"
|
|
" }\n"
|
|
"\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" toList {\n"
|
|
" var result = List.new()\n"
|
|
" for (element in this) {\n"
|
|
" result.add(element)\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class MapSequence is Sequence {\n"
|
|
" construct new(sequence, fn) {\n"
|
|
" _sequence = sequence\n"
|
|
" _fn = fn\n"
|
|
" }\n"
|
|
"\n"
|
|
" iterate(iterator) { _sequence.iterate(iterator) }\n"
|
|
" iteratorValue(iterator) { _fn.call(_sequence.iteratorValue(iterator)) }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class SkipSequence is Sequence {\n"
|
|
" construct new(sequence, count) {\n"
|
|
" _sequence = sequence\n"
|
|
" _count = count\n"
|
|
" }\n"
|
|
"\n"
|
|
" iterate(iterator) {\n"
|
|
" if (iterator) {\n"
|
|
" return _sequence.iterate(iterator)\n"
|
|
" } else {\n"
|
|
" iterator = _sequence.iterate(iterator)\n"
|
|
" var count = _count\n"
|
|
" while (count > 0 && iterator) {\n"
|
|
" iterator = _sequence.iterate(iterator)\n"
|
|
" count = count - 1\n"
|
|
" }\n"
|
|
" return iterator\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class TakeSequence is Sequence {\n"
|
|
" construct new(sequence, count) {\n"
|
|
" _sequence = sequence\n"
|
|
" _count = count\n"
|
|
" }\n"
|
|
"\n"
|
|
" iterate(iterator) {\n"
|
|
" if (!iterator) _taken = 1 else _taken = _taken + 1\n"
|
|
" return _taken > _count ? null : _sequence.iterate(iterator)\n"
|
|
" }\n"
|
|
"\n"
|
|
" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class WhereSequence is Sequence {\n"
|
|
" construct new(sequence, fn) {\n"
|
|
" _sequence = sequence\n"
|
|
" _fn = fn\n"
|
|
" }\n"
|
|
"\n"
|
|
" iterate(iterator) {\n"
|
|
" while (iterator = _sequence.iterate(iterator)) {\n"
|
|
" if (_fn.call(_sequence.iteratorValue(iterator))) break\n"
|
|
" }\n"
|
|
" return iterator\n"
|
|
" }\n"
|
|
"\n"
|
|
" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class String is Sequence {\n"
|
|
" bytes { StringByteSequence.new(this) }\n"
|
|
" codePoints { StringCodePointSequence.new(this) }\n"
|
|
"\n"
|
|
" split(delimiter) {\n"
|
|
" if (!(delimiter is String) || delimiter.isEmpty) {\n"
|
|
" Fiber.abort(\"Delimiter must be a non-empty string.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" var result = []\n"
|
|
"\n"
|
|
" var last = 0\n"
|
|
" var index = 0\n"
|
|
"\n"
|
|
" var delimSize = delimiter.byteCount_\n"
|
|
" var size = byteCount_\n"
|
|
"\n"
|
|
" while (last < size && (index = indexOf(delimiter, last)) != -1) {\n"
|
|
" result.add(this[last...index])\n"
|
|
" last = index + delimSize\n"
|
|
" }\n"
|
|
"\n"
|
|
" if (last < size) {\n"
|
|
" result.add(this[last..-1])\n"
|
|
" } else {\n"
|
|
" result.add(\"\")\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" replace(from, to) {\n"
|
|
" if (!(from is String) || from.isEmpty) {\n"
|
|
" Fiber.abort(\"From must be a non-empty string.\")\n"
|
|
" } else if (!(to is String)) {\n"
|
|
" Fiber.abort(\"To must be a string.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" var result = \"\"\n"
|
|
"\n"
|
|
" var last = 0\n"
|
|
" var index = 0\n"
|
|
"\n"
|
|
" var fromSize = from.byteCount_\n"
|
|
" var size = byteCount_\n"
|
|
"\n"
|
|
" while (last < size && (index = indexOf(from, last)) != -1) {\n"
|
|
" result = result + this[last...index] + to\n"
|
|
" last = index + fromSize\n"
|
|
" }\n"
|
|
"\n"
|
|
" if (last < size) result = result + this[last..-1]\n"
|
|
"\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" trim() { trim_(\"\\t\\r\\n \", true, true) }\n"
|
|
" trim(chars) { trim_(chars, true, true) }\n"
|
|
" trimEnd() { trim_(\"\\t\\r\\n \", false, true) }\n"
|
|
" trimEnd(chars) { trim_(chars, false, true) }\n"
|
|
" trimStart() { trim_(\"\\t\\r\\n \", true, false) }\n"
|
|
" trimStart(chars) { trim_(chars, true, false) }\n"
|
|
"\n"
|
|
" trim_(chars, trimStart, trimEnd) {\n"
|
|
" if (!(chars is String)) {\n"
|
|
" Fiber.abort(\"Characters must be a string.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" var codePoints = chars.codePoints.toList\n"
|
|
"\n"
|
|
" var start\n"
|
|
" if (trimStart) {\n"
|
|
" while (start = iterate(start)) {\n"
|
|
" if (!codePoints.contains(codePointAt_(start))) break\n"
|
|
" }\n"
|
|
"\n"
|
|
" if (start == false) return \"\"\n"
|
|
" } else {\n"
|
|
" start = 0\n"
|
|
" }\n"
|
|
"\n"
|
|
" var end\n"
|
|
" if (trimEnd) {\n"
|
|
" end = byteCount_ - 1\n"
|
|
" while (end >= start) {\n"
|
|
" var codePoint = codePointAt_(end)\n"
|
|
" if (codePoint != -1 && !codePoints.contains(codePoint)) break\n"
|
|
" end = end - 1\n"
|
|
" }\n"
|
|
"\n"
|
|
" if (end < start) return \"\"\n"
|
|
" } else {\n"
|
|
" end = -1\n"
|
|
" }\n"
|
|
"\n"
|
|
" return this[start..end]\n"
|
|
" }\n"
|
|
"\n"
|
|
" *(count) {\n"
|
|
" if (!(count is Num) || !count.isInteger || count < 0) {\n"
|
|
" Fiber.abort(\"Count must be a non-negative integer.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" var result = \"\"\n"
|
|
" for (i in 0...count) {\n"
|
|
" result = result + this\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class StringByteSequence is Sequence {\n"
|
|
" construct new(string) {\n"
|
|
" _string = string\n"
|
|
" }\n"
|
|
"\n"
|
|
" [index] { _string.byteAt_(index) }\n"
|
|
" iterate(iterator) { _string.iterateByte_(iterator) }\n"
|
|
" iteratorValue(iterator) { _string.byteAt_(iterator) }\n"
|
|
"\n"
|
|
" count { _string.byteCount_ }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class StringCodePointSequence is Sequence {\n"
|
|
" construct new(string) {\n"
|
|
" _string = string\n"
|
|
" }\n"
|
|
"\n"
|
|
" [index] { _string.codePointAt_(index) }\n"
|
|
" iterate(iterator) { _string.iterate(iterator) }\n"
|
|
" iteratorValue(iterator) { _string.codePointAt_(iterator) }\n"
|
|
"\n"
|
|
" count { _string.count }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class List is Sequence {\n"
|
|
" addAll(other) {\n"
|
|
" for (element in other) {\n"
|
|
" add(element)\n"
|
|
" }\n"
|
|
" return other\n"
|
|
" }\n"
|
|
"\n"
|
|
" sort() { sort {|low, high| low < high } }\n"
|
|
"\n"
|
|
" sort(comparer) {\n"
|
|
" if (!(comparer is Fn)) {\n"
|
|
" Fiber.abort(\"Comparer must be a function.\")\n"
|
|
" }\n"
|
|
" quicksort_(0, count - 1, comparer)\n"
|
|
" return this\n"
|
|
" }\n"
|
|
"\n"
|
|
" quicksort_(low, high, comparer) {\n"
|
|
" if (low < high) {\n"
|
|
" var p = partition_(low, high, comparer)\n"
|
|
" quicksort_(low, p - 1, comparer)\n"
|
|
" quicksort_(p + 1, high, comparer)\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" partition_(low, high, comparer) {\n"
|
|
" var p = this[high]\n"
|
|
" var i = low - 1\n"
|
|
" for (j in low..(high-1)) {\n"
|
|
" if (comparer.call(this[j], p)) { \n"
|
|
" i = i + 1\n"
|
|
" var t = this[i]\n"
|
|
" this[i] = this[j]\n"
|
|
" this[j] = t\n"
|
|
" }\n"
|
|
" }\n"
|
|
" var t = this[i+1]\n"
|
|
" this[i+1] = this[high]\n"
|
|
" this[high] = t\n"
|
|
" return i+1\n"
|
|
" }\n"
|
|
"\n"
|
|
" toString { \"[%(join(\", \"))]\" }\n"
|
|
"\n"
|
|
" +(other) {\n"
|
|
" var result = this[0..-1]\n"
|
|
" for (element in other) {\n"
|
|
" result.add(element)\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" *(count) {\n"
|
|
" if (!(count is Num) || !count.isInteger || count < 0) {\n"
|
|
" Fiber.abort(\"Count must be a non-negative integer.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" var result = []\n"
|
|
" for (i in 0...count) {\n"
|
|
" result.addAll(this)\n"
|
|
" }\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class Map is Sequence {\n"
|
|
" keys { MapKeySequence.new(this) }\n"
|
|
" values { MapValueSequence.new(this) }\n"
|
|
"\n"
|
|
" toString {\n"
|
|
" var first = true\n"
|
|
" var result = \"{\"\n"
|
|
"\n"
|
|
" for (key in keys) {\n"
|
|
" if (!first) result = result + \", \"\n"
|
|
" first = false\n"
|
|
" result = result + \"%(key): %(this[key])\"\n"
|
|
" }\n"
|
|
"\n"
|
|
" return result + \"}\"\n"
|
|
" }\n"
|
|
"\n"
|
|
" iteratorValue(iterator) {\n"
|
|
" return MapEntry.new(\n"
|
|
" keyIteratorValue_(iterator),\n"
|
|
" valueIteratorValue_(iterator))\n"
|
|
" }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class MapEntry {\n"
|
|
" construct new(key, value) {\n"
|
|
" _key = key\n"
|
|
" _value = value\n"
|
|
" }\n"
|
|
"\n"
|
|
" key { _key }\n"
|
|
" value { _value }\n"
|
|
"\n"
|
|
" toString { \"%(_key):%(_value)\" }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class MapKeySequence is Sequence {\n"
|
|
" construct new(map) {\n"
|
|
" _map = map\n"
|
|
" }\n"
|
|
"\n"
|
|
" iterate(n) { _map.iterate(n) }\n"
|
|
" iteratorValue(iterator) { _map.keyIteratorValue_(iterator) }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class MapValueSequence is Sequence {\n"
|
|
" construct new(map) {\n"
|
|
" _map = map\n"
|
|
" }\n"
|
|
"\n"
|
|
" iterate(n) { _map.iterate(n) }\n"
|
|
" iteratorValue(iterator) { _map.valueIteratorValue_(iterator) }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class Range is Sequence {}\n"
|
|
"\n"
|
|
"class System {\n"
|
|
" static print() {\n"
|
|
" writeString_(\"\\n\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" static print(obj) {\n"
|
|
" writeObject_(obj)\n"
|
|
" writeString_(\"\\n\")\n"
|
|
" return obj\n"
|
|
" }\n"
|
|
"\n"
|
|
" static printAll(sequence) {\n"
|
|
" for (object in sequence) writeObject_(object)\n"
|
|
" writeString_(\"\\n\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" static write(obj) {\n"
|
|
" writeObject_(obj)\n"
|
|
" return obj\n"
|
|
" }\n"
|
|
"\n"
|
|
" static writeAll(sequence) {\n"
|
|
" for (object in sequence) writeObject_(object)\n"
|
|
" }\n"
|
|
"\n"
|
|
" static writeObject_(obj) {\n"
|
|
" var string = obj.toString\n"
|
|
" if (string is String) {\n"
|
|
" writeString_(string)\n"
|
|
" } else {\n"
|
|
" writeString_(\"[invalid toString]\")\n"
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
"\n"
|
|
"class ClassAttributes {\n"
|
|
" self { _attributes }\n"
|
|
" methods { _methods }\n"
|
|
" construct new(attributes, methods) {\n"
|
|
" _attributes = attributes\n"
|
|
" _methods = methods\n"
|
|
" }\n"
|
|
" toString { \"attributes:%(_attributes) methods:%(_methods)\" }\n"
|
|
"}\n";
|
|
// End file "wren_core.wren.inc"
|
|
|
|
DEF_PRIMITIVE(bool_not)
|
|
{
|
|
RETURN_BOOL(!AS_BOOL(args[0]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(bool_toString)
|
|
{
|
|
if (AS_BOOL(args[0]))
|
|
{
|
|
RETURN_VAL(CONST_STRING(vm, "true"));
|
|
}
|
|
else
|
|
{
|
|
RETURN_VAL(CONST_STRING(vm, "false"));
|
|
}
|
|
}
|
|
|
|
DEF_PRIMITIVE(class_name)
|
|
{
|
|
RETURN_OBJ(AS_CLASS(args[0])->name);
|
|
}
|
|
|
|
DEF_PRIMITIVE(class_supertype)
|
|
{
|
|
ObjClass* classObj = AS_CLASS(args[0]);
|
|
|
|
// Object has no superclass.
|
|
if (classObj->superclass == NULL) RETURN_NULL;
|
|
|
|
RETURN_OBJ(classObj->superclass);
|
|
}
|
|
|
|
DEF_PRIMITIVE(class_toString)
|
|
{
|
|
RETURN_OBJ(AS_CLASS(args[0])->name);
|
|
}
|
|
|
|
DEF_PRIMITIVE(class_attributes)
|
|
{
|
|
RETURN_VAL(AS_CLASS(args[0])->attributes);
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_new)
|
|
{
|
|
if (!validateFn(vm, args[1], "Argument")) return false;
|
|
|
|
ObjClosure* closure = AS_CLOSURE(args[1]);
|
|
if (closure->fn->arity > 1)
|
|
{
|
|
RETURN_ERROR("Function cannot take more than one parameter.");
|
|
}
|
|
|
|
RETURN_OBJ(wrenNewFiber(vm, closure));
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_abort)
|
|
{
|
|
vm->fiber->error = args[1];
|
|
|
|
// If the error is explicitly null, it's not really an abort.
|
|
return IS_NULL(args[1]);
|
|
}
|
|
|
|
// Transfer execution to [fiber] coming from the current fiber whose stack has
|
|
// [args].
|
|
//
|
|
// [isCall] is true if [fiber] is being called and not transferred.
|
|
//
|
|
// [hasValue] is true if a value in [args] is being passed to the new fiber.
|
|
// Otherwise, `null` is implicitly being passed.
|
|
static bool runFiber(WrenVM* vm, ObjFiber* fiber, Value* args, bool isCall,
|
|
bool hasValue, const char* verb)
|
|
{
|
|
|
|
if (wrenHasError(fiber))
|
|
{
|
|
RETURN_ERROR_FMT("Cannot $ an aborted fiber.", verb);
|
|
}
|
|
|
|
if (isCall)
|
|
{
|
|
// You can't call a called fiber, but you can transfer directly to it,
|
|
// which is why this check is gated on `isCall`. This way, after resuming a
|
|
// suspended fiber, it will run and then return to the fiber that called it
|
|
// and so on.
|
|
if (fiber->caller != NULL) RETURN_ERROR("Fiber has already been called.");
|
|
|
|
if (fiber->state == FIBER_ROOT) RETURN_ERROR("Cannot call root fiber.");
|
|
|
|
// Remember who ran it.
|
|
fiber->caller = vm->fiber;
|
|
}
|
|
|
|
if (fiber->numFrames == 0)
|
|
{
|
|
RETURN_ERROR_FMT("Cannot $ a finished fiber.", verb);
|
|
}
|
|
|
|
// When the calling fiber resumes, we'll store the result of the call in its
|
|
// stack. If the call has two arguments (the fiber and the value), we only
|
|
// need one slot for the result, so discard the other slot now.
|
|
if (hasValue) vm->fiber->stackTop--;
|
|
|
|
if (fiber->numFrames == 1 &&
|
|
fiber->frames[0].ip == fiber->frames[0].closure->fn->code.data)
|
|
{
|
|
// The fiber is being started for the first time. If its function takes a
|
|
// parameter, bind an argument to it.
|
|
if (fiber->frames[0].closure->fn->arity == 1)
|
|
{
|
|
fiber->stackTop[0] = hasValue ? args[1] : NULL_VAL;
|
|
fiber->stackTop++;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// The fiber is being resumed, make yield() or transfer() return the result.
|
|
fiber->stackTop[-1] = hasValue ? args[1] : NULL_VAL;
|
|
}
|
|
|
|
vm->fiber = fiber;
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_call)
|
|
{
|
|
return runFiber(vm, AS_FIBER(args[0]), args, true, false, "call");
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_call1)
|
|
{
|
|
return runFiber(vm, AS_FIBER(args[0]), args, true, true, "call");
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_current)
|
|
{
|
|
RETURN_OBJ(vm->fiber);
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_error)
|
|
{
|
|
RETURN_VAL(AS_FIBER(args[0])->error);
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_isDone)
|
|
{
|
|
ObjFiber* runFiber = AS_FIBER(args[0]);
|
|
RETURN_BOOL(runFiber->numFrames == 0 || wrenHasError(runFiber));
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_suspend)
|
|
{
|
|
// Switching to a null fiber tells the interpreter to stop and exit.
|
|
vm->fiber = NULL;
|
|
vm->apiStack = NULL;
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_transfer)
|
|
{
|
|
return runFiber(vm, AS_FIBER(args[0]), args, false, false, "transfer to");
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_transfer1)
|
|
{
|
|
return runFiber(vm, AS_FIBER(args[0]), args, false, true, "transfer to");
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_transferError)
|
|
{
|
|
runFiber(vm, AS_FIBER(args[0]), args, false, true, "transfer to");
|
|
vm->fiber->error = args[1];
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_try)
|
|
{
|
|
runFiber(vm, AS_FIBER(args[0]), args, true, false, "try");
|
|
|
|
// If we're switching to a valid fiber to try, remember that we're trying it.
|
|
if (!wrenHasError(vm->fiber)) vm->fiber->state = FIBER_TRY;
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_try1)
|
|
{
|
|
runFiber(vm, AS_FIBER(args[0]), args, true, true, "try");
|
|
|
|
// If we're switching to a valid fiber to try, remember that we're trying it.
|
|
if (!wrenHasError(vm->fiber)) vm->fiber->state = FIBER_TRY;
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_yield)
|
|
{
|
|
ObjFiber* current = vm->fiber;
|
|
vm->fiber = current->caller;
|
|
|
|
// Unhook this fiber from the one that called it.
|
|
current->caller = NULL;
|
|
current->state = FIBER_OTHER;
|
|
|
|
if (vm->fiber != NULL)
|
|
{
|
|
// Make the caller's run method return null.
|
|
vm->fiber->stackTop[-1] = NULL_VAL;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fiber_yield1)
|
|
{
|
|
ObjFiber* current = vm->fiber;
|
|
vm->fiber = current->caller;
|
|
|
|
// Unhook this fiber from the one that called it.
|
|
current->caller = NULL;
|
|
current->state = FIBER_OTHER;
|
|
|
|
if (vm->fiber != NULL)
|
|
{
|
|
// Make the caller's run method return the argument passed to yield.
|
|
vm->fiber->stackTop[-1] = args[1];
|
|
|
|
// When the yielding fiber resumes, we'll store the result of the yield
|
|
// call in its stack. Since Fiber.yield(value) has two arguments (the Fiber
|
|
// class and the value) and we only need one slot for the result, discard
|
|
// the other slot now.
|
|
current->stackTop--;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
DEF_PRIMITIVE(fn_new)
|
|
{
|
|
if (!validateFn(vm, args[1], "Argument")) return false;
|
|
|
|
// The block argument is already a function, so just return it.
|
|
RETURN_VAL(args[1]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(fn_arity)
|
|
{
|
|
RETURN_NUM(AS_CLOSURE(args[0])->fn->arity);
|
|
}
|
|
|
|
static void call_fn(WrenVM* vm, Value* args, int numArgs)
|
|
{
|
|
// +1 to include the function itself.
|
|
wrenCallFunction(vm, vm->fiber, AS_CLOSURE(args[0]), numArgs + 1);
|
|
}
|
|
|
|
#define DEF_FN_CALL(numArgs) \
|
|
DEF_PRIMITIVE(fn_call##numArgs) \
|
|
{ \
|
|
call_fn(vm, args, numArgs); \
|
|
return false; \
|
|
}
|
|
|
|
DEF_FN_CALL(0)
|
|
DEF_FN_CALL(1)
|
|
DEF_FN_CALL(2)
|
|
DEF_FN_CALL(3)
|
|
DEF_FN_CALL(4)
|
|
DEF_FN_CALL(5)
|
|
DEF_FN_CALL(6)
|
|
DEF_FN_CALL(7)
|
|
DEF_FN_CALL(8)
|
|
DEF_FN_CALL(9)
|
|
DEF_FN_CALL(10)
|
|
DEF_FN_CALL(11)
|
|
DEF_FN_CALL(12)
|
|
DEF_FN_CALL(13)
|
|
DEF_FN_CALL(14)
|
|
DEF_FN_CALL(15)
|
|
DEF_FN_CALL(16)
|
|
|
|
DEF_PRIMITIVE(fn_toString)
|
|
{
|
|
RETURN_VAL(CONST_STRING(vm, "<fn>"));
|
|
}
|
|
|
|
// Creates a new list of size args[1], with all elements initialized to args[2].
|
|
DEF_PRIMITIVE(list_filled)
|
|
{
|
|
if (!validateInt(vm, args[1], "Size")) return false;
|
|
if (AS_NUM(args[1]) < 0) RETURN_ERROR("Size cannot be negative.");
|
|
|
|
uint32_t size = (uint32_t)AS_NUM(args[1]);
|
|
ObjList* list = wrenNewList(vm, size);
|
|
|
|
for (uint32_t i = 0; i < size; i++)
|
|
{
|
|
list->elements.data[i] = args[2];
|
|
}
|
|
|
|
RETURN_OBJ(list);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_new)
|
|
{
|
|
RETURN_OBJ(wrenNewList(vm, 0));
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_add)
|
|
{
|
|
wrenValueBufferWrite(vm, &AS_LIST(args[0])->elements, args[1]);
|
|
RETURN_VAL(args[1]);
|
|
}
|
|
|
|
// Adds an element to the list and then returns the list itself. This is called
|
|
// by the compiler when compiling list literals instead of using add() to
|
|
// minimize stack churn.
|
|
DEF_PRIMITIVE(list_addCore)
|
|
{
|
|
wrenValueBufferWrite(vm, &AS_LIST(args[0])->elements, args[1]);
|
|
|
|
// Return the list.
|
|
RETURN_VAL(args[0]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_clear)
|
|
{
|
|
wrenValueBufferClear(vm, &AS_LIST(args[0])->elements);
|
|
RETURN_NULL;
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_count)
|
|
{
|
|
RETURN_NUM(AS_LIST(args[0])->elements.count);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_insert)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
|
|
// count + 1 here so you can "insert" at the very end.
|
|
uint32_t index = validateIndex(vm, args[1], list->elements.count + 1,
|
|
"Index");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
wrenListInsert(vm, list, args[2], index);
|
|
RETURN_VAL(args[2]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_iterate)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
|
|
// If we're starting the iteration, return the first index.
|
|
if (IS_NULL(args[1]))
|
|
{
|
|
if (list->elements.count == 0) RETURN_FALSE;
|
|
RETURN_NUM(0);
|
|
}
|
|
|
|
if (!validateInt(vm, args[1], "Iterator")) return false;
|
|
|
|
// Stop if we're out of bounds.
|
|
double index = AS_NUM(args[1]);
|
|
if (index < 0 || index >= list->elements.count - 1) RETURN_FALSE;
|
|
|
|
// Otherwise, move to the next index.
|
|
RETURN_NUM(index + 1);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_iteratorValue)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
uint32_t index = validateIndex(vm, args[1], list->elements.count, "Iterator");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
RETURN_VAL(list->elements.data[index]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_removeAt)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
uint32_t index = validateIndex(vm, args[1], list->elements.count, "Index");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
RETURN_VAL(wrenListRemoveAt(vm, list, index));
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_removeValue) {
|
|
ObjList* list = AS_LIST(args[0]);
|
|
int index = wrenListIndexOf(vm, list, args[1]);
|
|
if(index == -1) RETURN_NULL;
|
|
RETURN_VAL(wrenListRemoveAt(vm, list, index));
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_indexOf)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
RETURN_NUM(wrenListIndexOf(vm, list, args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_swap)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
uint32_t indexA = validateIndex(vm, args[1], list->elements.count, "Index 0");
|
|
if (indexA == UINT32_MAX) return false;
|
|
uint32_t indexB = validateIndex(vm, args[2], list->elements.count, "Index 1");
|
|
if (indexB == UINT32_MAX) return false;
|
|
|
|
Value a = list->elements.data[indexA];
|
|
list->elements.data[indexA] = list->elements.data[indexB];
|
|
list->elements.data[indexB] = a;
|
|
|
|
RETURN_NULL;
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_subscript)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
|
|
if (IS_NUM(args[1]))
|
|
{
|
|
uint32_t index = validateIndex(vm, args[1], list->elements.count,
|
|
"Subscript");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
RETURN_VAL(list->elements.data[index]);
|
|
}
|
|
|
|
if (!IS_RANGE(args[1]))
|
|
{
|
|
RETURN_ERROR("Subscript must be a number or a range.");
|
|
}
|
|
|
|
int step;
|
|
uint32_t count = list->elements.count;
|
|
uint32_t start = calculateRange(vm, AS_RANGE(args[1]), &count, &step);
|
|
if (start == UINT32_MAX) return false;
|
|
|
|
ObjList* result = wrenNewList(vm, count);
|
|
for (uint32_t i = 0; i < count; i++)
|
|
{
|
|
result->elements.data[i] = list->elements.data[start + i * step];
|
|
}
|
|
|
|
RETURN_OBJ(result);
|
|
}
|
|
|
|
DEF_PRIMITIVE(list_subscriptSetter)
|
|
{
|
|
ObjList* list = AS_LIST(args[0]);
|
|
uint32_t index = validateIndex(vm, args[1], list->elements.count,
|
|
"Subscript");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
list->elements.data[index] = args[2];
|
|
RETURN_VAL(args[2]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_new)
|
|
{
|
|
RETURN_OBJ(wrenNewMap(vm));
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_subscript)
|
|
{
|
|
if (!validateKey(vm, args[1])) return false;
|
|
|
|
ObjMap* map = AS_MAP(args[0]);
|
|
Value value = wrenMapGet(map, args[1]);
|
|
if (IS_UNDEFINED(value)) RETURN_NULL;
|
|
|
|
RETURN_VAL(value);
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_subscriptSetter)
|
|
{
|
|
if (!validateKey(vm, args[1])) return false;
|
|
|
|
wrenMapSet(vm, AS_MAP(args[0]), args[1], args[2]);
|
|
RETURN_VAL(args[2]);
|
|
}
|
|
|
|
// Adds an entry to the map and then returns the map itself. This is called by
|
|
// the compiler when compiling map literals instead of using [_]=(_) to
|
|
// minimize stack churn.
|
|
DEF_PRIMITIVE(map_addCore)
|
|
{
|
|
if (!validateKey(vm, args[1])) return false;
|
|
|
|
wrenMapSet(vm, AS_MAP(args[0]), args[1], args[2]);
|
|
|
|
// Return the map itself.
|
|
RETURN_VAL(args[0]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_clear)
|
|
{
|
|
wrenMapClear(vm, AS_MAP(args[0]));
|
|
RETURN_NULL;
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_containsKey)
|
|
{
|
|
if (!validateKey(vm, args[1])) return false;
|
|
|
|
RETURN_BOOL(!IS_UNDEFINED(wrenMapGet(AS_MAP(args[0]), args[1])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_count)
|
|
{
|
|
RETURN_NUM(AS_MAP(args[0])->count);
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_iterate)
|
|
{
|
|
ObjMap* map = AS_MAP(args[0]);
|
|
|
|
if (map->count == 0) RETURN_FALSE;
|
|
|
|
// If we're starting the iteration, start at the first used entry.
|
|
uint32_t index = 0;
|
|
|
|
// Otherwise, start one past the last entry we stopped at.
|
|
if (!IS_NULL(args[1]))
|
|
{
|
|
if (!validateInt(vm, args[1], "Iterator")) return false;
|
|
|
|
if (AS_NUM(args[1]) < 0) RETURN_FALSE;
|
|
index = (uint32_t)AS_NUM(args[1]);
|
|
|
|
if (index >= map->capacity) RETURN_FALSE;
|
|
|
|
// Advance the iterator.
|
|
index++;
|
|
}
|
|
|
|
// Find a used entry, if any.
|
|
for (; index < map->capacity; index++)
|
|
{
|
|
if (!IS_UNDEFINED(map->entries[index].key)) RETURN_NUM(index);
|
|
}
|
|
|
|
// If we get here, walked all of the entries.
|
|
RETURN_FALSE;
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_remove)
|
|
{
|
|
if (!validateKey(vm, args[1])) return false;
|
|
|
|
RETURN_VAL(wrenMapRemoveKey(vm, AS_MAP(args[0]), args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_keyIteratorValue)
|
|
{
|
|
ObjMap* map = AS_MAP(args[0]);
|
|
uint32_t index = validateIndex(vm, args[1], map->capacity, "Iterator");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
MapEntry* entry = &map->entries[index];
|
|
if (IS_UNDEFINED(entry->key))
|
|
{
|
|
RETURN_ERROR("Invalid map iterator.");
|
|
}
|
|
|
|
RETURN_VAL(entry->key);
|
|
}
|
|
|
|
DEF_PRIMITIVE(map_valueIteratorValue)
|
|
{
|
|
ObjMap* map = AS_MAP(args[0]);
|
|
uint32_t index = validateIndex(vm, args[1], map->capacity, "Iterator");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
MapEntry* entry = &map->entries[index];
|
|
if (IS_UNDEFINED(entry->key))
|
|
{
|
|
RETURN_ERROR("Invalid map iterator.");
|
|
}
|
|
|
|
RETURN_VAL(entry->value);
|
|
}
|
|
|
|
DEF_PRIMITIVE(null_not)
|
|
{
|
|
RETURN_VAL(TRUE_VAL);
|
|
}
|
|
|
|
DEF_PRIMITIVE(null_toString)
|
|
{
|
|
RETURN_VAL(CONST_STRING(vm, "null"));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_fromString)
|
|
{
|
|
if (!validateString(vm, args[1], "Argument")) return false;
|
|
|
|
ObjString* string = AS_STRING(args[1]);
|
|
|
|
// Corner case: Can't parse an empty string.
|
|
if (string->length == 0) RETURN_NULL;
|
|
|
|
errno = 0;
|
|
char* end;
|
|
double number = strtod(string->value, &end);
|
|
|
|
// Skip past any trailing whitespace.
|
|
while (*end != '\0' && isspace((unsigned char)*end)) end++;
|
|
|
|
if (errno == ERANGE) RETURN_ERROR("Number literal is too large.");
|
|
|
|
// We must have consumed the entire string. Otherwise, it contains non-number
|
|
// characters and we can't parse it.
|
|
if (end < string->value + string->length) RETURN_NULL;
|
|
|
|
RETURN_NUM(number);
|
|
}
|
|
|
|
// Defines a primitive on Num that calls infix [op] and returns [type].
|
|
#define DEF_NUM_CONSTANT(name, value) \
|
|
DEF_PRIMITIVE(num_##name) \
|
|
{ \
|
|
RETURN_NUM(value); \
|
|
}
|
|
|
|
DEF_NUM_CONSTANT(infinity, INFINITY)
|
|
DEF_NUM_CONSTANT(nan, WREN_DOUBLE_NAN)
|
|
DEF_NUM_CONSTANT(pi, 3.14159265358979323846264338327950288)
|
|
DEF_NUM_CONSTANT(tau, 6.28318530717958647692528676655900577)
|
|
|
|
DEF_NUM_CONSTANT(largest, DBL_MAX)
|
|
DEF_NUM_CONSTANT(smallest, DBL_MIN)
|
|
|
|
DEF_NUM_CONSTANT(maxSafeInteger, 9007199254740991.0)
|
|
DEF_NUM_CONSTANT(minSafeInteger, -9007199254740991.0)
|
|
|
|
// Defines a primitive on Num that calls infix [op] and returns [type].
|
|
#define DEF_NUM_INFIX(name, op, type) \
|
|
DEF_PRIMITIVE(num_##name) \
|
|
{ \
|
|
if (!validateNum(vm, args[1], "Right operand")) return false; \
|
|
RETURN_##type(AS_NUM(args[0]) op AS_NUM(args[1])); \
|
|
}
|
|
|
|
DEF_NUM_INFIX(minus, -, NUM)
|
|
DEF_NUM_INFIX(plus, +, NUM)
|
|
DEF_NUM_INFIX(multiply, *, NUM)
|
|
DEF_NUM_INFIX(divide, /, NUM)
|
|
DEF_NUM_INFIX(lt, <, BOOL)
|
|
DEF_NUM_INFIX(gt, >, BOOL)
|
|
DEF_NUM_INFIX(lte, <=, BOOL)
|
|
DEF_NUM_INFIX(gte, >=, BOOL)
|
|
|
|
// Defines a primitive on Num that call infix bitwise [op].
|
|
#define DEF_NUM_BITWISE(name, op) \
|
|
DEF_PRIMITIVE(num_bitwise##name) \
|
|
{ \
|
|
if (!validateNum(vm, args[1], "Right operand")) return false; \
|
|
uint32_t left = (uint32_t)AS_NUM(args[0]); \
|
|
uint32_t right = (uint32_t)AS_NUM(args[1]); \
|
|
RETURN_NUM(left op right); \
|
|
}
|
|
|
|
DEF_NUM_BITWISE(And, &)
|
|
DEF_NUM_BITWISE(Or, |)
|
|
DEF_NUM_BITWISE(Xor, ^)
|
|
DEF_NUM_BITWISE(LeftShift, <<)
|
|
DEF_NUM_BITWISE(RightShift, >>)
|
|
|
|
// Defines a primitive method on Num that returns the result of [fn].
|
|
#define DEF_NUM_FN(name, fn) \
|
|
DEF_PRIMITIVE(num_##name) \
|
|
{ \
|
|
RETURN_NUM(fn(AS_NUM(args[0]))); \
|
|
}
|
|
|
|
DEF_NUM_FN(abs, fabs)
|
|
DEF_NUM_FN(acos, acos)
|
|
DEF_NUM_FN(asin, asin)
|
|
DEF_NUM_FN(atan, atan)
|
|
DEF_NUM_FN(cbrt, cbrt)
|
|
DEF_NUM_FN(ceil, ceil)
|
|
DEF_NUM_FN(cos, cos)
|
|
DEF_NUM_FN(floor, floor)
|
|
DEF_NUM_FN(negate, -)
|
|
DEF_NUM_FN(round, round)
|
|
DEF_NUM_FN(sin, sin)
|
|
DEF_NUM_FN(sqrt, sqrt)
|
|
DEF_NUM_FN(tan, tan)
|
|
DEF_NUM_FN(log, log)
|
|
DEF_NUM_FN(log2, log2)
|
|
DEF_NUM_FN(exp, exp)
|
|
|
|
DEF_PRIMITIVE(num_mod)
|
|
{
|
|
if (!validateNum(vm, args[1], "Right operand")) return false;
|
|
RETURN_NUM(fmod(AS_NUM(args[0]), AS_NUM(args[1])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_eqeq)
|
|
{
|
|
if (!IS_NUM(args[1])) RETURN_FALSE;
|
|
RETURN_BOOL(AS_NUM(args[0]) == AS_NUM(args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_bangeq)
|
|
{
|
|
if (!IS_NUM(args[1])) RETURN_TRUE;
|
|
RETURN_BOOL(AS_NUM(args[0]) != AS_NUM(args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_bitwiseNot)
|
|
{
|
|
// Bitwise operators always work on 32-bit unsigned ints.
|
|
RETURN_NUM(~(uint32_t)AS_NUM(args[0]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_dotDot)
|
|
{
|
|
if (!validateNum(vm, args[1], "Right hand side of range")) return false;
|
|
|
|
double from = AS_NUM(args[0]);
|
|
double to = AS_NUM(args[1]);
|
|
RETURN_VAL(wrenNewRange(vm, from, to, true));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_dotDotDot)
|
|
{
|
|
if (!validateNum(vm, args[1], "Right hand side of range")) return false;
|
|
|
|
double from = AS_NUM(args[0]);
|
|
double to = AS_NUM(args[1]);
|
|
RETURN_VAL(wrenNewRange(vm, from, to, false));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_atan2)
|
|
{
|
|
if (!validateNum(vm, args[1], "x value")) return false;
|
|
|
|
RETURN_NUM(atan2(AS_NUM(args[0]), AS_NUM(args[1])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_min)
|
|
{
|
|
if (!validateNum(vm, args[1], "Other value")) return false;
|
|
|
|
double value = AS_NUM(args[0]);
|
|
double other = AS_NUM(args[1]);
|
|
RETURN_NUM(value <= other ? value : other);
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_max)
|
|
{
|
|
if (!validateNum(vm, args[1], "Other value")) return false;
|
|
|
|
double value = AS_NUM(args[0]);
|
|
double other = AS_NUM(args[1]);
|
|
RETURN_NUM(value > other ? value : other);
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_clamp)
|
|
{
|
|
if (!validateNum(vm, args[1], "Min value")) return false;
|
|
if (!validateNum(vm, args[2], "Max value")) return false;
|
|
|
|
double value = AS_NUM(args[0]);
|
|
double min = AS_NUM(args[1]);
|
|
double max = AS_NUM(args[2]);
|
|
double result = (value < min) ? min : ((value > max) ? max : value);
|
|
RETURN_NUM(result);
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_pow)
|
|
{
|
|
if (!validateNum(vm, args[1], "Power value")) return false;
|
|
|
|
RETURN_NUM(pow(AS_NUM(args[0]), AS_NUM(args[1])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_fraction)
|
|
{
|
|
double unused;
|
|
RETURN_NUM(modf(AS_NUM(args[0]) , &unused));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_isInfinity)
|
|
{
|
|
RETURN_BOOL(isinf(AS_NUM(args[0])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_isInteger)
|
|
{
|
|
double value = AS_NUM(args[0]);
|
|
if (isnan(value) || isinf(value)) RETURN_FALSE;
|
|
RETURN_BOOL(trunc(value) == value);
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_isNan)
|
|
{
|
|
RETURN_BOOL(isnan(AS_NUM(args[0])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_sign)
|
|
{
|
|
double value = AS_NUM(args[0]);
|
|
if (value > 0)
|
|
{
|
|
RETURN_NUM(1);
|
|
}
|
|
else if (value < 0)
|
|
{
|
|
RETURN_NUM(-1);
|
|
}
|
|
else
|
|
{
|
|
RETURN_NUM(0);
|
|
}
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_toString)
|
|
{
|
|
RETURN_VAL(wrenNumToString(vm, AS_NUM(args[0])));
|
|
}
|
|
|
|
DEF_PRIMITIVE(num_truncate)
|
|
{
|
|
double integer;
|
|
modf(AS_NUM(args[0]) , &integer);
|
|
RETURN_NUM(integer);
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_same)
|
|
{
|
|
RETURN_BOOL(wrenValuesEqual(args[1], args[2]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_not)
|
|
{
|
|
RETURN_VAL(FALSE_VAL);
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_eqeq)
|
|
{
|
|
RETURN_BOOL(wrenValuesEqual(args[0], args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_bangeq)
|
|
{
|
|
RETURN_BOOL(!wrenValuesEqual(args[0], args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_is)
|
|
{
|
|
if (!IS_CLASS(args[1]))
|
|
{
|
|
RETURN_ERROR("Right operand must be a class.");
|
|
}
|
|
|
|
ObjClass *classObj = wrenGetClass(vm, args[0]);
|
|
ObjClass *baseClassObj = AS_CLASS(args[1]);
|
|
|
|
// Walk the superclass chain looking for the class.
|
|
do
|
|
{
|
|
if (baseClassObj == classObj) RETURN_BOOL(true);
|
|
|
|
classObj = classObj->superclass;
|
|
}
|
|
while (classObj != NULL);
|
|
|
|
RETURN_BOOL(false);
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_toString)
|
|
{
|
|
Obj* obj = AS_OBJ(args[0]);
|
|
Value name = OBJ_VAL(obj->classObj->name);
|
|
RETURN_VAL(wrenStringFormat(vm, "instance of @", name));
|
|
}
|
|
|
|
DEF_PRIMITIVE(object_type)
|
|
{
|
|
RETURN_OBJ(wrenGetClass(vm, args[0]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_from)
|
|
{
|
|
RETURN_NUM(AS_RANGE(args[0])->from);
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_to)
|
|
{
|
|
RETURN_NUM(AS_RANGE(args[0])->to);
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_min)
|
|
{
|
|
ObjRange* range = AS_RANGE(args[0]);
|
|
RETURN_NUM(fmin(range->from, range->to));
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_max)
|
|
{
|
|
ObjRange* range = AS_RANGE(args[0]);
|
|
RETURN_NUM(fmax(range->from, range->to));
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_isInclusive)
|
|
{
|
|
RETURN_BOOL(AS_RANGE(args[0])->isInclusive);
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_iterate)
|
|
{
|
|
ObjRange* range = AS_RANGE(args[0]);
|
|
|
|
// Special case: empty range.
|
|
if (range->from == range->to && !range->isInclusive) RETURN_FALSE;
|
|
|
|
// Start the iteration.
|
|
if (IS_NULL(args[1])) RETURN_NUM(range->from);
|
|
|
|
if (!validateNum(vm, args[1], "Iterator")) return false;
|
|
|
|
double iterator = AS_NUM(args[1]);
|
|
|
|
// Iterate towards [to] from [from].
|
|
if (range->from < range->to)
|
|
{
|
|
iterator++;
|
|
if (iterator > range->to) RETURN_FALSE;
|
|
}
|
|
else
|
|
{
|
|
iterator--;
|
|
if (iterator < range->to) RETURN_FALSE;
|
|
}
|
|
|
|
if (!range->isInclusive && iterator == range->to) RETURN_FALSE;
|
|
|
|
RETURN_NUM(iterator);
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_iteratorValue)
|
|
{
|
|
// Assume the iterator is a number so that is the value of the range.
|
|
RETURN_VAL(args[1]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(range_toString)
|
|
{
|
|
ObjRange* range = AS_RANGE(args[0]);
|
|
|
|
Value from = wrenNumToString(vm, range->from);
|
|
wrenPushRoot(vm, AS_OBJ(from));
|
|
|
|
Value to = wrenNumToString(vm, range->to);
|
|
wrenPushRoot(vm, AS_OBJ(to));
|
|
|
|
Value result = wrenStringFormat(vm, "@$@", from,
|
|
range->isInclusive ? ".." : "...", to);
|
|
|
|
wrenPopRoot(vm);
|
|
wrenPopRoot(vm);
|
|
RETURN_VAL(result);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_fromCodePoint)
|
|
{
|
|
if (!validateInt(vm, args[1], "Code point")) return false;
|
|
|
|
int codePoint = (int)AS_NUM(args[1]);
|
|
if (codePoint < 0)
|
|
{
|
|
RETURN_ERROR("Code point cannot be negative.");
|
|
}
|
|
else if (codePoint > 0x10ffff)
|
|
{
|
|
RETURN_ERROR("Code point cannot be greater than 0x10ffff.");
|
|
}
|
|
|
|
RETURN_VAL(wrenStringFromCodePoint(vm, codePoint));
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_fromByte)
|
|
{
|
|
if (!validateInt(vm, args[1], "Byte")) return false;
|
|
int byte = (int) AS_NUM(args[1]);
|
|
if (byte < 0)
|
|
{
|
|
RETURN_ERROR("Byte cannot be negative.");
|
|
}
|
|
else if (byte > 0xff)
|
|
{
|
|
RETURN_ERROR("Byte cannot be greater than 0xff.");
|
|
}
|
|
RETURN_VAL(wrenStringFromByte(vm, (uint8_t) byte));
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_byteAt)
|
|
{
|
|
ObjString* string = AS_STRING(args[0]);
|
|
|
|
uint32_t index = validateIndex(vm, args[1], string->length, "Index");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
RETURN_NUM((uint8_t)string->value[index]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_byteCount)
|
|
{
|
|
RETURN_NUM(AS_STRING(args[0])->length);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_codePointAt)
|
|
{
|
|
ObjString* string = AS_STRING(args[0]);
|
|
|
|
uint32_t index = validateIndex(vm, args[1], string->length, "Index");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
// If we are in the middle of a UTF-8 sequence, indicate that.
|
|
const uint8_t* bytes = (uint8_t*)string->value;
|
|
if ((bytes[index] & 0xc0) == 0x80) RETURN_NUM(-1);
|
|
|
|
// Decode the UTF-8 sequence.
|
|
RETURN_NUM(wrenUtf8Decode((uint8_t*)string->value + index,
|
|
string->length - index));
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_contains)
|
|
{
|
|
if (!validateString(vm, args[1], "Argument")) return false;
|
|
|
|
ObjString* string = AS_STRING(args[0]);
|
|
ObjString* search = AS_STRING(args[1]);
|
|
|
|
RETURN_BOOL(wrenStringFind(string, search, 0) != UINT32_MAX);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_endsWith)
|
|
{
|
|
if (!validateString(vm, args[1], "Argument")) return false;
|
|
|
|
ObjString* string = AS_STRING(args[0]);
|
|
ObjString* search = AS_STRING(args[1]);
|
|
|
|
// Edge case: If the search string is longer then return false right away.
|
|
if (search->length > string->length) RETURN_FALSE;
|
|
|
|
RETURN_BOOL(memcmp(string->value + string->length - search->length,
|
|
search->value, search->length) == 0);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_indexOf1)
|
|
{
|
|
if (!validateString(vm, args[1], "Argument")) return false;
|
|
|
|
ObjString* string = AS_STRING(args[0]);
|
|
ObjString* search = AS_STRING(args[1]);
|
|
|
|
uint32_t index = wrenStringFind(string, search, 0);
|
|
RETURN_NUM(index == UINT32_MAX ? -1 : (int)index);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_indexOf2)
|
|
{
|
|
if (!validateString(vm, args[1], "Argument")) return false;
|
|
|
|
ObjString* string = AS_STRING(args[0]);
|
|
ObjString* search = AS_STRING(args[1]);
|
|
uint32_t start = validateIndex(vm, args[2], string->length, "Start");
|
|
if (start == UINT32_MAX) return false;
|
|
|
|
uint32_t index = wrenStringFind(string, search, start);
|
|
RETURN_NUM(index == UINT32_MAX ? -1 : (int)index);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_iterate)
|
|
{
|
|
ObjString* string = AS_STRING(args[0]);
|
|
|
|
// If we're starting the iteration, return the first index.
|
|
if (IS_NULL(args[1]))
|
|
{
|
|
if (string->length == 0) RETURN_FALSE;
|
|
RETURN_NUM(0);
|
|
}
|
|
|
|
if (!validateInt(vm, args[1], "Iterator")) return false;
|
|
|
|
if (AS_NUM(args[1]) < 0) RETURN_FALSE;
|
|
uint32_t index = (uint32_t)AS_NUM(args[1]);
|
|
|
|
// Advance to the beginning of the next UTF-8 sequence.
|
|
do
|
|
{
|
|
index++;
|
|
if (index >= string->length) RETURN_FALSE;
|
|
} while ((string->value[index] & 0xc0) == 0x80);
|
|
|
|
RETURN_NUM(index);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_iterateByte)
|
|
{
|
|
ObjString* string = AS_STRING(args[0]);
|
|
|
|
// If we're starting the iteration, return the first index.
|
|
if (IS_NULL(args[1]))
|
|
{
|
|
if (string->length == 0) RETURN_FALSE;
|
|
RETURN_NUM(0);
|
|
}
|
|
|
|
if (!validateInt(vm, args[1], "Iterator")) return false;
|
|
|
|
if (AS_NUM(args[1]) < 0) RETURN_FALSE;
|
|
uint32_t index = (uint32_t)AS_NUM(args[1]);
|
|
|
|
// Advance to the next byte.
|
|
index++;
|
|
if (index >= string->length) RETURN_FALSE;
|
|
|
|
RETURN_NUM(index);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_iteratorValue)
|
|
{
|
|
ObjString* string = AS_STRING(args[0]);
|
|
uint32_t index = validateIndex(vm, args[1], string->length, "Iterator");
|
|
if (index == UINT32_MAX) return false;
|
|
|
|
RETURN_VAL(wrenStringCodePointAt(vm, string, index));
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_startsWith)
|
|
{
|
|
if (!validateString(vm, args[1], "Argument")) return false;
|
|
|
|
ObjString* string = AS_STRING(args[0]);
|
|
ObjString* search = AS_STRING(args[1]);
|
|
|
|
// Edge case: If the search string is longer then return false right away.
|
|
if (search->length > string->length) RETURN_FALSE;
|
|
|
|
RETURN_BOOL(memcmp(string->value, search->value, search->length) == 0);
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_plus)
|
|
{
|
|
if (!validateString(vm, args[1], "Right operand")) return false;
|
|
RETURN_VAL(wrenStringFormat(vm, "@@", args[0], args[1]));
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_subscript)
|
|
{
|
|
ObjString* string = AS_STRING(args[0]);
|
|
|
|
if (IS_NUM(args[1]))
|
|
{
|
|
int index = validateIndex(vm, args[1], string->length, "Subscript");
|
|
if (index == -1) return false;
|
|
|
|
RETURN_VAL(wrenStringCodePointAt(vm, string, index));
|
|
}
|
|
|
|
if (!IS_RANGE(args[1]))
|
|
{
|
|
RETURN_ERROR("Subscript must be a number or a range.");
|
|
}
|
|
|
|
int step;
|
|
uint32_t count = string->length;
|
|
int start = calculateRange(vm, AS_RANGE(args[1]), &count, &step);
|
|
if (start == -1) return false;
|
|
|
|
RETURN_VAL(wrenNewStringFromRange(vm, string, start, count, step));
|
|
}
|
|
|
|
DEF_PRIMITIVE(string_toString)
|
|
{
|
|
RETURN_VAL(args[0]);
|
|
}
|
|
|
|
DEF_PRIMITIVE(system_clock)
|
|
{
|
|
RETURN_NUM((double)clock() / CLOCKS_PER_SEC);
|
|
}
|
|
|
|
DEF_PRIMITIVE(system_gc)
|
|
{
|
|
wrenCollectGarbage(vm);
|
|
RETURN_NULL;
|
|
}
|
|
|
|
DEF_PRIMITIVE(system_writeString)
|
|
{
|
|
if (vm->config.writeFn != NULL)
|
|
{
|
|
vm->config.writeFn(vm, AS_CSTRING(args[1]));
|
|
}
|
|
|
|
RETURN_VAL(args[1]);
|
|
}
|
|
|
|
// Creates either the Object or Class class in the core module with [name].
|
|
static ObjClass* defineClass(WrenVM* vm, ObjModule* module, const char* name)
|
|
{
|
|
ObjString* nameString = AS_STRING(wrenNewString(vm, name));
|
|
wrenPushRoot(vm, (Obj*)nameString);
|
|
|
|
ObjClass* classObj = wrenNewSingleClass(vm, 0, nameString);
|
|
|
|
wrenDefineVariable(vm, module, name, nameString->length, OBJ_VAL(classObj), NULL);
|
|
|
|
wrenPopRoot(vm);
|
|
return classObj;
|
|
}
|
|
|
|
void wrenInitializeCore(WrenVM* vm)
|
|
{
|
|
ObjModule* coreModule = wrenNewModule(vm, NULL);
|
|
wrenPushRoot(vm, (Obj*)coreModule);
|
|
|
|
// The core module's key is null in the module map.
|
|
wrenMapSet(vm, vm->modules, NULL_VAL, OBJ_VAL(coreModule));
|
|
wrenPopRoot(vm); // coreModule.
|
|
|
|
// Define the root Object class. This has to be done a little specially
|
|
// because it has no superclass.
|
|
vm->objectClass = defineClass(vm, coreModule, "Object");
|
|
PRIMITIVE(vm->objectClass, "!", object_not);
|
|
PRIMITIVE(vm->objectClass, "==(_)", object_eqeq);
|
|
PRIMITIVE(vm->objectClass, "!=(_)", object_bangeq);
|
|
PRIMITIVE(vm->objectClass, "is(_)", object_is);
|
|
PRIMITIVE(vm->objectClass, "toString", object_toString);
|
|
PRIMITIVE(vm->objectClass, "type", object_type);
|
|
|
|
// Now we can define Class, which is a subclass of Object.
|
|
vm->classClass = defineClass(vm, coreModule, "Class");
|
|
wrenBindSuperclass(vm, vm->classClass, vm->objectClass);
|
|
PRIMITIVE(vm->classClass, "name", class_name);
|
|
PRIMITIVE(vm->classClass, "supertype", class_supertype);
|
|
PRIMITIVE(vm->classClass, "toString", class_toString);
|
|
PRIMITIVE(vm->classClass, "attributes", class_attributes);
|
|
|
|
// Finally, we can define Object's metaclass which is a subclass of Class.
|
|
ObjClass* objectMetaclass = defineClass(vm, coreModule, "Object metaclass");
|
|
|
|
// Wire up the metaclass relationships now that all three classes are built.
|
|
vm->objectClass->obj.classObj = objectMetaclass;
|
|
objectMetaclass->obj.classObj = vm->classClass;
|
|
vm->classClass->obj.classObj = vm->classClass;
|
|
|
|
// Do this after wiring up the metaclasses so objectMetaclass doesn't get
|
|
// collected.
|
|
wrenBindSuperclass(vm, objectMetaclass, vm->classClass);
|
|
|
|
PRIMITIVE(objectMetaclass, "same(_,_)", object_same);
|
|
|
|
// The core class diagram ends up looking like this, where single lines point
|
|
// to a class's superclass, and double lines point to its metaclass:
|
|
//
|
|
// .------------------------------------. .====.
|
|
// | .---------------. | # #
|
|
// v | v | v #
|
|
// .---------. .-------------------. .-------. #
|
|
// | Object |==>| Object metaclass |==>| Class |=="
|
|
// '---------' '-------------------' '-------'
|
|
// ^ ^ ^ ^ ^
|
|
// | .--------------' # | #
|
|
// | | # | #
|
|
// .---------. .-------------------. # | # -.
|
|
// | Base |==>| Base metaclass |======" | # |
|
|
// '---------' '-------------------' | # |
|
|
// ^ | # |
|
|
// | .------------------' # | Example classes
|
|
// | | # |
|
|
// .---------. .-------------------. # |
|
|
// | Derived |==>| Derived metaclass |==========" |
|
|
// '---------' '-------------------' -'
|
|
|
|
// The rest of the classes can now be defined normally.
|
|
wrenInterpret(vm, NULL, coreModuleSource);
|
|
|
|
vm->boolClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Bool"));
|
|
PRIMITIVE(vm->boolClass, "toString", bool_toString);
|
|
PRIMITIVE(vm->boolClass, "!", bool_not);
|
|
|
|
vm->fiberClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Fiber"));
|
|
PRIMITIVE(vm->fiberClass->obj.classObj, "new(_)", fiber_new);
|
|
PRIMITIVE(vm->fiberClass->obj.classObj, "abort(_)", fiber_abort);
|
|
PRIMITIVE(vm->fiberClass->obj.classObj, "current", fiber_current);
|
|
PRIMITIVE(vm->fiberClass->obj.classObj, "suspend()", fiber_suspend);
|
|
PRIMITIVE(vm->fiberClass->obj.classObj, "yield()", fiber_yield);
|
|
PRIMITIVE(vm->fiberClass->obj.classObj, "yield(_)", fiber_yield1);
|
|
PRIMITIVE(vm->fiberClass, "call()", fiber_call);
|
|
PRIMITIVE(vm->fiberClass, "call(_)", fiber_call1);
|
|
PRIMITIVE(vm->fiberClass, "error", fiber_error);
|
|
PRIMITIVE(vm->fiberClass, "isDone", fiber_isDone);
|
|
PRIMITIVE(vm->fiberClass, "transfer()", fiber_transfer);
|
|
PRIMITIVE(vm->fiberClass, "transfer(_)", fiber_transfer1);
|
|
PRIMITIVE(vm->fiberClass, "transferError(_)", fiber_transferError);
|
|
PRIMITIVE(vm->fiberClass, "try()", fiber_try);
|
|
PRIMITIVE(vm->fiberClass, "try(_)", fiber_try1);
|
|
|
|
vm->fnClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Fn"));
|
|
PRIMITIVE(vm->fnClass->obj.classObj, "new(_)", fn_new);
|
|
|
|
PRIMITIVE(vm->fnClass, "arity", fn_arity);
|
|
|
|
FUNCTION_CALL(vm->fnClass, "call()", fn_call0);
|
|
FUNCTION_CALL(vm->fnClass, "call(_)", fn_call1);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_)", fn_call2);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_)", fn_call3);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_)", fn_call4);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_)", fn_call5);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_)", fn_call6);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_)", fn_call7);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_)", fn_call8);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_)", fn_call9);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_)", fn_call10);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_)", fn_call11);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_)", fn_call12);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call13);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call14);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call15);
|
|
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call16);
|
|
|
|
PRIMITIVE(vm->fnClass, "toString", fn_toString);
|
|
|
|
vm->nullClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Null"));
|
|
PRIMITIVE(vm->nullClass, "!", null_not);
|
|
PRIMITIVE(vm->nullClass, "toString", null_toString);
|
|
|
|
vm->numClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Num"));
|
|
PRIMITIVE(vm->numClass->obj.classObj, "fromString(_)", num_fromString);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "infinity", num_infinity);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "nan", num_nan);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "pi", num_pi);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "tau", num_tau);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "largest", num_largest);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "smallest", num_smallest);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "maxSafeInteger", num_maxSafeInteger);
|
|
PRIMITIVE(vm->numClass->obj.classObj, "minSafeInteger", num_minSafeInteger);
|
|
PRIMITIVE(vm->numClass, "-(_)", num_minus);
|
|
PRIMITIVE(vm->numClass, "+(_)", num_plus);
|
|
PRIMITIVE(vm->numClass, "*(_)", num_multiply);
|
|
PRIMITIVE(vm->numClass, "/(_)", num_divide);
|
|
PRIMITIVE(vm->numClass, "<(_)", num_lt);
|
|
PRIMITIVE(vm->numClass, ">(_)", num_gt);
|
|
PRIMITIVE(vm->numClass, "<=(_)", num_lte);
|
|
PRIMITIVE(vm->numClass, ">=(_)", num_gte);
|
|
PRIMITIVE(vm->numClass, "&(_)", num_bitwiseAnd);
|
|
PRIMITIVE(vm->numClass, "|(_)", num_bitwiseOr);
|
|
PRIMITIVE(vm->numClass, "^(_)", num_bitwiseXor);
|
|
PRIMITIVE(vm->numClass, "<<(_)", num_bitwiseLeftShift);
|
|
PRIMITIVE(vm->numClass, ">>(_)", num_bitwiseRightShift);
|
|
PRIMITIVE(vm->numClass, "abs", num_abs);
|
|
PRIMITIVE(vm->numClass, "acos", num_acos);
|
|
PRIMITIVE(vm->numClass, "asin", num_asin);
|
|
PRIMITIVE(vm->numClass, "atan", num_atan);
|
|
PRIMITIVE(vm->numClass, "cbrt", num_cbrt);
|
|
PRIMITIVE(vm->numClass, "ceil", num_ceil);
|
|
PRIMITIVE(vm->numClass, "cos", num_cos);
|
|
PRIMITIVE(vm->numClass, "floor", num_floor);
|
|
PRIMITIVE(vm->numClass, "-", num_negate);
|
|
PRIMITIVE(vm->numClass, "round", num_round);
|
|
PRIMITIVE(vm->numClass, "min(_)", num_min);
|
|
PRIMITIVE(vm->numClass, "max(_)", num_max);
|
|
PRIMITIVE(vm->numClass, "clamp(_,_)", num_clamp);
|
|
PRIMITIVE(vm->numClass, "sin", num_sin);
|
|
PRIMITIVE(vm->numClass, "sqrt", num_sqrt);
|
|
PRIMITIVE(vm->numClass, "tan", num_tan);
|
|
PRIMITIVE(vm->numClass, "log", num_log);
|
|
PRIMITIVE(vm->numClass, "log2", num_log2);
|
|
PRIMITIVE(vm->numClass, "exp", num_exp);
|
|
PRIMITIVE(vm->numClass, "%(_)", num_mod);
|
|
PRIMITIVE(vm->numClass, "~", num_bitwiseNot);
|
|
PRIMITIVE(vm->numClass, "..(_)", num_dotDot);
|
|
PRIMITIVE(vm->numClass, "...(_)", num_dotDotDot);
|
|
PRIMITIVE(vm->numClass, "atan(_)", num_atan2);
|
|
PRIMITIVE(vm->numClass, "pow(_)", num_pow);
|
|
PRIMITIVE(vm->numClass, "fraction", num_fraction);
|
|
PRIMITIVE(vm->numClass, "isInfinity", num_isInfinity);
|
|
PRIMITIVE(vm->numClass, "isInteger", num_isInteger);
|
|
PRIMITIVE(vm->numClass, "isNan", num_isNan);
|
|
PRIMITIVE(vm->numClass, "sign", num_sign);
|
|
PRIMITIVE(vm->numClass, "toString", num_toString);
|
|
PRIMITIVE(vm->numClass, "truncate", num_truncate);
|
|
|
|
// These are defined just so that 0 and -0 are equal, which is specified by
|
|
// IEEE 754 even though they have different bit representations.
|
|
PRIMITIVE(vm->numClass, "==(_)", num_eqeq);
|
|
PRIMITIVE(vm->numClass, "!=(_)", num_bangeq);
|
|
|
|
vm->stringClass = AS_CLASS(wrenFindVariable(vm, coreModule, "String"));
|
|
PRIMITIVE(vm->stringClass->obj.classObj, "fromCodePoint(_)", string_fromCodePoint);
|
|
PRIMITIVE(vm->stringClass->obj.classObj, "fromByte(_)", string_fromByte);
|
|
PRIMITIVE(vm->stringClass, "+(_)", string_plus);
|
|
PRIMITIVE(vm->stringClass, "[_]", string_subscript);
|
|
PRIMITIVE(vm->stringClass, "byteAt_(_)", string_byteAt);
|
|
PRIMITIVE(vm->stringClass, "byteCount_", string_byteCount);
|
|
PRIMITIVE(vm->stringClass, "codePointAt_(_)", string_codePointAt);
|
|
PRIMITIVE(vm->stringClass, "contains(_)", string_contains);
|
|
PRIMITIVE(vm->stringClass, "endsWith(_)", string_endsWith);
|
|
PRIMITIVE(vm->stringClass, "indexOf(_)", string_indexOf1);
|
|
PRIMITIVE(vm->stringClass, "indexOf(_,_)", string_indexOf2);
|
|
PRIMITIVE(vm->stringClass, "iterate(_)", string_iterate);
|
|
PRIMITIVE(vm->stringClass, "iterateByte_(_)", string_iterateByte);
|
|
PRIMITIVE(vm->stringClass, "iteratorValue(_)", string_iteratorValue);
|
|
PRIMITIVE(vm->stringClass, "startsWith(_)", string_startsWith);
|
|
PRIMITIVE(vm->stringClass, "toString", string_toString);
|
|
|
|
vm->listClass = AS_CLASS(wrenFindVariable(vm, coreModule, "List"));
|
|
PRIMITIVE(vm->listClass->obj.classObj, "filled(_,_)", list_filled);
|
|
PRIMITIVE(vm->listClass->obj.classObj, "new()", list_new);
|
|
PRIMITIVE(vm->listClass, "[_]", list_subscript);
|
|
PRIMITIVE(vm->listClass, "[_]=(_)", list_subscriptSetter);
|
|
PRIMITIVE(vm->listClass, "add(_)", list_add);
|
|
PRIMITIVE(vm->listClass, "addCore_(_)", list_addCore);
|
|
PRIMITIVE(vm->listClass, "clear()", list_clear);
|
|
PRIMITIVE(vm->listClass, "count", list_count);
|
|
PRIMITIVE(vm->listClass, "insert(_,_)", list_insert);
|
|
PRIMITIVE(vm->listClass, "iterate(_)", list_iterate);
|
|
PRIMITIVE(vm->listClass, "iteratorValue(_)", list_iteratorValue);
|
|
PRIMITIVE(vm->listClass, "removeAt(_)", list_removeAt);
|
|
PRIMITIVE(vm->listClass, "remove(_)", list_removeValue);
|
|
PRIMITIVE(vm->listClass, "indexOf(_)", list_indexOf);
|
|
PRIMITIVE(vm->listClass, "swap(_,_)", list_swap);
|
|
|
|
vm->mapClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Map"));
|
|
PRIMITIVE(vm->mapClass->obj.classObj, "new()", map_new);
|
|
PRIMITIVE(vm->mapClass, "[_]", map_subscript);
|
|
PRIMITIVE(vm->mapClass, "[_]=(_)", map_subscriptSetter);
|
|
PRIMITIVE(vm->mapClass, "addCore_(_,_)", map_addCore);
|
|
PRIMITIVE(vm->mapClass, "clear()", map_clear);
|
|
PRIMITIVE(vm->mapClass, "containsKey(_)", map_containsKey);
|
|
PRIMITIVE(vm->mapClass, "count", map_count);
|
|
PRIMITIVE(vm->mapClass, "remove(_)", map_remove);
|
|
PRIMITIVE(vm->mapClass, "iterate(_)", map_iterate);
|
|
PRIMITIVE(vm->mapClass, "keyIteratorValue_(_)", map_keyIteratorValue);
|
|
PRIMITIVE(vm->mapClass, "valueIteratorValue_(_)", map_valueIteratorValue);
|
|
|
|
vm->rangeClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Range"));
|
|
PRIMITIVE(vm->rangeClass, "from", range_from);
|
|
PRIMITIVE(vm->rangeClass, "to", range_to);
|
|
PRIMITIVE(vm->rangeClass, "min", range_min);
|
|
PRIMITIVE(vm->rangeClass, "max", range_max);
|
|
PRIMITIVE(vm->rangeClass, "isInclusive", range_isInclusive);
|
|
PRIMITIVE(vm->rangeClass, "iterate(_)", range_iterate);
|
|
PRIMITIVE(vm->rangeClass, "iteratorValue(_)", range_iteratorValue);
|
|
PRIMITIVE(vm->rangeClass, "toString", range_toString);
|
|
|
|
ObjClass* systemClass = AS_CLASS(wrenFindVariable(vm, coreModule, "System"));
|
|
PRIMITIVE(systemClass->obj.classObj, "clock", system_clock);
|
|
PRIMITIVE(systemClass->obj.classObj, "gc()", system_gc);
|
|
PRIMITIVE(systemClass->obj.classObj, "writeString_(_)", system_writeString);
|
|
|
|
// While bootstrapping the core types and running the core module, a number
|
|
// of string objects have been created, many of which were instantiated
|
|
// before stringClass was stored in the VM. Some of them *must* be created
|
|
// first -- the ObjClass for string itself has a reference to the ObjString
|
|
// for its name.
|
|
//
|
|
// These all currently have a NULL classObj pointer, so go back and assign
|
|
// them now that the string class is known.
|
|
for (Obj* obj = vm->first; obj != NULL; obj = obj->next)
|
|
{
|
|
if (obj->type == OBJ_STRING) obj->classObj = vm->stringClass;
|
|
}
|
|
}
|
|
// End file "wren_core.c"
|
|
// Begin file "wren_value.c"
|
|
#include <math.h>
|
|
#include <stdarg.h>
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
|
|
|
|
#if WREN_DEBUG_TRACE_MEMORY
|
|
#endif
|
|
|
|
// TODO: Tune these.
|
|
// The initial (and minimum) capacity of a non-empty list or map object.
|
|
#define MIN_CAPACITY 16
|
|
|
|
// The rate at which a collection's capacity grows when the size exceeds the
|
|
// current capacity. The new capacity will be determined by *multiplying* the
|
|
// old capacity by this. Growing geometrically is necessary to ensure that
|
|
// adding to a collection has O(1) amortized complexity.
|
|
#define GROW_FACTOR 2
|
|
|
|
// The maximum percentage of map entries that can be filled before the map is
|
|
// grown. A lower load takes more memory but reduces collisions which makes
|
|
// lookup faster.
|
|
#define MAP_LOAD_PERCENT 75
|
|
|
|
// The number of call frames initially allocated when a fiber is created. Making
|
|
// this smaller makes fibers use less memory (at first) but spends more time
|
|
// reallocating when the call stack grows.
|
|
#define INITIAL_CALL_FRAMES 4
|
|
|
|
DEFINE_BUFFER(Value, Value);
|
|
DEFINE_BUFFER(Method, Method);
|
|
|
|
static void initObj(WrenVM* vm, Obj* obj, ObjType type, ObjClass* classObj)
|
|
{
|
|
obj->type = type;
|
|
obj->isDark = false;
|
|
obj->classObj = classObj;
|
|
obj->next = vm->first;
|
|
vm->first = obj;
|
|
}
|
|
|
|
ObjClass* wrenNewSingleClass(WrenVM* vm, int numFields, ObjString* name)
|
|
{
|
|
ObjClass* classObj = ALLOCATE(vm, ObjClass);
|
|
initObj(vm, &classObj->obj, OBJ_CLASS, NULL);
|
|
classObj->superclass = NULL;
|
|
classObj->numFields = numFields;
|
|
classObj->name = name;
|
|
classObj->attributes = NULL_VAL;
|
|
|
|
wrenPushRoot(vm, (Obj*)classObj);
|
|
wrenMethodBufferInit(&classObj->methods);
|
|
wrenPopRoot(vm);
|
|
|
|
return classObj;
|
|
}
|
|
|
|
void wrenBindSuperclass(WrenVM* vm, ObjClass* subclass, ObjClass* superclass)
|
|
{
|
|
ASSERT(superclass != NULL, "Must have superclass.");
|
|
|
|
subclass->superclass = superclass;
|
|
|
|
// Include the superclass in the total number of fields.
|
|
if (subclass->numFields != -1)
|
|
{
|
|
subclass->numFields += superclass->numFields;
|
|
}
|
|
else
|
|
{
|
|
ASSERT(superclass->numFields == 0,
|
|
"A foreign class cannot inherit from a class with fields.");
|
|
}
|
|
|
|
// Inherit methods from its superclass.
|
|
for (int i = 0; i < superclass->methods.count; i++)
|
|
{
|
|
wrenBindMethod(vm, subclass, i, superclass->methods.data[i]);
|
|
}
|
|
}
|
|
|
|
ObjClass* wrenNewClass(WrenVM* vm, ObjClass* superclass, int numFields,
|
|
ObjString* name)
|
|
{
|
|
// Create the metaclass.
|
|
Value metaclassName = wrenStringFormat(vm, "@ metaclass", OBJ_VAL(name));
|
|
wrenPushRoot(vm, AS_OBJ(metaclassName));
|
|
|
|
ObjClass* metaclass = wrenNewSingleClass(vm, 0, AS_STRING(metaclassName));
|
|
metaclass->obj.classObj = vm->classClass;
|
|
|
|
wrenPopRoot(vm);
|
|
|
|
// Make sure the metaclass isn't collected when we allocate the class.
|
|
wrenPushRoot(vm, (Obj*)metaclass);
|
|
|
|
// Metaclasses always inherit Class and do not parallel the non-metaclass
|
|
// hierarchy.
|
|
wrenBindSuperclass(vm, metaclass, vm->classClass);
|
|
|
|
ObjClass* classObj = wrenNewSingleClass(vm, numFields, name);
|
|
|
|
// Make sure the class isn't collected while the inherited methods are being
|
|
// bound.
|
|
wrenPushRoot(vm, (Obj*)classObj);
|
|
|
|
classObj->obj.classObj = metaclass;
|
|
wrenBindSuperclass(vm, classObj, superclass);
|
|
|
|
wrenPopRoot(vm);
|
|
wrenPopRoot(vm);
|
|
|
|
return classObj;
|
|
}
|
|
|
|
void wrenBindMethod(WrenVM* vm, ObjClass* classObj, int symbol, Method method)
|
|
{
|
|
// Make sure the buffer is big enough to contain the symbol's index.
|
|
if (symbol >= classObj->methods.count)
|
|
{
|
|
Method noMethod;
|
|
noMethod.type = METHOD_NONE;
|
|
wrenMethodBufferFill(vm, &classObj->methods, noMethod,
|
|
symbol - classObj->methods.count + 1);
|
|
}
|
|
|
|
classObj->methods.data[symbol] = method;
|
|
}
|
|
|
|
ObjClosure* wrenNewClosure(WrenVM* vm, ObjFn* fn)
|
|
{
|
|
ObjClosure* closure = ALLOCATE_FLEX(vm, ObjClosure,
|
|
ObjUpvalue*, fn->numUpvalues);
|
|
initObj(vm, &closure->obj, OBJ_CLOSURE, vm->fnClass);
|
|
|
|
closure->fn = fn;
|
|
|
|
// Clear the upvalue array. We need to do this in case a GC is triggered
|
|
// after the closure is created but before the upvalue array is populated.
|
|
for (int i = 0; i < fn->numUpvalues; i++) closure->upvalues[i] = NULL;
|
|
|
|
return closure;
|
|
}
|
|
|
|
ObjFiber* wrenNewFiber(WrenVM* vm, ObjClosure* closure)
|
|
{
|
|
// Allocate the arrays before the fiber in case it triggers a GC.
|
|
CallFrame* frames = ALLOCATE_ARRAY(vm, CallFrame, INITIAL_CALL_FRAMES);
|
|
|
|
// Add one slot for the unused implicit receiver slot that the compiler
|
|
// assumes all functions have.
|
|
int stackCapacity = closure == NULL
|
|
? 1
|
|
: wrenPowerOf2Ceil(closure->fn->maxSlots + 1);
|
|
Value* stack = ALLOCATE_ARRAY(vm, Value, stackCapacity);
|
|
|
|
ObjFiber* fiber = ALLOCATE(vm, ObjFiber);
|
|
initObj(vm, &fiber->obj, OBJ_FIBER, vm->fiberClass);
|
|
|
|
fiber->stack = stack;
|
|
fiber->stackTop = fiber->stack;
|
|
fiber->stackCapacity = stackCapacity;
|
|
|
|
fiber->frames = frames;
|
|
fiber->frameCapacity = INITIAL_CALL_FRAMES;
|
|
fiber->numFrames = 0;
|
|
|
|
fiber->openUpvalues = NULL;
|
|
fiber->caller = NULL;
|
|
fiber->error = NULL_VAL;
|
|
fiber->state = FIBER_OTHER;
|
|
|
|
if (closure != NULL)
|
|
{
|
|
// Initialize the first call frame.
|
|
wrenAppendCallFrame(vm, fiber, closure, fiber->stack);
|
|
|
|
// The first slot always holds the closure.
|
|
fiber->stackTop[0] = OBJ_VAL(closure);
|
|
fiber->stackTop++;
|
|
}
|
|
|
|
return fiber;
|
|
}
|
|
|
|
void wrenEnsureStack(WrenVM* vm, ObjFiber* fiber, int needed)
|
|
{
|
|
if (fiber->stackCapacity >= needed) return;
|
|
|
|
int capacity = wrenPowerOf2Ceil(needed);
|
|
|
|
Value* oldStack = fiber->stack;
|
|
fiber->stack = (Value*)wrenReallocate(vm, fiber->stack,
|
|
sizeof(Value) * fiber->stackCapacity,
|
|
sizeof(Value) * capacity);
|
|
fiber->stackCapacity = capacity;
|
|
|
|
// If the reallocation moves the stack, then we need to recalculate every
|
|
// pointer that points into the old stack to into the same relative distance
|
|
// in the new stack. We have to be a little careful about how these are
|
|
// calculated because pointer subtraction is only well-defined within a
|
|
// single array, hence the slightly redundant-looking arithmetic below.
|
|
if (fiber->stack != oldStack)
|
|
{
|
|
// Top of the stack.
|
|
if (vm->apiStack >= oldStack && vm->apiStack <= fiber->stackTop)
|
|
{
|
|
vm->apiStack = fiber->stack + (vm->apiStack - oldStack);
|
|
}
|
|
|
|
// Stack pointer for each call frame.
|
|
for (int i = 0; i < fiber->numFrames; i++)
|
|
{
|
|
CallFrame* frame = &fiber->frames[i];
|
|
frame->stackStart = fiber->stack + (frame->stackStart - oldStack);
|
|
}
|
|
|
|
// Open upvalues.
|
|
for (ObjUpvalue* upvalue = fiber->openUpvalues;
|
|
upvalue != NULL;
|
|
upvalue = upvalue->next)
|
|
{
|
|
upvalue->value = fiber->stack + (upvalue->value - oldStack);
|
|
}
|
|
|
|
fiber->stackTop = fiber->stack + (fiber->stackTop - oldStack);
|
|
}
|
|
}
|
|
|
|
ObjForeign* wrenNewForeign(WrenVM* vm, ObjClass* classObj, size_t size)
|
|
{
|
|
ObjForeign* object = ALLOCATE_FLEX(vm, ObjForeign, uint8_t, size);
|
|
initObj(vm, &object->obj, OBJ_FOREIGN, classObj);
|
|
|
|
// Zero out the bytes.
|
|
memset(object->data, 0, size);
|
|
return object;
|
|
}
|
|
|
|
ObjFn* wrenNewFunction(WrenVM* vm, ObjModule* module, int maxSlots)
|
|
{
|
|
FnDebug* debug = ALLOCATE(vm, FnDebug);
|
|
debug->name = NULL;
|
|
wrenIntBufferInit(&debug->sourceLines);
|
|
|
|
ObjFn* fn = ALLOCATE(vm, ObjFn);
|
|
initObj(vm, &fn->obj, OBJ_FN, vm->fnClass);
|
|
|
|
wrenValueBufferInit(&fn->constants);
|
|
wrenByteBufferInit(&fn->code);
|
|
fn->module = module;
|
|
fn->maxSlots = maxSlots;
|
|
fn->numUpvalues = 0;
|
|
fn->arity = 0;
|
|
fn->debug = debug;
|
|
|
|
return fn;
|
|
}
|
|
|
|
void wrenFunctionBindName(WrenVM* vm, ObjFn* fn, const char* name, int length)
|
|
{
|
|
fn->debug->name = ALLOCATE_ARRAY(vm, char, length + 1);
|
|
memcpy(fn->debug->name, name, length);
|
|
fn->debug->name[length] = '\0';
|
|
}
|
|
|
|
Value wrenNewInstance(WrenVM* vm, ObjClass* classObj)
|
|
{
|
|
ObjInstance* instance = ALLOCATE_FLEX(vm, ObjInstance,
|
|
Value, classObj->numFields);
|
|
initObj(vm, &instance->obj, OBJ_INSTANCE, classObj);
|
|
|
|
// Initialize fields to null.
|
|
for (int i = 0; i < classObj->numFields; i++)
|
|
{
|
|
instance->fields[i] = NULL_VAL;
|
|
}
|
|
|
|
return OBJ_VAL(instance);
|
|
}
|
|
|
|
ObjList* wrenNewList(WrenVM* vm, uint32_t numElements)
|
|
{
|
|
// Allocate this before the list object in case it triggers a GC which would
|
|
// free the list.
|
|
Value* elements = NULL;
|
|
if (numElements > 0)
|
|
{
|
|
elements = ALLOCATE_ARRAY(vm, Value, numElements);
|
|
}
|
|
|
|
ObjList* list = ALLOCATE(vm, ObjList);
|
|
initObj(vm, &list->obj, OBJ_LIST, vm->listClass);
|
|
list->elements.capacity = numElements;
|
|
list->elements.count = numElements;
|
|
list->elements.data = elements;
|
|
return list;
|
|
}
|
|
|
|
void wrenListInsert(WrenVM* vm, ObjList* list, Value value, uint32_t index)
|
|
{
|
|
if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value));
|
|
|
|
// Add a slot at the end of the list.
|
|
wrenValueBufferWrite(vm, &list->elements, NULL_VAL);
|
|
|
|
if (IS_OBJ(value)) wrenPopRoot(vm);
|
|
|
|
// Shift the existing elements down.
|
|
for (uint32_t i = list->elements.count - 1; i > index; i--)
|
|
{
|
|
list->elements.data[i] = list->elements.data[i - 1];
|
|
}
|
|
|
|
// Store the new element.
|
|
list->elements.data[index] = value;
|
|
}
|
|
|
|
int wrenListIndexOf(WrenVM* vm, ObjList* list, Value value)
|
|
{
|
|
int count = list->elements.count;
|
|
for (int i = 0; i < count; i++)
|
|
{
|
|
Value item = list->elements.data[i];
|
|
if(wrenValuesEqual(item, value)) {
|
|
return i;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
Value wrenListRemoveAt(WrenVM* vm, ObjList* list, uint32_t index)
|
|
{
|
|
Value removed = list->elements.data[index];
|
|
|
|
if (IS_OBJ(removed)) wrenPushRoot(vm, AS_OBJ(removed));
|
|
|
|
// Shift items up.
|
|
for (int i = index; i < list->elements.count - 1; i++)
|
|
{
|
|
list->elements.data[i] = list->elements.data[i + 1];
|
|
}
|
|
|
|
// If we have too much excess capacity, shrink it.
|
|
if (list->elements.capacity / GROW_FACTOR >= list->elements.count)
|
|
{
|
|
list->elements.data = (Value*)wrenReallocate(vm, list->elements.data,
|
|
sizeof(Value) * list->elements.capacity,
|
|
sizeof(Value) * (list->elements.capacity / GROW_FACTOR));
|
|
list->elements.capacity /= GROW_FACTOR;
|
|
}
|
|
|
|
if (IS_OBJ(removed)) wrenPopRoot(vm);
|
|
|
|
list->elements.count--;
|
|
return removed;
|
|
}
|
|
|
|
ObjMap* wrenNewMap(WrenVM* vm)
|
|
{
|
|
ObjMap* map = ALLOCATE(vm, ObjMap);
|
|
initObj(vm, &map->obj, OBJ_MAP, vm->mapClass);
|
|
map->capacity = 0;
|
|
map->count = 0;
|
|
map->entries = NULL;
|
|
return map;
|
|
}
|
|
|
|
static inline uint32_t hashBits(uint64_t hash)
|
|
{
|
|
// From v8's ComputeLongHash() which in turn cites:
|
|
// Thomas Wang, Integer Hash Functions.
|
|
// http://www.concentric.net/~Ttwang/tech/inthash.htm
|
|
hash = ~hash + (hash << 18); // hash = (hash << 18) - hash - 1;
|
|
hash = hash ^ (hash >> 31);
|
|
hash = hash * 21; // hash = (hash + (hash << 2)) + (hash << 4);
|
|
hash = hash ^ (hash >> 11);
|
|
hash = hash + (hash << 6);
|
|
hash = hash ^ (hash >> 22);
|
|
return (uint32_t)(hash & 0x3fffffff);
|
|
}
|
|
|
|
// Generates a hash code for [num].
|
|
static inline uint32_t hashNumber(double num)
|
|
{
|
|
// Hash the raw bits of the value.
|
|
return hashBits(wrenDoubleToBits(num));
|
|
}
|
|
|
|
// Generates a hash code for [object].
|
|
static uint32_t hashObject(Obj* object)
|
|
{
|
|
switch (object->type)
|
|
{
|
|
case OBJ_CLASS:
|
|
// Classes just use their name.
|
|
return hashObject((Obj*)((ObjClass*)object)->name);
|
|
|
|
// Allow bare (non-closure) functions so that we can use a map to find
|
|
// existing constants in a function's constant table. This is only used
|
|
// internally. Since user code never sees a non-closure function, they
|
|
// cannot use them as map keys.
|
|
case OBJ_FN:
|
|
{
|
|
ObjFn* fn = (ObjFn*)object;
|
|
return hashNumber(fn->arity) ^ hashNumber(fn->code.count);
|
|
}
|
|
|
|
case OBJ_RANGE:
|
|
{
|
|
ObjRange* range = (ObjRange*)object;
|
|
return hashNumber(range->from) ^ hashNumber(range->to);
|
|
}
|
|
|
|
case OBJ_STRING:
|
|
return ((ObjString*)object)->hash;
|
|
|
|
default:
|
|
ASSERT(false, "Only immutable objects can be hashed.");
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
// Generates a hash code for [value], which must be one of the built-in
|
|
// immutable types: null, bool, class, num, range, or string.
|
|
static uint32_t hashValue(Value value)
|
|
{
|
|
// TODO: We'll probably want to randomize this at some point.
|
|
|
|
#if WREN_NAN_TAGGING
|
|
if (IS_OBJ(value)) return hashObject(AS_OBJ(value));
|
|
|
|
// Hash the raw bits of the unboxed value.
|
|
return hashBits(value);
|
|
#else
|
|
switch (value.type)
|
|
{
|
|
case VAL_FALSE: return 0;
|
|
case VAL_NULL: return 1;
|
|
case VAL_NUM: return hashNumber(AS_NUM(value));
|
|
case VAL_TRUE: return 2;
|
|
case VAL_OBJ: return hashObject(AS_OBJ(value));
|
|
default: UNREACHABLE();
|
|
}
|
|
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
// Looks for an entry with [key] in an array of [capacity] [entries].
|
|
//
|
|
// If found, sets [result] to point to it and returns `true`. Otherwise,
|
|
// returns `false` and points [result] to the entry where the key/value pair
|
|
// should be inserted.
|
|
static bool findEntry(MapEntry* entries, uint32_t capacity, Value key,
|
|
MapEntry** result)
|
|
{
|
|
// If there is no entry array (an empty map), we definitely won't find it.
|
|
if (capacity == 0) return false;
|
|
|
|
// Figure out where to insert it in the table. Use open addressing and
|
|
// basic linear probing.
|
|
uint32_t startIndex = hashValue(key) % capacity;
|
|
uint32_t index = startIndex;
|
|
|
|
// If we pass a tombstone and don't end up finding the key, its entry will
|
|
// be re-used for the insert.
|
|
MapEntry* tombstone = NULL;
|
|
|
|
// Walk the probe sequence until we've tried every slot.
|
|
do
|
|
{
|
|
MapEntry* entry = &entries[index];
|
|
|
|
if (IS_UNDEFINED(entry->key))
|
|
{
|
|
// If we found an empty slot, the key is not in the table. If we found a
|
|
// slot that contains a deleted key, we have to keep looking.
|
|
if (IS_FALSE(entry->value))
|
|
{
|
|
// We found an empty slot, so we've reached the end of the probe
|
|
// sequence without finding the key. If we passed a tombstone, then
|
|
// that's where we should insert the item, otherwise, put it here at
|
|
// the end of the sequence.
|
|
*result = tombstone != NULL ? tombstone : entry;
|
|
return false;
|
|
}
|
|
else
|
|
{
|
|
// We found a tombstone. We need to keep looking in case the key is
|
|
// after it, but we'll use this entry as the insertion point if the
|
|
// key ends up not being found.
|
|
if (tombstone == NULL) tombstone = entry;
|
|
}
|
|
}
|
|
else if (wrenValuesEqual(entry->key, key))
|
|
{
|
|
// We found the key.
|
|
*result = entry;
|
|
return true;
|
|
}
|
|
|
|
// Try the next slot.
|
|
index = (index + 1) % capacity;
|
|
}
|
|
while (index != startIndex);
|
|
|
|
// If we get here, the table is full of tombstones. Return the first one we
|
|
// found.
|
|
ASSERT(tombstone != NULL, "Map should have tombstones or empty entries.");
|
|
*result = tombstone;
|
|
return false;
|
|
}
|
|
|
|
// Inserts [key] and [value] in the array of [entries] with the given
|
|
// [capacity].
|
|
//
|
|
// Returns `true` if this is the first time [key] was added to the map.
|
|
static bool insertEntry(MapEntry* entries, uint32_t capacity,
|
|
Value key, Value value)
|
|
{
|
|
ASSERT(entries != NULL, "Should ensure capacity before inserting.");
|
|
|
|
MapEntry* entry;
|
|
if (findEntry(entries, capacity, key, &entry))
|
|
{
|
|
// Already present, so just replace the value.
|
|
entry->value = value;
|
|
return false;
|
|
}
|
|
else
|
|
{
|
|
entry->key = key;
|
|
entry->value = value;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Updates [map]'s entry array to [capacity].
|
|
static void resizeMap(WrenVM* vm, ObjMap* map, uint32_t capacity)
|
|
{
|
|
// Create the new empty hash table.
|
|
MapEntry* entries = ALLOCATE_ARRAY(vm, MapEntry, capacity);
|
|
for (uint32_t i = 0; i < capacity; i++)
|
|
{
|
|
entries[i].key = UNDEFINED_VAL;
|
|
entries[i].value = FALSE_VAL;
|
|
}
|
|
|
|
// Re-add the existing entries.
|
|
if (map->capacity > 0)
|
|
{
|
|
for (uint32_t i = 0; i < map->capacity; i++)
|
|
{
|
|
MapEntry* entry = &map->entries[i];
|
|
|
|
// Don't copy empty entries or tombstones.
|
|
if (IS_UNDEFINED(entry->key)) continue;
|
|
|
|
insertEntry(entries, capacity, entry->key, entry->value);
|
|
}
|
|
}
|
|
|
|
// Replace the array.
|
|
DEALLOCATE(vm, map->entries);
|
|
map->entries = entries;
|
|
map->capacity = capacity;
|
|
}
|
|
|
|
Value wrenMapGet(ObjMap* map, Value key)
|
|
{
|
|
MapEntry* entry;
|
|
if (findEntry(map->entries, map->capacity, key, &entry)) return entry->value;
|
|
|
|
return UNDEFINED_VAL;
|
|
}
|
|
|
|
void wrenMapSet(WrenVM* vm, ObjMap* map, Value key, Value value)
|
|
{
|
|
// If the map is getting too full, make room first.
|
|
if (map->count + 1 > map->capacity * MAP_LOAD_PERCENT / 100)
|
|
{
|
|
// Figure out the new hash table size.
|
|
uint32_t capacity = map->capacity * GROW_FACTOR;
|
|
if (capacity < MIN_CAPACITY) capacity = MIN_CAPACITY;
|
|
|
|
resizeMap(vm, map, capacity);
|
|
}
|
|
|
|
if (insertEntry(map->entries, map->capacity, key, value))
|
|
{
|
|
// A new key was added.
|
|
map->count++;
|
|
}
|
|
}
|
|
|
|
void wrenMapClear(WrenVM* vm, ObjMap* map)
|
|
{
|
|
DEALLOCATE(vm, map->entries);
|
|
map->entries = NULL;
|
|
map->capacity = 0;
|
|
map->count = 0;
|
|
}
|
|
|
|
Value wrenMapRemoveKey(WrenVM* vm, ObjMap* map, Value key)
|
|
{
|
|
MapEntry* entry;
|
|
if (!findEntry(map->entries, map->capacity, key, &entry)) return NULL_VAL;
|
|
|
|
// Remove the entry from the map. Set this value to true, which marks it as a
|
|
// deleted slot. When searching for a key, we will stop on empty slots, but
|
|
// continue past deleted slots.
|
|
Value value = entry->value;
|
|
entry->key = UNDEFINED_VAL;
|
|
entry->value = TRUE_VAL;
|
|
|
|
if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value));
|
|
|
|
map->count--;
|
|
|
|
if (map->count == 0)
|
|
{
|
|
// Removed the last item, so free the array.
|
|
wrenMapClear(vm, map);
|
|
}
|
|
else if (map->capacity > MIN_CAPACITY &&
|
|
map->count < map->capacity / GROW_FACTOR * MAP_LOAD_PERCENT / 100)
|
|
{
|
|
uint32_t capacity = map->capacity / GROW_FACTOR;
|
|
if (capacity < MIN_CAPACITY) capacity = MIN_CAPACITY;
|
|
|
|
// The map is getting empty, so shrink the entry array back down.
|
|
// TODO: Should we do this less aggressively than we grow?
|
|
resizeMap(vm, map, capacity);
|
|
}
|
|
|
|
if (IS_OBJ(value)) wrenPopRoot(vm);
|
|
return value;
|
|
}
|
|
|
|
ObjModule* wrenNewModule(WrenVM* vm, ObjString* name)
|
|
{
|
|
ObjModule* module = ALLOCATE(vm, ObjModule);
|
|
|
|
// Modules are never used as first-class objects, so don't need a class.
|
|
initObj(vm, (Obj*)module, OBJ_MODULE, NULL);
|
|
|
|
wrenPushRoot(vm, (Obj*)module);
|
|
|
|
wrenSymbolTableInit(&module->variableNames);
|
|
wrenValueBufferInit(&module->variables);
|
|
|
|
module->name = name;
|
|
|
|
wrenPopRoot(vm);
|
|
return module;
|
|
}
|
|
|
|
Value wrenNewRange(WrenVM* vm, double from, double to, bool isInclusive)
|
|
{
|
|
ObjRange* range = ALLOCATE(vm, ObjRange);
|
|
initObj(vm, &range->obj, OBJ_RANGE, vm->rangeClass);
|
|
range->from = from;
|
|
range->to = to;
|
|
range->isInclusive = isInclusive;
|
|
|
|
return OBJ_VAL(range);
|
|
}
|
|
|
|
// Creates a new string object with a null-terminated buffer large enough to
|
|
// hold a string of [length] but does not fill in the bytes.
|
|
//
|
|
// The caller is expected to fill in the buffer and then calculate the string's
|
|
// hash.
|
|
static ObjString* allocateString(WrenVM* vm, size_t length)
|
|
{
|
|
ObjString* string = ALLOCATE_FLEX(vm, ObjString, char, length + 1);
|
|
initObj(vm, &string->obj, OBJ_STRING, vm->stringClass);
|
|
string->length = (int)length;
|
|
string->value[length] = '\0';
|
|
|
|
return string;
|
|
}
|
|
|
|
// Calculates and stores the hash code for [string].
|
|
static void hashString(ObjString* string)
|
|
{
|
|
// FNV-1a hash. See: http://www.isthe.com/chongo/tech/comp/fnv/
|
|
uint32_t hash = 2166136261u;
|
|
|
|
// This is O(n) on the length of the string, but we only call this when a new
|
|
// string is created. Since the creation is also O(n) (to copy/initialize all
|
|
// the bytes), we allow this here.
|
|
for (uint32_t i = 0; i < string->length; i++)
|
|
{
|
|
hash ^= string->value[i];
|
|
hash *= 16777619;
|
|
}
|
|
|
|
string->hash = hash;
|
|
}
|
|
|
|
Value wrenNewString(WrenVM* vm, const char* text)
|
|
{
|
|
return wrenNewStringLength(vm, text, strlen(text));
|
|
}
|
|
|
|
Value wrenNewStringLength(WrenVM* vm, const char* text, size_t length)
|
|
{
|
|
// Allow NULL if the string is empty since byte buffers don't allocate any
|
|
// characters for a zero-length string.
|
|
ASSERT(length == 0 || text != NULL, "Unexpected NULL string.");
|
|
|
|
ObjString* string = allocateString(vm, length);
|
|
|
|
// Copy the string (if given one).
|
|
if (length > 0 && text != NULL) memcpy(string->value, text, length);
|
|
|
|
hashString(string);
|
|
return OBJ_VAL(string);
|
|
}
|
|
|
|
|
|
Value wrenNewStringFromRange(WrenVM* vm, ObjString* source, int start,
|
|
uint32_t count, int step)
|
|
{
|
|
uint8_t* from = (uint8_t*)source->value;
|
|
int length = 0;
|
|
for (uint32_t i = 0; i < count; i++)
|
|
{
|
|
length += wrenUtf8DecodeNumBytes(from[start + i * step]);
|
|
}
|
|
|
|
ObjString* result = allocateString(vm, length);
|
|
result->value[length] = '\0';
|
|
|
|
uint8_t* to = (uint8_t*)result->value;
|
|
for (uint32_t i = 0; i < count; i++)
|
|
{
|
|
int index = start + i * step;
|
|
int codePoint = wrenUtf8Decode(from + index, source->length - index);
|
|
|
|
if (codePoint != -1)
|
|
{
|
|
to += wrenUtf8Encode(codePoint, to);
|
|
}
|
|
}
|
|
|
|
hashString(result);
|
|
return OBJ_VAL(result);
|
|
}
|
|
|
|
Value wrenNumToString(WrenVM* vm, double value)
|
|
{
|
|
// Edge case: If the value is NaN or infinity, different versions of libc
|
|
// produce different outputs (some will format it signed and some won't). To
|
|
// get reliable output, handle it ourselves.
|
|
if (isnan(value)) return CONST_STRING(vm, "nan");
|
|
if (isinf(value))
|
|
{
|
|
if (value > 0.0)
|
|
{
|
|
return CONST_STRING(vm, "infinity");
|
|
}
|
|
else
|
|
{
|
|
return CONST_STRING(vm, "-infinity");
|
|
}
|
|
}
|
|
|
|
// This is large enough to hold any double converted to a string using
|
|
// "%.14g". Example:
|
|
//
|
|
// -1.12345678901234e-1022
|
|
//
|
|
// So we have:
|
|
//
|
|
// + 1 char for sign
|
|
// + 1 char for digit
|
|
// + 1 char for "."
|
|
// + 14 chars for decimal digits
|
|
// + 1 char for "e"
|
|
// + 1 char for "-" or "+"
|
|
// + 4 chars for exponent
|
|
// + 1 char for "\0"
|
|
// = 24
|
|
char buffer[24];
|
|
int length = sprintf(buffer, "%.14g", value);
|
|
return wrenNewStringLength(vm, buffer, length);
|
|
}
|
|
|
|
Value wrenStringFromCodePoint(WrenVM* vm, int value)
|
|
{
|
|
int length = wrenUtf8EncodeNumBytes(value);
|
|
ASSERT(length != 0, "Value out of range.");
|
|
|
|
ObjString* string = allocateString(vm, length);
|
|
|
|
wrenUtf8Encode(value, (uint8_t*)string->value);
|
|
hashString(string);
|
|
|
|
return OBJ_VAL(string);
|
|
}
|
|
|
|
Value wrenStringFromByte(WrenVM *vm, uint8_t value)
|
|
{
|
|
int length = 1;
|
|
ObjString* string = allocateString(vm, length);
|
|
string->value[0] = value;
|
|
hashString(string);
|
|
return OBJ_VAL(string);
|
|
}
|
|
|
|
Value wrenStringFormat(WrenVM* vm, const char* format, ...)
|
|
{
|
|
va_list argList;
|
|
|
|
// Calculate the length of the result string. Do this up front so we can
|
|
// create the final string with a single allocation.
|
|
va_start(argList, format);
|
|
size_t totalLength = 0;
|
|
for (const char* c = format; *c != '\0'; c++)
|
|
{
|
|
switch (*c)
|
|
{
|
|
case '$':
|
|
totalLength += strlen(va_arg(argList, const char*));
|
|
break;
|
|
|
|
case '@':
|
|
totalLength += AS_STRING(va_arg(argList, Value))->length;
|
|
break;
|
|
|
|
default:
|
|
// Any other character is interpreted literally.
|
|
totalLength++;
|
|
}
|
|
}
|
|
va_end(argList);
|
|
|
|
// Concatenate the string.
|
|
ObjString* result = allocateString(vm, totalLength);
|
|
|
|
va_start(argList, format);
|
|
char* start = result->value;
|
|
for (const char* c = format; *c != '\0'; c++)
|
|
{
|
|
switch (*c)
|
|
{
|
|
case '$':
|
|
{
|
|
const char* string = va_arg(argList, const char*);
|
|
size_t length = strlen(string);
|
|
memcpy(start, string, length);
|
|
start += length;
|
|
break;
|
|
}
|
|
|
|
case '@':
|
|
{
|
|
ObjString* string = AS_STRING(va_arg(argList, Value));
|
|
memcpy(start, string->value, string->length);
|
|
start += string->length;
|
|
break;
|
|
}
|
|
|
|
default:
|
|
// Any other character is interpreted literally.
|
|
*start++ = *c;
|
|
}
|
|
}
|
|
va_end(argList);
|
|
|
|
hashString(result);
|
|
|
|
return OBJ_VAL(result);
|
|
}
|
|
|
|
Value wrenStringCodePointAt(WrenVM* vm, ObjString* string, uint32_t index)
|
|
{
|
|
ASSERT(index < string->length, "Index out of bounds.");
|
|
|
|
int codePoint = wrenUtf8Decode((uint8_t*)string->value + index,
|
|
string->length - index);
|
|
if (codePoint == -1)
|
|
{
|
|
// If it isn't a valid UTF-8 sequence, treat it as a single raw byte.
|
|
char bytes[2];
|
|
bytes[0] = string->value[index];
|
|
bytes[1] = '\0';
|
|
return wrenNewStringLength(vm, bytes, 1);
|
|
}
|
|
|
|
return wrenStringFromCodePoint(vm, codePoint);
|
|
}
|
|
|
|
// Uses the Boyer-Moore-Horspool string matching algorithm.
|
|
uint32_t wrenStringFind(ObjString* haystack, ObjString* needle, uint32_t start)
|
|
{
|
|
// Edge case: An empty needle is always found.
|
|
if (needle->length == 0) return start;
|
|
|
|
// If the needle goes past the haystack it won't be found.
|
|
if (start + needle->length > haystack->length) return UINT32_MAX;
|
|
|
|
// If the startIndex is too far it also won't be found.
|
|
if (start >= haystack->length) return UINT32_MAX;
|
|
|
|
// Pre-calculate the shift table. For each character (8-bit value), we
|
|
// determine how far the search window can be advanced if that character is
|
|
// the last character in the haystack where we are searching for the needle
|
|
// and the needle doesn't match there.
|
|
uint32_t shift[UINT8_MAX];
|
|
uint32_t needleEnd = needle->length - 1;
|
|
|
|
// By default, we assume the character is not the needle at all. In that case
|
|
// case, if a match fails on that character, we can advance one whole needle
|
|
// width since.
|
|
for (uint32_t index = 0; index < UINT8_MAX; index++)
|
|
{
|
|
shift[index] = needle->length;
|
|
}
|
|
|
|
// Then, for every character in the needle, determine how far it is from the
|
|
// end. If a match fails on that character, we can advance the window such
|
|
// that it the last character in it lines up with the last place we could
|
|
// find it in the needle.
|
|
for (uint32_t index = 0; index < needleEnd; index++)
|
|
{
|
|
char c = needle->value[index];
|
|
shift[(uint8_t)c] = needleEnd - index;
|
|
}
|
|
|
|
// Slide the needle across the haystack, looking for the first match or
|
|
// stopping if the needle goes off the end.
|
|
char lastChar = needle->value[needleEnd];
|
|
uint32_t range = haystack->length - needle->length;
|
|
|
|
for (uint32_t index = start; index <= range; )
|
|
{
|
|
// Compare the last character in the haystack's window to the last character
|
|
// in the needle. If it matches, see if the whole needle matches.
|
|
char c = haystack->value[index + needleEnd];
|
|
if (lastChar == c &&
|
|
memcmp(haystack->value + index, needle->value, needleEnd) == 0)
|
|
{
|
|
// Found a match.
|
|
return index;
|
|
}
|
|
|
|
// Otherwise, slide the needle forward.
|
|
index += shift[(uint8_t)c];
|
|
}
|
|
|
|
// Not found.
|
|
return UINT32_MAX;
|
|
}
|
|
|
|
ObjUpvalue* wrenNewUpvalue(WrenVM* vm, Value* value)
|
|
{
|
|
ObjUpvalue* upvalue = ALLOCATE(vm, ObjUpvalue);
|
|
|
|
// Upvalues are never used as first-class objects, so don't need a class.
|
|
initObj(vm, &upvalue->obj, OBJ_UPVALUE, NULL);
|
|
|
|
upvalue->value = value;
|
|
upvalue->closed = NULL_VAL;
|
|
upvalue->next = NULL;
|
|
return upvalue;
|
|
}
|
|
|
|
void wrenGrayObj(WrenVM* vm, Obj* obj)
|
|
{
|
|
if (obj == NULL) return;
|
|
|
|
// Stop if the object is already darkened so we don't get stuck in a cycle.
|
|
if (obj->isDark) return;
|
|
|
|
// It's been reached.
|
|
obj->isDark = true;
|
|
|
|
// Add it to the gray list so it can be recursively explored for
|
|
// more marks later.
|
|
if (vm->grayCount >= vm->grayCapacity)
|
|
{
|
|
vm->grayCapacity = vm->grayCount * 2;
|
|
vm->gray = (Obj**)vm->config.reallocateFn(vm->gray,
|
|
vm->grayCapacity * sizeof(Obj*),
|
|
vm->config.userData);
|
|
}
|
|
|
|
vm->gray[vm->grayCount++] = obj;
|
|
}
|
|
|
|
void wrenGrayValue(WrenVM* vm, Value value)
|
|
{
|
|
if (!IS_OBJ(value)) return;
|
|
wrenGrayObj(vm, AS_OBJ(value));
|
|
}
|
|
|
|
void wrenGrayBuffer(WrenVM* vm, ValueBuffer* buffer)
|
|
{
|
|
for (int i = 0; i < buffer->count; i++)
|
|
{
|
|
wrenGrayValue(vm, buffer->data[i]);
|
|
}
|
|
}
|
|
|
|
static void blackenClass(WrenVM* vm, ObjClass* classObj)
|
|
{
|
|
// The metaclass.
|
|
wrenGrayObj(vm, (Obj*)classObj->obj.classObj);
|
|
|
|
// The superclass.
|
|
wrenGrayObj(vm, (Obj*)classObj->superclass);
|
|
|
|
// Method function objects.
|
|
for (int i = 0; i < classObj->methods.count; i++)
|
|
{
|
|
if (classObj->methods.data[i].type == METHOD_BLOCK)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)classObj->methods.data[i].as.closure);
|
|
}
|
|
}
|
|
|
|
wrenGrayObj(vm, (Obj*)classObj->name);
|
|
|
|
if(!IS_NULL(classObj->attributes)) wrenGrayObj(vm, AS_OBJ(classObj->attributes));
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjClass);
|
|
vm->bytesAllocated += classObj->methods.capacity * sizeof(Method);
|
|
}
|
|
|
|
static void blackenClosure(WrenVM* vm, ObjClosure* closure)
|
|
{
|
|
// Mark the function.
|
|
wrenGrayObj(vm, (Obj*)closure->fn);
|
|
|
|
// Mark the upvalues.
|
|
for (int i = 0; i < closure->fn->numUpvalues; i++)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)closure->upvalues[i]);
|
|
}
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjClosure);
|
|
vm->bytesAllocated += sizeof(ObjUpvalue*) * closure->fn->numUpvalues;
|
|
}
|
|
|
|
static void blackenFiber(WrenVM* vm, ObjFiber* fiber)
|
|
{
|
|
// Stack functions.
|
|
for (int i = 0; i < fiber->numFrames; i++)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)fiber->frames[i].closure);
|
|
}
|
|
|
|
// Stack variables.
|
|
for (Value* slot = fiber->stack; slot < fiber->stackTop; slot++)
|
|
{
|
|
wrenGrayValue(vm, *slot);
|
|
}
|
|
|
|
// Open upvalues.
|
|
ObjUpvalue* upvalue = fiber->openUpvalues;
|
|
while (upvalue != NULL)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)upvalue);
|
|
upvalue = upvalue->next;
|
|
}
|
|
|
|
// The caller.
|
|
wrenGrayObj(vm, (Obj*)fiber->caller);
|
|
wrenGrayValue(vm, fiber->error);
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjFiber);
|
|
vm->bytesAllocated += fiber->frameCapacity * sizeof(CallFrame);
|
|
vm->bytesAllocated += fiber->stackCapacity * sizeof(Value);
|
|
}
|
|
|
|
static void blackenFn(WrenVM* vm, ObjFn* fn)
|
|
{
|
|
// Mark the constants.
|
|
wrenGrayBuffer(vm, &fn->constants);
|
|
|
|
// Mark the module it belongs to, in case it's been unloaded.
|
|
wrenGrayObj(vm, (Obj*)fn->module);
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjFn);
|
|
vm->bytesAllocated += sizeof(uint8_t) * fn->code.capacity;
|
|
vm->bytesAllocated += sizeof(Value) * fn->constants.capacity;
|
|
|
|
// The debug line number buffer.
|
|
vm->bytesAllocated += sizeof(int) * fn->code.capacity;
|
|
// TODO: What about the function name?
|
|
}
|
|
|
|
static void blackenForeign(WrenVM* vm, ObjForeign* foreign)
|
|
{
|
|
// TODO: Keep track of how much memory the foreign object uses. We can store
|
|
// this in each foreign object, but it will balloon the size. We may not want
|
|
// that much overhead. One option would be to let the foreign class register
|
|
// a C function that returns a size for the object. That way the VM doesn't
|
|
// always have to explicitly store it.
|
|
}
|
|
|
|
static void blackenInstance(WrenVM* vm, ObjInstance* instance)
|
|
{
|
|
wrenGrayObj(vm, (Obj*)instance->obj.classObj);
|
|
|
|
// Mark the fields.
|
|
for (int i = 0; i < instance->obj.classObj->numFields; i++)
|
|
{
|
|
wrenGrayValue(vm, instance->fields[i]);
|
|
}
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjInstance);
|
|
vm->bytesAllocated += sizeof(Value) * instance->obj.classObj->numFields;
|
|
}
|
|
|
|
static void blackenList(WrenVM* vm, ObjList* list)
|
|
{
|
|
// Mark the elements.
|
|
wrenGrayBuffer(vm, &list->elements);
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjList);
|
|
vm->bytesAllocated += sizeof(Value) * list->elements.capacity;
|
|
}
|
|
|
|
static void blackenMap(WrenVM* vm, ObjMap* map)
|
|
{
|
|
// Mark the entries.
|
|
for (uint32_t i = 0; i < map->capacity; i++)
|
|
{
|
|
MapEntry* entry = &map->entries[i];
|
|
if (IS_UNDEFINED(entry->key)) continue;
|
|
|
|
wrenGrayValue(vm, entry->key);
|
|
wrenGrayValue(vm, entry->value);
|
|
}
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjMap);
|
|
vm->bytesAllocated += sizeof(MapEntry) * map->capacity;
|
|
}
|
|
|
|
static void blackenModule(WrenVM* vm, ObjModule* module)
|
|
{
|
|
// Top-level variables.
|
|
for (int i = 0; i < module->variables.count; i++)
|
|
{
|
|
wrenGrayValue(vm, module->variables.data[i]);
|
|
}
|
|
|
|
wrenBlackenSymbolTable(vm, &module->variableNames);
|
|
|
|
wrenGrayObj(vm, (Obj*)module->name);
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjModule);
|
|
}
|
|
|
|
static void blackenRange(WrenVM* vm, ObjRange* range)
|
|
{
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjRange);
|
|
}
|
|
|
|
static void blackenString(WrenVM* vm, ObjString* string)
|
|
{
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjString) + string->length + 1;
|
|
}
|
|
|
|
static void blackenUpvalue(WrenVM* vm, ObjUpvalue* upvalue)
|
|
{
|
|
// Mark the closed-over object (in case it is closed).
|
|
wrenGrayValue(vm, upvalue->closed);
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += sizeof(ObjUpvalue);
|
|
}
|
|
|
|
static void blackenObject(WrenVM* vm, Obj* obj)
|
|
{
|
|
#if WREN_DEBUG_TRACE_MEMORY
|
|
printf("mark ");
|
|
wrenDumpValue(OBJ_VAL(obj));
|
|
printf(" @ %p\n", obj);
|
|
#endif
|
|
|
|
// Traverse the object's fields.
|
|
switch (obj->type)
|
|
{
|
|
case OBJ_CLASS: blackenClass( vm, (ObjClass*) obj); break;
|
|
case OBJ_CLOSURE: blackenClosure( vm, (ObjClosure*) obj); break;
|
|
case OBJ_FIBER: blackenFiber( vm, (ObjFiber*) obj); break;
|
|
case OBJ_FN: blackenFn( vm, (ObjFn*) obj); break;
|
|
case OBJ_FOREIGN: blackenForeign( vm, (ObjForeign*) obj); break;
|
|
case OBJ_INSTANCE: blackenInstance(vm, (ObjInstance*)obj); break;
|
|
case OBJ_LIST: blackenList( vm, (ObjList*) obj); break;
|
|
case OBJ_MAP: blackenMap( vm, (ObjMap*) obj); break;
|
|
case OBJ_MODULE: blackenModule( vm, (ObjModule*) obj); break;
|
|
case OBJ_RANGE: blackenRange( vm, (ObjRange*) obj); break;
|
|
case OBJ_STRING: blackenString( vm, (ObjString*) obj); break;
|
|
case OBJ_UPVALUE: blackenUpvalue( vm, (ObjUpvalue*) obj); break;
|
|
}
|
|
}
|
|
|
|
void wrenBlackenObjects(WrenVM* vm)
|
|
{
|
|
while (vm->grayCount > 0)
|
|
{
|
|
// Pop an item from the gray stack.
|
|
Obj* obj = vm->gray[--vm->grayCount];
|
|
blackenObject(vm, obj);
|
|
}
|
|
}
|
|
|
|
void wrenFreeObj(WrenVM* vm, Obj* obj)
|
|
{
|
|
#if WREN_DEBUG_TRACE_MEMORY
|
|
printf("free ");
|
|
wrenDumpValue(OBJ_VAL(obj));
|
|
printf(" @ %p\n", obj);
|
|
#endif
|
|
|
|
switch (obj->type)
|
|
{
|
|
case OBJ_CLASS:
|
|
wrenMethodBufferClear(vm, &((ObjClass*)obj)->methods);
|
|
break;
|
|
|
|
case OBJ_FIBER:
|
|
{
|
|
ObjFiber* fiber = (ObjFiber*)obj;
|
|
DEALLOCATE(vm, fiber->frames);
|
|
DEALLOCATE(vm, fiber->stack);
|
|
break;
|
|
}
|
|
|
|
case OBJ_FN:
|
|
{
|
|
ObjFn* fn = (ObjFn*)obj;
|
|
wrenValueBufferClear(vm, &fn->constants);
|
|
wrenByteBufferClear(vm, &fn->code);
|
|
wrenIntBufferClear(vm, &fn->debug->sourceLines);
|
|
DEALLOCATE(vm, fn->debug->name);
|
|
DEALLOCATE(vm, fn->debug);
|
|
break;
|
|
}
|
|
|
|
case OBJ_FOREIGN:
|
|
wrenFinalizeForeign(vm, (ObjForeign*)obj);
|
|
break;
|
|
|
|
case OBJ_LIST:
|
|
wrenValueBufferClear(vm, &((ObjList*)obj)->elements);
|
|
break;
|
|
|
|
case OBJ_MAP:
|
|
DEALLOCATE(vm, ((ObjMap*)obj)->entries);
|
|
break;
|
|
|
|
case OBJ_MODULE:
|
|
wrenSymbolTableClear(vm, &((ObjModule*)obj)->variableNames);
|
|
wrenValueBufferClear(vm, &((ObjModule*)obj)->variables);
|
|
break;
|
|
|
|
case OBJ_CLOSURE:
|
|
case OBJ_INSTANCE:
|
|
case OBJ_RANGE:
|
|
case OBJ_STRING:
|
|
case OBJ_UPVALUE:
|
|
break;
|
|
}
|
|
|
|
DEALLOCATE(vm, obj);
|
|
}
|
|
|
|
ObjClass* wrenGetClass(WrenVM* vm, Value value)
|
|
{
|
|
return wrenGetClassInline(vm, value);
|
|
}
|
|
|
|
bool wrenValuesEqual(Value a, Value b)
|
|
{
|
|
if (wrenValuesSame(a, b)) return true;
|
|
|
|
// If we get here, it's only possible for two heap-allocated immutable objects
|
|
// to be equal.
|
|
if (!IS_OBJ(a) || !IS_OBJ(b)) return false;
|
|
|
|
Obj* aObj = AS_OBJ(a);
|
|
Obj* bObj = AS_OBJ(b);
|
|
|
|
// Must be the same type.
|
|
if (aObj->type != bObj->type) return false;
|
|
|
|
switch (aObj->type)
|
|
{
|
|
case OBJ_RANGE:
|
|
{
|
|
ObjRange* aRange = (ObjRange*)aObj;
|
|
ObjRange* bRange = (ObjRange*)bObj;
|
|
return aRange->from == bRange->from &&
|
|
aRange->to == bRange->to &&
|
|
aRange->isInclusive == bRange->isInclusive;
|
|
}
|
|
|
|
case OBJ_STRING:
|
|
{
|
|
ObjString* aString = (ObjString*)aObj;
|
|
ObjString* bString = (ObjString*)bObj;
|
|
return aString->hash == bString->hash &&
|
|
wrenStringEqualsCString(aString, bString->value, bString->length);
|
|
}
|
|
|
|
default:
|
|
// All other types are only equal if they are same, which they aren't if
|
|
// we get here.
|
|
return false;
|
|
}
|
|
}
|
|
// End file "wren_value.c"
|
|
// Begin file "wren_utils.c"
|
|
#include <string.h>
|
|
|
|
|
|
DEFINE_BUFFER(Byte, uint8_t);
|
|
DEFINE_BUFFER(Int, int);
|
|
DEFINE_BUFFER(String, ObjString*);
|
|
|
|
void wrenSymbolTableInit(SymbolTable* symbols)
|
|
{
|
|
wrenStringBufferInit(symbols);
|
|
}
|
|
|
|
void wrenSymbolTableClear(WrenVM* vm, SymbolTable* symbols)
|
|
{
|
|
wrenStringBufferClear(vm, symbols);
|
|
}
|
|
|
|
int wrenSymbolTableAdd(WrenVM* vm, SymbolTable* symbols,
|
|
const char* name, size_t length)
|
|
{
|
|
ObjString* symbol = AS_STRING(wrenNewStringLength(vm, name, length));
|
|
|
|
wrenPushRoot(vm, &symbol->obj);
|
|
wrenStringBufferWrite(vm, symbols, symbol);
|
|
wrenPopRoot(vm);
|
|
|
|
return symbols->count - 1;
|
|
}
|
|
|
|
int wrenSymbolTableEnsure(WrenVM* vm, SymbolTable* symbols,
|
|
const char* name, size_t length)
|
|
{
|
|
// See if the symbol is already defined.
|
|
int existing = wrenSymbolTableFind(symbols, name, length);
|
|
if (existing != -1) return existing;
|
|
|
|
// New symbol, so add it.
|
|
return wrenSymbolTableAdd(vm, symbols, name, length);
|
|
}
|
|
|
|
int wrenSymbolTableFind(const SymbolTable* symbols,
|
|
const char* name, size_t length)
|
|
{
|
|
// See if the symbol is already defined.
|
|
// TODO: O(n). Do something better.
|
|
for (int i = 0; i < symbols->count; i++)
|
|
{
|
|
if (wrenStringEqualsCString(symbols->data[i], name, length)) return i;
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
void wrenBlackenSymbolTable(WrenVM* vm, SymbolTable* symbolTable)
|
|
{
|
|
for (int i = 0; i < symbolTable->count; i++)
|
|
{
|
|
wrenGrayObj(vm, &symbolTable->data[i]->obj);
|
|
}
|
|
|
|
// Keep track of how much memory is still in use.
|
|
vm->bytesAllocated += symbolTable->capacity * sizeof(*symbolTable->data);
|
|
}
|
|
|
|
int wrenUtf8EncodeNumBytes(int value)
|
|
{
|
|
ASSERT(value >= 0, "Cannot encode a negative value.");
|
|
|
|
if (value <= 0x7f) return 1;
|
|
if (value <= 0x7ff) return 2;
|
|
if (value <= 0xffff) return 3;
|
|
if (value <= 0x10ffff) return 4;
|
|
return 0;
|
|
}
|
|
|
|
int wrenUtf8Encode(int value, uint8_t* bytes)
|
|
{
|
|
if (value <= 0x7f)
|
|
{
|
|
// Single byte (i.e. fits in ASCII).
|
|
*bytes = value & 0x7f;
|
|
return 1;
|
|
}
|
|
else if (value <= 0x7ff)
|
|
{
|
|
// Two byte sequence: 110xxxxx 10xxxxxx.
|
|
*bytes = 0xc0 | ((value & 0x7c0) >> 6);
|
|
bytes++;
|
|
*bytes = 0x80 | (value & 0x3f);
|
|
return 2;
|
|
}
|
|
else if (value <= 0xffff)
|
|
{
|
|
// Three byte sequence: 1110xxxx 10xxxxxx 10xxxxxx.
|
|
*bytes = 0xe0 | ((value & 0xf000) >> 12);
|
|
bytes++;
|
|
*bytes = 0x80 | ((value & 0xfc0) >> 6);
|
|
bytes++;
|
|
*bytes = 0x80 | (value & 0x3f);
|
|
return 3;
|
|
}
|
|
else if (value <= 0x10ffff)
|
|
{
|
|
// Four byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx.
|
|
*bytes = 0xf0 | ((value & 0x1c0000) >> 18);
|
|
bytes++;
|
|
*bytes = 0x80 | ((value & 0x3f000) >> 12);
|
|
bytes++;
|
|
*bytes = 0x80 | ((value & 0xfc0) >> 6);
|
|
bytes++;
|
|
*bytes = 0x80 | (value & 0x3f);
|
|
return 4;
|
|
}
|
|
|
|
// Invalid Unicode value. See: http://tools.ietf.org/html/rfc3629
|
|
UNREACHABLE();
|
|
return 0;
|
|
}
|
|
|
|
int wrenUtf8Decode(const uint8_t* bytes, uint32_t length)
|
|
{
|
|
// Single byte (i.e. fits in ASCII).
|
|
if (*bytes <= 0x7f) return *bytes;
|
|
|
|
int value;
|
|
uint32_t remainingBytes;
|
|
if ((*bytes & 0xe0) == 0xc0)
|
|
{
|
|
// Two byte sequence: 110xxxxx 10xxxxxx.
|
|
value = *bytes & 0x1f;
|
|
remainingBytes = 1;
|
|
}
|
|
else if ((*bytes & 0xf0) == 0xe0)
|
|
{
|
|
// Three byte sequence: 1110xxxx 10xxxxxx 10xxxxxx.
|
|
value = *bytes & 0x0f;
|
|
remainingBytes = 2;
|
|
}
|
|
else if ((*bytes & 0xf8) == 0xf0)
|
|
{
|
|
// Four byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx.
|
|
value = *bytes & 0x07;
|
|
remainingBytes = 3;
|
|
}
|
|
else
|
|
{
|
|
// Invalid UTF-8 sequence.
|
|
return -1;
|
|
}
|
|
|
|
// Don't read past the end of the buffer on truncated UTF-8.
|
|
if (remainingBytes > length - 1) return -1;
|
|
|
|
while (remainingBytes > 0)
|
|
{
|
|
bytes++;
|
|
remainingBytes--;
|
|
|
|
// Remaining bytes must be of form 10xxxxxx.
|
|
if ((*bytes & 0xc0) != 0x80) return -1;
|
|
|
|
value = value << 6 | (*bytes & 0x3f);
|
|
}
|
|
|
|
return value;
|
|
}
|
|
|
|
int wrenUtf8DecodeNumBytes(uint8_t byte)
|
|
{
|
|
// If the byte starts with 10xxxxx, it's the middle of a UTF-8 sequence, so
|
|
// don't count it at all.
|
|
if ((byte & 0xc0) == 0x80) return 0;
|
|
|
|
// The first byte's high bits tell us how many bytes are in the UTF-8
|
|
// sequence.
|
|
if ((byte & 0xf8) == 0xf0) return 4;
|
|
if ((byte & 0xf0) == 0xe0) return 3;
|
|
if ((byte & 0xe0) == 0xc0) return 2;
|
|
return 1;
|
|
}
|
|
|
|
// From: http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2Float
|
|
int wrenPowerOf2Ceil(int n)
|
|
{
|
|
n--;
|
|
n |= n >> 1;
|
|
n |= n >> 2;
|
|
n |= n >> 4;
|
|
n |= n >> 8;
|
|
n |= n >> 16;
|
|
n++;
|
|
|
|
return n;
|
|
}
|
|
|
|
uint32_t wrenValidateIndex(uint32_t count, int64_t value)
|
|
{
|
|
// Negative indices count from the end.
|
|
if (value < 0) value = count + value;
|
|
|
|
// Check bounds.
|
|
if (value >= 0 && value < count) return (uint32_t)value;
|
|
|
|
return UINT32_MAX;
|
|
}
|
|
// End file "wren_utils.c"
|
|
// Begin file "wren_vm.c"
|
|
#include <stdarg.h>
|
|
#include <string.h>
|
|
|
|
|
|
#if WREN_OPT_META
|
|
// Begin file "wren_opt_meta.h"
|
|
#ifndef wren_opt_meta_h
|
|
#define wren_opt_meta_h
|
|
|
|
|
|
// This module defines the Meta class and its associated methods.
|
|
#if WREN_OPT_META
|
|
|
|
const char* wrenMetaSource();
|
|
WrenForeignMethodFn wrenMetaBindForeignMethod(WrenVM* vm,
|
|
const char* className,
|
|
bool isStatic,
|
|
const char* signature);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
// End file "wren_opt_meta.h"
|
|
#endif
|
|
#if WREN_OPT_RANDOM
|
|
// Begin file "wren_opt_random.h"
|
|
#ifndef wren_opt_random_h
|
|
#define wren_opt_random_h
|
|
|
|
|
|
#if WREN_OPT_RANDOM
|
|
|
|
const char* wrenRandomSource();
|
|
WrenForeignClassMethods wrenRandomBindForeignClass(WrenVM* vm,
|
|
const char* module,
|
|
const char* className);
|
|
WrenForeignMethodFn wrenRandomBindForeignMethod(WrenVM* vm,
|
|
const char* className,
|
|
bool isStatic,
|
|
const char* signature);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
// End file "wren_opt_random.h"
|
|
#endif
|
|
|
|
#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC
|
|
#include <time.h>
|
|
#include <stdio.h>
|
|
#endif
|
|
|
|
// The behavior of realloc() when the size is 0 is implementation defined. It
|
|
// may return a non-NULL pointer which must not be dereferenced but nevertheless
|
|
// should be freed. To prevent that, we avoid calling realloc() with a zero
|
|
// size.
|
|
static void* defaultReallocate(void* ptr, size_t newSize, void* _)
|
|
{
|
|
if (newSize == 0)
|
|
{
|
|
free(ptr);
|
|
return NULL;
|
|
}
|
|
|
|
return realloc(ptr, newSize);
|
|
}
|
|
|
|
int wrenGetVersionNumber()
|
|
{
|
|
return WREN_VERSION_NUMBER;
|
|
}
|
|
|
|
void wrenInitConfiguration(WrenConfiguration* config)
|
|
{
|
|
config->reallocateFn = defaultReallocate;
|
|
config->resolveModuleFn = NULL;
|
|
config->loadModuleFn = NULL;
|
|
config->bindForeignMethodFn = NULL;
|
|
config->bindForeignClassFn = NULL;
|
|
config->writeFn = NULL;
|
|
config->errorFn = NULL;
|
|
config->initialHeapSize = 1024 * 1024 * 10;
|
|
config->minHeapSize = 1024 * 1024;
|
|
config->heapGrowthPercent = 50;
|
|
config->userData = NULL;
|
|
}
|
|
|
|
WrenVM* wrenNewVM(WrenConfiguration* config)
|
|
{
|
|
WrenReallocateFn reallocate = defaultReallocate;
|
|
void* userData = NULL;
|
|
if (config != NULL) {
|
|
userData = config->userData;
|
|
reallocate = config->reallocateFn ? config->reallocateFn : defaultReallocate;
|
|
}
|
|
|
|
WrenVM* vm = (WrenVM*)reallocate(NULL, sizeof(*vm), userData);
|
|
memset(vm, 0, sizeof(WrenVM));
|
|
|
|
// Copy the configuration if given one.
|
|
if (config != NULL)
|
|
{
|
|
memcpy(&vm->config, config, sizeof(WrenConfiguration));
|
|
|
|
// We choose to set this after copying,
|
|
// rather than modifying the user config pointer
|
|
vm->config.reallocateFn = reallocate;
|
|
}
|
|
else
|
|
{
|
|
wrenInitConfiguration(&vm->config);
|
|
}
|
|
|
|
// TODO: Should we allocate and free this during a GC?
|
|
vm->grayCount = 0;
|
|
// TODO: Tune this.
|
|
vm->grayCapacity = 4;
|
|
vm->gray = (Obj**)reallocate(NULL, vm->grayCapacity * sizeof(Obj*), userData);
|
|
vm->nextGC = vm->config.initialHeapSize;
|
|
|
|
wrenSymbolTableInit(&vm->methodNames);
|
|
|
|
vm->modules = wrenNewMap(vm);
|
|
wrenInitializeCore(vm);
|
|
return vm;
|
|
}
|
|
|
|
void wrenFreeVM(WrenVM* vm)
|
|
{
|
|
ASSERT(vm->methodNames.count > 0, "VM appears to have already been freed.");
|
|
|
|
// Free all of the GC objects.
|
|
Obj* obj = vm->first;
|
|
while (obj != NULL)
|
|
{
|
|
Obj* next = obj->next;
|
|
wrenFreeObj(vm, obj);
|
|
obj = next;
|
|
}
|
|
|
|
// Free up the GC gray set.
|
|
vm->gray = (Obj**)vm->config.reallocateFn(vm->gray, 0, vm->config.userData);
|
|
|
|
// Tell the user if they didn't free any handles. We don't want to just free
|
|
// them here because the host app may still have pointers to them that they
|
|
// may try to use. Better to tell them about the bug early.
|
|
ASSERT(vm->handles == NULL, "All handles have not been released.");
|
|
|
|
wrenSymbolTableClear(vm, &vm->methodNames);
|
|
|
|
DEALLOCATE(vm, vm);
|
|
}
|
|
|
|
void wrenCollectGarbage(WrenVM* vm)
|
|
{
|
|
#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC
|
|
printf("-- gc --\n");
|
|
|
|
size_t before = vm->bytesAllocated;
|
|
double startTime = (double)clock() / CLOCKS_PER_SEC;
|
|
#endif
|
|
|
|
// Mark all reachable objects.
|
|
|
|
// Reset this. As we mark objects, their size will be counted again so that
|
|
// we can track how much memory is in use without needing to know the size
|
|
// of each *freed* object.
|
|
//
|
|
// This is important because when freeing an unmarked object, we don't always
|
|
// know how much memory it is using. For example, when freeing an instance,
|
|
// we need to know its class to know how big it is, but its class may have
|
|
// already been freed.
|
|
vm->bytesAllocated = 0;
|
|
|
|
wrenGrayObj(vm, (Obj*)vm->modules);
|
|
|
|
// Temporary roots.
|
|
for (int i = 0; i < vm->numTempRoots; i++)
|
|
{
|
|
wrenGrayObj(vm, vm->tempRoots[i]);
|
|
}
|
|
|
|
// The current fiber.
|
|
wrenGrayObj(vm, (Obj*)vm->fiber);
|
|
|
|
// The handles.
|
|
for (WrenHandle* handle = vm->handles;
|
|
handle != NULL;
|
|
handle = handle->next)
|
|
{
|
|
wrenGrayValue(vm, handle->value);
|
|
}
|
|
|
|
// Any object the compiler is using (if there is one).
|
|
if (vm->compiler != NULL) wrenMarkCompiler(vm, vm->compiler);
|
|
|
|
// Method names.
|
|
wrenBlackenSymbolTable(vm, &vm->methodNames);
|
|
|
|
// Now that we have grayed the roots, do a depth-first search over all of the
|
|
// reachable objects.
|
|
wrenBlackenObjects(vm);
|
|
|
|
// Collect the white objects.
|
|
Obj** obj = &vm->first;
|
|
while (*obj != NULL)
|
|
{
|
|
if (!((*obj)->isDark))
|
|
{
|
|
// This object wasn't reached, so remove it from the list and free it.
|
|
Obj* unreached = *obj;
|
|
*obj = unreached->next;
|
|
wrenFreeObj(vm, unreached);
|
|
}
|
|
else
|
|
{
|
|
// This object was reached, so unmark it (for the next GC) and move on to
|
|
// the next.
|
|
(*obj)->isDark = false;
|
|
obj = &(*obj)->next;
|
|
}
|
|
}
|
|
|
|
// Calculate the next gc point, this is the current allocation plus
|
|
// a configured percentage of the current allocation.
|
|
vm->nextGC = vm->bytesAllocated + ((vm->bytesAllocated * vm->config.heapGrowthPercent) / 100);
|
|
if (vm->nextGC < vm->config.minHeapSize) vm->nextGC = vm->config.minHeapSize;
|
|
|
|
#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC
|
|
double elapsed = ((double)clock() / CLOCKS_PER_SEC) - startTime;
|
|
// Explicit cast because size_t has different sizes on 32-bit and 64-bit and
|
|
// we need a consistent type for the format string.
|
|
printf("GC %lu before, %lu after (%lu collected), next at %lu. Took %.3fms.\n",
|
|
(unsigned long)before,
|
|
(unsigned long)vm->bytesAllocated,
|
|
(unsigned long)(before - vm->bytesAllocated),
|
|
(unsigned long)vm->nextGC,
|
|
elapsed*1000.0);
|
|
#endif
|
|
}
|
|
|
|
void* wrenReallocate(WrenVM* vm, void* memory, size_t oldSize, size_t newSize)
|
|
{
|
|
#if WREN_DEBUG_TRACE_MEMORY
|
|
// Explicit cast because size_t has different sizes on 32-bit and 64-bit and
|
|
// we need a consistent type for the format string.
|
|
printf("reallocate %p %lu -> %lu\n",
|
|
memory, (unsigned long)oldSize, (unsigned long)newSize);
|
|
#endif
|
|
|
|
// If new bytes are being allocated, add them to the total count. If objects
|
|
// are being completely deallocated, we don't track that (since we don't
|
|
// track the original size). Instead, that will be handled while marking
|
|
// during the next GC.
|
|
vm->bytesAllocated += newSize - oldSize;
|
|
|
|
#if WREN_DEBUG_GC_STRESS
|
|
// Since collecting calls this function to free things, make sure we don't
|
|
// recurse.
|
|
if (newSize > 0) wrenCollectGarbage(vm);
|
|
#else
|
|
if (newSize > 0 && vm->bytesAllocated > vm->nextGC) wrenCollectGarbage(vm);
|
|
#endif
|
|
|
|
return vm->config.reallocateFn(memory, newSize, vm->config.userData);
|
|
}
|
|
|
|
// Captures the local variable [local] into an [Upvalue]. If that local is
|
|
// already in an upvalue, the existing one will be used. (This is important to
|
|
// ensure that multiple closures closing over the same variable actually see
|
|
// the same variable.) Otherwise, it will create a new open upvalue and add it
|
|
// the fiber's list of upvalues.
|
|
static ObjUpvalue* captureUpvalue(WrenVM* vm, ObjFiber* fiber, Value* local)
|
|
{
|
|
// If there are no open upvalues at all, we must need a new one.
|
|
if (fiber->openUpvalues == NULL)
|
|
{
|
|
fiber->openUpvalues = wrenNewUpvalue(vm, local);
|
|
return fiber->openUpvalues;
|
|
}
|
|
|
|
ObjUpvalue* prevUpvalue = NULL;
|
|
ObjUpvalue* upvalue = fiber->openUpvalues;
|
|
|
|
// Walk towards the bottom of the stack until we find a previously existing
|
|
// upvalue or pass where it should be.
|
|
while (upvalue != NULL && upvalue->value > local)
|
|
{
|
|
prevUpvalue = upvalue;
|
|
upvalue = upvalue->next;
|
|
}
|
|
|
|
// Found an existing upvalue for this local.
|
|
if (upvalue != NULL && upvalue->value == local) return upvalue;
|
|
|
|
// We've walked past this local on the stack, so there must not be an
|
|
// upvalue for it already. Make a new one and link it in in the right
|
|
// place to keep the list sorted.
|
|
ObjUpvalue* createdUpvalue = wrenNewUpvalue(vm, local);
|
|
if (prevUpvalue == NULL)
|
|
{
|
|
// The new one is the first one in the list.
|
|
fiber->openUpvalues = createdUpvalue;
|
|
}
|
|
else
|
|
{
|
|
prevUpvalue->next = createdUpvalue;
|
|
}
|
|
|
|
createdUpvalue->next = upvalue;
|
|
return createdUpvalue;
|
|
}
|
|
|
|
// Closes any open upvalues that have been created for stack slots at [last]
|
|
// and above.
|
|
static void closeUpvalues(ObjFiber* fiber, Value* last)
|
|
{
|
|
while (fiber->openUpvalues != NULL &&
|
|
fiber->openUpvalues->value >= last)
|
|
{
|
|
ObjUpvalue* upvalue = fiber->openUpvalues;
|
|
|
|
// Move the value into the upvalue itself and point the upvalue to it.
|
|
upvalue->closed = *upvalue->value;
|
|
upvalue->value = &upvalue->closed;
|
|
|
|
// Remove it from the open upvalue list.
|
|
fiber->openUpvalues = upvalue->next;
|
|
}
|
|
}
|
|
|
|
// Looks up a foreign method in [moduleName] on [className] with [signature].
|
|
//
|
|
// This will try the host's foreign method binder first. If that fails, it
|
|
// falls back to handling the built-in modules.
|
|
static WrenForeignMethodFn findForeignMethod(WrenVM* vm,
|
|
const char* moduleName,
|
|
const char* className,
|
|
bool isStatic,
|
|
const char* signature)
|
|
{
|
|
WrenForeignMethodFn method = NULL;
|
|
|
|
if (vm->config.bindForeignMethodFn != NULL)
|
|
{
|
|
method = vm->config.bindForeignMethodFn(vm, moduleName, className, isStatic,
|
|
signature);
|
|
}
|
|
|
|
// If the host didn't provide it, see if it's an optional one.
|
|
if (method == NULL)
|
|
{
|
|
#if WREN_OPT_META
|
|
if (strcmp(moduleName, "meta") == 0)
|
|
{
|
|
method = wrenMetaBindForeignMethod(vm, className, isStatic, signature);
|
|
}
|
|
#endif
|
|
#if WREN_OPT_RANDOM
|
|
if (strcmp(moduleName, "random") == 0)
|
|
{
|
|
method = wrenRandomBindForeignMethod(vm, className, isStatic, signature);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
return method;
|
|
}
|
|
|
|
// Defines [methodValue] as a method on [classObj].
|
|
//
|
|
// Handles both foreign methods where [methodValue] is a string containing the
|
|
// method's signature and Wren methods where [methodValue] is a function.
|
|
//
|
|
// Aborts the current fiber if the method is a foreign method that could not be
|
|
// found.
|
|
static void bindMethod(WrenVM* vm, int methodType, int symbol,
|
|
ObjModule* module, ObjClass* classObj, Value methodValue)
|
|
{
|
|
const char* className = classObj->name->value;
|
|
if (methodType == CODE_METHOD_STATIC) classObj = classObj->obj.classObj;
|
|
|
|
Method method;
|
|
if (IS_STRING(methodValue))
|
|
{
|
|
const char* name = AS_CSTRING(methodValue);
|
|
method.type = METHOD_FOREIGN;
|
|
method.as.foreign = findForeignMethod(vm, module->name->value,
|
|
className,
|
|
methodType == CODE_METHOD_STATIC,
|
|
name);
|
|
|
|
if (method.as.foreign == NULL)
|
|
{
|
|
vm->fiber->error = wrenStringFormat(vm,
|
|
"Could not find foreign method '@' for class $ in module '$'.",
|
|
methodValue, classObj->name->value, module->name->value);
|
|
return;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
method.as.closure = AS_CLOSURE(methodValue);
|
|
method.type = METHOD_BLOCK;
|
|
|
|
// Patch up the bytecode now that we know the superclass.
|
|
wrenBindMethodCode(classObj, method.as.closure->fn);
|
|
}
|
|
|
|
wrenBindMethod(vm, classObj, symbol, method);
|
|
}
|
|
|
|
static void callForeign(WrenVM* vm, ObjFiber* fiber,
|
|
WrenForeignMethodFn foreign, int numArgs)
|
|
{
|
|
ASSERT(vm->apiStack == NULL, "Cannot already be in foreign call.");
|
|
vm->apiStack = fiber->stackTop - numArgs;
|
|
|
|
foreign(vm);
|
|
|
|
// Discard the stack slots for the arguments and temporaries but leave one
|
|
// for the result.
|
|
fiber->stackTop = vm->apiStack + 1;
|
|
|
|
vm->apiStack = NULL;
|
|
}
|
|
|
|
// Handles the current fiber having aborted because of an error.
|
|
//
|
|
// Walks the call chain of fibers, aborting each one until it hits a fiber that
|
|
// handles the error. If none do, tells the VM to stop.
|
|
static void runtimeError(WrenVM* vm)
|
|
{
|
|
ASSERT(wrenHasError(vm->fiber), "Should only call this after an error.");
|
|
|
|
ObjFiber* current = vm->fiber;
|
|
Value error = current->error;
|
|
|
|
while (current != NULL)
|
|
{
|
|
// Every fiber along the call chain gets aborted with the same error.
|
|
current->error = error;
|
|
|
|
// If the caller ran this fiber using "try", give it the error and stop.
|
|
if (current->state == FIBER_TRY)
|
|
{
|
|
// Make the caller's try method return the error message.
|
|
current->caller->stackTop[-1] = vm->fiber->error;
|
|
vm->fiber = current->caller;
|
|
return;
|
|
}
|
|
|
|
// Otherwise, unhook the caller since we will never resume and return to it.
|
|
ObjFiber* caller = current->caller;
|
|
current->caller = NULL;
|
|
current = caller;
|
|
}
|
|
|
|
// If we got here, nothing caught the error, so show the stack trace.
|
|
wrenDebugPrintStackTrace(vm);
|
|
vm->fiber = NULL;
|
|
vm->apiStack = NULL;
|
|
}
|
|
|
|
// Aborts the current fiber with an appropriate method not found error for a
|
|
// method with [symbol] on [classObj].
|
|
static void methodNotFound(WrenVM* vm, ObjClass* classObj, int symbol)
|
|
{
|
|
vm->fiber->error = wrenStringFormat(vm, "@ does not implement '$'.",
|
|
OBJ_VAL(classObj->name), vm->methodNames.data[symbol]->value);
|
|
}
|
|
|
|
// Looks up the previously loaded module with [name].
|
|
//
|
|
// Returns `NULL` if no module with that name has been loaded.
|
|
static ObjModule* getModule(WrenVM* vm, Value name)
|
|
{
|
|
Value moduleValue = wrenMapGet(vm->modules, name);
|
|
return !IS_UNDEFINED(moduleValue) ? AS_MODULE(moduleValue) : NULL;
|
|
}
|
|
|
|
static ObjClosure* compileInModule(WrenVM* vm, Value name, const char* source,
|
|
bool isExpression, bool printErrors)
|
|
{
|
|
// See if the module has already been loaded.
|
|
ObjModule* module = getModule(vm, name);
|
|
if (module == NULL)
|
|
{
|
|
module = wrenNewModule(vm, AS_STRING(name));
|
|
|
|
// It's possible for the wrenMapSet below to resize the modules map,
|
|
// and trigger a GC while doing so. When this happens it will collect
|
|
// the module we've just created. Once in the map it is safe.
|
|
wrenPushRoot(vm, (Obj*)module);
|
|
|
|
// Store it in the VM's module registry so we don't load the same module
|
|
// multiple times.
|
|
wrenMapSet(vm, vm->modules, name, OBJ_VAL(module));
|
|
|
|
wrenPopRoot(vm);
|
|
|
|
// Implicitly import the core module.
|
|
ObjModule* coreModule = getModule(vm, NULL_VAL);
|
|
for (int i = 0; i < coreModule->variables.count; i++)
|
|
{
|
|
wrenDefineVariable(vm, module,
|
|
coreModule->variableNames.data[i]->value,
|
|
coreModule->variableNames.data[i]->length,
|
|
coreModule->variables.data[i], NULL);
|
|
}
|
|
}
|
|
|
|
ObjFn* fn = wrenCompile(vm, module, source, isExpression, printErrors);
|
|
if (fn == NULL)
|
|
{
|
|
// TODO: Should we still store the module even if it didn't compile?
|
|
return NULL;
|
|
}
|
|
|
|
// Functions are always wrapped in closures.
|
|
wrenPushRoot(vm, (Obj*)fn);
|
|
ObjClosure* closure = wrenNewClosure(vm, fn);
|
|
wrenPopRoot(vm); // fn.
|
|
|
|
return closure;
|
|
}
|
|
|
|
// Verifies that [superclassValue] is a valid object to inherit from. That
|
|
// means it must be a class and cannot be the class of any built-in type.
|
|
//
|
|
// Also validates that it doesn't result in a class with too many fields and
|
|
// the other limitations foreign classes have.
|
|
//
|
|
// If successful, returns `null`. Otherwise, returns a string for the runtime
|
|
// error message.
|
|
static Value validateSuperclass(WrenVM* vm, Value name, Value superclassValue,
|
|
int numFields)
|
|
{
|
|
// Make sure the superclass is a class.
|
|
if (!IS_CLASS(superclassValue))
|
|
{
|
|
return wrenStringFormat(vm,
|
|
"Class '@' cannot inherit from a non-class object.",
|
|
name);
|
|
}
|
|
|
|
// Make sure it doesn't inherit from a sealed built-in type. Primitive methods
|
|
// on these classes assume the instance is one of the other Obj___ types and
|
|
// will fail horribly if it's actually an ObjInstance.
|
|
ObjClass* superclass = AS_CLASS(superclassValue);
|
|
if (superclass == vm->classClass ||
|
|
superclass == vm->fiberClass ||
|
|
superclass == vm->fnClass || // Includes OBJ_CLOSURE.
|
|
superclass == vm->listClass ||
|
|
superclass == vm->mapClass ||
|
|
superclass == vm->rangeClass ||
|
|
superclass == vm->stringClass ||
|
|
superclass == vm->boolClass ||
|
|
superclass == vm->nullClass ||
|
|
superclass == vm->numClass)
|
|
{
|
|
return wrenStringFormat(vm,
|
|
"Class '@' cannot inherit from built-in class '@'.",
|
|
name, OBJ_VAL(superclass->name));
|
|
}
|
|
|
|
if (superclass->numFields == -1)
|
|
{
|
|
return wrenStringFormat(vm,
|
|
"Class '@' cannot inherit from foreign class '@'.",
|
|
name, OBJ_VAL(superclass->name));
|
|
}
|
|
|
|
if (numFields == -1 && superclass->numFields > 0)
|
|
{
|
|
return wrenStringFormat(vm,
|
|
"Foreign class '@' may not inherit from a class with fields.",
|
|
name);
|
|
}
|
|
|
|
if (superclass->numFields + numFields > MAX_FIELDS)
|
|
{
|
|
return wrenStringFormat(vm,
|
|
"Class '@' may not have more than 255 fields, including inherited "
|
|
"ones.", name);
|
|
}
|
|
|
|
return NULL_VAL;
|
|
}
|
|
|
|
static void bindForeignClass(WrenVM* vm, ObjClass* classObj, ObjModule* module)
|
|
{
|
|
WrenForeignClassMethods methods;
|
|
methods.allocate = NULL;
|
|
methods.finalize = NULL;
|
|
|
|
// Check the optional built-in module first so the host can override it.
|
|
|
|
if (vm->config.bindForeignClassFn != NULL)
|
|
{
|
|
methods = vm->config.bindForeignClassFn(vm, module->name->value,
|
|
classObj->name->value);
|
|
}
|
|
|
|
// If the host didn't provide it, see if it's a built in optional module.
|
|
if (methods.allocate == NULL && methods.finalize == NULL)
|
|
{
|
|
#if WREN_OPT_RANDOM
|
|
if (strcmp(module->name->value, "random") == 0)
|
|
{
|
|
methods = wrenRandomBindForeignClass(vm, module->name->value,
|
|
classObj->name->value);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
Method method;
|
|
method.type = METHOD_FOREIGN;
|
|
|
|
// Add the symbol even if there is no allocator so we can ensure that the
|
|
// symbol itself is always in the symbol table.
|
|
int symbol = wrenSymbolTableEnsure(vm, &vm->methodNames, "<allocate>", 10);
|
|
if (methods.allocate != NULL)
|
|
{
|
|
method.as.foreign = methods.allocate;
|
|
wrenBindMethod(vm, classObj, symbol, method);
|
|
}
|
|
|
|
// Add the symbol even if there is no finalizer so we can ensure that the
|
|
// symbol itself is always in the symbol table.
|
|
symbol = wrenSymbolTableEnsure(vm, &vm->methodNames, "<finalize>", 10);
|
|
if (methods.finalize != NULL)
|
|
{
|
|
method.as.foreign = (WrenForeignMethodFn)methods.finalize;
|
|
wrenBindMethod(vm, classObj, symbol, method);
|
|
}
|
|
}
|
|
|
|
// Completes the process for creating a new class.
|
|
//
|
|
// The class attributes instance and the class itself should be on the
|
|
// top of the fiber's stack.
|
|
//
|
|
// This process handles moving the attribute data for a class from
|
|
// compile time to runtime, since it now has all the attributes associated
|
|
// with a class, including for methods.
|
|
static void endClass(WrenVM* vm)
|
|
{
|
|
// Pull the attributes and class off the stack
|
|
Value attributes = vm->fiber->stackTop[-2];
|
|
Value classValue = vm->fiber->stackTop[-1];
|
|
|
|
// Remove the stack items
|
|
vm->fiber->stackTop -= 2;
|
|
|
|
ObjClass* classObj = AS_CLASS(classValue);
|
|
classObj->attributes = attributes;
|
|
}
|
|
|
|
// Creates a new class.
|
|
//
|
|
// If [numFields] is -1, the class is a foreign class. The name and superclass
|
|
// should be on top of the fiber's stack. After calling this, the top of the
|
|
// stack will contain the new class.
|
|
//
|
|
// Aborts the current fiber if an error occurs.
|
|
static void createClass(WrenVM* vm, int numFields, ObjModule* module)
|
|
{
|
|
// Pull the name and superclass off the stack.
|
|
Value name = vm->fiber->stackTop[-2];
|
|
Value superclass = vm->fiber->stackTop[-1];
|
|
|
|
// We have two values on the stack and we are going to leave one, so discard
|
|
// the other slot.
|
|
vm->fiber->stackTop--;
|
|
|
|
vm->fiber->error = validateSuperclass(vm, name, superclass, numFields);
|
|
if (wrenHasError(vm->fiber)) return;
|
|
|
|
ObjClass* classObj = wrenNewClass(vm, AS_CLASS(superclass), numFields,
|
|
AS_STRING(name));
|
|
vm->fiber->stackTop[-1] = OBJ_VAL(classObj);
|
|
|
|
if (numFields == -1) bindForeignClass(vm, classObj, module);
|
|
}
|
|
|
|
static void createForeign(WrenVM* vm, ObjFiber* fiber, Value* stack)
|
|
{
|
|
ObjClass* classObj = AS_CLASS(stack[0]);
|
|
ASSERT(classObj->numFields == -1, "Class must be a foreign class.");
|
|
|
|
// TODO: Don't look up every time.
|
|
int symbol = wrenSymbolTableFind(&vm->methodNames, "<allocate>", 10);
|
|
ASSERT(symbol != -1, "Should have defined <allocate> symbol.");
|
|
|
|
ASSERT(classObj->methods.count > symbol, "Class should have allocator.");
|
|
Method* method = &classObj->methods.data[symbol];
|
|
ASSERT(method->type == METHOD_FOREIGN, "Allocator should be foreign.");
|
|
|
|
// Pass the constructor arguments to the allocator as well.
|
|
ASSERT(vm->apiStack == NULL, "Cannot already be in foreign call.");
|
|
vm->apiStack = stack;
|
|
|
|
method->as.foreign(vm);
|
|
|
|
vm->apiStack = NULL;
|
|
}
|
|
|
|
void wrenFinalizeForeign(WrenVM* vm, ObjForeign* foreign)
|
|
{
|
|
// TODO: Don't look up every time.
|
|
int symbol = wrenSymbolTableFind(&vm->methodNames, "<finalize>", 10);
|
|
ASSERT(symbol != -1, "Should have defined <finalize> symbol.");
|
|
|
|
// If there are no finalizers, don't finalize it.
|
|
if (symbol == -1) return;
|
|
|
|
// If the class doesn't have a finalizer, bail out.
|
|
ObjClass* classObj = foreign->obj.classObj;
|
|
if (symbol >= classObj->methods.count) return;
|
|
|
|
Method* method = &classObj->methods.data[symbol];
|
|
if (method->type == METHOD_NONE) return;
|
|
|
|
ASSERT(method->type == METHOD_FOREIGN, "Finalizer should be foreign.");
|
|
|
|
WrenFinalizerFn finalizer = (WrenFinalizerFn)method->as.foreign;
|
|
finalizer(foreign->data);
|
|
}
|
|
|
|
// Let the host resolve an imported module name if it wants to.
|
|
static Value resolveModule(WrenVM* vm, Value name)
|
|
{
|
|
// If the host doesn't care to resolve, leave the name alone.
|
|
if (vm->config.resolveModuleFn == NULL) return name;
|
|
|
|
ObjFiber* fiber = vm->fiber;
|
|
ObjFn* fn = fiber->frames[fiber->numFrames - 1].closure->fn;
|
|
ObjString* importer = fn->module->name;
|
|
|
|
const char* resolved = vm->config.resolveModuleFn(vm, importer->value,
|
|
AS_CSTRING(name));
|
|
if (resolved == NULL)
|
|
{
|
|
vm->fiber->error = wrenStringFormat(vm,
|
|
"Could not resolve module '@' imported from '@'.",
|
|
name, OBJ_VAL(importer));
|
|
return NULL_VAL;
|
|
}
|
|
|
|
// If they resolved to the exact same string, we don't need to copy it.
|
|
if (resolved == AS_CSTRING(name)) return name;
|
|
|
|
// Copy the string into a Wren String object.
|
|
name = wrenNewString(vm, resolved);
|
|
DEALLOCATE(vm, (char*)resolved);
|
|
return name;
|
|
}
|
|
|
|
static Value importModule(WrenVM* vm, Value name)
|
|
{
|
|
name = resolveModule(vm, name);
|
|
|
|
// If the module is already loaded, we don't need to do anything.
|
|
Value existing = wrenMapGet(vm->modules, name);
|
|
if (!IS_UNDEFINED(existing)) return existing;
|
|
|
|
wrenPushRoot(vm, AS_OBJ(name));
|
|
|
|
WrenLoadModuleResult result = {0};
|
|
const char* source = NULL;
|
|
|
|
// Let the host try to provide the module.
|
|
if (vm->config.loadModuleFn != NULL)
|
|
{
|
|
result = vm->config.loadModuleFn(vm, AS_CSTRING(name));
|
|
}
|
|
|
|
// If the host didn't provide it, see if it's a built in optional module.
|
|
if (result.source == NULL)
|
|
{
|
|
result.onComplete = NULL;
|
|
ObjString* nameString = AS_STRING(name);
|
|
#if WREN_OPT_META
|
|
if (strcmp(nameString->value, "meta") == 0) result.source = wrenMetaSource();
|
|
#endif
|
|
#if WREN_OPT_RANDOM
|
|
if (strcmp(nameString->value, "random") == 0) result.source = wrenRandomSource();
|
|
#endif
|
|
}
|
|
|
|
if (result.source == NULL)
|
|
{
|
|
vm->fiber->error = wrenStringFormat(vm, "Could not load module '@'.", name);
|
|
wrenPopRoot(vm); // name.
|
|
return NULL_VAL;
|
|
}
|
|
|
|
ObjClosure* moduleClosure = compileInModule(vm, name, result.source, false, true);
|
|
|
|
// Now that we're done, give the result back in case there's cleanup to do.
|
|
if(result.onComplete) result.onComplete(vm, AS_CSTRING(name), result);
|
|
|
|
if (moduleClosure == NULL)
|
|
{
|
|
vm->fiber->error = wrenStringFormat(vm,
|
|
"Could not compile module '@'.", name);
|
|
wrenPopRoot(vm); // name.
|
|
return NULL_VAL;
|
|
}
|
|
|
|
wrenPopRoot(vm); // name.
|
|
|
|
// Return the closure that executes the module.
|
|
return OBJ_VAL(moduleClosure);
|
|
}
|
|
|
|
static Value getModuleVariable(WrenVM* vm, ObjModule* module,
|
|
Value variableName)
|
|
{
|
|
ObjString* variable = AS_STRING(variableName);
|
|
uint32_t variableEntry = wrenSymbolTableFind(&module->variableNames,
|
|
variable->value,
|
|
variable->length);
|
|
|
|
// It's a runtime error if the imported variable does not exist.
|
|
if (variableEntry != UINT32_MAX)
|
|
{
|
|
return module->variables.data[variableEntry];
|
|
}
|
|
|
|
vm->fiber->error = wrenStringFormat(vm,
|
|
"Could not find a variable named '@' in module '@'.",
|
|
variableName, OBJ_VAL(module->name));
|
|
return NULL_VAL;
|
|
}
|
|
|
|
inline static bool checkArity(WrenVM* vm, Value value, int numArgs)
|
|
{
|
|
ASSERT(IS_CLOSURE(value), "Receiver must be a closure.");
|
|
ObjFn* fn = AS_CLOSURE(value)->fn;
|
|
|
|
// We only care about missing arguments, not extras. The "- 1" is because
|
|
// numArgs includes the receiver, the function itself, which we don't want to
|
|
// count.
|
|
if (numArgs - 1 >= fn->arity) return true;
|
|
|
|
vm->fiber->error = CONST_STRING(vm, "Function expects more arguments.");
|
|
return false;
|
|
}
|
|
|
|
|
|
// The main bytecode interpreter loop. This is where the magic happens. It is
|
|
// also, as you can imagine, highly performance critical.
|
|
static WrenInterpretResult runInterpreter(WrenVM* vm, register ObjFiber* fiber)
|
|
{
|
|
// Remember the current fiber so we can find it if a GC happens.
|
|
vm->fiber = fiber;
|
|
fiber->state = FIBER_ROOT;
|
|
|
|
// Hoist these into local variables. They are accessed frequently in the loop
|
|
// but assigned less frequently. Keeping them in locals and updating them when
|
|
// a call frame has been pushed or popped gives a large speed boost.
|
|
register CallFrame* frame;
|
|
register Value* stackStart;
|
|
register uint8_t* ip;
|
|
register ObjFn* fn;
|
|
|
|
// These macros are designed to only be invoked within this function.
|
|
#define PUSH(value) (*fiber->stackTop++ = value)
|
|
#define POP() (*(--fiber->stackTop))
|
|
#define DROP() (fiber->stackTop--)
|
|
#define PEEK() (*(fiber->stackTop - 1))
|
|
#define PEEK2() (*(fiber->stackTop - 2))
|
|
#define READ_BYTE() (*ip++)
|
|
#define READ_SHORT() (ip += 2, (uint16_t)((ip[-2] << 8) | ip[-1]))
|
|
|
|
// Use this before a CallFrame is pushed to store the local variables back
|
|
// into the current one.
|
|
#define STORE_FRAME() frame->ip = ip
|
|
|
|
// Use this after a CallFrame has been pushed or popped to refresh the local
|
|
// variables.
|
|
#define LOAD_FRAME() \
|
|
do \
|
|
{ \
|
|
frame = &fiber->frames[fiber->numFrames - 1]; \
|
|
stackStart = frame->stackStart; \
|
|
ip = frame->ip; \
|
|
fn = frame->closure->fn; \
|
|
} while (false)
|
|
|
|
// Terminates the current fiber with error string [error]. If another calling
|
|
// fiber is willing to catch the error, transfers control to it, otherwise
|
|
// exits the interpreter.
|
|
#define RUNTIME_ERROR() \
|
|
do \
|
|
{ \
|
|
STORE_FRAME(); \
|
|
runtimeError(vm); \
|
|
if (vm->fiber == NULL) return WREN_RESULT_RUNTIME_ERROR; \
|
|
fiber = vm->fiber; \
|
|
LOAD_FRAME(); \
|
|
DISPATCH(); \
|
|
} while (false)
|
|
|
|
#if WREN_DEBUG_TRACE_INSTRUCTIONS
|
|
// Prints the stack and instruction before each instruction is executed.
|
|
#define DEBUG_TRACE_INSTRUCTIONS() \
|
|
do \
|
|
{ \
|
|
wrenDumpStack(fiber); \
|
|
wrenDumpInstruction(vm, fn, (int)(ip - fn->code.data)); \
|
|
} while (false)
|
|
#else
|
|
#define DEBUG_TRACE_INSTRUCTIONS() do { } while (false)
|
|
#endif
|
|
|
|
#if WREN_COMPUTED_GOTO
|
|
|
|
static void* dispatchTable[] = {
|
|
#define OPCODE(name, _) &&code_##name,
|
|
// Begin file "wren_opcodes.h"
|
|
// This defines the bytecode instructions used by the VM. It does so by invoking
|
|
// an OPCODE() macro which is expected to be defined at the point that this is
|
|
// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.)
|
|
//
|
|
// The first argument is the name of the opcode. The second is its "stack
|
|
// effect" -- the amount that the op code changes the size of the stack. A
|
|
// stack effect of 1 means it pushes a value and the stack grows one larger.
|
|
// -2 means it pops two values, etc.
|
|
//
|
|
// Note that the order of instructions here affects the order of the dispatch
|
|
// table in the VM's interpreter loop. That in turn affects caching which
|
|
// affects overall performance. Take care to run benchmarks if you change the
|
|
// order here.
|
|
|
|
// Load the constant at index [arg].
|
|
OPCODE(CONSTANT, 1)
|
|
|
|
// Push null onto the stack.
|
|
OPCODE(NULL, 1)
|
|
|
|
// Push false onto the stack.
|
|
OPCODE(FALSE, 1)
|
|
|
|
// Push true onto the stack.
|
|
OPCODE(TRUE, 1)
|
|
|
|
// Pushes the value in the given local slot.
|
|
OPCODE(LOAD_LOCAL_0, 1)
|
|
OPCODE(LOAD_LOCAL_1, 1)
|
|
OPCODE(LOAD_LOCAL_2, 1)
|
|
OPCODE(LOAD_LOCAL_3, 1)
|
|
OPCODE(LOAD_LOCAL_4, 1)
|
|
OPCODE(LOAD_LOCAL_5, 1)
|
|
OPCODE(LOAD_LOCAL_6, 1)
|
|
OPCODE(LOAD_LOCAL_7, 1)
|
|
OPCODE(LOAD_LOCAL_8, 1)
|
|
|
|
// Note: The compiler assumes the following _STORE instructions always
|
|
// immediately follow their corresponding _LOAD ones.
|
|
|
|
// Pushes the value in local slot [arg].
|
|
OPCODE(LOAD_LOCAL, 1)
|
|
|
|
// Stores the top of stack in local slot [arg]. Does not pop it.
|
|
OPCODE(STORE_LOCAL, 0)
|
|
|
|
// Pushes the value in upvalue [arg].
|
|
OPCODE(LOAD_UPVALUE, 1)
|
|
|
|
// Stores the top of stack in upvalue [arg]. Does not pop it.
|
|
OPCODE(STORE_UPVALUE, 0)
|
|
|
|
// Pushes the value of the top-level variable in slot [arg].
|
|
OPCODE(LOAD_MODULE_VAR, 1)
|
|
|
|
// Stores the top of stack in top-level variable slot [arg]. Does not pop it.
|
|
OPCODE(STORE_MODULE_VAR, 0)
|
|
|
|
// Pushes the value of the field in slot [arg] of the receiver of the current
|
|
// function. This is used for regular field accesses on "this" directly in
|
|
// methods. This instruction is faster than the more general CODE_LOAD_FIELD
|
|
// instruction.
|
|
OPCODE(LOAD_FIELD_THIS, 1)
|
|
|
|
// Stores the top of the stack in field slot [arg] in the receiver of the
|
|
// current value. Does not pop the value. This instruction is faster than the
|
|
// more general CODE_LOAD_FIELD instruction.
|
|
OPCODE(STORE_FIELD_THIS, 0)
|
|
|
|
// Pops an instance and pushes the value of the field in slot [arg] of it.
|
|
OPCODE(LOAD_FIELD, 0)
|
|
|
|
// Pops an instance and stores the subsequent top of stack in field slot
|
|
// [arg] in it. Does not pop the value.
|
|
OPCODE(STORE_FIELD, -1)
|
|
|
|
// Pop and discard the top of stack.
|
|
OPCODE(POP, -1)
|
|
|
|
// Invoke the method with symbol [arg]. The number indicates the number of
|
|
// arguments (not including the receiver).
|
|
OPCODE(CALL_0, 0)
|
|
OPCODE(CALL_1, -1)
|
|
OPCODE(CALL_2, -2)
|
|
OPCODE(CALL_3, -3)
|
|
OPCODE(CALL_4, -4)
|
|
OPCODE(CALL_5, -5)
|
|
OPCODE(CALL_6, -6)
|
|
OPCODE(CALL_7, -7)
|
|
OPCODE(CALL_8, -8)
|
|
OPCODE(CALL_9, -9)
|
|
OPCODE(CALL_10, -10)
|
|
OPCODE(CALL_11, -11)
|
|
OPCODE(CALL_12, -12)
|
|
OPCODE(CALL_13, -13)
|
|
OPCODE(CALL_14, -14)
|
|
OPCODE(CALL_15, -15)
|
|
OPCODE(CALL_16, -16)
|
|
|
|
// Invoke a superclass method with symbol [arg]. The number indicates the
|
|
// number of arguments (not including the receiver).
|
|
OPCODE(SUPER_0, 0)
|
|
OPCODE(SUPER_1, -1)
|
|
OPCODE(SUPER_2, -2)
|
|
OPCODE(SUPER_3, -3)
|
|
OPCODE(SUPER_4, -4)
|
|
OPCODE(SUPER_5, -5)
|
|
OPCODE(SUPER_6, -6)
|
|
OPCODE(SUPER_7, -7)
|
|
OPCODE(SUPER_8, -8)
|
|
OPCODE(SUPER_9, -9)
|
|
OPCODE(SUPER_10, -10)
|
|
OPCODE(SUPER_11, -11)
|
|
OPCODE(SUPER_12, -12)
|
|
OPCODE(SUPER_13, -13)
|
|
OPCODE(SUPER_14, -14)
|
|
OPCODE(SUPER_15, -15)
|
|
OPCODE(SUPER_16, -16)
|
|
|
|
// Jump the instruction pointer [arg] forward.
|
|
OPCODE(JUMP, 0)
|
|
|
|
// Jump the instruction pointer [arg] backward.
|
|
OPCODE(LOOP, 0)
|
|
|
|
// Pop and if not truthy then jump the instruction pointer [arg] forward.
|
|
OPCODE(JUMP_IF, -1)
|
|
|
|
// If the top of the stack is false, jump [arg] forward. Otherwise, pop and
|
|
// continue.
|
|
OPCODE(AND, -1)
|
|
|
|
// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop
|
|
// and continue.
|
|
OPCODE(OR, -1)
|
|
|
|
// Close the upvalue for the local on the top of the stack, then pop it.
|
|
OPCODE(CLOSE_UPVALUE, -1)
|
|
|
|
// Exit from the current function and return the value on the top of the
|
|
// stack.
|
|
OPCODE(RETURN, 0)
|
|
|
|
// Creates a closure for the function stored at [arg] in the constant table.
|
|
//
|
|
// Following the function argument is a number of arguments, two for each
|
|
// upvalue. The first is true if the variable being captured is a local (as
|
|
// opposed to an upvalue), and the second is the index of the local or
|
|
// upvalue being captured.
|
|
//
|
|
// Pushes the created closure.
|
|
OPCODE(CLOSURE, 1)
|
|
|
|
// Creates a new instance of a class.
|
|
//
|
|
// Assumes the class object is in slot zero, and replaces it with the new
|
|
// uninitialized instance of that class. This opcode is only emitted by the
|
|
// compiler-generated constructor metaclass methods.
|
|
OPCODE(CONSTRUCT, 0)
|
|
|
|
// Creates a new instance of a foreign class.
|
|
//
|
|
// Assumes the class object is in slot zero, and replaces it with the new
|
|
// uninitialized instance of that class. This opcode is only emitted by the
|
|
// compiler-generated constructor metaclass methods.
|
|
OPCODE(FOREIGN_CONSTRUCT, 0)
|
|
|
|
// Creates a class. Top of stack is the superclass. Below that is a string for
|
|
// the name of the class. Byte [arg] is the number of fields in the class.
|
|
OPCODE(CLASS, -1)
|
|
|
|
// Ends a class.
|
|
// Atm the stack contains the class and the ClassAttributes (or null).
|
|
OPCODE(END_CLASS, -2)
|
|
|
|
// Creates a foreign class. Top of stack is the superclass. Below that is a
|
|
// string for the name of the class.
|
|
OPCODE(FOREIGN_CLASS, -1)
|
|
|
|
// Define a method for symbol [arg]. The class receiving the method is popped
|
|
// off the stack, then the function defining the body is popped.
|
|
//
|
|
// If a foreign method is being defined, the "function" will be a string
|
|
// identifying the foreign method. Otherwise, it will be a function or
|
|
// closure.
|
|
OPCODE(METHOD_INSTANCE, -2)
|
|
|
|
// Define a method for symbol [arg]. The class whose metaclass will receive
|
|
// the method is popped off the stack, then the function defining the body is
|
|
// popped.
|
|
//
|
|
// If a foreign method is being defined, the "function" will be a string
|
|
// identifying the foreign method. Otherwise, it will be a function or
|
|
// closure.
|
|
OPCODE(METHOD_STATIC, -2)
|
|
|
|
// This is executed at the end of the module's body. Pushes NULL onto the stack
|
|
// as the "return value" of the import statement and stores the module as the
|
|
// most recently imported one.
|
|
OPCODE(END_MODULE, 1)
|
|
|
|
// Import a module whose name is the string stored at [arg] in the constant
|
|
// table.
|
|
//
|
|
// Pushes null onto the stack so that the fiber for the imported module can
|
|
// replace that with a dummy value when it returns. (Fibers always return a
|
|
// value when resuming a caller.)
|
|
OPCODE(IMPORT_MODULE, 1)
|
|
|
|
// Import a variable from the most recently imported module. The name of the
|
|
// variable to import is at [arg] in the constant table. Pushes the loaded
|
|
// variable's value.
|
|
OPCODE(IMPORT_VARIABLE, 1)
|
|
|
|
// This pseudo-instruction indicates the end of the bytecode. It should
|
|
// always be preceded by a `CODE_RETURN`, so is never actually executed.
|
|
OPCODE(END, 0)
|
|
// End file "wren_opcodes.h"
|
|
#undef OPCODE
|
|
};
|
|
|
|
#define INTERPRET_LOOP DISPATCH();
|
|
#define CASE_CODE(name) code_##name
|
|
|
|
#define DISPATCH() \
|
|
do \
|
|
{ \
|
|
DEBUG_TRACE_INSTRUCTIONS(); \
|
|
goto *dispatchTable[instruction = (Code)READ_BYTE()]; \
|
|
} while (false)
|
|
|
|
#else
|
|
|
|
#define INTERPRET_LOOP \
|
|
loop: \
|
|
DEBUG_TRACE_INSTRUCTIONS(); \
|
|
switch (instruction = (Code)READ_BYTE())
|
|
|
|
#define CASE_CODE(name) case CODE_##name
|
|
#define DISPATCH() goto loop
|
|
|
|
#endif
|
|
|
|
LOAD_FRAME();
|
|
|
|
Code instruction;
|
|
INTERPRET_LOOP
|
|
{
|
|
CASE_CODE(LOAD_LOCAL_0):
|
|
CASE_CODE(LOAD_LOCAL_1):
|
|
CASE_CODE(LOAD_LOCAL_2):
|
|
CASE_CODE(LOAD_LOCAL_3):
|
|
CASE_CODE(LOAD_LOCAL_4):
|
|
CASE_CODE(LOAD_LOCAL_5):
|
|
CASE_CODE(LOAD_LOCAL_6):
|
|
CASE_CODE(LOAD_LOCAL_7):
|
|
CASE_CODE(LOAD_LOCAL_8):
|
|
PUSH(stackStart[instruction - CODE_LOAD_LOCAL_0]);
|
|
DISPATCH();
|
|
|
|
CASE_CODE(LOAD_LOCAL):
|
|
PUSH(stackStart[READ_BYTE()]);
|
|
DISPATCH();
|
|
|
|
CASE_CODE(LOAD_FIELD_THIS):
|
|
{
|
|
uint8_t field = READ_BYTE();
|
|
Value receiver = stackStart[0];
|
|
ASSERT(IS_INSTANCE(receiver), "Receiver should be instance.");
|
|
ObjInstance* instance = AS_INSTANCE(receiver);
|
|
ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field.");
|
|
PUSH(instance->fields[field]);
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(POP): DROP(); DISPATCH();
|
|
CASE_CODE(NULL): PUSH(NULL_VAL); DISPATCH();
|
|
CASE_CODE(FALSE): PUSH(FALSE_VAL); DISPATCH();
|
|
CASE_CODE(TRUE): PUSH(TRUE_VAL); DISPATCH();
|
|
|
|
CASE_CODE(STORE_LOCAL):
|
|
stackStart[READ_BYTE()] = PEEK();
|
|
DISPATCH();
|
|
|
|
CASE_CODE(CONSTANT):
|
|
PUSH(fn->constants.data[READ_SHORT()]);
|
|
DISPATCH();
|
|
|
|
{
|
|
// The opcodes for doing method and superclass calls share a lot of code.
|
|
// However, doing an if() test in the middle of the instruction sequence
|
|
// to handle the bit that is special to super calls makes the non-super
|
|
// call path noticeably slower.
|
|
//
|
|
// Instead, we do this old school using an explicit goto to share code for
|
|
// everything at the tail end of the call-handling code that is the same
|
|
// between normal and superclass calls.
|
|
int numArgs;
|
|
int symbol;
|
|
|
|
Value* args;
|
|
ObjClass* classObj;
|
|
|
|
Method* method;
|
|
|
|
CASE_CODE(CALL_0):
|
|
CASE_CODE(CALL_1):
|
|
CASE_CODE(CALL_2):
|
|
CASE_CODE(CALL_3):
|
|
CASE_CODE(CALL_4):
|
|
CASE_CODE(CALL_5):
|
|
CASE_CODE(CALL_6):
|
|
CASE_CODE(CALL_7):
|
|
CASE_CODE(CALL_8):
|
|
CASE_CODE(CALL_9):
|
|
CASE_CODE(CALL_10):
|
|
CASE_CODE(CALL_11):
|
|
CASE_CODE(CALL_12):
|
|
CASE_CODE(CALL_13):
|
|
CASE_CODE(CALL_14):
|
|
CASE_CODE(CALL_15):
|
|
CASE_CODE(CALL_16):
|
|
// Add one for the implicit receiver argument.
|
|
numArgs = instruction - CODE_CALL_0 + 1;
|
|
symbol = READ_SHORT();
|
|
|
|
// The receiver is the first argument.
|
|
args = fiber->stackTop - numArgs;
|
|
classObj = wrenGetClassInline(vm, args[0]);
|
|
goto completeCall;
|
|
|
|
CASE_CODE(SUPER_0):
|
|
CASE_CODE(SUPER_1):
|
|
CASE_CODE(SUPER_2):
|
|
CASE_CODE(SUPER_3):
|
|
CASE_CODE(SUPER_4):
|
|
CASE_CODE(SUPER_5):
|
|
CASE_CODE(SUPER_6):
|
|
CASE_CODE(SUPER_7):
|
|
CASE_CODE(SUPER_8):
|
|
CASE_CODE(SUPER_9):
|
|
CASE_CODE(SUPER_10):
|
|
CASE_CODE(SUPER_11):
|
|
CASE_CODE(SUPER_12):
|
|
CASE_CODE(SUPER_13):
|
|
CASE_CODE(SUPER_14):
|
|
CASE_CODE(SUPER_15):
|
|
CASE_CODE(SUPER_16):
|
|
// Add one for the implicit receiver argument.
|
|
numArgs = instruction - CODE_SUPER_0 + 1;
|
|
symbol = READ_SHORT();
|
|
|
|
// The receiver is the first argument.
|
|
args = fiber->stackTop - numArgs;
|
|
|
|
// The superclass is stored in a constant.
|
|
classObj = AS_CLASS(fn->constants.data[READ_SHORT()]);
|
|
goto completeCall;
|
|
|
|
completeCall:
|
|
// If the class's method table doesn't include the symbol, bail.
|
|
if (symbol >= classObj->methods.count ||
|
|
(method = &classObj->methods.data[symbol])->type == METHOD_NONE)
|
|
{
|
|
methodNotFound(vm, classObj, symbol);
|
|
RUNTIME_ERROR();
|
|
}
|
|
|
|
switch (method->type)
|
|
{
|
|
case METHOD_PRIMITIVE:
|
|
if (method->as.primitive(vm, args))
|
|
{
|
|
// The result is now in the first arg slot. Discard the other
|
|
// stack slots.
|
|
fiber->stackTop -= numArgs - 1;
|
|
} else {
|
|
// An error, fiber switch, or call frame change occurred.
|
|
STORE_FRAME();
|
|
|
|
// If we don't have a fiber to switch to, stop interpreting.
|
|
fiber = vm->fiber;
|
|
if (fiber == NULL) return WREN_RESULT_SUCCESS;
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
LOAD_FRAME();
|
|
}
|
|
break;
|
|
|
|
case METHOD_FUNCTION_CALL:
|
|
if (!checkArity(vm, args[0], numArgs)) {
|
|
RUNTIME_ERROR();
|
|
break;
|
|
}
|
|
|
|
STORE_FRAME();
|
|
method->as.primitive(vm, args);
|
|
LOAD_FRAME();
|
|
break;
|
|
|
|
case METHOD_FOREIGN:
|
|
callForeign(vm, fiber, method->as.foreign, numArgs);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
break;
|
|
|
|
case METHOD_BLOCK:
|
|
STORE_FRAME();
|
|
wrenCallFunction(vm, fiber, (ObjClosure*)method->as.closure, numArgs);
|
|
LOAD_FRAME();
|
|
break;
|
|
|
|
case METHOD_NONE:
|
|
UNREACHABLE();
|
|
break;
|
|
}
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(LOAD_UPVALUE):
|
|
{
|
|
ObjUpvalue** upvalues = frame->closure->upvalues;
|
|
PUSH(*upvalues[READ_BYTE()]->value);
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(STORE_UPVALUE):
|
|
{
|
|
ObjUpvalue** upvalues = frame->closure->upvalues;
|
|
*upvalues[READ_BYTE()]->value = PEEK();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(LOAD_MODULE_VAR):
|
|
PUSH(fn->module->variables.data[READ_SHORT()]);
|
|
DISPATCH();
|
|
|
|
CASE_CODE(STORE_MODULE_VAR):
|
|
fn->module->variables.data[READ_SHORT()] = PEEK();
|
|
DISPATCH();
|
|
|
|
CASE_CODE(STORE_FIELD_THIS):
|
|
{
|
|
uint8_t field = READ_BYTE();
|
|
Value receiver = stackStart[0];
|
|
ASSERT(IS_INSTANCE(receiver), "Receiver should be instance.");
|
|
ObjInstance* instance = AS_INSTANCE(receiver);
|
|
ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field.");
|
|
instance->fields[field] = PEEK();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(LOAD_FIELD):
|
|
{
|
|
uint8_t field = READ_BYTE();
|
|
Value receiver = POP();
|
|
ASSERT(IS_INSTANCE(receiver), "Receiver should be instance.");
|
|
ObjInstance* instance = AS_INSTANCE(receiver);
|
|
ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field.");
|
|
PUSH(instance->fields[field]);
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(STORE_FIELD):
|
|
{
|
|
uint8_t field = READ_BYTE();
|
|
Value receiver = POP();
|
|
ASSERT(IS_INSTANCE(receiver), "Receiver should be instance.");
|
|
ObjInstance* instance = AS_INSTANCE(receiver);
|
|
ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field.");
|
|
instance->fields[field] = PEEK();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(JUMP):
|
|
{
|
|
uint16_t offset = READ_SHORT();
|
|
ip += offset;
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(LOOP):
|
|
{
|
|
// Jump back to the top of the loop.
|
|
uint16_t offset = READ_SHORT();
|
|
ip -= offset;
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(JUMP_IF):
|
|
{
|
|
uint16_t offset = READ_SHORT();
|
|
Value condition = POP();
|
|
|
|
if (wrenIsFalsyValue(condition)) ip += offset;
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(AND):
|
|
{
|
|
uint16_t offset = READ_SHORT();
|
|
Value condition = PEEK();
|
|
|
|
if (wrenIsFalsyValue(condition))
|
|
{
|
|
// Short-circuit the right hand side.
|
|
ip += offset;
|
|
}
|
|
else
|
|
{
|
|
// Discard the condition and evaluate the right hand side.
|
|
DROP();
|
|
}
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(OR):
|
|
{
|
|
uint16_t offset = READ_SHORT();
|
|
Value condition = PEEK();
|
|
|
|
if (wrenIsFalsyValue(condition))
|
|
{
|
|
// Discard the condition and evaluate the right hand side.
|
|
DROP();
|
|
}
|
|
else
|
|
{
|
|
// Short-circuit the right hand side.
|
|
ip += offset;
|
|
}
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(CLOSE_UPVALUE):
|
|
// Close the upvalue for the local if we have one.
|
|
closeUpvalues(fiber, fiber->stackTop - 1);
|
|
DROP();
|
|
DISPATCH();
|
|
|
|
CASE_CODE(RETURN):
|
|
{
|
|
Value result = POP();
|
|
fiber->numFrames--;
|
|
|
|
// Close any upvalues still in scope.
|
|
closeUpvalues(fiber, stackStart);
|
|
|
|
// If the fiber is complete, end it.
|
|
if (fiber->numFrames == 0)
|
|
{
|
|
// See if there's another fiber to return to. If not, we're done.
|
|
if (fiber->caller == NULL)
|
|
{
|
|
// Store the final result value at the beginning of the stack so the
|
|
// C API can get it.
|
|
fiber->stack[0] = result;
|
|
fiber->stackTop = fiber->stack + 1;
|
|
return WREN_RESULT_SUCCESS;
|
|
}
|
|
|
|
ObjFiber* resumingFiber = fiber->caller;
|
|
fiber->caller = NULL;
|
|
fiber = resumingFiber;
|
|
vm->fiber = resumingFiber;
|
|
|
|
// Store the result in the resuming fiber.
|
|
fiber->stackTop[-1] = result;
|
|
}
|
|
else
|
|
{
|
|
// Store the result of the block in the first slot, which is where the
|
|
// caller expects it.
|
|
stackStart[0] = result;
|
|
|
|
// Discard the stack slots for the call frame (leaving one slot for the
|
|
// result).
|
|
fiber->stackTop = frame->stackStart + 1;
|
|
}
|
|
|
|
LOAD_FRAME();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(CONSTRUCT):
|
|
ASSERT(IS_CLASS(stackStart[0]), "'this' should be a class.");
|
|
stackStart[0] = wrenNewInstance(vm, AS_CLASS(stackStart[0]));
|
|
DISPATCH();
|
|
|
|
CASE_CODE(FOREIGN_CONSTRUCT):
|
|
ASSERT(IS_CLASS(stackStart[0]), "'this' should be a class.");
|
|
createForeign(vm, fiber, stackStart);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
DISPATCH();
|
|
|
|
CASE_CODE(CLOSURE):
|
|
{
|
|
// Create the closure and push it on the stack before creating upvalues
|
|
// so that it doesn't get collected.
|
|
ObjFn* function = AS_FN(fn->constants.data[READ_SHORT()]);
|
|
ObjClosure* closure = wrenNewClosure(vm, function);
|
|
PUSH(OBJ_VAL(closure));
|
|
|
|
// Capture upvalues, if any.
|
|
for (int i = 0; i < function->numUpvalues; i++)
|
|
{
|
|
uint8_t isLocal = READ_BYTE();
|
|
uint8_t index = READ_BYTE();
|
|
if (isLocal)
|
|
{
|
|
// Make an new upvalue to close over the parent's local variable.
|
|
closure->upvalues[i] = captureUpvalue(vm, fiber,
|
|
frame->stackStart + index);
|
|
}
|
|
else
|
|
{
|
|
// Use the same upvalue as the current call frame.
|
|
closure->upvalues[i] = frame->closure->upvalues[index];
|
|
}
|
|
}
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(END_CLASS):
|
|
{
|
|
endClass(vm);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(CLASS):
|
|
{
|
|
createClass(vm, READ_BYTE(), NULL);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(FOREIGN_CLASS):
|
|
{
|
|
createClass(vm, -1, fn->module);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(METHOD_INSTANCE):
|
|
CASE_CODE(METHOD_STATIC):
|
|
{
|
|
uint16_t symbol = READ_SHORT();
|
|
ObjClass* classObj = AS_CLASS(PEEK());
|
|
Value method = PEEK2();
|
|
bindMethod(vm, instruction, symbol, fn->module, classObj, method);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
DROP();
|
|
DROP();
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(END_MODULE):
|
|
{
|
|
vm->lastModule = fn->module;
|
|
PUSH(NULL_VAL);
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(IMPORT_MODULE):
|
|
{
|
|
// Make a slot on the stack for the module's fiber to place the return
|
|
// value. It will be popped after this fiber is resumed. Store the
|
|
// imported module's closure in the slot in case a GC happens when
|
|
// invoking the closure.
|
|
PUSH(importModule(vm, fn->constants.data[READ_SHORT()]));
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
|
|
// If we get a closure, call it to execute the module body.
|
|
if (IS_CLOSURE(PEEK()))
|
|
{
|
|
STORE_FRAME();
|
|
ObjClosure* closure = AS_CLOSURE(PEEK());
|
|
wrenCallFunction(vm, fiber, closure, 1);
|
|
LOAD_FRAME();
|
|
}
|
|
else
|
|
{
|
|
// The module has already been loaded. Remember it so we can import
|
|
// variables from it if needed.
|
|
vm->lastModule = AS_MODULE(PEEK());
|
|
}
|
|
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(IMPORT_VARIABLE):
|
|
{
|
|
Value variable = fn->constants.data[READ_SHORT()];
|
|
ASSERT(vm->lastModule != NULL, "Should have already imported module.");
|
|
Value result = getModuleVariable(vm, vm->lastModule, variable);
|
|
if (wrenHasError(fiber)) RUNTIME_ERROR();
|
|
|
|
PUSH(result);
|
|
DISPATCH();
|
|
}
|
|
|
|
CASE_CODE(END):
|
|
// A CODE_END should always be preceded by a CODE_RETURN. If we get here,
|
|
// the compiler generated wrong code.
|
|
UNREACHABLE();
|
|
}
|
|
|
|
// We should only exit this function from an explicit return from CODE_RETURN
|
|
// or a runtime error.
|
|
UNREACHABLE();
|
|
return WREN_RESULT_RUNTIME_ERROR;
|
|
|
|
#undef READ_BYTE
|
|
#undef READ_SHORT
|
|
}
|
|
|
|
WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature)
|
|
{
|
|
ASSERT(signature != NULL, "Signature cannot be NULL.");
|
|
|
|
int signatureLength = (int)strlen(signature);
|
|
ASSERT(signatureLength > 0, "Signature cannot be empty.");
|
|
|
|
// Count the number parameters the method expects.
|
|
int numParams = 0;
|
|
if (signature[signatureLength - 1] == ')')
|
|
{
|
|
for (int i = signatureLength - 1; i > 0 && signature[i] != '('; i--)
|
|
{
|
|
if (signature[i] == '_') numParams++;
|
|
}
|
|
}
|
|
|
|
// Count subscript arguments.
|
|
if (signature[0] == '[')
|
|
{
|
|
for (int i = 0; i < signatureLength && signature[i] != ']'; i++)
|
|
{
|
|
if (signature[i] == '_') numParams++;
|
|
}
|
|
}
|
|
|
|
// Add the signatue to the method table.
|
|
int method = wrenSymbolTableEnsure(vm, &vm->methodNames,
|
|
signature, signatureLength);
|
|
|
|
// Create a little stub function that assumes the arguments are on the stack
|
|
// and calls the method.
|
|
ObjFn* fn = wrenNewFunction(vm, NULL, numParams + 1);
|
|
|
|
// Wrap the function in a closure and then in a handle. Do this here so it
|
|
// doesn't get collected as we fill it in.
|
|
WrenHandle* value = wrenMakeHandle(vm, OBJ_VAL(fn));
|
|
value->value = OBJ_VAL(wrenNewClosure(vm, fn));
|
|
|
|
wrenByteBufferWrite(vm, &fn->code, (uint8_t)(CODE_CALL_0 + numParams));
|
|
wrenByteBufferWrite(vm, &fn->code, (method >> 8) & 0xff);
|
|
wrenByteBufferWrite(vm, &fn->code, method & 0xff);
|
|
wrenByteBufferWrite(vm, &fn->code, CODE_RETURN);
|
|
wrenByteBufferWrite(vm, &fn->code, CODE_END);
|
|
wrenIntBufferFill(vm, &fn->debug->sourceLines, 0, 5);
|
|
wrenFunctionBindName(vm, fn, signature, signatureLength);
|
|
|
|
return value;
|
|
}
|
|
|
|
WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method)
|
|
{
|
|
ASSERT(method != NULL, "Method cannot be NULL.");
|
|
ASSERT(IS_CLOSURE(method->value), "Method must be a method handle.");
|
|
ASSERT(vm->fiber != NULL, "Must set up arguments for call first.");
|
|
ASSERT(vm->apiStack != NULL, "Must set up arguments for call first.");
|
|
ASSERT(vm->fiber->numFrames == 0, "Can not call from a foreign method.");
|
|
|
|
ObjClosure* closure = AS_CLOSURE(method->value);
|
|
|
|
ASSERT(vm->fiber->stackTop - vm->fiber->stack >= closure->fn->arity,
|
|
"Stack must have enough arguments for method.");
|
|
|
|
// Clear the API stack. Now that wrenCall() has control, we no longer need
|
|
// it. We use this being non-null to tell if re-entrant calls to foreign
|
|
// methods are happening, so it's important to clear it out now so that you
|
|
// can call foreign methods from within calls to wrenCall().
|
|
vm->apiStack = NULL;
|
|
|
|
// Discard any extra temporary slots. We take for granted that the stub
|
|
// function has exactly one slot for each argument.
|
|
vm->fiber->stackTop = &vm->fiber->stack[closure->fn->maxSlots];
|
|
|
|
wrenCallFunction(vm, vm->fiber, closure, 0);
|
|
WrenInterpretResult result = runInterpreter(vm, vm->fiber);
|
|
|
|
// If the call didn't abort, then set up the API stack to point to the
|
|
// beginning of the stack so the host can access the call's return value.
|
|
if (vm->fiber != NULL) vm->apiStack = vm->fiber->stack;
|
|
|
|
return result;
|
|
}
|
|
|
|
WrenHandle* wrenMakeHandle(WrenVM* vm, Value value)
|
|
{
|
|
if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value));
|
|
|
|
// Make a handle for it.
|
|
WrenHandle* handle = ALLOCATE(vm, WrenHandle);
|
|
handle->value = value;
|
|
|
|
if (IS_OBJ(value)) wrenPopRoot(vm);
|
|
|
|
// Add it to the front of the linked list of handles.
|
|
if (vm->handles != NULL) vm->handles->prev = handle;
|
|
handle->prev = NULL;
|
|
handle->next = vm->handles;
|
|
vm->handles = handle;
|
|
|
|
return handle;
|
|
}
|
|
|
|
void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle)
|
|
{
|
|
ASSERT(handle != NULL, "Handle cannot be NULL.");
|
|
|
|
// Update the VM's head pointer if we're releasing the first handle.
|
|
if (vm->handles == handle) vm->handles = handle->next;
|
|
|
|
// Unlink it from the list.
|
|
if (handle->prev != NULL) handle->prev->next = handle->next;
|
|
if (handle->next != NULL) handle->next->prev = handle->prev;
|
|
|
|
// Clear it out. This isn't strictly necessary since we're going to free it,
|
|
// but it makes for easier debugging.
|
|
handle->prev = NULL;
|
|
handle->next = NULL;
|
|
handle->value = NULL_VAL;
|
|
DEALLOCATE(vm, handle);
|
|
}
|
|
|
|
WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module,
|
|
const char* source)
|
|
{
|
|
ObjClosure* closure = wrenCompileSource(vm, module, source, false, true);
|
|
if (closure == NULL) return WREN_RESULT_COMPILE_ERROR;
|
|
|
|
wrenPushRoot(vm, (Obj*)closure);
|
|
ObjFiber* fiber = wrenNewFiber(vm, closure);
|
|
wrenPopRoot(vm); // closure.
|
|
vm->apiStack = NULL;
|
|
|
|
return runInterpreter(vm, fiber);
|
|
}
|
|
|
|
ObjClosure* wrenCompileSource(WrenVM* vm, const char* module, const char* source,
|
|
bool isExpression, bool printErrors)
|
|
{
|
|
Value nameValue = NULL_VAL;
|
|
if (module != NULL)
|
|
{
|
|
nameValue = wrenNewString(vm, module);
|
|
wrenPushRoot(vm, AS_OBJ(nameValue));
|
|
}
|
|
|
|
ObjClosure* closure = compileInModule(vm, nameValue, source,
|
|
isExpression, printErrors);
|
|
|
|
if (module != NULL) wrenPopRoot(vm); // nameValue.
|
|
return closure;
|
|
}
|
|
|
|
Value wrenGetModuleVariable(WrenVM* vm, Value moduleName, Value variableName)
|
|
{
|
|
ObjModule* module = getModule(vm, moduleName);
|
|
if (module == NULL)
|
|
{
|
|
vm->fiber->error = wrenStringFormat(vm, "Module '@' is not loaded.",
|
|
moduleName);
|
|
return NULL_VAL;
|
|
}
|
|
|
|
return getModuleVariable(vm, module, variableName);
|
|
}
|
|
|
|
Value wrenFindVariable(WrenVM* vm, ObjModule* module, const char* name)
|
|
{
|
|
int symbol = wrenSymbolTableFind(&module->variableNames, name, strlen(name));
|
|
return module->variables.data[symbol];
|
|
}
|
|
|
|
int wrenDeclareVariable(WrenVM* vm, ObjModule* module, const char* name,
|
|
size_t length, int line)
|
|
{
|
|
if (module->variables.count == MAX_MODULE_VARS) return -2;
|
|
|
|
// Implicitly defined variables get a "value" that is the line where the
|
|
// variable is first used. We'll use that later to report an error on the
|
|
// right line.
|
|
wrenValueBufferWrite(vm, &module->variables, NUM_VAL(line));
|
|
return wrenSymbolTableAdd(vm, &module->variableNames, name, length);
|
|
}
|
|
|
|
int wrenDefineVariable(WrenVM* vm, ObjModule* module, const char* name,
|
|
size_t length, Value value, int* line)
|
|
{
|
|
if (module->variables.count == MAX_MODULE_VARS) return -2;
|
|
|
|
if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value));
|
|
|
|
// See if the variable is already explicitly or implicitly declared.
|
|
int symbol = wrenSymbolTableFind(&module->variableNames, name, length);
|
|
|
|
if (symbol == -1)
|
|
{
|
|
// Brand new variable.
|
|
symbol = wrenSymbolTableAdd(vm, &module->variableNames, name, length);
|
|
wrenValueBufferWrite(vm, &module->variables, value);
|
|
}
|
|
else if (IS_NUM(module->variables.data[symbol]))
|
|
{
|
|
// An implicitly declared variable's value will always be a number.
|
|
// Now we have a real definition.
|
|
if(line) *line = (int)AS_NUM(module->variables.data[symbol]);
|
|
module->variables.data[symbol] = value;
|
|
|
|
// If this was a localname we want to error if it was
|
|
// referenced before this definition.
|
|
if (wrenIsLocalName(name)) symbol = -3;
|
|
}
|
|
else
|
|
{
|
|
// Already explicitly declared.
|
|
symbol = -1;
|
|
}
|
|
|
|
if (IS_OBJ(value)) wrenPopRoot(vm);
|
|
|
|
return symbol;
|
|
}
|
|
|
|
// TODO: Inline?
|
|
void wrenPushRoot(WrenVM* vm, Obj* obj)
|
|
{
|
|
ASSERT(obj != NULL, "Can't root NULL.");
|
|
ASSERT(vm->numTempRoots < WREN_MAX_TEMP_ROOTS, "Too many temporary roots.");
|
|
|
|
vm->tempRoots[vm->numTempRoots++] = obj;
|
|
}
|
|
|
|
void wrenPopRoot(WrenVM* vm)
|
|
{
|
|
ASSERT(vm->numTempRoots > 0, "No temporary roots to release.");
|
|
vm->numTempRoots--;
|
|
}
|
|
|
|
int wrenGetSlotCount(WrenVM* vm)
|
|
{
|
|
if (vm->apiStack == NULL) return 0;
|
|
|
|
return (int)(vm->fiber->stackTop - vm->apiStack);
|
|
}
|
|
|
|
void wrenEnsureSlots(WrenVM* vm, int numSlots)
|
|
{
|
|
// If we don't have a fiber accessible, create one for the API to use.
|
|
if (vm->apiStack == NULL)
|
|
{
|
|
vm->fiber = wrenNewFiber(vm, NULL);
|
|
vm->apiStack = vm->fiber->stack;
|
|
}
|
|
|
|
int currentSize = (int)(vm->fiber->stackTop - vm->apiStack);
|
|
if (currentSize >= numSlots) return;
|
|
|
|
// Grow the stack if needed.
|
|
int needed = (int)(vm->apiStack - vm->fiber->stack) + numSlots;
|
|
wrenEnsureStack(vm, vm->fiber, needed);
|
|
|
|
vm->fiber->stackTop = vm->apiStack + numSlots;
|
|
}
|
|
|
|
// Ensures that [slot] is a valid index into the API's stack of slots.
|
|
static void validateApiSlot(WrenVM* vm, int slot)
|
|
{
|
|
ASSERT(slot >= 0, "Slot cannot be negative.");
|
|
ASSERT(slot < wrenGetSlotCount(vm), "Not that many slots.");
|
|
}
|
|
|
|
// Gets the type of the object in [slot].
|
|
WrenType wrenGetSlotType(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
if (IS_BOOL(vm->apiStack[slot])) return WREN_TYPE_BOOL;
|
|
if (IS_NUM(vm->apiStack[slot])) return WREN_TYPE_NUM;
|
|
if (IS_FOREIGN(vm->apiStack[slot])) return WREN_TYPE_FOREIGN;
|
|
if (IS_LIST(vm->apiStack[slot])) return WREN_TYPE_LIST;
|
|
if (IS_MAP(vm->apiStack[slot])) return WREN_TYPE_MAP;
|
|
if (IS_NULL(vm->apiStack[slot])) return WREN_TYPE_NULL;
|
|
if (IS_STRING(vm->apiStack[slot])) return WREN_TYPE_STRING;
|
|
|
|
return WREN_TYPE_UNKNOWN;
|
|
}
|
|
|
|
bool wrenGetSlotBool(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_BOOL(vm->apiStack[slot]), "Slot must hold a bool.");
|
|
|
|
return AS_BOOL(vm->apiStack[slot]);
|
|
}
|
|
|
|
const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_STRING(vm->apiStack[slot]), "Slot must hold a string.");
|
|
|
|
ObjString* string = AS_STRING(vm->apiStack[slot]);
|
|
*length = string->length;
|
|
return string->value;
|
|
}
|
|
|
|
double wrenGetSlotDouble(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_NUM(vm->apiStack[slot]), "Slot must hold a number.");
|
|
|
|
return AS_NUM(vm->apiStack[slot]);
|
|
}
|
|
|
|
void* wrenGetSlotForeign(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_FOREIGN(vm->apiStack[slot]),
|
|
"Slot must hold a foreign instance.");
|
|
|
|
return AS_FOREIGN(vm->apiStack[slot])->data;
|
|
}
|
|
|
|
const char* wrenGetSlotString(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_STRING(vm->apiStack[slot]), "Slot must hold a string.");
|
|
|
|
return AS_CSTRING(vm->apiStack[slot]);
|
|
}
|
|
|
|
WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
return wrenMakeHandle(vm, vm->apiStack[slot]);
|
|
}
|
|
|
|
// Stores [value] in [slot] in the foreign call stack.
|
|
static void setSlot(WrenVM* vm, int slot, Value value)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
vm->apiStack[slot] = value;
|
|
}
|
|
|
|
void wrenSetSlotBool(WrenVM* vm, int slot, bool value)
|
|
{
|
|
setSlot(vm, slot, BOOL_VAL(value));
|
|
}
|
|
|
|
void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length)
|
|
{
|
|
ASSERT(bytes != NULL, "Byte array cannot be NULL.");
|
|
setSlot(vm, slot, wrenNewStringLength(vm, bytes, length));
|
|
}
|
|
|
|
void wrenSetSlotDouble(WrenVM* vm, int slot, double value)
|
|
{
|
|
setSlot(vm, slot, NUM_VAL(value));
|
|
}
|
|
|
|
void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
validateApiSlot(vm, classSlot);
|
|
ASSERT(IS_CLASS(vm->apiStack[classSlot]), "Slot must hold a class.");
|
|
|
|
ObjClass* classObj = AS_CLASS(vm->apiStack[classSlot]);
|
|
ASSERT(classObj->numFields == -1, "Class must be a foreign class.");
|
|
|
|
ObjForeign* foreign = wrenNewForeign(vm, classObj, size);
|
|
vm->apiStack[slot] = OBJ_VAL(foreign);
|
|
|
|
return (void*)foreign->data;
|
|
}
|
|
|
|
void wrenSetSlotNewList(WrenVM* vm, int slot)
|
|
{
|
|
setSlot(vm, slot, OBJ_VAL(wrenNewList(vm, 0)));
|
|
}
|
|
|
|
void wrenSetSlotNewMap(WrenVM* vm, int slot)
|
|
{
|
|
setSlot(vm, slot, OBJ_VAL(wrenNewMap(vm)));
|
|
}
|
|
|
|
void wrenSetSlotNull(WrenVM* vm, int slot)
|
|
{
|
|
setSlot(vm, slot, NULL_VAL);
|
|
}
|
|
|
|
void wrenSetSlotString(WrenVM* vm, int slot, const char* text)
|
|
{
|
|
ASSERT(text != NULL, "String cannot be NULL.");
|
|
|
|
setSlot(vm, slot, wrenNewString(vm, text));
|
|
}
|
|
|
|
void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle)
|
|
{
|
|
ASSERT(handle != NULL, "Handle cannot be NULL.");
|
|
|
|
setSlot(vm, slot, handle->value);
|
|
}
|
|
|
|
int wrenGetListCount(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_LIST(vm->apiStack[slot]), "Slot must hold a list.");
|
|
|
|
ValueBuffer elements = AS_LIST(vm->apiStack[slot])->elements;
|
|
return elements.count;
|
|
}
|
|
|
|
void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot)
|
|
{
|
|
validateApiSlot(vm, listSlot);
|
|
validateApiSlot(vm, elementSlot);
|
|
ASSERT(IS_LIST(vm->apiStack[listSlot]), "Slot must hold a list.");
|
|
|
|
ValueBuffer elements = AS_LIST(vm->apiStack[listSlot])->elements;
|
|
|
|
uint32_t usedIndex = wrenValidateIndex(elements.count, index);
|
|
ASSERT(usedIndex != UINT32_MAX, "Index out of bounds.");
|
|
|
|
vm->apiStack[elementSlot] = elements.data[usedIndex];
|
|
}
|
|
|
|
void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot)
|
|
{
|
|
validateApiSlot(vm, listSlot);
|
|
validateApiSlot(vm, elementSlot);
|
|
ASSERT(IS_LIST(vm->apiStack[listSlot]), "Slot must hold a list.");
|
|
|
|
ObjList* list = AS_LIST(vm->apiStack[listSlot]);
|
|
|
|
uint32_t usedIndex = wrenValidateIndex(list->elements.count, index);
|
|
ASSERT(usedIndex != UINT32_MAX, "Index out of bounds.");
|
|
|
|
list->elements.data[usedIndex] = vm->apiStack[elementSlot];
|
|
}
|
|
|
|
void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot)
|
|
{
|
|
validateApiSlot(vm, listSlot);
|
|
validateApiSlot(vm, elementSlot);
|
|
ASSERT(IS_LIST(vm->apiStack[listSlot]), "Must insert into a list.");
|
|
|
|
ObjList* list = AS_LIST(vm->apiStack[listSlot]);
|
|
|
|
// Negative indices count from the end.
|
|
// We don't use wrenValidateIndex here because insert allows 1 past the end.
|
|
if (index < 0) index = list->elements.count + 1 + index;
|
|
|
|
ASSERT(index <= list->elements.count, "Index out of bounds.");
|
|
|
|
wrenListInsert(vm, list, vm->apiStack[elementSlot], index);
|
|
}
|
|
|
|
int wrenGetMapCount(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
ASSERT(IS_MAP(vm->apiStack[slot]), "Slot must hold a map.");
|
|
|
|
ObjMap* map = AS_MAP(vm->apiStack[slot]);
|
|
return map->count;
|
|
}
|
|
|
|
bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot)
|
|
{
|
|
validateApiSlot(vm, mapSlot);
|
|
validateApiSlot(vm, keySlot);
|
|
ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map.");
|
|
|
|
Value key = vm->apiStack[keySlot];
|
|
ASSERT(wrenMapIsValidKey(key), "Key must be a value type");
|
|
if (!validateKey(vm, key)) return false;
|
|
|
|
ObjMap* map = AS_MAP(vm->apiStack[mapSlot]);
|
|
Value value = wrenMapGet(map, key);
|
|
|
|
return !IS_UNDEFINED(value);
|
|
}
|
|
|
|
void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot)
|
|
{
|
|
validateApiSlot(vm, mapSlot);
|
|
validateApiSlot(vm, keySlot);
|
|
validateApiSlot(vm, valueSlot);
|
|
ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map.");
|
|
|
|
ObjMap* map = AS_MAP(vm->apiStack[mapSlot]);
|
|
Value value = wrenMapGet(map, vm->apiStack[keySlot]);
|
|
if (IS_UNDEFINED(value)) {
|
|
value = NULL_VAL;
|
|
}
|
|
|
|
vm->apiStack[valueSlot] = value;
|
|
}
|
|
|
|
void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot)
|
|
{
|
|
validateApiSlot(vm, mapSlot);
|
|
validateApiSlot(vm, keySlot);
|
|
validateApiSlot(vm, valueSlot);
|
|
ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Must insert into a map.");
|
|
|
|
Value key = vm->apiStack[keySlot];
|
|
ASSERT(wrenMapIsValidKey(key), "Key must be a value type");
|
|
|
|
if (!validateKey(vm, key)) {
|
|
return;
|
|
}
|
|
|
|
Value value = vm->apiStack[valueSlot];
|
|
ObjMap* map = AS_MAP(vm->apiStack[mapSlot]);
|
|
|
|
wrenMapSet(vm, map, key, value);
|
|
}
|
|
|
|
void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot,
|
|
int removedValueSlot)
|
|
{
|
|
validateApiSlot(vm, mapSlot);
|
|
validateApiSlot(vm, keySlot);
|
|
ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map.");
|
|
|
|
Value key = vm->apiStack[keySlot];
|
|
if (!validateKey(vm, key)) {
|
|
return;
|
|
}
|
|
|
|
ObjMap* map = AS_MAP(vm->apiStack[mapSlot]);
|
|
Value removed = wrenMapRemoveKey(vm, map, key);
|
|
setSlot(vm, removedValueSlot, removed);
|
|
}
|
|
|
|
void wrenGetVariable(WrenVM* vm, const char* module, const char* name,
|
|
int slot)
|
|
{
|
|
ASSERT(module != NULL, "Module cannot be NULL.");
|
|
ASSERT(name != NULL, "Variable name cannot be NULL.");
|
|
|
|
Value moduleName = wrenStringFormat(vm, "$", module);
|
|
wrenPushRoot(vm, AS_OBJ(moduleName));
|
|
|
|
ObjModule* moduleObj = getModule(vm, moduleName);
|
|
ASSERT(moduleObj != NULL, "Could not find module.");
|
|
|
|
wrenPopRoot(vm); // moduleName.
|
|
|
|
int variableSlot = wrenSymbolTableFind(&moduleObj->variableNames,
|
|
name, strlen(name));
|
|
ASSERT(variableSlot != -1, "Could not find variable.");
|
|
|
|
setSlot(vm, slot, moduleObj->variables.data[variableSlot]);
|
|
}
|
|
|
|
bool wrenHasVariable(WrenVM* vm, const char* module, const char* name)
|
|
{
|
|
ASSERT(module != NULL, "Module cannot be NULL.");
|
|
ASSERT(name != NULL, "Variable name cannot be NULL.");
|
|
|
|
Value moduleName = wrenStringFormat(vm, "$", module);
|
|
wrenPushRoot(vm, AS_OBJ(moduleName));
|
|
|
|
//We don't use wrenHasModule since we want to use the module object.
|
|
ObjModule* moduleObj = getModule(vm, moduleName);
|
|
ASSERT(moduleObj != NULL, "Could not find module.");
|
|
|
|
wrenPopRoot(vm); // moduleName.
|
|
|
|
int variableSlot = wrenSymbolTableFind(&moduleObj->variableNames,
|
|
name, strlen(name));
|
|
|
|
return variableSlot != -1;
|
|
}
|
|
|
|
bool wrenHasModule(WrenVM* vm, const char* module)
|
|
{
|
|
ASSERT(module != NULL, "Module cannot be NULL.");
|
|
|
|
Value moduleName = wrenStringFormat(vm, "$", module);
|
|
wrenPushRoot(vm, AS_OBJ(moduleName));
|
|
|
|
ObjModule* moduleObj = getModule(vm, moduleName);
|
|
|
|
wrenPopRoot(vm); // moduleName.
|
|
|
|
return moduleObj != NULL;
|
|
}
|
|
|
|
void wrenAbortFiber(WrenVM* vm, int slot)
|
|
{
|
|
validateApiSlot(vm, slot);
|
|
vm->fiber->error = vm->apiStack[slot];
|
|
}
|
|
|
|
void* wrenGetUserData(WrenVM* vm)
|
|
{
|
|
return vm->config.userData;
|
|
}
|
|
|
|
void wrenSetUserData(WrenVM* vm, void* userData)
|
|
{
|
|
vm->config.userData = userData;
|
|
}
|
|
// End file "wren_vm.c"
|
|
// Begin file "wren_opt_random.c"
|
|
|
|
#if WREN_OPT_RANDOM
|
|
|
|
#include <string.h>
|
|
#include <time.h>
|
|
|
|
|
|
// Begin file "wren_opt_random.wren.inc"
|
|
// Generated automatically from src/optional/wren_opt_random.wren. Do not edit.
|
|
static const char* randomModuleSource =
|
|
"foreign class Random {\n"
|
|
" construct new() {\n"
|
|
" seed_()\n"
|
|
" }\n"
|
|
"\n"
|
|
" construct new(seed) {\n"
|
|
" if (seed is Num) {\n"
|
|
" seed_(seed)\n"
|
|
" } else if (seed is Sequence) {\n"
|
|
" if (seed.isEmpty) Fiber.abort(\"Sequence cannot be empty.\")\n"
|
|
"\n"
|
|
" // TODO: Empty sequence.\n"
|
|
" var seeds = []\n"
|
|
" for (element in seed) {\n"
|
|
" if (!(element is Num)) Fiber.abort(\"Sequence elements must all be numbers.\")\n"
|
|
"\n"
|
|
" seeds.add(element)\n"
|
|
" if (seeds.count == 16) break\n"
|
|
" }\n"
|
|
"\n"
|
|
" // Cycle the values to fill in any missing slots.\n"
|
|
" var i = 0\n"
|
|
" while (seeds.count < 16) {\n"
|
|
" seeds.add(seeds[i])\n"
|
|
" i = i + 1\n"
|
|
" }\n"
|
|
"\n"
|
|
" seed_(\n"
|
|
" seeds[0], seeds[1], seeds[2], seeds[3],\n"
|
|
" seeds[4], seeds[5], seeds[6], seeds[7],\n"
|
|
" seeds[8], seeds[9], seeds[10], seeds[11],\n"
|
|
" seeds[12], seeds[13], seeds[14], seeds[15])\n"
|
|
" } else {\n"
|
|
" Fiber.abort(\"Seed must be a number or a sequence of numbers.\")\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" foreign seed_()\n"
|
|
" foreign seed_(seed)\n"
|
|
" foreign seed_(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10, n11, n12, n13, n14, n15, n16)\n"
|
|
"\n"
|
|
" foreign float()\n"
|
|
" float(end) { float() * end }\n"
|
|
" float(start, end) { float() * (end - start) + start }\n"
|
|
"\n"
|
|
" foreign int()\n"
|
|
" int(end) { (float() * end).floor }\n"
|
|
" int(start, end) { (float() * (end - start)).floor + start }\n"
|
|
"\n"
|
|
" sample(list) {\n"
|
|
" if (list.count == 0) Fiber.abort(\"Not enough elements to sample.\")\n"
|
|
" return list[int(list.count)]\n"
|
|
" }\n"
|
|
" sample(list, count) {\n"
|
|
" if (count > list.count) Fiber.abort(\"Not enough elements to sample.\")\n"
|
|
"\n"
|
|
" var result = []\n"
|
|
"\n"
|
|
" // The algorithm described in \"Programming pearls: a sample of brilliance\".\n"
|
|
" // Use a hash map for sample sizes less than 1/4 of the population size and\n"
|
|
" // an array of booleans for larger samples. This simple heuristic improves\n"
|
|
" // performance for large sample sizes as well as reduces memory usage.\n"
|
|
" if (count * 4 < list.count) {\n"
|
|
" var picked = {}\n"
|
|
" for (i in list.count - count...list.count) {\n"
|
|
" var index = int(i + 1)\n"
|
|
" if (picked.containsKey(index)) index = i\n"
|
|
" picked[index] = true\n"
|
|
" result.add(list[index])\n"
|
|
" }\n"
|
|
" } else {\n"
|
|
" var picked = List.filled(list.count, false)\n"
|
|
" for (i in list.count - count...list.count) {\n"
|
|
" var index = int(i + 1)\n"
|
|
" if (picked[index]) index = i\n"
|
|
" picked[index] = true\n"
|
|
" result.add(list[index])\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" return result\n"
|
|
" }\n"
|
|
"\n"
|
|
" shuffle(list) {\n"
|
|
" if (list.isEmpty) return\n"
|
|
"\n"
|
|
" // Fisher-Yates shuffle.\n"
|
|
" for (i in 0...list.count - 1) {\n"
|
|
" var from = int(i, list.count)\n"
|
|
" var temp = list[from]\n"
|
|
" list[from] = list[i]\n"
|
|
" list[i] = temp\n"
|
|
" }\n"
|
|
" }\n"
|
|
"}\n";
|
|
// End file "wren_opt_random.wren.inc"
|
|
|
|
// Implements the well equidistributed long-period linear PRNG (WELL512a).
|
|
//
|
|
// https://en.wikipedia.org/wiki/Well_equidistributed_long-period_linear
|
|
typedef struct
|
|
{
|
|
uint32_t state[16];
|
|
uint32_t index;
|
|
} Well512;
|
|
|
|
// Code from: http://www.lomont.org/Math/Papers/2008/Lomont_PRNG_2008.pdf
|
|
static uint32_t advanceState(Well512* well)
|
|
{
|
|
uint32_t a, b, c, d;
|
|
a = well->state[well->index];
|
|
c = well->state[(well->index + 13) & 15];
|
|
b = a ^ c ^ (a << 16) ^ (c << 15);
|
|
c = well->state[(well->index + 9) & 15];
|
|
c ^= (c >> 11);
|
|
a = well->state[well->index] = b ^ c;
|
|
d = a ^ ((a << 5) & 0xda442d24U);
|
|
|
|
well->index = (well->index + 15) & 15;
|
|
a = well->state[well->index];
|
|
well->state[well->index] = a ^ b ^ d ^ (a << 2) ^ (b << 18) ^ (c << 28);
|
|
return well->state[well->index];
|
|
}
|
|
|
|
static void randomAllocate(WrenVM* vm)
|
|
{
|
|
Well512* well = (Well512*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(Well512));
|
|
well->index = 0;
|
|
}
|
|
|
|
static void randomSeed0(WrenVM* vm)
|
|
{
|
|
Well512* well = (Well512*)wrenGetSlotForeign(vm, 0);
|
|
|
|
srand((uint32_t)time(NULL));
|
|
for (int i = 0; i < 16; i++)
|
|
{
|
|
well->state[i] = rand();
|
|
}
|
|
}
|
|
|
|
static void randomSeed1(WrenVM* vm)
|
|
{
|
|
Well512* well = (Well512*)wrenGetSlotForeign(vm, 0);
|
|
|
|
srand((uint32_t)wrenGetSlotDouble(vm, 1));
|
|
for (int i = 0; i < 16; i++)
|
|
{
|
|
well->state[i] = rand();
|
|
}
|
|
}
|
|
|
|
static void randomSeed16(WrenVM* vm)
|
|
{
|
|
Well512* well = (Well512*)wrenGetSlotForeign(vm, 0);
|
|
|
|
for (int i = 0; i < 16; i++)
|
|
{
|
|
well->state[i] = (uint32_t)wrenGetSlotDouble(vm, i + 1);
|
|
}
|
|
}
|
|
|
|
static void randomFloat(WrenVM* vm)
|
|
{
|
|
Well512* well = (Well512*)wrenGetSlotForeign(vm, 0);
|
|
|
|
// A double has 53 bits of precision in its mantissa, and we'd like to take
|
|
// full advantage of that, so we need 53 bits of random source data.
|
|
|
|
// First, start with 32 random bits, shifted to the left 21 bits.
|
|
double result = (double)advanceState(well) * (1 << 21);
|
|
|
|
// Then add another 21 random bits.
|
|
result += (double)(advanceState(well) & ((1 << 21) - 1));
|
|
|
|
// Now we have a number from 0 - (2^53). Divide be the range to get a double
|
|
// from 0 to 1.0 (half-inclusive).
|
|
result /= 9007199254740992.0;
|
|
|
|
wrenSetSlotDouble(vm, 0, result);
|
|
}
|
|
|
|
static void randomInt0(WrenVM* vm)
|
|
{
|
|
Well512* well = (Well512*)wrenGetSlotForeign(vm, 0);
|
|
|
|
wrenSetSlotDouble(vm, 0, (double)advanceState(well));
|
|
}
|
|
|
|
const char* wrenRandomSource()
|
|
{
|
|
return randomModuleSource;
|
|
}
|
|
|
|
WrenForeignClassMethods wrenRandomBindForeignClass(WrenVM* vm,
|
|
const char* module,
|
|
const char* className)
|
|
{
|
|
ASSERT(strcmp(className, "Random") == 0, "Should be in Random class.");
|
|
WrenForeignClassMethods methods;
|
|
methods.allocate = randomAllocate;
|
|
methods.finalize = NULL;
|
|
return methods;
|
|
}
|
|
|
|
WrenForeignMethodFn wrenRandomBindForeignMethod(WrenVM* vm,
|
|
const char* className,
|
|
bool isStatic,
|
|
const char* signature)
|
|
{
|
|
ASSERT(strcmp(className, "Random") == 0, "Should be in Random class.");
|
|
|
|
if (strcmp(signature, "<allocate>") == 0) return randomAllocate;
|
|
if (strcmp(signature, "seed_()") == 0) return randomSeed0;
|
|
if (strcmp(signature, "seed_(_)") == 0) return randomSeed1;
|
|
|
|
if (strcmp(signature, "seed_(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)") == 0)
|
|
{
|
|
return randomSeed16;
|
|
}
|
|
|
|
if (strcmp(signature, "float()") == 0) return randomFloat;
|
|
if (strcmp(signature, "int()") == 0) return randomInt0;
|
|
|
|
ASSERT(false, "Unknown method.");
|
|
return NULL;
|
|
}
|
|
|
|
#endif
|
|
// End file "wren_opt_random.c"
|
|
// Begin file "wren_opt_meta.c"
|
|
|
|
#if WREN_OPT_META
|
|
|
|
#include <string.h>
|
|
|
|
// Begin file "wren_opt_meta.wren.inc"
|
|
// Generated automatically from src/optional/wren_opt_meta.wren. Do not edit.
|
|
static const char* metaModuleSource =
|
|
"class Meta {\n"
|
|
" static getModuleVariables(module) {\n"
|
|
" if (!(module is String)) Fiber.abort(\"Module name must be a string.\")\n"
|
|
" var result = getModuleVariables_(module)\n"
|
|
" if (result != null) return result\n"
|
|
"\n"
|
|
" Fiber.abort(\"Could not find a module named '%(module)'.\")\n"
|
|
" }\n"
|
|
"\n"
|
|
" static eval(source) {\n"
|
|
" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n"
|
|
"\n"
|
|
" var closure = compile_(source, false, false)\n"
|
|
" // TODO: Include compile errors.\n"
|
|
" if (closure == null) Fiber.abort(\"Could not compile source code.\")\n"
|
|
"\n"
|
|
" closure.call()\n"
|
|
" }\n"
|
|
"\n"
|
|
" static compileExpression(source) {\n"
|
|
" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n"
|
|
" return compile_(source, true, true)\n"
|
|
" }\n"
|
|
"\n"
|
|
" static compile(source) {\n"
|
|
" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n"
|
|
" return compile_(source, false, true)\n"
|
|
" }\n"
|
|
"\n"
|
|
" foreign static compile_(source, isExpression, printErrors)\n"
|
|
" foreign static getModuleVariables_(module)\n"
|
|
"}\n";
|
|
// End file "wren_opt_meta.wren.inc"
|
|
|
|
void metaCompile(WrenVM* vm)
|
|
{
|
|
const char* source = wrenGetSlotString(vm, 1);
|
|
bool isExpression = wrenGetSlotBool(vm, 2);
|
|
bool printErrors = wrenGetSlotBool(vm, 3);
|
|
|
|
// TODO: Allow passing in module?
|
|
// Look up the module surrounding the callsite. This is brittle. The -2 walks
|
|
// up the callstack assuming that the meta module has one level of
|
|
// indirection before hitting the user's code. Any change to meta may require
|
|
// this constant to be tweaked.
|
|
ObjFiber* currentFiber = vm->fiber;
|
|
ObjFn* fn = currentFiber->frames[currentFiber->numFrames - 2].closure->fn;
|
|
ObjString* module = fn->module->name;
|
|
|
|
ObjClosure* closure = wrenCompileSource(vm, module->value, source,
|
|
isExpression, printErrors);
|
|
|
|
// Return the result. We can't use the public API for this since we have a
|
|
// bare ObjClosure*.
|
|
if (closure == NULL)
|
|
{
|
|
vm->apiStack[0] = NULL_VAL;
|
|
}
|
|
else
|
|
{
|
|
vm->apiStack[0] = OBJ_VAL(closure);
|
|
}
|
|
}
|
|
|
|
void metaGetModuleVariables(WrenVM* vm)
|
|
{
|
|
wrenEnsureSlots(vm, 3);
|
|
|
|
Value moduleValue = wrenMapGet(vm->modules, vm->apiStack[1]);
|
|
if (IS_UNDEFINED(moduleValue))
|
|
{
|
|
vm->apiStack[0] = NULL_VAL;
|
|
return;
|
|
}
|
|
|
|
ObjModule* module = AS_MODULE(moduleValue);
|
|
ObjList* names = wrenNewList(vm, module->variableNames.count);
|
|
vm->apiStack[0] = OBJ_VAL(names);
|
|
|
|
// Initialize the elements to null in case a collection happens when we
|
|
// allocate the strings below.
|
|
for (int i = 0; i < names->elements.count; i++)
|
|
{
|
|
names->elements.data[i] = NULL_VAL;
|
|
}
|
|
|
|
for (int i = 0; i < names->elements.count; i++)
|
|
{
|
|
names->elements.data[i] = OBJ_VAL(module->variableNames.data[i]);
|
|
}
|
|
}
|
|
|
|
const char* wrenMetaSource()
|
|
{
|
|
return metaModuleSource;
|
|
}
|
|
|
|
WrenForeignMethodFn wrenMetaBindForeignMethod(WrenVM* vm,
|
|
const char* className,
|
|
bool isStatic,
|
|
const char* signature)
|
|
{
|
|
// There is only one foreign method in the meta module.
|
|
ASSERT(strcmp(className, "Meta") == 0, "Should be in Meta class.");
|
|
ASSERT(isStatic, "Should be static.");
|
|
|
|
if (strcmp(signature, "compile_(_,_,_)") == 0)
|
|
{
|
|
return metaCompile;
|
|
}
|
|
|
|
if (strcmp(signature, "getModuleVariables_(_)") == 0)
|
|
{
|
|
return metaGetModuleVariables;
|
|
}
|
|
|
|
ASSERT(false, "Unknown method.");
|
|
return NULL;
|
|
}
|
|
|
|
#endif
|
|
// End file "wren_opt_meta.c"
|
|
|
|
// End of wren.c
|
|
|