From ba3152f553afcfae318811a413cdea6f5be9f413 Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 13 May 2025 18:30:31 +0200 Subject: [PATCH] Update. --- src/snek/research/serpentarium.py | 312 ++++++++++++++++++++++++++++++ src/snek/research/serptest.py | 51 +++++ src/snek/schema.sql | 103 ++++++++++ src/snek/snode.py | 116 +++++++++++ 4 files changed, 582 insertions(+) create mode 100644 src/snek/research/serpentarium.py create mode 100644 src/snek/research/serptest.py create mode 100644 src/snek/schema.sql create mode 100644 src/snek/snode.py diff --git a/src/snek/research/serpentarium.py b/src/snek/research/serpentarium.py new file mode 100644 index 0000000..c87b70b --- /dev/null +++ b/src/snek/research/serpentarium.py @@ -0,0 +1,312 @@ + +import json +import asyncio +import aiohttp +from aiohttp import web +import dataset +import dataset.util +import traceback +import socket +import base64 +import uuid + +class DatasetMethod: + def __init__(self, dt, name): + self.dt = dt + self.name = name + + def __call__(self, *args, **kwargs): + return self.dt.ds.call( + self.dt.name, + self.name, + *args, + **kwargs + ) + + +class DatasetTable: + + def __init__(self, ds, name): + self.ds = ds + self.name = name + + def __getattr__(self, name): + return DatasetMethod(self, name) + +class WebSocketClient: + def __init__(self): + self.buffer = b'' + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.connect() + + def connect(self): + self.socket.connect(("127.0.0.1", 3131)) + key = base64.b64encode(b'1234123412341234').decode('utf-8') + handshake = ( + f"GET /db HTTP/1.1\r\n" + f"Host: localhost:3131\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {key}\r\n" + f"Sec-WebSocket-Version: 13\r\n\r\n" + ) + self.socket.sendall(handshake.encode('utf-8')) + response = self.read_until(b'\r\n\r\n') + if b'101 Switching Protocols' not in response: + raise Exception("Failed to connect to WebSocket") + + def write(self, message): + message_bytes = message.encode('utf-8') + length = len(message_bytes) + if length <= 125: + self.socket.sendall(b'\x81' + bytes([length]) + message_bytes) + elif length >= 126 and length <= 65535: + self.socket.sendall(b'\x81' + bytes([126]) + length.to_bytes(2, 'big') + message_bytes) + else: + self.socket.sendall(b'\x81' + bytes([127]) + length.to_bytes(8, 'big') + message_bytes) + + + def read_until(self, delimiter): + while True: + find_pos = self.buffer.find(delimiter) + if find_pos != -1: + data = self.buffer[:find_pos+4] + self.buffer = self.buffer[find_pos+4:] + return data + + chunk = self.socket.recv(1024) + if not chunk: + return None + self.buffer += chunk + + def read_exactly(self, length): + while len(self.buffer) < length: + chunk = self.socket.recv(length - len(self.buffer)) + if not chunk: + return None + self.buffer += chunk + response = self.buffer[: length] + self.buffer = self.buffer[length:] + return response + + def read(self): + frame = None + frame = self.read_exactly(2) + length = frame[1] & 127 + if length == 126: + length = int.from_bytes(self.read_exactly(2), 'big') + elif length == 127: + length = int.from_bytes(self.read_exactly(8), 'big') + message = self.read_exactly(length) + return message + + def close(self): + self.socket.close() + + + + +class WebSocketClient2: + def __init__(self, uri): + self.uri = uri + self.loop = asyncio.get_event_loop() + self.websocket = None + self.receive_queue = asyncio.Queue() + + # Schedule connection setup + if self.loop.is_running(): + # Schedule connect in the existing loop + self._connect_future = asyncio.run_coroutine_threadsafe(self._connect(), self.loop) + else: + # If loop isn't running, connect synchronously + self.loop.run_until_complete(self._connect()) + + async def _connect(self): + self.websocket = await websockets.connect(self.uri) + # Start listening for messages + asyncio.create_task(self._receive_loop()) + + async def _receive_loop(self): + try: + async for message in self.websocket: + await self.receive_queue.put(message) + except Exception: + pass # Handle exceptions as needed + + def send(self, message: str): + if self.loop.is_running(): + # Schedule send in the existing loop + asyncio.run_coroutine_threadsafe(self.websocket.send(message), self.loop) + else: + # If loop isn't running, run directly + self.loop.run_until_complete(self.websocket.send(message)) + + def receive(self): + # Wait for a message synchronously + future = asyncio.run_coroutine_threadsafe(self.receive_queue.get(), self.loop) + return future.result() + + def close(self): + if self.websocket: + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(self.websocket.close(), self.loop) + else: + self.loop.run_until_complete(self.websocket.close()) + + +import websockets + +class DatasetWrapper(object): + + def __init__(self): + self.ws = WebSocketClient() + + def begin(self): + self.call(None, 'begin') + + def commit(self): + self.call(None, 'commit') + + def __getitem__(self, name): + return DatasetTable(self, name) + + def query(self, *args, **kwargs): + return self.call(None, 'query', *args, **kwargs) + + def call(self, table, method, *args, **kwargs): + payload = {"table": table, "method": method, "args": args, "kwargs": kwargs,"call_uid":None} + #if method in ['find','find_one']: + payload["call_uid"] = str(uuid.uuid4()) + self.ws.write(json.dumps(payload)) + if payload["call_uid"]: + response = self.ws.read() + return json.loads(response)['result'] + return True + + + +class DatasetWebSocketView: + def __init__(self): + self.ws = None + self.db = dataset.connect('sqlite:///snek.db') + self.setattr(self, "db", self.get) + self.setattr(self, "db", self.set) + ) + super() + + def format_result(self, result): + + try: + return dict(result) + except: + pass + try: + return [dict(row) for row in result] + except: + pass + return result + + async def send_str(self, msg): + return await self.ws.send_str(msg) + + def get(self, key): + returnl loads(dict(self.db['_kv'].get(key=key)['value'])) + + def set(self, key, value): + return self.db['_kv'].upsert({'key': key, 'value': json.dumps(value)}, ['key']) + + + + async def handle(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) + self.ws = ws + + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + call_uid = data.get("call_uid") + method = data.get("method") + table_name = data.get("table") + args = data.get("args", {}) + kwargs = data.get("kwargs", {}) + + + function = getattr(self.db, method, None) + if table_name: + function = getattr(self.db[table_name], method, None) + + print(method, table_name, args, kwargs,flush=True) + + if function: + response = {} + try: + result = function(*args, **kwargs) + print(result) + response['result'] = self.format_result(result) + response["call_uid"] = call_uid + response["success"] = True + except Exception as e: + response["call_uid"] = call_uid + response["success"] = False + response["error"] = str(e) + response["traceback"] = traceback.format_exc() + + if call_uid: + await self.send_str(json.dumps(response,default=str)) + else: + await self.send_str(json.dumps({"status": "error", "error":"Method not found.","call_uid": call_uid})) + except Exception as e: + await self.send_str(json.dumps({"success": False,"call_uid": call_uid, "error": str(e), "error": str(e), "traceback": traceback.format_exc()},default=str)) + elif msg.type == aiohttp.WSMsgType.ERROR: + print('ws connection closed with exception %s' % ws.exception()) + + return ws + + + + + +app = web.Application() +view = DatasetWebSocketView() +app.router.add_get('/db', view.handle) + +async def run_server(): + + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 3131) + await site.start() + + print("Server started at http://localhost:8080") + await asyncio.Event().wait() + +async def client(): + print("x") + d = DatasetWrapper() + print("y") + + for x in range(100): + for x in range(100): + if d['test'].insert({"name": "test", "number":x}): + print(".",end="",flush=True) + print("") + print(d['test'].find_one(name="test", order_by="-number")) + + print("DONE") + + + +import time +async def main(): + await run_server() + +import sys + +if __name__ == '__main__': + if sys.argv[1] == 'server': + asyncio.run(main()) + if sys.argv[1] == 'client': + asyncio.run(client()) diff --git a/src/snek/research/serptest.py b/src/snek/research/serptest.py new file mode 100644 index 0000000..94e22be --- /dev/null +++ b/src/snek/research/serptest.py @@ -0,0 +1,51 @@ +import snek.serpentarium + +import time + +from concurrent.futures import ProcessPoolExecutor + +durations = [] + +def task1(): + global durations + client = snek.serpentarium.DatasetWrapper() + + start=time.time() + for x in range(1500): + + client['a'].delete() + client['a'].insert({"foo": x}) + client['a'].find(foo=x) + client['a'].find_one(foo=x) + client['a'].count() + #print(client['a'].find(foo=x) ) + #print(client['a'].find_one(foo=x) ) + #print(client['a'].count()) + client.close() + duration1 = f"{time.time()-start}" + durations.append(duration1) + print(durations) + +with ProcessPoolExecutor(max_workers=4) as executor: + tasks = [executor.submit(task1), + executor.submit(task1), + executor.submit(task1), + executor.submit(task1) + ] + for task in tasks: + task.result() + + +import dataset +client = dataset.connect("sqlite:///snek.db") +start=time.time() +for x in range(1500): + + client['a'].delete() + client['a'].insert({"foo": x}) + print([dict(row) for row in client['a'].find(foo=x)]) + print(dict(client['a'].find_one(foo=x) )) + print(client['a'].count()) +duration2 = f"{time.time()-start}" + +print(duration1,duration2) diff --git a/src/snek/schema.sql b/src/snek/schema.sql new file mode 100644 index 0000000..5b9c9a5 --- /dev/null +++ b/src/snek/schema.sql @@ -0,0 +1,103 @@ +CREATE TABLE IF NOT EXISTS http_access ( + id INTEGER NOT NULL, + created TEXT, + path TEXT, + duration FLOAT, + PRIMARY KEY (id) +); +CREATE TABLE IF NOT EXISTS user ( + id INTEGER NOT NULL, + color TEXT, + created_at TEXT, + deleted_at TEXT, + email TEXT, + is_admin TEXT, + last_ping TEXT, + nick TEXT, + password TEXT, + uid TEXT, + updated_at TEXT, + username TEXT, + PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS ix_user_e2577dd78b54fe28 ON user (uid); +CREATE TABLE IF NOT EXISTS channel ( + id INTEGER NOT NULL, + created_at TEXT, + created_by_uid TEXT, + deleted_at TEXT, + description TEXT, + "index" BIGINT, + is_listed BOOLEAN, + is_private BOOLEAN, + label TEXT, + last_message_on TEXT, + tag TEXT, + uid TEXT, + updated_at TEXT, + PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS ix_channel_e2577dd78b54fe28 ON channel (uid); +CREATE TABLE IF NOT EXISTS channel_member ( + id INTEGER NOT NULL, + channel_uid TEXT, + created_at TEXT, + deleted_at TEXT, + is_banned BOOLEAN, + is_moderator BOOLEAN, + is_muted BOOLEAN, + is_read_only BOOLEAN, + label TEXT, + new_count BIGINT, + uid TEXT, + updated_at TEXT, + user_uid TEXT, + PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS ix_channel_member_e2577dd78b54fe28 ON channel_member (uid); +CREATE TABLE IF NOT EXISTS broadcast ( + id INTEGER NOT NULL, + channel_uid TEXT, + message TEXT, + created_at TEXT, + PRIMARY KEY (id) +); +CREATE TABLE IF NOT EXISTS channel_message ( + id INTEGER NOT NULL, + channel_uid TEXT, + created_at TEXT, + deleted_at TEXT, + html TEXT, + message TEXT, + uid TEXT, + updated_at TEXT, + user_uid TEXT, + PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS ix_channel_message_e2577dd78b54fe28 ON channel_message (uid); +CREATE TABLE IF NOT EXISTS notification ( + id INTEGER NOT NULL, + created_at TEXT, + deleted_at TEXT, + message TEXT, + object_type TEXT, + object_uid TEXT, + read_at TEXT, + uid TEXT, + updated_at TEXT, + user_uid TEXT, + PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS ix_notification_e2577dd78b54fe28 ON notification (uid); +CREATE TABLE IF NOT EXISTS repository ( + id INTEGER NOT NULL, + created_at TEXT, + deleted_at TEXT, + is_private BIGINT, + name TEXT, + uid TEXT, + updated_at TEXT, + user_uid TEXT, + PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS ix_repository_e2577dd78b54fe28 ON repository (uid); diff --git a/src/snek/snode.py b/src/snek/snode.py new file mode 100644 index 0000000..ae42a1f --- /dev/null +++ b/src/snek/snode.py @@ -0,0 +1,116 @@ +import aiohttp + +ENABLED = False + +import aiohttp +import asyncio +from aiohttp import web + +import sqlite3 + +import dataset +from sqlalchemy import event +from sqlalchemy.engine import Engine + +import json + +queue = asyncio.Queue() + +class State: + do_not_sync = False + +async def sync_service(app): + if not ENABLED: + return + session = aiohttp.ClientSession() + async with session.ws_connect('http://localhost:3131/ws') as ws: + async def receive(): + + queries_synced = 0 + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + State.do_not_sync = True + app.db.execute(*data) + app.db.commit() + State.do_not_sync = False + queries_synced += 1 + print("queries synced: " + str(queries_synced)) + print(*data) + await app.services.socket.broadcast_event() + except Exception as e: + print(e) + pass + #print(f"Received: {msg.data}") + elif msg.type == aiohttp.WSMsgType.ERROR: + break + async def write(): + while True: + msg = await queue.get() + await ws.send_str(json.dumps(msg,default=str)) + queue.task_done() + + await asyncio.gather(receive(), write()) + + await session.close() + +queries_queued = 0 +# Attach a listener to log all executed statements +@event.listens_for(Engine, "before_cursor_execute") +def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): + if not ENABLED: + return + global queries_queued + if State.do_not_sync: + print(statement,parameters) + return + if statement.startswith("SELECT"): + return + queue.put_nowait((statement, parameters)) + queries_queued += 1 + print("Queries queued: " + str(queries_queued)) + +async def websocket_handler(request): + queries_broadcasted = 0 + ws = web.WebSocketResponse() + await ws.prepare(request) + request.app['websockets'].append(ws) + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + for client in request.app['websockets']: + if client != ws: + await client.send_str(msg.data) + cursor = request.app['db'].cursor() + data = json.loads(msg.data) + queries_broadcasted += 1 + + cursor.execute(*data) + cursor.close() + print("Queries broadcasted: " + str(queries_broadcasted)) + elif msg.type == aiohttp.WSMsgType.ERROR: + print(f'WebSocket connection closed with exception {ws.exception()}') + + request.app['websockets'].remove(ws) + return ws + +app = web.Application() +app['websockets'] = [] + +app.router.add_get('/ws', websocket_handler) + +async def on_startup(app): + app['db'] = sqlite3.connect('snek.db') + print("Server starting...") + +async def on_cleanup(app): + for ws in app['websockets']: + await ws.close() + app['db'].close() + +app.on_startup.append(on_startup) +app.on_cleanup.append(on_cleanup) + + +if __name__ == '__main__': + web.run_app(app, host='127.0.0.1', port=3131)