586 lines
19 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import asyncio
import logging
import os
import sys
import subprocess
import urllib.parse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import pathlib
import ast
from types import SimpleNamespace
import time
from aiohttp import web
import aiohttp
# --- Optional external dependency used in your original code ---
from pr.ads import AsyncDataSet # noqa: E402
import sys
from http import cookies
import urllib.parse
import os
import io
import json
import urllib.request
import pathlib
import pickle
import zlib
import asyncio
import time
from functools import wraps
from pathlib import Path
import pickle
import zlib
class Cache:
def __init__(self, base_dir: Path | str = "."):
self.base_dir = Path(base_dir).resolve()
self.base_dir.mkdir(exist_ok=True, parents=True)
def is_cacheable(self, obj):
return isinstance(obj, (int, str, bool, float))
def generate_key(self, *args, **kwargs):
return zlib.crc32(
json.dumps(
{
"args": [arg for arg in args if self.is_cacheable(arg)],
"kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)},
},
sort_keys=True,
default=str,
).encode()
)
def set(self, key, value):
key = self.generate_key(key)
data = {"value": value, "timestamp": time.time()}
serialized_data = pickle.dumps(data)
with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f:
f.write(serialized_data)
def get(self, key, default=None):
key = self.generate_key(key)
try:
with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f:
data = pickle.loads(f.read())
return data
except FileNotFoundError:
return default
def is_cacheable(self, obj):
return isinstance(obj, (int, str, bool, float))
def cached(self, expiration=60):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()):
# return await func(*args, **kwargs)
key = self.generate_key(*args, **kwargs)
print("Cache hit:", key)
cached_data = self.get(key)
if cached_data is not None:
if expiration is None or (time.time() - cached_data["timestamp"]) < expiration:
return cached_data["value"]
result = await func(*args, **kwargs)
self.set(key, result)
return result
return wrapper
return decorator
cache = Cache()
class CGI:
def __init__(self):
self.environ = os.environ
if not self.gateway_interface == "CGI/1.1":
return
self.status = 200
self.headers = {}
self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi"))
self.headers["Content-Length"] = "0"
self.headers["Cache-Control"] = "no-cache"
self.headers["Access-Control-Allow-Origin"] = "*"
self.headers["Content-Type"] = "application/json; charset=utf-8"
self.cookies = cookies.SimpleCookie()
self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"])
self.file = io.BytesIO()
def validate(self, val, fn, error):
if fn(val):
return True
self.status = 421
self.write({"type": "error", "message": error})
exit()
def __getitem__(self, key):
return self.get(key)
def get(self, key, default=None):
result = self.query.get(key, [default])
if len(result) == 1:
return result[0]
return result
def get_bool(self, key):
return str(self.get(key)).lower() in ["true", "yes", "1"]
@property
def cache_key(self):
return self.environ["CACHE_KEY"]
@property
def env(self):
env = os.environ.copy()
return env
@property
def gateway_interface(self):
return self.env.get("GATEWAY_INTERFACE", "")
@property
def request_method(self):
return self.env["REQUEST_METHOD"]
@property
def query_string(self):
return self.env["QUERY_STRING"]
@property
def script_name(self):
return self.env["SCRIPT_NAME"]
@property
def path_info(self):
return self.env["PATH_INFO"]
@property
def server_name(self):
return self.env["SERVER_NAME"]
@property
def server_port(self):
return self.env["SERVER_PORT"]
@property
def server_protocol(self):
return self.env["SERVER_PROTOCOL"]
@property
def remote_addr(self):
return self.env["REMOTE_ADDR"]
def validate_get(self, key):
self.validate(self.get(key), lambda x: bool(x), f"Missing {key}")
def print(self, data):
self.write(data)
def write(self, data):
if not isinstance(data, bytes) and not isinstance(data, str):
data = json.dumps(data, default=str, indent=4)
try:
data = data.encode()
except:
pass
self.file.write(data)
self.headers["Content-Length"] = str(len(self.file.getvalue()))
@property
def http_status(self):
return f"HTTP/1.1 {self.status}\r\n".encode("utf-8")
@property
def http_headers(self):
headers = io.BytesIO()
for header in self.headers:
headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8"))
headers.write(b"\r\n")
return headers.getvalue()
@property
def http_body(self):
return self.file.getvalue()
@property
def http_response(self):
return self.http_status + self.http_headers + self.http_body
def flush(self, response=None):
if response:
try:
response = response.encode()
except:
pass
sys.stdout.buffer.write(response)
sys.stdout.buffer.flush()
return
sys.stdout.buffer.write(self.http_response)
sys.stdout.buffer.flush()
def __del__(self):
if self.http_body:
self.flush()
exit()
# -------------------------------
# Utilities
# -------------------------------
def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]:
matches: List[Tuple[str, str]] = []
for root, _, files in os.walk(directory):
for file in files:
if not file.endswith(".py"):
continue
path = os.path.join(root, file)
try:
with open(path, "r", encoding="utf-8") as fh:
source = fh.read()
tree = ast.parse(source, filename=path)
for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == name:
func_src = ast.get_source_segment(source, node)
if func_src:
matches.append((path, func_src))
break
except (SyntaxError, UnicodeDecodeError):
continue
return matches
# -------------------------------
# Server
# -------------------------------
class Static(SimpleNamespace):
pass
class View(web.View):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.db = self.request.app["db"]
self.server = self.request.app["server"]
@property
def client(self):
return aiohttp.ClientSession()
class RetoorServer:
def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None:
self.base_dir = Path(base_dir).resolve()
self.port = port
self.static = Static()
self.db = AsyncDataSet(".default.db")
self._logger = logging.getLogger("retoor.server")
self._logger.setLevel(logging.INFO)
if not self._logger.handlers:
h = logging.StreamHandler(sys.stdout)
fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
h.setFormatter(fmt)
self._logger.addHandler(h)
self._func_cache: Dict[str, Tuple[str, str]] = {}
self._compiled_cache: Dict[str, object] = {}
self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {}
self._static_path_cache: Dict[str, Path] = {}
self._base_dir_str = str(self.base_dir)
self.app = web.Application()
app = self.app
app["db"] = self.db
for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir():
if path.suffix == ".py":
exec(open(path).read(), globals(), locals())
self.app["server"] = self
# from rpanel import create_app as create_rpanel_app
# self.app.add_subapp("/api/{tail:.*}", create_rpanel_app())
self.app.router.add_route("*", "/{tail:.*}", self.handle_any)
# ---------------------------
# Simple endpoints
# ---------------------------
async def handle_root(self, request: web.Request) -> web.Response:
return web.Response(text="Static file and cgi server from retoor.")
async def handle_hello(self, request: web.Request) -> web.Response:
return web.Response(text="Welcome to the custom HTTP server!")
# ---------------------------
# Dynamic function dispatch
# ---------------------------
def _path_to_funcname(self, path: str) -> str:
# /foo/bar -> foo_bar ; strip leading/trailing slashes safely
return path.strip("/").replace("/", "_")
async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]:
relpath = request.path.strip("/")
relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace(
"..", ""
)
if not relpath:
return None
last_part = relpath.split("/")[-1]
if "." in last_part:
return None
funcname = self._path_to_funcname(request.path)
if funcname not in self._compiled_cache:
entry = self._func_cache.get(funcname)
if entry is None:
matches = get_function_source(funcname, directory=str(self.base_dir))
if not matches:
return None
entry = matches[0]
self._func_cache[funcname] = entry
filepath, source = entry
code_obj = compile(source, filepath, "exec")
self._compiled_cache[funcname] = (code_obj, filepath)
code_obj, filepath = self._compiled_cache[funcname]
ctx = {
"static": self.static,
"request": request,
"relpath": relpath,
"os": os,
"sys": sys,
"web": web,
"asyncio": asyncio,
"subprocess": subprocess,
"urllib": urllib.parse,
"app": request.app,
"db": self.db,
}
exec(code_obj, ctx, ctx)
coro = ctx.get(funcname)
if not callable(coro):
return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.")
return await coro(request)
# ---------------------------
# Static files
# ---------------------------
async def handle_static(self, request: web.Request) -> web.StreamResponse:
path = request.path
if path in self._static_path_cache:
cached_path = self._static_path_cache[path]
if cached_path is None:
return web.Response(status=404, text="Not found")
try:
return web.FileResponse(path=cached_path)
except Exception as e:
return web.Response(status=500, text=f"Failed to send file: {e!r}")
relpath = path.replace('..",', "").lstrip("/").rstrip("/")
abspath = (self.base_dir / relpath).resolve()
if not str(abspath).startswith(self._base_dir_str):
return web.Response(status=403, text="Forbidden")
if not abspath.exists():
self._static_path_cache[path] = None
return web.Response(status=404, text="Not found")
if abspath.is_dir():
return web.Response(status=403, text="Directory listing forbidden")
self._static_path_cache[path] = abspath
try:
return web.FileResponse(path=abspath)
except Exception as e:
return web.Response(status=500, text=f"Failed to send file: {e!r}")
# ---------------------------
# CGI
# ---------------------------
def _find_cgi_script(
self, path_only: str
) -> Tuple[Optional[Path], Optional[str], Optional[str]]:
path_only = "pr/cgi/" + path_only
if path_only in self._cgi_cache:
return self._cgi_cache[path_only]
split_path = [p for p in path_only.split("/") if p]
for i in range(len(split_path), -1, -1):
candidate = "/" + "/".join(split_path[:i])
candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve()
if (
str(candidate_fs).startswith(self._base_dir_str)
and candidate_fs.parent.is_dir()
and candidate_fs.parent.name == "cgi"
and candidate_fs.suffix == ".bin"
and candidate_fs.is_file()
and os.access(candidate_fs, os.X_OK)
):
script_name = candidate if candidate.startswith("/") else "/" + candidate
path_info = path_only[len(script_name) :]
result = (candidate_fs, script_name, path_info)
self._cgi_cache[path_only] = result
return result
result = (None, None, None)
self._cgi_cache[path_only] = result
return result
async def _run_cgi(
self, request: web.Request, script_path: Path, script_name: str, path_info: str
) -> web.Response:
start_time = time.time()
method = request.method
query = request.query_string
env = os.environ.copy()
env["CACHE_KEY"] = json.dumps(
{"method": method, "query": query, "script_name": script_name, "path_info": path_info},
default=str,
)
env["GATEWAY_INTERFACE"] = "CGI/1.1"
env["REQUEST_METHOD"] = method
env["QUERY_STRING"] = query
env["SCRIPT_NAME"] = script_name
env["PATH_INFO"] = path_info
host_parts = request.host.split(":", 1)
env["SERVER_NAME"] = host_parts[0]
env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80"
env["SERVER_PROTOCOL"] = "HTTP/1.1"
peername = request.transport.get_extra_info("peername")
env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0])
for hk, hv in request.headers.items():
hk_lower = hk.lower()
if hk_lower == "content-type":
env["CONTENT_TYPE"] = hv
elif hk_lower == "content-length":
env["CONTENT_LENGTH"] = hv
else:
env["HTTP_" + hk.upper().replace("-", "_")] = hv
post_data = None
if method in {"POST", "PUT", "PATCH"}:
post_data = await request.read()
env.setdefault("CONTENT_LENGTH", str(len(post_data)))
try:
proc = await asyncio.create_subprocess_exec(
str(script_path),
stdin=asyncio.subprocess.PIPE if post_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
stdout, stderr = await proc.communicate(input=post_data)
if proc.returncode != 0:
msg = stderr.decode(errors="ignore")
return web.Response(status=500, text=f"CGI script error:\n{msg}")
header_end = stdout.find(b"\r\n\r\n")
if header_end != -1:
offset = 4
else:
header_end = stdout.find(b"\n\n")
offset = 2
if header_end != -1:
body = stdout[header_end + offset :]
headers_blob = stdout[:header_end].decode(errors="ignore")
status_code = 200
headers: Dict[str, str] = {}
for line in headers_blob.splitlines():
line = line.strip()
if not line or ":" not in line:
continue
k, v = line.split(":", 1)
k_stripped = k.strip()
if k_stripped.lower() == "status":
try:
status_code = int(v.strip().split()[0])
except Exception:
status_code = 200
else:
headers[k_stripped] = v.strip()
end_time = time.time()
headers["X-Powered-By"] = "retoor"
headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime())
headers["X-Time"] = str(end_time)
headers["X-Duration"] = str(end_time - start_time)
if "error" in str(body).lower():
print(headers)
print(body)
return web.Response(body=body, headers=headers, status=status_code)
else:
return web.Response(body=stdout)
except Exception as e:
return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}")
# ---------------------------
# Main router
# ---------------------------
async def handle_any(self, request: web.Request) -> web.StreamResponse:
path = request.path
if path == "/":
return await self.handle_root(request)
if path == "/hello":
return await self.handle_hello(request)
script_path, script_name, path_info = self._find_cgi_script(path)
if script_path:
return await self._run_cgi(request, script_path, script_name or "", path_info or "")
dyn = await self._maybe_dispatch_dynamic(request)
if dyn is not None:
return dyn
return await self.handle_static(request)
# ---------------------------
# Runner
# ---------------------------
def run(self) -> None:
self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir)
web.run_app(self.app, port=self.port)
def main() -> None:
server = RetoorServer(base_dir=".", port=8118)
server.run()
if __name__ == "__main__":
main()
else:
cgi = CGI()