import re import json from uuid import uuid4 from datetime import datetime, timezone from typing import ( Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple, Set, ) from pathlib import Path import aiosqlite import unittest from types import SimpleNamespace import asyncio import aiohttp from aiohttp import web import socket import os import atexit class AsyncDataSet: """ Distributed AsyncDataSet with client-server model over Unix sockets. Parameters: - file: Path to the SQLite database file - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") - max_concurrent_queries: Maximum concurrent database operations (default: 1) Set to 1 to prevent "database is locked" errors with SQLite """ _KV_TABLE = "__kv_store" _DEFAULT_COLUMNS = { "uid": "TEXT PRIMARY KEY", "created_at": "TEXT", "updated_at": "TEXT", "deleted_at": "TEXT", } def __init__( self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 ): self._file = file self._socket_path = socket_path self._table_columns_cache: Dict[str, Set[str]] = {} self._is_server = None # None means not initialized yet self._server = None self._runner = None self._client_session = None self._initialized = False self._init_lock = None # Will be created when needed self._max_concurrent_queries = max_concurrent_queries self._db_semaphore = None # Will be created when needed self._db_connection = None # Persistent database connection async def _ensure_initialized(self): """Ensure the instance is initialized as server or client.""" if self._initialized: return # Create lock if needed (first time in async context) if self._init_lock is None: self._init_lock = asyncio.Lock() async with self._init_lock: if self._initialized: return await self._initialize() # Create semaphore for database operations (server only) if self._is_server: self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) self._initialized = True async def _initialize(self): """Initialize as server or client.""" try: # Try to create Unix socket if os.path.exists(self._socket_path): # Check if socket is active sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: sock.connect(self._socket_path) sock.close() # Socket is active, we're a client self._is_server = False await self._setup_client() except (ConnectionRefusedError, FileNotFoundError): # Socket file exists but not active, remove and become server os.unlink(self._socket_path) self._is_server = True await self._setup_server() else: # No socket file, become server self._is_server = True await self._setup_server() except Exception as e: # Fallback to client mode self._is_server = False await self._setup_client() # Establish persistent database connection self._db_connection = await aiosqlite.connect(self._file) async def _setup_server(self): """Setup aiohttp server.""" app = web.Application() # Add routes for all methods app.router.add_post("/insert", self._handle_insert) app.router.add_post("/update", self._handle_update) app.router.add_post("/delete", self._handle_delete) app.router.add_post("/upsert", self._handle_upsert) app.router.add_post("/get", self._handle_get) app.router.add_post("/find", self._handle_find) app.router.add_post("/count", self._handle_count) app.router.add_post("/exists", self._handle_exists) app.router.add_post("/kv_set", self._handle_kv_set) app.router.add_post("/kv_get", self._handle_kv_get) app.router.add_post("/execute_raw", self._handle_execute_raw) app.router.add_post("/query_raw", self._handle_query_raw) app.router.add_post("/query_one", self._handle_query_one) app.router.add_post("/create_table", self._handle_create_table) app.router.add_post("/insert_unique", self._handle_insert_unique) app.router.add_post("/aggregate", self._handle_aggregate) self._runner = web.AppRunner(app) await self._runner.setup() self._server = web.UnixSite(self._runner, self._socket_path) await self._server.start() # Register cleanup def cleanup(): if os.path.exists(self._socket_path): os.unlink(self._socket_path) atexit.register(cleanup) async def _setup_client(self): """Setup aiohttp client.""" connector = aiohttp.UnixConnector(path=self._socket_path) self._client_session = aiohttp.ClientSession(connector=connector) async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: """Make HTTP request to server.""" if not self._client_session: await self._setup_client() url = f"http://localhost/{endpoint}" async with self._client_session.post(url, json=data) as resp: result = await resp.json() if result.get("error"): raise Exception(result["error"]) return result.get("result") # Server handlers async def _handle_insert(self, request): data = await request.json() try: result = await self._server_insert( data["table"], data["args"], data.get("return_id", False) ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_update(self, request): data = await request.json() try: result = await self._server_update( data["table"], data["args"], data.get("where") ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_delete(self, request): data = await request.json() try: result = await self._server_delete(data["table"], data.get("where")) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_upsert(self, request): data = await request.json() try: result = await self._server_upsert( data["table"], data["args"], data.get("where") ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_get(self, request): data = await request.json() try: result = await self._server_get(data["table"], data.get("where")) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_find(self, request): data = await request.json() try: result = await self._server_find( data["table"], data.get("where"), limit=data.get("limit", 0), offset=data.get("offset", 0), order_by=data.get("order_by"), ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_count(self, request): data = await request.json() try: result = await self._server_count(data["table"], data.get("where")) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_exists(self, request): data = await request.json() try: result = await self._server_exists(data["table"], data["where"]) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_kv_set(self, request): data = await request.json() try: await self._server_kv_set( data["key"], data["value"], table=data.get("table") ) return web.json_response({"result": None}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_kv_get(self, request): data = await request.json() try: result = await self._server_kv_get( data["key"], default=data.get("default"), table=data.get("table") ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_execute_raw(self, request): data = await request.json() try: result = await self._server_execute_raw( data["sql"], tuple(data.get("params", [])) if data.get("params") else None, ) return web.json_response( {"result": result.rowcount if hasattr(result, "rowcount") else None} ) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_query_raw(self, request): data = await request.json() try: result = await self._server_query_raw( data["sql"], tuple(data.get("params", [])) if data.get("params") else None, ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_query_one(self, request): data = await request.json() try: result = await self._server_query_one( data["sql"], tuple(data.get("params", [])) if data.get("params") else None, ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_create_table(self, request): data = await request.json() try: await self._server_create_table( data["table"], data["schema"], data.get("constraints") ) return web.json_response({"result": None}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_insert_unique(self, request): data = await request.json() try: result = await self._server_insert_unique( data["table"], data["args"], data["unique_fields"] ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_aggregate(self, request): data = await request.json() try: result = await self._server_aggregate( data["table"], data["function"], data.get("column", "*"), data.get("where"), ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) # Public methods that delegate to server or client async def insert( self, table: str, args: Dict[str, Any], return_id: bool = False ) -> Union[str, int]: await self._ensure_initialized() if self._is_server: return await self._server_insert(table, args, return_id) else: return await self._make_request( "insert", {"table": table, "args": args, "return_id": return_id} ) async def update( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None ) -> int: await self._ensure_initialized() if self._is_server: return await self._server_update(table, args, where) else: return await self._make_request( "update", {"table": table, "args": args, "where": where} ) async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: await self._ensure_initialized() if self._is_server: return await self._server_delete(table, where) else: return await self._make_request("delete", {"table": table, "where": where}) async def upsert( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None ) -> str | None: await self._ensure_initialized() if self._is_server: return await self._server_upsert(table, args, where) else: return await self._make_request( "upsert", {"table": table, "args": args, "where": where} ) async def get( self, table: str, where: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_get(table, where) else: return await self._make_request("get", {"table": table, "where": where}) async def find( self, table: str, where: Optional[Dict[str, Any]] = None, *, limit: int = 0, offset: int = 0, order_by: Optional[str] = None, ) -> List[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_find( table, where, limit=limit, offset=offset, order_by=order_by ) else: return await self._make_request( "find", { "table": table, "where": where, "limit": limit, "offset": offset, "order_by": order_by, }, ) async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: await self._ensure_initialized() if self._is_server: return await self._server_count(table, where) else: return await self._make_request("count", {"table": table, "where": where}) async def exists(self, table: str, where: Dict[str, Any]) -> bool: await self._ensure_initialized() if self._is_server: return await self._server_exists(table, where) else: return await self._make_request("exists", {"table": table, "where": where}) async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None: await self._ensure_initialized() if self._is_server: return await self._server_kv_set(key, value, table=table) else: return await self._make_request( "kv_set", {"key": key, "value": value, "table": table} ) async def kv_get( self, key: str, *, default: Any = None, table: str | None = None ) -> Any: await self._ensure_initialized() if self._is_server: return await self._server_kv_get(key, default=default, table=table) else: return await self._make_request( "kv_get", {"key": key, "default": default, "table": table} ) async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: await self._ensure_initialized() if self._is_server: return await self._server_execute_raw(sql, params) else: return await self._make_request( "execute_raw", {"sql": sql, "params": list(params) if params else None} ) async def query_raw( self, sql: str, params: Optional[Tuple] = None ) -> List[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_query_raw(sql, params) else: return await self._make_request( "query_raw", {"sql": sql, "params": list(params) if params else None} ) async def query_one( self, sql: str, params: Optional[Tuple] = None ) -> Optional[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_query_one(sql, params) else: return await self._make_request( "query_one", {"sql": sql, "params": list(params) if params else None} ) async def create_table( self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None, ): await self._ensure_initialized() if self._is_server: return await self._server_create_table(table, schema, constraints) else: return await self._make_request( "create_table", {"table": table, "schema": schema, "constraints": constraints}, ) async def insert_unique( self, table: str, args: Dict[str, Any], unique_fields: List[str] ) -> Union[str, None]: await self._ensure_initialized() if self._is_server: return await self._server_insert_unique(table, args, unique_fields) else: return await self._make_request( "insert_unique", {"table": table, "args": args, "unique_fields": unique_fields}, ) async def aggregate( self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None, ) -> Any: await self._ensure_initialized() if self._is_server: return await self._server_aggregate(table, function, column, where) else: return await self._make_request( "aggregate", { "table": table, "function": function, "column": column, "where": where, }, ) async def transaction(self): """Context manager for transactions.""" await self._ensure_initialized() if self._is_server: return TransactionContext(self._db_connection, self._db_semaphore) else: raise NotImplementedError("Transactions not supported in client mode") # Server implementation methods (original logic) @staticmethod def _utc_iso() -> str: return ( datetime.now(timezone.utc) .replace(microsecond=0) .isoformat() .replace("+00:00", "Z") ) @staticmethod def _py_to_sqlite_type(value: Any) -> str: if value is None: return "TEXT" if isinstance(value, bool): return "INTEGER" if isinstance(value, int): return "INTEGER" if isinstance(value, float): return "REAL" if isinstance(value, (bytes, bytearray, memoryview)): return "BLOB" return "TEXT" async def _get_table_columns(self, table: str) -> Set[str]: """Get actual columns that exist in the table.""" if table in self._table_columns_cache: return self._table_columns_cache[table] columns = set() try: async with self._db_semaphore: async with self._db_connection.execute( f"PRAGMA table_info({table})" ) as cursor: async for row in cursor: columns.add(row[1]) # Column name is at index 1 self._table_columns_cache[table] = columns except: pass return columns async def _invalidate_column_cache(self, table: str): """Invalidate column cache for a table.""" if table in self._table_columns_cache: del self._table_columns_cache[table] async def _ensure_column(self, table: str, name: str, value: Any) -> None: col_type = self._py_to_sqlite_type(value) try: async with self._db_semaphore: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" ) await self._db_connection.commit() await self._invalidate_column_cache(table) except aiosqlite.OperationalError as e: if "duplicate column name" in str(e).lower(): pass # Column already exists else: raise async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: # Always include default columns cols = self._DEFAULT_COLUMNS.copy() # Add columns from col_sources for key, val in col_sources.items(): if key not in cols: cols[key] = self._py_to_sqlite_type(val) columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) async with self._db_semaphore: await self._db_connection.execute( f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" ) await self._db_connection.commit() await self._invalidate_column_cache(table) async def _table_exists(self, table: str) -> bool: """Check if a table exists.""" async with self._db_semaphore: async with self._db_connection.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) ) as cursor: return await cursor.fetchone() is not None _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") _RE_NO_TABLE = re.compile(r"no such table: (\w+)") @classmethod def _missing_column_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_COLUMN.search(str(err)) return m.group(1) if m else None @classmethod def _missing_table_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_TABLE.search(str(err)) return m.group(1) if m else None async def _safe_execute( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], max_retries: int = 10, ) -> aiosqlite.Cursor: retries = 0 while retries < max_retries: try: async with self._db_semaphore: cursor = await self._db_connection.execute(sql, params) await self._db_connection.commit() return cursor except aiosqlite.OperationalError as err: retries += 1 err_str = str(err).lower() # Handle missing column col = self._missing_column_from_error(err) if col: if col in col_sources: await self._ensure_column(table, col, col_sources[col]) else: # Column not in sources, ensure it with NULL/TEXT type await self._ensure_column(table, col, None) continue # Handle missing table tbl = self._missing_table_from_error(err) if tbl: await self._ensure_table(tbl, col_sources) continue # Handle other column-related errors if "has no column named" in err_str: # Extract column name differently match = re.search(r"table \w+ has no column named (\w+)", err_str) if match: col_name = match.group(1) if col_name in col_sources: await self._ensure_column( table, col_name, col_sources[col_name] ) else: await self._ensure_column(table, col_name, None) continue raise raise Exception(f"Max retries ({max_retries}) exceeded") async def _filter_existing_columns( self, table: str, data: Dict[str, Any] ) -> Dict[str, Any]: """Filter data to only include columns that exist in the table.""" if not await self._table_exists(table): return data existing_columns = await self._get_table_columns(table) if not existing_columns: return data return {k: v for k, v in data.items() if k in existing_columns} async def _safe_query( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], ) -> AsyncGenerator[Dict[str, Any], None]: # Check if table exists first if not await self._table_exists(table): return max_retries = 10 retries = 0 while retries < max_retries: try: async with self._db_semaphore: self._db_connection.row_factory = aiosqlite.Row async with self._db_connection.execute(sql, params) as cursor: # Fetch all rows while holding the semaphore rows = await cursor.fetchall() # Yield rows after releasing the semaphore for row in rows: yield dict(row) return except aiosqlite.OperationalError as err: retries += 1 err_str = str(err).lower() # Handle missing table tbl = self._missing_table_from_error(err) if tbl: # For queries, if table doesn't exist, just return empty return # Handle missing column in WHERE clause or SELECT if "no such column" in err_str: # For queries with missing columns, return empty return raise @staticmethod def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: if not where: return "", [] clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) return " WHERE " + " AND ".join(clauses), list(vals) async def _server_insert( self, table: str, args: Dict[str, Any], return_id: bool = False ) -> Union[str, int]: """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" uid = str(uuid4()) now = self._utc_iso() record = { "uid": uid, "created_at": now, "updated_at": now, "deleted_at": None, **args, } # Ensure table exists with all needed columns await self._ensure_table(table, record) # Handle auto-increment ID if requested if return_id and "id" not in args: # Ensure id column exists async with self._db_semaphore: # Add id column if it doesn't exist try: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" ) await self._db_connection.commit() except aiosqlite.OperationalError as e: if "duplicate column name" not in str(e).lower(): # Try without autoincrement constraint try: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN id INTEGER" ) await self._db_connection.commit() except: pass await self._invalidate_column_cache(table) # Insert and get lastrowid cols = "`" + "`, `".join(record.keys()) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" cursor = await self._safe_execute(table, sql, list(record.values()), record) return cursor.lastrowid cols = "`" + "`, `".join(record) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" await self._safe_execute(table, sql, list(record.values()), record) return uid async def _server_update( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> int: if not args: return 0 # Check if table exists if not await self._table_exists(table): return 0 args["updated_at"] = self._utc_iso() # Ensure all columns exist all_cols = {**args, **(where or {})} await self._ensure_table(table, all_cols) for col, val in all_cols.items(): await self._ensure_column(table, col, val) set_clause = ", ".join(f"`{k}` = ?" for k in args) where_clause, where_params = self._build_where(where) sql = f"UPDATE {table} SET {set_clause}{where_clause}" params = list(args.values()) + where_params cur = await self._safe_execute(table, sql, params, all_cols) return cur.rowcount async def _server_delete( self, table: str, where: Optional[Dict[str, Any]] = None ) -> int: # Check if table exists if not await self._table_exists(table): return 0 where_clause, where_params = self._build_where(where) sql = f"DELETE FROM {table}{where_clause}" cur = await self._safe_execute(table, sql, where_params, where or {}) return cur.rowcount async def _server_upsert( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> str | None: if not args: raise ValueError("Nothing to update. Empty dict given.") args["updated_at"] = self._utc_iso() affected = await self._server_update(table, args, where) if affected: rec = await self._server_get(table, where) return rec.get("uid") if rec else None merged = {**(where or {}), **args} return await self._server_insert(table, merged) async def _server_get( self, table: str, where: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: where_clause, where_params = self._build_where(where) sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" async for row in self._safe_query(table, sql, where_params, where or {}): return row return None async def _server_find( self, table: str, where: Optional[Dict[str, Any]] = None, *, limit: int = 0, offset: int = 0, order_by: Optional[str] = None, ) -> List[Dict[str, Any]]: """Find records with optional ordering.""" where_clause, where_params = self._build_where(where) order_clause = f" ORDER BY {order_by}" if order_by else "" extra = (f" LIMIT {limit}" if limit else "") + ( f" OFFSET {offset}" if offset else "" ) sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" return [ row async for row in self._safe_query(table, sql, where_params, where or {}) ] async def _server_count( self, table: str, where: Optional[Dict[str, Any]] = None ) -> int: # Check if table exists if not await self._table_exists(table): return 0 where_clause, where_params = self._build_where(where) sql = f"SELECT COUNT(*) FROM {table}{where_clause}" gen = self._safe_query(table, sql, where_params, where or {}) async for row in gen: return next(iter(row.values()), 0) return 0 async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: return (await self._server_count(table, where)) > 0 async def _server_kv_set( self, key: str, value: Any, *, table: str | None = None, ) -> None: tbl = table or self._KV_TABLE json_val = json.dumps(value, default=str) await self._server_upsert(tbl, {"value": json_val}, {"key": key}) async def _server_kv_get( self, key: str, *, default: Any = None, table: str | None = None, ) -> Any: tbl = table or self._KV_TABLE row = await self._server_get(tbl, {"key": key}) if not row: return default try: return json.loads(row["value"]) except Exception: return default async def _server_execute_raw( self, sql: str, params: Optional[Tuple] = None ) -> Any: """Execute raw SQL for complex queries like JOINs.""" async with self._db_semaphore: cursor = await self._db_connection.execute(sql, params or ()) await self._db_connection.commit() return cursor async def _server_query_raw( self, sql: str, params: Optional[Tuple] = None ) -> List[Dict[str, Any]]: """Execute raw SQL query and return results as list of dicts.""" try: async with self._db_semaphore: self._db_connection.row_factory = aiosqlite.Row async with self._db_connection.execute(sql, params or ()) as cursor: return [dict(row) async for row in cursor] except aiosqlite.OperationalError: # Return empty list if query fails return [] async def _server_query_one( self, sql: str, params: Optional[Tuple] = None ) -> Optional[Dict[str, Any]]: """Execute raw SQL query and return single result.""" results = await self._server_query_raw(sql + " LIMIT 1", params) return results[0] if results else None async def _server_create_table( self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None, ): """Create table with custom schema and constraints. Always includes default columns.""" # Merge default columns with custom schema full_schema = self._DEFAULT_COLUMNS.copy() full_schema.update(schema) columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] if constraints: columns.extend(constraints) columns_sql = ", ".join(columns) async with self._db_semaphore: await self._db_connection.execute( f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" ) await self._db_connection.commit() await self._invalidate_column_cache(table) async def _server_insert_unique( self, table: str, args: Dict[str, Any], unique_fields: List[str] ) -> Union[str, None]: """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" try: return await self._server_insert(table, args) except aiosqlite.IntegrityError as e: if "UNIQUE" in str(e): return None raise async def _server_aggregate( self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None, ) -> Any: """Perform aggregate functions like SUM, AVG, MAX, MIN.""" # Check if table exists if not await self._table_exists(table): return None where_clause, where_params = self._build_where(where) sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" result = await self._server_query_one(sql, tuple(where_params)) return result["result"] if result else None async def close(self): """Close the connection or server.""" if not self._initialized: return if self._client_session: await self._client_session.close() if self._runner: await self._runner.cleanup() if os.path.exists(self._socket_path) and self._is_server: os.unlink(self._socket_path) if self._db_connection: await self._db_connection.close() class TransactionContext: """Context manager for database transactions.""" def __init__( self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None ): self.db_connection = db_connection self.semaphore = semaphore async def __aenter__(self): if self.semaphore: await self.semaphore.acquire() try: await self.db_connection.execute("BEGIN") return self.db_connection except: if self.semaphore: self.semaphore.release() raise async def __aexit__(self, exc_type, exc_val, exc_tb): try: if exc_type is None: await self.db_connection.commit() else: await self.db_connection.rollback() finally: if self.semaphore: self.semaphore.release() # Test cases remain the same but with additional tests for new functionality class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.db_path = Path("temp_test.db") if self.db_path.exists(): self.db_path.unlink() self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1) async def asyncTearDown(self): await self.connector.close() if self.db_path.exists(): self.db_path.unlink() async def test_insert_and_get(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertIsNotNone(rec) self.assertEqual(rec["name"], "John Doe") async def test_get_nonexistent(self): result = await self.connector.get("people", {"name": "Jane Doe"}) self.assertIsNone(result) async def test_update(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertEqual(rec["age"], 31) async def test_order_by(self): await self.connector.insert("people", {"name": "Alice", "age": 25}) await self.connector.insert("people", {"name": "Bob", "age": 30}) await self.connector.insert("people", {"name": "Charlie", "age": 20}) results = await self.connector.find("people", order_by="age ASC") self.assertEqual(results[0]["name"], "Charlie") self.assertEqual(results[-1]["name"], "Bob") async def test_raw_query(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) results = await self.connector.query_raw( "SELECT * FROM people WHERE age > ?", (26,) ) self.assertEqual(len(results), 1) self.assertEqual(results[0]["name"], "John") async def test_aggregate(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) await self.connector.insert("people", {"name": "Bob", "age": 35}) avg_age = await self.connector.aggregate("people", "AVG", "age") self.assertEqual(avg_age, 30) max_age = await self.connector.aggregate("people", "MAX", "age") self.assertEqual(max_age, 35) async def test_insert_with_auto_id(self): # Test auto-increment ID functionality id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) self.assertEqual(id2, id1 + 1) async def test_transaction(self): await self.connector._ensure_initialized() if self.connector._is_server: async with self.connector.transaction() as conn: await conn.execute( "INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", ("test-uid", "John", 30, "2024-01-01", "2024-01-01"), ) # Transaction will be committed rec = await self.connector.get("people", {"name": "John"}) self.assertIsNotNone(rec) async def test_create_custom_table(self): schema = { "id": "INTEGER PRIMARY KEY AUTOINCREMENT", "username": "TEXT NOT NULL", "email": "TEXT NOT NULL", "score": "INTEGER DEFAULT 0", } constraints = ["UNIQUE(username)", "UNIQUE(email)"] await self.connector.create_table("users", schema, constraints) # Test that table was created with constraints result = await self.connector.insert_unique( "users", {"username": "john", "email": "john@example.com"}, ["username", "email"], ) self.assertIsNotNone(result) # Test duplicate insert result = await self.connector.insert_unique( "users", {"username": "john", "email": "different@example.com"}, ["username", "email"], ) self.assertIsNone(result) async def test_missing_table_operations(self): # Test operations on non-existent tables self.assertEqual(await self.connector.count("nonexistent"), 0) self.assertEqual(await self.connector.find("nonexistent"), []) self.assertIsNone(await self.connector.get("nonexistent")) self.assertFalse(await self.connector.exists("nonexistent", {"id": 1})) self.assertEqual(await self.connector.delete("nonexistent"), 0) self.assertEqual( await self.connector.update("nonexistent", {"name": "test"}), 0 ) async def test_auto_column_creation(self): # Insert with new columns that don't exist yet await self.connector.insert( "dynamic", {"col1": "value1", "col2": 42, "col3": 3.14} ) # Add more columns in next insert await self.connector.insert( "dynamic", {"col1": "value2", "col4": True, "col5": None} ) # All records should be retrievable records = await self.connector.find("dynamic") self.assertEqual(len(records), 2) if __name__ == "__main__": unittest.main()