From 5e99e894e971434955e55910cdeb04a46759d008 Mon Sep 17 00:00:00 2001 From: retoor Date: Fri, 8 Aug 2025 17:28:35 +0200 Subject: [PATCH] Not working version, issues with get --- pyproject.toml | 3 +- src/snek/app.py | 19 +- src/snek/service/user.py | 2 + src/snek/service/user_property.py | 3 +- src/snek/system/ads.py | 1207 +++++++++++++++++++++++++++++ src/snek/system/mapper.py | 32 +- src/snek/system/service.py | 2 +- src/snek/view/web.py | 4 + 8 files changed, 1243 insertions(+), 29 deletions(-) create mode 100644 src/snek/system/ads.py diff --git a/pyproject.toml b/pyproject.toml index a41d4bd..8497a4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,8 @@ dependencies = [ "pillow-heif", "IP2Location", "bleach", - "sentry-sdk" + "sentry-sdk", + "aiosqlite" ] [tool.setuptools.packages.find] diff --git a/src/snek/app.py b/src/snek/app.py index 73a0742..f5c0f27 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -9,7 +9,7 @@ from contextlib import asynccontextmanager from snek import snode from snek.view.threads import ThreadsView - +from snek.system.ads import AsyncDataSet logging.basicConfig(level=logging.DEBUG) from concurrent.futures import ThreadPoolExecutor from ipaddress import ip_address @@ -142,6 +142,7 @@ class Application(BaseApplication): client_max_size=1024 * 1024 * 1024 * 5 * args, **kwargs, ) + self.db = AsyncDataSet(kwargs["db_path"].replace("sqlite:///", "")) session_setup(self, EncryptedCookieStorage(SESSION_KEY)) self.tasks = asyncio.Queue() self._middlewares.append(session_middleware) @@ -174,7 +175,7 @@ class Application(BaseApplication): self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_ssh_server) - self.on_startup.append(self.prepare_database) + #self.on_startup.append(self.prepare_database) async def prepare_stats(self, app): app['stats'] = create_stats_structure() @@ -245,18 +246,8 @@ class Application(BaseApplication): async def prepare_database(self, app): - self.db.query("PRAGMA journal_mode=WAL") - self.db.query("PRAGMA syncnorm=off") - - try: - if not self.db["user"].has_index("username"): - self.db["user"].create_index("username", unique=True) - if not self.db["channel_member"].has_index(["channel_uid", "user_uid"]): - self.db["channel_member"].create_index(["channel_uid", "user_uid"]) - if not self.db["channel_message"].has_index(["channel_uid", "user_uid"]): - self.db["channel_message"].create_index(["channel_uid", "user_uid"]) - except: - pass + await self.db.query_raw("PRAGMA journal_mode=WAL") + await self.db.query_raw("PRAGMA syncnorm=off") await self.services.drive.prepare_all() self.loop.create_task(self.task_runner()) diff --git a/src/snek/service/user.py b/src/snek/service/user.py index 16b317b..7b1b332 100644 --- a/src/snek/service/user.py +++ b/src/snek/service/user.py @@ -101,6 +101,8 @@ class UserService(BaseService): model.username.value = username model.password.value = await security.hash(password) if await self.save(model): + for x in range(10): + print("Jazeker!!!") if model: channel = await self.services.channel.ensure_public_channel( model["uid"] diff --git a/src/snek/service/user_property.py b/src/snek/service/user_property.py index 4d11fa8..acaf0bf 100644 --- a/src/snek/service/user_property.py +++ b/src/snek/service/user_property.py @@ -7,7 +7,8 @@ class UserPropertyService(BaseService): mapper_name = "user_property" async def set(self, user_uid, name, value): - self.mapper.db["user_property"].upsert( + self.mapper.db.upsert( + "user_property", { "user_uid": user_uid, "name": name, diff --git a/src/snek/system/ads.py b/src/snek/system/ads.py new file mode 100644 index 0000000..34ed414 --- /dev/null +++ b/src/snek/system/ads.py @@ -0,0 +1,1207 @@ +import re +import json +from uuid import uuid4 +from datetime import datetime, timezone +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + AsyncGenerator, + Union, + Tuple, + Set, +) +from pathlib import Path +import aiosqlite +import unittest +from types import SimpleNamespace +import asyncio +import aiohttp +from aiohttp import web +import socket +import os +import atexit + + +class AsyncDataSet: + """ + Distributed AsyncDataSet with client-server model over Unix sockets. + + Parameters: + - file: Path to the SQLite database file + - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") + - max_concurrent_queries: Maximum concurrent database operations (default: 1) + Set to 1 to prevent "database is locked" errors with SQLite + """ + + _KV_TABLE = "__kv_store" + _DEFAULT_COLUMNS = { + "uid": "TEXT PRIMARY KEY", + "created_at": "TEXT", + "updated_at": "TEXT", + "deleted_at": "TEXT", + } + + def __init__( + self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 + ): + self._file = file + self._socket_path = socket_path + self._table_columns_cache: Dict[str, Set[str]] = {} + self._is_server = None # None means not initialized yet + self._server = None + self._runner = None + self._client_session = None + self._initialized = False + self._init_lock = None # Will be created when needed + self._max_concurrent_queries = max_concurrent_queries + self._db_semaphore = None # Will be created when needed + self._db_connection = None # Persistent database connection + + async def _ensure_initialized(self): + """Ensure the instance is initialized as server or client.""" + if self._initialized: + return + + # Create lock if needed (first time in async context) + if self._init_lock is None: + self._init_lock = asyncio.Lock() + + async with self._init_lock: + if self._initialized: + return + + await self._initialize() + + # Create semaphore for database operations (server only) + if self._is_server: + self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) + + self._initialized = True + + async def _initialize(self): + """Initialize as server or client.""" + try: + # Try to create Unix socket + if os.path.exists(self._socket_path): + # Check if socket is active + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(self._socket_path) + sock.close() + # Socket is active, we're a client + self._is_server = False + await self._setup_client() + except (ConnectionRefusedError, FileNotFoundError): + # Socket file exists but not active, remove and become server + os.unlink(self._socket_path) + self._is_server = True + await self._setup_server() + else: + # No socket file, become server + self._is_server = True + await self._setup_server() + except Exception as e: + # Fallback to client mode + self._is_server = False + await self._setup_client() + + # Establish persistent database connection + self._db_connection = await aiosqlite.connect(self._file) + + async def _setup_server(self): + """Setup aiohttp server.""" + app = web.Application() + + # Add routes for all methods + app.router.add_post("/insert", self._handle_insert) + app.router.add_post("/update", self._handle_update) + app.router.add_post("/delete", self._handle_delete) + app.router.add_post("/upsert", self._handle_upsert) + app.router.add_post("/get", self._handle_get) + app.router.add_post("/find", self._handle_find) + app.router.add_post("/count", self._handle_count) + app.router.add_post("/exists", self._handle_exists) + app.router.add_post("/kv_set", self._handle_kv_set) + app.router.add_post("/kv_get", self._handle_kv_get) + app.router.add_post("/execute_raw", self._handle_execute_raw) + app.router.add_post("/query_raw", self._handle_query_raw) + app.router.add_post("/query_one", self._handle_query_one) + app.router.add_post("/create_table", self._handle_create_table) + app.router.add_post("/insert_unique", self._handle_insert_unique) + app.router.add_post("/aggregate", self._handle_aggregate) + + self._runner = web.AppRunner(app) + await self._runner.setup() + self._server = web.UnixSite(self._runner, self._socket_path) + await self._server.start() + + # Register cleanup + def cleanup(): + if os.path.exists(self._socket_path): + os.unlink(self._socket_path) + + atexit.register(cleanup) + + async def _setup_client(self): + """Setup aiohttp client.""" + connector = aiohttp.UnixConnector(path=self._socket_path) + self._client_session = aiohttp.ClientSession(connector=connector) + + async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: + """Make HTTP request to server.""" + if not self._client_session: + await self._setup_client() + + url = f"http://localhost/{endpoint}" + async with self._client_session.post(url, json=data) as resp: + result = await resp.json() + if result.get("error"): + raise Exception(result["error"]) + return result.get("result") + + # Server handlers + async def _handle_insert(self, request): + data = await request.json() + try: + result = await self._server_insert( + data["table"], data["args"], data.get("return_id", False) + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_update(self, request): + data = await request.json() + try: + result = await self._server_update( + data["table"], data["args"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_delete(self, request): + data = await request.json() + try: + result = await self._server_delete(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_upsert(self, request): + data = await request.json() + try: + result = await self._server_upsert( + data["table"], data["args"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_get(self, request): + data = await request.json() + try: + result = await self._server_get(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_find(self, request): + data = await request.json() + try: + try: + _limit = data.pop("_limit",60) + except: + pass + + result = await self._server_find( + data["table"], + data.get("where"), + _limit=_limit, + offset=data.get("offset", 0), + order_by=data.get("order_by"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_count(self, request): + data = await request.json() + try: + result = await self._server_count(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_exists(self, request): + data = await request.json() + try: + result = await self._server_exists(data["table"], data["where"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_set(self, request): + data = await request.json() + try: + await self._server_kv_set( + data["key"], data["value"], table=data.get("table") + ) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_get(self, request): + data = await request.json() + try: + result = await self._server_kv_get( + data["key"], default=data.get("default"), table=data.get("table") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_execute_raw(self, request): + data = await request.json() + try: + result = await self._server_execute_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response( + {"result": result.rowcount if hasattr(result, "rowcount") else None} + ) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_raw(self, request): + data = await request.json() + try: + result = await self._server_query_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_one(self, request): + data = await request.json() + try: + result = await self._server_query_one( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_create_table(self, request): + data = await request.json() + try: + await self._server_create_table( + data["table"], data["schema"], data.get("constraints") + ) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_insert_unique(self, request): + data = await request.json() + try: + result = await self._server_insert_unique( + data["table"], data["args"], data["unique_fields"] + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_aggregate(self, request): + data = await request.json() + try: + result = await self._server_aggregate( + data["table"], + data["function"], + data.get("column", "*"), + data.get("where"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + # Public methods that delegate to server or client + async def insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert(table, args, return_id) + else: + return await self._make_request( + "insert", {"table": table, "args": args, "return_id": return_id} + ) + + async def update( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_update(table, args, where) + else: + return await self._make_request( + "update", {"table": table, "args": args, "where": where} + ) + + async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_delete(table, where) + else: + return await self._make_request("delete", {"table": table, "where": where}) + + async def upsert( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> str | None: + await self._ensure_initialized() + if self._is_server: + return await self._server_upsert(table, args, where) + else: + return await self._make_request( + "upsert", {"table": table, "args": args, "where": where} + ) + + async def get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_get(table, where) + else: + return await self._make_request("get", {"table": table, "where": where}) + + async def find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + _limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_find( + table, where, _limit=_limit, offset=offset, order_by=order_by + ) + else: + return await self._make_request( + "find", + { + "table": table, + "where": where, + "_limit": _limit, + "offset": offset, + "order_by": order_by, + }, + ) + + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_count(table, where) + else: + return await self._make_request("count", {"table": table, "where": where}) + + async def exists(self, table: str, where: Dict[str, Any]) -> bool: + await self._ensure_initialized() + if self._is_server: + return await self._server_exists(table, where) + else: + return await self._make_request("exists", {"table": table, "where": where}) + + async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None: + await self._ensure_initialized() + if self._is_server: + return await self._server_kv_set(key, value, table=table) + else: + return await self._make_request( + "kv_set", {"key": key, "value": value, "table": table} + ) + + async def kv_get( + self, key: str, *, default: Any = None, table: str | None = None + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_kv_get(key, default=default, table=table) + else: + return await self._make_request( + "kv_get", {"key": key, "default": default, "table": table} + ) + + async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_execute_raw(sql, params) + else: + return await self._make_request( + "execute_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_raw(sql, params) + else: + return await self._make_request( + "query_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_one(sql, params) + else: + return await self._make_request( + "query_one", {"sql": sql, "params": list(params) if params else None} + ) + + async def create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + await self._ensure_initialized() + if self._is_server: + return await self._server_create_table(table, schema, constraints) + else: + return await self._make_request( + "create_table", + {"table": table, "schema": schema, "constraints": constraints}, + ) + + async def insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert_unique(table, args, unique_fields) + else: + return await self._make_request( + "insert_unique", + {"table": table, "args": args, "unique_fields": unique_fields}, + ) + + async def aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_aggregate(table, function, column, where) + else: + return await self._make_request( + "aggregate", + { + "table": table, + "function": function, + "column": column, + "where": where, + }, + ) + + async def transaction(self): + """Context manager for transactions.""" + await self._ensure_initialized() + if self._is_server: + return TransactionContext(self._db_connection, self._db_semaphore) + else: + raise NotImplementedError("Transactions not supported in client mode") + + # Server implementation methods (original logic) + @staticmethod + def _utc_iso() -> str: + return ( + datetime.now(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + @staticmethod + def _py_to_sqlite_type(value: Any) -> str: + if value is None: + return "TEXT" + if isinstance(value, bool): + return "INTEGER" + if isinstance(value, int): + return "INTEGER" + if isinstance(value, float): + return "REAL" + if isinstance(value, (bytes, bytearray, memoryview)): + return "BLOB" + return "TEXT" + + async def _get_table_columns(self, table: str) -> Set[str]: + """Get actual columns that exist in the table.""" + if table in self._table_columns_cache: + return self._table_columns_cache[table] + + columns = set() + try: + async with self._db_semaphore: + async with self._db_connection.execute( + f"PRAGMA table_info({table})" + ) as cursor: + async for row in cursor: + columns.add(row[1]) # Column name is at index 1 + self._table_columns_cache[table] = columns + except: + pass + return columns + + async def _invalidate_column_cache(self, table: str): + """Invalidate column cache for a table.""" + if table in self._table_columns_cache: + del self._table_columns_cache[table] + + async def _ensure_column(self, table: str, name: str, value: Any) -> None: + col_type = self._py_to_sqlite_type(value) + try: + async with self._db_semaphore: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + except aiosqlite.OperationalError as e: + if "duplicate column name" in str(e).lower(): + pass # Column already exists + else: + raise + + async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + # Always include default columns + cols = self._DEFAULT_COLUMNS.copy() + + # Add columns from col_sources + for key, val in col_sources.items(): + if key not in cols: + cols[key] = self._py_to_sqlite_type(val) + + columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) + async with self._db_semaphore: + await self._db_connection.execute( + f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _table_exists(self, table: str) -> bool: + """Check if a table exists.""" + async with self._db_semaphore: + async with self._db_connection.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) + ) as cursor: + return await cursor.fetchone() is not None + + _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") + _RE_NO_TABLE = re.compile(r"no such table: (\w+)") + + @classmethod + def _missing_column_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_COLUMN.search(str(err)) + return m.group(1) if m else None + + @classmethod + def _missing_table_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_TABLE.search(str(err)) + return m.group(1) if m else None + + async def _safe_execute( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + max_retries: int = 10, + ) -> aiosqlite.Cursor: + retries = 0 + while retries < max_retries: + try: + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params) + await self._db_connection.commit() + return cursor + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing column + col = self._missing_column_from_error(err) + if col: + if col in col_sources: + await self._ensure_column(table, col, col_sources[col]) + else: + # Column not in sources, ensure it with NULL/TEXT type + await self._ensure_column(table, col, None) + continue + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + + # Handle other column-related errors + if "has no column named" in err_str: + # Extract column name differently + match = re.search(r"table \w+ has no column named (\w+)", err_str) + if match: + col_name = match.group(1) + if col_name in col_sources: + await self._ensure_column( + table, col_name, col_sources[col_name] + ) + else: + await self._ensure_column(table, col_name, None) + continue + + raise + raise Exception(f"Max retries ({max_retries}) exceeded") + + async def _filter_existing_columns( + self, table: str, data: Dict[str, Any] + ) -> Dict[str, Any]: + """Filter data to only include columns that exist in the table.""" + if not await self._table_exists(table): + return data + + existing_columns = await self._get_table_columns(table) + if not existing_columns: + return data + + return {k: v for k, v in data.items() if k in existing_columns} + + async def _safe_query( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> AsyncGenerator[Dict[str, Any], None]: + # Check if table exists first + if not await self._table_exists(table): + return + + max_retries = 10 + retries = 0 + + while retries < max_retries: + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params) as cursor: + # Fetch all rows while holding the semaphore + rows = await cursor.fetchall() + + # Yield rows after releasing the semaphore + for row in rows: + yield dict(row) + return + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + # For queries, if table doesn't exist, just return empty + return + + # Handle missing column in WHERE clause or SELECT + if "no such column" in err_str: + # For queries with missing columns, return empty + return + + raise + + @staticmethod + def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + if not where: + return "", [] + clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) + return " WHERE " + " AND ".join(clauses), list(vals) + + async def _server_insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" + uid = str(uuid4()) + now = self._utc_iso() + record = { + "uid": uid, + "created_at": now, + "updated_at": now, + "deleted_at": None, + **args, + } + + # Ensure table exists with all needed columns + await self._ensure_table(table, record) + + # Handle auto-increment ID if requested + if return_id and "id" not in args: + # Ensure id column exists + async with self._db_semaphore: + # Add id column if it doesn't exist + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" + ) + await self._db_connection.commit() + except aiosqlite.OperationalError as e: + if "duplicate column name" not in str(e).lower(): + # Try without autoincrement constraint + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER" + ) + await self._db_connection.commit() + except: + pass + + await self._invalidate_column_cache(table) + + # Insert and get lastrowid + cols = "`" + "`, `".join(record.keys()) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + cursor = await self._safe_execute(table, sql, list(record.values()), record) + return cursor.lastrowid + + cols = "`" + "`, `".join(record) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + await self._safe_execute(table, sql, list(record.values()), record) + return uid + + async def _server_update( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> int: + if not args: + return 0 + + # Check if table exists + if not await self._table_exists(table): + return 0 + + args["updated_at"] = self._utc_iso() + + # Ensure all columns exist + all_cols = {**args, **(where or {})} + await self._ensure_table(table, all_cols) + for col, val in all_cols.items(): + await self._ensure_column(table, col, val) + + set_clause = ", ".join(f"`{k}` = ?" for k in args) + where_clause, where_params = self._build_where(where) + sql = f"UPDATE {table} SET {set_clause}{where_clause}" + params = list(args.values()) + where_params + cur = await self._safe_execute(table, sql, params, all_cols) + return cur.rowcount + + async def _server_delete( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"DELETE FROM {table}{where_clause}" + cur = await self._safe_execute(table, sql, where_params, where or {}) + return cur.rowcount + + async def _server_upsert( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> str | None: + if not args: + raise ValueError("Nothing to update. Empty dict given.") + args["updated_at"] = self._utc_iso() + affected = await self._server_update(table, args, where) + if affected: + rec = await self._server_get(table, where) + return rec.get("uid") if rec else None + merged = {**(where or {}), **args} + return await self._server_insert(table, merged) + + async def _server_get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" + async for row in self._safe_query(table, sql, where_params, where or {}): + return row + return None + + async def _server_find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + _limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Find records with optional ordering.""" + + try: + _limit = where.pop("_limit") + except: + pass + + + where_clause, where_params = self._build_where(where) + order_clause = f" ORDER BY {order_by}" if order_by else "" + extra = (f" LIMIT {_limit}" if _limit else "") + ( + f" OFFSET {offset}" if offset else "" + ) + sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" + return [ + row async for row in self._safe_query(table, sql, where_params, where or {}) + ] + + async def _server_count( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"SELECT COUNT(*) FROM {table}{where_clause}" + gen = self._safe_query(table, sql, where_params, where or {}) + async for row in gen: + return next(iter(row.values()), 0) + return 0 + + async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: + return (await self._server_count(table, where)) > 0 + + async def _server_kv_set( + self, + key: str, + value: Any, + *, + table: str | None = None, + ) -> None: + tbl = table or self._KV_TABLE + json_val = json.dumps(value, default=str) + await self._server_upsert(tbl, {"value": json_val}, {"key": key}) + + async def _server_kv_get( + self, + key: str, + *, + default: Any = None, + table: str | None = None, + ) -> Any: + tbl = table or self._KV_TABLE + row = await self._server_get(tbl, {"key": key}) + if not row: + return default + try: + return json.loads(row["value"]) + except Exception: + return default + + async def query(self, sql, *args): + return await self._server_query_raw(sql, args) + + async def _server_execute_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> Any: + """Execute raw SQL for complex queries like JOINs.""" + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params or ()) + await self._db_connection.commit() + for x in range(10): + print(cursor) + return cursor + + async def _server_query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + """Execute raw SQL query and return results as list of dicts.""" + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params or ()) as cursor: + return [dict(row) async for row in cursor] + except aiosqlite.OperationalError as ex: + print(ex) + # Return empty list if query fails + return [] + + async def _server_query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + """Execute raw SQL query and return single result.""" + results = await self._server_query_raw(sql + " LIMIT 1", params) + return results[0] if results else None + + async def _server_create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + """Create table with custom schema and constraints. Always includes default columns.""" + # Merge default columns with custom schema + full_schema = self._DEFAULT_COLUMNS.copy() + full_schema.update(schema) + + columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] + if constraints: + columns.extend(constraints) + columns_sql = ", ".join(columns) + + async with self._db_semaphore: + await self._db_connection.execute( + f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _server_insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" + try: + return await self._server_insert(table, args) + except aiosqlite.IntegrityError as e: + if "UNIQUE" in str(e): + return None + raise + + async def _server_aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + """Perform aggregate functions like SUM, AVG, MAX, MIN.""" + # Check if table exists + if not await self._table_exists(table): + return None + + where_clause, where_params = self._build_where(where) + sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" + result = await self._server_query_one(sql, tuple(where_params)) + return result["result"] if result else None + + async def close(self): + """Close the connection or server.""" + if not self._initialized: + return + + if self._client_session: + await self._client_session.close() + if self._runner: + await self._runner.cleanup() + if os.path.exists(self._socket_path) and self._is_server: + os.unlink(self._socket_path) + if self._db_connection: + await self._db_connection.close() + + +class TransactionContext: + """Context manager for database transactions.""" + + def __init__( + self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None + ): + self.db_connection = db_connection + self.semaphore = semaphore + + async def __aenter__(self): + if self.semaphore: + await self.semaphore.acquire() + try: + await self.db_connection.execute("BEGIN") + return self.db_connection + except: + if self.semaphore: + self.semaphore.release() + raise + + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type is None: + await self.db_connection.commit() + else: + await self.db_connection.rollback() + finally: + if self.semaphore: + self.semaphore.release() + + +# Test cases remain the same but with additional tests for new functionality +class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.db_path = Path("temp_test.db") + if self.db_path.exists(): + self.db_path.unlink() + self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1) + + async def asyncTearDown(self): + await self.connector.close() + if self.db_path.exists(): + self.db_path.unlink() + + async def test_insert_and_get(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertIsNotNone(rec) + self.assertEqual(rec["name"], "John Doe") + + async def test_get_nonexistent(self): + result = await self.connector.get("people", {"name": "Jane Doe"}) + self.assertIsNone(result) + + async def test_update(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertEqual(rec["age"], 31) + + async def test_order_by(self): + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 20}) + + results = await self.connector.find("people", order_by="age ASC") + self.assertEqual(results[0]["name"], "Charlie") + self.assertEqual(results[-1]["name"], "Bob") + + async def test_raw_query(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + + results = await self.connector.query_raw( + "SELECT * FROM people WHERE age > ?", (26,) + ) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "John") + + async def test_aggregate(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 35}) + + avg_age = await self.connector.aggregate("people", "AVG", "age") + self.assertEqual(avg_age, 30) + + max_age = await self.connector.aggregate("people", "MAX", "age") + self.assertEqual(max_age, 35) + + async def test_insert_with_auto_id(self): + # Test auto-increment ID functionality + id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) + id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) + self.assertEqual(id2, id1 + 1) + + async def test_transaction(self): + await self.connector._ensure_initialized() + if self.connector._is_server: + async with self.connector.transaction() as conn: + await conn.execute( + "INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", + ("test-uid", "John", 30, "2024-01-01", "2024-01-01"), + ) + # Transaction will be committed + + rec = await self.connector.get("people", {"name": "John"}) + self.assertIsNotNone(rec) + + async def test_create_custom_table(self): + schema = { + "id": "INTEGER PRIMARY KEY AUTOINCREMENT", + "username": "TEXT NOT NULL", + "email": "TEXT NOT NULL", + "score": "INTEGER DEFAULT 0", + } + constraints = ["UNIQUE(username)", "UNIQUE(email)"] + + await self.connector.create_table("users", schema, constraints) + + # Test that table was created with constraints + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "john@example.com"}, + ["username", "email"], + ) + self.assertIsNotNone(result) + + # Test duplicate insert + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "different@example.com"}, + ["username", "email"], + ) + self.assertIsNone(result) + + async def test_missing_table_operations(self): + # Test operations on non-existent tables + self.assertEqual(await self.connector.count("nonexistent"), 0) + self.assertEqual(await self.connector.find("nonexistent"), []) + self.assertIsNone(await self.connector.get("nonexistent")) + self.assertFalse(await self.connector.exists("nonexistent", {"id": 1})) + self.assertEqual(await self.connector.delete("nonexistent"), 0) + self.assertEqual( + await self.connector.update("nonexistent", {"name": "test"}), 0 + ) + + async def test_auto_column_creation(self): + # Insert with new columns that don't exist yet + await self.connector.insert( + "dynamic", {"col1": "value1", "col2": 42, "col3": 3.14} + ) + + # Add more columns in next insert + await self.connector.insert( + "dynamic", {"col1": "value2", "col4": True, "col5": None} + ) + + # All records should be retrievable + records = await self.connector.find("dynamic") + self.assertEqual(len(records), 2) + + +if __name__ == "__main__": + unittest.main() + diff --git a/src/snek/system/mapper.py b/src/snek/system/mapper.py index 2e1c75d..c4f765e 100644 --- a/src/snek/system/mapper.py +++ b/src/snek/system/mapper.py @@ -1,7 +1,7 @@ DEFAULT_LIMIT = 30 import asyncio import typing - +import traceback from snek.system.model import BaseModel @@ -51,7 +51,9 @@ class BaseMapper: kwargs["uid"] = uid if not kwargs.get("deleted_at"): kwargs["deleted_at"] = None - record = await self.run_in_executor(self.table.find_one, **kwargs) + #traceback.print_exc() + + record = await self.db.get(self.table_name, kwargs) if not record: return None record = dict(record) @@ -61,23 +63,29 @@ class BaseMapper: return model async def exists(self, **kwargs): - return await self.run_in_executor(self.table.exists, **kwargs) + return await self.db.count(self.table_name, kwargs) + + + #return await self.run_in_executor(self.table.exists, **kwargs) async def count(self, **kwargs) -> int: - return await self.run_in_executor(self.table.count, **kwargs) + return await self.db.count(self.table_name,kwargs) + async def save(self, model: BaseModel) -> bool: if not model.record.get("uid"): raise Exception(f"Attempt to save without uid: {model.record}.") - model.updated_at.update() - return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True) + model.updated_at.update() + await self.upsert(model) + return model + #return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True) async def find(self, **kwargs) -> typing.AsyncGenerator: if not kwargs.get("_limit"): kwargs["_limit"] = self.default_limit if not kwargs.get("deleted_at"): kwargs["deleted_at"] = None - for record in await self.run_in_executor(self.table.find, **kwargs): + for record in await self.db.find(self.table_name, kwargs): model = await self.new() for key, value in record.items(): model[key] = value @@ -88,21 +96,21 @@ class BaseMapper: return "insert" in sql or "update" in sql or "delete" in sql async def query(self, sql, *args): - for record in await self.run_in_executor(self.db.query, sql, *args, use_semaphore=await self._use_semaphore(sql)): + for record in await self.db.query(sql, *args): yield dict(record) async def update(self, model): if not model["deleted_at"] is None: raise Exception("Can't update deleted record.") model.updated_at.update() - return await self.run_in_executor(self.table.update, model.record, ["uid"],use_semaphore=True) + return await self.db.update(self.table_name, model.record, {"uid": model["uid"]}) async def upsert(self, model): model.updated_at.update() - return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True) + await self.db.upsert(self.table_name, model.record, {"uid": model["uid"]}) + return model async def delete(self, **kwargs) -> int: if not kwargs or not isinstance(kwargs, dict): raise Exception("Can't execute delete with no filter.") - kwargs["use_semaphore"] = True - return await self.run_in_executor(self.table.delete, **kwargs) + return await self.db.delete(self.table_name, kwargs) diff --git a/src/snek/system/service.py b/src/snek/system/service.py index 01570d8..b47dbc4 100644 --- a/src/snek/system/service.py +++ b/src/snek/system/service.py @@ -50,7 +50,7 @@ class BaseService: if result and result.__class__ == self.mapper.model_class: return result kwargs["uid"] = uid - + print(kwargs,"ZZZZZZZ") result = await self.mapper.get(**kwargs) if result: await self.cache.set(result["uid"], result) diff --git a/src/snek/view/web.py b/src/snek/view/web.py index 77ff479..1152d4c 100644 --- a/src/snek/view/web.py +++ b/src/snek/view/web.py @@ -38,6 +38,10 @@ class WebView(BaseView): channel = await self.services.channel.get( uid=self.request.match_info.get("channel") ) + print(self.session.get("uid"),"ZZZZZZZZZZ") + qq = await self.services.user.get(uid=self.session.get("uid")) + + print("GGGGGGGGGG",qq) if not channel: user = await self.services.user.get( uid=self.request.match_info.get("channel")