import asyncio import atexit import json import os import re import socket import unittest from collections.abc import AsyncGenerator, Iterable from datetime import datetime, timezone from pathlib import Path from typing import ( Any, Dict, List, Optional, Set, Tuple, Union, ) from uuid import uuid4 import aiohttp import aiosqlite from aiohttp import web class AsyncDataSet: """ Distributed AsyncDataSet with client-server model over Unix sockets. Parameters: - file: Path to the SQLite database file - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") - max_concurrent_queries: Maximum concurrent database operations (default: 1) Set to 1 to prevent "database is locked" errors with SQLite """ _KV_TABLE = "__kv_store" _DEFAULT_COLUMNS = { "uid": "TEXT PRIMARY KEY", "created_at": "TEXT", "updated_at": "TEXT", "deleted_at": "TEXT", } def __init__( self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 ): self._file = file self._socket_path = socket_path self._table_columns_cache: Dict[str, Set[str]] = {} self._is_server = None # None means not initialized yet self._server = None self._runner = None self._client_session = None self._initialized = False self._init_lock = None # Will be created when needed self._max_concurrent_queries = max_concurrent_queries self._db_semaphore = None # Will be created when needed self._db_connection = None # Persistent database connection async def _ensure_initialized(self): """Ensure the instance is initialized as server or client.""" if self._initialized: return # Create lock if needed (first time in async context) if self._init_lock is None: self._init_lock = asyncio.Lock() async with self._init_lock: if self._initialized: return await self._initialize() # Create semaphore for database operations (server only) if self._is_server: self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) self._initialized = True async def _initialize(self): """Initialize as server or client.""" try: # Try to create Unix socket if os.path.exists(self._socket_path): # Check if socket is active sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: sock.connect(self._socket_path) sock.close() # Socket is active, we're a client self._is_server = False await self._setup_client() except (ConnectionRefusedError, FileNotFoundError): # Socket file exists but not active, remove and become server os.unlink(self._socket_path) self._is_server = True await self._setup_server() else: # No socket file, become server self._is_server = True await self._setup_server() except Exception: # Fallback to client mode self._is_server = False await self._setup_client() # Establish persistent database connection self._db_connection = await aiosqlite.connect(self._file) async def _setup_server(self): """Setup aiohttp server.""" app = web.Application() # Add routes for all methods app.router.add_post("/insert", self._handle_insert) app.router.add_post("/update", self._handle_update) app.router.add_post("/delete", self._handle_delete) app.router.add_post("/upsert", self._handle_upsert) app.router.add_post("/get", self._handle_get) app.router.add_post("/find", self._handle_find) app.router.add_post("/count", self._handle_count) app.router.add_post("/exists", self._handle_exists) app.router.add_post("/kv_set", self._handle_kv_set) app.router.add_post("/kv_get", self._handle_kv_get) app.router.add_post("/execute_raw", self._handle_execute_raw) app.router.add_post("/query_raw", self._handle_query_raw) app.router.add_post("/query_one", self._handle_query_one) app.router.add_post("/create_table", self._handle_create_table) app.router.add_post("/insert_unique", self._handle_insert_unique) app.router.add_post("/aggregate", self._handle_aggregate) self._runner = web.AppRunner(app) await self._runner.setup() self._server = web.UnixSite(self._runner, self._socket_path) await self._server.start() # Register cleanup def cleanup(): if os.path.exists(self._socket_path): os.unlink(self._socket_path) atexit.register(cleanup) async def _setup_client(self): """Setup aiohttp client.""" connector = aiohttp.UnixConnector(path=self._socket_path) self._client_session = aiohttp.ClientSession(connector=connector) async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: """Make HTTP request to server.""" if not self._client_session: await self._setup_client() url = f"http://localhost/{endpoint}" async with self._client_session.post(url, json=data) as resp: result = await resp.json() if result.get("error"): raise Exception(result["error"]) return result.get("result") # Server handlers async def _handle_insert(self, request): data = await request.json() try: result = await self._server_insert( data["table"], data["args"], data.get("return_id", False) ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_update(self, request): data = await request.json() try: result = await self._server_update( data["table"], data["args"], data.get("where") ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_delete(self, request): data = await request.json() try: result = await self._server_delete(data["table"], data.get("where")) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_upsert(self, request): data = await request.json() try: result = await self._server_upsert( data["table"], data["args"], data.get("where") ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_get(self, request): data = await request.json() try: result = await self._server_get(data["table"], data.get("where")) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_find(self, request): data = await request.json() try: result = await self._server_find( data["table"], data.get("where"), limit=data.get("limit", 0), offset=data.get("offset", 0), order_by=data.get("order_by"), ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_count(self, request): data = await request.json() try: result = await self._server_count(data["table"], data.get("where")) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_exists(self, request): data = await request.json() try: result = await self._server_exists(data["table"], data["where"]) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_kv_set(self, request): data = await request.json() try: await self._server_kv_set( data["key"], data["value"], table=data.get("table") ) return web.json_response({"result": None}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_kv_get(self, request): data = await request.json() try: result = await self._server_kv_get( data["key"], default=data.get("default"), table=data.get("table") ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_execute_raw(self, request): data = await request.json() try: result = await self._server_execute_raw( data["sql"], tuple(data.get("params", [])) if data.get("params") else None, ) return web.json_response( {"result": result.rowcount if hasattr(result, "rowcount") else None} ) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_query_raw(self, request): data = await request.json() try: result = await self._server_query_raw( data["sql"], tuple(data.get("params", [])) if data.get("params") else None, ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_query_one(self, request): data = await request.json() try: result = await self._server_query_one( data["sql"], tuple(data.get("params", [])) if data.get("params") else None, ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_create_table(self, request): data = await request.json() try: await self._server_create_table( data["table"], data["schema"], data.get("constraints") ) return web.json_response({"result": None}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_insert_unique(self, request): data = await request.json() try: result = await self._server_insert_unique( data["table"], data["args"], data["unique_fields"] ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) async def _handle_aggregate(self, request): data = await request.json() try: result = await self._server_aggregate( data["table"], data["function"], data.get("column", "*"), data.get("where"), ) return web.json_response({"result": result}) except Exception as e: return web.json_response({"error": str(e)}, status=500) # Public methods that delegate to server or client async def insert( self, table: str, args: Dict[str, Any], return_id: bool = False ) -> Union[str, int]: await self._ensure_initialized() if self._is_server: return await self._server_insert(table, args, return_id) else: return await self._make_request( "insert", {"table": table, "args": args, "return_id": return_id} ) async def update( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None ) -> int: await self._ensure_initialized() if self._is_server: return await self._server_update(table, args, where) else: return await self._make_request( "update", {"table": table, "args": args, "where": where} ) async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: await self._ensure_initialized() if self._is_server: return await self._server_delete(table, where) else: return await self._make_request("delete", {"table": table, "where": where}) async def upsert( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None ) -> str | None: await self._ensure_initialized() if self._is_server: return await self._server_upsert(table, args, where) else: return await self._make_request( "upsert", {"table": table, "args": args, "where": where} ) async def get( self, table: str, where: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_get(table, where) else: return await self._make_request("get", {"table": table, "where": where}) async def find( self, table: str, where: Optional[Dict[str, Any]] = None, *, limit: int = 0, offset: int = 0, order_by: Optional[str] = None, ) -> List[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_find( table, where, limit=limit, offset=offset, order_by=order_by ) else: return await self._make_request( "find", { "table": table, "where": where, "limit": limit, "offset": offset, "order_by": order_by, }, ) async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: await self._ensure_initialized() if self._is_server: return await self._server_count(table, where) else: return await self._make_request("count", {"table": table, "where": where}) async def exists(self, table: str, where: Dict[str, Any]) -> bool: await self._ensure_initialized() if self._is_server: return await self._server_exists(table, where) else: return await self._make_request("exists", {"table": table, "where": where}) async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None: await self._ensure_initialized() if self._is_server: return await self._server_kv_set(key, value, table=table) else: return await self._make_request( "kv_set", {"key": key, "value": value, "table": table} ) async def kv_get( self, key: str, *, default: Any = None, table: str | None = None ) -> Any: await self._ensure_initialized() if self._is_server: return await self._server_kv_get(key, default=default, table=table) else: return await self._make_request( "kv_get", {"key": key, "default": default, "table": table} ) async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: await self._ensure_initialized() if self._is_server: return await self._server_execute_raw(sql, params) else: return await self._make_request( "execute_raw", {"sql": sql, "params": list(params) if params else None} ) async def query_raw( self, sql: str, params: Optional[Tuple] = None ) -> List[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_query_raw(sql, params) else: return await self._make_request( "query_raw", {"sql": sql, "params": list(params) if params else None} ) async def query_one( self, sql: str, params: Optional[Tuple] = None ) -> Optional[Dict[str, Any]]: await self._ensure_initialized() if self._is_server: return await self._server_query_one(sql, params) else: return await self._make_request( "query_one", {"sql": sql, "params": list(params) if params else None} ) async def create_table( self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None, ): await self._ensure_initialized() if self._is_server: return await self._server_create_table(table, schema, constraints) else: return await self._make_request( "create_table", {"table": table, "schema": schema, "constraints": constraints}, ) async def insert_unique( self, table: str, args: Dict[str, Any], unique_fields: List[str] ) -> Union[str, None]: await self._ensure_initialized() if self._is_server: return await self._server_insert_unique(table, args, unique_fields) else: return await self._make_request( "insert_unique", {"table": table, "args": args, "unique_fields": unique_fields}, ) async def aggregate( self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None, ) -> Any: await self._ensure_initialized() if self._is_server: return await self._server_aggregate(table, function, column, where) else: return await self._make_request( "aggregate", { "table": table, "function": function, "column": column, "where": where, }, ) async def transaction(self): """Context manager for transactions.""" await self._ensure_initialized() if self._is_server: return TransactionContext(self._db_connection, self._db_semaphore) else: raise NotImplementedError("Transactions not supported in client mode") # Server implementation methods (original logic) @staticmethod def _utc_iso() -> str: return ( datetime.now(timezone.utc) .replace(microsecond=0) .isoformat() .replace("+00:00", "Z") ) @staticmethod def _py_to_sqlite_type(value: Any) -> str: if value is None: return "TEXT" if isinstance(value, bool): return "INTEGER" if isinstance(value, int): return "INTEGER" if isinstance(value, float): return "REAL" if isinstance(value, (bytes, bytearray, memoryview)): return "BLOB" return "TEXT" async def _get_table_columns(self, table: str) -> Set[str]: """Get actual columns that exist in the table.""" if table in self._table_columns_cache: return self._table_columns_cache[table] columns = set() try: async with self._db_semaphore: async with self._db_connection.execute( f"PRAGMA table_info({table})" ) as cursor: async for row in cursor: columns.add(row[1]) # Column name is at index 1 self._table_columns_cache[table] = columns except: pass return columns async def _invalidate_column_cache(self, table: str): """Invalidate column cache for a table.""" if table in self._table_columns_cache: del self._table_columns_cache[table] async def _ensure_column(self, table: str, name: str, value: Any) -> None: col_type = self._py_to_sqlite_type(value) try: async with self._db_semaphore: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" ) await self._db_connection.commit() await self._invalidate_column_cache(table) except aiosqlite.OperationalError as e: if "duplicate column name" in str(e).lower(): pass # Column already exists else: raise async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: # Always include default columns cols = self._DEFAULT_COLUMNS.copy() # Add columns from col_sources for key, val in col_sources.items(): if key not in cols: cols[key] = self._py_to_sqlite_type(val) columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) async with self._db_semaphore: await self._db_connection.execute( f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" ) await self._db_connection.commit() await self._invalidate_column_cache(table) async def _table_exists(self, table: str) -> bool: """Check if a table exists.""" async with self._db_semaphore: async with self._db_connection.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) ) as cursor: return await cursor.fetchone() is not None _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") _RE_NO_TABLE = re.compile(r"no such table: (\w+)") @classmethod def _missing_column_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_COLUMN.search(str(err)) return m.group(1) if m else None @classmethod def _missing_table_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_TABLE.search(str(err)) return m.group(1) if m else None async def _safe_execute( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], max_retries: int = 10, ) -> aiosqlite.Cursor: retries = 0 while retries < max_retries: try: async with self._db_semaphore: cursor = await self._db_connection.execute(sql, params) await self._db_connection.commit() return cursor except aiosqlite.OperationalError as err: retries += 1 err_str = str(err).lower() # Handle missing column col = self._missing_column_from_error(err) if col: if col in col_sources: await self._ensure_column(table, col, col_sources[col]) else: # Column not in sources, ensure it with NULL/TEXT type await self._ensure_column(table, col, None) continue # Handle missing table tbl = self._missing_table_from_error(err) if tbl: await self._ensure_table(tbl, col_sources) continue # Handle other column-related errors if "has no column named" in err_str: # Extract column name differently match = re.search(r"table \w+ has no column named (\w+)", err_str) if match: col_name = match.group(1) if col_name in col_sources: await self._ensure_column( table, col_name, col_sources[col_name] ) else: await self._ensure_column(table, col_name, None) continue raise raise Exception(f"Max retries ({max_retries}) exceeded") async def _filter_existing_columns( self, table: str, data: Dict[str, Any] ) -> Dict[str, Any]: """Filter data to only include columns that exist in the table.""" if not await self._table_exists(table): return data existing_columns = await self._get_table_columns(table) if not existing_columns: return data return {k: v for k, v in data.items() if k in existing_columns} async def _safe_query( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], ) -> AsyncGenerator[Dict[str, Any], None]: # Check if table exists first if not await self._table_exists(table): return max_retries = 10 retries = 0 while retries < max_retries: try: async with self._db_semaphore: self._db_connection.row_factory = aiosqlite.Row async with self._db_connection.execute(sql, params) as cursor: # Fetch all rows while holding the semaphore rows = await cursor.fetchall() # Yield rows after releasing the semaphore for row in rows: yield dict(row) return except aiosqlite.OperationalError as err: retries += 1 err_str = str(err).lower() # Handle missing table tbl = self._missing_table_from_error(err) if tbl: # For queries, if table doesn't exist, just return empty return # Handle missing column in WHERE clause or SELECT if "no such column" in err_str: # For queries with missing columns, return empty return raise @staticmethod def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: if not where: return "", [] clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) return " WHERE " + " AND ".join(clauses), list(vals) async def _server_insert( self, table: str, args: Dict[str, Any], return_id: bool = False ) -> Union[str, int]: """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" uid = str(uuid4()) now = self._utc_iso() record = { "uid": uid, "created_at": now, "updated_at": now, "deleted_at": None, **args, } # Ensure table exists with all needed columns await self._ensure_table(table, record) # Handle auto-increment ID if requested if return_id and "id" not in args: # Ensure id column exists async with self._db_semaphore: # Add id column if it doesn't exist try: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" ) await self._db_connection.commit() except aiosqlite.OperationalError as e: if "duplicate column name" not in str(e).lower(): # Try without autoincrement constraint try: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN id INTEGER" ) await self._db_connection.commit() except: pass await self._invalidate_column_cache(table) # Insert and get lastrowid cols = "`" + "`, `".join(record.keys()) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" cursor = await self._safe_execute(table, sql, list(record.values()), record) return cursor.lastrowid cols = "`" + "`, `".join(record) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" await self._safe_execute(table, sql, list(record.values()), record) return uid async def _server_update( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> int: if not args: return 0 # Check if table exists if not await self._table_exists(table): return 0 args["updated_at"] = self._utc_iso() # Ensure all columns exist all_cols = {**args, **(where or {})} await self._ensure_table(table, all_cols) for col, val in all_cols.items(): await self._ensure_column(table, col, val) set_clause = ", ".join(f"`{k}` = ?" for k in args) where_clause, where_params = self._build_where(where) sql = f"UPDATE {table} SET {set_clause}{where_clause}" params = list(args.values()) + where_params cur = await self._safe_execute(table, sql, params, all_cols) return cur.rowcount async def _server_delete( self, table: str, where: Optional[Dict[str, Any]] = None ) -> int: # Check if table exists if not await self._table_exists(table): return 0 where_clause, where_params = self._build_where(where) sql = f"DELETE FROM {table}{where_clause}" cur = await self._safe_execute(table, sql, where_params, where or {}) return cur.rowcount async def _server_upsert( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> str | None: if not args: raise ValueError("Nothing to update. Empty dict given.") args["updated_at"] = self._utc_iso() affected = await self._server_update(table, args, where) if affected: rec = await self._server_get(table, where) return rec.get("uid") if rec else None merged = {**(where or {}), **args} return await self._server_insert(table, merged) async def _server_get( self, table: str, where: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: where_clause, where_params = self._build_where(where) sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" async for row in self._safe_query(table, sql, where_params, where or {}): return row return None async def _server_find( self, table: str, where: Optional[Dict[str, Any]] = None, *, limit: int = 0, offset: int = 0, order_by: Optional[str] = None, ) -> List[Dict[str, Any]]: """Find records with optional ordering.""" where_clause, where_params = self._build_where(where) order_clause = f" ORDER BY {order_by}" if order_by else "" extra = (f" LIMIT {limit}" if limit else "") + ( f" OFFSET {offset}" if offset else "" ) sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" return [ row async for row in self._safe_query(table, sql, where_params, where or {}) ] async def _server_count( self, table: str, where: Optional[Dict[str, Any]] = None ) -> int: # Check if table exists if not await self._table_exists(table): return 0 where_clause, where_params = self._build_where(where) sql = f"SELECT COUNT(*) FROM {table}{where_clause}" gen = self._safe_query(table, sql, where_params, where or {}) async for row in gen: return next(iter(row.values()), 0) return 0 async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: return (await self._server_count(table, where)) > 0 async def _server_kv_set( self, key: str, value: Any, *, table: str | None = None, ) -> None: tbl = table or self._KV_TABLE json_val = json.dumps(value, default=str) await self._server_upsert(tbl, {"value": json_val}, {"key": key}) async def _server_kv_get( self, key: str, *, default: Any = None, table: str | None = None, ) -> Any: tbl = table or self._KV_TABLE row = await self._server_get(tbl, {"key": key}) if not row: return default try: return json.loads(row["value"]) except Exception: return default async def _server_execute_raw( self, sql: str, params: Optional[Tuple] = None ) -> Any: """Execute raw SQL for complex queries like JOINs.""" async with self._db_semaphore: cursor = await self._db_connection.execute(sql, params or ()) await self._db_connection.commit() return cursor async def _server_query_raw( self, sql: str, params: Optional[Tuple] = None ) -> List[Dict[str, Any]]: """Execute raw SQL query and return results as list of dicts.""" try: async with self._db_semaphore: self._db_connection.row_factory = aiosqlite.Row async with self._db_connection.execute(sql, params or ()) as cursor: return [dict(row) async for row in cursor] except aiosqlite.OperationalError: # Return empty list if query fails return [] async def _server_query_one( self, sql: str, params: Optional[Tuple] = None ) -> Optional[Dict[str, Any]]: """Execute raw SQL query and return single result.""" results = await self._server_query_raw(sql + " LIMIT 1", params) return results[0] if results else None async def _server_create_table( self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None, ): """Create table with custom schema and constraints. Always includes default columns.""" # Merge default columns with custom schema full_schema = self._DEFAULT_COLUMNS.copy() full_schema.update(schema) columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] if constraints: columns.extend(constraints) columns_sql = ", ".join(columns) async with self._db_semaphore: await self._db_connection.execute( f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" ) await self._db_connection.commit() await self._invalidate_column_cache(table) async def _server_insert_unique( self, table: str, args: Dict[str, Any], unique_fields: List[str] ) -> Union[str, None]: """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" try: return await self._server_insert(table, args) except aiosqlite.IntegrityError as e: if "UNIQUE" in str(e): return None raise async def _server_aggregate( self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None, ) -> Any: """Perform aggregate functions like SUM, AVG, MAX, MIN.""" # Check if table exists if not await self._table_exists(table): return None where_clause, where_params = self._build_where(where) sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" result = await self._server_query_one(sql, tuple(where_params)) return result["result"] if result else None async def close(self): """Close the connection or server.""" if not self._initialized: return if self._client_session: await self._client_session.close() if self._runner: await self._runner.cleanup() if os.path.exists(self._socket_path) and self._is_server: os.unlink(self._socket_path) if self._db_connection: await self._db_connection.close() class TransactionContext: """Context manager for database transactions.""" def __init__( self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None ): self.db_connection = db_connection self.semaphore = semaphore async def __aenter__(self): if self.semaphore: await self.semaphore.acquire() try: await self.db_connection.execute("BEGIN") return self.db_connection except: if self.semaphore: self.semaphore.release() raise async def __aexit__(self, exc_type, exc_val, exc_tb): try: if exc_type is None: await self.db_connection.commit() else: await self.db_connection.rollback() finally: if self.semaphore: self.semaphore.release() # Test cases remain the same but with additional tests for new functionality class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.db_path = Path("temp_test.db") if self.db_path.exists(): self.db_path.unlink() self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1) async def asyncTearDown(self): await self.connector.close() if self.db_path.exists(): self.db_path.unlink() async def test_insert_and_get(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertIsNotNone(rec) self.assertEqual(rec["name"], "John Doe") async def test_get_nonexistent(self): result = await self.connector.get("people", {"name": "Jane Doe"}) self.assertIsNone(result) async def test_update(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertEqual(rec["age"], 31) async def test_order_by(self): await self.connector.insert("people", {"name": "Alice", "age": 25}) await self.connector.insert("people", {"name": "Bob", "age": 30}) await self.connector.insert("people", {"name": "Charlie", "age": 20}) results = await self.connector.find("people", order_by="age ASC") self.assertEqual(results[0]["name"], "Charlie") self.assertEqual(results[-1]["name"], "Bob") async def test_raw_query(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) results = await self.connector.query_raw( "SELECT * FROM people WHERE age > ?", (26,) ) self.assertEqual(len(results), 1) self.assertEqual(results[0]["name"], "John") async def test_aggregate(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) await self.connector.insert("people", {"name": "Bob", "age": 35}) avg_age = await self.connector.aggregate("people", "AVG", "age") self.assertEqual(avg_age, 30) max_age = await self.connector.aggregate("people", "MAX", "age") self.assertEqual(max_age, 35) async def test_insert_with_auto_id(self): # Test auto-increment ID functionality id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) self.assertEqual(id2, id1 + 1) async def test_transaction(self): await self.connector._ensure_initialized() if self.connector._is_server: async with self.connector.transaction() as conn: await conn.execute( "INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", ("test-uid", "John", 30, "2024-01-01", "2024-01-01"), ) # Transaction will be committed rec = await self.connector.get("people", {"name": "John"}) self.assertIsNotNone(rec) async def test_create_custom_table(self): schema = { "id": "INTEGER PRIMARY KEY AUTOINCREMENT", "username": "TEXT NOT NULL", "email": "TEXT NOT NULL", "score": "INTEGER DEFAULT 0", } constraints = ["UNIQUE(username)", "UNIQUE(email)"] await self.connector.create_table("users", schema, constraints) # Test that table was created with constraints result = await self.connector.insert_unique( "users", {"username": "john", "email": "john@example.com"}, ["username", "email"], ) self.assertIsNotNone(result) # Test duplicate insert result = await self.connector.insert_unique( "users", {"username": "john", "email": "different@example.com"}, ["username", "email"], ) self.assertIsNone(result) async def test_missing_table_operations(self): # Test operations on non-existent tables self.assertEqual(await self.connector.count("nonexistent"), 0) self.assertEqual(await self.connector.find("nonexistent"), []) self.assertIsNone(await self.connector.get("nonexistent")) self.assertFalse(await self.connector.exists("nonexistent", {"id": 1})) self.assertEqual(await self.connector.delete("nonexistent"), 0) self.assertEqual( await self.connector.update("nonexistent", {"name": "test"}), 0 ) async def test_auto_column_creation(self): # Insert with new columns that don't exist yet await self.connector.insert( "dynamic", {"col1": "value1", "col2": 42, "col3": 3.14} ) # Add more columns in next insert await self.connector.insert( "dynamic", {"col1": "value2", "col4": True, "col5": None} ) # All records should be retrievable records = await self.connector.find("dynamic") self.assertEqual(len(records), 2) if __name__ == "__main__": unittest.main()