diff --git a/CMakeLists.txt b/CMakeLists.txt index 579f779..b8a45a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,13 +4,11 @@ project(sinja) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -march=native -flto") -add_executable(sinja sinja.cpp) # Assuming your file is sinja.cpp +add_executable(sinja sinja.cpp) -# Add this line to tell the compiler where to find Inja and JSON headers +# Include directories for Inja and JSON target_include_directories(sinja PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/external") -target_link_libraries(sinja PRIVATE pthread) - - - +# Link with required libraries +target_link_libraries(sinja PRIVATE pthread sqlite3 cmark-gfm) diff --git a/build.sh b/build.sh index 9038f18..f0870b5 100755 --- a/build.sh +++ b/build.sh @@ -1,10 +1,10 @@ #sudo apt-get update #sudo apt-get install -y build-essential cmake - +sudo apt install -y cmark-gfm libcmark-gfm-dev mkdir -p build && cd build cmake -DCMAKE_BUILD_TYPE=Release .. cmake --build . -j ulimit -n 65536 # $(nproc) -./sinja --templates /home/retoor/projects/sinja/templates --address 0.0.0.0 --port 8083 --threads 1 +./sinja --templates /home/retoor/projects/sinja/templates --address 0.0.0.0 --port 8083 --threads 8 diff --git a/sinja.cpp b/sinja.cpp index b023ba5..6c00fb4 100644 --- a/sinja.cpp +++ b/sinja.cpp @@ -1,9 +1,11 @@ // sinja: blazing-fast, stable, production-grade JSON templating REST server // Architecture: Multi-threaded SO_REUSEPORT with robust, high-performance request handling. // Dependency: inja (which includes nlohmann::json) -// Build: see CMakeLists.txt (ensure you link with -lpthread) -// Run : ./sinja --templates /path/to/templates - +// Extra deps added for template callbacks: +// - SQLite3 (libsqlite3) [for scalar/query/paginate] +// - cmark-gfm [for markdown(md) -> HTML] +// Build: see CMakeLists.txt (ensure you link with -lpthread, -lsqlite3, -lcmark-gfm) +// Run : ./sinja --templates /path/to/templates #include #include #include @@ -29,131 +31,252 @@ #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + using json = nlohmann::json; namespace fs = std::filesystem; // ---- Tunables ---------------------------------------------------------------- static constexpr const char* kServerName = "sinja/5.0-stable"; -static constexpr size_t kMaxBodyBytes = 8 * 1024 * 1024; +static constexpr size_t kMaxBodyBytes = 8 * 1024 * 1024; static constexpr size_t kMaxHeaderBytes = 64 * 1024; -static constexpr int kRecvTimeoutSec = 10; -static constexpr int kSendTimeoutSec = 10; -static constexpr int kBacklog = 8192; // Increased for modern kernels +static constexpr int kRecvTimeoutSec = 0; // Disabled to prevent connection timeouts. +static constexpr int kSendTimeoutSec = 0; // Disabled to prevent connection timeouts. +static constexpr int kBacklog = 8192; +static constexpr size_t kMaxTemplateNameLen = 256; // ---- Global state for graceful shutdown -------------------------------------- static std::atomic g_running{true}; -static void on_signal(int){ g_running.store(false); } +static void on_signal(int) { g_running.store(false); } // ---- Config ------------------------------------------------------------------ struct Settings { - std::string address = "0.0.0.0"; - int port = 8080; - size_t threads = std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : 4; - fs::path template_root; + std::string address = "0.0.0.0"; + int port = 8080; + bool debug = true; + size_t threads = std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : 4; + fs::path template_root; }; -// ---- Small utils ------------------------------------------------------------- -static inline std::string ltrim(std::string s){ - size_t i=0; while(idebug) { + fprintf(stdout, C_BLU "[DEBUG] " C_RST); + va_list args; + va_start(args, format); + vfprintf(stdout, format, args); + va_end(args); + fprintf(stdout, "\n"); + fflush(stdout); + } } -// ---- Secure template validation ---------------------------------------------- -static bool validate_template_path(const std::string& req) { - return !(req.empty() || req[0] == '/' || req.find("..") != std::string::npos); +static void log_info(const char* format, ...) { + if (g_settings && g_settings->debug) { + fprintf(stdout, C_WHT "[INFO] " C_RST); + va_list args; + va_start(args, format); + vfprintf(stdout, format, args); + va_end(args); + fprintf(stdout, "\n"); + fflush(stdout); + } +} + +static void log_warning(const char* format, ...) { + if (g_settings && g_settings->debug) { + fprintf(stderr, C_ORG "[WARN] " C_RST); + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fprintf(stderr, "\n"); + fflush(stderr); + } +} + +static void log_error(const char* format, ...) { + if (g_settings && g_settings->debug) { + fprintf(stderr, C_RED "[ERROR] " C_RST); + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fprintf(stderr, "\n"); + fflush(stderr); + } +} + +// ---- Small utils ------------------------------------------------------------- +static inline std::string ltrim(std::string s) { + size_t i = 0; + while (i < s.size() && std::isspace((unsigned char)s[i])) ++i; + return s.substr(i); +} + +static inline std::string rtrim(std::string s) { + if (s.empty()) return s; + size_t i = s.size() - 1; + while (i < s.size() && std::isspace((unsigned char)s[i])) { + if (i == 0) return ""; + --i; + } + return s.substr(0, i + 1); +} + +static inline std::string trim(std::string s) { return rtrim(ltrim(std::move(s))); } + +static std::string to_lower(std::string s) { + for (auto& c : s) c = (char)std::tolower((unsigned char)c); + return s; +} + +// ---- Secure template validation (FIXED) ------------------------------------- +static bool validate_template_path(const std::string& req, const fs::path& root) { + if (req.empty() || req.size() > kMaxTemplateNameLen) return false; + + // Block absolute paths and parent directory references + if (req[0] == '/' || req[0] == '\\') return false; + if (req.find("..") != std::string::npos) return false; + if (req.find("./") != std::string::npos) return false; + if (req.find("\\") != std::string::npos) return false; + + // Only allow alphanumeric, dash, underscore, slash, and dot for extension + static const std::regex valid_chars("^[a-zA-Z0-9_/-]+(\\.\\w+)?$"); + if (!std::regex_match(req, valid_chars)) return false; + + // Verify the resolved path is within template root + try { + fs::path full = root / req; + fs::path canonical_root = fs::weakly_canonical(root); + fs::path canonical_full = fs::weakly_canonical(full); + + // Must be under root directory + auto [root_end, nothing] = std::mismatch(canonical_root.begin(), canonical_root.end(), + canonical_full.begin(), canonical_full.end()); + return root_end == canonical_root.end(); + } catch (...) { + return false; + } } // ---- HTTP primitives --------------------------------------------------------- struct Request { - std::string method, target, version; - std::map headers; - std::string body; - bool keep_alive = true; + std::string method, target, version; + std::map headers; + std::string body; + bool keep_alive = true; }; + struct Response { - int status = 200; - std::string reason = "OK"; - std::vector> headers; - std::string body; + int status = 200; + std::string reason = "OK"; + std::vector> headers; + std::string body; }; -static std::string status_reason(int code){ - switch(code){ - case 200: return "OK"; case 400: return "Bad Request"; case 404: return "Not Found"; - case 411: return "Length Required"; case 413: return "Payload Too Large"; - case 415: return "Unsupported Media Type"; case 431: return "Request Header Fields Too Large"; - case 500: return "Internal Server Error"; default: return "Error"; - } +static std::string status_reason(int code) { + switch (code) { + case 200: return "OK"; + case 400: return "Bad Request"; + case 404: return "Not Found"; + case 405: return "Method Not Allowed"; + case 411: return "Length Required"; + case 413: return "Payload Too Large"; + case 415: return "Unsupported Media Type"; + case 431: return "Request Header Fields Too Large"; + case 500: return "Internal Server Error"; + default: return "Error"; + } } -static void set_socket_timeouts(int fd, int recv_sec, int send_sec){ - timeval tv{.tv_sec = recv_sec, .tv_usec = 0}; - setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); - tv = timeval{.tv_sec = send_sec, .tv_usec = 0}; - setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); - int flag = 1; setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); +static void set_socket_timeouts(int fd, int recv_sec, int send_sec) { + timeval tv{.tv_sec = recv_sec, .tv_usec = 0}; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + tv = timeval{.tv_sec = send_sec, .tv_usec = 0}; + setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); + int flag = 1; + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &flag, sizeof(flag)); } -static bool recv_append(int fd, std::string& buf){ - char tmp[8192]; - ssize_t n = ::recv(fd, tmp, sizeof(tmp), 0); - if (n <= 0) return false; - buf.append(tmp, (size_t)n); - return true; +static bool recv_append(int fd, std::string& buf) { + char tmp[8192]; + ssize_t n = ::recv(fd, tmp, sizeof(tmp), 0); + if (n <= 0) return false; + buf.append(tmp, (size_t)n); + return true; } -static bool send_all(int fd, const char* data, size_t len){ - size_t off = 0; - while (off < len) { - ssize_t n = ::send(fd, data + off, len - off, MSG_NOSIGNAL); - if (n < 0) { if (errno == EINTR) continue; return false; } - off += (size_t)n; - } - return true; +static bool send_all(int fd, const char* data, size_t len) { + size_t off = 0; + while (off < len) { + ssize_t n = ::send(fd, data + off, len - off, MSG_NOSIGNAL); + if (n < 0) { + if (errno == EINTR) continue; + return false; + } + off += (size_t)n; + } + return true; } -static std::string build_response(const Response& res, bool keep_alive){ - std::string out; - out.reserve(256 + res.body.size()); - out += "HTTP/1.1 " + std::to_string(res.status) + " " + (res.reason.empty()?status_reason(res.status):res.reason) + "\r\n"; - out += "Server: "; out += kServerName; out += "\r\n"; - out += "Content-Length: " + std::to_string(res.body.size()) + "\r\n"; - out += std::string("Connection: ") + (keep_alive ? "keep-alive" : "close") + "\r\n"; - for (const auto& h : res.headers) { out += h.first; out += ": "; out += h.second; out += "\r\n"; } - out += "\r\n"; - out += res.body; - return out; +static std::string build_response(const Response& res, bool keep_alive) { + std::string out; + out.reserve(256 + res.body.size()); + out += "HTTP/1.1 " + std::to_string(res.status) + " " + (res.reason.empty() ? status_reason(res.status) : res.reason) + "\r\n"; + out += "Server: "; out += kServerName; out += "\r\n"; + out += "Content-Length: " + std::to_string(res.body.size()) + "\r\n"; + out += std::string("Connection: ") + (keep_alive ? "keep-alive" : "close") + "\r\n"; + for (const auto& h : res.headers) { + out += h.first; + out += ": "; + out += h.second; + out += "\r\n"; + } + out += "\r\n"; + out += res.body; + return out; } -static bool is_json_ct(const std::map& h){ - auto it = h.find("content-type"); - if (it == h.end()) return false; - return to_lower(it->second).find("application/json") != std::string::npos; +static bool is_json_ct(const std::map& h) { + auto it = h.find("content-type"); + if (it == h.end()) return false; + return to_lower(it->second).find("application/json") != std::string::npos; } -static std::optional> -parse_chunked_body(const std::string& buf, size_t start_off) { +static std::optional> parse_chunked_body(const std::string& buf, size_t start_off) { size_t p = start_off; std::string out; while (true) { size_t eol = buf.find("\r\n", p); if (eol == std::string::npos) return std::nullopt; std::string size_line = buf.substr(p, eol - p); - size_t sc = size_line.find(';'); if (sc != std::string::npos) size_line.resize(sc); + size_t sc = size_line.find(';'); + if (sc != std::string::npos) size_line.resize(sc); size_line = trim(size_line); if (size_line.empty()) throw std::runtime_error("bad chunk size"); size_t chunk_size = 0; - try { chunk_size = std::stoul(size_line, nullptr, 16); } catch (...) { throw std::runtime_error("bad chunk size"); } + try { + chunk_size = std::stoul(size_line, nullptr, 16); + } catch (...) { + throw std::runtime_error("bad chunk size"); + } p = eol + 2; if (chunk_size == 0) { size_t trailer_end = buf.find("\r\n\r\n", p); @@ -164,7 +287,7 @@ parse_chunked_body(const std::string& buf, size_t start_off) { if (out.size() + chunk_size > kMaxBodyBytes) throw std::runtime_error("payload too large"); out.append(buf.data() + p, chunk_size); p += chunk_size; - if (!(buf[p] == '\r' && buf[p+1] == '\n')) throw std::runtime_error("bad chunk CRLF"); + if (p + 1 >= buf.size() || !(buf[p] == '\r' && buf[p + 1] == '\n')) throw std::runtime_error("bad chunk CRLF"); p += 2; } } @@ -216,14 +339,18 @@ static std::optional parse_request(std::string& inbuf) { size_t content_len = 0; auto cl_it = r.headers.find("content-length"); if (cl_it != r.headers.end()) { - try { content_len = std::stoull(cl_it->second); } catch (...) { throw std::runtime_error("invalid content-length"); } + try { + content_len = std::stoull(cl_it->second); + } catch (...) { + throw std::runtime_error("invalid content-length"); + } if (content_len > kMaxBodyBytes) throw std::runtime_error("payload too large"); } if (inbuf.size() < body_start + content_len) return std::nullopt; r.body = inbuf.substr(body_start, content_len); consumed_len = body_start + content_len; } - + inbuf.erase(0, consumed_len); auto conn_it = r.headers.find("connection"); @@ -235,221 +362,672 @@ static std::optional parse_request(std::string& inbuf) { return r; } -// ---- Core Server Logic ------------------------------------------------------- -class Server { - Settings cfg; - std::vector workers; - std::vector listen_fds; - std::mutex fds_mutex; +// ===================== Template callback infrastructure ================= +// --- Markdown (cmark-gfm) +static std::string md_to_html(const std::string& md) { + cmark_node* doc = cmark_parse_document(md.c_str(), md.size(), CMARK_OPT_DEFAULT); + if (!doc) throw std::runtime_error("cmark_parse_document failed"); + char* html = cmark_render_html(doc, CMARK_OPT_DEFAULT, nullptr); + cmark_node_free(doc); + if (!html) throw std::runtime_error("cmark_render_html failed"); + std::string out(html); + free(html); + return out; +} + +// --- Date helpers +static std::tm parse_date(const std::string& date_str, const std::string& parse_fmt = "%Y-%m-%d") { + std::tm tm{}; + std::istringstream ss(date_str); + ss >> std::get_time(&tm, parse_fmt.c_str()); + if (ss.fail()) throw std::runtime_error("Invalid date: " + date_str); + return tm; +} + +static std::string format_with_locale(const std::tm& tm, const std::string& fmt, const std::string& locale_name) { + std::ostringstream out; + try { + out.imbue(std::locale(locale_name.c_str())); + } catch (...) { + out.imbue(std::locale::classic()); + } + out << std::put_time(&tm, fmt.c_str()); + return out.str(); +} + +// --- SQLite RAII + connection pool (thread-safe) +struct Stmt { + sqlite3_stmt* ptr{}; + ~Stmt() { if (ptr) sqlite3_finalize(ptr); } +}; + +static void bind_params(sqlite3_stmt* stmt, const json& params) { + if (!params.is_array()) return; + for (int i = 0; i < static_cast(params.size()); ++i) { + int idx = i + 1; + const auto& v = params[i]; + if (v.is_null()) sqlite3_bind_null(stmt, idx); + else if (v.is_boolean()) sqlite3_bind_int(stmt, idx, v.get() ? 1 : 0); + else if (v.is_number_integer()) sqlite3_bind_int64(stmt, idx, v.get()); + else if (v.is_number_float()) sqlite3_bind_double(stmt, idx, v.get()); + else sqlite3_bind_text(stmt, idx, v.get_ref().c_str(), -1, SQLITE_TRANSIENT); + } +} + +static json step_scalar(sqlite3_stmt* stmt) { + int rc = sqlite3_step(stmt); + if (rc == SQLITE_ROW) { + switch (sqlite3_column_type(stmt, 0)) { + case SQLITE_NULL: return nullptr; + case SQLITE_INTEGER: return json(sqlite3_column_int64(stmt, 0)); + case SQLITE_FLOAT: return json(sqlite3_column_double(stmt, 0)); + case SQLITE_TEXT: return json(reinterpret_cast(sqlite3_column_text(stmt, 0))); + case SQLITE_BLOB: return ""; + } + } + if (rc == SQLITE_DONE) return nullptr; + throw std::runtime_error("SQLite step_scalar error: " + std::to_string(rc)); +} + +static json step_rows(sqlite3_stmt* stmt) { + json rows = json::array(); + int cols = sqlite3_column_count(stmt); + int rc; + while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) { + json row = json::object(); + for (int c = 0; c < cols; ++c) { + const char* name = sqlite3_column_name(stmt, c); + switch (sqlite3_column_type(stmt, c)) { + case SQLITE_NULL: row[name] = nullptr; break; + case SQLITE_INTEGER: row[name] = sqlite3_column_int64(stmt, c); break; + case SQLITE_FLOAT: row[name] = sqlite3_column_double(stmt, c); break; + case SQLITE_TEXT: row[name] = reinterpret_cast(sqlite3_column_text(stmt, c)); break; + case SQLITE_BLOB: row[name] = ""; break; + } + } + rows.push_back(std::move(row)); + } + if (rc != SQLITE_DONE) throw std::runtime_error("SQLite step_rows error: " + std::to_string(rc)); + return rows; +} + +class ConnectionPool { public: - explicit Server(Settings s): cfg(std::move(s)) {} + struct ConnDeleter { void operator()(sqlite3* db) const noexcept { if (db) sqlite3_close(db); } }; + using ConnPtr = std::unique_ptr; - void start() { - std::cout << "Starting " << kServerName << "...\n"; - for (size_t i = 0; i < cfg.threads; ++i) { - workers.emplace_back([this]{ worker_loop(); }); + ConnectionPool(const std::string& path, int size) : db_path_(path), size_(std::max(1, size)) { + for (int i = 0; i < size_; ++i) pool_.push_back(open_one()); } - std::cout << " Listening on http://" << cfg.address << ":" << cfg.port << " with " << cfg.threads << " workers.\n"; - } - void stop() { - g_running.store(false); - std::cout << "\nShutting down... finishing active connections.\n"; - std::lock_guard lock(fds_mutex); - for (int fd : listen_fds) { - ::shutdown(fd, SHUT_RDWR); - ::close(fd); - } - } + struct Lease { + sqlite3* db{}; + ConnectionPool* owner{}; + Lease() = default; + Lease(sqlite3* d, ConnectionPool* o) : db(d), owner(o) {} + Lease(const Lease&) = delete; + Lease& operator=(const Lease&) = delete; + Lease(Lease&& o) noexcept { db = o.db; owner = o.owner; o.db = nullptr; o.owner = nullptr; } + Lease& operator=(Lease&& o) noexcept { + if (this != &o) { + release(); + db = o.db; + owner = o.owner; + o.db = nullptr; + o.owner = nullptr; + } + return *this; + } + ~Lease() { release(); } + void release() { if (db && owner) owner->give_back(db); db = nullptr; owner = nullptr; } + }; - void join() { - for (auto& t : workers) { - if (t.joinable()) t.join(); + Lease take() { + std::unique_lock lk(mu_); + if (!cv_.wait_for(lk, std::chrono::seconds(5), [&] { return !pool_.empty(); })) { + throw std::runtime_error("Could not acquire database connection in time (pool exhausted)"); + } + sqlite3* db = pool_.back().release(); + pool_.pop_back(); + return Lease{db, this}; } - std::cout << "All workers stopped. Shutdown complete.\n"; - } private: - void send_json_error(int fd, int code, const std::string& msg) { - Response res; - res.status = code; res.reason = status_reason(code); - res.headers.push_back({"Content-Type","application/json; charset=utf-8"}); - res.headers.push_back({"Cache-Control","no-store"}); - res.body = json({{"error", msg}}).dump(); - auto raw = build_response(res, false); - send_all(fd, raw.data(), raw.size()); - } - - void handle_connection(int fd, inja::Environment& env) { - std::string inbuf; - inbuf.reserve(16*1024); - bool keep = true; - - while (keep && g_running.load(std::memory_order_relaxed)) { - std::optional reqOpt; - while (g_running.load(std::memory_order_relaxed)) { - try { - reqOpt = parse_request(inbuf); - if (reqOpt) break; - } catch (const std::exception& e) { - int code = 400; std::string what = e.what(); - if (what == "payload too large") code = 413; - else if (what == "headers too large") code = 431; - send_json_error(fd, code, what); - return; - } - if (!recv_append(fd, inbuf)) return; + ConnPtr open_one() { + sqlite3* db = nullptr; + int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX; + if (sqlite3_open_v2(db_path_.c_str(), &db, flags, nullptr) != SQLITE_OK) { + std::string emsg = db ? sqlite3_errmsg(db) : "unknown"; + if (db) sqlite3_close(db); + throw std::runtime_error("Open DB failed: " + emsg); } - - if (!reqOpt) return; - Request req = std::move(*reqOpt); - keep = req.keep_alive; - - if (!(req.method == "POST" && req.target == "/render")) { - send_json_error(fd, 404, "Not Found"); - continue; - } - if (!is_json_ct(req.headers)) { - send_json_error(fd, 415, "Content-Type must be application/json"); - continue; - } - - json jreq; - try { - if (req.body.empty()) throw std::runtime_error("Request body is empty"); - jreq = json::parse(req.body); - } catch (const std::exception& e) { - send_json_error(fd, 400, std::string("Invalid JSON: ") + e.what()); - continue; - } - - if (!jreq.contains("template") || !jreq["template"].is_string()) { - send_json_error(fd, 400, R"(Missing "template" (string))"); continue; - } - if (jreq.contains("context") && !jreq["context"].is_object()) { - send_json_error(fd, 400, R"("context" must be an object)"); continue; - } - - std::string template_name = jreq["template"].get(); - if (!validate_template_path(template_name)) { - send_json_error(fd, 400, "Template error: invalid path"); continue; - } - - json ctx = jreq.value("context", json::object()); - Response res; - - auto t0 = std::chrono::steady_clock::now(); - try { - res.body = env.render_file(template_name, ctx); - auto t1 = std::chrono::steady_clock::now(); - auto us = std::chrono::duration_cast(t1 - t0).count(); - res.headers.push_back({"Content-Type","text/plain; charset=utf-8"}); - res.headers.push_back({"X-Render-Time-Us", std::to_string(us)}); - } catch (const std::exception& e) { - send_json_error(fd, 500, std::string("Render error: ") + e.what()); - continue; - } - - auto raw = build_response(res, keep); - if (!send_all(fd, raw.data(), raw.size())) return; + exec_sql(db, "PRAGMA journal_mode=WAL;"); + exec_sql(db, "PRAGMA synchronous=NORMAL;"); + exec_sql(db, "PRAGMA foreign_keys=ON;"); + return ConnPtr(db); } - } - void worker_loop() { - inja::Environment env(cfg.template_root.string()); - env.set_trim_blocks(true); - env.set_lstrip_blocks(true); + static void exec_sql(sqlite3* db, const char* sql) { + char* err = nullptr; + if (sqlite3_exec(db, sql, nullptr, nullptr, &err) != SQLITE_OK) { + std::string msg = err ? err : "unknown"; + if (err) sqlite3_free(err); + throw std::runtime_error(std::string("PRAGMA error: ") + msg); + } + } - int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); - if (listen_fd < 0) { perror("socket"); return; } + void give_back(sqlite3* db) { + std::lock_guard lk(mu_); + pool_.push_back(ConnPtr(db)); + cv_.notify_one(); + } - int yes=1; - setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)); + std::string db_path_; + int size_; + std::vector pool_; + std::mutex mu_; + std::condition_variable cv_; +}; + +// Register all callbacks into an Environment +static void register_template_callbacks(inja::Environment& env, + std::shared_ptr pool, + const std::string& default_locale = "nl_NL.UTF-8") { + // Dates + env.add_callback("format_date", 2, [default_locale](inja::Arguments& args) { + const auto date_str = args.at(0)->get(); + const auto fmt = args.at(1)->get(); + const auto tm = parse_date(date_str); + return format_with_locale(tm, fmt, default_locale); + }); + + env.add_callback("format_date_loc", 3, [](inja::Arguments& args) { + const auto date_str = args.at(0)->get(); + const auto fmt = args.at(1)->get(); + const auto loc = args.at(2)->get(); + const auto tm = parse_date(date_str); + return format_with_locale(tm, fmt, loc); + }); + + // Markdown + env.add_callback("markdown", 1, [](inja::Arguments& args) { + const std::string md = args.at(0)->get(); + return md_to_html(md); + }); + + // DB: scalar(sql, params?) + env.add_callback("scalar", [pool](inja::Arguments& args) { + if (args.size() < 1) throw std::runtime_error("scalar requires at least 1 argument"); + std::string sql = args.at(0)->get(); + json params = json::array(); + if (args.size() >= 2) params = args.at(1)->get(); + auto lease = pool->take(); + Stmt st; + if (sqlite3_prepare_v2(lease.db, sql.c_str(), -1, &st.ptr, nullptr) != SQLITE_OK) + throw std::runtime_error(std::string("Prepare scalar failed: ") + sqlite3_errmsg(lease.db)); + bind_params(st.ptr, params); + return step_scalar(st.ptr); + }); + + // DB: query(sql, params?) + env.add_callback("query", [pool](inja::Arguments& args) { + if (args.size() < 1) throw std::runtime_error("query requires at least 1 argument"); + std::string sql = args.at(0)->get(); + json params = json::array(); + if (args.size() >= 2) params = args.at(1)->get(); + auto lease = pool->take(); + Stmt st; + if (sqlite3_prepare_v2(lease.db, sql.c_str(), -1, &st.ptr, nullptr) != SQLITE_OK) + throw std::runtime_error(std::string("Prepare query failed: ") + sqlite3_errmsg(lease.db)); + bind_params(st.ptr, params); + return step_rows(st.ptr); + }); + + // DB: paginate(sql, params?, page, page_size) + env.add_callback("paginate", [pool](inja::Arguments& args) { + if (args.size() < 3 || args.size() > 4) throw std::runtime_error("paginate requires 3 or 4 arguments"); + int idx = 0; + std::string sql = args.at(idx++)->get(); + json params = json::array(); + bool has_params = (args.size() == 4); + if (has_params) params = args.at(idx++)->get(); + int page = args.at(idx++)->get(); + int page_size = args.at(idx++)->get(); + if (page < 1) page = 1; + if (page_size < 1) page_size = 10; + if (page_size > 1000) page_size = 1000; // Max limit + + auto lease = pool->take(); + + // Count total using parameterized query + std::string count_sql = "SELECT COUNT(*) FROM (" + sql + ") AS _subq"; + Stmt stc; + if (sqlite3_prepare_v2(lease.db, count_sql.c_str(), -1, &stc.ptr, nullptr) != SQLITE_OK) + throw std::runtime_error(std::string("Prepare count failed: ") + sqlite3_errmsg(lease.db)); + bind_params(stc.ptr, params); + json total_j = step_scalar(stc.ptr); + long long total = total_j.is_null() ? 0 : + (total_j.is_number_integer() ? total_j.get() + : static_cast(total_j.get())); + + // Prevent integer overflow + long long pages = 1; + if (total > 0 && page_size > 0) { + pages = (total + static_cast(page_size) - 1) / page_size; + if (pages < 1) pages = 1; + } + if (page > pages) page = static_cast(pages); + + long long offset = (static_cast(page) - 1) * page_size; + std::string page_sql = sql + " LIMIT ? OFFSET ?"; + + Stmt stp; + if (sqlite3_prepare_v2(lease.db, page_sql.c_str(), -1, &stp.ptr, nullptr) != SQLITE_OK) + throw std::runtime_error(std::string("Prepare page failed: ") + sqlite3_errmsg(lease.db)); + bind_params(stp.ptr, params); + int base = static_cast(params.is_array() ? params.size() : 0); + sqlite3_bind_int64(stp.ptr, base + 1, page_size); + sqlite3_bind_int64(stp.ptr, base + 2, offset); + json rows = step_rows(stp.ptr); + + json out; + out["rows"] = std::move(rows); + out["total"] = total; + out["pages"] = pages; + out["page"] = page; + out["page_size"] = page_size; + out["has_prev"] = page > 1; + out["has_next"] = page < pages; + out["prev_page"] = page > 1 ? json(page - 1) : json(nullptr); + out["next_page"] = page < pages ? json(page + 1) : json(nullptr); + return out; + }); +} + +// ============================================================================== + +// ---- Core Server Logic ------------------------------------------------------- +class Server { + Settings cfg; + std::vector workers; + std::vector listen_fds; + std::mutex fds_mutex; + + // Shared resources for callbacks + std::shared_ptr db_pool; + std::string default_locale = "nl_NL.UTF-8"; + +public: + explicit Server(Settings s) : cfg(std::move(s)) {} + + void start() { + // Initialize DB pool - check if DB exists, if not skip DB features + const fs::path db_path = (cfg.template_root / "sinja.db"); + int pool_size = static_cast(cfg.threads ? cfg.threads : 4); + + try { + if (fs::exists(db_path)) { + db_pool = std::make_shared(db_path.string(), pool_size); + log_info("Database: %s (connected)", db_path.c_str()); + } else { + log_warning("Database: No sinja.db found (DB features disabled)"); + } + } catch (const std::exception& e) { + log_error("Database: Failed to connect - %s (DB features disabled)", e.what()); + } + + log_info("Starting %s...", kServerName); + if (cfg.debug) { + log_debug("Debug mode is enabled."); + } + for (size_t i = 0; i < cfg.threads; ++i) { + workers.emplace_back([this, i] { + log_debug("Worker thread %zu starting...", i + 1); + worker_loop(); + log_debug("Worker thread %zu finished.", i + 1); + }); + } + log_info("Listening on http://%s:%d with %zu workers.", cfg.address.c_str(), cfg.port, cfg.threads); + } + + void stop() { + g_running.store(false); + log_info("\nShutting down... finishing active connections."); + std::lock_guard lock(fds_mutex); + for (int fd : listen_fds) { + ::shutdown(fd, SHUT_RDWR); + ::close(fd); + } + listen_fds.clear(); + } + + void join() { + for (auto& t : workers) { + if (t.joinable()) t.join(); + } + log_info("All workers stopped. Shutdown complete."); + } + +private: + void send_json_error(int fd, int code, const std::string& msg, bool keep_alive = false) { + log_warning("fd=%d: Sending JSON error %d: %s", fd, code, msg.c_str()); + Response res; + res.status = code; + res.reason = status_reason(code); + res.headers.push_back({"Content-Type", "application/json; charset=utf-8"}); + res.headers.push_back({"Cache-Control", "no-store"}); + res.body = json({{"error", msg}}).dump(); + auto raw = build_response(res, keep_alive); + send_all(fd, raw.data(), raw.size()); + } + + void send_html_error(int fd, int code, const std::string& msg, bool keep_alive = false) { + log_warning("fd=%d: Sending HTML error %d: %s", fd, code, msg.c_str()); + Response res; + res.status = code; + res.reason = status_reason(code); + res.headers.push_back({"Content-Type", "text/html; charset=utf-8"}); + res.headers.push_back({"Cache-Control", "no-store"}); + res.body = "" + std::to_string(code) + " " + status_reason(code) + + "

