commit 9f27915cab0b5384bec59580bd1080383e237417 Author: retoor Date: Tue Jul 29 14:35:38 2025 +0200 Requests working good. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..88a4b05 --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ + +make: + gcc main.c wren.c -lcurl -lm -lpthread -o wren + + +make3: + g++ main3.c wren.c -lcurl -lm -o wren3 + + + +run: + ./wren_requests_example diff --git a/async_http.c b/async_http.c new file mode 100644 index 0000000..585c8dd --- /dev/null +++ b/async_http.c @@ -0,0 +1,79 @@ +#include "httplib.h" +#include "wren.h" + +// A struct to hold the context for an asynchronous HTTP request +struct RequestContext { + std::string url; + WrenHandle* callback; + WrenVM* vm; + std::string response; + bool error; +}; + +// A class to manage asynchronous HTTP requests +class AsyncHttp { +public: + AsyncHttp(WrenVM* vm) : vm_(vm), running_(true) { + // Create a pool of worker threads + for (int i = 0; i < 4; ++i) { + threads_.emplace_back([this] { + while (running_) { + RequestContext* context = requestQueue_.pop(); + if (!running_) break; + + httplib::Client cli("http://example.com"); + if (auto res = cli.Get(context->url.c_str())) { + context->response = res->body; + context->error = false; + } else { + context->response = "Error: " + to_string(res.error()); + context->error = true; + } + + completionQueue_.push(context); + } + }); + } + } + + ~AsyncHttp() { + running_ = false; + // Add dummy requests to unblock worker threads + for (size_t i = 0; i < threads_.size(); ++i) { + requestQueue_.push(nullptr); + } + for (auto& thread : threads_) { + thread.join(); + } + } + + void request(const std::string& url, WrenHandle* callback) { + RequestContext* context = new RequestContext{url, callback, vm_}; + requestQueue_.push(context); + } + + void processCompletions() { + while (!completionQueue_.empty()) { + RequestContext* context = completionQueue_.pop(); + + // Create a handle for the callback function + WrenHandle* callHandle = wrenMakeCallHandle(vm_, "call(_)"); + + wrenEnsureSlots(vm_, 2); + wrenSetSlotHandle(vm_, 0, context->callback); + wrenSetSlotString(vm_, 1, context->response.c_str()); + wrenCall(vm_, callHandle); + + wrenReleaseHandle(vm_, callHandle); + wrenReleaseHandle(vm_, context->callback); + delete context; + } + } + +private: + WrenVM* vm_; + bool running_; + std::vector threads_; + ThreadSafeQueue requestQueue_; + ThreadSafeQueue completionQueue_; +}; diff --git a/backend.cpp b/backend.cpp new file mode 100644 index 0000000..5526e2b --- /dev/null +++ b/backend.cpp @@ -0,0 +1,135 @@ +// backend.cpp (Corrected) +#include "httplib.h" +#include "wren.h" +#include +#include + +// A struct to hold the response data for our foreign object +struct ResponseData { + bool isError; + int statusCode; + std::string body; +}; + +// --- Response Class Foreign Methods --- + +void responseAllocate(WrenVM* vm) { + // This is the constructor for the Response class. + ResponseData* data = (ResponseData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(ResponseData)); + data->isError = false; + data->statusCode = 0; +} + +void responseIsError(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, data->isError); +} + +void responseStatusCode(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotDouble(vm, 0, data->statusCode); +} + +void responseBody(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBytes(vm, 0, data->body.c_str(), data->body.length()); +} + +void responseJson(WrenVM* vm) { + // For a real implementation, you would use a JSON library here. + // For this example, we just return the body text. + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBytes(vm, 0, data->body.c_str(), data->body.length()); +} + +// --- Requests Class Foreign Methods --- + +void requestsGet(WrenVM* vm) { + const char* url = wrenGetSlotString(vm, 1); + // TODO: Handle headers from slot 2. + + httplib::Client cli("jsonplaceholder.typicode.com"); + auto res = cli.Get("/posts/1"); + + // CHANGED: We need two slots: one for the Response class, one for the new instance. + wrenEnsureSlots(vm, 2); + + // CHANGED: Get the 'Response' class from the 'requests' module and put it in slot 1. + wrenGetVariable(vm, "requests", "Response", 1); + + // CHANGED: Create a new foreign object instance of the class in slot 1. + // The new instance is placed in slot 0, which becomes the return value. + ResponseData* data = (ResponseData*)wrenSetSlotNewForeign(vm, 0, 1, sizeof(ResponseData)); + + if (res) { + data->isError = false; + data->statusCode = res->status; + data->body = res->body; + } else { + data->isError = true; + data->statusCode = -1; + data->body = "GET request failed."; + } +} + +void requestsPost(WrenVM* vm) { + const char* url = wrenGetSlotString(vm, 1); + const char* body = wrenGetSlotString(vm, 2); + const char* contentType = wrenGetSlotString(vm, 3); + // TODO: Handle headers from slot 4. + + httplib::Client cli("jsonplaceholder.typicode.com"); + auto res = cli.Post("/posts", body, contentType); + + // CHANGED: We need two slots: one for the Response class, one for the new instance. + wrenEnsureSlots(vm, 2); + + // CHANGED: Get the 'Response' class from the 'requests' module and put it in slot 1. + wrenGetVariable(vm, "requests", "Response", 1); + + // CHANGED: Create a new foreign object instance of the class in slot 1. + // The new instance is placed in slot 0, which becomes the return value. + ResponseData* data = (ResponseData*)wrenSetSlotNewForeign(vm, 0, 1, sizeof(ResponseData)); + + if (res) { + data->isError = false; + data->statusCode = res->status; + data->body = res->body; + } else { + data->isError = true; + data->statusCode = -1; + data->body = "POST request failed."; + } +} + + +// --- FFI Binding Functions --- + +WrenForeignMethodFn bindForeignMethod(WrenVM* vm, const char* module, + const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "requests") != 0) return NULL; + + if (strcmp(className, "Requests") == 0 && isStatic) { + if (strcmp(signature, "get_(_,_)") == 0) return requestsGet; + if (strcmp(signature, "post_(_,_,_,_)") == 0) return requestsPost; + } + + if (strcmp(className, "Response") == 0 && !isStatic) { + if (strcmp(signature, "isError") == 0) return responseIsError; + if (strcmp(signature, "statusCode") == 0) return responseStatusCode; + if (strcmp(signature, "body") == 0) return responseBody; + if (strcmp(signature, "json()") == 0) return responseJson; + } + + return NULL; +} + +WrenForeignClassMethods bindForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0}; + if (strcmp(module, "requests") == 0) { + if (strcmp(className, "Response") == 0) { + methods.allocate = responseAllocate; + } + } + return methods; +} diff --git a/httplib.h b/httplib.h new file mode 100644 index 0000000..130d200 --- /dev/null +++ b/httplib.h @@ -0,0 +1,11573 @@ +// +// httplib.h +// +// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.23.1" +#define CPPHTTPLIB_VERSION_NUM "0x001701" + +/* + * Platform compatibility check + */ + +#if defined(_WIN32) && !defined(_WIN64) +#error \ + "cpp-httplib doesn't support 32-bit Windows. Please use a 64-bit compiler." +#elif defined(__SIZEOF_POINTER__) && __SIZEOF_POINTER__ < 8 +#warning \ + "cpp-httplib doesn't support 32-bit platforms. Please use a 64-bit compiler." +#elif defined(__SIZEOF_SIZE_T__) && __SIZEOF_SIZE_T__ < 8 +#warning \ + "cpp-httplib doesn't support platforms where size_t is less than 64 bits." +#endif + +#ifdef _WIN32 +#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0602 +#error \ + "cpp-httplib doesn't support Windows 8 or lower. Please use Windows 10 or later." +#endif +#endif + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND +#define CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN64 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 1000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_COUNT +#define CPPHTTPLIB_HEADER_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_SEND_BUFSIZ +#define CPPHTTPLIB_SEND_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ + ? std::thread::hardware_concurrency() - 1 \ + : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +#ifndef CPPHTTPLIB_MAX_LINE_LENGTH +#define CPPHTTPLIB_MAX_LINE_LENGTH 32768 +#endif + +/* + * Headers + */ + +#ifdef _WIN64 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +using ssize_t = __int64; +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#if defined(__has_include) +#if __has_include() +// afunix.h uses types declared in winsock2.h, so has to be included after it. +#include +#define CPPHTTPLIB_HAVE_AFUNIX_H 1 +#endif +#endif + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using nfds_t = unsigned long; +using socket_t = SOCKET; +using socklen_t = int; + +#else // not _WIN64 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN64 + +#if defined(__APPLE__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ + defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_OSX +#include +#include +#endif +#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or + // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN64 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN64 + +#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_OSX +#include +#endif +#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO + +#include +#include +#include +#include + +#if defined(_WIN64) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, + 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, + 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, + 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224, + 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, + }; + return table[(unsigned char)(char)c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { + return to_lower(ca) == to_lower(cb); + }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { + return equal(a, b); + } +}; + +struct hash { + size_t operator()(const std::string &key) const { + return hash_core(key.data(), key.size(), 0); + } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +template +using unordered_set = std::unordered_set; + +} // namespace case_ignore + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) + : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), + execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { this->exit_function(); } + } + + void release() { this->execute_on_destruction = false; } + +private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum SSLVerifierResponse { + // no decision has been made, use the built-in certificate verifier + NoDecisionMade, + // connection certificate is verified and accepted + CertificateAccepted, + // connection certificate was processed but is rejected + CertificateRejected +}; + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = + std::unordered_multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using DownloadProgress = std::function; +using UploadProgress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct FormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; + Headers headers; +}; + +struct FormField { + std::string name; + std::string content; + Headers headers; +}; +using FormFields = std::multimap; + +using FormFiles = std::multimap; + +struct MultipartFormData { + FormFields fields; // Text fields from multipart + FormFiles files; // Files from multipart + + // Text field access + std::string get_field(const std::string &key, size_t id = 0) const; + std::vector get_fields(const std::string &key) const; + bool has_field(const std::string &key) const; + size_t get_field_count(const std::string &key) const; + + // File access + FormData get_file(const std::string &key, size_t id = 0) const; + std::vector get_files(const std::string &key) const; + bool has_file(const std::string &key) const; + size_t get_file_count(const std::string &key) const; +}; + +struct UploadFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using UploadFormDataItems = std::vector; + +class DataSink { +public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + +private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = + std::function; + +using ContentProviderWithoutLength = + std::function; + +using ContentProviderResourceReleaser = std::function; + +struct FormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using FormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = std::function; + +using ContentReceiver = + std::function; + +using FormDataHeader = std::function; + +class ContentReader { +public: + using Reader = std::function; + using FormDataReader = + std::function; + + ContentReader(Reader reader, FormDataReader multipart_reader) + : reader_(std::move(reader)), + formdata_reader_(std::move(multipart_reader)) {} + + bool operator()(FormDataHeader header, ContentReceiver receiver) const { + return formdata_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } + + Reader reader_; + FormDataReader formdata_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + std::string matched_route; + Params params; + Headers headers; + Headers trailers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + MultipartFormData form; + Ranges ranges; + Match matches; + std::unordered_map path_params; + std::function is_connection_closed = []() { return true; }; + + // for client + std::vector accept_content_types; + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + DownloadProgress download_progress; + UploadProgress upload_progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + size_t get_header_value_u64(const std::string &key, size_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_trailer(const std::string &key) const; + std::string get_trailer_value(const std::string &key, size_t id = 0) const; + size_t get_trailer_value_count(const std::string &key) const; + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; + std::chrono::time_point start_time_ = + (std::chrono::steady_clock::time_point::min)(); +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + Headers trailers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + size_t get_header_value_u64(const std::string &key, size_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_trailer(const std::string &key) const; + std::string get_trailer_value(const std::string &key, size_t id = 0) const; + size_t get_trailer_value_count(const std::string &key) const; + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider( + size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_file_content(const std::string &path, + const std::string &content_type); + void set_file_content(const std::string &path); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + virtual time_t duration() const = 0; + + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { +public: + explicit ThreadPool(size_t n, size_t mqr = 0) + : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +namespace detail { + +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen); +bool set_socket_opt(socket_t sock, int level, int optname, int opt); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + +} // namespace detail + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { +public: + MatcherBase(std::string pattern) : pattern_(pattern) {} + virtual ~MatcherBase() = default; + + const std::string &pattern() const { return pattern_; } + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; + +private: + std::string pattern_; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched against the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) + : MatcherBase(pattern), regex_(pattern) {} + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { +public: + using Handler = std::function; + + using ExceptionHandler = + std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = + std::function; + + using HandlerWithContentReader = std::function; + + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, + const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, + Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + template + Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core( + std::forward(handler), + std::is_convertible{}); + } + + Server &set_exception_handler(ExceptionHandler handler); + + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_pre_request_handler(HandlerWithResponse handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + Server &set_pre_compression_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server & + set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template + Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template + Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template + Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + void decommission(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + +private: + using Handlers = + std::vector, Handler>>; + using HandlersForContentReader = + std::vector, + HandlerWithContentReader>>; + + static std::unique_ptr + make_matcher(const std::string &pattern); + + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + + socket_t create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res); + bool dispatch_request(Request &req, Response &res, + const Handlers &handlers) const; + bool dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, + std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, + Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, + const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool read_content_with_content_receiver(Stream &strm, Request &req, + Response &res, + ContentReceiver receiver, + FormDataHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + FormDataHeader multipart_header, + ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic is_decommissioned{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + HandlerWithResponse pre_request_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + Logger pre_compression_logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = + detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { +public: + Result() = default; + Result(std::unique_ptr &&res, Error err, + Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)) {} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error, unsigned long ssl_openssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error), + ssl_openssl_error_(ssl_openssl_error) {} +#endif + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // SSL Error + int ssl_error() const { return ssl_error_; } + // OpenSSL Error + unsigned long ssl_openssl_error() const { return ssl_openssl_error_; } +#endif + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, + const char *def = "", + size_t id = 0) const; + size_t get_request_header_value_u64(const std::string &key, size_t def = 0, + size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + +private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + int ssl_error_ = 0; + unsigned long ssl_openssl_error_ = 0; +#endif +}; + +class ClientImpl { +public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + // clang-format off + Result Get(const std::string &path, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Params ¶ms); + Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Delete(const std::string &path, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Params ¶ms, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const Params ¶ms, DownloadProgress progress = nullptr); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + // clang-format on + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_path_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier( + std::function verifier); +#endif + + void set_logger(Logger logger); + +protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, + Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = + detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + time_t max_timeout_msec_ = CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool path_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; +#endif + + Logger logger_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + int last_ssl_error_ = 0; + unsigned long last_openssl_error_ = 0; +#endif + +private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, + Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, + Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool create_redirect_client(const std::string &scheme, + const std::string &host, int port, Request &req, + Response &res, const std::string &path, + const std::string &location, Error &error); + template void setup_redirect_client(ClientType &client); + bool handle_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + std::unique_ptr send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider( + const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, UploadProgress progress); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const UploadFormDataItems &items, + const FormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback); + virtual bool is_ssl() const; +}; + +class Client { +public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + Client &operator=(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + // clang-format off + Result Get(const std::string &path, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Params ¶ms); + Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Delete(const std::string &path, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Params ¶ms, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const Params ¶ms, DownloadProgress progress = nullptr); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + // clang-format on + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_path_encode(bool on); + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier( + std::function verifier); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + +private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + SSLServer( + const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + int last_ssl_error_ = 0; +#endif +}; + +class SSLClient final : public ClientImpl { +public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy( + Socket &sock, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template +inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast( + duration - std::chrono::seconds(sec)) + .count(); + callback(static_cast(sec), static_cast(usec)); +} + +template inline constexpr size_t str_len(const char (&)[N]) { + return N - 1; +} + +inline bool is_numeric(const std::string &str) { + return !str.empty() && + std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); +} + +inline size_t get_header_value_u64(const Headers &headers, + const std::string &key, size_t def, + size_t id, bool &is_invalid_value) { + is_invalid_value = false; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } + } + return def; +} + +inline size_t get_header_value_u64(const Headers &headers, + const std::string &key, size_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + +} // namespace detail + +inline size_t Request::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline size_t Response::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +namespace detail { + +inline bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, +#ifdef _WIN64 + reinterpret_cast(optval), +#else + optval, +#endif + optlen) == 0; +} + +inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); +} + +inline bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN64 + auto timeout = static_cast(sec * 1000 + usec / 1000); +#else + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); +#endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + +} // namespace detail + +inline void default_socket_options(socket_t sock) { + detail::set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, +#endif + 1); +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); + return req.get_header_value("Authorization") + .substr(bearer_header_prefix_len); + } + return ""; +} + +template +inline Server & +Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::Unknown: return "Unknown"; + default: break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline size_t Result::get_request_header_value_u64(const std::string &key, + size_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + +template +inline void ClientImpl::set_connection_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { + set_connection_timeout(sec, usec); + }); +} + +template +inline void ClientImpl::set_read_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_max_timeout( + const std::chrono::duration &duration) { + auto msec = + std::chrono::duration_cast(duration).count(); + set_max_timeout(msec); +} + +template +inline void Client::set_connection_timeout( + const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void +Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void +Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +inline void Client::set_max_timeout(time_t msec) { + cli_->set_max_timeout(msec); +} + +template +inline void +Client::set_max_timeout(const std::chrono::duration &duration) { + cli_->set_max_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string encode_uri_component(const std::string &value); + +std::string encode_uri(const std::string &value); + +std::string decode_uri_component(const std::string &value); + +std::string decode_uri(const std::string &value); + +std::string encode_query_param(const std::string &value); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +#if defined(_WIN64) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + +private: +#if defined(_WIN64) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + +std::string decode_path(const std::string &s, bool convert_plus_to_space); + +std::string trim_copy(const std::string &s); + +void divide( + const char *data, std::size_t size, char d, + std::function + fn); + +void divide( + const std::string &str, char d, + std::function + fn); + +void split(const char *b, const char *e, char d, + std::function fn); + +void split(const char *b, const char *e, char d, size_t m, + std::function fn); + +bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback); + +socket_t create_client_socket(const std::string &host, const std::string &ip, + int port, int address_family, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, + const char *def, size_t id); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +bool parse_accept_header(const std::string &s, + std::vector &content_types); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli, Zstd }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +class compressor { +public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; +}; + +class decompressor { +public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; +}; + +class nocompressor final : public compressor { +public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { +public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { +public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { +public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { +public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +class zstd_compressor : public compressor { +public: + zstd_compressor(); + ~zstd_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + ZSTD_CCtx *ctx_ = nullptr; +}; + +class zstd_decompressor : public decompressor { +public: + zstd_decompressor(); + ~zstd_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + ZSTD_DCtx *ctx_ = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + +private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string growable_buffer_; +}; + +class mmap { +public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + +private: +#if defined(_WIN64) + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; +#else + int fd_ = -1; +#endif + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; +}; + +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +inline bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +inline bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +inline bool is_field_name(const std::string &s) { return is_token(s); } + +inline bool is_vchar(char c) { return c >= 33 && c <= 126; } + +inline bool is_obs_text(char c) { return 128 <= static_cast(c); } + +inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +inline bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +inline bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN64) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { + return ret_ >= 0 && S_ISREG(st_.st_mode); +} +inline bool FileStat::is_dir() const { + return ret_ >= 0 && S_ISDIR(st_.st_mode); +} + +inline std::string encode_path(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_path(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, + size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void +divide(const char *data, std::size_t size, char d, + std::function + fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void +divide(const std::string &str, char d, + std::function + fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, + std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (growable_buffer_.empty()) { + return fixed_buffer_; + } else { + return growable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (growable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return growable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + growable_buffer_.clear(); + +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + + for (size_t i = 0;; i++) { + if (size() >= CPPHTTPLIB_MAX_LINE_LENGTH) { + // Treat exceptionally long lines as an error to + // prevent infinite loops/memory exhaustion + return false; + } + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + if (byte == '\n') { break; } +#else + if (prev_byte == '\r' && byte == '\n') { break; } + prev_byte = byte; +#endif + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (growable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + growable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) { open(path); } + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN64) + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { return false; } + + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, + OPEN_EXISTING, NULL); + + if (hFile_ == INVALID_HANDLE_VALUE) { return false; } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { return false; } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > + (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } + size_ = static_cast(size.QuadPart); + + hMapping_ = + ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } + + if (hMapping_ == NULL) { + close(); + return false; + } + + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); + + if (addr_ == nullptr) { + close(); + return false; + } +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { return false; } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); + + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { + close(); + is_open_empty_file = true; + return false; + } +#endif + + return true; +} + +inline bool mmap::is_open() const { + return is_open_empty_file ? true : addr_ != nullptr; +} + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { + return is_open_empty_file ? "" : static_cast(addr_); +} + +inline void mmap::close() { +#if defined(_WIN64) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } + + is_open_empty_file = false; +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN64 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN64 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, + int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN64 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { +#ifdef _WIN64 + return ::WSAPoll(fds, nfds, timeout); +#else + return ::poll(fds, nfds, timeout); +#endif +} + +template +inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { +#ifdef __APPLE__ + if (sock >= FD_SETSIZE) { return -1; } + + fd_set fds, *rfds, *wfds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + rfds = (Read ? &fds : nullptr); + wfds = (Read ? nullptr : &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { + return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); + }); +#else + struct pollfd pfd; + pfd.fd = sock; + pfd.events = (Read ? POLLIN : POLLOUT); + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); +#endif +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, + time_t usec) { +#ifdef __APPLE__ + if (sock >= FD_SETSIZE) { return Error::Connection; } + + fd_set fdsr, fdsw; + FD_ZERO(&fdsr); + FD_ZERO(&fdsw); + FD_SET(sock, &fdsr); + FD_SET(sock, &fdsw); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, nullptr, &tv); + }); + + if (ret == 0) { return Error::ConnectionTimeout; } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = + handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { return Error::ConnectionTimeout; } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { +public: + SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; +}; +#endif + +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, + time_t keep_alive_timeout_sec) { + using namespace std::chrono; + + const auto interval_usec = + CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { return true; } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + + while (true) { + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); + if (val < 0) { + break; // Ssocket error + } else if (val == 0) { + if (steady_clock::now() - start > timeout) { + break; // Timeout + } + } else { + return true; // Ready for read + } + } + + return false; +} + +template +inline bool +process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { break; } + count--; + } + return ret; +} + +template +inline bool +process_server_socket(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN64 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string +unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + +inline int getaddrinfo_with_timeout(const char *node, const char *service, + const struct addrinfo *hints, + struct addrinfo **res, time_t timeout_sec) { +#ifdef CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO + if (timeout_sec <= 0) { + // No timeout specified, use standard getaddrinfo + return getaddrinfo(node, service, hints, res); + } + +#ifdef _WIN64 + // Windows-specific implementation using GetAddrInfoEx with overlapped I/O + OVERLAPPED overlapped = {0}; + HANDLE event = CreateEventW(nullptr, TRUE, FALSE, nullptr); + if (!event) { return EAI_FAIL; } + + overlapped.hEvent = event; + + PADDRINFOEXW result_addrinfo = nullptr; + HANDLE cancel_handle = nullptr; + + ADDRINFOEXW hints_ex = {0}; + if (hints) { + hints_ex.ai_flags = hints->ai_flags; + hints_ex.ai_family = hints->ai_family; + hints_ex.ai_socktype = hints->ai_socktype; + hints_ex.ai_protocol = hints->ai_protocol; + } + + auto wnode = u8string_to_wstring(node); + auto wservice = u8string_to_wstring(service); + + auto ret = ::GetAddrInfoExW(wnode.data(), wservice.data(), NS_DNS, nullptr, + hints ? &hints_ex : nullptr, &result_addrinfo, + nullptr, &overlapped, nullptr, &cancel_handle); + + if (ret == WSA_IO_PENDING) { + auto wait_result = + ::WaitForSingleObject(event, static_cast(timeout_sec * 1000)); + if (wait_result == WAIT_TIMEOUT) { + if (cancel_handle) { ::GetAddrInfoExCancel(&cancel_handle); } + ::CloseHandle(event); + return EAI_AGAIN; + } + + DWORD bytes_returned; + if (!::GetOverlappedResult((HANDLE)INVALID_SOCKET, &overlapped, + &bytes_returned, FALSE)) { + ::CloseHandle(event); + return ::WSAGetLastError(); + } + } + + ::CloseHandle(event); + + if (ret == NO_ERROR || ret == WSA_IO_PENDING) { + *res = reinterpret_cast(result_addrinfo); + return 0; + } + + return ret; +#elif defined(TARGET_OS_OSX) + // macOS implementation using CFHost API for asynchronous DNS resolution + CFStringRef hostname_ref = CFStringCreateWithCString( + kCFAllocatorDefault, node, kCFStringEncodingUTF8); + if (!hostname_ref) { return EAI_MEMORY; } + + CFHostRef host_ref = CFHostCreateWithName(kCFAllocatorDefault, hostname_ref); + CFRelease(hostname_ref); + if (!host_ref) { return EAI_MEMORY; } + + // Set up context for callback + struct CFHostContext { + bool completed = false; + bool success = false; + CFArrayRef addresses = nullptr; + std::mutex mutex; + std::condition_variable cv; + } context; + + CFHostClientContext client_context; + memset(&client_context, 0, sizeof(client_context)); + client_context.info = &context; + + // Set callback + auto callback = [](CFHostRef theHost, CFHostInfoType /*typeInfo*/, + const CFStreamError *error, void *info) { + auto ctx = static_cast(info); + std::lock_guard lock(ctx->mutex); + + if (error && error->error != 0) { + ctx->success = false; + } else { + Boolean hasBeenResolved; + ctx->addresses = CFHostGetAddressing(theHost, &hasBeenResolved); + if (ctx->addresses && hasBeenResolved) { + CFRetain(ctx->addresses); + ctx->success = true; + } else { + ctx->success = false; + } + } + ctx->completed = true; + ctx->cv.notify_one(); + }; + + if (!CFHostSetClient(host_ref, callback, &client_context)) { + CFRelease(host_ref); + return EAI_SYSTEM; + } + + // Schedule on run loop + CFRunLoopRef run_loop = CFRunLoopGetCurrent(); + CFHostScheduleWithRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode); + + // Start resolution + CFStreamError stream_error; + if (!CFHostStartInfoResolution(host_ref, kCFHostAddresses, &stream_error)) { + CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode); + CFRelease(host_ref); + return EAI_FAIL; + } + + // Wait for completion with timeout + auto timeout_time = + std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec); + bool timed_out = false; + + { + std::unique_lock lock(context.mutex); + + while (!context.completed) { + auto now = std::chrono::steady_clock::now(); + if (now >= timeout_time) { + timed_out = true; + break; + } + + // Run the runloop for a short time + lock.unlock(); + CFRunLoopRunInMode(kCFRunLoopDefaultMode, 0.1, true); + lock.lock(); + } + } + + // Clean up + CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode); + CFHostSetClient(host_ref, nullptr, nullptr); + + if (timed_out || !context.completed) { + CFHostCancelInfoResolution(host_ref, kCFHostAddresses); + CFRelease(host_ref); + return EAI_AGAIN; + } + + if (!context.success || !context.addresses) { + CFRelease(host_ref); + return EAI_NODATA; + } + + // Convert CFArray to addrinfo + CFIndex count = CFArrayGetCount(context.addresses); + if (count == 0) { + CFRelease(context.addresses); + CFRelease(host_ref); + return EAI_NODATA; + } + + struct addrinfo *result_addrinfo = nullptr; + struct addrinfo **current = &result_addrinfo; + + for (CFIndex i = 0; i < count; i++) { + CFDataRef addr_data = + static_cast(CFArrayGetValueAtIndex(context.addresses, i)); + if (!addr_data) continue; + + const struct sockaddr *sockaddr_ptr = + reinterpret_cast(CFDataGetBytePtr(addr_data)); + socklen_t sockaddr_len = static_cast(CFDataGetLength(addr_data)); + + // Allocate addrinfo structure + *current = static_cast(malloc(sizeof(struct addrinfo))); + if (!*current) { + freeaddrinfo(result_addrinfo); + CFRelease(context.addresses); + CFRelease(host_ref); + return EAI_MEMORY; + } + + memset(*current, 0, sizeof(struct addrinfo)); + + // Set up addrinfo fields + (*current)->ai_family = sockaddr_ptr->sa_family; + (*current)->ai_socktype = hints ? hints->ai_socktype : SOCK_STREAM; + (*current)->ai_protocol = hints ? hints->ai_protocol : IPPROTO_TCP; + (*current)->ai_addrlen = sockaddr_len; + + // Copy sockaddr + (*current)->ai_addr = static_cast(malloc(sockaddr_len)); + if (!(*current)->ai_addr) { + freeaddrinfo(result_addrinfo); + CFRelease(context.addresses); + CFRelease(host_ref); + return EAI_MEMORY; + } + memcpy((*current)->ai_addr, sockaddr_ptr, sockaddr_len); + + // Set port if service is specified + if (service && strlen(service) > 0) { + int port = atoi(service); + if (port > 0) { + if (sockaddr_ptr->sa_family == AF_INET) { + reinterpret_cast((*current)->ai_addr) + ->sin_port = htons(static_cast(port)); + } else if (sockaddr_ptr->sa_family == AF_INET6) { + reinterpret_cast((*current)->ai_addr) + ->sin6_port = htons(static_cast(port)); + } + } + } + + current = &((*current)->ai_next); + } + + CFRelease(context.addresses); + CFRelease(host_ref); + + *res = result_addrinfo; + return 0; +#elif defined(_GNU_SOURCE) && defined(__GLIBC__) && \ + (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2)) + // Linux implementation using getaddrinfo_a for asynchronous DNS resolution + struct gaicb request; + struct gaicb *requests[1] = {&request}; + struct sigevent sevp; + struct timespec timeout; + + // Initialize the request structure + memset(&request, 0, sizeof(request)); + request.ar_name = node; + request.ar_service = service; + request.ar_request = hints; + + // Set up timeout + timeout.tv_sec = timeout_sec; + timeout.tv_nsec = 0; + + // Initialize sigevent structure (not used, but required) + memset(&sevp, 0, sizeof(sevp)); + sevp.sigev_notify = SIGEV_NONE; + + // Start asynchronous resolution + int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp); + if (start_result != 0) { return start_result; } + + // Wait for completion with timeout + int wait_result = + gai_suspend((const struct gaicb *const *)requests, 1, &timeout); + + if (wait_result == 0) { + // Completed successfully, get the result + int gai_result = gai_error(&request); + if (gai_result == 0) { + *res = request.ar_result; + return 0; + } else { + // Clean up on error + if (request.ar_result) { freeaddrinfo(request.ar_result); } + return gai_result; + } + } else if (wait_result == EAI_AGAIN) { + // Timeout occurred, cancel the request + gai_cancel(&request); + return EAI_AGAIN; + } else { + // Other error occurred + gai_cancel(&request); + return wait_result; + } +#else + // Fallback implementation using thread-based timeout for other Unix systems + std::mutex result_mutex; + std::condition_variable result_cv; + auto completed = false; + auto result = EAI_SYSTEM; + struct addrinfo *result_addrinfo = nullptr; + + std::thread resolve_thread([&]() { + auto thread_result = getaddrinfo(node, service, hints, &result_addrinfo); + + std::lock_guard lock(result_mutex); + result = thread_result; + completed = true; + result_cv.notify_one(); + }); + + // Wait for completion or timeout + std::unique_lock lock(result_mutex); + auto finished = result_cv.wait_for(lock, std::chrono::seconds(timeout_sec), + [&] { return completed; }); + + if (finished) { + // Operation completed within timeout + resolve_thread.join(); + *res = result_addrinfo; + return result; + } else { + // Timeout occurred + resolve_thread.detach(); // Let the thread finish in background + return EAI_AGAIN; // Return timeout error + } +#endif +#else + (void)(timeout_sec); // Unused parameter for non-blocking getaddrinfo + return getaddrinfo(node, service, hints, res); +#endif +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, + int address_family, int socket_flags, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + BindOrConnect bind_or_connect, time_t timeout_sec = 0) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_IP; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { node = host.c_str(); } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#if !defined(_WIN64) || defined(CPPHTTPLIB_HAVE_AFUNIX_H) + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } + +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, + hints.ai_protocol); +#else + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast( + sizeof(addr) - sizeof(addr.sun_path) + addrlen); + +#ifndef SOCK_CLOEXEC +#ifndef _WIN64 + fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif +#endif + + if (socket_options) { socket_options(sock); } + +#ifdef _WIN64 + // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so + // remove the option. + detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); +#endif + + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo_with_timeout(node, service.c_str(), &hints, &result, + timeout_sec)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN64 + auto sock = + WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + +#ifdef SOCK_CLOEXEC + auto sock = + socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + +#endif + if (sock == INVALID_SOCKET) { continue; } + +#if !defined _WIN64 && !defined SOCK_CLOEXEC + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); } + + if (rp->ai_family == AF_INET6) { + set_socket_opt(sock, IPPROTO_IPV6, IPV6_V6ONLY, ipv6_v6only ? 1 : 0); + } + + if (socket_options) { socket_options(sock); } + + // bind or connect + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { return sock; } + + close_socket(sock); + + if (quit) { break; } + } + + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN64 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN64 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo_with_timeout(host.c_str(), "0", &hints, &result, 0)) { + return false; + } + + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + return ret; +} + +#if !defined _WIN64 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || + ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + return addr_candidate; +} +#endif + +inline socket_t create_client_socket( + const std::string &host, const std::string &ip, int port, + int address_family, bool tcp_nodelay, bool ipv6_v6only, + SocketOptions socket_options, time_t connection_timeout_sec, + time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { ip_from_if = intf; } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = + ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, + connection_timeout_usec); + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { quit = true; } + return false; + } + } + + set_nonblocking(sock2, false); + set_socket_opt_time(sock2, SOL_SOCKET, SO_RCVTIMEO, read_timeout_sec, + read_timeout_usec); + set_socket_opt_time(sock2, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec, + write_timeout_usec); + + error = Error::Success; + return true; + }, + connection_timeout_sec); // Pass DNS timeout + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { error = Error::Connection; } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), + &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { +#ifndef _WIN64 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, + unsigned int h) { + return (l == 0) + ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { + return str2tag_core(s.data(), s.size(), 0); +} + +namespace udl { + +inline constexpr unsigned int operator""_t(const char *s, size_t l) { + return str2tag_core(s, l, 0); +} + +} // namespace udl + +inline std::string +find_content_type(const std::string &path, + const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second; } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: return default_content_type; + + case "css"_t: return "text/css"; + case "csv"_t: return "text/csv"; + case "htm"_t: + case "html"_t: return "text/html"; + case "js"_t: + case "mjs"_t: return "text/javascript"; + case "txt"_t: return "text/plain"; + case "vtt"_t: return "text/vtt"; + + case "apng"_t: return "image/apng"; + case "avif"_t: return "image/avif"; + case "bmp"_t: return "image/bmp"; + case "gif"_t: return "image/gif"; + case "png"_t: return "image/png"; + case "svg"_t: return "image/svg+xml"; + case "webp"_t: return "image/webp"; + case "ico"_t: return "image/x-icon"; + case "tif"_t: return "image/tiff"; + case "tiff"_t: return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: return "image/jpeg"; + + case "mp4"_t: return "video/mp4"; + case "mpeg"_t: return "video/mpeg"; + case "webm"_t: return "video/webm"; + + case "mp3"_t: return "audio/mp3"; + case "mpga"_t: return "audio/mpeg"; + case "weba"_t: return "audio/webm"; + case "wav"_t: return "audio/wave"; + + case "otf"_t: return "font/otf"; + case "ttf"_t: return "font/ttf"; + case "woff"_t: return "font/woff"; + case "woff2"_t: return "font/woff2"; + + case "7z"_t: return "application/x-7z-compressed"; + case "atom"_t: return "application/atom+xml"; + case "pdf"_t: return "application/pdf"; + case "json"_t: return "application/json"; + case "rss"_t: return "application/rss+xml"; + case "tar"_t: return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: return "application/xhtml+xml"; + case "xslt"_t: return "application/xslt+xml"; + case "xml"_t: return "application/xml"; + case "gz"_t: return "application/gzip"; + case "zip"_t: return "application/zip"; + case "wasm"_t: return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: return true; + + case "text/event-stream"_t: return false; + + default: return !content_type.rfind("text/", 0); + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { return EncodingType::None; } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { return EncodingType::Brotli; } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { return EncodingType::Gzip; } +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 + ret = s.find("zstd") != std::string::npos; + if (ret) { return EncodingType::Zstd; } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, + bool /*last*/, Callback callback) { + if (!data_length) { return true; } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || + (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm_); return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { return false; } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); +} + +inline brotli_compressor::~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); +} + +inline bool brotli_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { break; } + } else { + if (!available_in) { break; } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, + &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT + : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, + size_t data_length, + Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - avail_out)) { return false; } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +inline zstd_compressor::zstd_compressor() { + ctx_ = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast); +} + +inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); } + +inline bool zstd_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue; + ZSTD_inBuffer input = {data, data_length, 0}; + + bool finished; + do { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + + finished = last ? (remaining == 0) : (input.pos == input.size); + + } while (!finished); + + return true; +} + +inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); } + +inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); } + +inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; } + +inline bool zstd_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + std::array buff{}; + ZSTD_inBuffer input = {data, data_length, 0}; + + while (input.pos < input.size) { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + } + + return true; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, + const std::string &key, const char *def, + size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.c_str(); } + return def; +} + +template +inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_path(val, false)); + } + + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + size_t header_count = 0; + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + // Blank line indicates end of headers. + if (line_reader.size() == 1) { break; } + line_terminator_len = 1; +#else + continue; // Skip invalid line. +#endif + } + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Check header count limit + if (header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + headers.emplace(key, val); + })) { + return false; + } + + header_count++; + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, size_t len, + DownloadProgress progress, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + size_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n), r, len)) { return false; } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, size_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + size_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +enum class ReadContentResult { + Success, // Successfully read the content + PayloadTooLarge, // The content exceeds the specified payload limit + Error // An error occurred while reading the content +}; + +inline ReadContentResult +read_content_without_length(Stream &strm, size_t payload_max_length, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + size_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n == 0) { return ReadContentResult::Success; } + if (n < 0) { return ReadContentResult::Error; } + + // Check if adding this data would exceed the payload limit + if (r > payload_max_length || + payload_max_length - r < static_cast(n)) { + return ReadContentResult::PayloadTooLarge; + } + + if (!out(buf, static_cast(n), r, 0)) { + return ReadContentResult::Error; + } + r += static_cast(n); + } + + return ReadContentResult::Success; +} + +template +inline ReadContentResult read_content_chunked(Stream &strm, T &x, + size_t payload_max_length, + ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return ReadContentResult::Error; } + + unsigned long chunk_len; + size_t total_len = 0; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; } + if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; } + + if (chunk_len == 0) { break; } + + // Check if adding this chunk would exceed the payload limit + if (total_len > payload_max_length || + payload_max_length - total_len < chunk_len) { + return ReadContentResult::PayloadTooLarge; + } + + total_len += chunk_len; + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return ReadContentResult::Error; + } + + if (!line_reader.getline()) { return ReadContentResult::Error; } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { + return ReadContentResult::Error; + } + + if (!line_reader.getline()) { return ReadContentResult::Error; } + } + + assert(chunk_len == 0); + + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + if (!line_reader.getline()) { return ReadContentResult::Success; } + + // RFC 7230 Section 4.1.2 - Headers prohibited in trailers + thread_local case_ignore::unordered_set prohibited_trailers = { + // Message framing + "transfer-encoding", "content-length", + + // Routing + "host", + + // Authentication + "authorization", "www-authenticate", "proxy-authenticate", + "proxy-authorization", "cookie", "set-cookie", + + // Request modifiers + "cache-control", "expect", "max-forwards", "pragma", "range", "te", + + // Response control + "age", "expires", "date", "location", "retry-after", "vary", "warning", + + // Payload processing + "content-encoding", "content-type", "content-range", "trailer"}; + + // Parse declared trailer headers once for performance + case_ignore::unordered_set declared_trailers; + if (has_header(x.headers, "Trailer")) { + auto trailer_header = get_header_value(x.headers, "Trailer", "", 0); + auto len = std::strlen(trailer_header); + + split(trailer_header, trailer_header + len, ',', + [&](const char *b, const char *e) { + std::string key(b, e); + if (prohibited_trailers.find(key) == prohibited_trailers.end()) { + declared_trailers.insert(key); + } + }); + } + + size_t trailer_header_count = 0; + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return ReadContentResult::Error; + } + + // Check trailer header count limit + if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { + return ReadContentResult::Error; + } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + if (declared_trailers.find(key) != declared_trailers.end()) { + x.trailers.emplace(key, val); + trailer_header_count++; + } + }); + + if (!line_reader.getline()) { return ReadContentResult::Error; } + } + + return ReadContentResult::Success; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return case_ignore::equal( + get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, + ContentReceiverWithProgress receiver, + bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding == "zstd") { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, + size_t off, size_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { + return receiver(buf2, n2, off, len); + }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, size_t off, + size_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + DownloadProgress progress, + ContentReceiverWithProgress receiver, bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, + [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + auto result = read_content_chunked(strm, x, payload_max_length, out); + if (result == ReadContentResult::Success) { + ret = true; + } else if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + ret = false; + } else { + ret = false; + } + } else if (!has_header(x.headers, "Content-Length")) { + auto result = + read_content_without_length(strm, payload_max_length, out); + if (result == ReadContentResult::Success) { + ret = true; + } else if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + ret = false; + } else { + ret = false; + } + } else { + auto is_invalid_value = false; + auto len = get_header_value_u64(x.headers, "Content-Length", + (std::numeric_limits::max)(), + 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 + : StatusCode::BadRequest_400; + } + return ret; + }); +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, + const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { return false; } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content_with_progress(Stream &strm, + const ContentProvider &content_provider, + size_t offset, size_t length, + T is_shutting_down, + const UploadProgress &upload_progress, + Error &error) { + size_t end_offset = offset + length; + size_t start_offset = offset; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (write_data(strm, d, l)) { + offset += l; + + if (upload_progress && length > 0) { + size_t current_written = offset - start_offset; + if (!upload_progress(current_written, length)) { + ok = false; + return false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.wait_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, T is_shutting_down, + Error &error) { + return write_content_with_progress(strm, content_provider, offset, length, + is_shutting_down, nullptr, error); +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, + error); +} + +template +inline bool +write_content_without_length(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { ok = false; } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.wait_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool +write_content_chunked(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = + from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { return; } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + constexpr const char done_marker[] = "0\r\n"; + if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + constexpr const char crlf[] = "\r\n"; + if (!write_data(strm, crlf, str_len(crlf))) { ok = false; } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { + done_with_trailer(&trailer); + }; + + while (data_available && !is_shutting_down()) { + if (!strm.wait_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, + compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, + const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && + (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { res.location = location; } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += httplib::encode_uri_component(it->second); + } + return query; +} + +inline void parse_query_text(const char *data, std::size_t size, + Params ¶ms) { + std::set cache; + split(data, data + size, '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(std::move(kv)); + + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, + std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + params.emplace(decode_path(key, true), decode_path(val, true)); + } + }); +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + parse_query_text(s.data(), s.size(), params); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { return false; } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), + trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { return; } + + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; + } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = + static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = + static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || + (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); + }); + return all_valid_ranges && !ranges.empty(); + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { return false; } +#endif + +inline bool parse_accept_header(const std::string &s, + std::vector &content_types) { + content_types.clear(); + + // Empty string is considered valid (no preference) + if (s.empty()) { return true; } + + // Check for invalid patterns: leading/trailing commas or consecutive commas + if (s.front() == ',' || s.back() == ',' || + s.find(",,") != std::string::npos) { + return false; + } + + struct AcceptEntry { + std::string media_type; + double quality; + int order; // Original order in header + }; + + std::vector entries; + int order = 0; + bool has_invalid_entry = false; + + // Split by comma and parse each entry + split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) { + std::string entry(b, e); + entry = trim_copy(entry); + + if (entry.empty()) { + has_invalid_entry = true; + return; + } + + AcceptEntry accept_entry; + accept_entry.quality = 1.0; // Default quality + accept_entry.order = order++; + + // Find q= parameter + auto q_pos = entry.find(";q="); + if (q_pos == std::string::npos) { q_pos = entry.find("; q="); } + + if (q_pos != std::string::npos) { + // Extract media type (before q parameter) + accept_entry.media_type = trim_copy(entry.substr(0, q_pos)); + + // Extract quality value + auto q_start = entry.find('=', q_pos) + 1; + auto q_end = entry.find(';', q_start); + if (q_end == std::string::npos) { q_end = entry.length(); } + + std::string quality_str = + trim_copy(entry.substr(q_start, q_end - q_start)); + if (quality_str.empty()) { + has_invalid_entry = true; + return; + } + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + { + std::istringstream iss(quality_str); + iss >> accept_entry.quality; + + // Check if conversion was successful and entire string was consumed + if (iss.fail() || !iss.eof()) { + has_invalid_entry = true; + return; + } + } +#else + try { + accept_entry.quality = std::stod(quality_str); + } catch (...) { + has_invalid_entry = true; + return; + } +#endif + // Check if quality is in valid range [0.0, 1.0] + if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) { + has_invalid_entry = true; + return; + } + } else { + // No quality parameter, use entire entry as media type + accept_entry.media_type = entry; + } + + // Remove additional parameters from media type + auto param_pos = accept_entry.media_type.find(';'); + if (param_pos != std::string::npos) { + accept_entry.media_type = + trim_copy(accept_entry.media_type.substr(0, param_pos)); + } + + // Basic validation of media type format + if (accept_entry.media_type.empty()) { + has_invalid_entry = true; + return; + } + + // Check for basic media type format (should contain '/' or be '*') + if (accept_entry.media_type != "*" && + accept_entry.media_type.find('/') == std::string::npos) { + has_invalid_entry = true; + return; + } + + entries.push_back(accept_entry); + }); + + // Return false if any invalid entry was found + if (has_invalid_entry) { return false; } + + // Sort by quality (descending), then by original order (ascending) + std::sort(entries.begin(), entries.end(), + [](const AcceptEntry &a, const AcceptEntry &b) { + if (a.quality != b.quality) { + return a.quality > b.quality; // Higher quality first + } + return a.order < b.order; // Earlier order first for same quality + }); + + // Extract sorted media types + content_types.reserve(entries.size()); + for (const auto &entry : entries) { + content_types.push_back(entry.media_type); + } + + return true; +} + +class FormDataParser { +public: + FormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const FormDataHeader &header_callback, + const ContentReceiver &content_callback) { + + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + auto pos = buf_find(dash_boundary_crlf_); + if (pos == buf_size()) { return true; } + buf_erase(pos + dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](const std::string &, const std::string &) {})) { + is_valid_ = false; + return false; + } + + // Parse and emplace space trimmed headers into a map + if (!parse_header( + header.data(), header.data() + header.size(), + [&](const std::string &key, const std::string &val) { + file_.headers.emplace(key, val); + })) { + is_valid_ = false; + return false; + } + + constexpr const char header_content_type[] = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = + trim_copy(header.substr(str_len(header_content_type))); + } else { + thread_local const std::regex re_content_disposition( + R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { file_.filename = it->second; } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 encoding... + thread_local const std::regex re_rfc5987_encoding( + R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_path(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { return true; } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { return true; } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { return true; } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { return true; } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + file_.headers.clear(); + } + + bool start_with_case_ignore(const std::string &a, const char *b) const { + const auto b_len = strlen(b); + if (a.size() < b_len) { return false; } + for (size_t i = 0; i < b_len; i++) { + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + FormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, + const std::string &b) const { + if (epos - spos < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { return false; } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { return buf_size(); } + if (buf_[pos] == c) { break; } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { return buf_size(); } + + if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string random_string(size_t length) { + constexpr const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + thread_local auto engine([]() { + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + return std::mt19937(seed_sequence); + }()); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string +serialize_multipart_formdata_item_begin(const T &item, + const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string +serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string +serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string +serialize_multipart_formdata(const UploadFormDataItems &items, + const std::string &boundary, bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { body += serialize_multipart_formdata_finish(boundary); } + + return body; +} + +inline void coalesce_ranges(Ranges &ranges, size_t content_length) { + if (ranges.size() <= 1) return; + + // Sort ranges by start position + std::sort(ranges.begin(), ranges.end(), + [](const Range &a, const Range &b) { return a.first < b.first; }); + + Ranges coalesced; + coalesced.reserve(ranges.size()); + + for (auto &r : ranges) { + auto first_pos = r.first; + auto last_pos = r.second; + + // Handle special cases like in range_error + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = static_cast(content_length); + } + + if (first_pos == -1) { + first_pos = static_cast(content_length) - last_pos; + last_pos = static_cast(content_length) - 1; + } + + if (last_pos == -1 || last_pos >= static_cast(content_length)) { + last_pos = static_cast(content_length) - 1; + } + + // Skip invalid ranges + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos < static_cast(content_length))) { + continue; + } + + // Coalesce with previous range if overlapping or adjacent (but not + // identical) + if (!coalesced.empty()) { + auto &prev = coalesced.back(); + // Check if current range overlaps or is adjacent to previous range + // but don't coalesce identical ranges (allow duplicates) + if (first_pos <= prev.second + 1 && + !(first_pos == prev.first && last_pos == prev.second)) { + // Extend the previous range + prev.second = (std::max)(prev.second, last_pos); + continue; + } + } + + // Add new range + coalesced.emplace_back(first_pos, last_pos); + } + + ranges = std::move(coalesced); +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t content_len = static_cast( + res.content_length_ ? res.content_length_ : res.body.size()); + + std::vector> processed_ranges; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { return true; } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = content_len; + } + + if (first_pos == -1) { + first_pos = content_len - last_pos; + last_pos = content_len - 1; + } + + // NOTE: RFC-9110 '14.1.2. Byte Ranges': + // A client can limit the number of bytes requested without knowing the + // size of the selected representation. If the last-pos value is absent, + // or if the value is greater than or equal to the current length of the + // representation data, the byte range is interpreted as the remainder of + // the representation (i.e., the server replaces the value of last-pos + // with a value that is one less than the current length of the selected + // representation). + // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 + if (last_pos == -1 || last_pos >= content_len) { + last_pos = content_len - 1; + } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos <= content_len - 1)) { + return true; + } + + // Request must not have more than two overlapping ranges + for (const auto &processed_range : processed_ranges) { + if (!(last_pos < processed_range.first || + first_pos > processed_range.second)) { + overwrapping_count++; + if (overwrapping_count > 2) { return true; } + break; // Only count once per range + } + } + + processed_ranges.emplace_back(first_pos, last_pos); + } + + // After validation, coalesce overlapping ranges as per RFC 9110 + coalesce_ranges(req.ranges, static_cast(content_len)); + } + + return false; +} + +inline std::pair +get_range_offset_and_length(Range r, size_t content_length) { + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && + r.second < static_cast(content_length)); + (void)(content_length); + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field( + const std::pair &offset_and_length, size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length, SToken stoken, + CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = + get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, + std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool +write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "DELETE") { + return true; + } + if (req.has_header("Content-Length") && + req.get_header_value_u64("Content-Length") > 0) { + return true; + } + if (is_chunked_transfer_encoding(req.headers)) { return true; } + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} + +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +inline bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} + +#ifdef _WIN64 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + if (!hStore) { return false; } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \ + defined(TARGET_OS_OSX) +template +using CFObjectPtr = + std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { CFRelease(obj); } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, + kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, + sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { return false; } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast( + CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != + errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast( + CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = + d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // _WIN64 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN64 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +inline bool parse_www_authenticate(const Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } + +private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { return std::string(); } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, + std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (detail::getaddrinfo_with_timeout(hostname.c_str(), nullptr, &hints, + &result, 0)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = + *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, + dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string encode_uri_component(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_uri(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')' || c == ';' || c == '/' || c == '?' || c == ':' || c == '@' || + c == '&' || c == '=' || c == '+' || c == '$' || c == ',' || c == '#') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string decode_uri_component(const std::string &value) { + std::string result; + + for (size_t i = 0; i < value.size(); i++) { + if (value[i] == '%' && i + 2 < value.size()) { + auto val = 0; + if (detail::from_hex_to_i(value, i + 1, 2, val)) { + result += static_cast(val); + i += 2; + } else { + result += value[i]; + } + } else { + result += value[i]; + } + } + + return result; +} + +inline std::string decode_uri(const std::string &value) { + std::string result; + + for (size_t i = 0; i < value.size(); i++) { + if (value[i] == '%' && i + 2 < value.size()) { + auto val = 0; + if (detail::from_hex_to_i(value, i + 1, 2, val)) { + result += static_cast(val); + i += 2; + } else { + result += value[i]; + } + } else { + result += value[i]; + } + } + + return result; +} + +[[deprecated("Use encode_uri_component instead")]] +inline std::string encode_query_param(const std::string &value) { + return encode_uri_component(value); +} + +inline std::string append_query_params(const std::string &path, + const Params ¶ms) { + std::string path_with_query = path; + thread_local const std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair +make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair +make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const std::string &key, + const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_trailer(const std::string &key) const { + return trailers.find(key) != trailers.end(); +} + +inline std::string Request::get_trailer_value(const std::string &key, + size_t id) const { + auto rng = trailers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_trailer_value_count(const std::string &key) const { + auto r = trailers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::has_param(const std::string &key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const std::string &key, + size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +// Multipart FormData implementation +inline std::string MultipartFormData::get_field(const std::string &key, + size_t id) const { + auto rng = fields.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.content; } + return std::string(); +} + +inline std::vector +MultipartFormData::get_fields(const std::string &key) const { + std::vector values; + auto rng = fields.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second.content); + } + return values; +} + +inline bool MultipartFormData::has_field(const std::string &key) const { + return fields.find(key) != fields.end(); +} + +inline size_t MultipartFormData::get_field_count(const std::string &key) const { + auto r = fields.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline FormData MultipartFormData::get_file(const std::string &key, + size_t id) const { + auto rng = files.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return FormData(); +} + +inline std::vector +MultipartFormData::get_files(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +inline bool MultipartFormData::has_file(const std::string &key) const { + return files.find(key) != files.end(); +} + +inline size_t MultipartFormData::get_file_count(const std::string &key) const { + auto r = files.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} +inline bool Response::has_trailer(const std::string &key) const { + return trailers.find(key) != trailers.end(); +} + +inline std::string Response::get_trailer_value(const std::string &key, + size_t id) const { + auto rng = trailers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Response::get_trailer_value_count(const std::string &key) const { + auto r = trailers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (detail::fields::is_field_value(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, + const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { content_provider_ = std::move(provider); } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +inline void Response::set_file_content(const std::string &path, + const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { + file_content_path_ = path; +} + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(request_headers_, key, def, id); +} + +inline size_t +Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +namespace detail { + +inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, + time_t timeout_sec, time_t timeout_usec, + time_t &actual_timeout_sec, + time_t &actual_timeout_usec) { + auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); + + auto actual_timeout_msec = + (std::min)(max_timeout_msec - duration_msec, timeout_msec); + + if (actual_timeout_msec < 0) { actual_timeout_msec = 0; } + + actual_timeout_sec = actual_timeout_msec / 1000; + actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; +} + +// Socket stream implementation +inline SocketStream::SocketStream( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time), + read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +inline bool SocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN64 + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, + static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!wait_readable()) { return -1; } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, + CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!wait_writable()) { return -1; } + +#if defined(_WIN64) && !defined(_WIN64) + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +inline time_t SocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline time_t BufferStream::duration() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) + : MatcherBase(pattern) { + constexpr const char marker[] = "/:"; + + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find( + marker, last_param_end == 0 ? last_param_end : last_param_end - 1); + if (marker_pos == std::string::npos) { break; } + + static_fragments_.push_back( + pattern.substr(last_param_end, marker_pos - last_param_end + 1)); + + const auto param_name_start = marker_pos + str_len(marker); + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } + + auto param_name = + pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), + fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { continue; } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { sep_pos = request.path.length(); } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace( + param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everything up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : new_task_queue( + [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN64 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr +Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, + const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, + const std::string &dir, Headers headers) { + detail::FileStat stat(dir); + if (stat.is_dir()) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server & +Server::set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, + std::true_type) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(Handler handler, + std::false_type) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_request_handler(HandlerWithResponse handler) { + pre_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server &Server::set_pre_compression_logger(Logger logger) { + pre_compression_logger_ = std::move(logger); + return *this; +} + +inline Server & +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer( + std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, + int socket_flags) { + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommissioned = true; } + return ret >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommissioned = true; } + return ret; +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const std::string &host, int port, + int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running_ && !is_decommissioned) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + is_decommissioned = false; +} + +inline void Server::decommission() { is_decommissioned = true; } + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: req.method = std::string(b, e); break; + case 1: req.target = std::string(b, e); break; + case 2: req.version = std::string(b, e); break; + default: break; + } + count++; + }); + + if (count != 3) { return false; } + } + + thread_local const std::set methods{ + "GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { return false; } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_path( + std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, + Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, + bool close_connection, + const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && + error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } + + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && + !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } + + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { post_routing_handler_(req, res); } + + // Response line and headers + { + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { return false; } + if (!header_writer_(bstrm, res.headers)) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return ret; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, + offset_and_length.first, + offset_and_length.second, is_shutting_down); + } else { + return detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + FormFields::iterator cur_field; + FormFiles::iterator cur_file; + auto is_text_field = false; + size_t count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart FormData + [&](const FormData &file) { + if (count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + + if (file.filename.empty()) { + cur_field = req.form.fields.emplace( + file.name, FormField{file.name, file.content, file.headers}); + is_text_field = true; + } else { + cur_file = req.form.files.emplace(file.name, file); + is_text_field = false; + } + return true; + }, + [&](const char *buf, size_t n) { + if (is_text_field) { + auto &content = cur_field->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + } else { + auto &content = cur_file->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + } + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + FormDataHeader multipart_header, ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool Server::read_content_core( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + FormDataHeader multipart_header, ContentReceiver multipart_receiver) const { + detail::FormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) { + return multipart_form_data_parser.parse(buf, n, multipart_header, + multipart_receiver); + }; + } else { + out = [receiver](const char *buf, size_t n, size_t /*off*/, + size_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { return false; } + + res.set_content_provider( + mm->size(), + detail::find_content_type(path, file_extension_and_mimetype_map_, + default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (req.method != "HEAD" && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t +Server::create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket( + host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, + ipv6_v6only_, std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, + int socket_flags) { + if (is_decommissioned) { return -1; } + + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + if (is_decommissioned) { return false; } + + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN64 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN64 + } +#endif + +#if defined _WIN64 + // sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else + socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, + read_timeout_sec_, read_timeout_usec_); + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO, + write_timeout_sec_, write_timeout_usec_); + + if (!task_queue->enqueue( + [this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + is_decommissioned = !ret; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + if ((req.method == "GET" || req.method == "HEAD") && + handle_file_request(req, res)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](FormDataHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + std::move(header), + std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + req.matched_route = matcher->pattern(); + if (!pre_request_handler_ || + pre_request_handler_(req, res) != HandlerResponse::Handled) { + handler(req, res); + } + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, + std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length( + req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } else if (type == detail::EncodingType::Zstd) { + res.set_header("Content-Encoding", "zstd"); + } + } + } + } + } else { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = + detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, + res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + if (pre_compression_logger_) { pre_compression_logger_(req, res); } + + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); + content_encoding = "zstd"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + req.matched_route = matcher->pattern(); + if (!pre_request_handler_ || + pre_request_handler_(req, res) != HandlerResponse::Handled) { + handler(req, res, content_reader); + } + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + +#ifdef __APPLE__ + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::InternalServerError_500; + return write_response(strm, close_connection, req, res); + } +#endif + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + // Check if the request URI doesn't exceed the limit + if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + req.local_addr = local_addr; + req.local_port = local_port; + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Accept")) { + const auto &accept_header = req.get_header_value("Accept"); + if (!detail::parse_accept_header(accept_header, req.accept_content_types)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + } + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + detail::write_response_line(strm, status); + strm.write("\r\n"); + break; + default: + connection_closed = true; + return write_response(strm, true, req, res); + } + } + + // Setup `is_connection_closed` method + auto sock = strm.socket(); + req.is_connection_closed = [sock]() { + return !detail::is_socket_alive(sock); + }; + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': val += "\\r"; break; + case '\n': val += "\\n"; break; + default: val += s[i]; break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 + : StatusCode::PartialContent_206; + } + + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { res.status = StatusCode::NotFound_404; } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, + nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) + : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + // Wait until all the requests in flight are handled. + size_t retry_count = 10; + while (retry_count-- > 0) { + { + std::lock_guard guard(socket_mutex_); + if (socket_requests_in_flight_ == 0) { break; } + } + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + max_timeout_msec_ = rhs.max_timeout_msec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + path_encode_ = rhs.path_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { ip = it->second; } + + return detail::create_client_socket( + host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, + Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { return false; } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, + bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { return; } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { return; } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, + Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { return false; } // CRLF + if (!line_reader.getline()) { return false; } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down non-gracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, req.start_time_, res, success, + error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { return false; } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || + !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, req.start_time_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection, error); + }); + + if (!ret) { + if (error == Error::Success) { error = Error::Unknown; } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers), + last_ssl_error_, last_openssl_error_}; +#else + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +#endif +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { return false; } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || + res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" + : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + thread_local const std::regex re( + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { next_host = m[3].str(); } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + auto path = detail::decode_path(next_path, true) + next_query; + + // Same host redirect - use current client + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } + + // Cross-host/scheme redirect - create new client with robust setup + return create_redirect_client(next_scheme, next_host, next_port, req, res, + path, location, error); +} + +// New method for robust redirect client creation +inline bool ClientImpl::create_redirect_client( + const std::string &scheme, const std::string &host, int port, Request &req, + Response &res, const std::string &path, const std::string &location, + Error &error) { + // Determine if we need SSL + auto need_ssl = (scheme == "https"); + + // Clean up request headers that are host/client specific + // Remove headers that should not be carried over to new host + auto headers_to_remove = + std::vector{"Host", "Proxy-Authorization", "Authorization"}; + + for (const auto &header_name : headers_to_remove) { + auto it = req.headers.find(header_name); + while (it != req.headers.end()) { + it = req.headers.erase(it); + it = req.headers.find(header_name); + } + } + + // Create appropriate client type and handle redirect + if (need_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Create SSL client for HTTPS redirect + SSLClient redirect_client(host, port); + + // Setup basic client configuration first + setup_redirect_client(redirect_client); + + // SSL-specific configuration for proxy environments + if (!proxy_host_.empty() && proxy_port_ != -1) { + // Critical: Disable SSL verification for proxy environments + redirect_client.enable_server_certificate_verification(false); + redirect_client.enable_server_hostname_verification(false); + } else { + // For direct SSL connections, copy SSL verification settings + redirect_client.enable_server_certificate_verification( + server_certificate_verification_); + redirect_client.enable_server_hostname_verification( + server_hostname_verification_); + } + + // Handle CA certificate store and paths if available + if (ca_cert_store_) { redirect_client.set_ca_cert_store(ca_cert_store_); } + if (!ca_cert_file_path_.empty()) { + redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_); + } + + // Client certificates are set through constructor for SSLClient + // NOTE: SSLClient constructor already takes client_cert_path and + // client_key_path so we need to create it properly if client certs are + // needed + + // Execute the redirect + return detail::redirect(redirect_client, req, res, path, location, error); +#else + // SSL not supported - set appropriate error + error = Error::SSLConnection; + return false; +#endif + } else { + // HTTP redirect + ClientImpl redirect_client(host, port); + + // Setup client with robust configuration + setup_redirect_client(redirect_client); + + // Execute the redirect + return detail::redirect(redirect_client, req, res, path, location, error); + } +} + +// New method for robust client setup (based on basic_manual_redirect.cpp logic) +template +inline void ClientImpl::setup_redirect_client(ClientType &client) { + // Copy basic settings first + client.set_connection_timeout(connection_timeout_sec_); + client.set_read_timeout(read_timeout_sec_, read_timeout_usec_); + client.set_write_timeout(write_timeout_sec_, write_timeout_usec_); + client.set_keep_alive(keep_alive_); + client.set_follow_location( + true); // Enable redirects to handle multi-step redirects + client.set_path_encode(path_encode_); + client.set_compress(compress_); + client.set_decompress(decompress_); + + // Copy authentication settings BEFORE proxy setup + if (!basic_auth_username_.empty()) { + client.set_basic_auth(basic_auth_username_, basic_auth_password_); + } + if (!bearer_token_auth_token_.empty()) { + client.set_bearer_token_auth(bearer_token_auth_token_); + } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!digest_auth_username_.empty()) { + client.set_digest_auth(digest_auth_username_, digest_auth_password_); + } +#endif + + // Setup proxy configuration (CRITICAL ORDER - proxy must be set + // before proxy auth) + if (!proxy_host_.empty() && proxy_port_ != -1) { + // First set proxy host and port + client.set_proxy(proxy_host_, proxy_port_); + + // Then set proxy authentication (order matters!) + if (!proxy_basic_auth_username_.empty()) { + client.set_proxy_basic_auth(proxy_basic_auth_username_, + proxy_basic_auth_password_); + } + if (!proxy_bearer_token_auth_token_.empty()) { + client.set_proxy_bearer_token_auth(proxy_bearer_token_auth_token_); + } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!proxy_digest_auth_username_.empty()) { + client.set_proxy_digest_auth(proxy_digest_auth_username_, + proxy_digest_auth_password_); + } +#endif + } + + // Copy network and socket settings + client.set_address_family(address_family_); + client.set_tcp_nodelay(tcp_nodelay_); + client.set_ipv6_v6only(ipv6_v6only_); + if (socket_options_) { client.set_socket_options(socket_options_); } + if (!interface_.empty()) { client.set_interface(interface_); } + + // Copy logging and headers + if (logger_) { client.set_logger(logger_); } + + // NOTE: DO NOT copy default_headers_ as they may contain stale Host headers + // Each new client should generate its own headers based on its target host +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, + const Request &req, + Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, + is_shutting_down, *compressor, error); + } else { + return detail::write_content_with_progress( + strm, req.content_provider_, 0, req.content_length_, is_shutting_down, + req.upload_progress, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, + bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + // For Unix socket connections, use "localhost" as Host header (similar to + // curl behavior) + if (address_family_ == AF_UNIX) { + req.set_header("Host", "localhost"); + } else if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } + + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || + req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path_with_query = + req.params.empty() ? req.path + : append_query_params(req.path, req.params); + + const auto &path = + path_encode_ ? detail::encode_path(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (req.upload_progress) { + auto body_size = req.body.size(); + size_t written = 0; + auto data = req.body.data(); + + while (written < body_size) { + size_t to_write = (std::min)(CPPHTTPLIB_SEND_BUFSIZ, body_size - written); + if (!detail::write_data(strm, data + written, to_write)) { + error = Error::Write; + return false; + } + written += to_write; + + if (!req.upload_progress(written, body_size)) { + error = Error::Canceled; + return false; + } + } + } else { + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error) { + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { req.set_header("Content-Encoding", "gzip"); } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, + [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter( + std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider( + const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, UploadProgress progress) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + req.upload_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + auto error = Error::Success; + + auto res = send_with_content_provider( + req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, + last_openssl_error_}; +#else + return Result{std::move(res), error, std::move(req.headers)}; +#endif +} + +inline std::string +ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { return "[" + host + "]"; } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || + !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && + req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && + res.status != StatusCode::NotModified_304 && + follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast( + [&](const char *buf, size_t n, size_t off, size_t len) { + if (redirect) { return true; } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { error = Error::Canceled; } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, size_t /*off*/, + size_t /*len*/) { + assert(res.body.size() + n <= res.body.max_size()); + res.body.append(buf, n); + return true; + }); + + auto progress = [&](size_t current, size_t total) { + if (!req.download_progress || redirect) { return true; } + auto ret = req.download_progress(current, total); + if (!ret) { error = Error::Canceled; } + return ret; + }; + + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(static_cast(len)); + } + } + + if (res.status != StatusCode::NotModified_304) { + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const UploadFormDataItems &items, + const FormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, + DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin( + provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool ClientImpl::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + return detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path, + DownloadProgress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + DownloadProgress progress) { + if (params.empty()) { return Get(path, headers); } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + DownloadProgress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.download_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { + return Head(path, Headers()); +} + +inline Result ClientImpl::Head(const std::string &path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { + return Post(path, std::string(), std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return Post(path, Headers(), items, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.body = body; + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Put(const std::string &path) { + return Put(path, std::string(), std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers) { + return Put(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return Put(path, Headers(), items, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "PUT"; + req.path = path; + req.headers = headers; + req.body = body; + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Patch(const std::string &path) { + return Patch(path, std::string(), std::string()); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + UploadProgress progress) { + return Patch(path, headers, nullptr, 0, std::string(), progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Params ¶ms) { + return Patch(path, Headers(), params); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Patch(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Patch(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return Patch(path, Headers(), items, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Patch(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Patch(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, body, + content_length, nullptr, nullptr, + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PATCH", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "PATCH"; + req.path = path; + req.headers = headers; + req.body = body; + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, + DownloadProgress progress) { + return Delete(path, Headers(), std::string(), std::string(), progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + DownloadProgress progress) { + return Delete(path, headers, std::string(), std::string(), progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const Params ¶ms, + DownloadProgress progress) { + return Delete(path, Headers(), params, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const Params ¶ms, + DownloadProgress progress) { + auto query = detail::params_to_query_str(params); + return Delete(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + req.download_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Options(const std::string &path) { + return Options(path, Headers()); +} + +inline Result ClientImpl::Options(const std::string &path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_max_timeout(time_t msec) { + max_timeout_msec_ = msec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, + const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { + bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_path_encode(bool on) { path_encode_ = on; } + +inline void +ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); +} + +inline void ClientImpl::set_header_writer( + std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { + address_family_ = family; +} + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { + interface_ = intf; +} + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, + std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); + if (!mem) { return nullptr; } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + + if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} + +inline void ClientImpl::set_server_certificate_verifier( + std::function verifier) { + server_certificate_verifier_ = verifier; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { + logger_ = std::move(logger); +} + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, + bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { + (void)(sock); + // SSL_shutdown() returns 0 on first call (indicating close_notify alert + // sent) and 1 on subsequent call (indicating close_notify alert received) + if (SSL_shutdown(ssl) == 0) { + // Expected to return 1, but even if it doesn't, we free ssl + SSL_shutdown(ssl); + } + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, + U ssl_connect_or_accept, + time_t timeout_sec, time_t timeout_usec, + int *ssl_error) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + if (ssl_error) { *ssl_error = err; } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl( + const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool process_client_socket_ssl( + SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +inline bool SSLSocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (wait_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN64 + while (--n >= 0 && (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (wait_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } else { + return -1; + } +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (wait_writable()) { + auto handle_size = static_cast( + std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN64 + while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (wait_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +inline time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, + reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1 || + SSL_CTX_check_private_key(ctx_) != 1) { + last_ssl_error_ = static_cast(ERR_get_error()); + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking( + sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_, + &last_ssl_error_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, + connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + last_openssl_error_ = ERR_get_error(); + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + last_openssl_error_ = ERR_get_error(); + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + // Close the current socket and create a new one for the authenticated + // request + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + + // Create a new socket for the authenticated CONNECT request + if (!create_and_connect_socket(socket, error)) { + success = false; + return false; + } + + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + last_openssl_error_ = ERR_get_error(); + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + last_openssl_error_ = ERR_get_error(); + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN64 + loaded = + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \ + defined(TARGET_OS_OSX) + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // _WIN64 + if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking( + socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_, &last_ssl_error_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + auto verification_status = SSLVerifierResponse::NoDecisionMade; + + if (server_certificate_verifier_) { + verification_status = server_certificate_verifier_(ssl2); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + last_openssl_error_ = ERR_get_error(); + error = Error::SSLServerVerification; + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + last_openssl_error_ = static_cast(verify_result_); + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + last_openssl_error_ = ERR_get_error(); + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; + error = Error::SSLServerHostnameVerification; + return false; + } + } + } + } + + return true; + }, + [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); +#endif + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 = {}; + struct in_addr addr = {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) + : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re( + R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { host = m[3].str(); } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + } + } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. + cli_ = detail::make_unique(scheme_host_port, 80, + client_cert_path, client_key_path); + } +} // namespace detail + +inline Client::Client(const std::string &host, int port) + : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, + client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { + return cli_ != nullptr && cli_->is_valid(); +} + +inline Result Client::Get(const std::string &path, DownloadProgress progress) { + return cli_->Get(path, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + DownloadProgress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, DownloadProgress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { + return cli_->Head(path, headers); +} + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { + return cli_->Post(path, headers); +} +inline Result Client::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, body, content_length, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + progress); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { + return cli_->Post(path, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Post(path, items, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Post(path, headers, items, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + return cli_->Post(path, headers, items, boundary, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + return cli_->Post(path, headers, items, provider_items, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, body, content_type, content_receiver, + progress); +} + +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const Headers &headers) { + return cli_->Put(path, headers); +} +inline Result Client::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + progress); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { + return cli_->Put(path, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Put(path, items, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Put(path, headers, items, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + return cli_->Put(path, headers, items, boundary, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + return cli_->Put(path, headers, items, provider_items, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Put(path, headers, body, content_type, content_receiver, + progress); +} + +inline Result Client::Patch(const std::string &path) { + return cli_->Patch(path); +} +inline Result Client::Patch(const std::string &path, const Headers &headers) { + return cli_->Patch(path, headers); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + progress); +} +inline Result Client::Patch(const std::string &path, const Params ¶ms) { + return cli_->Patch(path, params); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Patch(path, headers, params); +} +inline Result Client::Patch(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Patch(path, items, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Patch(path, headers, items, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + return cli_->Patch(path, headers, items, boundary, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + return cli_->Patch(path, headers, items, provider_items, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Patch(path, headers, body, content_type, content_receiver, + progress); +} + +inline Result Client::Delete(const std::string &path, + DownloadProgress progress) { + return cli_->Delete(path, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + DownloadProgress progress) { + return cli_->Delete(path, headers, progress); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Params ¶ms, + DownloadProgress progress) { + return cli_->Delete(path, params, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const Params ¶ms, DownloadProgress progress) { + return cli_->Delete(path, headers, params, progress); +} + +inline Result Client::Options(const std::string &path) { + return cli_->Options(path); +} +inline Result Client::Options(const std::string &path, const Headers &headers) { + return cli_->Options(path, headers); +} + +inline bool Client::send(Request &req, Response &res, Error &error) { + return cli_->send(req, res, error); +} + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void +Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { + cli_->set_default_headers(std::move(headers)); +} + +inline void Client::set_header_writer( + std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { + cli_->set_address_family(family); +} + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { + cli_->set_connection_timeout(sec, usec); +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + cli_->set_read_timeout(sec, usec); +} + +inline void Client::set_write_timeout(time_t sec, time_t usec) { + cli_->set_write_timeout(sec, usec); +} + +inline void Client::set_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { + cli_->set_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { + cli_->set_follow_location(on); +} + +inline void Client::set_path_encode(bool on) { cli_->set_path_encode(on); } + +[[deprecated("Use set_path_encode instead")]] +inline void Client::set_url_encode(bool on) { + cli_->set_path_encode(on); +} + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { + cli_->set_interface(intf); +} + +inline void Client::set_proxy(const std::string &host, int port) { + cli_->set_proxy(host, port); +} +inline void Client::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { + cli_->set_proxy_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} +#endif + +inline void Client::set_logger(Logger logger) { + cli_->set_logger(std::move(logger)); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { return static_cast(*cli_).ssl_context(); } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/main.c b/main.c new file mode 100644 index 0000000..376abb7 --- /dev/null +++ b/main.c @@ -0,0 +1,206 @@ +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #include +#else + #include +#endif + +#include "wren.h" +#include "requests_backend.c" +#include "socket_backend.c" + +// --- Global flag to control the main loop --- +static volatile bool g_mainFiberIsDone = false; + +// --- Foreign function for Wren to signal the host to exit --- +void hostSignalDone(WrenVM* vm) { + (void)vm; + g_mainFiberIsDone = true; +} + +// --- File/VM Setup --- +static char* readFile(const char* path) { + FILE* file = fopen(path, "rb"); + if (file == NULL) return NULL; + fseek(file, 0L, SEEK_END); + size_t fileSize = ftell(file); + rewind(file); + char* buffer = (char*)malloc(fileSize + 1); + if (!buffer) { fclose(file); return NULL; } + size_t bytesRead = fread(buffer, sizeof(char), fileSize, file); + if (bytesRead < fileSize) { + free(buffer); + fclose(file); + return NULL; + } + buffer[bytesRead] = '\0'; + fclose(file); + return buffer; +} + +static void writeFn(WrenVM* vm, const char* text) { (void)vm; printf("%s", text); } + +static void errorFn(WrenVM* vm, WrenErrorType type, const char* module, int line, const char* message) { + (void)vm; + switch (type) { + case WREN_ERROR_COMPILE: + fprintf(stderr, "[%s line %d] [Error] %s\n", module, line, message); + break; + case WREN_ERROR_RUNTIME: + fprintf(stderr, "[Runtime Error] %s\n", message); + g_mainFiberIsDone = true; // Stop on runtime errors + break; + case WREN_ERROR_STACK_TRACE: + fprintf(stderr, "[%s line %d] in %s\n", module, line, message); + break; + } +} + +static void onModuleComplete(WrenVM* vm, const char* name, WrenLoadModuleResult result) { + (void)vm; (void)name; + if (result.source) free((void*)result.source); +} + +static WrenLoadModuleResult loadModule(WrenVM* vm, const char* name) { + (void)vm; + WrenLoadModuleResult result = {0}; + char path[256]; + snprintf(path, sizeof(path), "%s.wren", name); + char* source = readFile(path); + if (source != NULL) { + result.source = source; + result.onComplete = onModuleComplete; + } + return result; +} + +// --- Combined Foreign Function Binders --- +WrenForeignMethodFn combinedBindForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) { + // Delegate to the socket backend's binder + if (strcmp(module, "socket") == 0) { + return bindSocketForeignMethod(vm, module, className, isStatic, signature); + } + + // Delegate to the requests backend's binder + if (strcmp(module, "requests") == 0) { + return bindForeignMethod(vm, module, className, isStatic, signature); + } + + // Handle host-specific methods + if (strcmp(module, "main") == 0 && strcmp(className, "Host") == 0 && isStatic) { + if (strcmp(signature, "signalDone()") == 0) return hostSignalDone; + } + + return NULL; +} + +WrenForeignClassMethods combinedBindForeignClass(WrenVM* vm, const char* module, const char* className) { + // Delegate to the socket backend's class binder + if (strcmp(module, "socket") == 0) { + return bindSocketForeignClass(vm, module, className); + } + + // Delegate to the requests backend's class binder + if (strcmp(module, "requests") == 0) { + return bindForeignClass(vm, module, className); + } + + WrenForeignClassMethods methods = {0, 0}; + return methods; +} + + +// --- Main Application Entry Point --- +int main(int argc, char* argv[]) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + // Initialize libcurl for the requests module + curl_global_init(CURL_GLOBAL_ALL); + + WrenConfiguration config; + wrenInitConfiguration(&config); + config.writeFn = writeFn; + config.errorFn = errorFn; + config.bindForeignMethodFn = combinedBindForeignMethod; + config.bindForeignClassFn = combinedBindForeignClass; + config.loadModuleFn = loadModule; + + WrenVM* vm = wrenNewVM(&config); + + // ** Initialize BOTH managers ** + socketManager_create(vm); + httpManager_create(vm); + + char* mainSource = readFile(argv[1]); + if (!mainSource) { + fprintf(stderr, "Could not open script: %s\n", argv[1]); + socketManager_destroy(); + httpManager_destroy(); + wrenFreeVM(vm); + curl_global_cleanup(); + return 1; + } + + wrenInterpret(vm, "main", mainSource); + free(mainSource); + + if (g_mainFiberIsDone) { + socketManager_destroy(); + httpManager_destroy(); + wrenFreeVM(vm); + curl_global_cleanup(); + return 1; + } + + wrenEnsureSlots(vm, 1); + wrenGetVariable(vm, "main", "mainFiber", 0); + WrenHandle* mainFiberHandle = wrenGetSlotHandle(vm, 0); + WrenHandle* callHandle = wrenMakeCallHandle(vm, "call()"); + + // === Main Event Loop === + while (!g_mainFiberIsDone) { + // ** Process completions for BOTH managers ** + socketManager_processCompletions(); + httpManager_processCompletions(); + + // Resume the main Wren fiber + wrenEnsureSlots(vm, 1); + wrenSetSlotHandle(vm, 0, mainFiberHandle); + WrenInterpretResult result = wrenCall(vm, callHandle); + if (result == WREN_RESULT_RUNTIME_ERROR) { + g_mainFiberIsDone = true; + } + + // Prevent 100% CPU usage + #ifdef _WIN32 + Sleep(1); + #else + usleep(1000); // 1ms + #endif + } + + // Process any final completions before shutting down + socketManager_processCompletions(); + httpManager_processCompletions(); + + wrenReleaseHandle(vm, mainFiberHandle); + wrenReleaseHandle(vm, callHandle); + + // ** Destroy BOTH managers ** + socketManager_destroy(); + httpManager_destroy(); + + wrenFreeVM(vm); + curl_global_cleanup(); + + printf("\nHost application finished.\n"); + return 0; +} diff --git a/merged_source_files.txt b/merged_source_files.txt new file mode 100644 index 0000000..be0c4a3 --- /dev/null +++ b/merged_source_files.txt @@ -0,0 +1,27070 @@ +// Start of socket_backend.c +// socket_backend.c (Corrected with better handle safety and non-blocking I/O) +#include "wren.h" +#include +#include +#include +#include +#include + +// Platform-specific includes and definitions +#ifdef _WIN32 + #include + #include + #include + #pragma comment(lib, "ws2_32.lib") + typedef SOCKET socket_t; + typedef int socklen_t; + typedef HANDLE thread_t; + typedef CRITICAL_SECTION mutex_t; + typedef CONDITION_VARIABLE cond_t; + #define IS_SOCKET_VALID(s) ((s) != INVALID_SOCKET) + #define CLOSE_SOCKET(s) closesocket(s) +#else + #include + #include + #include + #include + #include + #include + #include + #include + #include + typedef int socket_t; + typedef pthread_t thread_t; + typedef pthread_mutex_t mutex_t; + typedef pthread_cond_t cond_t; + #define INVALID_SOCKET -1 + #define IS_SOCKET_VALID(s) ((s) >= 0) + #define CLOSE_SOCKET(s) close(s) +#endif + +// --- Forward Declarations --- +typedef struct SocketContext SocketContext; + +// --- Socket Data Structures --- + +typedef enum { + SOCKET_OP_CONNECT, + SOCKET_OP_READ, + SOCKET_OP_WRITE, +} SocketOp; + +typedef struct { + socket_t sock; + bool isListener; +} SocketData; + +struct SocketContext { + SocketOp operation; + WrenVM* vm; + WrenHandle* socketHandle; + WrenHandle* callback; + + char* host; + int port; + char* data; + size_t dataLength; + + bool success; + char* resultData; + size_t resultDataLength; + char* errorMessage; + socket_t newSocket; + struct SocketContext* next; +}; + +// --- Thread-Safe Queue Implementation in C --- +typedef struct { + SocketContext *head, *tail; + mutex_t mutex; + cond_t cond; +} ThreadSafeQueueSocket; + +void queue_init(ThreadSafeQueueSocket* q) { + q->head = q->tail = NULL; + #ifdef _WIN32 + InitializeCriticalSection(&q->mutex); + InitializeConditionVariable(&q->cond); + #else + pthread_mutex_init(&q->mutex, NULL); + pthread_cond_init(&q->cond, NULL); + #endif +} + +void queue_destroy(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + DeleteCriticalSection(&q->mutex); + #else + pthread_mutex_destroy(&q->mutex); + pthread_cond_destroy(&q->cond); + #endif +} + +void queue_push(ThreadSafeQueueSocket* q, SocketContext* context) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + #endif + + if (context) { + context->next = NULL; + } + + if (q->tail) { + q->tail->next = context; + } else { + q->head = context; + } + q->tail = context; + + #ifdef _WIN32 + WakeConditionVariable(&q->cond); + LeaveCriticalSection(&q->mutex); + #else + pthread_cond_signal(&q->cond); + pthread_mutex_unlock(&q->mutex); + #endif +} + +SocketContext* queue_pop(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + while (q->head == NULL) { + SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE); + } + #else + pthread_mutex_lock(&q->mutex); + while (q->head == NULL) { + pthread_cond_wait(&q->cond, &q->mutex); + } + #endif + + SocketContext* context = q->head; + q->head = q->head->next; + if (q->head == NULL) { + q->tail = NULL; + } + + #ifdef _WIN32 + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_unlock(&q->mutex); + #endif + + return context; +} + +bool queue_empty(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + bool empty = (q->head == NULL); + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + bool empty = (q->head == NULL); + pthread_mutex_unlock(&q->mutex); + #endif + return empty; +} + +// --- Asynchronous Socket Manager --- + +#define MAX_LISTENERS 64 + +typedef struct { + WrenVM* vm; + volatile bool running; + thread_t worker_threads[4]; + thread_t listener_thread; + + ThreadSafeQueueSocket requestQueue; + ThreadSafeQueueSocket completionQueue; + ThreadSafeQueueSocket acceptQueue; + + mutex_t listener_mutex; + socket_t listener_sockets[MAX_LISTENERS]; + int listener_count; + #ifndef _WIN32 + socket_t wake_pipe[2]; + #endif +} AsyncSocketManager; + +static AsyncSocketManager* socketManager = NULL; + +void free_socket_context_data(SocketContext* context) { + if (!context) return; + free(context->host); + free(context->data); + free(context->resultData); + free(context->errorMessage); + free(context); +} + +#ifdef _WIN32 +DWORD WINAPI workerThread(LPVOID arg); +DWORD WINAPI listenerThread(LPVOID arg); +#else +void* workerThread(void* arg); +void* listenerThread(void* arg); +#endif + +// --- Worker and Listener Thread Implementations --- + +#ifdef _WIN32 +DWORD WINAPI listenerThread(LPVOID arg) { +#else +void* listenerThread(void* arg) { +#endif + AsyncSocketManager* manager = (AsyncSocketManager*)arg; + + while (manager->running) { + fd_set read_fds; + FD_ZERO(&read_fds); + + socket_t max_fd = 0; + + #ifndef _WIN32 + FD_SET(manager->wake_pipe[0], &read_fds); + max_fd = manager->wake_pipe[0]; + #endif + + #ifdef _WIN32 + EnterCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_lock(&manager->listener_mutex); + #endif + + for (int i = 0; i < manager->listener_count; i++) { + socket_t sock = manager->listener_sockets[i]; + if (IS_SOCKET_VALID(sock)) { + FD_SET(sock, &read_fds); + if (sock > max_fd) { + max_fd = sock; + } + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_unlock(&manager->listener_mutex); + #endif + + struct timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + + int activity = select(max_fd + 1, &read_fds, NULL, NULL, &timeout); + + if (!manager->running) break; + if (activity < 0) { + #ifndef _WIN32 + if (errno != EINTR) { + perror("select error"); + } + #endif + continue; + } + if (activity == 0) continue; + + #ifndef _WIN32 + if (FD_ISSET(manager->wake_pipe[0], &read_fds)) { + char buffer[1]; + read(manager->wake_pipe[0], buffer, 1); + } + #endif + + #ifdef _WIN32 + EnterCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_lock(&manager->listener_mutex); + #endif + + for (int i = 0; i < manager->listener_count; i++) { + socket_t sock = manager->listener_sockets[i]; + if (IS_SOCKET_VALID(sock) && FD_ISSET(sock, &read_fds)) { + if (!queue_empty(&manager->acceptQueue)) { + SocketContext* context = queue_pop(&manager->acceptQueue); + context->newSocket = accept(sock, NULL, NULL); + context->success = IS_SOCKET_VALID(context->newSocket); + if (!context->success) { + context->errorMessage = strdup("Accept failed."); + } + queue_push(&manager->completionQueue, context); + } + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_unlock(&manager->listener_mutex); + #endif + } + return 0; +} + +#ifdef _WIN32 +DWORD WINAPI workerThread(LPVOID arg) { +#else +void* workerThread(void* arg) { +#endif + AsyncSocketManager* manager = (AsyncSocketManager*)arg; + + while (manager->running) { + SocketContext* context = queue_pop(&manager->requestQueue); + if (!context || !manager->running) { + if (context) free_socket_context_data(context); + break; + } + + wrenEnsureSlots(context->vm, 1); + wrenSetSlotHandle(context->vm, 0, context->socketHandle); + SocketData* socketData = (wrenGetSlotType(context->vm, 0) == WREN_TYPE_FOREIGN) + ? (SocketData*)wrenGetSlotForeign(context->vm, 0) + : NULL; + + if (!socketData) { + context->success = false; + context->errorMessage = strdup("Invalid socket object."); + queue_push(&manager->completionQueue, context); + continue; + } + + switch (context->operation) { + case SOCKET_OP_CONNECT: { + struct addrinfo hints = {0}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + char port_str[6]; + snprintf(port_str, 6, "%d", context->port); + if (getaddrinfo(context->host, port_str, &hints, &addrs) != 0) { + context->success = false; + context->errorMessage = strdup("Host lookup failed."); + break; + } + + socket_t sock = INVALID_SOCKET; + for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + if (connect(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock)) { + socketData->sock = sock; + socketData->isListener = false; + context->success = true; + } else { + context->success = false; + context->errorMessage = strdup("Connection failed."); + } + break; + } + case SOCKET_OP_READ: { + if (socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot read from a listening socket."); + break; + } + + fd_set read_fds; + FD_ZERO(&read_fds); + FD_SET(socketData->sock, &read_fds); + struct timeval timeout = { .tv_sec = 5, .tv_usec = 0 }; // 5-second timeout + + int activity = select(socketData->sock + 1, &read_fds, NULL, NULL, &timeout); + if (activity > 0 && FD_ISSET(socketData->sock, &read_fds)) { + char buf[4096]; + ssize_t len = recv(socketData->sock, buf, sizeof(buf), 0); + if (len > 0) { + context->resultData = (char*)malloc(len); + memcpy(context->resultData, buf, len); + context->resultDataLength = len; + context->success = true; + } else { + context->success = false; + context->errorMessage = strdup("Read failed or connection closed."); + } + } else { + context->success = false; + context->errorMessage = strdup("Read timeout or error."); + } + break; + } + case SOCKET_OP_WRITE: { + if (socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot write to a listening socket."); + break; + } + ssize_t written = send(socketData->sock, context->data, context->dataLength, 0); + context->success = (written == (ssize_t)context->dataLength); + if(!context->success) context->errorMessage = strdup("Write failed."); + break; + } + } + queue_push(&manager->completionQueue, context); + } + return 0; +} + +// --- Manager Lifecycle --- + +void socketManager_create(WrenVM* vm) { + socketManager = (AsyncSocketManager*)malloc(sizeof(AsyncSocketManager)); + socketManager->vm = vm; + socketManager->running = true; + socketManager->listener_count = 0; + + queue_init(&socketManager->requestQueue); + queue_init(&socketManager->completionQueue); + queue_init(&socketManager->acceptQueue); + + #ifdef _WIN32 + InitializeCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_init(&socketManager->listener_mutex, NULL); + #endif + + #ifndef _WIN32 + if (pipe(socketManager->wake_pipe) == -1) { + perror("pipe"); + exit(1); + } + #endif + + for (int i = 0; i < 4; i++) { + #ifdef _WIN32 + socketManager->worker_threads[i] = CreateThread(NULL, 0, workerThread, socketManager, 0, NULL); + #else + pthread_create(&socketManager->worker_threads[i], NULL, workerThread, socketManager); + #endif + } + + #ifdef _WIN32 + socketManager->listener_thread = CreateThread(NULL, 0, listenerThread, socketManager, 0, NULL); + #else + pthread_create(&socketManager->listener_thread, NULL, listenerThread, socketManager); + #endif +} + +void socketManager_destroy() { + socketManager->running = false; + + #ifndef _WIN32 + write(socketManager->wake_pipe[1], "w", 1); + #endif + + for (int i = 0; i < 4; i++) { + queue_push(&socketManager->requestQueue, NULL); + } + + #ifdef _WIN32 + WaitForSingleObject(socketManager->listener_thread, INFINITE); + CloseHandle(socketManager->listener_thread); + for (int i = 0; i < 4; i++) { + WaitForSingleObject(socketManager->worker_threads[i], INFINITE); + CloseHandle(socketManager->worker_threads[i]); + } + #else + pthread_join(socketManager->listener_thread, NULL); + for (int i = 0; i < 4; i++) { + pthread_join(socketManager->worker_threads[i], NULL); + } + close(socketManager->wake_pipe[0]); + close(socketManager->wake_pipe[1]); + #endif + + queue_destroy(&socketManager->requestQueue); + queue_destroy(&socketManager->completionQueue); + queue_destroy(&socketManager->acceptQueue); + + #ifdef _WIN32 + DeleteCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_destroy(&socketManager->listener_mutex); + #endif + + free(socketManager); +} + +void socketManager_processCompletions() { + WrenHandle* callHandle = wrenMakeCallHandle(socketManager->vm, "call(_,_)"); + while (!queue_empty(&socketManager->completionQueue)) { + SocketContext* context = queue_pop(&socketManager->completionQueue); + + wrenEnsureSlots(socketManager->vm, 3); + wrenSetSlotHandle(socketManager->vm, 0, context->callback); + if (context->success) { + wrenSetSlotNull(socketManager->vm, 1); + if (IS_SOCKET_VALID(context->newSocket)) { + wrenGetVariable(socketManager->vm, "socket", "Socket", 2); + void* foreign = wrenSetSlotNewForeign(socketManager->vm, 2, 2, sizeof(SocketData)); + SocketData* clientData = (SocketData*)foreign; + clientData->sock = context->newSocket; + clientData->isListener = false; + } else if (context->resultData) { + wrenSetSlotBytes(socketManager->vm, 2, context->resultData, context->resultDataLength); + } else { + wrenSetSlotNull(socketManager->vm, 2); + } + } else { + wrenSetSlotString(socketManager->vm, 1, context->errorMessage ? context->errorMessage : "Unknown error."); + wrenSetSlotNull(socketManager->vm, 2); + } + + wrenCall(socketManager->vm, callHandle); + + // Safely release handles here on the main thread + wrenReleaseHandle(socketManager->vm, context->socketHandle); + wrenReleaseHandle(socketManager->vm, context->callback); + free_socket_context_data(context); + } + wrenReleaseHandle(socketManager->vm, callHandle); +} + +// ... (The rest of the foreign functions from socketAllocate onwards are identical to the previous response) ... +void socketAllocate(WrenVM* vm) { + SocketData* data = (SocketData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(SocketData)); + data->sock = INVALID_SOCKET; + data->isListener = false; +} + +void socketConnect(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_CONNECT; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->host = strdup(wrenGetSlotString(vm, 1)); + context->port = (int)wrenGetSlotDouble(vm, 2); + context->callback = wrenGetSlotHandle(vm, 3); + queue_push(&socketManager->requestQueue, context); +} + +void socketListen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + const char* host = wrenGetSlotString(vm, 1); + int port = (int)wrenGetSlotDouble(vm, 2); + int backlog = (int)wrenGetSlotDouble(vm, 3); + + struct addrinfo hints = {0}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + char port_str[6]; + snprintf(port_str, 6, "%d", port); + if (getaddrinfo(host, port_str, &hints, &addrs) != 0) { + wrenSetSlotBool(vm, 0, false); + return; + } + + socket_t sock = INVALID_SOCKET; + for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&yes, sizeof(yes)); + + if (bind(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock) && listen(sock, backlog) == 0) { + data->sock = sock; + data->isListener = true; + + #ifdef _WIN32 + EnterCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_lock(&socketManager->listener_mutex); + #endif + + if (socketManager->listener_count < MAX_LISTENERS) { + socketManager->listener_sockets[socketManager->listener_count++] = sock; + } + + #ifdef _WIN32 + LeaveCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_unlock(&socketManager->listener_mutex); + #endif + + #ifndef _WIN32 + write(socketManager->wake_pipe[1], "w", 1); + #endif + + wrenSetSlotBool(vm, 0, true); + } else { + if(IS_SOCKET_VALID(sock)) CLOSE_SOCKET(sock); + wrenSetSlotBool(vm, 0, false); + } +} + +void socketAccept(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + queue_push(&socketManager->acceptQueue, context); +} + +void socketRead(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_READ; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + queue_push(&socketManager->requestQueue, context); +} + +void socketWrite(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_WRITE; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + int len; + const char* bytes = wrenGetSlotBytes(vm, 1, &len); + context->data = (char*)malloc(len); + memcpy(context->data, bytes, len); + context->dataLength = len; + context->callback = wrenGetSlotHandle(vm, 2); + queue_push(&socketManager->requestQueue, context); +} + +void socketClose(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (IS_SOCKET_VALID(data->sock)) { + if (data->isListener) { + #ifdef _WIN32 + EnterCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_lock(&socketManager->listener_mutex); + #endif + + for (int i = 0; i < socketManager->listener_count; i++) { + if (socketManager->listener_sockets[i] == data->sock) { + socketManager->listener_sockets[i] = socketManager->listener_sockets[socketManager->listener_count - 1]; + socketManager->listener_count--; + break; + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_unlock(&socketManager->listener_mutex); + #endif + } + CLOSE_SOCKET(data->sock); + data->sock = INVALID_SOCKET; + } +} + +void socketIsOpen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, IS_SOCKET_VALID(data->sock)); +} + +void socketRemoteAddress(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (!IS_SOCKET_VALID(data->sock) || data->isListener) { + wrenSetSlotNull(vm, 0); + return; + } + + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + char ipstr[INET6_ADDRSTRLEN]; + + if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) { + if (addr.ss_family == AF_INET) { + inet_ntop(AF_INET, &((struct sockaddr_in*)&addr)->sin_addr, ipstr, sizeof(ipstr)); + } else { + inet_ntop(AF_INET6, &((struct sockaddr_in6*)&addr)->sin6_addr, ipstr, sizeof(ipstr)); + } + wrenSetSlotString(vm, 0, ipstr); + } else { + wrenSetSlotNull(vm, 0); + } +} + +void socketRemotePort(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (!IS_SOCKET_VALID(data->sock) || data->isListener) { + wrenSetSlotNull(vm, 0); + return; + } + + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + + if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) { + int port = 0; + if (addr.ss_family == AF_INET) { + port = ntohs(((struct sockaddr_in*)&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = ntohs(((struct sockaddr_in6*)&addr)->sin6_port); + } + wrenSetSlotDouble(vm, 0, (double)port); + } else { + wrenSetSlotNull(vm, 0); + } +} + +WrenForeignMethodFn bindSocketForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "socket") != 0) return NULL; + if (strcmp(className, "Socket") == 0 && !isStatic) { + if (strcmp(signature, "connect(_,_,_)") == 0) return socketConnect; + if (strcmp(signature, "listen(_,_,_)") == 0) return socketListen; + if (strcmp(signature, "accept(_)") == 0) return socketAccept; + // NOTE: The signature for read() in Wren takes one argument (the callback) now. + if (strcmp(signature, "read(_)") == 0) return socketRead; + if (strcmp(signature, "write_(_,_)") == 0) return socketWrite; + if (strcmp(signature, "close()") == 0) return socketClose; + if (strcmp(signature, "isOpen") == 0) return socketIsOpen; + if (strcmp(signature, "remoteAddress") == 0) return socketRemoteAddress; + if (strcmp(signature, "remotePort") == 0) return socketRemotePort; + } + return NULL; +} + +WrenForeignClassMethods bindSocketForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0, 0}; + if (strcmp(module, "socket") == 0 && strcmp(className, "Socket") == 0) { + methods.allocate = socketAllocate; + } + return methods; +} + +// End of socket_backend.c + +// Start of httplib.h +// +// httplib.h +// +// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.23.1" +#define CPPHTTPLIB_VERSION_NUM "0x001701" + +/* + * Platform compatibility check + */ + +#if defined(_WIN32) && !defined(_WIN64) +#error \ + "cpp-httplib doesn't support 32-bit Windows. Please use a 64-bit compiler." +#elif defined(__SIZEOF_POINTER__) && __SIZEOF_POINTER__ < 8 +#warning \ + "cpp-httplib doesn't support 32-bit platforms. Please use a 64-bit compiler." +#elif defined(__SIZEOF_SIZE_T__) && __SIZEOF_SIZE_T__ < 8 +#warning \ + "cpp-httplib doesn't support platforms where size_t is less than 64 bits." +#endif + +#ifdef _WIN32 +#if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0602 +#error \ + "cpp-httplib doesn't support Windows 8 or lower. Please use Windows 10 or later." +#endif +#endif + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND +#define CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN64 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 1000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_COUNT +#define CPPHTTPLIB_HEADER_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_SEND_BUFSIZ +#define CPPHTTPLIB_SEND_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ + ? std::thread::hardware_concurrency() - 1 \ + : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +#ifndef CPPHTTPLIB_MAX_LINE_LENGTH +#define CPPHTTPLIB_MAX_LINE_LENGTH 32768 +#endif + +/* + * Headers + */ + +#ifdef _WIN64 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +using ssize_t = __int64; +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#if defined(__has_include) +#if __has_include() +// afunix.h uses types declared in winsock2.h, so has to be included after it. +#include +#define CPPHTTPLIB_HAVE_AFUNIX_H 1 +#endif +#endif + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using nfds_t = unsigned long; +using socket_t = SOCKET; +using socklen_t = int; + +#else // not _WIN64 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN64 + +#if defined(__APPLE__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ + defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_OSX +#include +#include +#endif +#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or + // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN64 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN64 + +#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_OSX +#include +#endif +#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO + +#include +#include +#include +#include + +#if defined(_WIN64) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, + 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, + 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, + 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224, + 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, + }; + return table[(unsigned char)(char)c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { + return to_lower(ca) == to_lower(cb); + }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { + return equal(a, b); + } +}; + +struct hash { + size_t operator()(const std::string &key) const { + return hash_core(key.data(), key.size(), 0); + } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +template +using unordered_set = std::unordered_set; + +} // namespace case_ignore + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) + : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), + execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { this->exit_function(); } + } + + void release() { this->execute_on_destruction = false; } + +private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum SSLVerifierResponse { + // no decision has been made, use the built-in certificate verifier + NoDecisionMade, + // connection certificate is verified and accepted + CertificateAccepted, + // connection certificate was processed but is rejected + CertificateRejected +}; + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = + std::unordered_multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using DownloadProgress = std::function; +using UploadProgress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct FormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; + Headers headers; +}; + +struct FormField { + std::string name; + std::string content; + Headers headers; +}; +using FormFields = std::multimap; + +using FormFiles = std::multimap; + +struct MultipartFormData { + FormFields fields; // Text fields from multipart + FormFiles files; // Files from multipart + + // Text field access + std::string get_field(const std::string &key, size_t id = 0) const; + std::vector get_fields(const std::string &key) const; + bool has_field(const std::string &key) const; + size_t get_field_count(const std::string &key) const; + + // File access + FormData get_file(const std::string &key, size_t id = 0) const; + std::vector get_files(const std::string &key) const; + bool has_file(const std::string &key) const; + size_t get_file_count(const std::string &key) const; +}; + +struct UploadFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using UploadFormDataItems = std::vector; + +class DataSink { +public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + +private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = + std::function; + +using ContentProviderWithoutLength = + std::function; + +using ContentProviderResourceReleaser = std::function; + +struct FormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using FormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = std::function; + +using ContentReceiver = + std::function; + +using FormDataHeader = std::function; + +class ContentReader { +public: + using Reader = std::function; + using FormDataReader = + std::function; + + ContentReader(Reader reader, FormDataReader multipart_reader) + : reader_(std::move(reader)), + formdata_reader_(std::move(multipart_reader)) {} + + bool operator()(FormDataHeader header, ContentReceiver receiver) const { + return formdata_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } + + Reader reader_; + FormDataReader formdata_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + std::string matched_route; + Params params; + Headers headers; + Headers trailers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + MultipartFormData form; + Ranges ranges; + Match matches; + std::unordered_map path_params; + std::function is_connection_closed = []() { return true; }; + + // for client + std::vector accept_content_types; + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + DownloadProgress download_progress; + UploadProgress upload_progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + size_t get_header_value_u64(const std::string &key, size_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_trailer(const std::string &key) const; + std::string get_trailer_value(const std::string &key, size_t id = 0) const; + size_t get_trailer_value_count(const std::string &key) const; + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; + std::chrono::time_point start_time_ = + (std::chrono::steady_clock::time_point::min)(); +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + Headers trailers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + size_t get_header_value_u64(const std::string &key, size_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_trailer(const std::string &key) const; + std::string get_trailer_value(const std::string &key, size_t id = 0) const; + size_t get_trailer_value_count(const std::string &key) const; + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider( + size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_file_content(const std::string &path, + const std::string &content_type); + void set_file_content(const std::string &path); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + virtual time_t duration() const = 0; + + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { +public: + explicit ThreadPool(size_t n, size_t mqr = 0) + : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +namespace detail { + +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen); +bool set_socket_opt(socket_t sock, int level, int optname, int opt); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + +} // namespace detail + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { +public: + MatcherBase(std::string pattern) : pattern_(pattern) {} + virtual ~MatcherBase() = default; + + const std::string &pattern() const { return pattern_; } + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; + +private: + std::string pattern_; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched against the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) + : MatcherBase(pattern), regex_(pattern) {} + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { +public: + using Handler = std::function; + + using ExceptionHandler = + std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = + std::function; + + using HandlerWithContentReader = std::function; + + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, + const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, + Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + template + Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core( + std::forward(handler), + std::is_convertible{}); + } + + Server &set_exception_handler(ExceptionHandler handler); + + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_pre_request_handler(HandlerWithResponse handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + Server &set_pre_compression_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server & + set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template + Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template + Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template + Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + void decommission(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + +private: + using Handlers = + std::vector, Handler>>; + using HandlersForContentReader = + std::vector, + HandlerWithContentReader>>; + + static std::unique_ptr + make_matcher(const std::string &pattern); + + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + + socket_t create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res); + bool dispatch_request(Request &req, Response &res, + const Handlers &handlers) const; + bool dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, + std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, + Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, + const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool read_content_with_content_receiver(Stream &strm, Request &req, + Response &res, + ContentReceiver receiver, + FormDataHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + FormDataHeader multipart_header, + ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic is_decommissioned{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + HandlerWithResponse pre_request_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + Logger pre_compression_logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = + detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { +public: + Result() = default; + Result(std::unique_ptr &&res, Error err, + Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)) {} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error, unsigned long ssl_openssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error), + ssl_openssl_error_(ssl_openssl_error) {} +#endif + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // SSL Error + int ssl_error() const { return ssl_error_; } + // OpenSSL Error + unsigned long ssl_openssl_error() const { return ssl_openssl_error_; } +#endif + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, + const char *def = "", + size_t id = 0) const; + size_t get_request_header_value_u64(const std::string &key, size_t def = 0, + size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + +private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + int ssl_error_ = 0; + unsigned long ssl_openssl_error_ = 0; +#endif +}; + +class ClientImpl { +public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + // clang-format off + Result Get(const std::string &path, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Params ¶ms); + Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Delete(const std::string &path, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Params ¶ms, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const Params ¶ms, DownloadProgress progress = nullptr); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + // clang-format on + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_path_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier( + std::function verifier); +#endif + + void set_logger(Logger logger); + +protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, + Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = + detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + time_t max_timeout_msec_ = CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool path_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; +#endif + + Logger logger_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + int last_ssl_error_ = 0; + unsigned long last_openssl_error_ = 0; +#endif + +private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, + Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, + Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool create_redirect_client(const std::string &scheme, + const std::string &host, int port, Request &req, + Response &res, const std::string &path, + const std::string &location, Error &error); + template void setup_redirect_client(ClientType &client); + bool handle_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + std::unique_ptr send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider( + const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, UploadProgress progress); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const UploadFormDataItems &items, + const FormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback); + virtual bool is_ssl() const; +}; + +class Client { +public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + Client &operator=(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + // clang-format off + Result Get(const std::string &path, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Params ¶ms); + Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const FormDataProviderItems &provider_items, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); + + Result Delete(const std::string &path, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Params ¶ms, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, DownloadProgress progress = nullptr); + Result Delete(const std::string &path, const Headers &headers, const Params ¶ms, DownloadProgress progress = nullptr); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + // clang-format on + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_max_timeout(time_t msec); + template + void set_max_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_path_encode(bool on); + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier( + std::function verifier); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + +private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + SSLServer( + const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + int last_ssl_error_ = 0; +#endif +}; + +class SSLClient final : public ClientImpl { +public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool + process_socket(const Socket &socket, + std::chrono::time_point start_time, + std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy( + Socket &sock, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template +inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast( + duration - std::chrono::seconds(sec)) + .count(); + callback(static_cast(sec), static_cast(usec)); +} + +template inline constexpr size_t str_len(const char (&)[N]) { + return N - 1; +} + +inline bool is_numeric(const std::string &str) { + return !str.empty() && + std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); +} + +inline size_t get_header_value_u64(const Headers &headers, + const std::string &key, size_t def, + size_t id, bool &is_invalid_value) { + is_invalid_value = false; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } + } + return def; +} + +inline size_t get_header_value_u64(const Headers &headers, + const std::string &key, size_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + +} // namespace detail + +inline size_t Request::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline size_t Response::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +namespace detail { + +inline bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, +#ifdef _WIN64 + reinterpret_cast(optval), +#else + optval, +#endif + optlen) == 0; +} + +inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); +} + +inline bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN64 + auto timeout = static_cast(sec * 1000 + usec / 1000); +#else + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); +#endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + +} // namespace detail + +inline void default_socket_options(socket_t sock) { + detail::set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, +#endif + 1); +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); + return req.get_header_value("Authorization") + .substr(bearer_header_prefix_len); + } + return ""; +} + +template +inline Server & +Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::Unknown: return "Unknown"; + default: break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline size_t Result::get_request_header_value_u64(const std::string &key, + size_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + +template +inline void ClientImpl::set_connection_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { + set_connection_timeout(sec, usec); + }); +} + +template +inline void ClientImpl::set_read_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_max_timeout( + const std::chrono::duration &duration) { + auto msec = + std::chrono::duration_cast(duration).count(); + set_max_timeout(msec); +} + +template +inline void Client::set_connection_timeout( + const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void +Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void +Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +inline void Client::set_max_timeout(time_t msec) { + cli_->set_max_timeout(msec); +} + +template +inline void +Client::set_max_timeout(const std::chrono::duration &duration) { + cli_->set_max_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string encode_uri_component(const std::string &value); + +std::string encode_uri(const std::string &value); + +std::string decode_uri_component(const std::string &value); + +std::string decode_uri(const std::string &value); + +std::string encode_query_param(const std::string &value); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +#if defined(_WIN64) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + +private: +#if defined(_WIN64) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + +std::string decode_path(const std::string &s, bool convert_plus_to_space); + +std::string trim_copy(const std::string &s); + +void divide( + const char *data, std::size_t size, char d, + std::function + fn); + +void divide( + const std::string &str, char d, + std::function + fn); + +void split(const char *b, const char *e, char d, + std::function fn); + +void split(const char *b, const char *e, char d, size_t m, + std::function fn); + +bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback); + +socket_t create_client_socket(const std::string &host, const std::string &ip, + int port, int address_family, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, + const char *def, size_t id); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +bool parse_accept_header(const std::string &s, + std::vector &content_types); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli, Zstd }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +class compressor { +public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; +}; + +class decompressor { +public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; +}; + +class nocompressor final : public compressor { +public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { +public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { +public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { +public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { +public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +class zstd_compressor : public compressor { +public: + zstd_compressor(); + ~zstd_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + ZSTD_CCtx *ctx_ = nullptr; +}; + +class zstd_decompressor : public decompressor { +public: + zstd_decompressor(); + ~zstd_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + ZSTD_DCtx *ctx_ = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + +private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string growable_buffer_; +}; + +class mmap { +public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + +private: +#if defined(_WIN64) + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; +#else + int fd_ = -1; +#endif + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; +}; + +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +inline bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +inline bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +inline bool is_field_name(const std::string &s) { return is_token(s); } + +inline bool is_vchar(char c) { return c >= 33 && c <= 126; } + +inline bool is_obs_text(char c) { return 128 <= static_cast(c); } + +inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +inline bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +inline bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN64) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { + return ret_ >= 0 && S_ISREG(st_.st_mode); +} +inline bool FileStat::is_dir() const { + return ret_ >= 0 && S_ISDIR(st_.st_mode); +} + +inline std::string encode_path(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_path(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + thread_local auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, + size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void +divide(const char *data, std::size_t size, char d, + std::function + fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void +divide(const std::string &str, char d, + std::function + fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, + std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (growable_buffer_.empty()) { + return fixed_buffer_; + } else { + return growable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (growable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return growable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + growable_buffer_.clear(); + +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + + for (size_t i = 0;; i++) { + if (size() >= CPPHTTPLIB_MAX_LINE_LENGTH) { + // Treat exceptionally long lines as an error to + // prevent infinite loops/memory exhaustion + return false; + } + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + if (byte == '\n') { break; } +#else + if (prev_byte == '\r' && byte == '\n') { break; } + prev_byte = byte; +#endif + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (growable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + growable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + growable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) { open(path); } + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN64) + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { return false; } + + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, + OPEN_EXISTING, NULL); + + if (hFile_ == INVALID_HANDLE_VALUE) { return false; } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { return false; } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > + (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } + size_ = static_cast(size.QuadPart); + + hMapping_ = + ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } + + if (hMapping_ == NULL) { + close(); + return false; + } + + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); + + if (addr_ == nullptr) { + close(); + return false; + } +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { return false; } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); + + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { + close(); + is_open_empty_file = true; + return false; + } +#endif + + return true; +} + +inline bool mmap::is_open() const { + return is_open_empty_file ? true : addr_ != nullptr; +} + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { + return is_open_empty_file ? "" : static_cast(addr_); +} + +inline void mmap::close() { +#if defined(_WIN64) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } + + is_open_empty_file = false; +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN64 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN64 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, + int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN64 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { +#ifdef _WIN64 + return ::WSAPoll(fds, nfds, timeout); +#else + return ::poll(fds, nfds, timeout); +#endif +} + +template +inline ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { +#ifdef __APPLE__ + if (sock >= FD_SETSIZE) { return -1; } + + fd_set fds, *rfds, *wfds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + rfds = (Read ? &fds : nullptr); + wfds = (Read ? nullptr : &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { + return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); + }); +#else + struct pollfd pfd; + pfd.fd = sock; + pfd.events = (Read ? POLLIN : POLLOUT); + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); +#endif +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { + return select_impl(sock, sec, usec); +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, + time_t usec) { +#ifdef __APPLE__ + if (sock >= FD_SETSIZE) { return Error::Connection; } + + fd_set fdsr, fdsw; + FD_ZERO(&fdsr); + FD_ZERO(&fdsw); + FD_SET(sock, &fdsr); + FD_SET(sock, &fdsw); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, nullptr, &tv); + }); + + if (ret == 0) { return Error::ConnectionTimeout; } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = + handle_EINTR([&]() { return poll_wrapper(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { return Error::ConnectionTimeout; } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { +public: + SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; +}; +#endif + +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, + time_t keep_alive_timeout_sec) { + using namespace std::chrono; + + const auto interval_usec = + CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { return true; } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + + while (true) { + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); + if (val < 0) { + break; // Ssocket error + } else if (val == 0) { + if (steady_clock::now() - start > timeout) { + break; // Timeout + } + } else { + return true; // Ready for read + } + } + + return false; +} + +template +inline bool +process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { break; } + count--; + } + return ret; +} + +template +inline bool +process_server_socket(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN64 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string +unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + +inline int getaddrinfo_with_timeout(const char *node, const char *service, + const struct addrinfo *hints, + struct addrinfo **res, time_t timeout_sec) { +#ifdef CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO + if (timeout_sec <= 0) { + // No timeout specified, use standard getaddrinfo + return getaddrinfo(node, service, hints, res); + } + +#ifdef _WIN64 + // Windows-specific implementation using GetAddrInfoEx with overlapped I/O + OVERLAPPED overlapped = {0}; + HANDLE event = CreateEventW(nullptr, TRUE, FALSE, nullptr); + if (!event) { return EAI_FAIL; } + + overlapped.hEvent = event; + + PADDRINFOEXW result_addrinfo = nullptr; + HANDLE cancel_handle = nullptr; + + ADDRINFOEXW hints_ex = {0}; + if (hints) { + hints_ex.ai_flags = hints->ai_flags; + hints_ex.ai_family = hints->ai_family; + hints_ex.ai_socktype = hints->ai_socktype; + hints_ex.ai_protocol = hints->ai_protocol; + } + + auto wnode = u8string_to_wstring(node); + auto wservice = u8string_to_wstring(service); + + auto ret = ::GetAddrInfoExW(wnode.data(), wservice.data(), NS_DNS, nullptr, + hints ? &hints_ex : nullptr, &result_addrinfo, + nullptr, &overlapped, nullptr, &cancel_handle); + + if (ret == WSA_IO_PENDING) { + auto wait_result = + ::WaitForSingleObject(event, static_cast(timeout_sec * 1000)); + if (wait_result == WAIT_TIMEOUT) { + if (cancel_handle) { ::GetAddrInfoExCancel(&cancel_handle); } + ::CloseHandle(event); + return EAI_AGAIN; + } + + DWORD bytes_returned; + if (!::GetOverlappedResult((HANDLE)INVALID_SOCKET, &overlapped, + &bytes_returned, FALSE)) { + ::CloseHandle(event); + return ::WSAGetLastError(); + } + } + + ::CloseHandle(event); + + if (ret == NO_ERROR || ret == WSA_IO_PENDING) { + *res = reinterpret_cast(result_addrinfo); + return 0; + } + + return ret; +#elif defined(TARGET_OS_OSX) + // macOS implementation using CFHost API for asynchronous DNS resolution + CFStringRef hostname_ref = CFStringCreateWithCString( + kCFAllocatorDefault, node, kCFStringEncodingUTF8); + if (!hostname_ref) { return EAI_MEMORY; } + + CFHostRef host_ref = CFHostCreateWithName(kCFAllocatorDefault, hostname_ref); + CFRelease(hostname_ref); + if (!host_ref) { return EAI_MEMORY; } + + // Set up context for callback + struct CFHostContext { + bool completed = false; + bool success = false; + CFArrayRef addresses = nullptr; + std::mutex mutex; + std::condition_variable cv; + } context; + + CFHostClientContext client_context; + memset(&client_context, 0, sizeof(client_context)); + client_context.info = &context; + + // Set callback + auto callback = [](CFHostRef theHost, CFHostInfoType /*typeInfo*/, + const CFStreamError *error, void *info) { + auto ctx = static_cast(info); + std::lock_guard lock(ctx->mutex); + + if (error && error->error != 0) { + ctx->success = false; + } else { + Boolean hasBeenResolved; + ctx->addresses = CFHostGetAddressing(theHost, &hasBeenResolved); + if (ctx->addresses && hasBeenResolved) { + CFRetain(ctx->addresses); + ctx->success = true; + } else { + ctx->success = false; + } + } + ctx->completed = true; + ctx->cv.notify_one(); + }; + + if (!CFHostSetClient(host_ref, callback, &client_context)) { + CFRelease(host_ref); + return EAI_SYSTEM; + } + + // Schedule on run loop + CFRunLoopRef run_loop = CFRunLoopGetCurrent(); + CFHostScheduleWithRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode); + + // Start resolution + CFStreamError stream_error; + if (!CFHostStartInfoResolution(host_ref, kCFHostAddresses, &stream_error)) { + CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode); + CFRelease(host_ref); + return EAI_FAIL; + } + + // Wait for completion with timeout + auto timeout_time = + std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec); + bool timed_out = false; + + { + std::unique_lock lock(context.mutex); + + while (!context.completed) { + auto now = std::chrono::steady_clock::now(); + if (now >= timeout_time) { + timed_out = true; + break; + } + + // Run the runloop for a short time + lock.unlock(); + CFRunLoopRunInMode(kCFRunLoopDefaultMode, 0.1, true); + lock.lock(); + } + } + + // Clean up + CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode); + CFHostSetClient(host_ref, nullptr, nullptr); + + if (timed_out || !context.completed) { + CFHostCancelInfoResolution(host_ref, kCFHostAddresses); + CFRelease(host_ref); + return EAI_AGAIN; + } + + if (!context.success || !context.addresses) { + CFRelease(host_ref); + return EAI_NODATA; + } + + // Convert CFArray to addrinfo + CFIndex count = CFArrayGetCount(context.addresses); + if (count == 0) { + CFRelease(context.addresses); + CFRelease(host_ref); + return EAI_NODATA; + } + + struct addrinfo *result_addrinfo = nullptr; + struct addrinfo **current = &result_addrinfo; + + for (CFIndex i = 0; i < count; i++) { + CFDataRef addr_data = + static_cast(CFArrayGetValueAtIndex(context.addresses, i)); + if (!addr_data) continue; + + const struct sockaddr *sockaddr_ptr = + reinterpret_cast(CFDataGetBytePtr(addr_data)); + socklen_t sockaddr_len = static_cast(CFDataGetLength(addr_data)); + + // Allocate addrinfo structure + *current = static_cast(malloc(sizeof(struct addrinfo))); + if (!*current) { + freeaddrinfo(result_addrinfo); + CFRelease(context.addresses); + CFRelease(host_ref); + return EAI_MEMORY; + } + + memset(*current, 0, sizeof(struct addrinfo)); + + // Set up addrinfo fields + (*current)->ai_family = sockaddr_ptr->sa_family; + (*current)->ai_socktype = hints ? hints->ai_socktype : SOCK_STREAM; + (*current)->ai_protocol = hints ? hints->ai_protocol : IPPROTO_TCP; + (*current)->ai_addrlen = sockaddr_len; + + // Copy sockaddr + (*current)->ai_addr = static_cast(malloc(sockaddr_len)); + if (!(*current)->ai_addr) { + freeaddrinfo(result_addrinfo); + CFRelease(context.addresses); + CFRelease(host_ref); + return EAI_MEMORY; + } + memcpy((*current)->ai_addr, sockaddr_ptr, sockaddr_len); + + // Set port if service is specified + if (service && strlen(service) > 0) { + int port = atoi(service); + if (port > 0) { + if (sockaddr_ptr->sa_family == AF_INET) { + reinterpret_cast((*current)->ai_addr) + ->sin_port = htons(static_cast(port)); + } else if (sockaddr_ptr->sa_family == AF_INET6) { + reinterpret_cast((*current)->ai_addr) + ->sin6_port = htons(static_cast(port)); + } + } + } + + current = &((*current)->ai_next); + } + + CFRelease(context.addresses); + CFRelease(host_ref); + + *res = result_addrinfo; + return 0; +#elif defined(_GNU_SOURCE) && defined(__GLIBC__) && \ + (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2)) + // Linux implementation using getaddrinfo_a for asynchronous DNS resolution + struct gaicb request; + struct gaicb *requests[1] = {&request}; + struct sigevent sevp; + struct timespec timeout; + + // Initialize the request structure + memset(&request, 0, sizeof(request)); + request.ar_name = node; + request.ar_service = service; + request.ar_request = hints; + + // Set up timeout + timeout.tv_sec = timeout_sec; + timeout.tv_nsec = 0; + + // Initialize sigevent structure (not used, but required) + memset(&sevp, 0, sizeof(sevp)); + sevp.sigev_notify = SIGEV_NONE; + + // Start asynchronous resolution + int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp); + if (start_result != 0) { return start_result; } + + // Wait for completion with timeout + int wait_result = + gai_suspend((const struct gaicb *const *)requests, 1, &timeout); + + if (wait_result == 0) { + // Completed successfully, get the result + int gai_result = gai_error(&request); + if (gai_result == 0) { + *res = request.ar_result; + return 0; + } else { + // Clean up on error + if (request.ar_result) { freeaddrinfo(request.ar_result); } + return gai_result; + } + } else if (wait_result == EAI_AGAIN) { + // Timeout occurred, cancel the request + gai_cancel(&request); + return EAI_AGAIN; + } else { + // Other error occurred + gai_cancel(&request); + return wait_result; + } +#else + // Fallback implementation using thread-based timeout for other Unix systems + std::mutex result_mutex; + std::condition_variable result_cv; + auto completed = false; + auto result = EAI_SYSTEM; + struct addrinfo *result_addrinfo = nullptr; + + std::thread resolve_thread([&]() { + auto thread_result = getaddrinfo(node, service, hints, &result_addrinfo); + + std::lock_guard lock(result_mutex); + result = thread_result; + completed = true; + result_cv.notify_one(); + }); + + // Wait for completion or timeout + std::unique_lock lock(result_mutex); + auto finished = result_cv.wait_for(lock, std::chrono::seconds(timeout_sec), + [&] { return completed; }); + + if (finished) { + // Operation completed within timeout + resolve_thread.join(); + *res = result_addrinfo; + return result; + } else { + // Timeout occurred + resolve_thread.detach(); // Let the thread finish in background + return EAI_AGAIN; // Return timeout error + } +#endif +#else + (void)(timeout_sec); // Unused parameter for non-blocking getaddrinfo + return getaddrinfo(node, service, hints, res); +#endif +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, + int address_family, int socket_flags, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + BindOrConnect bind_or_connect, time_t timeout_sec = 0) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_IP; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { node = host.c_str(); } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#if !defined(_WIN64) || defined(CPPHTTPLIB_HAVE_AFUNIX_H) + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } + +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, + hints.ai_protocol); +#else + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast( + sizeof(addr) - sizeof(addr.sun_path) + addrlen); + +#ifndef SOCK_CLOEXEC +#ifndef _WIN64 + fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif +#endif + + if (socket_options) { socket_options(sock); } + +#ifdef _WIN64 + // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so + // remove the option. + detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); +#endif + + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo_with_timeout(node, service.c_str(), &hints, &result, + timeout_sec)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN64 + auto sock = + WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + +#ifdef SOCK_CLOEXEC + auto sock = + socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + +#endif + if (sock == INVALID_SOCKET) { continue; } + +#if !defined _WIN64 && !defined SOCK_CLOEXEC + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); } + + if (rp->ai_family == AF_INET6) { + set_socket_opt(sock, IPPROTO_IPV6, IPV6_V6ONLY, ipv6_v6only ? 1 : 0); + } + + if (socket_options) { socket_options(sock); } + + // bind or connect + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { return sock; } + + close_socket(sock); + + if (quit) { break; } + } + + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN64 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN64 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo_with_timeout(host.c_str(), "0", &hints, &result, 0)) { + return false; + } + + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + return ret; +} + +#if !defined _WIN64 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || + ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + return addr_candidate; +} +#endif + +inline socket_t create_client_socket( + const std::string &host, const std::string &ip, int port, + int address_family, bool tcp_nodelay, bool ipv6_v6only, + SocketOptions socket_options, time_t connection_timeout_sec, + time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { ip_from_if = intf; } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = + ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, + connection_timeout_usec); + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { quit = true; } + return false; + } + } + + set_nonblocking(sock2, false); + set_socket_opt_time(sock2, SOL_SOCKET, SO_RCVTIMEO, read_timeout_sec, + read_timeout_usec); + set_socket_opt_time(sock2, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec, + write_timeout_usec); + + error = Error::Success; + return true; + }, + connection_timeout_sec); // Pass DNS timeout + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { error = Error::Connection; } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), + &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { +#ifndef _WIN64 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, + unsigned int h) { + return (l == 0) + ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { + return str2tag_core(s.data(), s.size(), 0); +} + +namespace udl { + +inline constexpr unsigned int operator""_t(const char *s, size_t l) { + return str2tag_core(s, l, 0); +} + +} // namespace udl + +inline std::string +find_content_type(const std::string &path, + const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second; } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: return default_content_type; + + case "css"_t: return "text/css"; + case "csv"_t: return "text/csv"; + case "htm"_t: + case "html"_t: return "text/html"; + case "js"_t: + case "mjs"_t: return "text/javascript"; + case "txt"_t: return "text/plain"; + case "vtt"_t: return "text/vtt"; + + case "apng"_t: return "image/apng"; + case "avif"_t: return "image/avif"; + case "bmp"_t: return "image/bmp"; + case "gif"_t: return "image/gif"; + case "png"_t: return "image/png"; + case "svg"_t: return "image/svg+xml"; + case "webp"_t: return "image/webp"; + case "ico"_t: return "image/x-icon"; + case "tif"_t: return "image/tiff"; + case "tiff"_t: return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: return "image/jpeg"; + + case "mp4"_t: return "video/mp4"; + case "mpeg"_t: return "video/mpeg"; + case "webm"_t: return "video/webm"; + + case "mp3"_t: return "audio/mp3"; + case "mpga"_t: return "audio/mpeg"; + case "weba"_t: return "audio/webm"; + case "wav"_t: return "audio/wave"; + + case "otf"_t: return "font/otf"; + case "ttf"_t: return "font/ttf"; + case "woff"_t: return "font/woff"; + case "woff2"_t: return "font/woff2"; + + case "7z"_t: return "application/x-7z-compressed"; + case "atom"_t: return "application/atom+xml"; + case "pdf"_t: return "application/pdf"; + case "json"_t: return "application/json"; + case "rss"_t: return "application/rss+xml"; + case "tar"_t: return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: return "application/xhtml+xml"; + case "xslt"_t: return "application/xslt+xml"; + case "xml"_t: return "application/xml"; + case "gz"_t: return "application/gzip"; + case "zip"_t: return "application/zip"; + case "wasm"_t: return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: return true; + + case "text/event-stream"_t: return false; + + default: return !content_type.rfind("text/", 0); + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { return EncodingType::None; } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { return EncodingType::Brotli; } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { return EncodingType::Gzip; } +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 + ret = s.find("zstd") != std::string::npos; + if (ret) { return EncodingType::Zstd; } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, + bool /*last*/, Callback callback) { + if (!data_length) { return true; } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || + (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm_); return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { return false; } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); +} + +inline brotli_compressor::~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); +} + +inline bool brotli_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { break; } + } else { + if (!available_in) { break; } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, + &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT + : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, + size_t data_length, + Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - avail_out)) { return false; } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +#ifdef CPPHTTPLIB_ZSTD_SUPPORT +inline zstd_compressor::zstd_compressor() { + ctx_ = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ctx_, ZSTD_c_compressionLevel, ZSTD_fast); +} + +inline zstd_compressor::~zstd_compressor() { ZSTD_freeCCtx(ctx_); } + +inline bool zstd_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + ZSTD_EndDirective mode = last ? ZSTD_e_end : ZSTD_e_continue; + ZSTD_inBuffer input = {data, data_length, 0}; + + bool finished; + do { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_compressStream2(ctx_, &output, &input, mode); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + + finished = last ? (remaining == 0) : (input.pos == input.size); + + } while (!finished); + + return true; +} + +inline zstd_decompressor::zstd_decompressor() { ctx_ = ZSTD_createDCtx(); } + +inline zstd_decompressor::~zstd_decompressor() { ZSTD_freeDCtx(ctx_); } + +inline bool zstd_decompressor::is_valid() const { return ctx_ != nullptr; } + +inline bool zstd_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + std::array buff{}; + ZSTD_inBuffer input = {data, data_length, 0}; + + while (input.pos < input.size) { + ZSTD_outBuffer output = {buff.data(), CPPHTTPLIB_COMPRESSION_BUFSIZ, 0}; + size_t const remaining = ZSTD_decompressStream(ctx_, &output, &input); + + if (ZSTD_isError(remaining)) { return false; } + + if (!callback(buff.data(), output.pos)) { return false; } + } + + return true; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, + const std::string &key, const char *def, + size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.c_str(); } + return def; +} + +template +inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_path(val, false)); + } + + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + size_t header_count = 0; + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + // Blank line indicates end of headers. + if (line_reader.size() == 1) { break; } + line_terminator_len = 1; +#else + continue; // Skip invalid line. +#endif + } + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Check header count limit + if (header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + headers.emplace(key, val); + })) { + return false; + } + + header_count++; + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, size_t len, + DownloadProgress progress, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + size_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n), r, len)) { return false; } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, size_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + size_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +enum class ReadContentResult { + Success, // Successfully read the content + PayloadTooLarge, // The content exceeds the specified payload limit + Error // An error occurred while reading the content +}; + +inline ReadContentResult +read_content_without_length(Stream &strm, size_t payload_max_length, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + size_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n == 0) { return ReadContentResult::Success; } + if (n < 0) { return ReadContentResult::Error; } + + // Check if adding this data would exceed the payload limit + if (r > payload_max_length || + payload_max_length - r < static_cast(n)) { + return ReadContentResult::PayloadTooLarge; + } + + if (!out(buf, static_cast(n), r, 0)) { + return ReadContentResult::Error; + } + r += static_cast(n); + } + + return ReadContentResult::Success; +} + +template +inline ReadContentResult read_content_chunked(Stream &strm, T &x, + size_t payload_max_length, + ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return ReadContentResult::Error; } + + unsigned long chunk_len; + size_t total_len = 0; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; } + if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; } + + if (chunk_len == 0) { break; } + + // Check if adding this chunk would exceed the payload limit + if (total_len > payload_max_length || + payload_max_length - total_len < chunk_len) { + return ReadContentResult::PayloadTooLarge; + } + + total_len += chunk_len; + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return ReadContentResult::Error; + } + + if (!line_reader.getline()) { return ReadContentResult::Error; } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { + return ReadContentResult::Error; + } + + if (!line_reader.getline()) { return ReadContentResult::Error; } + } + + assert(chunk_len == 0); + + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + if (!line_reader.getline()) { return ReadContentResult::Success; } + + // RFC 7230 Section 4.1.2 - Headers prohibited in trailers + thread_local case_ignore::unordered_set prohibited_trailers = { + // Message framing + "transfer-encoding", "content-length", + + // Routing + "host", + + // Authentication + "authorization", "www-authenticate", "proxy-authenticate", + "proxy-authorization", "cookie", "set-cookie", + + // Request modifiers + "cache-control", "expect", "max-forwards", "pragma", "range", "te", + + // Response control + "age", "expires", "date", "location", "retry-after", "vary", "warning", + + // Payload processing + "content-encoding", "content-type", "content-range", "trailer"}; + + // Parse declared trailer headers once for performance + case_ignore::unordered_set declared_trailers; + if (has_header(x.headers, "Trailer")) { + auto trailer_header = get_header_value(x.headers, "Trailer", "", 0); + auto len = std::strlen(trailer_header); + + split(trailer_header, trailer_header + len, ',', + [&](const char *b, const char *e) { + std::string key(b, e); + if (prohibited_trailers.find(key) == prohibited_trailers.end()) { + declared_trailers.insert(key); + } + }); + } + + size_t trailer_header_count = 0; + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return ReadContentResult::Error; + } + + // Check trailer header count limit + if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { + return ReadContentResult::Error; + } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + if (declared_trailers.find(key) != declared_trailers.end()) { + x.trailers.emplace(key, val); + trailer_header_count++; + } + }); + + if (!line_reader.getline()) { return ReadContentResult::Error; } + } + + return ReadContentResult::Success; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return case_ignore::equal( + get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, + ContentReceiverWithProgress receiver, + bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding == "zstd") { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, + size_t off, size_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { + return receiver(buf2, n2, off, len); + }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, size_t off, + size_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + DownloadProgress progress, + ContentReceiverWithProgress receiver, bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, + [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + auto result = read_content_chunked(strm, x, payload_max_length, out); + if (result == ReadContentResult::Success) { + ret = true; + } else if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + ret = false; + } else { + ret = false; + } + } else if (!has_header(x.headers, "Content-Length")) { + auto result = + read_content_without_length(strm, payload_max_length, out); + if (result == ReadContentResult::Success) { + ret = true; + } else if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + ret = false; + } else { + ret = false; + } + } else { + auto is_invalid_value = false; + auto len = get_header_value_u64(x.headers, "Content-Length", + (std::numeric_limits::max)(), + 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 + : StatusCode::BadRequest_400; + } + return ret; + }); +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, + const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { return false; } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content_with_progress(Stream &strm, + const ContentProvider &content_provider, + size_t offset, size_t length, + T is_shutting_down, + const UploadProgress &upload_progress, + Error &error) { + size_t end_offset = offset + length; + size_t start_offset = offset; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (write_data(strm, d, l)) { + offset += l; + + if (upload_progress && length > 0) { + size_t current_written = offset - start_offset; + if (!upload_progress(current_written, length)) { + ok = false; + return false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.wait_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, T is_shutting_down, + Error &error) { + return write_content_with_progress(strm, content_provider, offset, length, + is_shutting_down, nullptr, error); +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, + error); +} + +template +inline bool +write_content_without_length(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!write_data(strm, d, l)) { ok = false; } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.wait_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool +write_content_chunked(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = + from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { return; } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + constexpr const char done_marker[] = "0\r\n"; + if (!write_data(strm, done_marker, str_len(done_marker))) { ok = false; } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + constexpr const char crlf[] = "\r\n"; + if (!write_data(strm, crlf, str_len(crlf))) { ok = false; } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { + done_with_trailer(&trailer); + }; + + while (data_available && !is_shutting_down()) { + if (!strm.wait_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, + compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, + const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && + (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { res.location = location; } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += httplib::encode_uri_component(it->second); + } + return query; +} + +inline void parse_query_text(const char *data, std::size_t size, + Params ¶ms) { + std::set cache; + split(data, data + size, '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(std::move(kv)); + + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, + std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + params.emplace(decode_path(key, true), decode_path(val, true)); + } + }); +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + parse_query_text(s.data(), s.size(), params); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { return false; } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), + trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { return; } + + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; + } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = + static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = + static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || + (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); + }); + return all_valid_ranges && !ranges.empty(); + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { return false; } +#endif + +inline bool parse_accept_header(const std::string &s, + std::vector &content_types) { + content_types.clear(); + + // Empty string is considered valid (no preference) + if (s.empty()) { return true; } + + // Check for invalid patterns: leading/trailing commas or consecutive commas + if (s.front() == ',' || s.back() == ',' || + s.find(",,") != std::string::npos) { + return false; + } + + struct AcceptEntry { + std::string media_type; + double quality; + int order; // Original order in header + }; + + std::vector entries; + int order = 0; + bool has_invalid_entry = false; + + // Split by comma and parse each entry + split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) { + std::string entry(b, e); + entry = trim_copy(entry); + + if (entry.empty()) { + has_invalid_entry = true; + return; + } + + AcceptEntry accept_entry; + accept_entry.quality = 1.0; // Default quality + accept_entry.order = order++; + + // Find q= parameter + auto q_pos = entry.find(";q="); + if (q_pos == std::string::npos) { q_pos = entry.find("; q="); } + + if (q_pos != std::string::npos) { + // Extract media type (before q parameter) + accept_entry.media_type = trim_copy(entry.substr(0, q_pos)); + + // Extract quality value + auto q_start = entry.find('=', q_pos) + 1; + auto q_end = entry.find(';', q_start); + if (q_end == std::string::npos) { q_end = entry.length(); } + + std::string quality_str = + trim_copy(entry.substr(q_start, q_end - q_start)); + if (quality_str.empty()) { + has_invalid_entry = true; + return; + } + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + { + std::istringstream iss(quality_str); + iss >> accept_entry.quality; + + // Check if conversion was successful and entire string was consumed + if (iss.fail() || !iss.eof()) { + has_invalid_entry = true; + return; + } + } +#else + try { + accept_entry.quality = std::stod(quality_str); + } catch (...) { + has_invalid_entry = true; + return; + } +#endif + // Check if quality is in valid range [0.0, 1.0] + if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) { + has_invalid_entry = true; + return; + } + } else { + // No quality parameter, use entire entry as media type + accept_entry.media_type = entry; + } + + // Remove additional parameters from media type + auto param_pos = accept_entry.media_type.find(';'); + if (param_pos != std::string::npos) { + accept_entry.media_type = + trim_copy(accept_entry.media_type.substr(0, param_pos)); + } + + // Basic validation of media type format + if (accept_entry.media_type.empty()) { + has_invalid_entry = true; + return; + } + + // Check for basic media type format (should contain '/' or be '*') + if (accept_entry.media_type != "*" && + accept_entry.media_type.find('/') == std::string::npos) { + has_invalid_entry = true; + return; + } + + entries.push_back(accept_entry); + }); + + // Return false if any invalid entry was found + if (has_invalid_entry) { return false; } + + // Sort by quality (descending), then by original order (ascending) + std::sort(entries.begin(), entries.end(), + [](const AcceptEntry &a, const AcceptEntry &b) { + if (a.quality != b.quality) { + return a.quality > b.quality; // Higher quality first + } + return a.order < b.order; // Earlier order first for same quality + }); + + // Extract sorted media types + content_types.reserve(entries.size()); + for (const auto &entry : entries) { + content_types.push_back(entry.media_type); + } + + return true; +} + +class FormDataParser { +public: + FormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const FormDataHeader &header_callback, + const ContentReceiver &content_callback) { + + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + auto pos = buf_find(dash_boundary_crlf_); + if (pos == buf_size()) { return true; } + buf_erase(pos + dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](const std::string &, const std::string &) {})) { + is_valid_ = false; + return false; + } + + // Parse and emplace space trimmed headers into a map + if (!parse_header( + header.data(), header.data() + header.size(), + [&](const std::string &key, const std::string &val) { + file_.headers.emplace(key, val); + })) { + is_valid_ = false; + return false; + } + + constexpr const char header_content_type[] = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = + trim_copy(header.substr(str_len(header_content_type))); + } else { + thread_local const std::regex re_content_disposition( + R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { file_.filename = it->second; } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 encoding... + thread_local const std::regex re_rfc5987_encoding( + R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_path(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { return true; } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { return true; } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { return true; } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { return true; } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + file_.headers.clear(); + } + + bool start_with_case_ignore(const std::string &a, const char *b) const { + const auto b_len = strlen(b); + if (a.size() < b_len) { return false; } + for (size_t i = 0; i < b_len; i++) { + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + FormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, + const std::string &b) const { + if (epos - spos < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { return false; } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { return buf_size(); } + if (buf_[pos] == c) { break; } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { return buf_size(); } + + if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string random_string(size_t length) { + constexpr const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + thread_local auto engine([]() { + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + return std::mt19937(seed_sequence); + }()); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string +serialize_multipart_formdata_item_begin(const T &item, + const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string +serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string +serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string +serialize_multipart_formdata(const UploadFormDataItems &items, + const std::string &boundary, bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { body += serialize_multipart_formdata_finish(boundary); } + + return body; +} + +inline void coalesce_ranges(Ranges &ranges, size_t content_length) { + if (ranges.size() <= 1) return; + + // Sort ranges by start position + std::sort(ranges.begin(), ranges.end(), + [](const Range &a, const Range &b) { return a.first < b.first; }); + + Ranges coalesced; + coalesced.reserve(ranges.size()); + + for (auto &r : ranges) { + auto first_pos = r.first; + auto last_pos = r.second; + + // Handle special cases like in range_error + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = static_cast(content_length); + } + + if (first_pos == -1) { + first_pos = static_cast(content_length) - last_pos; + last_pos = static_cast(content_length) - 1; + } + + if (last_pos == -1 || last_pos >= static_cast(content_length)) { + last_pos = static_cast(content_length) - 1; + } + + // Skip invalid ranges + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos < static_cast(content_length))) { + continue; + } + + // Coalesce with previous range if overlapping or adjacent (but not + // identical) + if (!coalesced.empty()) { + auto &prev = coalesced.back(); + // Check if current range overlaps or is adjacent to previous range + // but don't coalesce identical ranges (allow duplicates) + if (first_pos <= prev.second + 1 && + !(first_pos == prev.first && last_pos == prev.second)) { + // Extend the previous range + prev.second = (std::max)(prev.second, last_pos); + continue; + } + } + + // Add new range + coalesced.emplace_back(first_pos, last_pos); + } + + ranges = std::move(coalesced); +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t content_len = static_cast( + res.content_length_ ? res.content_length_ : res.body.size()); + + std::vector> processed_ranges; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { return true; } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = content_len; + } + + if (first_pos == -1) { + first_pos = content_len - last_pos; + last_pos = content_len - 1; + } + + // NOTE: RFC-9110 '14.1.2. Byte Ranges': + // A client can limit the number of bytes requested without knowing the + // size of the selected representation. If the last-pos value is absent, + // or if the value is greater than or equal to the current length of the + // representation data, the byte range is interpreted as the remainder of + // the representation (i.e., the server replaces the value of last-pos + // with a value that is one less than the current length of the selected + // representation). + // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 + if (last_pos == -1 || last_pos >= content_len) { + last_pos = content_len - 1; + } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos <= content_len - 1)) { + return true; + } + + // Request must not have more than two overlapping ranges + for (const auto &processed_range : processed_ranges) { + if (!(last_pos < processed_range.first || + first_pos > processed_range.second)) { + overwrapping_count++; + if (overwrapping_count > 2) { return true; } + break; // Only count once per range + } + } + + processed_ranges.emplace_back(first_pos, last_pos); + } + + // After validation, coalesce overlapping ranges as per RFC 9110 + coalesce_ranges(req.ranges, static_cast(content_len)); + } + + return false; +} + +inline std::pair +get_range_offset_and_length(Range r, size_t content_length) { + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && + r.second < static_cast(content_length)); + (void)(content_length); + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field( + const std::pair &offset_and_length, size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length, SToken stoken, + CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = + get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, + std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool +write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "DELETE") { + return true; + } + if (req.has_header("Content-Length") && + req.get_header_value_u64("Content-Length") > 0) { + return true; + } + if (is_chunked_transfer_encoding(req.headers)) { return true; } + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} + +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +inline bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} + +#ifdef _WIN64 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + if (!hStore) { return false; } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \ + defined(TARGET_OS_OSX) +template +using CFObjectPtr = + std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { CFRelease(obj); } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, + kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, + sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { return false; } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast( + CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != + errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast( + CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = + d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // _WIN64 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN64 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +inline bool parse_www_authenticate(const Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } + +private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { return std::string(); } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, + std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (detail::getaddrinfo_with_timeout(hostname.c_str(), nullptr, &hints, + &result, 0)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = + *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, + dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string encode_uri_component(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_uri(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')' || c == ';' || c == '/' || c == '?' || c == ':' || c == '@' || + c == '&' || c == '=' || c == '+' || c == '$' || c == ',' || c == '#') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string decode_uri_component(const std::string &value) { + std::string result; + + for (size_t i = 0; i < value.size(); i++) { + if (value[i] == '%' && i + 2 < value.size()) { + auto val = 0; + if (detail::from_hex_to_i(value, i + 1, 2, val)) { + result += static_cast(val); + i += 2; + } else { + result += value[i]; + } + } else { + result += value[i]; + } + } + + return result; +} + +inline std::string decode_uri(const std::string &value) { + std::string result; + + for (size_t i = 0; i < value.size(); i++) { + if (value[i] == '%' && i + 2 < value.size()) { + auto val = 0; + if (detail::from_hex_to_i(value, i + 1, 2, val)) { + result += static_cast(val); + i += 2; + } else { + result += value[i]; + } + } else { + result += value[i]; + } + } + + return result; +} + +[[deprecated("Use encode_uri_component instead")]] +inline std::string encode_query_param(const std::string &value) { + return encode_uri_component(value); +} + +inline std::string append_query_params(const std::string &path, + const Params ¶ms) { + std::string path_with_query = path; + thread_local const std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair +make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair +make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const std::string &key, + const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_trailer(const std::string &key) const { + return trailers.find(key) != trailers.end(); +} + +inline std::string Request::get_trailer_value(const std::string &key, + size_t id) const { + auto rng = trailers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_trailer_value_count(const std::string &key) const { + auto r = trailers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::has_param(const std::string &key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const std::string &key, + size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +// Multipart FormData implementation +inline std::string MultipartFormData::get_field(const std::string &key, + size_t id) const { + auto rng = fields.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.content; } + return std::string(); +} + +inline std::vector +MultipartFormData::get_fields(const std::string &key) const { + std::vector values; + auto rng = fields.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second.content); + } + return values; +} + +inline bool MultipartFormData::has_field(const std::string &key) const { + return fields.find(key) != fields.end(); +} + +inline size_t MultipartFormData::get_field_count(const std::string &key) const { + auto r = fields.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline FormData MultipartFormData::get_file(const std::string &key, + size_t id) const { + auto rng = files.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return FormData(); +} + +inline std::vector +MultipartFormData::get_files(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +inline bool MultipartFormData::has_file(const std::string &key) const { + return files.find(key) != files.end(); +} + +inline size_t MultipartFormData::get_file_count(const std::string &key) const { + auto r = files.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} +inline bool Response::has_trailer(const std::string &key) const { + return trailers.find(key) != trailers.end(); +} + +inline std::string Response::get_trailer_value(const std::string &key, + size_t id) const { + auto rng = trailers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Response::get_trailer_value_count(const std::string &key) const { + auto r = trailers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (detail::fields::is_field_value(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, + const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { content_provider_ = std::move(provider); } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +inline void Response::set_file_content(const std::string &path, + const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { + file_content_path_ = path; +} + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(request_headers_, key, def, id); +} + +inline size_t +Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +namespace detail { + +inline void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, + time_t timeout_sec, time_t timeout_usec, + time_t &actual_timeout_sec, + time_t &actual_timeout_usec) { + auto timeout_msec = (timeout_sec * 1000) + (timeout_usec / 1000); + + auto actual_timeout_msec = + (std::min)(max_timeout_msec - duration_msec, timeout_msec); + + if (actual_timeout_msec < 0) { actual_timeout_msec = 0; } + + actual_timeout_sec = actual_timeout_msec / 1000; + actual_timeout_usec = (actual_timeout_msec % 1000) * 1000; +} + +// Socket stream implementation +inline SocketStream::SocketStream( + socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time), + read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +inline bool SocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN64 + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, + static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!wait_readable()) { return -1; } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, + CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!wait_writable()) { return -1; } + +#if defined(_WIN64) && !defined(_WIN64) + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +inline time_t SocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline time_t BufferStream::duration() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) + : MatcherBase(pattern) { + constexpr const char marker[] = "/:"; + + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find( + marker, last_param_end == 0 ? last_param_end : last_param_end - 1); + if (marker_pos == std::string::npos) { break; } + + static_fragments_.push_back( + pattern.substr(last_param_end, marker_pos - last_param_end + 1)); + + const auto param_name_start = marker_pos + str_len(marker); + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } + + auto param_name = + pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), + fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { continue; } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { sep_pos = request.path.length(); } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace( + param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everything up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : new_task_queue( + [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN64 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr +Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, + const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, + const std::string &dir, Headers headers) { + detail::FileStat stat(dir); + if (stat.is_dir()) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server & +Server::set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, + std::true_type) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(Handler handler, + std::false_type) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_request_handler(HandlerWithResponse handler) { + pre_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server &Server::set_pre_compression_logger(Logger logger) { + pre_compression_logger_ = std::move(logger); + return *this; +} + +inline Server & +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer( + std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, + int socket_flags) { + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommissioned = true; } + return ret >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommissioned = true; } + return ret; +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const std::string &host, int port, + int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running_ && !is_decommissioned) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + is_decommissioned = false; +} + +inline void Server::decommission() { is_decommissioned = true; } + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: req.method = std::string(b, e); break; + case 1: req.target = std::string(b, e); break; + case 2: req.version = std::string(b, e); break; + default: break; + } + count++; + }); + + if (count != 3) { return false; } + } + + thread_local const std::set methods{ + "GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { return false; } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_path( + std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, + Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, + bool close_connection, + const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && + error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } + + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && + !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } + + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { post_routing_handler_(req, res); } + + // Response line and headers + { + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { return false; } + if (!header_writer_(bstrm, res.headers)) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return ret; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, + offset_and_length.first, + offset_and_length.second, is_shutting_down); + } else { + return detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + FormFields::iterator cur_field; + FormFiles::iterator cur_file; + auto is_text_field = false; + size_t count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart FormData + [&](const FormData &file) { + if (count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + + if (file.filename.empty()) { + cur_field = req.form.fields.emplace( + file.name, FormField{file.name, file.content, file.headers}); + is_text_field = true; + } else { + cur_file = req.form.files.emplace(file.name, file); + is_text_field = false; + } + return true; + }, + [&](const char *buf, size_t n) { + if (is_text_field) { + auto &content = cur_field->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + } else { + auto &content = cur_file->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + } + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + FormDataHeader multipart_header, ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool Server::read_content_core( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + FormDataHeader multipart_header, ContentReceiver multipart_receiver) const { + detail::FormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) { + return multipart_form_data_parser.parse(buf, n, multipart_header, + multipart_receiver); + }; + } else { + out = [receiver](const char *buf, size_t n, size_t /*off*/, + size_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { return false; } + + res.set_content_provider( + mm->size(), + detail::find_content_type(path, file_extension_and_mimetype_map_, + default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (req.method != "HEAD" && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t +Server::create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket( + host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, + ipv6_v6only_, std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, + int socket_flags) { + if (is_decommissioned) { return -1; } + + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + if (is_decommissioned) { return false; } + + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN64 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN64 + } +#endif + +#if defined _WIN64 + // sockets connected via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else + socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_RCVTIMEO, + read_timeout_sec_, read_timeout_usec_); + detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO, + write_timeout_sec_, write_timeout_usec_); + + if (!task_queue->enqueue( + [this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + is_decommissioned = !ret; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + if ((req.method == "GET" || req.method == "HEAD") && + handle_file_request(req, res)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](FormDataHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + std::move(header), + std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + req.matched_route = matcher->pattern(); + if (!pre_request_handler_ || + pre_request_handler_(req, res) != HandlerResponse::Handled) { + handler(req, res); + } + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, + std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length( + req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } else if (type == detail::EncodingType::Zstd) { + res.set_header("Content-Encoding", "zstd"); + } + } + } + } + } else { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = + detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, + res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + if (pre_compression_logger_) { pre_compression_logger_(req, res); } + + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } else if (type == detail::EncodingType::Zstd) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + compressor = detail::make_unique(); + content_encoding = "zstd"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + req.matched_route = matcher->pattern(); + if (!pre_request_handler_ || + pre_request_handler_(req, res) != HandlerResponse::Handled) { + handler(req, res, content_reader); + } + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + +#ifdef __APPLE__ + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::InternalServerError_500; + return write_response(strm, close_connection, req, res); + } +#endif + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + // Check if the request URI doesn't exceed the limit + if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + req.local_addr = local_addr; + req.local_port = local_port; + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Accept")) { + const auto &accept_header = req.get_header_value("Accept"); + if (!detail::parse_accept_header(accept_header, req.accept_content_types)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + } + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + detail::write_response_line(strm, status); + strm.write("\r\n"); + break; + default: + connection_closed = true; + return write_response(strm, true, req, res); + } + } + + // Setup `is_connection_closed` method + auto sock = strm.socket(); + req.is_connection_closed = [sock]() { + return !detail::is_socket_alive(sock); + }; + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': val += "\\r"; break; + case '\n': val += "\\n"; break; + default: val += s[i]; break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 + : StatusCode::PartialContent_206; + } + + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { res.status = StatusCode::NotFound_404; } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, + nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) + : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + // Wait until all the requests in flight are handled. + size_t retry_count = 10; + while (retry_count-- > 0) { + { + std::lock_guard guard(socket_mutex_); + if (socket_requests_in_flight_ == 0) { break; } + } + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + max_timeout_msec_ = rhs.max_timeout_msec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + path_encode_ = rhs.path_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { ip = it->second; } + + return detail::create_client_socket( + host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, + Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { return false; } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, + bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { return; } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { return; } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, + Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + thread_local const std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { return false; } // CRLF + if (!line_reader.getline()) { return false; } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down non-gracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, req.start_time_, res, success, + error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { return false; } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || + !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, req.start_time_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection, error); + }); + + if (!ret) { + if (error == Error::Success) { error = Error::Unknown; } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers), + last_ssl_error_, last_openssl_error_}; +#else + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +#endif +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { return false; } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || + res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" + : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + thread_local const std::regex re( + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { next_host = m[3].str(); } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + auto path = detail::decode_path(next_path, true) + next_query; + + // Same host redirect - use current client + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } + + // Cross-host/scheme redirect - create new client with robust setup + return create_redirect_client(next_scheme, next_host, next_port, req, res, + path, location, error); +} + +// New method for robust redirect client creation +inline bool ClientImpl::create_redirect_client( + const std::string &scheme, const std::string &host, int port, Request &req, + Response &res, const std::string &path, const std::string &location, + Error &error) { + // Determine if we need SSL + auto need_ssl = (scheme == "https"); + + // Clean up request headers that are host/client specific + // Remove headers that should not be carried over to new host + auto headers_to_remove = + std::vector{"Host", "Proxy-Authorization", "Authorization"}; + + for (const auto &header_name : headers_to_remove) { + auto it = req.headers.find(header_name); + while (it != req.headers.end()) { + it = req.headers.erase(it); + it = req.headers.find(header_name); + } + } + + // Create appropriate client type and handle redirect + if (need_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Create SSL client for HTTPS redirect + SSLClient redirect_client(host, port); + + // Setup basic client configuration first + setup_redirect_client(redirect_client); + + // SSL-specific configuration for proxy environments + if (!proxy_host_.empty() && proxy_port_ != -1) { + // Critical: Disable SSL verification for proxy environments + redirect_client.enable_server_certificate_verification(false); + redirect_client.enable_server_hostname_verification(false); + } else { + // For direct SSL connections, copy SSL verification settings + redirect_client.enable_server_certificate_verification( + server_certificate_verification_); + redirect_client.enable_server_hostname_verification( + server_hostname_verification_); + } + + // Handle CA certificate store and paths if available + if (ca_cert_store_) { redirect_client.set_ca_cert_store(ca_cert_store_); } + if (!ca_cert_file_path_.empty()) { + redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_); + } + + // Client certificates are set through constructor for SSLClient + // NOTE: SSLClient constructor already takes client_cert_path and + // client_key_path so we need to create it properly if client certs are + // needed + + // Execute the redirect + return detail::redirect(redirect_client, req, res, path, location, error); +#else + // SSL not supported - set appropriate error + error = Error::SSLConnection; + return false; +#endif + } else { + // HTTP redirect + ClientImpl redirect_client(host, port); + + // Setup client with robust configuration + setup_redirect_client(redirect_client); + + // Execute the redirect + return detail::redirect(redirect_client, req, res, path, location, error); + } +} + +// New method for robust client setup (based on basic_manual_redirect.cpp logic) +template +inline void ClientImpl::setup_redirect_client(ClientType &client) { + // Copy basic settings first + client.set_connection_timeout(connection_timeout_sec_); + client.set_read_timeout(read_timeout_sec_, read_timeout_usec_); + client.set_write_timeout(write_timeout_sec_, write_timeout_usec_); + client.set_keep_alive(keep_alive_); + client.set_follow_location( + true); // Enable redirects to handle multi-step redirects + client.set_path_encode(path_encode_); + client.set_compress(compress_); + client.set_decompress(decompress_); + + // Copy authentication settings BEFORE proxy setup + if (!basic_auth_username_.empty()) { + client.set_basic_auth(basic_auth_username_, basic_auth_password_); + } + if (!bearer_token_auth_token_.empty()) { + client.set_bearer_token_auth(bearer_token_auth_token_); + } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!digest_auth_username_.empty()) { + client.set_digest_auth(digest_auth_username_, digest_auth_password_); + } +#endif + + // Setup proxy configuration (CRITICAL ORDER - proxy must be set + // before proxy auth) + if (!proxy_host_.empty() && proxy_port_ != -1) { + // First set proxy host and port + client.set_proxy(proxy_host_, proxy_port_); + + // Then set proxy authentication (order matters!) + if (!proxy_basic_auth_username_.empty()) { + client.set_proxy_basic_auth(proxy_basic_auth_username_, + proxy_basic_auth_password_); + } + if (!proxy_bearer_token_auth_token_.empty()) { + client.set_proxy_bearer_token_auth(proxy_bearer_token_auth_token_); + } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!proxy_digest_auth_username_.empty()) { + client.set_proxy_digest_auth(proxy_digest_auth_username_, + proxy_digest_auth_password_); + } +#endif + } + + // Copy network and socket settings + client.set_address_family(address_family_); + client.set_tcp_nodelay(tcp_nodelay_); + client.set_ipv6_v6only(ipv6_v6only_); + if (socket_options_) { client.set_socket_options(socket_options_); } + if (!interface_.empty()) { client.set_interface(interface_); } + + // Copy logging and headers + if (logger_) { client.set_logger(logger_); } + + // NOTE: DO NOT copy default_headers_ as they may contain stale Host headers + // Each new client should generate its own headers based on its target host +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, + const Request &req, + Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, + is_shutting_down, *compressor, error); + } else { + return detail::write_content_with_progress( + strm, req.content_provider_, 0, req.content_length_, is_shutting_down, + req.upload_progress, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, + bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + // For Unix socket connections, use "localhost" as Host header (similar to + // curl behavior) + if (address_family_ == AF_UNIX) { + req.set_header("Host", "localhost"); + } else if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } + + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || + req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path_with_query = + req.params.empty() ? req.path + : append_query_params(req.path, req.params); + + const auto &path = + path_encode_ ? detail::encode_path(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (req.upload_progress) { + auto body_size = req.body.size(); + size_t written = 0; + auto data = req.body.data(); + + while (written < body_size) { + size_t to_write = (std::min)(CPPHTTPLIB_SEND_BUFSIZ, body_size - written); + if (!detail::write_data(strm, data + written, to_write)) { + error = Error::Write; + return false; + } + written += to_write; + + if (!req.upload_progress(written, body_size)) { + error = Error::Canceled; + return false; + } + } + } else { + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error) { + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { req.set_header("Content-Encoding", "gzip"); } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, + [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter( + std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider( + const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, UploadProgress progress) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + req.upload_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + auto error = Error::Success; + + auto res = send_with_content_provider( + req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, + last_openssl_error_}; +#else + return Result{std::move(res), error, std::move(req.headers)}; +#endif +} + +inline std::string +ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { return "[" + host + "]"; } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || + !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && + req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && + res.status != StatusCode::NotModified_304 && + follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast( + [&](const char *buf, size_t n, size_t off, size_t len) { + if (redirect) { return true; } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { error = Error::Canceled; } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, size_t /*off*/, + size_t /*len*/) { + assert(res.body.size() + n <= res.body.max_size()); + res.body.append(buf, n); + return true; + }); + + auto progress = [&](size_t current, size_t total) { + if (!req.download_progress || redirect) { return true; } + auto ret = req.download_progress(current, total); + if (!ret) { error = Error::Canceled; } + return ret; + }; + + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(static_cast(len)); + } + } + + if (res.status != StatusCode::NotModified_304) { + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const UploadFormDataItems &items, + const FormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, + DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin( + provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool ClientImpl::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + return detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, max_timeout_msec_, start_time, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path, + DownloadProgress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + DownloadProgress progress) { + if (params.empty()) { return Get(path, headers); } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + DownloadProgress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.download_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { + return Head(path, Headers()); +} + +inline Result ClientImpl::Head(const std::string &path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { + return Post(path, std::string(), std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return Post(path, Headers(), items, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.body = body; + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Put(const std::string &path) { + return Put(path, std::string(), std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers) { + return Put(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return Put(path, Headers(), items, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "PUT"; + req.path = path; + req.headers = headers; + req.body = body; + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Patch(const std::string &path) { + return Patch(path, std::string(), std::string()); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + UploadProgress progress) { + return Patch(path, headers, nullptr, 0, std::string(), progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Params ¶ms) { + return Patch(path, Headers(), params); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Patch(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Patch(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return Patch(path, Headers(), items, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Patch(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Patch(path, headers, body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, body, + content_length, nullptr, nullptr, + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PATCH", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + Request req; + req.method = "PATCH"; + req.path = path; + req.headers = headers; + req.body = body; + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.download_progress = std::move(progress); + + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, + DownloadProgress progress) { + return Delete(path, Headers(), std::string(), std::string(), progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + DownloadProgress progress) { + return Delete(path, headers, std::string(), std::string(), progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const Params ¶ms, + DownloadProgress progress) { + return Delete(path, Headers(), params, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const Params ¶ms, + DownloadProgress progress) { + auto query = detail::params_to_query_str(params); + return Delete(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + req.download_progress = std::move(progress); + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Options(const std::string &path) { + return Options(path, Headers()); +} + +inline Result ClientImpl::Options(const std::string &path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + if (max_timeout_msec_ > 0) { + req.start_time_ = std::chrono::steady_clock::now(); + } + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_max_timeout(time_t msec) { + max_timeout_msec_ = msec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, + const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { + bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_path_encode(bool on) { path_encode_ = on; } + +inline void +ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); +} + +inline void ClientImpl::set_header_writer( + std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { + address_family_ = family; +} + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { + interface_ = intf; +} + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, + std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); + if (!mem) { return nullptr; } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + + if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} + +inline void ClientImpl::set_server_certificate_verifier( + std::function verifier) { + server_certificate_verifier_ = verifier; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { + logger_ = std::move(logger); +} + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, + bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { + (void)(sock); + // SSL_shutdown() returns 0 on first call (indicating close_notify alert + // sent) and 1 on subsequent call (indicating close_notify alert received) + if (SSL_shutdown(ssl) == 0) { + // Expected to return 1, but even if it doesn't, we free ssl + SSL_shutdown(ssl); + } + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, + U ssl_connect_or_accept, + time_t timeout_sec, time_t timeout_usec, + int *ssl_error) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + if (ssl_error) { *ssl_error = err; } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl( + const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool process_client_socket_ssl( + SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream( + socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +inline bool SSLSocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (wait_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN64 + while (--n >= 0 && (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (wait_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } else { + return -1; + } +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (wait_writable()) { + auto handle_size = static_cast( + std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN64 + while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (wait_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +inline time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, + reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1 || + SSL_CTX_check_private_key(ctx_) != 1) { + last_ssl_error_ = static_cast(ERR_get_error()); + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking( + sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_, + &last_ssl_error_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, + connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + last_openssl_error_ = ERR_get_error(); + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + last_openssl_error_ = ERR_get_error(); + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + // Close the current socket and create a new one for the authenticated + // request + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + + // Create a new socket for the authenticated CONNECT request + if (!create_and_connect_socket(socket, error)) { + success = false; + return false; + } + + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + last_openssl_error_ = ERR_get_error(); + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + last_openssl_error_ = ERR_get_error(); + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN64 + loaded = + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \ + defined(TARGET_OS_OSX) + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // _WIN64 + if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking( + socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_, &last_ssl_error_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + auto verification_status = SSLVerifierResponse::NoDecisionMade; + + if (server_certificate_verifier_) { + verification_status = server_certificate_verifier_(ssl2); + } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + last_openssl_error_ = ERR_get_error(); + error = Error::SSLServerVerification; + return false; + } + + if (verification_status == SSLVerifierResponse::NoDecisionMade) { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + last_openssl_error_ = static_cast(verify_result_); + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + last_openssl_error_ = ERR_get_error(); + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; + error = Error::SSLServerHostnameVerification; + return false; + } + } + } + } + + return true; + }, + [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); +#endif + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 = {}; + struct in_addr addr = {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) + : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re( + R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { host = m[3].str(); } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + } + } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. + cli_ = detail::make_unique(scheme_host_port, 80, + client_cert_path, client_key_path); + } +} // namespace detail + +inline Client::Client(const std::string &host, int port) + : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, + client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { + return cli_ != nullptr && cli_->is_valid(); +} + +inline Result Client::Get(const std::string &path, DownloadProgress progress) { + return cli_->Get(path, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + DownloadProgress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, DownloadProgress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { + return cli_->Head(path, headers); +} + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { + return cli_->Post(path, headers); +} +inline Result Client::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, body, content_length, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + progress); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { + return cli_->Post(path, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Post(path, items, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Post(path, headers, items, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + return cli_->Post(path, headers, items, boundary, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + return cli_->Post(path, headers, items, provider_items, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, body, content_type, content_receiver, + progress); +} + +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const Headers &headers) { + return cli_->Put(path, headers); +} +inline Result Client::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + progress); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { + return cli_->Put(path, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Put(path, items, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Put(path, headers, items, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + return cli_->Put(path, headers, items, boundary, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + return cli_->Put(path, headers, items, provider_items, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Put(path, headers, body, content_type, content_receiver, + progress); +} + +inline Result Client::Patch(const std::string &path) { + return cli_->Patch(path); +} +inline Result Client::Patch(const std::string &path, const Headers &headers) { + return cli_->Patch(path, headers); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + progress); +} +inline Result Client::Patch(const std::string &path, const Params ¶ms) { + return cli_->Patch(path, params); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Patch(path, headers, params); +} +inline Result Client::Patch(const std::string &path, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Patch(path, items, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + UploadProgress progress) { + return cli_->Patch(path, headers, items, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const std::string &boundary, + UploadProgress progress) { + return cli_->Patch(path, headers, items, boundary, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const UploadFormDataItems &items, + const FormDataProviderItems &provider_items, + UploadProgress progress) { + return cli_->Patch(path, headers, items, provider_items, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Patch(path, headers, body, content_type, content_receiver, + progress); +} + +inline Result Client::Delete(const std::string &path, + DownloadProgress progress) { + return cli_->Delete(path, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + DownloadProgress progress) { + return cli_->Delete(path, headers, progress); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + DownloadProgress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Params ¶ms, + DownloadProgress progress) { + return cli_->Delete(path, params, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const Params ¶ms, DownloadProgress progress) { + return cli_->Delete(path, headers, params, progress); +} + +inline Result Client::Options(const std::string &path) { + return cli_->Options(path); +} +inline Result Client::Options(const std::string &path, const Headers &headers) { + return cli_->Options(path, headers); +} + +inline bool Client::send(Request &req, Response &res, Error &error) { + return cli_->send(req, res, error); +} + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void +Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { + cli_->set_default_headers(std::move(headers)); +} + +inline void Client::set_header_writer( + std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { + cli_->set_address_family(family); +} + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { + cli_->set_connection_timeout(sec, usec); +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + cli_->set_read_timeout(sec, usec); +} + +inline void Client::set_write_timeout(time_t sec, time_t usec) { + cli_->set_write_timeout(sec, usec); +} + +inline void Client::set_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { + cli_->set_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { + cli_->set_follow_location(on); +} + +inline void Client::set_path_encode(bool on) { cli_->set_path_encode(on); } + +[[deprecated("Use set_path_encode instead")]] +inline void Client::set_url_encode(bool on) { + cli_->set_path_encode(on); +} + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { + cli_->set_interface(intf); +} + +inline void Client::set_proxy(const std::string &host, int port) { + cli_->set_proxy(host, port); +} +inline void Client::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { + cli_->set_proxy_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} +#endif + +inline void Client::set_logger(Logger logger) { + cli_->set_logger(std::move(logger)); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { return static_cast(*cli_).ssl_context(); } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H + +// End of httplib.h + +// Start of main.c +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #include +#else + #include +#endif + +#include "wren.h" +#include "requests_backend.c" +#include "socket_backend.c" + +// --- Global flag to control the main loop --- +static volatile bool g_mainFiberIsDone = false; + +// --- Foreign function for Wren to signal the host to exit --- +void hostSignalDone(WrenVM* vm) { + (void)vm; + g_mainFiberIsDone = true; +} + +// --- File/VM Setup --- +static char* readFile(const char* path) { + FILE* file = fopen(path, "rb"); + if (file == NULL) return NULL; + fseek(file, 0L, SEEK_END); + size_t fileSize = ftell(file); + rewind(file); + char* buffer = (char*)malloc(fileSize + 1); + if (!buffer) { fclose(file); return NULL; } + size_t bytesRead = fread(buffer, sizeof(char), fileSize, file); + if (bytesRead < fileSize) { + free(buffer); + fclose(file); + return NULL; + } + buffer[bytesRead] = '\0'; + fclose(file); + return buffer; +} + +static void writeFn(WrenVM* vm, const char* text) { (void)vm; printf("%s", text); } + +static void errorFn(WrenVM* vm, WrenErrorType type, const char* module, int line, const char* message) { + (void)vm; + switch (type) { + case WREN_ERROR_COMPILE: + fprintf(stderr, "[%s line %d] [Error] %s\n", module, line, message); + break; + case WREN_ERROR_RUNTIME: + fprintf(stderr, "[Runtime Error] %s\n", message); + g_mainFiberIsDone = true; // Stop on runtime errors + break; + case WREN_ERROR_STACK_TRACE: + fprintf(stderr, "[%s line %d] in %s\n", module, line, message); + break; + } +} + +static void onModuleComplete(WrenVM* vm, const char* name, WrenLoadModuleResult result) { + (void)vm; (void)name; + if (result.source) free((void*)result.source); +} + +static WrenLoadModuleResult loadModule(WrenVM* vm, const char* name) { + (void)vm; + WrenLoadModuleResult result = {0}; + char path[256]; + snprintf(path, sizeof(path), "%s.wren", name); + char* source = readFile(path); + if (source != NULL) { + result.source = source; + result.onComplete = onModuleComplete; + } + return result; +} + +// --- Combined Foreign Function Binders --- +WrenForeignMethodFn combinedBindForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) { + // Delegate to the socket backend's binder + if (strcmp(module, "socket") == 0) { + return bindSocketForeignMethod(vm, module, className, isStatic, signature); + } + + // Delegate to the requests backend's binder + if (strcmp(module, "requests") == 0) { + return bindForeignMethod(vm, module, className, isStatic, signature); + } + + // Handle host-specific methods + if (strcmp(module, "main") == 0 && strcmp(className, "Host") == 0 && isStatic) { + if (strcmp(signature, "signalDone()") == 0) return hostSignalDone; + } + + return NULL; +} + +WrenForeignClassMethods combinedBindForeignClass(WrenVM* vm, const char* module, const char* className) { + // Delegate to the socket backend's class binder + if (strcmp(module, "socket") == 0) { + return bindSocketForeignClass(vm, module, className); + } + + // Delegate to the requests backend's class binder + if (strcmp(module, "requests") == 0) { + return bindForeignClass(vm, module, className); + } + + WrenForeignClassMethods methods = {0, 0}; + return methods; +} + + +// --- Main Application Entry Point --- +int main(int argc, char* argv[]) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + // Initialize libcurl for the requests module + curl_global_init(CURL_GLOBAL_ALL); + + WrenConfiguration config; + wrenInitConfiguration(&config); + config.writeFn = writeFn; + config.errorFn = errorFn; + config.bindForeignMethodFn = combinedBindForeignMethod; + config.bindForeignClassFn = combinedBindForeignClass; + config.loadModuleFn = loadModule; + + WrenVM* vm = wrenNewVM(&config); + + // ** Initialize BOTH managers ** + socketManager_create(vm); + httpManager_create(vm); + + char* mainSource = readFile(argv[1]); + if (!mainSource) { + fprintf(stderr, "Could not open script: %s\n", argv[1]); + socketManager_destroy(); + httpManager_destroy(); + wrenFreeVM(vm); + curl_global_cleanup(); + return 1; + } + + wrenInterpret(vm, "main", mainSource); + free(mainSource); + + if (g_mainFiberIsDone) { + socketManager_destroy(); + httpManager_destroy(); + wrenFreeVM(vm); + curl_global_cleanup(); + return 1; + } + + wrenEnsureSlots(vm, 1); + wrenGetVariable(vm, "main", "mainFiber", 0); + WrenHandle* mainFiberHandle = wrenGetSlotHandle(vm, 0); + WrenHandle* callHandle = wrenMakeCallHandle(vm, "call()"); + + // === Main Event Loop === + while (!g_mainFiberIsDone) { + // ** Process completions for BOTH managers ** + socketManager_processCompletions(); + httpManager_processCompletions(); + + // Resume the main Wren fiber + wrenEnsureSlots(vm, 1); + wrenSetSlotHandle(vm, 0, mainFiberHandle); + WrenInterpretResult result = wrenCall(vm, callHandle); + if (result == WREN_RESULT_RUNTIME_ERROR) { + g_mainFiberIsDone = true; + } + + // Prevent 100% CPU usage + #ifdef _WIN32 + Sleep(1); + #else + usleep(1000); // 1ms + #endif + } + + // Process any final completions before shutting down + socketManager_processCompletions(); + httpManager_processCompletions(); + + wrenReleaseHandle(vm, mainFiberHandle); + wrenReleaseHandle(vm, callHandle); + + // ** Destroy BOTH managers ** + socketManager_destroy(); + httpManager_destroy(); + + wrenFreeVM(vm); + curl_global_cleanup(); + + printf("\nHost application finished.\n"); + return 0; +} + +// End of main.c + +// Start of requests_backend.c +// http_backend.c (Corrected) +#include "wren.h" +#include +#include +#include +#include + +#ifdef _WIN32 + #include + typedef HANDLE thread_t; + typedef CRITICAL_SECTION mutex_t; + typedef CONDITION_VARIABLE cond_t; +#else + #include + typedef pthread_t thread_t; + typedef pthread_mutex_t mutex_t; + typedef pthread_cond_t cond_t; +#endif + +// --- Data Structures --- + +typedef struct { + int isError; + long statusCode; + char* body; + size_t body_len; +} ResponseData; + +typedef struct { + char* memory; + size_t size; +} MemoryStruct; + +typedef struct HttpContext { + WrenVM* vm; + WrenHandle* callback; + + char* url; + char* method; + char* body; + struct curl_slist* headers; + + bool success; + char* response_body; + size_t response_body_len; + long status_code; + char* error_message; + struct HttpContext* next; +} HttpContext; + + +// --- Thread-Safe Queue --- + +typedef struct { + HttpContext *head, *tail; + mutex_t mutex; + cond_t cond; +} ThreadSafeQueue; + +void http_queue_init(ThreadSafeQueue* q) { + q->head = q->tail = NULL; + #ifdef _WIN32 + InitializeCriticalSection(&q->mutex); + InitializeConditionVariable(&q->cond); + #else + pthread_mutex_init(&q->mutex, NULL); + pthread_cond_init(&q->cond, NULL); + #endif +} + +void http_queue_destroy(ThreadSafeQueue* q) { + #ifdef _WIN32 + DeleteCriticalSection(&q->mutex); + #else + pthread_mutex_destroy(&q->mutex); + pthread_cond_destroy(&q->cond); + #endif +} + +void http_queue_push(ThreadSafeQueue* q, HttpContext* context) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + #endif + + if(context) context->next = NULL; + if (q->tail) q->tail->next = context; + else q->head = context; + q->tail = context; + + #ifdef _WIN32 + WakeConditionVariable(&q->cond); + LeaveCriticalSection(&q->mutex); + #else + pthread_cond_signal(&q->cond); + pthread_mutex_unlock(&q->mutex); + #endif +} + +HttpContext* http_queue_pop(ThreadSafeQueue* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + while (q->head == NULL) { + SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE); + } + #else + pthread_mutex_lock(&q->mutex); + while (q->head == NULL) { + pthread_cond_wait(&q->cond, &q->mutex); + } + #endif + + HttpContext* context = q->head; + q->head = q->head->next; + if (q->head == NULL) q->tail = NULL; + + #ifdef _WIN32 + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_unlock(&q->mutex); + #endif + + return context; +} + +bool http_queue_empty(ThreadSafeQueue* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + bool empty = (q->head == NULL); + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + bool empty = (q->head == NULL); + pthread_mutex_unlock(&q->mutex); + #endif + return empty; +} + + +// --- libcurl Helpers --- +static size_t write_memory_callback(void *contents, size_t size, size_t nmemb, void *userp) { + size_t realsize = size * nmemb; + MemoryStruct *mem = (MemoryStruct *)userp; + char *ptr = (char*)realloc(mem->memory, mem->size + realsize + 1); + if (ptr == NULL) return 0; + mem->memory = ptr; + memcpy(&(mem->memory[mem->size]), contents, realsize); + mem->size += realsize; + mem->memory[mem->size] = 0; + return realsize; +} + +// --- Async HTTP Manager --- + +typedef struct { + WrenVM* vm; + volatile bool running; + thread_t threads[4]; + ThreadSafeQueue requestQueue; + ThreadSafeQueue completionQueue; +} AsyncHttpManager; + +static AsyncHttpManager* httpManager = NULL; + +void free_http_context(HttpContext* context) { + if (!context) return; + free(context->url); + free(context->method); + free(context->body); + curl_slist_free_all(context->headers); + free(context->response_body); + free(context->error_message); + free(context); +} + +#ifdef _WIN32 +DWORD WINAPI httpWorkerThread(LPVOID arg) { +#else +void* httpWorkerThread(void* arg) { +#endif + AsyncHttpManager* manager = (AsyncHttpManager*)arg; + while (manager->running) { + HttpContext* context = http_queue_pop(&manager->requestQueue); + if (!context || !manager->running) { + if (context) free_http_context(context); + break; + } + + CURL *curl = curl_easy_init(); + if (curl) { + MemoryStruct chunk; + chunk.memory = (char*)malloc(1); + chunk.size = 0; + + curl_easy_setopt(curl, CURLOPT_URL, context->url); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_memory_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *)&chunk); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "wren-curl-agent/1.0"); + + if (strcmp(context->method, "POST") == 0) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, context->body); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, context->headers); + } + + CURLcode res = curl_easy_perform(curl); + + if (res == CURLE_OK) { + context->success = true; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &context->status_code); + context->response_body = chunk.memory; + context->response_body_len = chunk.size; + } else { + context->success = false; + context->status_code = -1; + context->error_message = strdup(curl_easy_strerror(res)); + free(chunk.memory); + } + curl_easy_cleanup(curl); + } else { + context->success = false; + context->error_message = strdup("Failed to initialize cURL handle."); + } + http_queue_push(&manager->completionQueue, context); + } + return 0; +} + +void httpManager_create(WrenVM* vm) { + httpManager = (AsyncHttpManager*)malloc(sizeof(AsyncHttpManager)); + httpManager->vm = vm; + httpManager->running = true; + http_queue_init(&httpManager->requestQueue); + http_queue_init(&httpManager->completionQueue); + for (int i = 0; i < 4; ++i) { + #ifdef _WIN32 + httpManager->threads[i] = CreateThread(NULL, 0, httpWorkerThread, httpManager, 0, NULL); + #else + pthread_create(&httpManager->threads[i], NULL, httpWorkerThread, httpManager); + #endif + } +} + +void httpManager_destroy() { + httpManager->running = false; + for (int i = 0; i < 4; ++i) { + http_queue_push(&httpManager->requestQueue, NULL); + } + for (int i = 0; i < 4; ++i) { + #ifdef _WIN32 + WaitForSingleObject(httpManager->threads[i], INFINITE); + CloseHandle(httpManager->threads[i]); + #else + pthread_join(httpManager->threads[i], NULL); + #endif + } + http_queue_destroy(&httpManager->requestQueue); + http_queue_destroy(&httpManager->completionQueue); + free(httpManager); +} + +void httpManager_processCompletions() { + while (!http_queue_empty(&httpManager->completionQueue)) { + HttpContext* context = http_queue_pop(&httpManager->completionQueue); + + WrenHandle* callHandle = wrenMakeCallHandle(httpManager->vm, "call(_,_)"); + wrenEnsureSlots(httpManager->vm, 3); + wrenSetSlotHandle(httpManager->vm, 0, context->callback); + + if (context->success) { + wrenSetSlotNull(httpManager->vm, 1); + + wrenGetVariable(httpManager->vm, "requests", "Response", 2); + void* foreign = wrenSetSlotNewForeign(httpManager->vm, 2, 2, sizeof(ResponseData)); + ResponseData* data = (ResponseData*)foreign; + data->isError = false; + data->statusCode = context->status_code; + data->body = context->response_body; + data->body_len = context->response_body_len; + context->response_body = NULL; + } else { + wrenSetSlotString(httpManager->vm, 1, context->error_message); + wrenSetSlotNull(httpManager->vm, 2); + } + + wrenCall(httpManager->vm, callHandle); + wrenReleaseHandle(httpManager->vm, context->callback); + wrenReleaseHandle(httpManager->vm, callHandle); + free_http_context(context); + } +} + +void httpManager_submit(HttpContext* context) { + http_queue_push(&httpManager->requestQueue, context); +} + +// --- Wren Foreign Methods --- + +void responseFinalize(void* data) { + ResponseData* response = (ResponseData*)data; + free(response->body); +} + +void responseAllocate(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(ResponseData)); + data->isError = 0; + data->statusCode = 0; + data->body = NULL; + data->body_len = 0; +} + +void responseIsError(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, data->isError ? true : false); +} + +void responseStatusCode(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotDouble(vm, 0, (double)data->statusCode); +} + +void responseBody(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBytes(vm, 0, data->body ? data->body : "", data->body_len); +} + +void responseJson(WrenVM* vm) { + // CORRECTED: Replaced incorrect call with the actual logic. + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBytes(vm, 0, data->body ? data->body : "", data->body_len); +} + +void requestsGet(WrenVM* vm) { + HttpContext* context = (HttpContext*)calloc(1, sizeof(HttpContext)); + context->vm = vm; + context->method = strdup("GET"); + context->url = strdup(wrenGetSlotString(vm, 1)); + context->callback = wrenGetSlotHandle(vm, 3); + httpManager_submit(context); +} + +void requestsPost(WrenVM* vm) { + HttpContext* context = (HttpContext*)calloc(1, sizeof(HttpContext)); + context->vm = vm; + context->method = strdup("POST"); + context->url = strdup(wrenGetSlotString(vm, 1)); + context->body = strdup(wrenGetSlotString(vm, 2)); + const char* contentType = wrenGetSlotString(vm, 3); + char contentTypeHeader[256]; + snprintf(contentTypeHeader, sizeof(contentTypeHeader), "Content-Type: %s", contentType); + context->headers = curl_slist_append(NULL, contentTypeHeader); + context->callback = wrenGetSlotHandle(vm, 5); + httpManager_submit(context); +} + +// --- FFI Binding Functions --- + +WrenForeignMethodFn bindForeignMethod(WrenVM* vm, const char* module, + const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "requests") != 0) return NULL; + + if (strcmp(className, "Requests") == 0 && isStatic) { + if (strcmp(signature, "get_(_,_,_)") == 0) return requestsGet; + if (strcmp(signature, "post_(_,_,_,_,_)") == 0) return requestsPost; + } + + if (strcmp(className, "Response") == 0 && !isStatic) { + if (strcmp(signature, "isError") == 0) return responseIsError; + if (strcmp(signature, "statusCode") == 0) return responseStatusCode; + if (strcmp(signature, "body") == 0) return responseBody; + if (strcmp(signature, "json()") == 0) return responseJson; + } + + return NULL; +} + +WrenForeignClassMethods bindForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0, 0}; + if (strcmp(module, "requests") == 0) { + if (strcmp(className, "Response") == 0) { + methods.allocate = responseAllocate; + methods.finalize = responseFinalize; + } + } + return methods; +} + +// End of requests_backend.c + +// Start of wren.h +#ifndef wren_h +#define wren_h + +#include +#include +#include + +// The Wren semantic version number components. +#define WREN_VERSION_MAJOR 0 +#define WREN_VERSION_MINOR 4 +#define WREN_VERSION_PATCH 0 + +// A human-friendly string representation of the version. +#define WREN_VERSION_STRING "0.4.0" + +// A monotonically increasing numeric representation of the version number. Use +// this if you want to do range checks over versions. +#define WREN_VERSION_NUMBER (WREN_VERSION_MAJOR * 1000000 + \ + WREN_VERSION_MINOR * 1000 + \ + WREN_VERSION_PATCH) + +#ifndef WREN_API + #if defined(_MSC_VER) && defined(WREN_API_DLLEXPORT) + #define WREN_API __declspec( dllexport ) + #else + #define WREN_API + #endif +#endif //WREN_API + +// A single virtual machine for executing Wren code. +// +// Wren has no global state, so all state stored by a running interpreter lives +// here. +typedef struct WrenVM WrenVM; + +// A handle to a Wren object. +// +// This lets code outside of the VM hold a persistent reference to an object. +// After a handle is acquired, and until it is released, this ensures the +// garbage collector will not reclaim the object it references. +typedef struct WrenHandle WrenHandle; + +// A generic allocation function that handles all explicit memory management +// used by Wren. It's used like so: +// +// - To allocate new memory, [memory] is NULL and [newSize] is the desired +// size. It should return the allocated memory or NULL on failure. +// +// - To attempt to grow an existing allocation, [memory] is the memory, and +// [newSize] is the desired size. It should return [memory] if it was able to +// grow it in place, or a new pointer if it had to move it. +// +// - To shrink memory, [memory] and [newSize] are the same as above but it will +// always return [memory]. +// +// - To free memory, [memory] will be the memory to free and [newSize] will be +// zero. It should return NULL. +typedef void* (*WrenReallocateFn)(void* memory, size_t newSize, void* userData); + +// A function callable from Wren code, but implemented in C. +typedef void (*WrenForeignMethodFn)(WrenVM* vm); + +// A finalizer function for freeing resources owned by an instance of a foreign +// class. Unlike most foreign methods, finalizers do not have access to the VM +// and should not interact with it since it's in the middle of a garbage +// collection. +typedef void (*WrenFinalizerFn)(void* data); + +// Gives the host a chance to canonicalize the imported module name, +// potentially taking into account the (previously resolved) name of the module +// that contains the import. Typically, this is used to implement relative +// imports. +typedef const char* (*WrenResolveModuleFn)(WrenVM* vm, + const char* importer, const char* name); + +// Forward declare +struct WrenLoadModuleResult; + +// Called after loadModuleFn is called for module [name]. The original returned result +// is handed back to you in this callback, so that you can free memory if appropriate. +typedef void (*WrenLoadModuleCompleteFn)(WrenVM* vm, const char* name, struct WrenLoadModuleResult result); + +// The result of a loadModuleFn call. +// [source] is the source code for the module, or NULL if the module is not found. +// [onComplete] an optional callback that will be called once Wren is done with the result. +typedef struct WrenLoadModuleResult +{ + const char* source; + WrenLoadModuleCompleteFn onComplete; + void* userData; +} WrenLoadModuleResult; + +// Loads and returns the source code for the module [name]. +typedef WrenLoadModuleResult (*WrenLoadModuleFn)(WrenVM* vm, const char* name); + +// Returns a pointer to a foreign method on [className] in [module] with +// [signature]. +typedef WrenForeignMethodFn (*WrenBindForeignMethodFn)(WrenVM* vm, + const char* module, const char* className, bool isStatic, + const char* signature); + +// Displays a string of text to the user. +typedef void (*WrenWriteFn)(WrenVM* vm, const char* text); + +typedef enum +{ + // A syntax or resolution error detected at compile time. + WREN_ERROR_COMPILE, + + // The error message for a runtime error. + WREN_ERROR_RUNTIME, + + // One entry of a runtime error's stack trace. + WREN_ERROR_STACK_TRACE +} WrenErrorType; + +// Reports an error to the user. +// +// An error detected during compile time is reported by calling this once with +// [type] `WREN_ERROR_COMPILE`, the resolved name of the [module] and [line] +// where the error occurs, and the compiler's error [message]. +// +// A runtime error is reported by calling this once with [type] +// `WREN_ERROR_RUNTIME`, no [module] or [line], and the runtime error's +// [message]. After that, a series of [type] `WREN_ERROR_STACK_TRACE` calls are +// made for each line in the stack trace. Each of those has the resolved +// [module] and [line] where the method or function is defined and [message] is +// the name of the method or function. +typedef void (*WrenErrorFn)( + WrenVM* vm, WrenErrorType type, const char* module, int line, + const char* message); + +typedef struct +{ + // The callback invoked when the foreign object is created. + // + // This must be provided. Inside the body of this, it must call + // [wrenSetSlotNewForeign()] exactly once. + WrenForeignMethodFn allocate; + + // The callback invoked when the garbage collector is about to collect a + // foreign object's memory. + // + // This may be `NULL` if the foreign class does not need to finalize. + WrenFinalizerFn finalize; +} WrenForeignClassMethods; + +// Returns a pair of pointers to the foreign methods used to allocate and +// finalize the data for instances of [className] in resolved [module]. +typedef WrenForeignClassMethods (*WrenBindForeignClassFn)( + WrenVM* vm, const char* module, const char* className); + +typedef struct +{ + // The callback Wren will use to allocate, reallocate, and deallocate memory. + // + // If `NULL`, defaults to a built-in function that uses `realloc` and `free`. + WrenReallocateFn reallocateFn; + + // The callback Wren uses to resolve a module name. + // + // Some host applications may wish to support "relative" imports, where the + // meaning of an import string depends on the module that contains it. To + // support that without baking any policy into Wren itself, the VM gives the + // host a chance to resolve an import string. + // + // Before an import is loaded, it calls this, passing in the name of the + // module that contains the import and the import string. The host app can + // look at both of those and produce a new "canonical" string that uniquely + // identifies the module. This string is then used as the name of the module + // going forward. It is what is passed to [loadModuleFn], how duplicate + // imports of the same module are detected, and how the module is reported in + // stack traces. + // + // If you leave this function NULL, then the original import string is + // treated as the resolved string. + // + // If an import cannot be resolved by the embedder, it should return NULL and + // Wren will report that as a runtime error. + // + // Wren will take ownership of the string you return and free it for you, so + // it should be allocated using the same allocation function you provide + // above. + WrenResolveModuleFn resolveModuleFn; + + // The callback Wren uses to load a module. + // + // Since Wren does not talk directly to the file system, it relies on the + // embedder to physically locate and read the source code for a module. The + // first time an import appears, Wren will call this and pass in the name of + // the module being imported. The method will return a result, which contains + // the source code for that module. Memory for the source is owned by the + // host application, and can be freed using the onComplete callback. + // + // This will only be called once for any given module name. Wren caches the + // result internally so subsequent imports of the same module will use the + // previous source and not call this. + // + // If a module with the given name could not be found by the embedder, it + // should return NULL and Wren will report that as a runtime error. + WrenLoadModuleFn loadModuleFn; + + // The callback Wren uses to find a foreign method and bind it to a class. + // + // When a foreign method is declared in a class, this will be called with the + // foreign method's module, class, and signature when the class body is + // executed. It should return a pointer to the foreign function that will be + // bound to that method. + // + // If the foreign function could not be found, this should return NULL and + // Wren will report it as runtime error. + WrenBindForeignMethodFn bindForeignMethodFn; + + // The callback Wren uses to find a foreign class and get its foreign methods. + // + // When a foreign class is declared, this will be called with the class's + // module and name when the class body is executed. It should return the + // foreign functions uses to allocate and (optionally) finalize the bytes + // stored in the foreign object when an instance is created. + WrenBindForeignClassFn bindForeignClassFn; + + // The callback Wren uses to display text when `System.print()` or the other + // related functions are called. + // + // If this is `NULL`, Wren discards any printed text. + WrenWriteFn writeFn; + + // The callback Wren uses to report errors. + // + // When an error occurs, this will be called with the module name, line + // number, and an error message. If this is `NULL`, Wren doesn't report any + // errors. + WrenErrorFn errorFn; + + // The number of bytes Wren will allocate before triggering the first garbage + // collection. + // + // If zero, defaults to 10MB. + size_t initialHeapSize; + + // After a collection occurs, the threshold for the next collection is + // determined based on the number of bytes remaining in use. This allows Wren + // to shrink its memory usage automatically after reclaiming a large amount + // of memory. + // + // This can be used to ensure that the heap does not get too small, which can + // in turn lead to a large number of collections afterwards as the heap grows + // back to a usable size. + // + // If zero, defaults to 1MB. + size_t minHeapSize; + + // Wren will resize the heap automatically as the number of bytes + // remaining in use after a collection changes. This number determines the + // amount of additional memory Wren will use after a collection, as a + // percentage of the current heap size. + // + // For example, say that this is 50. After a garbage collection, when there + // are 400 bytes of memory still in use, the next collection will be triggered + // after a total of 600 bytes are allocated (including the 400 already in + // use.) + // + // Setting this to a smaller number wastes less memory, but triggers more + // frequent garbage collections. + // + // If zero, defaults to 50. + int heapGrowthPercent; + + // User-defined data associated with the VM. + void* userData; + +} WrenConfiguration; + +typedef enum +{ + WREN_RESULT_SUCCESS, + WREN_RESULT_COMPILE_ERROR, + WREN_RESULT_RUNTIME_ERROR +} WrenInterpretResult; + +// The type of an object stored in a slot. +// +// This is not necessarily the object's *class*, but instead its low level +// representation type. +typedef enum +{ + WREN_TYPE_BOOL, + WREN_TYPE_NUM, + WREN_TYPE_FOREIGN, + WREN_TYPE_LIST, + WREN_TYPE_MAP, + WREN_TYPE_NULL, + WREN_TYPE_STRING, + + // The object is of a type that isn't accessible by the C API. + WREN_TYPE_UNKNOWN +} WrenType; + +// Get the current wren version number. +// +// Can be used to range checks over versions. +WREN_API int wrenGetVersionNumber(); + +// Initializes [configuration] with all of its default values. +// +// Call this before setting the particular fields you care about. +WREN_API void wrenInitConfiguration(WrenConfiguration* configuration); + +// Creates a new Wren virtual machine using the given [configuration]. Wren +// will copy the configuration data, so the argument passed to this can be +// freed after calling this. If [configuration] is `NULL`, uses a default +// configuration. +WREN_API WrenVM* wrenNewVM(WrenConfiguration* configuration); + +// Disposes of all resources is use by [vm], which was previously created by a +// call to [wrenNewVM]. +WREN_API void wrenFreeVM(WrenVM* vm); + +// Immediately run the garbage collector to free unused memory. +WREN_API void wrenCollectGarbage(WrenVM* vm); + +// Runs [source], a string of Wren source code in a new fiber in [vm] in the +// context of resolved [module]. +WREN_API WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module, + const char* source); + +// Creates a handle that can be used to invoke a method with [signature] on +// using a receiver and arguments that are set up on the stack. +// +// This handle can be used repeatedly to directly invoke that method from C +// code using [wrenCall]. +// +// When you are done with this handle, it must be released using +// [wrenReleaseHandle]. +WREN_API WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature); + +// Calls [method], using the receiver and arguments previously set up on the +// stack. +// +// [method] must have been created by a call to [wrenMakeCallHandle]. The +// arguments to the method must be already on the stack. The receiver should be +// in slot 0 with the remaining arguments following it, in order. It is an +// error if the number of arguments provided does not match the method's +// signature. +// +// After this returns, you can access the return value from slot 0 on the stack. +WREN_API WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method); + +// Releases the reference stored in [handle]. After calling this, [handle] can +// no longer be used. +WREN_API void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle); + +// The following functions are intended to be called from foreign methods or +// finalizers. The interface Wren provides to a foreign method is like a +// register machine: you are given a numbered array of slots that values can be +// read from and written to. Values always live in a slot (unless explicitly +// captured using wrenGetSlotHandle(), which ensures the garbage collector can +// find them. +// +// When your foreign function is called, you are given one slot for the receiver +// and each argument to the method. The receiver is in slot 0 and the arguments +// are in increasingly numbered slots after that. You are free to read and +// write to those slots as you want. If you want more slots to use as scratch +// space, you can call wrenEnsureSlots() to add more. +// +// When your function returns, every slot except slot zero is discarded and the +// value in slot zero is used as the return value of the method. If you don't +// store a return value in that slot yourself, it will retain its previous +// value, the receiver. +// +// While Wren is dynamically typed, C is not. This means the C interface has to +// support the various types of primitive values a Wren variable can hold: bool, +// double, string, etc. If we supported this for every operation in the C API, +// there would be a combinatorial explosion of functions, like "get a +// double-valued element from a list", "insert a string key and double value +// into a map", etc. +// +// To avoid that, the only way to convert to and from a raw C value is by going +// into and out of a slot. All other functions work with values already in a +// slot. So, to add an element to a list, you put the list in one slot, and the +// element in another. Then there is a single API function wrenInsertInList() +// that takes the element out of that slot and puts it into the list. +// +// The goal of this API is to be easy to use while not compromising performance. +// The latter means it does not do type or bounds checking at runtime except +// using assertions which are generally removed from release builds. C is an +// unsafe language, so it's up to you to be careful to use it correctly. In +// return, you get a very fast FFI. + +// Returns the number of slots available to the current foreign method. +WREN_API int wrenGetSlotCount(WrenVM* vm); + +// Ensures that the foreign method stack has at least [numSlots] available for +// use, growing the stack if needed. +// +// Does not shrink the stack if it has more than enough slots. +// +// It is an error to call this from a finalizer. +WREN_API void wrenEnsureSlots(WrenVM* vm, int numSlots); + +// Gets the type of the object in [slot]. +WREN_API WrenType wrenGetSlotType(WrenVM* vm, int slot); + +// Reads a boolean value from [slot]. +// +// It is an error to call this if the slot does not contain a boolean value. +WREN_API bool wrenGetSlotBool(WrenVM* vm, int slot); + +// Reads a byte array from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// Returns a pointer to the first byte of the array and fill [length] with the +// number of bytes in the array. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length); + +// Reads a number from [slot]. +// +// It is an error to call this if the slot does not contain a number. +WREN_API double wrenGetSlotDouble(WrenVM* vm, int slot); + +// Reads a foreign object from [slot] and returns a pointer to the foreign data +// stored with it. +// +// It is an error to call this if the slot does not contain an instance of a +// foreign class. +WREN_API void* wrenGetSlotForeign(WrenVM* vm, int slot); + +// Reads a string from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotString(WrenVM* vm, int slot); + +// Creates a handle for the value stored in [slot]. +// +// This will prevent the object that is referred to from being garbage collected +// until the handle is released by calling [wrenReleaseHandle()]. +WREN_API WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot); + +// Stores the boolean [value] in [slot]. +WREN_API void wrenSetSlotBool(WrenVM* vm, int slot, bool value); + +// Stores the array [length] of [bytes] in [slot]. +// +// The bytes are copied to a new string within Wren's heap, so you can free +// memory used by them after this is called. +WREN_API void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length); + +// Stores the numeric [value] in [slot]. +WREN_API void wrenSetSlotDouble(WrenVM* vm, int slot, double value); + +// Creates a new instance of the foreign class stored in [classSlot] with [size] +// bytes of raw storage and places the resulting object in [slot]. +// +// This does not invoke the foreign class's constructor on the new instance. If +// you need that to happen, call the constructor from Wren, which will then +// call the allocator foreign method. In there, call this to create the object +// and then the constructor will be invoked when the allocator returns. +// +// Returns a pointer to the foreign object's data. +WREN_API void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size); + +// Stores a new empty list in [slot]. +WREN_API void wrenSetSlotNewList(WrenVM* vm, int slot); + +// Stores a new empty map in [slot]. +WREN_API void wrenSetSlotNewMap(WrenVM* vm, int slot); + +// Stores null in [slot]. +WREN_API void wrenSetSlotNull(WrenVM* vm, int slot); + +// Stores the string [text] in [slot]. +// +// The [text] is copied to a new string within Wren's heap, so you can free +// memory used by it after this is called. The length is calculated using +// [strlen()]. If the string may contain any null bytes in the middle, then you +// should use [wrenSetSlotBytes()] instead. +WREN_API void wrenSetSlotString(WrenVM* vm, int slot, const char* text); + +// Stores the value captured in [handle] in [slot]. +// +// This does not release the handle for the value. +WREN_API void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle); + +// Returns the number of elements in the list stored in [slot]. +WREN_API int wrenGetListCount(WrenVM* vm, int slot); + +// Reads element [index] from the list in [listSlot] and stores it in +// [elementSlot]. +WREN_API void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Sets the value stored at [index] in the list at [listSlot], +// to the value from [elementSlot]. +WREN_API void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Takes the value stored at [elementSlot] and inserts it into the list stored +// at [listSlot] at [index]. +// +// As in Wren, negative indexes can be used to insert from the end. To append +// an element, use `-1` for the index. +WREN_API void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Returns the number of entries in the map stored in [slot]. +WREN_API int wrenGetMapCount(WrenVM* vm, int slot); + +// Returns true if the key in [keySlot] is found in the map placed in [mapSlot]. +WREN_API bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot); + +// Retrieves a value with the key in [keySlot] from the map in [mapSlot] and +// stores it in [valueSlot]. +WREN_API void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Takes the value stored at [valueSlot] and inserts it into the map stored +// at [mapSlot] with key [keySlot]. +WREN_API void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Removes a value from the map in [mapSlot], with the key from [keySlot], +// and place it in [removedValueSlot]. If not found, [removedValueSlot] is +// set to null, the same behaviour as the Wren Map API. +WREN_API void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot, + int removedValueSlot); + +// Looks up the top level variable with [name] in resolved [module] and stores +// it in [slot]. +WREN_API void wrenGetVariable(WrenVM* vm, const char* module, const char* name, + int slot); + +// Looks up the top level variable with [name] in resolved [module], +// returns false if not found. The module must be imported at the time, +// use wrenHasModule to ensure that before calling. +WREN_API bool wrenHasVariable(WrenVM* vm, const char* module, const char* name); + +// Returns true if [module] has been imported/resolved before, false if not. +WREN_API bool wrenHasModule(WrenVM* vm, const char* module); + +// Sets the current fiber to be aborted, and uses the value in [slot] as the +// runtime error object. +WREN_API void wrenAbortFiber(WrenVM* vm, int slot); + +// Returns the user data associated with the WrenVM. +WREN_API void* wrenGetUserData(WrenVM* vm); + +// Sets user data associated with the WrenVM. +WREN_API void wrenSetUserData(WrenVM* vm, void* userData); + +#endif + +// End of wren.h + +// Start of async_http.c +#include "httplib.h" +#include "wren.h" + +// A struct to hold the context for an asynchronous HTTP request +struct RequestContext { + std::string url; + WrenHandle* callback; + WrenVM* vm; + std::string response; + bool error; +}; + +// A class to manage asynchronous HTTP requests +class AsyncHttp { +public: + AsyncHttp(WrenVM* vm) : vm_(vm), running_(true) { + // Create a pool of worker threads + for (int i = 0; i < 4; ++i) { + threads_.emplace_back([this] { + while (running_) { + RequestContext* context = requestQueue_.pop(); + if (!running_) break; + + httplib::Client cli("http://example.com"); + if (auto res = cli.Get(context->url.c_str())) { + context->response = res->body; + context->error = false; + } else { + context->response = "Error: " + to_string(res.error()); + context->error = true; + } + + completionQueue_.push(context); + } + }); + } + } + + ~AsyncHttp() { + running_ = false; + // Add dummy requests to unblock worker threads + for (size_t i = 0; i < threads_.size(); ++i) { + requestQueue_.push(nullptr); + } + for (auto& thread : threads_) { + thread.join(); + } + } + + void request(const std::string& url, WrenHandle* callback) { + RequestContext* context = new RequestContext{url, callback, vm_}; + requestQueue_.push(context); + } + + void processCompletions() { + while (!completionQueue_.empty()) { + RequestContext* context = completionQueue_.pop(); + + // Create a handle for the callback function + WrenHandle* callHandle = wrenMakeCallHandle(vm_, "call(_)"); + + wrenEnsureSlots(vm_, 2); + wrenSetSlotHandle(vm_, 0, context->callback); + wrenSetSlotString(vm_, 1, context->response.c_str()); + wrenCall(vm_, callHandle); + + wrenReleaseHandle(vm_, callHandle); + wrenReleaseHandle(vm_, context->callback); + delete context; + } + } + +private: + WrenVM* vm_; + bool running_; + std::vector threads_; + ThreadSafeQueue requestQueue_; + ThreadSafeQueue completionQueue_; +}; + +// End of async_http.c + +// Start of wren.c +// MIT License +// +// Copyright (c) 2013-2021 Robert Nystrom and Wren Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Begin file "wren.h" +#ifndef wren_h +#define wren_h + +#include +#include +#include + +// The Wren semantic version number components. +#define WREN_VERSION_MAJOR 0 +#define WREN_VERSION_MINOR 4 +#define WREN_VERSION_PATCH 0 + +// A human-friendly string representation of the version. +#define WREN_VERSION_STRING "0.4.0" + +// A monotonically increasing numeric representation of the version number. Use +// this if you want to do range checks over versions. +#define WREN_VERSION_NUMBER (WREN_VERSION_MAJOR * 1000000 + \ + WREN_VERSION_MINOR * 1000 + \ + WREN_VERSION_PATCH) + +#ifndef WREN_API + #if defined(_MSC_VER) && defined(WREN_API_DLLEXPORT) + #define WREN_API __declspec( dllexport ) + #else + #define WREN_API + #endif +#endif //WREN_API + +// A single virtual machine for executing Wren code. +// +// Wren has no global state, so all state stored by a running interpreter lives +// here. +typedef struct WrenVM WrenVM; + +// A handle to a Wren object. +// +// This lets code outside of the VM hold a persistent reference to an object. +// After a handle is acquired, and until it is released, this ensures the +// garbage collector will not reclaim the object it references. +typedef struct WrenHandle WrenHandle; + +// A generic allocation function that handles all explicit memory management +// used by Wren. It's used like so: +// +// - To allocate new memory, [memory] is NULL and [newSize] is the desired +// size. It should return the allocated memory or NULL on failure. +// +// - To attempt to grow an existing allocation, [memory] is the memory, and +// [newSize] is the desired size. It should return [memory] if it was able to +// grow it in place, or a new pointer if it had to move it. +// +// - To shrink memory, [memory] and [newSize] are the same as above but it will +// always return [memory]. +// +// - To free memory, [memory] will be the memory to free and [newSize] will be +// zero. It should return NULL. +typedef void* (*WrenReallocateFn)(void* memory, size_t newSize, void* userData); + +// A function callable from Wren code, but implemented in C. +typedef void (*WrenForeignMethodFn)(WrenVM* vm); + +// A finalizer function for freeing resources owned by an instance of a foreign +// class. Unlike most foreign methods, finalizers do not have access to the VM +// and should not interact with it since it's in the middle of a garbage +// collection. +typedef void (*WrenFinalizerFn)(void* data); + +// Gives the host a chance to canonicalize the imported module name, +// potentially taking into account the (previously resolved) name of the module +// that contains the import. Typically, this is used to implement relative +// imports. +typedef const char* (*WrenResolveModuleFn)(WrenVM* vm, + const char* importer, const char* name); + +// Forward declare +struct WrenLoadModuleResult; + +// Called after loadModuleFn is called for module [name]. The original returned result +// is handed back to you in this callback, so that you can free memory if appropriate. +typedef void (*WrenLoadModuleCompleteFn)(WrenVM* vm, const char* name, struct WrenLoadModuleResult result); + +// The result of a loadModuleFn call. +// [source] is the source code for the module, or NULL if the module is not found. +// [onComplete] an optional callback that will be called once Wren is done with the result. +typedef struct WrenLoadModuleResult +{ + const char* source; + WrenLoadModuleCompleteFn onComplete; + void* userData; +} WrenLoadModuleResult; + +// Loads and returns the source code for the module [name]. +typedef WrenLoadModuleResult (*WrenLoadModuleFn)(WrenVM* vm, const char* name); + +// Returns a pointer to a foreign method on [className] in [module] with +// [signature]. +typedef WrenForeignMethodFn (*WrenBindForeignMethodFn)(WrenVM* vm, + const char* module, const char* className, bool isStatic, + const char* signature); + +// Displays a string of text to the user. +typedef void (*WrenWriteFn)(WrenVM* vm, const char* text); + +typedef enum +{ + // A syntax or resolution error detected at compile time. + WREN_ERROR_COMPILE, + + // The error message for a runtime error. + WREN_ERROR_RUNTIME, + + // One entry of a runtime error's stack trace. + WREN_ERROR_STACK_TRACE +} WrenErrorType; + +// Reports an error to the user. +// +// An error detected during compile time is reported by calling this once with +// [type] `WREN_ERROR_COMPILE`, the resolved name of the [module] and [line] +// where the error occurs, and the compiler's error [message]. +// +// A runtime error is reported by calling this once with [type] +// `WREN_ERROR_RUNTIME`, no [module] or [line], and the runtime error's +// [message]. After that, a series of [type] `WREN_ERROR_STACK_TRACE` calls are +// made for each line in the stack trace. Each of those has the resolved +// [module] and [line] where the method or function is defined and [message] is +// the name of the method or function. +typedef void (*WrenErrorFn)( + WrenVM* vm, WrenErrorType type, const char* module, int line, + const char* message); + +typedef struct +{ + // The callback invoked when the foreign object is created. + // + // This must be provided. Inside the body of this, it must call + // [wrenSetSlotNewForeign()] exactly once. + WrenForeignMethodFn allocate; + + // The callback invoked when the garbage collector is about to collect a + // foreign object's memory. + // + // This may be `NULL` if the foreign class does not need to finalize. + WrenFinalizerFn finalize; +} WrenForeignClassMethods; + +// Returns a pair of pointers to the foreign methods used to allocate and +// finalize the data for instances of [className] in resolved [module]. +typedef WrenForeignClassMethods (*WrenBindForeignClassFn)( + WrenVM* vm, const char* module, const char* className); + +typedef struct +{ + // The callback Wren will use to allocate, reallocate, and deallocate memory. + // + // If `NULL`, defaults to a built-in function that uses `realloc` and `free`. + WrenReallocateFn reallocateFn; + + // The callback Wren uses to resolve a module name. + // + // Some host applications may wish to support "relative" imports, where the + // meaning of an import string depends on the module that contains it. To + // support that without baking any policy into Wren itself, the VM gives the + // host a chance to resolve an import string. + // + // Before an import is loaded, it calls this, passing in the name of the + // module that contains the import and the import string. The host app can + // look at both of those and produce a new "canonical" string that uniquely + // identifies the module. This string is then used as the name of the module + // going forward. It is what is passed to [loadModuleFn], how duplicate + // imports of the same module are detected, and how the module is reported in + // stack traces. + // + // If you leave this function NULL, then the original import string is + // treated as the resolved string. + // + // If an import cannot be resolved by the embedder, it should return NULL and + // Wren will report that as a runtime error. + // + // Wren will take ownership of the string you return and free it for you, so + // it should be allocated using the same allocation function you provide + // above. + WrenResolveModuleFn resolveModuleFn; + + // The callback Wren uses to load a module. + // + // Since Wren does not talk directly to the file system, it relies on the + // embedder to physically locate and read the source code for a module. The + // first time an import appears, Wren will call this and pass in the name of + // the module being imported. The method will return a result, which contains + // the source code for that module. Memory for the source is owned by the + // host application, and can be freed using the onComplete callback. + // + // This will only be called once for any given module name. Wren caches the + // result internally so subsequent imports of the same module will use the + // previous source and not call this. + // + // If a module with the given name could not be found by the embedder, it + // should return NULL and Wren will report that as a runtime error. + WrenLoadModuleFn loadModuleFn; + + // The callback Wren uses to find a foreign method and bind it to a class. + // + // When a foreign method is declared in a class, this will be called with the + // foreign method's module, class, and signature when the class body is + // executed. It should return a pointer to the foreign function that will be + // bound to that method. + // + // If the foreign function could not be found, this should return NULL and + // Wren will report it as runtime error. + WrenBindForeignMethodFn bindForeignMethodFn; + + // The callback Wren uses to find a foreign class and get its foreign methods. + // + // When a foreign class is declared, this will be called with the class's + // module and name when the class body is executed. It should return the + // foreign functions uses to allocate and (optionally) finalize the bytes + // stored in the foreign object when an instance is created. + WrenBindForeignClassFn bindForeignClassFn; + + // The callback Wren uses to display text when `System.print()` or the other + // related functions are called. + // + // If this is `NULL`, Wren discards any printed text. + WrenWriteFn writeFn; + + // The callback Wren uses to report errors. + // + // When an error occurs, this will be called with the module name, line + // number, and an error message. If this is `NULL`, Wren doesn't report any + // errors. + WrenErrorFn errorFn; + + // The number of bytes Wren will allocate before triggering the first garbage + // collection. + // + // If zero, defaults to 10MB. + size_t initialHeapSize; + + // After a collection occurs, the threshold for the next collection is + // determined based on the number of bytes remaining in use. This allows Wren + // to shrink its memory usage automatically after reclaiming a large amount + // of memory. + // + // This can be used to ensure that the heap does not get too small, which can + // in turn lead to a large number of collections afterwards as the heap grows + // back to a usable size. + // + // If zero, defaults to 1MB. + size_t minHeapSize; + + // Wren will resize the heap automatically as the number of bytes + // remaining in use after a collection changes. This number determines the + // amount of additional memory Wren will use after a collection, as a + // percentage of the current heap size. + // + // For example, say that this is 50. After a garbage collection, when there + // are 400 bytes of memory still in use, the next collection will be triggered + // after a total of 600 bytes are allocated (including the 400 already in + // use.) + // + // Setting this to a smaller number wastes less memory, but triggers more + // frequent garbage collections. + // + // If zero, defaults to 50. + int heapGrowthPercent; + + // User-defined data associated with the VM. + void* userData; + +} WrenConfiguration; + +typedef enum +{ + WREN_RESULT_SUCCESS, + WREN_RESULT_COMPILE_ERROR, + WREN_RESULT_RUNTIME_ERROR +} WrenInterpretResult; + +// The type of an object stored in a slot. +// +// This is not necessarily the object's *class*, but instead its low level +// representation type. +typedef enum +{ + WREN_TYPE_BOOL, + WREN_TYPE_NUM, + WREN_TYPE_FOREIGN, + WREN_TYPE_LIST, + WREN_TYPE_MAP, + WREN_TYPE_NULL, + WREN_TYPE_STRING, + + // The object is of a type that isn't accessible by the C API. + WREN_TYPE_UNKNOWN +} WrenType; + +// Get the current wren version number. +// +// Can be used to range checks over versions. +WREN_API int wrenGetVersionNumber(); + +// Initializes [configuration] with all of its default values. +// +// Call this before setting the particular fields you care about. +WREN_API void wrenInitConfiguration(WrenConfiguration* configuration); + +// Creates a new Wren virtual machine using the given [configuration]. Wren +// will copy the configuration data, so the argument passed to this can be +// freed after calling this. If [configuration] is `NULL`, uses a default +// configuration. +WREN_API WrenVM* wrenNewVM(WrenConfiguration* configuration); + +// Disposes of all resources is use by [vm], which was previously created by a +// call to [wrenNewVM]. +WREN_API void wrenFreeVM(WrenVM* vm); + +// Immediately run the garbage collector to free unused memory. +WREN_API void wrenCollectGarbage(WrenVM* vm); + +// Runs [source], a string of Wren source code in a new fiber in [vm] in the +// context of resolved [module]. +WREN_API WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module, + const char* source); + +// Creates a handle that can be used to invoke a method with [signature] on +// using a receiver and arguments that are set up on the stack. +// +// This handle can be used repeatedly to directly invoke that method from C +// code using [wrenCall]. +// +// When you are done with this handle, it must be released using +// [wrenReleaseHandle]. +WREN_API WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature); + +// Calls [method], using the receiver and arguments previously set up on the +// stack. +// +// [method] must have been created by a call to [wrenMakeCallHandle]. The +// arguments to the method must be already on the stack. The receiver should be +// in slot 0 with the remaining arguments following it, in order. It is an +// error if the number of arguments provided does not match the method's +// signature. +// +// After this returns, you can access the return value from slot 0 on the stack. +WREN_API WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method); + +// Releases the reference stored in [handle]. After calling this, [handle] can +// no longer be used. +WREN_API void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle); + +// The following functions are intended to be called from foreign methods or +// finalizers. The interface Wren provides to a foreign method is like a +// register machine: you are given a numbered array of slots that values can be +// read from and written to. Values always live in a slot (unless explicitly +// captured using wrenGetSlotHandle(), which ensures the garbage collector can +// find them. +// +// When your foreign function is called, you are given one slot for the receiver +// and each argument to the method. The receiver is in slot 0 and the arguments +// are in increasingly numbered slots after that. You are free to read and +// write to those slots as you want. If you want more slots to use as scratch +// space, you can call wrenEnsureSlots() to add more. +// +// When your function returns, every slot except slot zero is discarded and the +// value in slot zero is used as the return value of the method. If you don't +// store a return value in that slot yourself, it will retain its previous +// value, the receiver. +// +// While Wren is dynamically typed, C is not. This means the C interface has to +// support the various types of primitive values a Wren variable can hold: bool, +// double, string, etc. If we supported this for every operation in the C API, +// there would be a combinatorial explosion of functions, like "get a +// double-valued element from a list", "insert a string key and double value +// into a map", etc. +// +// To avoid that, the only way to convert to and from a raw C value is by going +// into and out of a slot. All other functions work with values already in a +// slot. So, to add an element to a list, you put the list in one slot, and the +// element in another. Then there is a single API function wrenInsertInList() +// that takes the element out of that slot and puts it into the list. +// +// The goal of this API is to be easy to use while not compromising performance. +// The latter means it does not do type or bounds checking at runtime except +// using assertions which are generally removed from release builds. C is an +// unsafe language, so it's up to you to be careful to use it correctly. In +// return, you get a very fast FFI. + +// Returns the number of slots available to the current foreign method. +WREN_API int wrenGetSlotCount(WrenVM* vm); + +// Ensures that the foreign method stack has at least [numSlots] available for +// use, growing the stack if needed. +// +// Does not shrink the stack if it has more than enough slots. +// +// It is an error to call this from a finalizer. +WREN_API void wrenEnsureSlots(WrenVM* vm, int numSlots); + +// Gets the type of the object in [slot]. +WREN_API WrenType wrenGetSlotType(WrenVM* vm, int slot); + +// Reads a boolean value from [slot]. +// +// It is an error to call this if the slot does not contain a boolean value. +WREN_API bool wrenGetSlotBool(WrenVM* vm, int slot); + +// Reads a byte array from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// Returns a pointer to the first byte of the array and fill [length] with the +// number of bytes in the array. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length); + +// Reads a number from [slot]. +// +// It is an error to call this if the slot does not contain a number. +WREN_API double wrenGetSlotDouble(WrenVM* vm, int slot); + +// Reads a foreign object from [slot] and returns a pointer to the foreign data +// stored with it. +// +// It is an error to call this if the slot does not contain an instance of a +// foreign class. +WREN_API void* wrenGetSlotForeign(WrenVM* vm, int slot); + +// Reads a string from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotString(WrenVM* vm, int slot); + +// Creates a handle for the value stored in [slot]. +// +// This will prevent the object that is referred to from being garbage collected +// until the handle is released by calling [wrenReleaseHandle()]. +WREN_API WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot); + +// Stores the boolean [value] in [slot]. +WREN_API void wrenSetSlotBool(WrenVM* vm, int slot, bool value); + +// Stores the array [length] of [bytes] in [slot]. +// +// The bytes are copied to a new string within Wren's heap, so you can free +// memory used by them after this is called. +WREN_API void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length); + +// Stores the numeric [value] in [slot]. +WREN_API void wrenSetSlotDouble(WrenVM* vm, int slot, double value); + +// Creates a new instance of the foreign class stored in [classSlot] with [size] +// bytes of raw storage and places the resulting object in [slot]. +// +// This does not invoke the foreign class's constructor on the new instance. If +// you need that to happen, call the constructor from Wren, which will then +// call the allocator foreign method. In there, call this to create the object +// and then the constructor will be invoked when the allocator returns. +// +// Returns a pointer to the foreign object's data. +WREN_API void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size); + +// Stores a new empty list in [slot]. +WREN_API void wrenSetSlotNewList(WrenVM* vm, int slot); + +// Stores a new empty map in [slot]. +WREN_API void wrenSetSlotNewMap(WrenVM* vm, int slot); + +// Stores null in [slot]. +WREN_API void wrenSetSlotNull(WrenVM* vm, int slot); + +// Stores the string [text] in [slot]. +// +// The [text] is copied to a new string within Wren's heap, so you can free +// memory used by it after this is called. The length is calculated using +// [strlen()]. If the string may contain any null bytes in the middle, then you +// should use [wrenSetSlotBytes()] instead. +WREN_API void wrenSetSlotString(WrenVM* vm, int slot, const char* text); + +// Stores the value captured in [handle] in [slot]. +// +// This does not release the handle for the value. +WREN_API void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle); + +// Returns the number of elements in the list stored in [slot]. +WREN_API int wrenGetListCount(WrenVM* vm, int slot); + +// Reads element [index] from the list in [listSlot] and stores it in +// [elementSlot]. +WREN_API void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Sets the value stored at [index] in the list at [listSlot], +// to the value from [elementSlot]. +WREN_API void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Takes the value stored at [elementSlot] and inserts it into the list stored +// at [listSlot] at [index]. +// +// As in Wren, negative indexes can be used to insert from the end. To append +// an element, use `-1` for the index. +WREN_API void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Returns the number of entries in the map stored in [slot]. +WREN_API int wrenGetMapCount(WrenVM* vm, int slot); + +// Returns true if the key in [keySlot] is found in the map placed in [mapSlot]. +WREN_API bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot); + +// Retrieves a value with the key in [keySlot] from the map in [mapSlot] and +// stores it in [valueSlot]. +WREN_API void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Takes the value stored at [valueSlot] and inserts it into the map stored +// at [mapSlot] with key [keySlot]. +WREN_API void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Removes a value from the map in [mapSlot], with the key from [keySlot], +// and place it in [removedValueSlot]. If not found, [removedValueSlot] is +// set to null, the same behaviour as the Wren Map API. +WREN_API void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot, + int removedValueSlot); + +// Looks up the top level variable with [name] in resolved [module] and stores +// it in [slot]. +WREN_API void wrenGetVariable(WrenVM* vm, const char* module, const char* name, + int slot); + +// Looks up the top level variable with [name] in resolved [module], +// returns false if not found. The module must be imported at the time, +// use wrenHasModule to ensure that before calling. +WREN_API bool wrenHasVariable(WrenVM* vm, const char* module, const char* name); + +// Returns true if [module] has been imported/resolved before, false if not. +WREN_API bool wrenHasModule(WrenVM* vm, const char* module); + +// Sets the current fiber to be aborted, and uses the value in [slot] as the +// runtime error object. +WREN_API void wrenAbortFiber(WrenVM* vm, int slot); + +// Returns the user data associated with the WrenVM. +WREN_API void* wrenGetUserData(WrenVM* vm); + +// Sets user data associated with the WrenVM. +WREN_API void wrenSetUserData(WrenVM* vm, void* userData); + +#endif +// End file "wren.h" +// Begin file "wren_debug.h" +#ifndef wren_debug_h +#define wren_debug_h + +// Begin file "wren_value.h" +#ifndef wren_value_h +#define wren_value_h + +#include +#include + +// Begin file "wren_common.h" +#ifndef wren_common_h +#define wren_common_h + +// This header contains macros and defines used across the entire Wren +// implementation. In particular, it contains "configuration" defines that +// control how Wren works. Some of these are only used while hacking on Wren +// itself. +// +// This header is *not* intended to be included by code outside of Wren itself. + +// Wren pervasively uses the C99 integer types (uint16_t, etc.) along with some +// of the associated limit constants (UINT32_MAX, etc.). The constants are not +// part of standard C++, so aren't included by default by C++ compilers when you +// include unless __STDC_LIMIT_MACROS is defined. +#define __STDC_LIMIT_MACROS +#include + +// These flags let you control some details of the interpreter's implementation. +// Usually they trade-off a bit of portability for speed. They default to the +// most efficient behavior. + +// If true, then Wren uses a NaN-tagged double for its core value +// representation. Otherwise, it uses a larger more conventional struct. The +// former is significantly faster and more compact. The latter is useful for +// debugging and may be more portable. +// +// Defaults to on. +#ifndef WREN_NAN_TAGGING + #define WREN_NAN_TAGGING 1 +#endif + +// If true, the VM's interpreter loop uses computed gotos. See this for more: +// http://gcc.gnu.org/onlinedocs/gcc-3.1.1/gcc/Labels-as-Values.html +// Enabling this speeds up the main dispatch loop a bit, but requires compiler +// support. +// see https://bullno1.com/blog/switched-goto for alternative +// Defaults to true on supported compilers. +#ifndef WREN_COMPUTED_GOTO + #if defined(_MSC_VER) && !defined(__clang__) + // No computed gotos in Visual Studio. + #define WREN_COMPUTED_GOTO 0 + #else + #define WREN_COMPUTED_GOTO 1 + #endif +#endif + +// The VM includes a number of optional modules. You can choose to include +// these or not. By default, they are all available. To disable one, set the +// corresponding `WREN_OPT_` define to `0`. +#ifndef WREN_OPT_META + #define WREN_OPT_META 1 +#endif + +#ifndef WREN_OPT_RANDOM + #define WREN_OPT_RANDOM 1 +#endif + +// These flags are useful for debugging and hacking on Wren itself. They are not +// intended to be used for production code. They default to off. + +// Set this to true to stress test the GC. It will perform a collection before +// every allocation. This is useful to ensure that memory is always correctly +// reachable. +#define WREN_DEBUG_GC_STRESS 0 + +// Set this to true to log memory operations as they occur. +#define WREN_DEBUG_TRACE_MEMORY 0 + +// Set this to true to log garbage collections as they occur. +#define WREN_DEBUG_TRACE_GC 0 + +// Set this to true to print out the compiled bytecode of each function. +#define WREN_DEBUG_DUMP_COMPILED_CODE 0 + +// Set this to trace each instruction as it's executed. +#define WREN_DEBUG_TRACE_INSTRUCTIONS 0 + +// The maximum number of module-level variables that may be defined at one time. +// This limitation comes from the 16 bits used for the arguments to +// `CODE_LOAD_MODULE_VAR` and `CODE_STORE_MODULE_VAR`. +#define MAX_MODULE_VARS 65536 + +// The maximum number of arguments that can be passed to a method. Note that +// this limitation is hardcoded in other places in the VM, in particular, the +// `CODE_CALL_XX` instructions assume a certain maximum number. +#define MAX_PARAMETERS 16 + +// The maximum name of a method, not including the signature. This is an +// arbitrary but enforced maximum just so we know how long the method name +// strings need to be in the parser. +#define MAX_METHOD_NAME 64 + +// The maximum length of a method signature. Signatures look like: +// +// foo // Getter. +// foo() // No-argument method. +// foo(_) // One-argument method. +// foo(_,_) // Two-argument method. +// init foo() // Constructor initializer. +// +// The maximum signature length takes into account the longest method name, the +// maximum number of parameters with separators between them, "init ", and "()". +#define MAX_METHOD_SIGNATURE (MAX_METHOD_NAME + (MAX_PARAMETERS * 2) + 6) + +// The maximum length of an identifier. The only real reason for this limitation +// is so that error messages mentioning variables can be stack allocated. +#define MAX_VARIABLE_NAME 64 + +// The maximum number of fields a class can have, including inherited fields. +// This is explicit in the bytecode since `CODE_CLASS` and `CODE_SUBCLASS` take +// a single byte for the number of fields. Note that it's 255 and not 256 +// because creating a class takes the *number* of fields, not the *highest +// field index*. +#define MAX_FIELDS 255 + +// Use the VM's allocator to allocate an object of [type]. +#define ALLOCATE(vm, type) \ + ((type*)wrenReallocate(vm, NULL, 0, sizeof(type))) + +// Use the VM's allocator to allocate an object of [mainType] containing a +// flexible array of [count] objects of [arrayType]. +#define ALLOCATE_FLEX(vm, mainType, arrayType, count) \ + ((mainType*)wrenReallocate(vm, NULL, 0, \ + sizeof(mainType) + sizeof(arrayType) * (count))) + +// Use the VM's allocator to allocate an array of [count] elements of [type]. +#define ALLOCATE_ARRAY(vm, type, count) \ + ((type*)wrenReallocate(vm, NULL, 0, sizeof(type) * (count))) + +// Use the VM's allocator to free the previously allocated memory at [pointer]. +#define DEALLOCATE(vm, pointer) wrenReallocate(vm, pointer, 0, 0) + +// The Microsoft compiler does not support the "inline" modifier when compiling +// as plain C. +#if defined( _MSC_VER ) && !defined(__cplusplus) + #define inline _inline +#endif + +// This is used to clearly mark flexible-sized arrays that appear at the end of +// some dynamically-allocated structs, known as the "struct hack". +#if __STDC_VERSION__ >= 199901L + // In C99, a flexible array member is just "[]". + #define FLEXIBLE_ARRAY +#else + // Elsewhere, use a zero-sized array. It's technically undefined behavior, + // but works reliably in most known compilers. + #define FLEXIBLE_ARRAY 0 +#endif + +// Assertions are used to validate program invariants. They indicate things the +// program expects to be true about its internal state during execution. If an +// assertion fails, there is a bug in Wren. +// +// Assertions add significant overhead, so are only enabled in debug builds. +#ifdef DEBUG + + #include + + #define ASSERT(condition, message) \ + do \ + { \ + if (!(condition)) \ + { \ + fprintf(stderr, "[%s:%d] Assert failed in %s(): %s\n", \ + __FILE__, __LINE__, __func__, message); \ + abort(); \ + } \ + } while (false) + + // Indicates that we know execution should never reach this point in the + // program. In debug mode, we assert this fact because it's a bug to get here. + // + // In release mode, we use compiler-specific built in functions to tell the + // compiler the code can't be reached. This avoids "missing return" warnings + // in some cases and also lets it perform some optimizations by assuming the + // code is never reached. + #define UNREACHABLE() \ + do \ + { \ + fprintf(stderr, "[%s:%d] This code should not be reached in %s()\n", \ + __FILE__, __LINE__, __func__); \ + abort(); \ + } while (false) + +#else + + #define ASSERT(condition, message) do { } while (false) + + // Tell the compiler that this part of the code will never be reached. + #if defined( _MSC_VER ) + #define UNREACHABLE() __assume(0) + #elif (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)) + #define UNREACHABLE() __builtin_unreachable() + #else + #define UNREACHABLE() + #endif + +#endif + +#endif +// End file "wren_common.h" +// Begin file "wren_math.h" +#ifndef wren_math_h +#define wren_math_h + +#include +#include + +// A union to let us reinterpret a double as raw bits and back. +typedef union +{ + uint64_t bits64; + uint32_t bits32[2]; + double num; +} WrenDoubleBits; + +#define WREN_DOUBLE_QNAN_POS_MIN_BITS (UINT64_C(0x7FF8000000000000)) +#define WREN_DOUBLE_QNAN_POS_MAX_BITS (UINT64_C(0x7FFFFFFFFFFFFFFF)) + +#define WREN_DOUBLE_NAN (wrenDoubleFromBits(WREN_DOUBLE_QNAN_POS_MIN_BITS)) + +static inline double wrenDoubleFromBits(uint64_t bits) +{ + WrenDoubleBits data; + data.bits64 = bits; + return data.num; +} + +static inline uint64_t wrenDoubleToBits(double num) +{ + WrenDoubleBits data; + data.num = num; + return data.bits64; +} + +#endif +// End file "wren_math.h" +// Begin file "wren_utils.h" +#ifndef wren_utils_h +#define wren_utils_h + + +// Reusable data structures and other utility functions. + +// Forward declare this here to break a cycle between wren_utils.h and +// wren_value.h. +typedef struct sObjString ObjString; + +// We need buffers of a few different types. To avoid lots of casting between +// void* and back, we'll use the preprocessor as a poor man's generics and let +// it generate a few type-specific ones. +#define DECLARE_BUFFER(name, type) \ + typedef struct \ + { \ + type* data; \ + int count; \ + int capacity; \ + } name##Buffer; \ + void wren##name##BufferInit(name##Buffer* buffer); \ + void wren##name##BufferClear(WrenVM* vm, name##Buffer* buffer); \ + void wren##name##BufferFill(WrenVM* vm, name##Buffer* buffer, type data, \ + int count); \ + void wren##name##BufferWrite(WrenVM* vm, name##Buffer* buffer, type data) + +// This should be used once for each type instantiation, somewhere in a .c file. +#define DEFINE_BUFFER(name, type) \ + void wren##name##BufferInit(name##Buffer* buffer) \ + { \ + buffer->data = NULL; \ + buffer->capacity = 0; \ + buffer->count = 0; \ + } \ + \ + void wren##name##BufferClear(WrenVM* vm, name##Buffer* buffer) \ + { \ + wrenReallocate(vm, buffer->data, 0, 0); \ + wren##name##BufferInit(buffer); \ + } \ + \ + void wren##name##BufferFill(WrenVM* vm, name##Buffer* buffer, type data, \ + int count) \ + { \ + if (buffer->capacity < buffer->count + count) \ + { \ + int capacity = wrenPowerOf2Ceil(buffer->count + count); \ + buffer->data = (type*)wrenReallocate(vm, buffer->data, \ + buffer->capacity * sizeof(type), capacity * sizeof(type)); \ + buffer->capacity = capacity; \ + } \ + \ + for (int i = 0; i < count; i++) \ + { \ + buffer->data[buffer->count++] = data; \ + } \ + } \ + \ + void wren##name##BufferWrite(WrenVM* vm, name##Buffer* buffer, type data) \ + { \ + wren##name##BufferFill(vm, buffer, data, 1); \ + } + +DECLARE_BUFFER(Byte, uint8_t); +DECLARE_BUFFER(Int, int); +DECLARE_BUFFER(String, ObjString*); + +// TODO: Change this to use a map. +typedef StringBuffer SymbolTable; + +// Initializes the symbol table. +void wrenSymbolTableInit(SymbolTable* symbols); + +// Frees all dynamically allocated memory used by the symbol table, but not the +// SymbolTable itself. +void wrenSymbolTableClear(WrenVM* vm, SymbolTable* symbols); + +// Adds name to the symbol table. Returns the index of it in the table. +int wrenSymbolTableAdd(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length); + +// Adds name to the symbol table. Returns the index of it in the table. Will +// use an existing symbol if already present. +int wrenSymbolTableEnsure(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length); + +// Looks up name in the symbol table. Returns its index if found or -1 if not. +int wrenSymbolTableFind(const SymbolTable* symbols, + const char* name, size_t length); + +void wrenBlackenSymbolTable(WrenVM* vm, SymbolTable* symbolTable); + +// Returns the number of bytes needed to encode [value] in UTF-8. +// +// Returns 0 if [value] is too large to encode. +int wrenUtf8EncodeNumBytes(int value); + +// Encodes value as a series of bytes in [bytes], which is assumed to be large +// enough to hold the encoded result. +// +// Returns the number of written bytes. +int wrenUtf8Encode(int value, uint8_t* bytes); + +// Decodes the UTF-8 sequence starting at [bytes] (which has max [length]), +// returning the code point. +// +// Returns -1 if the bytes are not a valid UTF-8 sequence. +int wrenUtf8Decode(const uint8_t* bytes, uint32_t length); + +// Returns the number of bytes in the UTF-8 sequence starting with [byte]. +// +// If the character at that index is not the beginning of a UTF-8 sequence, +// returns 0. +int wrenUtf8DecodeNumBytes(uint8_t byte); + +// Returns the smallest power of two that is equal to or greater than [n]. +int wrenPowerOf2Ceil(int n); + +// Validates that [value] is within `[0, count)`. Also allows +// negative indices which map backwards from the end. Returns the valid positive +// index value. If invalid, returns `UINT32_MAX`. +uint32_t wrenValidateIndex(uint32_t count, int64_t value); + +#endif +// End file "wren_utils.h" + +// This defines the built-in types and their core representations in memory. +// Since Wren is dynamically typed, any variable can hold a value of any type, +// and the type can change at runtime. Implementing this efficiently is +// critical for performance. +// +// The main type exposed by this is [Value]. A C variable of that type is a +// storage location that can hold any Wren value. The stack, module variables, +// and instance fields are all implemented in C as variables of type Value. +// +// The built-in types for booleans, numbers, and null are unboxed: their value +// is stored directly in the Value, and copying a Value copies the value. Other +// types--classes, instances of classes, functions, lists, and strings--are all +// reference types. They are stored on the heap and the Value just stores a +// pointer to it. Copying the Value copies a reference to the same object. The +// Wren implementation calls these "Obj", or objects, though to a user, all +// values are objects. +// +// There is also a special singleton value "undefined". It is used internally +// but never appears as a real value to a user. It has two uses: +// +// - It is used to identify module variables that have been implicitly declared +// by use in a forward reference but not yet explicitly declared. These only +// exist during compilation and do not appear at runtime. +// +// - It is used to represent unused map entries in an ObjMap. +// +// There are two supported Value representations. The main one uses a technique +// called "NaN tagging" (explained in detail below) to store a number, any of +// the value types, or a pointer, all inside one double-precision floating +// point number. A larger, slower, Value type that uses a struct to store these +// is also supported, and is useful for debugging the VM. +// +// The representation is controlled by the `WREN_NAN_TAGGING` define. If that's +// defined, Nan tagging is used. + +// These macros cast a Value to one of the specific object types. These do *not* +// perform any validation, so must only be used after the Value has been +// ensured to be the right type. +#define AS_CLASS(value) ((ObjClass*)AS_OBJ(value)) // ObjClass* +#define AS_CLOSURE(value) ((ObjClosure*)AS_OBJ(value)) // ObjClosure* +#define AS_FIBER(v) ((ObjFiber*)AS_OBJ(v)) // ObjFiber* +#define AS_FN(value) ((ObjFn*)AS_OBJ(value)) // ObjFn* +#define AS_FOREIGN(v) ((ObjForeign*)AS_OBJ(v)) // ObjForeign* +#define AS_INSTANCE(value) ((ObjInstance*)AS_OBJ(value)) // ObjInstance* +#define AS_LIST(value) ((ObjList*)AS_OBJ(value)) // ObjList* +#define AS_MAP(value) ((ObjMap*)AS_OBJ(value)) // ObjMap* +#define AS_MODULE(value) ((ObjModule*)AS_OBJ(value)) // ObjModule* +#define AS_NUM(value) (wrenValueToNum(value)) // double +#define AS_RANGE(v) ((ObjRange*)AS_OBJ(v)) // ObjRange* +#define AS_STRING(v) ((ObjString*)AS_OBJ(v)) // ObjString* +#define AS_CSTRING(v) (AS_STRING(v)->value) // const char* + +// These macros promote a primitive C value to a full Wren Value. There are +// more defined below that are specific to the Nan tagged or other +// representation. +#define BOOL_VAL(boolean) ((boolean) ? TRUE_VAL : FALSE_VAL) // boolean +#define NUM_VAL(num) (wrenNumToValue(num)) // double +#define OBJ_VAL(obj) (wrenObjectToValue((Obj*)(obj))) // Any Obj___* + +// These perform type tests on a Value, returning `true` if the Value is of the +// given type. +#define IS_BOOL(value) (wrenIsBool(value)) // Bool +#define IS_CLASS(value) (wrenIsObjType(value, OBJ_CLASS)) // ObjClass +#define IS_CLOSURE(value) (wrenIsObjType(value, OBJ_CLOSURE)) // ObjClosure +#define IS_FIBER(value) (wrenIsObjType(value, OBJ_FIBER)) // ObjFiber +#define IS_FN(value) (wrenIsObjType(value, OBJ_FN)) // ObjFn +#define IS_FOREIGN(value) (wrenIsObjType(value, OBJ_FOREIGN)) // ObjForeign +#define IS_INSTANCE(value) (wrenIsObjType(value, OBJ_INSTANCE)) // ObjInstance +#define IS_LIST(value) (wrenIsObjType(value, OBJ_LIST)) // ObjList +#define IS_MAP(value) (wrenIsObjType(value, OBJ_MAP)) // ObjMap +#define IS_RANGE(value) (wrenIsObjType(value, OBJ_RANGE)) // ObjRange +#define IS_STRING(value) (wrenIsObjType(value, OBJ_STRING)) // ObjString + +// Creates a new string object from [text], which should be a bare C string +// literal. This determines the length of the string automatically at compile +// time based on the size of the character array (-1 for the terminating '\0'). +#define CONST_STRING(vm, text) wrenNewStringLength((vm), (text), sizeof(text) - 1) + +// Identifies which specific type a heap-allocated object is. +typedef enum { + OBJ_CLASS, + OBJ_CLOSURE, + OBJ_FIBER, + OBJ_FN, + OBJ_FOREIGN, + OBJ_INSTANCE, + OBJ_LIST, + OBJ_MAP, + OBJ_MODULE, + OBJ_RANGE, + OBJ_STRING, + OBJ_UPVALUE +} ObjType; + +typedef struct sObjClass ObjClass; + +// Base struct for all heap-allocated objects. +typedef struct sObj Obj; +struct sObj +{ + ObjType type; + bool isDark; + + // The object's class. + ObjClass* classObj; + + // The next object in the linked list of all currently allocated objects. + struct sObj* next; +}; + +#if WREN_NAN_TAGGING + +typedef uint64_t Value; + +#else + +typedef enum +{ + VAL_FALSE, + VAL_NULL, + VAL_NUM, + VAL_TRUE, + VAL_UNDEFINED, + VAL_OBJ +} ValueType; + +typedef struct +{ + ValueType type; + union + { + double num; + Obj* obj; + } as; +} Value; + +#endif + +DECLARE_BUFFER(Value, Value); + +// A heap-allocated string object. +struct sObjString +{ + Obj obj; + + // Number of bytes in the string, not including the null terminator. + uint32_t length; + + // The hash value of the string's contents. + uint32_t hash; + + // Inline array of the string's bytes followed by a null terminator. + char value[FLEXIBLE_ARRAY]; +}; + +// The dynamically allocated data structure for a variable that has been used +// by a closure. Whenever a function accesses a variable declared in an +// enclosing function, it will get to it through this. +// +// An upvalue can be either "closed" or "open". An open upvalue points directly +// to a [Value] that is still stored on the fiber's stack because the local +// variable is still in scope in the function where it's declared. +// +// When that local variable goes out of scope, the upvalue pointing to it will +// be closed. When that happens, the value gets copied off the stack into the +// upvalue itself. That way, it can have a longer lifetime than the stack +// variable. +typedef struct sObjUpvalue +{ + // The object header. Note that upvalues have this because they are garbage + // collected, but they are not first class Wren objects. + Obj obj; + + // Pointer to the variable this upvalue is referencing. + Value* value; + + // If the upvalue is closed (i.e. the local variable it was pointing to has + // been popped off the stack) then the closed-over value will be hoisted out + // of the stack into here. [value] will then be changed to point to this. + Value closed; + + // Open upvalues are stored in a linked list by the fiber. This points to the + // next upvalue in that list. + struct sObjUpvalue* next; +} ObjUpvalue; + +// The type of a primitive function. +// +// Primitives are similar to foreign functions, but have more direct access to +// VM internals. It is passed the arguments in [args]. If it returns a value, +// it places it in `args[0]` and returns `true`. If it causes a runtime error +// or modifies the running fiber, it returns `false`. +typedef bool (*Primitive)(WrenVM* vm, Value* args); + +// TODO: See if it's actually a perf improvement to have this in a separate +// struct instead of in ObjFn. +// Stores debugging information for a function used for things like stack +// traces. +typedef struct +{ + // The name of the function. Heap allocated and owned by the FnDebug. + char* name; + + // An array of line numbers. There is one element in this array for each + // bytecode in the function's bytecode array. The value of that element is + // the line in the source code that generated that instruction. + IntBuffer sourceLines; +} FnDebug; + +// A loaded module and the top-level variables it defines. +// +// While this is an Obj and is managed by the GC, it never appears as a +// first-class object in Wren. +typedef struct +{ + Obj obj; + + // The currently defined top-level variables. + ValueBuffer variables; + + // Symbol table for the names of all module variables. Indexes here directly + // correspond to entries in [variables]. + SymbolTable variableNames; + + // The name of the module. + ObjString* name; +} ObjModule; + +// A function object. It wraps and owns the bytecode and other debug information +// for a callable chunk of code. +// +// Function objects are not passed around and invoked directly. Instead, they +// are always referenced by an [ObjClosure] which is the real first-class +// representation of a function. This isn't strictly necessary if they function +// has no upvalues, but lets the rest of the VM assume all called objects will +// be closures. +typedef struct +{ + Obj obj; + + ByteBuffer code; + ValueBuffer constants; + + // The module where this function was defined. + ObjModule* module; + + // The maximum number of stack slots this function may use. + int maxSlots; + + // The number of upvalues this function closes over. + int numUpvalues; + + // The number of parameters this function expects. Used to ensure that .call + // handles a mismatch between number of parameters and arguments. This will + // only be set for fns, and not ObjFns that represent methods or scripts. + int arity; + FnDebug* debug; +} ObjFn; + +// An instance of a first-class function and the environment it has closed over. +// Unlike [ObjFn], this has captured the upvalues that the function accesses. +typedef struct +{ + Obj obj; + + // The function that this closure is an instance of. + ObjFn* fn; + + // The upvalues this function has closed over. + ObjUpvalue* upvalues[FLEXIBLE_ARRAY]; +} ObjClosure; + +typedef struct +{ + // Pointer to the current (really next-to-be-executed) instruction in the + // function's bytecode. + uint8_t* ip; + + // The closure being executed. + ObjClosure* closure; + + // Pointer to the first stack slot used by this call frame. This will contain + // the receiver, followed by the function's parameters, then local variables + // and temporaries. + Value* stackStart; +} CallFrame; + +// Tracks how this fiber has been invoked, aside from the ways that can be +// detected from the state of other fields in the fiber. +typedef enum +{ + // The fiber is being run from another fiber using a call to `try()`. + FIBER_TRY, + + // The fiber was directly invoked by `runInterpreter()`. This means it's the + // initial fiber used by a call to `wrenCall()` or `wrenInterpret()`. + FIBER_ROOT, + + // The fiber is invoked some other way. If [caller] is `NULL` then the fiber + // was invoked using `call()`. If [numFrames] is zero, then the fiber has + // finished running and is done. If [numFrames] is one and that frame's `ip` + // points to the first byte of code, the fiber has not been started yet. + FIBER_OTHER, +} FiberState; + +typedef struct sObjFiber +{ + Obj obj; + + // The stack of value slots. This is used for holding local variables and + // temporaries while the fiber is executing. It is heap-allocated and grown + // as needed. + Value* stack; + + // A pointer to one past the top-most value on the stack. + Value* stackTop; + + // The number of allocated slots in the stack array. + int stackCapacity; + + // The stack of call frames. This is a dynamic array that grows as needed but + // never shrinks. + CallFrame* frames; + + // The number of frames currently in use in [frames]. + int numFrames; + + // The number of [frames] allocated. + int frameCapacity; + + // Pointer to the first node in the linked list of open upvalues that are + // pointing to values still on the stack. The head of the list will be the + // upvalue closest to the top of the stack, and then the list works downwards. + ObjUpvalue* openUpvalues; + + // The fiber that ran this one. If this fiber is yielded, control will resume + // to this one. May be `NULL`. + struct sObjFiber* caller; + + // If the fiber failed because of a runtime error, this will contain the + // error object. Otherwise, it will be null. + Value error; + + FiberState state; +} ObjFiber; + +typedef enum +{ + // A primitive method implemented in C in the VM. Unlike foreign methods, + // this can directly manipulate the fiber's stack. + METHOD_PRIMITIVE, + + // A primitive that handles .call on Fn. + METHOD_FUNCTION_CALL, + + // A externally-defined C method. + METHOD_FOREIGN, + + // A normal user-defined method. + METHOD_BLOCK, + + // No method for the given symbol. + METHOD_NONE +} MethodType; + +typedef struct +{ + MethodType type; + + // The method function itself. The [type] determines which field of the union + // is used. + union + { + Primitive primitive; + WrenForeignMethodFn foreign; + ObjClosure* closure; + } as; +} Method; + +DECLARE_BUFFER(Method, Method); + +struct sObjClass +{ + Obj obj; + ObjClass* superclass; + + // The number of fields needed for an instance of this class, including all + // of its superclass fields. + int numFields; + + // The table of methods that are defined in or inherited by this class. + // Methods are called by symbol, and the symbol directly maps to an index in + // this table. This makes method calls fast at the expense of empty cells in + // the list for methods the class doesn't support. + // + // You can think of it as a hash table that never has collisions but has a + // really low load factor. Since methods are pretty small (just a type and a + // pointer), this should be a worthwhile trade-off. + MethodBuffer methods; + + // The name of the class. + ObjString* name; + + // The ClassAttribute for the class, if any + Value attributes; +}; + +typedef struct +{ + Obj obj; + uint8_t data[FLEXIBLE_ARRAY]; +} ObjForeign; + +typedef struct +{ + Obj obj; + Value fields[FLEXIBLE_ARRAY]; +} ObjInstance; + +typedef struct +{ + Obj obj; + + // The elements in the list. + ValueBuffer elements; +} ObjList; + +typedef struct +{ + // The entry's key, or UNDEFINED_VAL if the entry is not in use. + Value key; + + // The value associated with the key. If the key is UNDEFINED_VAL, this will + // be false to indicate an open available entry or true to indicate a + // tombstone -- an entry that was previously in use but was then deleted. + Value value; +} MapEntry; + +// A hash table mapping keys to values. +// +// We use something very simple: open addressing with linear probing. The hash +// table is an array of entries. Each entry is a key-value pair. If the key is +// the special UNDEFINED_VAL, it indicates no value is currently in that slot. +// Otherwise, it's a valid key, and the value is the value associated with it. +// +// When entries are added, the array is dynamically scaled by GROW_FACTOR to +// keep the number of filled slots under MAP_LOAD_PERCENT. Likewise, if the map +// gets empty enough, it will be resized to a smaller array. When this happens, +// all existing entries are rehashed and re-added to the new array. +// +// When an entry is removed, its slot is replaced with a "tombstone". This is an +// entry whose key is UNDEFINED_VAL and whose value is TRUE_VAL. When probing +// for a key, we will continue past tombstones, because the desired key may be +// found after them if the key that was removed was part of a prior collision. +// When the array gets resized, all tombstones are discarded. +typedef struct +{ + Obj obj; + + // The number of entries allocated. + uint32_t capacity; + + // The number of entries in the map. + uint32_t count; + + // Pointer to a contiguous array of [capacity] entries. + MapEntry* entries; +} ObjMap; + +typedef struct +{ + Obj obj; + + // The beginning of the range. + double from; + + // The end of the range. May be greater or less than [from]. + double to; + + // True if [to] is included in the range. + bool isInclusive; +} ObjRange; + +// An IEEE 754 double-precision float is a 64-bit value with bits laid out like: +// +// 1 Sign bit +// | 11 Exponent bits +// | | 52 Mantissa (i.e. fraction) bits +// | | | +// S[Exponent-][Mantissa------------------------------------------] +// +// The details of how these are used to represent numbers aren't really +// relevant here as long we don't interfere with them. The important bit is NaN. +// +// An IEEE double can represent a few magical values like NaN ("not a number"), +// Infinity, and -Infinity. A NaN is any value where all exponent bits are set: +// +// v--NaN bits +// -11111111111---------------------------------------------------- +// +// Here, "-" means "doesn't matter". Any bit sequence that matches the above is +// a NaN. With all of those "-", it obvious there are a *lot* of different +// bit patterns that all mean the same thing. NaN tagging takes advantage of +// this. We'll use those available bit patterns to represent things other than +// numbers without giving up any valid numeric values. +// +// NaN values come in two flavors: "signalling" and "quiet". The former are +// intended to halt execution, while the latter just flow through arithmetic +// operations silently. We want the latter. Quiet NaNs are indicated by setting +// the highest mantissa bit: +// +// v--Highest mantissa bit +// -[NaN ]1--------------------------------------------------- +// +// If all of the NaN bits are set, it's not a number. Otherwise, it is. +// That leaves all of the remaining bits as available for us to play with. We +// stuff a few different kinds of things here: special singleton values like +// "true", "false", and "null", and pointers to objects allocated on the heap. +// We'll use the sign bit to distinguish singleton values from pointers. If +// it's set, it's a pointer. +// +// v--Pointer or singleton? +// S[NaN ]1--------------------------------------------------- +// +// For singleton values, we just enumerate the different values. We'll use the +// low bits of the mantissa for that, and only need a few: +// +// 3 Type bits--v +// 0[NaN ]1------------------------------------------------[T] +// +// For pointers, we are left with 51 bits of mantissa to store an address. +// That's more than enough room for a 32-bit address. Even 64-bit machines +// only actually use 48 bits for addresses, so we've got plenty. We just stuff +// the address right into the mantissa. +// +// Ta-da, double precision numbers, pointers, and a bunch of singleton values, +// all stuffed into a single 64-bit sequence. Even better, we don't have to +// do any masking or work to extract number values: they are unmodified. This +// means math on numbers is fast. +#if WREN_NAN_TAGGING + +// A mask that selects the sign bit. +#define SIGN_BIT ((uint64_t)1 << 63) + +// The bits that must be set to indicate a quiet NaN. +#define QNAN ((uint64_t)0x7ffc000000000000) + +// If the NaN bits are set, it's not a number. +#define IS_NUM(value) (((value) & QNAN) != QNAN) + +// An object pointer is a NaN with a set sign bit. +#define IS_OBJ(value) (((value) & (QNAN | SIGN_BIT)) == (QNAN | SIGN_BIT)) + +#define IS_FALSE(value) ((value) == FALSE_VAL) +#define IS_NULL(value) ((value) == NULL_VAL) +#define IS_UNDEFINED(value) ((value) == UNDEFINED_VAL) + +// Masks out the tag bits used to identify the singleton value. +#define MASK_TAG (7) + +// Tag values for the different singleton values. +#define TAG_NAN (0) +#define TAG_NULL (1) +#define TAG_FALSE (2) +#define TAG_TRUE (3) +#define TAG_UNDEFINED (4) +#define TAG_UNUSED2 (5) +#define TAG_UNUSED3 (6) +#define TAG_UNUSED4 (7) + +// Value -> 0 or 1. +#define AS_BOOL(value) ((value) == TRUE_VAL) + +// Value -> Obj*. +#define AS_OBJ(value) ((Obj*)(uintptr_t)((value) & ~(SIGN_BIT | QNAN))) + +// Singleton values. +#define NULL_VAL ((Value)(uint64_t)(QNAN | TAG_NULL)) +#define FALSE_VAL ((Value)(uint64_t)(QNAN | TAG_FALSE)) +#define TRUE_VAL ((Value)(uint64_t)(QNAN | TAG_TRUE)) +#define UNDEFINED_VAL ((Value)(uint64_t)(QNAN | TAG_UNDEFINED)) + +// Gets the singleton type tag for a Value (which must be a singleton). +#define GET_TAG(value) ((int)((value) & MASK_TAG)) + +#else + +// Value -> 0 or 1. +#define AS_BOOL(value) ((value).type == VAL_TRUE) + +// Value -> Obj*. +#define AS_OBJ(v) ((v).as.obj) + +// Determines if [value] is a garbage-collected object or not. +#define IS_OBJ(value) ((value).type == VAL_OBJ) + +#define IS_FALSE(value) ((value).type == VAL_FALSE) +#define IS_NULL(value) ((value).type == VAL_NULL) +#define IS_NUM(value) ((value).type == VAL_NUM) +#define IS_UNDEFINED(value) ((value).type == VAL_UNDEFINED) + +// Singleton values. +#define FALSE_VAL ((Value){ VAL_FALSE, { 0 } }) +#define NULL_VAL ((Value){ VAL_NULL, { 0 } }) +#define TRUE_VAL ((Value){ VAL_TRUE, { 0 } }) +#define UNDEFINED_VAL ((Value){ VAL_UNDEFINED, { 0 } }) + +#endif + +// Creates a new "raw" class. It has no metaclass or superclass whatsoever. +// This is only used for bootstrapping the initial Object and Class classes, +// which are a little special. +ObjClass* wrenNewSingleClass(WrenVM* vm, int numFields, ObjString* name); + +// Makes [superclass] the superclass of [subclass], and causes subclass to +// inherit its methods. This should be called before any methods are defined +// on subclass. +void wrenBindSuperclass(WrenVM* vm, ObjClass* subclass, ObjClass* superclass); + +// Creates a new class object as well as its associated metaclass. +ObjClass* wrenNewClass(WrenVM* vm, ObjClass* superclass, int numFields, + ObjString* name); + +void wrenBindMethod(WrenVM* vm, ObjClass* classObj, int symbol, Method method); + +// Creates a new closure object that invokes [fn]. Allocates room for its +// upvalues, but assumes outside code will populate it. +ObjClosure* wrenNewClosure(WrenVM* vm, ObjFn* fn); + +// Creates a new fiber object that will invoke [closure]. +ObjFiber* wrenNewFiber(WrenVM* vm, ObjClosure* closure); + +// Adds a new [CallFrame] to [fiber] invoking [closure] whose stack starts at +// [stackStart]. +static inline void wrenAppendCallFrame(WrenVM* vm, ObjFiber* fiber, + ObjClosure* closure, Value* stackStart) +{ + // The caller should have ensured we already have enough capacity. + ASSERT(fiber->frameCapacity > fiber->numFrames, "No memory for call frame."); + + CallFrame* frame = &fiber->frames[fiber->numFrames++]; + frame->stackStart = stackStart; + frame->closure = closure; + frame->ip = closure->fn->code.data; +} + +// Ensures [fiber]'s stack has at least [needed] slots. +void wrenEnsureStack(WrenVM* vm, ObjFiber* fiber, int needed); + +static inline bool wrenHasError(const ObjFiber* fiber) +{ + return !IS_NULL(fiber->error); +} + +ObjForeign* wrenNewForeign(WrenVM* vm, ObjClass* classObj, size_t size); + +// Creates a new empty function. Before being used, it must have code, +// constants, etc. added to it. +ObjFn* wrenNewFunction(WrenVM* vm, ObjModule* module, int maxSlots); + +void wrenFunctionBindName(WrenVM* vm, ObjFn* fn, const char* name, int length); + +// Creates a new instance of the given [classObj]. +Value wrenNewInstance(WrenVM* vm, ObjClass* classObj); + +// Creates a new list with [numElements] elements (which are left +// uninitialized.) +ObjList* wrenNewList(WrenVM* vm, uint32_t numElements); + +// Inserts [value] in [list] at [index], shifting down the other elements. +void wrenListInsert(WrenVM* vm, ObjList* list, Value value, uint32_t index); + +// Removes and returns the item at [index] from [list]. +Value wrenListRemoveAt(WrenVM* vm, ObjList* list, uint32_t index); + +// Searches for [value] in [list], returns the index or -1 if not found. +int wrenListIndexOf(WrenVM* vm, ObjList* list, Value value); + +// Creates a new empty map. +ObjMap* wrenNewMap(WrenVM* vm); + +// Validates that [arg] is a valid object for use as a map key. Returns true if +// it is and returns false otherwise. Use validateKey usually, for a runtime error. +// This separation exists to aid the API in surfacing errors to the developer as well. +static inline bool wrenMapIsValidKey(Value arg); + +// Looks up [key] in [map]. If found, returns the value. Otherwise, returns +// `UNDEFINED_VAL`. +Value wrenMapGet(ObjMap* map, Value key); + +// Associates [key] with [value] in [map]. +void wrenMapSet(WrenVM* vm, ObjMap* map, Value key, Value value); + +void wrenMapClear(WrenVM* vm, ObjMap* map); + +// Removes [key] from [map], if present. Returns the value for the key if found +// or `NULL_VAL` otherwise. +Value wrenMapRemoveKey(WrenVM* vm, ObjMap* map, Value key); + +// Creates a new module. +ObjModule* wrenNewModule(WrenVM* vm, ObjString* name); + +// Creates a new range from [from] to [to]. +Value wrenNewRange(WrenVM* vm, double from, double to, bool isInclusive); + +// Creates a new string object and copies [text] into it. +// +// [text] must be non-NULL. +Value wrenNewString(WrenVM* vm, const char* text); + +// Creates a new string object of [length] and copies [text] into it. +// +// [text] may be NULL if [length] is zero. +Value wrenNewStringLength(WrenVM* vm, const char* text, size_t length); + +// Creates a new string object by taking a range of characters from [source]. +// The range starts at [start], contains [count] bytes, and increments by +// [step]. +Value wrenNewStringFromRange(WrenVM* vm, ObjString* source, int start, + uint32_t count, int step); + +// Produces a string representation of [value]. +Value wrenNumToString(WrenVM* vm, double value); + +// Creates a new formatted string from [format] and any additional arguments +// used in the format string. +// +// This is a very restricted flavor of formatting, intended only for internal +// use by the VM. Two formatting characters are supported, each of which reads +// the next argument as a certain type: +// +// $ - A C string. +// @ - A Wren string object. +Value wrenStringFormat(WrenVM* vm, const char* format, ...); + +// Creates a new string containing the UTF-8 encoding of [value]. +Value wrenStringFromCodePoint(WrenVM* vm, int value); + +// Creates a new string from the integer representation of a byte +Value wrenStringFromByte(WrenVM* vm, uint8_t value); + +// Creates a new string containing the code point in [string] starting at byte +// [index]. If [index] points into the middle of a UTF-8 sequence, returns an +// empty string. +Value wrenStringCodePointAt(WrenVM* vm, ObjString* string, uint32_t index); + +// Search for the first occurence of [needle] within [haystack] and returns its +// zero-based offset. Returns `UINT32_MAX` if [haystack] does not contain +// [needle]. +uint32_t wrenStringFind(ObjString* haystack, ObjString* needle, + uint32_t startIndex); + +// Returns true if [a] and [b] represent the same string. +static inline bool wrenStringEqualsCString(const ObjString* a, + const char* b, size_t length) +{ + return a->length == length && memcmp(a->value, b, length) == 0; +} + +// Creates a new open upvalue pointing to [value] on the stack. +ObjUpvalue* wrenNewUpvalue(WrenVM* vm, Value* value); + +// Mark [obj] as reachable and still in use. This should only be called +// during the sweep phase of a garbage collection. +void wrenGrayObj(WrenVM* vm, Obj* obj); + +// Mark [value] as reachable and still in use. This should only be called +// during the sweep phase of a garbage collection. +void wrenGrayValue(WrenVM* vm, Value value); + +// Mark the values in [buffer] as reachable and still in use. This should only +// be called during the sweep phase of a garbage collection. +void wrenGrayBuffer(WrenVM* vm, ValueBuffer* buffer); + +// Processes every object in the gray stack until all reachable objects have +// been marked. After that, all objects are either white (freeable) or black +// (in use and fully traversed). +void wrenBlackenObjects(WrenVM* vm); + +// Releases all memory owned by [obj], including [obj] itself. +void wrenFreeObj(WrenVM* vm, Obj* obj); + +// Returns the class of [value]. +// +// Unlike wrenGetClassInline in wren_vm.h, this is not inlined. Inlining helps +// performance (significantly) in some cases, but degrades it in others. The +// ones used by the implementation were chosen to give the best results in the +// benchmarks. +ObjClass* wrenGetClass(WrenVM* vm, Value value); + +// Returns true if [a] and [b] are strictly the same value. This is identity +// for object values, and value equality for unboxed values. +static inline bool wrenValuesSame(Value a, Value b) +{ +#if WREN_NAN_TAGGING + // Value types have unique bit representations and we compare object types + // by identity (i.e. pointer), so all we need to do is compare the bits. + return a == b; +#else + if (a.type != b.type) return false; + if (a.type == VAL_NUM) return a.as.num == b.as.num; + return a.as.obj == b.as.obj; +#endif +} + +// Returns true if [a] and [b] are equivalent. Immutable values (null, bools, +// numbers, ranges, and strings) are equal if they have the same data. All +// other values are equal if they are identical objects. +bool wrenValuesEqual(Value a, Value b); + +// Returns true if [value] is a bool. Do not call this directly, instead use +// [IS_BOOL]. +static inline bool wrenIsBool(Value value) +{ +#if WREN_NAN_TAGGING + return value == TRUE_VAL || value == FALSE_VAL; +#else + return value.type == VAL_FALSE || value.type == VAL_TRUE; +#endif +} + +// Returns true if [value] is an object of type [type]. Do not call this +// directly, instead use the [IS___] macro for the type in question. +static inline bool wrenIsObjType(Value value, ObjType type) +{ + return IS_OBJ(value) && AS_OBJ(value)->type == type; +} + +// Converts the raw object pointer [obj] to a [Value]. +static inline Value wrenObjectToValue(Obj* obj) +{ +#if WREN_NAN_TAGGING + // The triple casting is necessary here to satisfy some compilers: + // 1. (uintptr_t) Convert the pointer to a number of the right size. + // 2. (uint64_t) Pad it up to 64 bits in 32-bit builds. + // 3. Or in the bits to make a tagged Nan. + // 4. Cast to a typedef'd value. + return (Value)(SIGN_BIT | QNAN | (uint64_t)(uintptr_t)(obj)); +#else + Value value; + value.type = VAL_OBJ; + value.as.obj = obj; + return value; +#endif +} + +// Interprets [value] as a [double]. +static inline double wrenValueToNum(Value value) +{ +#if WREN_NAN_TAGGING + return wrenDoubleFromBits(value); +#else + return value.as.num; +#endif +} + +// Converts [num] to a [Value]. +static inline Value wrenNumToValue(double num) +{ +#if WREN_NAN_TAGGING + return wrenDoubleToBits(num); +#else + Value value; + value.type = VAL_NUM; + value.as.num = num; + return value; +#endif +} + +static inline bool wrenMapIsValidKey(Value arg) +{ + return IS_BOOL(arg) + || IS_CLASS(arg) + || IS_NULL(arg) + || IS_NUM(arg) + || IS_RANGE(arg) + || IS_STRING(arg); +} + +#endif +// End file "wren_value.h" +// Begin file "wren_vm.h" +#ifndef wren_vm_h +#define wren_vm_h + +// Begin file "wren_compiler.h" +#ifndef wren_compiler_h +#define wren_compiler_h + + +typedef struct sCompiler Compiler; + +// This module defines the compiler for Wren. It takes a string of source code +// and lexes, parses, and compiles it. Wren uses a single-pass compiler. It +// does not build an actual AST during parsing and then consume that to +// generate code. Instead, the parser directly emits bytecode. +// +// This forces a few restrictions on the grammar and semantics of the language. +// Things like forward references and arbitrary lookahead are much harder. We +// get a lot in return for that, though. +// +// The implementation is much simpler since we don't need to define a bunch of +// AST data structures. More so, we don't have to deal with managing memory for +// AST objects. The compiler does almost no dynamic allocation while running. +// +// Compilation is also faster since we don't create a bunch of temporary data +// structures and destroy them after generating code. + +// Compiles [source], a string of Wren source code located in [module], to an +// [ObjFn] that will execute that code when invoked. Returns `NULL` if the +// source contains any syntax errors. +// +// If [isExpression] is `true`, [source] should be a single expression, and +// this compiles it to a function that evaluates and returns that expression. +// Otherwise, [source] should be a series of top level statements. +// +// If [printErrors] is `true`, any compile errors are output to stderr. +// Otherwise, they are silently discarded. +ObjFn* wrenCompile(WrenVM* vm, ObjModule* module, const char* source, + bool isExpression, bool printErrors); + +// When a class is defined, its superclass is not known until runtime since +// class definitions are just imperative statements. Most of the bytecode for a +// a method doesn't care, but there are two places where it matters: +// +// - To load or store a field, we need to know the index of the field in the +// instance's field array. We need to adjust this so that subclass fields +// are positioned after superclass fields, and we don't know this until the +// superclass is known. +// +// - Superclass calls need to know which superclass to dispatch to. +// +// We could handle this dynamically, but that adds overhead. Instead, when a +// method is bound, we walk the bytecode for the function and patch it up. +void wrenBindMethodCode(ObjClass* classObj, ObjFn* fn); + +// Reaches all of the heap-allocated objects in use by [compiler] (and all of +// its parents) so that they are not collected by the GC. +void wrenMarkCompiler(WrenVM* vm, Compiler* compiler); + +#endif +// End file "wren_compiler.h" + +// The maximum number of temporary objects that can be made visible to the GC +// at one time. +#define WREN_MAX_TEMP_ROOTS 8 + +typedef enum +{ + #define OPCODE(name, _) CODE_##name, +// Begin file "wren_opcodes.h" +// This defines the bytecode instructions used by the VM. It does so by invoking +// an OPCODE() macro which is expected to be defined at the point that this is +// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.) +// +// The first argument is the name of the opcode. The second is its "stack +// effect" -- the amount that the op code changes the size of the stack. A +// stack effect of 1 means it pushes a value and the stack grows one larger. +// -2 means it pops two values, etc. +// +// Note that the order of instructions here affects the order of the dispatch +// table in the VM's interpreter loop. That in turn affects caching which +// affects overall performance. Take care to run benchmarks if you change the +// order here. + +// Load the constant at index [arg]. +OPCODE(CONSTANT, 1) + +// Push null onto the stack. +OPCODE(NULL, 1) + +// Push false onto the stack. +OPCODE(FALSE, 1) + +// Push true onto the stack. +OPCODE(TRUE, 1) + +// Pushes the value in the given local slot. +OPCODE(LOAD_LOCAL_0, 1) +OPCODE(LOAD_LOCAL_1, 1) +OPCODE(LOAD_LOCAL_2, 1) +OPCODE(LOAD_LOCAL_3, 1) +OPCODE(LOAD_LOCAL_4, 1) +OPCODE(LOAD_LOCAL_5, 1) +OPCODE(LOAD_LOCAL_6, 1) +OPCODE(LOAD_LOCAL_7, 1) +OPCODE(LOAD_LOCAL_8, 1) + +// Note: The compiler assumes the following _STORE instructions always +// immediately follow their corresponding _LOAD ones. + +// Pushes the value in local slot [arg]. +OPCODE(LOAD_LOCAL, 1) + +// Stores the top of stack in local slot [arg]. Does not pop it. +OPCODE(STORE_LOCAL, 0) + +// Pushes the value in upvalue [arg]. +OPCODE(LOAD_UPVALUE, 1) + +// Stores the top of stack in upvalue [arg]. Does not pop it. +OPCODE(STORE_UPVALUE, 0) + +// Pushes the value of the top-level variable in slot [arg]. +OPCODE(LOAD_MODULE_VAR, 1) + +// Stores the top of stack in top-level variable slot [arg]. Does not pop it. +OPCODE(STORE_MODULE_VAR, 0) + +// Pushes the value of the field in slot [arg] of the receiver of the current +// function. This is used for regular field accesses on "this" directly in +// methods. This instruction is faster than the more general CODE_LOAD_FIELD +// instruction. +OPCODE(LOAD_FIELD_THIS, 1) + +// Stores the top of the stack in field slot [arg] in the receiver of the +// current value. Does not pop the value. This instruction is faster than the +// more general CODE_LOAD_FIELD instruction. +OPCODE(STORE_FIELD_THIS, 0) + +// Pops an instance and pushes the value of the field in slot [arg] of it. +OPCODE(LOAD_FIELD, 0) + +// Pops an instance and stores the subsequent top of stack in field slot +// [arg] in it. Does not pop the value. +OPCODE(STORE_FIELD, -1) + +// Pop and discard the top of stack. +OPCODE(POP, -1) + +// Invoke the method with symbol [arg]. The number indicates the number of +// arguments (not including the receiver). +OPCODE(CALL_0, 0) +OPCODE(CALL_1, -1) +OPCODE(CALL_2, -2) +OPCODE(CALL_3, -3) +OPCODE(CALL_4, -4) +OPCODE(CALL_5, -5) +OPCODE(CALL_6, -6) +OPCODE(CALL_7, -7) +OPCODE(CALL_8, -8) +OPCODE(CALL_9, -9) +OPCODE(CALL_10, -10) +OPCODE(CALL_11, -11) +OPCODE(CALL_12, -12) +OPCODE(CALL_13, -13) +OPCODE(CALL_14, -14) +OPCODE(CALL_15, -15) +OPCODE(CALL_16, -16) + +// Invoke a superclass method with symbol [arg]. The number indicates the +// number of arguments (not including the receiver). +OPCODE(SUPER_0, 0) +OPCODE(SUPER_1, -1) +OPCODE(SUPER_2, -2) +OPCODE(SUPER_3, -3) +OPCODE(SUPER_4, -4) +OPCODE(SUPER_5, -5) +OPCODE(SUPER_6, -6) +OPCODE(SUPER_7, -7) +OPCODE(SUPER_8, -8) +OPCODE(SUPER_9, -9) +OPCODE(SUPER_10, -10) +OPCODE(SUPER_11, -11) +OPCODE(SUPER_12, -12) +OPCODE(SUPER_13, -13) +OPCODE(SUPER_14, -14) +OPCODE(SUPER_15, -15) +OPCODE(SUPER_16, -16) + +// Jump the instruction pointer [arg] forward. +OPCODE(JUMP, 0) + +// Jump the instruction pointer [arg] backward. +OPCODE(LOOP, 0) + +// Pop and if not truthy then jump the instruction pointer [arg] forward. +OPCODE(JUMP_IF, -1) + +// If the top of the stack is false, jump [arg] forward. Otherwise, pop and +// continue. +OPCODE(AND, -1) + +// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop +// and continue. +OPCODE(OR, -1) + +// Close the upvalue for the local on the top of the stack, then pop it. +OPCODE(CLOSE_UPVALUE, -1) + +// Exit from the current function and return the value on the top of the +// stack. +OPCODE(RETURN, 0) + +// Creates a closure for the function stored at [arg] in the constant table. +// +// Following the function argument is a number of arguments, two for each +// upvalue. The first is true if the variable being captured is a local (as +// opposed to an upvalue), and the second is the index of the local or +// upvalue being captured. +// +// Pushes the created closure. +OPCODE(CLOSURE, 1) + +// Creates a new instance of a class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(CONSTRUCT, 0) + +// Creates a new instance of a foreign class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(FOREIGN_CONSTRUCT, 0) + +// Creates a class. Top of stack is the superclass. Below that is a string for +// the name of the class. Byte [arg] is the number of fields in the class. +OPCODE(CLASS, -1) + +// Ends a class. +// Atm the stack contains the class and the ClassAttributes (or null). +OPCODE(END_CLASS, -2) + +// Creates a foreign class. Top of stack is the superclass. Below that is a +// string for the name of the class. +OPCODE(FOREIGN_CLASS, -1) + +// Define a method for symbol [arg]. The class receiving the method is popped +// off the stack, then the function defining the body is popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_INSTANCE, -2) + +// Define a method for symbol [arg]. The class whose metaclass will receive +// the method is popped off the stack, then the function defining the body is +// popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_STATIC, -2) + +// This is executed at the end of the module's body. Pushes NULL onto the stack +// as the "return value" of the import statement and stores the module as the +// most recently imported one. +OPCODE(END_MODULE, 1) + +// Import a module whose name is the string stored at [arg] in the constant +// table. +// +// Pushes null onto the stack so that the fiber for the imported module can +// replace that with a dummy value when it returns. (Fibers always return a +// value when resuming a caller.) +OPCODE(IMPORT_MODULE, 1) + +// Import a variable from the most recently imported module. The name of the +// variable to import is at [arg] in the constant table. Pushes the loaded +// variable's value. +OPCODE(IMPORT_VARIABLE, 1) + +// This pseudo-instruction indicates the end of the bytecode. It should +// always be preceded by a `CODE_RETURN`, so is never actually executed. +OPCODE(END, 0) +// End file "wren_opcodes.h" + #undef OPCODE +} Code; + +// A handle to a value, basically just a linked list of extra GC roots. +// +// Note that even non-heap-allocated values can be stored here. +struct WrenHandle +{ + Value value; + + WrenHandle* prev; + WrenHandle* next; +}; + +struct WrenVM +{ + ObjClass* boolClass; + ObjClass* classClass; + ObjClass* fiberClass; + ObjClass* fnClass; + ObjClass* listClass; + ObjClass* mapClass; + ObjClass* nullClass; + ObjClass* numClass; + ObjClass* objectClass; + ObjClass* rangeClass; + ObjClass* stringClass; + + // The fiber that is currently running. + ObjFiber* fiber; + + // The loaded modules. Each key is an ObjString (except for the main module, + // whose key is null) for the module's name and the value is the ObjModule + // for the module. + ObjMap* modules; + + // The most recently imported module. More specifically, the module whose + // code has most recently finished executing. + // + // Not treated like a GC root since the module is already in [modules]. + ObjModule* lastModule; + + // Memory management data: + + // The number of bytes that are known to be currently allocated. Includes all + // memory that was proven live after the last GC, as well as any new bytes + // that were allocated since then. Does *not* include bytes for objects that + // were freed since the last GC. + size_t bytesAllocated; + + // The number of total allocated bytes that will trigger the next GC. + size_t nextGC; + + // The first object in the linked list of all currently allocated objects. + Obj* first; + + // The "gray" set for the garbage collector. This is the stack of unprocessed + // objects while a garbage collection pass is in process. + Obj** gray; + int grayCount; + int grayCapacity; + + // The list of temporary roots. This is for temporary or new objects that are + // not otherwise reachable but should not be collected. + // + // They are organized as a stack of pointers stored in this array. This + // implies that temporary roots need to have stack semantics: only the most + // recently pushed object can be released. + Obj* tempRoots[WREN_MAX_TEMP_ROOTS]; + + int numTempRoots; + + // Pointer to the first node in the linked list of active handles or NULL if + // there are none. + WrenHandle* handles; + + // Pointer to the bottom of the range of stack slots available for use from + // the C API. During a foreign method, this will be in the stack of the fiber + // that is executing a method. + // + // If not in a foreign method, this is initially NULL. If the user requests + // slots by calling wrenEnsureSlots(), a stack is created and this is + // initialized. + Value* apiStack; + + WrenConfiguration config; + + // Compiler and debugger data: + + // The compiler that is currently compiling code. This is used so that heap + // allocated objects used by the compiler can be found if a GC is kicked off + // in the middle of a compile. + Compiler* compiler; + + // There is a single global symbol table for all method names on all classes. + // Method calls are dispatched directly by index in this table. + SymbolTable methodNames; +}; + +// A generic allocation function that handles all explicit memory management. +// It's used like so: +// +// - To allocate new memory, [memory] is NULL and [oldSize] is zero. It should +// return the allocated memory or NULL on failure. +// +// - To attempt to grow an existing allocation, [memory] is the memory, +// [oldSize] is its previous size, and [newSize] is the desired size. +// It should return [memory] if it was able to grow it in place, or a new +// pointer if it had to move it. +// +// - To shrink memory, [memory], [oldSize], and [newSize] are the same as above +// but it will always return [memory]. +// +// - To free memory, [memory] will be the memory to free and [newSize] and +// [oldSize] will be zero. It should return NULL. +void* wrenReallocate(WrenVM* vm, void* memory, size_t oldSize, size_t newSize); + +// Invoke the finalizer for the foreign object referenced by [foreign]. +void wrenFinalizeForeign(WrenVM* vm, ObjForeign* foreign); + +// Creates a new [WrenHandle] for [value]. +WrenHandle* wrenMakeHandle(WrenVM* vm, Value value); + +// Compile [source] in the context of [module] and wrap in a fiber that can +// execute it. +// +// Returns NULL if a compile error occurred. +ObjClosure* wrenCompileSource(WrenVM* vm, const char* module, + const char* source, bool isExpression, + bool printErrors); + +// Looks up a variable from a previously-loaded module. +// +// Aborts the current fiber if the module or variable could not be found. +Value wrenGetModuleVariable(WrenVM* vm, Value moduleName, Value variableName); + +// Returns the value of the module-level variable named [name] in the main +// module. +Value wrenFindVariable(WrenVM* vm, ObjModule* module, const char* name); + +// Adds a new implicitly declared top-level variable named [name] to [module] +// based on a use site occurring on [line]. +// +// Does not check to see if a variable with that name is already declared or +// defined. Returns the symbol for the new variable or -2 if there are too many +// variables defined. +int wrenDeclareVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, int line); + +// Adds a new top-level variable named [name] to [module], and optionally +// populates line with the line of the implicit first use (line can be NULL). +// +// Returns the symbol for the new variable, -1 if a variable with the given name +// is already defined, or -2 if there are too many variables defined. +// Returns -3 if this is a top-level lowercase variable (localname) that was +// used before being defined. +int wrenDefineVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, Value value, int* line); + +// Pushes [closure] onto [fiber]'s callstack to invoke it. Expects [numArgs] +// arguments (including the receiver) to be on the top of the stack already. +static inline void wrenCallFunction(WrenVM* vm, ObjFiber* fiber, + ObjClosure* closure, int numArgs) +{ + // Grow the call frame array if needed. + if (fiber->numFrames + 1 > fiber->frameCapacity) + { + int max = fiber->frameCapacity * 2; + fiber->frames = (CallFrame*)wrenReallocate(vm, fiber->frames, + sizeof(CallFrame) * fiber->frameCapacity, sizeof(CallFrame) * max); + fiber->frameCapacity = max; + } + + // Grow the stack if needed. + int stackSize = (int)(fiber->stackTop - fiber->stack); + int needed = stackSize + closure->fn->maxSlots; + wrenEnsureStack(vm, fiber, needed); + + wrenAppendCallFrame(vm, fiber, closure, fiber->stackTop - numArgs); +} + +// Marks [obj] as a GC root so that it doesn't get collected. +void wrenPushRoot(WrenVM* vm, Obj* obj); + +// Removes the most recently pushed temporary root. +void wrenPopRoot(WrenVM* vm); + +// Returns the class of [value]. +// +// Defined here instead of in wren_value.h because it's critical that this be +// inlined. That means it must be defined in the header, but the wren_value.h +// header doesn't have a full definitely of WrenVM yet. +static inline ObjClass* wrenGetClassInline(WrenVM* vm, Value value) +{ + if (IS_NUM(value)) return vm->numClass; + if (IS_OBJ(value)) return AS_OBJ(value)->classObj; + +#if WREN_NAN_TAGGING + switch (GET_TAG(value)) + { + case TAG_FALSE: return vm->boolClass; break; + case TAG_NAN: return vm->numClass; break; + case TAG_NULL: return vm->nullClass; break; + case TAG_TRUE: return vm->boolClass; break; + case TAG_UNDEFINED: UNREACHABLE(); + } +#else + switch (value.type) + { + case VAL_FALSE: return vm->boolClass; + case VAL_NULL: return vm->nullClass; + case VAL_NUM: return vm->numClass; + case VAL_TRUE: return vm->boolClass; + case VAL_OBJ: return AS_OBJ(value)->classObj; + case VAL_UNDEFINED: UNREACHABLE(); + } +#endif + + UNREACHABLE(); + return NULL; +} + +// Returns `true` if [name] is a local variable name (starts with a lowercase +// letter). +static inline bool wrenIsLocalName(const char* name) +{ + return name[0] >= 'a' && name[0] <= 'z'; +} + +static inline bool wrenIsFalsyValue(Value value) +{ + return IS_FALSE(value) || IS_NULL(value); +} + +#endif +// End file "wren_vm.h" + +// Prints the stack trace for the current fiber. +// +// Used when a fiber throws a runtime error which is not caught. +void wrenDebugPrintStackTrace(WrenVM* vm); + +// The "dump" functions are used for debugging Wren itself. Normal code paths +// will not call them unless one of the various DEBUG_ flags is enabled. + +// Prints a representation of [value] to stdout. +void wrenDumpValue(Value value); + +// Prints a representation of the bytecode for [fn] at instruction [i]. +int wrenDumpInstruction(WrenVM* vm, ObjFn* fn, int i); + +// Prints the disassembled code for [fn] to stdout. +void wrenDumpCode(WrenVM* vm, ObjFn* fn); + +// Prints the contents of the current stack for [fiber] to stdout. +void wrenDumpStack(ObjFiber* fiber); + +#endif +// End file "wren_debug.h" +// Begin file "wren_debug.c" +#include + + +void wrenDebugPrintStackTrace(WrenVM* vm) +{ + // Bail if the host doesn't enable printing errors. + if (vm->config.errorFn == NULL) return; + + ObjFiber* fiber = vm->fiber; + if (IS_STRING(fiber->error)) + { + vm->config.errorFn(vm, WREN_ERROR_RUNTIME, + NULL, -1, AS_CSTRING(fiber->error)); + } + else + { + // TODO: Print something a little useful here. Maybe the name of the error's + // class? + vm->config.errorFn(vm, WREN_ERROR_RUNTIME, + NULL, -1, "[error object]"); + } + + for (int i = fiber->numFrames - 1; i >= 0; i--) + { + CallFrame* frame = &fiber->frames[i]; + ObjFn* fn = frame->closure->fn; + + // Skip over stub functions for calling methods from the C API. + if (fn->module == NULL) continue; + + // The built-in core module has no name. We explicitly omit it from stack + // traces since we don't want to highlight to a user the implementation + // detail of what part of the core module is written in C and what is Wren. + if (fn->module->name == NULL) continue; + + // -1 because IP has advanced past the instruction that it just executed. + int line = fn->debug->sourceLines.data[frame->ip - fn->code.data - 1]; + vm->config.errorFn(vm, WREN_ERROR_STACK_TRACE, + fn->module->name->value, line, + fn->debug->name); + } +} + +static void dumpObject(Obj* obj) +{ + switch (obj->type) + { + case OBJ_CLASS: + printf("[class %s %p]", ((ObjClass*)obj)->name->value, obj); + break; + case OBJ_CLOSURE: printf("[closure %p]", obj); break; + case OBJ_FIBER: printf("[fiber %p]", obj); break; + case OBJ_FN: printf("[fn %p]", obj); break; + case OBJ_FOREIGN: printf("[foreign %p]", obj); break; + case OBJ_INSTANCE: printf("[instance %p]", obj); break; + case OBJ_LIST: printf("[list %p]", obj); break; + case OBJ_MAP: printf("[map %p]", obj); break; + case OBJ_MODULE: printf("[module %p]", obj); break; + case OBJ_RANGE: printf("[range %p]", obj); break; + case OBJ_STRING: printf("%s", ((ObjString*)obj)->value); break; + case OBJ_UPVALUE: printf("[upvalue %p]", obj); break; + default: printf("[unknown object %d]", obj->type); break; + } +} + +void wrenDumpValue(Value value) +{ +#if WREN_NAN_TAGGING + if (IS_NUM(value)) + { + printf("%.14g", AS_NUM(value)); + } + else if (IS_OBJ(value)) + { + dumpObject(AS_OBJ(value)); + } + else + { + switch (GET_TAG(value)) + { + case TAG_FALSE: printf("false"); break; + case TAG_NAN: printf("NaN"); break; + case TAG_NULL: printf("null"); break; + case TAG_TRUE: printf("true"); break; + case TAG_UNDEFINED: UNREACHABLE(); + } + } +#else + switch (value.type) + { + case VAL_FALSE: printf("false"); break; + case VAL_NULL: printf("null"); break; + case VAL_NUM: printf("%.14g", AS_NUM(value)); break; + case VAL_TRUE: printf("true"); break; + case VAL_OBJ: dumpObject(AS_OBJ(value)); break; + case VAL_UNDEFINED: UNREACHABLE(); + } +#endif +} + +static int dumpInstruction(WrenVM* vm, ObjFn* fn, int i, int* lastLine) +{ + int start = i; + uint8_t* bytecode = fn->code.data; + Code code = (Code)bytecode[i]; + + int line = fn->debug->sourceLines.data[i]; + if (lastLine == NULL || *lastLine != line) + { + printf("%4d:", line); + if (lastLine != NULL) *lastLine = line; + } + else + { + printf(" "); + } + + printf(" %04d ", i++); + + #define READ_BYTE() (bytecode[i++]) + #define READ_SHORT() (i += 2, (bytecode[i - 2] << 8) | bytecode[i - 1]) + + #define BYTE_INSTRUCTION(name) \ + printf("%-16s %5d\n", name, READ_BYTE()); \ + break + + switch (code) + { + case CODE_CONSTANT: + { + int constant = READ_SHORT(); + printf("%-16s %5d '", "CONSTANT", constant); + wrenDumpValue(fn->constants.data[constant]); + printf("'\n"); + break; + } + + case CODE_NULL: printf("NULL\n"); break; + case CODE_FALSE: printf("FALSE\n"); break; + case CODE_TRUE: printf("TRUE\n"); break; + + case CODE_LOAD_LOCAL_0: printf("LOAD_LOCAL_0\n"); break; + case CODE_LOAD_LOCAL_1: printf("LOAD_LOCAL_1\n"); break; + case CODE_LOAD_LOCAL_2: printf("LOAD_LOCAL_2\n"); break; + case CODE_LOAD_LOCAL_3: printf("LOAD_LOCAL_3\n"); break; + case CODE_LOAD_LOCAL_4: printf("LOAD_LOCAL_4\n"); break; + case CODE_LOAD_LOCAL_5: printf("LOAD_LOCAL_5\n"); break; + case CODE_LOAD_LOCAL_6: printf("LOAD_LOCAL_6\n"); break; + case CODE_LOAD_LOCAL_7: printf("LOAD_LOCAL_7\n"); break; + case CODE_LOAD_LOCAL_8: printf("LOAD_LOCAL_8\n"); break; + + case CODE_LOAD_LOCAL: BYTE_INSTRUCTION("LOAD_LOCAL"); + case CODE_STORE_LOCAL: BYTE_INSTRUCTION("STORE_LOCAL"); + case CODE_LOAD_UPVALUE: BYTE_INSTRUCTION("LOAD_UPVALUE"); + case CODE_STORE_UPVALUE: BYTE_INSTRUCTION("STORE_UPVALUE"); + + case CODE_LOAD_MODULE_VAR: + { + int slot = READ_SHORT(); + printf("%-16s %5d '%s'\n", "LOAD_MODULE_VAR", slot, + fn->module->variableNames.data[slot]->value); + break; + } + + case CODE_STORE_MODULE_VAR: + { + int slot = READ_SHORT(); + printf("%-16s %5d '%s'\n", "STORE_MODULE_VAR", slot, + fn->module->variableNames.data[slot]->value); + break; + } + + case CODE_LOAD_FIELD_THIS: BYTE_INSTRUCTION("LOAD_FIELD_THIS"); + case CODE_STORE_FIELD_THIS: BYTE_INSTRUCTION("STORE_FIELD_THIS"); + case CODE_LOAD_FIELD: BYTE_INSTRUCTION("LOAD_FIELD"); + case CODE_STORE_FIELD: BYTE_INSTRUCTION("STORE_FIELD"); + + case CODE_POP: printf("POP\n"); break; + + case CODE_CALL_0: + case CODE_CALL_1: + case CODE_CALL_2: + case CODE_CALL_3: + case CODE_CALL_4: + case CODE_CALL_5: + case CODE_CALL_6: + case CODE_CALL_7: + case CODE_CALL_8: + case CODE_CALL_9: + case CODE_CALL_10: + case CODE_CALL_11: + case CODE_CALL_12: + case CODE_CALL_13: + case CODE_CALL_14: + case CODE_CALL_15: + case CODE_CALL_16: + { + int numArgs = bytecode[i - 1] - CODE_CALL_0; + int symbol = READ_SHORT(); + printf("CALL_%-11d %5d '%s'\n", numArgs, symbol, + vm->methodNames.data[symbol]->value); + break; + } + + case CODE_SUPER_0: + case CODE_SUPER_1: + case CODE_SUPER_2: + case CODE_SUPER_3: + case CODE_SUPER_4: + case CODE_SUPER_5: + case CODE_SUPER_6: + case CODE_SUPER_7: + case CODE_SUPER_8: + case CODE_SUPER_9: + case CODE_SUPER_10: + case CODE_SUPER_11: + case CODE_SUPER_12: + case CODE_SUPER_13: + case CODE_SUPER_14: + case CODE_SUPER_15: + case CODE_SUPER_16: + { + int numArgs = bytecode[i - 1] - CODE_SUPER_0; + int symbol = READ_SHORT(); + int superclass = READ_SHORT(); + printf("SUPER_%-10d %5d '%s' %5d\n", numArgs, symbol, + vm->methodNames.data[symbol]->value, superclass); + break; + } + + case CODE_JUMP: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "JUMP", offset, i + offset); + break; + } + + case CODE_LOOP: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "LOOP", offset, i - offset); + break; + } + + case CODE_JUMP_IF: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "JUMP_IF", offset, i + offset); + break; + } + + case CODE_AND: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "AND", offset, i + offset); + break; + } + + case CODE_OR: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "OR", offset, i + offset); + break; + } + + case CODE_CLOSE_UPVALUE: printf("CLOSE_UPVALUE\n"); break; + case CODE_RETURN: printf("RETURN\n"); break; + + case CODE_CLOSURE: + { + int constant = READ_SHORT(); + printf("%-16s %5d ", "CLOSURE", constant); + wrenDumpValue(fn->constants.data[constant]); + printf(" "); + ObjFn* loadedFn = AS_FN(fn->constants.data[constant]); + for (int j = 0; j < loadedFn->numUpvalues; j++) + { + int isLocal = READ_BYTE(); + int index = READ_BYTE(); + if (j > 0) printf(", "); + printf("%s %d", isLocal ? "local" : "upvalue", index); + } + printf("\n"); + break; + } + + case CODE_CONSTRUCT: printf("CONSTRUCT\n"); break; + case CODE_FOREIGN_CONSTRUCT: printf("FOREIGN_CONSTRUCT\n"); break; + + case CODE_CLASS: + { + int numFields = READ_BYTE(); + printf("%-16s %5d fields\n", "CLASS", numFields); + break; + } + + case CODE_FOREIGN_CLASS: printf("FOREIGN_CLASS\n"); break; + case CODE_END_CLASS: printf("END_CLASS\n"); break; + + case CODE_METHOD_INSTANCE: + { + int symbol = READ_SHORT(); + printf("%-16s %5d '%s'\n", "METHOD_INSTANCE", symbol, + vm->methodNames.data[symbol]->value); + break; + } + + case CODE_METHOD_STATIC: + { + int symbol = READ_SHORT(); + printf("%-16s %5d '%s'\n", "METHOD_STATIC", symbol, + vm->methodNames.data[symbol]->value); + break; + } + + case CODE_END_MODULE: + printf("END_MODULE\n"); + break; + + case CODE_IMPORT_MODULE: + { + int name = READ_SHORT(); + printf("%-16s %5d '", "IMPORT_MODULE", name); + wrenDumpValue(fn->constants.data[name]); + printf("'\n"); + break; + } + + case CODE_IMPORT_VARIABLE: + { + int variable = READ_SHORT(); + printf("%-16s %5d '", "IMPORT_VARIABLE", variable); + wrenDumpValue(fn->constants.data[variable]); + printf("'\n"); + break; + } + + case CODE_END: + printf("END\n"); + break; + + default: + printf("UKNOWN! [%d]\n", bytecode[i - 1]); + break; + } + + // Return how many bytes this instruction takes, or -1 if it's an END. + if (code == CODE_END) return -1; + return i - start; + + #undef READ_BYTE + #undef READ_SHORT +} + +int wrenDumpInstruction(WrenVM* vm, ObjFn* fn, int i) +{ + return dumpInstruction(vm, fn, i, NULL); +} + +void wrenDumpCode(WrenVM* vm, ObjFn* fn) +{ + printf("%s: %s\n", + fn->module->name == NULL ? "" : fn->module->name->value, + fn->debug->name); + + int i = 0; + int lastLine = -1; + for (;;) + { + int offset = dumpInstruction(vm, fn, i, &lastLine); + if (offset == -1) break; + i += offset; + } + + printf("\n"); +} + +void wrenDumpStack(ObjFiber* fiber) +{ + printf("(fiber %p) ", fiber); + for (Value* slot = fiber->stack; slot < fiber->stackTop; slot++) + { + wrenDumpValue(*slot); + printf(" | "); + } + printf("\n"); +} +// End file "wren_debug.c" +// Begin file "wren_compiler.c" +#include +#include +#include +#include + + +#if WREN_DEBUG_DUMP_COMPILED_CODE +#endif + +// This is written in bottom-up order, so the tokenization comes first, then +// parsing/code generation. This minimizes the number of explicit forward +// declarations needed. + +// The maximum number of local (i.e. not module level) variables that can be +// declared in a single function, method, or chunk of top level code. This is +// the maximum number of variables in scope at one time, and spans block scopes. +// +// Note that this limitation is also explicit in the bytecode. Since +// `CODE_LOAD_LOCAL` and `CODE_STORE_LOCAL` use a single argument byte to +// identify the local, only 256 can be in scope at one time. +#define MAX_LOCALS 256 + +// The maximum number of upvalues (i.e. variables from enclosing functions) +// that a function can close over. +#define MAX_UPVALUES 256 + +// The maximum number of distinct constants that a function can contain. This +// value is explicit in the bytecode since `CODE_CONSTANT` only takes a single +// two-byte argument. +#define MAX_CONSTANTS (1 << 16) + +// The maximum distance a CODE_JUMP or CODE_JUMP_IF instruction can move the +// instruction pointer. +#define MAX_JUMP (1 << 16) + +// The maximum depth that interpolation can nest. For example, this string has +// three levels: +// +// "outside %(one + "%(two + "%(three)")")" +#define MAX_INTERPOLATION_NESTING 8 + +// The buffer size used to format a compile error message, excluding the header +// with the module name and error location. Using a hardcoded buffer for this +// is kind of hairy, but fortunately we can control what the longest possible +// message is and handle that. Ideally, we'd use `snprintf()`, but that's not +// available in standard C++98. +#define ERROR_MESSAGE_SIZE (80 + MAX_VARIABLE_NAME + 15) + +typedef enum +{ + TOKEN_LEFT_PAREN, + TOKEN_RIGHT_PAREN, + TOKEN_LEFT_BRACKET, + TOKEN_RIGHT_BRACKET, + TOKEN_LEFT_BRACE, + TOKEN_RIGHT_BRACE, + TOKEN_COLON, + TOKEN_DOT, + TOKEN_DOTDOT, + TOKEN_DOTDOTDOT, + TOKEN_COMMA, + TOKEN_STAR, + TOKEN_SLASH, + TOKEN_PERCENT, + TOKEN_HASH, + TOKEN_PLUS, + TOKEN_MINUS, + TOKEN_LTLT, + TOKEN_GTGT, + TOKEN_PIPE, + TOKEN_PIPEPIPE, + TOKEN_CARET, + TOKEN_AMP, + TOKEN_AMPAMP, + TOKEN_BANG, + TOKEN_TILDE, + TOKEN_QUESTION, + TOKEN_EQ, + TOKEN_LT, + TOKEN_GT, + TOKEN_LTEQ, + TOKEN_GTEQ, + TOKEN_EQEQ, + TOKEN_BANGEQ, + + TOKEN_BREAK, + TOKEN_CONTINUE, + TOKEN_CLASS, + TOKEN_CONSTRUCT, + TOKEN_ELSE, + TOKEN_FALSE, + TOKEN_FOR, + TOKEN_FOREIGN, + TOKEN_IF, + TOKEN_IMPORT, + TOKEN_AS, + TOKEN_IN, + TOKEN_IS, + TOKEN_NULL, + TOKEN_RETURN, + TOKEN_STATIC, + TOKEN_SUPER, + TOKEN_THIS, + TOKEN_TRUE, + TOKEN_VAR, + TOKEN_WHILE, + + TOKEN_FIELD, + TOKEN_STATIC_FIELD, + TOKEN_NAME, + TOKEN_NUMBER, + + // A string literal without any interpolation, or the last section of a + // string following the last interpolated expression. + TOKEN_STRING, + + // A portion of a string literal preceding an interpolated expression. This + // string: + // + // "a %(b) c %(d) e" + // + // is tokenized to: + // + // TOKEN_INTERPOLATION "a " + // TOKEN_NAME b + // TOKEN_INTERPOLATION " c " + // TOKEN_NAME d + // TOKEN_STRING " e" + TOKEN_INTERPOLATION, + + TOKEN_LINE, + + TOKEN_ERROR, + TOKEN_EOF +} TokenType; + +typedef struct +{ + TokenType type; + + // The beginning of the token, pointing directly into the source. + const char* start; + + // The length of the token in characters. + int length; + + // The 1-based line where the token appears. + int line; + + // The parsed value if the token is a literal. + Value value; +} Token; + +typedef struct +{ + WrenVM* vm; + + // The module being parsed. + ObjModule* module; + + // The source code being parsed. + const char* source; + + // The beginning of the currently-being-lexed token in [source]. + const char* tokenStart; + + // The current character being lexed in [source]. + const char* currentChar; + + // The 1-based line number of [currentChar]. + int currentLine; + + // The upcoming token. + Token next; + + // The most recently lexed token. + Token current; + + // The most recently consumed/advanced token. + Token previous; + + // Tracks the lexing state when tokenizing interpolated strings. + // + // Interpolated strings make the lexer not strictly regular: we don't know + // whether a ")" should be treated as a RIGHT_PAREN token or as ending an + // interpolated expression unless we know whether we are inside a string + // interpolation and how many unmatched "(" there are. This is particularly + // complex because interpolation can nest: + // + // " %( " %( inner ) " ) " + // + // This tracks that state. The parser maintains a stack of ints, one for each + // level of current interpolation nesting. Each value is the number of + // unmatched "(" that are waiting to be closed. + int parens[MAX_INTERPOLATION_NESTING]; + int numParens; + + // Whether compile errors should be printed to stderr or discarded. + bool printErrors; + + // If a syntax or compile error has occurred. + bool hasError; +} Parser; + +typedef struct +{ + // The name of the local variable. This points directly into the original + // source code string. + const char* name; + + // The length of the local variable's name. + int length; + + // The depth in the scope chain that this variable was declared at. Zero is + // the outermost scope--parameters for a method, or the first local block in + // top level code. One is the scope within that, etc. + int depth; + + // If this local variable is being used as an upvalue. + bool isUpvalue; +} Local; + +typedef struct +{ + // True if this upvalue is capturing a local variable from the enclosing + // function. False if it's capturing an upvalue. + bool isLocal; + + // The index of the local or upvalue being captured in the enclosing function. + int index; +} CompilerUpvalue; + +// Bookkeeping information for the current loop being compiled. +typedef struct sLoop +{ + // Index of the instruction that the loop should jump back to. + int start; + + // Index of the argument for the CODE_JUMP_IF instruction used to exit the + // loop. Stored so we can patch it once we know where the loop ends. + int exitJump; + + // Index of the first instruction of the body of the loop. + int body; + + // Depth of the scope(s) that need to be exited if a break is hit inside the + // loop. + int scopeDepth; + + // The loop enclosing this one, or NULL if this is the outermost loop. + struct sLoop* enclosing; +} Loop; + +// The different signature syntaxes for different kinds of methods. +typedef enum +{ + // A name followed by a (possibly empty) parenthesized parameter list. Also + // used for binary operators. + SIG_METHOD, + + // Just a name. Also used for unary operators. + SIG_GETTER, + + // A name followed by "=". + SIG_SETTER, + + // A square bracketed parameter list. + SIG_SUBSCRIPT, + + // A square bracketed parameter list followed by "=". + SIG_SUBSCRIPT_SETTER, + + // A constructor initializer function. This has a distinct signature to + // prevent it from being invoked directly outside of the constructor on the + // metaclass. + SIG_INITIALIZER +} SignatureType; + +typedef struct +{ + const char* name; + int length; + SignatureType type; + int arity; +} Signature; + +// Bookkeeping information for compiling a class definition. +typedef struct +{ + // The name of the class. + ObjString* name; + + // Attributes for the class itself + ObjMap* classAttributes; + // Attributes for methods in this class + ObjMap* methodAttributes; + + // Symbol table for the fields of the class. + SymbolTable fields; + + // Symbols for the methods defined by the class. Used to detect duplicate + // method definitions. + IntBuffer methods; + IntBuffer staticMethods; + + // True if the class being compiled is a foreign class. + bool isForeign; + + // True if the current method being compiled is static. + bool inStatic; + + // The signature of the method being compiled. + Signature* signature; +} ClassInfo; + +struct sCompiler +{ + Parser* parser; + + // The compiler for the function enclosing this one, or NULL if it's the + // top level. + struct sCompiler* parent; + + // The currently in scope local variables. + Local locals[MAX_LOCALS]; + + // The number of local variables currently in scope. + int numLocals; + + // The upvalues that this function has captured from outer scopes. The count + // of them is stored in [numUpvalues]. + CompilerUpvalue upvalues[MAX_UPVALUES]; + + // The current level of block scope nesting, where zero is no nesting. A -1 + // here means top-level code is being compiled and there is no block scope + // in effect at all. Any variables declared will be module-level. + int scopeDepth; + + // The current number of slots (locals and temporaries) in use. + // + // We use this and maxSlots to track the maximum number of additional slots + // a function may need while executing. When the function is called, the + // fiber will check to ensure its stack has enough room to cover that worst + // case and grow the stack if needed. + // + // This value here doesn't include parameters to the function. Since those + // are already pushed onto the stack by the caller and tracked there, we + // don't need to double count them here. + int numSlots; + + // The current innermost loop being compiled, or NULL if not in a loop. + Loop* loop; + + // If this is a compiler for a method, keeps track of the class enclosing it. + ClassInfo* enclosingClass; + + // The function being compiled. + ObjFn* fn; + + // The constants for the function being compiled. + ObjMap* constants; + + // Whether or not the compiler is for a constructor initializer + bool isInitializer; + + // The number of attributes seen while parsing. + // We track this separately as compile time attributes + // are not stored, so we can't rely on attributes->count + // to enforce an error message when attributes are used + // anywhere other than methods or classes. + int numAttributes; + // Attributes for the next class or method. + ObjMap* attributes; +}; + +// Describes where a variable is declared. +typedef enum +{ + // A local variable in the current function. + SCOPE_LOCAL, + + // A local variable declared in an enclosing function. + SCOPE_UPVALUE, + + // A top-level module variable. + SCOPE_MODULE +} Scope; + +// A reference to a variable and the scope where it is defined. This contains +// enough information to emit correct code to load or store the variable. +typedef struct +{ + // The stack slot, upvalue slot, or module symbol defining the variable. + int index; + + // Where the variable is declared. + Scope scope; +} Variable; + +// Forward declarations +static void disallowAttributes(Compiler* compiler); +static void addToAttributeGroup(Compiler* compiler, Value group, Value key, Value value); +static void emitClassAttributes(Compiler* compiler, ClassInfo* classInfo); +static void copyAttributes(Compiler* compiler, ObjMap* into); +static void copyMethodAttributes(Compiler* compiler, bool isForeign, + bool isStatic, const char* fullSignature, int32_t length); + +// The stack effect of each opcode. The index in the array is the opcode, and +// the value is the stack effect of that instruction. +static const int stackEffects[] = { + #define OPCODE(_, effect) effect, +// Begin file "wren_opcodes.h" +// This defines the bytecode instructions used by the VM. It does so by invoking +// an OPCODE() macro which is expected to be defined at the point that this is +// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.) +// +// The first argument is the name of the opcode. The second is its "stack +// effect" -- the amount that the op code changes the size of the stack. A +// stack effect of 1 means it pushes a value and the stack grows one larger. +// -2 means it pops two values, etc. +// +// Note that the order of instructions here affects the order of the dispatch +// table in the VM's interpreter loop. That in turn affects caching which +// affects overall performance. Take care to run benchmarks if you change the +// order here. + +// Load the constant at index [arg]. +OPCODE(CONSTANT, 1) + +// Push null onto the stack. +OPCODE(NULL, 1) + +// Push false onto the stack. +OPCODE(FALSE, 1) + +// Push true onto the stack. +OPCODE(TRUE, 1) + +// Pushes the value in the given local slot. +OPCODE(LOAD_LOCAL_0, 1) +OPCODE(LOAD_LOCAL_1, 1) +OPCODE(LOAD_LOCAL_2, 1) +OPCODE(LOAD_LOCAL_3, 1) +OPCODE(LOAD_LOCAL_4, 1) +OPCODE(LOAD_LOCAL_5, 1) +OPCODE(LOAD_LOCAL_6, 1) +OPCODE(LOAD_LOCAL_7, 1) +OPCODE(LOAD_LOCAL_8, 1) + +// Note: The compiler assumes the following _STORE instructions always +// immediately follow their corresponding _LOAD ones. + +// Pushes the value in local slot [arg]. +OPCODE(LOAD_LOCAL, 1) + +// Stores the top of stack in local slot [arg]. Does not pop it. +OPCODE(STORE_LOCAL, 0) + +// Pushes the value in upvalue [arg]. +OPCODE(LOAD_UPVALUE, 1) + +// Stores the top of stack in upvalue [arg]. Does not pop it. +OPCODE(STORE_UPVALUE, 0) + +// Pushes the value of the top-level variable in slot [arg]. +OPCODE(LOAD_MODULE_VAR, 1) + +// Stores the top of stack in top-level variable slot [arg]. Does not pop it. +OPCODE(STORE_MODULE_VAR, 0) + +// Pushes the value of the field in slot [arg] of the receiver of the current +// function. This is used for regular field accesses on "this" directly in +// methods. This instruction is faster than the more general CODE_LOAD_FIELD +// instruction. +OPCODE(LOAD_FIELD_THIS, 1) + +// Stores the top of the stack in field slot [arg] in the receiver of the +// current value. Does not pop the value. This instruction is faster than the +// more general CODE_LOAD_FIELD instruction. +OPCODE(STORE_FIELD_THIS, 0) + +// Pops an instance and pushes the value of the field in slot [arg] of it. +OPCODE(LOAD_FIELD, 0) + +// Pops an instance and stores the subsequent top of stack in field slot +// [arg] in it. Does not pop the value. +OPCODE(STORE_FIELD, -1) + +// Pop and discard the top of stack. +OPCODE(POP, -1) + +// Invoke the method with symbol [arg]. The number indicates the number of +// arguments (not including the receiver). +OPCODE(CALL_0, 0) +OPCODE(CALL_1, -1) +OPCODE(CALL_2, -2) +OPCODE(CALL_3, -3) +OPCODE(CALL_4, -4) +OPCODE(CALL_5, -5) +OPCODE(CALL_6, -6) +OPCODE(CALL_7, -7) +OPCODE(CALL_8, -8) +OPCODE(CALL_9, -9) +OPCODE(CALL_10, -10) +OPCODE(CALL_11, -11) +OPCODE(CALL_12, -12) +OPCODE(CALL_13, -13) +OPCODE(CALL_14, -14) +OPCODE(CALL_15, -15) +OPCODE(CALL_16, -16) + +// Invoke a superclass method with symbol [arg]. The number indicates the +// number of arguments (not including the receiver). +OPCODE(SUPER_0, 0) +OPCODE(SUPER_1, -1) +OPCODE(SUPER_2, -2) +OPCODE(SUPER_3, -3) +OPCODE(SUPER_4, -4) +OPCODE(SUPER_5, -5) +OPCODE(SUPER_6, -6) +OPCODE(SUPER_7, -7) +OPCODE(SUPER_8, -8) +OPCODE(SUPER_9, -9) +OPCODE(SUPER_10, -10) +OPCODE(SUPER_11, -11) +OPCODE(SUPER_12, -12) +OPCODE(SUPER_13, -13) +OPCODE(SUPER_14, -14) +OPCODE(SUPER_15, -15) +OPCODE(SUPER_16, -16) + +// Jump the instruction pointer [arg] forward. +OPCODE(JUMP, 0) + +// Jump the instruction pointer [arg] backward. +OPCODE(LOOP, 0) + +// Pop and if not truthy then jump the instruction pointer [arg] forward. +OPCODE(JUMP_IF, -1) + +// If the top of the stack is false, jump [arg] forward. Otherwise, pop and +// continue. +OPCODE(AND, -1) + +// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop +// and continue. +OPCODE(OR, -1) + +// Close the upvalue for the local on the top of the stack, then pop it. +OPCODE(CLOSE_UPVALUE, -1) + +// Exit from the current function and return the value on the top of the +// stack. +OPCODE(RETURN, 0) + +// Creates a closure for the function stored at [arg] in the constant table. +// +// Following the function argument is a number of arguments, two for each +// upvalue. The first is true if the variable being captured is a local (as +// opposed to an upvalue), and the second is the index of the local or +// upvalue being captured. +// +// Pushes the created closure. +OPCODE(CLOSURE, 1) + +// Creates a new instance of a class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(CONSTRUCT, 0) + +// Creates a new instance of a foreign class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(FOREIGN_CONSTRUCT, 0) + +// Creates a class. Top of stack is the superclass. Below that is a string for +// the name of the class. Byte [arg] is the number of fields in the class. +OPCODE(CLASS, -1) + +// Ends a class. +// Atm the stack contains the class and the ClassAttributes (or null). +OPCODE(END_CLASS, -2) + +// Creates a foreign class. Top of stack is the superclass. Below that is a +// string for the name of the class. +OPCODE(FOREIGN_CLASS, -1) + +// Define a method for symbol [arg]. The class receiving the method is popped +// off the stack, then the function defining the body is popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_INSTANCE, -2) + +// Define a method for symbol [arg]. The class whose metaclass will receive +// the method is popped off the stack, then the function defining the body is +// popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_STATIC, -2) + +// This is executed at the end of the module's body. Pushes NULL onto the stack +// as the "return value" of the import statement and stores the module as the +// most recently imported one. +OPCODE(END_MODULE, 1) + +// Import a module whose name is the string stored at [arg] in the constant +// table. +// +// Pushes null onto the stack so that the fiber for the imported module can +// replace that with a dummy value when it returns. (Fibers always return a +// value when resuming a caller.) +OPCODE(IMPORT_MODULE, 1) + +// Import a variable from the most recently imported module. The name of the +// variable to import is at [arg] in the constant table. Pushes the loaded +// variable's value. +OPCODE(IMPORT_VARIABLE, 1) + +// This pseudo-instruction indicates the end of the bytecode. It should +// always be preceded by a `CODE_RETURN`, so is never actually executed. +OPCODE(END, 0) +// End file "wren_opcodes.h" + #undef OPCODE +}; + +static void printError(Parser* parser, int line, const char* label, + const char* format, va_list args) +{ + parser->hasError = true; + if (!parser->printErrors) return; + + // Only report errors if there is a WrenErrorFn to handle them. + if (parser->vm->config.errorFn == NULL) return; + + // Format the label and message. + char message[ERROR_MESSAGE_SIZE]; + int length = sprintf(message, "%s: ", label); + length += vsprintf(message + length, format, args); + ASSERT(length < ERROR_MESSAGE_SIZE, "Error should not exceed buffer."); + + ObjString* module = parser->module->name; + const char* module_name = module ? module->value : ""; + + parser->vm->config.errorFn(parser->vm, WREN_ERROR_COMPILE, + module_name, line, message); +} + +// Outputs a lexical error. +static void lexError(Parser* parser, const char* format, ...) +{ + va_list args; + va_start(args, format); + printError(parser, parser->currentLine, "Error", format, args); + va_end(args); +} + +// Outputs a compile or syntax error. This also marks the compilation as having +// an error, which ensures that the resulting code will be discarded and never +// run. This means that after calling error(), it's fine to generate whatever +// invalid bytecode you want since it won't be used. +// +// You'll note that most places that call error() continue to parse and compile +// after that. That's so that we can try to find as many compilation errors in +// one pass as possible instead of just bailing at the first one. +static void error(Compiler* compiler, const char* format, ...) +{ + Token* token = &compiler->parser->previous; + + // If the parse error was caused by an error token, the lexer has already + // reported it. + if (token->type == TOKEN_ERROR) return; + + va_list args; + va_start(args, format); + if (token->type == TOKEN_LINE) + { + printError(compiler->parser, token->line, "Error at newline", format, args); + } + else if (token->type == TOKEN_EOF) + { + printError(compiler->parser, token->line, + "Error at end of file", format, args); + } + else + { + // Make sure we don't exceed the buffer with a very long token. + char label[10 + MAX_VARIABLE_NAME + 4 + 1]; + if (token->length <= MAX_VARIABLE_NAME) + { + sprintf(label, "Error at '%.*s'", token->length, token->start); + } + else + { + sprintf(label, "Error at '%.*s...'", MAX_VARIABLE_NAME, token->start); + } + printError(compiler->parser, token->line, label, format, args); + } + va_end(args); +} + +// Adds [constant] to the constant pool and returns its index. +static int addConstant(Compiler* compiler, Value constant) +{ + if (compiler->parser->hasError) return -1; + + // See if we already have a constant for the value. If so, reuse it. + if (compiler->constants != NULL) + { + Value existing = wrenMapGet(compiler->constants, constant); + if (IS_NUM(existing)) return (int)AS_NUM(existing); + } + + // It's a new constant. + if (compiler->fn->constants.count < MAX_CONSTANTS) + { + if (IS_OBJ(constant)) wrenPushRoot(compiler->parser->vm, AS_OBJ(constant)); + wrenValueBufferWrite(compiler->parser->vm, &compiler->fn->constants, + constant); + if (IS_OBJ(constant)) wrenPopRoot(compiler->parser->vm); + + if (compiler->constants == NULL) + { + compiler->constants = wrenNewMap(compiler->parser->vm); + } + wrenMapSet(compiler->parser->vm, compiler->constants, constant, + NUM_VAL(compiler->fn->constants.count - 1)); + } + else + { + error(compiler, "A function may only contain %d unique constants.", + MAX_CONSTANTS); + } + + return compiler->fn->constants.count - 1; +} + +// Initializes [compiler]. +static void initCompiler(Compiler* compiler, Parser* parser, Compiler* parent, + bool isMethod) +{ + compiler->parser = parser; + compiler->parent = parent; + compiler->loop = NULL; + compiler->enclosingClass = NULL; + compiler->isInitializer = false; + + // Initialize these to NULL before allocating in case a GC gets triggered in + // the middle of initializing the compiler. + compiler->fn = NULL; + compiler->constants = NULL; + compiler->attributes = NULL; + + parser->vm->compiler = compiler; + + // Declare a local slot for either the closure or method receiver so that we + // don't try to reuse that slot for a user-defined local variable. For + // methods, we name it "this", so that we can resolve references to that like + // a normal variable. For functions, they have no explicit "this", so we use + // an empty name. That way references to "this" inside a function walks up + // the parent chain to find a method enclosing the function whose "this" we + // can close over. + compiler->numLocals = 1; + compiler->numSlots = compiler->numLocals; + + if (isMethod) + { + compiler->locals[0].name = "this"; + compiler->locals[0].length = 4; + } + else + { + compiler->locals[0].name = NULL; + compiler->locals[0].length = 0; + } + + compiler->locals[0].depth = -1; + compiler->locals[0].isUpvalue = false; + + if (parent == NULL) + { + // Compiling top-level code, so the initial scope is module-level. + compiler->scopeDepth = -1; + } + else + { + // The initial scope for functions and methods is local scope. + compiler->scopeDepth = 0; + } + + compiler->numAttributes = 0; + compiler->attributes = wrenNewMap(parser->vm); + compiler->fn = wrenNewFunction(parser->vm, parser->module, + compiler->numLocals); +} + +// Lexing ---------------------------------------------------------------------- + +typedef struct +{ + const char* identifier; + size_t length; + TokenType tokenType; +} Keyword; + +// The table of reserved words and their associated token types. +static Keyword keywords[] = +{ + {"break", 5, TOKEN_BREAK}, + {"continue", 8, TOKEN_CONTINUE}, + {"class", 5, TOKEN_CLASS}, + {"construct", 9, TOKEN_CONSTRUCT}, + {"else", 4, TOKEN_ELSE}, + {"false", 5, TOKEN_FALSE}, + {"for", 3, TOKEN_FOR}, + {"foreign", 7, TOKEN_FOREIGN}, + {"if", 2, TOKEN_IF}, + {"import", 6, TOKEN_IMPORT}, + {"as", 2, TOKEN_AS}, + {"in", 2, TOKEN_IN}, + {"is", 2, TOKEN_IS}, + {"null", 4, TOKEN_NULL}, + {"return", 6, TOKEN_RETURN}, + {"static", 6, TOKEN_STATIC}, + {"super", 5, TOKEN_SUPER}, + {"this", 4, TOKEN_THIS}, + {"true", 4, TOKEN_TRUE}, + {"var", 3, TOKEN_VAR}, + {"while", 5, TOKEN_WHILE}, + {NULL, 0, TOKEN_EOF} // Sentinel to mark the end of the array. +}; + +// Returns true if [c] is a valid (non-initial) identifier character. +static bool isName(char c) +{ + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'; +} + +// Returns true if [c] is a digit. +static bool isDigit(char c) +{ + return c >= '0' && c <= '9'; +} + +// Returns the current character the parser is sitting on. +static char peekChar(Parser* parser) +{ + return *parser->currentChar; +} + +// Returns the character after the current character. +static char peekNextChar(Parser* parser) +{ + // If we're at the end of the source, don't read past it. + if (peekChar(parser) == '\0') return '\0'; + return *(parser->currentChar + 1); +} + +// Advances the parser forward one character. +static char nextChar(Parser* parser) +{ + char c = peekChar(parser); + parser->currentChar++; + if (c == '\n') parser->currentLine++; + return c; +} + +// If the current character is [c], consumes it and returns `true`. +static bool matchChar(Parser* parser, char c) +{ + if (peekChar(parser) != c) return false; + nextChar(parser); + return true; +} + +// Sets the parser's current token to the given [type] and current character +// range. +static void makeToken(Parser* parser, TokenType type) +{ + parser->next.type = type; + parser->next.start = parser->tokenStart; + parser->next.length = (int)(parser->currentChar - parser->tokenStart); + parser->next.line = parser->currentLine; + + // Make line tokens appear on the line containing the "\n". + if (type == TOKEN_LINE) parser->next.line--; +} + +// If the current character is [c], then consumes it and makes a token of type +// [two]. Otherwise makes a token of type [one]. +static void twoCharToken(Parser* parser, char c, TokenType two, TokenType one) +{ + makeToken(parser, matchChar(parser, c) ? two : one); +} + +// Skips the rest of the current line. +static void skipLineComment(Parser* parser) +{ + while (peekChar(parser) != '\n' && peekChar(parser) != '\0') + { + nextChar(parser); + } +} + +// Skips the rest of a block comment. +static void skipBlockComment(Parser* parser) +{ + int nesting = 1; + while (nesting > 0) + { + if (peekChar(parser) == '\0') + { + lexError(parser, "Unterminated block comment."); + return; + } + + if (peekChar(parser) == '/' && peekNextChar(parser) == '*') + { + nextChar(parser); + nextChar(parser); + nesting++; + continue; + } + + if (peekChar(parser) == '*' && peekNextChar(parser) == '/') + { + nextChar(parser); + nextChar(parser); + nesting--; + continue; + } + + // Regular comment character. + nextChar(parser); + } +} + +// Reads the next character, which should be a hex digit (0-9, a-f, or A-F) and +// returns its numeric value. If the character isn't a hex digit, returns -1. +static int readHexDigit(Parser* parser) +{ + char c = nextChar(parser); + if (c >= '0' && c <= '9') return c - '0'; + if (c >= 'a' && c <= 'f') return c - 'a' + 10; + if (c >= 'A' && c <= 'F') return c - 'A' + 10; + + // Don't consume it if it isn't expected. Keeps us from reading past the end + // of an unterminated string. + parser->currentChar--; + return -1; +} + +// Parses the numeric value of the current token. +static void makeNumber(Parser* parser, bool isHex) +{ + errno = 0; + + if (isHex) + { + parser->next.value = NUM_VAL((double)strtoll(parser->tokenStart, NULL, 16)); + } + else + { + parser->next.value = NUM_VAL(strtod(parser->tokenStart, NULL)); + } + + if (errno == ERANGE) + { + lexError(parser, "Number literal was too large (%d).", sizeof(long int)); + parser->next.value = NUM_VAL(0); + } + + // We don't check that the entire token is consumed after calling strtoll() + // or strtod() because we've already scanned it ourselves and know it's valid. + + makeToken(parser, TOKEN_NUMBER); +} + +// Finishes lexing a hexadecimal number literal. +static void readHexNumber(Parser* parser) +{ + // Skip past the `x` used to denote a hexadecimal literal. + nextChar(parser); + + // Iterate over all the valid hexadecimal digits found. + while (readHexDigit(parser) != -1) continue; + + makeNumber(parser, true); +} + +// Finishes lexing a number literal. +static void readNumber(Parser* parser) +{ + while (isDigit(peekChar(parser))) nextChar(parser); + + // See if it has a floating point. Make sure there is a digit after the "." + // so we don't get confused by method calls on number literals. + if (peekChar(parser) == '.' && isDigit(peekNextChar(parser))) + { + nextChar(parser); + while (isDigit(peekChar(parser))) nextChar(parser); + } + + // See if the number is in scientific notation. + if (matchChar(parser, 'e') || matchChar(parser, 'E')) + { + // Allow a single positive/negative exponent symbol. + if(!matchChar(parser, '+')) + { + matchChar(parser, '-'); + } + + if (!isDigit(peekChar(parser))) + { + lexError(parser, "Unterminated scientific notation."); + } + + while (isDigit(peekChar(parser))) nextChar(parser); + } + + makeNumber(parser, false); +} + +// Finishes lexing an identifier. Handles reserved words. +static void readName(Parser* parser, TokenType type, char firstChar) +{ + ByteBuffer string; + wrenByteBufferInit(&string); + wrenByteBufferWrite(parser->vm, &string, firstChar); + + while (isName(peekChar(parser)) || isDigit(peekChar(parser))) + { + char c = nextChar(parser); + wrenByteBufferWrite(parser->vm, &string, c); + } + + // Update the type if it's a keyword. + size_t length = parser->currentChar - parser->tokenStart; + for (int i = 0; keywords[i].identifier != NULL; i++) + { + if (length == keywords[i].length && + memcmp(parser->tokenStart, keywords[i].identifier, length) == 0) + { + type = keywords[i].tokenType; + break; + } + } + + parser->next.value = wrenNewStringLength(parser->vm, + (char*)string.data, string.count); + + wrenByteBufferClear(parser->vm, &string); + makeToken(parser, type); +} + +// Reads [digits] hex digits in a string literal and returns their number value. +static int readHexEscape(Parser* parser, int digits, const char* description) +{ + int value = 0; + for (int i = 0; i < digits; i++) + { + if (peekChar(parser) == '"' || peekChar(parser) == '\0') + { + lexError(parser, "Incomplete %s escape sequence.", description); + + // Don't consume it if it isn't expected. Keeps us from reading past the + // end of an unterminated string. + parser->currentChar--; + break; + } + + int digit = readHexDigit(parser); + if (digit == -1) + { + lexError(parser, "Invalid %s escape sequence.", description); + break; + } + + value = (value * 16) | digit; + } + + return value; +} + +// Reads a hex digit Unicode escape sequence in a string literal. +static void readUnicodeEscape(Parser* parser, ByteBuffer* string, int length) +{ + int value = readHexEscape(parser, length, "Unicode"); + + // Grow the buffer enough for the encoded result. + int numBytes = wrenUtf8EncodeNumBytes(value); + if (numBytes != 0) + { + wrenByteBufferFill(parser->vm, string, 0, numBytes); + wrenUtf8Encode(value, string->data + string->count - numBytes); + } +} + +static void readRawString(Parser* parser) +{ + ByteBuffer string; + wrenByteBufferInit(&string); + TokenType type = TOKEN_STRING; + + //consume the second and third " + nextChar(parser); + nextChar(parser); + + int skipStart = 0; + int firstNewline = -1; + + int skipEnd = -1; + int lastNewline = -1; + + for (;;) + { + char c = nextChar(parser); + char c1 = peekChar(parser); + char c2 = peekNextChar(parser); + + if (c == '\r') continue; + + if (c == '\n') { + lastNewline = string.count; + skipEnd = lastNewline; + firstNewline = firstNewline == -1 ? string.count : firstNewline; + } + + if (c == '"' && c1 == '"' && c2 == '"') break; + + bool isWhitespace = c == ' ' || c == '\t'; + skipEnd = c == '\n' || isWhitespace ? skipEnd : -1; + + // If we haven't seen a newline or other character yet, + // and still seeing whitespace, count the characters + // as skippable till we know otherwise + bool skippable = skipStart != -1 && isWhitespace && firstNewline == -1; + skipStart = skippable ? string.count + 1 : skipStart; + + // We've counted leading whitespace till we hit something else, + // but it's not a newline, so we reset skipStart since we need these characters + if (firstNewline == -1 && !isWhitespace && c != '\n') skipStart = -1; + + if (c == '\0' || c1 == '\0' || c2 == '\0') + { + lexError(parser, "Unterminated raw string."); + + // Don't consume it if it isn't expected. Keeps us from reading past the + // end of an unterminated string. + parser->currentChar--; + break; + } + + wrenByteBufferWrite(parser->vm, &string, c); + } + + //consume the second and third " + nextChar(parser); + nextChar(parser); + + int offset = 0; + int count = string.count; + + if(firstNewline != -1 && skipStart == firstNewline) offset = firstNewline + 1; + if(lastNewline != -1 && skipEnd == lastNewline) count = lastNewline; + + count -= (offset > count) ? count : offset; + + parser->next.value = wrenNewStringLength(parser->vm, + ((char*)string.data) + offset, count); + + wrenByteBufferClear(parser->vm, &string); + makeToken(parser, type); +} + +// Finishes lexing a string literal. +static void readString(Parser* parser) +{ + ByteBuffer string; + TokenType type = TOKEN_STRING; + wrenByteBufferInit(&string); + + for (;;) + { + char c = nextChar(parser); + if (c == '"') break; + if (c == '\r') continue; + + if (c == '\0') + { + lexError(parser, "Unterminated string."); + + // Don't consume it if it isn't expected. Keeps us from reading past the + // end of an unterminated string. + parser->currentChar--; + break; + } + + if (c == '%') + { + if (parser->numParens < MAX_INTERPOLATION_NESTING) + { + // TODO: Allow format string. + if (nextChar(parser) != '(') lexError(parser, "Expect '(' after '%%'."); + + parser->parens[parser->numParens++] = 1; + type = TOKEN_INTERPOLATION; + break; + } + + lexError(parser, "Interpolation may only nest %d levels deep.", + MAX_INTERPOLATION_NESTING); + } + + if (c == '\\') + { + switch (nextChar(parser)) + { + case '"': wrenByteBufferWrite(parser->vm, &string, '"'); break; + case '\\': wrenByteBufferWrite(parser->vm, &string, '\\'); break; + case '%': wrenByteBufferWrite(parser->vm, &string, '%'); break; + case '0': wrenByteBufferWrite(parser->vm, &string, '\0'); break; + case 'a': wrenByteBufferWrite(parser->vm, &string, '\a'); break; + case 'b': wrenByteBufferWrite(parser->vm, &string, '\b'); break; + case 'e': wrenByteBufferWrite(parser->vm, &string, '\33'); break; + case 'f': wrenByteBufferWrite(parser->vm, &string, '\f'); break; + case 'n': wrenByteBufferWrite(parser->vm, &string, '\n'); break; + case 'r': wrenByteBufferWrite(parser->vm, &string, '\r'); break; + case 't': wrenByteBufferWrite(parser->vm, &string, '\t'); break; + case 'u': readUnicodeEscape(parser, &string, 4); break; + case 'U': readUnicodeEscape(parser, &string, 8); break; + case 'v': wrenByteBufferWrite(parser->vm, &string, '\v'); break; + case 'x': + wrenByteBufferWrite(parser->vm, &string, + (uint8_t)readHexEscape(parser, 2, "byte")); + break; + + default: + lexError(parser, "Invalid escape character '%c'.", + *(parser->currentChar - 1)); + break; + } + } + else + { + wrenByteBufferWrite(parser->vm, &string, c); + } + } + + parser->next.value = wrenNewStringLength(parser->vm, + (char*)string.data, string.count); + + wrenByteBufferClear(parser->vm, &string); + makeToken(parser, type); +} + +// Lex the next token and store it in [parser.next]. +static void nextToken(Parser* parser) +{ + parser->previous = parser->current; + parser->current = parser->next; + + // If we are out of tokens, don't try to tokenize any more. We *do* still + // copy the TOKEN_EOF to previous so that code that expects it to be consumed + // will still work. + if (parser->next.type == TOKEN_EOF) return; + if (parser->current.type == TOKEN_EOF) return; + + while (peekChar(parser) != '\0') + { + parser->tokenStart = parser->currentChar; + + char c = nextChar(parser); + switch (c) + { + case '(': + // If we are inside an interpolated expression, count the unmatched "(". + if (parser->numParens > 0) parser->parens[parser->numParens - 1]++; + makeToken(parser, TOKEN_LEFT_PAREN); + return; + + case ')': + // If we are inside an interpolated expression, count the ")". + if (parser->numParens > 0 && + --parser->parens[parser->numParens - 1] == 0) + { + // This is the final ")", so the interpolation expression has ended. + // This ")" now begins the next section of the template string. + parser->numParens--; + readString(parser); + return; + } + + makeToken(parser, TOKEN_RIGHT_PAREN); + return; + + case '[': makeToken(parser, TOKEN_LEFT_BRACKET); return; + case ']': makeToken(parser, TOKEN_RIGHT_BRACKET); return; + case '{': makeToken(parser, TOKEN_LEFT_BRACE); return; + case '}': makeToken(parser, TOKEN_RIGHT_BRACE); return; + case ':': makeToken(parser, TOKEN_COLON); return; + case ',': makeToken(parser, TOKEN_COMMA); return; + case '*': makeToken(parser, TOKEN_STAR); return; + case '%': makeToken(parser, TOKEN_PERCENT); return; + case '#': { + // Ignore shebang on the first line. + if (parser->currentLine == 1 && peekChar(parser) == '!' && peekNextChar(parser) == '/') + { + skipLineComment(parser); + break; + } + // Otherwise we treat it as a token + makeToken(parser, TOKEN_HASH); + return; + } + case '^': makeToken(parser, TOKEN_CARET); return; + case '+': makeToken(parser, TOKEN_PLUS); return; + case '-': makeToken(parser, TOKEN_MINUS); return; + case '~': makeToken(parser, TOKEN_TILDE); return; + case '?': makeToken(parser, TOKEN_QUESTION); return; + + case '|': twoCharToken(parser, '|', TOKEN_PIPEPIPE, TOKEN_PIPE); return; + case '&': twoCharToken(parser, '&', TOKEN_AMPAMP, TOKEN_AMP); return; + case '=': twoCharToken(parser, '=', TOKEN_EQEQ, TOKEN_EQ); return; + case '!': twoCharToken(parser, '=', TOKEN_BANGEQ, TOKEN_BANG); return; + + case '.': + if (matchChar(parser, '.')) + { + twoCharToken(parser, '.', TOKEN_DOTDOTDOT, TOKEN_DOTDOT); + return; + } + + makeToken(parser, TOKEN_DOT); + return; + + case '/': + if (matchChar(parser, '/')) + { + skipLineComment(parser); + break; + } + + if (matchChar(parser, '*')) + { + skipBlockComment(parser); + break; + } + + makeToken(parser, TOKEN_SLASH); + return; + + case '<': + if (matchChar(parser, '<')) + { + makeToken(parser, TOKEN_LTLT); + } + else + { + twoCharToken(parser, '=', TOKEN_LTEQ, TOKEN_LT); + } + return; + + case '>': + if (matchChar(parser, '>')) + { + makeToken(parser, TOKEN_GTGT); + } + else + { + twoCharToken(parser, '=', TOKEN_GTEQ, TOKEN_GT); + } + return; + + case '\n': + makeToken(parser, TOKEN_LINE); + return; + + case ' ': + case '\r': + case '\t': + // Skip forward until we run out of whitespace. + while (peekChar(parser) == ' ' || + peekChar(parser) == '\r' || + peekChar(parser) == '\t') + { + nextChar(parser); + } + break; + + case '"': { + if(peekChar(parser) == '"' && peekNextChar(parser) == '"') { + readRawString(parser); + return; + } + readString(parser); return; + } + case '_': + readName(parser, + peekChar(parser) == '_' ? TOKEN_STATIC_FIELD : TOKEN_FIELD, c); + return; + + case '0': + if (peekChar(parser) == 'x') + { + readHexNumber(parser); + return; + } + + readNumber(parser); + return; + + default: + if (isName(c)) + { + readName(parser, TOKEN_NAME, c); + } + else if (isDigit(c)) + { + readNumber(parser); + } + else + { + if (c >= 32 && c <= 126) + { + lexError(parser, "Invalid character '%c'.", c); + } + else + { + // Don't show non-ASCII values since we didn't UTF-8 decode the + // bytes. Since there are no non-ASCII byte values that are + // meaningful code units in Wren, the lexer works on raw bytes, + // even though the source code and console output are UTF-8. + lexError(parser, "Invalid byte 0x%x.", (uint8_t)c); + } + parser->next.type = TOKEN_ERROR; + parser->next.length = 0; + } + return; + } + } + + // If we get here, we're out of source, so just make EOF tokens. + parser->tokenStart = parser->currentChar; + makeToken(parser, TOKEN_EOF); +} + +// Parsing --------------------------------------------------------------------- + +// Returns the type of the current token. +static TokenType peek(Compiler* compiler) +{ + return compiler->parser->current.type; +} + +// Returns the type of the current token. +static TokenType peekNext(Compiler* compiler) +{ + return compiler->parser->next.type; +} + +// Consumes the current token if its type is [expected]. Returns true if a +// token was consumed. +static bool match(Compiler* compiler, TokenType expected) +{ + if (peek(compiler) != expected) return false; + + nextToken(compiler->parser); + return true; +} + +// Consumes the current token. Emits an error if its type is not [expected]. +static void consume(Compiler* compiler, TokenType expected, + const char* errorMessage) +{ + nextToken(compiler->parser); + if (compiler->parser->previous.type != expected) + { + error(compiler, errorMessage); + + // If the next token is the one we want, assume the current one is just a + // spurious error and discard it to minimize the number of cascaded errors. + if (compiler->parser->current.type == expected) nextToken(compiler->parser); + } +} + +// Matches one or more newlines. Returns true if at least one was found. +static bool matchLine(Compiler* compiler) +{ + if (!match(compiler, TOKEN_LINE)) return false; + + while (match(compiler, TOKEN_LINE)); + return true; +} + +// Discards any newlines starting at the current token. +static void ignoreNewlines(Compiler* compiler) +{ + matchLine(compiler); +} + +// Consumes the current token. Emits an error if it is not a newline. Then +// discards any duplicate newlines following it. +static void consumeLine(Compiler* compiler, const char* errorMessage) +{ + consume(compiler, TOKEN_LINE, errorMessage); + ignoreNewlines(compiler); +} + +static void allowLineBeforeDot(Compiler* compiler) { + if (peek(compiler) == TOKEN_LINE && peekNext(compiler) == TOKEN_DOT) { + nextToken(compiler->parser); + } +} + +// Variables and scopes -------------------------------------------------------- + +// Emits one single-byte argument. Returns its index. +static int emitByte(Compiler* compiler, int byte) +{ + wrenByteBufferWrite(compiler->parser->vm, &compiler->fn->code, (uint8_t)byte); + + // Assume the instruction is associated with the most recently consumed token. + wrenIntBufferWrite(compiler->parser->vm, &compiler->fn->debug->sourceLines, + compiler->parser->previous.line); + + return compiler->fn->code.count - 1; +} + +// Emits one bytecode instruction. +static void emitOp(Compiler* compiler, Code instruction) +{ + emitByte(compiler, instruction); + + // Keep track of the stack's high water mark. + compiler->numSlots += stackEffects[instruction]; + if (compiler->numSlots > compiler->fn->maxSlots) + { + compiler->fn->maxSlots = compiler->numSlots; + } +} + +// Emits one 16-bit argument, which will be written big endian. +static void emitShort(Compiler* compiler, int arg) +{ + emitByte(compiler, (arg >> 8) & 0xff); + emitByte(compiler, arg & 0xff); +} + +// Emits one bytecode instruction followed by a 8-bit argument. Returns the +// index of the argument in the bytecode. +static int emitByteArg(Compiler* compiler, Code instruction, int arg) +{ + emitOp(compiler, instruction); + return emitByte(compiler, arg); +} + +// Emits one bytecode instruction followed by a 16-bit argument, which will be +// written big endian. +static void emitShortArg(Compiler* compiler, Code instruction, int arg) +{ + emitOp(compiler, instruction); + emitShort(compiler, arg); +} + +// Emits [instruction] followed by a placeholder for a jump offset. The +// placeholder can be patched by calling [jumpPatch]. Returns the index of the +// placeholder. +static int emitJump(Compiler* compiler, Code instruction) +{ + emitOp(compiler, instruction); + emitByte(compiler, 0xff); + return emitByte(compiler, 0xff) - 1; +} + +// Creates a new constant for the current value and emits the bytecode to load +// it from the constant table. +static void emitConstant(Compiler* compiler, Value value) +{ + int constant = addConstant(compiler, value); + + // Compile the code to load the constant. + emitShortArg(compiler, CODE_CONSTANT, constant); +} + +// Create a new local variable with [name]. Assumes the current scope is local +// and the name is unique. +static int addLocal(Compiler* compiler, const char* name, int length) +{ + Local* local = &compiler->locals[compiler->numLocals]; + local->name = name; + local->length = length; + local->depth = compiler->scopeDepth; + local->isUpvalue = false; + return compiler->numLocals++; +} + +// Declares a variable in the current scope whose name is the given token. +// +// If [token] is `NULL`, uses the previously consumed token. Returns its symbol. +static int declareVariable(Compiler* compiler, Token* token) +{ + if (token == NULL) token = &compiler->parser->previous; + + if (token->length > MAX_VARIABLE_NAME) + { + error(compiler, "Variable name cannot be longer than %d characters.", + MAX_VARIABLE_NAME); + } + + // Top-level module scope. + if (compiler->scopeDepth == -1) + { + int line = -1; + int symbol = wrenDefineVariable(compiler->parser->vm, + compiler->parser->module, + token->start, token->length, + NULL_VAL, &line); + + if (symbol == -1) + { + error(compiler, "Module variable is already defined."); + } + else if (symbol == -2) + { + error(compiler, "Too many module variables defined."); + } + else if (symbol == -3) + { + error(compiler, + "Variable '%.*s' referenced before this definition (first use at line %d).", + token->length, token->start, line); + } + + return symbol; + } + + // See if there is already a variable with this name declared in the current + // scope. (Outer scopes are OK: those get shadowed.) + for (int i = compiler->numLocals - 1; i >= 0; i--) + { + Local* local = &compiler->locals[i]; + + // Once we escape this scope and hit an outer one, we can stop. + if (local->depth < compiler->scopeDepth) break; + + if (local->length == token->length && + memcmp(local->name, token->start, token->length) == 0) + { + error(compiler, "Variable is already declared in this scope."); + return i; + } + } + + if (compiler->numLocals == MAX_LOCALS) + { + error(compiler, "Cannot declare more than %d variables in one scope.", + MAX_LOCALS); + return -1; + } + + return addLocal(compiler, token->start, token->length); +} + +// Parses a name token and declares a variable in the current scope with that +// name. Returns its slot. +static int declareNamedVariable(Compiler* compiler) +{ + consume(compiler, TOKEN_NAME, "Expect variable name."); + return declareVariable(compiler, NULL); +} + +// Stores a variable with the previously defined symbol in the current scope. +static void defineVariable(Compiler* compiler, int symbol) +{ + // Store the variable. If it's a local, the result of the initializer is + // in the correct slot on the stack already so we're done. + if (compiler->scopeDepth >= 0) return; + + // It's a module-level variable, so store the value in the module slot and + // then discard the temporary for the initializer. + emitShortArg(compiler, CODE_STORE_MODULE_VAR, symbol); + emitOp(compiler, CODE_POP); +} + +// Starts a new local block scope. +static void pushScope(Compiler* compiler) +{ + compiler->scopeDepth++; +} + +// Generates code to discard local variables at [depth] or greater. Does *not* +// actually undeclare variables or pop any scopes, though. This is called +// directly when compiling "break" statements to ditch the local variables +// before jumping out of the loop even though they are still in scope *past* +// the break instruction. +// +// Returns the number of local variables that were eliminated. +static int discardLocals(Compiler* compiler, int depth) +{ + ASSERT(compiler->scopeDepth > -1, "Cannot exit top-level scope."); + + int local = compiler->numLocals - 1; + while (local >= 0 && compiler->locals[local].depth >= depth) + { + // If the local was closed over, make sure the upvalue gets closed when it + // goes out of scope on the stack. We use emitByte() and not emitOp() here + // because we don't want to track that stack effect of these pops since the + // variables are still in scope after the break. + if (compiler->locals[local].isUpvalue) + { + emitByte(compiler, CODE_CLOSE_UPVALUE); + } + else + { + emitByte(compiler, CODE_POP); + } + + + local--; + } + + return compiler->numLocals - local - 1; +} + +// Closes the last pushed block scope and discards any local variables declared +// in that scope. This should only be called in a statement context where no +// temporaries are still on the stack. +static void popScope(Compiler* compiler) +{ + int popped = discardLocals(compiler, compiler->scopeDepth); + compiler->numLocals -= popped; + compiler->numSlots -= popped; + compiler->scopeDepth--; +} + +// Attempts to look up the name in the local variables of [compiler]. If found, +// returns its index, otherwise returns -1. +static int resolveLocal(Compiler* compiler, const char* name, int length) +{ + // Look it up in the local scopes. Look in reverse order so that the most + // nested variable is found first and shadows outer ones. + for (int i = compiler->numLocals - 1; i >= 0; i--) + { + if (compiler->locals[i].length == length && + memcmp(name, compiler->locals[i].name, length) == 0) + { + return i; + } + } + + return -1; +} + +// Adds an upvalue to [compiler]'s function with the given properties. Does not +// add one if an upvalue for that variable is already in the list. Returns the +// index of the upvalue. +static int addUpvalue(Compiler* compiler, bool isLocal, int index) +{ + // Look for an existing one. + for (int i = 0; i < compiler->fn->numUpvalues; i++) + { + CompilerUpvalue* upvalue = &compiler->upvalues[i]; + if (upvalue->index == index && upvalue->isLocal == isLocal) return i; + } + + // If we got here, it's a new upvalue. + compiler->upvalues[compiler->fn->numUpvalues].isLocal = isLocal; + compiler->upvalues[compiler->fn->numUpvalues].index = index; + return compiler->fn->numUpvalues++; +} + +// Attempts to look up [name] in the functions enclosing the one being compiled +// by [compiler]. If found, it adds an upvalue for it to this compiler's list +// of upvalues (unless it's already in there) and returns its index. If not +// found, returns -1. +// +// If the name is found outside of the immediately enclosing function, this +// will flatten the closure and add upvalues to all of the intermediate +// functions so that it gets walked down to this one. +// +// If it reaches a method boundary, this stops and returns -1 since methods do +// not close over local variables. +static int findUpvalue(Compiler* compiler, const char* name, int length) +{ + // If we are at the top level, we didn't find it. + if (compiler->parent == NULL) return -1; + + // If we hit the method boundary (and the name isn't a static field), then + // stop looking for it. We'll instead treat it as a self send. + if (name[0] != '_' && compiler->parent->enclosingClass != NULL) return -1; + + // See if it's a local variable in the immediately enclosing function. + int local = resolveLocal(compiler->parent, name, length); + if (local != -1) + { + // Mark the local as an upvalue so we know to close it when it goes out of + // scope. + compiler->parent->locals[local].isUpvalue = true; + + return addUpvalue(compiler, true, local); + } + + // See if it's an upvalue in the immediately enclosing function. In other + // words, if it's a local variable in a non-immediately enclosing function. + // This "flattens" closures automatically: it adds upvalues to all of the + // intermediate functions to get from the function where a local is declared + // all the way into the possibly deeply nested function that is closing over + // it. + int upvalue = findUpvalue(compiler->parent, name, length); + if (upvalue != -1) + { + return addUpvalue(compiler, false, upvalue); + } + + // If we got here, we walked all the way up the parent chain and couldn't + // find it. + return -1; +} + +// Look up [name] in the current scope to see what variable it refers to. +// Returns the variable either in local scope, or the enclosing function's +// upvalue list. Does not search the module scope. Returns a variable with +// index -1 if not found. +static Variable resolveNonmodule(Compiler* compiler, + const char* name, int length) +{ + // Look it up in the local scopes. + Variable variable; + variable.scope = SCOPE_LOCAL; + variable.index = resolveLocal(compiler, name, length); + if (variable.index != -1) return variable; + + // Tt's not a local, so guess that it's an upvalue. + variable.scope = SCOPE_UPVALUE; + variable.index = findUpvalue(compiler, name, length); + return variable; +} + +// Look up [name] in the current scope to see what variable it refers to. +// Returns the variable either in module scope, local scope, or the enclosing +// function's upvalue list. Returns a variable with index -1 if not found. +static Variable resolveName(Compiler* compiler, const char* name, int length) +{ + Variable variable = resolveNonmodule(compiler, name, length); + if (variable.index != -1) return variable; + + variable.scope = SCOPE_MODULE; + variable.index = wrenSymbolTableFind(&compiler->parser->module->variableNames, + name, length); + return variable; +} + +static void loadLocal(Compiler* compiler, int slot) +{ + if (slot <= 8) + { + emitOp(compiler, (Code)(CODE_LOAD_LOCAL_0 + slot)); + return; + } + + emitByteArg(compiler, CODE_LOAD_LOCAL, slot); +} + +// Finishes [compiler], which is compiling a function, method, or chunk of top +// level code. If there is a parent compiler, then this emits code in the +// parent compiler to load the resulting function. +static ObjFn* endCompiler(Compiler* compiler, + const char* debugName, int debugNameLength) +{ + // If we hit an error, don't finish the function since it's borked anyway. + if (compiler->parser->hasError) + { + compiler->parser->vm->compiler = compiler->parent; + return NULL; + } + + // Mark the end of the bytecode. Since it may contain multiple early returns, + // we can't rely on CODE_RETURN to tell us we're at the end. + emitOp(compiler, CODE_END); + + wrenFunctionBindName(compiler->parser->vm, compiler->fn, + debugName, debugNameLength); + + // In the function that contains this one, load the resulting function object. + if (compiler->parent != NULL) + { + int constant = addConstant(compiler->parent, OBJ_VAL(compiler->fn)); + + // Wrap the function in a closure. We do this even if it has no upvalues so + // that the VM can uniformly assume all called objects are closures. This + // makes creating a function a little slower, but makes invoking them + // faster. Given that functions are invoked more often than they are + // created, this is a win. + emitShortArg(compiler->parent, CODE_CLOSURE, constant); + + // Emit arguments for each upvalue to know whether to capture a local or + // an upvalue. + for (int i = 0; i < compiler->fn->numUpvalues; i++) + { + emitByte(compiler->parent, compiler->upvalues[i].isLocal ? 1 : 0); + emitByte(compiler->parent, compiler->upvalues[i].index); + } + } + + // Pop this compiler off the stack. + compiler->parser->vm->compiler = compiler->parent; + + #if WREN_DEBUG_DUMP_COMPILED_CODE + wrenDumpCode(compiler->parser->vm, compiler->fn); + #endif + + return compiler->fn; +} + +// Grammar --------------------------------------------------------------------- + +typedef enum +{ + PREC_NONE, + PREC_LOWEST, + PREC_ASSIGNMENT, // = + PREC_CONDITIONAL, // ?: + PREC_LOGICAL_OR, // || + PREC_LOGICAL_AND, // && + PREC_EQUALITY, // == != + PREC_IS, // is + PREC_COMPARISON, // < > <= >= + PREC_BITWISE_OR, // | + PREC_BITWISE_XOR, // ^ + PREC_BITWISE_AND, // & + PREC_BITWISE_SHIFT, // << >> + PREC_RANGE, // .. ... + PREC_TERM, // + - + PREC_FACTOR, // * / % + PREC_UNARY, // unary - ! ~ + PREC_CALL, // . () [] + PREC_PRIMARY +} Precedence; + +typedef void (*GrammarFn)(Compiler*, bool canAssign); + +typedef void (*SignatureFn)(Compiler* compiler, Signature* signature); + +typedef struct +{ + GrammarFn prefix; + GrammarFn infix; + SignatureFn method; + Precedence precedence; + const char* name; +} GrammarRule; + +// Forward declarations since the grammar is recursive. +static GrammarRule* getRule(TokenType type); +static void expression(Compiler* compiler); +static void statement(Compiler* compiler); +static void definition(Compiler* compiler); +static void parsePrecedence(Compiler* compiler, Precedence precedence); + +// Replaces the placeholder argument for a previous CODE_JUMP or CODE_JUMP_IF +// instruction with an offset that jumps to the current end of bytecode. +static void patchJump(Compiler* compiler, int offset) +{ + // -2 to adjust for the bytecode for the jump offset itself. + int jump = compiler->fn->code.count - offset - 2; + if (jump > MAX_JUMP) error(compiler, "Too much code to jump over."); + + compiler->fn->code.data[offset] = (jump >> 8) & 0xff; + compiler->fn->code.data[offset + 1] = jump & 0xff; +} + +// Parses a block body, after the initial "{" has been consumed. +// +// Returns true if it was a expression body, false if it was a statement body. +// (More precisely, returns true if a value was left on the stack. An empty +// block returns false.) +static bool finishBlock(Compiler* compiler) +{ + // Empty blocks do nothing. + if (match(compiler, TOKEN_RIGHT_BRACE)) return false; + + // If there's no line after the "{", it's a single-expression body. + if (!matchLine(compiler)) + { + expression(compiler); + consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' at end of block."); + return true; + } + + // Empty blocks (with just a newline inside) do nothing. + if (match(compiler, TOKEN_RIGHT_BRACE)) return false; + + // Compile the definition list. + do + { + definition(compiler); + consumeLine(compiler, "Expect newline after statement."); + } + while (peek(compiler) != TOKEN_RIGHT_BRACE && peek(compiler) != TOKEN_EOF); + + consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' at end of block."); + return false; +} + +// Parses a method or function body, after the initial "{" has been consumed. +// +// If [Compiler->isInitializer] is `true`, this is the body of a constructor +// initializer. In that case, this adds the code to ensure it returns `this`. +static void finishBody(Compiler* compiler) +{ + bool isExpressionBody = finishBlock(compiler); + + if (compiler->isInitializer) + { + // If the initializer body evaluates to a value, discard it. + if (isExpressionBody) emitOp(compiler, CODE_POP); + + // The receiver is always stored in the first local slot. + emitOp(compiler, CODE_LOAD_LOCAL_0); + } + else if (!isExpressionBody) + { + // Implicitly return null in statement bodies. + emitOp(compiler, CODE_NULL); + } + + emitOp(compiler, CODE_RETURN); +} + +// The VM can only handle a certain number of parameters, so check that we +// haven't exceeded that and give a usable error. +static void validateNumParameters(Compiler* compiler, int numArgs) +{ + if (numArgs == MAX_PARAMETERS + 1) + { + // Only show an error at exactly max + 1 so that we can keep parsing the + // parameters and minimize cascaded errors. + error(compiler, "Methods cannot have more than %d parameters.", + MAX_PARAMETERS); + } +} + +// Parses the rest of a comma-separated parameter list after the opening +// delimeter. Updates `arity` in [signature] with the number of parameters. +static void finishParameterList(Compiler* compiler, Signature* signature) +{ + do + { + ignoreNewlines(compiler); + validateNumParameters(compiler, ++signature->arity); + + // Define a local variable in the method for the parameter. + declareNamedVariable(compiler); + } + while (match(compiler, TOKEN_COMMA)); +} + +// Gets the symbol for a method [name] with [length]. +static int methodSymbol(Compiler* compiler, const char* name, int length) +{ + return wrenSymbolTableEnsure(compiler->parser->vm, + &compiler->parser->vm->methodNames, name, length); +} + +// Appends characters to [name] (and updates [length]) for [numParams] "_" +// surrounded by [leftBracket] and [rightBracket]. +static void signatureParameterList(char name[MAX_METHOD_SIGNATURE], int* length, + int numParams, char leftBracket, char rightBracket) +{ + name[(*length)++] = leftBracket; + + // This function may be called with too many parameters. When that happens, + // a compile error has already been reported, but we need to make sure we + // don't overflow the string too, hence the MAX_PARAMETERS check. + for (int i = 0; i < numParams && i < MAX_PARAMETERS; i++) + { + if (i > 0) name[(*length)++] = ','; + name[(*length)++] = '_'; + } + name[(*length)++] = rightBracket; +} + +// Fills [name] with the stringified version of [signature] and updates +// [length] to the resulting length. +static void signatureToString(Signature* signature, + char name[MAX_METHOD_SIGNATURE], int* length) +{ + *length = 0; + + // Build the full name from the signature. + memcpy(name + *length, signature->name, signature->length); + *length += signature->length; + + switch (signature->type) + { + case SIG_METHOD: + signatureParameterList(name, length, signature->arity, '(', ')'); + break; + + case SIG_GETTER: + // The signature is just the name. + break; + + case SIG_SETTER: + name[(*length)++] = '='; + signatureParameterList(name, length, 1, '(', ')'); + break; + + case SIG_SUBSCRIPT: + signatureParameterList(name, length, signature->arity, '[', ']'); + break; + + case SIG_SUBSCRIPT_SETTER: + signatureParameterList(name, length, signature->arity - 1, '[', ']'); + name[(*length)++] = '='; + signatureParameterList(name, length, 1, '(', ')'); + break; + + case SIG_INITIALIZER: + memcpy(name, "init ", 5); + memcpy(name + 5, signature->name, signature->length); + *length = 5 + signature->length; + signatureParameterList(name, length, signature->arity, '(', ')'); + break; + } + + name[*length] = '\0'; +} + +// Gets the symbol for a method with [signature]. +static int signatureSymbol(Compiler* compiler, Signature* signature) +{ + // Build the full name from the signature. + char name[MAX_METHOD_SIGNATURE]; + int length; + signatureToString(signature, name, &length); + + return methodSymbol(compiler, name, length); +} + +// Returns a signature with [type] whose name is from the last consumed token. +static Signature signatureFromToken(Compiler* compiler, SignatureType type) +{ + Signature signature; + + // Get the token for the method name. + Token* token = &compiler->parser->previous; + signature.name = token->start; + signature.length = token->length; + signature.type = type; + signature.arity = 0; + + if (signature.length > MAX_METHOD_NAME) + { + error(compiler, "Method names cannot be longer than %d characters.", + MAX_METHOD_NAME); + signature.length = MAX_METHOD_NAME; + } + + return signature; +} + +// Parses a comma-separated list of arguments. Modifies [signature] to include +// the arity of the argument list. +static void finishArgumentList(Compiler* compiler, Signature* signature) +{ + do + { + ignoreNewlines(compiler); + validateNumParameters(compiler, ++signature->arity); + expression(compiler); + } + while (match(compiler, TOKEN_COMMA)); + + // Allow a newline before the closing delimiter. + ignoreNewlines(compiler); +} + +// Compiles a method call with [signature] using [instruction]. +static void callSignature(Compiler* compiler, Code instruction, + Signature* signature) +{ + int symbol = signatureSymbol(compiler, signature); + emitShortArg(compiler, (Code)(instruction + signature->arity), symbol); + + if (instruction == CODE_SUPER_0) + { + // Super calls need to be statically bound to the class's superclass. This + // ensures we call the right method even when a method containing a super + // call is inherited by another subclass. + // + // We bind it at class definition time by storing a reference to the + // superclass in a constant. So, here, we create a slot in the constant + // table and store NULL in it. When the method is bound, we'll look up the + // superclass then and store it in the constant slot. + emitShort(compiler, addConstant(compiler, NULL_VAL)); + } +} + +// Compiles a method call with [numArgs] for a method with [name] with [length]. +static void callMethod(Compiler* compiler, int numArgs, const char* name, + int length) +{ + int symbol = methodSymbol(compiler, name, length); + emitShortArg(compiler, (Code)(CODE_CALL_0 + numArgs), symbol); +} + +// Compiles an (optional) argument list for a method call with [methodSignature] +// and then calls it. +static void methodCall(Compiler* compiler, Code instruction, + Signature* signature) +{ + // Make a new signature that contains the updated arity and type based on + // the arguments we find. + Signature called = { signature->name, signature->length, SIG_GETTER, 0 }; + + // Parse the argument list, if any. + if (match(compiler, TOKEN_LEFT_PAREN)) + { + called.type = SIG_METHOD; + + // Allow new line before an empty argument list + ignoreNewlines(compiler); + + // Allow empty an argument list. + if (peek(compiler) != TOKEN_RIGHT_PAREN) + { + finishArgumentList(compiler, &called); + } + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after arguments."); + } + + // Parse the block argument, if any. + if (match(compiler, TOKEN_LEFT_BRACE)) + { + // Include the block argument in the arity. + called.type = SIG_METHOD; + called.arity++; + + Compiler fnCompiler; + initCompiler(&fnCompiler, compiler->parser, compiler, false); + + // Make a dummy signature to track the arity. + Signature fnSignature = { "", 0, SIG_METHOD, 0 }; + + // Parse the parameter list, if any. + if (match(compiler, TOKEN_PIPE)) + { + finishParameterList(&fnCompiler, &fnSignature); + consume(compiler, TOKEN_PIPE, "Expect '|' after function parameters."); + } + + fnCompiler.fn->arity = fnSignature.arity; + + finishBody(&fnCompiler); + + // Name the function based on the method its passed to. + char blockName[MAX_METHOD_SIGNATURE + 15]; + int blockLength; + signatureToString(&called, blockName, &blockLength); + memmove(blockName + blockLength, " block argument", 16); + + endCompiler(&fnCompiler, blockName, blockLength + 15); + } + + // TODO: Allow Grace-style mixfix methods? + + // If this is a super() call for an initializer, make sure we got an actual + // argument list. + if (signature->type == SIG_INITIALIZER) + { + if (called.type != SIG_METHOD) + { + error(compiler, "A superclass constructor must have an argument list."); + } + + called.type = SIG_INITIALIZER; + } + + callSignature(compiler, instruction, &called); +} + +// Compiles a call whose name is the previously consumed token. This includes +// getters, method calls with arguments, and setter calls. +static void namedCall(Compiler* compiler, bool canAssign, Code instruction) +{ + // Get the token for the method name. + Signature signature = signatureFromToken(compiler, SIG_GETTER); + + if (canAssign && match(compiler, TOKEN_EQ)) + { + ignoreNewlines(compiler); + + // Build the setter signature. + signature.type = SIG_SETTER; + signature.arity = 1; + + // Compile the assigned value. + expression(compiler); + callSignature(compiler, instruction, &signature); + } + else + { + methodCall(compiler, instruction, &signature); + allowLineBeforeDot(compiler); + } +} + +// Emits the code to load [variable] onto the stack. +static void loadVariable(Compiler* compiler, Variable variable) +{ + switch (variable.scope) + { + case SCOPE_LOCAL: + loadLocal(compiler, variable.index); + break; + case SCOPE_UPVALUE: + emitByteArg(compiler, CODE_LOAD_UPVALUE, variable.index); + break; + case SCOPE_MODULE: + emitShortArg(compiler, CODE_LOAD_MODULE_VAR, variable.index); + break; + default: + UNREACHABLE(); + } +} + +// Loads the receiver of the currently enclosing method. Correctly handles +// functions defined inside methods. +static void loadThis(Compiler* compiler) +{ + loadVariable(compiler, resolveNonmodule(compiler, "this", 4)); +} + +// Pushes the value for a module-level variable implicitly imported from core. +static void loadCoreVariable(Compiler* compiler, const char* name) +{ + int symbol = wrenSymbolTableFind(&compiler->parser->module->variableNames, + name, strlen(name)); + ASSERT(symbol != -1, "Should have already defined core name."); + emitShortArg(compiler, CODE_LOAD_MODULE_VAR, symbol); +} + +// A parenthesized expression. +static void grouping(Compiler* compiler, bool canAssign) +{ + expression(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after expression."); +} + +// A list literal. +static void list(Compiler* compiler, bool canAssign) +{ + // Instantiate a new list. + loadCoreVariable(compiler, "List"); + callMethod(compiler, 0, "new()", 5); + + // Compile the list elements. Each one compiles to a ".add()" call. + do + { + ignoreNewlines(compiler); + + // Stop if we hit the end of the list. + if (peek(compiler) == TOKEN_RIGHT_BRACKET) break; + + // The element. + expression(compiler); + callMethod(compiler, 1, "addCore_(_)", 11); + } while (match(compiler, TOKEN_COMMA)); + + // Allow newlines before the closing ']'. + ignoreNewlines(compiler); + consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after list elements."); +} + +// A map literal. +static void map(Compiler* compiler, bool canAssign) +{ + // Instantiate a new map. + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + // Compile the map elements. Each one is compiled to just invoke the + // subscript setter on the map. + do + { + ignoreNewlines(compiler); + + // Stop if we hit the end of the map. + if (peek(compiler) == TOKEN_RIGHT_BRACE) break; + + // The key. + parsePrecedence(compiler, PREC_UNARY); + consume(compiler, TOKEN_COLON, "Expect ':' after map key."); + ignoreNewlines(compiler); + + // The value. + expression(compiler); + callMethod(compiler, 2, "addCore_(_,_)", 13); + } while (match(compiler, TOKEN_COMMA)); + + // Allow newlines before the closing '}'. + ignoreNewlines(compiler); + consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' after map entries."); +} + +// Unary operators like `-foo`. +static void unaryOp(Compiler* compiler, bool canAssign) +{ + GrammarRule* rule = getRule(compiler->parser->previous.type); + + ignoreNewlines(compiler); + + // Compile the argument. + parsePrecedence(compiler, (Precedence)(PREC_UNARY + 1)); + + // Call the operator method on the left-hand side. + callMethod(compiler, 0, rule->name, 1); +} + +static void boolean(Compiler* compiler, bool canAssign) +{ + emitOp(compiler, + compiler->parser->previous.type == TOKEN_FALSE ? CODE_FALSE : CODE_TRUE); +} + +// Walks the compiler chain to find the compiler for the nearest class +// enclosing this one. Returns NULL if not currently inside a class definition. +static Compiler* getEnclosingClassCompiler(Compiler* compiler) +{ + while (compiler != NULL) + { + if (compiler->enclosingClass != NULL) return compiler; + compiler = compiler->parent; + } + + return NULL; +} + +// Walks the compiler chain to find the nearest class enclosing this one. +// Returns NULL if not currently inside a class definition. +static ClassInfo* getEnclosingClass(Compiler* compiler) +{ + compiler = getEnclosingClassCompiler(compiler); + return compiler == NULL ? NULL : compiler->enclosingClass; +} + +static void field(Compiler* compiler, bool canAssign) +{ + // Initialize it with a fake value so we can keep parsing and minimize the + // number of cascaded errors. + int field = MAX_FIELDS; + + ClassInfo* enclosingClass = getEnclosingClass(compiler); + + if (enclosingClass == NULL) + { + error(compiler, "Cannot reference a field outside of a class definition."); + } + else if (enclosingClass->isForeign) + { + error(compiler, "Cannot define fields in a foreign class."); + } + else if (enclosingClass->inStatic) + { + error(compiler, "Cannot use an instance field in a static method."); + } + else + { + // Look up the field, or implicitly define it. + field = wrenSymbolTableEnsure(compiler->parser->vm, &enclosingClass->fields, + compiler->parser->previous.start, + compiler->parser->previous.length); + + if (field >= MAX_FIELDS) + { + error(compiler, "A class can only have %d fields.", MAX_FIELDS); + } + } + + // If there's an "=" after a field name, it's an assignment. + bool isLoad = true; + if (canAssign && match(compiler, TOKEN_EQ)) + { + // Compile the right-hand side. + expression(compiler); + isLoad = false; + } + + // If we're directly inside a method, use a more optimal instruction. + if (compiler->parent != NULL && + compiler->parent->enclosingClass == enclosingClass) + { + emitByteArg(compiler, isLoad ? CODE_LOAD_FIELD_THIS : CODE_STORE_FIELD_THIS, + field); + } + else + { + loadThis(compiler); + emitByteArg(compiler, isLoad ? CODE_LOAD_FIELD : CODE_STORE_FIELD, field); + } + + allowLineBeforeDot(compiler); +} + +// Compiles a read or assignment to [variable]. +static void bareName(Compiler* compiler, bool canAssign, Variable variable) +{ + // If there's an "=" after a bare name, it's a variable assignment. + if (canAssign && match(compiler, TOKEN_EQ)) + { + // Compile the right-hand side. + expression(compiler); + + // Emit the store instruction. + switch (variable.scope) + { + case SCOPE_LOCAL: + emitByteArg(compiler, CODE_STORE_LOCAL, variable.index); + break; + case SCOPE_UPVALUE: + emitByteArg(compiler, CODE_STORE_UPVALUE, variable.index); + break; + case SCOPE_MODULE: + emitShortArg(compiler, CODE_STORE_MODULE_VAR, variable.index); + break; + default: + UNREACHABLE(); + } + return; + } + + // Emit the load instruction. + loadVariable(compiler, variable); + + allowLineBeforeDot(compiler); +} + +static void staticField(Compiler* compiler, bool canAssign) +{ + Compiler* classCompiler = getEnclosingClassCompiler(compiler); + if (classCompiler == NULL) + { + error(compiler, "Cannot use a static field outside of a class definition."); + return; + } + + // Look up the name in the scope chain. + Token* token = &compiler->parser->previous; + + // If this is the first time we've seen this static field, implicitly + // define it as a variable in the scope surrounding the class definition. + if (resolveLocal(classCompiler, token->start, token->length) == -1) + { + int symbol = declareVariable(classCompiler, NULL); + + // Implicitly initialize it to null. + emitOp(classCompiler, CODE_NULL); + defineVariable(classCompiler, symbol); + } + + // It definitely exists now, so resolve it properly. This is different from + // the above resolveLocal() call because we may have already closed over it + // as an upvalue. + Variable variable = resolveName(compiler, token->start, token->length); + bareName(compiler, canAssign, variable); +} + +// Compiles a variable name or method call with an implicit receiver. +static void name(Compiler* compiler, bool canAssign) +{ + // Look for the name in the scope chain up to the nearest enclosing method. + Token* token = &compiler->parser->previous; + + Variable variable = resolveNonmodule(compiler, token->start, token->length); + if (variable.index != -1) + { + bareName(compiler, canAssign, variable); + return; + } + + // TODO: The fact that we return above here if the variable is known and parse + // an optional argument list below if not means that the grammar is not + // context-free. A line of code in a method like "someName(foo)" is a parse + // error if "someName" is a defined variable in the surrounding scope and not + // if it isn't. Fix this. One option is to have "someName(foo)" always + // resolve to a self-call if there is an argument list, but that makes + // getters a little confusing. + + // If we're inside a method and the name is lowercase, treat it as a method + // on this. + if (wrenIsLocalName(token->start) && getEnclosingClass(compiler) != NULL) + { + loadThis(compiler); + namedCall(compiler, canAssign, CODE_CALL_0); + return; + } + + // Otherwise, look for a module-level variable with the name. + variable.scope = SCOPE_MODULE; + variable.index = wrenSymbolTableFind(&compiler->parser->module->variableNames, + token->start, token->length); + if (variable.index == -1) + { + // Implicitly define a module-level variable in + // the hopes that we get a real definition later. + variable.index = wrenDeclareVariable(compiler->parser->vm, + compiler->parser->module, + token->start, token->length, + token->line); + + if (variable.index == -2) + { + error(compiler, "Too many module variables defined."); + } + } + + bareName(compiler, canAssign, variable); +} + +static void null(Compiler* compiler, bool canAssign) +{ + emitOp(compiler, CODE_NULL); +} + +// A number or string literal. +static void literal(Compiler* compiler, bool canAssign) +{ + emitConstant(compiler, compiler->parser->previous.value); +} + +// A string literal that contains interpolated expressions. +// +// Interpolation is syntactic sugar for calling ".join()" on a list. So the +// string: +// +// "a %(b + c) d" +// +// is compiled roughly like: +// +// ["a ", b + c, " d"].join() +static void stringInterpolation(Compiler* compiler, bool canAssign) +{ + // Instantiate a new list. + loadCoreVariable(compiler, "List"); + callMethod(compiler, 0, "new()", 5); + + do + { + // The opening string part. + literal(compiler, false); + callMethod(compiler, 1, "addCore_(_)", 11); + + // The interpolated expression. + ignoreNewlines(compiler); + expression(compiler); + callMethod(compiler, 1, "addCore_(_)", 11); + + ignoreNewlines(compiler); + } while (match(compiler, TOKEN_INTERPOLATION)); + + // The trailing string part. + consume(compiler, TOKEN_STRING, "Expect end of string interpolation."); + literal(compiler, false); + callMethod(compiler, 1, "addCore_(_)", 11); + + // The list of interpolated parts. + callMethod(compiler, 0, "join()", 6); +} + +static void super_(Compiler* compiler, bool canAssign) +{ + ClassInfo* enclosingClass = getEnclosingClass(compiler); + if (enclosingClass == NULL) + { + error(compiler, "Cannot use 'super' outside of a method."); + } + + loadThis(compiler); + + // TODO: Super operator calls. + // TODO: There's no syntax for invoking a superclass constructor with a + // different name from the enclosing one. Figure that out. + + // See if it's a named super call, or an unnamed one. + if (match(compiler, TOKEN_DOT)) + { + // Compile the superclass call. + consume(compiler, TOKEN_NAME, "Expect method name after 'super.'."); + namedCall(compiler, canAssign, CODE_SUPER_0); + } + else if (enclosingClass != NULL) + { + // No explicit name, so use the name of the enclosing method. Make sure we + // check that enclosingClass isn't NULL first. We've already reported the + // error, but we don't want to crash here. + methodCall(compiler, CODE_SUPER_0, enclosingClass->signature); + } +} + +static void this_(Compiler* compiler, bool canAssign) +{ + if (getEnclosingClass(compiler) == NULL) + { + error(compiler, "Cannot use 'this' outside of a method."); + return; + } + + loadThis(compiler); +} + +// Subscript or "array indexing" operator like `foo[bar]`. +static void subscript(Compiler* compiler, bool canAssign) +{ + Signature signature = { "", 0, SIG_SUBSCRIPT, 0 }; + + // Parse the argument list. + finishArgumentList(compiler, &signature); + consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after arguments."); + + allowLineBeforeDot(compiler); + + if (canAssign && match(compiler, TOKEN_EQ)) + { + signature.type = SIG_SUBSCRIPT_SETTER; + + // Compile the assigned value. + validateNumParameters(compiler, ++signature.arity); + expression(compiler); + } + + callSignature(compiler, CODE_CALL_0, &signature); +} + +static void call(Compiler* compiler, bool canAssign) +{ + ignoreNewlines(compiler); + consume(compiler, TOKEN_NAME, "Expect method name after '.'."); + namedCall(compiler, canAssign, CODE_CALL_0); +} + +static void and_(Compiler* compiler, bool canAssign) +{ + ignoreNewlines(compiler); + + // Skip the right argument if the left is false. + int jump = emitJump(compiler, CODE_AND); + parsePrecedence(compiler, PREC_LOGICAL_AND); + patchJump(compiler, jump); +} + +static void or_(Compiler* compiler, bool canAssign) +{ + ignoreNewlines(compiler); + + // Skip the right argument if the left is true. + int jump = emitJump(compiler, CODE_OR); + parsePrecedence(compiler, PREC_LOGICAL_OR); + patchJump(compiler, jump); +} + +static void conditional(Compiler* compiler, bool canAssign) +{ + // Ignore newline after '?'. + ignoreNewlines(compiler); + + // Jump to the else branch if the condition is false. + int ifJump = emitJump(compiler, CODE_JUMP_IF); + + // Compile the then branch. + parsePrecedence(compiler, PREC_CONDITIONAL); + + consume(compiler, TOKEN_COLON, + "Expect ':' after then branch of conditional operator."); + ignoreNewlines(compiler); + + // Jump over the else branch when the if branch is taken. + int elseJump = emitJump(compiler, CODE_JUMP); + + // Compile the else branch. + patchJump(compiler, ifJump); + + parsePrecedence(compiler, PREC_ASSIGNMENT); + + // Patch the jump over the else. + patchJump(compiler, elseJump); +} + +void infixOp(Compiler* compiler, bool canAssign) +{ + GrammarRule* rule = getRule(compiler->parser->previous.type); + + // An infix operator cannot end an expression. + ignoreNewlines(compiler); + + // Compile the right-hand side. + parsePrecedence(compiler, (Precedence)(rule->precedence + 1)); + + // Call the operator method on the left-hand side. + Signature signature = { rule->name, (int)strlen(rule->name), SIG_METHOD, 1 }; + callSignature(compiler, CODE_CALL_0, &signature); +} + +// Compiles a method signature for an infix operator. +void infixSignature(Compiler* compiler, Signature* signature) +{ + // Add the RHS parameter. + signature->type = SIG_METHOD; + signature->arity = 1; + + // Parse the parameter name. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after operator name."); + declareNamedVariable(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name."); +} + +// Compiles a method signature for an unary operator (i.e. "!"). +void unarySignature(Compiler* compiler, Signature* signature) +{ + // Do nothing. The name is already complete. + signature->type = SIG_GETTER; +} + +// Compiles a method signature for an operator that can either be unary or +// infix (i.e. "-"). +void mixedSignature(Compiler* compiler, Signature* signature) +{ + signature->type = SIG_GETTER; + + // If there is a parameter, it's an infix operator, otherwise it's unary. + if (match(compiler, TOKEN_LEFT_PAREN)) + { + // Add the RHS parameter. + signature->type = SIG_METHOD; + signature->arity = 1; + + // Parse the parameter name. + declareNamedVariable(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name."); + } +} + +// Compiles an optional setter parameter in a method [signature]. +// +// Returns `true` if it was a setter. +static bool maybeSetter(Compiler* compiler, Signature* signature) +{ + // See if it's a setter. + if (!match(compiler, TOKEN_EQ)) return false; + + // It's a setter. + if (signature->type == SIG_SUBSCRIPT) + { + signature->type = SIG_SUBSCRIPT_SETTER; + } + else + { + signature->type = SIG_SETTER; + } + + // Parse the value parameter. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after '='."); + declareNamedVariable(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name."); + + signature->arity++; + + return true; +} + +// Compiles a method signature for a subscript operator. +void subscriptSignature(Compiler* compiler, Signature* signature) +{ + signature->type = SIG_SUBSCRIPT; + + // The signature currently has "[" as its name since that was the token that + // matched it. Clear that out. + signature->length = 0; + + // Parse the parameters inside the subscript. + finishParameterList(compiler, signature); + consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after parameters."); + + maybeSetter(compiler, signature); +} + +// Parses an optional parenthesized parameter list. Updates `type` and `arity` +// in [signature] to match what was parsed. +static void parameterList(Compiler* compiler, Signature* signature) +{ + // The parameter list is optional. + if (!match(compiler, TOKEN_LEFT_PAREN)) return; + + signature->type = SIG_METHOD; + + // Allow new line before an empty argument list + ignoreNewlines(compiler); + + // Allow an empty parameter list. + if (match(compiler, TOKEN_RIGHT_PAREN)) return; + + finishParameterList(compiler, signature); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameters."); +} + +// Compiles a method signature for a named method or setter. +void namedSignature(Compiler* compiler, Signature* signature) +{ + signature->type = SIG_GETTER; + + // If it's a setter, it can't also have a parameter list. + if (maybeSetter(compiler, signature)) return; + + // Regular named method with an optional parameter list. + parameterList(compiler, signature); +} + +// Compiles a method signature for a constructor. +void constructorSignature(Compiler* compiler, Signature* signature) +{ + consume(compiler, TOKEN_NAME, "Expect constructor name after 'construct'."); + + // Capture the name. + *signature = signatureFromToken(compiler, SIG_INITIALIZER); + + if (match(compiler, TOKEN_EQ)) + { + error(compiler, "A constructor cannot be a setter."); + } + + if (!match(compiler, TOKEN_LEFT_PAREN)) + { + error(compiler, "A constructor cannot be a getter."); + return; + } + + // Allow an empty parameter list. + if (match(compiler, TOKEN_RIGHT_PAREN)) return; + + finishParameterList(compiler, signature); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameters."); +} + +// This table defines all of the parsing rules for the prefix and infix +// expressions in the grammar. Expressions are parsed using a Pratt parser. +// +// See: http://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ +#define UNUSED { NULL, NULL, NULL, PREC_NONE, NULL } +#define PREFIX(fn) { fn, NULL, NULL, PREC_NONE, NULL } +#define INFIX(prec, fn) { NULL, fn, NULL, prec, NULL } +#define INFIX_OPERATOR(prec, name) { NULL, infixOp, infixSignature, prec, name } +#define PREFIX_OPERATOR(name) { unaryOp, NULL, unarySignature, PREC_NONE, name } +#define OPERATOR(name) { unaryOp, infixOp, mixedSignature, PREC_TERM, name } + +GrammarRule rules[] = +{ + /* TOKEN_LEFT_PAREN */ PREFIX(grouping), + /* TOKEN_RIGHT_PAREN */ UNUSED, + /* TOKEN_LEFT_BRACKET */ { list, subscript, subscriptSignature, PREC_CALL, NULL }, + /* TOKEN_RIGHT_BRACKET */ UNUSED, + /* TOKEN_LEFT_BRACE */ PREFIX(map), + /* TOKEN_RIGHT_BRACE */ UNUSED, + /* TOKEN_COLON */ UNUSED, + /* TOKEN_DOT */ INFIX(PREC_CALL, call), + /* TOKEN_DOTDOT */ INFIX_OPERATOR(PREC_RANGE, ".."), + /* TOKEN_DOTDOTDOT */ INFIX_OPERATOR(PREC_RANGE, "..."), + /* TOKEN_COMMA */ UNUSED, + /* TOKEN_STAR */ INFIX_OPERATOR(PREC_FACTOR, "*"), + /* TOKEN_SLASH */ INFIX_OPERATOR(PREC_FACTOR, "/"), + /* TOKEN_PERCENT */ INFIX_OPERATOR(PREC_FACTOR, "%"), + /* TOKEN_HASH */ UNUSED, + /* TOKEN_PLUS */ INFIX_OPERATOR(PREC_TERM, "+"), + /* TOKEN_MINUS */ OPERATOR("-"), + /* TOKEN_LTLT */ INFIX_OPERATOR(PREC_BITWISE_SHIFT, "<<"), + /* TOKEN_GTGT */ INFIX_OPERATOR(PREC_BITWISE_SHIFT, ">>"), + /* TOKEN_PIPE */ INFIX_OPERATOR(PREC_BITWISE_OR, "|"), + /* TOKEN_PIPEPIPE */ INFIX(PREC_LOGICAL_OR, or_), + /* TOKEN_CARET */ INFIX_OPERATOR(PREC_BITWISE_XOR, "^"), + /* TOKEN_AMP */ INFIX_OPERATOR(PREC_BITWISE_AND, "&"), + /* TOKEN_AMPAMP */ INFIX(PREC_LOGICAL_AND, and_), + /* TOKEN_BANG */ PREFIX_OPERATOR("!"), + /* TOKEN_TILDE */ PREFIX_OPERATOR("~"), + /* TOKEN_QUESTION */ INFIX(PREC_ASSIGNMENT, conditional), + /* TOKEN_EQ */ UNUSED, + /* TOKEN_LT */ INFIX_OPERATOR(PREC_COMPARISON, "<"), + /* TOKEN_GT */ INFIX_OPERATOR(PREC_COMPARISON, ">"), + /* TOKEN_LTEQ */ INFIX_OPERATOR(PREC_COMPARISON, "<="), + /* TOKEN_GTEQ */ INFIX_OPERATOR(PREC_COMPARISON, ">="), + /* TOKEN_EQEQ */ INFIX_OPERATOR(PREC_EQUALITY, "=="), + /* TOKEN_BANGEQ */ INFIX_OPERATOR(PREC_EQUALITY, "!="), + /* TOKEN_BREAK */ UNUSED, + /* TOKEN_CONTINUE */ UNUSED, + /* TOKEN_CLASS */ UNUSED, + /* TOKEN_CONSTRUCT */ { NULL, NULL, constructorSignature, PREC_NONE, NULL }, + /* TOKEN_ELSE */ UNUSED, + /* TOKEN_FALSE */ PREFIX(boolean), + /* TOKEN_FOR */ UNUSED, + /* TOKEN_FOREIGN */ UNUSED, + /* TOKEN_IF */ UNUSED, + /* TOKEN_IMPORT */ UNUSED, + /* TOKEN_AS */ UNUSED, + /* TOKEN_IN */ UNUSED, + /* TOKEN_IS */ INFIX_OPERATOR(PREC_IS, "is"), + /* TOKEN_NULL */ PREFIX(null), + /* TOKEN_RETURN */ UNUSED, + /* TOKEN_STATIC */ UNUSED, + /* TOKEN_SUPER */ PREFIX(super_), + /* TOKEN_THIS */ PREFIX(this_), + /* TOKEN_TRUE */ PREFIX(boolean), + /* TOKEN_VAR */ UNUSED, + /* TOKEN_WHILE */ UNUSED, + /* TOKEN_FIELD */ PREFIX(field), + /* TOKEN_STATIC_FIELD */ PREFIX(staticField), + /* TOKEN_NAME */ { name, NULL, namedSignature, PREC_NONE, NULL }, + /* TOKEN_NUMBER */ PREFIX(literal), + /* TOKEN_STRING */ PREFIX(literal), + /* TOKEN_INTERPOLATION */ PREFIX(stringInterpolation), + /* TOKEN_LINE */ UNUSED, + /* TOKEN_ERROR */ UNUSED, + /* TOKEN_EOF */ UNUSED +}; + +// Gets the [GrammarRule] associated with tokens of [type]. +static GrammarRule* getRule(TokenType type) +{ + return &rules[type]; +} + +// The main entrypoint for the top-down operator precedence parser. +void parsePrecedence(Compiler* compiler, Precedence precedence) +{ + nextToken(compiler->parser); + GrammarFn prefix = rules[compiler->parser->previous.type].prefix; + + if (prefix == NULL) + { + error(compiler, "Expected expression."); + return; + } + + // Track if the precendence of the surrounding expression is low enough to + // allow an assignment inside this one. We can't compile an assignment like + // a normal expression because it requires us to handle the LHS specially -- + // it needs to be an lvalue, not an rvalue. So, for each of the kinds of + // expressions that are valid lvalues -- names, subscripts, fields, etc. -- + // we pass in whether or not it appears in a context loose enough to allow + // "=". If so, it will parse the "=" itself and handle it appropriately. + bool canAssign = precedence <= PREC_CONDITIONAL; + prefix(compiler, canAssign); + + while (precedence <= rules[compiler->parser->current.type].precedence) + { + nextToken(compiler->parser); + GrammarFn infix = rules[compiler->parser->previous.type].infix; + infix(compiler, canAssign); + } +} + +// Parses an expression. Unlike statements, expressions leave a resulting value +// on the stack. +void expression(Compiler* compiler) +{ + parsePrecedence(compiler, PREC_LOWEST); +} + +// Returns the number of bytes for the arguments to the instruction +// at [ip] in [fn]'s bytecode. +static int getByteCountForArguments(const uint8_t* bytecode, + const Value* constants, int ip) +{ + Code instruction = (Code)bytecode[ip]; + switch (instruction) + { + case CODE_NULL: + case CODE_FALSE: + case CODE_TRUE: + case CODE_POP: + case CODE_CLOSE_UPVALUE: + case CODE_RETURN: + case CODE_END: + case CODE_LOAD_LOCAL_0: + case CODE_LOAD_LOCAL_1: + case CODE_LOAD_LOCAL_2: + case CODE_LOAD_LOCAL_3: + case CODE_LOAD_LOCAL_4: + case CODE_LOAD_LOCAL_5: + case CODE_LOAD_LOCAL_6: + case CODE_LOAD_LOCAL_7: + case CODE_LOAD_LOCAL_8: + case CODE_CONSTRUCT: + case CODE_FOREIGN_CONSTRUCT: + case CODE_FOREIGN_CLASS: + case CODE_END_MODULE: + case CODE_END_CLASS: + return 0; + + case CODE_LOAD_LOCAL: + case CODE_STORE_LOCAL: + case CODE_LOAD_UPVALUE: + case CODE_STORE_UPVALUE: + case CODE_LOAD_FIELD_THIS: + case CODE_STORE_FIELD_THIS: + case CODE_LOAD_FIELD: + case CODE_STORE_FIELD: + case CODE_CLASS: + return 1; + + case CODE_CONSTANT: + case CODE_LOAD_MODULE_VAR: + case CODE_STORE_MODULE_VAR: + case CODE_CALL_0: + case CODE_CALL_1: + case CODE_CALL_2: + case CODE_CALL_3: + case CODE_CALL_4: + case CODE_CALL_5: + case CODE_CALL_6: + case CODE_CALL_7: + case CODE_CALL_8: + case CODE_CALL_9: + case CODE_CALL_10: + case CODE_CALL_11: + case CODE_CALL_12: + case CODE_CALL_13: + case CODE_CALL_14: + case CODE_CALL_15: + case CODE_CALL_16: + case CODE_JUMP: + case CODE_LOOP: + case CODE_JUMP_IF: + case CODE_AND: + case CODE_OR: + case CODE_METHOD_INSTANCE: + case CODE_METHOD_STATIC: + case CODE_IMPORT_MODULE: + case CODE_IMPORT_VARIABLE: + return 2; + + case CODE_SUPER_0: + case CODE_SUPER_1: + case CODE_SUPER_2: + case CODE_SUPER_3: + case CODE_SUPER_4: + case CODE_SUPER_5: + case CODE_SUPER_6: + case CODE_SUPER_7: + case CODE_SUPER_8: + case CODE_SUPER_9: + case CODE_SUPER_10: + case CODE_SUPER_11: + case CODE_SUPER_12: + case CODE_SUPER_13: + case CODE_SUPER_14: + case CODE_SUPER_15: + case CODE_SUPER_16: + return 4; + + case CODE_CLOSURE: + { + int constant = (bytecode[ip + 1] << 8) | bytecode[ip + 2]; + ObjFn* loadedFn = AS_FN(constants[constant]); + + // There are two bytes for the constant, then two for each upvalue. + return 2 + (loadedFn->numUpvalues * 2); + } + } + + UNREACHABLE(); + return 0; +} + +// Marks the beginning of a loop. Keeps track of the current instruction so we +// know what to loop back to at the end of the body. +static void startLoop(Compiler* compiler, Loop* loop) +{ + loop->enclosing = compiler->loop; + loop->start = compiler->fn->code.count - 1; + loop->scopeDepth = compiler->scopeDepth; + compiler->loop = loop; +} + +// Emits the [CODE_JUMP_IF] instruction used to test the loop condition and +// potentially exit the loop. Keeps track of the instruction so we can patch it +// later once we know where the end of the body is. +static void testExitLoop(Compiler* compiler) +{ + compiler->loop->exitJump = emitJump(compiler, CODE_JUMP_IF); +} + +// Compiles the body of the loop and tracks its extent so that contained "break" +// statements can be handled correctly. +static void loopBody(Compiler* compiler) +{ + compiler->loop->body = compiler->fn->code.count; + statement(compiler); +} + +// Ends the current innermost loop. Patches up all jumps and breaks now that +// we know where the end of the loop is. +static void endLoop(Compiler* compiler) +{ + // We don't check for overflow here since the forward jump over the loop body + // will report an error for the same problem. + int loopOffset = compiler->fn->code.count - compiler->loop->start + 2; + emitShortArg(compiler, CODE_LOOP, loopOffset); + + patchJump(compiler, compiler->loop->exitJump); + + // Find any break placeholder instructions (which will be CODE_END in the + // bytecode) and replace them with real jumps. + int i = compiler->loop->body; + while (i < compiler->fn->code.count) + { + if (compiler->fn->code.data[i] == CODE_END) + { + compiler->fn->code.data[i] = CODE_JUMP; + patchJump(compiler, i + 1); + i += 3; + } + else + { + // Skip this instruction and its arguments. + i += 1 + getByteCountForArguments(compiler->fn->code.data, + compiler->fn->constants.data, i); + } + } + + compiler->loop = compiler->loop->enclosing; +} + +static void forStatement(Compiler* compiler) +{ + // A for statement like: + // + // for (i in sequence.expression) { + // System.print(i) + // } + // + // Is compiled to bytecode almost as if the source looked like this: + // + // { + // var seq_ = sequence.expression + // var iter_ + // while (iter_ = seq_.iterate(iter_)) { + // var i = seq_.iteratorValue(iter_) + // System.print(i) + // } + // } + // + // It's not exactly this, because the synthetic variables `seq_` and `iter_` + // actually get names that aren't valid Wren identfiers, but that's the basic + // idea. + // + // The important parts are: + // - The sequence expression is only evaluated once. + // - The .iterate() method is used to advance the iterator and determine if + // it should exit the loop. + // - The .iteratorValue() method is used to get the value at the current + // iterator position. + + // Create a scope for the hidden local variables used for the iterator. + pushScope(compiler); + + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'for'."); + consume(compiler, TOKEN_NAME, "Expect for loop variable name."); + + // Remember the name of the loop variable. + const char* name = compiler->parser->previous.start; + int length = compiler->parser->previous.length; + + consume(compiler, TOKEN_IN, "Expect 'in' after loop variable."); + ignoreNewlines(compiler); + + // Evaluate the sequence expression and store it in a hidden local variable. + // The space in the variable name ensures it won't collide with a user-defined + // variable. + expression(compiler); + + // Verify that there is space to hidden local variables. + // Note that we expect only two addLocal calls next to each other in the + // following code. + if (compiler->numLocals + 2 > MAX_LOCALS) + { + error(compiler, "Cannot declare more than %d variables in one scope. (Not enough space for for-loops internal variables)", + MAX_LOCALS); + return; + } + int seqSlot = addLocal(compiler, "seq ", 4); + + // Create another hidden local for the iterator object. + null(compiler, false); + int iterSlot = addLocal(compiler, "iter ", 5); + + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after loop expression."); + + Loop loop; + startLoop(compiler, &loop); + + // Advance the iterator by calling the ".iterate" method on the sequence. + loadLocal(compiler, seqSlot); + loadLocal(compiler, iterSlot); + + // Update and test the iterator. + callMethod(compiler, 1, "iterate(_)", 10); + emitByteArg(compiler, CODE_STORE_LOCAL, iterSlot); + testExitLoop(compiler); + + // Get the current value in the sequence by calling ".iteratorValue". + loadLocal(compiler, seqSlot); + loadLocal(compiler, iterSlot); + callMethod(compiler, 1, "iteratorValue(_)", 16); + + // Bind the loop variable in its own scope. This ensures we get a fresh + // variable each iteration so that closures for it don't all see the same one. + pushScope(compiler); + addLocal(compiler, name, length); + + loopBody(compiler); + + // Loop variable. + popScope(compiler); + + endLoop(compiler); + + // Hidden variables. + popScope(compiler); +} + +static void ifStatement(Compiler* compiler) +{ + // Compile the condition. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'if'."); + expression(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after if condition."); + + // Jump to the else branch if the condition is false. + int ifJump = emitJump(compiler, CODE_JUMP_IF); + + // Compile the then branch. + statement(compiler); + + // Compile the else branch if there is one. + if (match(compiler, TOKEN_ELSE)) + { + // Jump over the else branch when the if branch is taken. + int elseJump = emitJump(compiler, CODE_JUMP); + patchJump(compiler, ifJump); + + statement(compiler); + + // Patch the jump over the else. + patchJump(compiler, elseJump); + } + else + { + patchJump(compiler, ifJump); + } +} + +static void whileStatement(Compiler* compiler) +{ + Loop loop; + startLoop(compiler, &loop); + + // Compile the condition. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'while'."); + expression(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after while condition."); + + testExitLoop(compiler); + loopBody(compiler); + endLoop(compiler); +} + +// Compiles a simple statement. These can only appear at the top-level or +// within curly blocks. Simple statements exclude variable binding statements +// like "var" and "class" which are not allowed directly in places like the +// branches of an "if" statement. +// +// Unlike expressions, statements do not leave a value on the stack. +void statement(Compiler* compiler) +{ + if (match(compiler, TOKEN_BREAK)) + { + if (compiler->loop == NULL) + { + error(compiler, "Cannot use 'break' outside of a loop."); + return; + } + + // Since we will be jumping out of the scope, make sure any locals in it + // are discarded first. + discardLocals(compiler, compiler->loop->scopeDepth + 1); + + // Emit a placeholder instruction for the jump to the end of the body. When + // we're done compiling the loop body and know where the end is, we'll + // replace these with `CODE_JUMP` instructions with appropriate offsets. + // We use `CODE_END` here because that can't occur in the middle of + // bytecode. + emitJump(compiler, CODE_END); + } + else if (match(compiler, TOKEN_CONTINUE)) + { + if (compiler->loop == NULL) + { + error(compiler, "Cannot use 'continue' outside of a loop."); + return; + } + + // Since we will be jumping out of the scope, make sure any locals in it + // are discarded first. + discardLocals(compiler, compiler->loop->scopeDepth + 1); + + // emit a jump back to the top of the loop + int loopOffset = compiler->fn->code.count - compiler->loop->start + 2; + emitShortArg(compiler, CODE_LOOP, loopOffset); + } + else if (match(compiler, TOKEN_FOR)) + { + forStatement(compiler); + } + else if (match(compiler, TOKEN_IF)) + { + ifStatement(compiler); + } + else if (match(compiler, TOKEN_RETURN)) + { + // Compile the return value. + if (peek(compiler) == TOKEN_LINE) + { + // If there's no expression after return, initializers should + // return 'this' and regular methods should return null + Code result = compiler->isInitializer ? CODE_LOAD_LOCAL_0 : CODE_NULL; + emitOp(compiler, result); + } + else + { + if (compiler->isInitializer) + { + error(compiler, "A constructor cannot return a value."); + } + + expression(compiler); + } + + emitOp(compiler, CODE_RETURN); + } + else if (match(compiler, TOKEN_WHILE)) + { + whileStatement(compiler); + } + else if (match(compiler, TOKEN_LEFT_BRACE)) + { + // Block statement. + pushScope(compiler); + if (finishBlock(compiler)) + { + // Block was an expression, so discard it. + emitOp(compiler, CODE_POP); + } + popScope(compiler); + } + else + { + // Expression statement. + expression(compiler); + emitOp(compiler, CODE_POP); + } +} + +// Creates a matching constructor method for an initializer with [signature] +// and [initializerSymbol]. +// +// Construction is a two-stage process in Wren that involves two separate +// methods. There is a static method that allocates a new instance of the class. +// It then invokes an initializer method on the new instance, forwarding all of +// the constructor arguments to it. +// +// The allocator method always has a fixed implementation: +// +// CODE_CONSTRUCT - Replace the class in slot 0 with a new instance of it. +// CODE_CALL - Invoke the initializer on the new instance. +// +// This creates that method and calls the initializer with [initializerSymbol]. +static void createConstructor(Compiler* compiler, Signature* signature, + int initializerSymbol) +{ + Compiler methodCompiler; + initCompiler(&methodCompiler, compiler->parser, compiler, true); + + // Allocate the instance. + emitOp(&methodCompiler, compiler->enclosingClass->isForeign + ? CODE_FOREIGN_CONSTRUCT : CODE_CONSTRUCT); + + // Run its initializer. + emitShortArg(&methodCompiler, (Code)(CODE_CALL_0 + signature->arity), + initializerSymbol); + + // Return the instance. + emitOp(&methodCompiler, CODE_RETURN); + + endCompiler(&methodCompiler, "", 0); +} + +// Loads the enclosing class onto the stack and then binds the function already +// on the stack as a method on that class. +static void defineMethod(Compiler* compiler, Variable classVariable, + bool isStatic, int methodSymbol) +{ + // Load the class. We have to do this for each method because we can't + // keep the class on top of the stack. If there are static fields, they + // will be locals above the initial variable slot for the class on the + // stack. To skip past those, we just load the class each time right before + // defining a method. + loadVariable(compiler, classVariable); + + // Define the method. + Code instruction = isStatic ? CODE_METHOD_STATIC : CODE_METHOD_INSTANCE; + emitShortArg(compiler, instruction, methodSymbol); +} + +// Declares a method in the enclosing class with [signature]. +// +// Reports an error if a method with that signature is already declared. +// Returns the symbol for the method. +static int declareMethod(Compiler* compiler, Signature* signature, + const char* name, int length) +{ + int symbol = signatureSymbol(compiler, signature); + + // See if the class has already declared method with this signature. + ClassInfo* classInfo = compiler->enclosingClass; + IntBuffer* methods = classInfo->inStatic + ? &classInfo->staticMethods : &classInfo->methods; + for (int i = 0; i < methods->count; i++) + { + if (methods->data[i] == symbol) + { + const char* staticPrefix = classInfo->inStatic ? "static " : ""; + error(compiler, "Class %s already defines a %smethod '%s'.", + &compiler->enclosingClass->name->value, staticPrefix, name); + break; + } + } + + wrenIntBufferWrite(compiler->parser->vm, methods, symbol); + return symbol; +} + +static Value consumeLiteral(Compiler* compiler, const char* message) +{ + if(match(compiler, TOKEN_FALSE)) return FALSE_VAL; + if(match(compiler, TOKEN_TRUE)) return TRUE_VAL; + if(match(compiler, TOKEN_NUMBER)) return compiler->parser->previous.value; + if(match(compiler, TOKEN_STRING)) return compiler->parser->previous.value; + if(match(compiler, TOKEN_NAME)) return compiler->parser->previous.value; + + error(compiler, message); + nextToken(compiler->parser); + return NULL_VAL; +} + +static bool matchAttribute(Compiler* compiler) { + + if(match(compiler, TOKEN_HASH)) + { + compiler->numAttributes++; + bool runtimeAccess = match(compiler, TOKEN_BANG); + if(match(compiler, TOKEN_NAME)) + { + Value group = compiler->parser->previous.value; + TokenType ahead = peek(compiler); + if(ahead == TOKEN_EQ || ahead == TOKEN_LINE) + { + Value key = group; + Value value = NULL_VAL; + if(match(compiler, TOKEN_EQ)) + { + value = consumeLiteral(compiler, "Expect a Bool, Num, String or Identifier literal for an attribute value."); + } + if(runtimeAccess) addToAttributeGroup(compiler, NULL_VAL, key, value); + } + else if(match(compiler, TOKEN_LEFT_PAREN)) + { + ignoreNewlines(compiler); + if(match(compiler, TOKEN_RIGHT_PAREN)) + { + error(compiler, "Expected attributes in group, group cannot be empty."); + } + else + { + while(peek(compiler) != TOKEN_RIGHT_PAREN) + { + consume(compiler, TOKEN_NAME, "Expect name for attribute key."); + Value key = compiler->parser->previous.value; + Value value = NULL_VAL; + if(match(compiler, TOKEN_EQ)) + { + value = consumeLiteral(compiler, "Expect a Bool, Num, String or Identifier literal for an attribute value."); + } + if(runtimeAccess) addToAttributeGroup(compiler, group, key, value); + ignoreNewlines(compiler); + if(!match(compiler, TOKEN_COMMA)) break; + ignoreNewlines(compiler); + } + + ignoreNewlines(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, + "Expected ')' after grouped attributes."); + } + } + else + { + error(compiler, "Expect an equal, newline or grouping after an attribute key."); + } + } + else + { + error(compiler, "Expect an attribute definition after #."); + } + + consumeLine(compiler, "Expect newline after attribute."); + return true; + } + + return false; +} + +// Compiles a method definition inside a class body. +// +// Returns `true` if it compiled successfully, or `false` if the method couldn't +// be parsed. +static bool method(Compiler* compiler, Variable classVariable) +{ + // Parse any attributes before the method and store them + if(matchAttribute(compiler)) { + return method(compiler, classVariable); + } + + // TODO: What about foreign constructors? + bool isForeign = match(compiler, TOKEN_FOREIGN); + bool isStatic = match(compiler, TOKEN_STATIC); + compiler->enclosingClass->inStatic = isStatic; + + SignatureFn signatureFn = rules[compiler->parser->current.type].method; + nextToken(compiler->parser); + + if (signatureFn == NULL) + { + error(compiler, "Expect method definition."); + return false; + } + + // Build the method signature. + Signature signature = signatureFromToken(compiler, SIG_GETTER); + compiler->enclosingClass->signature = &signature; + + Compiler methodCompiler; + initCompiler(&methodCompiler, compiler->parser, compiler, true); + + // Compile the method signature. + signatureFn(&methodCompiler, &signature); + + methodCompiler.isInitializer = signature.type == SIG_INITIALIZER; + + if (isStatic && signature.type == SIG_INITIALIZER) + { + error(compiler, "A constructor cannot be static."); + } + + // Include the full signature in debug messages in stack traces. + char fullSignature[MAX_METHOD_SIGNATURE]; + int length; + signatureToString(&signature, fullSignature, &length); + + // Copy any attributes the compiler collected into the enclosing class + copyMethodAttributes(compiler, isForeign, isStatic, fullSignature, length); + + // Check for duplicate methods. Doesn't matter that it's already been + // defined, error will discard bytecode anyway. + // Check if the method table already contains this symbol + int methodSymbol = declareMethod(compiler, &signature, fullSignature, length); + + if (isForeign) + { + // Define a constant for the signature. + emitConstant(compiler, wrenNewStringLength(compiler->parser->vm, + fullSignature, length)); + + // We don't need the function we started compiling in the parameter list + // any more. + methodCompiler.parser->vm->compiler = methodCompiler.parent; + } + else + { + consume(compiler, TOKEN_LEFT_BRACE, "Expect '{' to begin method body."); + finishBody(&methodCompiler); + endCompiler(&methodCompiler, fullSignature, length); + } + + // Define the method. For a constructor, this defines the instance + // initializer method. + defineMethod(compiler, classVariable, isStatic, methodSymbol); + + if (signature.type == SIG_INITIALIZER) + { + // Also define a matching constructor method on the metaclass. + signature.type = SIG_METHOD; + int constructorSymbol = signatureSymbol(compiler, &signature); + + createConstructor(compiler, &signature, methodSymbol); + defineMethod(compiler, classVariable, true, constructorSymbol); + } + + return true; +} + +// Compiles a class definition. Assumes the "class" token has already been +// consumed (along with a possibly preceding "foreign" token). +static void classDefinition(Compiler* compiler, bool isForeign) +{ + // Create a variable to store the class in. + Variable classVariable; + classVariable.scope = compiler->scopeDepth == -1 ? SCOPE_MODULE : SCOPE_LOCAL; + classVariable.index = declareNamedVariable(compiler); + + // Create shared class name value + Value classNameString = wrenNewStringLength(compiler->parser->vm, + compiler->parser->previous.start, compiler->parser->previous.length); + + // Create class name string to track method duplicates + ObjString* className = AS_STRING(classNameString); + + // Make a string constant for the name. + emitConstant(compiler, classNameString); + + // Load the superclass (if there is one). + if (match(compiler, TOKEN_IS)) + { + parsePrecedence(compiler, PREC_CALL); + } + else + { + // Implicitly inherit from Object. + loadCoreVariable(compiler, "Object"); + } + + // Store a placeholder for the number of fields argument. We don't know the + // count until we've compiled all the methods to see which fields are used. + int numFieldsInstruction = -1; + if (isForeign) + { + emitOp(compiler, CODE_FOREIGN_CLASS); + } + else + { + numFieldsInstruction = emitByteArg(compiler, CODE_CLASS, 255); + } + + // Store it in its name. + defineVariable(compiler, classVariable.index); + + // Push a local variable scope. Static fields in a class body are hoisted out + // into local variables declared in this scope. Methods that use them will + // have upvalues referencing them. + pushScope(compiler); + + ClassInfo classInfo; + classInfo.isForeign = isForeign; + classInfo.name = className; + + // Allocate attribute maps if necessary. + // A method will allocate the methods one if needed + classInfo.classAttributes = compiler->attributes->count > 0 + ? wrenNewMap(compiler->parser->vm) + : NULL; + classInfo.methodAttributes = NULL; + // Copy any existing attributes into the class + copyAttributes(compiler, classInfo.classAttributes); + + // Set up a symbol table for the class's fields. We'll initially compile + // them to slots starting at zero. When the method is bound to the class, the + // bytecode will be adjusted by [wrenBindMethod] to take inherited fields + // into account. + wrenSymbolTableInit(&classInfo.fields); + + // Set up symbol buffers to track duplicate static and instance methods. + wrenIntBufferInit(&classInfo.methods); + wrenIntBufferInit(&classInfo.staticMethods); + compiler->enclosingClass = &classInfo; + + // Compile the method definitions. + consume(compiler, TOKEN_LEFT_BRACE, "Expect '{' after class declaration."); + matchLine(compiler); + + while (!match(compiler, TOKEN_RIGHT_BRACE)) + { + if (!method(compiler, classVariable)) break; + + // Don't require a newline after the last definition. + if (match(compiler, TOKEN_RIGHT_BRACE)) break; + + consumeLine(compiler, "Expect newline after definition in class."); + } + + // If any attributes are present, + // instantiate a ClassAttributes instance for the class + // and send it over to CODE_END_CLASS + bool hasAttr = classInfo.classAttributes != NULL || + classInfo.methodAttributes != NULL; + if(hasAttr) { + emitClassAttributes(compiler, &classInfo); + loadVariable(compiler, classVariable); + // At the moment, we don't have other uses for CODE_END_CLASS, + // so we put it inside this condition. Later, we can always + // emit it and use it as needed. + emitOp(compiler, CODE_END_CLASS); + } + + // Update the class with the number of fields. + if (!isForeign) + { + compiler->fn->code.data[numFieldsInstruction] = + (uint8_t)classInfo.fields.count; + } + + // Clear symbol tables for tracking field and method names. + wrenSymbolTableClear(compiler->parser->vm, &classInfo.fields); + wrenIntBufferClear(compiler->parser->vm, &classInfo.methods); + wrenIntBufferClear(compiler->parser->vm, &classInfo.staticMethods); + compiler->enclosingClass = NULL; + popScope(compiler); +} + +// Compiles an "import" statement. +// +// An import compiles to a series of instructions. Given: +// +// import "foo" for Bar, Baz +// +// We compile a single IMPORT_MODULE "foo" instruction to load the module +// itself. When that finishes executing the imported module, it leaves the +// ObjModule in vm->lastModule. Then, for Bar and Baz, we: +// +// * Declare a variable in the current scope with that name. +// * Emit an IMPORT_VARIABLE instruction to load the variable's value from the +// other module. +// * Compile the code to store that value in the variable in this scope. +static void import(Compiler* compiler) +{ + ignoreNewlines(compiler); + consume(compiler, TOKEN_STRING, "Expect a string after 'import'."); + int moduleConstant = addConstant(compiler, compiler->parser->previous.value); + + // Load the module. + emitShortArg(compiler, CODE_IMPORT_MODULE, moduleConstant); + + // Discard the unused result value from calling the module body's closure. + emitOp(compiler, CODE_POP); + + // The for clause is optional. + if (!match(compiler, TOKEN_FOR)) return; + + // Compile the comma-separated list of variables to import. + do + { + ignoreNewlines(compiler); + + consume(compiler, TOKEN_NAME, "Expect variable name."); + + // We need to hold onto the source variable, + // in order to reference it in the import later + Token sourceVariableToken = compiler->parser->previous; + + // Define a string constant for the original variable name. + int sourceVariableConstant = addConstant(compiler, + wrenNewStringLength(compiler->parser->vm, + sourceVariableToken.start, + sourceVariableToken.length)); + + // Store the symbol we care about for the variable + int slot = -1; + if(match(compiler, TOKEN_AS)) + { + //import "module" for Source as Dest + //Use 'Dest' as the name by declaring a new variable for it. + //This parses a name after the 'as' and defines it. + slot = declareNamedVariable(compiler); + } + else + { + //import "module" for Source + //Uses 'Source' as the name directly + slot = declareVariable(compiler, &sourceVariableToken); + } + + // Load the variable from the other module. + emitShortArg(compiler, CODE_IMPORT_VARIABLE, sourceVariableConstant); + + // Store the result in the variable here. + defineVariable(compiler, slot); + } while (match(compiler, TOKEN_COMMA)); +} + +// Compiles a "var" variable definition statement. +static void variableDefinition(Compiler* compiler) +{ + // Grab its name, but don't declare it yet. A (local) variable shouldn't be + // in scope in its own initializer. + consume(compiler, TOKEN_NAME, "Expect variable name."); + Token nameToken = compiler->parser->previous; + + // Compile the initializer. + if (match(compiler, TOKEN_EQ)) + { + ignoreNewlines(compiler); + expression(compiler); + } + else + { + // Default initialize it to null. + null(compiler, false); + } + + // Now put it in scope. + int symbol = declareVariable(compiler, &nameToken); + defineVariable(compiler, symbol); +} + +// Compiles a "definition". These are the statements that bind new variables. +// They can only appear at the top level of a block and are prohibited in places +// like the non-curly body of an if or while. +void definition(Compiler* compiler) +{ + if(matchAttribute(compiler)) { + definition(compiler); + return; + } + + if (match(compiler, TOKEN_CLASS)) + { + classDefinition(compiler, false); + return; + } + else if (match(compiler, TOKEN_FOREIGN)) + { + consume(compiler, TOKEN_CLASS, "Expect 'class' after 'foreign'."); + classDefinition(compiler, true); + return; + } + + disallowAttributes(compiler); + + if (match(compiler, TOKEN_IMPORT)) + { + import(compiler); + } + else if (match(compiler, TOKEN_VAR)) + { + variableDefinition(compiler); + } + else + { + statement(compiler); + } +} + +ObjFn* wrenCompile(WrenVM* vm, ObjModule* module, const char* source, + bool isExpression, bool printErrors) +{ + // Skip the UTF-8 BOM if there is one. + if (strncmp(source, "\xEF\xBB\xBF", 3) == 0) source += 3; + + Parser parser; + parser.vm = vm; + parser.module = module; + parser.source = source; + + parser.tokenStart = source; + parser.currentChar = source; + parser.currentLine = 1; + parser.numParens = 0; + + // Zero-init the current token. This will get copied to previous when + // nextToken() is called below. + parser.next.type = TOKEN_ERROR; + parser.next.start = source; + parser.next.length = 0; + parser.next.line = 0; + parser.next.value = UNDEFINED_VAL; + + parser.printErrors = printErrors; + parser.hasError = false; + + // Read the first token into next + nextToken(&parser); + // Copy next -> current + nextToken(&parser); + + int numExistingVariables = module->variables.count; + + Compiler compiler; + initCompiler(&compiler, &parser, NULL, false); + ignoreNewlines(&compiler); + + if (isExpression) + { + expression(&compiler); + consume(&compiler, TOKEN_EOF, "Expect end of expression."); + } + else + { + while (!match(&compiler, TOKEN_EOF)) + { + definition(&compiler); + + // If there is no newline, it must be the end of file on the same line. + if (!matchLine(&compiler)) + { + consume(&compiler, TOKEN_EOF, "Expect end of file."); + break; + } + } + + emitOp(&compiler, CODE_END_MODULE); + } + + emitOp(&compiler, CODE_RETURN); + + // See if there are any implicitly declared module-level variables that never + // got an explicit definition. They will have values that are numbers + // indicating the line where the variable was first used. + for (int i = numExistingVariables; i < parser.module->variables.count; i++) + { + if (IS_NUM(parser.module->variables.data[i])) + { + // Synthesize a token for the original use site. + parser.previous.type = TOKEN_NAME; + parser.previous.start = parser.module->variableNames.data[i]->value; + parser.previous.length = parser.module->variableNames.data[i]->length; + parser.previous.line = (int)AS_NUM(parser.module->variables.data[i]); + error(&compiler, "Variable is used but not defined."); + } + } + + return endCompiler(&compiler, "(script)", 8); +} + +void wrenBindMethodCode(ObjClass* classObj, ObjFn* fn) +{ + int ip = 0; + for (;;) + { + Code instruction = (Code)fn->code.data[ip]; + switch (instruction) + { + case CODE_LOAD_FIELD: + case CODE_STORE_FIELD: + case CODE_LOAD_FIELD_THIS: + case CODE_STORE_FIELD_THIS: + // Shift this class's fields down past the inherited ones. We don't + // check for overflow here because we'll see if the number of fields + // overflows when the subclass is created. + fn->code.data[ip + 1] += classObj->superclass->numFields; + break; + + case CODE_SUPER_0: + case CODE_SUPER_1: + case CODE_SUPER_2: + case CODE_SUPER_3: + case CODE_SUPER_4: + case CODE_SUPER_5: + case CODE_SUPER_6: + case CODE_SUPER_7: + case CODE_SUPER_8: + case CODE_SUPER_9: + case CODE_SUPER_10: + case CODE_SUPER_11: + case CODE_SUPER_12: + case CODE_SUPER_13: + case CODE_SUPER_14: + case CODE_SUPER_15: + case CODE_SUPER_16: + { + // Fill in the constant slot with a reference to the superclass. + int constant = (fn->code.data[ip + 3] << 8) | fn->code.data[ip + 4]; + fn->constants.data[constant] = OBJ_VAL(classObj->superclass); + break; + } + + case CODE_CLOSURE: + { + // Bind the nested closure too. + int constant = (fn->code.data[ip + 1] << 8) | fn->code.data[ip + 2]; + wrenBindMethodCode(classObj, AS_FN(fn->constants.data[constant])); + break; + } + + case CODE_END: + return; + + default: + // Other instructions are unaffected, so just skip over them. + break; + } + ip += 1 + getByteCountForArguments(fn->code.data, fn->constants.data, ip); + } +} + +void wrenMarkCompiler(WrenVM* vm, Compiler* compiler) +{ + wrenGrayValue(vm, compiler->parser->current.value); + wrenGrayValue(vm, compiler->parser->previous.value); + wrenGrayValue(vm, compiler->parser->next.value); + + // Walk up the parent chain to mark the outer compilers too. The VM only + // tracks the innermost one. + do + { + wrenGrayObj(vm, (Obj*)compiler->fn); + wrenGrayObj(vm, (Obj*)compiler->constants); + wrenGrayObj(vm, (Obj*)compiler->attributes); + + if (compiler->enclosingClass != NULL) + { + wrenBlackenSymbolTable(vm, &compiler->enclosingClass->fields); + + if(compiler->enclosingClass->methodAttributes != NULL) + { + wrenGrayObj(vm, (Obj*)compiler->enclosingClass->methodAttributes); + } + if(compiler->enclosingClass->classAttributes != NULL) + { + wrenGrayObj(vm, (Obj*)compiler->enclosingClass->classAttributes); + } + } + + compiler = compiler->parent; + } + while (compiler != NULL); +} + +// Helpers for Attributes + +// Throw an error if any attributes were found preceding, +// and clear the attributes so the error doesn't keep happening. +static void disallowAttributes(Compiler* compiler) +{ + if (compiler->numAttributes > 0) + { + error(compiler, "Attributes can only specified before a class or a method"); + wrenMapClear(compiler->parser->vm, compiler->attributes); + compiler->numAttributes = 0; + } +} + +// Add an attribute to a given group in the compiler attribues map +static void addToAttributeGroup(Compiler* compiler, + Value group, Value key, Value value) +{ + WrenVM* vm = compiler->parser->vm; + + if(IS_OBJ(group)) wrenPushRoot(vm, AS_OBJ(group)); + if(IS_OBJ(key)) wrenPushRoot(vm, AS_OBJ(key)); + if(IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + Value groupMapValue = wrenMapGet(compiler->attributes, group); + if(IS_UNDEFINED(groupMapValue)) + { + groupMapValue = OBJ_VAL(wrenNewMap(vm)); + wrenMapSet(vm, compiler->attributes, group, groupMapValue); + } + + //we store them as a map per so we can maintain duplicate keys + //group = { key:[value, ...], } + ObjMap* groupMap = AS_MAP(groupMapValue); + + //var keyItems = group[key] + //if(!keyItems) keyItems = group[key] = [] + Value keyItemsValue = wrenMapGet(groupMap, key); + if(IS_UNDEFINED(keyItemsValue)) + { + keyItemsValue = OBJ_VAL(wrenNewList(vm, 0)); + wrenMapSet(vm, groupMap, key, keyItemsValue); + } + + //keyItems.add(value) + ObjList* keyItems = AS_LIST(keyItemsValue); + wrenValueBufferWrite(vm, &keyItems->elements, value); + + if(IS_OBJ(group)) wrenPopRoot(vm); + if(IS_OBJ(key)) wrenPopRoot(vm); + if(IS_OBJ(value)) wrenPopRoot(vm); +} + + +// Emit the attributes in the give map onto the stack +static void emitAttributes(Compiler* compiler, ObjMap* attributes) +{ + // Instantiate a new map for the attributes + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + // The attributes are stored as group = { key:[value, value, ...] } + // so our first level is the group map + for(uint32_t groupIdx = 0; groupIdx < attributes->capacity; groupIdx++) + { + const MapEntry* groupEntry = &attributes->entries[groupIdx]; + if(IS_UNDEFINED(groupEntry->key)) continue; + //group key + emitConstant(compiler, groupEntry->key); + + //group value is gonna be a map + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + ObjMap* groupItems = AS_MAP(groupEntry->value); + for(uint32_t itemIdx = 0; itemIdx < groupItems->capacity; itemIdx++) + { + const MapEntry* itemEntry = &groupItems->entries[itemIdx]; + if(IS_UNDEFINED(itemEntry->key)) continue; + + emitConstant(compiler, itemEntry->key); + // Attribute key value, key = [] + loadCoreVariable(compiler, "List"); + callMethod(compiler, 0, "new()", 5); + // Add the items to the key list + ObjList* items = AS_LIST(itemEntry->value); + for(int itemIdx = 0; itemIdx < items->elements.count; ++itemIdx) + { + emitConstant(compiler, items->elements.data[itemIdx]); + callMethod(compiler, 1, "addCore_(_)", 11); + } + // Add the list to the map + callMethod(compiler, 2, "addCore_(_,_)", 13); + } + + // Add the key/value to the map + callMethod(compiler, 2, "addCore_(_,_)", 13); + } + +} + +// Methods are stored as method <-> attributes, so we have to have +// an indirection to resolve for methods +static void emitAttributeMethods(Compiler* compiler, ObjMap* attributes) +{ + // Instantiate a new map for the attributes + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + for(uint32_t methodIdx = 0; methodIdx < attributes->capacity; methodIdx++) + { + const MapEntry* methodEntry = &attributes->entries[methodIdx]; + if(IS_UNDEFINED(methodEntry->key)) continue; + emitConstant(compiler, methodEntry->key); + ObjMap* attributeMap = AS_MAP(methodEntry->value); + emitAttributes(compiler, attributeMap); + callMethod(compiler, 2, "addCore_(_,_)", 13); + } +} + + +// Emit the final ClassAttributes that exists at runtime +static void emitClassAttributes(Compiler* compiler, ClassInfo* classInfo) +{ + loadCoreVariable(compiler, "ClassAttributes"); + + classInfo->classAttributes + ? emitAttributes(compiler, classInfo->classAttributes) + : null(compiler, false); + + classInfo->methodAttributes + ? emitAttributeMethods(compiler, classInfo->methodAttributes) + : null(compiler, false); + + callMethod(compiler, 2, "new(_,_)", 8); +} + +// Copy the current attributes stored in the compiler into a destination map +// This also resets the counter, since the intent is to consume the attributes +static void copyAttributes(Compiler* compiler, ObjMap* into) +{ + compiler->numAttributes = 0; + + if(compiler->attributes->count == 0) return; + if(into == NULL) return; + + WrenVM* vm = compiler->parser->vm; + + // Note we copy the actual values as is since we'll take ownership + // and clear the original map + for(uint32_t attrIdx = 0; attrIdx < compiler->attributes->capacity; attrIdx++) + { + const MapEntry* attrEntry = &compiler->attributes->entries[attrIdx]; + if(IS_UNDEFINED(attrEntry->key)) continue; + wrenMapSet(vm, into, attrEntry->key, attrEntry->value); + } + + wrenMapClear(vm, compiler->attributes); +} + +// Copy the current attributes stored in the compiler into the method specific +// attributes for the current enclosingClass. +// This also resets the counter, since the intent is to consume the attributes +static void copyMethodAttributes(Compiler* compiler, bool isForeign, + bool isStatic, const char* fullSignature, int32_t length) +{ + compiler->numAttributes = 0; + + if(compiler->attributes->count == 0) return; + + WrenVM* vm = compiler->parser->vm; + + // Make a map for this method to copy into + ObjMap* methodAttr = wrenNewMap(vm); + wrenPushRoot(vm, (Obj*)methodAttr); + copyAttributes(compiler, methodAttr); + + // Include 'foreign static ' in front as needed + int32_t fullLength = length; + if(isForeign) fullLength += 8; + if(isStatic) fullLength += 7; + char fullSignatureWithPrefix[MAX_METHOD_SIGNATURE + 8 + 7]; + const char* foreignPrefix = isForeign ? "foreign " : ""; + const char* staticPrefix = isStatic ? "static " : ""; + sprintf(fullSignatureWithPrefix, "%s%s%.*s", foreignPrefix, staticPrefix, + length, fullSignature); + fullSignatureWithPrefix[fullLength] = '\0'; + + if(compiler->enclosingClass->methodAttributes == NULL) { + compiler->enclosingClass->methodAttributes = wrenNewMap(vm); + } + + // Store the method attributes in the class map + Value key = wrenNewStringLength(vm, fullSignatureWithPrefix, fullLength); + wrenMapSet(vm, compiler->enclosingClass->methodAttributes, key, OBJ_VAL(methodAttr)); + + wrenPopRoot(vm); +} +// End file "wren_compiler.c" +// Begin file "wren_primitive.c" +// Begin file "wren_primitive.h" +#ifndef wren_primitive_h +#define wren_primitive_h + + +// Binds a primitive method named [name] (in Wren) implemented using C function +// [fn] to `ObjClass` [cls]. +#define PRIMITIVE(cls, name, function) \ + do \ + { \ + int symbol = wrenSymbolTableEnsure(vm, \ + &vm->methodNames, name, strlen(name)); \ + Method method; \ + method.type = METHOD_PRIMITIVE; \ + method.as.primitive = prim_##function; \ + wrenBindMethod(vm, cls, symbol, method); \ + } while (false) + +// Binds a primitive method named [name] (in Wren) implemented using C function +// [fn] to `ObjClass` [cls], but as a FN call. +#define FUNCTION_CALL(cls, name, function) \ + do \ + { \ + int symbol = wrenSymbolTableEnsure(vm, \ + &vm->methodNames, name, strlen(name)); \ + Method method; \ + method.type = METHOD_FUNCTION_CALL; \ + method.as.primitive = prim_##function; \ + wrenBindMethod(vm, cls, symbol, method); \ + } while (false) + +// Defines a primitive method whose C function name is [name]. This abstracts +// the actual type signature of a primitive function and makes it clear which C +// functions are invoked as primitives. +#define DEF_PRIMITIVE(name) \ + static bool prim_##name(WrenVM* vm, Value* args) + +#define RETURN_VAL(value) \ + do \ + { \ + args[0] = value; \ + return true; \ + } while (false) + +#define RETURN_OBJ(obj) RETURN_VAL(OBJ_VAL(obj)) +#define RETURN_BOOL(value) RETURN_VAL(BOOL_VAL(value)) +#define RETURN_FALSE RETURN_VAL(FALSE_VAL) +#define RETURN_NULL RETURN_VAL(NULL_VAL) +#define RETURN_NUM(value) RETURN_VAL(NUM_VAL(value)) +#define RETURN_TRUE RETURN_VAL(TRUE_VAL) + +#define RETURN_ERROR(msg) \ + do \ + { \ + vm->fiber->error = wrenNewStringLength(vm, msg, sizeof(msg) - 1); \ + return false; \ + } while (false) + +#define RETURN_ERROR_FMT(...) \ + do \ + { \ + vm->fiber->error = wrenStringFormat(vm, __VA_ARGS__); \ + return false; \ + } while (false) + +// Validates that the given [arg] is a function. Returns true if it is. If not, +// reports an error and returns false. +bool validateFn(WrenVM* vm, Value arg, const char* argName); + +// Validates that the given [arg] is a Num. Returns true if it is. If not, +// reports an error and returns false. +bool validateNum(WrenVM* vm, Value arg, const char* argName); + +// Validates that [value] is an integer. Returns true if it is. If not, reports +// an error and returns false. +bool validateIntValue(WrenVM* vm, double value, const char* argName); + +// Validates that the given [arg] is an integer. Returns true if it is. If not, +// reports an error and returns false. +bool validateInt(WrenVM* vm, Value arg, const char* argName); + +// Validates that [arg] is a valid object for use as a map key. Returns true if +// it is. If not, reports an error and returns false. +bool validateKey(WrenVM* vm, Value arg); + +// Validates that the argument at [argIndex] is an integer within `[0, count)`. +// Also allows negative indices which map backwards from the end. Returns the +// valid positive index value. If invalid, reports an error and returns +// `UINT32_MAX`. +uint32_t validateIndex(WrenVM* vm, Value arg, uint32_t count, + const char* argName); + +// Validates that the given [arg] is a String. Returns true if it is. If not, +// reports an error and returns false. +bool validateString(WrenVM* vm, Value arg, const char* argName); + +// Given a [range] and the [length] of the object being operated on, determines +// the series of elements that should be chosen from the underlying object. +// Handles ranges that count backwards from the end as well as negative ranges. +// +// Returns the index from which the range should start or `UINT32_MAX` if the +// range is invalid. After calling, [length] will be updated with the number of +// elements in the resulting sequence. [step] will be direction that the range +// is going: `1` if the range is increasing from the start index or `-1` if the +// range is decreasing. +uint32_t calculateRange(WrenVM* vm, ObjRange* range, uint32_t* length, + int* step); + +#endif +// End file "wren_primitive.h" + +#include + +// Validates that [value] is an integer within `[0, count)`. Also allows +// negative indices which map backwards from the end. Returns the valid positive +// index value. If invalid, reports an error and returns `UINT32_MAX`. +static uint32_t validateIndexValue(WrenVM* vm, uint32_t count, double value, + const char* argName) +{ + if (!validateIntValue(vm, value, argName)) return UINT32_MAX; + + // Negative indices count from the end. + if (value < 0) value = count + value; + + // Check bounds. + if (value >= 0 && value < count) return (uint32_t)value; + + vm->fiber->error = wrenStringFormat(vm, "$ out of bounds.", argName); + return UINT32_MAX; +} + +bool validateFn(WrenVM* vm, Value arg, const char* argName) +{ + if (IS_CLOSURE(arg)) return true; + RETURN_ERROR_FMT("$ must be a function.", argName); +} + +bool validateNum(WrenVM* vm, Value arg, const char* argName) +{ + if (IS_NUM(arg)) return true; + RETURN_ERROR_FMT("$ must be a number.", argName); +} + +bool validateIntValue(WrenVM* vm, double value, const char* argName) +{ + if (trunc(value) == value) return true; + RETURN_ERROR_FMT("$ must be an integer.", argName); +} + +bool validateInt(WrenVM* vm, Value arg, const char* argName) +{ + // Make sure it's a number first. + if (!validateNum(vm, arg, argName)) return false; + return validateIntValue(vm, AS_NUM(arg), argName); +} + +bool validateKey(WrenVM* vm, Value arg) +{ + if (wrenMapIsValidKey(arg)) return true; + + RETURN_ERROR("Key must be a value type."); +} + +uint32_t validateIndex(WrenVM* vm, Value arg, uint32_t count, + const char* argName) +{ + if (!validateNum(vm, arg, argName)) return UINT32_MAX; + return validateIndexValue(vm, count, AS_NUM(arg), argName); +} + +bool validateString(WrenVM* vm, Value arg, const char* argName) +{ + if (IS_STRING(arg)) return true; + RETURN_ERROR_FMT("$ must be a string.", argName); +} + +uint32_t calculateRange(WrenVM* vm, ObjRange* range, uint32_t* length, + int* step) +{ + *step = 0; + + // Edge case: an empty range is allowed at the end of a sequence. This way, + // list[0..-1] and list[0...list.count] can be used to copy a list even when + // empty. + if (range->from == *length && + range->to == (range->isInclusive ? -1.0 : (double)*length)) + { + *length = 0; + return 0; + } + + uint32_t from = validateIndexValue(vm, *length, range->from, "Range start"); + if (from == UINT32_MAX) return UINT32_MAX; + + // Bounds check the end manually to handle exclusive ranges. + double value = range->to; + if (!validateIntValue(vm, value, "Range end")) return UINT32_MAX; + + // Negative indices count from the end. + if (value < 0) value = *length + value; + + // Convert the exclusive range to an inclusive one. + if (!range->isInclusive) + { + // An exclusive range with the same start and end points is empty. + if (value == from) + { + *length = 0; + return from; + } + + // Shift the endpoint to make it inclusive, handling both increasing and + // decreasing ranges. + value += value >= from ? -1 : 1; + } + + // Check bounds. + if (value < 0 || value >= *length) + { + vm->fiber->error = CONST_STRING(vm, "Range end out of bounds."); + return UINT32_MAX; + } + + uint32_t to = (uint32_t)value; + *length = abs((int)(from - to)) + 1; + *step = from < to ? 1 : -1; + return from; +} +// End file "wren_primitive.c" +// Begin file "wren_core.c" +#include +#include +#include +#include +#include +#include + +// Begin file "wren_core.h" +#ifndef wren_core_h +#define wren_core_h + + +// This module defines the built-in classes and their primitives methods that +// are implemented directly in C code. Some languages try to implement as much +// of the core module itself in the primary language instead of in the host +// language. +// +// With Wren, we try to do as much of it in C as possible. Primitive methods +// are always faster than code written in Wren, and it minimizes startup time +// since we don't have to parse, compile, and execute Wren code. +// +// There is one limitation, though. Methods written in C cannot call Wren ones. +// They can only be the top of the callstack, and immediately return. This +// makes it difficult to have primitive methods that rely on polymorphic +// behavior. For example, `System.print` should call `toString` on its argument, +// including user-defined `toString` methods on user-defined classes. + +void wrenInitializeCore(WrenVM* vm); + +#endif +// End file "wren_core.h" + +// Begin file "wren_core.wren.inc" +// Generated automatically from src/vm/wren_core.wren. Do not edit. +static const char* coreModuleSource = +"class Bool {}\n" +"class Fiber {}\n" +"class Fn {}\n" +"class Null {}\n" +"class Num {}\n" +"\n" +"class Sequence {\n" +" all(f) {\n" +" var result = true\n" +" for (element in this) {\n" +" result = f.call(element)\n" +" if (!result) return result\n" +" }\n" +" return result\n" +" }\n" +"\n" +" any(f) {\n" +" var result = false\n" +" for (element in this) {\n" +" result = f.call(element)\n" +" if (result) return result\n" +" }\n" +" return result\n" +" }\n" +"\n" +" contains(element) {\n" +" for (item in this) {\n" +" if (element == item) return true\n" +" }\n" +" return false\n" +" }\n" +"\n" +" count {\n" +" var result = 0\n" +" for (element in this) {\n" +" result = result + 1\n" +" }\n" +" return result\n" +" }\n" +"\n" +" count(f) {\n" +" var result = 0\n" +" for (element in this) {\n" +" if (f.call(element)) result = result + 1\n" +" }\n" +" return result\n" +" }\n" +"\n" +" each(f) {\n" +" for (element in this) {\n" +" f.call(element)\n" +" }\n" +" }\n" +"\n" +" isEmpty { iterate(null) ? false : true }\n" +"\n" +" map(transformation) { MapSequence.new(this, transformation) }\n" +"\n" +" skip(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" return SkipSequence.new(this, count)\n" +" }\n" +"\n" +" take(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" return TakeSequence.new(this, count)\n" +" }\n" +"\n" +" where(predicate) { WhereSequence.new(this, predicate) }\n" +"\n" +" reduce(acc, f) {\n" +" for (element in this) {\n" +" acc = f.call(acc, element)\n" +" }\n" +" return acc\n" +" }\n" +"\n" +" reduce(f) {\n" +" var iter = iterate(null)\n" +" if (!iter) Fiber.abort(\"Can't reduce an empty sequence.\")\n" +"\n" +" // Seed with the first element.\n" +" var result = iteratorValue(iter)\n" +" while (iter = iterate(iter)) {\n" +" result = f.call(result, iteratorValue(iter))\n" +" }\n" +"\n" +" return result\n" +" }\n" +"\n" +" join() { join(\"\") }\n" +"\n" +" join(sep) {\n" +" var first = true\n" +" var result = \"\"\n" +"\n" +" for (element in this) {\n" +" if (!first) result = result + sep\n" +" first = false\n" +" result = result + element.toString\n" +" }\n" +"\n" +" return result\n" +" }\n" +"\n" +" toList {\n" +" var result = List.new()\n" +" for (element in this) {\n" +" result.add(element)\n" +" }\n" +" return result\n" +" }\n" +"}\n" +"\n" +"class MapSequence is Sequence {\n" +" construct new(sequence, fn) {\n" +" _sequence = sequence\n" +" _fn = fn\n" +" }\n" +"\n" +" iterate(iterator) { _sequence.iterate(iterator) }\n" +" iteratorValue(iterator) { _fn.call(_sequence.iteratorValue(iterator)) }\n" +"}\n" +"\n" +"class SkipSequence is Sequence {\n" +" construct new(sequence, count) {\n" +" _sequence = sequence\n" +" _count = count\n" +" }\n" +"\n" +" iterate(iterator) {\n" +" if (iterator) {\n" +" return _sequence.iterate(iterator)\n" +" } else {\n" +" iterator = _sequence.iterate(iterator)\n" +" var count = _count\n" +" while (count > 0 && iterator) {\n" +" iterator = _sequence.iterate(iterator)\n" +" count = count - 1\n" +" }\n" +" return iterator\n" +" }\n" +" }\n" +"\n" +" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n" +"}\n" +"\n" +"class TakeSequence is Sequence {\n" +" construct new(sequence, count) {\n" +" _sequence = sequence\n" +" _count = count\n" +" }\n" +"\n" +" iterate(iterator) {\n" +" if (!iterator) _taken = 1 else _taken = _taken + 1\n" +" return _taken > _count ? null : _sequence.iterate(iterator)\n" +" }\n" +"\n" +" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n" +"}\n" +"\n" +"class WhereSequence is Sequence {\n" +" construct new(sequence, fn) {\n" +" _sequence = sequence\n" +" _fn = fn\n" +" }\n" +"\n" +" iterate(iterator) {\n" +" while (iterator = _sequence.iterate(iterator)) {\n" +" if (_fn.call(_sequence.iteratorValue(iterator))) break\n" +" }\n" +" return iterator\n" +" }\n" +"\n" +" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n" +"}\n" +"\n" +"class String is Sequence {\n" +" bytes { StringByteSequence.new(this) }\n" +" codePoints { StringCodePointSequence.new(this) }\n" +"\n" +" split(delimiter) {\n" +" if (!(delimiter is String) || delimiter.isEmpty) {\n" +" Fiber.abort(\"Delimiter must be a non-empty string.\")\n" +" }\n" +"\n" +" var result = []\n" +"\n" +" var last = 0\n" +" var index = 0\n" +"\n" +" var delimSize = delimiter.byteCount_\n" +" var size = byteCount_\n" +"\n" +" while (last < size && (index = indexOf(delimiter, last)) != -1) {\n" +" result.add(this[last...index])\n" +" last = index + delimSize\n" +" }\n" +"\n" +" if (last < size) {\n" +" result.add(this[last..-1])\n" +" } else {\n" +" result.add(\"\")\n" +" }\n" +" return result\n" +" }\n" +"\n" +" replace(from, to) {\n" +" if (!(from is String) || from.isEmpty) {\n" +" Fiber.abort(\"From must be a non-empty string.\")\n" +" } else if (!(to is String)) {\n" +" Fiber.abort(\"To must be a string.\")\n" +" }\n" +"\n" +" var result = \"\"\n" +"\n" +" var last = 0\n" +" var index = 0\n" +"\n" +" var fromSize = from.byteCount_\n" +" var size = byteCount_\n" +"\n" +" while (last < size && (index = indexOf(from, last)) != -1) {\n" +" result = result + this[last...index] + to\n" +" last = index + fromSize\n" +" }\n" +"\n" +" if (last < size) result = result + this[last..-1]\n" +"\n" +" return result\n" +" }\n" +"\n" +" trim() { trim_(\"\\t\\r\\n \", true, true) }\n" +" trim(chars) { trim_(chars, true, true) }\n" +" trimEnd() { trim_(\"\\t\\r\\n \", false, true) }\n" +" trimEnd(chars) { trim_(chars, false, true) }\n" +" trimStart() { trim_(\"\\t\\r\\n \", true, false) }\n" +" trimStart(chars) { trim_(chars, true, false) }\n" +"\n" +" trim_(chars, trimStart, trimEnd) {\n" +" if (!(chars is String)) {\n" +" Fiber.abort(\"Characters must be a string.\")\n" +" }\n" +"\n" +" var codePoints = chars.codePoints.toList\n" +"\n" +" var start\n" +" if (trimStart) {\n" +" while (start = iterate(start)) {\n" +" if (!codePoints.contains(codePointAt_(start))) break\n" +" }\n" +"\n" +" if (start == false) return \"\"\n" +" } else {\n" +" start = 0\n" +" }\n" +"\n" +" var end\n" +" if (trimEnd) {\n" +" end = byteCount_ - 1\n" +" while (end >= start) {\n" +" var codePoint = codePointAt_(end)\n" +" if (codePoint != -1 && !codePoints.contains(codePoint)) break\n" +" end = end - 1\n" +" }\n" +"\n" +" if (end < start) return \"\"\n" +" } else {\n" +" end = -1\n" +" }\n" +"\n" +" return this[start..end]\n" +" }\n" +"\n" +" *(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" var result = \"\"\n" +" for (i in 0...count) {\n" +" result = result + this\n" +" }\n" +" return result\n" +" }\n" +"}\n" +"\n" +"class StringByteSequence is Sequence {\n" +" construct new(string) {\n" +" _string = string\n" +" }\n" +"\n" +" [index] { _string.byteAt_(index) }\n" +" iterate(iterator) { _string.iterateByte_(iterator) }\n" +" iteratorValue(iterator) { _string.byteAt_(iterator) }\n" +"\n" +" count { _string.byteCount_ }\n" +"}\n" +"\n" +"class StringCodePointSequence is Sequence {\n" +" construct new(string) {\n" +" _string = string\n" +" }\n" +"\n" +" [index] { _string.codePointAt_(index) }\n" +" iterate(iterator) { _string.iterate(iterator) }\n" +" iteratorValue(iterator) { _string.codePointAt_(iterator) }\n" +"\n" +" count { _string.count }\n" +"}\n" +"\n" +"class List is Sequence {\n" +" addAll(other) {\n" +" for (element in other) {\n" +" add(element)\n" +" }\n" +" return other\n" +" }\n" +"\n" +" sort() { sort {|low, high| low < high } }\n" +"\n" +" sort(comparer) {\n" +" if (!(comparer is Fn)) {\n" +" Fiber.abort(\"Comparer must be a function.\")\n" +" }\n" +" quicksort_(0, count - 1, comparer)\n" +" return this\n" +" }\n" +"\n" +" quicksort_(low, high, comparer) {\n" +" if (low < high) {\n" +" var p = partition_(low, high, comparer)\n" +" quicksort_(low, p - 1, comparer)\n" +" quicksort_(p + 1, high, comparer)\n" +" }\n" +" }\n" +"\n" +" partition_(low, high, comparer) {\n" +" var p = this[high]\n" +" var i = low - 1\n" +" for (j in low..(high-1)) {\n" +" if (comparer.call(this[j], p)) { \n" +" i = i + 1\n" +" var t = this[i]\n" +" this[i] = this[j]\n" +" this[j] = t\n" +" }\n" +" }\n" +" var t = this[i+1]\n" +" this[i+1] = this[high]\n" +" this[high] = t\n" +" return i+1\n" +" }\n" +"\n" +" toString { \"[%(join(\", \"))]\" }\n" +"\n" +" +(other) {\n" +" var result = this[0..-1]\n" +" for (element in other) {\n" +" result.add(element)\n" +" }\n" +" return result\n" +" }\n" +"\n" +" *(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" var result = []\n" +" for (i in 0...count) {\n" +" result.addAll(this)\n" +" }\n" +" return result\n" +" }\n" +"}\n" +"\n" +"class Map is Sequence {\n" +" keys { MapKeySequence.new(this) }\n" +" values { MapValueSequence.new(this) }\n" +"\n" +" toString {\n" +" var first = true\n" +" var result = \"{\"\n" +"\n" +" for (key in keys) {\n" +" if (!first) result = result + \", \"\n" +" first = false\n" +" result = result + \"%(key): %(this[key])\"\n" +" }\n" +"\n" +" return result + \"}\"\n" +" }\n" +"\n" +" iteratorValue(iterator) {\n" +" return MapEntry.new(\n" +" keyIteratorValue_(iterator),\n" +" valueIteratorValue_(iterator))\n" +" }\n" +"}\n" +"\n" +"class MapEntry {\n" +" construct new(key, value) {\n" +" _key = key\n" +" _value = value\n" +" }\n" +"\n" +" key { _key }\n" +" value { _value }\n" +"\n" +" toString { \"%(_key):%(_value)\" }\n" +"}\n" +"\n" +"class MapKeySequence is Sequence {\n" +" construct new(map) {\n" +" _map = map\n" +" }\n" +"\n" +" iterate(n) { _map.iterate(n) }\n" +" iteratorValue(iterator) { _map.keyIteratorValue_(iterator) }\n" +"}\n" +"\n" +"class MapValueSequence is Sequence {\n" +" construct new(map) {\n" +" _map = map\n" +" }\n" +"\n" +" iterate(n) { _map.iterate(n) }\n" +" iteratorValue(iterator) { _map.valueIteratorValue_(iterator) }\n" +"}\n" +"\n" +"class Range is Sequence {}\n" +"\n" +"class System {\n" +" static print() {\n" +" writeString_(\"\\n\")\n" +" }\n" +"\n" +" static print(obj) {\n" +" writeObject_(obj)\n" +" writeString_(\"\\n\")\n" +" return obj\n" +" }\n" +"\n" +" static printAll(sequence) {\n" +" for (object in sequence) writeObject_(object)\n" +" writeString_(\"\\n\")\n" +" }\n" +"\n" +" static write(obj) {\n" +" writeObject_(obj)\n" +" return obj\n" +" }\n" +"\n" +" static writeAll(sequence) {\n" +" for (object in sequence) writeObject_(object)\n" +" }\n" +"\n" +" static writeObject_(obj) {\n" +" var string = obj.toString\n" +" if (string is String) {\n" +" writeString_(string)\n" +" } else {\n" +" writeString_(\"[invalid toString]\")\n" +" }\n" +" }\n" +"}\n" +"\n" +"class ClassAttributes {\n" +" self { _attributes }\n" +" methods { _methods }\n" +" construct new(attributes, methods) {\n" +" _attributes = attributes\n" +" _methods = methods\n" +" }\n" +" toString { \"attributes:%(_attributes) methods:%(_methods)\" }\n" +"}\n"; +// End file "wren_core.wren.inc" + +DEF_PRIMITIVE(bool_not) +{ + RETURN_BOOL(!AS_BOOL(args[0])); +} + +DEF_PRIMITIVE(bool_toString) +{ + if (AS_BOOL(args[0])) + { + RETURN_VAL(CONST_STRING(vm, "true")); + } + else + { + RETURN_VAL(CONST_STRING(vm, "false")); + } +} + +DEF_PRIMITIVE(class_name) +{ + RETURN_OBJ(AS_CLASS(args[0])->name); +} + +DEF_PRIMITIVE(class_supertype) +{ + ObjClass* classObj = AS_CLASS(args[0]); + + // Object has no superclass. + if (classObj->superclass == NULL) RETURN_NULL; + + RETURN_OBJ(classObj->superclass); +} + +DEF_PRIMITIVE(class_toString) +{ + RETURN_OBJ(AS_CLASS(args[0])->name); +} + +DEF_PRIMITIVE(class_attributes) +{ + RETURN_VAL(AS_CLASS(args[0])->attributes); +} + +DEF_PRIMITIVE(fiber_new) +{ + if (!validateFn(vm, args[1], "Argument")) return false; + + ObjClosure* closure = AS_CLOSURE(args[1]); + if (closure->fn->arity > 1) + { + RETURN_ERROR("Function cannot take more than one parameter."); + } + + RETURN_OBJ(wrenNewFiber(vm, closure)); +} + +DEF_PRIMITIVE(fiber_abort) +{ + vm->fiber->error = args[1]; + + // If the error is explicitly null, it's not really an abort. + return IS_NULL(args[1]); +} + +// Transfer execution to [fiber] coming from the current fiber whose stack has +// [args]. +// +// [isCall] is true if [fiber] is being called and not transferred. +// +// [hasValue] is true if a value in [args] is being passed to the new fiber. +// Otherwise, `null` is implicitly being passed. +static bool runFiber(WrenVM* vm, ObjFiber* fiber, Value* args, bool isCall, + bool hasValue, const char* verb) +{ + + if (wrenHasError(fiber)) + { + RETURN_ERROR_FMT("Cannot $ an aborted fiber.", verb); + } + + if (isCall) + { + // You can't call a called fiber, but you can transfer directly to it, + // which is why this check is gated on `isCall`. This way, after resuming a + // suspended fiber, it will run and then return to the fiber that called it + // and so on. + if (fiber->caller != NULL) RETURN_ERROR("Fiber has already been called."); + + if (fiber->state == FIBER_ROOT) RETURN_ERROR("Cannot call root fiber."); + + // Remember who ran it. + fiber->caller = vm->fiber; + } + + if (fiber->numFrames == 0) + { + RETURN_ERROR_FMT("Cannot $ a finished fiber.", verb); + } + + // When the calling fiber resumes, we'll store the result of the call in its + // stack. If the call has two arguments (the fiber and the value), we only + // need one slot for the result, so discard the other slot now. + if (hasValue) vm->fiber->stackTop--; + + if (fiber->numFrames == 1 && + fiber->frames[0].ip == fiber->frames[0].closure->fn->code.data) + { + // The fiber is being started for the first time. If its function takes a + // parameter, bind an argument to it. + if (fiber->frames[0].closure->fn->arity == 1) + { + fiber->stackTop[0] = hasValue ? args[1] : NULL_VAL; + fiber->stackTop++; + } + } + else + { + // The fiber is being resumed, make yield() or transfer() return the result. + fiber->stackTop[-1] = hasValue ? args[1] : NULL_VAL; + } + + vm->fiber = fiber; + return false; +} + +DEF_PRIMITIVE(fiber_call) +{ + return runFiber(vm, AS_FIBER(args[0]), args, true, false, "call"); +} + +DEF_PRIMITIVE(fiber_call1) +{ + return runFiber(vm, AS_FIBER(args[0]), args, true, true, "call"); +} + +DEF_PRIMITIVE(fiber_current) +{ + RETURN_OBJ(vm->fiber); +} + +DEF_PRIMITIVE(fiber_error) +{ + RETURN_VAL(AS_FIBER(args[0])->error); +} + +DEF_PRIMITIVE(fiber_isDone) +{ + ObjFiber* runFiber = AS_FIBER(args[0]); + RETURN_BOOL(runFiber->numFrames == 0 || wrenHasError(runFiber)); +} + +DEF_PRIMITIVE(fiber_suspend) +{ + // Switching to a null fiber tells the interpreter to stop and exit. + vm->fiber = NULL; + vm->apiStack = NULL; + return false; +} + +DEF_PRIMITIVE(fiber_transfer) +{ + return runFiber(vm, AS_FIBER(args[0]), args, false, false, "transfer to"); +} + +DEF_PRIMITIVE(fiber_transfer1) +{ + return runFiber(vm, AS_FIBER(args[0]), args, false, true, "transfer to"); +} + +DEF_PRIMITIVE(fiber_transferError) +{ + runFiber(vm, AS_FIBER(args[0]), args, false, true, "transfer to"); + vm->fiber->error = args[1]; + return false; +} + +DEF_PRIMITIVE(fiber_try) +{ + runFiber(vm, AS_FIBER(args[0]), args, true, false, "try"); + + // If we're switching to a valid fiber to try, remember that we're trying it. + if (!wrenHasError(vm->fiber)) vm->fiber->state = FIBER_TRY; + return false; +} + +DEF_PRIMITIVE(fiber_try1) +{ + runFiber(vm, AS_FIBER(args[0]), args, true, true, "try"); + + // If we're switching to a valid fiber to try, remember that we're trying it. + if (!wrenHasError(vm->fiber)) vm->fiber->state = FIBER_TRY; + return false; +} + +DEF_PRIMITIVE(fiber_yield) +{ + ObjFiber* current = vm->fiber; + vm->fiber = current->caller; + + // Unhook this fiber from the one that called it. + current->caller = NULL; + current->state = FIBER_OTHER; + + if (vm->fiber != NULL) + { + // Make the caller's run method return null. + vm->fiber->stackTop[-1] = NULL_VAL; + } + + return false; +} + +DEF_PRIMITIVE(fiber_yield1) +{ + ObjFiber* current = vm->fiber; + vm->fiber = current->caller; + + // Unhook this fiber from the one that called it. + current->caller = NULL; + current->state = FIBER_OTHER; + + if (vm->fiber != NULL) + { + // Make the caller's run method return the argument passed to yield. + vm->fiber->stackTop[-1] = args[1]; + + // When the yielding fiber resumes, we'll store the result of the yield + // call in its stack. Since Fiber.yield(value) has two arguments (the Fiber + // class and the value) and we only need one slot for the result, discard + // the other slot now. + current->stackTop--; + } + + return false; +} + +DEF_PRIMITIVE(fn_new) +{ + if (!validateFn(vm, args[1], "Argument")) return false; + + // The block argument is already a function, so just return it. + RETURN_VAL(args[1]); +} + +DEF_PRIMITIVE(fn_arity) +{ + RETURN_NUM(AS_CLOSURE(args[0])->fn->arity); +} + +static void call_fn(WrenVM* vm, Value* args, int numArgs) +{ + // +1 to include the function itself. + wrenCallFunction(vm, vm->fiber, AS_CLOSURE(args[0]), numArgs + 1); +} + +#define DEF_FN_CALL(numArgs) \ + DEF_PRIMITIVE(fn_call##numArgs) \ + { \ + call_fn(vm, args, numArgs); \ + return false; \ + } + +DEF_FN_CALL(0) +DEF_FN_CALL(1) +DEF_FN_CALL(2) +DEF_FN_CALL(3) +DEF_FN_CALL(4) +DEF_FN_CALL(5) +DEF_FN_CALL(6) +DEF_FN_CALL(7) +DEF_FN_CALL(8) +DEF_FN_CALL(9) +DEF_FN_CALL(10) +DEF_FN_CALL(11) +DEF_FN_CALL(12) +DEF_FN_CALL(13) +DEF_FN_CALL(14) +DEF_FN_CALL(15) +DEF_FN_CALL(16) + +DEF_PRIMITIVE(fn_toString) +{ + RETURN_VAL(CONST_STRING(vm, "")); +} + +// Creates a new list of size args[1], with all elements initialized to args[2]. +DEF_PRIMITIVE(list_filled) +{ + if (!validateInt(vm, args[1], "Size")) return false; + if (AS_NUM(args[1]) < 0) RETURN_ERROR("Size cannot be negative."); + + uint32_t size = (uint32_t)AS_NUM(args[1]); + ObjList* list = wrenNewList(vm, size); + + for (uint32_t i = 0; i < size; i++) + { + list->elements.data[i] = args[2]; + } + + RETURN_OBJ(list); +} + +DEF_PRIMITIVE(list_new) +{ + RETURN_OBJ(wrenNewList(vm, 0)); +} + +DEF_PRIMITIVE(list_add) +{ + wrenValueBufferWrite(vm, &AS_LIST(args[0])->elements, args[1]); + RETURN_VAL(args[1]); +} + +// Adds an element to the list and then returns the list itself. This is called +// by the compiler when compiling list literals instead of using add() to +// minimize stack churn. +DEF_PRIMITIVE(list_addCore) +{ + wrenValueBufferWrite(vm, &AS_LIST(args[0])->elements, args[1]); + + // Return the list. + RETURN_VAL(args[0]); +} + +DEF_PRIMITIVE(list_clear) +{ + wrenValueBufferClear(vm, &AS_LIST(args[0])->elements); + RETURN_NULL; +} + +DEF_PRIMITIVE(list_count) +{ + RETURN_NUM(AS_LIST(args[0])->elements.count); +} + +DEF_PRIMITIVE(list_insert) +{ + ObjList* list = AS_LIST(args[0]); + + // count + 1 here so you can "insert" at the very end. + uint32_t index = validateIndex(vm, args[1], list->elements.count + 1, + "Index"); + if (index == UINT32_MAX) return false; + + wrenListInsert(vm, list, args[2], index); + RETURN_VAL(args[2]); +} + +DEF_PRIMITIVE(list_iterate) +{ + ObjList* list = AS_LIST(args[0]); + + // If we're starting the iteration, return the first index. + if (IS_NULL(args[1])) + { + if (list->elements.count == 0) RETURN_FALSE; + RETURN_NUM(0); + } + + if (!validateInt(vm, args[1], "Iterator")) return false; + + // Stop if we're out of bounds. + double index = AS_NUM(args[1]); + if (index < 0 || index >= list->elements.count - 1) RETURN_FALSE; + + // Otherwise, move to the next index. + RETURN_NUM(index + 1); +} + +DEF_PRIMITIVE(list_iteratorValue) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t index = validateIndex(vm, args[1], list->elements.count, "Iterator"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(list->elements.data[index]); +} + +DEF_PRIMITIVE(list_removeAt) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t index = validateIndex(vm, args[1], list->elements.count, "Index"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(wrenListRemoveAt(vm, list, index)); +} + +DEF_PRIMITIVE(list_removeValue) { + ObjList* list = AS_LIST(args[0]); + int index = wrenListIndexOf(vm, list, args[1]); + if(index == -1) RETURN_NULL; + RETURN_VAL(wrenListRemoveAt(vm, list, index)); +} + +DEF_PRIMITIVE(list_indexOf) +{ + ObjList* list = AS_LIST(args[0]); + RETURN_NUM(wrenListIndexOf(vm, list, args[1])); +} + +DEF_PRIMITIVE(list_swap) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t indexA = validateIndex(vm, args[1], list->elements.count, "Index 0"); + if (indexA == UINT32_MAX) return false; + uint32_t indexB = validateIndex(vm, args[2], list->elements.count, "Index 1"); + if (indexB == UINT32_MAX) return false; + + Value a = list->elements.data[indexA]; + list->elements.data[indexA] = list->elements.data[indexB]; + list->elements.data[indexB] = a; + + RETURN_NULL; +} + +DEF_PRIMITIVE(list_subscript) +{ + ObjList* list = AS_LIST(args[0]); + + if (IS_NUM(args[1])) + { + uint32_t index = validateIndex(vm, args[1], list->elements.count, + "Subscript"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(list->elements.data[index]); + } + + if (!IS_RANGE(args[1])) + { + RETURN_ERROR("Subscript must be a number or a range."); + } + + int step; + uint32_t count = list->elements.count; + uint32_t start = calculateRange(vm, AS_RANGE(args[1]), &count, &step); + if (start == UINT32_MAX) return false; + + ObjList* result = wrenNewList(vm, count); + for (uint32_t i = 0; i < count; i++) + { + result->elements.data[i] = list->elements.data[start + i * step]; + } + + RETURN_OBJ(result); +} + +DEF_PRIMITIVE(list_subscriptSetter) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t index = validateIndex(vm, args[1], list->elements.count, + "Subscript"); + if (index == UINT32_MAX) return false; + + list->elements.data[index] = args[2]; + RETURN_VAL(args[2]); +} + +DEF_PRIMITIVE(map_new) +{ + RETURN_OBJ(wrenNewMap(vm)); +} + +DEF_PRIMITIVE(map_subscript) +{ + if (!validateKey(vm, args[1])) return false; + + ObjMap* map = AS_MAP(args[0]); + Value value = wrenMapGet(map, args[1]); + if (IS_UNDEFINED(value)) RETURN_NULL; + + RETURN_VAL(value); +} + +DEF_PRIMITIVE(map_subscriptSetter) +{ + if (!validateKey(vm, args[1])) return false; + + wrenMapSet(vm, AS_MAP(args[0]), args[1], args[2]); + RETURN_VAL(args[2]); +} + +// Adds an entry to the map and then returns the map itself. This is called by +// the compiler when compiling map literals instead of using [_]=(_) to +// minimize stack churn. +DEF_PRIMITIVE(map_addCore) +{ + if (!validateKey(vm, args[1])) return false; + + wrenMapSet(vm, AS_MAP(args[0]), args[1], args[2]); + + // Return the map itself. + RETURN_VAL(args[0]); +} + +DEF_PRIMITIVE(map_clear) +{ + wrenMapClear(vm, AS_MAP(args[0])); + RETURN_NULL; +} + +DEF_PRIMITIVE(map_containsKey) +{ + if (!validateKey(vm, args[1])) return false; + + RETURN_BOOL(!IS_UNDEFINED(wrenMapGet(AS_MAP(args[0]), args[1]))); +} + +DEF_PRIMITIVE(map_count) +{ + RETURN_NUM(AS_MAP(args[0])->count); +} + +DEF_PRIMITIVE(map_iterate) +{ + ObjMap* map = AS_MAP(args[0]); + + if (map->count == 0) RETURN_FALSE; + + // If we're starting the iteration, start at the first used entry. + uint32_t index = 0; + + // Otherwise, start one past the last entry we stopped at. + if (!IS_NULL(args[1])) + { + if (!validateInt(vm, args[1], "Iterator")) return false; + + if (AS_NUM(args[1]) < 0) RETURN_FALSE; + index = (uint32_t)AS_NUM(args[1]); + + if (index >= map->capacity) RETURN_FALSE; + + // Advance the iterator. + index++; + } + + // Find a used entry, if any. + for (; index < map->capacity; index++) + { + if (!IS_UNDEFINED(map->entries[index].key)) RETURN_NUM(index); + } + + // If we get here, walked all of the entries. + RETURN_FALSE; +} + +DEF_PRIMITIVE(map_remove) +{ + if (!validateKey(vm, args[1])) return false; + + RETURN_VAL(wrenMapRemoveKey(vm, AS_MAP(args[0]), args[1])); +} + +DEF_PRIMITIVE(map_keyIteratorValue) +{ + ObjMap* map = AS_MAP(args[0]); + uint32_t index = validateIndex(vm, args[1], map->capacity, "Iterator"); + if (index == UINT32_MAX) return false; + + MapEntry* entry = &map->entries[index]; + if (IS_UNDEFINED(entry->key)) + { + RETURN_ERROR("Invalid map iterator."); + } + + RETURN_VAL(entry->key); +} + +DEF_PRIMITIVE(map_valueIteratorValue) +{ + ObjMap* map = AS_MAP(args[0]); + uint32_t index = validateIndex(vm, args[1], map->capacity, "Iterator"); + if (index == UINT32_MAX) return false; + + MapEntry* entry = &map->entries[index]; + if (IS_UNDEFINED(entry->key)) + { + RETURN_ERROR("Invalid map iterator."); + } + + RETURN_VAL(entry->value); +} + +DEF_PRIMITIVE(null_not) +{ + RETURN_VAL(TRUE_VAL); +} + +DEF_PRIMITIVE(null_toString) +{ + RETURN_VAL(CONST_STRING(vm, "null")); +} + +DEF_PRIMITIVE(num_fromString) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[1]); + + // Corner case: Can't parse an empty string. + if (string->length == 0) RETURN_NULL; + + errno = 0; + char* end; + double number = strtod(string->value, &end); + + // Skip past any trailing whitespace. + while (*end != '\0' && isspace((unsigned char)*end)) end++; + + if (errno == ERANGE) RETURN_ERROR("Number literal is too large."); + + // We must have consumed the entire string. Otherwise, it contains non-number + // characters and we can't parse it. + if (end < string->value + string->length) RETURN_NULL; + + RETURN_NUM(number); +} + +// Defines a primitive on Num that calls infix [op] and returns [type]. +#define DEF_NUM_CONSTANT(name, value) \ + DEF_PRIMITIVE(num_##name) \ + { \ + RETURN_NUM(value); \ + } + +DEF_NUM_CONSTANT(infinity, INFINITY) +DEF_NUM_CONSTANT(nan, WREN_DOUBLE_NAN) +DEF_NUM_CONSTANT(pi, 3.14159265358979323846264338327950288) +DEF_NUM_CONSTANT(tau, 6.28318530717958647692528676655900577) + +DEF_NUM_CONSTANT(largest, DBL_MAX) +DEF_NUM_CONSTANT(smallest, DBL_MIN) + +DEF_NUM_CONSTANT(maxSafeInteger, 9007199254740991.0) +DEF_NUM_CONSTANT(minSafeInteger, -9007199254740991.0) + +// Defines a primitive on Num that calls infix [op] and returns [type]. +#define DEF_NUM_INFIX(name, op, type) \ + DEF_PRIMITIVE(num_##name) \ + { \ + if (!validateNum(vm, args[1], "Right operand")) return false; \ + RETURN_##type(AS_NUM(args[0]) op AS_NUM(args[1])); \ + } + +DEF_NUM_INFIX(minus, -, NUM) +DEF_NUM_INFIX(plus, +, NUM) +DEF_NUM_INFIX(multiply, *, NUM) +DEF_NUM_INFIX(divide, /, NUM) +DEF_NUM_INFIX(lt, <, BOOL) +DEF_NUM_INFIX(gt, >, BOOL) +DEF_NUM_INFIX(lte, <=, BOOL) +DEF_NUM_INFIX(gte, >=, BOOL) + +// Defines a primitive on Num that call infix bitwise [op]. +#define DEF_NUM_BITWISE(name, op) \ + DEF_PRIMITIVE(num_bitwise##name) \ + { \ + if (!validateNum(vm, args[1], "Right operand")) return false; \ + uint32_t left = (uint32_t)AS_NUM(args[0]); \ + uint32_t right = (uint32_t)AS_NUM(args[1]); \ + RETURN_NUM(left op right); \ + } + +DEF_NUM_BITWISE(And, &) +DEF_NUM_BITWISE(Or, |) +DEF_NUM_BITWISE(Xor, ^) +DEF_NUM_BITWISE(LeftShift, <<) +DEF_NUM_BITWISE(RightShift, >>) + +// Defines a primitive method on Num that returns the result of [fn]. +#define DEF_NUM_FN(name, fn) \ + DEF_PRIMITIVE(num_##name) \ + { \ + RETURN_NUM(fn(AS_NUM(args[0]))); \ + } + +DEF_NUM_FN(abs, fabs) +DEF_NUM_FN(acos, acos) +DEF_NUM_FN(asin, asin) +DEF_NUM_FN(atan, atan) +DEF_NUM_FN(cbrt, cbrt) +DEF_NUM_FN(ceil, ceil) +DEF_NUM_FN(cos, cos) +DEF_NUM_FN(floor, floor) +DEF_NUM_FN(negate, -) +DEF_NUM_FN(round, round) +DEF_NUM_FN(sin, sin) +DEF_NUM_FN(sqrt, sqrt) +DEF_NUM_FN(tan, tan) +DEF_NUM_FN(log, log) +DEF_NUM_FN(log2, log2) +DEF_NUM_FN(exp, exp) + +DEF_PRIMITIVE(num_mod) +{ + if (!validateNum(vm, args[1], "Right operand")) return false; + RETURN_NUM(fmod(AS_NUM(args[0]), AS_NUM(args[1]))); +} + +DEF_PRIMITIVE(num_eqeq) +{ + if (!IS_NUM(args[1])) RETURN_FALSE; + RETURN_BOOL(AS_NUM(args[0]) == AS_NUM(args[1])); +} + +DEF_PRIMITIVE(num_bangeq) +{ + if (!IS_NUM(args[1])) RETURN_TRUE; + RETURN_BOOL(AS_NUM(args[0]) != AS_NUM(args[1])); +} + +DEF_PRIMITIVE(num_bitwiseNot) +{ + // Bitwise operators always work on 32-bit unsigned ints. + RETURN_NUM(~(uint32_t)AS_NUM(args[0])); +} + +DEF_PRIMITIVE(num_dotDot) +{ + if (!validateNum(vm, args[1], "Right hand side of range")) return false; + + double from = AS_NUM(args[0]); + double to = AS_NUM(args[1]); + RETURN_VAL(wrenNewRange(vm, from, to, true)); +} + +DEF_PRIMITIVE(num_dotDotDot) +{ + if (!validateNum(vm, args[1], "Right hand side of range")) return false; + + double from = AS_NUM(args[0]); + double to = AS_NUM(args[1]); + RETURN_VAL(wrenNewRange(vm, from, to, false)); +} + +DEF_PRIMITIVE(num_atan2) +{ + if (!validateNum(vm, args[1], "x value")) return false; + + RETURN_NUM(atan2(AS_NUM(args[0]), AS_NUM(args[1]))); +} + +DEF_PRIMITIVE(num_min) +{ + if (!validateNum(vm, args[1], "Other value")) return false; + + double value = AS_NUM(args[0]); + double other = AS_NUM(args[1]); + RETURN_NUM(value <= other ? value : other); +} + +DEF_PRIMITIVE(num_max) +{ + if (!validateNum(vm, args[1], "Other value")) return false; + + double value = AS_NUM(args[0]); + double other = AS_NUM(args[1]); + RETURN_NUM(value > other ? value : other); +} + +DEF_PRIMITIVE(num_clamp) +{ + if (!validateNum(vm, args[1], "Min value")) return false; + if (!validateNum(vm, args[2], "Max value")) return false; + + double value = AS_NUM(args[0]); + double min = AS_NUM(args[1]); + double max = AS_NUM(args[2]); + double result = (value < min) ? min : ((value > max) ? max : value); + RETURN_NUM(result); +} + +DEF_PRIMITIVE(num_pow) +{ + if (!validateNum(vm, args[1], "Power value")) return false; + + RETURN_NUM(pow(AS_NUM(args[0]), AS_NUM(args[1]))); +} + +DEF_PRIMITIVE(num_fraction) +{ + double unused; + RETURN_NUM(modf(AS_NUM(args[0]) , &unused)); +} + +DEF_PRIMITIVE(num_isInfinity) +{ + RETURN_BOOL(isinf(AS_NUM(args[0]))); +} + +DEF_PRIMITIVE(num_isInteger) +{ + double value = AS_NUM(args[0]); + if (isnan(value) || isinf(value)) RETURN_FALSE; + RETURN_BOOL(trunc(value) == value); +} + +DEF_PRIMITIVE(num_isNan) +{ + RETURN_BOOL(isnan(AS_NUM(args[0]))); +} + +DEF_PRIMITIVE(num_sign) +{ + double value = AS_NUM(args[0]); + if (value > 0) + { + RETURN_NUM(1); + } + else if (value < 0) + { + RETURN_NUM(-1); + } + else + { + RETURN_NUM(0); + } +} + +DEF_PRIMITIVE(num_toString) +{ + RETURN_VAL(wrenNumToString(vm, AS_NUM(args[0]))); +} + +DEF_PRIMITIVE(num_truncate) +{ + double integer; + modf(AS_NUM(args[0]) , &integer); + RETURN_NUM(integer); +} + +DEF_PRIMITIVE(object_same) +{ + RETURN_BOOL(wrenValuesEqual(args[1], args[2])); +} + +DEF_PRIMITIVE(object_not) +{ + RETURN_VAL(FALSE_VAL); +} + +DEF_PRIMITIVE(object_eqeq) +{ + RETURN_BOOL(wrenValuesEqual(args[0], args[1])); +} + +DEF_PRIMITIVE(object_bangeq) +{ + RETURN_BOOL(!wrenValuesEqual(args[0], args[1])); +} + +DEF_PRIMITIVE(object_is) +{ + if (!IS_CLASS(args[1])) + { + RETURN_ERROR("Right operand must be a class."); + } + + ObjClass *classObj = wrenGetClass(vm, args[0]); + ObjClass *baseClassObj = AS_CLASS(args[1]); + + // Walk the superclass chain looking for the class. + do + { + if (baseClassObj == classObj) RETURN_BOOL(true); + + classObj = classObj->superclass; + } + while (classObj != NULL); + + RETURN_BOOL(false); +} + +DEF_PRIMITIVE(object_toString) +{ + Obj* obj = AS_OBJ(args[0]); + Value name = OBJ_VAL(obj->classObj->name); + RETURN_VAL(wrenStringFormat(vm, "instance of @", name)); +} + +DEF_PRIMITIVE(object_type) +{ + RETURN_OBJ(wrenGetClass(vm, args[0])); +} + +DEF_PRIMITIVE(range_from) +{ + RETURN_NUM(AS_RANGE(args[0])->from); +} + +DEF_PRIMITIVE(range_to) +{ + RETURN_NUM(AS_RANGE(args[0])->to); +} + +DEF_PRIMITIVE(range_min) +{ + ObjRange* range = AS_RANGE(args[0]); + RETURN_NUM(fmin(range->from, range->to)); +} + +DEF_PRIMITIVE(range_max) +{ + ObjRange* range = AS_RANGE(args[0]); + RETURN_NUM(fmax(range->from, range->to)); +} + +DEF_PRIMITIVE(range_isInclusive) +{ + RETURN_BOOL(AS_RANGE(args[0])->isInclusive); +} + +DEF_PRIMITIVE(range_iterate) +{ + ObjRange* range = AS_RANGE(args[0]); + + // Special case: empty range. + if (range->from == range->to && !range->isInclusive) RETURN_FALSE; + + // Start the iteration. + if (IS_NULL(args[1])) RETURN_NUM(range->from); + + if (!validateNum(vm, args[1], "Iterator")) return false; + + double iterator = AS_NUM(args[1]); + + // Iterate towards [to] from [from]. + if (range->from < range->to) + { + iterator++; + if (iterator > range->to) RETURN_FALSE; + } + else + { + iterator--; + if (iterator < range->to) RETURN_FALSE; + } + + if (!range->isInclusive && iterator == range->to) RETURN_FALSE; + + RETURN_NUM(iterator); +} + +DEF_PRIMITIVE(range_iteratorValue) +{ + // Assume the iterator is a number so that is the value of the range. + RETURN_VAL(args[1]); +} + +DEF_PRIMITIVE(range_toString) +{ + ObjRange* range = AS_RANGE(args[0]); + + Value from = wrenNumToString(vm, range->from); + wrenPushRoot(vm, AS_OBJ(from)); + + Value to = wrenNumToString(vm, range->to); + wrenPushRoot(vm, AS_OBJ(to)); + + Value result = wrenStringFormat(vm, "@$@", from, + range->isInclusive ? ".." : "...", to); + + wrenPopRoot(vm); + wrenPopRoot(vm); + RETURN_VAL(result); +} + +DEF_PRIMITIVE(string_fromCodePoint) +{ + if (!validateInt(vm, args[1], "Code point")) return false; + + int codePoint = (int)AS_NUM(args[1]); + if (codePoint < 0) + { + RETURN_ERROR("Code point cannot be negative."); + } + else if (codePoint > 0x10ffff) + { + RETURN_ERROR("Code point cannot be greater than 0x10ffff."); + } + + RETURN_VAL(wrenStringFromCodePoint(vm, codePoint)); +} + +DEF_PRIMITIVE(string_fromByte) +{ + if (!validateInt(vm, args[1], "Byte")) return false; + int byte = (int) AS_NUM(args[1]); + if (byte < 0) + { + RETURN_ERROR("Byte cannot be negative."); + } + else if (byte > 0xff) + { + RETURN_ERROR("Byte cannot be greater than 0xff."); + } + RETURN_VAL(wrenStringFromByte(vm, (uint8_t) byte)); +} + +DEF_PRIMITIVE(string_byteAt) +{ + ObjString* string = AS_STRING(args[0]); + + uint32_t index = validateIndex(vm, args[1], string->length, "Index"); + if (index == UINT32_MAX) return false; + + RETURN_NUM((uint8_t)string->value[index]); +} + +DEF_PRIMITIVE(string_byteCount) +{ + RETURN_NUM(AS_STRING(args[0])->length); +} + +DEF_PRIMITIVE(string_codePointAt) +{ + ObjString* string = AS_STRING(args[0]); + + uint32_t index = validateIndex(vm, args[1], string->length, "Index"); + if (index == UINT32_MAX) return false; + + // If we are in the middle of a UTF-8 sequence, indicate that. + const uint8_t* bytes = (uint8_t*)string->value; + if ((bytes[index] & 0xc0) == 0x80) RETURN_NUM(-1); + + // Decode the UTF-8 sequence. + RETURN_NUM(wrenUtf8Decode((uint8_t*)string->value + index, + string->length - index)); +} + +DEF_PRIMITIVE(string_contains) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + RETURN_BOOL(wrenStringFind(string, search, 0) != UINT32_MAX); +} + +DEF_PRIMITIVE(string_endsWith) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + // Edge case: If the search string is longer then return false right away. + if (search->length > string->length) RETURN_FALSE; + + RETURN_BOOL(memcmp(string->value + string->length - search->length, + search->value, search->length) == 0); +} + +DEF_PRIMITIVE(string_indexOf1) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + uint32_t index = wrenStringFind(string, search, 0); + RETURN_NUM(index == UINT32_MAX ? -1 : (int)index); +} + +DEF_PRIMITIVE(string_indexOf2) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + uint32_t start = validateIndex(vm, args[2], string->length, "Start"); + if (start == UINT32_MAX) return false; + + uint32_t index = wrenStringFind(string, search, start); + RETURN_NUM(index == UINT32_MAX ? -1 : (int)index); +} + +DEF_PRIMITIVE(string_iterate) +{ + ObjString* string = AS_STRING(args[0]); + + // If we're starting the iteration, return the first index. + if (IS_NULL(args[1])) + { + if (string->length == 0) RETURN_FALSE; + RETURN_NUM(0); + } + + if (!validateInt(vm, args[1], "Iterator")) return false; + + if (AS_NUM(args[1]) < 0) RETURN_FALSE; + uint32_t index = (uint32_t)AS_NUM(args[1]); + + // Advance to the beginning of the next UTF-8 sequence. + do + { + index++; + if (index >= string->length) RETURN_FALSE; + } while ((string->value[index] & 0xc0) == 0x80); + + RETURN_NUM(index); +} + +DEF_PRIMITIVE(string_iterateByte) +{ + ObjString* string = AS_STRING(args[0]); + + // If we're starting the iteration, return the first index. + if (IS_NULL(args[1])) + { + if (string->length == 0) RETURN_FALSE; + RETURN_NUM(0); + } + + if (!validateInt(vm, args[1], "Iterator")) return false; + + if (AS_NUM(args[1]) < 0) RETURN_FALSE; + uint32_t index = (uint32_t)AS_NUM(args[1]); + + // Advance to the next byte. + index++; + if (index >= string->length) RETURN_FALSE; + + RETURN_NUM(index); +} + +DEF_PRIMITIVE(string_iteratorValue) +{ + ObjString* string = AS_STRING(args[0]); + uint32_t index = validateIndex(vm, args[1], string->length, "Iterator"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(wrenStringCodePointAt(vm, string, index)); +} + +DEF_PRIMITIVE(string_startsWith) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + // Edge case: If the search string is longer then return false right away. + if (search->length > string->length) RETURN_FALSE; + + RETURN_BOOL(memcmp(string->value, search->value, search->length) == 0); +} + +DEF_PRIMITIVE(string_plus) +{ + if (!validateString(vm, args[1], "Right operand")) return false; + RETURN_VAL(wrenStringFormat(vm, "@@", args[0], args[1])); +} + +DEF_PRIMITIVE(string_subscript) +{ + ObjString* string = AS_STRING(args[0]); + + if (IS_NUM(args[1])) + { + int index = validateIndex(vm, args[1], string->length, "Subscript"); + if (index == -1) return false; + + RETURN_VAL(wrenStringCodePointAt(vm, string, index)); + } + + if (!IS_RANGE(args[1])) + { + RETURN_ERROR("Subscript must be a number or a range."); + } + + int step; + uint32_t count = string->length; + int start = calculateRange(vm, AS_RANGE(args[1]), &count, &step); + if (start == -1) return false; + + RETURN_VAL(wrenNewStringFromRange(vm, string, start, count, step)); +} + +DEF_PRIMITIVE(string_toString) +{ + RETURN_VAL(args[0]); +} + +DEF_PRIMITIVE(system_clock) +{ + RETURN_NUM((double)clock() / CLOCKS_PER_SEC); +} + +DEF_PRIMITIVE(system_gc) +{ + wrenCollectGarbage(vm); + RETURN_NULL; +} + +DEF_PRIMITIVE(system_writeString) +{ + if (vm->config.writeFn != NULL) + { + vm->config.writeFn(vm, AS_CSTRING(args[1])); + } + + RETURN_VAL(args[1]); +} + +// Creates either the Object or Class class in the core module with [name]. +static ObjClass* defineClass(WrenVM* vm, ObjModule* module, const char* name) +{ + ObjString* nameString = AS_STRING(wrenNewString(vm, name)); + wrenPushRoot(vm, (Obj*)nameString); + + ObjClass* classObj = wrenNewSingleClass(vm, 0, nameString); + + wrenDefineVariable(vm, module, name, nameString->length, OBJ_VAL(classObj), NULL); + + wrenPopRoot(vm); + return classObj; +} + +void wrenInitializeCore(WrenVM* vm) +{ + ObjModule* coreModule = wrenNewModule(vm, NULL); + wrenPushRoot(vm, (Obj*)coreModule); + + // The core module's key is null in the module map. + wrenMapSet(vm, vm->modules, NULL_VAL, OBJ_VAL(coreModule)); + wrenPopRoot(vm); // coreModule. + + // Define the root Object class. This has to be done a little specially + // because it has no superclass. + vm->objectClass = defineClass(vm, coreModule, "Object"); + PRIMITIVE(vm->objectClass, "!", object_not); + PRIMITIVE(vm->objectClass, "==(_)", object_eqeq); + PRIMITIVE(vm->objectClass, "!=(_)", object_bangeq); + PRIMITIVE(vm->objectClass, "is(_)", object_is); + PRIMITIVE(vm->objectClass, "toString", object_toString); + PRIMITIVE(vm->objectClass, "type", object_type); + + // Now we can define Class, which is a subclass of Object. + vm->classClass = defineClass(vm, coreModule, "Class"); + wrenBindSuperclass(vm, vm->classClass, vm->objectClass); + PRIMITIVE(vm->classClass, "name", class_name); + PRIMITIVE(vm->classClass, "supertype", class_supertype); + PRIMITIVE(vm->classClass, "toString", class_toString); + PRIMITIVE(vm->classClass, "attributes", class_attributes); + + // Finally, we can define Object's metaclass which is a subclass of Class. + ObjClass* objectMetaclass = defineClass(vm, coreModule, "Object metaclass"); + + // Wire up the metaclass relationships now that all three classes are built. + vm->objectClass->obj.classObj = objectMetaclass; + objectMetaclass->obj.classObj = vm->classClass; + vm->classClass->obj.classObj = vm->classClass; + + // Do this after wiring up the metaclasses so objectMetaclass doesn't get + // collected. + wrenBindSuperclass(vm, objectMetaclass, vm->classClass); + + PRIMITIVE(objectMetaclass, "same(_,_)", object_same); + + // The core class diagram ends up looking like this, where single lines point + // to a class's superclass, and double lines point to its metaclass: + // + // .------------------------------------. .====. + // | .---------------. | # # + // v | v | v # + // .---------. .-------------------. .-------. # + // | Object |==>| Object metaclass |==>| Class |==" + // '---------' '-------------------' '-------' + // ^ ^ ^ ^ ^ + // | .--------------' # | # + // | | # | # + // .---------. .-------------------. # | # -. + // | Base |==>| Base metaclass |======" | # | + // '---------' '-------------------' | # | + // ^ | # | + // | .------------------' # | Example classes + // | | # | + // .---------. .-------------------. # | + // | Derived |==>| Derived metaclass |==========" | + // '---------' '-------------------' -' + + // The rest of the classes can now be defined normally. + wrenInterpret(vm, NULL, coreModuleSource); + + vm->boolClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Bool")); + PRIMITIVE(vm->boolClass, "toString", bool_toString); + PRIMITIVE(vm->boolClass, "!", bool_not); + + vm->fiberClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Fiber")); + PRIMITIVE(vm->fiberClass->obj.classObj, "new(_)", fiber_new); + PRIMITIVE(vm->fiberClass->obj.classObj, "abort(_)", fiber_abort); + PRIMITIVE(vm->fiberClass->obj.classObj, "current", fiber_current); + PRIMITIVE(vm->fiberClass->obj.classObj, "suspend()", fiber_suspend); + PRIMITIVE(vm->fiberClass->obj.classObj, "yield()", fiber_yield); + PRIMITIVE(vm->fiberClass->obj.classObj, "yield(_)", fiber_yield1); + PRIMITIVE(vm->fiberClass, "call()", fiber_call); + PRIMITIVE(vm->fiberClass, "call(_)", fiber_call1); + PRIMITIVE(vm->fiberClass, "error", fiber_error); + PRIMITIVE(vm->fiberClass, "isDone", fiber_isDone); + PRIMITIVE(vm->fiberClass, "transfer()", fiber_transfer); + PRIMITIVE(vm->fiberClass, "transfer(_)", fiber_transfer1); + PRIMITIVE(vm->fiberClass, "transferError(_)", fiber_transferError); + PRIMITIVE(vm->fiberClass, "try()", fiber_try); + PRIMITIVE(vm->fiberClass, "try(_)", fiber_try1); + + vm->fnClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Fn")); + PRIMITIVE(vm->fnClass->obj.classObj, "new(_)", fn_new); + + PRIMITIVE(vm->fnClass, "arity", fn_arity); + + FUNCTION_CALL(vm->fnClass, "call()", fn_call0); + FUNCTION_CALL(vm->fnClass, "call(_)", fn_call1); + FUNCTION_CALL(vm->fnClass, "call(_,_)", fn_call2); + FUNCTION_CALL(vm->fnClass, "call(_,_,_)", fn_call3); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_)", fn_call4); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_)", fn_call5); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_)", fn_call6); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_)", fn_call7); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_)", fn_call8); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_)", fn_call9); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_)", fn_call10); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_)", fn_call11); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_)", fn_call12); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call13); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call14); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call15); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call16); + + PRIMITIVE(vm->fnClass, "toString", fn_toString); + + vm->nullClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Null")); + PRIMITIVE(vm->nullClass, "!", null_not); + PRIMITIVE(vm->nullClass, "toString", null_toString); + + vm->numClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Num")); + PRIMITIVE(vm->numClass->obj.classObj, "fromString(_)", num_fromString); + PRIMITIVE(vm->numClass->obj.classObj, "infinity", num_infinity); + PRIMITIVE(vm->numClass->obj.classObj, "nan", num_nan); + PRIMITIVE(vm->numClass->obj.classObj, "pi", num_pi); + PRIMITIVE(vm->numClass->obj.classObj, "tau", num_tau); + PRIMITIVE(vm->numClass->obj.classObj, "largest", num_largest); + PRIMITIVE(vm->numClass->obj.classObj, "smallest", num_smallest); + PRIMITIVE(vm->numClass->obj.classObj, "maxSafeInteger", num_maxSafeInteger); + PRIMITIVE(vm->numClass->obj.classObj, "minSafeInteger", num_minSafeInteger); + PRIMITIVE(vm->numClass, "-(_)", num_minus); + PRIMITIVE(vm->numClass, "+(_)", num_plus); + PRIMITIVE(vm->numClass, "*(_)", num_multiply); + PRIMITIVE(vm->numClass, "/(_)", num_divide); + PRIMITIVE(vm->numClass, "<(_)", num_lt); + PRIMITIVE(vm->numClass, ">(_)", num_gt); + PRIMITIVE(vm->numClass, "<=(_)", num_lte); + PRIMITIVE(vm->numClass, ">=(_)", num_gte); + PRIMITIVE(vm->numClass, "&(_)", num_bitwiseAnd); + PRIMITIVE(vm->numClass, "|(_)", num_bitwiseOr); + PRIMITIVE(vm->numClass, "^(_)", num_bitwiseXor); + PRIMITIVE(vm->numClass, "<<(_)", num_bitwiseLeftShift); + PRIMITIVE(vm->numClass, ">>(_)", num_bitwiseRightShift); + PRIMITIVE(vm->numClass, "abs", num_abs); + PRIMITIVE(vm->numClass, "acos", num_acos); + PRIMITIVE(vm->numClass, "asin", num_asin); + PRIMITIVE(vm->numClass, "atan", num_atan); + PRIMITIVE(vm->numClass, "cbrt", num_cbrt); + PRIMITIVE(vm->numClass, "ceil", num_ceil); + PRIMITIVE(vm->numClass, "cos", num_cos); + PRIMITIVE(vm->numClass, "floor", num_floor); + PRIMITIVE(vm->numClass, "-", num_negate); + PRIMITIVE(vm->numClass, "round", num_round); + PRIMITIVE(vm->numClass, "min(_)", num_min); + PRIMITIVE(vm->numClass, "max(_)", num_max); + PRIMITIVE(vm->numClass, "clamp(_,_)", num_clamp); + PRIMITIVE(vm->numClass, "sin", num_sin); + PRIMITIVE(vm->numClass, "sqrt", num_sqrt); + PRIMITIVE(vm->numClass, "tan", num_tan); + PRIMITIVE(vm->numClass, "log", num_log); + PRIMITIVE(vm->numClass, "log2", num_log2); + PRIMITIVE(vm->numClass, "exp", num_exp); + PRIMITIVE(vm->numClass, "%(_)", num_mod); + PRIMITIVE(vm->numClass, "~", num_bitwiseNot); + PRIMITIVE(vm->numClass, "..(_)", num_dotDot); + PRIMITIVE(vm->numClass, "...(_)", num_dotDotDot); + PRIMITIVE(vm->numClass, "atan(_)", num_atan2); + PRIMITIVE(vm->numClass, "pow(_)", num_pow); + PRIMITIVE(vm->numClass, "fraction", num_fraction); + PRIMITIVE(vm->numClass, "isInfinity", num_isInfinity); + PRIMITIVE(vm->numClass, "isInteger", num_isInteger); + PRIMITIVE(vm->numClass, "isNan", num_isNan); + PRIMITIVE(vm->numClass, "sign", num_sign); + PRIMITIVE(vm->numClass, "toString", num_toString); + PRIMITIVE(vm->numClass, "truncate", num_truncate); + + // These are defined just so that 0 and -0 are equal, which is specified by + // IEEE 754 even though they have different bit representations. + PRIMITIVE(vm->numClass, "==(_)", num_eqeq); + PRIMITIVE(vm->numClass, "!=(_)", num_bangeq); + + vm->stringClass = AS_CLASS(wrenFindVariable(vm, coreModule, "String")); + PRIMITIVE(vm->stringClass->obj.classObj, "fromCodePoint(_)", string_fromCodePoint); + PRIMITIVE(vm->stringClass->obj.classObj, "fromByte(_)", string_fromByte); + PRIMITIVE(vm->stringClass, "+(_)", string_plus); + PRIMITIVE(vm->stringClass, "[_]", string_subscript); + PRIMITIVE(vm->stringClass, "byteAt_(_)", string_byteAt); + PRIMITIVE(vm->stringClass, "byteCount_", string_byteCount); + PRIMITIVE(vm->stringClass, "codePointAt_(_)", string_codePointAt); + PRIMITIVE(vm->stringClass, "contains(_)", string_contains); + PRIMITIVE(vm->stringClass, "endsWith(_)", string_endsWith); + PRIMITIVE(vm->stringClass, "indexOf(_)", string_indexOf1); + PRIMITIVE(vm->stringClass, "indexOf(_,_)", string_indexOf2); + PRIMITIVE(vm->stringClass, "iterate(_)", string_iterate); + PRIMITIVE(vm->stringClass, "iterateByte_(_)", string_iterateByte); + PRIMITIVE(vm->stringClass, "iteratorValue(_)", string_iteratorValue); + PRIMITIVE(vm->stringClass, "startsWith(_)", string_startsWith); + PRIMITIVE(vm->stringClass, "toString", string_toString); + + vm->listClass = AS_CLASS(wrenFindVariable(vm, coreModule, "List")); + PRIMITIVE(vm->listClass->obj.classObj, "filled(_,_)", list_filled); + PRIMITIVE(vm->listClass->obj.classObj, "new()", list_new); + PRIMITIVE(vm->listClass, "[_]", list_subscript); + PRIMITIVE(vm->listClass, "[_]=(_)", list_subscriptSetter); + PRIMITIVE(vm->listClass, "add(_)", list_add); + PRIMITIVE(vm->listClass, "addCore_(_)", list_addCore); + PRIMITIVE(vm->listClass, "clear()", list_clear); + PRIMITIVE(vm->listClass, "count", list_count); + PRIMITIVE(vm->listClass, "insert(_,_)", list_insert); + PRIMITIVE(vm->listClass, "iterate(_)", list_iterate); + PRIMITIVE(vm->listClass, "iteratorValue(_)", list_iteratorValue); + PRIMITIVE(vm->listClass, "removeAt(_)", list_removeAt); + PRIMITIVE(vm->listClass, "remove(_)", list_removeValue); + PRIMITIVE(vm->listClass, "indexOf(_)", list_indexOf); + PRIMITIVE(vm->listClass, "swap(_,_)", list_swap); + + vm->mapClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Map")); + PRIMITIVE(vm->mapClass->obj.classObj, "new()", map_new); + PRIMITIVE(vm->mapClass, "[_]", map_subscript); + PRIMITIVE(vm->mapClass, "[_]=(_)", map_subscriptSetter); + PRIMITIVE(vm->mapClass, "addCore_(_,_)", map_addCore); + PRIMITIVE(vm->mapClass, "clear()", map_clear); + PRIMITIVE(vm->mapClass, "containsKey(_)", map_containsKey); + PRIMITIVE(vm->mapClass, "count", map_count); + PRIMITIVE(vm->mapClass, "remove(_)", map_remove); + PRIMITIVE(vm->mapClass, "iterate(_)", map_iterate); + PRIMITIVE(vm->mapClass, "keyIteratorValue_(_)", map_keyIteratorValue); + PRIMITIVE(vm->mapClass, "valueIteratorValue_(_)", map_valueIteratorValue); + + vm->rangeClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Range")); + PRIMITIVE(vm->rangeClass, "from", range_from); + PRIMITIVE(vm->rangeClass, "to", range_to); + PRIMITIVE(vm->rangeClass, "min", range_min); + PRIMITIVE(vm->rangeClass, "max", range_max); + PRIMITIVE(vm->rangeClass, "isInclusive", range_isInclusive); + PRIMITIVE(vm->rangeClass, "iterate(_)", range_iterate); + PRIMITIVE(vm->rangeClass, "iteratorValue(_)", range_iteratorValue); + PRIMITIVE(vm->rangeClass, "toString", range_toString); + + ObjClass* systemClass = AS_CLASS(wrenFindVariable(vm, coreModule, "System")); + PRIMITIVE(systemClass->obj.classObj, "clock", system_clock); + PRIMITIVE(systemClass->obj.classObj, "gc()", system_gc); + PRIMITIVE(systemClass->obj.classObj, "writeString_(_)", system_writeString); + + // While bootstrapping the core types and running the core module, a number + // of string objects have been created, many of which were instantiated + // before stringClass was stored in the VM. Some of them *must* be created + // first -- the ObjClass for string itself has a reference to the ObjString + // for its name. + // + // These all currently have a NULL classObj pointer, so go back and assign + // them now that the string class is known. + for (Obj* obj = vm->first; obj != NULL; obj = obj->next) + { + if (obj->type == OBJ_STRING) obj->classObj = vm->stringClass; + } +} +// End file "wren_core.c" +// Begin file "wren_value.c" +#include +#include +#include +#include + + +#if WREN_DEBUG_TRACE_MEMORY +#endif + +// TODO: Tune these. +// The initial (and minimum) capacity of a non-empty list or map object. +#define MIN_CAPACITY 16 + +// The rate at which a collection's capacity grows when the size exceeds the +// current capacity. The new capacity will be determined by *multiplying* the +// old capacity by this. Growing geometrically is necessary to ensure that +// adding to a collection has O(1) amortized complexity. +#define GROW_FACTOR 2 + +// The maximum percentage of map entries that can be filled before the map is +// grown. A lower load takes more memory but reduces collisions which makes +// lookup faster. +#define MAP_LOAD_PERCENT 75 + +// The number of call frames initially allocated when a fiber is created. Making +// this smaller makes fibers use less memory (at first) but spends more time +// reallocating when the call stack grows. +#define INITIAL_CALL_FRAMES 4 + +DEFINE_BUFFER(Value, Value); +DEFINE_BUFFER(Method, Method); + +static void initObj(WrenVM* vm, Obj* obj, ObjType type, ObjClass* classObj) +{ + obj->type = type; + obj->isDark = false; + obj->classObj = classObj; + obj->next = vm->first; + vm->first = obj; +} + +ObjClass* wrenNewSingleClass(WrenVM* vm, int numFields, ObjString* name) +{ + ObjClass* classObj = ALLOCATE(vm, ObjClass); + initObj(vm, &classObj->obj, OBJ_CLASS, NULL); + classObj->superclass = NULL; + classObj->numFields = numFields; + classObj->name = name; + classObj->attributes = NULL_VAL; + + wrenPushRoot(vm, (Obj*)classObj); + wrenMethodBufferInit(&classObj->methods); + wrenPopRoot(vm); + + return classObj; +} + +void wrenBindSuperclass(WrenVM* vm, ObjClass* subclass, ObjClass* superclass) +{ + ASSERT(superclass != NULL, "Must have superclass."); + + subclass->superclass = superclass; + + // Include the superclass in the total number of fields. + if (subclass->numFields != -1) + { + subclass->numFields += superclass->numFields; + } + else + { + ASSERT(superclass->numFields == 0, + "A foreign class cannot inherit from a class with fields."); + } + + // Inherit methods from its superclass. + for (int i = 0; i < superclass->methods.count; i++) + { + wrenBindMethod(vm, subclass, i, superclass->methods.data[i]); + } +} + +ObjClass* wrenNewClass(WrenVM* vm, ObjClass* superclass, int numFields, + ObjString* name) +{ + // Create the metaclass. + Value metaclassName = wrenStringFormat(vm, "@ metaclass", OBJ_VAL(name)); + wrenPushRoot(vm, AS_OBJ(metaclassName)); + + ObjClass* metaclass = wrenNewSingleClass(vm, 0, AS_STRING(metaclassName)); + metaclass->obj.classObj = vm->classClass; + + wrenPopRoot(vm); + + // Make sure the metaclass isn't collected when we allocate the class. + wrenPushRoot(vm, (Obj*)metaclass); + + // Metaclasses always inherit Class and do not parallel the non-metaclass + // hierarchy. + wrenBindSuperclass(vm, metaclass, vm->classClass); + + ObjClass* classObj = wrenNewSingleClass(vm, numFields, name); + + // Make sure the class isn't collected while the inherited methods are being + // bound. + wrenPushRoot(vm, (Obj*)classObj); + + classObj->obj.classObj = metaclass; + wrenBindSuperclass(vm, classObj, superclass); + + wrenPopRoot(vm); + wrenPopRoot(vm); + + return classObj; +} + +void wrenBindMethod(WrenVM* vm, ObjClass* classObj, int symbol, Method method) +{ + // Make sure the buffer is big enough to contain the symbol's index. + if (symbol >= classObj->methods.count) + { + Method noMethod; + noMethod.type = METHOD_NONE; + wrenMethodBufferFill(vm, &classObj->methods, noMethod, + symbol - classObj->methods.count + 1); + } + + classObj->methods.data[symbol] = method; +} + +ObjClosure* wrenNewClosure(WrenVM* vm, ObjFn* fn) +{ + ObjClosure* closure = ALLOCATE_FLEX(vm, ObjClosure, + ObjUpvalue*, fn->numUpvalues); + initObj(vm, &closure->obj, OBJ_CLOSURE, vm->fnClass); + + closure->fn = fn; + + // Clear the upvalue array. We need to do this in case a GC is triggered + // after the closure is created but before the upvalue array is populated. + for (int i = 0; i < fn->numUpvalues; i++) closure->upvalues[i] = NULL; + + return closure; +} + +ObjFiber* wrenNewFiber(WrenVM* vm, ObjClosure* closure) +{ + // Allocate the arrays before the fiber in case it triggers a GC. + CallFrame* frames = ALLOCATE_ARRAY(vm, CallFrame, INITIAL_CALL_FRAMES); + + // Add one slot for the unused implicit receiver slot that the compiler + // assumes all functions have. + int stackCapacity = closure == NULL + ? 1 + : wrenPowerOf2Ceil(closure->fn->maxSlots + 1); + Value* stack = ALLOCATE_ARRAY(vm, Value, stackCapacity); + + ObjFiber* fiber = ALLOCATE(vm, ObjFiber); + initObj(vm, &fiber->obj, OBJ_FIBER, vm->fiberClass); + + fiber->stack = stack; + fiber->stackTop = fiber->stack; + fiber->stackCapacity = stackCapacity; + + fiber->frames = frames; + fiber->frameCapacity = INITIAL_CALL_FRAMES; + fiber->numFrames = 0; + + fiber->openUpvalues = NULL; + fiber->caller = NULL; + fiber->error = NULL_VAL; + fiber->state = FIBER_OTHER; + + if (closure != NULL) + { + // Initialize the first call frame. + wrenAppendCallFrame(vm, fiber, closure, fiber->stack); + + // The first slot always holds the closure. + fiber->stackTop[0] = OBJ_VAL(closure); + fiber->stackTop++; + } + + return fiber; +} + +void wrenEnsureStack(WrenVM* vm, ObjFiber* fiber, int needed) +{ + if (fiber->stackCapacity >= needed) return; + + int capacity = wrenPowerOf2Ceil(needed); + + Value* oldStack = fiber->stack; + fiber->stack = (Value*)wrenReallocate(vm, fiber->stack, + sizeof(Value) * fiber->stackCapacity, + sizeof(Value) * capacity); + fiber->stackCapacity = capacity; + + // If the reallocation moves the stack, then we need to recalculate every + // pointer that points into the old stack to into the same relative distance + // in the new stack. We have to be a little careful about how these are + // calculated because pointer subtraction is only well-defined within a + // single array, hence the slightly redundant-looking arithmetic below. + if (fiber->stack != oldStack) + { + // Top of the stack. + if (vm->apiStack >= oldStack && vm->apiStack <= fiber->stackTop) + { + vm->apiStack = fiber->stack + (vm->apiStack - oldStack); + } + + // Stack pointer for each call frame. + for (int i = 0; i < fiber->numFrames; i++) + { + CallFrame* frame = &fiber->frames[i]; + frame->stackStart = fiber->stack + (frame->stackStart - oldStack); + } + + // Open upvalues. + for (ObjUpvalue* upvalue = fiber->openUpvalues; + upvalue != NULL; + upvalue = upvalue->next) + { + upvalue->value = fiber->stack + (upvalue->value - oldStack); + } + + fiber->stackTop = fiber->stack + (fiber->stackTop - oldStack); + } +} + +ObjForeign* wrenNewForeign(WrenVM* vm, ObjClass* classObj, size_t size) +{ + ObjForeign* object = ALLOCATE_FLEX(vm, ObjForeign, uint8_t, size); + initObj(vm, &object->obj, OBJ_FOREIGN, classObj); + + // Zero out the bytes. + memset(object->data, 0, size); + return object; +} + +ObjFn* wrenNewFunction(WrenVM* vm, ObjModule* module, int maxSlots) +{ + FnDebug* debug = ALLOCATE(vm, FnDebug); + debug->name = NULL; + wrenIntBufferInit(&debug->sourceLines); + + ObjFn* fn = ALLOCATE(vm, ObjFn); + initObj(vm, &fn->obj, OBJ_FN, vm->fnClass); + + wrenValueBufferInit(&fn->constants); + wrenByteBufferInit(&fn->code); + fn->module = module; + fn->maxSlots = maxSlots; + fn->numUpvalues = 0; + fn->arity = 0; + fn->debug = debug; + + return fn; +} + +void wrenFunctionBindName(WrenVM* vm, ObjFn* fn, const char* name, int length) +{ + fn->debug->name = ALLOCATE_ARRAY(vm, char, length + 1); + memcpy(fn->debug->name, name, length); + fn->debug->name[length] = '\0'; +} + +Value wrenNewInstance(WrenVM* vm, ObjClass* classObj) +{ + ObjInstance* instance = ALLOCATE_FLEX(vm, ObjInstance, + Value, classObj->numFields); + initObj(vm, &instance->obj, OBJ_INSTANCE, classObj); + + // Initialize fields to null. + for (int i = 0; i < classObj->numFields; i++) + { + instance->fields[i] = NULL_VAL; + } + + return OBJ_VAL(instance); +} + +ObjList* wrenNewList(WrenVM* vm, uint32_t numElements) +{ + // Allocate this before the list object in case it triggers a GC which would + // free the list. + Value* elements = NULL; + if (numElements > 0) + { + elements = ALLOCATE_ARRAY(vm, Value, numElements); + } + + ObjList* list = ALLOCATE(vm, ObjList); + initObj(vm, &list->obj, OBJ_LIST, vm->listClass); + list->elements.capacity = numElements; + list->elements.count = numElements; + list->elements.data = elements; + return list; +} + +void wrenListInsert(WrenVM* vm, ObjList* list, Value value, uint32_t index) +{ + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + // Add a slot at the end of the list. + wrenValueBufferWrite(vm, &list->elements, NULL_VAL); + + if (IS_OBJ(value)) wrenPopRoot(vm); + + // Shift the existing elements down. + for (uint32_t i = list->elements.count - 1; i > index; i--) + { + list->elements.data[i] = list->elements.data[i - 1]; + } + + // Store the new element. + list->elements.data[index] = value; +} + +int wrenListIndexOf(WrenVM* vm, ObjList* list, Value value) +{ + int count = list->elements.count; + for (int i = 0; i < count; i++) + { + Value item = list->elements.data[i]; + if(wrenValuesEqual(item, value)) { + return i; + } + } + return -1; +} + +Value wrenListRemoveAt(WrenVM* vm, ObjList* list, uint32_t index) +{ + Value removed = list->elements.data[index]; + + if (IS_OBJ(removed)) wrenPushRoot(vm, AS_OBJ(removed)); + + // Shift items up. + for (int i = index; i < list->elements.count - 1; i++) + { + list->elements.data[i] = list->elements.data[i + 1]; + } + + // If we have too much excess capacity, shrink it. + if (list->elements.capacity / GROW_FACTOR >= list->elements.count) + { + list->elements.data = (Value*)wrenReallocate(vm, list->elements.data, + sizeof(Value) * list->elements.capacity, + sizeof(Value) * (list->elements.capacity / GROW_FACTOR)); + list->elements.capacity /= GROW_FACTOR; + } + + if (IS_OBJ(removed)) wrenPopRoot(vm); + + list->elements.count--; + return removed; +} + +ObjMap* wrenNewMap(WrenVM* vm) +{ + ObjMap* map = ALLOCATE(vm, ObjMap); + initObj(vm, &map->obj, OBJ_MAP, vm->mapClass); + map->capacity = 0; + map->count = 0; + map->entries = NULL; + return map; +} + +static inline uint32_t hashBits(uint64_t hash) +{ + // From v8's ComputeLongHash() which in turn cites: + // Thomas Wang, Integer Hash Functions. + // http://www.concentric.net/~Ttwang/tech/inthash.htm + hash = ~hash + (hash << 18); // hash = (hash << 18) - hash - 1; + hash = hash ^ (hash >> 31); + hash = hash * 21; // hash = (hash + (hash << 2)) + (hash << 4); + hash = hash ^ (hash >> 11); + hash = hash + (hash << 6); + hash = hash ^ (hash >> 22); + return (uint32_t)(hash & 0x3fffffff); +} + +// Generates a hash code for [num]. +static inline uint32_t hashNumber(double num) +{ + // Hash the raw bits of the value. + return hashBits(wrenDoubleToBits(num)); +} + +// Generates a hash code for [object]. +static uint32_t hashObject(Obj* object) +{ + switch (object->type) + { + case OBJ_CLASS: + // Classes just use their name. + return hashObject((Obj*)((ObjClass*)object)->name); + + // Allow bare (non-closure) functions so that we can use a map to find + // existing constants in a function's constant table. This is only used + // internally. Since user code never sees a non-closure function, they + // cannot use them as map keys. + case OBJ_FN: + { + ObjFn* fn = (ObjFn*)object; + return hashNumber(fn->arity) ^ hashNumber(fn->code.count); + } + + case OBJ_RANGE: + { + ObjRange* range = (ObjRange*)object; + return hashNumber(range->from) ^ hashNumber(range->to); + } + + case OBJ_STRING: + return ((ObjString*)object)->hash; + + default: + ASSERT(false, "Only immutable objects can be hashed."); + return 0; + } +} + +// Generates a hash code for [value], which must be one of the built-in +// immutable types: null, bool, class, num, range, or string. +static uint32_t hashValue(Value value) +{ + // TODO: We'll probably want to randomize this at some point. + +#if WREN_NAN_TAGGING + if (IS_OBJ(value)) return hashObject(AS_OBJ(value)); + + // Hash the raw bits of the unboxed value. + return hashBits(value); +#else + switch (value.type) + { + case VAL_FALSE: return 0; + case VAL_NULL: return 1; + case VAL_NUM: return hashNumber(AS_NUM(value)); + case VAL_TRUE: return 2; + case VAL_OBJ: return hashObject(AS_OBJ(value)); + default: UNREACHABLE(); + } + + return 0; +#endif +} + +// Looks for an entry with [key] in an array of [capacity] [entries]. +// +// If found, sets [result] to point to it and returns `true`. Otherwise, +// returns `false` and points [result] to the entry where the key/value pair +// should be inserted. +static bool findEntry(MapEntry* entries, uint32_t capacity, Value key, + MapEntry** result) +{ + // If there is no entry array (an empty map), we definitely won't find it. + if (capacity == 0) return false; + + // Figure out where to insert it in the table. Use open addressing and + // basic linear probing. + uint32_t startIndex = hashValue(key) % capacity; + uint32_t index = startIndex; + + // If we pass a tombstone and don't end up finding the key, its entry will + // be re-used for the insert. + MapEntry* tombstone = NULL; + + // Walk the probe sequence until we've tried every slot. + do + { + MapEntry* entry = &entries[index]; + + if (IS_UNDEFINED(entry->key)) + { + // If we found an empty slot, the key is not in the table. If we found a + // slot that contains a deleted key, we have to keep looking. + if (IS_FALSE(entry->value)) + { + // We found an empty slot, so we've reached the end of the probe + // sequence without finding the key. If we passed a tombstone, then + // that's where we should insert the item, otherwise, put it here at + // the end of the sequence. + *result = tombstone != NULL ? tombstone : entry; + return false; + } + else + { + // We found a tombstone. We need to keep looking in case the key is + // after it, but we'll use this entry as the insertion point if the + // key ends up not being found. + if (tombstone == NULL) tombstone = entry; + } + } + else if (wrenValuesEqual(entry->key, key)) + { + // We found the key. + *result = entry; + return true; + } + + // Try the next slot. + index = (index + 1) % capacity; + } + while (index != startIndex); + + // If we get here, the table is full of tombstones. Return the first one we + // found. + ASSERT(tombstone != NULL, "Map should have tombstones or empty entries."); + *result = tombstone; + return false; +} + +// Inserts [key] and [value] in the array of [entries] with the given +// [capacity]. +// +// Returns `true` if this is the first time [key] was added to the map. +static bool insertEntry(MapEntry* entries, uint32_t capacity, + Value key, Value value) +{ + ASSERT(entries != NULL, "Should ensure capacity before inserting."); + + MapEntry* entry; + if (findEntry(entries, capacity, key, &entry)) + { + // Already present, so just replace the value. + entry->value = value; + return false; + } + else + { + entry->key = key; + entry->value = value; + return true; + } +} + +// Updates [map]'s entry array to [capacity]. +static void resizeMap(WrenVM* vm, ObjMap* map, uint32_t capacity) +{ + // Create the new empty hash table. + MapEntry* entries = ALLOCATE_ARRAY(vm, MapEntry, capacity); + for (uint32_t i = 0; i < capacity; i++) + { + entries[i].key = UNDEFINED_VAL; + entries[i].value = FALSE_VAL; + } + + // Re-add the existing entries. + if (map->capacity > 0) + { + for (uint32_t i = 0; i < map->capacity; i++) + { + MapEntry* entry = &map->entries[i]; + + // Don't copy empty entries or tombstones. + if (IS_UNDEFINED(entry->key)) continue; + + insertEntry(entries, capacity, entry->key, entry->value); + } + } + + // Replace the array. + DEALLOCATE(vm, map->entries); + map->entries = entries; + map->capacity = capacity; +} + +Value wrenMapGet(ObjMap* map, Value key) +{ + MapEntry* entry; + if (findEntry(map->entries, map->capacity, key, &entry)) return entry->value; + + return UNDEFINED_VAL; +} + +void wrenMapSet(WrenVM* vm, ObjMap* map, Value key, Value value) +{ + // If the map is getting too full, make room first. + if (map->count + 1 > map->capacity * MAP_LOAD_PERCENT / 100) + { + // Figure out the new hash table size. + uint32_t capacity = map->capacity * GROW_FACTOR; + if (capacity < MIN_CAPACITY) capacity = MIN_CAPACITY; + + resizeMap(vm, map, capacity); + } + + if (insertEntry(map->entries, map->capacity, key, value)) + { + // A new key was added. + map->count++; + } +} + +void wrenMapClear(WrenVM* vm, ObjMap* map) +{ + DEALLOCATE(vm, map->entries); + map->entries = NULL; + map->capacity = 0; + map->count = 0; +} + +Value wrenMapRemoveKey(WrenVM* vm, ObjMap* map, Value key) +{ + MapEntry* entry; + if (!findEntry(map->entries, map->capacity, key, &entry)) return NULL_VAL; + + // Remove the entry from the map. Set this value to true, which marks it as a + // deleted slot. When searching for a key, we will stop on empty slots, but + // continue past deleted slots. + Value value = entry->value; + entry->key = UNDEFINED_VAL; + entry->value = TRUE_VAL; + + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + map->count--; + + if (map->count == 0) + { + // Removed the last item, so free the array. + wrenMapClear(vm, map); + } + else if (map->capacity > MIN_CAPACITY && + map->count < map->capacity / GROW_FACTOR * MAP_LOAD_PERCENT / 100) + { + uint32_t capacity = map->capacity / GROW_FACTOR; + if (capacity < MIN_CAPACITY) capacity = MIN_CAPACITY; + + // The map is getting empty, so shrink the entry array back down. + // TODO: Should we do this less aggressively than we grow? + resizeMap(vm, map, capacity); + } + + if (IS_OBJ(value)) wrenPopRoot(vm); + return value; +} + +ObjModule* wrenNewModule(WrenVM* vm, ObjString* name) +{ + ObjModule* module = ALLOCATE(vm, ObjModule); + + // Modules are never used as first-class objects, so don't need a class. + initObj(vm, (Obj*)module, OBJ_MODULE, NULL); + + wrenPushRoot(vm, (Obj*)module); + + wrenSymbolTableInit(&module->variableNames); + wrenValueBufferInit(&module->variables); + + module->name = name; + + wrenPopRoot(vm); + return module; +} + +Value wrenNewRange(WrenVM* vm, double from, double to, bool isInclusive) +{ + ObjRange* range = ALLOCATE(vm, ObjRange); + initObj(vm, &range->obj, OBJ_RANGE, vm->rangeClass); + range->from = from; + range->to = to; + range->isInclusive = isInclusive; + + return OBJ_VAL(range); +} + +// Creates a new string object with a null-terminated buffer large enough to +// hold a string of [length] but does not fill in the bytes. +// +// The caller is expected to fill in the buffer and then calculate the string's +// hash. +static ObjString* allocateString(WrenVM* vm, size_t length) +{ + ObjString* string = ALLOCATE_FLEX(vm, ObjString, char, length + 1); + initObj(vm, &string->obj, OBJ_STRING, vm->stringClass); + string->length = (int)length; + string->value[length] = '\0'; + + return string; +} + +// Calculates and stores the hash code for [string]. +static void hashString(ObjString* string) +{ + // FNV-1a hash. See: http://www.isthe.com/chongo/tech/comp/fnv/ + uint32_t hash = 2166136261u; + + // This is O(n) on the length of the string, but we only call this when a new + // string is created. Since the creation is also O(n) (to copy/initialize all + // the bytes), we allow this here. + for (uint32_t i = 0; i < string->length; i++) + { + hash ^= string->value[i]; + hash *= 16777619; + } + + string->hash = hash; +} + +Value wrenNewString(WrenVM* vm, const char* text) +{ + return wrenNewStringLength(vm, text, strlen(text)); +} + +Value wrenNewStringLength(WrenVM* vm, const char* text, size_t length) +{ + // Allow NULL if the string is empty since byte buffers don't allocate any + // characters for a zero-length string. + ASSERT(length == 0 || text != NULL, "Unexpected NULL string."); + + ObjString* string = allocateString(vm, length); + + // Copy the string (if given one). + if (length > 0 && text != NULL) memcpy(string->value, text, length); + + hashString(string); + return OBJ_VAL(string); +} + + +Value wrenNewStringFromRange(WrenVM* vm, ObjString* source, int start, + uint32_t count, int step) +{ + uint8_t* from = (uint8_t*)source->value; + int length = 0; + for (uint32_t i = 0; i < count; i++) + { + length += wrenUtf8DecodeNumBytes(from[start + i * step]); + } + + ObjString* result = allocateString(vm, length); + result->value[length] = '\0'; + + uint8_t* to = (uint8_t*)result->value; + for (uint32_t i = 0; i < count; i++) + { + int index = start + i * step; + int codePoint = wrenUtf8Decode(from + index, source->length - index); + + if (codePoint != -1) + { + to += wrenUtf8Encode(codePoint, to); + } + } + + hashString(result); + return OBJ_VAL(result); +} + +Value wrenNumToString(WrenVM* vm, double value) +{ + // Edge case: If the value is NaN or infinity, different versions of libc + // produce different outputs (some will format it signed and some won't). To + // get reliable output, handle it ourselves. + if (isnan(value)) return CONST_STRING(vm, "nan"); + if (isinf(value)) + { + if (value > 0.0) + { + return CONST_STRING(vm, "infinity"); + } + else + { + return CONST_STRING(vm, "-infinity"); + } + } + + // This is large enough to hold any double converted to a string using + // "%.14g". Example: + // + // -1.12345678901234e-1022 + // + // So we have: + // + // + 1 char for sign + // + 1 char for digit + // + 1 char for "." + // + 14 chars for decimal digits + // + 1 char for "e" + // + 1 char for "-" or "+" + // + 4 chars for exponent + // + 1 char for "\0" + // = 24 + char buffer[24]; + int length = sprintf(buffer, "%.14g", value); + return wrenNewStringLength(vm, buffer, length); +} + +Value wrenStringFromCodePoint(WrenVM* vm, int value) +{ + int length = wrenUtf8EncodeNumBytes(value); + ASSERT(length != 0, "Value out of range."); + + ObjString* string = allocateString(vm, length); + + wrenUtf8Encode(value, (uint8_t*)string->value); + hashString(string); + + return OBJ_VAL(string); +} + +Value wrenStringFromByte(WrenVM *vm, uint8_t value) +{ + int length = 1; + ObjString* string = allocateString(vm, length); + string->value[0] = value; + hashString(string); + return OBJ_VAL(string); +} + +Value wrenStringFormat(WrenVM* vm, const char* format, ...) +{ + va_list argList; + + // Calculate the length of the result string. Do this up front so we can + // create the final string with a single allocation. + va_start(argList, format); + size_t totalLength = 0; + for (const char* c = format; *c != '\0'; c++) + { + switch (*c) + { + case '$': + totalLength += strlen(va_arg(argList, const char*)); + break; + + case '@': + totalLength += AS_STRING(va_arg(argList, Value))->length; + break; + + default: + // Any other character is interpreted literally. + totalLength++; + } + } + va_end(argList); + + // Concatenate the string. + ObjString* result = allocateString(vm, totalLength); + + va_start(argList, format); + char* start = result->value; + for (const char* c = format; *c != '\0'; c++) + { + switch (*c) + { + case '$': + { + const char* string = va_arg(argList, const char*); + size_t length = strlen(string); + memcpy(start, string, length); + start += length; + break; + } + + case '@': + { + ObjString* string = AS_STRING(va_arg(argList, Value)); + memcpy(start, string->value, string->length); + start += string->length; + break; + } + + default: + // Any other character is interpreted literally. + *start++ = *c; + } + } + va_end(argList); + + hashString(result); + + return OBJ_VAL(result); +} + +Value wrenStringCodePointAt(WrenVM* vm, ObjString* string, uint32_t index) +{ + ASSERT(index < string->length, "Index out of bounds."); + + int codePoint = wrenUtf8Decode((uint8_t*)string->value + index, + string->length - index); + if (codePoint == -1) + { + // If it isn't a valid UTF-8 sequence, treat it as a single raw byte. + char bytes[2]; + bytes[0] = string->value[index]; + bytes[1] = '\0'; + return wrenNewStringLength(vm, bytes, 1); + } + + return wrenStringFromCodePoint(vm, codePoint); +} + +// Uses the Boyer-Moore-Horspool string matching algorithm. +uint32_t wrenStringFind(ObjString* haystack, ObjString* needle, uint32_t start) +{ + // Edge case: An empty needle is always found. + if (needle->length == 0) return start; + + // If the needle goes past the haystack it won't be found. + if (start + needle->length > haystack->length) return UINT32_MAX; + + // If the startIndex is too far it also won't be found. + if (start >= haystack->length) return UINT32_MAX; + + // Pre-calculate the shift table. For each character (8-bit value), we + // determine how far the search window can be advanced if that character is + // the last character in the haystack where we are searching for the needle + // and the needle doesn't match there. + uint32_t shift[UINT8_MAX]; + uint32_t needleEnd = needle->length - 1; + + // By default, we assume the character is not the needle at all. In that case + // case, if a match fails on that character, we can advance one whole needle + // width since. + for (uint32_t index = 0; index < UINT8_MAX; index++) + { + shift[index] = needle->length; + } + + // Then, for every character in the needle, determine how far it is from the + // end. If a match fails on that character, we can advance the window such + // that it the last character in it lines up with the last place we could + // find it in the needle. + for (uint32_t index = 0; index < needleEnd; index++) + { + char c = needle->value[index]; + shift[(uint8_t)c] = needleEnd - index; + } + + // Slide the needle across the haystack, looking for the first match or + // stopping if the needle goes off the end. + char lastChar = needle->value[needleEnd]; + uint32_t range = haystack->length - needle->length; + + for (uint32_t index = start; index <= range; ) + { + // Compare the last character in the haystack's window to the last character + // in the needle. If it matches, see if the whole needle matches. + char c = haystack->value[index + needleEnd]; + if (lastChar == c && + memcmp(haystack->value + index, needle->value, needleEnd) == 0) + { + // Found a match. + return index; + } + + // Otherwise, slide the needle forward. + index += shift[(uint8_t)c]; + } + + // Not found. + return UINT32_MAX; +} + +ObjUpvalue* wrenNewUpvalue(WrenVM* vm, Value* value) +{ + ObjUpvalue* upvalue = ALLOCATE(vm, ObjUpvalue); + + // Upvalues are never used as first-class objects, so don't need a class. + initObj(vm, &upvalue->obj, OBJ_UPVALUE, NULL); + + upvalue->value = value; + upvalue->closed = NULL_VAL; + upvalue->next = NULL; + return upvalue; +} + +void wrenGrayObj(WrenVM* vm, Obj* obj) +{ + if (obj == NULL) return; + + // Stop if the object is already darkened so we don't get stuck in a cycle. + if (obj->isDark) return; + + // It's been reached. + obj->isDark = true; + + // Add it to the gray list so it can be recursively explored for + // more marks later. + if (vm->grayCount >= vm->grayCapacity) + { + vm->grayCapacity = vm->grayCount * 2; + vm->gray = (Obj**)vm->config.reallocateFn(vm->gray, + vm->grayCapacity * sizeof(Obj*), + vm->config.userData); + } + + vm->gray[vm->grayCount++] = obj; +} + +void wrenGrayValue(WrenVM* vm, Value value) +{ + if (!IS_OBJ(value)) return; + wrenGrayObj(vm, AS_OBJ(value)); +} + +void wrenGrayBuffer(WrenVM* vm, ValueBuffer* buffer) +{ + for (int i = 0; i < buffer->count; i++) + { + wrenGrayValue(vm, buffer->data[i]); + } +} + +static void blackenClass(WrenVM* vm, ObjClass* classObj) +{ + // The metaclass. + wrenGrayObj(vm, (Obj*)classObj->obj.classObj); + + // The superclass. + wrenGrayObj(vm, (Obj*)classObj->superclass); + + // Method function objects. + for (int i = 0; i < classObj->methods.count; i++) + { + if (classObj->methods.data[i].type == METHOD_BLOCK) + { + wrenGrayObj(vm, (Obj*)classObj->methods.data[i].as.closure); + } + } + + wrenGrayObj(vm, (Obj*)classObj->name); + + if(!IS_NULL(classObj->attributes)) wrenGrayObj(vm, AS_OBJ(classObj->attributes)); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjClass); + vm->bytesAllocated += classObj->methods.capacity * sizeof(Method); +} + +static void blackenClosure(WrenVM* vm, ObjClosure* closure) +{ + // Mark the function. + wrenGrayObj(vm, (Obj*)closure->fn); + + // Mark the upvalues. + for (int i = 0; i < closure->fn->numUpvalues; i++) + { + wrenGrayObj(vm, (Obj*)closure->upvalues[i]); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjClosure); + vm->bytesAllocated += sizeof(ObjUpvalue*) * closure->fn->numUpvalues; +} + +static void blackenFiber(WrenVM* vm, ObjFiber* fiber) +{ + // Stack functions. + for (int i = 0; i < fiber->numFrames; i++) + { + wrenGrayObj(vm, (Obj*)fiber->frames[i].closure); + } + + // Stack variables. + for (Value* slot = fiber->stack; slot < fiber->stackTop; slot++) + { + wrenGrayValue(vm, *slot); + } + + // Open upvalues. + ObjUpvalue* upvalue = fiber->openUpvalues; + while (upvalue != NULL) + { + wrenGrayObj(vm, (Obj*)upvalue); + upvalue = upvalue->next; + } + + // The caller. + wrenGrayObj(vm, (Obj*)fiber->caller); + wrenGrayValue(vm, fiber->error); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjFiber); + vm->bytesAllocated += fiber->frameCapacity * sizeof(CallFrame); + vm->bytesAllocated += fiber->stackCapacity * sizeof(Value); +} + +static void blackenFn(WrenVM* vm, ObjFn* fn) +{ + // Mark the constants. + wrenGrayBuffer(vm, &fn->constants); + + // Mark the module it belongs to, in case it's been unloaded. + wrenGrayObj(vm, (Obj*)fn->module); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjFn); + vm->bytesAllocated += sizeof(uint8_t) * fn->code.capacity; + vm->bytesAllocated += sizeof(Value) * fn->constants.capacity; + + // The debug line number buffer. + vm->bytesAllocated += sizeof(int) * fn->code.capacity; + // TODO: What about the function name? +} + +static void blackenForeign(WrenVM* vm, ObjForeign* foreign) +{ + // TODO: Keep track of how much memory the foreign object uses. We can store + // this in each foreign object, but it will balloon the size. We may not want + // that much overhead. One option would be to let the foreign class register + // a C function that returns a size for the object. That way the VM doesn't + // always have to explicitly store it. +} + +static void blackenInstance(WrenVM* vm, ObjInstance* instance) +{ + wrenGrayObj(vm, (Obj*)instance->obj.classObj); + + // Mark the fields. + for (int i = 0; i < instance->obj.classObj->numFields; i++) + { + wrenGrayValue(vm, instance->fields[i]); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjInstance); + vm->bytesAllocated += sizeof(Value) * instance->obj.classObj->numFields; +} + +static void blackenList(WrenVM* vm, ObjList* list) +{ + // Mark the elements. + wrenGrayBuffer(vm, &list->elements); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjList); + vm->bytesAllocated += sizeof(Value) * list->elements.capacity; +} + +static void blackenMap(WrenVM* vm, ObjMap* map) +{ + // Mark the entries. + for (uint32_t i = 0; i < map->capacity; i++) + { + MapEntry* entry = &map->entries[i]; + if (IS_UNDEFINED(entry->key)) continue; + + wrenGrayValue(vm, entry->key); + wrenGrayValue(vm, entry->value); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjMap); + vm->bytesAllocated += sizeof(MapEntry) * map->capacity; +} + +static void blackenModule(WrenVM* vm, ObjModule* module) +{ + // Top-level variables. + for (int i = 0; i < module->variables.count; i++) + { + wrenGrayValue(vm, module->variables.data[i]); + } + + wrenBlackenSymbolTable(vm, &module->variableNames); + + wrenGrayObj(vm, (Obj*)module->name); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjModule); +} + +static void blackenRange(WrenVM* vm, ObjRange* range) +{ + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjRange); +} + +static void blackenString(WrenVM* vm, ObjString* string) +{ + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjString) + string->length + 1; +} + +static void blackenUpvalue(WrenVM* vm, ObjUpvalue* upvalue) +{ + // Mark the closed-over object (in case it is closed). + wrenGrayValue(vm, upvalue->closed); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjUpvalue); +} + +static void blackenObject(WrenVM* vm, Obj* obj) +{ +#if WREN_DEBUG_TRACE_MEMORY + printf("mark "); + wrenDumpValue(OBJ_VAL(obj)); + printf(" @ %p\n", obj); +#endif + + // Traverse the object's fields. + switch (obj->type) + { + case OBJ_CLASS: blackenClass( vm, (ObjClass*) obj); break; + case OBJ_CLOSURE: blackenClosure( vm, (ObjClosure*) obj); break; + case OBJ_FIBER: blackenFiber( vm, (ObjFiber*) obj); break; + case OBJ_FN: blackenFn( vm, (ObjFn*) obj); break; + case OBJ_FOREIGN: blackenForeign( vm, (ObjForeign*) obj); break; + case OBJ_INSTANCE: blackenInstance(vm, (ObjInstance*)obj); break; + case OBJ_LIST: blackenList( vm, (ObjList*) obj); break; + case OBJ_MAP: blackenMap( vm, (ObjMap*) obj); break; + case OBJ_MODULE: blackenModule( vm, (ObjModule*) obj); break; + case OBJ_RANGE: blackenRange( vm, (ObjRange*) obj); break; + case OBJ_STRING: blackenString( vm, (ObjString*) obj); break; + case OBJ_UPVALUE: blackenUpvalue( vm, (ObjUpvalue*) obj); break; + } +} + +void wrenBlackenObjects(WrenVM* vm) +{ + while (vm->grayCount > 0) + { + // Pop an item from the gray stack. + Obj* obj = vm->gray[--vm->grayCount]; + blackenObject(vm, obj); + } +} + +void wrenFreeObj(WrenVM* vm, Obj* obj) +{ +#if WREN_DEBUG_TRACE_MEMORY + printf("free "); + wrenDumpValue(OBJ_VAL(obj)); + printf(" @ %p\n", obj); +#endif + + switch (obj->type) + { + case OBJ_CLASS: + wrenMethodBufferClear(vm, &((ObjClass*)obj)->methods); + break; + + case OBJ_FIBER: + { + ObjFiber* fiber = (ObjFiber*)obj; + DEALLOCATE(vm, fiber->frames); + DEALLOCATE(vm, fiber->stack); + break; + } + + case OBJ_FN: + { + ObjFn* fn = (ObjFn*)obj; + wrenValueBufferClear(vm, &fn->constants); + wrenByteBufferClear(vm, &fn->code); + wrenIntBufferClear(vm, &fn->debug->sourceLines); + DEALLOCATE(vm, fn->debug->name); + DEALLOCATE(vm, fn->debug); + break; + } + + case OBJ_FOREIGN: + wrenFinalizeForeign(vm, (ObjForeign*)obj); + break; + + case OBJ_LIST: + wrenValueBufferClear(vm, &((ObjList*)obj)->elements); + break; + + case OBJ_MAP: + DEALLOCATE(vm, ((ObjMap*)obj)->entries); + break; + + case OBJ_MODULE: + wrenSymbolTableClear(vm, &((ObjModule*)obj)->variableNames); + wrenValueBufferClear(vm, &((ObjModule*)obj)->variables); + break; + + case OBJ_CLOSURE: + case OBJ_INSTANCE: + case OBJ_RANGE: + case OBJ_STRING: + case OBJ_UPVALUE: + break; + } + + DEALLOCATE(vm, obj); +} + +ObjClass* wrenGetClass(WrenVM* vm, Value value) +{ + return wrenGetClassInline(vm, value); +} + +bool wrenValuesEqual(Value a, Value b) +{ + if (wrenValuesSame(a, b)) return true; + + // If we get here, it's only possible for two heap-allocated immutable objects + // to be equal. + if (!IS_OBJ(a) || !IS_OBJ(b)) return false; + + Obj* aObj = AS_OBJ(a); + Obj* bObj = AS_OBJ(b); + + // Must be the same type. + if (aObj->type != bObj->type) return false; + + switch (aObj->type) + { + case OBJ_RANGE: + { + ObjRange* aRange = (ObjRange*)aObj; + ObjRange* bRange = (ObjRange*)bObj; + return aRange->from == bRange->from && + aRange->to == bRange->to && + aRange->isInclusive == bRange->isInclusive; + } + + case OBJ_STRING: + { + ObjString* aString = (ObjString*)aObj; + ObjString* bString = (ObjString*)bObj; + return aString->hash == bString->hash && + wrenStringEqualsCString(aString, bString->value, bString->length); + } + + default: + // All other types are only equal if they are same, which they aren't if + // we get here. + return false; + } +} +// End file "wren_value.c" +// Begin file "wren_utils.c" +#include + + +DEFINE_BUFFER(Byte, uint8_t); +DEFINE_BUFFER(Int, int); +DEFINE_BUFFER(String, ObjString*); + +void wrenSymbolTableInit(SymbolTable* symbols) +{ + wrenStringBufferInit(symbols); +} + +void wrenSymbolTableClear(WrenVM* vm, SymbolTable* symbols) +{ + wrenStringBufferClear(vm, symbols); +} + +int wrenSymbolTableAdd(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length) +{ + ObjString* symbol = AS_STRING(wrenNewStringLength(vm, name, length)); + + wrenPushRoot(vm, &symbol->obj); + wrenStringBufferWrite(vm, symbols, symbol); + wrenPopRoot(vm); + + return symbols->count - 1; +} + +int wrenSymbolTableEnsure(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length) +{ + // See if the symbol is already defined. + int existing = wrenSymbolTableFind(symbols, name, length); + if (existing != -1) return existing; + + // New symbol, so add it. + return wrenSymbolTableAdd(vm, symbols, name, length); +} + +int wrenSymbolTableFind(const SymbolTable* symbols, + const char* name, size_t length) +{ + // See if the symbol is already defined. + // TODO: O(n). Do something better. + for (int i = 0; i < symbols->count; i++) + { + if (wrenStringEqualsCString(symbols->data[i], name, length)) return i; + } + + return -1; +} + +void wrenBlackenSymbolTable(WrenVM* vm, SymbolTable* symbolTable) +{ + for (int i = 0; i < symbolTable->count; i++) + { + wrenGrayObj(vm, &symbolTable->data[i]->obj); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += symbolTable->capacity * sizeof(*symbolTable->data); +} + +int wrenUtf8EncodeNumBytes(int value) +{ + ASSERT(value >= 0, "Cannot encode a negative value."); + + if (value <= 0x7f) return 1; + if (value <= 0x7ff) return 2; + if (value <= 0xffff) return 3; + if (value <= 0x10ffff) return 4; + return 0; +} + +int wrenUtf8Encode(int value, uint8_t* bytes) +{ + if (value <= 0x7f) + { + // Single byte (i.e. fits in ASCII). + *bytes = value & 0x7f; + return 1; + } + else if (value <= 0x7ff) + { + // Two byte sequence: 110xxxxx 10xxxxxx. + *bytes = 0xc0 | ((value & 0x7c0) >> 6); + bytes++; + *bytes = 0x80 | (value & 0x3f); + return 2; + } + else if (value <= 0xffff) + { + // Three byte sequence: 1110xxxx 10xxxxxx 10xxxxxx. + *bytes = 0xe0 | ((value & 0xf000) >> 12); + bytes++; + *bytes = 0x80 | ((value & 0xfc0) >> 6); + bytes++; + *bytes = 0x80 | (value & 0x3f); + return 3; + } + else if (value <= 0x10ffff) + { + // Four byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx. + *bytes = 0xf0 | ((value & 0x1c0000) >> 18); + bytes++; + *bytes = 0x80 | ((value & 0x3f000) >> 12); + bytes++; + *bytes = 0x80 | ((value & 0xfc0) >> 6); + bytes++; + *bytes = 0x80 | (value & 0x3f); + return 4; + } + + // Invalid Unicode value. See: http://tools.ietf.org/html/rfc3629 + UNREACHABLE(); + return 0; +} + +int wrenUtf8Decode(const uint8_t* bytes, uint32_t length) +{ + // Single byte (i.e. fits in ASCII). + if (*bytes <= 0x7f) return *bytes; + + int value; + uint32_t remainingBytes; + if ((*bytes & 0xe0) == 0xc0) + { + // Two byte sequence: 110xxxxx 10xxxxxx. + value = *bytes & 0x1f; + remainingBytes = 1; + } + else if ((*bytes & 0xf0) == 0xe0) + { + // Three byte sequence: 1110xxxx 10xxxxxx 10xxxxxx. + value = *bytes & 0x0f; + remainingBytes = 2; + } + else if ((*bytes & 0xf8) == 0xf0) + { + // Four byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx. + value = *bytes & 0x07; + remainingBytes = 3; + } + else + { + // Invalid UTF-8 sequence. + return -1; + } + + // Don't read past the end of the buffer on truncated UTF-8. + if (remainingBytes > length - 1) return -1; + + while (remainingBytes > 0) + { + bytes++; + remainingBytes--; + + // Remaining bytes must be of form 10xxxxxx. + if ((*bytes & 0xc0) != 0x80) return -1; + + value = value << 6 | (*bytes & 0x3f); + } + + return value; +} + +int wrenUtf8DecodeNumBytes(uint8_t byte) +{ + // If the byte starts with 10xxxxx, it's the middle of a UTF-8 sequence, so + // don't count it at all. + if ((byte & 0xc0) == 0x80) return 0; + + // The first byte's high bits tell us how many bytes are in the UTF-8 + // sequence. + if ((byte & 0xf8) == 0xf0) return 4; + if ((byte & 0xf0) == 0xe0) return 3; + if ((byte & 0xe0) == 0xc0) return 2; + return 1; +} + +// From: http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2Float +int wrenPowerOf2Ceil(int n) +{ + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n++; + + return n; +} + +uint32_t wrenValidateIndex(uint32_t count, int64_t value) +{ + // Negative indices count from the end. + if (value < 0) value = count + value; + + // Check bounds. + if (value >= 0 && value < count) return (uint32_t)value; + + return UINT32_MAX; +} +// End file "wren_utils.c" +// Begin file "wren_vm.c" +#include +#include + + +#if WREN_OPT_META +// Begin file "wren_opt_meta.h" +#ifndef wren_opt_meta_h +#define wren_opt_meta_h + + +// This module defines the Meta class and its associated methods. +#if WREN_OPT_META + +const char* wrenMetaSource(); +WrenForeignMethodFn wrenMetaBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature); + +#endif + +#endif +// End file "wren_opt_meta.h" +#endif +#if WREN_OPT_RANDOM +// Begin file "wren_opt_random.h" +#ifndef wren_opt_random_h +#define wren_opt_random_h + + +#if WREN_OPT_RANDOM + +const char* wrenRandomSource(); +WrenForeignClassMethods wrenRandomBindForeignClass(WrenVM* vm, + const char* module, + const char* className); +WrenForeignMethodFn wrenRandomBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature); + +#endif + +#endif +// End file "wren_opt_random.h" +#endif + +#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC + #include + #include +#endif + +// The behavior of realloc() when the size is 0 is implementation defined. It +// may return a non-NULL pointer which must not be dereferenced but nevertheless +// should be freed. To prevent that, we avoid calling realloc() with a zero +// size. +static void* defaultReallocate(void* ptr, size_t newSize, void* _) +{ + if (newSize == 0) + { + free(ptr); + return NULL; + } + + return realloc(ptr, newSize); +} + +int wrenGetVersionNumber() +{ + return WREN_VERSION_NUMBER; +} + +void wrenInitConfiguration(WrenConfiguration* config) +{ + config->reallocateFn = defaultReallocate; + config->resolveModuleFn = NULL; + config->loadModuleFn = NULL; + config->bindForeignMethodFn = NULL; + config->bindForeignClassFn = NULL; + config->writeFn = NULL; + config->errorFn = NULL; + config->initialHeapSize = 1024 * 1024 * 10; + config->minHeapSize = 1024 * 1024; + config->heapGrowthPercent = 50; + config->userData = NULL; +} + +WrenVM* wrenNewVM(WrenConfiguration* config) +{ + WrenReallocateFn reallocate = defaultReallocate; + void* userData = NULL; + if (config != NULL) { + userData = config->userData; + reallocate = config->reallocateFn ? config->reallocateFn : defaultReallocate; + } + + WrenVM* vm = (WrenVM*)reallocate(NULL, sizeof(*vm), userData); + memset(vm, 0, sizeof(WrenVM)); + + // Copy the configuration if given one. + if (config != NULL) + { + memcpy(&vm->config, config, sizeof(WrenConfiguration)); + + // We choose to set this after copying, + // rather than modifying the user config pointer + vm->config.reallocateFn = reallocate; + } + else + { + wrenInitConfiguration(&vm->config); + } + + // TODO: Should we allocate and free this during a GC? + vm->grayCount = 0; + // TODO: Tune this. + vm->grayCapacity = 4; + vm->gray = (Obj**)reallocate(NULL, vm->grayCapacity * sizeof(Obj*), userData); + vm->nextGC = vm->config.initialHeapSize; + + wrenSymbolTableInit(&vm->methodNames); + + vm->modules = wrenNewMap(vm); + wrenInitializeCore(vm); + return vm; +} + +void wrenFreeVM(WrenVM* vm) +{ + ASSERT(vm->methodNames.count > 0, "VM appears to have already been freed."); + + // Free all of the GC objects. + Obj* obj = vm->first; + while (obj != NULL) + { + Obj* next = obj->next; + wrenFreeObj(vm, obj); + obj = next; + } + + // Free up the GC gray set. + vm->gray = (Obj**)vm->config.reallocateFn(vm->gray, 0, vm->config.userData); + + // Tell the user if they didn't free any handles. We don't want to just free + // them here because the host app may still have pointers to them that they + // may try to use. Better to tell them about the bug early. + ASSERT(vm->handles == NULL, "All handles have not been released."); + + wrenSymbolTableClear(vm, &vm->methodNames); + + DEALLOCATE(vm, vm); +} + +void wrenCollectGarbage(WrenVM* vm) +{ +#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC + printf("-- gc --\n"); + + size_t before = vm->bytesAllocated; + double startTime = (double)clock() / CLOCKS_PER_SEC; +#endif + + // Mark all reachable objects. + + // Reset this. As we mark objects, their size will be counted again so that + // we can track how much memory is in use without needing to know the size + // of each *freed* object. + // + // This is important because when freeing an unmarked object, we don't always + // know how much memory it is using. For example, when freeing an instance, + // we need to know its class to know how big it is, but its class may have + // already been freed. + vm->bytesAllocated = 0; + + wrenGrayObj(vm, (Obj*)vm->modules); + + // Temporary roots. + for (int i = 0; i < vm->numTempRoots; i++) + { + wrenGrayObj(vm, vm->tempRoots[i]); + } + + // The current fiber. + wrenGrayObj(vm, (Obj*)vm->fiber); + + // The handles. + for (WrenHandle* handle = vm->handles; + handle != NULL; + handle = handle->next) + { + wrenGrayValue(vm, handle->value); + } + + // Any object the compiler is using (if there is one). + if (vm->compiler != NULL) wrenMarkCompiler(vm, vm->compiler); + + // Method names. + wrenBlackenSymbolTable(vm, &vm->methodNames); + + // Now that we have grayed the roots, do a depth-first search over all of the + // reachable objects. + wrenBlackenObjects(vm); + + // Collect the white objects. + Obj** obj = &vm->first; + while (*obj != NULL) + { + if (!((*obj)->isDark)) + { + // This object wasn't reached, so remove it from the list and free it. + Obj* unreached = *obj; + *obj = unreached->next; + wrenFreeObj(vm, unreached); + } + else + { + // This object was reached, so unmark it (for the next GC) and move on to + // the next. + (*obj)->isDark = false; + obj = &(*obj)->next; + } + } + + // Calculate the next gc point, this is the current allocation plus + // a configured percentage of the current allocation. + vm->nextGC = vm->bytesAllocated + ((vm->bytesAllocated * vm->config.heapGrowthPercent) / 100); + if (vm->nextGC < vm->config.minHeapSize) vm->nextGC = vm->config.minHeapSize; + +#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC + double elapsed = ((double)clock() / CLOCKS_PER_SEC) - startTime; + // Explicit cast because size_t has different sizes on 32-bit and 64-bit and + // we need a consistent type for the format string. + printf("GC %lu before, %lu after (%lu collected), next at %lu. Took %.3fms.\n", + (unsigned long)before, + (unsigned long)vm->bytesAllocated, + (unsigned long)(before - vm->bytesAllocated), + (unsigned long)vm->nextGC, + elapsed*1000.0); +#endif +} + +void* wrenReallocate(WrenVM* vm, void* memory, size_t oldSize, size_t newSize) +{ +#if WREN_DEBUG_TRACE_MEMORY + // Explicit cast because size_t has different sizes on 32-bit and 64-bit and + // we need a consistent type for the format string. + printf("reallocate %p %lu -> %lu\n", + memory, (unsigned long)oldSize, (unsigned long)newSize); +#endif + + // If new bytes are being allocated, add them to the total count. If objects + // are being completely deallocated, we don't track that (since we don't + // track the original size). Instead, that will be handled while marking + // during the next GC. + vm->bytesAllocated += newSize - oldSize; + +#if WREN_DEBUG_GC_STRESS + // Since collecting calls this function to free things, make sure we don't + // recurse. + if (newSize > 0) wrenCollectGarbage(vm); +#else + if (newSize > 0 && vm->bytesAllocated > vm->nextGC) wrenCollectGarbage(vm); +#endif + + return vm->config.reallocateFn(memory, newSize, vm->config.userData); +} + +// Captures the local variable [local] into an [Upvalue]. If that local is +// already in an upvalue, the existing one will be used. (This is important to +// ensure that multiple closures closing over the same variable actually see +// the same variable.) Otherwise, it will create a new open upvalue and add it +// the fiber's list of upvalues. +static ObjUpvalue* captureUpvalue(WrenVM* vm, ObjFiber* fiber, Value* local) +{ + // If there are no open upvalues at all, we must need a new one. + if (fiber->openUpvalues == NULL) + { + fiber->openUpvalues = wrenNewUpvalue(vm, local); + return fiber->openUpvalues; + } + + ObjUpvalue* prevUpvalue = NULL; + ObjUpvalue* upvalue = fiber->openUpvalues; + + // Walk towards the bottom of the stack until we find a previously existing + // upvalue or pass where it should be. + while (upvalue != NULL && upvalue->value > local) + { + prevUpvalue = upvalue; + upvalue = upvalue->next; + } + + // Found an existing upvalue for this local. + if (upvalue != NULL && upvalue->value == local) return upvalue; + + // We've walked past this local on the stack, so there must not be an + // upvalue for it already. Make a new one and link it in in the right + // place to keep the list sorted. + ObjUpvalue* createdUpvalue = wrenNewUpvalue(vm, local); + if (prevUpvalue == NULL) + { + // The new one is the first one in the list. + fiber->openUpvalues = createdUpvalue; + } + else + { + prevUpvalue->next = createdUpvalue; + } + + createdUpvalue->next = upvalue; + return createdUpvalue; +} + +// Closes any open upvalues that have been created for stack slots at [last] +// and above. +static void closeUpvalues(ObjFiber* fiber, Value* last) +{ + while (fiber->openUpvalues != NULL && + fiber->openUpvalues->value >= last) + { + ObjUpvalue* upvalue = fiber->openUpvalues; + + // Move the value into the upvalue itself and point the upvalue to it. + upvalue->closed = *upvalue->value; + upvalue->value = &upvalue->closed; + + // Remove it from the open upvalue list. + fiber->openUpvalues = upvalue->next; + } +} + +// Looks up a foreign method in [moduleName] on [className] with [signature]. +// +// This will try the host's foreign method binder first. If that fails, it +// falls back to handling the built-in modules. +static WrenForeignMethodFn findForeignMethod(WrenVM* vm, + const char* moduleName, + const char* className, + bool isStatic, + const char* signature) +{ + WrenForeignMethodFn method = NULL; + + if (vm->config.bindForeignMethodFn != NULL) + { + method = vm->config.bindForeignMethodFn(vm, moduleName, className, isStatic, + signature); + } + + // If the host didn't provide it, see if it's an optional one. + if (method == NULL) + { +#if WREN_OPT_META + if (strcmp(moduleName, "meta") == 0) + { + method = wrenMetaBindForeignMethod(vm, className, isStatic, signature); + } +#endif +#if WREN_OPT_RANDOM + if (strcmp(moduleName, "random") == 0) + { + method = wrenRandomBindForeignMethod(vm, className, isStatic, signature); + } +#endif + } + + return method; +} + +// Defines [methodValue] as a method on [classObj]. +// +// Handles both foreign methods where [methodValue] is a string containing the +// method's signature and Wren methods where [methodValue] is a function. +// +// Aborts the current fiber if the method is a foreign method that could not be +// found. +static void bindMethod(WrenVM* vm, int methodType, int symbol, + ObjModule* module, ObjClass* classObj, Value methodValue) +{ + const char* className = classObj->name->value; + if (methodType == CODE_METHOD_STATIC) classObj = classObj->obj.classObj; + + Method method; + if (IS_STRING(methodValue)) + { + const char* name = AS_CSTRING(methodValue); + method.type = METHOD_FOREIGN; + method.as.foreign = findForeignMethod(vm, module->name->value, + className, + methodType == CODE_METHOD_STATIC, + name); + + if (method.as.foreign == NULL) + { + vm->fiber->error = wrenStringFormat(vm, + "Could not find foreign method '@' for class $ in module '$'.", + methodValue, classObj->name->value, module->name->value); + return; + } + } + else + { + method.as.closure = AS_CLOSURE(methodValue); + method.type = METHOD_BLOCK; + + // Patch up the bytecode now that we know the superclass. + wrenBindMethodCode(classObj, method.as.closure->fn); + } + + wrenBindMethod(vm, classObj, symbol, method); +} + +static void callForeign(WrenVM* vm, ObjFiber* fiber, + WrenForeignMethodFn foreign, int numArgs) +{ + ASSERT(vm->apiStack == NULL, "Cannot already be in foreign call."); + vm->apiStack = fiber->stackTop - numArgs; + + foreign(vm); + + // Discard the stack slots for the arguments and temporaries but leave one + // for the result. + fiber->stackTop = vm->apiStack + 1; + + vm->apiStack = NULL; +} + +// Handles the current fiber having aborted because of an error. +// +// Walks the call chain of fibers, aborting each one until it hits a fiber that +// handles the error. If none do, tells the VM to stop. +static void runtimeError(WrenVM* vm) +{ + ASSERT(wrenHasError(vm->fiber), "Should only call this after an error."); + + ObjFiber* current = vm->fiber; + Value error = current->error; + + while (current != NULL) + { + // Every fiber along the call chain gets aborted with the same error. + current->error = error; + + // If the caller ran this fiber using "try", give it the error and stop. + if (current->state == FIBER_TRY) + { + // Make the caller's try method return the error message. + current->caller->stackTop[-1] = vm->fiber->error; + vm->fiber = current->caller; + return; + } + + // Otherwise, unhook the caller since we will never resume and return to it. + ObjFiber* caller = current->caller; + current->caller = NULL; + current = caller; + } + + // If we got here, nothing caught the error, so show the stack trace. + wrenDebugPrintStackTrace(vm); + vm->fiber = NULL; + vm->apiStack = NULL; +} + +// Aborts the current fiber with an appropriate method not found error for a +// method with [symbol] on [classObj]. +static void methodNotFound(WrenVM* vm, ObjClass* classObj, int symbol) +{ + vm->fiber->error = wrenStringFormat(vm, "@ does not implement '$'.", + OBJ_VAL(classObj->name), vm->methodNames.data[symbol]->value); +} + +// Looks up the previously loaded module with [name]. +// +// Returns `NULL` if no module with that name has been loaded. +static ObjModule* getModule(WrenVM* vm, Value name) +{ + Value moduleValue = wrenMapGet(vm->modules, name); + return !IS_UNDEFINED(moduleValue) ? AS_MODULE(moduleValue) : NULL; +} + +static ObjClosure* compileInModule(WrenVM* vm, Value name, const char* source, + bool isExpression, bool printErrors) +{ + // See if the module has already been loaded. + ObjModule* module = getModule(vm, name); + if (module == NULL) + { + module = wrenNewModule(vm, AS_STRING(name)); + + // It's possible for the wrenMapSet below to resize the modules map, + // and trigger a GC while doing so. When this happens it will collect + // the module we've just created. Once in the map it is safe. + wrenPushRoot(vm, (Obj*)module); + + // Store it in the VM's module registry so we don't load the same module + // multiple times. + wrenMapSet(vm, vm->modules, name, OBJ_VAL(module)); + + wrenPopRoot(vm); + + // Implicitly import the core module. + ObjModule* coreModule = getModule(vm, NULL_VAL); + for (int i = 0; i < coreModule->variables.count; i++) + { + wrenDefineVariable(vm, module, + coreModule->variableNames.data[i]->value, + coreModule->variableNames.data[i]->length, + coreModule->variables.data[i], NULL); + } + } + + ObjFn* fn = wrenCompile(vm, module, source, isExpression, printErrors); + if (fn == NULL) + { + // TODO: Should we still store the module even if it didn't compile? + return NULL; + } + + // Functions are always wrapped in closures. + wrenPushRoot(vm, (Obj*)fn); + ObjClosure* closure = wrenNewClosure(vm, fn); + wrenPopRoot(vm); // fn. + + return closure; +} + +// Verifies that [superclassValue] is a valid object to inherit from. That +// means it must be a class and cannot be the class of any built-in type. +// +// Also validates that it doesn't result in a class with too many fields and +// the other limitations foreign classes have. +// +// If successful, returns `null`. Otherwise, returns a string for the runtime +// error message. +static Value validateSuperclass(WrenVM* vm, Value name, Value superclassValue, + int numFields) +{ + // Make sure the superclass is a class. + if (!IS_CLASS(superclassValue)) + { + return wrenStringFormat(vm, + "Class '@' cannot inherit from a non-class object.", + name); + } + + // Make sure it doesn't inherit from a sealed built-in type. Primitive methods + // on these classes assume the instance is one of the other Obj___ types and + // will fail horribly if it's actually an ObjInstance. + ObjClass* superclass = AS_CLASS(superclassValue); + if (superclass == vm->classClass || + superclass == vm->fiberClass || + superclass == vm->fnClass || // Includes OBJ_CLOSURE. + superclass == vm->listClass || + superclass == vm->mapClass || + superclass == vm->rangeClass || + superclass == vm->stringClass || + superclass == vm->boolClass || + superclass == vm->nullClass || + superclass == vm->numClass) + { + return wrenStringFormat(vm, + "Class '@' cannot inherit from built-in class '@'.", + name, OBJ_VAL(superclass->name)); + } + + if (superclass->numFields == -1) + { + return wrenStringFormat(vm, + "Class '@' cannot inherit from foreign class '@'.", + name, OBJ_VAL(superclass->name)); + } + + if (numFields == -1 && superclass->numFields > 0) + { + return wrenStringFormat(vm, + "Foreign class '@' may not inherit from a class with fields.", + name); + } + + if (superclass->numFields + numFields > MAX_FIELDS) + { + return wrenStringFormat(vm, + "Class '@' may not have more than 255 fields, including inherited " + "ones.", name); + } + + return NULL_VAL; +} + +static void bindForeignClass(WrenVM* vm, ObjClass* classObj, ObjModule* module) +{ + WrenForeignClassMethods methods; + methods.allocate = NULL; + methods.finalize = NULL; + + // Check the optional built-in module first so the host can override it. + + if (vm->config.bindForeignClassFn != NULL) + { + methods = vm->config.bindForeignClassFn(vm, module->name->value, + classObj->name->value); + } + + // If the host didn't provide it, see if it's a built in optional module. + if (methods.allocate == NULL && methods.finalize == NULL) + { +#if WREN_OPT_RANDOM + if (strcmp(module->name->value, "random") == 0) + { + methods = wrenRandomBindForeignClass(vm, module->name->value, + classObj->name->value); + } +#endif + } + + Method method; + method.type = METHOD_FOREIGN; + + // Add the symbol even if there is no allocator so we can ensure that the + // symbol itself is always in the symbol table. + int symbol = wrenSymbolTableEnsure(vm, &vm->methodNames, "", 10); + if (methods.allocate != NULL) + { + method.as.foreign = methods.allocate; + wrenBindMethod(vm, classObj, symbol, method); + } + + // Add the symbol even if there is no finalizer so we can ensure that the + // symbol itself is always in the symbol table. + symbol = wrenSymbolTableEnsure(vm, &vm->methodNames, "", 10); + if (methods.finalize != NULL) + { + method.as.foreign = (WrenForeignMethodFn)methods.finalize; + wrenBindMethod(vm, classObj, symbol, method); + } +} + +// Completes the process for creating a new class. +// +// The class attributes instance and the class itself should be on the +// top of the fiber's stack. +// +// This process handles moving the attribute data for a class from +// compile time to runtime, since it now has all the attributes associated +// with a class, including for methods. +static void endClass(WrenVM* vm) +{ + // Pull the attributes and class off the stack + Value attributes = vm->fiber->stackTop[-2]; + Value classValue = vm->fiber->stackTop[-1]; + + // Remove the stack items + vm->fiber->stackTop -= 2; + + ObjClass* classObj = AS_CLASS(classValue); + classObj->attributes = attributes; +} + +// Creates a new class. +// +// If [numFields] is -1, the class is a foreign class. The name and superclass +// should be on top of the fiber's stack. After calling this, the top of the +// stack will contain the new class. +// +// Aborts the current fiber if an error occurs. +static void createClass(WrenVM* vm, int numFields, ObjModule* module) +{ + // Pull the name and superclass off the stack. + Value name = vm->fiber->stackTop[-2]; + Value superclass = vm->fiber->stackTop[-1]; + + // We have two values on the stack and we are going to leave one, so discard + // the other slot. + vm->fiber->stackTop--; + + vm->fiber->error = validateSuperclass(vm, name, superclass, numFields); + if (wrenHasError(vm->fiber)) return; + + ObjClass* classObj = wrenNewClass(vm, AS_CLASS(superclass), numFields, + AS_STRING(name)); + vm->fiber->stackTop[-1] = OBJ_VAL(classObj); + + if (numFields == -1) bindForeignClass(vm, classObj, module); +} + +static void createForeign(WrenVM* vm, ObjFiber* fiber, Value* stack) +{ + ObjClass* classObj = AS_CLASS(stack[0]); + ASSERT(classObj->numFields == -1, "Class must be a foreign class."); + + // TODO: Don't look up every time. + int symbol = wrenSymbolTableFind(&vm->methodNames, "", 10); + ASSERT(symbol != -1, "Should have defined symbol."); + + ASSERT(classObj->methods.count > symbol, "Class should have allocator."); + Method* method = &classObj->methods.data[symbol]; + ASSERT(method->type == METHOD_FOREIGN, "Allocator should be foreign."); + + // Pass the constructor arguments to the allocator as well. + ASSERT(vm->apiStack == NULL, "Cannot already be in foreign call."); + vm->apiStack = stack; + + method->as.foreign(vm); + + vm->apiStack = NULL; +} + +void wrenFinalizeForeign(WrenVM* vm, ObjForeign* foreign) +{ + // TODO: Don't look up every time. + int symbol = wrenSymbolTableFind(&vm->methodNames, "", 10); + ASSERT(symbol != -1, "Should have defined symbol."); + + // If there are no finalizers, don't finalize it. + if (symbol == -1) return; + + // If the class doesn't have a finalizer, bail out. + ObjClass* classObj = foreign->obj.classObj; + if (symbol >= classObj->methods.count) return; + + Method* method = &classObj->methods.data[symbol]; + if (method->type == METHOD_NONE) return; + + ASSERT(method->type == METHOD_FOREIGN, "Finalizer should be foreign."); + + WrenFinalizerFn finalizer = (WrenFinalizerFn)method->as.foreign; + finalizer(foreign->data); +} + +// Let the host resolve an imported module name if it wants to. +static Value resolveModule(WrenVM* vm, Value name) +{ + // If the host doesn't care to resolve, leave the name alone. + if (vm->config.resolveModuleFn == NULL) return name; + + ObjFiber* fiber = vm->fiber; + ObjFn* fn = fiber->frames[fiber->numFrames - 1].closure->fn; + ObjString* importer = fn->module->name; + + const char* resolved = vm->config.resolveModuleFn(vm, importer->value, + AS_CSTRING(name)); + if (resolved == NULL) + { + vm->fiber->error = wrenStringFormat(vm, + "Could not resolve module '@' imported from '@'.", + name, OBJ_VAL(importer)); + return NULL_VAL; + } + + // If they resolved to the exact same string, we don't need to copy it. + if (resolved == AS_CSTRING(name)) return name; + + // Copy the string into a Wren String object. + name = wrenNewString(vm, resolved); + DEALLOCATE(vm, (char*)resolved); + return name; +} + +static Value importModule(WrenVM* vm, Value name) +{ + name = resolveModule(vm, name); + + // If the module is already loaded, we don't need to do anything. + Value existing = wrenMapGet(vm->modules, name); + if (!IS_UNDEFINED(existing)) return existing; + + wrenPushRoot(vm, AS_OBJ(name)); + + WrenLoadModuleResult result = {0}; + const char* source = NULL; + + // Let the host try to provide the module. + if (vm->config.loadModuleFn != NULL) + { + result = vm->config.loadModuleFn(vm, AS_CSTRING(name)); + } + + // If the host didn't provide it, see if it's a built in optional module. + if (result.source == NULL) + { + result.onComplete = NULL; + ObjString* nameString = AS_STRING(name); +#if WREN_OPT_META + if (strcmp(nameString->value, "meta") == 0) result.source = wrenMetaSource(); +#endif +#if WREN_OPT_RANDOM + if (strcmp(nameString->value, "random") == 0) result.source = wrenRandomSource(); +#endif + } + + if (result.source == NULL) + { + vm->fiber->error = wrenStringFormat(vm, "Could not load module '@'.", name); + wrenPopRoot(vm); // name. + return NULL_VAL; + } + + ObjClosure* moduleClosure = compileInModule(vm, name, result.source, false, true); + + // Now that we're done, give the result back in case there's cleanup to do. + if(result.onComplete) result.onComplete(vm, AS_CSTRING(name), result); + + if (moduleClosure == NULL) + { + vm->fiber->error = wrenStringFormat(vm, + "Could not compile module '@'.", name); + wrenPopRoot(vm); // name. + return NULL_VAL; + } + + wrenPopRoot(vm); // name. + + // Return the closure that executes the module. + return OBJ_VAL(moduleClosure); +} + +static Value getModuleVariable(WrenVM* vm, ObjModule* module, + Value variableName) +{ + ObjString* variable = AS_STRING(variableName); + uint32_t variableEntry = wrenSymbolTableFind(&module->variableNames, + variable->value, + variable->length); + + // It's a runtime error if the imported variable does not exist. + if (variableEntry != UINT32_MAX) + { + return module->variables.data[variableEntry]; + } + + vm->fiber->error = wrenStringFormat(vm, + "Could not find a variable named '@' in module '@'.", + variableName, OBJ_VAL(module->name)); + return NULL_VAL; +} + +inline static bool checkArity(WrenVM* vm, Value value, int numArgs) +{ + ASSERT(IS_CLOSURE(value), "Receiver must be a closure."); + ObjFn* fn = AS_CLOSURE(value)->fn; + + // We only care about missing arguments, not extras. The "- 1" is because + // numArgs includes the receiver, the function itself, which we don't want to + // count. + if (numArgs - 1 >= fn->arity) return true; + + vm->fiber->error = CONST_STRING(vm, "Function expects more arguments."); + return false; +} + + +// The main bytecode interpreter loop. This is where the magic happens. It is +// also, as you can imagine, highly performance critical. +static WrenInterpretResult runInterpreter(WrenVM* vm, register ObjFiber* fiber) +{ + // Remember the current fiber so we can find it if a GC happens. + vm->fiber = fiber; + fiber->state = FIBER_ROOT; + + // Hoist these into local variables. They are accessed frequently in the loop + // but assigned less frequently. Keeping them in locals and updating them when + // a call frame has been pushed or popped gives a large speed boost. + register CallFrame* frame; + register Value* stackStart; + register uint8_t* ip; + register ObjFn* fn; + + // These macros are designed to only be invoked within this function. + #define PUSH(value) (*fiber->stackTop++ = value) + #define POP() (*(--fiber->stackTop)) + #define DROP() (fiber->stackTop--) + #define PEEK() (*(fiber->stackTop - 1)) + #define PEEK2() (*(fiber->stackTop - 2)) + #define READ_BYTE() (*ip++) + #define READ_SHORT() (ip += 2, (uint16_t)((ip[-2] << 8) | ip[-1])) + + // Use this before a CallFrame is pushed to store the local variables back + // into the current one. + #define STORE_FRAME() frame->ip = ip + + // Use this after a CallFrame has been pushed or popped to refresh the local + // variables. + #define LOAD_FRAME() \ + do \ + { \ + frame = &fiber->frames[fiber->numFrames - 1]; \ + stackStart = frame->stackStart; \ + ip = frame->ip; \ + fn = frame->closure->fn; \ + } while (false) + + // Terminates the current fiber with error string [error]. If another calling + // fiber is willing to catch the error, transfers control to it, otherwise + // exits the interpreter. + #define RUNTIME_ERROR() \ + do \ + { \ + STORE_FRAME(); \ + runtimeError(vm); \ + if (vm->fiber == NULL) return WREN_RESULT_RUNTIME_ERROR; \ + fiber = vm->fiber; \ + LOAD_FRAME(); \ + DISPATCH(); \ + } while (false) + + #if WREN_DEBUG_TRACE_INSTRUCTIONS + // Prints the stack and instruction before each instruction is executed. + #define DEBUG_TRACE_INSTRUCTIONS() \ + do \ + { \ + wrenDumpStack(fiber); \ + wrenDumpInstruction(vm, fn, (int)(ip - fn->code.data)); \ + } while (false) + #else + #define DEBUG_TRACE_INSTRUCTIONS() do { } while (false) + #endif + + #if WREN_COMPUTED_GOTO + + static void* dispatchTable[] = { + #define OPCODE(name, _) &&code_##name, +// Begin file "wren_opcodes.h" +// This defines the bytecode instructions used by the VM. It does so by invoking +// an OPCODE() macro which is expected to be defined at the point that this is +// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.) +// +// The first argument is the name of the opcode. The second is its "stack +// effect" -- the amount that the op code changes the size of the stack. A +// stack effect of 1 means it pushes a value and the stack grows one larger. +// -2 means it pops two values, etc. +// +// Note that the order of instructions here affects the order of the dispatch +// table in the VM's interpreter loop. That in turn affects caching which +// affects overall performance. Take care to run benchmarks if you change the +// order here. + +// Load the constant at index [arg]. +OPCODE(CONSTANT, 1) + +// Push null onto the stack. +OPCODE(NULL, 1) + +// Push false onto the stack. +OPCODE(FALSE, 1) + +// Push true onto the stack. +OPCODE(TRUE, 1) + +// Pushes the value in the given local slot. +OPCODE(LOAD_LOCAL_0, 1) +OPCODE(LOAD_LOCAL_1, 1) +OPCODE(LOAD_LOCAL_2, 1) +OPCODE(LOAD_LOCAL_3, 1) +OPCODE(LOAD_LOCAL_4, 1) +OPCODE(LOAD_LOCAL_5, 1) +OPCODE(LOAD_LOCAL_6, 1) +OPCODE(LOAD_LOCAL_7, 1) +OPCODE(LOAD_LOCAL_8, 1) + +// Note: The compiler assumes the following _STORE instructions always +// immediately follow their corresponding _LOAD ones. + +// Pushes the value in local slot [arg]. +OPCODE(LOAD_LOCAL, 1) + +// Stores the top of stack in local slot [arg]. Does not pop it. +OPCODE(STORE_LOCAL, 0) + +// Pushes the value in upvalue [arg]. +OPCODE(LOAD_UPVALUE, 1) + +// Stores the top of stack in upvalue [arg]. Does not pop it. +OPCODE(STORE_UPVALUE, 0) + +// Pushes the value of the top-level variable in slot [arg]. +OPCODE(LOAD_MODULE_VAR, 1) + +// Stores the top of stack in top-level variable slot [arg]. Does not pop it. +OPCODE(STORE_MODULE_VAR, 0) + +// Pushes the value of the field in slot [arg] of the receiver of the current +// function. This is used for regular field accesses on "this" directly in +// methods. This instruction is faster than the more general CODE_LOAD_FIELD +// instruction. +OPCODE(LOAD_FIELD_THIS, 1) + +// Stores the top of the stack in field slot [arg] in the receiver of the +// current value. Does not pop the value. This instruction is faster than the +// more general CODE_LOAD_FIELD instruction. +OPCODE(STORE_FIELD_THIS, 0) + +// Pops an instance and pushes the value of the field in slot [arg] of it. +OPCODE(LOAD_FIELD, 0) + +// Pops an instance and stores the subsequent top of stack in field slot +// [arg] in it. Does not pop the value. +OPCODE(STORE_FIELD, -1) + +// Pop and discard the top of stack. +OPCODE(POP, -1) + +// Invoke the method with symbol [arg]. The number indicates the number of +// arguments (not including the receiver). +OPCODE(CALL_0, 0) +OPCODE(CALL_1, -1) +OPCODE(CALL_2, -2) +OPCODE(CALL_3, -3) +OPCODE(CALL_4, -4) +OPCODE(CALL_5, -5) +OPCODE(CALL_6, -6) +OPCODE(CALL_7, -7) +OPCODE(CALL_8, -8) +OPCODE(CALL_9, -9) +OPCODE(CALL_10, -10) +OPCODE(CALL_11, -11) +OPCODE(CALL_12, -12) +OPCODE(CALL_13, -13) +OPCODE(CALL_14, -14) +OPCODE(CALL_15, -15) +OPCODE(CALL_16, -16) + +// Invoke a superclass method with symbol [arg]. The number indicates the +// number of arguments (not including the receiver). +OPCODE(SUPER_0, 0) +OPCODE(SUPER_1, -1) +OPCODE(SUPER_2, -2) +OPCODE(SUPER_3, -3) +OPCODE(SUPER_4, -4) +OPCODE(SUPER_5, -5) +OPCODE(SUPER_6, -6) +OPCODE(SUPER_7, -7) +OPCODE(SUPER_8, -8) +OPCODE(SUPER_9, -9) +OPCODE(SUPER_10, -10) +OPCODE(SUPER_11, -11) +OPCODE(SUPER_12, -12) +OPCODE(SUPER_13, -13) +OPCODE(SUPER_14, -14) +OPCODE(SUPER_15, -15) +OPCODE(SUPER_16, -16) + +// Jump the instruction pointer [arg] forward. +OPCODE(JUMP, 0) + +// Jump the instruction pointer [arg] backward. +OPCODE(LOOP, 0) + +// Pop and if not truthy then jump the instruction pointer [arg] forward. +OPCODE(JUMP_IF, -1) + +// If the top of the stack is false, jump [arg] forward. Otherwise, pop and +// continue. +OPCODE(AND, -1) + +// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop +// and continue. +OPCODE(OR, -1) + +// Close the upvalue for the local on the top of the stack, then pop it. +OPCODE(CLOSE_UPVALUE, -1) + +// Exit from the current function and return the value on the top of the +// stack. +OPCODE(RETURN, 0) + +// Creates a closure for the function stored at [arg] in the constant table. +// +// Following the function argument is a number of arguments, two for each +// upvalue. The first is true if the variable being captured is a local (as +// opposed to an upvalue), and the second is the index of the local or +// upvalue being captured. +// +// Pushes the created closure. +OPCODE(CLOSURE, 1) + +// Creates a new instance of a class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(CONSTRUCT, 0) + +// Creates a new instance of a foreign class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(FOREIGN_CONSTRUCT, 0) + +// Creates a class. Top of stack is the superclass. Below that is a string for +// the name of the class. Byte [arg] is the number of fields in the class. +OPCODE(CLASS, -1) + +// Ends a class. +// Atm the stack contains the class and the ClassAttributes (or null). +OPCODE(END_CLASS, -2) + +// Creates a foreign class. Top of stack is the superclass. Below that is a +// string for the name of the class. +OPCODE(FOREIGN_CLASS, -1) + +// Define a method for symbol [arg]. The class receiving the method is popped +// off the stack, then the function defining the body is popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_INSTANCE, -2) + +// Define a method for symbol [arg]. The class whose metaclass will receive +// the method is popped off the stack, then the function defining the body is +// popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_STATIC, -2) + +// This is executed at the end of the module's body. Pushes NULL onto the stack +// as the "return value" of the import statement and stores the module as the +// most recently imported one. +OPCODE(END_MODULE, 1) + +// Import a module whose name is the string stored at [arg] in the constant +// table. +// +// Pushes null onto the stack so that the fiber for the imported module can +// replace that with a dummy value when it returns. (Fibers always return a +// value when resuming a caller.) +OPCODE(IMPORT_MODULE, 1) + +// Import a variable from the most recently imported module. The name of the +// variable to import is at [arg] in the constant table. Pushes the loaded +// variable's value. +OPCODE(IMPORT_VARIABLE, 1) + +// This pseudo-instruction indicates the end of the bytecode. It should +// always be preceded by a `CODE_RETURN`, so is never actually executed. +OPCODE(END, 0) +// End file "wren_opcodes.h" + #undef OPCODE + }; + + #define INTERPRET_LOOP DISPATCH(); + #define CASE_CODE(name) code_##name + + #define DISPATCH() \ + do \ + { \ + DEBUG_TRACE_INSTRUCTIONS(); \ + goto *dispatchTable[instruction = (Code)READ_BYTE()]; \ + } while (false) + + #else + + #define INTERPRET_LOOP \ + loop: \ + DEBUG_TRACE_INSTRUCTIONS(); \ + switch (instruction = (Code)READ_BYTE()) + + #define CASE_CODE(name) case CODE_##name + #define DISPATCH() goto loop + + #endif + + LOAD_FRAME(); + + Code instruction; + INTERPRET_LOOP + { + CASE_CODE(LOAD_LOCAL_0): + CASE_CODE(LOAD_LOCAL_1): + CASE_CODE(LOAD_LOCAL_2): + CASE_CODE(LOAD_LOCAL_3): + CASE_CODE(LOAD_LOCAL_4): + CASE_CODE(LOAD_LOCAL_5): + CASE_CODE(LOAD_LOCAL_6): + CASE_CODE(LOAD_LOCAL_7): + CASE_CODE(LOAD_LOCAL_8): + PUSH(stackStart[instruction - CODE_LOAD_LOCAL_0]); + DISPATCH(); + + CASE_CODE(LOAD_LOCAL): + PUSH(stackStart[READ_BYTE()]); + DISPATCH(); + + CASE_CODE(LOAD_FIELD_THIS): + { + uint8_t field = READ_BYTE(); + Value receiver = stackStart[0]; + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + PUSH(instance->fields[field]); + DISPATCH(); + } + + CASE_CODE(POP): DROP(); DISPATCH(); + CASE_CODE(NULL): PUSH(NULL_VAL); DISPATCH(); + CASE_CODE(FALSE): PUSH(FALSE_VAL); DISPATCH(); + CASE_CODE(TRUE): PUSH(TRUE_VAL); DISPATCH(); + + CASE_CODE(STORE_LOCAL): + stackStart[READ_BYTE()] = PEEK(); + DISPATCH(); + + CASE_CODE(CONSTANT): + PUSH(fn->constants.data[READ_SHORT()]); + DISPATCH(); + + { + // The opcodes for doing method and superclass calls share a lot of code. + // However, doing an if() test in the middle of the instruction sequence + // to handle the bit that is special to super calls makes the non-super + // call path noticeably slower. + // + // Instead, we do this old school using an explicit goto to share code for + // everything at the tail end of the call-handling code that is the same + // between normal and superclass calls. + int numArgs; + int symbol; + + Value* args; + ObjClass* classObj; + + Method* method; + + CASE_CODE(CALL_0): + CASE_CODE(CALL_1): + CASE_CODE(CALL_2): + CASE_CODE(CALL_3): + CASE_CODE(CALL_4): + CASE_CODE(CALL_5): + CASE_CODE(CALL_6): + CASE_CODE(CALL_7): + CASE_CODE(CALL_8): + CASE_CODE(CALL_9): + CASE_CODE(CALL_10): + CASE_CODE(CALL_11): + CASE_CODE(CALL_12): + CASE_CODE(CALL_13): + CASE_CODE(CALL_14): + CASE_CODE(CALL_15): + CASE_CODE(CALL_16): + // Add one for the implicit receiver argument. + numArgs = instruction - CODE_CALL_0 + 1; + symbol = READ_SHORT(); + + // The receiver is the first argument. + args = fiber->stackTop - numArgs; + classObj = wrenGetClassInline(vm, args[0]); + goto completeCall; + + CASE_CODE(SUPER_0): + CASE_CODE(SUPER_1): + CASE_CODE(SUPER_2): + CASE_CODE(SUPER_3): + CASE_CODE(SUPER_4): + CASE_CODE(SUPER_5): + CASE_CODE(SUPER_6): + CASE_CODE(SUPER_7): + CASE_CODE(SUPER_8): + CASE_CODE(SUPER_9): + CASE_CODE(SUPER_10): + CASE_CODE(SUPER_11): + CASE_CODE(SUPER_12): + CASE_CODE(SUPER_13): + CASE_CODE(SUPER_14): + CASE_CODE(SUPER_15): + CASE_CODE(SUPER_16): + // Add one for the implicit receiver argument. + numArgs = instruction - CODE_SUPER_0 + 1; + symbol = READ_SHORT(); + + // The receiver is the first argument. + args = fiber->stackTop - numArgs; + + // The superclass is stored in a constant. + classObj = AS_CLASS(fn->constants.data[READ_SHORT()]); + goto completeCall; + + completeCall: + // If the class's method table doesn't include the symbol, bail. + if (symbol >= classObj->methods.count || + (method = &classObj->methods.data[symbol])->type == METHOD_NONE) + { + methodNotFound(vm, classObj, symbol); + RUNTIME_ERROR(); + } + + switch (method->type) + { + case METHOD_PRIMITIVE: + if (method->as.primitive(vm, args)) + { + // The result is now in the first arg slot. Discard the other + // stack slots. + fiber->stackTop -= numArgs - 1; + } else { + // An error, fiber switch, or call frame change occurred. + STORE_FRAME(); + + // If we don't have a fiber to switch to, stop interpreting. + fiber = vm->fiber; + if (fiber == NULL) return WREN_RESULT_SUCCESS; + if (wrenHasError(fiber)) RUNTIME_ERROR(); + LOAD_FRAME(); + } + break; + + case METHOD_FUNCTION_CALL: + if (!checkArity(vm, args[0], numArgs)) { + RUNTIME_ERROR(); + break; + } + + STORE_FRAME(); + method->as.primitive(vm, args); + LOAD_FRAME(); + break; + + case METHOD_FOREIGN: + callForeign(vm, fiber, method->as.foreign, numArgs); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + break; + + case METHOD_BLOCK: + STORE_FRAME(); + wrenCallFunction(vm, fiber, (ObjClosure*)method->as.closure, numArgs); + LOAD_FRAME(); + break; + + case METHOD_NONE: + UNREACHABLE(); + break; + } + DISPATCH(); + } + + CASE_CODE(LOAD_UPVALUE): + { + ObjUpvalue** upvalues = frame->closure->upvalues; + PUSH(*upvalues[READ_BYTE()]->value); + DISPATCH(); + } + + CASE_CODE(STORE_UPVALUE): + { + ObjUpvalue** upvalues = frame->closure->upvalues; + *upvalues[READ_BYTE()]->value = PEEK(); + DISPATCH(); + } + + CASE_CODE(LOAD_MODULE_VAR): + PUSH(fn->module->variables.data[READ_SHORT()]); + DISPATCH(); + + CASE_CODE(STORE_MODULE_VAR): + fn->module->variables.data[READ_SHORT()] = PEEK(); + DISPATCH(); + + CASE_CODE(STORE_FIELD_THIS): + { + uint8_t field = READ_BYTE(); + Value receiver = stackStart[0]; + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + instance->fields[field] = PEEK(); + DISPATCH(); + } + + CASE_CODE(LOAD_FIELD): + { + uint8_t field = READ_BYTE(); + Value receiver = POP(); + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + PUSH(instance->fields[field]); + DISPATCH(); + } + + CASE_CODE(STORE_FIELD): + { + uint8_t field = READ_BYTE(); + Value receiver = POP(); + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + instance->fields[field] = PEEK(); + DISPATCH(); + } + + CASE_CODE(JUMP): + { + uint16_t offset = READ_SHORT(); + ip += offset; + DISPATCH(); + } + + CASE_CODE(LOOP): + { + // Jump back to the top of the loop. + uint16_t offset = READ_SHORT(); + ip -= offset; + DISPATCH(); + } + + CASE_CODE(JUMP_IF): + { + uint16_t offset = READ_SHORT(); + Value condition = POP(); + + if (wrenIsFalsyValue(condition)) ip += offset; + DISPATCH(); + } + + CASE_CODE(AND): + { + uint16_t offset = READ_SHORT(); + Value condition = PEEK(); + + if (wrenIsFalsyValue(condition)) + { + // Short-circuit the right hand side. + ip += offset; + } + else + { + // Discard the condition and evaluate the right hand side. + DROP(); + } + DISPATCH(); + } + + CASE_CODE(OR): + { + uint16_t offset = READ_SHORT(); + Value condition = PEEK(); + + if (wrenIsFalsyValue(condition)) + { + // Discard the condition and evaluate the right hand side. + DROP(); + } + else + { + // Short-circuit the right hand side. + ip += offset; + } + DISPATCH(); + } + + CASE_CODE(CLOSE_UPVALUE): + // Close the upvalue for the local if we have one. + closeUpvalues(fiber, fiber->stackTop - 1); + DROP(); + DISPATCH(); + + CASE_CODE(RETURN): + { + Value result = POP(); + fiber->numFrames--; + + // Close any upvalues still in scope. + closeUpvalues(fiber, stackStart); + + // If the fiber is complete, end it. + if (fiber->numFrames == 0) + { + // See if there's another fiber to return to. If not, we're done. + if (fiber->caller == NULL) + { + // Store the final result value at the beginning of the stack so the + // C API can get it. + fiber->stack[0] = result; + fiber->stackTop = fiber->stack + 1; + return WREN_RESULT_SUCCESS; + } + + ObjFiber* resumingFiber = fiber->caller; + fiber->caller = NULL; + fiber = resumingFiber; + vm->fiber = resumingFiber; + + // Store the result in the resuming fiber. + fiber->stackTop[-1] = result; + } + else + { + // Store the result of the block in the first slot, which is where the + // caller expects it. + stackStart[0] = result; + + // Discard the stack slots for the call frame (leaving one slot for the + // result). + fiber->stackTop = frame->stackStart + 1; + } + + LOAD_FRAME(); + DISPATCH(); + } + + CASE_CODE(CONSTRUCT): + ASSERT(IS_CLASS(stackStart[0]), "'this' should be a class."); + stackStart[0] = wrenNewInstance(vm, AS_CLASS(stackStart[0])); + DISPATCH(); + + CASE_CODE(FOREIGN_CONSTRUCT): + ASSERT(IS_CLASS(stackStart[0]), "'this' should be a class."); + createForeign(vm, fiber, stackStart); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + + CASE_CODE(CLOSURE): + { + // Create the closure and push it on the stack before creating upvalues + // so that it doesn't get collected. + ObjFn* function = AS_FN(fn->constants.data[READ_SHORT()]); + ObjClosure* closure = wrenNewClosure(vm, function); + PUSH(OBJ_VAL(closure)); + + // Capture upvalues, if any. + for (int i = 0; i < function->numUpvalues; i++) + { + uint8_t isLocal = READ_BYTE(); + uint8_t index = READ_BYTE(); + if (isLocal) + { + // Make an new upvalue to close over the parent's local variable. + closure->upvalues[i] = captureUpvalue(vm, fiber, + frame->stackStart + index); + } + else + { + // Use the same upvalue as the current call frame. + closure->upvalues[i] = frame->closure->upvalues[index]; + } + } + DISPATCH(); + } + + CASE_CODE(END_CLASS): + { + endClass(vm); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + } + + CASE_CODE(CLASS): + { + createClass(vm, READ_BYTE(), NULL); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + } + + CASE_CODE(FOREIGN_CLASS): + { + createClass(vm, -1, fn->module); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + } + + CASE_CODE(METHOD_INSTANCE): + CASE_CODE(METHOD_STATIC): + { + uint16_t symbol = READ_SHORT(); + ObjClass* classObj = AS_CLASS(PEEK()); + Value method = PEEK2(); + bindMethod(vm, instruction, symbol, fn->module, classObj, method); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DROP(); + DROP(); + DISPATCH(); + } + + CASE_CODE(END_MODULE): + { + vm->lastModule = fn->module; + PUSH(NULL_VAL); + DISPATCH(); + } + + CASE_CODE(IMPORT_MODULE): + { + // Make a slot on the stack for the module's fiber to place the return + // value. It will be popped after this fiber is resumed. Store the + // imported module's closure in the slot in case a GC happens when + // invoking the closure. + PUSH(importModule(vm, fn->constants.data[READ_SHORT()])); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + + // If we get a closure, call it to execute the module body. + if (IS_CLOSURE(PEEK())) + { + STORE_FRAME(); + ObjClosure* closure = AS_CLOSURE(PEEK()); + wrenCallFunction(vm, fiber, closure, 1); + LOAD_FRAME(); + } + else + { + // The module has already been loaded. Remember it so we can import + // variables from it if needed. + vm->lastModule = AS_MODULE(PEEK()); + } + + DISPATCH(); + } + + CASE_CODE(IMPORT_VARIABLE): + { + Value variable = fn->constants.data[READ_SHORT()]; + ASSERT(vm->lastModule != NULL, "Should have already imported module."); + Value result = getModuleVariable(vm, vm->lastModule, variable); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + + PUSH(result); + DISPATCH(); + } + + CASE_CODE(END): + // A CODE_END should always be preceded by a CODE_RETURN. If we get here, + // the compiler generated wrong code. + UNREACHABLE(); + } + + // We should only exit this function from an explicit return from CODE_RETURN + // or a runtime error. + UNREACHABLE(); + return WREN_RESULT_RUNTIME_ERROR; + + #undef READ_BYTE + #undef READ_SHORT +} + +WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature) +{ + ASSERT(signature != NULL, "Signature cannot be NULL."); + + int signatureLength = (int)strlen(signature); + ASSERT(signatureLength > 0, "Signature cannot be empty."); + + // Count the number parameters the method expects. + int numParams = 0; + if (signature[signatureLength - 1] == ')') + { + for (int i = signatureLength - 1; i > 0 && signature[i] != '('; i--) + { + if (signature[i] == '_') numParams++; + } + } + + // Count subscript arguments. + if (signature[0] == '[') + { + for (int i = 0; i < signatureLength && signature[i] != ']'; i++) + { + if (signature[i] == '_') numParams++; + } + } + + // Add the signatue to the method table. + int method = wrenSymbolTableEnsure(vm, &vm->methodNames, + signature, signatureLength); + + // Create a little stub function that assumes the arguments are on the stack + // and calls the method. + ObjFn* fn = wrenNewFunction(vm, NULL, numParams + 1); + + // Wrap the function in a closure and then in a handle. Do this here so it + // doesn't get collected as we fill it in. + WrenHandle* value = wrenMakeHandle(vm, OBJ_VAL(fn)); + value->value = OBJ_VAL(wrenNewClosure(vm, fn)); + + wrenByteBufferWrite(vm, &fn->code, (uint8_t)(CODE_CALL_0 + numParams)); + wrenByteBufferWrite(vm, &fn->code, (method >> 8) & 0xff); + wrenByteBufferWrite(vm, &fn->code, method & 0xff); + wrenByteBufferWrite(vm, &fn->code, CODE_RETURN); + wrenByteBufferWrite(vm, &fn->code, CODE_END); + wrenIntBufferFill(vm, &fn->debug->sourceLines, 0, 5); + wrenFunctionBindName(vm, fn, signature, signatureLength); + + return value; +} + +WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method) +{ + ASSERT(method != NULL, "Method cannot be NULL."); + ASSERT(IS_CLOSURE(method->value), "Method must be a method handle."); + ASSERT(vm->fiber != NULL, "Must set up arguments for call first."); + ASSERT(vm->apiStack != NULL, "Must set up arguments for call first."); + ASSERT(vm->fiber->numFrames == 0, "Can not call from a foreign method."); + + ObjClosure* closure = AS_CLOSURE(method->value); + + ASSERT(vm->fiber->stackTop - vm->fiber->stack >= closure->fn->arity, + "Stack must have enough arguments for method."); + + // Clear the API stack. Now that wrenCall() has control, we no longer need + // it. We use this being non-null to tell if re-entrant calls to foreign + // methods are happening, so it's important to clear it out now so that you + // can call foreign methods from within calls to wrenCall(). + vm->apiStack = NULL; + + // Discard any extra temporary slots. We take for granted that the stub + // function has exactly one slot for each argument. + vm->fiber->stackTop = &vm->fiber->stack[closure->fn->maxSlots]; + + wrenCallFunction(vm, vm->fiber, closure, 0); + WrenInterpretResult result = runInterpreter(vm, vm->fiber); + + // If the call didn't abort, then set up the API stack to point to the + // beginning of the stack so the host can access the call's return value. + if (vm->fiber != NULL) vm->apiStack = vm->fiber->stack; + + return result; +} + +WrenHandle* wrenMakeHandle(WrenVM* vm, Value value) +{ + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + // Make a handle for it. + WrenHandle* handle = ALLOCATE(vm, WrenHandle); + handle->value = value; + + if (IS_OBJ(value)) wrenPopRoot(vm); + + // Add it to the front of the linked list of handles. + if (vm->handles != NULL) vm->handles->prev = handle; + handle->prev = NULL; + handle->next = vm->handles; + vm->handles = handle; + + return handle; +} + +void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle) +{ + ASSERT(handle != NULL, "Handle cannot be NULL."); + + // Update the VM's head pointer if we're releasing the first handle. + if (vm->handles == handle) vm->handles = handle->next; + + // Unlink it from the list. + if (handle->prev != NULL) handle->prev->next = handle->next; + if (handle->next != NULL) handle->next->prev = handle->prev; + + // Clear it out. This isn't strictly necessary since we're going to free it, + // but it makes for easier debugging. + handle->prev = NULL; + handle->next = NULL; + handle->value = NULL_VAL; + DEALLOCATE(vm, handle); +} + +WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module, + const char* source) +{ + ObjClosure* closure = wrenCompileSource(vm, module, source, false, true); + if (closure == NULL) return WREN_RESULT_COMPILE_ERROR; + + wrenPushRoot(vm, (Obj*)closure); + ObjFiber* fiber = wrenNewFiber(vm, closure); + wrenPopRoot(vm); // closure. + vm->apiStack = NULL; + + return runInterpreter(vm, fiber); +} + +ObjClosure* wrenCompileSource(WrenVM* vm, const char* module, const char* source, + bool isExpression, bool printErrors) +{ + Value nameValue = NULL_VAL; + if (module != NULL) + { + nameValue = wrenNewString(vm, module); + wrenPushRoot(vm, AS_OBJ(nameValue)); + } + + ObjClosure* closure = compileInModule(vm, nameValue, source, + isExpression, printErrors); + + if (module != NULL) wrenPopRoot(vm); // nameValue. + return closure; +} + +Value wrenGetModuleVariable(WrenVM* vm, Value moduleName, Value variableName) +{ + ObjModule* module = getModule(vm, moduleName); + if (module == NULL) + { + vm->fiber->error = wrenStringFormat(vm, "Module '@' is not loaded.", + moduleName); + return NULL_VAL; + } + + return getModuleVariable(vm, module, variableName); +} + +Value wrenFindVariable(WrenVM* vm, ObjModule* module, const char* name) +{ + int symbol = wrenSymbolTableFind(&module->variableNames, name, strlen(name)); + return module->variables.data[symbol]; +} + +int wrenDeclareVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, int line) +{ + if (module->variables.count == MAX_MODULE_VARS) return -2; + + // Implicitly defined variables get a "value" that is the line where the + // variable is first used. We'll use that later to report an error on the + // right line. + wrenValueBufferWrite(vm, &module->variables, NUM_VAL(line)); + return wrenSymbolTableAdd(vm, &module->variableNames, name, length); +} + +int wrenDefineVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, Value value, int* line) +{ + if (module->variables.count == MAX_MODULE_VARS) return -2; + + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + // See if the variable is already explicitly or implicitly declared. + int symbol = wrenSymbolTableFind(&module->variableNames, name, length); + + if (symbol == -1) + { + // Brand new variable. + symbol = wrenSymbolTableAdd(vm, &module->variableNames, name, length); + wrenValueBufferWrite(vm, &module->variables, value); + } + else if (IS_NUM(module->variables.data[symbol])) + { + // An implicitly declared variable's value will always be a number. + // Now we have a real definition. + if(line) *line = (int)AS_NUM(module->variables.data[symbol]); + module->variables.data[symbol] = value; + + // If this was a localname we want to error if it was + // referenced before this definition. + if (wrenIsLocalName(name)) symbol = -3; + } + else + { + // Already explicitly declared. + symbol = -1; + } + + if (IS_OBJ(value)) wrenPopRoot(vm); + + return symbol; +} + +// TODO: Inline? +void wrenPushRoot(WrenVM* vm, Obj* obj) +{ + ASSERT(obj != NULL, "Can't root NULL."); + ASSERT(vm->numTempRoots < WREN_MAX_TEMP_ROOTS, "Too many temporary roots."); + + vm->tempRoots[vm->numTempRoots++] = obj; +} + +void wrenPopRoot(WrenVM* vm) +{ + ASSERT(vm->numTempRoots > 0, "No temporary roots to release."); + vm->numTempRoots--; +} + +int wrenGetSlotCount(WrenVM* vm) +{ + if (vm->apiStack == NULL) return 0; + + return (int)(vm->fiber->stackTop - vm->apiStack); +} + +void wrenEnsureSlots(WrenVM* vm, int numSlots) +{ + // If we don't have a fiber accessible, create one for the API to use. + if (vm->apiStack == NULL) + { + vm->fiber = wrenNewFiber(vm, NULL); + vm->apiStack = vm->fiber->stack; + } + + int currentSize = (int)(vm->fiber->stackTop - vm->apiStack); + if (currentSize >= numSlots) return; + + // Grow the stack if needed. + int needed = (int)(vm->apiStack - vm->fiber->stack) + numSlots; + wrenEnsureStack(vm, vm->fiber, needed); + + vm->fiber->stackTop = vm->apiStack + numSlots; +} + +// Ensures that [slot] is a valid index into the API's stack of slots. +static void validateApiSlot(WrenVM* vm, int slot) +{ + ASSERT(slot >= 0, "Slot cannot be negative."); + ASSERT(slot < wrenGetSlotCount(vm), "Not that many slots."); +} + +// Gets the type of the object in [slot]. +WrenType wrenGetSlotType(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + if (IS_BOOL(vm->apiStack[slot])) return WREN_TYPE_BOOL; + if (IS_NUM(vm->apiStack[slot])) return WREN_TYPE_NUM; + if (IS_FOREIGN(vm->apiStack[slot])) return WREN_TYPE_FOREIGN; + if (IS_LIST(vm->apiStack[slot])) return WREN_TYPE_LIST; + if (IS_MAP(vm->apiStack[slot])) return WREN_TYPE_MAP; + if (IS_NULL(vm->apiStack[slot])) return WREN_TYPE_NULL; + if (IS_STRING(vm->apiStack[slot])) return WREN_TYPE_STRING; + + return WREN_TYPE_UNKNOWN; +} + +bool wrenGetSlotBool(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_BOOL(vm->apiStack[slot]), "Slot must hold a bool."); + + return AS_BOOL(vm->apiStack[slot]); +} + +const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length) +{ + validateApiSlot(vm, slot); + ASSERT(IS_STRING(vm->apiStack[slot]), "Slot must hold a string."); + + ObjString* string = AS_STRING(vm->apiStack[slot]); + *length = string->length; + return string->value; +} + +double wrenGetSlotDouble(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_NUM(vm->apiStack[slot]), "Slot must hold a number."); + + return AS_NUM(vm->apiStack[slot]); +} + +void* wrenGetSlotForeign(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_FOREIGN(vm->apiStack[slot]), + "Slot must hold a foreign instance."); + + return AS_FOREIGN(vm->apiStack[slot])->data; +} + +const char* wrenGetSlotString(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_STRING(vm->apiStack[slot]), "Slot must hold a string."); + + return AS_CSTRING(vm->apiStack[slot]); +} + +WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + return wrenMakeHandle(vm, vm->apiStack[slot]); +} + +// Stores [value] in [slot] in the foreign call stack. +static void setSlot(WrenVM* vm, int slot, Value value) +{ + validateApiSlot(vm, slot); + vm->apiStack[slot] = value; +} + +void wrenSetSlotBool(WrenVM* vm, int slot, bool value) +{ + setSlot(vm, slot, BOOL_VAL(value)); +} + +void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length) +{ + ASSERT(bytes != NULL, "Byte array cannot be NULL."); + setSlot(vm, slot, wrenNewStringLength(vm, bytes, length)); +} + +void wrenSetSlotDouble(WrenVM* vm, int slot, double value) +{ + setSlot(vm, slot, NUM_VAL(value)); +} + +void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size) +{ + validateApiSlot(vm, slot); + validateApiSlot(vm, classSlot); + ASSERT(IS_CLASS(vm->apiStack[classSlot]), "Slot must hold a class."); + + ObjClass* classObj = AS_CLASS(vm->apiStack[classSlot]); + ASSERT(classObj->numFields == -1, "Class must be a foreign class."); + + ObjForeign* foreign = wrenNewForeign(vm, classObj, size); + vm->apiStack[slot] = OBJ_VAL(foreign); + + return (void*)foreign->data; +} + +void wrenSetSlotNewList(WrenVM* vm, int slot) +{ + setSlot(vm, slot, OBJ_VAL(wrenNewList(vm, 0))); +} + +void wrenSetSlotNewMap(WrenVM* vm, int slot) +{ + setSlot(vm, slot, OBJ_VAL(wrenNewMap(vm))); +} + +void wrenSetSlotNull(WrenVM* vm, int slot) +{ + setSlot(vm, slot, NULL_VAL); +} + +void wrenSetSlotString(WrenVM* vm, int slot, const char* text) +{ + ASSERT(text != NULL, "String cannot be NULL."); + + setSlot(vm, slot, wrenNewString(vm, text)); +} + +void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle) +{ + ASSERT(handle != NULL, "Handle cannot be NULL."); + + setSlot(vm, slot, handle->value); +} + +int wrenGetListCount(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_LIST(vm->apiStack[slot]), "Slot must hold a list."); + + ValueBuffer elements = AS_LIST(vm->apiStack[slot])->elements; + return elements.count; +} + +void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot) +{ + validateApiSlot(vm, listSlot); + validateApiSlot(vm, elementSlot); + ASSERT(IS_LIST(vm->apiStack[listSlot]), "Slot must hold a list."); + + ValueBuffer elements = AS_LIST(vm->apiStack[listSlot])->elements; + + uint32_t usedIndex = wrenValidateIndex(elements.count, index); + ASSERT(usedIndex != UINT32_MAX, "Index out of bounds."); + + vm->apiStack[elementSlot] = elements.data[usedIndex]; +} + +void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot) +{ + validateApiSlot(vm, listSlot); + validateApiSlot(vm, elementSlot); + ASSERT(IS_LIST(vm->apiStack[listSlot]), "Slot must hold a list."); + + ObjList* list = AS_LIST(vm->apiStack[listSlot]); + + uint32_t usedIndex = wrenValidateIndex(list->elements.count, index); + ASSERT(usedIndex != UINT32_MAX, "Index out of bounds."); + + list->elements.data[usedIndex] = vm->apiStack[elementSlot]; +} + +void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot) +{ + validateApiSlot(vm, listSlot); + validateApiSlot(vm, elementSlot); + ASSERT(IS_LIST(vm->apiStack[listSlot]), "Must insert into a list."); + + ObjList* list = AS_LIST(vm->apiStack[listSlot]); + + // Negative indices count from the end. + // We don't use wrenValidateIndex here because insert allows 1 past the end. + if (index < 0) index = list->elements.count + 1 + index; + + ASSERT(index <= list->elements.count, "Index out of bounds."); + + wrenListInsert(vm, list, vm->apiStack[elementSlot], index); +} + +int wrenGetMapCount(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_MAP(vm->apiStack[slot]), "Slot must hold a map."); + + ObjMap* map = AS_MAP(vm->apiStack[slot]); + return map->count; +} + +bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map."); + + Value key = vm->apiStack[keySlot]; + ASSERT(wrenMapIsValidKey(key), "Key must be a value type"); + if (!validateKey(vm, key)) return false; + + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + Value value = wrenMapGet(map, key); + + return !IS_UNDEFINED(value); +} + +void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + validateApiSlot(vm, valueSlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map."); + + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + Value value = wrenMapGet(map, vm->apiStack[keySlot]); + if (IS_UNDEFINED(value)) { + value = NULL_VAL; + } + + vm->apiStack[valueSlot] = value; +} + +void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + validateApiSlot(vm, valueSlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Must insert into a map."); + + Value key = vm->apiStack[keySlot]; + ASSERT(wrenMapIsValidKey(key), "Key must be a value type"); + + if (!validateKey(vm, key)) { + return; + } + + Value value = vm->apiStack[valueSlot]; + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + + wrenMapSet(vm, map, key, value); +} + +void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot, + int removedValueSlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map."); + + Value key = vm->apiStack[keySlot]; + if (!validateKey(vm, key)) { + return; + } + + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + Value removed = wrenMapRemoveKey(vm, map, key); + setSlot(vm, removedValueSlot, removed); +} + +void wrenGetVariable(WrenVM* vm, const char* module, const char* name, + int slot) +{ + ASSERT(module != NULL, "Module cannot be NULL."); + ASSERT(name != NULL, "Variable name cannot be NULL."); + + Value moduleName = wrenStringFormat(vm, "$", module); + wrenPushRoot(vm, AS_OBJ(moduleName)); + + ObjModule* moduleObj = getModule(vm, moduleName); + ASSERT(moduleObj != NULL, "Could not find module."); + + wrenPopRoot(vm); // moduleName. + + int variableSlot = wrenSymbolTableFind(&moduleObj->variableNames, + name, strlen(name)); + ASSERT(variableSlot != -1, "Could not find variable."); + + setSlot(vm, slot, moduleObj->variables.data[variableSlot]); +} + +bool wrenHasVariable(WrenVM* vm, const char* module, const char* name) +{ + ASSERT(module != NULL, "Module cannot be NULL."); + ASSERT(name != NULL, "Variable name cannot be NULL."); + + Value moduleName = wrenStringFormat(vm, "$", module); + wrenPushRoot(vm, AS_OBJ(moduleName)); + + //We don't use wrenHasModule since we want to use the module object. + ObjModule* moduleObj = getModule(vm, moduleName); + ASSERT(moduleObj != NULL, "Could not find module."); + + wrenPopRoot(vm); // moduleName. + + int variableSlot = wrenSymbolTableFind(&moduleObj->variableNames, + name, strlen(name)); + + return variableSlot != -1; +} + +bool wrenHasModule(WrenVM* vm, const char* module) +{ + ASSERT(module != NULL, "Module cannot be NULL."); + + Value moduleName = wrenStringFormat(vm, "$", module); + wrenPushRoot(vm, AS_OBJ(moduleName)); + + ObjModule* moduleObj = getModule(vm, moduleName); + + wrenPopRoot(vm); // moduleName. + + return moduleObj != NULL; +} + +void wrenAbortFiber(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + vm->fiber->error = vm->apiStack[slot]; +} + +void* wrenGetUserData(WrenVM* vm) +{ + return vm->config.userData; +} + +void wrenSetUserData(WrenVM* vm, void* userData) +{ + vm->config.userData = userData; +} +// End file "wren_vm.c" +// Begin file "wren_opt_random.c" + +#if WREN_OPT_RANDOM + +#include +#include + + +// Begin file "wren_opt_random.wren.inc" +// Generated automatically from src/optional/wren_opt_random.wren. Do not edit. +static const char* randomModuleSource = +"foreign class Random {\n" +" construct new() {\n" +" seed_()\n" +" }\n" +"\n" +" construct new(seed) {\n" +" if (seed is Num) {\n" +" seed_(seed)\n" +" } else if (seed is Sequence) {\n" +" if (seed.isEmpty) Fiber.abort(\"Sequence cannot be empty.\")\n" +"\n" +" // TODO: Empty sequence.\n" +" var seeds = []\n" +" for (element in seed) {\n" +" if (!(element is Num)) Fiber.abort(\"Sequence elements must all be numbers.\")\n" +"\n" +" seeds.add(element)\n" +" if (seeds.count == 16) break\n" +" }\n" +"\n" +" // Cycle the values to fill in any missing slots.\n" +" var i = 0\n" +" while (seeds.count < 16) {\n" +" seeds.add(seeds[i])\n" +" i = i + 1\n" +" }\n" +"\n" +" seed_(\n" +" seeds[0], seeds[1], seeds[2], seeds[3],\n" +" seeds[4], seeds[5], seeds[6], seeds[7],\n" +" seeds[8], seeds[9], seeds[10], seeds[11],\n" +" seeds[12], seeds[13], seeds[14], seeds[15])\n" +" } else {\n" +" Fiber.abort(\"Seed must be a number or a sequence of numbers.\")\n" +" }\n" +" }\n" +"\n" +" foreign seed_()\n" +" foreign seed_(seed)\n" +" foreign seed_(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10, n11, n12, n13, n14, n15, n16)\n" +"\n" +" foreign float()\n" +" float(end) { float() * end }\n" +" float(start, end) { float() * (end - start) + start }\n" +"\n" +" foreign int()\n" +" int(end) { (float() * end).floor }\n" +" int(start, end) { (float() * (end - start)).floor + start }\n" +"\n" +" sample(list) {\n" +" if (list.count == 0) Fiber.abort(\"Not enough elements to sample.\")\n" +" return list[int(list.count)]\n" +" }\n" +" sample(list, count) {\n" +" if (count > list.count) Fiber.abort(\"Not enough elements to sample.\")\n" +"\n" +" var result = []\n" +"\n" +" // The algorithm described in \"Programming pearls: a sample of brilliance\".\n" +" // Use a hash map for sample sizes less than 1/4 of the population size and\n" +" // an array of booleans for larger samples. This simple heuristic improves\n" +" // performance for large sample sizes as well as reduces memory usage.\n" +" if (count * 4 < list.count) {\n" +" var picked = {}\n" +" for (i in list.count - count...list.count) {\n" +" var index = int(i + 1)\n" +" if (picked.containsKey(index)) index = i\n" +" picked[index] = true\n" +" result.add(list[index])\n" +" }\n" +" } else {\n" +" var picked = List.filled(list.count, false)\n" +" for (i in list.count - count...list.count) {\n" +" var index = int(i + 1)\n" +" if (picked[index]) index = i\n" +" picked[index] = true\n" +" result.add(list[index])\n" +" }\n" +" }\n" +"\n" +" return result\n" +" }\n" +"\n" +" shuffle(list) {\n" +" if (list.isEmpty) return\n" +"\n" +" // Fisher-Yates shuffle.\n" +" for (i in 0...list.count - 1) {\n" +" var from = int(i, list.count)\n" +" var temp = list[from]\n" +" list[from] = list[i]\n" +" list[i] = temp\n" +" }\n" +" }\n" +"}\n"; +// End file "wren_opt_random.wren.inc" + +// Implements the well equidistributed long-period linear PRNG (WELL512a). +// +// https://en.wikipedia.org/wiki/Well_equidistributed_long-period_linear +typedef struct +{ + uint32_t state[16]; + uint32_t index; +} Well512; + +// Code from: http://www.lomont.org/Math/Papers/2008/Lomont_PRNG_2008.pdf +static uint32_t advanceState(Well512* well) +{ + uint32_t a, b, c, d; + a = well->state[well->index]; + c = well->state[(well->index + 13) & 15]; + b = a ^ c ^ (a << 16) ^ (c << 15); + c = well->state[(well->index + 9) & 15]; + c ^= (c >> 11); + a = well->state[well->index] = b ^ c; + d = a ^ ((a << 5) & 0xda442d24U); + + well->index = (well->index + 15) & 15; + a = well->state[well->index]; + well->state[well->index] = a ^ b ^ d ^ (a << 2) ^ (b << 18) ^ (c << 28); + return well->state[well->index]; +} + +static void randomAllocate(WrenVM* vm) +{ + Well512* well = (Well512*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(Well512)); + well->index = 0; +} + +static void randomSeed0(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + srand((uint32_t)time(NULL)); + for (int i = 0; i < 16; i++) + { + well->state[i] = rand(); + } +} + +static void randomSeed1(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + srand((uint32_t)wrenGetSlotDouble(vm, 1)); + for (int i = 0; i < 16; i++) + { + well->state[i] = rand(); + } +} + +static void randomSeed16(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + for (int i = 0; i < 16; i++) + { + well->state[i] = (uint32_t)wrenGetSlotDouble(vm, i + 1); + } +} + +static void randomFloat(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + // A double has 53 bits of precision in its mantissa, and we'd like to take + // full advantage of that, so we need 53 bits of random source data. + + // First, start with 32 random bits, shifted to the left 21 bits. + double result = (double)advanceState(well) * (1 << 21); + + // Then add another 21 random bits. + result += (double)(advanceState(well) & ((1 << 21) - 1)); + + // Now we have a number from 0 - (2^53). Divide be the range to get a double + // from 0 to 1.0 (half-inclusive). + result /= 9007199254740992.0; + + wrenSetSlotDouble(vm, 0, result); +} + +static void randomInt0(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + wrenSetSlotDouble(vm, 0, (double)advanceState(well)); +} + +const char* wrenRandomSource() +{ + return randomModuleSource; +} + +WrenForeignClassMethods wrenRandomBindForeignClass(WrenVM* vm, + const char* module, + const char* className) +{ + ASSERT(strcmp(className, "Random") == 0, "Should be in Random class."); + WrenForeignClassMethods methods; + methods.allocate = randomAllocate; + methods.finalize = NULL; + return methods; +} + +WrenForeignMethodFn wrenRandomBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature) +{ + ASSERT(strcmp(className, "Random") == 0, "Should be in Random class."); + + if (strcmp(signature, "") == 0) return randomAllocate; + if (strcmp(signature, "seed_()") == 0) return randomSeed0; + if (strcmp(signature, "seed_(_)") == 0) return randomSeed1; + + if (strcmp(signature, "seed_(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)") == 0) + { + return randomSeed16; + } + + if (strcmp(signature, "float()") == 0) return randomFloat; + if (strcmp(signature, "int()") == 0) return randomInt0; + + ASSERT(false, "Unknown method."); + return NULL; +} + +#endif +// End file "wren_opt_random.c" +// Begin file "wren_opt_meta.c" + +#if WREN_OPT_META + +#include + +// Begin file "wren_opt_meta.wren.inc" +// Generated automatically from src/optional/wren_opt_meta.wren. Do not edit. +static const char* metaModuleSource = +"class Meta {\n" +" static getModuleVariables(module) {\n" +" if (!(module is String)) Fiber.abort(\"Module name must be a string.\")\n" +" var result = getModuleVariables_(module)\n" +" if (result != null) return result\n" +"\n" +" Fiber.abort(\"Could not find a module named '%(module)'.\")\n" +" }\n" +"\n" +" static eval(source) {\n" +" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n" +"\n" +" var closure = compile_(source, false, false)\n" +" // TODO: Include compile errors.\n" +" if (closure == null) Fiber.abort(\"Could not compile source code.\")\n" +"\n" +" closure.call()\n" +" }\n" +"\n" +" static compileExpression(source) {\n" +" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n" +" return compile_(source, true, true)\n" +" }\n" +"\n" +" static compile(source) {\n" +" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n" +" return compile_(source, false, true)\n" +" }\n" +"\n" +" foreign static compile_(source, isExpression, printErrors)\n" +" foreign static getModuleVariables_(module)\n" +"}\n"; +// End file "wren_opt_meta.wren.inc" + +void metaCompile(WrenVM* vm) +{ + const char* source = wrenGetSlotString(vm, 1); + bool isExpression = wrenGetSlotBool(vm, 2); + bool printErrors = wrenGetSlotBool(vm, 3); + + // TODO: Allow passing in module? + // Look up the module surrounding the callsite. This is brittle. The -2 walks + // up the callstack assuming that the meta module has one level of + // indirection before hitting the user's code. Any change to meta may require + // this constant to be tweaked. + ObjFiber* currentFiber = vm->fiber; + ObjFn* fn = currentFiber->frames[currentFiber->numFrames - 2].closure->fn; + ObjString* module = fn->module->name; + + ObjClosure* closure = wrenCompileSource(vm, module->value, source, + isExpression, printErrors); + + // Return the result. We can't use the public API for this since we have a + // bare ObjClosure*. + if (closure == NULL) + { + vm->apiStack[0] = NULL_VAL; + } + else + { + vm->apiStack[0] = OBJ_VAL(closure); + } +} + +void metaGetModuleVariables(WrenVM* vm) +{ + wrenEnsureSlots(vm, 3); + + Value moduleValue = wrenMapGet(vm->modules, vm->apiStack[1]); + if (IS_UNDEFINED(moduleValue)) + { + vm->apiStack[0] = NULL_VAL; + return; + } + + ObjModule* module = AS_MODULE(moduleValue); + ObjList* names = wrenNewList(vm, module->variableNames.count); + vm->apiStack[0] = OBJ_VAL(names); + + // Initialize the elements to null in case a collection happens when we + // allocate the strings below. + for (int i = 0; i < names->elements.count; i++) + { + names->elements.data[i] = NULL_VAL; + } + + for (int i = 0; i < names->elements.count; i++) + { + names->elements.data[i] = OBJ_VAL(module->variableNames.data[i]); + } +} + +const char* wrenMetaSource() +{ + return metaModuleSource; +} + +WrenForeignMethodFn wrenMetaBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature) +{ + // There is only one foreign method in the meta module. + ASSERT(strcmp(className, "Meta") == 0, "Should be in Meta class."); + ASSERT(isStatic, "Should be static."); + + if (strcmp(signature, "compile_(_,_,_)") == 0) + { + return metaCompile; + } + + if (strcmp(signature, "getModuleVariables_(_)") == 0) + { + return metaGetModuleVariables; + } + + ASSERT(false, "Unknown method."); + return NULL; +} + +#endif +// End file "wren_opt_meta.c" + +// End of wren.c + diff --git a/requests.wren b/requests.wren new file mode 100644 index 0000000..2d1619a --- /dev/null +++ b/requests.wren @@ -0,0 +1,31 @@ +// requests.wren +foreign class Response { + construct new() {} + + foreign isError + foreign statusCode + foreign body + foreign json() +} + +class Requests { + // Foreign methods now expect a callback function. + foreign static get_(url, headers, callback) + foreign static post_(url, body, contentType, headers, callback) + + static get(url, headers, callback) { + if (!(callback is Fn) || callback.arity != 2) { + Fiber.abort("Callback must be a function that accepts 2 arguments: (error, response).") + } + get_(url, headers, callback) + } + + static post(url, body, headers, callback) { + if (!(callback is Fn) || callback.arity != 2) { + Fiber.abort("Callback must be a function that accepts 2 arguments: (error, response).") + } + var contentType = headers != null && headers.containsKey("Content-Type") ? + headers["Content-Type"] : "application/json" + post_(url, body, contentType, headers, callback) + } +} diff --git a/requests_backend.c b/requests_backend.c new file mode 100644 index 0000000..0e8aac4 --- /dev/null +++ b/requests_backend.c @@ -0,0 +1,385 @@ +#include "wren.h" +#include +#include +#include +#include + +#ifdef _WIN32 + #include + typedef HANDLE thread_t; + typedef CRITICAL_SECTION mutex_t; + typedef CONDITION_VARIABLE cond_t; +#else + #include + typedef pthread_t thread_t; + typedef pthread_mutex_t mutex_t; + typedef pthread_cond_t cond_t; +#endif + +// --- Data Structures --- + +typedef struct { + int isError; + long statusCode; + char* body; + size_t body_len; +} ResponseData; + +typedef struct { + char* memory; + size_t size; +} MemoryStruct; + +typedef struct HttpContext { + WrenVM* vm; + WrenHandle* callback; + + char* url; + char* method; + char* body; + struct curl_slist* headers; + + bool success; + char* response_body; + size_t response_body_len; + long status_code; + char* error_message; + struct HttpContext* next; +} HttpContext; + + +// --- Thread-Safe Queue --- + +typedef struct { + HttpContext *head, *tail; + mutex_t mutex; + cond_t cond; +} ThreadSafeQueue; + +void http_queue_init(ThreadSafeQueue* q) { + q->head = q->tail = NULL; + #ifdef _WIN32 + InitializeCriticalSection(&q->mutex); + InitializeConditionVariable(&q->cond); + #else + pthread_mutex_init(&q->mutex, NULL); + pthread_cond_init(&q->cond, NULL); + #endif +} + +void http_queue_destroy(ThreadSafeQueue* q) { + #ifdef _WIN32 + DeleteCriticalSection(&q->mutex); + #else + pthread_mutex_destroy(&q->mutex); + pthread_cond_destroy(&q->cond); + #endif +} + +void http_queue_push(ThreadSafeQueue* q, HttpContext* context) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + #endif + + if(context) context->next = NULL; + if (q->tail) q->tail->next = context; + else q->head = context; + q->tail = context; + + #ifdef _WIN32 + WakeConditionVariable(&q->cond); + LeaveCriticalSection(&q->mutex); + #else + pthread_cond_signal(&q->cond); + pthread_mutex_unlock(&q->mutex); + #endif +} + +HttpContext* http_queue_pop(ThreadSafeQueue* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + while (q->head == NULL) { + SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE); + } + #else + pthread_mutex_lock(&q->mutex); + while (q->head == NULL) { + pthread_cond_wait(&q->cond, &q->mutex); + } + #endif + + HttpContext* context = q->head; + q->head = q->head->next; + if (q->head == NULL) q->tail = NULL; + + #ifdef _WIN32 + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_unlock(&q->mutex); + #endif + + return context; +} + +bool http_queue_empty(ThreadSafeQueue* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + bool empty = (q->head == NULL); + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + bool empty = (q->head == NULL); + pthread_mutex_unlock(&q->mutex); + #endif + return empty; +} + + +// --- libcurl Helpers --- +static size_t write_memory_callback(void *contents, size_t size, size_t nmemb, void *userp) { + size_t realsize = size * nmemb; + MemoryStruct *mem = (MemoryStruct *)userp; + char *ptr = (char*)realloc(mem->memory, mem->size + realsize + 1); + if (ptr == NULL) return 0; + mem->memory = ptr; + memcpy(&(mem->memory[mem->size]), contents, realsize); + mem->size += realsize; + mem->memory[mem->size] = 0; + return realsize; +} + +// --- Async HTTP Manager --- + +typedef struct { + WrenVM* vm; + volatile bool running; + thread_t threads[4]; + ThreadSafeQueue requestQueue; + ThreadSafeQueue completionQueue; +} AsyncHttpManager; + +static AsyncHttpManager* httpManager = NULL; + +void free_http_context(HttpContext* context) { + if (!context) return; + free(context->url); + free(context->method); + free(context->body); + curl_slist_free_all(context->headers); + free(context->response_body); + free(context->error_message); + free(context); +} + +#ifdef _WIN32 +DWORD WINAPI httpWorkerThread(LPVOID arg) { +#else +void* httpWorkerThread(void* arg) { +#endif + AsyncHttpManager* manager = (AsyncHttpManager*)arg; + while (manager->running) { + HttpContext* context = http_queue_pop(&manager->requestQueue); + if (!context || !manager->running) { + if (context) free_http_context(context); + break; + } + + CURL *curl = curl_easy_init(); + if (curl) { + MemoryStruct chunk; + chunk.memory = (char*)malloc(1); + chunk.size = 0; + + curl_easy_setopt(curl, CURLOPT_URL, context->url); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_memory_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, (void *)&chunk); + curl_easy_setopt(curl, CURLOPT_USERAGENT, "wren-curl-agent/1.0"); + + if (strcmp(context->method, "POST") == 0) { + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, context->body); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, context->headers); + } + + CURLcode res = curl_easy_perform(curl); + + if (res == CURLE_OK) { + context->success = true; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &context->status_code); + context->response_body = chunk.memory; + context->response_body_len = chunk.size; + } else { + context->success = false; + context->status_code = -1; + context->error_message = strdup(curl_easy_strerror(res)); + free(chunk.memory); + } + curl_easy_cleanup(curl); + } else { + context->success = false; + context->error_message = strdup("Failed to initialize cURL handle."); + } + http_queue_push(&manager->completionQueue, context); + } + return 0; +} + +void httpManager_create(WrenVM* vm) { + httpManager = (AsyncHttpManager*)malloc(sizeof(AsyncHttpManager)); + httpManager->vm = vm; + httpManager->running = true; + http_queue_init(&httpManager->requestQueue); + http_queue_init(&httpManager->completionQueue); + for (int i = 0; i < 4; ++i) { + #ifdef _WIN32 + httpManager->threads[i] = CreateThread(NULL, 0, httpWorkerThread, httpManager, 0, NULL); + #else + pthread_create(&httpManager->threads[i], NULL, httpWorkerThread, httpManager); + #endif + } +} + +void httpManager_destroy() { + httpManager->running = false; + for (int i = 0; i < 4; ++i) { + http_queue_push(&httpManager->requestQueue, NULL); + } + for (int i = 0; i < 4; ++i) { + #ifdef _WIN32 + WaitForSingleObject(httpManager->threads[i], INFINITE); + CloseHandle(httpManager->threads[i]); + #else + pthread_join(httpManager->threads[i], NULL); + #endif + } + http_queue_destroy(&httpManager->requestQueue); + http_queue_destroy(&httpManager->completionQueue); + free(httpManager); +} + +void httpManager_processCompletions() { + while (!http_queue_empty(&httpManager->completionQueue)) { + HttpContext* context = http_queue_pop(&httpManager->completionQueue); + + WrenHandle* callHandle = wrenMakeCallHandle(httpManager->vm, "call(_,_)"); + wrenEnsureSlots(httpManager->vm, 3); + wrenSetSlotHandle(httpManager->vm, 0, context->callback); + + if (context->success) { + wrenSetSlotNull(httpManager->vm, 1); + + wrenGetVariable(httpManager->vm, "requests", "Response", 2); + void* foreign = wrenSetSlotNewForeign(httpManager->vm, 2, 2, sizeof(ResponseData)); + ResponseData* data = (ResponseData*)foreign; + data->isError = false; + data->statusCode = context->status_code; + data->body = context->response_body; + data->body_len = context->response_body_len; + context->response_body = NULL; + } else { + wrenSetSlotString(httpManager->vm, 1, context->error_message); + wrenSetSlotNull(httpManager->vm, 2); + } + + wrenCall(httpManager->vm, callHandle); + wrenReleaseHandle(httpManager->vm, context->callback); + wrenReleaseHandle(httpManager->vm, callHandle); + free_http_context(context); + } +} + +void httpManager_submit(HttpContext* context) { + http_queue_push(&httpManager->requestQueue, context); +} + +// --- Wren Foreign Methods --- + +void responseFinalize(void* data) { + ResponseData* response = (ResponseData*)data; + free(response->body); +} + +void responseAllocate(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(ResponseData)); + data->isError = 0; + data->statusCode = 0; + data->body = NULL; + data->body_len = 0; +} + +void responseIsError(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, data->isError ? true : false); +} + +void responseStatusCode(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotDouble(vm, 0, (double)data->statusCode); +} + +void responseBody(WrenVM* vm) { + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBytes(vm, 0, data->body ? data->body : "", data->body_len); +} + +void responseJson(WrenVM* vm) { + // CORRECTED: Replaced incorrect call with the actual logic. + ResponseData* data = (ResponseData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBytes(vm, 0, data->body ? data->body : "", data->body_len); +} + +void requestsGet(WrenVM* vm) { + HttpContext* context = (HttpContext*)calloc(1, sizeof(HttpContext)); + context->vm = vm; + context->method = strdup("GET"); + context->url = strdup(wrenGetSlotString(vm, 1)); + context->callback = wrenGetSlotHandle(vm, 3); + httpManager_submit(context); +} + +void requestsPost(WrenVM* vm) { + HttpContext* context = (HttpContext*)calloc(1, sizeof(HttpContext)); + context->vm = vm; + context->method = strdup("POST"); + context->url = strdup(wrenGetSlotString(vm, 1)); + context->body = strdup(wrenGetSlotString(vm, 2)); + const char* contentType = wrenGetSlotString(vm, 3); + char contentTypeHeader[256]; + snprintf(contentTypeHeader, sizeof(contentTypeHeader), "Content-Type: %s", contentType); + context->headers = curl_slist_append(NULL, contentTypeHeader); + context->callback = wrenGetSlotHandle(vm, 5); + httpManager_submit(context); +} + +// --- FFI Binding Functions --- + +WrenForeignMethodFn bindForeignMethod(WrenVM* vm, const char* module, + const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "requests") != 0) return NULL; + + if (strcmp(className, "Requests") == 0 && isStatic) { + if (strcmp(signature, "get_(_,_,_)") == 0) return requestsGet; + if (strcmp(signature, "post_(_,_,_,_,_)") == 0) return requestsPost; + } + + if (strcmp(className, "Response") == 0 && !isStatic) { + if (strcmp(signature, "isError") == 0) return responseIsError; + if (strcmp(signature, "statusCode") == 0) return responseStatusCode; + if (strcmp(signature, "body") == 0) return responseBody; + if (strcmp(signature, "json()") == 0) return responseJson; + } + + return NULL; +} + +WrenForeignClassMethods bindForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0, 0}; + if (strcmp(module, "requests") == 0) { + if (strcmp(className, "Response") == 0) { + methods.allocate = responseAllocate; + methods.finalize = responseFinalize; + } + } + return methods; +} diff --git a/requests_example.wren b/requests_example.wren new file mode 100644 index 0000000..384de49 --- /dev/null +++ b/requests_example.wren @@ -0,0 +1,59 @@ +// main.wren (Corrected) +import "requests" for Requests + +// This class provides a hook back into the C host +foreign class Host { + foreign static signalDone() +} + +var mainFiber = Fiber.new { + System.print("--- Running 20 CONCURRENT GET requests ---") + + var completed = 0 + + for (i in 1..1000) { + Requests.get("https://jsonplaceholder.typicode.com/posts/%(i)", null) { |err, res| + if (err) { + System.print("Request #%(i) [GET] Error: %(err)") + } else { + System.print("Request #%(i) [GET] Status: %(res.statusCode)") + } + // CORRECTED: Create a new fiber for each atomic operation + Fiber.new { completed = completed + 1 }.call() + } + } + + // Wait for GET requests to finish + while (completed < 1000) { + Fiber.yield() + } + + System.print("\n--- Running 20 CONCURRENT POST requests ---") + + var postBody = """ + { "title": "wren-test" } + """ + var headers = { "Content-Type": "application/json; charset=UTF-8" } + var postCompleted = 0 + for (i in 1..1000) { + Requests.post("https://jsonplaceholder.typicode.com/posts", postBody, headers) { |err, res| + if (err) { + System.print("Request #%(i) [POST] Error: %(err)") + } else { + System.print("Request #%(i) [POST] Response Code: %(res.statusCode)") + } + // CORRECTED: Create a new fiber for each atomic operation + Fiber.new { postCompleted = postCompleted + 1 }.call() + } + } + + // Wait for POST requests to finish + while (postCompleted < 1000) { + Fiber.yield() + } + + System.print("\n--- All concurrent requests finished. ---") + + // Tell the C host that we are done. + Host.signalDone() +} diff --git a/socket.wren b/socket.wren new file mode 100644 index 0000000..62392fd --- /dev/null +++ b/socket.wren @@ -0,0 +1,49 @@ +// socket.wren (Corrected) +foreign class Socket { + // CORRECTED: Changed 'new_' to 'new' to match the standard convention. + construct new() {} + + foreign connect(host, port, callback) + foreign listen(host, port, backlog) + foreign accept(callback) + foreign read(bytes) + foreign close() + + foreign isOpen + foreign remoteAddress + foreign remotePort + + // Implemented in Wren + write(data, callback) { + write_(data, callback) + } + + readUntil(delimiter, callback) { + var buffer = "" + var readChunk + readChunk = Fn.new { + this.read(4096) { |err, data| + if (err) { + callback.call(err, null) + return + } + + buffer = buffer + data + var index = buffer.indexOf(delimiter) + + if (index != -1) { + var result = buffer.substring(0, index + delimiter.count) + callback.call(null, result) + } else { + // Delimiter not found, read more data. + readChunk.call() + } + } + } + // Start reading. + readChunk.call() + } + + // Private foreign method for writing + foreign write_(data, callback) +} diff --git a/socket_backend.c b/socket_backend.c new file mode 100644 index 0000000..6bc8eea --- /dev/null +++ b/socket_backend.c @@ -0,0 +1,596 @@ +#include "wren.h" +#include +#include +#include +#include +#include + +// Platform-specific includes and definitions +#ifdef _WIN32 + #include + #include + #include + #pragma comment(lib, "ws2_32.lib") + typedef SOCKET socket_t; + typedef int socklen_t; + typedef HANDLE thread_t; + typedef CRITICAL_SECTION mutex_t; + typedef CONDITION_VARIABLE cond_t; + #define IS_SOCKET_VALID(s) ((s) != INVALID_SOCKET) + #define CLOSE_SOCKET(s) closesocket(s) +#else + #include + #include + #include + #include + #include + #include + #include + #include + typedef int socket_t; + typedef pthread_t thread_t; + typedef pthread_mutex_t mutex_t; + typedef pthread_cond_t cond_t; + #define INVALID_SOCKET -1 + #define IS_SOCKET_VALID(s) ((s) >= 0) + #define CLOSE_SOCKET(s) close(s) +#endif + +// --- Forward Declarations --- +typedef struct SocketContext SocketContext; + +// --- Socket Data Structures --- + +typedef enum { + SOCKET_OP_CONNECT, + SOCKET_OP_ACCEPT, + SOCKET_OP_READ, + SOCKET_OP_WRITE, +} SocketOp; + +typedef struct { + socket_t sock; + bool isListener; +} SocketData; + +struct SocketContext { + SocketOp operation; + WrenVM* vm; + WrenHandle* socketHandle; + WrenHandle* callback; + + // For connect + char* host; + int port; + + // For write + char* data; + size_t dataLength; + + // Results + bool success; + char* resultData; + size_t resultDataLength; + char* errorMessage; + socket_t newSocket; // For accept + struct SocketContext* next; +}; + +// --- Thread-Safe Queue Implementation --- +typedef struct { + SocketContext *head, *tail; + mutex_t mutex; + cond_t cond; +} ThreadSafeQueueSocket; + +void queue_init(ThreadSafeQueueSocket* q) { + q->head = q->tail = NULL; + #ifdef _WIN32 + InitializeCriticalSection(&q->mutex); + InitializeConditionVariable(&q->cond); + #else + pthread_mutex_init(&q->mutex, NULL); + pthread_cond_init(&q->cond, NULL); + #endif +} + +void queue_destroy(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + DeleteCriticalSection(&q->mutex); + #else + pthread_mutex_destroy(&q->mutex); + pthread_cond_destroy(&q->cond); + #endif +} + +void queue_push(ThreadSafeQueueSocket* q, SocketContext* context) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + #endif + + if (context) { + context->next = NULL; + } + + if (q->tail) { + q->tail->next = context; + } else { + q->head = context; + } + q->tail = context; + + #ifdef _WIN32 + WakeConditionVariable(&q->cond); + LeaveCriticalSection(&q->mutex); + #else + pthread_cond_signal(&q->cond); + pthread_mutex_unlock(&q->mutex); + #endif +} + +SocketContext* queue_pop(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + while (q->head == NULL) { + SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE); + } + #else + pthread_mutex_lock(&q->mutex); + while (q->head == NULL) { + pthread_cond_wait(&q->cond, &q->mutex); + } + #endif + + SocketContext* context = q->head; + if (context) { + q->head = q->head->next; + if (q->head == NULL) { + q->tail = NULL; + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_unlock(&q->mutex); + #endif + + return context; +} + +bool queue_empty(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + bool empty = (q->head == NULL); + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + bool empty = (q->head == NULL); + pthread_mutex_unlock(&q->mutex); + #endif + return empty; +} + +// --- Asynchronous Socket Manager --- + +typedef struct { + WrenVM* vm; + volatile bool running; + thread_t worker_threads[4]; + ThreadSafeQueueSocket requestQueue; + ThreadSafeQueueSocket completionQueue; +} AsyncSocketManager; + +static AsyncSocketManager* socketManager = NULL; + +void free_socket_context_data(SocketContext* context) { + if (!context) return; + free(context->host); + free(context->data); + free(context->resultData); + free(context->errorMessage); + free(context); +} + +#ifdef _WIN32 +DWORD WINAPI workerThread(LPVOID arg); +#else +void* workerThread(void* arg); +#endif + +// --- Worker Thread Implementation --- + +#ifdef _WIN32 +DWORD WINAPI workerThread(LPVOID arg) { +#else +void* workerThread(void* arg) { +#endif + AsyncSocketManager* manager = (AsyncSocketManager*)arg; + + while (manager->running) { + SocketContext* context = queue_pop(&manager->requestQueue); + if (!context || !manager->running) { + if (context) free_socket_context_data(context); + break; + } + + wrenEnsureSlots(context->vm, 1); + wrenSetSlotHandle(context->vm, 0, context->socketHandle); + SocketData* socketData = (wrenGetSlotType(context->vm, 0) == WREN_TYPE_FOREIGN) + ? (SocketData*)wrenGetSlotForeign(context->vm, 0) + : NULL; + + if (!socketData || !IS_SOCKET_VALID(socketData->sock)) { + context->success = false; + context->errorMessage = strdup("Invalid or closed socket object."); + queue_push(&manager->completionQueue, context); + continue; + } + + switch (context->operation) { + case SOCKET_OP_CONNECT: { + struct addrinfo hints = {0}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + char port_str[6]; + snprintf(port_str, 6, "%d", context->port); + if (getaddrinfo(context->host, port_str, &hints, &addrs) != 0) { + context->success = false; + context->errorMessage = strdup("Host lookup failed."); + break; + } + + socket_t sock = INVALID_SOCKET; + for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + if (connect(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock)) { + socketData->sock = sock; + socketData->isListener = false; + context->success = true; + } else { + context->success = false; + context->errorMessage = strdup("Connection failed."); + } + break; + } + case SOCKET_OP_ACCEPT: { + if (!socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot accept on a non-listening socket."); + break; + } + + // This is a blocking call. The worker thread will wait here. + context->newSocket = accept(socketData->sock, NULL, NULL); + context->success = IS_SOCKET_VALID(context->newSocket); + if (!context->success) { + #ifdef _WIN32 + // TODO: A more descriptive error using FormatMessageA + context->errorMessage = strdup("Accept failed."); + #else + context->errorMessage = strdup(strerror(errno)); + #endif + } + break; + } + case SOCKET_OP_READ: { + if (socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot read from a listening socket."); + break; + } + + char buf[4096]; + // This is a blocking call. + ssize_t len = recv(socketData->sock, buf, sizeof(buf), 0); + if (len > 0) { + context->resultData = (char*)malloc(len); + memcpy(context->resultData, buf, len); + context->resultDataLength = len; + context->success = true; + } else { + context->success = false; + if (len == 0) { + context->errorMessage = strdup("Connection closed by peer."); + } else { + #ifdef _WIN32 + context->errorMessage = strdup("Read failed."); + #else + context->errorMessage = strdup(strerror(errno)); + #endif + } + } + break; + } + case SOCKET_OP_WRITE: { + if (socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot write to a listening socket."); + break; + } + ssize_t written = send(socketData->sock, context->data, context->dataLength, 0); + context->success = (written == (ssize_t)context->dataLength); + if(!context->success) context->errorMessage = strdup("Write failed."); + break; + } + } + queue_push(&manager->completionQueue, context); + } + return 0; +} + +// --- Manager Lifecycle --- + +void socketManager_create(WrenVM* vm) { + if (socketManager != NULL) return; + socketManager = (AsyncSocketManager*)malloc(sizeof(AsyncSocketManager)); + socketManager->vm = vm; + socketManager->running = true; + + queue_init(&socketManager->requestQueue); + queue_init(&socketManager->completionQueue); + + for (int i = 0; i < 4; i++) { + #ifdef _WIN32 + socketManager->worker_threads[i] = CreateThread(NULL, 0, workerThread, socketManager, 0, NULL); + #else + pthread_create(&socketManager->worker_threads[i], NULL, workerThread, socketManager); + #endif + } +} + +void socketManager_destroy() { + if (!socketManager) return; + socketManager->running = false; + + // Unblock all worker threads + for (int i = 0; i < 4; i++) { + queue_push(&socketManager->requestQueue, NULL); + } + + // Wait for threads to finish + for (int i = 0; i < 4; i++) { + #ifdef _WIN32 + WaitForSingleObject(socketManager->worker_threads[i], INFINITE); + CloseHandle(socketManager->worker_threads[i]); + #else + pthread_join(socketManager->worker_threads[i], NULL); + #endif + } + + // Clean up any remaining contexts in queues + while (!queue_empty(&socketManager->requestQueue)) { + free_socket_context_data(queue_pop(&socketManager->requestQueue)); + } + while (!queue_empty(&socketManager->completionQueue)) { + free_socket_context_data(queue_pop(&socketManager->completionQueue)); + } + + queue_destroy(&socketManager->requestQueue); + queue_destroy(&socketManager->completionQueue); + + free(socketManager); + socketManager = NULL; +} + +void socketManager_processCompletions() { + if (!socketManager || queue_empty(&socketManager->completionQueue)) return; + + WrenHandle* callHandle = wrenMakeCallHandle(socketManager->vm, "call(_,_)"); + while (!queue_empty(&socketManager->completionQueue)) { + SocketContext* context = queue_pop(&socketManager->completionQueue); + + wrenEnsureSlots(socketManager->vm, 3); + wrenSetSlotHandle(socketManager->vm, 0, context->callback); + if (context->success) { + wrenSetSlotNull(socketManager->vm, 1); // error slot + if (IS_SOCKET_VALID(context->newSocket)) { // Accept succeeded + wrenGetVariable(socketManager->vm, "socket", "Socket", 2); + void* foreign = wrenSetSlotNewForeign(socketManager->vm, 2, 2, sizeof(SocketData)); + SocketData* clientData = (SocketData*)foreign; + clientData->sock = context->newSocket; + clientData->isListener = false; + } else if (context->resultData) { // Read succeeded + wrenSetSlotBytes(socketManager->vm, 2, context->resultData, context->resultDataLength); + } else { // Other successes (connect, write) + wrenSetSlotNull(socketManager->vm, 2); + } + } else { + wrenSetSlotString(socketManager->vm, 1, context->errorMessage ? context->errorMessage : "Unknown error."); + wrenSetSlotNull(socketManager->vm, 2); + } + + wrenCall(socketManager->vm, callHandle); + + wrenReleaseHandle(socketManager->vm, context->socketHandle); + wrenReleaseHandle(socketManager->vm, context->callback); + free_socket_context_data(context); + } + wrenReleaseHandle(socketManager->vm, callHandle); +} + +// --- Wren Foreign Methods --- + +void socketAllocate(WrenVM* vm) { + SocketData* data = (SocketData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(SocketData)); + data->sock = INVALID_SOCKET; + data->isListener = false; +} + +void socketConnect(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_CONNECT; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->host = strdup(wrenGetSlotString(vm, 1)); + context->port = (int)wrenGetSlotDouble(vm, 2); + context->callback = wrenGetSlotHandle(vm, 3); + queue_push(&socketManager->requestQueue, context); +} + +void socketListen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + const char* host = wrenGetSlotString(vm, 1); + int port = (int)wrenGetSlotDouble(vm, 2); + int backlog = (int)wrenGetSlotDouble(vm, 3); + + struct addrinfo hints = {0}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + char port_str[6]; + snprintf(port_str, 6, "%d", port); + if (getaddrinfo(host, port_str, &hints, &addrs) != 0) { + wrenSetSlotBool(vm, 0, false); + return; + } + + socket_t sock = INVALID_SOCKET; + for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&yes, sizeof(yes)); + + if (bind(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock) && listen(sock, backlog) == 0) { + data->sock = sock; + data->isListener = true; + wrenSetSlotBool(vm, 0, true); + } else { + if(IS_SOCKET_VALID(sock)) CLOSE_SOCKET(sock); + wrenSetSlotBool(vm, 0, false); + } +} + +void socketAccept(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_ACCEPT; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + queue_push(&socketManager->requestQueue, context); +} + +void socketRead(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_READ; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + queue_push(&socketManager->requestQueue, context); +} + +void socketWrite(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_WRITE; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + int len; + const char* bytes = wrenGetSlotBytes(vm, 1, &len); + context->data = (char*)malloc(len); + memcpy(context->data, bytes, len); + context->dataLength = len; + context->callback = wrenGetSlotHandle(vm, 2); + queue_push(&socketManager->requestQueue, context); +} + +void socketClose(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (IS_SOCKET_VALID(data->sock)) { + CLOSE_SOCKET(data->sock); + data->sock = INVALID_SOCKET; + } +} + +void socketIsOpen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, IS_SOCKET_VALID(data->sock)); +} + +void socketRemoteAddress(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (!IS_SOCKET_VALID(data->sock) || data->isListener) { + wrenSetSlotNull(vm, 0); + return; + } + + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + char ipstr[INET6_ADDRSTRLEN]; + + if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) { + if (addr.ss_family == AF_INET) { + inet_ntop(AF_INET, &((struct sockaddr_in*)&addr)->sin_addr, ipstr, sizeof(ipstr)); + } else { + inet_ntop(AF_INET6, &((struct sockaddr_in6*)&addr)->sin6_addr, ipstr, sizeof(ipstr)); + } + wrenSetSlotString(vm, 0, ipstr); + } else { + wrenSetSlotNull(vm, 0); + } +} + +void socketRemotePort(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (!IS_SOCKET_VALID(data->sock) || data->isListener) { + wrenSetSlotNull(vm, 0); + return; + } + + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + + if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) { + int port = 0; + if (addr.ss_family == AF_INET) { + port = ntohs(((struct sockaddr_in*)&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = ntohs(((struct sockaddr_in6*)&addr)->sin6_port); + } + wrenSetSlotDouble(vm, 0, (double)port); + } else { + wrenSetSlotNull(vm, 0); + } +} + +WrenForeignMethodFn bindSocketForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "socket") != 0) return NULL; + if (strcmp(className, "Socket") == 0 && !isStatic) { + if (strcmp(signature, "connect(_,_,_)") == 0) return socketConnect; + if (strcmp(signature, "listen(_,_,_)") == 0) return socketListen; + if (strcmp(signature, "accept(_)") == 0) return socketAccept; + if (strcmp(signature, "read(_)") == 0) return socketRead; + if (strcmp(signature, "write_(_,_)") == 0) return socketWrite; + if (strcmp(signature, "close()") == 0) return socketClose; + if (strcmp(signature, "isOpen") == 0) return socketIsOpen; + if (strcmp(signature, "remoteAddress") == 0) return socketRemoteAddress; + if (strcmp(signature, "remotePort") == 0) return socketRemotePort; + } + return NULL; +} + +WrenForeignClassMethods bindSocketForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0, 0}; + if (strcmp(module, "socket") == 0 && strcmp(className, "Socket") == 0) { + methods.allocate = socketAllocate; + } + return methods; +} diff --git a/socket_backend.cpp b/socket_backend.cpp new file mode 100644 index 0000000..1af211d --- /dev/null +++ b/socket_backend.cpp @@ -0,0 +1,384 @@ +// socket_backend.cpp (Native Sockets Implementation) +#include "wren.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 + #include + #include + #pragma comment(lib, "ws2_32.lib") + typedef SOCKET socket_t; + #define IS_SOCKET_VALID(s) ((s) != INVALID_SOCKET) + #define CLOSE_SOCKET(s) closesocket(s) +#else + #include + #include + #include + #include + #include + #include + #include + typedef int socket_t; + #define INVALID_SOCKET -1 + #define IS_SOCKET_VALID(s) ((s) >= 0) + #define CLOSE_SOCKET(s) close(s) +#endif + + +// --- Thread-Safe Queue for Asynchronous Operations --- + +template +class ThreadSafeQueue { +public: + void push(T item) { + std::lock_guard lock(mutex_); + queue_.push(item); + cond_.notify_one(); + } + + T pop() { + std::unique_lock lock(mutex_); + cond_.wait(lock, [this] { return !queue_.empty(); }); + T item = queue_.front(); + queue_.pop(); + return item; + } + + bool empty() { + std::lock_guard lock(mutex_); + return queue_.empty(); + } + +private: + std::queue queue_; + std::mutex mutex_; + std::condition_variable cond_; +}; + +// --- Context for Asynchronous Socket Operations --- + +enum class SocketOp { + CONNECT, + ACCEPT, + READ, + WRITE +}; + +struct SocketData { + socket_t sock; + bool isListener; +}; + +struct SocketContext { + SocketOp operation; + WrenVM* vm; + WrenHandle* socketHandle; + WrenHandle* callback; + + // Operation-specific data + std::string host; + int port; + std::string data; + int bytesToRead; + + // Result data + bool success; + std::string resultData; + std::string errorMessage; + socket_t newSocket; // For accepted connections +}; + + +// --- Asynchronous Socket Manager --- + +class AsyncSocketManager { +public: + AsyncSocketManager(WrenVM* vm) : vm_(vm), running_(true) { + for (int i = 0; i < 4; ++i) { + threads_.emplace_back([this] { workerThread(); }); + } + } + + ~AsyncSocketManager() { + running_ = false; + for (size_t i = 0; i < threads_.size(); ++i) { + requestQueue_.push(nullptr); + } + for (auto& thread : threads_) { + thread.join(); + } + } + + void submit(SocketContext* context) { + requestQueue_.push(context); + } + + bool completionQueueEmpty() { + return completionQueue_.empty(); + } + + void processCompletions() { + while (!completionQueue_.empty()) { + SocketContext* context = completionQueue_.pop(); + + WrenHandle* callHandle = wrenMakeCallHandle(vm_, "call(_,_)"); + wrenEnsureSlots(vm_, 3); + wrenSetSlotHandle(vm_, 0, context->callback); + + if (context->success) { + wrenSetSlotNull(vm_, 1); // No error + switch (context->operation) { + case SocketOp::ACCEPT: { + wrenGetVariable(vm_, "socket", "Socket", 2); + void* foreign = wrenSetSlotNewForeign(vm_, 2, 2, sizeof(SocketData)); + SocketData* clientData = (SocketData*)foreign; + clientData->sock = context->newSocket; + clientData->isListener = false; + break; + } + case SocketOp::READ: + wrenSetSlotBytes(vm_, 2, context->resultData.c_str(), context->resultData.length()); + break; + default: + wrenSetSlotNull(vm_, 2); + break; + } + } else { + wrenSetSlotString(vm_, 1, context->errorMessage.c_str()); + wrenSetSlotNull(vm_, 2); + } + + wrenCall(vm_, callHandle); + + wrenReleaseHandle(vm_, context->socketHandle); + wrenReleaseHandle(vm_, context->callback); + wrenReleaseHandle(vm_, callHandle); + delete context; + } + } + +private: + void workerThread() { + while (running_) { + SocketContext* context = requestQueue_.pop(); + if (!context || !running_) break; + + wrenEnsureSlots(context->vm, 1); + wrenSetSlotHandle(context->vm, 0, context->socketHandle); + SocketData* socketData = (wrenGetSlotType(context->vm, 0) == WREN_TYPE_FOREIGN) + ? (SocketData*)wrenGetSlotForeign(context->vm, 0) + : nullptr; + + if (!socketData) { + context->success = false; + context->errorMessage = "Invalid socket object."; + completionQueue_.push(context); + continue; + } + + switch (context->operation) { + case SocketOp::CONNECT: { + addrinfo hints = {}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + getaddrinfo(context->host.c_str(), std::to_string(context->port).c_str(), &hints, &addrs); + + socket_t sock = INVALID_SOCKET; + for (addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + if (connect(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock)) { + socketData->sock = sock; + socketData->isListener = false; + context->success = true; + } else { + context->success = false; + context->errorMessage = "Connection failed."; + } + break; + } + case SocketOp::ACCEPT: { + if (socketData->isListener) { + context->newSocket = accept(socketData->sock, nullptr, nullptr); + context->success = IS_SOCKET_VALID(context->newSocket); + if (!context->success) context->errorMessage = "Accept failed."; + } else { + context->success = false; + context->errorMessage = "Cannot accept on a non-listening socket."; + } + break; + } + case SocketOp::READ: { + if (!socketData->isListener) { + char buf[4096]; + ssize_t len = recv(socketData->sock, buf, sizeof(buf), 0); + if (len > 0) { + context->resultData.assign(buf, len); + context->success = true; + } else { + context->success = false; + context->errorMessage = "Read failed or connection closed."; + } + } + break; + } + case SocketOp::WRITE: { + if (!socketData->isListener) { + ssize_t written = send(socketData->sock, context->data.c_str(), context->data.length(), 0); + context->success = (written == (ssize_t)context->data.length()); + if(!context->success) context->errorMessage = "Write failed."; + } + break; + } + } + completionQueue_.push(context); + } + } + + WrenVM* vm_; + std::atomic running_; + std::vector threads_; + ThreadSafeQueue requestQueue_; + ThreadSafeQueue completionQueue_; +}; + +static AsyncSocketManager* socketManager = nullptr; + +// --- Socket Foreign Class/Methods --- + +void socketAllocate(WrenVM* vm) { + SocketData* data = (SocketData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(SocketData)); + data->sock = INVALID_SOCKET; + data->isListener = false; +} + +void socketConnect(WrenVM* vm) { + SocketContext* context = new SocketContext(); + context->operation = SocketOp::CONNECT; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->host = wrenGetSlotString(vm, 1); + context->port = (int)wrenGetSlotDouble(vm, 2); + context->callback = wrenGetSlotHandle(vm, 3); + socketManager->submit(context); + wrenSetSlotNull(vm, 0); +} + +void socketListen(WrenVM* vm) { + const char* host = wrenGetSlotString(vm, 1); + int port = (int)wrenGetSlotDouble(vm, 2); + int backlog = (int)wrenGetSlotDouble(vm, 3); + + addrinfo hints = {}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + getaddrinfo(host, std::to_string(port).c_str(), &hints, &addrs); + + socket_t sock = INVALID_SOCKET; + for (addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&yes, sizeof(yes)); + + if (bind(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock) && listen(sock, backlog) == 0) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + data->sock = sock; + data->isListener = true; + wrenSetSlotBool(vm, 0, true); + } else { + if(IS_SOCKET_VALID(sock)) CLOSE_SOCKET(sock); + wrenSetSlotBool(vm, 0, false); + } +} + +void socketAccept(WrenVM* vm) { + SocketContext* context = new SocketContext(); + context->operation = SocketOp::ACCEPT; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + socketManager->submit(context); + wrenSetSlotNull(vm, 0); +} + +void socketRead(WrenVM* vm) { + SocketContext* context = new SocketContext(); + context->operation = SocketOp::READ; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->bytesToRead = (int)wrenGetSlotDouble(vm, 1); + context->callback = wrenGetSlotHandle(vm, 2); + socketManager->submit(context); + wrenSetSlotNull(vm, 0); +} + +void socketWrite(WrenVM* vm) { + SocketContext* context = new SocketContext(); + context->operation = SocketOp::WRITE; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + int len; + const char* bytes = wrenGetSlotBytes(vm, 1, &len); + context->data.assign(bytes, len); + context->callback = wrenGetSlotHandle(vm, 2); + socketManager->submit(context); + wrenSetSlotNull(vm, 0); +} + +void socketClose(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (IS_SOCKET_VALID(data->sock)) { + CLOSE_SOCKET(data->sock); + data->sock = INVALID_SOCKET; + } +} + +void socketIsOpen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, IS_SOCKET_VALID(data->sock)); +} + +WrenForeignMethodFn bindSocketForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "socket") != 0) return NULL; + if (strcmp(className, "Socket") == 0 && !isStatic) { + if (strcmp(signature, "connect(_,_,_)") == 0) return socketConnect; + if (strcmp(signature, "listen(_,_,_)") == 0) return socketListen; + if (strcmp(signature, "accept(_)") == 0) return socketAccept; + if (strcmp(signature, "read(_,_)") == 0) return socketRead; + if (strcmp(signature, "write(_,_)") == 0) return socketWrite; + if (strcmp(signature, "close()") == 0) return socketClose; + if (strcmp(signature, "isOpen") == 0) return socketIsOpen; + } + return NULL; +} + +WrenForeignClassMethods bindSocketForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0}; + if (strcmp(module, "socket") == 0 && strcmp(className, "Socket") == 0) { + methods.allocate = socketAllocate; + } + return methods; +} diff --git a/socket_backend_20250729_141833.c b/socket_backend_20250729_141833.c new file mode 100644 index 0000000..6bb0b1c --- /dev/null +++ b/socket_backend_20250729_141833.c @@ -0,0 +1,747 @@ +// socket_backend.c (Corrected with better handle safety and non-blocking I/O) +#include "wren.h" +#include +#include +#include +#include +#include + +// Platform-specific includes and definitions +#ifdef _WIN32 + #include + #include + #include + #pragma comment(lib, "ws2_32.lib") + typedef SOCKET socket_t; + typedef int socklen_t; + typedef HANDLE thread_t; + typedef CRITICAL_SECTION mutex_t; + typedef CONDITION_VARIABLE cond_t; + #define IS_SOCKET_VALID(s) ((s) != INVALID_SOCKET) + #define CLOSE_SOCKET(s) closesocket(s) +#else + #include + #include + #include + #include + #include + #include + #include + #include + #include + typedef int socket_t; + typedef pthread_t thread_t; + typedef pthread_mutex_t mutex_t; + typedef pthread_cond_t cond_t; + #define INVALID_SOCKET -1 + #define IS_SOCKET_VALID(s) ((s) >= 0) + #define CLOSE_SOCKET(s) close(s) +#endif + +// --- Forward Declarations --- +typedef struct SocketContext SocketContext; + +// --- Socket Data Structures --- + +typedef enum { + SOCKET_OP_CONNECT, + SOCKET_OP_READ, + SOCKET_OP_WRITE, +} SocketOp; + +typedef struct { + socket_t sock; + bool isListener; +} SocketData; + +struct SocketContext { + SocketOp operation; + WrenVM* vm; + WrenHandle* socketHandle; + WrenHandle* callback; + + char* host; + int port; + char* data; + size_t dataLength; + + bool success; + char* resultData; + size_t resultDataLength; + char* errorMessage; + socket_t newSocket; + struct SocketContext* next; +}; + +// --- Thread-Safe Queue Implementation in C --- +typedef struct { + SocketContext *head, *tail; + mutex_t mutex; + cond_t cond; +} ThreadSafeQueueSocket; + +void queue_init(ThreadSafeQueueSocket* q) { + q->head = q->tail = NULL; + #ifdef _WIN32 + InitializeCriticalSection(&q->mutex); + InitializeConditionVariable(&q->cond); + #else + pthread_mutex_init(&q->mutex, NULL); + pthread_cond_init(&q->cond, NULL); + #endif +} + +void queue_destroy(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + DeleteCriticalSection(&q->mutex); + #else + pthread_mutex_destroy(&q->mutex); + pthread_cond_destroy(&q->cond); + #endif +} + +void queue_push(ThreadSafeQueueSocket* q, SocketContext* context) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + #endif + + if (context) { + context->next = NULL; + } + + if (q->tail) { + q->tail->next = context; + } else { + q->head = context; + } + q->tail = context; + + #ifdef _WIN32 + WakeConditionVariable(&q->cond); + LeaveCriticalSection(&q->mutex); + #else + pthread_cond_signal(&q->cond); + pthread_mutex_unlock(&q->mutex); + #endif +} + +SocketContext* queue_pop(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + while (q->head == NULL) { + SleepConditionVariableCS(&q->cond, &q->mutex, INFINITE); + } + #else + pthread_mutex_lock(&q->mutex); + while (q->head == NULL) { + pthread_cond_wait(&q->cond, &q->mutex); + } + #endif + + SocketContext* context = q->head; + q->head = q->head->next; + if (q->head == NULL) { + q->tail = NULL; + } + + #ifdef _WIN32 + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_unlock(&q->mutex); + #endif + + return context; +} + +bool queue_empty(ThreadSafeQueueSocket* q) { + #ifdef _WIN32 + EnterCriticalSection(&q->mutex); + bool empty = (q->head == NULL); + LeaveCriticalSection(&q->mutex); + #else + pthread_mutex_lock(&q->mutex); + bool empty = (q->head == NULL); + pthread_mutex_unlock(&q->mutex); + #endif + return empty; +} + +// --- Asynchronous Socket Manager --- + +#define MAX_LISTENERS 64 + +typedef struct { + WrenVM* vm; + volatile bool running; + thread_t worker_threads[4]; + thread_t listener_thread; + + ThreadSafeQueueSocket requestQueue; + ThreadSafeQueueSocket completionQueue; + ThreadSafeQueueSocket acceptQueue; + + mutex_t listener_mutex; + socket_t listener_sockets[MAX_LISTENERS]; + int listener_count; + #ifndef _WIN32 + socket_t wake_pipe[2]; + #endif +} AsyncSocketManager; + +static AsyncSocketManager* socketManager = NULL; + +void free_socket_context_data(SocketContext* context) { + if (!context) return; + free(context->host); + free(context->data); + free(context->resultData); + free(context->errorMessage); + free(context); +} + +#ifdef _WIN32 +DWORD WINAPI workerThread(LPVOID arg); +DWORD WINAPI listenerThread(LPVOID arg); +#else +void* workerThread(void* arg); +void* listenerThread(void* arg); +#endif + +// --- Worker and Listener Thread Implementations --- + +#ifdef _WIN32 +DWORD WINAPI listenerThread(LPVOID arg) { +#else +void* listenerThread(void* arg) { +#endif + AsyncSocketManager* manager = (AsyncSocketManager*)arg; + + while (manager->running) { + fd_set read_fds; + FD_ZERO(&read_fds); + + socket_t max_fd = 0; + + #ifndef _WIN32 + FD_SET(manager->wake_pipe[0], &read_fds); + max_fd = manager->wake_pipe[0]; + #endif + + #ifdef _WIN32 + EnterCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_lock(&manager->listener_mutex); + #endif + + for (int i = 0; i < manager->listener_count; i++) { + socket_t sock = manager->listener_sockets[i]; + if (IS_SOCKET_VALID(sock)) { + FD_SET(sock, &read_fds); + if (sock > max_fd) { + max_fd = sock; + } + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_unlock(&manager->listener_mutex); + #endif + + struct timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + + int activity = select(max_fd + 1, &read_fds, NULL, NULL, &timeout); + + if (!manager->running) break; + if (activity < 0) { + #ifndef _WIN32 + if (errno != EINTR) { + perror("select error"); + } + #endif + continue; + } + if (activity == 0) continue; + + #ifndef _WIN32 + if (FD_ISSET(manager->wake_pipe[0], &read_fds)) { + char buffer[1]; + read(manager->wake_pipe[0], buffer, 1); + } + #endif + + #ifdef _WIN32 + EnterCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_lock(&manager->listener_mutex); + #endif + + for (int i = 0; i < manager->listener_count; i++) { + socket_t sock = manager->listener_sockets[i]; + if (IS_SOCKET_VALID(sock) && FD_ISSET(sock, &read_fds)) { + if (!queue_empty(&manager->acceptQueue)) { + SocketContext* context = queue_pop(&manager->acceptQueue); + context->newSocket = accept(sock, NULL, NULL); + context->success = IS_SOCKET_VALID(context->newSocket); + if (!context->success) { + context->errorMessage = strdup("Accept failed."); + } + queue_push(&manager->completionQueue, context); + } + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&manager->listener_mutex); + #else + pthread_mutex_unlock(&manager->listener_mutex); + #endif + } + return 0; +} + +#ifdef _WIN32 +DWORD WINAPI workerThread(LPVOID arg) { +#else +void* workerThread(void* arg) { +#endif + AsyncSocketManager* manager = (AsyncSocketManager*)arg; + + while (manager->running) { + SocketContext* context = queue_pop(&manager->requestQueue); + if (!context || !manager->running) { + if (context) free_socket_context_data(context); + break; + } + + wrenEnsureSlots(context->vm, 1); + wrenSetSlotHandle(context->vm, 0, context->socketHandle); + SocketData* socketData = (wrenGetSlotType(context->vm, 0) == WREN_TYPE_FOREIGN) + ? (SocketData*)wrenGetSlotForeign(context->vm, 0) + : NULL; + + if (!socketData) { + context->success = false; + context->errorMessage = strdup("Invalid socket object."); + queue_push(&manager->completionQueue, context); + continue; + } + + switch (context->operation) { + case SOCKET_OP_CONNECT: { + struct addrinfo hints = {0}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + char port_str[6]; + snprintf(port_str, 6, "%d", context->port); + if (getaddrinfo(context->host, port_str, &hints, &addrs) != 0) { + context->success = false; + context->errorMessage = strdup("Host lookup failed."); + break; + } + + socket_t sock = INVALID_SOCKET; + for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + if (connect(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock)) { + socketData->sock = sock; + socketData->isListener = false; + context->success = true; + } else { + context->success = false; + context->errorMessage = strdup("Connection failed."); + } + break; + } + case SOCKET_OP_READ: { + if (socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot read from a listening socket."); + break; + } + + fd_set read_fds; + FD_ZERO(&read_fds); + FD_SET(socketData->sock, &read_fds); + struct timeval timeout = { .tv_sec = 5, .tv_usec = 0 }; // 5-second timeout + + int activity = select(socketData->sock + 1, &read_fds, NULL, NULL, &timeout); + if (activity > 0 && FD_ISSET(socketData->sock, &read_fds)) { + char buf[4096]; + ssize_t len = recv(socketData->sock, buf, sizeof(buf), 0); + if (len > 0) { + context->resultData = (char*)malloc(len); + memcpy(context->resultData, buf, len); + context->resultDataLength = len; + context->success = true; + } else { + context->success = false; + context->errorMessage = strdup("Read failed or connection closed."); + } + } else { + context->success = false; + context->errorMessage = strdup("Read timeout or error."); + } + break; + } + case SOCKET_OP_WRITE: { + if (socketData->isListener) { + context->success = false; + context->errorMessage = strdup("Cannot write to a listening socket."); + break; + } + ssize_t written = send(socketData->sock, context->data, context->dataLength, 0); + context->success = (written == (ssize_t)context->dataLength); + if(!context->success) context->errorMessage = strdup("Write failed."); + break; + } + } + queue_push(&manager->completionQueue, context); + } + return 0; +} + +// --- Manager Lifecycle --- + +void socketManager_create(WrenVM* vm) { + socketManager = (AsyncSocketManager*)malloc(sizeof(AsyncSocketManager)); + socketManager->vm = vm; + socketManager->running = true; + socketManager->listener_count = 0; + + queue_init(&socketManager->requestQueue); + queue_init(&socketManager->completionQueue); + queue_init(&socketManager->acceptQueue); + + #ifdef _WIN32 + InitializeCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_init(&socketManager->listener_mutex, NULL); + #endif + + #ifndef _WIN32 + if (pipe(socketManager->wake_pipe) == -1) { + perror("pipe"); + exit(1); + } + #endif + + for (int i = 0; i < 4; i++) { + #ifdef _WIN32 + socketManager->worker_threads[i] = CreateThread(NULL, 0, workerThread, socketManager, 0, NULL); + #else + pthread_create(&socketManager->worker_threads[i], NULL, workerThread, socketManager); + #endif + } + + #ifdef _WIN32 + socketManager->listener_thread = CreateThread(NULL, 0, listenerThread, socketManager, 0, NULL); + #else + pthread_create(&socketManager->listener_thread, NULL, listenerThread, socketManager); + #endif +} + +void socketManager_destroy() { + socketManager->running = false; + + #ifndef _WIN32 + write(socketManager->wake_pipe[1], "w", 1); + #endif + + for (int i = 0; i < 4; i++) { + queue_push(&socketManager->requestQueue, NULL); + } + + #ifdef _WIN32 + WaitForSingleObject(socketManager->listener_thread, INFINITE); + CloseHandle(socketManager->listener_thread); + for (int i = 0; i < 4; i++) { + WaitForSingleObject(socketManager->worker_threads[i], INFINITE); + CloseHandle(socketManager->worker_threads[i]); + } + #else + pthread_join(socketManager->listener_thread, NULL); + for (int i = 0; i < 4; i++) { + pthread_join(socketManager->worker_threads[i], NULL); + } + close(socketManager->wake_pipe[0]); + close(socketManager->wake_pipe[1]); + #endif + + queue_destroy(&socketManager->requestQueue); + queue_destroy(&socketManager->completionQueue); + queue_destroy(&socketManager->acceptQueue); + + #ifdef _WIN32 + DeleteCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_destroy(&socketManager->listener_mutex); + #endif + + free(socketManager); +} + +void socketManager_processCompletions() { + WrenHandle* callHandle = wrenMakeCallHandle(socketManager->vm, "call(_,_)"); + while (!queue_empty(&socketManager->completionQueue)) { + SocketContext* context = queue_pop(&socketManager->completionQueue); + + wrenEnsureSlots(socketManager->vm, 3); + wrenSetSlotHandle(socketManager->vm, 0, context->callback); + if (context->success) { + wrenSetSlotNull(socketManager->vm, 1); + if (IS_SOCKET_VALID(context->newSocket)) { + wrenGetVariable(socketManager->vm, "socket", "Socket", 2); + void* foreign = wrenSetSlotNewForeign(socketManager->vm, 2, 2, sizeof(SocketData)); + SocketData* clientData = (SocketData*)foreign; + clientData->sock = context->newSocket; + clientData->isListener = false; + } else if (context->resultData) { + wrenSetSlotBytes(socketManager->vm, 2, context->resultData, context->resultDataLength); + } else { + wrenSetSlotNull(socketManager->vm, 2); + } + } else { + wrenSetSlotString(socketManager->vm, 1, context->errorMessage ? context->errorMessage : "Unknown error."); + wrenSetSlotNull(socketManager->vm, 2); + } + + wrenCall(socketManager->vm, callHandle); + + // Safely release handles here on the main thread + wrenReleaseHandle(socketManager->vm, context->socketHandle); + wrenReleaseHandle(socketManager->vm, context->callback); + free_socket_context_data(context); + } + wrenReleaseHandle(socketManager->vm, callHandle); +} + +// ... (The rest of the foreign functions from socketAllocate onwards are identical to the previous response) ... +void socketAllocate(WrenVM* vm) { + SocketData* data = (SocketData*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(SocketData)); + data->sock = INVALID_SOCKET; + data->isListener = false; +} + +void socketConnect(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_CONNECT; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->host = strdup(wrenGetSlotString(vm, 1)); + context->port = (int)wrenGetSlotDouble(vm, 2); + context->callback = wrenGetSlotHandle(vm, 3); + queue_push(&socketManager->requestQueue, context); +} + +void socketListen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + const char* host = wrenGetSlotString(vm, 1); + int port = (int)wrenGetSlotDouble(vm, 2); + int backlog = (int)wrenGetSlotDouble(vm, 3); + + struct addrinfo hints = {0}, *addrs; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + char port_str[6]; + snprintf(port_str, 6, "%d", port); + if (getaddrinfo(host, port_str, &hints, &addrs) != 0) { + wrenSetSlotBool(vm, 0, false); + return; + } + + socket_t sock = INVALID_SOCKET; + for (struct addrinfo* addr = addrs; addr; addr = addr->ai_next) { + sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); + if (!IS_SOCKET_VALID(sock)) continue; + + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&yes, sizeof(yes)); + + if (bind(sock, addr->ai_addr, (int)addr->ai_addrlen) == 0) break; + + CLOSE_SOCKET(sock); + sock = INVALID_SOCKET; + } + freeaddrinfo(addrs); + + if (IS_SOCKET_VALID(sock) && listen(sock, backlog) == 0) { + data->sock = sock; + data->isListener = true; + + #ifdef _WIN32 + EnterCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_lock(&socketManager->listener_mutex); + #endif + + if (socketManager->listener_count < MAX_LISTENERS) { + socketManager->listener_sockets[socketManager->listener_count++] = sock; + } + + #ifdef _WIN32 + LeaveCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_unlock(&socketManager->listener_mutex); + #endif + + #ifndef _WIN32 + write(socketManager->wake_pipe[1], "w", 1); + #endif + + wrenSetSlotBool(vm, 0, true); + } else { + if(IS_SOCKET_VALID(sock)) CLOSE_SOCKET(sock); + wrenSetSlotBool(vm, 0, false); + } +} + +void socketAccept(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + queue_push(&socketManager->acceptQueue, context); +} + +void socketRead(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_READ; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + context->callback = wrenGetSlotHandle(vm, 1); + queue_push(&socketManager->requestQueue, context); +} + +void socketWrite(WrenVM* vm) { + SocketContext* context = (SocketContext*)calloc(1, sizeof(SocketContext)); + context->operation = SOCKET_OP_WRITE; + context->vm = vm; + context->socketHandle = wrenGetSlotHandle(vm, 0); + int len; + const char* bytes = wrenGetSlotBytes(vm, 1, &len); + context->data = (char*)malloc(len); + memcpy(context->data, bytes, len); + context->dataLength = len; + context->callback = wrenGetSlotHandle(vm, 2); + queue_push(&socketManager->requestQueue, context); +} + +void socketClose(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (IS_SOCKET_VALID(data->sock)) { + if (data->isListener) { + #ifdef _WIN32 + EnterCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_lock(&socketManager->listener_mutex); + #endif + + for (int i = 0; i < socketManager->listener_count; i++) { + if (socketManager->listener_sockets[i] == data->sock) { + socketManager->listener_sockets[i] = socketManager->listener_sockets[socketManager->listener_count - 1]; + socketManager->listener_count--; + break; + } + } + + #ifdef _WIN32 + LeaveCriticalSection(&socketManager->listener_mutex); + #else + pthread_mutex_unlock(&socketManager->listener_mutex); + #endif + } + CLOSE_SOCKET(data->sock); + data->sock = INVALID_SOCKET; + } +} + +void socketIsOpen(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + wrenSetSlotBool(vm, 0, IS_SOCKET_VALID(data->sock)); +} + +void socketRemoteAddress(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (!IS_SOCKET_VALID(data->sock) || data->isListener) { + wrenSetSlotNull(vm, 0); + return; + } + + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + char ipstr[INET6_ADDRSTRLEN]; + + if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) { + if (addr.ss_family == AF_INET) { + inet_ntop(AF_INET, &((struct sockaddr_in*)&addr)->sin_addr, ipstr, sizeof(ipstr)); + } else { + inet_ntop(AF_INET6, &((struct sockaddr_in6*)&addr)->sin6_addr, ipstr, sizeof(ipstr)); + } + wrenSetSlotString(vm, 0, ipstr); + } else { + wrenSetSlotNull(vm, 0); + } +} + +void socketRemotePort(WrenVM* vm) { + SocketData* data = (SocketData*)wrenGetSlotForeign(vm, 0); + if (!IS_SOCKET_VALID(data->sock) || data->isListener) { + wrenSetSlotNull(vm, 0); + return; + } + + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + + if (getpeername(data->sock, (struct sockaddr*)&addr, &len) == 0) { + int port = 0; + if (addr.ss_family == AF_INET) { + port = ntohs(((struct sockaddr_in*)&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = ntohs(((struct sockaddr_in6*)&addr)->sin6_port); + } + wrenSetSlotDouble(vm, 0, (double)port); + } else { + wrenSetSlotNull(vm, 0); + } +} + +WrenForeignMethodFn bindSocketForeignMethod(WrenVM* vm, const char* module, const char* className, bool isStatic, const char* signature) { + if (strcmp(module, "socket") != 0) return NULL; + if (strcmp(className, "Socket") == 0 && !isStatic) { + if (strcmp(signature, "connect(_,_,_)") == 0) return socketConnect; + if (strcmp(signature, "listen(_,_,_)") == 0) return socketListen; + if (strcmp(signature, "accept(_)") == 0) return socketAccept; + // NOTE: The signature for read() in Wren takes one argument (the callback) now. + if (strcmp(signature, "read(_)") == 0) return socketRead; + if (strcmp(signature, "write_(_,_)") == 0) return socketWrite; + if (strcmp(signature, "close()") == 0) return socketClose; + if (strcmp(signature, "isOpen") == 0) return socketIsOpen; + if (strcmp(signature, "remoteAddress") == 0) return socketRemoteAddress; + if (strcmp(signature, "remotePort") == 0) return socketRemotePort; + } + return NULL; +} + +WrenForeignClassMethods bindSocketForeignClass(WrenVM* vm, const char* module, const char* className) { + WrenForeignClassMethods methods = {0, 0}; + if (strcmp(module, "socket") == 0 && strcmp(className, "Socket") == 0) { + methods.allocate = socketAllocate; + } + return methods; +} diff --git a/socket_example.wren b/socket_example.wren new file mode 100644 index 0000000..9192951 --- /dev/null +++ b/socket_example.wren @@ -0,0 +1,81 @@ +// socket_example.wren (Corrected) +import "socket" for Socket + +System.print("--- Wren Socket Echo Server and Client ---") + +var serverFiber = Fiber.new { + var server = Socket.new() + if (server.listen("localhost", 8080, 5)) { + System.print("Server listening on localhost:8080") + while (server.isOpen) { + server.accept { |err, client| + + if (err) { + System.print("Accept error: %(err)") + return + } + + System.print("Client connected!") + Fiber.new { + while (client.isOpen) { + client.read(4096) { |readErr, data| + if (readErr) { + System.print("Client disconnected.") + client.close() + return + } + System.print("Received: %(data)") + // CORRECTED: Replaced '_' with 'result' + client.write("Echo: %(data)") { |writeErr, result| + if (writeErr) System.print("Write error: %(writeErr)") + } + } + } + }.call() + } + } + } else { + System.print("Failed to start server.") + } +} + +var clientFiber = Fiber.new { + var client = Socket.new() + // CORRECTED: Replaced '_' with 'result' + client.connect("localhost", 8080) { |err, result| + if (err) { + System.print("Client connection error: %(err)") + return + } + + System.print("Client connected to server.") + // CORRECTED: Replaced '_' with 'result' + client.write("Hello from Wren!") { |writeErr, result| + if (writeErr) { + System.print("Client write error: %(writeErr)") + return + } + + client.read(1024) { |readErr, data| + if (readErr) { + System.print("Client read error: %(readErr)") + } else { + System.print("Client received: %(data)") + } + client.close() + } + } + } +} + +// Start the server +serverFiber.call() + +// Give the server a moment to start up before connecting the client +Fiber.sleep(100) + +// Start the client +clientFiber.call() + +// Let the operations complete +Fiber.sleep(1000) diff --git a/wren b/wren new file mode 100755 index 0000000..1023333 Binary files /dev/null and b/wren differ diff --git a/wren.c b/wren.c new file mode 100644 index 0000000..f28c6ea --- /dev/null +++ b/wren.c @@ -0,0 +1,13497 @@ +// MIT License +// +// Copyright (c) 2013-2021 Robert Nystrom and Wren Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Begin file "wren.h" +#ifndef wren_h +#define wren_h + +#include +#include +#include + +// The Wren semantic version number components. +#define WREN_VERSION_MAJOR 0 +#define WREN_VERSION_MINOR 4 +#define WREN_VERSION_PATCH 0 + +// A human-friendly string representation of the version. +#define WREN_VERSION_STRING "0.4.0" + +// A monotonically increasing numeric representation of the version number. Use +// this if you want to do range checks over versions. +#define WREN_VERSION_NUMBER (WREN_VERSION_MAJOR * 1000000 + \ + WREN_VERSION_MINOR * 1000 + \ + WREN_VERSION_PATCH) + +#ifndef WREN_API + #if defined(_MSC_VER) && defined(WREN_API_DLLEXPORT) + #define WREN_API __declspec( dllexport ) + #else + #define WREN_API + #endif +#endif //WREN_API + +// A single virtual machine for executing Wren code. +// +// Wren has no global state, so all state stored by a running interpreter lives +// here. +typedef struct WrenVM WrenVM; + +// A handle to a Wren object. +// +// This lets code outside of the VM hold a persistent reference to an object. +// After a handle is acquired, and until it is released, this ensures the +// garbage collector will not reclaim the object it references. +typedef struct WrenHandle WrenHandle; + +// A generic allocation function that handles all explicit memory management +// used by Wren. It's used like so: +// +// - To allocate new memory, [memory] is NULL and [newSize] is the desired +// size. It should return the allocated memory or NULL on failure. +// +// - To attempt to grow an existing allocation, [memory] is the memory, and +// [newSize] is the desired size. It should return [memory] if it was able to +// grow it in place, or a new pointer if it had to move it. +// +// - To shrink memory, [memory] and [newSize] are the same as above but it will +// always return [memory]. +// +// - To free memory, [memory] will be the memory to free and [newSize] will be +// zero. It should return NULL. +typedef void* (*WrenReallocateFn)(void* memory, size_t newSize, void* userData); + +// A function callable from Wren code, but implemented in C. +typedef void (*WrenForeignMethodFn)(WrenVM* vm); + +// A finalizer function for freeing resources owned by an instance of a foreign +// class. Unlike most foreign methods, finalizers do not have access to the VM +// and should not interact with it since it's in the middle of a garbage +// collection. +typedef void (*WrenFinalizerFn)(void* data); + +// Gives the host a chance to canonicalize the imported module name, +// potentially taking into account the (previously resolved) name of the module +// that contains the import. Typically, this is used to implement relative +// imports. +typedef const char* (*WrenResolveModuleFn)(WrenVM* vm, + const char* importer, const char* name); + +// Forward declare +struct WrenLoadModuleResult; + +// Called after loadModuleFn is called for module [name]. The original returned result +// is handed back to you in this callback, so that you can free memory if appropriate. +typedef void (*WrenLoadModuleCompleteFn)(WrenVM* vm, const char* name, struct WrenLoadModuleResult result); + +// The result of a loadModuleFn call. +// [source] is the source code for the module, or NULL if the module is not found. +// [onComplete] an optional callback that will be called once Wren is done with the result. +typedef struct WrenLoadModuleResult +{ + const char* source; + WrenLoadModuleCompleteFn onComplete; + void* userData; +} WrenLoadModuleResult; + +// Loads and returns the source code for the module [name]. +typedef WrenLoadModuleResult (*WrenLoadModuleFn)(WrenVM* vm, const char* name); + +// Returns a pointer to a foreign method on [className] in [module] with +// [signature]. +typedef WrenForeignMethodFn (*WrenBindForeignMethodFn)(WrenVM* vm, + const char* module, const char* className, bool isStatic, + const char* signature); + +// Displays a string of text to the user. +typedef void (*WrenWriteFn)(WrenVM* vm, const char* text); + +typedef enum +{ + // A syntax or resolution error detected at compile time. + WREN_ERROR_COMPILE, + + // The error message for a runtime error. + WREN_ERROR_RUNTIME, + + // One entry of a runtime error's stack trace. + WREN_ERROR_STACK_TRACE +} WrenErrorType; + +// Reports an error to the user. +// +// An error detected during compile time is reported by calling this once with +// [type] `WREN_ERROR_COMPILE`, the resolved name of the [module] and [line] +// where the error occurs, and the compiler's error [message]. +// +// A runtime error is reported by calling this once with [type] +// `WREN_ERROR_RUNTIME`, no [module] or [line], and the runtime error's +// [message]. After that, a series of [type] `WREN_ERROR_STACK_TRACE` calls are +// made for each line in the stack trace. Each of those has the resolved +// [module] and [line] where the method or function is defined and [message] is +// the name of the method or function. +typedef void (*WrenErrorFn)( + WrenVM* vm, WrenErrorType type, const char* module, int line, + const char* message); + +typedef struct +{ + // The callback invoked when the foreign object is created. + // + // This must be provided. Inside the body of this, it must call + // [wrenSetSlotNewForeign()] exactly once. + WrenForeignMethodFn allocate; + + // The callback invoked when the garbage collector is about to collect a + // foreign object's memory. + // + // This may be `NULL` if the foreign class does not need to finalize. + WrenFinalizerFn finalize; +} WrenForeignClassMethods; + +// Returns a pair of pointers to the foreign methods used to allocate and +// finalize the data for instances of [className] in resolved [module]. +typedef WrenForeignClassMethods (*WrenBindForeignClassFn)( + WrenVM* vm, const char* module, const char* className); + +typedef struct +{ + // The callback Wren will use to allocate, reallocate, and deallocate memory. + // + // If `NULL`, defaults to a built-in function that uses `realloc` and `free`. + WrenReallocateFn reallocateFn; + + // The callback Wren uses to resolve a module name. + // + // Some host applications may wish to support "relative" imports, where the + // meaning of an import string depends on the module that contains it. To + // support that without baking any policy into Wren itself, the VM gives the + // host a chance to resolve an import string. + // + // Before an import is loaded, it calls this, passing in the name of the + // module that contains the import and the import string. The host app can + // look at both of those and produce a new "canonical" string that uniquely + // identifies the module. This string is then used as the name of the module + // going forward. It is what is passed to [loadModuleFn], how duplicate + // imports of the same module are detected, and how the module is reported in + // stack traces. + // + // If you leave this function NULL, then the original import string is + // treated as the resolved string. + // + // If an import cannot be resolved by the embedder, it should return NULL and + // Wren will report that as a runtime error. + // + // Wren will take ownership of the string you return and free it for you, so + // it should be allocated using the same allocation function you provide + // above. + WrenResolveModuleFn resolveModuleFn; + + // The callback Wren uses to load a module. + // + // Since Wren does not talk directly to the file system, it relies on the + // embedder to physically locate and read the source code for a module. The + // first time an import appears, Wren will call this and pass in the name of + // the module being imported. The method will return a result, which contains + // the source code for that module. Memory for the source is owned by the + // host application, and can be freed using the onComplete callback. + // + // This will only be called once for any given module name. Wren caches the + // result internally so subsequent imports of the same module will use the + // previous source and not call this. + // + // If a module with the given name could not be found by the embedder, it + // should return NULL and Wren will report that as a runtime error. + WrenLoadModuleFn loadModuleFn; + + // The callback Wren uses to find a foreign method and bind it to a class. + // + // When a foreign method is declared in a class, this will be called with the + // foreign method's module, class, and signature when the class body is + // executed. It should return a pointer to the foreign function that will be + // bound to that method. + // + // If the foreign function could not be found, this should return NULL and + // Wren will report it as runtime error. + WrenBindForeignMethodFn bindForeignMethodFn; + + // The callback Wren uses to find a foreign class and get its foreign methods. + // + // When a foreign class is declared, this will be called with the class's + // module and name when the class body is executed. It should return the + // foreign functions uses to allocate and (optionally) finalize the bytes + // stored in the foreign object when an instance is created. + WrenBindForeignClassFn bindForeignClassFn; + + // The callback Wren uses to display text when `System.print()` or the other + // related functions are called. + // + // If this is `NULL`, Wren discards any printed text. + WrenWriteFn writeFn; + + // The callback Wren uses to report errors. + // + // When an error occurs, this will be called with the module name, line + // number, and an error message. If this is `NULL`, Wren doesn't report any + // errors. + WrenErrorFn errorFn; + + // The number of bytes Wren will allocate before triggering the first garbage + // collection. + // + // If zero, defaults to 10MB. + size_t initialHeapSize; + + // After a collection occurs, the threshold for the next collection is + // determined based on the number of bytes remaining in use. This allows Wren + // to shrink its memory usage automatically after reclaiming a large amount + // of memory. + // + // This can be used to ensure that the heap does not get too small, which can + // in turn lead to a large number of collections afterwards as the heap grows + // back to a usable size. + // + // If zero, defaults to 1MB. + size_t minHeapSize; + + // Wren will resize the heap automatically as the number of bytes + // remaining in use after a collection changes. This number determines the + // amount of additional memory Wren will use after a collection, as a + // percentage of the current heap size. + // + // For example, say that this is 50. After a garbage collection, when there + // are 400 bytes of memory still in use, the next collection will be triggered + // after a total of 600 bytes are allocated (including the 400 already in + // use.) + // + // Setting this to a smaller number wastes less memory, but triggers more + // frequent garbage collections. + // + // If zero, defaults to 50. + int heapGrowthPercent; + + // User-defined data associated with the VM. + void* userData; + +} WrenConfiguration; + +typedef enum +{ + WREN_RESULT_SUCCESS, + WREN_RESULT_COMPILE_ERROR, + WREN_RESULT_RUNTIME_ERROR +} WrenInterpretResult; + +// The type of an object stored in a slot. +// +// This is not necessarily the object's *class*, but instead its low level +// representation type. +typedef enum +{ + WREN_TYPE_BOOL, + WREN_TYPE_NUM, + WREN_TYPE_FOREIGN, + WREN_TYPE_LIST, + WREN_TYPE_MAP, + WREN_TYPE_NULL, + WREN_TYPE_STRING, + + // The object is of a type that isn't accessible by the C API. + WREN_TYPE_UNKNOWN +} WrenType; + +// Get the current wren version number. +// +// Can be used to range checks over versions. +WREN_API int wrenGetVersionNumber(); + +// Initializes [configuration] with all of its default values. +// +// Call this before setting the particular fields you care about. +WREN_API void wrenInitConfiguration(WrenConfiguration* configuration); + +// Creates a new Wren virtual machine using the given [configuration]. Wren +// will copy the configuration data, so the argument passed to this can be +// freed after calling this. If [configuration] is `NULL`, uses a default +// configuration. +WREN_API WrenVM* wrenNewVM(WrenConfiguration* configuration); + +// Disposes of all resources is use by [vm], which was previously created by a +// call to [wrenNewVM]. +WREN_API void wrenFreeVM(WrenVM* vm); + +// Immediately run the garbage collector to free unused memory. +WREN_API void wrenCollectGarbage(WrenVM* vm); + +// Runs [source], a string of Wren source code in a new fiber in [vm] in the +// context of resolved [module]. +WREN_API WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module, + const char* source); + +// Creates a handle that can be used to invoke a method with [signature] on +// using a receiver and arguments that are set up on the stack. +// +// This handle can be used repeatedly to directly invoke that method from C +// code using [wrenCall]. +// +// When you are done with this handle, it must be released using +// [wrenReleaseHandle]. +WREN_API WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature); + +// Calls [method], using the receiver and arguments previously set up on the +// stack. +// +// [method] must have been created by a call to [wrenMakeCallHandle]. The +// arguments to the method must be already on the stack. The receiver should be +// in slot 0 with the remaining arguments following it, in order. It is an +// error if the number of arguments provided does not match the method's +// signature. +// +// After this returns, you can access the return value from slot 0 on the stack. +WREN_API WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method); + +// Releases the reference stored in [handle]. After calling this, [handle] can +// no longer be used. +WREN_API void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle); + +// The following functions are intended to be called from foreign methods or +// finalizers. The interface Wren provides to a foreign method is like a +// register machine: you are given a numbered array of slots that values can be +// read from and written to. Values always live in a slot (unless explicitly +// captured using wrenGetSlotHandle(), which ensures the garbage collector can +// find them. +// +// When your foreign function is called, you are given one slot for the receiver +// and each argument to the method. The receiver is in slot 0 and the arguments +// are in increasingly numbered slots after that. You are free to read and +// write to those slots as you want. If you want more slots to use as scratch +// space, you can call wrenEnsureSlots() to add more. +// +// When your function returns, every slot except slot zero is discarded and the +// value in slot zero is used as the return value of the method. If you don't +// store a return value in that slot yourself, it will retain its previous +// value, the receiver. +// +// While Wren is dynamically typed, C is not. This means the C interface has to +// support the various types of primitive values a Wren variable can hold: bool, +// double, string, etc. If we supported this for every operation in the C API, +// there would be a combinatorial explosion of functions, like "get a +// double-valued element from a list", "insert a string key and double value +// into a map", etc. +// +// To avoid that, the only way to convert to and from a raw C value is by going +// into and out of a slot. All other functions work with values already in a +// slot. So, to add an element to a list, you put the list in one slot, and the +// element in another. Then there is a single API function wrenInsertInList() +// that takes the element out of that slot and puts it into the list. +// +// The goal of this API is to be easy to use while not compromising performance. +// The latter means it does not do type or bounds checking at runtime except +// using assertions which are generally removed from release builds. C is an +// unsafe language, so it's up to you to be careful to use it correctly. In +// return, you get a very fast FFI. + +// Returns the number of slots available to the current foreign method. +WREN_API int wrenGetSlotCount(WrenVM* vm); + +// Ensures that the foreign method stack has at least [numSlots] available for +// use, growing the stack if needed. +// +// Does not shrink the stack if it has more than enough slots. +// +// It is an error to call this from a finalizer. +WREN_API void wrenEnsureSlots(WrenVM* vm, int numSlots); + +// Gets the type of the object in [slot]. +WREN_API WrenType wrenGetSlotType(WrenVM* vm, int slot); + +// Reads a boolean value from [slot]. +// +// It is an error to call this if the slot does not contain a boolean value. +WREN_API bool wrenGetSlotBool(WrenVM* vm, int slot); + +// Reads a byte array from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// Returns a pointer to the first byte of the array and fill [length] with the +// number of bytes in the array. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length); + +// Reads a number from [slot]. +// +// It is an error to call this if the slot does not contain a number. +WREN_API double wrenGetSlotDouble(WrenVM* vm, int slot); + +// Reads a foreign object from [slot] and returns a pointer to the foreign data +// stored with it. +// +// It is an error to call this if the slot does not contain an instance of a +// foreign class. +WREN_API void* wrenGetSlotForeign(WrenVM* vm, int slot); + +// Reads a string from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotString(WrenVM* vm, int slot); + +// Creates a handle for the value stored in [slot]. +// +// This will prevent the object that is referred to from being garbage collected +// until the handle is released by calling [wrenReleaseHandle()]. +WREN_API WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot); + +// Stores the boolean [value] in [slot]. +WREN_API void wrenSetSlotBool(WrenVM* vm, int slot, bool value); + +// Stores the array [length] of [bytes] in [slot]. +// +// The bytes are copied to a new string within Wren's heap, so you can free +// memory used by them after this is called. +WREN_API void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length); + +// Stores the numeric [value] in [slot]. +WREN_API void wrenSetSlotDouble(WrenVM* vm, int slot, double value); + +// Creates a new instance of the foreign class stored in [classSlot] with [size] +// bytes of raw storage and places the resulting object in [slot]. +// +// This does not invoke the foreign class's constructor on the new instance. If +// you need that to happen, call the constructor from Wren, which will then +// call the allocator foreign method. In there, call this to create the object +// and then the constructor will be invoked when the allocator returns. +// +// Returns a pointer to the foreign object's data. +WREN_API void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size); + +// Stores a new empty list in [slot]. +WREN_API void wrenSetSlotNewList(WrenVM* vm, int slot); + +// Stores a new empty map in [slot]. +WREN_API void wrenSetSlotNewMap(WrenVM* vm, int slot); + +// Stores null in [slot]. +WREN_API void wrenSetSlotNull(WrenVM* vm, int slot); + +// Stores the string [text] in [slot]. +// +// The [text] is copied to a new string within Wren's heap, so you can free +// memory used by it after this is called. The length is calculated using +// [strlen()]. If the string may contain any null bytes in the middle, then you +// should use [wrenSetSlotBytes()] instead. +WREN_API void wrenSetSlotString(WrenVM* vm, int slot, const char* text); + +// Stores the value captured in [handle] in [slot]. +// +// This does not release the handle for the value. +WREN_API void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle); + +// Returns the number of elements in the list stored in [slot]. +WREN_API int wrenGetListCount(WrenVM* vm, int slot); + +// Reads element [index] from the list in [listSlot] and stores it in +// [elementSlot]. +WREN_API void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Sets the value stored at [index] in the list at [listSlot], +// to the value from [elementSlot]. +WREN_API void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Takes the value stored at [elementSlot] and inserts it into the list stored +// at [listSlot] at [index]. +// +// As in Wren, negative indexes can be used to insert from the end. To append +// an element, use `-1` for the index. +WREN_API void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Returns the number of entries in the map stored in [slot]. +WREN_API int wrenGetMapCount(WrenVM* vm, int slot); + +// Returns true if the key in [keySlot] is found in the map placed in [mapSlot]. +WREN_API bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot); + +// Retrieves a value with the key in [keySlot] from the map in [mapSlot] and +// stores it in [valueSlot]. +WREN_API void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Takes the value stored at [valueSlot] and inserts it into the map stored +// at [mapSlot] with key [keySlot]. +WREN_API void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Removes a value from the map in [mapSlot], with the key from [keySlot], +// and place it in [removedValueSlot]. If not found, [removedValueSlot] is +// set to null, the same behaviour as the Wren Map API. +WREN_API void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot, + int removedValueSlot); + +// Looks up the top level variable with [name] in resolved [module] and stores +// it in [slot]. +WREN_API void wrenGetVariable(WrenVM* vm, const char* module, const char* name, + int slot); + +// Looks up the top level variable with [name] in resolved [module], +// returns false if not found. The module must be imported at the time, +// use wrenHasModule to ensure that before calling. +WREN_API bool wrenHasVariable(WrenVM* vm, const char* module, const char* name); + +// Returns true if [module] has been imported/resolved before, false if not. +WREN_API bool wrenHasModule(WrenVM* vm, const char* module); + +// Sets the current fiber to be aborted, and uses the value in [slot] as the +// runtime error object. +WREN_API void wrenAbortFiber(WrenVM* vm, int slot); + +// Returns the user data associated with the WrenVM. +WREN_API void* wrenGetUserData(WrenVM* vm); + +// Sets user data associated with the WrenVM. +WREN_API void wrenSetUserData(WrenVM* vm, void* userData); + +#endif +// End file "wren.h" +// Begin file "wren_debug.h" +#ifndef wren_debug_h +#define wren_debug_h + +// Begin file "wren_value.h" +#ifndef wren_value_h +#define wren_value_h + +#include +#include + +// Begin file "wren_common.h" +#ifndef wren_common_h +#define wren_common_h + +// This header contains macros and defines used across the entire Wren +// implementation. In particular, it contains "configuration" defines that +// control how Wren works. Some of these are only used while hacking on Wren +// itself. +// +// This header is *not* intended to be included by code outside of Wren itself. + +// Wren pervasively uses the C99 integer types (uint16_t, etc.) along with some +// of the associated limit constants (UINT32_MAX, etc.). The constants are not +// part of standard C++, so aren't included by default by C++ compilers when you +// include unless __STDC_LIMIT_MACROS is defined. +#define __STDC_LIMIT_MACROS +#include + +// These flags let you control some details of the interpreter's implementation. +// Usually they trade-off a bit of portability for speed. They default to the +// most efficient behavior. + +// If true, then Wren uses a NaN-tagged double for its core value +// representation. Otherwise, it uses a larger more conventional struct. The +// former is significantly faster and more compact. The latter is useful for +// debugging and may be more portable. +// +// Defaults to on. +#ifndef WREN_NAN_TAGGING + #define WREN_NAN_TAGGING 1 +#endif + +// If true, the VM's interpreter loop uses computed gotos. See this for more: +// http://gcc.gnu.org/onlinedocs/gcc-3.1.1/gcc/Labels-as-Values.html +// Enabling this speeds up the main dispatch loop a bit, but requires compiler +// support. +// see https://bullno1.com/blog/switched-goto for alternative +// Defaults to true on supported compilers. +#ifndef WREN_COMPUTED_GOTO + #if defined(_MSC_VER) && !defined(__clang__) + // No computed gotos in Visual Studio. + #define WREN_COMPUTED_GOTO 0 + #else + #define WREN_COMPUTED_GOTO 1 + #endif +#endif + +// The VM includes a number of optional modules. You can choose to include +// these or not. By default, they are all available. To disable one, set the +// corresponding `WREN_OPT_` define to `0`. +#ifndef WREN_OPT_META + #define WREN_OPT_META 1 +#endif + +#ifndef WREN_OPT_RANDOM + #define WREN_OPT_RANDOM 1 +#endif + +// These flags are useful for debugging and hacking on Wren itself. They are not +// intended to be used for production code. They default to off. + +// Set this to true to stress test the GC. It will perform a collection before +// every allocation. This is useful to ensure that memory is always correctly +// reachable. +#define WREN_DEBUG_GC_STRESS 0 + +// Set this to true to log memory operations as they occur. +#define WREN_DEBUG_TRACE_MEMORY 0 + +// Set this to true to log garbage collections as they occur. +#define WREN_DEBUG_TRACE_GC 0 + +// Set this to true to print out the compiled bytecode of each function. +#define WREN_DEBUG_DUMP_COMPILED_CODE 0 + +// Set this to trace each instruction as it's executed. +#define WREN_DEBUG_TRACE_INSTRUCTIONS 0 + +// The maximum number of module-level variables that may be defined at one time. +// This limitation comes from the 16 bits used for the arguments to +// `CODE_LOAD_MODULE_VAR` and `CODE_STORE_MODULE_VAR`. +#define MAX_MODULE_VARS 65536 + +// The maximum number of arguments that can be passed to a method. Note that +// this limitation is hardcoded in other places in the VM, in particular, the +// `CODE_CALL_XX` instructions assume a certain maximum number. +#define MAX_PARAMETERS 16 + +// The maximum name of a method, not including the signature. This is an +// arbitrary but enforced maximum just so we know how long the method name +// strings need to be in the parser. +#define MAX_METHOD_NAME 64 + +// The maximum length of a method signature. Signatures look like: +// +// foo // Getter. +// foo() // No-argument method. +// foo(_) // One-argument method. +// foo(_,_) // Two-argument method. +// init foo() // Constructor initializer. +// +// The maximum signature length takes into account the longest method name, the +// maximum number of parameters with separators between them, "init ", and "()". +#define MAX_METHOD_SIGNATURE (MAX_METHOD_NAME + (MAX_PARAMETERS * 2) + 6) + +// The maximum length of an identifier. The only real reason for this limitation +// is so that error messages mentioning variables can be stack allocated. +#define MAX_VARIABLE_NAME 64 + +// The maximum number of fields a class can have, including inherited fields. +// This is explicit in the bytecode since `CODE_CLASS` and `CODE_SUBCLASS` take +// a single byte for the number of fields. Note that it's 255 and not 256 +// because creating a class takes the *number* of fields, not the *highest +// field index*. +#define MAX_FIELDS 255 + +// Use the VM's allocator to allocate an object of [type]. +#define ALLOCATE(vm, type) \ + ((type*)wrenReallocate(vm, NULL, 0, sizeof(type))) + +// Use the VM's allocator to allocate an object of [mainType] containing a +// flexible array of [count] objects of [arrayType]. +#define ALLOCATE_FLEX(vm, mainType, arrayType, count) \ + ((mainType*)wrenReallocate(vm, NULL, 0, \ + sizeof(mainType) + sizeof(arrayType) * (count))) + +// Use the VM's allocator to allocate an array of [count] elements of [type]. +#define ALLOCATE_ARRAY(vm, type, count) \ + ((type*)wrenReallocate(vm, NULL, 0, sizeof(type) * (count))) + +// Use the VM's allocator to free the previously allocated memory at [pointer]. +#define DEALLOCATE(vm, pointer) wrenReallocate(vm, pointer, 0, 0) + +// The Microsoft compiler does not support the "inline" modifier when compiling +// as plain C. +#if defined( _MSC_VER ) && !defined(__cplusplus) + #define inline _inline +#endif + +// This is used to clearly mark flexible-sized arrays that appear at the end of +// some dynamically-allocated structs, known as the "struct hack". +#if __STDC_VERSION__ >= 199901L + // In C99, a flexible array member is just "[]". + #define FLEXIBLE_ARRAY +#else + // Elsewhere, use a zero-sized array. It's technically undefined behavior, + // but works reliably in most known compilers. + #define FLEXIBLE_ARRAY 0 +#endif + +// Assertions are used to validate program invariants. They indicate things the +// program expects to be true about its internal state during execution. If an +// assertion fails, there is a bug in Wren. +// +// Assertions add significant overhead, so are only enabled in debug builds. +#ifdef DEBUG + + #include + + #define ASSERT(condition, message) \ + do \ + { \ + if (!(condition)) \ + { \ + fprintf(stderr, "[%s:%d] Assert failed in %s(): %s\n", \ + __FILE__, __LINE__, __func__, message); \ + abort(); \ + } \ + } while (false) + + // Indicates that we know execution should never reach this point in the + // program. In debug mode, we assert this fact because it's a bug to get here. + // + // In release mode, we use compiler-specific built in functions to tell the + // compiler the code can't be reached. This avoids "missing return" warnings + // in some cases and also lets it perform some optimizations by assuming the + // code is never reached. + #define UNREACHABLE() \ + do \ + { \ + fprintf(stderr, "[%s:%d] This code should not be reached in %s()\n", \ + __FILE__, __LINE__, __func__); \ + abort(); \ + } while (false) + +#else + + #define ASSERT(condition, message) do { } while (false) + + // Tell the compiler that this part of the code will never be reached. + #if defined( _MSC_VER ) + #define UNREACHABLE() __assume(0) + #elif (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)) + #define UNREACHABLE() __builtin_unreachable() + #else + #define UNREACHABLE() + #endif + +#endif + +#endif +// End file "wren_common.h" +// Begin file "wren_math.h" +#ifndef wren_math_h +#define wren_math_h + +#include +#include + +// A union to let us reinterpret a double as raw bits and back. +typedef union +{ + uint64_t bits64; + uint32_t bits32[2]; + double num; +} WrenDoubleBits; + +#define WREN_DOUBLE_QNAN_POS_MIN_BITS (UINT64_C(0x7FF8000000000000)) +#define WREN_DOUBLE_QNAN_POS_MAX_BITS (UINT64_C(0x7FFFFFFFFFFFFFFF)) + +#define WREN_DOUBLE_NAN (wrenDoubleFromBits(WREN_DOUBLE_QNAN_POS_MIN_BITS)) + +static inline double wrenDoubleFromBits(uint64_t bits) +{ + WrenDoubleBits data; + data.bits64 = bits; + return data.num; +} + +static inline uint64_t wrenDoubleToBits(double num) +{ + WrenDoubleBits data; + data.num = num; + return data.bits64; +} + +#endif +// End file "wren_math.h" +// Begin file "wren_utils.h" +#ifndef wren_utils_h +#define wren_utils_h + + +// Reusable data structures and other utility functions. + +// Forward declare this here to break a cycle between wren_utils.h and +// wren_value.h. +typedef struct sObjString ObjString; + +// We need buffers of a few different types. To avoid lots of casting between +// void* and back, we'll use the preprocessor as a poor man's generics and let +// it generate a few type-specific ones. +#define DECLARE_BUFFER(name, type) \ + typedef struct \ + { \ + type* data; \ + int count; \ + int capacity; \ + } name##Buffer; \ + void wren##name##BufferInit(name##Buffer* buffer); \ + void wren##name##BufferClear(WrenVM* vm, name##Buffer* buffer); \ + void wren##name##BufferFill(WrenVM* vm, name##Buffer* buffer, type data, \ + int count); \ + void wren##name##BufferWrite(WrenVM* vm, name##Buffer* buffer, type data) + +// This should be used once for each type instantiation, somewhere in a .c file. +#define DEFINE_BUFFER(name, type) \ + void wren##name##BufferInit(name##Buffer* buffer) \ + { \ + buffer->data = NULL; \ + buffer->capacity = 0; \ + buffer->count = 0; \ + } \ + \ + void wren##name##BufferClear(WrenVM* vm, name##Buffer* buffer) \ + { \ + wrenReallocate(vm, buffer->data, 0, 0); \ + wren##name##BufferInit(buffer); \ + } \ + \ + void wren##name##BufferFill(WrenVM* vm, name##Buffer* buffer, type data, \ + int count) \ + { \ + if (buffer->capacity < buffer->count + count) \ + { \ + int capacity = wrenPowerOf2Ceil(buffer->count + count); \ + buffer->data = (type*)wrenReallocate(vm, buffer->data, \ + buffer->capacity * sizeof(type), capacity * sizeof(type)); \ + buffer->capacity = capacity; \ + } \ + \ + for (int i = 0; i < count; i++) \ + { \ + buffer->data[buffer->count++] = data; \ + } \ + } \ + \ + void wren##name##BufferWrite(WrenVM* vm, name##Buffer* buffer, type data) \ + { \ + wren##name##BufferFill(vm, buffer, data, 1); \ + } + +DECLARE_BUFFER(Byte, uint8_t); +DECLARE_BUFFER(Int, int); +DECLARE_BUFFER(String, ObjString*); + +// TODO: Change this to use a map. +typedef StringBuffer SymbolTable; + +// Initializes the symbol table. +void wrenSymbolTableInit(SymbolTable* symbols); + +// Frees all dynamically allocated memory used by the symbol table, but not the +// SymbolTable itself. +void wrenSymbolTableClear(WrenVM* vm, SymbolTable* symbols); + +// Adds name to the symbol table. Returns the index of it in the table. +int wrenSymbolTableAdd(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length); + +// Adds name to the symbol table. Returns the index of it in the table. Will +// use an existing symbol if already present. +int wrenSymbolTableEnsure(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length); + +// Looks up name in the symbol table. Returns its index if found or -1 if not. +int wrenSymbolTableFind(const SymbolTable* symbols, + const char* name, size_t length); + +void wrenBlackenSymbolTable(WrenVM* vm, SymbolTable* symbolTable); + +// Returns the number of bytes needed to encode [value] in UTF-8. +// +// Returns 0 if [value] is too large to encode. +int wrenUtf8EncodeNumBytes(int value); + +// Encodes value as a series of bytes in [bytes], which is assumed to be large +// enough to hold the encoded result. +// +// Returns the number of written bytes. +int wrenUtf8Encode(int value, uint8_t* bytes); + +// Decodes the UTF-8 sequence starting at [bytes] (which has max [length]), +// returning the code point. +// +// Returns -1 if the bytes are not a valid UTF-8 sequence. +int wrenUtf8Decode(const uint8_t* bytes, uint32_t length); + +// Returns the number of bytes in the UTF-8 sequence starting with [byte]. +// +// If the character at that index is not the beginning of a UTF-8 sequence, +// returns 0. +int wrenUtf8DecodeNumBytes(uint8_t byte); + +// Returns the smallest power of two that is equal to or greater than [n]. +int wrenPowerOf2Ceil(int n); + +// Validates that [value] is within `[0, count)`. Also allows +// negative indices which map backwards from the end. Returns the valid positive +// index value. If invalid, returns `UINT32_MAX`. +uint32_t wrenValidateIndex(uint32_t count, int64_t value); + +#endif +// End file "wren_utils.h" + +// This defines the built-in types and their core representations in memory. +// Since Wren is dynamically typed, any variable can hold a value of any type, +// and the type can change at runtime. Implementing this efficiently is +// critical for performance. +// +// The main type exposed by this is [Value]. A C variable of that type is a +// storage location that can hold any Wren value. The stack, module variables, +// and instance fields are all implemented in C as variables of type Value. +// +// The built-in types for booleans, numbers, and null are unboxed: their value +// is stored directly in the Value, and copying a Value copies the value. Other +// types--classes, instances of classes, functions, lists, and strings--are all +// reference types. They are stored on the heap and the Value just stores a +// pointer to it. Copying the Value copies a reference to the same object. The +// Wren implementation calls these "Obj", or objects, though to a user, all +// values are objects. +// +// There is also a special singleton value "undefined". It is used internally +// but never appears as a real value to a user. It has two uses: +// +// - It is used to identify module variables that have been implicitly declared +// by use in a forward reference but not yet explicitly declared. These only +// exist during compilation and do not appear at runtime. +// +// - It is used to represent unused map entries in an ObjMap. +// +// There are two supported Value representations. The main one uses a technique +// called "NaN tagging" (explained in detail below) to store a number, any of +// the value types, or a pointer, all inside one double-precision floating +// point number. A larger, slower, Value type that uses a struct to store these +// is also supported, and is useful for debugging the VM. +// +// The representation is controlled by the `WREN_NAN_TAGGING` define. If that's +// defined, Nan tagging is used. + +// These macros cast a Value to one of the specific object types. These do *not* +// perform any validation, so must only be used after the Value has been +// ensured to be the right type. +#define AS_CLASS(value) ((ObjClass*)AS_OBJ(value)) // ObjClass* +#define AS_CLOSURE(value) ((ObjClosure*)AS_OBJ(value)) // ObjClosure* +#define AS_FIBER(v) ((ObjFiber*)AS_OBJ(v)) // ObjFiber* +#define AS_FN(value) ((ObjFn*)AS_OBJ(value)) // ObjFn* +#define AS_FOREIGN(v) ((ObjForeign*)AS_OBJ(v)) // ObjForeign* +#define AS_INSTANCE(value) ((ObjInstance*)AS_OBJ(value)) // ObjInstance* +#define AS_LIST(value) ((ObjList*)AS_OBJ(value)) // ObjList* +#define AS_MAP(value) ((ObjMap*)AS_OBJ(value)) // ObjMap* +#define AS_MODULE(value) ((ObjModule*)AS_OBJ(value)) // ObjModule* +#define AS_NUM(value) (wrenValueToNum(value)) // double +#define AS_RANGE(v) ((ObjRange*)AS_OBJ(v)) // ObjRange* +#define AS_STRING(v) ((ObjString*)AS_OBJ(v)) // ObjString* +#define AS_CSTRING(v) (AS_STRING(v)->value) // const char* + +// These macros promote a primitive C value to a full Wren Value. There are +// more defined below that are specific to the Nan tagged or other +// representation. +#define BOOL_VAL(boolean) ((boolean) ? TRUE_VAL : FALSE_VAL) // boolean +#define NUM_VAL(num) (wrenNumToValue(num)) // double +#define OBJ_VAL(obj) (wrenObjectToValue((Obj*)(obj))) // Any Obj___* + +// These perform type tests on a Value, returning `true` if the Value is of the +// given type. +#define IS_BOOL(value) (wrenIsBool(value)) // Bool +#define IS_CLASS(value) (wrenIsObjType(value, OBJ_CLASS)) // ObjClass +#define IS_CLOSURE(value) (wrenIsObjType(value, OBJ_CLOSURE)) // ObjClosure +#define IS_FIBER(value) (wrenIsObjType(value, OBJ_FIBER)) // ObjFiber +#define IS_FN(value) (wrenIsObjType(value, OBJ_FN)) // ObjFn +#define IS_FOREIGN(value) (wrenIsObjType(value, OBJ_FOREIGN)) // ObjForeign +#define IS_INSTANCE(value) (wrenIsObjType(value, OBJ_INSTANCE)) // ObjInstance +#define IS_LIST(value) (wrenIsObjType(value, OBJ_LIST)) // ObjList +#define IS_MAP(value) (wrenIsObjType(value, OBJ_MAP)) // ObjMap +#define IS_RANGE(value) (wrenIsObjType(value, OBJ_RANGE)) // ObjRange +#define IS_STRING(value) (wrenIsObjType(value, OBJ_STRING)) // ObjString + +// Creates a new string object from [text], which should be a bare C string +// literal. This determines the length of the string automatically at compile +// time based on the size of the character array (-1 for the terminating '\0'). +#define CONST_STRING(vm, text) wrenNewStringLength((vm), (text), sizeof(text) - 1) + +// Identifies which specific type a heap-allocated object is. +typedef enum { + OBJ_CLASS, + OBJ_CLOSURE, + OBJ_FIBER, + OBJ_FN, + OBJ_FOREIGN, + OBJ_INSTANCE, + OBJ_LIST, + OBJ_MAP, + OBJ_MODULE, + OBJ_RANGE, + OBJ_STRING, + OBJ_UPVALUE +} ObjType; + +typedef struct sObjClass ObjClass; + +// Base struct for all heap-allocated objects. +typedef struct sObj Obj; +struct sObj +{ + ObjType type; + bool isDark; + + // The object's class. + ObjClass* classObj; + + // The next object in the linked list of all currently allocated objects. + struct sObj* next; +}; + +#if WREN_NAN_TAGGING + +typedef uint64_t Value; + +#else + +typedef enum +{ + VAL_FALSE, + VAL_NULL, + VAL_NUM, + VAL_TRUE, + VAL_UNDEFINED, + VAL_OBJ +} ValueType; + +typedef struct +{ + ValueType type; + union + { + double num; + Obj* obj; + } as; +} Value; + +#endif + +DECLARE_BUFFER(Value, Value); + +// A heap-allocated string object. +struct sObjString +{ + Obj obj; + + // Number of bytes in the string, not including the null terminator. + uint32_t length; + + // The hash value of the string's contents. + uint32_t hash; + + // Inline array of the string's bytes followed by a null terminator. + char value[FLEXIBLE_ARRAY]; +}; + +// The dynamically allocated data structure for a variable that has been used +// by a closure. Whenever a function accesses a variable declared in an +// enclosing function, it will get to it through this. +// +// An upvalue can be either "closed" or "open". An open upvalue points directly +// to a [Value] that is still stored on the fiber's stack because the local +// variable is still in scope in the function where it's declared. +// +// When that local variable goes out of scope, the upvalue pointing to it will +// be closed. When that happens, the value gets copied off the stack into the +// upvalue itself. That way, it can have a longer lifetime than the stack +// variable. +typedef struct sObjUpvalue +{ + // The object header. Note that upvalues have this because they are garbage + // collected, but they are not first class Wren objects. + Obj obj; + + // Pointer to the variable this upvalue is referencing. + Value* value; + + // If the upvalue is closed (i.e. the local variable it was pointing to has + // been popped off the stack) then the closed-over value will be hoisted out + // of the stack into here. [value] will then be changed to point to this. + Value closed; + + // Open upvalues are stored in a linked list by the fiber. This points to the + // next upvalue in that list. + struct sObjUpvalue* next; +} ObjUpvalue; + +// The type of a primitive function. +// +// Primitives are similar to foreign functions, but have more direct access to +// VM internals. It is passed the arguments in [args]. If it returns a value, +// it places it in `args[0]` and returns `true`. If it causes a runtime error +// or modifies the running fiber, it returns `false`. +typedef bool (*Primitive)(WrenVM* vm, Value* args); + +// TODO: See if it's actually a perf improvement to have this in a separate +// struct instead of in ObjFn. +// Stores debugging information for a function used for things like stack +// traces. +typedef struct +{ + // The name of the function. Heap allocated and owned by the FnDebug. + char* name; + + // An array of line numbers. There is one element in this array for each + // bytecode in the function's bytecode array. The value of that element is + // the line in the source code that generated that instruction. + IntBuffer sourceLines; +} FnDebug; + +// A loaded module and the top-level variables it defines. +// +// While this is an Obj and is managed by the GC, it never appears as a +// first-class object in Wren. +typedef struct +{ + Obj obj; + + // The currently defined top-level variables. + ValueBuffer variables; + + // Symbol table for the names of all module variables. Indexes here directly + // correspond to entries in [variables]. + SymbolTable variableNames; + + // The name of the module. + ObjString* name; +} ObjModule; + +// A function object. It wraps and owns the bytecode and other debug information +// for a callable chunk of code. +// +// Function objects are not passed around and invoked directly. Instead, they +// are always referenced by an [ObjClosure] which is the real first-class +// representation of a function. This isn't strictly necessary if they function +// has no upvalues, but lets the rest of the VM assume all called objects will +// be closures. +typedef struct +{ + Obj obj; + + ByteBuffer code; + ValueBuffer constants; + + // The module where this function was defined. + ObjModule* module; + + // The maximum number of stack slots this function may use. + int maxSlots; + + // The number of upvalues this function closes over. + int numUpvalues; + + // The number of parameters this function expects. Used to ensure that .call + // handles a mismatch between number of parameters and arguments. This will + // only be set for fns, and not ObjFns that represent methods or scripts. + int arity; + FnDebug* debug; +} ObjFn; + +// An instance of a first-class function and the environment it has closed over. +// Unlike [ObjFn], this has captured the upvalues that the function accesses. +typedef struct +{ + Obj obj; + + // The function that this closure is an instance of. + ObjFn* fn; + + // The upvalues this function has closed over. + ObjUpvalue* upvalues[FLEXIBLE_ARRAY]; +} ObjClosure; + +typedef struct +{ + // Pointer to the current (really next-to-be-executed) instruction in the + // function's bytecode. + uint8_t* ip; + + // The closure being executed. + ObjClosure* closure; + + // Pointer to the first stack slot used by this call frame. This will contain + // the receiver, followed by the function's parameters, then local variables + // and temporaries. + Value* stackStart; +} CallFrame; + +// Tracks how this fiber has been invoked, aside from the ways that can be +// detected from the state of other fields in the fiber. +typedef enum +{ + // The fiber is being run from another fiber using a call to `try()`. + FIBER_TRY, + + // The fiber was directly invoked by `runInterpreter()`. This means it's the + // initial fiber used by a call to `wrenCall()` or `wrenInterpret()`. + FIBER_ROOT, + + // The fiber is invoked some other way. If [caller] is `NULL` then the fiber + // was invoked using `call()`. If [numFrames] is zero, then the fiber has + // finished running and is done. If [numFrames] is one and that frame's `ip` + // points to the first byte of code, the fiber has not been started yet. + FIBER_OTHER, +} FiberState; + +typedef struct sObjFiber +{ + Obj obj; + + // The stack of value slots. This is used for holding local variables and + // temporaries while the fiber is executing. It is heap-allocated and grown + // as needed. + Value* stack; + + // A pointer to one past the top-most value on the stack. + Value* stackTop; + + // The number of allocated slots in the stack array. + int stackCapacity; + + // The stack of call frames. This is a dynamic array that grows as needed but + // never shrinks. + CallFrame* frames; + + // The number of frames currently in use in [frames]. + int numFrames; + + // The number of [frames] allocated. + int frameCapacity; + + // Pointer to the first node in the linked list of open upvalues that are + // pointing to values still on the stack. The head of the list will be the + // upvalue closest to the top of the stack, and then the list works downwards. + ObjUpvalue* openUpvalues; + + // The fiber that ran this one. If this fiber is yielded, control will resume + // to this one. May be `NULL`. + struct sObjFiber* caller; + + // If the fiber failed because of a runtime error, this will contain the + // error object. Otherwise, it will be null. + Value error; + + FiberState state; +} ObjFiber; + +typedef enum +{ + // A primitive method implemented in C in the VM. Unlike foreign methods, + // this can directly manipulate the fiber's stack. + METHOD_PRIMITIVE, + + // A primitive that handles .call on Fn. + METHOD_FUNCTION_CALL, + + // A externally-defined C method. + METHOD_FOREIGN, + + // A normal user-defined method. + METHOD_BLOCK, + + // No method for the given symbol. + METHOD_NONE +} MethodType; + +typedef struct +{ + MethodType type; + + // The method function itself. The [type] determines which field of the union + // is used. + union + { + Primitive primitive; + WrenForeignMethodFn foreign; + ObjClosure* closure; + } as; +} Method; + +DECLARE_BUFFER(Method, Method); + +struct sObjClass +{ + Obj obj; + ObjClass* superclass; + + // The number of fields needed for an instance of this class, including all + // of its superclass fields. + int numFields; + + // The table of methods that are defined in or inherited by this class. + // Methods are called by symbol, and the symbol directly maps to an index in + // this table. This makes method calls fast at the expense of empty cells in + // the list for methods the class doesn't support. + // + // You can think of it as a hash table that never has collisions but has a + // really low load factor. Since methods are pretty small (just a type and a + // pointer), this should be a worthwhile trade-off. + MethodBuffer methods; + + // The name of the class. + ObjString* name; + + // The ClassAttribute for the class, if any + Value attributes; +}; + +typedef struct +{ + Obj obj; + uint8_t data[FLEXIBLE_ARRAY]; +} ObjForeign; + +typedef struct +{ + Obj obj; + Value fields[FLEXIBLE_ARRAY]; +} ObjInstance; + +typedef struct +{ + Obj obj; + + // The elements in the list. + ValueBuffer elements; +} ObjList; + +typedef struct +{ + // The entry's key, or UNDEFINED_VAL if the entry is not in use. + Value key; + + // The value associated with the key. If the key is UNDEFINED_VAL, this will + // be false to indicate an open available entry or true to indicate a + // tombstone -- an entry that was previously in use but was then deleted. + Value value; +} MapEntry; + +// A hash table mapping keys to values. +// +// We use something very simple: open addressing with linear probing. The hash +// table is an array of entries. Each entry is a key-value pair. If the key is +// the special UNDEFINED_VAL, it indicates no value is currently in that slot. +// Otherwise, it's a valid key, and the value is the value associated with it. +// +// When entries are added, the array is dynamically scaled by GROW_FACTOR to +// keep the number of filled slots under MAP_LOAD_PERCENT. Likewise, if the map +// gets empty enough, it will be resized to a smaller array. When this happens, +// all existing entries are rehashed and re-added to the new array. +// +// When an entry is removed, its slot is replaced with a "tombstone". This is an +// entry whose key is UNDEFINED_VAL and whose value is TRUE_VAL. When probing +// for a key, we will continue past tombstones, because the desired key may be +// found after them if the key that was removed was part of a prior collision. +// When the array gets resized, all tombstones are discarded. +typedef struct +{ + Obj obj; + + // The number of entries allocated. + uint32_t capacity; + + // The number of entries in the map. + uint32_t count; + + // Pointer to a contiguous array of [capacity] entries. + MapEntry* entries; +} ObjMap; + +typedef struct +{ + Obj obj; + + // The beginning of the range. + double from; + + // The end of the range. May be greater or less than [from]. + double to; + + // True if [to] is included in the range. + bool isInclusive; +} ObjRange; + +// An IEEE 754 double-precision float is a 64-bit value with bits laid out like: +// +// 1 Sign bit +// | 11 Exponent bits +// | | 52 Mantissa (i.e. fraction) bits +// | | | +// S[Exponent-][Mantissa------------------------------------------] +// +// The details of how these are used to represent numbers aren't really +// relevant here as long we don't interfere with them. The important bit is NaN. +// +// An IEEE double can represent a few magical values like NaN ("not a number"), +// Infinity, and -Infinity. A NaN is any value where all exponent bits are set: +// +// v--NaN bits +// -11111111111---------------------------------------------------- +// +// Here, "-" means "doesn't matter". Any bit sequence that matches the above is +// a NaN. With all of those "-", it obvious there are a *lot* of different +// bit patterns that all mean the same thing. NaN tagging takes advantage of +// this. We'll use those available bit patterns to represent things other than +// numbers without giving up any valid numeric values. +// +// NaN values come in two flavors: "signalling" and "quiet". The former are +// intended to halt execution, while the latter just flow through arithmetic +// operations silently. We want the latter. Quiet NaNs are indicated by setting +// the highest mantissa bit: +// +// v--Highest mantissa bit +// -[NaN ]1--------------------------------------------------- +// +// If all of the NaN bits are set, it's not a number. Otherwise, it is. +// That leaves all of the remaining bits as available for us to play with. We +// stuff a few different kinds of things here: special singleton values like +// "true", "false", and "null", and pointers to objects allocated on the heap. +// We'll use the sign bit to distinguish singleton values from pointers. If +// it's set, it's a pointer. +// +// v--Pointer or singleton? +// S[NaN ]1--------------------------------------------------- +// +// For singleton values, we just enumerate the different values. We'll use the +// low bits of the mantissa for that, and only need a few: +// +// 3 Type bits--v +// 0[NaN ]1------------------------------------------------[T] +// +// For pointers, we are left with 51 bits of mantissa to store an address. +// That's more than enough room for a 32-bit address. Even 64-bit machines +// only actually use 48 bits for addresses, so we've got plenty. We just stuff +// the address right into the mantissa. +// +// Ta-da, double precision numbers, pointers, and a bunch of singleton values, +// all stuffed into a single 64-bit sequence. Even better, we don't have to +// do any masking or work to extract number values: they are unmodified. This +// means math on numbers is fast. +#if WREN_NAN_TAGGING + +// A mask that selects the sign bit. +#define SIGN_BIT ((uint64_t)1 << 63) + +// The bits that must be set to indicate a quiet NaN. +#define QNAN ((uint64_t)0x7ffc000000000000) + +// If the NaN bits are set, it's not a number. +#define IS_NUM(value) (((value) & QNAN) != QNAN) + +// An object pointer is a NaN with a set sign bit. +#define IS_OBJ(value) (((value) & (QNAN | SIGN_BIT)) == (QNAN | SIGN_BIT)) + +#define IS_FALSE(value) ((value) == FALSE_VAL) +#define IS_NULL(value) ((value) == NULL_VAL) +#define IS_UNDEFINED(value) ((value) == UNDEFINED_VAL) + +// Masks out the tag bits used to identify the singleton value. +#define MASK_TAG (7) + +// Tag values for the different singleton values. +#define TAG_NAN (0) +#define TAG_NULL (1) +#define TAG_FALSE (2) +#define TAG_TRUE (3) +#define TAG_UNDEFINED (4) +#define TAG_UNUSED2 (5) +#define TAG_UNUSED3 (6) +#define TAG_UNUSED4 (7) + +// Value -> 0 or 1. +#define AS_BOOL(value) ((value) == TRUE_VAL) + +// Value -> Obj*. +#define AS_OBJ(value) ((Obj*)(uintptr_t)((value) & ~(SIGN_BIT | QNAN))) + +// Singleton values. +#define NULL_VAL ((Value)(uint64_t)(QNAN | TAG_NULL)) +#define FALSE_VAL ((Value)(uint64_t)(QNAN | TAG_FALSE)) +#define TRUE_VAL ((Value)(uint64_t)(QNAN | TAG_TRUE)) +#define UNDEFINED_VAL ((Value)(uint64_t)(QNAN | TAG_UNDEFINED)) + +// Gets the singleton type tag for a Value (which must be a singleton). +#define GET_TAG(value) ((int)((value) & MASK_TAG)) + +#else + +// Value -> 0 or 1. +#define AS_BOOL(value) ((value).type == VAL_TRUE) + +// Value -> Obj*. +#define AS_OBJ(v) ((v).as.obj) + +// Determines if [value] is a garbage-collected object or not. +#define IS_OBJ(value) ((value).type == VAL_OBJ) + +#define IS_FALSE(value) ((value).type == VAL_FALSE) +#define IS_NULL(value) ((value).type == VAL_NULL) +#define IS_NUM(value) ((value).type == VAL_NUM) +#define IS_UNDEFINED(value) ((value).type == VAL_UNDEFINED) + +// Singleton values. +#define FALSE_VAL ((Value){ VAL_FALSE, { 0 } }) +#define NULL_VAL ((Value){ VAL_NULL, { 0 } }) +#define TRUE_VAL ((Value){ VAL_TRUE, { 0 } }) +#define UNDEFINED_VAL ((Value){ VAL_UNDEFINED, { 0 } }) + +#endif + +// Creates a new "raw" class. It has no metaclass or superclass whatsoever. +// This is only used for bootstrapping the initial Object and Class classes, +// which are a little special. +ObjClass* wrenNewSingleClass(WrenVM* vm, int numFields, ObjString* name); + +// Makes [superclass] the superclass of [subclass], and causes subclass to +// inherit its methods. This should be called before any methods are defined +// on subclass. +void wrenBindSuperclass(WrenVM* vm, ObjClass* subclass, ObjClass* superclass); + +// Creates a new class object as well as its associated metaclass. +ObjClass* wrenNewClass(WrenVM* vm, ObjClass* superclass, int numFields, + ObjString* name); + +void wrenBindMethod(WrenVM* vm, ObjClass* classObj, int symbol, Method method); + +// Creates a new closure object that invokes [fn]. Allocates room for its +// upvalues, but assumes outside code will populate it. +ObjClosure* wrenNewClosure(WrenVM* vm, ObjFn* fn); + +// Creates a new fiber object that will invoke [closure]. +ObjFiber* wrenNewFiber(WrenVM* vm, ObjClosure* closure); + +// Adds a new [CallFrame] to [fiber] invoking [closure] whose stack starts at +// [stackStart]. +static inline void wrenAppendCallFrame(WrenVM* vm, ObjFiber* fiber, + ObjClosure* closure, Value* stackStart) +{ + // The caller should have ensured we already have enough capacity. + ASSERT(fiber->frameCapacity > fiber->numFrames, "No memory for call frame."); + + CallFrame* frame = &fiber->frames[fiber->numFrames++]; + frame->stackStart = stackStart; + frame->closure = closure; + frame->ip = closure->fn->code.data; +} + +// Ensures [fiber]'s stack has at least [needed] slots. +void wrenEnsureStack(WrenVM* vm, ObjFiber* fiber, int needed); + +static inline bool wrenHasError(const ObjFiber* fiber) +{ + return !IS_NULL(fiber->error); +} + +ObjForeign* wrenNewForeign(WrenVM* vm, ObjClass* classObj, size_t size); + +// Creates a new empty function. Before being used, it must have code, +// constants, etc. added to it. +ObjFn* wrenNewFunction(WrenVM* vm, ObjModule* module, int maxSlots); + +void wrenFunctionBindName(WrenVM* vm, ObjFn* fn, const char* name, int length); + +// Creates a new instance of the given [classObj]. +Value wrenNewInstance(WrenVM* vm, ObjClass* classObj); + +// Creates a new list with [numElements] elements (which are left +// uninitialized.) +ObjList* wrenNewList(WrenVM* vm, uint32_t numElements); + +// Inserts [value] in [list] at [index], shifting down the other elements. +void wrenListInsert(WrenVM* vm, ObjList* list, Value value, uint32_t index); + +// Removes and returns the item at [index] from [list]. +Value wrenListRemoveAt(WrenVM* vm, ObjList* list, uint32_t index); + +// Searches for [value] in [list], returns the index or -1 if not found. +int wrenListIndexOf(WrenVM* vm, ObjList* list, Value value); + +// Creates a new empty map. +ObjMap* wrenNewMap(WrenVM* vm); + +// Validates that [arg] is a valid object for use as a map key. Returns true if +// it is and returns false otherwise. Use validateKey usually, for a runtime error. +// This separation exists to aid the API in surfacing errors to the developer as well. +static inline bool wrenMapIsValidKey(Value arg); + +// Looks up [key] in [map]. If found, returns the value. Otherwise, returns +// `UNDEFINED_VAL`. +Value wrenMapGet(ObjMap* map, Value key); + +// Associates [key] with [value] in [map]. +void wrenMapSet(WrenVM* vm, ObjMap* map, Value key, Value value); + +void wrenMapClear(WrenVM* vm, ObjMap* map); + +// Removes [key] from [map], if present. Returns the value for the key if found +// or `NULL_VAL` otherwise. +Value wrenMapRemoveKey(WrenVM* vm, ObjMap* map, Value key); + +// Creates a new module. +ObjModule* wrenNewModule(WrenVM* vm, ObjString* name); + +// Creates a new range from [from] to [to]. +Value wrenNewRange(WrenVM* vm, double from, double to, bool isInclusive); + +// Creates a new string object and copies [text] into it. +// +// [text] must be non-NULL. +Value wrenNewString(WrenVM* vm, const char* text); + +// Creates a new string object of [length] and copies [text] into it. +// +// [text] may be NULL if [length] is zero. +Value wrenNewStringLength(WrenVM* vm, const char* text, size_t length); + +// Creates a new string object by taking a range of characters from [source]. +// The range starts at [start], contains [count] bytes, and increments by +// [step]. +Value wrenNewStringFromRange(WrenVM* vm, ObjString* source, int start, + uint32_t count, int step); + +// Produces a string representation of [value]. +Value wrenNumToString(WrenVM* vm, double value); + +// Creates a new formatted string from [format] and any additional arguments +// used in the format string. +// +// This is a very restricted flavor of formatting, intended only for internal +// use by the VM. Two formatting characters are supported, each of which reads +// the next argument as a certain type: +// +// $ - A C string. +// @ - A Wren string object. +Value wrenStringFormat(WrenVM* vm, const char* format, ...); + +// Creates a new string containing the UTF-8 encoding of [value]. +Value wrenStringFromCodePoint(WrenVM* vm, int value); + +// Creates a new string from the integer representation of a byte +Value wrenStringFromByte(WrenVM* vm, uint8_t value); + +// Creates a new string containing the code point in [string] starting at byte +// [index]. If [index] points into the middle of a UTF-8 sequence, returns an +// empty string. +Value wrenStringCodePointAt(WrenVM* vm, ObjString* string, uint32_t index); + +// Search for the first occurence of [needle] within [haystack] and returns its +// zero-based offset. Returns `UINT32_MAX` if [haystack] does not contain +// [needle]. +uint32_t wrenStringFind(ObjString* haystack, ObjString* needle, + uint32_t startIndex); + +// Returns true if [a] and [b] represent the same string. +static inline bool wrenStringEqualsCString(const ObjString* a, + const char* b, size_t length) +{ + return a->length == length && memcmp(a->value, b, length) == 0; +} + +// Creates a new open upvalue pointing to [value] on the stack. +ObjUpvalue* wrenNewUpvalue(WrenVM* vm, Value* value); + +// Mark [obj] as reachable and still in use. This should only be called +// during the sweep phase of a garbage collection. +void wrenGrayObj(WrenVM* vm, Obj* obj); + +// Mark [value] as reachable and still in use. This should only be called +// during the sweep phase of a garbage collection. +void wrenGrayValue(WrenVM* vm, Value value); + +// Mark the values in [buffer] as reachable and still in use. This should only +// be called during the sweep phase of a garbage collection. +void wrenGrayBuffer(WrenVM* vm, ValueBuffer* buffer); + +// Processes every object in the gray stack until all reachable objects have +// been marked. After that, all objects are either white (freeable) or black +// (in use and fully traversed). +void wrenBlackenObjects(WrenVM* vm); + +// Releases all memory owned by [obj], including [obj] itself. +void wrenFreeObj(WrenVM* vm, Obj* obj); + +// Returns the class of [value]. +// +// Unlike wrenGetClassInline in wren_vm.h, this is not inlined. Inlining helps +// performance (significantly) in some cases, but degrades it in others. The +// ones used by the implementation were chosen to give the best results in the +// benchmarks. +ObjClass* wrenGetClass(WrenVM* vm, Value value); + +// Returns true if [a] and [b] are strictly the same value. This is identity +// for object values, and value equality for unboxed values. +static inline bool wrenValuesSame(Value a, Value b) +{ +#if WREN_NAN_TAGGING + // Value types have unique bit representations and we compare object types + // by identity (i.e. pointer), so all we need to do is compare the bits. + return a == b; +#else + if (a.type != b.type) return false; + if (a.type == VAL_NUM) return a.as.num == b.as.num; + return a.as.obj == b.as.obj; +#endif +} + +// Returns true if [a] and [b] are equivalent. Immutable values (null, bools, +// numbers, ranges, and strings) are equal if they have the same data. All +// other values are equal if they are identical objects. +bool wrenValuesEqual(Value a, Value b); + +// Returns true if [value] is a bool. Do not call this directly, instead use +// [IS_BOOL]. +static inline bool wrenIsBool(Value value) +{ +#if WREN_NAN_TAGGING + return value == TRUE_VAL || value == FALSE_VAL; +#else + return value.type == VAL_FALSE || value.type == VAL_TRUE; +#endif +} + +// Returns true if [value] is an object of type [type]. Do not call this +// directly, instead use the [IS___] macro for the type in question. +static inline bool wrenIsObjType(Value value, ObjType type) +{ + return IS_OBJ(value) && AS_OBJ(value)->type == type; +} + +// Converts the raw object pointer [obj] to a [Value]. +static inline Value wrenObjectToValue(Obj* obj) +{ +#if WREN_NAN_TAGGING + // The triple casting is necessary here to satisfy some compilers: + // 1. (uintptr_t) Convert the pointer to a number of the right size. + // 2. (uint64_t) Pad it up to 64 bits in 32-bit builds. + // 3. Or in the bits to make a tagged Nan. + // 4. Cast to a typedef'd value. + return (Value)(SIGN_BIT | QNAN | (uint64_t)(uintptr_t)(obj)); +#else + Value value; + value.type = VAL_OBJ; + value.as.obj = obj; + return value; +#endif +} + +// Interprets [value] as a [double]. +static inline double wrenValueToNum(Value value) +{ +#if WREN_NAN_TAGGING + return wrenDoubleFromBits(value); +#else + return value.as.num; +#endif +} + +// Converts [num] to a [Value]. +static inline Value wrenNumToValue(double num) +{ +#if WREN_NAN_TAGGING + return wrenDoubleToBits(num); +#else + Value value; + value.type = VAL_NUM; + value.as.num = num; + return value; +#endif +} + +static inline bool wrenMapIsValidKey(Value arg) +{ + return IS_BOOL(arg) + || IS_CLASS(arg) + || IS_NULL(arg) + || IS_NUM(arg) + || IS_RANGE(arg) + || IS_STRING(arg); +} + +#endif +// End file "wren_value.h" +// Begin file "wren_vm.h" +#ifndef wren_vm_h +#define wren_vm_h + +// Begin file "wren_compiler.h" +#ifndef wren_compiler_h +#define wren_compiler_h + + +typedef struct sCompiler Compiler; + +// This module defines the compiler for Wren. It takes a string of source code +// and lexes, parses, and compiles it. Wren uses a single-pass compiler. It +// does not build an actual AST during parsing and then consume that to +// generate code. Instead, the parser directly emits bytecode. +// +// This forces a few restrictions on the grammar and semantics of the language. +// Things like forward references and arbitrary lookahead are much harder. We +// get a lot in return for that, though. +// +// The implementation is much simpler since we don't need to define a bunch of +// AST data structures. More so, we don't have to deal with managing memory for +// AST objects. The compiler does almost no dynamic allocation while running. +// +// Compilation is also faster since we don't create a bunch of temporary data +// structures and destroy them after generating code. + +// Compiles [source], a string of Wren source code located in [module], to an +// [ObjFn] that will execute that code when invoked. Returns `NULL` if the +// source contains any syntax errors. +// +// If [isExpression] is `true`, [source] should be a single expression, and +// this compiles it to a function that evaluates and returns that expression. +// Otherwise, [source] should be a series of top level statements. +// +// If [printErrors] is `true`, any compile errors are output to stderr. +// Otherwise, they are silently discarded. +ObjFn* wrenCompile(WrenVM* vm, ObjModule* module, const char* source, + bool isExpression, bool printErrors); + +// When a class is defined, its superclass is not known until runtime since +// class definitions are just imperative statements. Most of the bytecode for a +// a method doesn't care, but there are two places where it matters: +// +// - To load or store a field, we need to know the index of the field in the +// instance's field array. We need to adjust this so that subclass fields +// are positioned after superclass fields, and we don't know this until the +// superclass is known. +// +// - Superclass calls need to know which superclass to dispatch to. +// +// We could handle this dynamically, but that adds overhead. Instead, when a +// method is bound, we walk the bytecode for the function and patch it up. +void wrenBindMethodCode(ObjClass* classObj, ObjFn* fn); + +// Reaches all of the heap-allocated objects in use by [compiler] (and all of +// its parents) so that they are not collected by the GC. +void wrenMarkCompiler(WrenVM* vm, Compiler* compiler); + +#endif +// End file "wren_compiler.h" + +// The maximum number of temporary objects that can be made visible to the GC +// at one time. +#define WREN_MAX_TEMP_ROOTS 8 + +typedef enum +{ + #define OPCODE(name, _) CODE_##name, +// Begin file "wren_opcodes.h" +// This defines the bytecode instructions used by the VM. It does so by invoking +// an OPCODE() macro which is expected to be defined at the point that this is +// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.) +// +// The first argument is the name of the opcode. The second is its "stack +// effect" -- the amount that the op code changes the size of the stack. A +// stack effect of 1 means it pushes a value and the stack grows one larger. +// -2 means it pops two values, etc. +// +// Note that the order of instructions here affects the order of the dispatch +// table in the VM's interpreter loop. That in turn affects caching which +// affects overall performance. Take care to run benchmarks if you change the +// order here. + +// Load the constant at index [arg]. +OPCODE(CONSTANT, 1) + +// Push null onto the stack. +OPCODE(NULL, 1) + +// Push false onto the stack. +OPCODE(FALSE, 1) + +// Push true onto the stack. +OPCODE(TRUE, 1) + +// Pushes the value in the given local slot. +OPCODE(LOAD_LOCAL_0, 1) +OPCODE(LOAD_LOCAL_1, 1) +OPCODE(LOAD_LOCAL_2, 1) +OPCODE(LOAD_LOCAL_3, 1) +OPCODE(LOAD_LOCAL_4, 1) +OPCODE(LOAD_LOCAL_5, 1) +OPCODE(LOAD_LOCAL_6, 1) +OPCODE(LOAD_LOCAL_7, 1) +OPCODE(LOAD_LOCAL_8, 1) + +// Note: The compiler assumes the following _STORE instructions always +// immediately follow their corresponding _LOAD ones. + +// Pushes the value in local slot [arg]. +OPCODE(LOAD_LOCAL, 1) + +// Stores the top of stack in local slot [arg]. Does not pop it. +OPCODE(STORE_LOCAL, 0) + +// Pushes the value in upvalue [arg]. +OPCODE(LOAD_UPVALUE, 1) + +// Stores the top of stack in upvalue [arg]. Does not pop it. +OPCODE(STORE_UPVALUE, 0) + +// Pushes the value of the top-level variable in slot [arg]. +OPCODE(LOAD_MODULE_VAR, 1) + +// Stores the top of stack in top-level variable slot [arg]. Does not pop it. +OPCODE(STORE_MODULE_VAR, 0) + +// Pushes the value of the field in slot [arg] of the receiver of the current +// function. This is used for regular field accesses on "this" directly in +// methods. This instruction is faster than the more general CODE_LOAD_FIELD +// instruction. +OPCODE(LOAD_FIELD_THIS, 1) + +// Stores the top of the stack in field slot [arg] in the receiver of the +// current value. Does not pop the value. This instruction is faster than the +// more general CODE_LOAD_FIELD instruction. +OPCODE(STORE_FIELD_THIS, 0) + +// Pops an instance and pushes the value of the field in slot [arg] of it. +OPCODE(LOAD_FIELD, 0) + +// Pops an instance and stores the subsequent top of stack in field slot +// [arg] in it. Does not pop the value. +OPCODE(STORE_FIELD, -1) + +// Pop and discard the top of stack. +OPCODE(POP, -1) + +// Invoke the method with symbol [arg]. The number indicates the number of +// arguments (not including the receiver). +OPCODE(CALL_0, 0) +OPCODE(CALL_1, -1) +OPCODE(CALL_2, -2) +OPCODE(CALL_3, -3) +OPCODE(CALL_4, -4) +OPCODE(CALL_5, -5) +OPCODE(CALL_6, -6) +OPCODE(CALL_7, -7) +OPCODE(CALL_8, -8) +OPCODE(CALL_9, -9) +OPCODE(CALL_10, -10) +OPCODE(CALL_11, -11) +OPCODE(CALL_12, -12) +OPCODE(CALL_13, -13) +OPCODE(CALL_14, -14) +OPCODE(CALL_15, -15) +OPCODE(CALL_16, -16) + +// Invoke a superclass method with symbol [arg]. The number indicates the +// number of arguments (not including the receiver). +OPCODE(SUPER_0, 0) +OPCODE(SUPER_1, -1) +OPCODE(SUPER_2, -2) +OPCODE(SUPER_3, -3) +OPCODE(SUPER_4, -4) +OPCODE(SUPER_5, -5) +OPCODE(SUPER_6, -6) +OPCODE(SUPER_7, -7) +OPCODE(SUPER_8, -8) +OPCODE(SUPER_9, -9) +OPCODE(SUPER_10, -10) +OPCODE(SUPER_11, -11) +OPCODE(SUPER_12, -12) +OPCODE(SUPER_13, -13) +OPCODE(SUPER_14, -14) +OPCODE(SUPER_15, -15) +OPCODE(SUPER_16, -16) + +// Jump the instruction pointer [arg] forward. +OPCODE(JUMP, 0) + +// Jump the instruction pointer [arg] backward. +OPCODE(LOOP, 0) + +// Pop and if not truthy then jump the instruction pointer [arg] forward. +OPCODE(JUMP_IF, -1) + +// If the top of the stack is false, jump [arg] forward. Otherwise, pop and +// continue. +OPCODE(AND, -1) + +// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop +// and continue. +OPCODE(OR, -1) + +// Close the upvalue for the local on the top of the stack, then pop it. +OPCODE(CLOSE_UPVALUE, -1) + +// Exit from the current function and return the value on the top of the +// stack. +OPCODE(RETURN, 0) + +// Creates a closure for the function stored at [arg] in the constant table. +// +// Following the function argument is a number of arguments, two for each +// upvalue. The first is true if the variable being captured is a local (as +// opposed to an upvalue), and the second is the index of the local or +// upvalue being captured. +// +// Pushes the created closure. +OPCODE(CLOSURE, 1) + +// Creates a new instance of a class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(CONSTRUCT, 0) + +// Creates a new instance of a foreign class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(FOREIGN_CONSTRUCT, 0) + +// Creates a class. Top of stack is the superclass. Below that is a string for +// the name of the class. Byte [arg] is the number of fields in the class. +OPCODE(CLASS, -1) + +// Ends a class. +// Atm the stack contains the class and the ClassAttributes (or null). +OPCODE(END_CLASS, -2) + +// Creates a foreign class. Top of stack is the superclass. Below that is a +// string for the name of the class. +OPCODE(FOREIGN_CLASS, -1) + +// Define a method for symbol [arg]. The class receiving the method is popped +// off the stack, then the function defining the body is popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_INSTANCE, -2) + +// Define a method for symbol [arg]. The class whose metaclass will receive +// the method is popped off the stack, then the function defining the body is +// popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_STATIC, -2) + +// This is executed at the end of the module's body. Pushes NULL onto the stack +// as the "return value" of the import statement and stores the module as the +// most recently imported one. +OPCODE(END_MODULE, 1) + +// Import a module whose name is the string stored at [arg] in the constant +// table. +// +// Pushes null onto the stack so that the fiber for the imported module can +// replace that with a dummy value when it returns. (Fibers always return a +// value when resuming a caller.) +OPCODE(IMPORT_MODULE, 1) + +// Import a variable from the most recently imported module. The name of the +// variable to import is at [arg] in the constant table. Pushes the loaded +// variable's value. +OPCODE(IMPORT_VARIABLE, 1) + +// This pseudo-instruction indicates the end of the bytecode. It should +// always be preceded by a `CODE_RETURN`, so is never actually executed. +OPCODE(END, 0) +// End file "wren_opcodes.h" + #undef OPCODE +} Code; + +// A handle to a value, basically just a linked list of extra GC roots. +// +// Note that even non-heap-allocated values can be stored here. +struct WrenHandle +{ + Value value; + + WrenHandle* prev; + WrenHandle* next; +}; + +struct WrenVM +{ + ObjClass* boolClass; + ObjClass* classClass; + ObjClass* fiberClass; + ObjClass* fnClass; + ObjClass* listClass; + ObjClass* mapClass; + ObjClass* nullClass; + ObjClass* numClass; + ObjClass* objectClass; + ObjClass* rangeClass; + ObjClass* stringClass; + + // The fiber that is currently running. + ObjFiber* fiber; + + // The loaded modules. Each key is an ObjString (except for the main module, + // whose key is null) for the module's name and the value is the ObjModule + // for the module. + ObjMap* modules; + + // The most recently imported module. More specifically, the module whose + // code has most recently finished executing. + // + // Not treated like a GC root since the module is already in [modules]. + ObjModule* lastModule; + + // Memory management data: + + // The number of bytes that are known to be currently allocated. Includes all + // memory that was proven live after the last GC, as well as any new bytes + // that were allocated since then. Does *not* include bytes for objects that + // were freed since the last GC. + size_t bytesAllocated; + + // The number of total allocated bytes that will trigger the next GC. + size_t nextGC; + + // The first object in the linked list of all currently allocated objects. + Obj* first; + + // The "gray" set for the garbage collector. This is the stack of unprocessed + // objects while a garbage collection pass is in process. + Obj** gray; + int grayCount; + int grayCapacity; + + // The list of temporary roots. This is for temporary or new objects that are + // not otherwise reachable but should not be collected. + // + // They are organized as a stack of pointers stored in this array. This + // implies that temporary roots need to have stack semantics: only the most + // recently pushed object can be released. + Obj* tempRoots[WREN_MAX_TEMP_ROOTS]; + + int numTempRoots; + + // Pointer to the first node in the linked list of active handles or NULL if + // there are none. + WrenHandle* handles; + + // Pointer to the bottom of the range of stack slots available for use from + // the C API. During a foreign method, this will be in the stack of the fiber + // that is executing a method. + // + // If not in a foreign method, this is initially NULL. If the user requests + // slots by calling wrenEnsureSlots(), a stack is created and this is + // initialized. + Value* apiStack; + + WrenConfiguration config; + + // Compiler and debugger data: + + // The compiler that is currently compiling code. This is used so that heap + // allocated objects used by the compiler can be found if a GC is kicked off + // in the middle of a compile. + Compiler* compiler; + + // There is a single global symbol table for all method names on all classes. + // Method calls are dispatched directly by index in this table. + SymbolTable methodNames; +}; + +// A generic allocation function that handles all explicit memory management. +// It's used like so: +// +// - To allocate new memory, [memory] is NULL and [oldSize] is zero. It should +// return the allocated memory or NULL on failure. +// +// - To attempt to grow an existing allocation, [memory] is the memory, +// [oldSize] is its previous size, and [newSize] is the desired size. +// It should return [memory] if it was able to grow it in place, or a new +// pointer if it had to move it. +// +// - To shrink memory, [memory], [oldSize], and [newSize] are the same as above +// but it will always return [memory]. +// +// - To free memory, [memory] will be the memory to free and [newSize] and +// [oldSize] will be zero. It should return NULL. +void* wrenReallocate(WrenVM* vm, void* memory, size_t oldSize, size_t newSize); + +// Invoke the finalizer for the foreign object referenced by [foreign]. +void wrenFinalizeForeign(WrenVM* vm, ObjForeign* foreign); + +// Creates a new [WrenHandle] for [value]. +WrenHandle* wrenMakeHandle(WrenVM* vm, Value value); + +// Compile [source] in the context of [module] and wrap in a fiber that can +// execute it. +// +// Returns NULL if a compile error occurred. +ObjClosure* wrenCompileSource(WrenVM* vm, const char* module, + const char* source, bool isExpression, + bool printErrors); + +// Looks up a variable from a previously-loaded module. +// +// Aborts the current fiber if the module or variable could not be found. +Value wrenGetModuleVariable(WrenVM* vm, Value moduleName, Value variableName); + +// Returns the value of the module-level variable named [name] in the main +// module. +Value wrenFindVariable(WrenVM* vm, ObjModule* module, const char* name); + +// Adds a new implicitly declared top-level variable named [name] to [module] +// based on a use site occurring on [line]. +// +// Does not check to see if a variable with that name is already declared or +// defined. Returns the symbol for the new variable or -2 if there are too many +// variables defined. +int wrenDeclareVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, int line); + +// Adds a new top-level variable named [name] to [module], and optionally +// populates line with the line of the implicit first use (line can be NULL). +// +// Returns the symbol for the new variable, -1 if a variable with the given name +// is already defined, or -2 if there are too many variables defined. +// Returns -3 if this is a top-level lowercase variable (localname) that was +// used before being defined. +int wrenDefineVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, Value value, int* line); + +// Pushes [closure] onto [fiber]'s callstack to invoke it. Expects [numArgs] +// arguments (including the receiver) to be on the top of the stack already. +static inline void wrenCallFunction(WrenVM* vm, ObjFiber* fiber, + ObjClosure* closure, int numArgs) +{ + // Grow the call frame array if needed. + if (fiber->numFrames + 1 > fiber->frameCapacity) + { + int max = fiber->frameCapacity * 2; + fiber->frames = (CallFrame*)wrenReallocate(vm, fiber->frames, + sizeof(CallFrame) * fiber->frameCapacity, sizeof(CallFrame) * max); + fiber->frameCapacity = max; + } + + // Grow the stack if needed. + int stackSize = (int)(fiber->stackTop - fiber->stack); + int needed = stackSize + closure->fn->maxSlots; + wrenEnsureStack(vm, fiber, needed); + + wrenAppendCallFrame(vm, fiber, closure, fiber->stackTop - numArgs); +} + +// Marks [obj] as a GC root so that it doesn't get collected. +void wrenPushRoot(WrenVM* vm, Obj* obj); + +// Removes the most recently pushed temporary root. +void wrenPopRoot(WrenVM* vm); + +// Returns the class of [value]. +// +// Defined here instead of in wren_value.h because it's critical that this be +// inlined. That means it must be defined in the header, but the wren_value.h +// header doesn't have a full definitely of WrenVM yet. +static inline ObjClass* wrenGetClassInline(WrenVM* vm, Value value) +{ + if (IS_NUM(value)) return vm->numClass; + if (IS_OBJ(value)) return AS_OBJ(value)->classObj; + +#if WREN_NAN_TAGGING + switch (GET_TAG(value)) + { + case TAG_FALSE: return vm->boolClass; break; + case TAG_NAN: return vm->numClass; break; + case TAG_NULL: return vm->nullClass; break; + case TAG_TRUE: return vm->boolClass; break; + case TAG_UNDEFINED: UNREACHABLE(); + } +#else + switch (value.type) + { + case VAL_FALSE: return vm->boolClass; + case VAL_NULL: return vm->nullClass; + case VAL_NUM: return vm->numClass; + case VAL_TRUE: return vm->boolClass; + case VAL_OBJ: return AS_OBJ(value)->classObj; + case VAL_UNDEFINED: UNREACHABLE(); + } +#endif + + UNREACHABLE(); + return NULL; +} + +// Returns `true` if [name] is a local variable name (starts with a lowercase +// letter). +static inline bool wrenIsLocalName(const char* name) +{ + return name[0] >= 'a' && name[0] <= 'z'; +} + +static inline bool wrenIsFalsyValue(Value value) +{ + return IS_FALSE(value) || IS_NULL(value); +} + +#endif +// End file "wren_vm.h" + +// Prints the stack trace for the current fiber. +// +// Used when a fiber throws a runtime error which is not caught. +void wrenDebugPrintStackTrace(WrenVM* vm); + +// The "dump" functions are used for debugging Wren itself. Normal code paths +// will not call them unless one of the various DEBUG_ flags is enabled. + +// Prints a representation of [value] to stdout. +void wrenDumpValue(Value value); + +// Prints a representation of the bytecode for [fn] at instruction [i]. +int wrenDumpInstruction(WrenVM* vm, ObjFn* fn, int i); + +// Prints the disassembled code for [fn] to stdout. +void wrenDumpCode(WrenVM* vm, ObjFn* fn); + +// Prints the contents of the current stack for [fiber] to stdout. +void wrenDumpStack(ObjFiber* fiber); + +#endif +// End file "wren_debug.h" +// Begin file "wren_debug.c" +#include + + +void wrenDebugPrintStackTrace(WrenVM* vm) +{ + // Bail if the host doesn't enable printing errors. + if (vm->config.errorFn == NULL) return; + + ObjFiber* fiber = vm->fiber; + if (IS_STRING(fiber->error)) + { + vm->config.errorFn(vm, WREN_ERROR_RUNTIME, + NULL, -1, AS_CSTRING(fiber->error)); + } + else + { + // TODO: Print something a little useful here. Maybe the name of the error's + // class? + vm->config.errorFn(vm, WREN_ERROR_RUNTIME, + NULL, -1, "[error object]"); + } + + for (int i = fiber->numFrames - 1; i >= 0; i--) + { + CallFrame* frame = &fiber->frames[i]; + ObjFn* fn = frame->closure->fn; + + // Skip over stub functions for calling methods from the C API. + if (fn->module == NULL) continue; + + // The built-in core module has no name. We explicitly omit it from stack + // traces since we don't want to highlight to a user the implementation + // detail of what part of the core module is written in C and what is Wren. + if (fn->module->name == NULL) continue; + + // -1 because IP has advanced past the instruction that it just executed. + int line = fn->debug->sourceLines.data[frame->ip - fn->code.data - 1]; + vm->config.errorFn(vm, WREN_ERROR_STACK_TRACE, + fn->module->name->value, line, + fn->debug->name); + } +} + +static void dumpObject(Obj* obj) +{ + switch (obj->type) + { + case OBJ_CLASS: + printf("[class %s %p]", ((ObjClass*)obj)->name->value, obj); + break; + case OBJ_CLOSURE: printf("[closure %p]", obj); break; + case OBJ_FIBER: printf("[fiber %p]", obj); break; + case OBJ_FN: printf("[fn %p]", obj); break; + case OBJ_FOREIGN: printf("[foreign %p]", obj); break; + case OBJ_INSTANCE: printf("[instance %p]", obj); break; + case OBJ_LIST: printf("[list %p]", obj); break; + case OBJ_MAP: printf("[map %p]", obj); break; + case OBJ_MODULE: printf("[module %p]", obj); break; + case OBJ_RANGE: printf("[range %p]", obj); break; + case OBJ_STRING: printf("%s", ((ObjString*)obj)->value); break; + case OBJ_UPVALUE: printf("[upvalue %p]", obj); break; + default: printf("[unknown object %d]", obj->type); break; + } +} + +void wrenDumpValue(Value value) +{ +#if WREN_NAN_TAGGING + if (IS_NUM(value)) + { + printf("%.14g", AS_NUM(value)); + } + else if (IS_OBJ(value)) + { + dumpObject(AS_OBJ(value)); + } + else + { + switch (GET_TAG(value)) + { + case TAG_FALSE: printf("false"); break; + case TAG_NAN: printf("NaN"); break; + case TAG_NULL: printf("null"); break; + case TAG_TRUE: printf("true"); break; + case TAG_UNDEFINED: UNREACHABLE(); + } + } +#else + switch (value.type) + { + case VAL_FALSE: printf("false"); break; + case VAL_NULL: printf("null"); break; + case VAL_NUM: printf("%.14g", AS_NUM(value)); break; + case VAL_TRUE: printf("true"); break; + case VAL_OBJ: dumpObject(AS_OBJ(value)); break; + case VAL_UNDEFINED: UNREACHABLE(); + } +#endif +} + +static int dumpInstruction(WrenVM* vm, ObjFn* fn, int i, int* lastLine) +{ + int start = i; + uint8_t* bytecode = fn->code.data; + Code code = (Code)bytecode[i]; + + int line = fn->debug->sourceLines.data[i]; + if (lastLine == NULL || *lastLine != line) + { + printf("%4d:", line); + if (lastLine != NULL) *lastLine = line; + } + else + { + printf(" "); + } + + printf(" %04d ", i++); + + #define READ_BYTE() (bytecode[i++]) + #define READ_SHORT() (i += 2, (bytecode[i - 2] << 8) | bytecode[i - 1]) + + #define BYTE_INSTRUCTION(name) \ + printf("%-16s %5d\n", name, READ_BYTE()); \ + break + + switch (code) + { + case CODE_CONSTANT: + { + int constant = READ_SHORT(); + printf("%-16s %5d '", "CONSTANT", constant); + wrenDumpValue(fn->constants.data[constant]); + printf("'\n"); + break; + } + + case CODE_NULL: printf("NULL\n"); break; + case CODE_FALSE: printf("FALSE\n"); break; + case CODE_TRUE: printf("TRUE\n"); break; + + case CODE_LOAD_LOCAL_0: printf("LOAD_LOCAL_0\n"); break; + case CODE_LOAD_LOCAL_1: printf("LOAD_LOCAL_1\n"); break; + case CODE_LOAD_LOCAL_2: printf("LOAD_LOCAL_2\n"); break; + case CODE_LOAD_LOCAL_3: printf("LOAD_LOCAL_3\n"); break; + case CODE_LOAD_LOCAL_4: printf("LOAD_LOCAL_4\n"); break; + case CODE_LOAD_LOCAL_5: printf("LOAD_LOCAL_5\n"); break; + case CODE_LOAD_LOCAL_6: printf("LOAD_LOCAL_6\n"); break; + case CODE_LOAD_LOCAL_7: printf("LOAD_LOCAL_7\n"); break; + case CODE_LOAD_LOCAL_8: printf("LOAD_LOCAL_8\n"); break; + + case CODE_LOAD_LOCAL: BYTE_INSTRUCTION("LOAD_LOCAL"); + case CODE_STORE_LOCAL: BYTE_INSTRUCTION("STORE_LOCAL"); + case CODE_LOAD_UPVALUE: BYTE_INSTRUCTION("LOAD_UPVALUE"); + case CODE_STORE_UPVALUE: BYTE_INSTRUCTION("STORE_UPVALUE"); + + case CODE_LOAD_MODULE_VAR: + { + int slot = READ_SHORT(); + printf("%-16s %5d '%s'\n", "LOAD_MODULE_VAR", slot, + fn->module->variableNames.data[slot]->value); + break; + } + + case CODE_STORE_MODULE_VAR: + { + int slot = READ_SHORT(); + printf("%-16s %5d '%s'\n", "STORE_MODULE_VAR", slot, + fn->module->variableNames.data[slot]->value); + break; + } + + case CODE_LOAD_FIELD_THIS: BYTE_INSTRUCTION("LOAD_FIELD_THIS"); + case CODE_STORE_FIELD_THIS: BYTE_INSTRUCTION("STORE_FIELD_THIS"); + case CODE_LOAD_FIELD: BYTE_INSTRUCTION("LOAD_FIELD"); + case CODE_STORE_FIELD: BYTE_INSTRUCTION("STORE_FIELD"); + + case CODE_POP: printf("POP\n"); break; + + case CODE_CALL_0: + case CODE_CALL_1: + case CODE_CALL_2: + case CODE_CALL_3: + case CODE_CALL_4: + case CODE_CALL_5: + case CODE_CALL_6: + case CODE_CALL_7: + case CODE_CALL_8: + case CODE_CALL_9: + case CODE_CALL_10: + case CODE_CALL_11: + case CODE_CALL_12: + case CODE_CALL_13: + case CODE_CALL_14: + case CODE_CALL_15: + case CODE_CALL_16: + { + int numArgs = bytecode[i - 1] - CODE_CALL_0; + int symbol = READ_SHORT(); + printf("CALL_%-11d %5d '%s'\n", numArgs, symbol, + vm->methodNames.data[symbol]->value); + break; + } + + case CODE_SUPER_0: + case CODE_SUPER_1: + case CODE_SUPER_2: + case CODE_SUPER_3: + case CODE_SUPER_4: + case CODE_SUPER_5: + case CODE_SUPER_6: + case CODE_SUPER_7: + case CODE_SUPER_8: + case CODE_SUPER_9: + case CODE_SUPER_10: + case CODE_SUPER_11: + case CODE_SUPER_12: + case CODE_SUPER_13: + case CODE_SUPER_14: + case CODE_SUPER_15: + case CODE_SUPER_16: + { + int numArgs = bytecode[i - 1] - CODE_SUPER_0; + int symbol = READ_SHORT(); + int superclass = READ_SHORT(); + printf("SUPER_%-10d %5d '%s' %5d\n", numArgs, symbol, + vm->methodNames.data[symbol]->value, superclass); + break; + } + + case CODE_JUMP: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "JUMP", offset, i + offset); + break; + } + + case CODE_LOOP: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "LOOP", offset, i - offset); + break; + } + + case CODE_JUMP_IF: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "JUMP_IF", offset, i + offset); + break; + } + + case CODE_AND: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "AND", offset, i + offset); + break; + } + + case CODE_OR: + { + int offset = READ_SHORT(); + printf("%-16s %5d to %d\n", "OR", offset, i + offset); + break; + } + + case CODE_CLOSE_UPVALUE: printf("CLOSE_UPVALUE\n"); break; + case CODE_RETURN: printf("RETURN\n"); break; + + case CODE_CLOSURE: + { + int constant = READ_SHORT(); + printf("%-16s %5d ", "CLOSURE", constant); + wrenDumpValue(fn->constants.data[constant]); + printf(" "); + ObjFn* loadedFn = AS_FN(fn->constants.data[constant]); + for (int j = 0; j < loadedFn->numUpvalues; j++) + { + int isLocal = READ_BYTE(); + int index = READ_BYTE(); + if (j > 0) printf(", "); + printf("%s %d", isLocal ? "local" : "upvalue", index); + } + printf("\n"); + break; + } + + case CODE_CONSTRUCT: printf("CONSTRUCT\n"); break; + case CODE_FOREIGN_CONSTRUCT: printf("FOREIGN_CONSTRUCT\n"); break; + + case CODE_CLASS: + { + int numFields = READ_BYTE(); + printf("%-16s %5d fields\n", "CLASS", numFields); + break; + } + + case CODE_FOREIGN_CLASS: printf("FOREIGN_CLASS\n"); break; + case CODE_END_CLASS: printf("END_CLASS\n"); break; + + case CODE_METHOD_INSTANCE: + { + int symbol = READ_SHORT(); + printf("%-16s %5d '%s'\n", "METHOD_INSTANCE", symbol, + vm->methodNames.data[symbol]->value); + break; + } + + case CODE_METHOD_STATIC: + { + int symbol = READ_SHORT(); + printf("%-16s %5d '%s'\n", "METHOD_STATIC", symbol, + vm->methodNames.data[symbol]->value); + break; + } + + case CODE_END_MODULE: + printf("END_MODULE\n"); + break; + + case CODE_IMPORT_MODULE: + { + int name = READ_SHORT(); + printf("%-16s %5d '", "IMPORT_MODULE", name); + wrenDumpValue(fn->constants.data[name]); + printf("'\n"); + break; + } + + case CODE_IMPORT_VARIABLE: + { + int variable = READ_SHORT(); + printf("%-16s %5d '", "IMPORT_VARIABLE", variable); + wrenDumpValue(fn->constants.data[variable]); + printf("'\n"); + break; + } + + case CODE_END: + printf("END\n"); + break; + + default: + printf("UKNOWN! [%d]\n", bytecode[i - 1]); + break; + } + + // Return how many bytes this instruction takes, or -1 if it's an END. + if (code == CODE_END) return -1; + return i - start; + + #undef READ_BYTE + #undef READ_SHORT +} + +int wrenDumpInstruction(WrenVM* vm, ObjFn* fn, int i) +{ + return dumpInstruction(vm, fn, i, NULL); +} + +void wrenDumpCode(WrenVM* vm, ObjFn* fn) +{ + printf("%s: %s\n", + fn->module->name == NULL ? "" : fn->module->name->value, + fn->debug->name); + + int i = 0; + int lastLine = -1; + for (;;) + { + int offset = dumpInstruction(vm, fn, i, &lastLine); + if (offset == -1) break; + i += offset; + } + + printf("\n"); +} + +void wrenDumpStack(ObjFiber* fiber) +{ + printf("(fiber %p) ", fiber); + for (Value* slot = fiber->stack; slot < fiber->stackTop; slot++) + { + wrenDumpValue(*slot); + printf(" | "); + } + printf("\n"); +} +// End file "wren_debug.c" +// Begin file "wren_compiler.c" +#include +#include +#include +#include + + +#if WREN_DEBUG_DUMP_COMPILED_CODE +#endif + +// This is written in bottom-up order, so the tokenization comes first, then +// parsing/code generation. This minimizes the number of explicit forward +// declarations needed. + +// The maximum number of local (i.e. not module level) variables that can be +// declared in a single function, method, or chunk of top level code. This is +// the maximum number of variables in scope at one time, and spans block scopes. +// +// Note that this limitation is also explicit in the bytecode. Since +// `CODE_LOAD_LOCAL` and `CODE_STORE_LOCAL` use a single argument byte to +// identify the local, only 256 can be in scope at one time. +#define MAX_LOCALS 256 + +// The maximum number of upvalues (i.e. variables from enclosing functions) +// that a function can close over. +#define MAX_UPVALUES 256 + +// The maximum number of distinct constants that a function can contain. This +// value is explicit in the bytecode since `CODE_CONSTANT` only takes a single +// two-byte argument. +#define MAX_CONSTANTS (1 << 16) + +// The maximum distance a CODE_JUMP or CODE_JUMP_IF instruction can move the +// instruction pointer. +#define MAX_JUMP (1 << 16) + +// The maximum depth that interpolation can nest. For example, this string has +// three levels: +// +// "outside %(one + "%(two + "%(three)")")" +#define MAX_INTERPOLATION_NESTING 8 + +// The buffer size used to format a compile error message, excluding the header +// with the module name and error location. Using a hardcoded buffer for this +// is kind of hairy, but fortunately we can control what the longest possible +// message is and handle that. Ideally, we'd use `snprintf()`, but that's not +// available in standard C++98. +#define ERROR_MESSAGE_SIZE (80 + MAX_VARIABLE_NAME + 15) + +typedef enum +{ + TOKEN_LEFT_PAREN, + TOKEN_RIGHT_PAREN, + TOKEN_LEFT_BRACKET, + TOKEN_RIGHT_BRACKET, + TOKEN_LEFT_BRACE, + TOKEN_RIGHT_BRACE, + TOKEN_COLON, + TOKEN_DOT, + TOKEN_DOTDOT, + TOKEN_DOTDOTDOT, + TOKEN_COMMA, + TOKEN_STAR, + TOKEN_SLASH, + TOKEN_PERCENT, + TOKEN_HASH, + TOKEN_PLUS, + TOKEN_MINUS, + TOKEN_LTLT, + TOKEN_GTGT, + TOKEN_PIPE, + TOKEN_PIPEPIPE, + TOKEN_CARET, + TOKEN_AMP, + TOKEN_AMPAMP, + TOKEN_BANG, + TOKEN_TILDE, + TOKEN_QUESTION, + TOKEN_EQ, + TOKEN_LT, + TOKEN_GT, + TOKEN_LTEQ, + TOKEN_GTEQ, + TOKEN_EQEQ, + TOKEN_BANGEQ, + + TOKEN_BREAK, + TOKEN_CONTINUE, + TOKEN_CLASS, + TOKEN_CONSTRUCT, + TOKEN_ELSE, + TOKEN_FALSE, + TOKEN_FOR, + TOKEN_FOREIGN, + TOKEN_IF, + TOKEN_IMPORT, + TOKEN_AS, + TOKEN_IN, + TOKEN_IS, + TOKEN_NULL, + TOKEN_RETURN, + TOKEN_STATIC, + TOKEN_SUPER, + TOKEN_THIS, + TOKEN_TRUE, + TOKEN_VAR, + TOKEN_WHILE, + + TOKEN_FIELD, + TOKEN_STATIC_FIELD, + TOKEN_NAME, + TOKEN_NUMBER, + + // A string literal without any interpolation, or the last section of a + // string following the last interpolated expression. + TOKEN_STRING, + + // A portion of a string literal preceding an interpolated expression. This + // string: + // + // "a %(b) c %(d) e" + // + // is tokenized to: + // + // TOKEN_INTERPOLATION "a " + // TOKEN_NAME b + // TOKEN_INTERPOLATION " c " + // TOKEN_NAME d + // TOKEN_STRING " e" + TOKEN_INTERPOLATION, + + TOKEN_LINE, + + TOKEN_ERROR, + TOKEN_EOF +} TokenType; + +typedef struct +{ + TokenType type; + + // The beginning of the token, pointing directly into the source. + const char* start; + + // The length of the token in characters. + int length; + + // The 1-based line where the token appears. + int line; + + // The parsed value if the token is a literal. + Value value; +} Token; + +typedef struct +{ + WrenVM* vm; + + // The module being parsed. + ObjModule* module; + + // The source code being parsed. + const char* source; + + // The beginning of the currently-being-lexed token in [source]. + const char* tokenStart; + + // The current character being lexed in [source]. + const char* currentChar; + + // The 1-based line number of [currentChar]. + int currentLine; + + // The upcoming token. + Token next; + + // The most recently lexed token. + Token current; + + // The most recently consumed/advanced token. + Token previous; + + // Tracks the lexing state when tokenizing interpolated strings. + // + // Interpolated strings make the lexer not strictly regular: we don't know + // whether a ")" should be treated as a RIGHT_PAREN token or as ending an + // interpolated expression unless we know whether we are inside a string + // interpolation and how many unmatched "(" there are. This is particularly + // complex because interpolation can nest: + // + // " %( " %( inner ) " ) " + // + // This tracks that state. The parser maintains a stack of ints, one for each + // level of current interpolation nesting. Each value is the number of + // unmatched "(" that are waiting to be closed. + int parens[MAX_INTERPOLATION_NESTING]; + int numParens; + + // Whether compile errors should be printed to stderr or discarded. + bool printErrors; + + // If a syntax or compile error has occurred. + bool hasError; +} Parser; + +typedef struct +{ + // The name of the local variable. This points directly into the original + // source code string. + const char* name; + + // The length of the local variable's name. + int length; + + // The depth in the scope chain that this variable was declared at. Zero is + // the outermost scope--parameters for a method, or the first local block in + // top level code. One is the scope within that, etc. + int depth; + + // If this local variable is being used as an upvalue. + bool isUpvalue; +} Local; + +typedef struct +{ + // True if this upvalue is capturing a local variable from the enclosing + // function. False if it's capturing an upvalue. + bool isLocal; + + // The index of the local or upvalue being captured in the enclosing function. + int index; +} CompilerUpvalue; + +// Bookkeeping information for the current loop being compiled. +typedef struct sLoop +{ + // Index of the instruction that the loop should jump back to. + int start; + + // Index of the argument for the CODE_JUMP_IF instruction used to exit the + // loop. Stored so we can patch it once we know where the loop ends. + int exitJump; + + // Index of the first instruction of the body of the loop. + int body; + + // Depth of the scope(s) that need to be exited if a break is hit inside the + // loop. + int scopeDepth; + + // The loop enclosing this one, or NULL if this is the outermost loop. + struct sLoop* enclosing; +} Loop; + +// The different signature syntaxes for different kinds of methods. +typedef enum +{ + // A name followed by a (possibly empty) parenthesized parameter list. Also + // used for binary operators. + SIG_METHOD, + + // Just a name. Also used for unary operators. + SIG_GETTER, + + // A name followed by "=". + SIG_SETTER, + + // A square bracketed parameter list. + SIG_SUBSCRIPT, + + // A square bracketed parameter list followed by "=". + SIG_SUBSCRIPT_SETTER, + + // A constructor initializer function. This has a distinct signature to + // prevent it from being invoked directly outside of the constructor on the + // metaclass. + SIG_INITIALIZER +} SignatureType; + +typedef struct +{ + const char* name; + int length; + SignatureType type; + int arity; +} Signature; + +// Bookkeeping information for compiling a class definition. +typedef struct +{ + // The name of the class. + ObjString* name; + + // Attributes for the class itself + ObjMap* classAttributes; + // Attributes for methods in this class + ObjMap* methodAttributes; + + // Symbol table for the fields of the class. + SymbolTable fields; + + // Symbols for the methods defined by the class. Used to detect duplicate + // method definitions. + IntBuffer methods; + IntBuffer staticMethods; + + // True if the class being compiled is a foreign class. + bool isForeign; + + // True if the current method being compiled is static. + bool inStatic; + + // The signature of the method being compiled. + Signature* signature; +} ClassInfo; + +struct sCompiler +{ + Parser* parser; + + // The compiler for the function enclosing this one, or NULL if it's the + // top level. + struct sCompiler* parent; + + // The currently in scope local variables. + Local locals[MAX_LOCALS]; + + // The number of local variables currently in scope. + int numLocals; + + // The upvalues that this function has captured from outer scopes. The count + // of them is stored in [numUpvalues]. + CompilerUpvalue upvalues[MAX_UPVALUES]; + + // The current level of block scope nesting, where zero is no nesting. A -1 + // here means top-level code is being compiled and there is no block scope + // in effect at all. Any variables declared will be module-level. + int scopeDepth; + + // The current number of slots (locals and temporaries) in use. + // + // We use this and maxSlots to track the maximum number of additional slots + // a function may need while executing. When the function is called, the + // fiber will check to ensure its stack has enough room to cover that worst + // case and grow the stack if needed. + // + // This value here doesn't include parameters to the function. Since those + // are already pushed onto the stack by the caller and tracked there, we + // don't need to double count them here. + int numSlots; + + // The current innermost loop being compiled, or NULL if not in a loop. + Loop* loop; + + // If this is a compiler for a method, keeps track of the class enclosing it. + ClassInfo* enclosingClass; + + // The function being compiled. + ObjFn* fn; + + // The constants for the function being compiled. + ObjMap* constants; + + // Whether or not the compiler is for a constructor initializer + bool isInitializer; + + // The number of attributes seen while parsing. + // We track this separately as compile time attributes + // are not stored, so we can't rely on attributes->count + // to enforce an error message when attributes are used + // anywhere other than methods or classes. + int numAttributes; + // Attributes for the next class or method. + ObjMap* attributes; +}; + +// Describes where a variable is declared. +typedef enum +{ + // A local variable in the current function. + SCOPE_LOCAL, + + // A local variable declared in an enclosing function. + SCOPE_UPVALUE, + + // A top-level module variable. + SCOPE_MODULE +} Scope; + +// A reference to a variable and the scope where it is defined. This contains +// enough information to emit correct code to load or store the variable. +typedef struct +{ + // The stack slot, upvalue slot, or module symbol defining the variable. + int index; + + // Where the variable is declared. + Scope scope; +} Variable; + +// Forward declarations +static void disallowAttributes(Compiler* compiler); +static void addToAttributeGroup(Compiler* compiler, Value group, Value key, Value value); +static void emitClassAttributes(Compiler* compiler, ClassInfo* classInfo); +static void copyAttributes(Compiler* compiler, ObjMap* into); +static void copyMethodAttributes(Compiler* compiler, bool isForeign, + bool isStatic, const char* fullSignature, int32_t length); + +// The stack effect of each opcode. The index in the array is the opcode, and +// the value is the stack effect of that instruction. +static const int stackEffects[] = { + #define OPCODE(_, effect) effect, +// Begin file "wren_opcodes.h" +// This defines the bytecode instructions used by the VM. It does so by invoking +// an OPCODE() macro which is expected to be defined at the point that this is +// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.) +// +// The first argument is the name of the opcode. The second is its "stack +// effect" -- the amount that the op code changes the size of the stack. A +// stack effect of 1 means it pushes a value and the stack grows one larger. +// -2 means it pops two values, etc. +// +// Note that the order of instructions here affects the order of the dispatch +// table in the VM's interpreter loop. That in turn affects caching which +// affects overall performance. Take care to run benchmarks if you change the +// order here. + +// Load the constant at index [arg]. +OPCODE(CONSTANT, 1) + +// Push null onto the stack. +OPCODE(NULL, 1) + +// Push false onto the stack. +OPCODE(FALSE, 1) + +// Push true onto the stack. +OPCODE(TRUE, 1) + +// Pushes the value in the given local slot. +OPCODE(LOAD_LOCAL_0, 1) +OPCODE(LOAD_LOCAL_1, 1) +OPCODE(LOAD_LOCAL_2, 1) +OPCODE(LOAD_LOCAL_3, 1) +OPCODE(LOAD_LOCAL_4, 1) +OPCODE(LOAD_LOCAL_5, 1) +OPCODE(LOAD_LOCAL_6, 1) +OPCODE(LOAD_LOCAL_7, 1) +OPCODE(LOAD_LOCAL_8, 1) + +// Note: The compiler assumes the following _STORE instructions always +// immediately follow their corresponding _LOAD ones. + +// Pushes the value in local slot [arg]. +OPCODE(LOAD_LOCAL, 1) + +// Stores the top of stack in local slot [arg]. Does not pop it. +OPCODE(STORE_LOCAL, 0) + +// Pushes the value in upvalue [arg]. +OPCODE(LOAD_UPVALUE, 1) + +// Stores the top of stack in upvalue [arg]. Does not pop it. +OPCODE(STORE_UPVALUE, 0) + +// Pushes the value of the top-level variable in slot [arg]. +OPCODE(LOAD_MODULE_VAR, 1) + +// Stores the top of stack in top-level variable slot [arg]. Does not pop it. +OPCODE(STORE_MODULE_VAR, 0) + +// Pushes the value of the field in slot [arg] of the receiver of the current +// function. This is used for regular field accesses on "this" directly in +// methods. This instruction is faster than the more general CODE_LOAD_FIELD +// instruction. +OPCODE(LOAD_FIELD_THIS, 1) + +// Stores the top of the stack in field slot [arg] in the receiver of the +// current value. Does not pop the value. This instruction is faster than the +// more general CODE_LOAD_FIELD instruction. +OPCODE(STORE_FIELD_THIS, 0) + +// Pops an instance and pushes the value of the field in slot [arg] of it. +OPCODE(LOAD_FIELD, 0) + +// Pops an instance and stores the subsequent top of stack in field slot +// [arg] in it. Does not pop the value. +OPCODE(STORE_FIELD, -1) + +// Pop and discard the top of stack. +OPCODE(POP, -1) + +// Invoke the method with symbol [arg]. The number indicates the number of +// arguments (not including the receiver). +OPCODE(CALL_0, 0) +OPCODE(CALL_1, -1) +OPCODE(CALL_2, -2) +OPCODE(CALL_3, -3) +OPCODE(CALL_4, -4) +OPCODE(CALL_5, -5) +OPCODE(CALL_6, -6) +OPCODE(CALL_7, -7) +OPCODE(CALL_8, -8) +OPCODE(CALL_9, -9) +OPCODE(CALL_10, -10) +OPCODE(CALL_11, -11) +OPCODE(CALL_12, -12) +OPCODE(CALL_13, -13) +OPCODE(CALL_14, -14) +OPCODE(CALL_15, -15) +OPCODE(CALL_16, -16) + +// Invoke a superclass method with symbol [arg]. The number indicates the +// number of arguments (not including the receiver). +OPCODE(SUPER_0, 0) +OPCODE(SUPER_1, -1) +OPCODE(SUPER_2, -2) +OPCODE(SUPER_3, -3) +OPCODE(SUPER_4, -4) +OPCODE(SUPER_5, -5) +OPCODE(SUPER_6, -6) +OPCODE(SUPER_7, -7) +OPCODE(SUPER_8, -8) +OPCODE(SUPER_9, -9) +OPCODE(SUPER_10, -10) +OPCODE(SUPER_11, -11) +OPCODE(SUPER_12, -12) +OPCODE(SUPER_13, -13) +OPCODE(SUPER_14, -14) +OPCODE(SUPER_15, -15) +OPCODE(SUPER_16, -16) + +// Jump the instruction pointer [arg] forward. +OPCODE(JUMP, 0) + +// Jump the instruction pointer [arg] backward. +OPCODE(LOOP, 0) + +// Pop and if not truthy then jump the instruction pointer [arg] forward. +OPCODE(JUMP_IF, -1) + +// If the top of the stack is false, jump [arg] forward. Otherwise, pop and +// continue. +OPCODE(AND, -1) + +// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop +// and continue. +OPCODE(OR, -1) + +// Close the upvalue for the local on the top of the stack, then pop it. +OPCODE(CLOSE_UPVALUE, -1) + +// Exit from the current function and return the value on the top of the +// stack. +OPCODE(RETURN, 0) + +// Creates a closure for the function stored at [arg] in the constant table. +// +// Following the function argument is a number of arguments, two for each +// upvalue. The first is true if the variable being captured is a local (as +// opposed to an upvalue), and the second is the index of the local or +// upvalue being captured. +// +// Pushes the created closure. +OPCODE(CLOSURE, 1) + +// Creates a new instance of a class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(CONSTRUCT, 0) + +// Creates a new instance of a foreign class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(FOREIGN_CONSTRUCT, 0) + +// Creates a class. Top of stack is the superclass. Below that is a string for +// the name of the class. Byte [arg] is the number of fields in the class. +OPCODE(CLASS, -1) + +// Ends a class. +// Atm the stack contains the class and the ClassAttributes (or null). +OPCODE(END_CLASS, -2) + +// Creates a foreign class. Top of stack is the superclass. Below that is a +// string for the name of the class. +OPCODE(FOREIGN_CLASS, -1) + +// Define a method for symbol [arg]. The class receiving the method is popped +// off the stack, then the function defining the body is popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_INSTANCE, -2) + +// Define a method for symbol [arg]. The class whose metaclass will receive +// the method is popped off the stack, then the function defining the body is +// popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_STATIC, -2) + +// This is executed at the end of the module's body. Pushes NULL onto the stack +// as the "return value" of the import statement and stores the module as the +// most recently imported one. +OPCODE(END_MODULE, 1) + +// Import a module whose name is the string stored at [arg] in the constant +// table. +// +// Pushes null onto the stack so that the fiber for the imported module can +// replace that with a dummy value when it returns. (Fibers always return a +// value when resuming a caller.) +OPCODE(IMPORT_MODULE, 1) + +// Import a variable from the most recently imported module. The name of the +// variable to import is at [arg] in the constant table. Pushes the loaded +// variable's value. +OPCODE(IMPORT_VARIABLE, 1) + +// This pseudo-instruction indicates the end of the bytecode. It should +// always be preceded by a `CODE_RETURN`, so is never actually executed. +OPCODE(END, 0) +// End file "wren_opcodes.h" + #undef OPCODE +}; + +static void printError(Parser* parser, int line, const char* label, + const char* format, va_list args) +{ + parser->hasError = true; + if (!parser->printErrors) return; + + // Only report errors if there is a WrenErrorFn to handle them. + if (parser->vm->config.errorFn == NULL) return; + + // Format the label and message. + char message[ERROR_MESSAGE_SIZE]; + int length = sprintf(message, "%s: ", label); + length += vsprintf(message + length, format, args); + ASSERT(length < ERROR_MESSAGE_SIZE, "Error should not exceed buffer."); + + ObjString* module = parser->module->name; + const char* module_name = module ? module->value : ""; + + parser->vm->config.errorFn(parser->vm, WREN_ERROR_COMPILE, + module_name, line, message); +} + +// Outputs a lexical error. +static void lexError(Parser* parser, const char* format, ...) +{ + va_list args; + va_start(args, format); + printError(parser, parser->currentLine, "Error", format, args); + va_end(args); +} + +// Outputs a compile or syntax error. This also marks the compilation as having +// an error, which ensures that the resulting code will be discarded and never +// run. This means that after calling error(), it's fine to generate whatever +// invalid bytecode you want since it won't be used. +// +// You'll note that most places that call error() continue to parse and compile +// after that. That's so that we can try to find as many compilation errors in +// one pass as possible instead of just bailing at the first one. +static void error(Compiler* compiler, const char* format, ...) +{ + Token* token = &compiler->parser->previous; + + // If the parse error was caused by an error token, the lexer has already + // reported it. + if (token->type == TOKEN_ERROR) return; + + va_list args; + va_start(args, format); + if (token->type == TOKEN_LINE) + { + printError(compiler->parser, token->line, "Error at newline", format, args); + } + else if (token->type == TOKEN_EOF) + { + printError(compiler->parser, token->line, + "Error at end of file", format, args); + } + else + { + // Make sure we don't exceed the buffer with a very long token. + char label[10 + MAX_VARIABLE_NAME + 4 + 1]; + if (token->length <= MAX_VARIABLE_NAME) + { + sprintf(label, "Error at '%.*s'", token->length, token->start); + } + else + { + sprintf(label, "Error at '%.*s...'", MAX_VARIABLE_NAME, token->start); + } + printError(compiler->parser, token->line, label, format, args); + } + va_end(args); +} + +// Adds [constant] to the constant pool and returns its index. +static int addConstant(Compiler* compiler, Value constant) +{ + if (compiler->parser->hasError) return -1; + + // See if we already have a constant for the value. If so, reuse it. + if (compiler->constants != NULL) + { + Value existing = wrenMapGet(compiler->constants, constant); + if (IS_NUM(existing)) return (int)AS_NUM(existing); + } + + // It's a new constant. + if (compiler->fn->constants.count < MAX_CONSTANTS) + { + if (IS_OBJ(constant)) wrenPushRoot(compiler->parser->vm, AS_OBJ(constant)); + wrenValueBufferWrite(compiler->parser->vm, &compiler->fn->constants, + constant); + if (IS_OBJ(constant)) wrenPopRoot(compiler->parser->vm); + + if (compiler->constants == NULL) + { + compiler->constants = wrenNewMap(compiler->parser->vm); + } + wrenMapSet(compiler->parser->vm, compiler->constants, constant, + NUM_VAL(compiler->fn->constants.count - 1)); + } + else + { + error(compiler, "A function may only contain %d unique constants.", + MAX_CONSTANTS); + } + + return compiler->fn->constants.count - 1; +} + +// Initializes [compiler]. +static void initCompiler(Compiler* compiler, Parser* parser, Compiler* parent, + bool isMethod) +{ + compiler->parser = parser; + compiler->parent = parent; + compiler->loop = NULL; + compiler->enclosingClass = NULL; + compiler->isInitializer = false; + + // Initialize these to NULL before allocating in case a GC gets triggered in + // the middle of initializing the compiler. + compiler->fn = NULL; + compiler->constants = NULL; + compiler->attributes = NULL; + + parser->vm->compiler = compiler; + + // Declare a local slot for either the closure or method receiver so that we + // don't try to reuse that slot for a user-defined local variable. For + // methods, we name it "this", so that we can resolve references to that like + // a normal variable. For functions, they have no explicit "this", so we use + // an empty name. That way references to "this" inside a function walks up + // the parent chain to find a method enclosing the function whose "this" we + // can close over. + compiler->numLocals = 1; + compiler->numSlots = compiler->numLocals; + + if (isMethod) + { + compiler->locals[0].name = "this"; + compiler->locals[0].length = 4; + } + else + { + compiler->locals[0].name = NULL; + compiler->locals[0].length = 0; + } + + compiler->locals[0].depth = -1; + compiler->locals[0].isUpvalue = false; + + if (parent == NULL) + { + // Compiling top-level code, so the initial scope is module-level. + compiler->scopeDepth = -1; + } + else + { + // The initial scope for functions and methods is local scope. + compiler->scopeDepth = 0; + } + + compiler->numAttributes = 0; + compiler->attributes = wrenNewMap(parser->vm); + compiler->fn = wrenNewFunction(parser->vm, parser->module, + compiler->numLocals); +} + +// Lexing ---------------------------------------------------------------------- + +typedef struct +{ + const char* identifier; + size_t length; + TokenType tokenType; +} Keyword; + +// The table of reserved words and their associated token types. +static Keyword keywords[] = +{ + {"break", 5, TOKEN_BREAK}, + {"continue", 8, TOKEN_CONTINUE}, + {"class", 5, TOKEN_CLASS}, + {"construct", 9, TOKEN_CONSTRUCT}, + {"else", 4, TOKEN_ELSE}, + {"false", 5, TOKEN_FALSE}, + {"for", 3, TOKEN_FOR}, + {"foreign", 7, TOKEN_FOREIGN}, + {"if", 2, TOKEN_IF}, + {"import", 6, TOKEN_IMPORT}, + {"as", 2, TOKEN_AS}, + {"in", 2, TOKEN_IN}, + {"is", 2, TOKEN_IS}, + {"null", 4, TOKEN_NULL}, + {"return", 6, TOKEN_RETURN}, + {"static", 6, TOKEN_STATIC}, + {"super", 5, TOKEN_SUPER}, + {"this", 4, TOKEN_THIS}, + {"true", 4, TOKEN_TRUE}, + {"var", 3, TOKEN_VAR}, + {"while", 5, TOKEN_WHILE}, + {NULL, 0, TOKEN_EOF} // Sentinel to mark the end of the array. +}; + +// Returns true if [c] is a valid (non-initial) identifier character. +static bool isName(char c) +{ + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'; +} + +// Returns true if [c] is a digit. +static bool isDigit(char c) +{ + return c >= '0' && c <= '9'; +} + +// Returns the current character the parser is sitting on. +static char peekChar(Parser* parser) +{ + return *parser->currentChar; +} + +// Returns the character after the current character. +static char peekNextChar(Parser* parser) +{ + // If we're at the end of the source, don't read past it. + if (peekChar(parser) == '\0') return '\0'; + return *(parser->currentChar + 1); +} + +// Advances the parser forward one character. +static char nextChar(Parser* parser) +{ + char c = peekChar(parser); + parser->currentChar++; + if (c == '\n') parser->currentLine++; + return c; +} + +// If the current character is [c], consumes it and returns `true`. +static bool matchChar(Parser* parser, char c) +{ + if (peekChar(parser) != c) return false; + nextChar(parser); + return true; +} + +// Sets the parser's current token to the given [type] and current character +// range. +static void makeToken(Parser* parser, TokenType type) +{ + parser->next.type = type; + parser->next.start = parser->tokenStart; + parser->next.length = (int)(parser->currentChar - parser->tokenStart); + parser->next.line = parser->currentLine; + + // Make line tokens appear on the line containing the "\n". + if (type == TOKEN_LINE) parser->next.line--; +} + +// If the current character is [c], then consumes it and makes a token of type +// [two]. Otherwise makes a token of type [one]. +static void twoCharToken(Parser* parser, char c, TokenType two, TokenType one) +{ + makeToken(parser, matchChar(parser, c) ? two : one); +} + +// Skips the rest of the current line. +static void skipLineComment(Parser* parser) +{ + while (peekChar(parser) != '\n' && peekChar(parser) != '\0') + { + nextChar(parser); + } +} + +// Skips the rest of a block comment. +static void skipBlockComment(Parser* parser) +{ + int nesting = 1; + while (nesting > 0) + { + if (peekChar(parser) == '\0') + { + lexError(parser, "Unterminated block comment."); + return; + } + + if (peekChar(parser) == '/' && peekNextChar(parser) == '*') + { + nextChar(parser); + nextChar(parser); + nesting++; + continue; + } + + if (peekChar(parser) == '*' && peekNextChar(parser) == '/') + { + nextChar(parser); + nextChar(parser); + nesting--; + continue; + } + + // Regular comment character. + nextChar(parser); + } +} + +// Reads the next character, which should be a hex digit (0-9, a-f, or A-F) and +// returns its numeric value. If the character isn't a hex digit, returns -1. +static int readHexDigit(Parser* parser) +{ + char c = nextChar(parser); + if (c >= '0' && c <= '9') return c - '0'; + if (c >= 'a' && c <= 'f') return c - 'a' + 10; + if (c >= 'A' && c <= 'F') return c - 'A' + 10; + + // Don't consume it if it isn't expected. Keeps us from reading past the end + // of an unterminated string. + parser->currentChar--; + return -1; +} + +// Parses the numeric value of the current token. +static void makeNumber(Parser* parser, bool isHex) +{ + errno = 0; + + if (isHex) + { + parser->next.value = NUM_VAL((double)strtoll(parser->tokenStart, NULL, 16)); + } + else + { + parser->next.value = NUM_VAL(strtod(parser->tokenStart, NULL)); + } + + if (errno == ERANGE) + { + lexError(parser, "Number literal was too large (%d).", sizeof(long int)); + parser->next.value = NUM_VAL(0); + } + + // We don't check that the entire token is consumed after calling strtoll() + // or strtod() because we've already scanned it ourselves and know it's valid. + + makeToken(parser, TOKEN_NUMBER); +} + +// Finishes lexing a hexadecimal number literal. +static void readHexNumber(Parser* parser) +{ + // Skip past the `x` used to denote a hexadecimal literal. + nextChar(parser); + + // Iterate over all the valid hexadecimal digits found. + while (readHexDigit(parser) != -1) continue; + + makeNumber(parser, true); +} + +// Finishes lexing a number literal. +static void readNumber(Parser* parser) +{ + while (isDigit(peekChar(parser))) nextChar(parser); + + // See if it has a floating point. Make sure there is a digit after the "." + // so we don't get confused by method calls on number literals. + if (peekChar(parser) == '.' && isDigit(peekNextChar(parser))) + { + nextChar(parser); + while (isDigit(peekChar(parser))) nextChar(parser); + } + + // See if the number is in scientific notation. + if (matchChar(parser, 'e') || matchChar(parser, 'E')) + { + // Allow a single positive/negative exponent symbol. + if(!matchChar(parser, '+')) + { + matchChar(parser, '-'); + } + + if (!isDigit(peekChar(parser))) + { + lexError(parser, "Unterminated scientific notation."); + } + + while (isDigit(peekChar(parser))) nextChar(parser); + } + + makeNumber(parser, false); +} + +// Finishes lexing an identifier. Handles reserved words. +static void readName(Parser* parser, TokenType type, char firstChar) +{ + ByteBuffer string; + wrenByteBufferInit(&string); + wrenByteBufferWrite(parser->vm, &string, firstChar); + + while (isName(peekChar(parser)) || isDigit(peekChar(parser))) + { + char c = nextChar(parser); + wrenByteBufferWrite(parser->vm, &string, c); + } + + // Update the type if it's a keyword. + size_t length = parser->currentChar - parser->tokenStart; + for (int i = 0; keywords[i].identifier != NULL; i++) + { + if (length == keywords[i].length && + memcmp(parser->tokenStart, keywords[i].identifier, length) == 0) + { + type = keywords[i].tokenType; + break; + } + } + + parser->next.value = wrenNewStringLength(parser->vm, + (char*)string.data, string.count); + + wrenByteBufferClear(parser->vm, &string); + makeToken(parser, type); +} + +// Reads [digits] hex digits in a string literal and returns their number value. +static int readHexEscape(Parser* parser, int digits, const char* description) +{ + int value = 0; + for (int i = 0; i < digits; i++) + { + if (peekChar(parser) == '"' || peekChar(parser) == '\0') + { + lexError(parser, "Incomplete %s escape sequence.", description); + + // Don't consume it if it isn't expected. Keeps us from reading past the + // end of an unterminated string. + parser->currentChar--; + break; + } + + int digit = readHexDigit(parser); + if (digit == -1) + { + lexError(parser, "Invalid %s escape sequence.", description); + break; + } + + value = (value * 16) | digit; + } + + return value; +} + +// Reads a hex digit Unicode escape sequence in a string literal. +static void readUnicodeEscape(Parser* parser, ByteBuffer* string, int length) +{ + int value = readHexEscape(parser, length, "Unicode"); + + // Grow the buffer enough for the encoded result. + int numBytes = wrenUtf8EncodeNumBytes(value); + if (numBytes != 0) + { + wrenByteBufferFill(parser->vm, string, 0, numBytes); + wrenUtf8Encode(value, string->data + string->count - numBytes); + } +} + +static void readRawString(Parser* parser) +{ + ByteBuffer string; + wrenByteBufferInit(&string); + TokenType type = TOKEN_STRING; + + //consume the second and third " + nextChar(parser); + nextChar(parser); + + int skipStart = 0; + int firstNewline = -1; + + int skipEnd = -1; + int lastNewline = -1; + + for (;;) + { + char c = nextChar(parser); + char c1 = peekChar(parser); + char c2 = peekNextChar(parser); + + if (c == '\r') continue; + + if (c == '\n') { + lastNewline = string.count; + skipEnd = lastNewline; + firstNewline = firstNewline == -1 ? string.count : firstNewline; + } + + if (c == '"' && c1 == '"' && c2 == '"') break; + + bool isWhitespace = c == ' ' || c == '\t'; + skipEnd = c == '\n' || isWhitespace ? skipEnd : -1; + + // If we haven't seen a newline or other character yet, + // and still seeing whitespace, count the characters + // as skippable till we know otherwise + bool skippable = skipStart != -1 && isWhitespace && firstNewline == -1; + skipStart = skippable ? string.count + 1 : skipStart; + + // We've counted leading whitespace till we hit something else, + // but it's not a newline, so we reset skipStart since we need these characters + if (firstNewline == -1 && !isWhitespace && c != '\n') skipStart = -1; + + if (c == '\0' || c1 == '\0' || c2 == '\0') + { + lexError(parser, "Unterminated raw string."); + + // Don't consume it if it isn't expected. Keeps us from reading past the + // end of an unterminated string. + parser->currentChar--; + break; + } + + wrenByteBufferWrite(parser->vm, &string, c); + } + + //consume the second and third " + nextChar(parser); + nextChar(parser); + + int offset = 0; + int count = string.count; + + if(firstNewline != -1 && skipStart == firstNewline) offset = firstNewline + 1; + if(lastNewline != -1 && skipEnd == lastNewline) count = lastNewline; + + count -= (offset > count) ? count : offset; + + parser->next.value = wrenNewStringLength(parser->vm, + ((char*)string.data) + offset, count); + + wrenByteBufferClear(parser->vm, &string); + makeToken(parser, type); +} + +// Finishes lexing a string literal. +static void readString(Parser* parser) +{ + ByteBuffer string; + TokenType type = TOKEN_STRING; + wrenByteBufferInit(&string); + + for (;;) + { + char c = nextChar(parser); + if (c == '"') break; + if (c == '\r') continue; + + if (c == '\0') + { + lexError(parser, "Unterminated string."); + + // Don't consume it if it isn't expected. Keeps us from reading past the + // end of an unterminated string. + parser->currentChar--; + break; + } + + if (c == '%') + { + if (parser->numParens < MAX_INTERPOLATION_NESTING) + { + // TODO: Allow format string. + if (nextChar(parser) != '(') lexError(parser, "Expect '(' after '%%'."); + + parser->parens[parser->numParens++] = 1; + type = TOKEN_INTERPOLATION; + break; + } + + lexError(parser, "Interpolation may only nest %d levels deep.", + MAX_INTERPOLATION_NESTING); + } + + if (c == '\\') + { + switch (nextChar(parser)) + { + case '"': wrenByteBufferWrite(parser->vm, &string, '"'); break; + case '\\': wrenByteBufferWrite(parser->vm, &string, '\\'); break; + case '%': wrenByteBufferWrite(parser->vm, &string, '%'); break; + case '0': wrenByteBufferWrite(parser->vm, &string, '\0'); break; + case 'a': wrenByteBufferWrite(parser->vm, &string, '\a'); break; + case 'b': wrenByteBufferWrite(parser->vm, &string, '\b'); break; + case 'e': wrenByteBufferWrite(parser->vm, &string, '\33'); break; + case 'f': wrenByteBufferWrite(parser->vm, &string, '\f'); break; + case 'n': wrenByteBufferWrite(parser->vm, &string, '\n'); break; + case 'r': wrenByteBufferWrite(parser->vm, &string, '\r'); break; + case 't': wrenByteBufferWrite(parser->vm, &string, '\t'); break; + case 'u': readUnicodeEscape(parser, &string, 4); break; + case 'U': readUnicodeEscape(parser, &string, 8); break; + case 'v': wrenByteBufferWrite(parser->vm, &string, '\v'); break; + case 'x': + wrenByteBufferWrite(parser->vm, &string, + (uint8_t)readHexEscape(parser, 2, "byte")); + break; + + default: + lexError(parser, "Invalid escape character '%c'.", + *(parser->currentChar - 1)); + break; + } + } + else + { + wrenByteBufferWrite(parser->vm, &string, c); + } + } + + parser->next.value = wrenNewStringLength(parser->vm, + (char*)string.data, string.count); + + wrenByteBufferClear(parser->vm, &string); + makeToken(parser, type); +} + +// Lex the next token and store it in [parser.next]. +static void nextToken(Parser* parser) +{ + parser->previous = parser->current; + parser->current = parser->next; + + // If we are out of tokens, don't try to tokenize any more. We *do* still + // copy the TOKEN_EOF to previous so that code that expects it to be consumed + // will still work. + if (parser->next.type == TOKEN_EOF) return; + if (parser->current.type == TOKEN_EOF) return; + + while (peekChar(parser) != '\0') + { + parser->tokenStart = parser->currentChar; + + char c = nextChar(parser); + switch (c) + { + case '(': + // If we are inside an interpolated expression, count the unmatched "(". + if (parser->numParens > 0) parser->parens[parser->numParens - 1]++; + makeToken(parser, TOKEN_LEFT_PAREN); + return; + + case ')': + // If we are inside an interpolated expression, count the ")". + if (parser->numParens > 0 && + --parser->parens[parser->numParens - 1] == 0) + { + // This is the final ")", so the interpolation expression has ended. + // This ")" now begins the next section of the template string. + parser->numParens--; + readString(parser); + return; + } + + makeToken(parser, TOKEN_RIGHT_PAREN); + return; + + case '[': makeToken(parser, TOKEN_LEFT_BRACKET); return; + case ']': makeToken(parser, TOKEN_RIGHT_BRACKET); return; + case '{': makeToken(parser, TOKEN_LEFT_BRACE); return; + case '}': makeToken(parser, TOKEN_RIGHT_BRACE); return; + case ':': makeToken(parser, TOKEN_COLON); return; + case ',': makeToken(parser, TOKEN_COMMA); return; + case '*': makeToken(parser, TOKEN_STAR); return; + case '%': makeToken(parser, TOKEN_PERCENT); return; + case '#': { + // Ignore shebang on the first line. + if (parser->currentLine == 1 && peekChar(parser) == '!' && peekNextChar(parser) == '/') + { + skipLineComment(parser); + break; + } + // Otherwise we treat it as a token + makeToken(parser, TOKEN_HASH); + return; + } + case '^': makeToken(parser, TOKEN_CARET); return; + case '+': makeToken(parser, TOKEN_PLUS); return; + case '-': makeToken(parser, TOKEN_MINUS); return; + case '~': makeToken(parser, TOKEN_TILDE); return; + case '?': makeToken(parser, TOKEN_QUESTION); return; + + case '|': twoCharToken(parser, '|', TOKEN_PIPEPIPE, TOKEN_PIPE); return; + case '&': twoCharToken(parser, '&', TOKEN_AMPAMP, TOKEN_AMP); return; + case '=': twoCharToken(parser, '=', TOKEN_EQEQ, TOKEN_EQ); return; + case '!': twoCharToken(parser, '=', TOKEN_BANGEQ, TOKEN_BANG); return; + + case '.': + if (matchChar(parser, '.')) + { + twoCharToken(parser, '.', TOKEN_DOTDOTDOT, TOKEN_DOTDOT); + return; + } + + makeToken(parser, TOKEN_DOT); + return; + + case '/': + if (matchChar(parser, '/')) + { + skipLineComment(parser); + break; + } + + if (matchChar(parser, '*')) + { + skipBlockComment(parser); + break; + } + + makeToken(parser, TOKEN_SLASH); + return; + + case '<': + if (matchChar(parser, '<')) + { + makeToken(parser, TOKEN_LTLT); + } + else + { + twoCharToken(parser, '=', TOKEN_LTEQ, TOKEN_LT); + } + return; + + case '>': + if (matchChar(parser, '>')) + { + makeToken(parser, TOKEN_GTGT); + } + else + { + twoCharToken(parser, '=', TOKEN_GTEQ, TOKEN_GT); + } + return; + + case '\n': + makeToken(parser, TOKEN_LINE); + return; + + case ' ': + case '\r': + case '\t': + // Skip forward until we run out of whitespace. + while (peekChar(parser) == ' ' || + peekChar(parser) == '\r' || + peekChar(parser) == '\t') + { + nextChar(parser); + } + break; + + case '"': { + if(peekChar(parser) == '"' && peekNextChar(parser) == '"') { + readRawString(parser); + return; + } + readString(parser); return; + } + case '_': + readName(parser, + peekChar(parser) == '_' ? TOKEN_STATIC_FIELD : TOKEN_FIELD, c); + return; + + case '0': + if (peekChar(parser) == 'x') + { + readHexNumber(parser); + return; + } + + readNumber(parser); + return; + + default: + if (isName(c)) + { + readName(parser, TOKEN_NAME, c); + } + else if (isDigit(c)) + { + readNumber(parser); + } + else + { + if (c >= 32 && c <= 126) + { + lexError(parser, "Invalid character '%c'.", c); + } + else + { + // Don't show non-ASCII values since we didn't UTF-8 decode the + // bytes. Since there are no non-ASCII byte values that are + // meaningful code units in Wren, the lexer works on raw bytes, + // even though the source code and console output are UTF-8. + lexError(parser, "Invalid byte 0x%x.", (uint8_t)c); + } + parser->next.type = TOKEN_ERROR; + parser->next.length = 0; + } + return; + } + } + + // If we get here, we're out of source, so just make EOF tokens. + parser->tokenStart = parser->currentChar; + makeToken(parser, TOKEN_EOF); +} + +// Parsing --------------------------------------------------------------------- + +// Returns the type of the current token. +static TokenType peek(Compiler* compiler) +{ + return compiler->parser->current.type; +} + +// Returns the type of the current token. +static TokenType peekNext(Compiler* compiler) +{ + return compiler->parser->next.type; +} + +// Consumes the current token if its type is [expected]. Returns true if a +// token was consumed. +static bool match(Compiler* compiler, TokenType expected) +{ + if (peek(compiler) != expected) return false; + + nextToken(compiler->parser); + return true; +} + +// Consumes the current token. Emits an error if its type is not [expected]. +static void consume(Compiler* compiler, TokenType expected, + const char* errorMessage) +{ + nextToken(compiler->parser); + if (compiler->parser->previous.type != expected) + { + error(compiler, errorMessage); + + // If the next token is the one we want, assume the current one is just a + // spurious error and discard it to minimize the number of cascaded errors. + if (compiler->parser->current.type == expected) nextToken(compiler->parser); + } +} + +// Matches one or more newlines. Returns true if at least one was found. +static bool matchLine(Compiler* compiler) +{ + if (!match(compiler, TOKEN_LINE)) return false; + + while (match(compiler, TOKEN_LINE)); + return true; +} + +// Discards any newlines starting at the current token. +static void ignoreNewlines(Compiler* compiler) +{ + matchLine(compiler); +} + +// Consumes the current token. Emits an error if it is not a newline. Then +// discards any duplicate newlines following it. +static void consumeLine(Compiler* compiler, const char* errorMessage) +{ + consume(compiler, TOKEN_LINE, errorMessage); + ignoreNewlines(compiler); +} + +static void allowLineBeforeDot(Compiler* compiler) { + if (peek(compiler) == TOKEN_LINE && peekNext(compiler) == TOKEN_DOT) { + nextToken(compiler->parser); + } +} + +// Variables and scopes -------------------------------------------------------- + +// Emits one single-byte argument. Returns its index. +static int emitByte(Compiler* compiler, int byte) +{ + wrenByteBufferWrite(compiler->parser->vm, &compiler->fn->code, (uint8_t)byte); + + // Assume the instruction is associated with the most recently consumed token. + wrenIntBufferWrite(compiler->parser->vm, &compiler->fn->debug->sourceLines, + compiler->parser->previous.line); + + return compiler->fn->code.count - 1; +} + +// Emits one bytecode instruction. +static void emitOp(Compiler* compiler, Code instruction) +{ + emitByte(compiler, instruction); + + // Keep track of the stack's high water mark. + compiler->numSlots += stackEffects[instruction]; + if (compiler->numSlots > compiler->fn->maxSlots) + { + compiler->fn->maxSlots = compiler->numSlots; + } +} + +// Emits one 16-bit argument, which will be written big endian. +static void emitShort(Compiler* compiler, int arg) +{ + emitByte(compiler, (arg >> 8) & 0xff); + emitByte(compiler, arg & 0xff); +} + +// Emits one bytecode instruction followed by a 8-bit argument. Returns the +// index of the argument in the bytecode. +static int emitByteArg(Compiler* compiler, Code instruction, int arg) +{ + emitOp(compiler, instruction); + return emitByte(compiler, arg); +} + +// Emits one bytecode instruction followed by a 16-bit argument, which will be +// written big endian. +static void emitShortArg(Compiler* compiler, Code instruction, int arg) +{ + emitOp(compiler, instruction); + emitShort(compiler, arg); +} + +// Emits [instruction] followed by a placeholder for a jump offset. The +// placeholder can be patched by calling [jumpPatch]. Returns the index of the +// placeholder. +static int emitJump(Compiler* compiler, Code instruction) +{ + emitOp(compiler, instruction); + emitByte(compiler, 0xff); + return emitByte(compiler, 0xff) - 1; +} + +// Creates a new constant for the current value and emits the bytecode to load +// it from the constant table. +static void emitConstant(Compiler* compiler, Value value) +{ + int constant = addConstant(compiler, value); + + // Compile the code to load the constant. + emitShortArg(compiler, CODE_CONSTANT, constant); +} + +// Create a new local variable with [name]. Assumes the current scope is local +// and the name is unique. +static int addLocal(Compiler* compiler, const char* name, int length) +{ + Local* local = &compiler->locals[compiler->numLocals]; + local->name = name; + local->length = length; + local->depth = compiler->scopeDepth; + local->isUpvalue = false; + return compiler->numLocals++; +} + +// Declares a variable in the current scope whose name is the given token. +// +// If [token] is `NULL`, uses the previously consumed token. Returns its symbol. +static int declareVariable(Compiler* compiler, Token* token) +{ + if (token == NULL) token = &compiler->parser->previous; + + if (token->length > MAX_VARIABLE_NAME) + { + error(compiler, "Variable name cannot be longer than %d characters.", + MAX_VARIABLE_NAME); + } + + // Top-level module scope. + if (compiler->scopeDepth == -1) + { + int line = -1; + int symbol = wrenDefineVariable(compiler->parser->vm, + compiler->parser->module, + token->start, token->length, + NULL_VAL, &line); + + if (symbol == -1) + { + error(compiler, "Module variable is already defined."); + } + else if (symbol == -2) + { + error(compiler, "Too many module variables defined."); + } + else if (symbol == -3) + { + error(compiler, + "Variable '%.*s' referenced before this definition (first use at line %d).", + token->length, token->start, line); + } + + return symbol; + } + + // See if there is already a variable with this name declared in the current + // scope. (Outer scopes are OK: those get shadowed.) + for (int i = compiler->numLocals - 1; i >= 0; i--) + { + Local* local = &compiler->locals[i]; + + // Once we escape this scope and hit an outer one, we can stop. + if (local->depth < compiler->scopeDepth) break; + + if (local->length == token->length && + memcmp(local->name, token->start, token->length) == 0) + { + error(compiler, "Variable is already declared in this scope."); + return i; + } + } + + if (compiler->numLocals == MAX_LOCALS) + { + error(compiler, "Cannot declare more than %d variables in one scope.", + MAX_LOCALS); + return -1; + } + + return addLocal(compiler, token->start, token->length); +} + +// Parses a name token and declares a variable in the current scope with that +// name. Returns its slot. +static int declareNamedVariable(Compiler* compiler) +{ + consume(compiler, TOKEN_NAME, "Expect variable name."); + return declareVariable(compiler, NULL); +} + +// Stores a variable with the previously defined symbol in the current scope. +static void defineVariable(Compiler* compiler, int symbol) +{ + // Store the variable. If it's a local, the result of the initializer is + // in the correct slot on the stack already so we're done. + if (compiler->scopeDepth >= 0) return; + + // It's a module-level variable, so store the value in the module slot and + // then discard the temporary for the initializer. + emitShortArg(compiler, CODE_STORE_MODULE_VAR, symbol); + emitOp(compiler, CODE_POP); +} + +// Starts a new local block scope. +static void pushScope(Compiler* compiler) +{ + compiler->scopeDepth++; +} + +// Generates code to discard local variables at [depth] or greater. Does *not* +// actually undeclare variables or pop any scopes, though. This is called +// directly when compiling "break" statements to ditch the local variables +// before jumping out of the loop even though they are still in scope *past* +// the break instruction. +// +// Returns the number of local variables that were eliminated. +static int discardLocals(Compiler* compiler, int depth) +{ + ASSERT(compiler->scopeDepth > -1, "Cannot exit top-level scope."); + + int local = compiler->numLocals - 1; + while (local >= 0 && compiler->locals[local].depth >= depth) + { + // If the local was closed over, make sure the upvalue gets closed when it + // goes out of scope on the stack. We use emitByte() and not emitOp() here + // because we don't want to track that stack effect of these pops since the + // variables are still in scope after the break. + if (compiler->locals[local].isUpvalue) + { + emitByte(compiler, CODE_CLOSE_UPVALUE); + } + else + { + emitByte(compiler, CODE_POP); + } + + + local--; + } + + return compiler->numLocals - local - 1; +} + +// Closes the last pushed block scope and discards any local variables declared +// in that scope. This should only be called in a statement context where no +// temporaries are still on the stack. +static void popScope(Compiler* compiler) +{ + int popped = discardLocals(compiler, compiler->scopeDepth); + compiler->numLocals -= popped; + compiler->numSlots -= popped; + compiler->scopeDepth--; +} + +// Attempts to look up the name in the local variables of [compiler]. If found, +// returns its index, otherwise returns -1. +static int resolveLocal(Compiler* compiler, const char* name, int length) +{ + // Look it up in the local scopes. Look in reverse order so that the most + // nested variable is found first and shadows outer ones. + for (int i = compiler->numLocals - 1; i >= 0; i--) + { + if (compiler->locals[i].length == length && + memcmp(name, compiler->locals[i].name, length) == 0) + { + return i; + } + } + + return -1; +} + +// Adds an upvalue to [compiler]'s function with the given properties. Does not +// add one if an upvalue for that variable is already in the list. Returns the +// index of the upvalue. +static int addUpvalue(Compiler* compiler, bool isLocal, int index) +{ + // Look for an existing one. + for (int i = 0; i < compiler->fn->numUpvalues; i++) + { + CompilerUpvalue* upvalue = &compiler->upvalues[i]; + if (upvalue->index == index && upvalue->isLocal == isLocal) return i; + } + + // If we got here, it's a new upvalue. + compiler->upvalues[compiler->fn->numUpvalues].isLocal = isLocal; + compiler->upvalues[compiler->fn->numUpvalues].index = index; + return compiler->fn->numUpvalues++; +} + +// Attempts to look up [name] in the functions enclosing the one being compiled +// by [compiler]. If found, it adds an upvalue for it to this compiler's list +// of upvalues (unless it's already in there) and returns its index. If not +// found, returns -1. +// +// If the name is found outside of the immediately enclosing function, this +// will flatten the closure and add upvalues to all of the intermediate +// functions so that it gets walked down to this one. +// +// If it reaches a method boundary, this stops and returns -1 since methods do +// not close over local variables. +static int findUpvalue(Compiler* compiler, const char* name, int length) +{ + // If we are at the top level, we didn't find it. + if (compiler->parent == NULL) return -1; + + // If we hit the method boundary (and the name isn't a static field), then + // stop looking for it. We'll instead treat it as a self send. + if (name[0] != '_' && compiler->parent->enclosingClass != NULL) return -1; + + // See if it's a local variable in the immediately enclosing function. + int local = resolveLocal(compiler->parent, name, length); + if (local != -1) + { + // Mark the local as an upvalue so we know to close it when it goes out of + // scope. + compiler->parent->locals[local].isUpvalue = true; + + return addUpvalue(compiler, true, local); + } + + // See if it's an upvalue in the immediately enclosing function. In other + // words, if it's a local variable in a non-immediately enclosing function. + // This "flattens" closures automatically: it adds upvalues to all of the + // intermediate functions to get from the function where a local is declared + // all the way into the possibly deeply nested function that is closing over + // it. + int upvalue = findUpvalue(compiler->parent, name, length); + if (upvalue != -1) + { + return addUpvalue(compiler, false, upvalue); + } + + // If we got here, we walked all the way up the parent chain and couldn't + // find it. + return -1; +} + +// Look up [name] in the current scope to see what variable it refers to. +// Returns the variable either in local scope, or the enclosing function's +// upvalue list. Does not search the module scope. Returns a variable with +// index -1 if not found. +static Variable resolveNonmodule(Compiler* compiler, + const char* name, int length) +{ + // Look it up in the local scopes. + Variable variable; + variable.scope = SCOPE_LOCAL; + variable.index = resolveLocal(compiler, name, length); + if (variable.index != -1) return variable; + + // Tt's not a local, so guess that it's an upvalue. + variable.scope = SCOPE_UPVALUE; + variable.index = findUpvalue(compiler, name, length); + return variable; +} + +// Look up [name] in the current scope to see what variable it refers to. +// Returns the variable either in module scope, local scope, or the enclosing +// function's upvalue list. Returns a variable with index -1 if not found. +static Variable resolveName(Compiler* compiler, const char* name, int length) +{ + Variable variable = resolveNonmodule(compiler, name, length); + if (variable.index != -1) return variable; + + variable.scope = SCOPE_MODULE; + variable.index = wrenSymbolTableFind(&compiler->parser->module->variableNames, + name, length); + return variable; +} + +static void loadLocal(Compiler* compiler, int slot) +{ + if (slot <= 8) + { + emitOp(compiler, (Code)(CODE_LOAD_LOCAL_0 + slot)); + return; + } + + emitByteArg(compiler, CODE_LOAD_LOCAL, slot); +} + +// Finishes [compiler], which is compiling a function, method, or chunk of top +// level code. If there is a parent compiler, then this emits code in the +// parent compiler to load the resulting function. +static ObjFn* endCompiler(Compiler* compiler, + const char* debugName, int debugNameLength) +{ + // If we hit an error, don't finish the function since it's borked anyway. + if (compiler->parser->hasError) + { + compiler->parser->vm->compiler = compiler->parent; + return NULL; + } + + // Mark the end of the bytecode. Since it may contain multiple early returns, + // we can't rely on CODE_RETURN to tell us we're at the end. + emitOp(compiler, CODE_END); + + wrenFunctionBindName(compiler->parser->vm, compiler->fn, + debugName, debugNameLength); + + // In the function that contains this one, load the resulting function object. + if (compiler->parent != NULL) + { + int constant = addConstant(compiler->parent, OBJ_VAL(compiler->fn)); + + // Wrap the function in a closure. We do this even if it has no upvalues so + // that the VM can uniformly assume all called objects are closures. This + // makes creating a function a little slower, but makes invoking them + // faster. Given that functions are invoked more often than they are + // created, this is a win. + emitShortArg(compiler->parent, CODE_CLOSURE, constant); + + // Emit arguments for each upvalue to know whether to capture a local or + // an upvalue. + for (int i = 0; i < compiler->fn->numUpvalues; i++) + { + emitByte(compiler->parent, compiler->upvalues[i].isLocal ? 1 : 0); + emitByte(compiler->parent, compiler->upvalues[i].index); + } + } + + // Pop this compiler off the stack. + compiler->parser->vm->compiler = compiler->parent; + + #if WREN_DEBUG_DUMP_COMPILED_CODE + wrenDumpCode(compiler->parser->vm, compiler->fn); + #endif + + return compiler->fn; +} + +// Grammar --------------------------------------------------------------------- + +typedef enum +{ + PREC_NONE, + PREC_LOWEST, + PREC_ASSIGNMENT, // = + PREC_CONDITIONAL, // ?: + PREC_LOGICAL_OR, // || + PREC_LOGICAL_AND, // && + PREC_EQUALITY, // == != + PREC_IS, // is + PREC_COMPARISON, // < > <= >= + PREC_BITWISE_OR, // | + PREC_BITWISE_XOR, // ^ + PREC_BITWISE_AND, // & + PREC_BITWISE_SHIFT, // << >> + PREC_RANGE, // .. ... + PREC_TERM, // + - + PREC_FACTOR, // * / % + PREC_UNARY, // unary - ! ~ + PREC_CALL, // . () [] + PREC_PRIMARY +} Precedence; + +typedef void (*GrammarFn)(Compiler*, bool canAssign); + +typedef void (*SignatureFn)(Compiler* compiler, Signature* signature); + +typedef struct +{ + GrammarFn prefix; + GrammarFn infix; + SignatureFn method; + Precedence precedence; + const char* name; +} GrammarRule; + +// Forward declarations since the grammar is recursive. +static GrammarRule* getRule(TokenType type); +static void expression(Compiler* compiler); +static void statement(Compiler* compiler); +static void definition(Compiler* compiler); +static void parsePrecedence(Compiler* compiler, Precedence precedence); + +// Replaces the placeholder argument for a previous CODE_JUMP or CODE_JUMP_IF +// instruction with an offset that jumps to the current end of bytecode. +static void patchJump(Compiler* compiler, int offset) +{ + // -2 to adjust for the bytecode for the jump offset itself. + int jump = compiler->fn->code.count - offset - 2; + if (jump > MAX_JUMP) error(compiler, "Too much code to jump over."); + + compiler->fn->code.data[offset] = (jump >> 8) & 0xff; + compiler->fn->code.data[offset + 1] = jump & 0xff; +} + +// Parses a block body, after the initial "{" has been consumed. +// +// Returns true if it was a expression body, false if it was a statement body. +// (More precisely, returns true if a value was left on the stack. An empty +// block returns false.) +static bool finishBlock(Compiler* compiler) +{ + // Empty blocks do nothing. + if (match(compiler, TOKEN_RIGHT_BRACE)) return false; + + // If there's no line after the "{", it's a single-expression body. + if (!matchLine(compiler)) + { + expression(compiler); + consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' at end of block."); + return true; + } + + // Empty blocks (with just a newline inside) do nothing. + if (match(compiler, TOKEN_RIGHT_BRACE)) return false; + + // Compile the definition list. + do + { + definition(compiler); + consumeLine(compiler, "Expect newline after statement."); + } + while (peek(compiler) != TOKEN_RIGHT_BRACE && peek(compiler) != TOKEN_EOF); + + consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' at end of block."); + return false; +} + +// Parses a method or function body, after the initial "{" has been consumed. +// +// If [Compiler->isInitializer] is `true`, this is the body of a constructor +// initializer. In that case, this adds the code to ensure it returns `this`. +static void finishBody(Compiler* compiler) +{ + bool isExpressionBody = finishBlock(compiler); + + if (compiler->isInitializer) + { + // If the initializer body evaluates to a value, discard it. + if (isExpressionBody) emitOp(compiler, CODE_POP); + + // The receiver is always stored in the first local slot. + emitOp(compiler, CODE_LOAD_LOCAL_0); + } + else if (!isExpressionBody) + { + // Implicitly return null in statement bodies. + emitOp(compiler, CODE_NULL); + } + + emitOp(compiler, CODE_RETURN); +} + +// The VM can only handle a certain number of parameters, so check that we +// haven't exceeded that and give a usable error. +static void validateNumParameters(Compiler* compiler, int numArgs) +{ + if (numArgs == MAX_PARAMETERS + 1) + { + // Only show an error at exactly max + 1 so that we can keep parsing the + // parameters and minimize cascaded errors. + error(compiler, "Methods cannot have more than %d parameters.", + MAX_PARAMETERS); + } +} + +// Parses the rest of a comma-separated parameter list after the opening +// delimeter. Updates `arity` in [signature] with the number of parameters. +static void finishParameterList(Compiler* compiler, Signature* signature) +{ + do + { + ignoreNewlines(compiler); + validateNumParameters(compiler, ++signature->arity); + + // Define a local variable in the method for the parameter. + declareNamedVariable(compiler); + } + while (match(compiler, TOKEN_COMMA)); +} + +// Gets the symbol for a method [name] with [length]. +static int methodSymbol(Compiler* compiler, const char* name, int length) +{ + return wrenSymbolTableEnsure(compiler->parser->vm, + &compiler->parser->vm->methodNames, name, length); +} + +// Appends characters to [name] (and updates [length]) for [numParams] "_" +// surrounded by [leftBracket] and [rightBracket]. +static void signatureParameterList(char name[MAX_METHOD_SIGNATURE], int* length, + int numParams, char leftBracket, char rightBracket) +{ + name[(*length)++] = leftBracket; + + // This function may be called with too many parameters. When that happens, + // a compile error has already been reported, but we need to make sure we + // don't overflow the string too, hence the MAX_PARAMETERS check. + for (int i = 0; i < numParams && i < MAX_PARAMETERS; i++) + { + if (i > 0) name[(*length)++] = ','; + name[(*length)++] = '_'; + } + name[(*length)++] = rightBracket; +} + +// Fills [name] with the stringified version of [signature] and updates +// [length] to the resulting length. +static void signatureToString(Signature* signature, + char name[MAX_METHOD_SIGNATURE], int* length) +{ + *length = 0; + + // Build the full name from the signature. + memcpy(name + *length, signature->name, signature->length); + *length += signature->length; + + switch (signature->type) + { + case SIG_METHOD: + signatureParameterList(name, length, signature->arity, '(', ')'); + break; + + case SIG_GETTER: + // The signature is just the name. + break; + + case SIG_SETTER: + name[(*length)++] = '='; + signatureParameterList(name, length, 1, '(', ')'); + break; + + case SIG_SUBSCRIPT: + signatureParameterList(name, length, signature->arity, '[', ']'); + break; + + case SIG_SUBSCRIPT_SETTER: + signatureParameterList(name, length, signature->arity - 1, '[', ']'); + name[(*length)++] = '='; + signatureParameterList(name, length, 1, '(', ')'); + break; + + case SIG_INITIALIZER: + memcpy(name, "init ", 5); + memcpy(name + 5, signature->name, signature->length); + *length = 5 + signature->length; + signatureParameterList(name, length, signature->arity, '(', ')'); + break; + } + + name[*length] = '\0'; +} + +// Gets the symbol for a method with [signature]. +static int signatureSymbol(Compiler* compiler, Signature* signature) +{ + // Build the full name from the signature. + char name[MAX_METHOD_SIGNATURE]; + int length; + signatureToString(signature, name, &length); + + return methodSymbol(compiler, name, length); +} + +// Returns a signature with [type] whose name is from the last consumed token. +static Signature signatureFromToken(Compiler* compiler, SignatureType type) +{ + Signature signature; + + // Get the token for the method name. + Token* token = &compiler->parser->previous; + signature.name = token->start; + signature.length = token->length; + signature.type = type; + signature.arity = 0; + + if (signature.length > MAX_METHOD_NAME) + { + error(compiler, "Method names cannot be longer than %d characters.", + MAX_METHOD_NAME); + signature.length = MAX_METHOD_NAME; + } + + return signature; +} + +// Parses a comma-separated list of arguments. Modifies [signature] to include +// the arity of the argument list. +static void finishArgumentList(Compiler* compiler, Signature* signature) +{ + do + { + ignoreNewlines(compiler); + validateNumParameters(compiler, ++signature->arity); + expression(compiler); + } + while (match(compiler, TOKEN_COMMA)); + + // Allow a newline before the closing delimiter. + ignoreNewlines(compiler); +} + +// Compiles a method call with [signature] using [instruction]. +static void callSignature(Compiler* compiler, Code instruction, + Signature* signature) +{ + int symbol = signatureSymbol(compiler, signature); + emitShortArg(compiler, (Code)(instruction + signature->arity), symbol); + + if (instruction == CODE_SUPER_0) + { + // Super calls need to be statically bound to the class's superclass. This + // ensures we call the right method even when a method containing a super + // call is inherited by another subclass. + // + // We bind it at class definition time by storing a reference to the + // superclass in a constant. So, here, we create a slot in the constant + // table and store NULL in it. When the method is bound, we'll look up the + // superclass then and store it in the constant slot. + emitShort(compiler, addConstant(compiler, NULL_VAL)); + } +} + +// Compiles a method call with [numArgs] for a method with [name] with [length]. +static void callMethod(Compiler* compiler, int numArgs, const char* name, + int length) +{ + int symbol = methodSymbol(compiler, name, length); + emitShortArg(compiler, (Code)(CODE_CALL_0 + numArgs), symbol); +} + +// Compiles an (optional) argument list for a method call with [methodSignature] +// and then calls it. +static void methodCall(Compiler* compiler, Code instruction, + Signature* signature) +{ + // Make a new signature that contains the updated arity and type based on + // the arguments we find. + Signature called = { signature->name, signature->length, SIG_GETTER, 0 }; + + // Parse the argument list, if any. + if (match(compiler, TOKEN_LEFT_PAREN)) + { + called.type = SIG_METHOD; + + // Allow new line before an empty argument list + ignoreNewlines(compiler); + + // Allow empty an argument list. + if (peek(compiler) != TOKEN_RIGHT_PAREN) + { + finishArgumentList(compiler, &called); + } + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after arguments."); + } + + // Parse the block argument, if any. + if (match(compiler, TOKEN_LEFT_BRACE)) + { + // Include the block argument in the arity. + called.type = SIG_METHOD; + called.arity++; + + Compiler fnCompiler; + initCompiler(&fnCompiler, compiler->parser, compiler, false); + + // Make a dummy signature to track the arity. + Signature fnSignature = { "", 0, SIG_METHOD, 0 }; + + // Parse the parameter list, if any. + if (match(compiler, TOKEN_PIPE)) + { + finishParameterList(&fnCompiler, &fnSignature); + consume(compiler, TOKEN_PIPE, "Expect '|' after function parameters."); + } + + fnCompiler.fn->arity = fnSignature.arity; + + finishBody(&fnCompiler); + + // Name the function based on the method its passed to. + char blockName[MAX_METHOD_SIGNATURE + 15]; + int blockLength; + signatureToString(&called, blockName, &blockLength); + memmove(blockName + blockLength, " block argument", 16); + + endCompiler(&fnCompiler, blockName, blockLength + 15); + } + + // TODO: Allow Grace-style mixfix methods? + + // If this is a super() call for an initializer, make sure we got an actual + // argument list. + if (signature->type == SIG_INITIALIZER) + { + if (called.type != SIG_METHOD) + { + error(compiler, "A superclass constructor must have an argument list."); + } + + called.type = SIG_INITIALIZER; + } + + callSignature(compiler, instruction, &called); +} + +// Compiles a call whose name is the previously consumed token. This includes +// getters, method calls with arguments, and setter calls. +static void namedCall(Compiler* compiler, bool canAssign, Code instruction) +{ + // Get the token for the method name. + Signature signature = signatureFromToken(compiler, SIG_GETTER); + + if (canAssign && match(compiler, TOKEN_EQ)) + { + ignoreNewlines(compiler); + + // Build the setter signature. + signature.type = SIG_SETTER; + signature.arity = 1; + + // Compile the assigned value. + expression(compiler); + callSignature(compiler, instruction, &signature); + } + else + { + methodCall(compiler, instruction, &signature); + allowLineBeforeDot(compiler); + } +} + +// Emits the code to load [variable] onto the stack. +static void loadVariable(Compiler* compiler, Variable variable) +{ + switch (variable.scope) + { + case SCOPE_LOCAL: + loadLocal(compiler, variable.index); + break; + case SCOPE_UPVALUE: + emitByteArg(compiler, CODE_LOAD_UPVALUE, variable.index); + break; + case SCOPE_MODULE: + emitShortArg(compiler, CODE_LOAD_MODULE_VAR, variable.index); + break; + default: + UNREACHABLE(); + } +} + +// Loads the receiver of the currently enclosing method. Correctly handles +// functions defined inside methods. +static void loadThis(Compiler* compiler) +{ + loadVariable(compiler, resolveNonmodule(compiler, "this", 4)); +} + +// Pushes the value for a module-level variable implicitly imported from core. +static void loadCoreVariable(Compiler* compiler, const char* name) +{ + int symbol = wrenSymbolTableFind(&compiler->parser->module->variableNames, + name, strlen(name)); + ASSERT(symbol != -1, "Should have already defined core name."); + emitShortArg(compiler, CODE_LOAD_MODULE_VAR, symbol); +} + +// A parenthesized expression. +static void grouping(Compiler* compiler, bool canAssign) +{ + expression(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after expression."); +} + +// A list literal. +static void list(Compiler* compiler, bool canAssign) +{ + // Instantiate a new list. + loadCoreVariable(compiler, "List"); + callMethod(compiler, 0, "new()", 5); + + // Compile the list elements. Each one compiles to a ".add()" call. + do + { + ignoreNewlines(compiler); + + // Stop if we hit the end of the list. + if (peek(compiler) == TOKEN_RIGHT_BRACKET) break; + + // The element. + expression(compiler); + callMethod(compiler, 1, "addCore_(_)", 11); + } while (match(compiler, TOKEN_COMMA)); + + // Allow newlines before the closing ']'. + ignoreNewlines(compiler); + consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after list elements."); +} + +// A map literal. +static void map(Compiler* compiler, bool canAssign) +{ + // Instantiate a new map. + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + // Compile the map elements. Each one is compiled to just invoke the + // subscript setter on the map. + do + { + ignoreNewlines(compiler); + + // Stop if we hit the end of the map. + if (peek(compiler) == TOKEN_RIGHT_BRACE) break; + + // The key. + parsePrecedence(compiler, PREC_UNARY); + consume(compiler, TOKEN_COLON, "Expect ':' after map key."); + ignoreNewlines(compiler); + + // The value. + expression(compiler); + callMethod(compiler, 2, "addCore_(_,_)", 13); + } while (match(compiler, TOKEN_COMMA)); + + // Allow newlines before the closing '}'. + ignoreNewlines(compiler); + consume(compiler, TOKEN_RIGHT_BRACE, "Expect '}' after map entries."); +} + +// Unary operators like `-foo`. +static void unaryOp(Compiler* compiler, bool canAssign) +{ + GrammarRule* rule = getRule(compiler->parser->previous.type); + + ignoreNewlines(compiler); + + // Compile the argument. + parsePrecedence(compiler, (Precedence)(PREC_UNARY + 1)); + + // Call the operator method on the left-hand side. + callMethod(compiler, 0, rule->name, 1); +} + +static void boolean(Compiler* compiler, bool canAssign) +{ + emitOp(compiler, + compiler->parser->previous.type == TOKEN_FALSE ? CODE_FALSE : CODE_TRUE); +} + +// Walks the compiler chain to find the compiler for the nearest class +// enclosing this one. Returns NULL if not currently inside a class definition. +static Compiler* getEnclosingClassCompiler(Compiler* compiler) +{ + while (compiler != NULL) + { + if (compiler->enclosingClass != NULL) return compiler; + compiler = compiler->parent; + } + + return NULL; +} + +// Walks the compiler chain to find the nearest class enclosing this one. +// Returns NULL if not currently inside a class definition. +static ClassInfo* getEnclosingClass(Compiler* compiler) +{ + compiler = getEnclosingClassCompiler(compiler); + return compiler == NULL ? NULL : compiler->enclosingClass; +} + +static void field(Compiler* compiler, bool canAssign) +{ + // Initialize it with a fake value so we can keep parsing and minimize the + // number of cascaded errors. + int field = MAX_FIELDS; + + ClassInfo* enclosingClass = getEnclosingClass(compiler); + + if (enclosingClass == NULL) + { + error(compiler, "Cannot reference a field outside of a class definition."); + } + else if (enclosingClass->isForeign) + { + error(compiler, "Cannot define fields in a foreign class."); + } + else if (enclosingClass->inStatic) + { + error(compiler, "Cannot use an instance field in a static method."); + } + else + { + // Look up the field, or implicitly define it. + field = wrenSymbolTableEnsure(compiler->parser->vm, &enclosingClass->fields, + compiler->parser->previous.start, + compiler->parser->previous.length); + + if (field >= MAX_FIELDS) + { + error(compiler, "A class can only have %d fields.", MAX_FIELDS); + } + } + + // If there's an "=" after a field name, it's an assignment. + bool isLoad = true; + if (canAssign && match(compiler, TOKEN_EQ)) + { + // Compile the right-hand side. + expression(compiler); + isLoad = false; + } + + // If we're directly inside a method, use a more optimal instruction. + if (compiler->parent != NULL && + compiler->parent->enclosingClass == enclosingClass) + { + emitByteArg(compiler, isLoad ? CODE_LOAD_FIELD_THIS : CODE_STORE_FIELD_THIS, + field); + } + else + { + loadThis(compiler); + emitByteArg(compiler, isLoad ? CODE_LOAD_FIELD : CODE_STORE_FIELD, field); + } + + allowLineBeforeDot(compiler); +} + +// Compiles a read or assignment to [variable]. +static void bareName(Compiler* compiler, bool canAssign, Variable variable) +{ + // If there's an "=" after a bare name, it's a variable assignment. + if (canAssign && match(compiler, TOKEN_EQ)) + { + // Compile the right-hand side. + expression(compiler); + + // Emit the store instruction. + switch (variable.scope) + { + case SCOPE_LOCAL: + emitByteArg(compiler, CODE_STORE_LOCAL, variable.index); + break; + case SCOPE_UPVALUE: + emitByteArg(compiler, CODE_STORE_UPVALUE, variable.index); + break; + case SCOPE_MODULE: + emitShortArg(compiler, CODE_STORE_MODULE_VAR, variable.index); + break; + default: + UNREACHABLE(); + } + return; + } + + // Emit the load instruction. + loadVariable(compiler, variable); + + allowLineBeforeDot(compiler); +} + +static void staticField(Compiler* compiler, bool canAssign) +{ + Compiler* classCompiler = getEnclosingClassCompiler(compiler); + if (classCompiler == NULL) + { + error(compiler, "Cannot use a static field outside of a class definition."); + return; + } + + // Look up the name in the scope chain. + Token* token = &compiler->parser->previous; + + // If this is the first time we've seen this static field, implicitly + // define it as a variable in the scope surrounding the class definition. + if (resolveLocal(classCompiler, token->start, token->length) == -1) + { + int symbol = declareVariable(classCompiler, NULL); + + // Implicitly initialize it to null. + emitOp(classCompiler, CODE_NULL); + defineVariable(classCompiler, symbol); + } + + // It definitely exists now, so resolve it properly. This is different from + // the above resolveLocal() call because we may have already closed over it + // as an upvalue. + Variable variable = resolveName(compiler, token->start, token->length); + bareName(compiler, canAssign, variable); +} + +// Compiles a variable name or method call with an implicit receiver. +static void name(Compiler* compiler, bool canAssign) +{ + // Look for the name in the scope chain up to the nearest enclosing method. + Token* token = &compiler->parser->previous; + + Variable variable = resolveNonmodule(compiler, token->start, token->length); + if (variable.index != -1) + { + bareName(compiler, canAssign, variable); + return; + } + + // TODO: The fact that we return above here if the variable is known and parse + // an optional argument list below if not means that the grammar is not + // context-free. A line of code in a method like "someName(foo)" is a parse + // error if "someName" is a defined variable in the surrounding scope and not + // if it isn't. Fix this. One option is to have "someName(foo)" always + // resolve to a self-call if there is an argument list, but that makes + // getters a little confusing. + + // If we're inside a method and the name is lowercase, treat it as a method + // on this. + if (wrenIsLocalName(token->start) && getEnclosingClass(compiler) != NULL) + { + loadThis(compiler); + namedCall(compiler, canAssign, CODE_CALL_0); + return; + } + + // Otherwise, look for a module-level variable with the name. + variable.scope = SCOPE_MODULE; + variable.index = wrenSymbolTableFind(&compiler->parser->module->variableNames, + token->start, token->length); + if (variable.index == -1) + { + // Implicitly define a module-level variable in + // the hopes that we get a real definition later. + variable.index = wrenDeclareVariable(compiler->parser->vm, + compiler->parser->module, + token->start, token->length, + token->line); + + if (variable.index == -2) + { + error(compiler, "Too many module variables defined."); + } + } + + bareName(compiler, canAssign, variable); +} + +static void null(Compiler* compiler, bool canAssign) +{ + emitOp(compiler, CODE_NULL); +} + +// A number or string literal. +static void literal(Compiler* compiler, bool canAssign) +{ + emitConstant(compiler, compiler->parser->previous.value); +} + +// A string literal that contains interpolated expressions. +// +// Interpolation is syntactic sugar for calling ".join()" on a list. So the +// string: +// +// "a %(b + c) d" +// +// is compiled roughly like: +// +// ["a ", b + c, " d"].join() +static void stringInterpolation(Compiler* compiler, bool canAssign) +{ + // Instantiate a new list. + loadCoreVariable(compiler, "List"); + callMethod(compiler, 0, "new()", 5); + + do + { + // The opening string part. + literal(compiler, false); + callMethod(compiler, 1, "addCore_(_)", 11); + + // The interpolated expression. + ignoreNewlines(compiler); + expression(compiler); + callMethod(compiler, 1, "addCore_(_)", 11); + + ignoreNewlines(compiler); + } while (match(compiler, TOKEN_INTERPOLATION)); + + // The trailing string part. + consume(compiler, TOKEN_STRING, "Expect end of string interpolation."); + literal(compiler, false); + callMethod(compiler, 1, "addCore_(_)", 11); + + // The list of interpolated parts. + callMethod(compiler, 0, "join()", 6); +} + +static void super_(Compiler* compiler, bool canAssign) +{ + ClassInfo* enclosingClass = getEnclosingClass(compiler); + if (enclosingClass == NULL) + { + error(compiler, "Cannot use 'super' outside of a method."); + } + + loadThis(compiler); + + // TODO: Super operator calls. + // TODO: There's no syntax for invoking a superclass constructor with a + // different name from the enclosing one. Figure that out. + + // See if it's a named super call, or an unnamed one. + if (match(compiler, TOKEN_DOT)) + { + // Compile the superclass call. + consume(compiler, TOKEN_NAME, "Expect method name after 'super.'."); + namedCall(compiler, canAssign, CODE_SUPER_0); + } + else if (enclosingClass != NULL) + { + // No explicit name, so use the name of the enclosing method. Make sure we + // check that enclosingClass isn't NULL first. We've already reported the + // error, but we don't want to crash here. + methodCall(compiler, CODE_SUPER_0, enclosingClass->signature); + } +} + +static void this_(Compiler* compiler, bool canAssign) +{ + if (getEnclosingClass(compiler) == NULL) + { + error(compiler, "Cannot use 'this' outside of a method."); + return; + } + + loadThis(compiler); +} + +// Subscript or "array indexing" operator like `foo[bar]`. +static void subscript(Compiler* compiler, bool canAssign) +{ + Signature signature = { "", 0, SIG_SUBSCRIPT, 0 }; + + // Parse the argument list. + finishArgumentList(compiler, &signature); + consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after arguments."); + + allowLineBeforeDot(compiler); + + if (canAssign && match(compiler, TOKEN_EQ)) + { + signature.type = SIG_SUBSCRIPT_SETTER; + + // Compile the assigned value. + validateNumParameters(compiler, ++signature.arity); + expression(compiler); + } + + callSignature(compiler, CODE_CALL_0, &signature); +} + +static void call(Compiler* compiler, bool canAssign) +{ + ignoreNewlines(compiler); + consume(compiler, TOKEN_NAME, "Expect method name after '.'."); + namedCall(compiler, canAssign, CODE_CALL_0); +} + +static void and_(Compiler* compiler, bool canAssign) +{ + ignoreNewlines(compiler); + + // Skip the right argument if the left is false. + int jump = emitJump(compiler, CODE_AND); + parsePrecedence(compiler, PREC_LOGICAL_AND); + patchJump(compiler, jump); +} + +static void or_(Compiler* compiler, bool canAssign) +{ + ignoreNewlines(compiler); + + // Skip the right argument if the left is true. + int jump = emitJump(compiler, CODE_OR); + parsePrecedence(compiler, PREC_LOGICAL_OR); + patchJump(compiler, jump); +} + +static void conditional(Compiler* compiler, bool canAssign) +{ + // Ignore newline after '?'. + ignoreNewlines(compiler); + + // Jump to the else branch if the condition is false. + int ifJump = emitJump(compiler, CODE_JUMP_IF); + + // Compile the then branch. + parsePrecedence(compiler, PREC_CONDITIONAL); + + consume(compiler, TOKEN_COLON, + "Expect ':' after then branch of conditional operator."); + ignoreNewlines(compiler); + + // Jump over the else branch when the if branch is taken. + int elseJump = emitJump(compiler, CODE_JUMP); + + // Compile the else branch. + patchJump(compiler, ifJump); + + parsePrecedence(compiler, PREC_ASSIGNMENT); + + // Patch the jump over the else. + patchJump(compiler, elseJump); +} + +void infixOp(Compiler* compiler, bool canAssign) +{ + GrammarRule* rule = getRule(compiler->parser->previous.type); + + // An infix operator cannot end an expression. + ignoreNewlines(compiler); + + // Compile the right-hand side. + parsePrecedence(compiler, (Precedence)(rule->precedence + 1)); + + // Call the operator method on the left-hand side. + Signature signature = { rule->name, (int)strlen(rule->name), SIG_METHOD, 1 }; + callSignature(compiler, CODE_CALL_0, &signature); +} + +// Compiles a method signature for an infix operator. +void infixSignature(Compiler* compiler, Signature* signature) +{ + // Add the RHS parameter. + signature->type = SIG_METHOD; + signature->arity = 1; + + // Parse the parameter name. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after operator name."); + declareNamedVariable(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name."); +} + +// Compiles a method signature for an unary operator (i.e. "!"). +void unarySignature(Compiler* compiler, Signature* signature) +{ + // Do nothing. The name is already complete. + signature->type = SIG_GETTER; +} + +// Compiles a method signature for an operator that can either be unary or +// infix (i.e. "-"). +void mixedSignature(Compiler* compiler, Signature* signature) +{ + signature->type = SIG_GETTER; + + // If there is a parameter, it's an infix operator, otherwise it's unary. + if (match(compiler, TOKEN_LEFT_PAREN)) + { + // Add the RHS parameter. + signature->type = SIG_METHOD; + signature->arity = 1; + + // Parse the parameter name. + declareNamedVariable(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name."); + } +} + +// Compiles an optional setter parameter in a method [signature]. +// +// Returns `true` if it was a setter. +static bool maybeSetter(Compiler* compiler, Signature* signature) +{ + // See if it's a setter. + if (!match(compiler, TOKEN_EQ)) return false; + + // It's a setter. + if (signature->type == SIG_SUBSCRIPT) + { + signature->type = SIG_SUBSCRIPT_SETTER; + } + else + { + signature->type = SIG_SETTER; + } + + // Parse the value parameter. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after '='."); + declareNamedVariable(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameter name."); + + signature->arity++; + + return true; +} + +// Compiles a method signature for a subscript operator. +void subscriptSignature(Compiler* compiler, Signature* signature) +{ + signature->type = SIG_SUBSCRIPT; + + // The signature currently has "[" as its name since that was the token that + // matched it. Clear that out. + signature->length = 0; + + // Parse the parameters inside the subscript. + finishParameterList(compiler, signature); + consume(compiler, TOKEN_RIGHT_BRACKET, "Expect ']' after parameters."); + + maybeSetter(compiler, signature); +} + +// Parses an optional parenthesized parameter list. Updates `type` and `arity` +// in [signature] to match what was parsed. +static void parameterList(Compiler* compiler, Signature* signature) +{ + // The parameter list is optional. + if (!match(compiler, TOKEN_LEFT_PAREN)) return; + + signature->type = SIG_METHOD; + + // Allow new line before an empty argument list + ignoreNewlines(compiler); + + // Allow an empty parameter list. + if (match(compiler, TOKEN_RIGHT_PAREN)) return; + + finishParameterList(compiler, signature); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameters."); +} + +// Compiles a method signature for a named method or setter. +void namedSignature(Compiler* compiler, Signature* signature) +{ + signature->type = SIG_GETTER; + + // If it's a setter, it can't also have a parameter list. + if (maybeSetter(compiler, signature)) return; + + // Regular named method with an optional parameter list. + parameterList(compiler, signature); +} + +// Compiles a method signature for a constructor. +void constructorSignature(Compiler* compiler, Signature* signature) +{ + consume(compiler, TOKEN_NAME, "Expect constructor name after 'construct'."); + + // Capture the name. + *signature = signatureFromToken(compiler, SIG_INITIALIZER); + + if (match(compiler, TOKEN_EQ)) + { + error(compiler, "A constructor cannot be a setter."); + } + + if (!match(compiler, TOKEN_LEFT_PAREN)) + { + error(compiler, "A constructor cannot be a getter."); + return; + } + + // Allow an empty parameter list. + if (match(compiler, TOKEN_RIGHT_PAREN)) return; + + finishParameterList(compiler, signature); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after parameters."); +} + +// This table defines all of the parsing rules for the prefix and infix +// expressions in the grammar. Expressions are parsed using a Pratt parser. +// +// See: http://journal.stuffwithstuff.com/2011/03/19/pratt-parsers-expression-parsing-made-easy/ +#define UNUSED { NULL, NULL, NULL, PREC_NONE, NULL } +#define PREFIX(fn) { fn, NULL, NULL, PREC_NONE, NULL } +#define INFIX(prec, fn) { NULL, fn, NULL, prec, NULL } +#define INFIX_OPERATOR(prec, name) { NULL, infixOp, infixSignature, prec, name } +#define PREFIX_OPERATOR(name) { unaryOp, NULL, unarySignature, PREC_NONE, name } +#define OPERATOR(name) { unaryOp, infixOp, mixedSignature, PREC_TERM, name } + +GrammarRule rules[] = +{ + /* TOKEN_LEFT_PAREN */ PREFIX(grouping), + /* TOKEN_RIGHT_PAREN */ UNUSED, + /* TOKEN_LEFT_BRACKET */ { list, subscript, subscriptSignature, PREC_CALL, NULL }, + /* TOKEN_RIGHT_BRACKET */ UNUSED, + /* TOKEN_LEFT_BRACE */ PREFIX(map), + /* TOKEN_RIGHT_BRACE */ UNUSED, + /* TOKEN_COLON */ UNUSED, + /* TOKEN_DOT */ INFIX(PREC_CALL, call), + /* TOKEN_DOTDOT */ INFIX_OPERATOR(PREC_RANGE, ".."), + /* TOKEN_DOTDOTDOT */ INFIX_OPERATOR(PREC_RANGE, "..."), + /* TOKEN_COMMA */ UNUSED, + /* TOKEN_STAR */ INFIX_OPERATOR(PREC_FACTOR, "*"), + /* TOKEN_SLASH */ INFIX_OPERATOR(PREC_FACTOR, "/"), + /* TOKEN_PERCENT */ INFIX_OPERATOR(PREC_FACTOR, "%"), + /* TOKEN_HASH */ UNUSED, + /* TOKEN_PLUS */ INFIX_OPERATOR(PREC_TERM, "+"), + /* TOKEN_MINUS */ OPERATOR("-"), + /* TOKEN_LTLT */ INFIX_OPERATOR(PREC_BITWISE_SHIFT, "<<"), + /* TOKEN_GTGT */ INFIX_OPERATOR(PREC_BITWISE_SHIFT, ">>"), + /* TOKEN_PIPE */ INFIX_OPERATOR(PREC_BITWISE_OR, "|"), + /* TOKEN_PIPEPIPE */ INFIX(PREC_LOGICAL_OR, or_), + /* TOKEN_CARET */ INFIX_OPERATOR(PREC_BITWISE_XOR, "^"), + /* TOKEN_AMP */ INFIX_OPERATOR(PREC_BITWISE_AND, "&"), + /* TOKEN_AMPAMP */ INFIX(PREC_LOGICAL_AND, and_), + /* TOKEN_BANG */ PREFIX_OPERATOR("!"), + /* TOKEN_TILDE */ PREFIX_OPERATOR("~"), + /* TOKEN_QUESTION */ INFIX(PREC_ASSIGNMENT, conditional), + /* TOKEN_EQ */ UNUSED, + /* TOKEN_LT */ INFIX_OPERATOR(PREC_COMPARISON, "<"), + /* TOKEN_GT */ INFIX_OPERATOR(PREC_COMPARISON, ">"), + /* TOKEN_LTEQ */ INFIX_OPERATOR(PREC_COMPARISON, "<="), + /* TOKEN_GTEQ */ INFIX_OPERATOR(PREC_COMPARISON, ">="), + /* TOKEN_EQEQ */ INFIX_OPERATOR(PREC_EQUALITY, "=="), + /* TOKEN_BANGEQ */ INFIX_OPERATOR(PREC_EQUALITY, "!="), + /* TOKEN_BREAK */ UNUSED, + /* TOKEN_CONTINUE */ UNUSED, + /* TOKEN_CLASS */ UNUSED, + /* TOKEN_CONSTRUCT */ { NULL, NULL, constructorSignature, PREC_NONE, NULL }, + /* TOKEN_ELSE */ UNUSED, + /* TOKEN_FALSE */ PREFIX(boolean), + /* TOKEN_FOR */ UNUSED, + /* TOKEN_FOREIGN */ UNUSED, + /* TOKEN_IF */ UNUSED, + /* TOKEN_IMPORT */ UNUSED, + /* TOKEN_AS */ UNUSED, + /* TOKEN_IN */ UNUSED, + /* TOKEN_IS */ INFIX_OPERATOR(PREC_IS, "is"), + /* TOKEN_NULL */ PREFIX(null), + /* TOKEN_RETURN */ UNUSED, + /* TOKEN_STATIC */ UNUSED, + /* TOKEN_SUPER */ PREFIX(super_), + /* TOKEN_THIS */ PREFIX(this_), + /* TOKEN_TRUE */ PREFIX(boolean), + /* TOKEN_VAR */ UNUSED, + /* TOKEN_WHILE */ UNUSED, + /* TOKEN_FIELD */ PREFIX(field), + /* TOKEN_STATIC_FIELD */ PREFIX(staticField), + /* TOKEN_NAME */ { name, NULL, namedSignature, PREC_NONE, NULL }, + /* TOKEN_NUMBER */ PREFIX(literal), + /* TOKEN_STRING */ PREFIX(literal), + /* TOKEN_INTERPOLATION */ PREFIX(stringInterpolation), + /* TOKEN_LINE */ UNUSED, + /* TOKEN_ERROR */ UNUSED, + /* TOKEN_EOF */ UNUSED +}; + +// Gets the [GrammarRule] associated with tokens of [type]. +static GrammarRule* getRule(TokenType type) +{ + return &rules[type]; +} + +// The main entrypoint for the top-down operator precedence parser. +void parsePrecedence(Compiler* compiler, Precedence precedence) +{ + nextToken(compiler->parser); + GrammarFn prefix = rules[compiler->parser->previous.type].prefix; + + if (prefix == NULL) + { + error(compiler, "Expected expression."); + return; + } + + // Track if the precendence of the surrounding expression is low enough to + // allow an assignment inside this one. We can't compile an assignment like + // a normal expression because it requires us to handle the LHS specially -- + // it needs to be an lvalue, not an rvalue. So, for each of the kinds of + // expressions that are valid lvalues -- names, subscripts, fields, etc. -- + // we pass in whether or not it appears in a context loose enough to allow + // "=". If so, it will parse the "=" itself and handle it appropriately. + bool canAssign = precedence <= PREC_CONDITIONAL; + prefix(compiler, canAssign); + + while (precedence <= rules[compiler->parser->current.type].precedence) + { + nextToken(compiler->parser); + GrammarFn infix = rules[compiler->parser->previous.type].infix; + infix(compiler, canAssign); + } +} + +// Parses an expression. Unlike statements, expressions leave a resulting value +// on the stack. +void expression(Compiler* compiler) +{ + parsePrecedence(compiler, PREC_LOWEST); +} + +// Returns the number of bytes for the arguments to the instruction +// at [ip] in [fn]'s bytecode. +static int getByteCountForArguments(const uint8_t* bytecode, + const Value* constants, int ip) +{ + Code instruction = (Code)bytecode[ip]; + switch (instruction) + { + case CODE_NULL: + case CODE_FALSE: + case CODE_TRUE: + case CODE_POP: + case CODE_CLOSE_UPVALUE: + case CODE_RETURN: + case CODE_END: + case CODE_LOAD_LOCAL_0: + case CODE_LOAD_LOCAL_1: + case CODE_LOAD_LOCAL_2: + case CODE_LOAD_LOCAL_3: + case CODE_LOAD_LOCAL_4: + case CODE_LOAD_LOCAL_5: + case CODE_LOAD_LOCAL_6: + case CODE_LOAD_LOCAL_7: + case CODE_LOAD_LOCAL_8: + case CODE_CONSTRUCT: + case CODE_FOREIGN_CONSTRUCT: + case CODE_FOREIGN_CLASS: + case CODE_END_MODULE: + case CODE_END_CLASS: + return 0; + + case CODE_LOAD_LOCAL: + case CODE_STORE_LOCAL: + case CODE_LOAD_UPVALUE: + case CODE_STORE_UPVALUE: + case CODE_LOAD_FIELD_THIS: + case CODE_STORE_FIELD_THIS: + case CODE_LOAD_FIELD: + case CODE_STORE_FIELD: + case CODE_CLASS: + return 1; + + case CODE_CONSTANT: + case CODE_LOAD_MODULE_VAR: + case CODE_STORE_MODULE_VAR: + case CODE_CALL_0: + case CODE_CALL_1: + case CODE_CALL_2: + case CODE_CALL_3: + case CODE_CALL_4: + case CODE_CALL_5: + case CODE_CALL_6: + case CODE_CALL_7: + case CODE_CALL_8: + case CODE_CALL_9: + case CODE_CALL_10: + case CODE_CALL_11: + case CODE_CALL_12: + case CODE_CALL_13: + case CODE_CALL_14: + case CODE_CALL_15: + case CODE_CALL_16: + case CODE_JUMP: + case CODE_LOOP: + case CODE_JUMP_IF: + case CODE_AND: + case CODE_OR: + case CODE_METHOD_INSTANCE: + case CODE_METHOD_STATIC: + case CODE_IMPORT_MODULE: + case CODE_IMPORT_VARIABLE: + return 2; + + case CODE_SUPER_0: + case CODE_SUPER_1: + case CODE_SUPER_2: + case CODE_SUPER_3: + case CODE_SUPER_4: + case CODE_SUPER_5: + case CODE_SUPER_6: + case CODE_SUPER_7: + case CODE_SUPER_8: + case CODE_SUPER_9: + case CODE_SUPER_10: + case CODE_SUPER_11: + case CODE_SUPER_12: + case CODE_SUPER_13: + case CODE_SUPER_14: + case CODE_SUPER_15: + case CODE_SUPER_16: + return 4; + + case CODE_CLOSURE: + { + int constant = (bytecode[ip + 1] << 8) | bytecode[ip + 2]; + ObjFn* loadedFn = AS_FN(constants[constant]); + + // There are two bytes for the constant, then two for each upvalue. + return 2 + (loadedFn->numUpvalues * 2); + } + } + + UNREACHABLE(); + return 0; +} + +// Marks the beginning of a loop. Keeps track of the current instruction so we +// know what to loop back to at the end of the body. +static void startLoop(Compiler* compiler, Loop* loop) +{ + loop->enclosing = compiler->loop; + loop->start = compiler->fn->code.count - 1; + loop->scopeDepth = compiler->scopeDepth; + compiler->loop = loop; +} + +// Emits the [CODE_JUMP_IF] instruction used to test the loop condition and +// potentially exit the loop. Keeps track of the instruction so we can patch it +// later once we know where the end of the body is. +static void testExitLoop(Compiler* compiler) +{ + compiler->loop->exitJump = emitJump(compiler, CODE_JUMP_IF); +} + +// Compiles the body of the loop and tracks its extent so that contained "break" +// statements can be handled correctly. +static void loopBody(Compiler* compiler) +{ + compiler->loop->body = compiler->fn->code.count; + statement(compiler); +} + +// Ends the current innermost loop. Patches up all jumps and breaks now that +// we know where the end of the loop is. +static void endLoop(Compiler* compiler) +{ + // We don't check for overflow here since the forward jump over the loop body + // will report an error for the same problem. + int loopOffset = compiler->fn->code.count - compiler->loop->start + 2; + emitShortArg(compiler, CODE_LOOP, loopOffset); + + patchJump(compiler, compiler->loop->exitJump); + + // Find any break placeholder instructions (which will be CODE_END in the + // bytecode) and replace them with real jumps. + int i = compiler->loop->body; + while (i < compiler->fn->code.count) + { + if (compiler->fn->code.data[i] == CODE_END) + { + compiler->fn->code.data[i] = CODE_JUMP; + patchJump(compiler, i + 1); + i += 3; + } + else + { + // Skip this instruction and its arguments. + i += 1 + getByteCountForArguments(compiler->fn->code.data, + compiler->fn->constants.data, i); + } + } + + compiler->loop = compiler->loop->enclosing; +} + +static void forStatement(Compiler* compiler) +{ + // A for statement like: + // + // for (i in sequence.expression) { + // System.print(i) + // } + // + // Is compiled to bytecode almost as if the source looked like this: + // + // { + // var seq_ = sequence.expression + // var iter_ + // while (iter_ = seq_.iterate(iter_)) { + // var i = seq_.iteratorValue(iter_) + // System.print(i) + // } + // } + // + // It's not exactly this, because the synthetic variables `seq_` and `iter_` + // actually get names that aren't valid Wren identfiers, but that's the basic + // idea. + // + // The important parts are: + // - The sequence expression is only evaluated once. + // - The .iterate() method is used to advance the iterator and determine if + // it should exit the loop. + // - The .iteratorValue() method is used to get the value at the current + // iterator position. + + // Create a scope for the hidden local variables used for the iterator. + pushScope(compiler); + + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'for'."); + consume(compiler, TOKEN_NAME, "Expect for loop variable name."); + + // Remember the name of the loop variable. + const char* name = compiler->parser->previous.start; + int length = compiler->parser->previous.length; + + consume(compiler, TOKEN_IN, "Expect 'in' after loop variable."); + ignoreNewlines(compiler); + + // Evaluate the sequence expression and store it in a hidden local variable. + // The space in the variable name ensures it won't collide with a user-defined + // variable. + expression(compiler); + + // Verify that there is space to hidden local variables. + // Note that we expect only two addLocal calls next to each other in the + // following code. + if (compiler->numLocals + 2 > MAX_LOCALS) + { + error(compiler, "Cannot declare more than %d variables in one scope. (Not enough space for for-loops internal variables)", + MAX_LOCALS); + return; + } + int seqSlot = addLocal(compiler, "seq ", 4); + + // Create another hidden local for the iterator object. + null(compiler, false); + int iterSlot = addLocal(compiler, "iter ", 5); + + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after loop expression."); + + Loop loop; + startLoop(compiler, &loop); + + // Advance the iterator by calling the ".iterate" method on the sequence. + loadLocal(compiler, seqSlot); + loadLocal(compiler, iterSlot); + + // Update and test the iterator. + callMethod(compiler, 1, "iterate(_)", 10); + emitByteArg(compiler, CODE_STORE_LOCAL, iterSlot); + testExitLoop(compiler); + + // Get the current value in the sequence by calling ".iteratorValue". + loadLocal(compiler, seqSlot); + loadLocal(compiler, iterSlot); + callMethod(compiler, 1, "iteratorValue(_)", 16); + + // Bind the loop variable in its own scope. This ensures we get a fresh + // variable each iteration so that closures for it don't all see the same one. + pushScope(compiler); + addLocal(compiler, name, length); + + loopBody(compiler); + + // Loop variable. + popScope(compiler); + + endLoop(compiler); + + // Hidden variables. + popScope(compiler); +} + +static void ifStatement(Compiler* compiler) +{ + // Compile the condition. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'if'."); + expression(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after if condition."); + + // Jump to the else branch if the condition is false. + int ifJump = emitJump(compiler, CODE_JUMP_IF); + + // Compile the then branch. + statement(compiler); + + // Compile the else branch if there is one. + if (match(compiler, TOKEN_ELSE)) + { + // Jump over the else branch when the if branch is taken. + int elseJump = emitJump(compiler, CODE_JUMP); + patchJump(compiler, ifJump); + + statement(compiler); + + // Patch the jump over the else. + patchJump(compiler, elseJump); + } + else + { + patchJump(compiler, ifJump); + } +} + +static void whileStatement(Compiler* compiler) +{ + Loop loop; + startLoop(compiler, &loop); + + // Compile the condition. + consume(compiler, TOKEN_LEFT_PAREN, "Expect '(' after 'while'."); + expression(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, "Expect ')' after while condition."); + + testExitLoop(compiler); + loopBody(compiler); + endLoop(compiler); +} + +// Compiles a simple statement. These can only appear at the top-level or +// within curly blocks. Simple statements exclude variable binding statements +// like "var" and "class" which are not allowed directly in places like the +// branches of an "if" statement. +// +// Unlike expressions, statements do not leave a value on the stack. +void statement(Compiler* compiler) +{ + if (match(compiler, TOKEN_BREAK)) + { + if (compiler->loop == NULL) + { + error(compiler, "Cannot use 'break' outside of a loop."); + return; + } + + // Since we will be jumping out of the scope, make sure any locals in it + // are discarded first. + discardLocals(compiler, compiler->loop->scopeDepth + 1); + + // Emit a placeholder instruction for the jump to the end of the body. When + // we're done compiling the loop body and know where the end is, we'll + // replace these with `CODE_JUMP` instructions with appropriate offsets. + // We use `CODE_END` here because that can't occur in the middle of + // bytecode. + emitJump(compiler, CODE_END); + } + else if (match(compiler, TOKEN_CONTINUE)) + { + if (compiler->loop == NULL) + { + error(compiler, "Cannot use 'continue' outside of a loop."); + return; + } + + // Since we will be jumping out of the scope, make sure any locals in it + // are discarded first. + discardLocals(compiler, compiler->loop->scopeDepth + 1); + + // emit a jump back to the top of the loop + int loopOffset = compiler->fn->code.count - compiler->loop->start + 2; + emitShortArg(compiler, CODE_LOOP, loopOffset); + } + else if (match(compiler, TOKEN_FOR)) + { + forStatement(compiler); + } + else if (match(compiler, TOKEN_IF)) + { + ifStatement(compiler); + } + else if (match(compiler, TOKEN_RETURN)) + { + // Compile the return value. + if (peek(compiler) == TOKEN_LINE) + { + // If there's no expression after return, initializers should + // return 'this' and regular methods should return null + Code result = compiler->isInitializer ? CODE_LOAD_LOCAL_0 : CODE_NULL; + emitOp(compiler, result); + } + else + { + if (compiler->isInitializer) + { + error(compiler, "A constructor cannot return a value."); + } + + expression(compiler); + } + + emitOp(compiler, CODE_RETURN); + } + else if (match(compiler, TOKEN_WHILE)) + { + whileStatement(compiler); + } + else if (match(compiler, TOKEN_LEFT_BRACE)) + { + // Block statement. + pushScope(compiler); + if (finishBlock(compiler)) + { + // Block was an expression, so discard it. + emitOp(compiler, CODE_POP); + } + popScope(compiler); + } + else + { + // Expression statement. + expression(compiler); + emitOp(compiler, CODE_POP); + } +} + +// Creates a matching constructor method for an initializer with [signature] +// and [initializerSymbol]. +// +// Construction is a two-stage process in Wren that involves two separate +// methods. There is a static method that allocates a new instance of the class. +// It then invokes an initializer method on the new instance, forwarding all of +// the constructor arguments to it. +// +// The allocator method always has a fixed implementation: +// +// CODE_CONSTRUCT - Replace the class in slot 0 with a new instance of it. +// CODE_CALL - Invoke the initializer on the new instance. +// +// This creates that method and calls the initializer with [initializerSymbol]. +static void createConstructor(Compiler* compiler, Signature* signature, + int initializerSymbol) +{ + Compiler methodCompiler; + initCompiler(&methodCompiler, compiler->parser, compiler, true); + + // Allocate the instance. + emitOp(&methodCompiler, compiler->enclosingClass->isForeign + ? CODE_FOREIGN_CONSTRUCT : CODE_CONSTRUCT); + + // Run its initializer. + emitShortArg(&methodCompiler, (Code)(CODE_CALL_0 + signature->arity), + initializerSymbol); + + // Return the instance. + emitOp(&methodCompiler, CODE_RETURN); + + endCompiler(&methodCompiler, "", 0); +} + +// Loads the enclosing class onto the stack and then binds the function already +// on the stack as a method on that class. +static void defineMethod(Compiler* compiler, Variable classVariable, + bool isStatic, int methodSymbol) +{ + // Load the class. We have to do this for each method because we can't + // keep the class on top of the stack. If there are static fields, they + // will be locals above the initial variable slot for the class on the + // stack. To skip past those, we just load the class each time right before + // defining a method. + loadVariable(compiler, classVariable); + + // Define the method. + Code instruction = isStatic ? CODE_METHOD_STATIC : CODE_METHOD_INSTANCE; + emitShortArg(compiler, instruction, methodSymbol); +} + +// Declares a method in the enclosing class with [signature]. +// +// Reports an error if a method with that signature is already declared. +// Returns the symbol for the method. +static int declareMethod(Compiler* compiler, Signature* signature, + const char* name, int length) +{ + int symbol = signatureSymbol(compiler, signature); + + // See if the class has already declared method with this signature. + ClassInfo* classInfo = compiler->enclosingClass; + IntBuffer* methods = classInfo->inStatic + ? &classInfo->staticMethods : &classInfo->methods; + for (int i = 0; i < methods->count; i++) + { + if (methods->data[i] == symbol) + { + const char* staticPrefix = classInfo->inStatic ? "static " : ""; + error(compiler, "Class %s already defines a %smethod '%s'.", + &compiler->enclosingClass->name->value, staticPrefix, name); + break; + } + } + + wrenIntBufferWrite(compiler->parser->vm, methods, symbol); + return symbol; +} + +static Value consumeLiteral(Compiler* compiler, const char* message) +{ + if(match(compiler, TOKEN_FALSE)) return FALSE_VAL; + if(match(compiler, TOKEN_TRUE)) return TRUE_VAL; + if(match(compiler, TOKEN_NUMBER)) return compiler->parser->previous.value; + if(match(compiler, TOKEN_STRING)) return compiler->parser->previous.value; + if(match(compiler, TOKEN_NAME)) return compiler->parser->previous.value; + + error(compiler, message); + nextToken(compiler->parser); + return NULL_VAL; +} + +static bool matchAttribute(Compiler* compiler) { + + if(match(compiler, TOKEN_HASH)) + { + compiler->numAttributes++; + bool runtimeAccess = match(compiler, TOKEN_BANG); + if(match(compiler, TOKEN_NAME)) + { + Value group = compiler->parser->previous.value; + TokenType ahead = peek(compiler); + if(ahead == TOKEN_EQ || ahead == TOKEN_LINE) + { + Value key = group; + Value value = NULL_VAL; + if(match(compiler, TOKEN_EQ)) + { + value = consumeLiteral(compiler, "Expect a Bool, Num, String or Identifier literal for an attribute value."); + } + if(runtimeAccess) addToAttributeGroup(compiler, NULL_VAL, key, value); + } + else if(match(compiler, TOKEN_LEFT_PAREN)) + { + ignoreNewlines(compiler); + if(match(compiler, TOKEN_RIGHT_PAREN)) + { + error(compiler, "Expected attributes in group, group cannot be empty."); + } + else + { + while(peek(compiler) != TOKEN_RIGHT_PAREN) + { + consume(compiler, TOKEN_NAME, "Expect name for attribute key."); + Value key = compiler->parser->previous.value; + Value value = NULL_VAL; + if(match(compiler, TOKEN_EQ)) + { + value = consumeLiteral(compiler, "Expect a Bool, Num, String or Identifier literal for an attribute value."); + } + if(runtimeAccess) addToAttributeGroup(compiler, group, key, value); + ignoreNewlines(compiler); + if(!match(compiler, TOKEN_COMMA)) break; + ignoreNewlines(compiler); + } + + ignoreNewlines(compiler); + consume(compiler, TOKEN_RIGHT_PAREN, + "Expected ')' after grouped attributes."); + } + } + else + { + error(compiler, "Expect an equal, newline or grouping after an attribute key."); + } + } + else + { + error(compiler, "Expect an attribute definition after #."); + } + + consumeLine(compiler, "Expect newline after attribute."); + return true; + } + + return false; +} + +// Compiles a method definition inside a class body. +// +// Returns `true` if it compiled successfully, or `false` if the method couldn't +// be parsed. +static bool method(Compiler* compiler, Variable classVariable) +{ + // Parse any attributes before the method and store them + if(matchAttribute(compiler)) { + return method(compiler, classVariable); + } + + // TODO: What about foreign constructors? + bool isForeign = match(compiler, TOKEN_FOREIGN); + bool isStatic = match(compiler, TOKEN_STATIC); + compiler->enclosingClass->inStatic = isStatic; + + SignatureFn signatureFn = rules[compiler->parser->current.type].method; + nextToken(compiler->parser); + + if (signatureFn == NULL) + { + error(compiler, "Expect method definition."); + return false; + } + + // Build the method signature. + Signature signature = signatureFromToken(compiler, SIG_GETTER); + compiler->enclosingClass->signature = &signature; + + Compiler methodCompiler; + initCompiler(&methodCompiler, compiler->parser, compiler, true); + + // Compile the method signature. + signatureFn(&methodCompiler, &signature); + + methodCompiler.isInitializer = signature.type == SIG_INITIALIZER; + + if (isStatic && signature.type == SIG_INITIALIZER) + { + error(compiler, "A constructor cannot be static."); + } + + // Include the full signature in debug messages in stack traces. + char fullSignature[MAX_METHOD_SIGNATURE]; + int length; + signatureToString(&signature, fullSignature, &length); + + // Copy any attributes the compiler collected into the enclosing class + copyMethodAttributes(compiler, isForeign, isStatic, fullSignature, length); + + // Check for duplicate methods. Doesn't matter that it's already been + // defined, error will discard bytecode anyway. + // Check if the method table already contains this symbol + int methodSymbol = declareMethod(compiler, &signature, fullSignature, length); + + if (isForeign) + { + // Define a constant for the signature. + emitConstant(compiler, wrenNewStringLength(compiler->parser->vm, + fullSignature, length)); + + // We don't need the function we started compiling in the parameter list + // any more. + methodCompiler.parser->vm->compiler = methodCompiler.parent; + } + else + { + consume(compiler, TOKEN_LEFT_BRACE, "Expect '{' to begin method body."); + finishBody(&methodCompiler); + endCompiler(&methodCompiler, fullSignature, length); + } + + // Define the method. For a constructor, this defines the instance + // initializer method. + defineMethod(compiler, classVariable, isStatic, methodSymbol); + + if (signature.type == SIG_INITIALIZER) + { + // Also define a matching constructor method on the metaclass. + signature.type = SIG_METHOD; + int constructorSymbol = signatureSymbol(compiler, &signature); + + createConstructor(compiler, &signature, methodSymbol); + defineMethod(compiler, classVariable, true, constructorSymbol); + } + + return true; +} + +// Compiles a class definition. Assumes the "class" token has already been +// consumed (along with a possibly preceding "foreign" token). +static void classDefinition(Compiler* compiler, bool isForeign) +{ + // Create a variable to store the class in. + Variable classVariable; + classVariable.scope = compiler->scopeDepth == -1 ? SCOPE_MODULE : SCOPE_LOCAL; + classVariable.index = declareNamedVariable(compiler); + + // Create shared class name value + Value classNameString = wrenNewStringLength(compiler->parser->vm, + compiler->parser->previous.start, compiler->parser->previous.length); + + // Create class name string to track method duplicates + ObjString* className = AS_STRING(classNameString); + + // Make a string constant for the name. + emitConstant(compiler, classNameString); + + // Load the superclass (if there is one). + if (match(compiler, TOKEN_IS)) + { + parsePrecedence(compiler, PREC_CALL); + } + else + { + // Implicitly inherit from Object. + loadCoreVariable(compiler, "Object"); + } + + // Store a placeholder for the number of fields argument. We don't know the + // count until we've compiled all the methods to see which fields are used. + int numFieldsInstruction = -1; + if (isForeign) + { + emitOp(compiler, CODE_FOREIGN_CLASS); + } + else + { + numFieldsInstruction = emitByteArg(compiler, CODE_CLASS, 255); + } + + // Store it in its name. + defineVariable(compiler, classVariable.index); + + // Push a local variable scope. Static fields in a class body are hoisted out + // into local variables declared in this scope. Methods that use them will + // have upvalues referencing them. + pushScope(compiler); + + ClassInfo classInfo; + classInfo.isForeign = isForeign; + classInfo.name = className; + + // Allocate attribute maps if necessary. + // A method will allocate the methods one if needed + classInfo.classAttributes = compiler->attributes->count > 0 + ? wrenNewMap(compiler->parser->vm) + : NULL; + classInfo.methodAttributes = NULL; + // Copy any existing attributes into the class + copyAttributes(compiler, classInfo.classAttributes); + + // Set up a symbol table for the class's fields. We'll initially compile + // them to slots starting at zero. When the method is bound to the class, the + // bytecode will be adjusted by [wrenBindMethod] to take inherited fields + // into account. + wrenSymbolTableInit(&classInfo.fields); + + // Set up symbol buffers to track duplicate static and instance methods. + wrenIntBufferInit(&classInfo.methods); + wrenIntBufferInit(&classInfo.staticMethods); + compiler->enclosingClass = &classInfo; + + // Compile the method definitions. + consume(compiler, TOKEN_LEFT_BRACE, "Expect '{' after class declaration."); + matchLine(compiler); + + while (!match(compiler, TOKEN_RIGHT_BRACE)) + { + if (!method(compiler, classVariable)) break; + + // Don't require a newline after the last definition. + if (match(compiler, TOKEN_RIGHT_BRACE)) break; + + consumeLine(compiler, "Expect newline after definition in class."); + } + + // If any attributes are present, + // instantiate a ClassAttributes instance for the class + // and send it over to CODE_END_CLASS + bool hasAttr = classInfo.classAttributes != NULL || + classInfo.methodAttributes != NULL; + if(hasAttr) { + emitClassAttributes(compiler, &classInfo); + loadVariable(compiler, classVariable); + // At the moment, we don't have other uses for CODE_END_CLASS, + // so we put it inside this condition. Later, we can always + // emit it and use it as needed. + emitOp(compiler, CODE_END_CLASS); + } + + // Update the class with the number of fields. + if (!isForeign) + { + compiler->fn->code.data[numFieldsInstruction] = + (uint8_t)classInfo.fields.count; + } + + // Clear symbol tables for tracking field and method names. + wrenSymbolTableClear(compiler->parser->vm, &classInfo.fields); + wrenIntBufferClear(compiler->parser->vm, &classInfo.methods); + wrenIntBufferClear(compiler->parser->vm, &classInfo.staticMethods); + compiler->enclosingClass = NULL; + popScope(compiler); +} + +// Compiles an "import" statement. +// +// An import compiles to a series of instructions. Given: +// +// import "foo" for Bar, Baz +// +// We compile a single IMPORT_MODULE "foo" instruction to load the module +// itself. When that finishes executing the imported module, it leaves the +// ObjModule in vm->lastModule. Then, for Bar and Baz, we: +// +// * Declare a variable in the current scope with that name. +// * Emit an IMPORT_VARIABLE instruction to load the variable's value from the +// other module. +// * Compile the code to store that value in the variable in this scope. +static void import(Compiler* compiler) +{ + ignoreNewlines(compiler); + consume(compiler, TOKEN_STRING, "Expect a string after 'import'."); + int moduleConstant = addConstant(compiler, compiler->parser->previous.value); + + // Load the module. + emitShortArg(compiler, CODE_IMPORT_MODULE, moduleConstant); + + // Discard the unused result value from calling the module body's closure. + emitOp(compiler, CODE_POP); + + // The for clause is optional. + if (!match(compiler, TOKEN_FOR)) return; + + // Compile the comma-separated list of variables to import. + do + { + ignoreNewlines(compiler); + + consume(compiler, TOKEN_NAME, "Expect variable name."); + + // We need to hold onto the source variable, + // in order to reference it in the import later + Token sourceVariableToken = compiler->parser->previous; + + // Define a string constant for the original variable name. + int sourceVariableConstant = addConstant(compiler, + wrenNewStringLength(compiler->parser->vm, + sourceVariableToken.start, + sourceVariableToken.length)); + + // Store the symbol we care about for the variable + int slot = -1; + if(match(compiler, TOKEN_AS)) + { + //import "module" for Source as Dest + //Use 'Dest' as the name by declaring a new variable for it. + //This parses a name after the 'as' and defines it. + slot = declareNamedVariable(compiler); + } + else + { + //import "module" for Source + //Uses 'Source' as the name directly + slot = declareVariable(compiler, &sourceVariableToken); + } + + // Load the variable from the other module. + emitShortArg(compiler, CODE_IMPORT_VARIABLE, sourceVariableConstant); + + // Store the result in the variable here. + defineVariable(compiler, slot); + } while (match(compiler, TOKEN_COMMA)); +} + +// Compiles a "var" variable definition statement. +static void variableDefinition(Compiler* compiler) +{ + // Grab its name, but don't declare it yet. A (local) variable shouldn't be + // in scope in its own initializer. + consume(compiler, TOKEN_NAME, "Expect variable name."); + Token nameToken = compiler->parser->previous; + + // Compile the initializer. + if (match(compiler, TOKEN_EQ)) + { + ignoreNewlines(compiler); + expression(compiler); + } + else + { + // Default initialize it to null. + null(compiler, false); + } + + // Now put it in scope. + int symbol = declareVariable(compiler, &nameToken); + defineVariable(compiler, symbol); +} + +// Compiles a "definition". These are the statements that bind new variables. +// They can only appear at the top level of a block and are prohibited in places +// like the non-curly body of an if or while. +void definition(Compiler* compiler) +{ + if(matchAttribute(compiler)) { + definition(compiler); + return; + } + + if (match(compiler, TOKEN_CLASS)) + { + classDefinition(compiler, false); + return; + } + else if (match(compiler, TOKEN_FOREIGN)) + { + consume(compiler, TOKEN_CLASS, "Expect 'class' after 'foreign'."); + classDefinition(compiler, true); + return; + } + + disallowAttributes(compiler); + + if (match(compiler, TOKEN_IMPORT)) + { + import(compiler); + } + else if (match(compiler, TOKEN_VAR)) + { + variableDefinition(compiler); + } + else + { + statement(compiler); + } +} + +ObjFn* wrenCompile(WrenVM* vm, ObjModule* module, const char* source, + bool isExpression, bool printErrors) +{ + // Skip the UTF-8 BOM if there is one. + if (strncmp(source, "\xEF\xBB\xBF", 3) == 0) source += 3; + + Parser parser; + parser.vm = vm; + parser.module = module; + parser.source = source; + + parser.tokenStart = source; + parser.currentChar = source; + parser.currentLine = 1; + parser.numParens = 0; + + // Zero-init the current token. This will get copied to previous when + // nextToken() is called below. + parser.next.type = TOKEN_ERROR; + parser.next.start = source; + parser.next.length = 0; + parser.next.line = 0; + parser.next.value = UNDEFINED_VAL; + + parser.printErrors = printErrors; + parser.hasError = false; + + // Read the first token into next + nextToken(&parser); + // Copy next -> current + nextToken(&parser); + + int numExistingVariables = module->variables.count; + + Compiler compiler; + initCompiler(&compiler, &parser, NULL, false); + ignoreNewlines(&compiler); + + if (isExpression) + { + expression(&compiler); + consume(&compiler, TOKEN_EOF, "Expect end of expression."); + } + else + { + while (!match(&compiler, TOKEN_EOF)) + { + definition(&compiler); + + // If there is no newline, it must be the end of file on the same line. + if (!matchLine(&compiler)) + { + consume(&compiler, TOKEN_EOF, "Expect end of file."); + break; + } + } + + emitOp(&compiler, CODE_END_MODULE); + } + + emitOp(&compiler, CODE_RETURN); + + // See if there are any implicitly declared module-level variables that never + // got an explicit definition. They will have values that are numbers + // indicating the line where the variable was first used. + for (int i = numExistingVariables; i < parser.module->variables.count; i++) + { + if (IS_NUM(parser.module->variables.data[i])) + { + // Synthesize a token for the original use site. + parser.previous.type = TOKEN_NAME; + parser.previous.start = parser.module->variableNames.data[i]->value; + parser.previous.length = parser.module->variableNames.data[i]->length; + parser.previous.line = (int)AS_NUM(parser.module->variables.data[i]); + error(&compiler, "Variable is used but not defined."); + } + } + + return endCompiler(&compiler, "(script)", 8); +} + +void wrenBindMethodCode(ObjClass* classObj, ObjFn* fn) +{ + int ip = 0; + for (;;) + { + Code instruction = (Code)fn->code.data[ip]; + switch (instruction) + { + case CODE_LOAD_FIELD: + case CODE_STORE_FIELD: + case CODE_LOAD_FIELD_THIS: + case CODE_STORE_FIELD_THIS: + // Shift this class's fields down past the inherited ones. We don't + // check for overflow here because we'll see if the number of fields + // overflows when the subclass is created. + fn->code.data[ip + 1] += classObj->superclass->numFields; + break; + + case CODE_SUPER_0: + case CODE_SUPER_1: + case CODE_SUPER_2: + case CODE_SUPER_3: + case CODE_SUPER_4: + case CODE_SUPER_5: + case CODE_SUPER_6: + case CODE_SUPER_7: + case CODE_SUPER_8: + case CODE_SUPER_9: + case CODE_SUPER_10: + case CODE_SUPER_11: + case CODE_SUPER_12: + case CODE_SUPER_13: + case CODE_SUPER_14: + case CODE_SUPER_15: + case CODE_SUPER_16: + { + // Fill in the constant slot with a reference to the superclass. + int constant = (fn->code.data[ip + 3] << 8) | fn->code.data[ip + 4]; + fn->constants.data[constant] = OBJ_VAL(classObj->superclass); + break; + } + + case CODE_CLOSURE: + { + // Bind the nested closure too. + int constant = (fn->code.data[ip + 1] << 8) | fn->code.data[ip + 2]; + wrenBindMethodCode(classObj, AS_FN(fn->constants.data[constant])); + break; + } + + case CODE_END: + return; + + default: + // Other instructions are unaffected, so just skip over them. + break; + } + ip += 1 + getByteCountForArguments(fn->code.data, fn->constants.data, ip); + } +} + +void wrenMarkCompiler(WrenVM* vm, Compiler* compiler) +{ + wrenGrayValue(vm, compiler->parser->current.value); + wrenGrayValue(vm, compiler->parser->previous.value); + wrenGrayValue(vm, compiler->parser->next.value); + + // Walk up the parent chain to mark the outer compilers too. The VM only + // tracks the innermost one. + do + { + wrenGrayObj(vm, (Obj*)compiler->fn); + wrenGrayObj(vm, (Obj*)compiler->constants); + wrenGrayObj(vm, (Obj*)compiler->attributes); + + if (compiler->enclosingClass != NULL) + { + wrenBlackenSymbolTable(vm, &compiler->enclosingClass->fields); + + if(compiler->enclosingClass->methodAttributes != NULL) + { + wrenGrayObj(vm, (Obj*)compiler->enclosingClass->methodAttributes); + } + if(compiler->enclosingClass->classAttributes != NULL) + { + wrenGrayObj(vm, (Obj*)compiler->enclosingClass->classAttributes); + } + } + + compiler = compiler->parent; + } + while (compiler != NULL); +} + +// Helpers for Attributes + +// Throw an error if any attributes were found preceding, +// and clear the attributes so the error doesn't keep happening. +static void disallowAttributes(Compiler* compiler) +{ + if (compiler->numAttributes > 0) + { + error(compiler, "Attributes can only specified before a class or a method"); + wrenMapClear(compiler->parser->vm, compiler->attributes); + compiler->numAttributes = 0; + } +} + +// Add an attribute to a given group in the compiler attribues map +static void addToAttributeGroup(Compiler* compiler, + Value group, Value key, Value value) +{ + WrenVM* vm = compiler->parser->vm; + + if(IS_OBJ(group)) wrenPushRoot(vm, AS_OBJ(group)); + if(IS_OBJ(key)) wrenPushRoot(vm, AS_OBJ(key)); + if(IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + Value groupMapValue = wrenMapGet(compiler->attributes, group); + if(IS_UNDEFINED(groupMapValue)) + { + groupMapValue = OBJ_VAL(wrenNewMap(vm)); + wrenMapSet(vm, compiler->attributes, group, groupMapValue); + } + + //we store them as a map per so we can maintain duplicate keys + //group = { key:[value, ...], } + ObjMap* groupMap = AS_MAP(groupMapValue); + + //var keyItems = group[key] + //if(!keyItems) keyItems = group[key] = [] + Value keyItemsValue = wrenMapGet(groupMap, key); + if(IS_UNDEFINED(keyItemsValue)) + { + keyItemsValue = OBJ_VAL(wrenNewList(vm, 0)); + wrenMapSet(vm, groupMap, key, keyItemsValue); + } + + //keyItems.add(value) + ObjList* keyItems = AS_LIST(keyItemsValue); + wrenValueBufferWrite(vm, &keyItems->elements, value); + + if(IS_OBJ(group)) wrenPopRoot(vm); + if(IS_OBJ(key)) wrenPopRoot(vm); + if(IS_OBJ(value)) wrenPopRoot(vm); +} + + +// Emit the attributes in the give map onto the stack +static void emitAttributes(Compiler* compiler, ObjMap* attributes) +{ + // Instantiate a new map for the attributes + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + // The attributes are stored as group = { key:[value, value, ...] } + // so our first level is the group map + for(uint32_t groupIdx = 0; groupIdx < attributes->capacity; groupIdx++) + { + const MapEntry* groupEntry = &attributes->entries[groupIdx]; + if(IS_UNDEFINED(groupEntry->key)) continue; + //group key + emitConstant(compiler, groupEntry->key); + + //group value is gonna be a map + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + ObjMap* groupItems = AS_MAP(groupEntry->value); + for(uint32_t itemIdx = 0; itemIdx < groupItems->capacity; itemIdx++) + { + const MapEntry* itemEntry = &groupItems->entries[itemIdx]; + if(IS_UNDEFINED(itemEntry->key)) continue; + + emitConstant(compiler, itemEntry->key); + // Attribute key value, key = [] + loadCoreVariable(compiler, "List"); + callMethod(compiler, 0, "new()", 5); + // Add the items to the key list + ObjList* items = AS_LIST(itemEntry->value); + for(int itemIdx = 0; itemIdx < items->elements.count; ++itemIdx) + { + emitConstant(compiler, items->elements.data[itemIdx]); + callMethod(compiler, 1, "addCore_(_)", 11); + } + // Add the list to the map + callMethod(compiler, 2, "addCore_(_,_)", 13); + } + + // Add the key/value to the map + callMethod(compiler, 2, "addCore_(_,_)", 13); + } + +} + +// Methods are stored as method <-> attributes, so we have to have +// an indirection to resolve for methods +static void emitAttributeMethods(Compiler* compiler, ObjMap* attributes) +{ + // Instantiate a new map for the attributes + loadCoreVariable(compiler, "Map"); + callMethod(compiler, 0, "new()", 5); + + for(uint32_t methodIdx = 0; methodIdx < attributes->capacity; methodIdx++) + { + const MapEntry* methodEntry = &attributes->entries[methodIdx]; + if(IS_UNDEFINED(methodEntry->key)) continue; + emitConstant(compiler, methodEntry->key); + ObjMap* attributeMap = AS_MAP(methodEntry->value); + emitAttributes(compiler, attributeMap); + callMethod(compiler, 2, "addCore_(_,_)", 13); + } +} + + +// Emit the final ClassAttributes that exists at runtime +static void emitClassAttributes(Compiler* compiler, ClassInfo* classInfo) +{ + loadCoreVariable(compiler, "ClassAttributes"); + + classInfo->classAttributes + ? emitAttributes(compiler, classInfo->classAttributes) + : null(compiler, false); + + classInfo->methodAttributes + ? emitAttributeMethods(compiler, classInfo->methodAttributes) + : null(compiler, false); + + callMethod(compiler, 2, "new(_,_)", 8); +} + +// Copy the current attributes stored in the compiler into a destination map +// This also resets the counter, since the intent is to consume the attributes +static void copyAttributes(Compiler* compiler, ObjMap* into) +{ + compiler->numAttributes = 0; + + if(compiler->attributes->count == 0) return; + if(into == NULL) return; + + WrenVM* vm = compiler->parser->vm; + + // Note we copy the actual values as is since we'll take ownership + // and clear the original map + for(uint32_t attrIdx = 0; attrIdx < compiler->attributes->capacity; attrIdx++) + { + const MapEntry* attrEntry = &compiler->attributes->entries[attrIdx]; + if(IS_UNDEFINED(attrEntry->key)) continue; + wrenMapSet(vm, into, attrEntry->key, attrEntry->value); + } + + wrenMapClear(vm, compiler->attributes); +} + +// Copy the current attributes stored in the compiler into the method specific +// attributes for the current enclosingClass. +// This also resets the counter, since the intent is to consume the attributes +static void copyMethodAttributes(Compiler* compiler, bool isForeign, + bool isStatic, const char* fullSignature, int32_t length) +{ + compiler->numAttributes = 0; + + if(compiler->attributes->count == 0) return; + + WrenVM* vm = compiler->parser->vm; + + // Make a map for this method to copy into + ObjMap* methodAttr = wrenNewMap(vm); + wrenPushRoot(vm, (Obj*)methodAttr); + copyAttributes(compiler, methodAttr); + + // Include 'foreign static ' in front as needed + int32_t fullLength = length; + if(isForeign) fullLength += 8; + if(isStatic) fullLength += 7; + char fullSignatureWithPrefix[MAX_METHOD_SIGNATURE + 8 + 7]; + const char* foreignPrefix = isForeign ? "foreign " : ""; + const char* staticPrefix = isStatic ? "static " : ""; + sprintf(fullSignatureWithPrefix, "%s%s%.*s", foreignPrefix, staticPrefix, + length, fullSignature); + fullSignatureWithPrefix[fullLength] = '\0'; + + if(compiler->enclosingClass->methodAttributes == NULL) { + compiler->enclosingClass->methodAttributes = wrenNewMap(vm); + } + + // Store the method attributes in the class map + Value key = wrenNewStringLength(vm, fullSignatureWithPrefix, fullLength); + wrenMapSet(vm, compiler->enclosingClass->methodAttributes, key, OBJ_VAL(methodAttr)); + + wrenPopRoot(vm); +} +// End file "wren_compiler.c" +// Begin file "wren_primitive.c" +// Begin file "wren_primitive.h" +#ifndef wren_primitive_h +#define wren_primitive_h + + +// Binds a primitive method named [name] (in Wren) implemented using C function +// [fn] to `ObjClass` [cls]. +#define PRIMITIVE(cls, name, function) \ + do \ + { \ + int symbol = wrenSymbolTableEnsure(vm, \ + &vm->methodNames, name, strlen(name)); \ + Method method; \ + method.type = METHOD_PRIMITIVE; \ + method.as.primitive = prim_##function; \ + wrenBindMethod(vm, cls, symbol, method); \ + } while (false) + +// Binds a primitive method named [name] (in Wren) implemented using C function +// [fn] to `ObjClass` [cls], but as a FN call. +#define FUNCTION_CALL(cls, name, function) \ + do \ + { \ + int symbol = wrenSymbolTableEnsure(vm, \ + &vm->methodNames, name, strlen(name)); \ + Method method; \ + method.type = METHOD_FUNCTION_CALL; \ + method.as.primitive = prim_##function; \ + wrenBindMethod(vm, cls, symbol, method); \ + } while (false) + +// Defines a primitive method whose C function name is [name]. This abstracts +// the actual type signature of a primitive function and makes it clear which C +// functions are invoked as primitives. +#define DEF_PRIMITIVE(name) \ + static bool prim_##name(WrenVM* vm, Value* args) + +#define RETURN_VAL(value) \ + do \ + { \ + args[0] = value; \ + return true; \ + } while (false) + +#define RETURN_OBJ(obj) RETURN_VAL(OBJ_VAL(obj)) +#define RETURN_BOOL(value) RETURN_VAL(BOOL_VAL(value)) +#define RETURN_FALSE RETURN_VAL(FALSE_VAL) +#define RETURN_NULL RETURN_VAL(NULL_VAL) +#define RETURN_NUM(value) RETURN_VAL(NUM_VAL(value)) +#define RETURN_TRUE RETURN_VAL(TRUE_VAL) + +#define RETURN_ERROR(msg) \ + do \ + { \ + vm->fiber->error = wrenNewStringLength(vm, msg, sizeof(msg) - 1); \ + return false; \ + } while (false) + +#define RETURN_ERROR_FMT(...) \ + do \ + { \ + vm->fiber->error = wrenStringFormat(vm, __VA_ARGS__); \ + return false; \ + } while (false) + +// Validates that the given [arg] is a function. Returns true if it is. If not, +// reports an error and returns false. +bool validateFn(WrenVM* vm, Value arg, const char* argName); + +// Validates that the given [arg] is a Num. Returns true if it is. If not, +// reports an error and returns false. +bool validateNum(WrenVM* vm, Value arg, const char* argName); + +// Validates that [value] is an integer. Returns true if it is. If not, reports +// an error and returns false. +bool validateIntValue(WrenVM* vm, double value, const char* argName); + +// Validates that the given [arg] is an integer. Returns true if it is. If not, +// reports an error and returns false. +bool validateInt(WrenVM* vm, Value arg, const char* argName); + +// Validates that [arg] is a valid object for use as a map key. Returns true if +// it is. If not, reports an error and returns false. +bool validateKey(WrenVM* vm, Value arg); + +// Validates that the argument at [argIndex] is an integer within `[0, count)`. +// Also allows negative indices which map backwards from the end. Returns the +// valid positive index value. If invalid, reports an error and returns +// `UINT32_MAX`. +uint32_t validateIndex(WrenVM* vm, Value arg, uint32_t count, + const char* argName); + +// Validates that the given [arg] is a String. Returns true if it is. If not, +// reports an error and returns false. +bool validateString(WrenVM* vm, Value arg, const char* argName); + +// Given a [range] and the [length] of the object being operated on, determines +// the series of elements that should be chosen from the underlying object. +// Handles ranges that count backwards from the end as well as negative ranges. +// +// Returns the index from which the range should start or `UINT32_MAX` if the +// range is invalid. After calling, [length] will be updated with the number of +// elements in the resulting sequence. [step] will be direction that the range +// is going: `1` if the range is increasing from the start index or `-1` if the +// range is decreasing. +uint32_t calculateRange(WrenVM* vm, ObjRange* range, uint32_t* length, + int* step); + +#endif +// End file "wren_primitive.h" + +#include + +// Validates that [value] is an integer within `[0, count)`. Also allows +// negative indices which map backwards from the end. Returns the valid positive +// index value. If invalid, reports an error and returns `UINT32_MAX`. +static uint32_t validateIndexValue(WrenVM* vm, uint32_t count, double value, + const char* argName) +{ + if (!validateIntValue(vm, value, argName)) return UINT32_MAX; + + // Negative indices count from the end. + if (value < 0) value = count + value; + + // Check bounds. + if (value >= 0 && value < count) return (uint32_t)value; + + vm->fiber->error = wrenStringFormat(vm, "$ out of bounds.", argName); + return UINT32_MAX; +} + +bool validateFn(WrenVM* vm, Value arg, const char* argName) +{ + if (IS_CLOSURE(arg)) return true; + RETURN_ERROR_FMT("$ must be a function.", argName); +} + +bool validateNum(WrenVM* vm, Value arg, const char* argName) +{ + if (IS_NUM(arg)) return true; + RETURN_ERROR_FMT("$ must be a number.", argName); +} + +bool validateIntValue(WrenVM* vm, double value, const char* argName) +{ + if (trunc(value) == value) return true; + RETURN_ERROR_FMT("$ must be an integer.", argName); +} + +bool validateInt(WrenVM* vm, Value arg, const char* argName) +{ + // Make sure it's a number first. + if (!validateNum(vm, arg, argName)) return false; + return validateIntValue(vm, AS_NUM(arg), argName); +} + +bool validateKey(WrenVM* vm, Value arg) +{ + if (wrenMapIsValidKey(arg)) return true; + + RETURN_ERROR("Key must be a value type."); +} + +uint32_t validateIndex(WrenVM* vm, Value arg, uint32_t count, + const char* argName) +{ + if (!validateNum(vm, arg, argName)) return UINT32_MAX; + return validateIndexValue(vm, count, AS_NUM(arg), argName); +} + +bool validateString(WrenVM* vm, Value arg, const char* argName) +{ + if (IS_STRING(arg)) return true; + RETURN_ERROR_FMT("$ must be a string.", argName); +} + +uint32_t calculateRange(WrenVM* vm, ObjRange* range, uint32_t* length, + int* step) +{ + *step = 0; + + // Edge case: an empty range is allowed at the end of a sequence. This way, + // list[0..-1] and list[0...list.count] can be used to copy a list even when + // empty. + if (range->from == *length && + range->to == (range->isInclusive ? -1.0 : (double)*length)) + { + *length = 0; + return 0; + } + + uint32_t from = validateIndexValue(vm, *length, range->from, "Range start"); + if (from == UINT32_MAX) return UINT32_MAX; + + // Bounds check the end manually to handle exclusive ranges. + double value = range->to; + if (!validateIntValue(vm, value, "Range end")) return UINT32_MAX; + + // Negative indices count from the end. + if (value < 0) value = *length + value; + + // Convert the exclusive range to an inclusive one. + if (!range->isInclusive) + { + // An exclusive range with the same start and end points is empty. + if (value == from) + { + *length = 0; + return from; + } + + // Shift the endpoint to make it inclusive, handling both increasing and + // decreasing ranges. + value += value >= from ? -1 : 1; + } + + // Check bounds. + if (value < 0 || value >= *length) + { + vm->fiber->error = CONST_STRING(vm, "Range end out of bounds."); + return UINT32_MAX; + } + + uint32_t to = (uint32_t)value; + *length = abs((int)(from - to)) + 1; + *step = from < to ? 1 : -1; + return from; +} +// End file "wren_primitive.c" +// Begin file "wren_core.c" +#include +#include +#include +#include +#include +#include + +// Begin file "wren_core.h" +#ifndef wren_core_h +#define wren_core_h + + +// This module defines the built-in classes and their primitives methods that +// are implemented directly in C code. Some languages try to implement as much +// of the core module itself in the primary language instead of in the host +// language. +// +// With Wren, we try to do as much of it in C as possible. Primitive methods +// are always faster than code written in Wren, and it minimizes startup time +// since we don't have to parse, compile, and execute Wren code. +// +// There is one limitation, though. Methods written in C cannot call Wren ones. +// They can only be the top of the callstack, and immediately return. This +// makes it difficult to have primitive methods that rely on polymorphic +// behavior. For example, `System.print` should call `toString` on its argument, +// including user-defined `toString` methods on user-defined classes. + +void wrenInitializeCore(WrenVM* vm); + +#endif +// End file "wren_core.h" + +// Begin file "wren_core.wren.inc" +// Generated automatically from src/vm/wren_core.wren. Do not edit. +static const char* coreModuleSource = +"class Bool {}\n" +"class Fiber {}\n" +"class Fn {}\n" +"class Null {}\n" +"class Num {}\n" +"\n" +"class Sequence {\n" +" all(f) {\n" +" var result = true\n" +" for (element in this) {\n" +" result = f.call(element)\n" +" if (!result) return result\n" +" }\n" +" return result\n" +" }\n" +"\n" +" any(f) {\n" +" var result = false\n" +" for (element in this) {\n" +" result = f.call(element)\n" +" if (result) return result\n" +" }\n" +" return result\n" +" }\n" +"\n" +" contains(element) {\n" +" for (item in this) {\n" +" if (element == item) return true\n" +" }\n" +" return false\n" +" }\n" +"\n" +" count {\n" +" var result = 0\n" +" for (element in this) {\n" +" result = result + 1\n" +" }\n" +" return result\n" +" }\n" +"\n" +" count(f) {\n" +" var result = 0\n" +" for (element in this) {\n" +" if (f.call(element)) result = result + 1\n" +" }\n" +" return result\n" +" }\n" +"\n" +" each(f) {\n" +" for (element in this) {\n" +" f.call(element)\n" +" }\n" +" }\n" +"\n" +" isEmpty { iterate(null) ? false : true }\n" +"\n" +" map(transformation) { MapSequence.new(this, transformation) }\n" +"\n" +" skip(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" return SkipSequence.new(this, count)\n" +" }\n" +"\n" +" take(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" return TakeSequence.new(this, count)\n" +" }\n" +"\n" +" where(predicate) { WhereSequence.new(this, predicate) }\n" +"\n" +" reduce(acc, f) {\n" +" for (element in this) {\n" +" acc = f.call(acc, element)\n" +" }\n" +" return acc\n" +" }\n" +"\n" +" reduce(f) {\n" +" var iter = iterate(null)\n" +" if (!iter) Fiber.abort(\"Can't reduce an empty sequence.\")\n" +"\n" +" // Seed with the first element.\n" +" var result = iteratorValue(iter)\n" +" while (iter = iterate(iter)) {\n" +" result = f.call(result, iteratorValue(iter))\n" +" }\n" +"\n" +" return result\n" +" }\n" +"\n" +" join() { join(\"\") }\n" +"\n" +" join(sep) {\n" +" var first = true\n" +" var result = \"\"\n" +"\n" +" for (element in this) {\n" +" if (!first) result = result + sep\n" +" first = false\n" +" result = result + element.toString\n" +" }\n" +"\n" +" return result\n" +" }\n" +"\n" +" toList {\n" +" var result = List.new()\n" +" for (element in this) {\n" +" result.add(element)\n" +" }\n" +" return result\n" +" }\n" +"}\n" +"\n" +"class MapSequence is Sequence {\n" +" construct new(sequence, fn) {\n" +" _sequence = sequence\n" +" _fn = fn\n" +" }\n" +"\n" +" iterate(iterator) { _sequence.iterate(iterator) }\n" +" iteratorValue(iterator) { _fn.call(_sequence.iteratorValue(iterator)) }\n" +"}\n" +"\n" +"class SkipSequence is Sequence {\n" +" construct new(sequence, count) {\n" +" _sequence = sequence\n" +" _count = count\n" +" }\n" +"\n" +" iterate(iterator) {\n" +" if (iterator) {\n" +" return _sequence.iterate(iterator)\n" +" } else {\n" +" iterator = _sequence.iterate(iterator)\n" +" var count = _count\n" +" while (count > 0 && iterator) {\n" +" iterator = _sequence.iterate(iterator)\n" +" count = count - 1\n" +" }\n" +" return iterator\n" +" }\n" +" }\n" +"\n" +" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n" +"}\n" +"\n" +"class TakeSequence is Sequence {\n" +" construct new(sequence, count) {\n" +" _sequence = sequence\n" +" _count = count\n" +" }\n" +"\n" +" iterate(iterator) {\n" +" if (!iterator) _taken = 1 else _taken = _taken + 1\n" +" return _taken > _count ? null : _sequence.iterate(iterator)\n" +" }\n" +"\n" +" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n" +"}\n" +"\n" +"class WhereSequence is Sequence {\n" +" construct new(sequence, fn) {\n" +" _sequence = sequence\n" +" _fn = fn\n" +" }\n" +"\n" +" iterate(iterator) {\n" +" while (iterator = _sequence.iterate(iterator)) {\n" +" if (_fn.call(_sequence.iteratorValue(iterator))) break\n" +" }\n" +" return iterator\n" +" }\n" +"\n" +" iteratorValue(iterator) { _sequence.iteratorValue(iterator) }\n" +"}\n" +"\n" +"class String is Sequence {\n" +" bytes { StringByteSequence.new(this) }\n" +" codePoints { StringCodePointSequence.new(this) }\n" +"\n" +" split(delimiter) {\n" +" if (!(delimiter is String) || delimiter.isEmpty) {\n" +" Fiber.abort(\"Delimiter must be a non-empty string.\")\n" +" }\n" +"\n" +" var result = []\n" +"\n" +" var last = 0\n" +" var index = 0\n" +"\n" +" var delimSize = delimiter.byteCount_\n" +" var size = byteCount_\n" +"\n" +" while (last < size && (index = indexOf(delimiter, last)) != -1) {\n" +" result.add(this[last...index])\n" +" last = index + delimSize\n" +" }\n" +"\n" +" if (last < size) {\n" +" result.add(this[last..-1])\n" +" } else {\n" +" result.add(\"\")\n" +" }\n" +" return result\n" +" }\n" +"\n" +" replace(from, to) {\n" +" if (!(from is String) || from.isEmpty) {\n" +" Fiber.abort(\"From must be a non-empty string.\")\n" +" } else if (!(to is String)) {\n" +" Fiber.abort(\"To must be a string.\")\n" +" }\n" +"\n" +" var result = \"\"\n" +"\n" +" var last = 0\n" +" var index = 0\n" +"\n" +" var fromSize = from.byteCount_\n" +" var size = byteCount_\n" +"\n" +" while (last < size && (index = indexOf(from, last)) != -1) {\n" +" result = result + this[last...index] + to\n" +" last = index + fromSize\n" +" }\n" +"\n" +" if (last < size) result = result + this[last..-1]\n" +"\n" +" return result\n" +" }\n" +"\n" +" trim() { trim_(\"\\t\\r\\n \", true, true) }\n" +" trim(chars) { trim_(chars, true, true) }\n" +" trimEnd() { trim_(\"\\t\\r\\n \", false, true) }\n" +" trimEnd(chars) { trim_(chars, false, true) }\n" +" trimStart() { trim_(\"\\t\\r\\n \", true, false) }\n" +" trimStart(chars) { trim_(chars, true, false) }\n" +"\n" +" trim_(chars, trimStart, trimEnd) {\n" +" if (!(chars is String)) {\n" +" Fiber.abort(\"Characters must be a string.\")\n" +" }\n" +"\n" +" var codePoints = chars.codePoints.toList\n" +"\n" +" var start\n" +" if (trimStart) {\n" +" while (start = iterate(start)) {\n" +" if (!codePoints.contains(codePointAt_(start))) break\n" +" }\n" +"\n" +" if (start == false) return \"\"\n" +" } else {\n" +" start = 0\n" +" }\n" +"\n" +" var end\n" +" if (trimEnd) {\n" +" end = byteCount_ - 1\n" +" while (end >= start) {\n" +" var codePoint = codePointAt_(end)\n" +" if (codePoint != -1 && !codePoints.contains(codePoint)) break\n" +" end = end - 1\n" +" }\n" +"\n" +" if (end < start) return \"\"\n" +" } else {\n" +" end = -1\n" +" }\n" +"\n" +" return this[start..end]\n" +" }\n" +"\n" +" *(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" var result = \"\"\n" +" for (i in 0...count) {\n" +" result = result + this\n" +" }\n" +" return result\n" +" }\n" +"}\n" +"\n" +"class StringByteSequence is Sequence {\n" +" construct new(string) {\n" +" _string = string\n" +" }\n" +"\n" +" [index] { _string.byteAt_(index) }\n" +" iterate(iterator) { _string.iterateByte_(iterator) }\n" +" iteratorValue(iterator) { _string.byteAt_(iterator) }\n" +"\n" +" count { _string.byteCount_ }\n" +"}\n" +"\n" +"class StringCodePointSequence is Sequence {\n" +" construct new(string) {\n" +" _string = string\n" +" }\n" +"\n" +" [index] { _string.codePointAt_(index) }\n" +" iterate(iterator) { _string.iterate(iterator) }\n" +" iteratorValue(iterator) { _string.codePointAt_(iterator) }\n" +"\n" +" count { _string.count }\n" +"}\n" +"\n" +"class List is Sequence {\n" +" addAll(other) {\n" +" for (element in other) {\n" +" add(element)\n" +" }\n" +" return other\n" +" }\n" +"\n" +" sort() { sort {|low, high| low < high } }\n" +"\n" +" sort(comparer) {\n" +" if (!(comparer is Fn)) {\n" +" Fiber.abort(\"Comparer must be a function.\")\n" +" }\n" +" quicksort_(0, count - 1, comparer)\n" +" return this\n" +" }\n" +"\n" +" quicksort_(low, high, comparer) {\n" +" if (low < high) {\n" +" var p = partition_(low, high, comparer)\n" +" quicksort_(low, p - 1, comparer)\n" +" quicksort_(p + 1, high, comparer)\n" +" }\n" +" }\n" +"\n" +" partition_(low, high, comparer) {\n" +" var p = this[high]\n" +" var i = low - 1\n" +" for (j in low..(high-1)) {\n" +" if (comparer.call(this[j], p)) { \n" +" i = i + 1\n" +" var t = this[i]\n" +" this[i] = this[j]\n" +" this[j] = t\n" +" }\n" +" }\n" +" var t = this[i+1]\n" +" this[i+1] = this[high]\n" +" this[high] = t\n" +" return i+1\n" +" }\n" +"\n" +" toString { \"[%(join(\", \"))]\" }\n" +"\n" +" +(other) {\n" +" var result = this[0..-1]\n" +" for (element in other) {\n" +" result.add(element)\n" +" }\n" +" return result\n" +" }\n" +"\n" +" *(count) {\n" +" if (!(count is Num) || !count.isInteger || count < 0) {\n" +" Fiber.abort(\"Count must be a non-negative integer.\")\n" +" }\n" +"\n" +" var result = []\n" +" for (i in 0...count) {\n" +" result.addAll(this)\n" +" }\n" +" return result\n" +" }\n" +"}\n" +"\n" +"class Map is Sequence {\n" +" keys { MapKeySequence.new(this) }\n" +" values { MapValueSequence.new(this) }\n" +"\n" +" toString {\n" +" var first = true\n" +" var result = \"{\"\n" +"\n" +" for (key in keys) {\n" +" if (!first) result = result + \", \"\n" +" first = false\n" +" result = result + \"%(key): %(this[key])\"\n" +" }\n" +"\n" +" return result + \"}\"\n" +" }\n" +"\n" +" iteratorValue(iterator) {\n" +" return MapEntry.new(\n" +" keyIteratorValue_(iterator),\n" +" valueIteratorValue_(iterator))\n" +" }\n" +"}\n" +"\n" +"class MapEntry {\n" +" construct new(key, value) {\n" +" _key = key\n" +" _value = value\n" +" }\n" +"\n" +" key { _key }\n" +" value { _value }\n" +"\n" +" toString { \"%(_key):%(_value)\" }\n" +"}\n" +"\n" +"class MapKeySequence is Sequence {\n" +" construct new(map) {\n" +" _map = map\n" +" }\n" +"\n" +" iterate(n) { _map.iterate(n) }\n" +" iteratorValue(iterator) { _map.keyIteratorValue_(iterator) }\n" +"}\n" +"\n" +"class MapValueSequence is Sequence {\n" +" construct new(map) {\n" +" _map = map\n" +" }\n" +"\n" +" iterate(n) { _map.iterate(n) }\n" +" iteratorValue(iterator) { _map.valueIteratorValue_(iterator) }\n" +"}\n" +"\n" +"class Range is Sequence {}\n" +"\n" +"class System {\n" +" static print() {\n" +" writeString_(\"\\n\")\n" +" }\n" +"\n" +" static print(obj) {\n" +" writeObject_(obj)\n" +" writeString_(\"\\n\")\n" +" return obj\n" +" }\n" +"\n" +" static printAll(sequence) {\n" +" for (object in sequence) writeObject_(object)\n" +" writeString_(\"\\n\")\n" +" }\n" +"\n" +" static write(obj) {\n" +" writeObject_(obj)\n" +" return obj\n" +" }\n" +"\n" +" static writeAll(sequence) {\n" +" for (object in sequence) writeObject_(object)\n" +" }\n" +"\n" +" static writeObject_(obj) {\n" +" var string = obj.toString\n" +" if (string is String) {\n" +" writeString_(string)\n" +" } else {\n" +" writeString_(\"[invalid toString]\")\n" +" }\n" +" }\n" +"}\n" +"\n" +"class ClassAttributes {\n" +" self { _attributes }\n" +" methods { _methods }\n" +" construct new(attributes, methods) {\n" +" _attributes = attributes\n" +" _methods = methods\n" +" }\n" +" toString { \"attributes:%(_attributes) methods:%(_methods)\" }\n" +"}\n"; +// End file "wren_core.wren.inc" + +DEF_PRIMITIVE(bool_not) +{ + RETURN_BOOL(!AS_BOOL(args[0])); +} + +DEF_PRIMITIVE(bool_toString) +{ + if (AS_BOOL(args[0])) + { + RETURN_VAL(CONST_STRING(vm, "true")); + } + else + { + RETURN_VAL(CONST_STRING(vm, "false")); + } +} + +DEF_PRIMITIVE(class_name) +{ + RETURN_OBJ(AS_CLASS(args[0])->name); +} + +DEF_PRIMITIVE(class_supertype) +{ + ObjClass* classObj = AS_CLASS(args[0]); + + // Object has no superclass. + if (classObj->superclass == NULL) RETURN_NULL; + + RETURN_OBJ(classObj->superclass); +} + +DEF_PRIMITIVE(class_toString) +{ + RETURN_OBJ(AS_CLASS(args[0])->name); +} + +DEF_PRIMITIVE(class_attributes) +{ + RETURN_VAL(AS_CLASS(args[0])->attributes); +} + +DEF_PRIMITIVE(fiber_new) +{ + if (!validateFn(vm, args[1], "Argument")) return false; + + ObjClosure* closure = AS_CLOSURE(args[1]); + if (closure->fn->arity > 1) + { + RETURN_ERROR("Function cannot take more than one parameter."); + } + + RETURN_OBJ(wrenNewFiber(vm, closure)); +} + +DEF_PRIMITIVE(fiber_abort) +{ + vm->fiber->error = args[1]; + + // If the error is explicitly null, it's not really an abort. + return IS_NULL(args[1]); +} + +// Transfer execution to [fiber] coming from the current fiber whose stack has +// [args]. +// +// [isCall] is true if [fiber] is being called and not transferred. +// +// [hasValue] is true if a value in [args] is being passed to the new fiber. +// Otherwise, `null` is implicitly being passed. +static bool runFiber(WrenVM* vm, ObjFiber* fiber, Value* args, bool isCall, + bool hasValue, const char* verb) +{ + + if (wrenHasError(fiber)) + { + RETURN_ERROR_FMT("Cannot $ an aborted fiber.", verb); + } + + if (isCall) + { + // You can't call a called fiber, but you can transfer directly to it, + // which is why this check is gated on `isCall`. This way, after resuming a + // suspended fiber, it will run and then return to the fiber that called it + // and so on. + if (fiber->caller != NULL) RETURN_ERROR("Fiber has already been called."); + + if (fiber->state == FIBER_ROOT) RETURN_ERROR("Cannot call root fiber."); + + // Remember who ran it. + fiber->caller = vm->fiber; + } + + if (fiber->numFrames == 0) + { + RETURN_ERROR_FMT("Cannot $ a finished fiber.", verb); + } + + // When the calling fiber resumes, we'll store the result of the call in its + // stack. If the call has two arguments (the fiber and the value), we only + // need one slot for the result, so discard the other slot now. + if (hasValue) vm->fiber->stackTop--; + + if (fiber->numFrames == 1 && + fiber->frames[0].ip == fiber->frames[0].closure->fn->code.data) + { + // The fiber is being started for the first time. If its function takes a + // parameter, bind an argument to it. + if (fiber->frames[0].closure->fn->arity == 1) + { + fiber->stackTop[0] = hasValue ? args[1] : NULL_VAL; + fiber->stackTop++; + } + } + else + { + // The fiber is being resumed, make yield() or transfer() return the result. + fiber->stackTop[-1] = hasValue ? args[1] : NULL_VAL; + } + + vm->fiber = fiber; + return false; +} + +DEF_PRIMITIVE(fiber_call) +{ + return runFiber(vm, AS_FIBER(args[0]), args, true, false, "call"); +} + +DEF_PRIMITIVE(fiber_call1) +{ + return runFiber(vm, AS_FIBER(args[0]), args, true, true, "call"); +} + +DEF_PRIMITIVE(fiber_current) +{ + RETURN_OBJ(vm->fiber); +} + +DEF_PRIMITIVE(fiber_error) +{ + RETURN_VAL(AS_FIBER(args[0])->error); +} + +DEF_PRIMITIVE(fiber_isDone) +{ + ObjFiber* runFiber = AS_FIBER(args[0]); + RETURN_BOOL(runFiber->numFrames == 0 || wrenHasError(runFiber)); +} + +DEF_PRIMITIVE(fiber_suspend) +{ + // Switching to a null fiber tells the interpreter to stop and exit. + vm->fiber = NULL; + vm->apiStack = NULL; + return false; +} + +DEF_PRIMITIVE(fiber_transfer) +{ + return runFiber(vm, AS_FIBER(args[0]), args, false, false, "transfer to"); +} + +DEF_PRIMITIVE(fiber_transfer1) +{ + return runFiber(vm, AS_FIBER(args[0]), args, false, true, "transfer to"); +} + +DEF_PRIMITIVE(fiber_transferError) +{ + runFiber(vm, AS_FIBER(args[0]), args, false, true, "transfer to"); + vm->fiber->error = args[1]; + return false; +} + +DEF_PRIMITIVE(fiber_try) +{ + runFiber(vm, AS_FIBER(args[0]), args, true, false, "try"); + + // If we're switching to a valid fiber to try, remember that we're trying it. + if (!wrenHasError(vm->fiber)) vm->fiber->state = FIBER_TRY; + return false; +} + +DEF_PRIMITIVE(fiber_try1) +{ + runFiber(vm, AS_FIBER(args[0]), args, true, true, "try"); + + // If we're switching to a valid fiber to try, remember that we're trying it. + if (!wrenHasError(vm->fiber)) vm->fiber->state = FIBER_TRY; + return false; +} + +DEF_PRIMITIVE(fiber_yield) +{ + ObjFiber* current = vm->fiber; + vm->fiber = current->caller; + + // Unhook this fiber from the one that called it. + current->caller = NULL; + current->state = FIBER_OTHER; + + if (vm->fiber != NULL) + { + // Make the caller's run method return null. + vm->fiber->stackTop[-1] = NULL_VAL; + } + + return false; +} + +DEF_PRIMITIVE(fiber_yield1) +{ + ObjFiber* current = vm->fiber; + vm->fiber = current->caller; + + // Unhook this fiber from the one that called it. + current->caller = NULL; + current->state = FIBER_OTHER; + + if (vm->fiber != NULL) + { + // Make the caller's run method return the argument passed to yield. + vm->fiber->stackTop[-1] = args[1]; + + // When the yielding fiber resumes, we'll store the result of the yield + // call in its stack. Since Fiber.yield(value) has two arguments (the Fiber + // class and the value) and we only need one slot for the result, discard + // the other slot now. + current->stackTop--; + } + + return false; +} + +DEF_PRIMITIVE(fn_new) +{ + if (!validateFn(vm, args[1], "Argument")) return false; + + // The block argument is already a function, so just return it. + RETURN_VAL(args[1]); +} + +DEF_PRIMITIVE(fn_arity) +{ + RETURN_NUM(AS_CLOSURE(args[0])->fn->arity); +} + +static void call_fn(WrenVM* vm, Value* args, int numArgs) +{ + // +1 to include the function itself. + wrenCallFunction(vm, vm->fiber, AS_CLOSURE(args[0]), numArgs + 1); +} + +#define DEF_FN_CALL(numArgs) \ + DEF_PRIMITIVE(fn_call##numArgs) \ + { \ + call_fn(vm, args, numArgs); \ + return false; \ + } + +DEF_FN_CALL(0) +DEF_FN_CALL(1) +DEF_FN_CALL(2) +DEF_FN_CALL(3) +DEF_FN_CALL(4) +DEF_FN_CALL(5) +DEF_FN_CALL(6) +DEF_FN_CALL(7) +DEF_FN_CALL(8) +DEF_FN_CALL(9) +DEF_FN_CALL(10) +DEF_FN_CALL(11) +DEF_FN_CALL(12) +DEF_FN_CALL(13) +DEF_FN_CALL(14) +DEF_FN_CALL(15) +DEF_FN_CALL(16) + +DEF_PRIMITIVE(fn_toString) +{ + RETURN_VAL(CONST_STRING(vm, "")); +} + +// Creates a new list of size args[1], with all elements initialized to args[2]. +DEF_PRIMITIVE(list_filled) +{ + if (!validateInt(vm, args[1], "Size")) return false; + if (AS_NUM(args[1]) < 0) RETURN_ERROR("Size cannot be negative."); + + uint32_t size = (uint32_t)AS_NUM(args[1]); + ObjList* list = wrenNewList(vm, size); + + for (uint32_t i = 0; i < size; i++) + { + list->elements.data[i] = args[2]; + } + + RETURN_OBJ(list); +} + +DEF_PRIMITIVE(list_new) +{ + RETURN_OBJ(wrenNewList(vm, 0)); +} + +DEF_PRIMITIVE(list_add) +{ + wrenValueBufferWrite(vm, &AS_LIST(args[0])->elements, args[1]); + RETURN_VAL(args[1]); +} + +// Adds an element to the list and then returns the list itself. This is called +// by the compiler when compiling list literals instead of using add() to +// minimize stack churn. +DEF_PRIMITIVE(list_addCore) +{ + wrenValueBufferWrite(vm, &AS_LIST(args[0])->elements, args[1]); + + // Return the list. + RETURN_VAL(args[0]); +} + +DEF_PRIMITIVE(list_clear) +{ + wrenValueBufferClear(vm, &AS_LIST(args[0])->elements); + RETURN_NULL; +} + +DEF_PRIMITIVE(list_count) +{ + RETURN_NUM(AS_LIST(args[0])->elements.count); +} + +DEF_PRIMITIVE(list_insert) +{ + ObjList* list = AS_LIST(args[0]); + + // count + 1 here so you can "insert" at the very end. + uint32_t index = validateIndex(vm, args[1], list->elements.count + 1, + "Index"); + if (index == UINT32_MAX) return false; + + wrenListInsert(vm, list, args[2], index); + RETURN_VAL(args[2]); +} + +DEF_PRIMITIVE(list_iterate) +{ + ObjList* list = AS_LIST(args[0]); + + // If we're starting the iteration, return the first index. + if (IS_NULL(args[1])) + { + if (list->elements.count == 0) RETURN_FALSE; + RETURN_NUM(0); + } + + if (!validateInt(vm, args[1], "Iterator")) return false; + + // Stop if we're out of bounds. + double index = AS_NUM(args[1]); + if (index < 0 || index >= list->elements.count - 1) RETURN_FALSE; + + // Otherwise, move to the next index. + RETURN_NUM(index + 1); +} + +DEF_PRIMITIVE(list_iteratorValue) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t index = validateIndex(vm, args[1], list->elements.count, "Iterator"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(list->elements.data[index]); +} + +DEF_PRIMITIVE(list_removeAt) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t index = validateIndex(vm, args[1], list->elements.count, "Index"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(wrenListRemoveAt(vm, list, index)); +} + +DEF_PRIMITIVE(list_removeValue) { + ObjList* list = AS_LIST(args[0]); + int index = wrenListIndexOf(vm, list, args[1]); + if(index == -1) RETURN_NULL; + RETURN_VAL(wrenListRemoveAt(vm, list, index)); +} + +DEF_PRIMITIVE(list_indexOf) +{ + ObjList* list = AS_LIST(args[0]); + RETURN_NUM(wrenListIndexOf(vm, list, args[1])); +} + +DEF_PRIMITIVE(list_swap) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t indexA = validateIndex(vm, args[1], list->elements.count, "Index 0"); + if (indexA == UINT32_MAX) return false; + uint32_t indexB = validateIndex(vm, args[2], list->elements.count, "Index 1"); + if (indexB == UINT32_MAX) return false; + + Value a = list->elements.data[indexA]; + list->elements.data[indexA] = list->elements.data[indexB]; + list->elements.data[indexB] = a; + + RETURN_NULL; +} + +DEF_PRIMITIVE(list_subscript) +{ + ObjList* list = AS_LIST(args[0]); + + if (IS_NUM(args[1])) + { + uint32_t index = validateIndex(vm, args[1], list->elements.count, + "Subscript"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(list->elements.data[index]); + } + + if (!IS_RANGE(args[1])) + { + RETURN_ERROR("Subscript must be a number or a range."); + } + + int step; + uint32_t count = list->elements.count; + uint32_t start = calculateRange(vm, AS_RANGE(args[1]), &count, &step); + if (start == UINT32_MAX) return false; + + ObjList* result = wrenNewList(vm, count); + for (uint32_t i = 0; i < count; i++) + { + result->elements.data[i] = list->elements.data[start + i * step]; + } + + RETURN_OBJ(result); +} + +DEF_PRIMITIVE(list_subscriptSetter) +{ + ObjList* list = AS_LIST(args[0]); + uint32_t index = validateIndex(vm, args[1], list->elements.count, + "Subscript"); + if (index == UINT32_MAX) return false; + + list->elements.data[index] = args[2]; + RETURN_VAL(args[2]); +} + +DEF_PRIMITIVE(map_new) +{ + RETURN_OBJ(wrenNewMap(vm)); +} + +DEF_PRIMITIVE(map_subscript) +{ + if (!validateKey(vm, args[1])) return false; + + ObjMap* map = AS_MAP(args[0]); + Value value = wrenMapGet(map, args[1]); + if (IS_UNDEFINED(value)) RETURN_NULL; + + RETURN_VAL(value); +} + +DEF_PRIMITIVE(map_subscriptSetter) +{ + if (!validateKey(vm, args[1])) return false; + + wrenMapSet(vm, AS_MAP(args[0]), args[1], args[2]); + RETURN_VAL(args[2]); +} + +// Adds an entry to the map and then returns the map itself. This is called by +// the compiler when compiling map literals instead of using [_]=(_) to +// minimize stack churn. +DEF_PRIMITIVE(map_addCore) +{ + if (!validateKey(vm, args[1])) return false; + + wrenMapSet(vm, AS_MAP(args[0]), args[1], args[2]); + + // Return the map itself. + RETURN_VAL(args[0]); +} + +DEF_PRIMITIVE(map_clear) +{ + wrenMapClear(vm, AS_MAP(args[0])); + RETURN_NULL; +} + +DEF_PRIMITIVE(map_containsKey) +{ + if (!validateKey(vm, args[1])) return false; + + RETURN_BOOL(!IS_UNDEFINED(wrenMapGet(AS_MAP(args[0]), args[1]))); +} + +DEF_PRIMITIVE(map_count) +{ + RETURN_NUM(AS_MAP(args[0])->count); +} + +DEF_PRIMITIVE(map_iterate) +{ + ObjMap* map = AS_MAP(args[0]); + + if (map->count == 0) RETURN_FALSE; + + // If we're starting the iteration, start at the first used entry. + uint32_t index = 0; + + // Otherwise, start one past the last entry we stopped at. + if (!IS_NULL(args[1])) + { + if (!validateInt(vm, args[1], "Iterator")) return false; + + if (AS_NUM(args[1]) < 0) RETURN_FALSE; + index = (uint32_t)AS_NUM(args[1]); + + if (index >= map->capacity) RETURN_FALSE; + + // Advance the iterator. + index++; + } + + // Find a used entry, if any. + for (; index < map->capacity; index++) + { + if (!IS_UNDEFINED(map->entries[index].key)) RETURN_NUM(index); + } + + // If we get here, walked all of the entries. + RETURN_FALSE; +} + +DEF_PRIMITIVE(map_remove) +{ + if (!validateKey(vm, args[1])) return false; + + RETURN_VAL(wrenMapRemoveKey(vm, AS_MAP(args[0]), args[1])); +} + +DEF_PRIMITIVE(map_keyIteratorValue) +{ + ObjMap* map = AS_MAP(args[0]); + uint32_t index = validateIndex(vm, args[1], map->capacity, "Iterator"); + if (index == UINT32_MAX) return false; + + MapEntry* entry = &map->entries[index]; + if (IS_UNDEFINED(entry->key)) + { + RETURN_ERROR("Invalid map iterator."); + } + + RETURN_VAL(entry->key); +} + +DEF_PRIMITIVE(map_valueIteratorValue) +{ + ObjMap* map = AS_MAP(args[0]); + uint32_t index = validateIndex(vm, args[1], map->capacity, "Iterator"); + if (index == UINT32_MAX) return false; + + MapEntry* entry = &map->entries[index]; + if (IS_UNDEFINED(entry->key)) + { + RETURN_ERROR("Invalid map iterator."); + } + + RETURN_VAL(entry->value); +} + +DEF_PRIMITIVE(null_not) +{ + RETURN_VAL(TRUE_VAL); +} + +DEF_PRIMITIVE(null_toString) +{ + RETURN_VAL(CONST_STRING(vm, "null")); +} + +DEF_PRIMITIVE(num_fromString) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[1]); + + // Corner case: Can't parse an empty string. + if (string->length == 0) RETURN_NULL; + + errno = 0; + char* end; + double number = strtod(string->value, &end); + + // Skip past any trailing whitespace. + while (*end != '\0' && isspace((unsigned char)*end)) end++; + + if (errno == ERANGE) RETURN_ERROR("Number literal is too large."); + + // We must have consumed the entire string. Otherwise, it contains non-number + // characters and we can't parse it. + if (end < string->value + string->length) RETURN_NULL; + + RETURN_NUM(number); +} + +// Defines a primitive on Num that calls infix [op] and returns [type]. +#define DEF_NUM_CONSTANT(name, value) \ + DEF_PRIMITIVE(num_##name) \ + { \ + RETURN_NUM(value); \ + } + +DEF_NUM_CONSTANT(infinity, INFINITY) +DEF_NUM_CONSTANT(nan, WREN_DOUBLE_NAN) +DEF_NUM_CONSTANT(pi, 3.14159265358979323846264338327950288) +DEF_NUM_CONSTANT(tau, 6.28318530717958647692528676655900577) + +DEF_NUM_CONSTANT(largest, DBL_MAX) +DEF_NUM_CONSTANT(smallest, DBL_MIN) + +DEF_NUM_CONSTANT(maxSafeInteger, 9007199254740991.0) +DEF_NUM_CONSTANT(minSafeInteger, -9007199254740991.0) + +// Defines a primitive on Num that calls infix [op] and returns [type]. +#define DEF_NUM_INFIX(name, op, type) \ + DEF_PRIMITIVE(num_##name) \ + { \ + if (!validateNum(vm, args[1], "Right operand")) return false; \ + RETURN_##type(AS_NUM(args[0]) op AS_NUM(args[1])); \ + } + +DEF_NUM_INFIX(minus, -, NUM) +DEF_NUM_INFIX(plus, +, NUM) +DEF_NUM_INFIX(multiply, *, NUM) +DEF_NUM_INFIX(divide, /, NUM) +DEF_NUM_INFIX(lt, <, BOOL) +DEF_NUM_INFIX(gt, >, BOOL) +DEF_NUM_INFIX(lte, <=, BOOL) +DEF_NUM_INFIX(gte, >=, BOOL) + +// Defines a primitive on Num that call infix bitwise [op]. +#define DEF_NUM_BITWISE(name, op) \ + DEF_PRIMITIVE(num_bitwise##name) \ + { \ + if (!validateNum(vm, args[1], "Right operand")) return false; \ + uint32_t left = (uint32_t)AS_NUM(args[0]); \ + uint32_t right = (uint32_t)AS_NUM(args[1]); \ + RETURN_NUM(left op right); \ + } + +DEF_NUM_BITWISE(And, &) +DEF_NUM_BITWISE(Or, |) +DEF_NUM_BITWISE(Xor, ^) +DEF_NUM_BITWISE(LeftShift, <<) +DEF_NUM_BITWISE(RightShift, >>) + +// Defines a primitive method on Num that returns the result of [fn]. +#define DEF_NUM_FN(name, fn) \ + DEF_PRIMITIVE(num_##name) \ + { \ + RETURN_NUM(fn(AS_NUM(args[0]))); \ + } + +DEF_NUM_FN(abs, fabs) +DEF_NUM_FN(acos, acos) +DEF_NUM_FN(asin, asin) +DEF_NUM_FN(atan, atan) +DEF_NUM_FN(cbrt, cbrt) +DEF_NUM_FN(ceil, ceil) +DEF_NUM_FN(cos, cos) +DEF_NUM_FN(floor, floor) +DEF_NUM_FN(negate, -) +DEF_NUM_FN(round, round) +DEF_NUM_FN(sin, sin) +DEF_NUM_FN(sqrt, sqrt) +DEF_NUM_FN(tan, tan) +DEF_NUM_FN(log, log) +DEF_NUM_FN(log2, log2) +DEF_NUM_FN(exp, exp) + +DEF_PRIMITIVE(num_mod) +{ + if (!validateNum(vm, args[1], "Right operand")) return false; + RETURN_NUM(fmod(AS_NUM(args[0]), AS_NUM(args[1]))); +} + +DEF_PRIMITIVE(num_eqeq) +{ + if (!IS_NUM(args[1])) RETURN_FALSE; + RETURN_BOOL(AS_NUM(args[0]) == AS_NUM(args[1])); +} + +DEF_PRIMITIVE(num_bangeq) +{ + if (!IS_NUM(args[1])) RETURN_TRUE; + RETURN_BOOL(AS_NUM(args[0]) != AS_NUM(args[1])); +} + +DEF_PRIMITIVE(num_bitwiseNot) +{ + // Bitwise operators always work on 32-bit unsigned ints. + RETURN_NUM(~(uint32_t)AS_NUM(args[0])); +} + +DEF_PRIMITIVE(num_dotDot) +{ + if (!validateNum(vm, args[1], "Right hand side of range")) return false; + + double from = AS_NUM(args[0]); + double to = AS_NUM(args[1]); + RETURN_VAL(wrenNewRange(vm, from, to, true)); +} + +DEF_PRIMITIVE(num_dotDotDot) +{ + if (!validateNum(vm, args[1], "Right hand side of range")) return false; + + double from = AS_NUM(args[0]); + double to = AS_NUM(args[1]); + RETURN_VAL(wrenNewRange(vm, from, to, false)); +} + +DEF_PRIMITIVE(num_atan2) +{ + if (!validateNum(vm, args[1], "x value")) return false; + + RETURN_NUM(atan2(AS_NUM(args[0]), AS_NUM(args[1]))); +} + +DEF_PRIMITIVE(num_min) +{ + if (!validateNum(vm, args[1], "Other value")) return false; + + double value = AS_NUM(args[0]); + double other = AS_NUM(args[1]); + RETURN_NUM(value <= other ? value : other); +} + +DEF_PRIMITIVE(num_max) +{ + if (!validateNum(vm, args[1], "Other value")) return false; + + double value = AS_NUM(args[0]); + double other = AS_NUM(args[1]); + RETURN_NUM(value > other ? value : other); +} + +DEF_PRIMITIVE(num_clamp) +{ + if (!validateNum(vm, args[1], "Min value")) return false; + if (!validateNum(vm, args[2], "Max value")) return false; + + double value = AS_NUM(args[0]); + double min = AS_NUM(args[1]); + double max = AS_NUM(args[2]); + double result = (value < min) ? min : ((value > max) ? max : value); + RETURN_NUM(result); +} + +DEF_PRIMITIVE(num_pow) +{ + if (!validateNum(vm, args[1], "Power value")) return false; + + RETURN_NUM(pow(AS_NUM(args[0]), AS_NUM(args[1]))); +} + +DEF_PRIMITIVE(num_fraction) +{ + double unused; + RETURN_NUM(modf(AS_NUM(args[0]) , &unused)); +} + +DEF_PRIMITIVE(num_isInfinity) +{ + RETURN_BOOL(isinf(AS_NUM(args[0]))); +} + +DEF_PRIMITIVE(num_isInteger) +{ + double value = AS_NUM(args[0]); + if (isnan(value) || isinf(value)) RETURN_FALSE; + RETURN_BOOL(trunc(value) == value); +} + +DEF_PRIMITIVE(num_isNan) +{ + RETURN_BOOL(isnan(AS_NUM(args[0]))); +} + +DEF_PRIMITIVE(num_sign) +{ + double value = AS_NUM(args[0]); + if (value > 0) + { + RETURN_NUM(1); + } + else if (value < 0) + { + RETURN_NUM(-1); + } + else + { + RETURN_NUM(0); + } +} + +DEF_PRIMITIVE(num_toString) +{ + RETURN_VAL(wrenNumToString(vm, AS_NUM(args[0]))); +} + +DEF_PRIMITIVE(num_truncate) +{ + double integer; + modf(AS_NUM(args[0]) , &integer); + RETURN_NUM(integer); +} + +DEF_PRIMITIVE(object_same) +{ + RETURN_BOOL(wrenValuesEqual(args[1], args[2])); +} + +DEF_PRIMITIVE(object_not) +{ + RETURN_VAL(FALSE_VAL); +} + +DEF_PRIMITIVE(object_eqeq) +{ + RETURN_BOOL(wrenValuesEqual(args[0], args[1])); +} + +DEF_PRIMITIVE(object_bangeq) +{ + RETURN_BOOL(!wrenValuesEqual(args[0], args[1])); +} + +DEF_PRIMITIVE(object_is) +{ + if (!IS_CLASS(args[1])) + { + RETURN_ERROR("Right operand must be a class."); + } + + ObjClass *classObj = wrenGetClass(vm, args[0]); + ObjClass *baseClassObj = AS_CLASS(args[1]); + + // Walk the superclass chain looking for the class. + do + { + if (baseClassObj == classObj) RETURN_BOOL(true); + + classObj = classObj->superclass; + } + while (classObj != NULL); + + RETURN_BOOL(false); +} + +DEF_PRIMITIVE(object_toString) +{ + Obj* obj = AS_OBJ(args[0]); + Value name = OBJ_VAL(obj->classObj->name); + RETURN_VAL(wrenStringFormat(vm, "instance of @", name)); +} + +DEF_PRIMITIVE(object_type) +{ + RETURN_OBJ(wrenGetClass(vm, args[0])); +} + +DEF_PRIMITIVE(range_from) +{ + RETURN_NUM(AS_RANGE(args[0])->from); +} + +DEF_PRIMITIVE(range_to) +{ + RETURN_NUM(AS_RANGE(args[0])->to); +} + +DEF_PRIMITIVE(range_min) +{ + ObjRange* range = AS_RANGE(args[0]); + RETURN_NUM(fmin(range->from, range->to)); +} + +DEF_PRIMITIVE(range_max) +{ + ObjRange* range = AS_RANGE(args[0]); + RETURN_NUM(fmax(range->from, range->to)); +} + +DEF_PRIMITIVE(range_isInclusive) +{ + RETURN_BOOL(AS_RANGE(args[0])->isInclusive); +} + +DEF_PRIMITIVE(range_iterate) +{ + ObjRange* range = AS_RANGE(args[0]); + + // Special case: empty range. + if (range->from == range->to && !range->isInclusive) RETURN_FALSE; + + // Start the iteration. + if (IS_NULL(args[1])) RETURN_NUM(range->from); + + if (!validateNum(vm, args[1], "Iterator")) return false; + + double iterator = AS_NUM(args[1]); + + // Iterate towards [to] from [from]. + if (range->from < range->to) + { + iterator++; + if (iterator > range->to) RETURN_FALSE; + } + else + { + iterator--; + if (iterator < range->to) RETURN_FALSE; + } + + if (!range->isInclusive && iterator == range->to) RETURN_FALSE; + + RETURN_NUM(iterator); +} + +DEF_PRIMITIVE(range_iteratorValue) +{ + // Assume the iterator is a number so that is the value of the range. + RETURN_VAL(args[1]); +} + +DEF_PRIMITIVE(range_toString) +{ + ObjRange* range = AS_RANGE(args[0]); + + Value from = wrenNumToString(vm, range->from); + wrenPushRoot(vm, AS_OBJ(from)); + + Value to = wrenNumToString(vm, range->to); + wrenPushRoot(vm, AS_OBJ(to)); + + Value result = wrenStringFormat(vm, "@$@", from, + range->isInclusive ? ".." : "...", to); + + wrenPopRoot(vm); + wrenPopRoot(vm); + RETURN_VAL(result); +} + +DEF_PRIMITIVE(string_fromCodePoint) +{ + if (!validateInt(vm, args[1], "Code point")) return false; + + int codePoint = (int)AS_NUM(args[1]); + if (codePoint < 0) + { + RETURN_ERROR("Code point cannot be negative."); + } + else if (codePoint > 0x10ffff) + { + RETURN_ERROR("Code point cannot be greater than 0x10ffff."); + } + + RETURN_VAL(wrenStringFromCodePoint(vm, codePoint)); +} + +DEF_PRIMITIVE(string_fromByte) +{ + if (!validateInt(vm, args[1], "Byte")) return false; + int byte = (int) AS_NUM(args[1]); + if (byte < 0) + { + RETURN_ERROR("Byte cannot be negative."); + } + else if (byte > 0xff) + { + RETURN_ERROR("Byte cannot be greater than 0xff."); + } + RETURN_VAL(wrenStringFromByte(vm, (uint8_t) byte)); +} + +DEF_PRIMITIVE(string_byteAt) +{ + ObjString* string = AS_STRING(args[0]); + + uint32_t index = validateIndex(vm, args[1], string->length, "Index"); + if (index == UINT32_MAX) return false; + + RETURN_NUM((uint8_t)string->value[index]); +} + +DEF_PRIMITIVE(string_byteCount) +{ + RETURN_NUM(AS_STRING(args[0])->length); +} + +DEF_PRIMITIVE(string_codePointAt) +{ + ObjString* string = AS_STRING(args[0]); + + uint32_t index = validateIndex(vm, args[1], string->length, "Index"); + if (index == UINT32_MAX) return false; + + // If we are in the middle of a UTF-8 sequence, indicate that. + const uint8_t* bytes = (uint8_t*)string->value; + if ((bytes[index] & 0xc0) == 0x80) RETURN_NUM(-1); + + // Decode the UTF-8 sequence. + RETURN_NUM(wrenUtf8Decode((uint8_t*)string->value + index, + string->length - index)); +} + +DEF_PRIMITIVE(string_contains) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + RETURN_BOOL(wrenStringFind(string, search, 0) != UINT32_MAX); +} + +DEF_PRIMITIVE(string_endsWith) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + // Edge case: If the search string is longer then return false right away. + if (search->length > string->length) RETURN_FALSE; + + RETURN_BOOL(memcmp(string->value + string->length - search->length, + search->value, search->length) == 0); +} + +DEF_PRIMITIVE(string_indexOf1) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + uint32_t index = wrenStringFind(string, search, 0); + RETURN_NUM(index == UINT32_MAX ? -1 : (int)index); +} + +DEF_PRIMITIVE(string_indexOf2) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + uint32_t start = validateIndex(vm, args[2], string->length, "Start"); + if (start == UINT32_MAX) return false; + + uint32_t index = wrenStringFind(string, search, start); + RETURN_NUM(index == UINT32_MAX ? -1 : (int)index); +} + +DEF_PRIMITIVE(string_iterate) +{ + ObjString* string = AS_STRING(args[0]); + + // If we're starting the iteration, return the first index. + if (IS_NULL(args[1])) + { + if (string->length == 0) RETURN_FALSE; + RETURN_NUM(0); + } + + if (!validateInt(vm, args[1], "Iterator")) return false; + + if (AS_NUM(args[1]) < 0) RETURN_FALSE; + uint32_t index = (uint32_t)AS_NUM(args[1]); + + // Advance to the beginning of the next UTF-8 sequence. + do + { + index++; + if (index >= string->length) RETURN_FALSE; + } while ((string->value[index] & 0xc0) == 0x80); + + RETURN_NUM(index); +} + +DEF_PRIMITIVE(string_iterateByte) +{ + ObjString* string = AS_STRING(args[0]); + + // If we're starting the iteration, return the first index. + if (IS_NULL(args[1])) + { + if (string->length == 0) RETURN_FALSE; + RETURN_NUM(0); + } + + if (!validateInt(vm, args[1], "Iterator")) return false; + + if (AS_NUM(args[1]) < 0) RETURN_FALSE; + uint32_t index = (uint32_t)AS_NUM(args[1]); + + // Advance to the next byte. + index++; + if (index >= string->length) RETURN_FALSE; + + RETURN_NUM(index); +} + +DEF_PRIMITIVE(string_iteratorValue) +{ + ObjString* string = AS_STRING(args[0]); + uint32_t index = validateIndex(vm, args[1], string->length, "Iterator"); + if (index == UINT32_MAX) return false; + + RETURN_VAL(wrenStringCodePointAt(vm, string, index)); +} + +DEF_PRIMITIVE(string_startsWith) +{ + if (!validateString(vm, args[1], "Argument")) return false; + + ObjString* string = AS_STRING(args[0]); + ObjString* search = AS_STRING(args[1]); + + // Edge case: If the search string is longer then return false right away. + if (search->length > string->length) RETURN_FALSE; + + RETURN_BOOL(memcmp(string->value, search->value, search->length) == 0); +} + +DEF_PRIMITIVE(string_plus) +{ + if (!validateString(vm, args[1], "Right operand")) return false; + RETURN_VAL(wrenStringFormat(vm, "@@", args[0], args[1])); +} + +DEF_PRIMITIVE(string_subscript) +{ + ObjString* string = AS_STRING(args[0]); + + if (IS_NUM(args[1])) + { + int index = validateIndex(vm, args[1], string->length, "Subscript"); + if (index == -1) return false; + + RETURN_VAL(wrenStringCodePointAt(vm, string, index)); + } + + if (!IS_RANGE(args[1])) + { + RETURN_ERROR("Subscript must be a number or a range."); + } + + int step; + uint32_t count = string->length; + int start = calculateRange(vm, AS_RANGE(args[1]), &count, &step); + if (start == -1) return false; + + RETURN_VAL(wrenNewStringFromRange(vm, string, start, count, step)); +} + +DEF_PRIMITIVE(string_toString) +{ + RETURN_VAL(args[0]); +} + +DEF_PRIMITIVE(system_clock) +{ + RETURN_NUM((double)clock() / CLOCKS_PER_SEC); +} + +DEF_PRIMITIVE(system_gc) +{ + wrenCollectGarbage(vm); + RETURN_NULL; +} + +DEF_PRIMITIVE(system_writeString) +{ + if (vm->config.writeFn != NULL) + { + vm->config.writeFn(vm, AS_CSTRING(args[1])); + } + + RETURN_VAL(args[1]); +} + +// Creates either the Object or Class class in the core module with [name]. +static ObjClass* defineClass(WrenVM* vm, ObjModule* module, const char* name) +{ + ObjString* nameString = AS_STRING(wrenNewString(vm, name)); + wrenPushRoot(vm, (Obj*)nameString); + + ObjClass* classObj = wrenNewSingleClass(vm, 0, nameString); + + wrenDefineVariable(vm, module, name, nameString->length, OBJ_VAL(classObj), NULL); + + wrenPopRoot(vm); + return classObj; +} + +void wrenInitializeCore(WrenVM* vm) +{ + ObjModule* coreModule = wrenNewModule(vm, NULL); + wrenPushRoot(vm, (Obj*)coreModule); + + // The core module's key is null in the module map. + wrenMapSet(vm, vm->modules, NULL_VAL, OBJ_VAL(coreModule)); + wrenPopRoot(vm); // coreModule. + + // Define the root Object class. This has to be done a little specially + // because it has no superclass. + vm->objectClass = defineClass(vm, coreModule, "Object"); + PRIMITIVE(vm->objectClass, "!", object_not); + PRIMITIVE(vm->objectClass, "==(_)", object_eqeq); + PRIMITIVE(vm->objectClass, "!=(_)", object_bangeq); + PRIMITIVE(vm->objectClass, "is(_)", object_is); + PRIMITIVE(vm->objectClass, "toString", object_toString); + PRIMITIVE(vm->objectClass, "type", object_type); + + // Now we can define Class, which is a subclass of Object. + vm->classClass = defineClass(vm, coreModule, "Class"); + wrenBindSuperclass(vm, vm->classClass, vm->objectClass); + PRIMITIVE(vm->classClass, "name", class_name); + PRIMITIVE(vm->classClass, "supertype", class_supertype); + PRIMITIVE(vm->classClass, "toString", class_toString); + PRIMITIVE(vm->classClass, "attributes", class_attributes); + + // Finally, we can define Object's metaclass which is a subclass of Class. + ObjClass* objectMetaclass = defineClass(vm, coreModule, "Object metaclass"); + + // Wire up the metaclass relationships now that all three classes are built. + vm->objectClass->obj.classObj = objectMetaclass; + objectMetaclass->obj.classObj = vm->classClass; + vm->classClass->obj.classObj = vm->classClass; + + // Do this after wiring up the metaclasses so objectMetaclass doesn't get + // collected. + wrenBindSuperclass(vm, objectMetaclass, vm->classClass); + + PRIMITIVE(objectMetaclass, "same(_,_)", object_same); + + // The core class diagram ends up looking like this, where single lines point + // to a class's superclass, and double lines point to its metaclass: + // + // .------------------------------------. .====. + // | .---------------. | # # + // v | v | v # + // .---------. .-------------------. .-------. # + // | Object |==>| Object metaclass |==>| Class |==" + // '---------' '-------------------' '-------' + // ^ ^ ^ ^ ^ + // | .--------------' # | # + // | | # | # + // .---------. .-------------------. # | # -. + // | Base |==>| Base metaclass |======" | # | + // '---------' '-------------------' | # | + // ^ | # | + // | .------------------' # | Example classes + // | | # | + // .---------. .-------------------. # | + // | Derived |==>| Derived metaclass |==========" | + // '---------' '-------------------' -' + + // The rest of the classes can now be defined normally. + wrenInterpret(vm, NULL, coreModuleSource); + + vm->boolClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Bool")); + PRIMITIVE(vm->boolClass, "toString", bool_toString); + PRIMITIVE(vm->boolClass, "!", bool_not); + + vm->fiberClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Fiber")); + PRIMITIVE(vm->fiberClass->obj.classObj, "new(_)", fiber_new); + PRIMITIVE(vm->fiberClass->obj.classObj, "abort(_)", fiber_abort); + PRIMITIVE(vm->fiberClass->obj.classObj, "current", fiber_current); + PRIMITIVE(vm->fiberClass->obj.classObj, "suspend()", fiber_suspend); + PRIMITIVE(vm->fiberClass->obj.classObj, "yield()", fiber_yield); + PRIMITIVE(vm->fiberClass->obj.classObj, "yield(_)", fiber_yield1); + PRIMITIVE(vm->fiberClass, "call()", fiber_call); + PRIMITIVE(vm->fiberClass, "call(_)", fiber_call1); + PRIMITIVE(vm->fiberClass, "error", fiber_error); + PRIMITIVE(vm->fiberClass, "isDone", fiber_isDone); + PRIMITIVE(vm->fiberClass, "transfer()", fiber_transfer); + PRIMITIVE(vm->fiberClass, "transfer(_)", fiber_transfer1); + PRIMITIVE(vm->fiberClass, "transferError(_)", fiber_transferError); + PRIMITIVE(vm->fiberClass, "try()", fiber_try); + PRIMITIVE(vm->fiberClass, "try(_)", fiber_try1); + + vm->fnClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Fn")); + PRIMITIVE(vm->fnClass->obj.classObj, "new(_)", fn_new); + + PRIMITIVE(vm->fnClass, "arity", fn_arity); + + FUNCTION_CALL(vm->fnClass, "call()", fn_call0); + FUNCTION_CALL(vm->fnClass, "call(_)", fn_call1); + FUNCTION_CALL(vm->fnClass, "call(_,_)", fn_call2); + FUNCTION_CALL(vm->fnClass, "call(_,_,_)", fn_call3); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_)", fn_call4); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_)", fn_call5); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_)", fn_call6); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_)", fn_call7); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_)", fn_call8); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_)", fn_call9); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_)", fn_call10); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_)", fn_call11); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_)", fn_call12); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call13); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call14); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call15); + FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call16); + + PRIMITIVE(vm->fnClass, "toString", fn_toString); + + vm->nullClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Null")); + PRIMITIVE(vm->nullClass, "!", null_not); + PRIMITIVE(vm->nullClass, "toString", null_toString); + + vm->numClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Num")); + PRIMITIVE(vm->numClass->obj.classObj, "fromString(_)", num_fromString); + PRIMITIVE(vm->numClass->obj.classObj, "infinity", num_infinity); + PRIMITIVE(vm->numClass->obj.classObj, "nan", num_nan); + PRIMITIVE(vm->numClass->obj.classObj, "pi", num_pi); + PRIMITIVE(vm->numClass->obj.classObj, "tau", num_tau); + PRIMITIVE(vm->numClass->obj.classObj, "largest", num_largest); + PRIMITIVE(vm->numClass->obj.classObj, "smallest", num_smallest); + PRIMITIVE(vm->numClass->obj.classObj, "maxSafeInteger", num_maxSafeInteger); + PRIMITIVE(vm->numClass->obj.classObj, "minSafeInteger", num_minSafeInteger); + PRIMITIVE(vm->numClass, "-(_)", num_minus); + PRIMITIVE(vm->numClass, "+(_)", num_plus); + PRIMITIVE(vm->numClass, "*(_)", num_multiply); + PRIMITIVE(vm->numClass, "/(_)", num_divide); + PRIMITIVE(vm->numClass, "<(_)", num_lt); + PRIMITIVE(vm->numClass, ">(_)", num_gt); + PRIMITIVE(vm->numClass, "<=(_)", num_lte); + PRIMITIVE(vm->numClass, ">=(_)", num_gte); + PRIMITIVE(vm->numClass, "&(_)", num_bitwiseAnd); + PRIMITIVE(vm->numClass, "|(_)", num_bitwiseOr); + PRIMITIVE(vm->numClass, "^(_)", num_bitwiseXor); + PRIMITIVE(vm->numClass, "<<(_)", num_bitwiseLeftShift); + PRIMITIVE(vm->numClass, ">>(_)", num_bitwiseRightShift); + PRIMITIVE(vm->numClass, "abs", num_abs); + PRIMITIVE(vm->numClass, "acos", num_acos); + PRIMITIVE(vm->numClass, "asin", num_asin); + PRIMITIVE(vm->numClass, "atan", num_atan); + PRIMITIVE(vm->numClass, "cbrt", num_cbrt); + PRIMITIVE(vm->numClass, "ceil", num_ceil); + PRIMITIVE(vm->numClass, "cos", num_cos); + PRIMITIVE(vm->numClass, "floor", num_floor); + PRIMITIVE(vm->numClass, "-", num_negate); + PRIMITIVE(vm->numClass, "round", num_round); + PRIMITIVE(vm->numClass, "min(_)", num_min); + PRIMITIVE(vm->numClass, "max(_)", num_max); + PRIMITIVE(vm->numClass, "clamp(_,_)", num_clamp); + PRIMITIVE(vm->numClass, "sin", num_sin); + PRIMITIVE(vm->numClass, "sqrt", num_sqrt); + PRIMITIVE(vm->numClass, "tan", num_tan); + PRIMITIVE(vm->numClass, "log", num_log); + PRIMITIVE(vm->numClass, "log2", num_log2); + PRIMITIVE(vm->numClass, "exp", num_exp); + PRIMITIVE(vm->numClass, "%(_)", num_mod); + PRIMITIVE(vm->numClass, "~", num_bitwiseNot); + PRIMITIVE(vm->numClass, "..(_)", num_dotDot); + PRIMITIVE(vm->numClass, "...(_)", num_dotDotDot); + PRIMITIVE(vm->numClass, "atan(_)", num_atan2); + PRIMITIVE(vm->numClass, "pow(_)", num_pow); + PRIMITIVE(vm->numClass, "fraction", num_fraction); + PRIMITIVE(vm->numClass, "isInfinity", num_isInfinity); + PRIMITIVE(vm->numClass, "isInteger", num_isInteger); + PRIMITIVE(vm->numClass, "isNan", num_isNan); + PRIMITIVE(vm->numClass, "sign", num_sign); + PRIMITIVE(vm->numClass, "toString", num_toString); + PRIMITIVE(vm->numClass, "truncate", num_truncate); + + // These are defined just so that 0 and -0 are equal, which is specified by + // IEEE 754 even though they have different bit representations. + PRIMITIVE(vm->numClass, "==(_)", num_eqeq); + PRIMITIVE(vm->numClass, "!=(_)", num_bangeq); + + vm->stringClass = AS_CLASS(wrenFindVariable(vm, coreModule, "String")); + PRIMITIVE(vm->stringClass->obj.classObj, "fromCodePoint(_)", string_fromCodePoint); + PRIMITIVE(vm->stringClass->obj.classObj, "fromByte(_)", string_fromByte); + PRIMITIVE(vm->stringClass, "+(_)", string_plus); + PRIMITIVE(vm->stringClass, "[_]", string_subscript); + PRIMITIVE(vm->stringClass, "byteAt_(_)", string_byteAt); + PRIMITIVE(vm->stringClass, "byteCount_", string_byteCount); + PRIMITIVE(vm->stringClass, "codePointAt_(_)", string_codePointAt); + PRIMITIVE(vm->stringClass, "contains(_)", string_contains); + PRIMITIVE(vm->stringClass, "endsWith(_)", string_endsWith); + PRIMITIVE(vm->stringClass, "indexOf(_)", string_indexOf1); + PRIMITIVE(vm->stringClass, "indexOf(_,_)", string_indexOf2); + PRIMITIVE(vm->stringClass, "iterate(_)", string_iterate); + PRIMITIVE(vm->stringClass, "iterateByte_(_)", string_iterateByte); + PRIMITIVE(vm->stringClass, "iteratorValue(_)", string_iteratorValue); + PRIMITIVE(vm->stringClass, "startsWith(_)", string_startsWith); + PRIMITIVE(vm->stringClass, "toString", string_toString); + + vm->listClass = AS_CLASS(wrenFindVariable(vm, coreModule, "List")); + PRIMITIVE(vm->listClass->obj.classObj, "filled(_,_)", list_filled); + PRIMITIVE(vm->listClass->obj.classObj, "new()", list_new); + PRIMITIVE(vm->listClass, "[_]", list_subscript); + PRIMITIVE(vm->listClass, "[_]=(_)", list_subscriptSetter); + PRIMITIVE(vm->listClass, "add(_)", list_add); + PRIMITIVE(vm->listClass, "addCore_(_)", list_addCore); + PRIMITIVE(vm->listClass, "clear()", list_clear); + PRIMITIVE(vm->listClass, "count", list_count); + PRIMITIVE(vm->listClass, "insert(_,_)", list_insert); + PRIMITIVE(vm->listClass, "iterate(_)", list_iterate); + PRIMITIVE(vm->listClass, "iteratorValue(_)", list_iteratorValue); + PRIMITIVE(vm->listClass, "removeAt(_)", list_removeAt); + PRIMITIVE(vm->listClass, "remove(_)", list_removeValue); + PRIMITIVE(vm->listClass, "indexOf(_)", list_indexOf); + PRIMITIVE(vm->listClass, "swap(_,_)", list_swap); + + vm->mapClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Map")); + PRIMITIVE(vm->mapClass->obj.classObj, "new()", map_new); + PRIMITIVE(vm->mapClass, "[_]", map_subscript); + PRIMITIVE(vm->mapClass, "[_]=(_)", map_subscriptSetter); + PRIMITIVE(vm->mapClass, "addCore_(_,_)", map_addCore); + PRIMITIVE(vm->mapClass, "clear()", map_clear); + PRIMITIVE(vm->mapClass, "containsKey(_)", map_containsKey); + PRIMITIVE(vm->mapClass, "count", map_count); + PRIMITIVE(vm->mapClass, "remove(_)", map_remove); + PRIMITIVE(vm->mapClass, "iterate(_)", map_iterate); + PRIMITIVE(vm->mapClass, "keyIteratorValue_(_)", map_keyIteratorValue); + PRIMITIVE(vm->mapClass, "valueIteratorValue_(_)", map_valueIteratorValue); + + vm->rangeClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Range")); + PRIMITIVE(vm->rangeClass, "from", range_from); + PRIMITIVE(vm->rangeClass, "to", range_to); + PRIMITIVE(vm->rangeClass, "min", range_min); + PRIMITIVE(vm->rangeClass, "max", range_max); + PRIMITIVE(vm->rangeClass, "isInclusive", range_isInclusive); + PRIMITIVE(vm->rangeClass, "iterate(_)", range_iterate); + PRIMITIVE(vm->rangeClass, "iteratorValue(_)", range_iteratorValue); + PRIMITIVE(vm->rangeClass, "toString", range_toString); + + ObjClass* systemClass = AS_CLASS(wrenFindVariable(vm, coreModule, "System")); + PRIMITIVE(systemClass->obj.classObj, "clock", system_clock); + PRIMITIVE(systemClass->obj.classObj, "gc()", system_gc); + PRIMITIVE(systemClass->obj.classObj, "writeString_(_)", system_writeString); + + // While bootstrapping the core types and running the core module, a number + // of string objects have been created, many of which were instantiated + // before stringClass was stored in the VM. Some of them *must* be created + // first -- the ObjClass for string itself has a reference to the ObjString + // for its name. + // + // These all currently have a NULL classObj pointer, so go back and assign + // them now that the string class is known. + for (Obj* obj = vm->first; obj != NULL; obj = obj->next) + { + if (obj->type == OBJ_STRING) obj->classObj = vm->stringClass; + } +} +// End file "wren_core.c" +// Begin file "wren_value.c" +#include +#include +#include +#include + + +#if WREN_DEBUG_TRACE_MEMORY +#endif + +// TODO: Tune these. +// The initial (and minimum) capacity of a non-empty list or map object. +#define MIN_CAPACITY 16 + +// The rate at which a collection's capacity grows when the size exceeds the +// current capacity. The new capacity will be determined by *multiplying* the +// old capacity by this. Growing geometrically is necessary to ensure that +// adding to a collection has O(1) amortized complexity. +#define GROW_FACTOR 2 + +// The maximum percentage of map entries that can be filled before the map is +// grown. A lower load takes more memory but reduces collisions which makes +// lookup faster. +#define MAP_LOAD_PERCENT 75 + +// The number of call frames initially allocated when a fiber is created. Making +// this smaller makes fibers use less memory (at first) but spends more time +// reallocating when the call stack grows. +#define INITIAL_CALL_FRAMES 4 + +DEFINE_BUFFER(Value, Value); +DEFINE_BUFFER(Method, Method); + +static void initObj(WrenVM* vm, Obj* obj, ObjType type, ObjClass* classObj) +{ + obj->type = type; + obj->isDark = false; + obj->classObj = classObj; + obj->next = vm->first; + vm->first = obj; +} + +ObjClass* wrenNewSingleClass(WrenVM* vm, int numFields, ObjString* name) +{ + ObjClass* classObj = ALLOCATE(vm, ObjClass); + initObj(vm, &classObj->obj, OBJ_CLASS, NULL); + classObj->superclass = NULL; + classObj->numFields = numFields; + classObj->name = name; + classObj->attributes = NULL_VAL; + + wrenPushRoot(vm, (Obj*)classObj); + wrenMethodBufferInit(&classObj->methods); + wrenPopRoot(vm); + + return classObj; +} + +void wrenBindSuperclass(WrenVM* vm, ObjClass* subclass, ObjClass* superclass) +{ + ASSERT(superclass != NULL, "Must have superclass."); + + subclass->superclass = superclass; + + // Include the superclass in the total number of fields. + if (subclass->numFields != -1) + { + subclass->numFields += superclass->numFields; + } + else + { + ASSERT(superclass->numFields == 0, + "A foreign class cannot inherit from a class with fields."); + } + + // Inherit methods from its superclass. + for (int i = 0; i < superclass->methods.count; i++) + { + wrenBindMethod(vm, subclass, i, superclass->methods.data[i]); + } +} + +ObjClass* wrenNewClass(WrenVM* vm, ObjClass* superclass, int numFields, + ObjString* name) +{ + // Create the metaclass. + Value metaclassName = wrenStringFormat(vm, "@ metaclass", OBJ_VAL(name)); + wrenPushRoot(vm, AS_OBJ(metaclassName)); + + ObjClass* metaclass = wrenNewSingleClass(vm, 0, AS_STRING(metaclassName)); + metaclass->obj.classObj = vm->classClass; + + wrenPopRoot(vm); + + // Make sure the metaclass isn't collected when we allocate the class. + wrenPushRoot(vm, (Obj*)metaclass); + + // Metaclasses always inherit Class and do not parallel the non-metaclass + // hierarchy. + wrenBindSuperclass(vm, metaclass, vm->classClass); + + ObjClass* classObj = wrenNewSingleClass(vm, numFields, name); + + // Make sure the class isn't collected while the inherited methods are being + // bound. + wrenPushRoot(vm, (Obj*)classObj); + + classObj->obj.classObj = metaclass; + wrenBindSuperclass(vm, classObj, superclass); + + wrenPopRoot(vm); + wrenPopRoot(vm); + + return classObj; +} + +void wrenBindMethod(WrenVM* vm, ObjClass* classObj, int symbol, Method method) +{ + // Make sure the buffer is big enough to contain the symbol's index. + if (symbol >= classObj->methods.count) + { + Method noMethod; + noMethod.type = METHOD_NONE; + wrenMethodBufferFill(vm, &classObj->methods, noMethod, + symbol - classObj->methods.count + 1); + } + + classObj->methods.data[symbol] = method; +} + +ObjClosure* wrenNewClosure(WrenVM* vm, ObjFn* fn) +{ + ObjClosure* closure = ALLOCATE_FLEX(vm, ObjClosure, + ObjUpvalue*, fn->numUpvalues); + initObj(vm, &closure->obj, OBJ_CLOSURE, vm->fnClass); + + closure->fn = fn; + + // Clear the upvalue array. We need to do this in case a GC is triggered + // after the closure is created but before the upvalue array is populated. + for (int i = 0; i < fn->numUpvalues; i++) closure->upvalues[i] = NULL; + + return closure; +} + +ObjFiber* wrenNewFiber(WrenVM* vm, ObjClosure* closure) +{ + // Allocate the arrays before the fiber in case it triggers a GC. + CallFrame* frames = ALLOCATE_ARRAY(vm, CallFrame, INITIAL_CALL_FRAMES); + + // Add one slot for the unused implicit receiver slot that the compiler + // assumes all functions have. + int stackCapacity = closure == NULL + ? 1 + : wrenPowerOf2Ceil(closure->fn->maxSlots + 1); + Value* stack = ALLOCATE_ARRAY(vm, Value, stackCapacity); + + ObjFiber* fiber = ALLOCATE(vm, ObjFiber); + initObj(vm, &fiber->obj, OBJ_FIBER, vm->fiberClass); + + fiber->stack = stack; + fiber->stackTop = fiber->stack; + fiber->stackCapacity = stackCapacity; + + fiber->frames = frames; + fiber->frameCapacity = INITIAL_CALL_FRAMES; + fiber->numFrames = 0; + + fiber->openUpvalues = NULL; + fiber->caller = NULL; + fiber->error = NULL_VAL; + fiber->state = FIBER_OTHER; + + if (closure != NULL) + { + // Initialize the first call frame. + wrenAppendCallFrame(vm, fiber, closure, fiber->stack); + + // The first slot always holds the closure. + fiber->stackTop[0] = OBJ_VAL(closure); + fiber->stackTop++; + } + + return fiber; +} + +void wrenEnsureStack(WrenVM* vm, ObjFiber* fiber, int needed) +{ + if (fiber->stackCapacity >= needed) return; + + int capacity = wrenPowerOf2Ceil(needed); + + Value* oldStack = fiber->stack; + fiber->stack = (Value*)wrenReallocate(vm, fiber->stack, + sizeof(Value) * fiber->stackCapacity, + sizeof(Value) * capacity); + fiber->stackCapacity = capacity; + + // If the reallocation moves the stack, then we need to recalculate every + // pointer that points into the old stack to into the same relative distance + // in the new stack. We have to be a little careful about how these are + // calculated because pointer subtraction is only well-defined within a + // single array, hence the slightly redundant-looking arithmetic below. + if (fiber->stack != oldStack) + { + // Top of the stack. + if (vm->apiStack >= oldStack && vm->apiStack <= fiber->stackTop) + { + vm->apiStack = fiber->stack + (vm->apiStack - oldStack); + } + + // Stack pointer for each call frame. + for (int i = 0; i < fiber->numFrames; i++) + { + CallFrame* frame = &fiber->frames[i]; + frame->stackStart = fiber->stack + (frame->stackStart - oldStack); + } + + // Open upvalues. + for (ObjUpvalue* upvalue = fiber->openUpvalues; + upvalue != NULL; + upvalue = upvalue->next) + { + upvalue->value = fiber->stack + (upvalue->value - oldStack); + } + + fiber->stackTop = fiber->stack + (fiber->stackTop - oldStack); + } +} + +ObjForeign* wrenNewForeign(WrenVM* vm, ObjClass* classObj, size_t size) +{ + ObjForeign* object = ALLOCATE_FLEX(vm, ObjForeign, uint8_t, size); + initObj(vm, &object->obj, OBJ_FOREIGN, classObj); + + // Zero out the bytes. + memset(object->data, 0, size); + return object; +} + +ObjFn* wrenNewFunction(WrenVM* vm, ObjModule* module, int maxSlots) +{ + FnDebug* debug = ALLOCATE(vm, FnDebug); + debug->name = NULL; + wrenIntBufferInit(&debug->sourceLines); + + ObjFn* fn = ALLOCATE(vm, ObjFn); + initObj(vm, &fn->obj, OBJ_FN, vm->fnClass); + + wrenValueBufferInit(&fn->constants); + wrenByteBufferInit(&fn->code); + fn->module = module; + fn->maxSlots = maxSlots; + fn->numUpvalues = 0; + fn->arity = 0; + fn->debug = debug; + + return fn; +} + +void wrenFunctionBindName(WrenVM* vm, ObjFn* fn, const char* name, int length) +{ + fn->debug->name = ALLOCATE_ARRAY(vm, char, length + 1); + memcpy(fn->debug->name, name, length); + fn->debug->name[length] = '\0'; +} + +Value wrenNewInstance(WrenVM* vm, ObjClass* classObj) +{ + ObjInstance* instance = ALLOCATE_FLEX(vm, ObjInstance, + Value, classObj->numFields); + initObj(vm, &instance->obj, OBJ_INSTANCE, classObj); + + // Initialize fields to null. + for (int i = 0; i < classObj->numFields; i++) + { + instance->fields[i] = NULL_VAL; + } + + return OBJ_VAL(instance); +} + +ObjList* wrenNewList(WrenVM* vm, uint32_t numElements) +{ + // Allocate this before the list object in case it triggers a GC which would + // free the list. + Value* elements = NULL; + if (numElements > 0) + { + elements = ALLOCATE_ARRAY(vm, Value, numElements); + } + + ObjList* list = ALLOCATE(vm, ObjList); + initObj(vm, &list->obj, OBJ_LIST, vm->listClass); + list->elements.capacity = numElements; + list->elements.count = numElements; + list->elements.data = elements; + return list; +} + +void wrenListInsert(WrenVM* vm, ObjList* list, Value value, uint32_t index) +{ + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + // Add a slot at the end of the list. + wrenValueBufferWrite(vm, &list->elements, NULL_VAL); + + if (IS_OBJ(value)) wrenPopRoot(vm); + + // Shift the existing elements down. + for (uint32_t i = list->elements.count - 1; i > index; i--) + { + list->elements.data[i] = list->elements.data[i - 1]; + } + + // Store the new element. + list->elements.data[index] = value; +} + +int wrenListIndexOf(WrenVM* vm, ObjList* list, Value value) +{ + int count = list->elements.count; + for (int i = 0; i < count; i++) + { + Value item = list->elements.data[i]; + if(wrenValuesEqual(item, value)) { + return i; + } + } + return -1; +} + +Value wrenListRemoveAt(WrenVM* vm, ObjList* list, uint32_t index) +{ + Value removed = list->elements.data[index]; + + if (IS_OBJ(removed)) wrenPushRoot(vm, AS_OBJ(removed)); + + // Shift items up. + for (int i = index; i < list->elements.count - 1; i++) + { + list->elements.data[i] = list->elements.data[i + 1]; + } + + // If we have too much excess capacity, shrink it. + if (list->elements.capacity / GROW_FACTOR >= list->elements.count) + { + list->elements.data = (Value*)wrenReallocate(vm, list->elements.data, + sizeof(Value) * list->elements.capacity, + sizeof(Value) * (list->elements.capacity / GROW_FACTOR)); + list->elements.capacity /= GROW_FACTOR; + } + + if (IS_OBJ(removed)) wrenPopRoot(vm); + + list->elements.count--; + return removed; +} + +ObjMap* wrenNewMap(WrenVM* vm) +{ + ObjMap* map = ALLOCATE(vm, ObjMap); + initObj(vm, &map->obj, OBJ_MAP, vm->mapClass); + map->capacity = 0; + map->count = 0; + map->entries = NULL; + return map; +} + +static inline uint32_t hashBits(uint64_t hash) +{ + // From v8's ComputeLongHash() which in turn cites: + // Thomas Wang, Integer Hash Functions. + // http://www.concentric.net/~Ttwang/tech/inthash.htm + hash = ~hash + (hash << 18); // hash = (hash << 18) - hash - 1; + hash = hash ^ (hash >> 31); + hash = hash * 21; // hash = (hash + (hash << 2)) + (hash << 4); + hash = hash ^ (hash >> 11); + hash = hash + (hash << 6); + hash = hash ^ (hash >> 22); + return (uint32_t)(hash & 0x3fffffff); +} + +// Generates a hash code for [num]. +static inline uint32_t hashNumber(double num) +{ + // Hash the raw bits of the value. + return hashBits(wrenDoubleToBits(num)); +} + +// Generates a hash code for [object]. +static uint32_t hashObject(Obj* object) +{ + switch (object->type) + { + case OBJ_CLASS: + // Classes just use their name. + return hashObject((Obj*)((ObjClass*)object)->name); + + // Allow bare (non-closure) functions so that we can use a map to find + // existing constants in a function's constant table. This is only used + // internally. Since user code never sees a non-closure function, they + // cannot use them as map keys. + case OBJ_FN: + { + ObjFn* fn = (ObjFn*)object; + return hashNumber(fn->arity) ^ hashNumber(fn->code.count); + } + + case OBJ_RANGE: + { + ObjRange* range = (ObjRange*)object; + return hashNumber(range->from) ^ hashNumber(range->to); + } + + case OBJ_STRING: + return ((ObjString*)object)->hash; + + default: + ASSERT(false, "Only immutable objects can be hashed."); + return 0; + } +} + +// Generates a hash code for [value], which must be one of the built-in +// immutable types: null, bool, class, num, range, or string. +static uint32_t hashValue(Value value) +{ + // TODO: We'll probably want to randomize this at some point. + +#if WREN_NAN_TAGGING + if (IS_OBJ(value)) return hashObject(AS_OBJ(value)); + + // Hash the raw bits of the unboxed value. + return hashBits(value); +#else + switch (value.type) + { + case VAL_FALSE: return 0; + case VAL_NULL: return 1; + case VAL_NUM: return hashNumber(AS_NUM(value)); + case VAL_TRUE: return 2; + case VAL_OBJ: return hashObject(AS_OBJ(value)); + default: UNREACHABLE(); + } + + return 0; +#endif +} + +// Looks for an entry with [key] in an array of [capacity] [entries]. +// +// If found, sets [result] to point to it and returns `true`. Otherwise, +// returns `false` and points [result] to the entry where the key/value pair +// should be inserted. +static bool findEntry(MapEntry* entries, uint32_t capacity, Value key, + MapEntry** result) +{ + // If there is no entry array (an empty map), we definitely won't find it. + if (capacity == 0) return false; + + // Figure out where to insert it in the table. Use open addressing and + // basic linear probing. + uint32_t startIndex = hashValue(key) % capacity; + uint32_t index = startIndex; + + // If we pass a tombstone and don't end up finding the key, its entry will + // be re-used for the insert. + MapEntry* tombstone = NULL; + + // Walk the probe sequence until we've tried every slot. + do + { + MapEntry* entry = &entries[index]; + + if (IS_UNDEFINED(entry->key)) + { + // If we found an empty slot, the key is not in the table. If we found a + // slot that contains a deleted key, we have to keep looking. + if (IS_FALSE(entry->value)) + { + // We found an empty slot, so we've reached the end of the probe + // sequence without finding the key. If we passed a tombstone, then + // that's where we should insert the item, otherwise, put it here at + // the end of the sequence. + *result = tombstone != NULL ? tombstone : entry; + return false; + } + else + { + // We found a tombstone. We need to keep looking in case the key is + // after it, but we'll use this entry as the insertion point if the + // key ends up not being found. + if (tombstone == NULL) tombstone = entry; + } + } + else if (wrenValuesEqual(entry->key, key)) + { + // We found the key. + *result = entry; + return true; + } + + // Try the next slot. + index = (index + 1) % capacity; + } + while (index != startIndex); + + // If we get here, the table is full of tombstones. Return the first one we + // found. + ASSERT(tombstone != NULL, "Map should have tombstones or empty entries."); + *result = tombstone; + return false; +} + +// Inserts [key] and [value] in the array of [entries] with the given +// [capacity]. +// +// Returns `true` if this is the first time [key] was added to the map. +static bool insertEntry(MapEntry* entries, uint32_t capacity, + Value key, Value value) +{ + ASSERT(entries != NULL, "Should ensure capacity before inserting."); + + MapEntry* entry; + if (findEntry(entries, capacity, key, &entry)) + { + // Already present, so just replace the value. + entry->value = value; + return false; + } + else + { + entry->key = key; + entry->value = value; + return true; + } +} + +// Updates [map]'s entry array to [capacity]. +static void resizeMap(WrenVM* vm, ObjMap* map, uint32_t capacity) +{ + // Create the new empty hash table. + MapEntry* entries = ALLOCATE_ARRAY(vm, MapEntry, capacity); + for (uint32_t i = 0; i < capacity; i++) + { + entries[i].key = UNDEFINED_VAL; + entries[i].value = FALSE_VAL; + } + + // Re-add the existing entries. + if (map->capacity > 0) + { + for (uint32_t i = 0; i < map->capacity; i++) + { + MapEntry* entry = &map->entries[i]; + + // Don't copy empty entries or tombstones. + if (IS_UNDEFINED(entry->key)) continue; + + insertEntry(entries, capacity, entry->key, entry->value); + } + } + + // Replace the array. + DEALLOCATE(vm, map->entries); + map->entries = entries; + map->capacity = capacity; +} + +Value wrenMapGet(ObjMap* map, Value key) +{ + MapEntry* entry; + if (findEntry(map->entries, map->capacity, key, &entry)) return entry->value; + + return UNDEFINED_VAL; +} + +void wrenMapSet(WrenVM* vm, ObjMap* map, Value key, Value value) +{ + // If the map is getting too full, make room first. + if (map->count + 1 > map->capacity * MAP_LOAD_PERCENT / 100) + { + // Figure out the new hash table size. + uint32_t capacity = map->capacity * GROW_FACTOR; + if (capacity < MIN_CAPACITY) capacity = MIN_CAPACITY; + + resizeMap(vm, map, capacity); + } + + if (insertEntry(map->entries, map->capacity, key, value)) + { + // A new key was added. + map->count++; + } +} + +void wrenMapClear(WrenVM* vm, ObjMap* map) +{ + DEALLOCATE(vm, map->entries); + map->entries = NULL; + map->capacity = 0; + map->count = 0; +} + +Value wrenMapRemoveKey(WrenVM* vm, ObjMap* map, Value key) +{ + MapEntry* entry; + if (!findEntry(map->entries, map->capacity, key, &entry)) return NULL_VAL; + + // Remove the entry from the map. Set this value to true, which marks it as a + // deleted slot. When searching for a key, we will stop on empty slots, but + // continue past deleted slots. + Value value = entry->value; + entry->key = UNDEFINED_VAL; + entry->value = TRUE_VAL; + + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + map->count--; + + if (map->count == 0) + { + // Removed the last item, so free the array. + wrenMapClear(vm, map); + } + else if (map->capacity > MIN_CAPACITY && + map->count < map->capacity / GROW_FACTOR * MAP_LOAD_PERCENT / 100) + { + uint32_t capacity = map->capacity / GROW_FACTOR; + if (capacity < MIN_CAPACITY) capacity = MIN_CAPACITY; + + // The map is getting empty, so shrink the entry array back down. + // TODO: Should we do this less aggressively than we grow? + resizeMap(vm, map, capacity); + } + + if (IS_OBJ(value)) wrenPopRoot(vm); + return value; +} + +ObjModule* wrenNewModule(WrenVM* vm, ObjString* name) +{ + ObjModule* module = ALLOCATE(vm, ObjModule); + + // Modules are never used as first-class objects, so don't need a class. + initObj(vm, (Obj*)module, OBJ_MODULE, NULL); + + wrenPushRoot(vm, (Obj*)module); + + wrenSymbolTableInit(&module->variableNames); + wrenValueBufferInit(&module->variables); + + module->name = name; + + wrenPopRoot(vm); + return module; +} + +Value wrenNewRange(WrenVM* vm, double from, double to, bool isInclusive) +{ + ObjRange* range = ALLOCATE(vm, ObjRange); + initObj(vm, &range->obj, OBJ_RANGE, vm->rangeClass); + range->from = from; + range->to = to; + range->isInclusive = isInclusive; + + return OBJ_VAL(range); +} + +// Creates a new string object with a null-terminated buffer large enough to +// hold a string of [length] but does not fill in the bytes. +// +// The caller is expected to fill in the buffer and then calculate the string's +// hash. +static ObjString* allocateString(WrenVM* vm, size_t length) +{ + ObjString* string = ALLOCATE_FLEX(vm, ObjString, char, length + 1); + initObj(vm, &string->obj, OBJ_STRING, vm->stringClass); + string->length = (int)length; + string->value[length] = '\0'; + + return string; +} + +// Calculates and stores the hash code for [string]. +static void hashString(ObjString* string) +{ + // FNV-1a hash. See: http://www.isthe.com/chongo/tech/comp/fnv/ + uint32_t hash = 2166136261u; + + // This is O(n) on the length of the string, but we only call this when a new + // string is created. Since the creation is also O(n) (to copy/initialize all + // the bytes), we allow this here. + for (uint32_t i = 0; i < string->length; i++) + { + hash ^= string->value[i]; + hash *= 16777619; + } + + string->hash = hash; +} + +Value wrenNewString(WrenVM* vm, const char* text) +{ + return wrenNewStringLength(vm, text, strlen(text)); +} + +Value wrenNewStringLength(WrenVM* vm, const char* text, size_t length) +{ + // Allow NULL if the string is empty since byte buffers don't allocate any + // characters for a zero-length string. + ASSERT(length == 0 || text != NULL, "Unexpected NULL string."); + + ObjString* string = allocateString(vm, length); + + // Copy the string (if given one). + if (length > 0 && text != NULL) memcpy(string->value, text, length); + + hashString(string); + return OBJ_VAL(string); +} + + +Value wrenNewStringFromRange(WrenVM* vm, ObjString* source, int start, + uint32_t count, int step) +{ + uint8_t* from = (uint8_t*)source->value; + int length = 0; + for (uint32_t i = 0; i < count; i++) + { + length += wrenUtf8DecodeNumBytes(from[start + i * step]); + } + + ObjString* result = allocateString(vm, length); + result->value[length] = '\0'; + + uint8_t* to = (uint8_t*)result->value; + for (uint32_t i = 0; i < count; i++) + { + int index = start + i * step; + int codePoint = wrenUtf8Decode(from + index, source->length - index); + + if (codePoint != -1) + { + to += wrenUtf8Encode(codePoint, to); + } + } + + hashString(result); + return OBJ_VAL(result); +} + +Value wrenNumToString(WrenVM* vm, double value) +{ + // Edge case: If the value is NaN or infinity, different versions of libc + // produce different outputs (some will format it signed and some won't). To + // get reliable output, handle it ourselves. + if (isnan(value)) return CONST_STRING(vm, "nan"); + if (isinf(value)) + { + if (value > 0.0) + { + return CONST_STRING(vm, "infinity"); + } + else + { + return CONST_STRING(vm, "-infinity"); + } + } + + // This is large enough to hold any double converted to a string using + // "%.14g". Example: + // + // -1.12345678901234e-1022 + // + // So we have: + // + // + 1 char for sign + // + 1 char for digit + // + 1 char for "." + // + 14 chars for decimal digits + // + 1 char for "e" + // + 1 char for "-" or "+" + // + 4 chars for exponent + // + 1 char for "\0" + // = 24 + char buffer[24]; + int length = sprintf(buffer, "%.14g", value); + return wrenNewStringLength(vm, buffer, length); +} + +Value wrenStringFromCodePoint(WrenVM* vm, int value) +{ + int length = wrenUtf8EncodeNumBytes(value); + ASSERT(length != 0, "Value out of range."); + + ObjString* string = allocateString(vm, length); + + wrenUtf8Encode(value, (uint8_t*)string->value); + hashString(string); + + return OBJ_VAL(string); +} + +Value wrenStringFromByte(WrenVM *vm, uint8_t value) +{ + int length = 1; + ObjString* string = allocateString(vm, length); + string->value[0] = value; + hashString(string); + return OBJ_VAL(string); +} + +Value wrenStringFormat(WrenVM* vm, const char* format, ...) +{ + va_list argList; + + // Calculate the length of the result string. Do this up front so we can + // create the final string with a single allocation. + va_start(argList, format); + size_t totalLength = 0; + for (const char* c = format; *c != '\0'; c++) + { + switch (*c) + { + case '$': + totalLength += strlen(va_arg(argList, const char*)); + break; + + case '@': + totalLength += AS_STRING(va_arg(argList, Value))->length; + break; + + default: + // Any other character is interpreted literally. + totalLength++; + } + } + va_end(argList); + + // Concatenate the string. + ObjString* result = allocateString(vm, totalLength); + + va_start(argList, format); + char* start = result->value; + for (const char* c = format; *c != '\0'; c++) + { + switch (*c) + { + case '$': + { + const char* string = va_arg(argList, const char*); + size_t length = strlen(string); + memcpy(start, string, length); + start += length; + break; + } + + case '@': + { + ObjString* string = AS_STRING(va_arg(argList, Value)); + memcpy(start, string->value, string->length); + start += string->length; + break; + } + + default: + // Any other character is interpreted literally. + *start++ = *c; + } + } + va_end(argList); + + hashString(result); + + return OBJ_VAL(result); +} + +Value wrenStringCodePointAt(WrenVM* vm, ObjString* string, uint32_t index) +{ + ASSERT(index < string->length, "Index out of bounds."); + + int codePoint = wrenUtf8Decode((uint8_t*)string->value + index, + string->length - index); + if (codePoint == -1) + { + // If it isn't a valid UTF-8 sequence, treat it as a single raw byte. + char bytes[2]; + bytes[0] = string->value[index]; + bytes[1] = '\0'; + return wrenNewStringLength(vm, bytes, 1); + } + + return wrenStringFromCodePoint(vm, codePoint); +} + +// Uses the Boyer-Moore-Horspool string matching algorithm. +uint32_t wrenStringFind(ObjString* haystack, ObjString* needle, uint32_t start) +{ + // Edge case: An empty needle is always found. + if (needle->length == 0) return start; + + // If the needle goes past the haystack it won't be found. + if (start + needle->length > haystack->length) return UINT32_MAX; + + // If the startIndex is too far it also won't be found. + if (start >= haystack->length) return UINT32_MAX; + + // Pre-calculate the shift table. For each character (8-bit value), we + // determine how far the search window can be advanced if that character is + // the last character in the haystack where we are searching for the needle + // and the needle doesn't match there. + uint32_t shift[UINT8_MAX]; + uint32_t needleEnd = needle->length - 1; + + // By default, we assume the character is not the needle at all. In that case + // case, if a match fails on that character, we can advance one whole needle + // width since. + for (uint32_t index = 0; index < UINT8_MAX; index++) + { + shift[index] = needle->length; + } + + // Then, for every character in the needle, determine how far it is from the + // end. If a match fails on that character, we can advance the window such + // that it the last character in it lines up with the last place we could + // find it in the needle. + for (uint32_t index = 0; index < needleEnd; index++) + { + char c = needle->value[index]; + shift[(uint8_t)c] = needleEnd - index; + } + + // Slide the needle across the haystack, looking for the first match or + // stopping if the needle goes off the end. + char lastChar = needle->value[needleEnd]; + uint32_t range = haystack->length - needle->length; + + for (uint32_t index = start; index <= range; ) + { + // Compare the last character in the haystack's window to the last character + // in the needle. If it matches, see if the whole needle matches. + char c = haystack->value[index + needleEnd]; + if (lastChar == c && + memcmp(haystack->value + index, needle->value, needleEnd) == 0) + { + // Found a match. + return index; + } + + // Otherwise, slide the needle forward. + index += shift[(uint8_t)c]; + } + + // Not found. + return UINT32_MAX; +} + +ObjUpvalue* wrenNewUpvalue(WrenVM* vm, Value* value) +{ + ObjUpvalue* upvalue = ALLOCATE(vm, ObjUpvalue); + + // Upvalues are never used as first-class objects, so don't need a class. + initObj(vm, &upvalue->obj, OBJ_UPVALUE, NULL); + + upvalue->value = value; + upvalue->closed = NULL_VAL; + upvalue->next = NULL; + return upvalue; +} + +void wrenGrayObj(WrenVM* vm, Obj* obj) +{ + if (obj == NULL) return; + + // Stop if the object is already darkened so we don't get stuck in a cycle. + if (obj->isDark) return; + + // It's been reached. + obj->isDark = true; + + // Add it to the gray list so it can be recursively explored for + // more marks later. + if (vm->grayCount >= vm->grayCapacity) + { + vm->grayCapacity = vm->grayCount * 2; + vm->gray = (Obj**)vm->config.reallocateFn(vm->gray, + vm->grayCapacity * sizeof(Obj*), + vm->config.userData); + } + + vm->gray[vm->grayCount++] = obj; +} + +void wrenGrayValue(WrenVM* vm, Value value) +{ + if (!IS_OBJ(value)) return; + wrenGrayObj(vm, AS_OBJ(value)); +} + +void wrenGrayBuffer(WrenVM* vm, ValueBuffer* buffer) +{ + for (int i = 0; i < buffer->count; i++) + { + wrenGrayValue(vm, buffer->data[i]); + } +} + +static void blackenClass(WrenVM* vm, ObjClass* classObj) +{ + // The metaclass. + wrenGrayObj(vm, (Obj*)classObj->obj.classObj); + + // The superclass. + wrenGrayObj(vm, (Obj*)classObj->superclass); + + // Method function objects. + for (int i = 0; i < classObj->methods.count; i++) + { + if (classObj->methods.data[i].type == METHOD_BLOCK) + { + wrenGrayObj(vm, (Obj*)classObj->methods.data[i].as.closure); + } + } + + wrenGrayObj(vm, (Obj*)classObj->name); + + if(!IS_NULL(classObj->attributes)) wrenGrayObj(vm, AS_OBJ(classObj->attributes)); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjClass); + vm->bytesAllocated += classObj->methods.capacity * sizeof(Method); +} + +static void blackenClosure(WrenVM* vm, ObjClosure* closure) +{ + // Mark the function. + wrenGrayObj(vm, (Obj*)closure->fn); + + // Mark the upvalues. + for (int i = 0; i < closure->fn->numUpvalues; i++) + { + wrenGrayObj(vm, (Obj*)closure->upvalues[i]); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjClosure); + vm->bytesAllocated += sizeof(ObjUpvalue*) * closure->fn->numUpvalues; +} + +static void blackenFiber(WrenVM* vm, ObjFiber* fiber) +{ + // Stack functions. + for (int i = 0; i < fiber->numFrames; i++) + { + wrenGrayObj(vm, (Obj*)fiber->frames[i].closure); + } + + // Stack variables. + for (Value* slot = fiber->stack; slot < fiber->stackTop; slot++) + { + wrenGrayValue(vm, *slot); + } + + // Open upvalues. + ObjUpvalue* upvalue = fiber->openUpvalues; + while (upvalue != NULL) + { + wrenGrayObj(vm, (Obj*)upvalue); + upvalue = upvalue->next; + } + + // The caller. + wrenGrayObj(vm, (Obj*)fiber->caller); + wrenGrayValue(vm, fiber->error); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjFiber); + vm->bytesAllocated += fiber->frameCapacity * sizeof(CallFrame); + vm->bytesAllocated += fiber->stackCapacity * sizeof(Value); +} + +static void blackenFn(WrenVM* vm, ObjFn* fn) +{ + // Mark the constants. + wrenGrayBuffer(vm, &fn->constants); + + // Mark the module it belongs to, in case it's been unloaded. + wrenGrayObj(vm, (Obj*)fn->module); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjFn); + vm->bytesAllocated += sizeof(uint8_t) * fn->code.capacity; + vm->bytesAllocated += sizeof(Value) * fn->constants.capacity; + + // The debug line number buffer. + vm->bytesAllocated += sizeof(int) * fn->code.capacity; + // TODO: What about the function name? +} + +static void blackenForeign(WrenVM* vm, ObjForeign* foreign) +{ + // TODO: Keep track of how much memory the foreign object uses. We can store + // this in each foreign object, but it will balloon the size. We may not want + // that much overhead. One option would be to let the foreign class register + // a C function that returns a size for the object. That way the VM doesn't + // always have to explicitly store it. +} + +static void blackenInstance(WrenVM* vm, ObjInstance* instance) +{ + wrenGrayObj(vm, (Obj*)instance->obj.classObj); + + // Mark the fields. + for (int i = 0; i < instance->obj.classObj->numFields; i++) + { + wrenGrayValue(vm, instance->fields[i]); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjInstance); + vm->bytesAllocated += sizeof(Value) * instance->obj.classObj->numFields; +} + +static void blackenList(WrenVM* vm, ObjList* list) +{ + // Mark the elements. + wrenGrayBuffer(vm, &list->elements); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjList); + vm->bytesAllocated += sizeof(Value) * list->elements.capacity; +} + +static void blackenMap(WrenVM* vm, ObjMap* map) +{ + // Mark the entries. + for (uint32_t i = 0; i < map->capacity; i++) + { + MapEntry* entry = &map->entries[i]; + if (IS_UNDEFINED(entry->key)) continue; + + wrenGrayValue(vm, entry->key); + wrenGrayValue(vm, entry->value); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjMap); + vm->bytesAllocated += sizeof(MapEntry) * map->capacity; +} + +static void blackenModule(WrenVM* vm, ObjModule* module) +{ + // Top-level variables. + for (int i = 0; i < module->variables.count; i++) + { + wrenGrayValue(vm, module->variables.data[i]); + } + + wrenBlackenSymbolTable(vm, &module->variableNames); + + wrenGrayObj(vm, (Obj*)module->name); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjModule); +} + +static void blackenRange(WrenVM* vm, ObjRange* range) +{ + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjRange); +} + +static void blackenString(WrenVM* vm, ObjString* string) +{ + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjString) + string->length + 1; +} + +static void blackenUpvalue(WrenVM* vm, ObjUpvalue* upvalue) +{ + // Mark the closed-over object (in case it is closed). + wrenGrayValue(vm, upvalue->closed); + + // Keep track of how much memory is still in use. + vm->bytesAllocated += sizeof(ObjUpvalue); +} + +static void blackenObject(WrenVM* vm, Obj* obj) +{ +#if WREN_DEBUG_TRACE_MEMORY + printf("mark "); + wrenDumpValue(OBJ_VAL(obj)); + printf(" @ %p\n", obj); +#endif + + // Traverse the object's fields. + switch (obj->type) + { + case OBJ_CLASS: blackenClass( vm, (ObjClass*) obj); break; + case OBJ_CLOSURE: blackenClosure( vm, (ObjClosure*) obj); break; + case OBJ_FIBER: blackenFiber( vm, (ObjFiber*) obj); break; + case OBJ_FN: blackenFn( vm, (ObjFn*) obj); break; + case OBJ_FOREIGN: blackenForeign( vm, (ObjForeign*) obj); break; + case OBJ_INSTANCE: blackenInstance(vm, (ObjInstance*)obj); break; + case OBJ_LIST: blackenList( vm, (ObjList*) obj); break; + case OBJ_MAP: blackenMap( vm, (ObjMap*) obj); break; + case OBJ_MODULE: blackenModule( vm, (ObjModule*) obj); break; + case OBJ_RANGE: blackenRange( vm, (ObjRange*) obj); break; + case OBJ_STRING: blackenString( vm, (ObjString*) obj); break; + case OBJ_UPVALUE: blackenUpvalue( vm, (ObjUpvalue*) obj); break; + } +} + +void wrenBlackenObjects(WrenVM* vm) +{ + while (vm->grayCount > 0) + { + // Pop an item from the gray stack. + Obj* obj = vm->gray[--vm->grayCount]; + blackenObject(vm, obj); + } +} + +void wrenFreeObj(WrenVM* vm, Obj* obj) +{ +#if WREN_DEBUG_TRACE_MEMORY + printf("free "); + wrenDumpValue(OBJ_VAL(obj)); + printf(" @ %p\n", obj); +#endif + + switch (obj->type) + { + case OBJ_CLASS: + wrenMethodBufferClear(vm, &((ObjClass*)obj)->methods); + break; + + case OBJ_FIBER: + { + ObjFiber* fiber = (ObjFiber*)obj; + DEALLOCATE(vm, fiber->frames); + DEALLOCATE(vm, fiber->stack); + break; + } + + case OBJ_FN: + { + ObjFn* fn = (ObjFn*)obj; + wrenValueBufferClear(vm, &fn->constants); + wrenByteBufferClear(vm, &fn->code); + wrenIntBufferClear(vm, &fn->debug->sourceLines); + DEALLOCATE(vm, fn->debug->name); + DEALLOCATE(vm, fn->debug); + break; + } + + case OBJ_FOREIGN: + wrenFinalizeForeign(vm, (ObjForeign*)obj); + break; + + case OBJ_LIST: + wrenValueBufferClear(vm, &((ObjList*)obj)->elements); + break; + + case OBJ_MAP: + DEALLOCATE(vm, ((ObjMap*)obj)->entries); + break; + + case OBJ_MODULE: + wrenSymbolTableClear(vm, &((ObjModule*)obj)->variableNames); + wrenValueBufferClear(vm, &((ObjModule*)obj)->variables); + break; + + case OBJ_CLOSURE: + case OBJ_INSTANCE: + case OBJ_RANGE: + case OBJ_STRING: + case OBJ_UPVALUE: + break; + } + + DEALLOCATE(vm, obj); +} + +ObjClass* wrenGetClass(WrenVM* vm, Value value) +{ + return wrenGetClassInline(vm, value); +} + +bool wrenValuesEqual(Value a, Value b) +{ + if (wrenValuesSame(a, b)) return true; + + // If we get here, it's only possible for two heap-allocated immutable objects + // to be equal. + if (!IS_OBJ(a) || !IS_OBJ(b)) return false; + + Obj* aObj = AS_OBJ(a); + Obj* bObj = AS_OBJ(b); + + // Must be the same type. + if (aObj->type != bObj->type) return false; + + switch (aObj->type) + { + case OBJ_RANGE: + { + ObjRange* aRange = (ObjRange*)aObj; + ObjRange* bRange = (ObjRange*)bObj; + return aRange->from == bRange->from && + aRange->to == bRange->to && + aRange->isInclusive == bRange->isInclusive; + } + + case OBJ_STRING: + { + ObjString* aString = (ObjString*)aObj; + ObjString* bString = (ObjString*)bObj; + return aString->hash == bString->hash && + wrenStringEqualsCString(aString, bString->value, bString->length); + } + + default: + // All other types are only equal if they are same, which they aren't if + // we get here. + return false; + } +} +// End file "wren_value.c" +// Begin file "wren_utils.c" +#include + + +DEFINE_BUFFER(Byte, uint8_t); +DEFINE_BUFFER(Int, int); +DEFINE_BUFFER(String, ObjString*); + +void wrenSymbolTableInit(SymbolTable* symbols) +{ + wrenStringBufferInit(symbols); +} + +void wrenSymbolTableClear(WrenVM* vm, SymbolTable* symbols) +{ + wrenStringBufferClear(vm, symbols); +} + +int wrenSymbolTableAdd(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length) +{ + ObjString* symbol = AS_STRING(wrenNewStringLength(vm, name, length)); + + wrenPushRoot(vm, &symbol->obj); + wrenStringBufferWrite(vm, symbols, symbol); + wrenPopRoot(vm); + + return symbols->count - 1; +} + +int wrenSymbolTableEnsure(WrenVM* vm, SymbolTable* symbols, + const char* name, size_t length) +{ + // See if the symbol is already defined. + int existing = wrenSymbolTableFind(symbols, name, length); + if (existing != -1) return existing; + + // New symbol, so add it. + return wrenSymbolTableAdd(vm, symbols, name, length); +} + +int wrenSymbolTableFind(const SymbolTable* symbols, + const char* name, size_t length) +{ + // See if the symbol is already defined. + // TODO: O(n). Do something better. + for (int i = 0; i < symbols->count; i++) + { + if (wrenStringEqualsCString(symbols->data[i], name, length)) return i; + } + + return -1; +} + +void wrenBlackenSymbolTable(WrenVM* vm, SymbolTable* symbolTable) +{ + for (int i = 0; i < symbolTable->count; i++) + { + wrenGrayObj(vm, &symbolTable->data[i]->obj); + } + + // Keep track of how much memory is still in use. + vm->bytesAllocated += symbolTable->capacity * sizeof(*symbolTable->data); +} + +int wrenUtf8EncodeNumBytes(int value) +{ + ASSERT(value >= 0, "Cannot encode a negative value."); + + if (value <= 0x7f) return 1; + if (value <= 0x7ff) return 2; + if (value <= 0xffff) return 3; + if (value <= 0x10ffff) return 4; + return 0; +} + +int wrenUtf8Encode(int value, uint8_t* bytes) +{ + if (value <= 0x7f) + { + // Single byte (i.e. fits in ASCII). + *bytes = value & 0x7f; + return 1; + } + else if (value <= 0x7ff) + { + // Two byte sequence: 110xxxxx 10xxxxxx. + *bytes = 0xc0 | ((value & 0x7c0) >> 6); + bytes++; + *bytes = 0x80 | (value & 0x3f); + return 2; + } + else if (value <= 0xffff) + { + // Three byte sequence: 1110xxxx 10xxxxxx 10xxxxxx. + *bytes = 0xe0 | ((value & 0xf000) >> 12); + bytes++; + *bytes = 0x80 | ((value & 0xfc0) >> 6); + bytes++; + *bytes = 0x80 | (value & 0x3f); + return 3; + } + else if (value <= 0x10ffff) + { + // Four byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx. + *bytes = 0xf0 | ((value & 0x1c0000) >> 18); + bytes++; + *bytes = 0x80 | ((value & 0x3f000) >> 12); + bytes++; + *bytes = 0x80 | ((value & 0xfc0) >> 6); + bytes++; + *bytes = 0x80 | (value & 0x3f); + return 4; + } + + // Invalid Unicode value. See: http://tools.ietf.org/html/rfc3629 + UNREACHABLE(); + return 0; +} + +int wrenUtf8Decode(const uint8_t* bytes, uint32_t length) +{ + // Single byte (i.e. fits in ASCII). + if (*bytes <= 0x7f) return *bytes; + + int value; + uint32_t remainingBytes; + if ((*bytes & 0xe0) == 0xc0) + { + // Two byte sequence: 110xxxxx 10xxxxxx. + value = *bytes & 0x1f; + remainingBytes = 1; + } + else if ((*bytes & 0xf0) == 0xe0) + { + // Three byte sequence: 1110xxxx 10xxxxxx 10xxxxxx. + value = *bytes & 0x0f; + remainingBytes = 2; + } + else if ((*bytes & 0xf8) == 0xf0) + { + // Four byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx. + value = *bytes & 0x07; + remainingBytes = 3; + } + else + { + // Invalid UTF-8 sequence. + return -1; + } + + // Don't read past the end of the buffer on truncated UTF-8. + if (remainingBytes > length - 1) return -1; + + while (remainingBytes > 0) + { + bytes++; + remainingBytes--; + + // Remaining bytes must be of form 10xxxxxx. + if ((*bytes & 0xc0) != 0x80) return -1; + + value = value << 6 | (*bytes & 0x3f); + } + + return value; +} + +int wrenUtf8DecodeNumBytes(uint8_t byte) +{ + // If the byte starts with 10xxxxx, it's the middle of a UTF-8 sequence, so + // don't count it at all. + if ((byte & 0xc0) == 0x80) return 0; + + // The first byte's high bits tell us how many bytes are in the UTF-8 + // sequence. + if ((byte & 0xf8) == 0xf0) return 4; + if ((byte & 0xf0) == 0xe0) return 3; + if ((byte & 0xe0) == 0xc0) return 2; + return 1; +} + +// From: http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2Float +int wrenPowerOf2Ceil(int n) +{ + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n++; + + return n; +} + +uint32_t wrenValidateIndex(uint32_t count, int64_t value) +{ + // Negative indices count from the end. + if (value < 0) value = count + value; + + // Check bounds. + if (value >= 0 && value < count) return (uint32_t)value; + + return UINT32_MAX; +} +// End file "wren_utils.c" +// Begin file "wren_vm.c" +#include +#include + + +#if WREN_OPT_META +// Begin file "wren_opt_meta.h" +#ifndef wren_opt_meta_h +#define wren_opt_meta_h + + +// This module defines the Meta class and its associated methods. +#if WREN_OPT_META + +const char* wrenMetaSource(); +WrenForeignMethodFn wrenMetaBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature); + +#endif + +#endif +// End file "wren_opt_meta.h" +#endif +#if WREN_OPT_RANDOM +// Begin file "wren_opt_random.h" +#ifndef wren_opt_random_h +#define wren_opt_random_h + + +#if WREN_OPT_RANDOM + +const char* wrenRandomSource(); +WrenForeignClassMethods wrenRandomBindForeignClass(WrenVM* vm, + const char* module, + const char* className); +WrenForeignMethodFn wrenRandomBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature); + +#endif + +#endif +// End file "wren_opt_random.h" +#endif + +#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC + #include + #include +#endif + +// The behavior of realloc() when the size is 0 is implementation defined. It +// may return a non-NULL pointer which must not be dereferenced but nevertheless +// should be freed. To prevent that, we avoid calling realloc() with a zero +// size. +static void* defaultReallocate(void* ptr, size_t newSize, void* _) +{ + if (newSize == 0) + { + free(ptr); + return NULL; + } + + return realloc(ptr, newSize); +} + +int wrenGetVersionNumber() +{ + return WREN_VERSION_NUMBER; +} + +void wrenInitConfiguration(WrenConfiguration* config) +{ + config->reallocateFn = defaultReallocate; + config->resolveModuleFn = NULL; + config->loadModuleFn = NULL; + config->bindForeignMethodFn = NULL; + config->bindForeignClassFn = NULL; + config->writeFn = NULL; + config->errorFn = NULL; + config->initialHeapSize = 1024 * 1024 * 10; + config->minHeapSize = 1024 * 1024; + config->heapGrowthPercent = 50; + config->userData = NULL; +} + +WrenVM* wrenNewVM(WrenConfiguration* config) +{ + WrenReallocateFn reallocate = defaultReallocate; + void* userData = NULL; + if (config != NULL) { + userData = config->userData; + reallocate = config->reallocateFn ? config->reallocateFn : defaultReallocate; + } + + WrenVM* vm = (WrenVM*)reallocate(NULL, sizeof(*vm), userData); + memset(vm, 0, sizeof(WrenVM)); + + // Copy the configuration if given one. + if (config != NULL) + { + memcpy(&vm->config, config, sizeof(WrenConfiguration)); + + // We choose to set this after copying, + // rather than modifying the user config pointer + vm->config.reallocateFn = reallocate; + } + else + { + wrenInitConfiguration(&vm->config); + } + + // TODO: Should we allocate and free this during a GC? + vm->grayCount = 0; + // TODO: Tune this. + vm->grayCapacity = 4; + vm->gray = (Obj**)reallocate(NULL, vm->grayCapacity * sizeof(Obj*), userData); + vm->nextGC = vm->config.initialHeapSize; + + wrenSymbolTableInit(&vm->methodNames); + + vm->modules = wrenNewMap(vm); + wrenInitializeCore(vm); + return vm; +} + +void wrenFreeVM(WrenVM* vm) +{ + ASSERT(vm->methodNames.count > 0, "VM appears to have already been freed."); + + // Free all of the GC objects. + Obj* obj = vm->first; + while (obj != NULL) + { + Obj* next = obj->next; + wrenFreeObj(vm, obj); + obj = next; + } + + // Free up the GC gray set. + vm->gray = (Obj**)vm->config.reallocateFn(vm->gray, 0, vm->config.userData); + + // Tell the user if they didn't free any handles. We don't want to just free + // them here because the host app may still have pointers to them that they + // may try to use. Better to tell them about the bug early. + ASSERT(vm->handles == NULL, "All handles have not been released."); + + wrenSymbolTableClear(vm, &vm->methodNames); + + DEALLOCATE(vm, vm); +} + +void wrenCollectGarbage(WrenVM* vm) +{ +#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC + printf("-- gc --\n"); + + size_t before = vm->bytesAllocated; + double startTime = (double)clock() / CLOCKS_PER_SEC; +#endif + + // Mark all reachable objects. + + // Reset this. As we mark objects, their size will be counted again so that + // we can track how much memory is in use without needing to know the size + // of each *freed* object. + // + // This is important because when freeing an unmarked object, we don't always + // know how much memory it is using. For example, when freeing an instance, + // we need to know its class to know how big it is, but its class may have + // already been freed. + vm->bytesAllocated = 0; + + wrenGrayObj(vm, (Obj*)vm->modules); + + // Temporary roots. + for (int i = 0; i < vm->numTempRoots; i++) + { + wrenGrayObj(vm, vm->tempRoots[i]); + } + + // The current fiber. + wrenGrayObj(vm, (Obj*)vm->fiber); + + // The handles. + for (WrenHandle* handle = vm->handles; + handle != NULL; + handle = handle->next) + { + wrenGrayValue(vm, handle->value); + } + + // Any object the compiler is using (if there is one). + if (vm->compiler != NULL) wrenMarkCompiler(vm, vm->compiler); + + // Method names. + wrenBlackenSymbolTable(vm, &vm->methodNames); + + // Now that we have grayed the roots, do a depth-first search over all of the + // reachable objects. + wrenBlackenObjects(vm); + + // Collect the white objects. + Obj** obj = &vm->first; + while (*obj != NULL) + { + if (!((*obj)->isDark)) + { + // This object wasn't reached, so remove it from the list and free it. + Obj* unreached = *obj; + *obj = unreached->next; + wrenFreeObj(vm, unreached); + } + else + { + // This object was reached, so unmark it (for the next GC) and move on to + // the next. + (*obj)->isDark = false; + obj = &(*obj)->next; + } + } + + // Calculate the next gc point, this is the current allocation plus + // a configured percentage of the current allocation. + vm->nextGC = vm->bytesAllocated + ((vm->bytesAllocated * vm->config.heapGrowthPercent) / 100); + if (vm->nextGC < vm->config.minHeapSize) vm->nextGC = vm->config.minHeapSize; + +#if WREN_DEBUG_TRACE_MEMORY || WREN_DEBUG_TRACE_GC + double elapsed = ((double)clock() / CLOCKS_PER_SEC) - startTime; + // Explicit cast because size_t has different sizes on 32-bit and 64-bit and + // we need a consistent type for the format string. + printf("GC %lu before, %lu after (%lu collected), next at %lu. Took %.3fms.\n", + (unsigned long)before, + (unsigned long)vm->bytesAllocated, + (unsigned long)(before - vm->bytesAllocated), + (unsigned long)vm->nextGC, + elapsed*1000.0); +#endif +} + +void* wrenReallocate(WrenVM* vm, void* memory, size_t oldSize, size_t newSize) +{ +#if WREN_DEBUG_TRACE_MEMORY + // Explicit cast because size_t has different sizes on 32-bit and 64-bit and + // we need a consistent type for the format string. + printf("reallocate %p %lu -> %lu\n", + memory, (unsigned long)oldSize, (unsigned long)newSize); +#endif + + // If new bytes are being allocated, add them to the total count. If objects + // are being completely deallocated, we don't track that (since we don't + // track the original size). Instead, that will be handled while marking + // during the next GC. + vm->bytesAllocated += newSize - oldSize; + +#if WREN_DEBUG_GC_STRESS + // Since collecting calls this function to free things, make sure we don't + // recurse. + if (newSize > 0) wrenCollectGarbage(vm); +#else + if (newSize > 0 && vm->bytesAllocated > vm->nextGC) wrenCollectGarbage(vm); +#endif + + return vm->config.reallocateFn(memory, newSize, vm->config.userData); +} + +// Captures the local variable [local] into an [Upvalue]. If that local is +// already in an upvalue, the existing one will be used. (This is important to +// ensure that multiple closures closing over the same variable actually see +// the same variable.) Otherwise, it will create a new open upvalue and add it +// the fiber's list of upvalues. +static ObjUpvalue* captureUpvalue(WrenVM* vm, ObjFiber* fiber, Value* local) +{ + // If there are no open upvalues at all, we must need a new one. + if (fiber->openUpvalues == NULL) + { + fiber->openUpvalues = wrenNewUpvalue(vm, local); + return fiber->openUpvalues; + } + + ObjUpvalue* prevUpvalue = NULL; + ObjUpvalue* upvalue = fiber->openUpvalues; + + // Walk towards the bottom of the stack until we find a previously existing + // upvalue or pass where it should be. + while (upvalue != NULL && upvalue->value > local) + { + prevUpvalue = upvalue; + upvalue = upvalue->next; + } + + // Found an existing upvalue for this local. + if (upvalue != NULL && upvalue->value == local) return upvalue; + + // We've walked past this local on the stack, so there must not be an + // upvalue for it already. Make a new one and link it in in the right + // place to keep the list sorted. + ObjUpvalue* createdUpvalue = wrenNewUpvalue(vm, local); + if (prevUpvalue == NULL) + { + // The new one is the first one in the list. + fiber->openUpvalues = createdUpvalue; + } + else + { + prevUpvalue->next = createdUpvalue; + } + + createdUpvalue->next = upvalue; + return createdUpvalue; +} + +// Closes any open upvalues that have been created for stack slots at [last] +// and above. +static void closeUpvalues(ObjFiber* fiber, Value* last) +{ + while (fiber->openUpvalues != NULL && + fiber->openUpvalues->value >= last) + { + ObjUpvalue* upvalue = fiber->openUpvalues; + + // Move the value into the upvalue itself and point the upvalue to it. + upvalue->closed = *upvalue->value; + upvalue->value = &upvalue->closed; + + // Remove it from the open upvalue list. + fiber->openUpvalues = upvalue->next; + } +} + +// Looks up a foreign method in [moduleName] on [className] with [signature]. +// +// This will try the host's foreign method binder first. If that fails, it +// falls back to handling the built-in modules. +static WrenForeignMethodFn findForeignMethod(WrenVM* vm, + const char* moduleName, + const char* className, + bool isStatic, + const char* signature) +{ + WrenForeignMethodFn method = NULL; + + if (vm->config.bindForeignMethodFn != NULL) + { + method = vm->config.bindForeignMethodFn(vm, moduleName, className, isStatic, + signature); + } + + // If the host didn't provide it, see if it's an optional one. + if (method == NULL) + { +#if WREN_OPT_META + if (strcmp(moduleName, "meta") == 0) + { + method = wrenMetaBindForeignMethod(vm, className, isStatic, signature); + } +#endif +#if WREN_OPT_RANDOM + if (strcmp(moduleName, "random") == 0) + { + method = wrenRandomBindForeignMethod(vm, className, isStatic, signature); + } +#endif + } + + return method; +} + +// Defines [methodValue] as a method on [classObj]. +// +// Handles both foreign methods where [methodValue] is a string containing the +// method's signature and Wren methods where [methodValue] is a function. +// +// Aborts the current fiber if the method is a foreign method that could not be +// found. +static void bindMethod(WrenVM* vm, int methodType, int symbol, + ObjModule* module, ObjClass* classObj, Value methodValue) +{ + const char* className = classObj->name->value; + if (methodType == CODE_METHOD_STATIC) classObj = classObj->obj.classObj; + + Method method; + if (IS_STRING(methodValue)) + { + const char* name = AS_CSTRING(methodValue); + method.type = METHOD_FOREIGN; + method.as.foreign = findForeignMethod(vm, module->name->value, + className, + methodType == CODE_METHOD_STATIC, + name); + + if (method.as.foreign == NULL) + { + vm->fiber->error = wrenStringFormat(vm, + "Could not find foreign method '@' for class $ in module '$'.", + methodValue, classObj->name->value, module->name->value); + return; + } + } + else + { + method.as.closure = AS_CLOSURE(methodValue); + method.type = METHOD_BLOCK; + + // Patch up the bytecode now that we know the superclass. + wrenBindMethodCode(classObj, method.as.closure->fn); + } + + wrenBindMethod(vm, classObj, symbol, method); +} + +static void callForeign(WrenVM* vm, ObjFiber* fiber, + WrenForeignMethodFn foreign, int numArgs) +{ + ASSERT(vm->apiStack == NULL, "Cannot already be in foreign call."); + vm->apiStack = fiber->stackTop - numArgs; + + foreign(vm); + + // Discard the stack slots for the arguments and temporaries but leave one + // for the result. + fiber->stackTop = vm->apiStack + 1; + + vm->apiStack = NULL; +} + +// Handles the current fiber having aborted because of an error. +// +// Walks the call chain of fibers, aborting each one until it hits a fiber that +// handles the error. If none do, tells the VM to stop. +static void runtimeError(WrenVM* vm) +{ + ASSERT(wrenHasError(vm->fiber), "Should only call this after an error."); + + ObjFiber* current = vm->fiber; + Value error = current->error; + + while (current != NULL) + { + // Every fiber along the call chain gets aborted with the same error. + current->error = error; + + // If the caller ran this fiber using "try", give it the error and stop. + if (current->state == FIBER_TRY) + { + // Make the caller's try method return the error message. + current->caller->stackTop[-1] = vm->fiber->error; + vm->fiber = current->caller; + return; + } + + // Otherwise, unhook the caller since we will never resume and return to it. + ObjFiber* caller = current->caller; + current->caller = NULL; + current = caller; + } + + // If we got here, nothing caught the error, so show the stack trace. + wrenDebugPrintStackTrace(vm); + vm->fiber = NULL; + vm->apiStack = NULL; +} + +// Aborts the current fiber with an appropriate method not found error for a +// method with [symbol] on [classObj]. +static void methodNotFound(WrenVM* vm, ObjClass* classObj, int symbol) +{ + vm->fiber->error = wrenStringFormat(vm, "@ does not implement '$'.", + OBJ_VAL(classObj->name), vm->methodNames.data[symbol]->value); +} + +// Looks up the previously loaded module with [name]. +// +// Returns `NULL` if no module with that name has been loaded. +static ObjModule* getModule(WrenVM* vm, Value name) +{ + Value moduleValue = wrenMapGet(vm->modules, name); + return !IS_UNDEFINED(moduleValue) ? AS_MODULE(moduleValue) : NULL; +} + +static ObjClosure* compileInModule(WrenVM* vm, Value name, const char* source, + bool isExpression, bool printErrors) +{ + // See if the module has already been loaded. + ObjModule* module = getModule(vm, name); + if (module == NULL) + { + module = wrenNewModule(vm, AS_STRING(name)); + + // It's possible for the wrenMapSet below to resize the modules map, + // and trigger a GC while doing so. When this happens it will collect + // the module we've just created. Once in the map it is safe. + wrenPushRoot(vm, (Obj*)module); + + // Store it in the VM's module registry so we don't load the same module + // multiple times. + wrenMapSet(vm, vm->modules, name, OBJ_VAL(module)); + + wrenPopRoot(vm); + + // Implicitly import the core module. + ObjModule* coreModule = getModule(vm, NULL_VAL); + for (int i = 0; i < coreModule->variables.count; i++) + { + wrenDefineVariable(vm, module, + coreModule->variableNames.data[i]->value, + coreModule->variableNames.data[i]->length, + coreModule->variables.data[i], NULL); + } + } + + ObjFn* fn = wrenCompile(vm, module, source, isExpression, printErrors); + if (fn == NULL) + { + // TODO: Should we still store the module even if it didn't compile? + return NULL; + } + + // Functions are always wrapped in closures. + wrenPushRoot(vm, (Obj*)fn); + ObjClosure* closure = wrenNewClosure(vm, fn); + wrenPopRoot(vm); // fn. + + return closure; +} + +// Verifies that [superclassValue] is a valid object to inherit from. That +// means it must be a class and cannot be the class of any built-in type. +// +// Also validates that it doesn't result in a class with too many fields and +// the other limitations foreign classes have. +// +// If successful, returns `null`. Otherwise, returns a string for the runtime +// error message. +static Value validateSuperclass(WrenVM* vm, Value name, Value superclassValue, + int numFields) +{ + // Make sure the superclass is a class. + if (!IS_CLASS(superclassValue)) + { + return wrenStringFormat(vm, + "Class '@' cannot inherit from a non-class object.", + name); + } + + // Make sure it doesn't inherit from a sealed built-in type. Primitive methods + // on these classes assume the instance is one of the other Obj___ types and + // will fail horribly if it's actually an ObjInstance. + ObjClass* superclass = AS_CLASS(superclassValue); + if (superclass == vm->classClass || + superclass == vm->fiberClass || + superclass == vm->fnClass || // Includes OBJ_CLOSURE. + superclass == vm->listClass || + superclass == vm->mapClass || + superclass == vm->rangeClass || + superclass == vm->stringClass || + superclass == vm->boolClass || + superclass == vm->nullClass || + superclass == vm->numClass) + { + return wrenStringFormat(vm, + "Class '@' cannot inherit from built-in class '@'.", + name, OBJ_VAL(superclass->name)); + } + + if (superclass->numFields == -1) + { + return wrenStringFormat(vm, + "Class '@' cannot inherit from foreign class '@'.", + name, OBJ_VAL(superclass->name)); + } + + if (numFields == -1 && superclass->numFields > 0) + { + return wrenStringFormat(vm, + "Foreign class '@' may not inherit from a class with fields.", + name); + } + + if (superclass->numFields + numFields > MAX_FIELDS) + { + return wrenStringFormat(vm, + "Class '@' may not have more than 255 fields, including inherited " + "ones.", name); + } + + return NULL_VAL; +} + +static void bindForeignClass(WrenVM* vm, ObjClass* classObj, ObjModule* module) +{ + WrenForeignClassMethods methods; + methods.allocate = NULL; + methods.finalize = NULL; + + // Check the optional built-in module first so the host can override it. + + if (vm->config.bindForeignClassFn != NULL) + { + methods = vm->config.bindForeignClassFn(vm, module->name->value, + classObj->name->value); + } + + // If the host didn't provide it, see if it's a built in optional module. + if (methods.allocate == NULL && methods.finalize == NULL) + { +#if WREN_OPT_RANDOM + if (strcmp(module->name->value, "random") == 0) + { + methods = wrenRandomBindForeignClass(vm, module->name->value, + classObj->name->value); + } +#endif + } + + Method method; + method.type = METHOD_FOREIGN; + + // Add the symbol even if there is no allocator so we can ensure that the + // symbol itself is always in the symbol table. + int symbol = wrenSymbolTableEnsure(vm, &vm->methodNames, "", 10); + if (methods.allocate != NULL) + { + method.as.foreign = methods.allocate; + wrenBindMethod(vm, classObj, symbol, method); + } + + // Add the symbol even if there is no finalizer so we can ensure that the + // symbol itself is always in the symbol table. + symbol = wrenSymbolTableEnsure(vm, &vm->methodNames, "", 10); + if (methods.finalize != NULL) + { + method.as.foreign = (WrenForeignMethodFn)methods.finalize; + wrenBindMethod(vm, classObj, symbol, method); + } +} + +// Completes the process for creating a new class. +// +// The class attributes instance and the class itself should be on the +// top of the fiber's stack. +// +// This process handles moving the attribute data for a class from +// compile time to runtime, since it now has all the attributes associated +// with a class, including for methods. +static void endClass(WrenVM* vm) +{ + // Pull the attributes and class off the stack + Value attributes = vm->fiber->stackTop[-2]; + Value classValue = vm->fiber->stackTop[-1]; + + // Remove the stack items + vm->fiber->stackTop -= 2; + + ObjClass* classObj = AS_CLASS(classValue); + classObj->attributes = attributes; +} + +// Creates a new class. +// +// If [numFields] is -1, the class is a foreign class. The name and superclass +// should be on top of the fiber's stack. After calling this, the top of the +// stack will contain the new class. +// +// Aborts the current fiber if an error occurs. +static void createClass(WrenVM* vm, int numFields, ObjModule* module) +{ + // Pull the name and superclass off the stack. + Value name = vm->fiber->stackTop[-2]; + Value superclass = vm->fiber->stackTop[-1]; + + // We have two values on the stack and we are going to leave one, so discard + // the other slot. + vm->fiber->stackTop--; + + vm->fiber->error = validateSuperclass(vm, name, superclass, numFields); + if (wrenHasError(vm->fiber)) return; + + ObjClass* classObj = wrenNewClass(vm, AS_CLASS(superclass), numFields, + AS_STRING(name)); + vm->fiber->stackTop[-1] = OBJ_VAL(classObj); + + if (numFields == -1) bindForeignClass(vm, classObj, module); +} + +static void createForeign(WrenVM* vm, ObjFiber* fiber, Value* stack) +{ + ObjClass* classObj = AS_CLASS(stack[0]); + ASSERT(classObj->numFields == -1, "Class must be a foreign class."); + + // TODO: Don't look up every time. + int symbol = wrenSymbolTableFind(&vm->methodNames, "", 10); + ASSERT(symbol != -1, "Should have defined symbol."); + + ASSERT(classObj->methods.count > symbol, "Class should have allocator."); + Method* method = &classObj->methods.data[symbol]; + ASSERT(method->type == METHOD_FOREIGN, "Allocator should be foreign."); + + // Pass the constructor arguments to the allocator as well. + ASSERT(vm->apiStack == NULL, "Cannot already be in foreign call."); + vm->apiStack = stack; + + method->as.foreign(vm); + + vm->apiStack = NULL; +} + +void wrenFinalizeForeign(WrenVM* vm, ObjForeign* foreign) +{ + // TODO: Don't look up every time. + int symbol = wrenSymbolTableFind(&vm->methodNames, "", 10); + ASSERT(symbol != -1, "Should have defined symbol."); + + // If there are no finalizers, don't finalize it. + if (symbol == -1) return; + + // If the class doesn't have a finalizer, bail out. + ObjClass* classObj = foreign->obj.classObj; + if (symbol >= classObj->methods.count) return; + + Method* method = &classObj->methods.data[symbol]; + if (method->type == METHOD_NONE) return; + + ASSERT(method->type == METHOD_FOREIGN, "Finalizer should be foreign."); + + WrenFinalizerFn finalizer = (WrenFinalizerFn)method->as.foreign; + finalizer(foreign->data); +} + +// Let the host resolve an imported module name if it wants to. +static Value resolveModule(WrenVM* vm, Value name) +{ + // If the host doesn't care to resolve, leave the name alone. + if (vm->config.resolveModuleFn == NULL) return name; + + ObjFiber* fiber = vm->fiber; + ObjFn* fn = fiber->frames[fiber->numFrames - 1].closure->fn; + ObjString* importer = fn->module->name; + + const char* resolved = vm->config.resolveModuleFn(vm, importer->value, + AS_CSTRING(name)); + if (resolved == NULL) + { + vm->fiber->error = wrenStringFormat(vm, + "Could not resolve module '@' imported from '@'.", + name, OBJ_VAL(importer)); + return NULL_VAL; + } + + // If they resolved to the exact same string, we don't need to copy it. + if (resolved == AS_CSTRING(name)) return name; + + // Copy the string into a Wren String object. + name = wrenNewString(vm, resolved); + DEALLOCATE(vm, (char*)resolved); + return name; +} + +static Value importModule(WrenVM* vm, Value name) +{ + name = resolveModule(vm, name); + + // If the module is already loaded, we don't need to do anything. + Value existing = wrenMapGet(vm->modules, name); + if (!IS_UNDEFINED(existing)) return existing; + + wrenPushRoot(vm, AS_OBJ(name)); + + WrenLoadModuleResult result = {0}; + const char* source = NULL; + + // Let the host try to provide the module. + if (vm->config.loadModuleFn != NULL) + { + result = vm->config.loadModuleFn(vm, AS_CSTRING(name)); + } + + // If the host didn't provide it, see if it's a built in optional module. + if (result.source == NULL) + { + result.onComplete = NULL; + ObjString* nameString = AS_STRING(name); +#if WREN_OPT_META + if (strcmp(nameString->value, "meta") == 0) result.source = wrenMetaSource(); +#endif +#if WREN_OPT_RANDOM + if (strcmp(nameString->value, "random") == 0) result.source = wrenRandomSource(); +#endif + } + + if (result.source == NULL) + { + vm->fiber->error = wrenStringFormat(vm, "Could not load module '@'.", name); + wrenPopRoot(vm); // name. + return NULL_VAL; + } + + ObjClosure* moduleClosure = compileInModule(vm, name, result.source, false, true); + + // Now that we're done, give the result back in case there's cleanup to do. + if(result.onComplete) result.onComplete(vm, AS_CSTRING(name), result); + + if (moduleClosure == NULL) + { + vm->fiber->error = wrenStringFormat(vm, + "Could not compile module '@'.", name); + wrenPopRoot(vm); // name. + return NULL_VAL; + } + + wrenPopRoot(vm); // name. + + // Return the closure that executes the module. + return OBJ_VAL(moduleClosure); +} + +static Value getModuleVariable(WrenVM* vm, ObjModule* module, + Value variableName) +{ + ObjString* variable = AS_STRING(variableName); + uint32_t variableEntry = wrenSymbolTableFind(&module->variableNames, + variable->value, + variable->length); + + // It's a runtime error if the imported variable does not exist. + if (variableEntry != UINT32_MAX) + { + return module->variables.data[variableEntry]; + } + + vm->fiber->error = wrenStringFormat(vm, + "Could not find a variable named '@' in module '@'.", + variableName, OBJ_VAL(module->name)); + return NULL_VAL; +} + +inline static bool checkArity(WrenVM* vm, Value value, int numArgs) +{ + ASSERT(IS_CLOSURE(value), "Receiver must be a closure."); + ObjFn* fn = AS_CLOSURE(value)->fn; + + // We only care about missing arguments, not extras. The "- 1" is because + // numArgs includes the receiver, the function itself, which we don't want to + // count. + if (numArgs - 1 >= fn->arity) return true; + + vm->fiber->error = CONST_STRING(vm, "Function expects more arguments."); + return false; +} + + +// The main bytecode interpreter loop. This is where the magic happens. It is +// also, as you can imagine, highly performance critical. +static WrenInterpretResult runInterpreter(WrenVM* vm, register ObjFiber* fiber) +{ + // Remember the current fiber so we can find it if a GC happens. + vm->fiber = fiber; + fiber->state = FIBER_ROOT; + + // Hoist these into local variables. They are accessed frequently in the loop + // but assigned less frequently. Keeping them in locals and updating them when + // a call frame has been pushed or popped gives a large speed boost. + register CallFrame* frame; + register Value* stackStart; + register uint8_t* ip; + register ObjFn* fn; + + // These macros are designed to only be invoked within this function. + #define PUSH(value) (*fiber->stackTop++ = value) + #define POP() (*(--fiber->stackTop)) + #define DROP() (fiber->stackTop--) + #define PEEK() (*(fiber->stackTop - 1)) + #define PEEK2() (*(fiber->stackTop - 2)) + #define READ_BYTE() (*ip++) + #define READ_SHORT() (ip += 2, (uint16_t)((ip[-2] << 8) | ip[-1])) + + // Use this before a CallFrame is pushed to store the local variables back + // into the current one. + #define STORE_FRAME() frame->ip = ip + + // Use this after a CallFrame has been pushed or popped to refresh the local + // variables. + #define LOAD_FRAME() \ + do \ + { \ + frame = &fiber->frames[fiber->numFrames - 1]; \ + stackStart = frame->stackStart; \ + ip = frame->ip; \ + fn = frame->closure->fn; \ + } while (false) + + // Terminates the current fiber with error string [error]. If another calling + // fiber is willing to catch the error, transfers control to it, otherwise + // exits the interpreter. + #define RUNTIME_ERROR() \ + do \ + { \ + STORE_FRAME(); \ + runtimeError(vm); \ + if (vm->fiber == NULL) return WREN_RESULT_RUNTIME_ERROR; \ + fiber = vm->fiber; \ + LOAD_FRAME(); \ + DISPATCH(); \ + } while (false) + + #if WREN_DEBUG_TRACE_INSTRUCTIONS + // Prints the stack and instruction before each instruction is executed. + #define DEBUG_TRACE_INSTRUCTIONS() \ + do \ + { \ + wrenDumpStack(fiber); \ + wrenDumpInstruction(vm, fn, (int)(ip - fn->code.data)); \ + } while (false) + #else + #define DEBUG_TRACE_INSTRUCTIONS() do { } while (false) + #endif + + #if WREN_COMPUTED_GOTO + + static void* dispatchTable[] = { + #define OPCODE(name, _) &&code_##name, +// Begin file "wren_opcodes.h" +// This defines the bytecode instructions used by the VM. It does so by invoking +// an OPCODE() macro which is expected to be defined at the point that this is +// included. (See: http://en.wikipedia.org/wiki/X_Macro for more.) +// +// The first argument is the name of the opcode. The second is its "stack +// effect" -- the amount that the op code changes the size of the stack. A +// stack effect of 1 means it pushes a value and the stack grows one larger. +// -2 means it pops two values, etc. +// +// Note that the order of instructions here affects the order of the dispatch +// table in the VM's interpreter loop. That in turn affects caching which +// affects overall performance. Take care to run benchmarks if you change the +// order here. + +// Load the constant at index [arg]. +OPCODE(CONSTANT, 1) + +// Push null onto the stack. +OPCODE(NULL, 1) + +// Push false onto the stack. +OPCODE(FALSE, 1) + +// Push true onto the stack. +OPCODE(TRUE, 1) + +// Pushes the value in the given local slot. +OPCODE(LOAD_LOCAL_0, 1) +OPCODE(LOAD_LOCAL_1, 1) +OPCODE(LOAD_LOCAL_2, 1) +OPCODE(LOAD_LOCAL_3, 1) +OPCODE(LOAD_LOCAL_4, 1) +OPCODE(LOAD_LOCAL_5, 1) +OPCODE(LOAD_LOCAL_6, 1) +OPCODE(LOAD_LOCAL_7, 1) +OPCODE(LOAD_LOCAL_8, 1) + +// Note: The compiler assumes the following _STORE instructions always +// immediately follow their corresponding _LOAD ones. + +// Pushes the value in local slot [arg]. +OPCODE(LOAD_LOCAL, 1) + +// Stores the top of stack in local slot [arg]. Does not pop it. +OPCODE(STORE_LOCAL, 0) + +// Pushes the value in upvalue [arg]. +OPCODE(LOAD_UPVALUE, 1) + +// Stores the top of stack in upvalue [arg]. Does not pop it. +OPCODE(STORE_UPVALUE, 0) + +// Pushes the value of the top-level variable in slot [arg]. +OPCODE(LOAD_MODULE_VAR, 1) + +// Stores the top of stack in top-level variable slot [arg]. Does not pop it. +OPCODE(STORE_MODULE_VAR, 0) + +// Pushes the value of the field in slot [arg] of the receiver of the current +// function. This is used for regular field accesses on "this" directly in +// methods. This instruction is faster than the more general CODE_LOAD_FIELD +// instruction. +OPCODE(LOAD_FIELD_THIS, 1) + +// Stores the top of the stack in field slot [arg] in the receiver of the +// current value. Does not pop the value. This instruction is faster than the +// more general CODE_LOAD_FIELD instruction. +OPCODE(STORE_FIELD_THIS, 0) + +// Pops an instance and pushes the value of the field in slot [arg] of it. +OPCODE(LOAD_FIELD, 0) + +// Pops an instance and stores the subsequent top of stack in field slot +// [arg] in it. Does not pop the value. +OPCODE(STORE_FIELD, -1) + +// Pop and discard the top of stack. +OPCODE(POP, -1) + +// Invoke the method with symbol [arg]. The number indicates the number of +// arguments (not including the receiver). +OPCODE(CALL_0, 0) +OPCODE(CALL_1, -1) +OPCODE(CALL_2, -2) +OPCODE(CALL_3, -3) +OPCODE(CALL_4, -4) +OPCODE(CALL_5, -5) +OPCODE(CALL_6, -6) +OPCODE(CALL_7, -7) +OPCODE(CALL_8, -8) +OPCODE(CALL_9, -9) +OPCODE(CALL_10, -10) +OPCODE(CALL_11, -11) +OPCODE(CALL_12, -12) +OPCODE(CALL_13, -13) +OPCODE(CALL_14, -14) +OPCODE(CALL_15, -15) +OPCODE(CALL_16, -16) + +// Invoke a superclass method with symbol [arg]. The number indicates the +// number of arguments (not including the receiver). +OPCODE(SUPER_0, 0) +OPCODE(SUPER_1, -1) +OPCODE(SUPER_2, -2) +OPCODE(SUPER_3, -3) +OPCODE(SUPER_4, -4) +OPCODE(SUPER_5, -5) +OPCODE(SUPER_6, -6) +OPCODE(SUPER_7, -7) +OPCODE(SUPER_8, -8) +OPCODE(SUPER_9, -9) +OPCODE(SUPER_10, -10) +OPCODE(SUPER_11, -11) +OPCODE(SUPER_12, -12) +OPCODE(SUPER_13, -13) +OPCODE(SUPER_14, -14) +OPCODE(SUPER_15, -15) +OPCODE(SUPER_16, -16) + +// Jump the instruction pointer [arg] forward. +OPCODE(JUMP, 0) + +// Jump the instruction pointer [arg] backward. +OPCODE(LOOP, 0) + +// Pop and if not truthy then jump the instruction pointer [arg] forward. +OPCODE(JUMP_IF, -1) + +// If the top of the stack is false, jump [arg] forward. Otherwise, pop and +// continue. +OPCODE(AND, -1) + +// If the top of the stack is non-false, jump [arg] forward. Otherwise, pop +// and continue. +OPCODE(OR, -1) + +// Close the upvalue for the local on the top of the stack, then pop it. +OPCODE(CLOSE_UPVALUE, -1) + +// Exit from the current function and return the value on the top of the +// stack. +OPCODE(RETURN, 0) + +// Creates a closure for the function stored at [arg] in the constant table. +// +// Following the function argument is a number of arguments, two for each +// upvalue. The first is true if the variable being captured is a local (as +// opposed to an upvalue), and the second is the index of the local or +// upvalue being captured. +// +// Pushes the created closure. +OPCODE(CLOSURE, 1) + +// Creates a new instance of a class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(CONSTRUCT, 0) + +// Creates a new instance of a foreign class. +// +// Assumes the class object is in slot zero, and replaces it with the new +// uninitialized instance of that class. This opcode is only emitted by the +// compiler-generated constructor metaclass methods. +OPCODE(FOREIGN_CONSTRUCT, 0) + +// Creates a class. Top of stack is the superclass. Below that is a string for +// the name of the class. Byte [arg] is the number of fields in the class. +OPCODE(CLASS, -1) + +// Ends a class. +// Atm the stack contains the class and the ClassAttributes (or null). +OPCODE(END_CLASS, -2) + +// Creates a foreign class. Top of stack is the superclass. Below that is a +// string for the name of the class. +OPCODE(FOREIGN_CLASS, -1) + +// Define a method for symbol [arg]. The class receiving the method is popped +// off the stack, then the function defining the body is popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_INSTANCE, -2) + +// Define a method for symbol [arg]. The class whose metaclass will receive +// the method is popped off the stack, then the function defining the body is +// popped. +// +// If a foreign method is being defined, the "function" will be a string +// identifying the foreign method. Otherwise, it will be a function or +// closure. +OPCODE(METHOD_STATIC, -2) + +// This is executed at the end of the module's body. Pushes NULL onto the stack +// as the "return value" of the import statement and stores the module as the +// most recently imported one. +OPCODE(END_MODULE, 1) + +// Import a module whose name is the string stored at [arg] in the constant +// table. +// +// Pushes null onto the stack so that the fiber for the imported module can +// replace that with a dummy value when it returns. (Fibers always return a +// value when resuming a caller.) +OPCODE(IMPORT_MODULE, 1) + +// Import a variable from the most recently imported module. The name of the +// variable to import is at [arg] in the constant table. Pushes the loaded +// variable's value. +OPCODE(IMPORT_VARIABLE, 1) + +// This pseudo-instruction indicates the end of the bytecode. It should +// always be preceded by a `CODE_RETURN`, so is never actually executed. +OPCODE(END, 0) +// End file "wren_opcodes.h" + #undef OPCODE + }; + + #define INTERPRET_LOOP DISPATCH(); + #define CASE_CODE(name) code_##name + + #define DISPATCH() \ + do \ + { \ + DEBUG_TRACE_INSTRUCTIONS(); \ + goto *dispatchTable[instruction = (Code)READ_BYTE()]; \ + } while (false) + + #else + + #define INTERPRET_LOOP \ + loop: \ + DEBUG_TRACE_INSTRUCTIONS(); \ + switch (instruction = (Code)READ_BYTE()) + + #define CASE_CODE(name) case CODE_##name + #define DISPATCH() goto loop + + #endif + + LOAD_FRAME(); + + Code instruction; + INTERPRET_LOOP + { + CASE_CODE(LOAD_LOCAL_0): + CASE_CODE(LOAD_LOCAL_1): + CASE_CODE(LOAD_LOCAL_2): + CASE_CODE(LOAD_LOCAL_3): + CASE_CODE(LOAD_LOCAL_4): + CASE_CODE(LOAD_LOCAL_5): + CASE_CODE(LOAD_LOCAL_6): + CASE_CODE(LOAD_LOCAL_7): + CASE_CODE(LOAD_LOCAL_8): + PUSH(stackStart[instruction - CODE_LOAD_LOCAL_0]); + DISPATCH(); + + CASE_CODE(LOAD_LOCAL): + PUSH(stackStart[READ_BYTE()]); + DISPATCH(); + + CASE_CODE(LOAD_FIELD_THIS): + { + uint8_t field = READ_BYTE(); + Value receiver = stackStart[0]; + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + PUSH(instance->fields[field]); + DISPATCH(); + } + + CASE_CODE(POP): DROP(); DISPATCH(); + CASE_CODE(NULL): PUSH(NULL_VAL); DISPATCH(); + CASE_CODE(FALSE): PUSH(FALSE_VAL); DISPATCH(); + CASE_CODE(TRUE): PUSH(TRUE_VAL); DISPATCH(); + + CASE_CODE(STORE_LOCAL): + stackStart[READ_BYTE()] = PEEK(); + DISPATCH(); + + CASE_CODE(CONSTANT): + PUSH(fn->constants.data[READ_SHORT()]); + DISPATCH(); + + { + // The opcodes for doing method and superclass calls share a lot of code. + // However, doing an if() test in the middle of the instruction sequence + // to handle the bit that is special to super calls makes the non-super + // call path noticeably slower. + // + // Instead, we do this old school using an explicit goto to share code for + // everything at the tail end of the call-handling code that is the same + // between normal and superclass calls. + int numArgs; + int symbol; + + Value* args; + ObjClass* classObj; + + Method* method; + + CASE_CODE(CALL_0): + CASE_CODE(CALL_1): + CASE_CODE(CALL_2): + CASE_CODE(CALL_3): + CASE_CODE(CALL_4): + CASE_CODE(CALL_5): + CASE_CODE(CALL_6): + CASE_CODE(CALL_7): + CASE_CODE(CALL_8): + CASE_CODE(CALL_9): + CASE_CODE(CALL_10): + CASE_CODE(CALL_11): + CASE_CODE(CALL_12): + CASE_CODE(CALL_13): + CASE_CODE(CALL_14): + CASE_CODE(CALL_15): + CASE_CODE(CALL_16): + // Add one for the implicit receiver argument. + numArgs = instruction - CODE_CALL_0 + 1; + symbol = READ_SHORT(); + + // The receiver is the first argument. + args = fiber->stackTop - numArgs; + classObj = wrenGetClassInline(vm, args[0]); + goto completeCall; + + CASE_CODE(SUPER_0): + CASE_CODE(SUPER_1): + CASE_CODE(SUPER_2): + CASE_CODE(SUPER_3): + CASE_CODE(SUPER_4): + CASE_CODE(SUPER_5): + CASE_CODE(SUPER_6): + CASE_CODE(SUPER_7): + CASE_CODE(SUPER_8): + CASE_CODE(SUPER_9): + CASE_CODE(SUPER_10): + CASE_CODE(SUPER_11): + CASE_CODE(SUPER_12): + CASE_CODE(SUPER_13): + CASE_CODE(SUPER_14): + CASE_CODE(SUPER_15): + CASE_CODE(SUPER_16): + // Add one for the implicit receiver argument. + numArgs = instruction - CODE_SUPER_0 + 1; + symbol = READ_SHORT(); + + // The receiver is the first argument. + args = fiber->stackTop - numArgs; + + // The superclass is stored in a constant. + classObj = AS_CLASS(fn->constants.data[READ_SHORT()]); + goto completeCall; + + completeCall: + // If the class's method table doesn't include the symbol, bail. + if (symbol >= classObj->methods.count || + (method = &classObj->methods.data[symbol])->type == METHOD_NONE) + { + methodNotFound(vm, classObj, symbol); + RUNTIME_ERROR(); + } + + switch (method->type) + { + case METHOD_PRIMITIVE: + if (method->as.primitive(vm, args)) + { + // The result is now in the first arg slot. Discard the other + // stack slots. + fiber->stackTop -= numArgs - 1; + } else { + // An error, fiber switch, or call frame change occurred. + STORE_FRAME(); + + // If we don't have a fiber to switch to, stop interpreting. + fiber = vm->fiber; + if (fiber == NULL) return WREN_RESULT_SUCCESS; + if (wrenHasError(fiber)) RUNTIME_ERROR(); + LOAD_FRAME(); + } + break; + + case METHOD_FUNCTION_CALL: + if (!checkArity(vm, args[0], numArgs)) { + RUNTIME_ERROR(); + break; + } + + STORE_FRAME(); + method->as.primitive(vm, args); + LOAD_FRAME(); + break; + + case METHOD_FOREIGN: + callForeign(vm, fiber, method->as.foreign, numArgs); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + break; + + case METHOD_BLOCK: + STORE_FRAME(); + wrenCallFunction(vm, fiber, (ObjClosure*)method->as.closure, numArgs); + LOAD_FRAME(); + break; + + case METHOD_NONE: + UNREACHABLE(); + break; + } + DISPATCH(); + } + + CASE_CODE(LOAD_UPVALUE): + { + ObjUpvalue** upvalues = frame->closure->upvalues; + PUSH(*upvalues[READ_BYTE()]->value); + DISPATCH(); + } + + CASE_CODE(STORE_UPVALUE): + { + ObjUpvalue** upvalues = frame->closure->upvalues; + *upvalues[READ_BYTE()]->value = PEEK(); + DISPATCH(); + } + + CASE_CODE(LOAD_MODULE_VAR): + PUSH(fn->module->variables.data[READ_SHORT()]); + DISPATCH(); + + CASE_CODE(STORE_MODULE_VAR): + fn->module->variables.data[READ_SHORT()] = PEEK(); + DISPATCH(); + + CASE_CODE(STORE_FIELD_THIS): + { + uint8_t field = READ_BYTE(); + Value receiver = stackStart[0]; + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + instance->fields[field] = PEEK(); + DISPATCH(); + } + + CASE_CODE(LOAD_FIELD): + { + uint8_t field = READ_BYTE(); + Value receiver = POP(); + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + PUSH(instance->fields[field]); + DISPATCH(); + } + + CASE_CODE(STORE_FIELD): + { + uint8_t field = READ_BYTE(); + Value receiver = POP(); + ASSERT(IS_INSTANCE(receiver), "Receiver should be instance."); + ObjInstance* instance = AS_INSTANCE(receiver); + ASSERT(field < instance->obj.classObj->numFields, "Out of bounds field."); + instance->fields[field] = PEEK(); + DISPATCH(); + } + + CASE_CODE(JUMP): + { + uint16_t offset = READ_SHORT(); + ip += offset; + DISPATCH(); + } + + CASE_CODE(LOOP): + { + // Jump back to the top of the loop. + uint16_t offset = READ_SHORT(); + ip -= offset; + DISPATCH(); + } + + CASE_CODE(JUMP_IF): + { + uint16_t offset = READ_SHORT(); + Value condition = POP(); + + if (wrenIsFalsyValue(condition)) ip += offset; + DISPATCH(); + } + + CASE_CODE(AND): + { + uint16_t offset = READ_SHORT(); + Value condition = PEEK(); + + if (wrenIsFalsyValue(condition)) + { + // Short-circuit the right hand side. + ip += offset; + } + else + { + // Discard the condition and evaluate the right hand side. + DROP(); + } + DISPATCH(); + } + + CASE_CODE(OR): + { + uint16_t offset = READ_SHORT(); + Value condition = PEEK(); + + if (wrenIsFalsyValue(condition)) + { + // Discard the condition and evaluate the right hand side. + DROP(); + } + else + { + // Short-circuit the right hand side. + ip += offset; + } + DISPATCH(); + } + + CASE_CODE(CLOSE_UPVALUE): + // Close the upvalue for the local if we have one. + closeUpvalues(fiber, fiber->stackTop - 1); + DROP(); + DISPATCH(); + + CASE_CODE(RETURN): + { + Value result = POP(); + fiber->numFrames--; + + // Close any upvalues still in scope. + closeUpvalues(fiber, stackStart); + + // If the fiber is complete, end it. + if (fiber->numFrames == 0) + { + // See if there's another fiber to return to. If not, we're done. + if (fiber->caller == NULL) + { + // Store the final result value at the beginning of the stack so the + // C API can get it. + fiber->stack[0] = result; + fiber->stackTop = fiber->stack + 1; + return WREN_RESULT_SUCCESS; + } + + ObjFiber* resumingFiber = fiber->caller; + fiber->caller = NULL; + fiber = resumingFiber; + vm->fiber = resumingFiber; + + // Store the result in the resuming fiber. + fiber->stackTop[-1] = result; + } + else + { + // Store the result of the block in the first slot, which is where the + // caller expects it. + stackStart[0] = result; + + // Discard the stack slots for the call frame (leaving one slot for the + // result). + fiber->stackTop = frame->stackStart + 1; + } + + LOAD_FRAME(); + DISPATCH(); + } + + CASE_CODE(CONSTRUCT): + ASSERT(IS_CLASS(stackStart[0]), "'this' should be a class."); + stackStart[0] = wrenNewInstance(vm, AS_CLASS(stackStart[0])); + DISPATCH(); + + CASE_CODE(FOREIGN_CONSTRUCT): + ASSERT(IS_CLASS(stackStart[0]), "'this' should be a class."); + createForeign(vm, fiber, stackStart); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + + CASE_CODE(CLOSURE): + { + // Create the closure and push it on the stack before creating upvalues + // so that it doesn't get collected. + ObjFn* function = AS_FN(fn->constants.data[READ_SHORT()]); + ObjClosure* closure = wrenNewClosure(vm, function); + PUSH(OBJ_VAL(closure)); + + // Capture upvalues, if any. + for (int i = 0; i < function->numUpvalues; i++) + { + uint8_t isLocal = READ_BYTE(); + uint8_t index = READ_BYTE(); + if (isLocal) + { + // Make an new upvalue to close over the parent's local variable. + closure->upvalues[i] = captureUpvalue(vm, fiber, + frame->stackStart + index); + } + else + { + // Use the same upvalue as the current call frame. + closure->upvalues[i] = frame->closure->upvalues[index]; + } + } + DISPATCH(); + } + + CASE_CODE(END_CLASS): + { + endClass(vm); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + } + + CASE_CODE(CLASS): + { + createClass(vm, READ_BYTE(), NULL); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + } + + CASE_CODE(FOREIGN_CLASS): + { + createClass(vm, -1, fn->module); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DISPATCH(); + } + + CASE_CODE(METHOD_INSTANCE): + CASE_CODE(METHOD_STATIC): + { + uint16_t symbol = READ_SHORT(); + ObjClass* classObj = AS_CLASS(PEEK()); + Value method = PEEK2(); + bindMethod(vm, instruction, symbol, fn->module, classObj, method); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + DROP(); + DROP(); + DISPATCH(); + } + + CASE_CODE(END_MODULE): + { + vm->lastModule = fn->module; + PUSH(NULL_VAL); + DISPATCH(); + } + + CASE_CODE(IMPORT_MODULE): + { + // Make a slot on the stack for the module's fiber to place the return + // value. It will be popped after this fiber is resumed. Store the + // imported module's closure in the slot in case a GC happens when + // invoking the closure. + PUSH(importModule(vm, fn->constants.data[READ_SHORT()])); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + + // If we get a closure, call it to execute the module body. + if (IS_CLOSURE(PEEK())) + { + STORE_FRAME(); + ObjClosure* closure = AS_CLOSURE(PEEK()); + wrenCallFunction(vm, fiber, closure, 1); + LOAD_FRAME(); + } + else + { + // The module has already been loaded. Remember it so we can import + // variables from it if needed. + vm->lastModule = AS_MODULE(PEEK()); + } + + DISPATCH(); + } + + CASE_CODE(IMPORT_VARIABLE): + { + Value variable = fn->constants.data[READ_SHORT()]; + ASSERT(vm->lastModule != NULL, "Should have already imported module."); + Value result = getModuleVariable(vm, vm->lastModule, variable); + if (wrenHasError(fiber)) RUNTIME_ERROR(); + + PUSH(result); + DISPATCH(); + } + + CASE_CODE(END): + // A CODE_END should always be preceded by a CODE_RETURN. If we get here, + // the compiler generated wrong code. + UNREACHABLE(); + } + + // We should only exit this function from an explicit return from CODE_RETURN + // or a runtime error. + UNREACHABLE(); + return WREN_RESULT_RUNTIME_ERROR; + + #undef READ_BYTE + #undef READ_SHORT +} + +WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature) +{ + ASSERT(signature != NULL, "Signature cannot be NULL."); + + int signatureLength = (int)strlen(signature); + ASSERT(signatureLength > 0, "Signature cannot be empty."); + + // Count the number parameters the method expects. + int numParams = 0; + if (signature[signatureLength - 1] == ')') + { + for (int i = signatureLength - 1; i > 0 && signature[i] != '('; i--) + { + if (signature[i] == '_') numParams++; + } + } + + // Count subscript arguments. + if (signature[0] == '[') + { + for (int i = 0; i < signatureLength && signature[i] != ']'; i++) + { + if (signature[i] == '_') numParams++; + } + } + + // Add the signatue to the method table. + int method = wrenSymbolTableEnsure(vm, &vm->methodNames, + signature, signatureLength); + + // Create a little stub function that assumes the arguments are on the stack + // and calls the method. + ObjFn* fn = wrenNewFunction(vm, NULL, numParams + 1); + + // Wrap the function in a closure and then in a handle. Do this here so it + // doesn't get collected as we fill it in. + WrenHandle* value = wrenMakeHandle(vm, OBJ_VAL(fn)); + value->value = OBJ_VAL(wrenNewClosure(vm, fn)); + + wrenByteBufferWrite(vm, &fn->code, (uint8_t)(CODE_CALL_0 + numParams)); + wrenByteBufferWrite(vm, &fn->code, (method >> 8) & 0xff); + wrenByteBufferWrite(vm, &fn->code, method & 0xff); + wrenByteBufferWrite(vm, &fn->code, CODE_RETURN); + wrenByteBufferWrite(vm, &fn->code, CODE_END); + wrenIntBufferFill(vm, &fn->debug->sourceLines, 0, 5); + wrenFunctionBindName(vm, fn, signature, signatureLength); + + return value; +} + +WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method) +{ + ASSERT(method != NULL, "Method cannot be NULL."); + ASSERT(IS_CLOSURE(method->value), "Method must be a method handle."); + ASSERT(vm->fiber != NULL, "Must set up arguments for call first."); + ASSERT(vm->apiStack != NULL, "Must set up arguments for call first."); + ASSERT(vm->fiber->numFrames == 0, "Can not call from a foreign method."); + + ObjClosure* closure = AS_CLOSURE(method->value); + + ASSERT(vm->fiber->stackTop - vm->fiber->stack >= closure->fn->arity, + "Stack must have enough arguments for method."); + + // Clear the API stack. Now that wrenCall() has control, we no longer need + // it. We use this being non-null to tell if re-entrant calls to foreign + // methods are happening, so it's important to clear it out now so that you + // can call foreign methods from within calls to wrenCall(). + vm->apiStack = NULL; + + // Discard any extra temporary slots. We take for granted that the stub + // function has exactly one slot for each argument. + vm->fiber->stackTop = &vm->fiber->stack[closure->fn->maxSlots]; + + wrenCallFunction(vm, vm->fiber, closure, 0); + WrenInterpretResult result = runInterpreter(vm, vm->fiber); + + // If the call didn't abort, then set up the API stack to point to the + // beginning of the stack so the host can access the call's return value. + if (vm->fiber != NULL) vm->apiStack = vm->fiber->stack; + + return result; +} + +WrenHandle* wrenMakeHandle(WrenVM* vm, Value value) +{ + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + // Make a handle for it. + WrenHandle* handle = ALLOCATE(vm, WrenHandle); + handle->value = value; + + if (IS_OBJ(value)) wrenPopRoot(vm); + + // Add it to the front of the linked list of handles. + if (vm->handles != NULL) vm->handles->prev = handle; + handle->prev = NULL; + handle->next = vm->handles; + vm->handles = handle; + + return handle; +} + +void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle) +{ + ASSERT(handle != NULL, "Handle cannot be NULL."); + + // Update the VM's head pointer if we're releasing the first handle. + if (vm->handles == handle) vm->handles = handle->next; + + // Unlink it from the list. + if (handle->prev != NULL) handle->prev->next = handle->next; + if (handle->next != NULL) handle->next->prev = handle->prev; + + // Clear it out. This isn't strictly necessary since we're going to free it, + // but it makes for easier debugging. + handle->prev = NULL; + handle->next = NULL; + handle->value = NULL_VAL; + DEALLOCATE(vm, handle); +} + +WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module, + const char* source) +{ + ObjClosure* closure = wrenCompileSource(vm, module, source, false, true); + if (closure == NULL) return WREN_RESULT_COMPILE_ERROR; + + wrenPushRoot(vm, (Obj*)closure); + ObjFiber* fiber = wrenNewFiber(vm, closure); + wrenPopRoot(vm); // closure. + vm->apiStack = NULL; + + return runInterpreter(vm, fiber); +} + +ObjClosure* wrenCompileSource(WrenVM* vm, const char* module, const char* source, + bool isExpression, bool printErrors) +{ + Value nameValue = NULL_VAL; + if (module != NULL) + { + nameValue = wrenNewString(vm, module); + wrenPushRoot(vm, AS_OBJ(nameValue)); + } + + ObjClosure* closure = compileInModule(vm, nameValue, source, + isExpression, printErrors); + + if (module != NULL) wrenPopRoot(vm); // nameValue. + return closure; +} + +Value wrenGetModuleVariable(WrenVM* vm, Value moduleName, Value variableName) +{ + ObjModule* module = getModule(vm, moduleName); + if (module == NULL) + { + vm->fiber->error = wrenStringFormat(vm, "Module '@' is not loaded.", + moduleName); + return NULL_VAL; + } + + return getModuleVariable(vm, module, variableName); +} + +Value wrenFindVariable(WrenVM* vm, ObjModule* module, const char* name) +{ + int symbol = wrenSymbolTableFind(&module->variableNames, name, strlen(name)); + return module->variables.data[symbol]; +} + +int wrenDeclareVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, int line) +{ + if (module->variables.count == MAX_MODULE_VARS) return -2; + + // Implicitly defined variables get a "value" that is the line where the + // variable is first used. We'll use that later to report an error on the + // right line. + wrenValueBufferWrite(vm, &module->variables, NUM_VAL(line)); + return wrenSymbolTableAdd(vm, &module->variableNames, name, length); +} + +int wrenDefineVariable(WrenVM* vm, ObjModule* module, const char* name, + size_t length, Value value, int* line) +{ + if (module->variables.count == MAX_MODULE_VARS) return -2; + + if (IS_OBJ(value)) wrenPushRoot(vm, AS_OBJ(value)); + + // See if the variable is already explicitly or implicitly declared. + int symbol = wrenSymbolTableFind(&module->variableNames, name, length); + + if (symbol == -1) + { + // Brand new variable. + symbol = wrenSymbolTableAdd(vm, &module->variableNames, name, length); + wrenValueBufferWrite(vm, &module->variables, value); + } + else if (IS_NUM(module->variables.data[symbol])) + { + // An implicitly declared variable's value will always be a number. + // Now we have a real definition. + if(line) *line = (int)AS_NUM(module->variables.data[symbol]); + module->variables.data[symbol] = value; + + // If this was a localname we want to error if it was + // referenced before this definition. + if (wrenIsLocalName(name)) symbol = -3; + } + else + { + // Already explicitly declared. + symbol = -1; + } + + if (IS_OBJ(value)) wrenPopRoot(vm); + + return symbol; +} + +// TODO: Inline? +void wrenPushRoot(WrenVM* vm, Obj* obj) +{ + ASSERT(obj != NULL, "Can't root NULL."); + ASSERT(vm->numTempRoots < WREN_MAX_TEMP_ROOTS, "Too many temporary roots."); + + vm->tempRoots[vm->numTempRoots++] = obj; +} + +void wrenPopRoot(WrenVM* vm) +{ + ASSERT(vm->numTempRoots > 0, "No temporary roots to release."); + vm->numTempRoots--; +} + +int wrenGetSlotCount(WrenVM* vm) +{ + if (vm->apiStack == NULL) return 0; + + return (int)(vm->fiber->stackTop - vm->apiStack); +} + +void wrenEnsureSlots(WrenVM* vm, int numSlots) +{ + // If we don't have a fiber accessible, create one for the API to use. + if (vm->apiStack == NULL) + { + vm->fiber = wrenNewFiber(vm, NULL); + vm->apiStack = vm->fiber->stack; + } + + int currentSize = (int)(vm->fiber->stackTop - vm->apiStack); + if (currentSize >= numSlots) return; + + // Grow the stack if needed. + int needed = (int)(vm->apiStack - vm->fiber->stack) + numSlots; + wrenEnsureStack(vm, vm->fiber, needed); + + vm->fiber->stackTop = vm->apiStack + numSlots; +} + +// Ensures that [slot] is a valid index into the API's stack of slots. +static void validateApiSlot(WrenVM* vm, int slot) +{ + ASSERT(slot >= 0, "Slot cannot be negative."); + ASSERT(slot < wrenGetSlotCount(vm), "Not that many slots."); +} + +// Gets the type of the object in [slot]. +WrenType wrenGetSlotType(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + if (IS_BOOL(vm->apiStack[slot])) return WREN_TYPE_BOOL; + if (IS_NUM(vm->apiStack[slot])) return WREN_TYPE_NUM; + if (IS_FOREIGN(vm->apiStack[slot])) return WREN_TYPE_FOREIGN; + if (IS_LIST(vm->apiStack[slot])) return WREN_TYPE_LIST; + if (IS_MAP(vm->apiStack[slot])) return WREN_TYPE_MAP; + if (IS_NULL(vm->apiStack[slot])) return WREN_TYPE_NULL; + if (IS_STRING(vm->apiStack[slot])) return WREN_TYPE_STRING; + + return WREN_TYPE_UNKNOWN; +} + +bool wrenGetSlotBool(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_BOOL(vm->apiStack[slot]), "Slot must hold a bool."); + + return AS_BOOL(vm->apiStack[slot]); +} + +const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length) +{ + validateApiSlot(vm, slot); + ASSERT(IS_STRING(vm->apiStack[slot]), "Slot must hold a string."); + + ObjString* string = AS_STRING(vm->apiStack[slot]); + *length = string->length; + return string->value; +} + +double wrenGetSlotDouble(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_NUM(vm->apiStack[slot]), "Slot must hold a number."); + + return AS_NUM(vm->apiStack[slot]); +} + +void* wrenGetSlotForeign(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_FOREIGN(vm->apiStack[slot]), + "Slot must hold a foreign instance."); + + return AS_FOREIGN(vm->apiStack[slot])->data; +} + +const char* wrenGetSlotString(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_STRING(vm->apiStack[slot]), "Slot must hold a string."); + + return AS_CSTRING(vm->apiStack[slot]); +} + +WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + return wrenMakeHandle(vm, vm->apiStack[slot]); +} + +// Stores [value] in [slot] in the foreign call stack. +static void setSlot(WrenVM* vm, int slot, Value value) +{ + validateApiSlot(vm, slot); + vm->apiStack[slot] = value; +} + +void wrenSetSlotBool(WrenVM* vm, int slot, bool value) +{ + setSlot(vm, slot, BOOL_VAL(value)); +} + +void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length) +{ + ASSERT(bytes != NULL, "Byte array cannot be NULL."); + setSlot(vm, slot, wrenNewStringLength(vm, bytes, length)); +} + +void wrenSetSlotDouble(WrenVM* vm, int slot, double value) +{ + setSlot(vm, slot, NUM_VAL(value)); +} + +void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size) +{ + validateApiSlot(vm, slot); + validateApiSlot(vm, classSlot); + ASSERT(IS_CLASS(vm->apiStack[classSlot]), "Slot must hold a class."); + + ObjClass* classObj = AS_CLASS(vm->apiStack[classSlot]); + ASSERT(classObj->numFields == -1, "Class must be a foreign class."); + + ObjForeign* foreign = wrenNewForeign(vm, classObj, size); + vm->apiStack[slot] = OBJ_VAL(foreign); + + return (void*)foreign->data; +} + +void wrenSetSlotNewList(WrenVM* vm, int slot) +{ + setSlot(vm, slot, OBJ_VAL(wrenNewList(vm, 0))); +} + +void wrenSetSlotNewMap(WrenVM* vm, int slot) +{ + setSlot(vm, slot, OBJ_VAL(wrenNewMap(vm))); +} + +void wrenSetSlotNull(WrenVM* vm, int slot) +{ + setSlot(vm, slot, NULL_VAL); +} + +void wrenSetSlotString(WrenVM* vm, int slot, const char* text) +{ + ASSERT(text != NULL, "String cannot be NULL."); + + setSlot(vm, slot, wrenNewString(vm, text)); +} + +void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle) +{ + ASSERT(handle != NULL, "Handle cannot be NULL."); + + setSlot(vm, slot, handle->value); +} + +int wrenGetListCount(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_LIST(vm->apiStack[slot]), "Slot must hold a list."); + + ValueBuffer elements = AS_LIST(vm->apiStack[slot])->elements; + return elements.count; +} + +void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot) +{ + validateApiSlot(vm, listSlot); + validateApiSlot(vm, elementSlot); + ASSERT(IS_LIST(vm->apiStack[listSlot]), "Slot must hold a list."); + + ValueBuffer elements = AS_LIST(vm->apiStack[listSlot])->elements; + + uint32_t usedIndex = wrenValidateIndex(elements.count, index); + ASSERT(usedIndex != UINT32_MAX, "Index out of bounds."); + + vm->apiStack[elementSlot] = elements.data[usedIndex]; +} + +void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot) +{ + validateApiSlot(vm, listSlot); + validateApiSlot(vm, elementSlot); + ASSERT(IS_LIST(vm->apiStack[listSlot]), "Slot must hold a list."); + + ObjList* list = AS_LIST(vm->apiStack[listSlot]); + + uint32_t usedIndex = wrenValidateIndex(list->elements.count, index); + ASSERT(usedIndex != UINT32_MAX, "Index out of bounds."); + + list->elements.data[usedIndex] = vm->apiStack[elementSlot]; +} + +void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot) +{ + validateApiSlot(vm, listSlot); + validateApiSlot(vm, elementSlot); + ASSERT(IS_LIST(vm->apiStack[listSlot]), "Must insert into a list."); + + ObjList* list = AS_LIST(vm->apiStack[listSlot]); + + // Negative indices count from the end. + // We don't use wrenValidateIndex here because insert allows 1 past the end. + if (index < 0) index = list->elements.count + 1 + index; + + ASSERT(index <= list->elements.count, "Index out of bounds."); + + wrenListInsert(vm, list, vm->apiStack[elementSlot], index); +} + +int wrenGetMapCount(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + ASSERT(IS_MAP(vm->apiStack[slot]), "Slot must hold a map."); + + ObjMap* map = AS_MAP(vm->apiStack[slot]); + return map->count; +} + +bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map."); + + Value key = vm->apiStack[keySlot]; + ASSERT(wrenMapIsValidKey(key), "Key must be a value type"); + if (!validateKey(vm, key)) return false; + + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + Value value = wrenMapGet(map, key); + + return !IS_UNDEFINED(value); +} + +void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + validateApiSlot(vm, valueSlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map."); + + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + Value value = wrenMapGet(map, vm->apiStack[keySlot]); + if (IS_UNDEFINED(value)) { + value = NULL_VAL; + } + + vm->apiStack[valueSlot] = value; +} + +void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + validateApiSlot(vm, valueSlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Must insert into a map."); + + Value key = vm->apiStack[keySlot]; + ASSERT(wrenMapIsValidKey(key), "Key must be a value type"); + + if (!validateKey(vm, key)) { + return; + } + + Value value = vm->apiStack[valueSlot]; + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + + wrenMapSet(vm, map, key, value); +} + +void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot, + int removedValueSlot) +{ + validateApiSlot(vm, mapSlot); + validateApiSlot(vm, keySlot); + ASSERT(IS_MAP(vm->apiStack[mapSlot]), "Slot must hold a map."); + + Value key = vm->apiStack[keySlot]; + if (!validateKey(vm, key)) { + return; + } + + ObjMap* map = AS_MAP(vm->apiStack[mapSlot]); + Value removed = wrenMapRemoveKey(vm, map, key); + setSlot(vm, removedValueSlot, removed); +} + +void wrenGetVariable(WrenVM* vm, const char* module, const char* name, + int slot) +{ + ASSERT(module != NULL, "Module cannot be NULL."); + ASSERT(name != NULL, "Variable name cannot be NULL."); + + Value moduleName = wrenStringFormat(vm, "$", module); + wrenPushRoot(vm, AS_OBJ(moduleName)); + + ObjModule* moduleObj = getModule(vm, moduleName); + ASSERT(moduleObj != NULL, "Could not find module."); + + wrenPopRoot(vm); // moduleName. + + int variableSlot = wrenSymbolTableFind(&moduleObj->variableNames, + name, strlen(name)); + ASSERT(variableSlot != -1, "Could not find variable."); + + setSlot(vm, slot, moduleObj->variables.data[variableSlot]); +} + +bool wrenHasVariable(WrenVM* vm, const char* module, const char* name) +{ + ASSERT(module != NULL, "Module cannot be NULL."); + ASSERT(name != NULL, "Variable name cannot be NULL."); + + Value moduleName = wrenStringFormat(vm, "$", module); + wrenPushRoot(vm, AS_OBJ(moduleName)); + + //We don't use wrenHasModule since we want to use the module object. + ObjModule* moduleObj = getModule(vm, moduleName); + ASSERT(moduleObj != NULL, "Could not find module."); + + wrenPopRoot(vm); // moduleName. + + int variableSlot = wrenSymbolTableFind(&moduleObj->variableNames, + name, strlen(name)); + + return variableSlot != -1; +} + +bool wrenHasModule(WrenVM* vm, const char* module) +{ + ASSERT(module != NULL, "Module cannot be NULL."); + + Value moduleName = wrenStringFormat(vm, "$", module); + wrenPushRoot(vm, AS_OBJ(moduleName)); + + ObjModule* moduleObj = getModule(vm, moduleName); + + wrenPopRoot(vm); // moduleName. + + return moduleObj != NULL; +} + +void wrenAbortFiber(WrenVM* vm, int slot) +{ + validateApiSlot(vm, slot); + vm->fiber->error = vm->apiStack[slot]; +} + +void* wrenGetUserData(WrenVM* vm) +{ + return vm->config.userData; +} + +void wrenSetUserData(WrenVM* vm, void* userData) +{ + vm->config.userData = userData; +} +// End file "wren_vm.c" +// Begin file "wren_opt_random.c" + +#if WREN_OPT_RANDOM + +#include +#include + + +// Begin file "wren_opt_random.wren.inc" +// Generated automatically from src/optional/wren_opt_random.wren. Do not edit. +static const char* randomModuleSource = +"foreign class Random {\n" +" construct new() {\n" +" seed_()\n" +" }\n" +"\n" +" construct new(seed) {\n" +" if (seed is Num) {\n" +" seed_(seed)\n" +" } else if (seed is Sequence) {\n" +" if (seed.isEmpty) Fiber.abort(\"Sequence cannot be empty.\")\n" +"\n" +" // TODO: Empty sequence.\n" +" var seeds = []\n" +" for (element in seed) {\n" +" if (!(element is Num)) Fiber.abort(\"Sequence elements must all be numbers.\")\n" +"\n" +" seeds.add(element)\n" +" if (seeds.count == 16) break\n" +" }\n" +"\n" +" // Cycle the values to fill in any missing slots.\n" +" var i = 0\n" +" while (seeds.count < 16) {\n" +" seeds.add(seeds[i])\n" +" i = i + 1\n" +" }\n" +"\n" +" seed_(\n" +" seeds[0], seeds[1], seeds[2], seeds[3],\n" +" seeds[4], seeds[5], seeds[6], seeds[7],\n" +" seeds[8], seeds[9], seeds[10], seeds[11],\n" +" seeds[12], seeds[13], seeds[14], seeds[15])\n" +" } else {\n" +" Fiber.abort(\"Seed must be a number or a sequence of numbers.\")\n" +" }\n" +" }\n" +"\n" +" foreign seed_()\n" +" foreign seed_(seed)\n" +" foreign seed_(n1, n2, n3, n4, n5, n6, n7, n8, n9, n10, n11, n12, n13, n14, n15, n16)\n" +"\n" +" foreign float()\n" +" float(end) { float() * end }\n" +" float(start, end) { float() * (end - start) + start }\n" +"\n" +" foreign int()\n" +" int(end) { (float() * end).floor }\n" +" int(start, end) { (float() * (end - start)).floor + start }\n" +"\n" +" sample(list) {\n" +" if (list.count == 0) Fiber.abort(\"Not enough elements to sample.\")\n" +" return list[int(list.count)]\n" +" }\n" +" sample(list, count) {\n" +" if (count > list.count) Fiber.abort(\"Not enough elements to sample.\")\n" +"\n" +" var result = []\n" +"\n" +" // The algorithm described in \"Programming pearls: a sample of brilliance\".\n" +" // Use a hash map for sample sizes less than 1/4 of the population size and\n" +" // an array of booleans for larger samples. This simple heuristic improves\n" +" // performance for large sample sizes as well as reduces memory usage.\n" +" if (count * 4 < list.count) {\n" +" var picked = {}\n" +" for (i in list.count - count...list.count) {\n" +" var index = int(i + 1)\n" +" if (picked.containsKey(index)) index = i\n" +" picked[index] = true\n" +" result.add(list[index])\n" +" }\n" +" } else {\n" +" var picked = List.filled(list.count, false)\n" +" for (i in list.count - count...list.count) {\n" +" var index = int(i + 1)\n" +" if (picked[index]) index = i\n" +" picked[index] = true\n" +" result.add(list[index])\n" +" }\n" +" }\n" +"\n" +" return result\n" +" }\n" +"\n" +" shuffle(list) {\n" +" if (list.isEmpty) return\n" +"\n" +" // Fisher-Yates shuffle.\n" +" for (i in 0...list.count - 1) {\n" +" var from = int(i, list.count)\n" +" var temp = list[from]\n" +" list[from] = list[i]\n" +" list[i] = temp\n" +" }\n" +" }\n" +"}\n"; +// End file "wren_opt_random.wren.inc" + +// Implements the well equidistributed long-period linear PRNG (WELL512a). +// +// https://en.wikipedia.org/wiki/Well_equidistributed_long-period_linear +typedef struct +{ + uint32_t state[16]; + uint32_t index; +} Well512; + +// Code from: http://www.lomont.org/Math/Papers/2008/Lomont_PRNG_2008.pdf +static uint32_t advanceState(Well512* well) +{ + uint32_t a, b, c, d; + a = well->state[well->index]; + c = well->state[(well->index + 13) & 15]; + b = a ^ c ^ (a << 16) ^ (c << 15); + c = well->state[(well->index + 9) & 15]; + c ^= (c >> 11); + a = well->state[well->index] = b ^ c; + d = a ^ ((a << 5) & 0xda442d24U); + + well->index = (well->index + 15) & 15; + a = well->state[well->index]; + well->state[well->index] = a ^ b ^ d ^ (a << 2) ^ (b << 18) ^ (c << 28); + return well->state[well->index]; +} + +static void randomAllocate(WrenVM* vm) +{ + Well512* well = (Well512*)wrenSetSlotNewForeign(vm, 0, 0, sizeof(Well512)); + well->index = 0; +} + +static void randomSeed0(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + srand((uint32_t)time(NULL)); + for (int i = 0; i < 16; i++) + { + well->state[i] = rand(); + } +} + +static void randomSeed1(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + srand((uint32_t)wrenGetSlotDouble(vm, 1)); + for (int i = 0; i < 16; i++) + { + well->state[i] = rand(); + } +} + +static void randomSeed16(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + for (int i = 0; i < 16; i++) + { + well->state[i] = (uint32_t)wrenGetSlotDouble(vm, i + 1); + } +} + +static void randomFloat(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + // A double has 53 bits of precision in its mantissa, and we'd like to take + // full advantage of that, so we need 53 bits of random source data. + + // First, start with 32 random bits, shifted to the left 21 bits. + double result = (double)advanceState(well) * (1 << 21); + + // Then add another 21 random bits. + result += (double)(advanceState(well) & ((1 << 21) - 1)); + + // Now we have a number from 0 - (2^53). Divide be the range to get a double + // from 0 to 1.0 (half-inclusive). + result /= 9007199254740992.0; + + wrenSetSlotDouble(vm, 0, result); +} + +static void randomInt0(WrenVM* vm) +{ + Well512* well = (Well512*)wrenGetSlotForeign(vm, 0); + + wrenSetSlotDouble(vm, 0, (double)advanceState(well)); +} + +const char* wrenRandomSource() +{ + return randomModuleSource; +} + +WrenForeignClassMethods wrenRandomBindForeignClass(WrenVM* vm, + const char* module, + const char* className) +{ + ASSERT(strcmp(className, "Random") == 0, "Should be in Random class."); + WrenForeignClassMethods methods; + methods.allocate = randomAllocate; + methods.finalize = NULL; + return methods; +} + +WrenForeignMethodFn wrenRandomBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature) +{ + ASSERT(strcmp(className, "Random") == 0, "Should be in Random class."); + + if (strcmp(signature, "") == 0) return randomAllocate; + if (strcmp(signature, "seed_()") == 0) return randomSeed0; + if (strcmp(signature, "seed_(_)") == 0) return randomSeed1; + + if (strcmp(signature, "seed_(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)") == 0) + { + return randomSeed16; + } + + if (strcmp(signature, "float()") == 0) return randomFloat; + if (strcmp(signature, "int()") == 0) return randomInt0; + + ASSERT(false, "Unknown method."); + return NULL; +} + +#endif +// End file "wren_opt_random.c" +// Begin file "wren_opt_meta.c" + +#if WREN_OPT_META + +#include + +// Begin file "wren_opt_meta.wren.inc" +// Generated automatically from src/optional/wren_opt_meta.wren. Do not edit. +static const char* metaModuleSource = +"class Meta {\n" +" static getModuleVariables(module) {\n" +" if (!(module is String)) Fiber.abort(\"Module name must be a string.\")\n" +" var result = getModuleVariables_(module)\n" +" if (result != null) return result\n" +"\n" +" Fiber.abort(\"Could not find a module named '%(module)'.\")\n" +" }\n" +"\n" +" static eval(source) {\n" +" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n" +"\n" +" var closure = compile_(source, false, false)\n" +" // TODO: Include compile errors.\n" +" if (closure == null) Fiber.abort(\"Could not compile source code.\")\n" +"\n" +" closure.call()\n" +" }\n" +"\n" +" static compileExpression(source) {\n" +" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n" +" return compile_(source, true, true)\n" +" }\n" +"\n" +" static compile(source) {\n" +" if (!(source is String)) Fiber.abort(\"Source code must be a string.\")\n" +" return compile_(source, false, true)\n" +" }\n" +"\n" +" foreign static compile_(source, isExpression, printErrors)\n" +" foreign static getModuleVariables_(module)\n" +"}\n"; +// End file "wren_opt_meta.wren.inc" + +void metaCompile(WrenVM* vm) +{ + const char* source = wrenGetSlotString(vm, 1); + bool isExpression = wrenGetSlotBool(vm, 2); + bool printErrors = wrenGetSlotBool(vm, 3); + + // TODO: Allow passing in module? + // Look up the module surrounding the callsite. This is brittle. The -2 walks + // up the callstack assuming that the meta module has one level of + // indirection before hitting the user's code. Any change to meta may require + // this constant to be tweaked. + ObjFiber* currentFiber = vm->fiber; + ObjFn* fn = currentFiber->frames[currentFiber->numFrames - 2].closure->fn; + ObjString* module = fn->module->name; + + ObjClosure* closure = wrenCompileSource(vm, module->value, source, + isExpression, printErrors); + + // Return the result. We can't use the public API for this since we have a + // bare ObjClosure*. + if (closure == NULL) + { + vm->apiStack[0] = NULL_VAL; + } + else + { + vm->apiStack[0] = OBJ_VAL(closure); + } +} + +void metaGetModuleVariables(WrenVM* vm) +{ + wrenEnsureSlots(vm, 3); + + Value moduleValue = wrenMapGet(vm->modules, vm->apiStack[1]); + if (IS_UNDEFINED(moduleValue)) + { + vm->apiStack[0] = NULL_VAL; + return; + } + + ObjModule* module = AS_MODULE(moduleValue); + ObjList* names = wrenNewList(vm, module->variableNames.count); + vm->apiStack[0] = OBJ_VAL(names); + + // Initialize the elements to null in case a collection happens when we + // allocate the strings below. + for (int i = 0; i < names->elements.count; i++) + { + names->elements.data[i] = NULL_VAL; + } + + for (int i = 0; i < names->elements.count; i++) + { + names->elements.data[i] = OBJ_VAL(module->variableNames.data[i]); + } +} + +const char* wrenMetaSource() +{ + return metaModuleSource; +} + +WrenForeignMethodFn wrenMetaBindForeignMethod(WrenVM* vm, + const char* className, + bool isStatic, + const char* signature) +{ + // There is only one foreign method in the meta module. + ASSERT(strcmp(className, "Meta") == 0, "Should be in Meta class."); + ASSERT(isStatic, "Should be static."); + + if (strcmp(signature, "compile_(_,_,_)") == 0) + { + return metaCompile; + } + + if (strcmp(signature, "getModuleVariables_(_)") == 0) + { + return metaGetModuleVariables; + } + + ASSERT(false, "Unknown method."); + return NULL; +} + +#endif +// End file "wren_opt_meta.c" diff --git a/wren.h b/wren.h new file mode 100644 index 0000000..7845911 --- /dev/null +++ b/wren.h @@ -0,0 +1,554 @@ +#ifndef wren_h +#define wren_h + +#include +#include +#include + +// The Wren semantic version number components. +#define WREN_VERSION_MAJOR 0 +#define WREN_VERSION_MINOR 4 +#define WREN_VERSION_PATCH 0 + +// A human-friendly string representation of the version. +#define WREN_VERSION_STRING "0.4.0" + +// A monotonically increasing numeric representation of the version number. Use +// this if you want to do range checks over versions. +#define WREN_VERSION_NUMBER (WREN_VERSION_MAJOR * 1000000 + \ + WREN_VERSION_MINOR * 1000 + \ + WREN_VERSION_PATCH) + +#ifndef WREN_API + #if defined(_MSC_VER) && defined(WREN_API_DLLEXPORT) + #define WREN_API __declspec( dllexport ) + #else + #define WREN_API + #endif +#endif //WREN_API + +// A single virtual machine for executing Wren code. +// +// Wren has no global state, so all state stored by a running interpreter lives +// here. +typedef struct WrenVM WrenVM; + +// A handle to a Wren object. +// +// This lets code outside of the VM hold a persistent reference to an object. +// After a handle is acquired, and until it is released, this ensures the +// garbage collector will not reclaim the object it references. +typedef struct WrenHandle WrenHandle; + +// A generic allocation function that handles all explicit memory management +// used by Wren. It's used like so: +// +// - To allocate new memory, [memory] is NULL and [newSize] is the desired +// size. It should return the allocated memory or NULL on failure. +// +// - To attempt to grow an existing allocation, [memory] is the memory, and +// [newSize] is the desired size. It should return [memory] if it was able to +// grow it in place, or a new pointer if it had to move it. +// +// - To shrink memory, [memory] and [newSize] are the same as above but it will +// always return [memory]. +// +// - To free memory, [memory] will be the memory to free and [newSize] will be +// zero. It should return NULL. +typedef void* (*WrenReallocateFn)(void* memory, size_t newSize, void* userData); + +// A function callable from Wren code, but implemented in C. +typedef void (*WrenForeignMethodFn)(WrenVM* vm); + +// A finalizer function for freeing resources owned by an instance of a foreign +// class. Unlike most foreign methods, finalizers do not have access to the VM +// and should not interact with it since it's in the middle of a garbage +// collection. +typedef void (*WrenFinalizerFn)(void* data); + +// Gives the host a chance to canonicalize the imported module name, +// potentially taking into account the (previously resolved) name of the module +// that contains the import. Typically, this is used to implement relative +// imports. +typedef const char* (*WrenResolveModuleFn)(WrenVM* vm, + const char* importer, const char* name); + +// Forward declare +struct WrenLoadModuleResult; + +// Called after loadModuleFn is called for module [name]. The original returned result +// is handed back to you in this callback, so that you can free memory if appropriate. +typedef void (*WrenLoadModuleCompleteFn)(WrenVM* vm, const char* name, struct WrenLoadModuleResult result); + +// The result of a loadModuleFn call. +// [source] is the source code for the module, or NULL if the module is not found. +// [onComplete] an optional callback that will be called once Wren is done with the result. +typedef struct WrenLoadModuleResult +{ + const char* source; + WrenLoadModuleCompleteFn onComplete; + void* userData; +} WrenLoadModuleResult; + +// Loads and returns the source code for the module [name]. +typedef WrenLoadModuleResult (*WrenLoadModuleFn)(WrenVM* vm, const char* name); + +// Returns a pointer to a foreign method on [className] in [module] with +// [signature]. +typedef WrenForeignMethodFn (*WrenBindForeignMethodFn)(WrenVM* vm, + const char* module, const char* className, bool isStatic, + const char* signature); + +// Displays a string of text to the user. +typedef void (*WrenWriteFn)(WrenVM* vm, const char* text); + +typedef enum +{ + // A syntax or resolution error detected at compile time. + WREN_ERROR_COMPILE, + + // The error message for a runtime error. + WREN_ERROR_RUNTIME, + + // One entry of a runtime error's stack trace. + WREN_ERROR_STACK_TRACE +} WrenErrorType; + +// Reports an error to the user. +// +// An error detected during compile time is reported by calling this once with +// [type] `WREN_ERROR_COMPILE`, the resolved name of the [module] and [line] +// where the error occurs, and the compiler's error [message]. +// +// A runtime error is reported by calling this once with [type] +// `WREN_ERROR_RUNTIME`, no [module] or [line], and the runtime error's +// [message]. After that, a series of [type] `WREN_ERROR_STACK_TRACE` calls are +// made for each line in the stack trace. Each of those has the resolved +// [module] and [line] where the method or function is defined and [message] is +// the name of the method or function. +typedef void (*WrenErrorFn)( + WrenVM* vm, WrenErrorType type, const char* module, int line, + const char* message); + +typedef struct +{ + // The callback invoked when the foreign object is created. + // + // This must be provided. Inside the body of this, it must call + // [wrenSetSlotNewForeign()] exactly once. + WrenForeignMethodFn allocate; + + // The callback invoked when the garbage collector is about to collect a + // foreign object's memory. + // + // This may be `NULL` if the foreign class does not need to finalize. + WrenFinalizerFn finalize; +} WrenForeignClassMethods; + +// Returns a pair of pointers to the foreign methods used to allocate and +// finalize the data for instances of [className] in resolved [module]. +typedef WrenForeignClassMethods (*WrenBindForeignClassFn)( + WrenVM* vm, const char* module, const char* className); + +typedef struct +{ + // The callback Wren will use to allocate, reallocate, and deallocate memory. + // + // If `NULL`, defaults to a built-in function that uses `realloc` and `free`. + WrenReallocateFn reallocateFn; + + // The callback Wren uses to resolve a module name. + // + // Some host applications may wish to support "relative" imports, where the + // meaning of an import string depends on the module that contains it. To + // support that without baking any policy into Wren itself, the VM gives the + // host a chance to resolve an import string. + // + // Before an import is loaded, it calls this, passing in the name of the + // module that contains the import and the import string. The host app can + // look at both of those and produce a new "canonical" string that uniquely + // identifies the module. This string is then used as the name of the module + // going forward. It is what is passed to [loadModuleFn], how duplicate + // imports of the same module are detected, and how the module is reported in + // stack traces. + // + // If you leave this function NULL, then the original import string is + // treated as the resolved string. + // + // If an import cannot be resolved by the embedder, it should return NULL and + // Wren will report that as a runtime error. + // + // Wren will take ownership of the string you return and free it for you, so + // it should be allocated using the same allocation function you provide + // above. + WrenResolveModuleFn resolveModuleFn; + + // The callback Wren uses to load a module. + // + // Since Wren does not talk directly to the file system, it relies on the + // embedder to physically locate and read the source code for a module. The + // first time an import appears, Wren will call this and pass in the name of + // the module being imported. The method will return a result, which contains + // the source code for that module. Memory for the source is owned by the + // host application, and can be freed using the onComplete callback. + // + // This will only be called once for any given module name. Wren caches the + // result internally so subsequent imports of the same module will use the + // previous source and not call this. + // + // If a module with the given name could not be found by the embedder, it + // should return NULL and Wren will report that as a runtime error. + WrenLoadModuleFn loadModuleFn; + + // The callback Wren uses to find a foreign method and bind it to a class. + // + // When a foreign method is declared in a class, this will be called with the + // foreign method's module, class, and signature when the class body is + // executed. It should return a pointer to the foreign function that will be + // bound to that method. + // + // If the foreign function could not be found, this should return NULL and + // Wren will report it as runtime error. + WrenBindForeignMethodFn bindForeignMethodFn; + + // The callback Wren uses to find a foreign class and get its foreign methods. + // + // When a foreign class is declared, this will be called with the class's + // module and name when the class body is executed. It should return the + // foreign functions uses to allocate and (optionally) finalize the bytes + // stored in the foreign object when an instance is created. + WrenBindForeignClassFn bindForeignClassFn; + + // The callback Wren uses to display text when `System.print()` or the other + // related functions are called. + // + // If this is `NULL`, Wren discards any printed text. + WrenWriteFn writeFn; + + // The callback Wren uses to report errors. + // + // When an error occurs, this will be called with the module name, line + // number, and an error message. If this is `NULL`, Wren doesn't report any + // errors. + WrenErrorFn errorFn; + + // The number of bytes Wren will allocate before triggering the first garbage + // collection. + // + // If zero, defaults to 10MB. + size_t initialHeapSize; + + // After a collection occurs, the threshold for the next collection is + // determined based on the number of bytes remaining in use. This allows Wren + // to shrink its memory usage automatically after reclaiming a large amount + // of memory. + // + // This can be used to ensure that the heap does not get too small, which can + // in turn lead to a large number of collections afterwards as the heap grows + // back to a usable size. + // + // If zero, defaults to 1MB. + size_t minHeapSize; + + // Wren will resize the heap automatically as the number of bytes + // remaining in use after a collection changes. This number determines the + // amount of additional memory Wren will use after a collection, as a + // percentage of the current heap size. + // + // For example, say that this is 50. After a garbage collection, when there + // are 400 bytes of memory still in use, the next collection will be triggered + // after a total of 600 bytes are allocated (including the 400 already in + // use.) + // + // Setting this to a smaller number wastes less memory, but triggers more + // frequent garbage collections. + // + // If zero, defaults to 50. + int heapGrowthPercent; + + // User-defined data associated with the VM. + void* userData; + +} WrenConfiguration; + +typedef enum +{ + WREN_RESULT_SUCCESS, + WREN_RESULT_COMPILE_ERROR, + WREN_RESULT_RUNTIME_ERROR +} WrenInterpretResult; + +// The type of an object stored in a slot. +// +// This is not necessarily the object's *class*, but instead its low level +// representation type. +typedef enum +{ + WREN_TYPE_BOOL, + WREN_TYPE_NUM, + WREN_TYPE_FOREIGN, + WREN_TYPE_LIST, + WREN_TYPE_MAP, + WREN_TYPE_NULL, + WREN_TYPE_STRING, + + // The object is of a type that isn't accessible by the C API. + WREN_TYPE_UNKNOWN +} WrenType; + +// Get the current wren version number. +// +// Can be used to range checks over versions. +WREN_API int wrenGetVersionNumber(); + +// Initializes [configuration] with all of its default values. +// +// Call this before setting the particular fields you care about. +WREN_API void wrenInitConfiguration(WrenConfiguration* configuration); + +// Creates a new Wren virtual machine using the given [configuration]. Wren +// will copy the configuration data, so the argument passed to this can be +// freed after calling this. If [configuration] is `NULL`, uses a default +// configuration. +WREN_API WrenVM* wrenNewVM(WrenConfiguration* configuration); + +// Disposes of all resources is use by [vm], which was previously created by a +// call to [wrenNewVM]. +WREN_API void wrenFreeVM(WrenVM* vm); + +// Immediately run the garbage collector to free unused memory. +WREN_API void wrenCollectGarbage(WrenVM* vm); + +// Runs [source], a string of Wren source code in a new fiber in [vm] in the +// context of resolved [module]. +WREN_API WrenInterpretResult wrenInterpret(WrenVM* vm, const char* module, + const char* source); + +// Creates a handle that can be used to invoke a method with [signature] on +// using a receiver and arguments that are set up on the stack. +// +// This handle can be used repeatedly to directly invoke that method from C +// code using [wrenCall]. +// +// When you are done with this handle, it must be released using +// [wrenReleaseHandle]. +WREN_API WrenHandle* wrenMakeCallHandle(WrenVM* vm, const char* signature); + +// Calls [method], using the receiver and arguments previously set up on the +// stack. +// +// [method] must have been created by a call to [wrenMakeCallHandle]. The +// arguments to the method must be already on the stack. The receiver should be +// in slot 0 with the remaining arguments following it, in order. It is an +// error if the number of arguments provided does not match the method's +// signature. +// +// After this returns, you can access the return value from slot 0 on the stack. +WREN_API WrenInterpretResult wrenCall(WrenVM* vm, WrenHandle* method); + +// Releases the reference stored in [handle]. After calling this, [handle] can +// no longer be used. +WREN_API void wrenReleaseHandle(WrenVM* vm, WrenHandle* handle); + +// The following functions are intended to be called from foreign methods or +// finalizers. The interface Wren provides to a foreign method is like a +// register machine: you are given a numbered array of slots that values can be +// read from and written to. Values always live in a slot (unless explicitly +// captured using wrenGetSlotHandle(), which ensures the garbage collector can +// find them. +// +// When your foreign function is called, you are given one slot for the receiver +// and each argument to the method. The receiver is in slot 0 and the arguments +// are in increasingly numbered slots after that. You are free to read and +// write to those slots as you want. If you want more slots to use as scratch +// space, you can call wrenEnsureSlots() to add more. +// +// When your function returns, every slot except slot zero is discarded and the +// value in slot zero is used as the return value of the method. If you don't +// store a return value in that slot yourself, it will retain its previous +// value, the receiver. +// +// While Wren is dynamically typed, C is not. This means the C interface has to +// support the various types of primitive values a Wren variable can hold: bool, +// double, string, etc. If we supported this for every operation in the C API, +// there would be a combinatorial explosion of functions, like "get a +// double-valued element from a list", "insert a string key and double value +// into a map", etc. +// +// To avoid that, the only way to convert to and from a raw C value is by going +// into and out of a slot. All other functions work with values already in a +// slot. So, to add an element to a list, you put the list in one slot, and the +// element in another. Then there is a single API function wrenInsertInList() +// that takes the element out of that slot and puts it into the list. +// +// The goal of this API is to be easy to use while not compromising performance. +// The latter means it does not do type or bounds checking at runtime except +// using assertions which are generally removed from release builds. C is an +// unsafe language, so it's up to you to be careful to use it correctly. In +// return, you get a very fast FFI. + +// Returns the number of slots available to the current foreign method. +WREN_API int wrenGetSlotCount(WrenVM* vm); + +// Ensures that the foreign method stack has at least [numSlots] available for +// use, growing the stack if needed. +// +// Does not shrink the stack if it has more than enough slots. +// +// It is an error to call this from a finalizer. +WREN_API void wrenEnsureSlots(WrenVM* vm, int numSlots); + +// Gets the type of the object in [slot]. +WREN_API WrenType wrenGetSlotType(WrenVM* vm, int slot); + +// Reads a boolean value from [slot]. +// +// It is an error to call this if the slot does not contain a boolean value. +WREN_API bool wrenGetSlotBool(WrenVM* vm, int slot); + +// Reads a byte array from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// Returns a pointer to the first byte of the array and fill [length] with the +// number of bytes in the array. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotBytes(WrenVM* vm, int slot, int* length); + +// Reads a number from [slot]. +// +// It is an error to call this if the slot does not contain a number. +WREN_API double wrenGetSlotDouble(WrenVM* vm, int slot); + +// Reads a foreign object from [slot] and returns a pointer to the foreign data +// stored with it. +// +// It is an error to call this if the slot does not contain an instance of a +// foreign class. +WREN_API void* wrenGetSlotForeign(WrenVM* vm, int slot); + +// Reads a string from [slot]. +// +// The memory for the returned string is owned by Wren. You can inspect it +// while in your foreign method, but cannot keep a pointer to it after the +// function returns, since the garbage collector may reclaim it. +// +// It is an error to call this if the slot does not contain a string. +WREN_API const char* wrenGetSlotString(WrenVM* vm, int slot); + +// Creates a handle for the value stored in [slot]. +// +// This will prevent the object that is referred to from being garbage collected +// until the handle is released by calling [wrenReleaseHandle()]. +WREN_API WrenHandle* wrenGetSlotHandle(WrenVM* vm, int slot); + +// Stores the boolean [value] in [slot]. +WREN_API void wrenSetSlotBool(WrenVM* vm, int slot, bool value); + +// Stores the array [length] of [bytes] in [slot]. +// +// The bytes are copied to a new string within Wren's heap, so you can free +// memory used by them after this is called. +WREN_API void wrenSetSlotBytes(WrenVM* vm, int slot, const char* bytes, size_t length); + +// Stores the numeric [value] in [slot]. +WREN_API void wrenSetSlotDouble(WrenVM* vm, int slot, double value); + +// Creates a new instance of the foreign class stored in [classSlot] with [size] +// bytes of raw storage and places the resulting object in [slot]. +// +// This does not invoke the foreign class's constructor on the new instance. If +// you need that to happen, call the constructor from Wren, which will then +// call the allocator foreign method. In there, call this to create the object +// and then the constructor will be invoked when the allocator returns. +// +// Returns a pointer to the foreign object's data. +WREN_API void* wrenSetSlotNewForeign(WrenVM* vm, int slot, int classSlot, size_t size); + +// Stores a new empty list in [slot]. +WREN_API void wrenSetSlotNewList(WrenVM* vm, int slot); + +// Stores a new empty map in [slot]. +WREN_API void wrenSetSlotNewMap(WrenVM* vm, int slot); + +// Stores null in [slot]. +WREN_API void wrenSetSlotNull(WrenVM* vm, int slot); + +// Stores the string [text] in [slot]. +// +// The [text] is copied to a new string within Wren's heap, so you can free +// memory used by it after this is called. The length is calculated using +// [strlen()]. If the string may contain any null bytes in the middle, then you +// should use [wrenSetSlotBytes()] instead. +WREN_API void wrenSetSlotString(WrenVM* vm, int slot, const char* text); + +// Stores the value captured in [handle] in [slot]. +// +// This does not release the handle for the value. +WREN_API void wrenSetSlotHandle(WrenVM* vm, int slot, WrenHandle* handle); + +// Returns the number of elements in the list stored in [slot]. +WREN_API int wrenGetListCount(WrenVM* vm, int slot); + +// Reads element [index] from the list in [listSlot] and stores it in +// [elementSlot]. +WREN_API void wrenGetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Sets the value stored at [index] in the list at [listSlot], +// to the value from [elementSlot]. +WREN_API void wrenSetListElement(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Takes the value stored at [elementSlot] and inserts it into the list stored +// at [listSlot] at [index]. +// +// As in Wren, negative indexes can be used to insert from the end. To append +// an element, use `-1` for the index. +WREN_API void wrenInsertInList(WrenVM* vm, int listSlot, int index, int elementSlot); + +// Returns the number of entries in the map stored in [slot]. +WREN_API int wrenGetMapCount(WrenVM* vm, int slot); + +// Returns true if the key in [keySlot] is found in the map placed in [mapSlot]. +WREN_API bool wrenGetMapContainsKey(WrenVM* vm, int mapSlot, int keySlot); + +// Retrieves a value with the key in [keySlot] from the map in [mapSlot] and +// stores it in [valueSlot]. +WREN_API void wrenGetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Takes the value stored at [valueSlot] and inserts it into the map stored +// at [mapSlot] with key [keySlot]. +WREN_API void wrenSetMapValue(WrenVM* vm, int mapSlot, int keySlot, int valueSlot); + +// Removes a value from the map in [mapSlot], with the key from [keySlot], +// and place it in [removedValueSlot]. If not found, [removedValueSlot] is +// set to null, the same behaviour as the Wren Map API. +WREN_API void wrenRemoveMapValue(WrenVM* vm, int mapSlot, int keySlot, + int removedValueSlot); + +// Looks up the top level variable with [name] in resolved [module] and stores +// it in [slot]. +WREN_API void wrenGetVariable(WrenVM* vm, const char* module, const char* name, + int slot); + +// Looks up the top level variable with [name] in resolved [module], +// returns false if not found. The module must be imported at the time, +// use wrenHasModule to ensure that before calling. +WREN_API bool wrenHasVariable(WrenVM* vm, const char* module, const char* name); + +// Returns true if [module] has been imported/resolved before, false if not. +WREN_API bool wrenHasModule(WrenVM* vm, const char* module); + +// Sets the current fiber to be aborted, and uses the value in [slot] as the +// runtime error object. +WREN_API void wrenAbortFiber(WrenVM* vm, int slot); + +// Returns the user data associated with the WrenVM. +WREN_API void* wrenGetUserData(WrenVM* vm); + +// Sets user data associated with the WrenVM. +WREN_API void wrenSetUserData(WrenVM* vm, void* userData); + +#endif diff --git a/wren3 b/wren3 new file mode 100755 index 0000000..114f72a Binary files /dev/null and b/wren3 differ