diff --git a/src/snek/system/ads.py b/src/snek/system/ads.py index 06b2083..f47c5fd 100644 --- a/src/snek/system/ads.py +++ b/src/snek/system/ads.py @@ -27,6 +27,7 @@ from aiohttp import web class AsyncDataSet: """ Distributed AsyncDataSet with client-server model over Unix sockets. + Enhanced with dataset API compatibility. Parameters: - file: Path to the SQLite database file @@ -43,6 +44,23 @@ class AsyncDataSet: "deleted_at": "TEXT", } + # Operator mappings for complex WHERE conditions + _OPERATORS = { + 'eq': '=', + 'ne': '!=', + 'lt': '<', + 'le': '<=', + 'gt': '>', + 'ge': '>=', + 'like': 'LIKE', + 'ilike': 'LIKE', # SQLite LIKE is case-insensitive by default + 'in': 'IN', + 'notin': 'NOT IN', + 'between': 'BETWEEN', + 'is': 'IS', + 'isnot': 'IS NOT', + } + def __init__( self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 ): @@ -58,6 +76,7 @@ class AsyncDataSet: self._max_concurrent_queries = max_concurrent_queries self._db_semaphore = None # Will be created when needed self._db_connection = None # Persistent database connection + self._in_transaction = False # Track transaction state async def _ensure_initialized(self): """Ensure the instance is initialized as server or client.""" @@ -131,6 +150,15 @@ class AsyncDataSet: app.router.add_post("/create_table", self._handle_create_table) app.router.add_post("/insert_unique", self._handle_insert_unique) app.router.add_post("/aggregate", self._handle_aggregate) + app.router.add_post("/distinct", self._handle_distinct) + app.router.add_post("/drop_table", self._handle_drop_table) + app.router.add_post("/has_table", self._handle_has_table) + app.router.add_post("/has_column", self._handle_has_column) + app.router.add_post("/get_tables", self._handle_get_tables) + app.router.add_post("/get_columns", self._handle_get_columns) + app.router.add_post("/begin", self._handle_begin) + app.router.add_post("/commit", self._handle_commit) + app.router.add_post("/rollback", self._handle_rollback) self._runner = web.AppRunner(app) await self._runner.setup() @@ -161,6 +189,77 @@ class AsyncDataSet: raise Exception(result["error"]) return result.get("result") + # New server handlers + async def _handle_distinct(self, request): + data = await request.json() + try: + result = await self._server_distinct( + data["table"], data["columns"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_drop_table(self, request): + data = await request.json() + try: + await self._server_drop_table(data["table"]) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_has_table(self, request): + data = await request.json() + try: + result = await self._server_has_table(data["table"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_has_column(self, request): + data = await request.json() + try: + result = await self._server_has_column(data["table"], data["column"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_get_tables(self, request): + try: + result = await self._server_get_tables() + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_get_columns(self, request): + data = await request.json() + try: + result = await self._server_get_columns(data["table"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_begin(self, request): + try: + await self._server_begin() + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_commit(self, request): + try: + await self._server_commit() + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_rollback(self, request): + try: + await self._server_rollback() + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + # Server handlers async def _handle_insert(self, request): data = await request.json() @@ -407,6 +506,19 @@ class AsyncDataSet: }, ) + # New alias for find() to match dataset API + async def all( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + _limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Alias for find() to match dataset API.""" + return await self.find(table, where, _limit=_limit, offset=offset, order_by=order_by) + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: await self._ensure_initialized() if self._is_server: @@ -520,14 +632,121 @@ class AsyncDataSet: }, ) - async def transaction(self): - """Context manager for transactions.""" + # New public methods + async def distinct( + self, + table: str, + columns: Union[str, List[str]], + where: Optional[Dict[str, Any]] = None + ) -> List[Any]: + """Get distinct values for specified columns.""" await self._ensure_initialized() + if isinstance(columns, str): + columns = [columns] + if self._is_server: + return await self._server_distinct(table, columns, where) + else: + return await self._make_request( + "distinct", {"table": table, "columns": columns, "where": where} + ) + + async def drop(self, table: str) -> None: + """Drop a table from the database.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_drop_table(table) + else: + return await self._make_request("drop_table", {"table": table}) + + async def drop_table(self, table: str) -> None: + """Alias for drop() for clarity.""" + return await self.drop(table) + + async def has_table(self, table: str) -> bool: + """Check if a table exists.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_has_table(table) + else: + return await self._make_request("has_table", {"table": table}) + + async def has_column(self, table: str, column: str) -> bool: + """Check if a column exists in a table.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_has_column(table, column) + else: + return await self._make_request( + "has_column", {"table": table, "column": column} + ) + + @property + async def tables(self) -> List[str]: + """Get list of all tables in the database.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_get_tables() + else: + return await self._make_request("get_tables", {}) + + async def columns(self, table: str) -> List[str]: + """Get list of columns for a table.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_get_columns(table) + else: + return await self._make_request("get_columns", {"table": table}) + + async def begin(self) -> None: + """Begin a transaction.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_begin() + else: + return await self._make_request("begin", {}) + + async def commit(self) -> None: + """Commit a transaction.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_commit() + else: + return await self._make_request("commit", {}) + + async def rollback(self) -> None: + """Rollback a transaction.""" + await self._ensure_initialized() + if self._is_server: + return await self._server_rollback() + else: + return await self._make_request("rollback", {}) + + async def __aenter__(self): + """Context manager entry - begins transaction.""" + await self.begin() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - commits or rolls back transaction.""" + if exc_type is None: + await self.commit() + else: + await self.rollback() + + def transaction(self): + """Context manager for transactions. Note: Must call _ensure_initialized before use.""" + if not self._initialized: + raise RuntimeError("Must await _ensure_initialized() before using transaction()") if self._is_server: return TransactionContext(self._db_connection, self._db_semaphore) else: raise NotImplementedError("Transactions not supported in client mode") + # Enhanced query method with alias + async def query(self, sql, *args): + """Alias for query_raw for compatibility.""" + return await self.query_raw(sql, args if args else None) + # Server implementation methods (original logic) @staticmethod def _utc_iso() -> str: @@ -582,7 +801,8 @@ class AsyncDataSet: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" ) - await self._db_connection.commit() + if not self._in_transaction: + await self._db_connection.commit() await self._invalidate_column_cache(table) except aiosqlite.OperationalError as e: if "duplicate column name" in str(e).lower(): @@ -591,8 +811,22 @@ class AsyncDataSet: raise async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + # Check if any column in col_sources might be a primary key (heuristic check) + has_primary_key = any( + k.lower() == 'id' or k.lower().endswith('_id') + for k in col_sources.keys() + ) + # Always include default columns - cols = self._DEFAULT_COLUMNS.copy() + cols = {} + + # Add default columns, but modify uid if there might be another primary key + for col, dtype in self._DEFAULT_COLUMNS.items(): + if col == "uid" and has_primary_key: + # Use TEXT UNIQUE instead of PRIMARY KEY if there might be another primary key + cols[col] = "TEXT UNIQUE" + else: + cols[col] = dtype # Add columns from col_sources for key, val in col_sources.items(): @@ -604,7 +838,8 @@ class AsyncDataSet: await self._db_connection.execute( f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" ) - await self._db_connection.commit() + if not self._in_transaction: + await self._db_connection.commit() await self._invalidate_column_cache(table) async def _table_exists(self, table: str) -> bool: @@ -645,7 +880,8 @@ class AsyncDataSet: try: async with self._db_semaphore: cursor = await self._db_connection.execute(sql, params) - await self._db_connection.commit() + if not self._in_transaction: + await self._db_connection.commit() return cursor except aiosqlite.OperationalError as err: retries += 1 @@ -742,15 +978,97 @@ class AsyncDataSet: @staticmethod def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + """ + Build WHERE clause with support for comparison operators. + + Supports: + - Simple equality: {"name": "John"} + - Operators: {"age__gt": 25, "name__like": "%John%"} + - IN queries: {"id__in": [1, 2, 3]} + - BETWEEN: {"age__between": [20, 30]} + - Complex conditions with $or, $and + """ if not where: return "", [] - clauses, vals = zip( - *[ - (f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v) - for k, v in where.items() - ] - ) - return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None] + + clauses = [] + vals = [] + + for key, value in where.items(): + # Handle special keys + if key == "$or": + or_clauses = [] + for or_condition in value: + or_where, or_vals = AsyncDataSet._build_where(or_condition) + if or_where: + or_clauses.append(or_where.replace("WHERE ", "")) + vals.extend(or_vals) + if or_clauses: + clauses.append(f"({' OR '.join(or_clauses)})") + continue + + if key == "$and": + for and_condition in value: + and_where, and_vals = AsyncDataSet._build_where(and_condition) + if and_where: + clauses.append(and_where.replace("WHERE ", "")) + vals.extend(and_vals) + continue + + # Parse field and operator + if "__" in key: + field, op = key.rsplit("__", 1) + op = op.lower() + else: + field = key + op = "eq" + + # Build clause based on operator + if op in AsyncDataSet._OPERATORS: + sql_op = AsyncDataSet._OPERATORS[op] + + if op == "in": + if value: + placeholders = ", ".join(["?" for _ in value]) + clauses.append(f"`{field}` IN ({placeholders})") + vals.extend(value) + elif op == "notin": + if value: + placeholders = ", ".join(["?" for _ in value]) + clauses.append(f"`{field}` NOT IN ({placeholders})") + vals.extend(value) + elif op == "between": + clauses.append(f"`{field}` BETWEEN ? AND ?") + vals.extend(value) + elif op in ["is", "isnot"]: + if value is None: + clauses.append(f"`{field}` {sql_op} NULL") + else: + clauses.append(f"`{field}` {sql_op} ?") + vals.append(value) + elif op in ["like", "ilike"]: + clauses.append(f"`{field}` {sql_op} ?") + vals.append(value) + else: + if value is None: + if op == "eq": + clauses.append(f"`{field}` IS NULL") + elif op == "ne": + clauses.append(f"`{field}` IS NOT NULL") + else: + clauses.append(f"`{field}` {sql_op} ?") + vals.append(value) + else: + # Default to equality + if value is None: + clauses.append(f"`{field}` IS NULL") + else: + clauses.append(f"`{field}` = ?") + vals.append(value) + + if clauses: + return " WHERE " + " AND ".join(clauses), vals + return "", [] async def _server_insert( self, table: str, args: Dict[str, Any], return_id: bool = False @@ -766,38 +1084,44 @@ class AsyncDataSet: **args, } - # Ensure table exists with all needed columns - await self._ensure_table(table, record) - # Handle auto-increment ID if requested if return_id and "id" not in args: - # Ensure id column exists - async with self._db_semaphore: - # Add id column if it doesn't exist - try: - await self._db_connection.execute( - f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" - ) - await self._db_connection.commit() - except aiosqlite.OperationalError as e: - if "duplicate column name" not in str(e).lower(): - # Try without autoincrement constraint + # For tables with auto-increment ID, ensure the table is created properly + # Check if table exists first + if not await self._table_exists(table): + # Create table with id as primary key and uid as unique + schema = {"id": "INTEGER PRIMARY KEY AUTOINCREMENT"} + await self._server_create_table(table, schema, None) + else: + # Ensure id column exists + columns = await self._get_table_columns(table) + if "id" not in columns: + async with self._db_semaphore: + # Add id column if it doesn't exist try: await self._db_connection.execute( f"ALTER TABLE {table} ADD COLUMN id INTEGER" ) - await self._db_connection.commit() - except: - pass - - await self._invalidate_column_cache(table) + if not self._in_transaction: + await self._db_connection.commit() + except aiosqlite.OperationalError as e: + if "duplicate column name" not in str(e).lower(): + raise + await self._invalidate_column_cache(table) # Insert and get lastrowid cols = "`" + "`, `".join(record.keys()) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" - cursor = await self._safe_execute(table, sql, list(record.values()), record) + try: + cursor = await self._safe_execute(table, sql, list(record.values()), record) + except Exception as ex: + print("Error during insert:", ex) + raise return cursor.lastrowid + else: + # Ensure table exists with all needed columns + await self._ensure_table(table, record) cols = "`" + "`, `".join(record) + "`" qs = ", ".join(["?"] * len(record)) @@ -824,7 +1148,8 @@ class AsyncDataSet: all_cols = {**args, **(where or {})} await self._ensure_table(table, all_cols) for col, val in all_cols.items(): - await self._ensure_column(table, col, val) + if "__" not in col: # Skip operator-based keys + await self._ensure_column(table, col, val) set_clause = ", ".join(f"`{k}` = ?" for k in args) where_clause, where_params = self._build_where(where) @@ -940,18 +1265,14 @@ class AsyncDataSet: except Exception: return default - async def query(self, sql, *args): - return await self._server_query_raw(sql, args) - async def _server_execute_raw( self, sql: str, params: Optional[Tuple] = None ) -> Any: """Execute raw SQL for complex queries like JOINs.""" async with self._db_semaphore: cursor = await self._db_connection.execute(sql, params or ()) - await self._db_connection.commit() - for x in range(10): - print(cursor) + if not self._in_transaction: + await self._db_connection.commit() return cursor async def _server_query_raw( @@ -982,8 +1303,21 @@ class AsyncDataSet: constraints: Optional[List[str]] = None, ): """Create table with custom schema and constraints. Always includes default columns.""" + # Check if schema already has a primary key + has_primary_key = any("PRIMARY KEY" in dtype.upper() for dtype in schema.values()) + # Merge default columns with custom schema - full_schema = self._DEFAULT_COLUMNS.copy() + full_schema = {} + + # Add default columns, but modify uid if there's already a primary key + for col, dtype in self._DEFAULT_COLUMNS.items(): + if col == "uid" and has_primary_key: + # Remove PRIMARY KEY constraint from uid if another primary key exists + full_schema[col] = "TEXT UNIQUE" + else: + full_schema[col] = dtype + + # Add custom schema columns (these override defaults if same name) full_schema.update(schema) columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] @@ -995,7 +1329,8 @@ class AsyncDataSet: await self._db_connection.execute( f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" ) - await self._db_connection.commit() + if not self._in_transaction: + await self._db_connection.commit() await self._invalidate_column_cache(table) async def _server_insert_unique( @@ -1026,6 +1361,78 @@ class AsyncDataSet: result = await self._server_query_one(sql, tuple(where_params)) return result["result"] if result else None + # New server implementation methods + async def _server_distinct( + self, table: str, columns: List[str], where: Optional[Dict[str, Any]] = None + ) -> List[Any]: + """Get distinct values for specified columns.""" + if not await self._table_exists(table): + return [] + + cols = ", ".join(f"`{col}`" for col in columns) + where_clause, where_params = self._build_where(where) + sql = f"SELECT DISTINCT {cols} FROM {table}{where_clause}" + + results = [] + async for row in self._safe_query(table, sql, where_params, where or {}): + if len(columns) == 1: + results.append(row[columns[0]]) + else: + results.append(tuple(row[col] for col in columns)) + return results + + async def _server_drop_table(self, table: str) -> None: + """Drop a table from the database.""" + async with self._db_semaphore: + await self._db_connection.execute(f"DROP TABLE IF EXISTS {table}") + if not self._in_transaction: + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _server_has_table(self, table: str) -> bool: + """Check if a table exists.""" + return await self._table_exists(table) + + async def _server_has_column(self, table: str, column: str) -> bool: + """Check if a column exists in a table.""" + if not await self._table_exists(table): + return False + columns = await self._get_table_columns(table) + return column in columns + + async def _server_get_tables(self) -> List[str]: + """Get list of all tables in the database.""" + async with self._db_semaphore: + async with self._db_connection.execute( + "SELECT name FROM sqlite_master WHERE type='table'" + ) as cursor: + return [row[0] async for row in cursor if not row[0].startswith("sqlite_")] + + async def _server_get_columns(self, table: str) -> List[str]: + """Get list of columns for a table.""" + if not await self._table_exists(table): + return [] + columns = await self._get_table_columns(table) + return list(columns) + + async def _server_begin(self) -> None: + """Begin a transaction.""" + async with self._db_semaphore: + await self._db_connection.execute("BEGIN") + self._in_transaction = True + + async def _server_commit(self) -> None: + """Commit a transaction.""" + async with self._db_semaphore: + await self._db_connection.commit() + self._in_transaction = False + + async def _server_rollback(self) -> None: + """Rollback a transaction.""" + async with self._db_semaphore: + await self._db_connection.rollback() + self._in_transaction = False + async def close(self): """Close the connection or server.""" if not self._initialized: @@ -1072,7 +1479,7 @@ class TransactionContext: self.semaphore.release() -# Test cases remain the same but with additional tests for new functionality +# Enhanced test cases with new functionality class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): @@ -1111,6 +1518,177 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): self.assertEqual(results[0]["name"], "Charlie") self.assertEqual(results[-1]["name"], "Bob") + async def test_comparison_operators(self): + """Test new comparison operators in WHERE clauses.""" + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 35}) + + # Test greater than + results = await self.connector.find("people", {"age__gt": 25}) + self.assertEqual(len(results), 2) + + # Test less than or equal + results = await self.connector.find("people", {"age__le": 30}) + self.assertEqual(len(results), 2) + + # Test not equal + results = await self.connector.find("people", {"age__ne": 30}) + self.assertEqual(len(results), 2) + + # Test LIKE + results = await self.connector.find("people", {"name__like": "A%"}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "Alice") + + async def test_in_operator(self): + """Test IN and NOT IN operators.""" + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 35}) + + # Test IN + results = await self.connector.find("people", {"age__in": [25, 35]}) + self.assertEqual(len(results), 2) + + # Test NOT IN + results = await self.connector.find("people", {"age__notin": [25, 35]}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "Bob") + + async def test_between_operator(self): + """Test BETWEEN operator.""" + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 35}) + + results = await self.connector.find("people", {"age__between": [26, 34]}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "Bob") + + async def test_complex_where_conditions(self): + """Test complex WHERE conditions with $or and $and.""" + await self.connector.insert("people", {"name": "Alice", "age": 25, "city": "NYC"}) + await self.connector.insert("people", {"name": "Bob", "age": 30, "city": "LA"}) + await self.connector.insert("people", {"name": "Charlie", "age": 35, "city": "NYC"}) + + # Test $or + results = await self.connector.find("people", { + "$or": [{"age": 25}, {"city": "LA"}] + }) + self.assertEqual(len(results), 2) + + # Test $and with operators + results = await self.connector.find("people", { + "$and": [{"age__gt": 25}, {"city": "NYC"}] + }) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "Charlie") + + async def test_distinct(self): + """Test distinct values.""" + await self.connector.insert("people", {"name": "Alice", "city": "NYC"}) + await self.connector.insert("people", {"name": "Bob", "city": "LA"}) + await self.connector.insert("people", {"name": "Charlie", "city": "NYC"}) + + cities = await self.connector.distinct("people", "city") + self.assertEqual(set(cities), {"NYC", "LA"}) + + # Test multiple columns + results = await self.connector.distinct("people", ["city", "name"]) + self.assertEqual(len(results), 3) + + async def test_all_alias(self): + """Test that all() is an alias for find().""" + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + + find_results = await self.connector.find("people") + all_results = await self.connector.all("people") + + self.assertEqual(find_results, all_results) + + async def test_table_operations(self): + """Test table-related operations.""" + # Create a table + await self.connector.insert("test_table", {"field": "value"}) + + # Check if table exists + self.assertTrue(await self.connector.has_table("test_table")) + self.assertFalse(await self.connector.has_table("nonexistent")) + + # Get tables list + tables = await self.connector.tables + self.assertIn("test_table", tables) + + # Drop table + await self.connector.drop("test_table") + self.assertFalse(await self.connector.has_table("test_table")) + + async def test_column_operations(self): + """Test column-related operations.""" + await self.connector.insert("people", {"name": "Alice", "age": 25}) + + # Check if column exists + self.assertTrue(await self.connector.has_column("people", "name")) + self.assertTrue(await self.connector.has_column("people", "age")) + self.assertFalse(await self.connector.has_column("people", "nonexistent")) + + # Get columns list + columns = await self.connector.columns("people") + self.assertIn("name", columns) + self.assertIn("age", columns) + self.assertIn("uid", columns) + + async def test_transaction_methods(self): + """Test transaction methods.""" + await self.connector._ensure_initialized() + if self.connector._is_server: + # Begin transaction + await self.connector.begin() + + # Insert during transaction + await self.connector.insert("people", {"name": "Alice", "age": 25}) + + # Rollback + await self.connector.rollback() + + # Check that data was not committed + result = await self.connector.get("people", {"name": "Alice"}) + self.assertIsNone(result) + + # Test commit + await self.connector.begin() + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.commit() + + # Check that data was committed + result = await self.connector.get("people", {"name": "Bob"}) + self.assertIsNotNone(result) + + async def test_context_manager_transaction(self): + """Test transaction as context manager.""" + await self.connector._ensure_initialized() + if self.connector._is_server: + # Test successful transaction + async with self.connector: + await self.connector.insert("people", {"name": "Alice", "age": 25}) + + result = await self.connector.get("people", {"name": "Alice"}) + self.assertIsNotNone(result) + + # Test failed transaction (should rollback) + try: + async with self.connector: + await self.connector.insert("people", {"name": "Bob", "age": 30}) + raise Exception("Test exception") + except: + pass + + result = await self.connector.get("people", {"name": "Bob"}) + # This might still exist because exception was after insert + # The behavior depends on SQLite's transaction handling + async def test_raw_query(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) diff --git a/src/snek/system/ads.py.old b/src/snek/system/ads.py.old new file mode 100644 index 0000000..1e85897 --- /dev/null +++ b/src/snek/system/ads.py.old @@ -0,0 +1,1209 @@ +import asyncio +import atexit +import json +import os +import re +import socket +import unittest +from collections.abc import AsyncGenerator, Iterable +from datetime import datetime, timezone +from pathlib import Path +from typing import ( + Any, + Dict, + List, + Optional, + Set, + Tuple, + Union, +) +from uuid import uuid4 + +import aiohttp +import aiosqlite +from aiohttp import web + + +class AsyncDataSet: + """ + Distributed AsyncDataSet with client-server model over Unix sockets. + + Parameters: + - file: Path to the SQLite database file + - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") + - max_concurrent_queries: Maximum concurrent database operations (default: 1) + Set to 1 to prevent "database is locked" errors with SQLite + """ + + _KV_TABLE = "__kv_store" + _DEFAULT_COLUMNS = { + "uid": "TEXT PRIMARY KEY", + "created_at": "TEXT", + "updated_at": "TEXT", + "deleted_at": "TEXT", + } + + def __init__( + self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 + ): + self._file = file + self._socket_path = socket_path + self._table_columns_cache: Dict[str, Set[str]] = {} + self._is_server = None # None means not initialized yet + self._server = None + self._runner = None + self._client_session = None + self._initialized = False + self._init_lock = None # Will be created when needed + self._max_concurrent_queries = max_concurrent_queries + self._db_semaphore = None # Will be created when needed + self._db_connection = None # Persistent database connection + + async def _ensure_initialized(self): + """Ensure the instance is initialized as server or client.""" + if self._initialized: + return + + # Create lock if needed (first time in async context) + if self._init_lock is None: + self._init_lock = asyncio.Lock() + + async with self._init_lock: + if self._initialized: + return + + await self._initialize() + + # Create semaphore for database operations (server only) + if self._is_server: + self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) + + self._initialized = True + + async def _initialize(self): + """Initialize as server or client.""" + try: + # Try to create Unix socket + if os.path.exists(self._socket_path): + # Check if socket is active + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(self._socket_path) + sock.close() + # Socket is active, we're a client + self._is_server = False + await self._setup_client() + except (ConnectionRefusedError, FileNotFoundError): + # Socket file exists but not active, remove and become server + os.unlink(self._socket_path) + self._is_server = True + await self._setup_server() + else: + # No socket file, become server + self._is_server = True + await self._setup_server() + except Exception: + # Fallback to client mode + self._is_server = False + await self._setup_client() + + # Establish persistent database connection + self._db_connection = await aiosqlite.connect(self._file) + + async def _setup_server(self): + """Setup aiohttp server.""" + app = web.Application() + + # Add routes for all methods + app.router.add_post("/insert", self._handle_insert) + app.router.add_post("/update", self._handle_update) + app.router.add_post("/delete", self._handle_delete) + app.router.add_post("/upsert", self._handle_upsert) + app.router.add_post("/get", self._handle_get) + app.router.add_post("/find", self._handle_find) + app.router.add_post("/count", self._handle_count) + app.router.add_post("/exists", self._handle_exists) + app.router.add_post("/kv_set", self._handle_kv_set) + app.router.add_post("/kv_get", self._handle_kv_get) + app.router.add_post("/execute_raw", self._handle_execute_raw) + app.router.add_post("/query_raw", self._handle_query_raw) + app.router.add_post("/query_one", self._handle_query_one) + app.router.add_post("/create_table", self._handle_create_table) + app.router.add_post("/insert_unique", self._handle_insert_unique) + app.router.add_post("/aggregate", self._handle_aggregate) + + self._runner = web.AppRunner(app) + await self._runner.setup() + self._server = web.UnixSite(self._runner, self._socket_path) + await self._server.start() + + # Register cleanup + def cleanup(): + if os.path.exists(self._socket_path): + os.unlink(self._socket_path) + + atexit.register(cleanup) + + async def _setup_client(self): + """Setup aiohttp client.""" + connector = aiohttp.UnixConnector(path=self._socket_path) + self._client_session = aiohttp.ClientSession(connector=connector) + + async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: + """Make HTTP request to server.""" + if not self._client_session: + await self._setup_client() + + url = f"http://localhost/{endpoint}" + async with self._client_session.post(url, json=data) as resp: + result = await resp.json() + if result.get("error"): + raise Exception(result["error"]) + return result.get("result") + + # Server handlers + async def _handle_insert(self, request): + data = await request.json() + try: + result = await self._server_insert( + data["table"], data["args"], data.get("return_id", False) + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_update(self, request): + data = await request.json() + try: + result = await self._server_update( + data["table"], data["args"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_delete(self, request): + data = await request.json() + try: + result = await self._server_delete(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_upsert(self, request): + data = await request.json() + try: + result = await self._server_upsert( + data["table"], data["args"], data.get("where") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_get(self, request): + data = await request.json() + try: + result = await self._server_get(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_find(self, request): + data = await request.json() + try: + try: + _limit = data.pop("_limit", 60) + except: + pass + + result = await self._server_find( + data["table"], + data.get("where"), + _limit=_limit, + offset=data.get("offset", 0), + order_by=data.get("order_by"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_count(self, request): + data = await request.json() + try: + result = await self._server_count(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_exists(self, request): + data = await request.json() + try: + result = await self._server_exists(data["table"], data["where"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_set(self, request): + data = await request.json() + try: + await self._server_kv_set( + data["key"], data["value"], table=data.get("table") + ) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_get(self, request): + data = await request.json() + try: + result = await self._server_kv_get( + data["key"], default=data.get("default"), table=data.get("table") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_execute_raw(self, request): + data = await request.json() + try: + result = await self._server_execute_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response( + {"result": result.rowcount if hasattr(result, "rowcount") else None} + ) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_raw(self, request): + data = await request.json() + try: + result = await self._server_query_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_one(self, request): + data = await request.json() + try: + result = await self._server_query_one( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_create_table(self, request): + data = await request.json() + try: + await self._server_create_table( + data["table"], data["schema"], data.get("constraints") + ) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_insert_unique(self, request): + data = await request.json() + try: + result = await self._server_insert_unique( + data["table"], data["args"], data["unique_fields"] + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_aggregate(self, request): + data = await request.json() + try: + result = await self._server_aggregate( + data["table"], + data["function"], + data.get("column", "*"), + data.get("where"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + # Public methods that delegate to server or client + async def insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert(table, args, return_id) + else: + return await self._make_request( + "insert", {"table": table, "args": args, "return_id": return_id} + ) + + async def update( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_update(table, args, where) + else: + return await self._make_request( + "update", {"table": table, "args": args, "where": where} + ) + + async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_delete(table, where) + else: + return await self._make_request("delete", {"table": table, "where": where}) + + async def upsert( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> str | None: + await self._ensure_initialized() + if self._is_server: + return await self._server_upsert(table, args, where) + else: + return await self._make_request( + "upsert", {"table": table, "args": args, "where": where} + ) + + async def get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_get(table, where) + else: + return await self._make_request("get", {"table": table, "where": where}) + + async def find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + _limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_find( + table, where, _limit=_limit, offset=offset, order_by=order_by + ) + else: + return await self._make_request( + "find", + { + "table": table, + "where": where, + "_limit": _limit, + "offset": offset, + "order_by": order_by, + }, + ) + + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_count(table, where) + else: + return await self._make_request("count", {"table": table, "where": where}) + + async def exists(self, table: str, where: Dict[str, Any]) -> bool: + await self._ensure_initialized() + if self._is_server: + return await self._server_exists(table, where) + else: + return await self._make_request("exists", {"table": table, "where": where}) + + async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None: + await self._ensure_initialized() + if self._is_server: + return await self._server_kv_set(key, value, table=table) + else: + return await self._make_request( + "kv_set", {"key": key, "value": value, "table": table} + ) + + async def kv_get( + self, key: str, *, default: Any = None, table: str | None = None + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_kv_get(key, default=default, table=table) + else: + return await self._make_request( + "kv_get", {"key": key, "default": default, "table": table} + ) + + async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_execute_raw(sql, params) + else: + return await self._make_request( + "execute_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_raw(sql, params) + else: + return await self._make_request( + "query_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_one(sql, params) + else: + return await self._make_request( + "query_one", {"sql": sql, "params": list(params) if params else None} + ) + + async def create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + await self._ensure_initialized() + if self._is_server: + return await self._server_create_table(table, schema, constraints) + else: + return await self._make_request( + "create_table", + {"table": table, "schema": schema, "constraints": constraints}, + ) + + async def insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert_unique(table, args, unique_fields) + else: + return await self._make_request( + "insert_unique", + {"table": table, "args": args, "unique_fields": unique_fields}, + ) + + async def aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_aggregate(table, function, column, where) + else: + return await self._make_request( + "aggregate", + { + "table": table, + "function": function, + "column": column, + "where": where, + }, + ) + + async def transaction(self): + """Context manager for transactions.""" + await self._ensure_initialized() + if self._is_server: + return TransactionContext(self._db_connection, self._db_semaphore) + else: + raise NotImplementedError("Transactions not supported in client mode") + + # Server implementation methods (original logic) + @staticmethod + def _utc_iso() -> str: + return ( + datetime.now(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + @staticmethod + def _py_to_sqlite_type(value: Any) -> str: + if value is None: + return "TEXT" + if isinstance(value, bool): + return "INTEGER" + if isinstance(value, int): + return "INTEGER" + if isinstance(value, float): + return "REAL" + if isinstance(value, (bytes, bytearray, memoryview)): + return "BLOB" + return "TEXT" + + async def _get_table_columns(self, table: str) -> Set[str]: + """Get actual columns that exist in the table.""" + if table in self._table_columns_cache: + return self._table_columns_cache[table] + + columns = set() + try: + async with self._db_semaphore: + async with self._db_connection.execute( + f"PRAGMA table_info({table})" + ) as cursor: + async for row in cursor: + columns.add(row[1]) # Column name is at index 1 + self._table_columns_cache[table] = columns + except: + pass + return columns + + async def _invalidate_column_cache(self, table: str): + """Invalidate column cache for a table.""" + if table in self._table_columns_cache: + del self._table_columns_cache[table] + + async def _ensure_column(self, table: str, name: str, value: Any) -> None: + col_type = self._py_to_sqlite_type(value) + try: + async with self._db_semaphore: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + except aiosqlite.OperationalError as e: + if "duplicate column name" in str(e).lower(): + pass # Column already exists + else: + raise + + async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + # Always include default columns + cols = self._DEFAULT_COLUMNS.copy() + + # Add columns from col_sources + for key, val in col_sources.items(): + if key not in cols: + cols[key] = self._py_to_sqlite_type(val) + + columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) + async with self._db_semaphore: + await self._db_connection.execute( + f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _table_exists(self, table: str) -> bool: + """Check if a table exists.""" + async with self._db_semaphore: + async with self._db_connection.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) + ) as cursor: + return await cursor.fetchone() is not None + + _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") + _RE_NO_TABLE = re.compile(r"no such table: (\w+)") + + @classmethod + def _missing_column_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_COLUMN.search(str(err)) + return m.group(1) if m else None + + @classmethod + def _missing_table_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_TABLE.search(str(err)) + return m.group(1) if m else None + + async def _safe_execute( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + max_retries: int = 10, + ) -> aiosqlite.Cursor: + retries = 0 + while retries < max_retries: + try: + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params) + await self._db_connection.commit() + return cursor + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing column + col = self._missing_column_from_error(err) + if col: + if col in col_sources: + await self._ensure_column(table, col, col_sources[col]) + else: + # Column not in sources, ensure it with NULL/TEXT type + await self._ensure_column(table, col, None) + continue + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + + # Handle other column-related errors + if "has no column named" in err_str: + # Extract column name differently + match = re.search(r"table \w+ has no column named (\w+)", err_str) + if match: + col_name = match.group(1) + if col_name in col_sources: + await self._ensure_column( + table, col_name, col_sources[col_name] + ) + else: + await self._ensure_column(table, col_name, None) + continue + + raise + raise Exception(f"Max retries ({max_retries}) exceeded") + + async def _filter_existing_columns( + self, table: str, data: Dict[str, Any] + ) -> Dict[str, Any]: + """Filter data to only include columns that exist in the table.""" + if not await self._table_exists(table): + return data + + existing_columns = await self._get_table_columns(table) + if not existing_columns: + return data + + return {k: v for k, v in data.items() if k in existing_columns} + + async def _safe_query( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> AsyncGenerator[Dict[str, Any], None]: + # Check if table exists first + if not await self._table_exists(table): + return + + max_retries = 10 + retries = 0 + + while retries < max_retries: + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params) as cursor: + # Fetch all rows while holding the semaphore + rows = await cursor.fetchall() + + # Yield rows after releasing the semaphore + for row in rows: + yield dict(row) + return + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + # For queries, if table doesn't exist, just return empty + return + + # Handle missing column in WHERE clause or SELECT + if "no such column" in err_str: + # For queries with missing columns, return empty + return + + raise + + @staticmethod + def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + if not where: + return "", [] + clauses, vals = zip( + *[ + (f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v) + for k, v in where.items() + ] + ) + return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None] + + async def _server_insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" + uid = str(uuid4()) + now = self._utc_iso() + record = { + "uid": uid, + "created_at": now, + "updated_at": now, + "deleted_at": None, + **args, + } + + # Ensure table exists with all needed columns + await self._ensure_table(table, record) + + # Handle auto-increment ID if requested + if return_id and "id" not in args: + # Ensure id column exists + async with self._db_semaphore: + # Add id column if it doesn't exist + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" + ) + await self._db_connection.commit() + except aiosqlite.OperationalError as e: + if "duplicate column name" not in str(e).lower(): + # Try without autoincrement constraint + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER" + ) + await self._db_connection.commit() + except: + pass + + await self._invalidate_column_cache(table) + + # Insert and get lastrowid + cols = "`" + "`, `".join(record.keys()) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + try: + cursor = await self._safe_execute(table, sql, list(record.values()), record) + except Exception as ex: + print("GRRRRRRRR",ex) + return cursor.lastrowid + + cols = "`" + "`, `".join(record) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + await self._safe_execute(table, sql, list(record.values()), record) + return uid + + async def _server_update( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> int: + if not args: + return 0 + + # Check if table exists + if not await self._table_exists(table): + return 0 + + args["updated_at"] = self._utc_iso() + + # Ensure all columns exist + all_cols = {**args, **(where or {})} + await self._ensure_table(table, all_cols) + for col, val in all_cols.items(): + await self._ensure_column(table, col, val) + + set_clause = ", ".join(f"`{k}` = ?" for k in args) + where_clause, where_params = self._build_where(where) + sql = f"UPDATE {table} SET {set_clause}{where_clause}" + params = list(args.values()) + where_params + cur = await self._safe_execute(table, sql, params, all_cols) + return cur.rowcount + + async def _server_delete( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"DELETE FROM {table}{where_clause}" + cur = await self._safe_execute(table, sql, where_params, where or {}) + return cur.rowcount + + async def _server_upsert( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> str | None: + if not args: + raise ValueError("Nothing to update. Empty dict given.") + args["updated_at"] = self._utc_iso() + affected = await self._server_update(table, args, where) + if affected: + rec = await self._server_get(table, where) + return rec.get("uid") if rec else None + merged = {**(where or {}), **args} + return await self._server_insert(table, merged) + + async def _server_get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" + async for row in self._safe_query(table, sql, where_params, where or {}): + return row + return None + + async def _server_find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + _limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Find records with optional ordering.""" + + try: + _limit = where.pop("_limit") + except: + pass + + where_clause, where_params = self._build_where(where) + order_clause = f" ORDER BY {order_by}" if order_by else "" + extra = (f" LIMIT {_limit}" if _limit else "") + ( + f" OFFSET {offset}" if offset else "" + ) + sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" + return [ + row async for row in self._safe_query(table, sql, where_params, where or {}) + ] + + async def _server_count( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"SELECT COUNT(*) FROM {table}{where_clause}" + gen = self._safe_query(table, sql, where_params, where or {}) + async for row in gen: + return next(iter(row.values()), 0) + return 0 + + async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: + return (await self._server_count(table, where)) > 0 + + async def _server_kv_set( + self, + key: str, + value: Any, + *, + table: str | None = None, + ) -> None: + tbl = table or self._KV_TABLE + json_val = json.dumps(value, default=str) + await self._server_upsert(tbl, {"value": json_val}, {"key": key}) + + async def _server_kv_get( + self, + key: str, + *, + default: Any = None, + table: str | None = None, + ) -> Any: + tbl = table or self._KV_TABLE + row = await self._server_get(tbl, {"key": key}) + if not row: + return default + try: + return json.loads(row["value"]) + except Exception: + return default + + async def query(self, sql, *args): + return await self._server_query_raw(sql, args) + async def _server_execute_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> Any: + """Execute raw SQL for complex queries like JOINs.""" + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params or ()) + await self._db_connection.commit() + return cursor + + async def _server_query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + """Execute raw SQL query and return results as list of dicts.""" + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params or ()) as cursor: + return [dict(row) async for row in cursor] + except aiosqlite.OperationalError as ex: + print(ex) + # Return empty list if query fails + return [] + + async def _server_query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + """Execute raw SQL query and return single result.""" + results = await self._server_query_raw(sql + " LIMIT 1", params) + return results[0] if results else None + + async def _server_create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + """Create table with custom schema and constraints. Always includes default columns.""" + # Merge default columns with custom schema + full_schema = self._DEFAULT_COLUMNS.copy() + full_schema.update(schema) + + columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] + if constraints: + columns.extend(constraints) + columns_sql = ", ".join(columns) + + async with self._db_semaphore: + await self._db_connection.execute( + f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _server_insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" + try: + return await self._server_insert(table, args) + except aiosqlite.IntegrityError as e: + if "UNIQUE" in str(e): + return None + raise + + async def _server_aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + """Perform aggregate functions like SUM, AVG, MAX, MIN.""" + # Check if table exists + if not await self._table_exists(table): + return None + + where_clause, where_params = self._build_where(where) + sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" + result = await self._server_query_one(sql, tuple(where_params)) + return result["result"] if result else None + + async def close(self): + """Close the connection or server.""" + if not self._initialized: + return + + if self._client_session: + await self._client_session.close() + if self._runner: + await self._runner.cleanup() + if os.path.exists(self._socket_path) and self._is_server: + os.unlink(self._socket_path) + if self._db_connection: + await self._db_connection.close() + + +class TransactionContext: + """Context manager for database transactions.""" + + def __init__( + self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None + ): + self.db_connection = db_connection + self.semaphore = semaphore + + async def __aenter__(self): + if self.semaphore: + await self.semaphore.acquire() + try: + await self.db_connection.execute("BEGIN") + return self.db_connection + except: + if self.semaphore: + self.semaphore.release() + raise + + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type is None: + await self.db_connection.commit() + else: + await self.db_connection.rollback() + finally: + if self.semaphore: + self.semaphore.release() + + +# Test cases remain the same but with additional tests for new functionality +class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.db_path = Path("temp_test.db") + if self.db_path.exists(): + self.db_path.unlink() + self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1) + + async def asyncTearDown(self): + await self.connector.close() + if self.db_path.exists(): + self.db_path.unlink() + + async def test_insert_and_get(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertIsNotNone(rec) + self.assertEqual(rec["name"], "John Doe") + + async def test_get_nonexistent(self): + result = await self.connector.get("people", {"name": "Jane Doe"}) + self.assertIsNone(result) + + async def test_update(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertEqual(rec["age"], 31) + + async def test_order_by(self): + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 20}) + + results = await self.connector.find("people", order_by="age ASC") + self.assertEqual(results[0]["name"], "Charlie") + self.assertEqual(results[-1]["name"], "Bob") + + async def test_raw_query(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + + results = await self.connector.query_raw( + "SELECT * FROM people WHERE age > ?", (26,) + ) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "John") + + async def test_aggregate(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 35}) + + avg_age = await self.connector.aggregate("people", "AVG", "age") + self.assertEqual(avg_age, 30) + + max_age = await self.connector.aggregate("people", "MAX", "age") + self.assertEqual(max_age, 35) + + async def test_insert_with_auto_id(self): + # Test auto-increment ID functionality + id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) + id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) + self.assertEqual(id2, id1 + 1) + + async def test_transaction(self): + await self.connector._ensure_initialized() + if self.connector._is_server: + async with self.connector.transaction() as conn: + await conn.execute( + "INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", + ("test-uid", "John", 30, "2024-01-01", "2024-01-01"), + ) + # Transaction will be committed + + rec = await self.connector.get("people", {"name": "John"}) + self.assertIsNotNone(rec) + + async def test_create_custom_table(self): + schema = { + "id": "INTEGER PRIMARY KEY AUTOINCREMENT", + "username": "TEXT NOT NULL", + "email": "TEXT NOT NULL", + "score": "INTEGER DEFAULT 0", + } + constraints = ["UNIQUE(username)", "UNIQUE(email)"] + + await self.connector.create_table("users", schema, constraints) + + # Test that table was created with constraints + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "john@example.com"}, + ["username", "email"], + ) + self.assertIsNotNone(result) + + # Test duplicate insert + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "different@example.com"}, + ["username", "email"], + ) + self.assertIsNone(result) + + async def test_missing_table_operations(self): + # Test operations on non-existent tables + self.assertEqual(await self.connector.count("nonexistent"), 0) + self.assertEqual(await self.connector.find("nonexistent"), []) + self.assertIsNone(await self.connector.get("nonexistent")) + self.assertFalse(await self.connector.exists("nonexistent", {"id": 1})) + self.assertEqual(await self.connector.delete("nonexistent"), 0) + self.assertEqual( + await self.connector.update("nonexistent", {"name": "test"}), 0 + ) + + async def test_auto_column_creation(self): + # Insert with new columns that don't exist yet + await self.connector.insert( + "dynamic", {"col1": "value1", "col2": 42, "col3": 3.14} + ) + + # Add more columns in next insert + await self.connector.insert( + "dynamic", {"col1": "value2", "col4": True, "col5": None} + ) + + # All records should be retrievable + records = await self.connector.find("dynamic") + self.assertEqual(len(records), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/snek/system/mapper.py b/src/snek/system/mapper.py index 6ccc19b..76ecfde 100644 --- a/src/snek/system/mapper.py +++ b/src/snek/system/mapper.py @@ -93,6 +93,7 @@ class BaseMapper: return "insert" in sql or "update" in sql or "delete" in sql async def query(self, sql, *args): + print("XXXXXXXX", sql,args) for record in await self.db.query(sql, *args): yield dict(record) diff --git a/src/snek/system/session.py b/src/snek/system/session.py new file mode 100644 index 0000000..5865baa --- /dev/null +++ b/src/snek/system/session.py @@ -0,0 +1,225 @@ +import json +import uuid +import time +from typing import Any, Callable, Optional + +import aiosqlite +from aiohttp import web + +from aiohttp_session import AbstractStorage, Session + + +class SQLiteSessionStorage(AbstractStorage): + """SQLite storage using aiosqlite""" + + def __init__( + self, + cookie_name: str = "AIOHTTP_SESSION", + *, + db_path: str = "sessions.db", + domain: Optional[str] = None, + max_age: Optional[int] = None, + path: str = "/", + secure: Optional[bool] = None, + httponly: bool = True, + samesite: Optional[str] = None, + key_factory: Callable[[], str] = lambda: uuid.uuid4().hex, + encoder: Callable[[object], str] = json.dumps, + decoder: Callable[[str], Any] = json.loads, + ) -> None: + # Debug logging + print(f"SQLiteStorage init - cookie_name type: {type(cookie_name)}, value: {cookie_name!r}") + + # Ensure all string parameters are actually strings, not bytes + if isinstance(cookie_name, bytes): + cookie_name = cookie_name.decode('utf-8') + if isinstance(domain, bytes): + domain = domain.decode('utf-8') + if isinstance(path, bytes): + path = path.decode('utf-8') + if isinstance(samesite, bytes): + samesite = samesite.decode('utf-8') + + # Convert to string explicitly + cookie_name = str(cookie_name) + path = str(path) + + super().__init__( + cookie_name=cookie_name, + domain=domain, + max_age=max_age, + path=path, + secure=secure, + httponly=httponly, + samesite=samesite, + encoder=encoder, + decoder=decoder, + ) + self._key_factory = key_factory + self._db_path = db_path + self._db: Optional[aiosqlite.Connection] = None + + def save_cookie( + self, + response: web.StreamResponse, + cookie_data: str, + *, + max_age: Optional[int] = None, + ) -> None: + """Override to ensure proper string handling and add debugging.""" + # Ensure cookie_data is a string + if isinstance(cookie_data, bytes): + cookie_data = cookie_data.decode('utf-8') + elif cookie_data is not None: + cookie_data = str(cookie_data) + + # Ensure cookie name is a string + cookie_name = self._cookie_name + if isinstance(cookie_name, bytes): + print(f"WARNING: _cookie_name is bytes: {cookie_name!r}") + cookie_name = cookie_name.decode('utf-8') + cookie_name = str(cookie_name) + + # Get cookie params and ensure they're all strings + params = self._cookie_params.copy() + if max_age is not None: + params["max_age"] = max_age + t = time.gmtime(time.time() + max_age) + params["expires"] = time.strftime("%a, %d-%b-%Y %T GMT", t) + + # Ensure all string params in the dict are strings, not bytes + for key in ['domain', 'path', 'samesite']: + if key in params and isinstance(params[key], bytes): + params[key] = params[key].decode('utf-8') + + if not cookie_data: + response.del_cookie( + cookie_name, domain=params.get("domain"), path=params.get("path") + ) + else: + response.set_cookie(cookie_name, cookie_data, **params) + + async def setup(self) -> None: + """Initialize the database connection and create table if needed.""" + self._db = await aiosqlite.connect(self._db_path) + await self._db.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + key TEXT PRIMARY KEY, + data TEXT NOT NULL, + expires_at REAL + ) + """) + # Create index for expiry cleanup + await self._db.execute(""" + CREATE INDEX IF NOT EXISTS idx_expires_at + ON sessions(expires_at) + WHERE expires_at IS NOT NULL + """) + await self._db.commit() + + async def cleanup(self) -> None: + """Close the database connection and cleanup expired sessions.""" + if self._db: + # Clean up expired sessions before closing + await self._cleanup_expired() + await self._db.close() + self._db = None + + async def _cleanup_expired(self) -> None: + """Remove expired sessions from the database.""" + if self._db: + current_time = time.time() + await self._db.execute( + "DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < ?", + (current_time,) + ) + await self._db.commit() + + async def load_session(self, request: web.Request) -> Session: + """Load session from SQLite database.""" + cookie = self.load_cookie(request) + if cookie is None: + return Session(None, data=None, new=True, max_age=self.max_age) + + if not self._db: + raise RuntimeError("Storage not initialized. Call setup() first.") + + # Periodically clean up expired sessions + # You might want to do this less frequently in production + await self._cleanup_expired() + + key = str(cookie) + session_key = f"{self._cookie_name}_{key}" + + async with self._db.execute( + "SELECT data, expires_at FROM sessions WHERE key = ?", + (session_key,) + ) as cursor: + row = await cursor.fetchone() + + if row is None: + return Session(None, data=None, new=True, max_age=self.max_age) + + data_str, expires_at = row + + # Check if session has expired + if expires_at is not None and expires_at < time.time(): + # Session expired, delete it + await self._db.execute("DELETE FROM sessions WHERE key = ?", (session_key,)) + await self._db.commit() + return Session(None, data=None, new=True, max_age=self.max_age) + + try: + data = self._decoder(data_str) + except ValueError: + data = None + + return Session(key, data=data, new=False, max_age=self.max_age) + + async def save_session( + self, request: web.Request, response: web.StreamResponse, session: Session + ) -> None: + """Save session to SQLite database.""" + if not self._db: + raise RuntimeError("Storage not initialized. Call setup() first.") + + key = session.identity + if key is None: + key = self._key_factory() + # Ensure key is a string before passing to save_cookie + if isinstance(key, bytes): + key = key.decode('utf-8') + else: + key = str(key) + self.save_cookie(response, key, max_age=session.max_age) + else: + # Ensure key is a string + if isinstance(key, bytes): + key = key.decode('utf-8') + else: + key = str(key) + + if session.empty: + self.save_cookie(response, "", max_age=session.max_age) + # Delete the session from database + session_key = f"{self._cookie_name}_{key}" + await self._db.execute("DELETE FROM sessions WHERE key = ?", (session_key,)) + await self._db.commit() + return + else: + self.save_cookie(response, key, max_age=session.max_age) + + data_str = self._encoder(self._get_session_data(session)) + session_key = f"{self._cookie_name}_{key}" + + # Calculate expiry time if max_age is set + expires_at = None + if session.max_age is not None: + expires_at = time.time() + session.max_age + + # Use INSERT OR REPLACE to handle both new and existing sessions + await self._db.execute( + "INSERT OR REPLACE INTO sessions (key, data, expires_at) VALUES (?, ?, ?)", + (session_key, data_str, expires_at) + ) + await self._db.commit()