" + std::to_string(code) + " " + status_reason(code) + + "

" + msg + "

"; + auto raw = build_response(res, keep_alive); + send_all(fd, raw.data(), raw.size()); + } + + void handle_connection(int fd, inja::Environment& env) { + std::string inbuf; + inbuf.reserve(16 * 1024); + bool keep = true; + + while (keep && g_running.load(std::memory_order_relaxed)) { + std::optional reqOpt; + while (g_running.load(std::memory_order_relaxed)) { + try { + reqOpt = parse_request(inbuf); + if (reqOpt) break; + } catch (const std::exception& e) { + int code = 400; + std::string what = e.what(); + if (what == "payload too large") code = 413; + else if (what == "headers too large") code = 431; + log_warning("fd=%d: Parse exception: %s", fd, what.c_str()); + send_json_error(fd, code, what); + return; + } + if (!recv_append(fd, inbuf)) { + if (inbuf.empty()) { + log_debug("fd=%d: Connection closed or timed out by peer.", fd); + } else { + log_warning("fd=%d: Connection closed with partial data.", fd); + } + return; + } + } + + if (!reqOpt) return; + Request req = std::move(*reqOpt); + keep = req.keep_alive; + log_info("fd=%d: %s %s", fd, req.method.c_str(), req.target.c_str()); + + // Parse URL path and query + std::string path = req.target; + size_t query_pos = path.find('?'); + if (query_pos != std::string::npos) { + path = path.substr(0, query_pos); + } + + // Handle GET /[template-name] (NEW) + if (req.method == "GET" && path.size() > 1 && path[0] == '/') { + std::string template_name = path.substr(1); + + // Add .inja extension if not present + if (template_name.find('.') == std::string::npos) { + template_name += ".inja"; + } + + if (!validate_template_path(template_name, cfg.template_root)) { + log_warning("fd=%d: Invalid template path rejected: '%s'", fd, template_name.c_str()); + send_html_error(fd, 400, "Invalid template path", keep); + continue; + } + + // Check if template file exists + fs::path template_path = cfg.template_root / template_name; + if (!fs::exists(template_path)) { + log_warning("fd=%d: Template not found: %s", fd, template_path.c_str()); + send_html_error(fd, 404, "Template not found", keep); + continue; + } + + Response res; + auto t0 = std::chrono::steady_clock::now(); + try { + // Render with empty context for GET requests + json ctx = json::object(); + log_debug("fd=%d: Rendering GET request for template '%s'", fd, template_name.c_str()); + res.body = env.render_file(template_name, ctx); + auto t1 = std::chrono::steady_clock::now(); + auto us = std::chrono::duration_cast(t1 - t0).count(); + log_debug("fd=%d: Rendered '%s' in %ld us", fd, template_name.c_str(), us); + + // Detect content type based on template extension or content + std::string ct = "text/html; charset=utf-8"; + if (template_name.find(".json") != std::string::npos) { + ct = "application/json; charset=utf-8"; + } else if (template_name.find(".xml") != std::string::npos) { + ct = "application/xml; charset=utf-8"; + } else if (template_name.find(".txt") != std::string::npos) { + ct = "text/plain; charset=utf-8"; + } + + res.headers.push_back({"Content-Type", ct}); + res.headers.push_back({"X-Render-Time-Us", std::to_string(us)}); + } catch (const std::exception& e) { + log_error("fd=%d: Render error for '%s': %s", fd, template_name.c_str(), e.what()); + send_html_error(fd, 500, std::string("Render error: ") + e.what(), keep); + continue; + } + + auto raw = build_response(res, keep); + if (!send_all(fd, raw.data(), raw.size())) return; + continue; + } + + // Handle POST /render (existing) + if (req.method == "POST" && req.target == "/render") { + if (!is_json_ct(req.headers)) { + send_json_error(fd, 415, "Content-Type must be application/json", keep); + continue; + } + + json jreq; + try { + if (req.body.empty()) throw std::runtime_error("Request body is empty"); + jreq = json::parse(req.body); + } catch (const std::exception& e) { + send_json_error(fd, 400, std::string("Invalid JSON: ") + e.what() + " (" + req.body + ")", keep); + continue; + } + + if (!jreq.contains("template") || !jreq["template"].is_string()) { + send_json_error(fd, 400, R"(Missing "template" (string))", keep); + continue; + } + if (jreq.contains("context") && !jreq["context"].is_object()) { + send_json_error(fd, 400, R"("context" must be an object)", keep); + continue; + } + + std::string template_name = jreq["template"].get(); + if (!validate_template_path(template_name, cfg.template_root)) { + log_warning("fd=%d: Invalid template path rejected in POST: '%s'", fd, template_name.c_str()); + send_json_error(fd, 400, "Template error: invalid path", keep); + continue; + } + + json ctx = jreq.value("context", json::object()); + Response res; + + auto t0 = std::chrono::steady_clock::now(); + try { + log_debug("fd=%d: Rendering POST request for template '%s'", fd, template_name.c_str()); + res.body = env.render_file(template_name, ctx); + auto t1 = std::chrono::steady_clock::now(); + auto us = std::chrono::duration_cast(t1 - t0).count(); + log_debug("fd=%d: Rendered '%s' in %ld us", fd, template_name.c_str(), us); + res.headers.push_back({"Content-Type", "text/plain; charset=utf-8"}); + res.headers.push_back({"X-Render-Time-Us", std::to_string(us)}); + } catch (const std::exception& e) { + log_error("fd=%d: Render error for '%s': %s", fd, template_name.c_str(), e.what()); + send_json_error(fd, 500, std::string("Render error: ") + e.what(), keep); + continue; + } + + auto raw = build_response(res, keep); + if (!send_all(fd, raw.data(), raw.size())) return; + continue; + } + + // Invalid route + if (req.method == "GET" || req.method == "HEAD") { + send_html_error(fd, 404, "Not Found", keep); + } else { + send_json_error(fd, 404, "Not Found", keep); + } + } + } + + void worker_loop() { + inja::Environment env(cfg.template_root.string()); + env.set_trim_blocks(true); + env.set_lstrip_blocks(true); + + // Register template callbacks for this worker's Environment (if DB available) + if (db_pool) { + register_template_callbacks(env, db_pool, default_locale); + } + + int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd < 0) { log_error("socket: %s", strerror(errno)); return; } + + int yes = 1; + setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)); #ifdef SO_REUSEPORT - setsockopt(listen_fd, SOL_SOCKET, SO_REUSEPORT, &yes, sizeof(yes)); + setsockopt(listen_fd, SOL_SOCKET, SO_REUSEPORT, &yes, sizeof(yes)); #endif - sockaddr_in addr{}; - addr.sin_family = AF_INET; - addr.sin_port = htons(cfg.port); - if (::inet_pton(AF_INET, cfg.address.c_str(), &addr.sin_addr) != 1) { - std::cerr << "Invalid address\n"; ::close(listen_fd); return; - } - if (::bind(listen_fd, (sockaddr*)&addr, sizeof(addr)) < 0) { - perror("bind"); ::close(listen_fd); return; - } - if (::listen(listen_fd, kBacklog) < 0) { - perror("listen"); ::close(listen_fd); return; - } - - { - std::lock_guard lock(fds_mutex); - listen_fds.push_back(listen_fd); - } - - while (g_running.load(std::memory_order_relaxed)) { - int client_fd = ::accept(listen_fd, nullptr, nullptr); - if (client_fd < 0) { - if (g_running.load(std::memory_order_relaxed)) perror("accept"); - break; + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(static_cast(cfg.port)); + if (::inet_pton(AF_INET, cfg.address.c_str(), &addr.sin_addr) != 1) { + log_error("Invalid address: %s", cfg.address.c_str()); + ::close(listen_fd); + return; + } + if (::bind(listen_fd, reinterpret_cast(&addr), sizeof(addr)) < 0) { + log_error("bind: %s", strerror(errno)); + ::close(listen_fd); + return; + } + if (::listen(listen_fd, kBacklog) < 0) { + log_error("listen: %s", strerror(errno)); + ::close(listen_fd); + return; } - set_socket_timeouts(client_fd, kRecvTimeoutSec, kSendTimeoutSec); - handle_connection(client_fd, env); - ::shutdown(client_fd, SHUT_WR); - ::close(client_fd); - } - ::close(listen_fd); - } + { + std::lock_guard lock(fds_mutex); + listen_fds.push_back(listen_fd); + } + + while (g_running.load(std::memory_order_relaxed)) { + int client_fd = ::accept(listen_fd, nullptr, nullptr); + if (client_fd < 0) { + if (errno == EINTR) continue; + if (g_running.load(std::memory_order_relaxed) && errno != EBADF) { + log_warning("accept failed: %s", strerror(errno)); + } + break; + } + + // Handle each connection in a new thread to run in parallel. + std::thread([this, client_fd, &env] { + log_debug("Accepted new connection, fd=%d", client_fd); + set_socket_timeouts(client_fd, kRecvTimeoutSec, kSendTimeoutSec); + handle_connection(client_fd, env); + ::shutdown(client_fd, SHUT_WR); + ::close(client_fd); + log_debug("Closed connection, fd=%d", client_fd); + }).detach(); + } + + ::close(listen_fd); + } }; // ---- Args & main ------------------------------------------------------------- -static Settings parse_args(int argc, char** argv){ - Settings s; - for (int i=1;istd::string{ if (i+1>=argc) throw std::runtime_error("Missing value for "+a); return std::string(argv[++i]); }; - if (a=="--templates"||a=="-t") s.template_root = next(); - else if (a=="--address"||a=="-a") s.address = next(); - else if (a=="--port"||a=="-p") s.port = std::stoi(next()); - else if (a=="--threads"||a=="-w") s.threads = (size_t)std::stoul(next()); - else if (a=="--help"||a=="-h") { - std::cout << "Usage: sinja --templates DIR [options]\n" - " --address, -a Bind address (default 0.0.0.0)\n" - " --port, -p Port (default 8080)\n" - " --threads, -w Worker threads (default HW concurrency)\n"; - std::exit(0); - } else throw std::runtime_error("Unknown arg: " + a); - } - if (s.template_root.empty()) throw std::runtime_error("Missing required --templates DIR"); - if (!fs::is_directory(s.template_root)) throw std::runtime_error("Template directory invalid: " + s.template_root.string()); - (s.template_root) = fs::weakly_canonical(s.template_root); - return s; -} - -int main(int argc, char** argv){ - try { - auto cfg = parse_args(argc, argv); - std::signal(SIGPIPE, SIG_IGN); - std::signal(SIGINT, on_signal); - std::signal(SIGTERM, on_signal); - - Server srv(std::move(cfg)); - srv.start(); - - while (g_running.load()) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); +static Settings parse_args(int argc, char** argv) { + Settings s; + for (int i = 1; i < argc; i++) { + std::string a = argv[i]; + auto next = [&]() -> std::string { + if (i + 1 >= argc) throw std::runtime_error("Missing value for " + a); + return std::string(argv[++i]); + }; + if (a == "--templates" || a == "-t") s.template_root = next(); + else if (a == "--address" || a == "-a") s.address = next(); + else if (a == "--port" || a == "-p") { + int port = std::stoi(next()); + if (port < 1 || port > 65535) throw std::runtime_error("Invalid port number"); + s.port = port; + } + else if (a == "--threads" || a == "-w") { + size_t threads = static_cast(std::stoul(next())); + if (threads < 1 || threads > 1000) throw std::runtime_error("Invalid thread count"); + s.threads = threads; + } + else if (a == "--help" || a == "-h") { + std::cout << "Usage: sinja --templates DIR [options]\n" + " --address, -a Bind address (default 0.0.0.0)\n" + " --port, -p Port (default 8080)\n" + " --threads, -w Worker threads (default HW concurrency)\n" + "\nEndpoints:\n" + " GET /[template-name] Render template with empty context\n" + " POST /render Render template with JSON context\n"; + std::exit(0); + } else throw std::runtime_error("Unknown arg: " + a); } - - srv.stop(); - srv.join(); - - } catch (const std::exception& e) { - std::cerr << "Fatal: " << e.what() << "\n"; - return 1; - } - return 0; + if (s.template_root.empty()) throw std::runtime_error("Missing required --templates DIR"); + if (!fs::is_directory(s.template_root)) throw std::runtime_error("Template directory invalid: " + s.template_root.string()); + s.template_root = fs::weakly_canonical(s.template_root); + return s; } +int main(int argc, char** argv) { + try { + auto cfg = parse_args(argc, argv); + g_settings = &cfg; // Make settings globally available for logging + std::signal(SIGPIPE, SIG_IGN); + std::signal(SIGINT, on_signal); + std::signal(SIGTERM, on_signal); + + Server srv(std::move(cfg)); + srv.start(); + + while (g_running.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + srv.stop(); + srv.join(); + + } catch (const std::exception& e) { + // Manually print error as logging might not be initialized + fprintf(stderr, C_RED "[FATAL] " C_RST "%s\n", e.what()); + return 1; + } + return 0; +}