import re import json from uuid import uuid4 from datetime import datetime, timezone from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple from pathlib import Path import aiosqlite import unittest from types import SimpleNamespace import asyncio class AsyncDataSet: _KV_TABLE = "__kv_store" def __init__(self, file: str): self._file = file @staticmethod def _utc_iso() -> str: return ( datetime.now(timezone.utc) .replace(microsecond=0) .isoformat() .replace("+00:00", "Z") ) @staticmethod def _py_to_sqlite_type(value: Any) -> str: if value is None: return "TEXT" if isinstance(value, bool): return "INTEGER" if isinstance(value, int): return "INTEGER" if isinstance(value, float): return "REAL" if isinstance(value, (bytes, bytearray, memoryview)): return "BLOB" return "TEXT" async def _ensure_column(self, table: str, name: str, value: Any) -> None: col_type = self._py_to_sqlite_type(value) async with aiosqlite.connect(self._file) as db: await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}") await db.commit() async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: cols: Dict[str, str] = { "uid": "TEXT PRIMARY KEY", "created_at": "TEXT", "updated_at": "TEXT", "deleted_at": "TEXT", } for key, val in col_sources.items(): if key not in cols: cols[key] = self._py_to_sqlite_type(val) columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) async with aiosqlite.connect(self._file) as db: await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") await db.commit() _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") _RE_NO_TABLE = re.compile(r"no such table: (\w+)") @classmethod def _missing_column_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_COLUMN.search(str(err)) return m.group(1) if m else None @classmethod def _missing_table_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_TABLE.search(str(err)) return m.group(1) if m else None async def _safe_execute( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], ) -> aiosqlite.Cursor: while True: try: async with aiosqlite.connect(self._file) as db: cursor = await db.execute(sql, params) await db.commit() return cursor except aiosqlite.OperationalError as err: col = self._missing_column_from_error(err) if col: if col not in col_sources: raise await self._ensure_column(table, col, col_sources[col]) continue tbl = self._missing_table_from_error(err) if tbl: await self._ensure_table(tbl, col_sources) continue raise async def _safe_query( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], ) -> AsyncGenerator[Dict[str, Any], None]: while True: try: async with aiosqlite.connect(self._file) as db: db.row_factory = aiosqlite.Row async with db.execute(sql, params) as cursor: async for row in cursor: yield dict(row) return except aiosqlite.OperationalError as err: tbl = self._missing_table_from_error(err) if tbl: await self._ensure_table(tbl, col_sources) continue raise @staticmethod def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: if not where: return "", [] clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) return " WHERE " + " AND ".join(clauses), list(vals) async def insert(self, table: str, args: Dict[str, Any], return_id: bool = False) -> Union[str, int]: """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" uid = str(uuid4()) now = self._utc_iso() record = { "uid": uid, "created_at": now, "updated_at": now, "deleted_at": None, **args, } # Handle auto-increment ID if requested if return_id and 'id' not in args: async with aiosqlite.connect(self._file) as db: # Create table with autoincrement ID if needed await db.execute(f""" CREATE TABLE IF NOT EXISTS {table} ( id INTEGER PRIMARY KEY AUTOINCREMENT, uid TEXT UNIQUE, created_at TEXT, updated_at TEXT, deleted_at TEXT ) """) await db.commit() # Insert and get lastrowid cols = "`" + "`, `".join(record.keys()) + "`" qs = ", ".join(["?"] * len(record)) cursor = await db.execute(f"INSERT INTO {table} ({cols}) VALUES ({qs})", list(record.values())) await db.commit() return cursor.lastrowid cols = "`" + "`, `".join(record) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" await self._safe_execute(table, sql, list(record.values()), record) return uid async def update( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> int: if not args: return 0 args["updated_at"] = self._utc_iso() set_clause = ", ".join(f"`{k}` = ?" for k in args) where_clause, where_params = self._build_where(where) sql = f"UPDATE {table} SET {set_clause}{where_clause}" params = list(args.values()) + where_params cur = await self._safe_execute(table, sql, params, {**args, **(where or {})}) return cur.rowcount async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: where_clause, where_params = self._build_where(where) sql = f"DELETE FROM {table}{where_clause}" cur = await self._safe_execute(table, sql, where_params, where or {}) return cur.rowcount async def upsert( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> str | None: if not args: raise ValueError("Nothing to update. Empty dict given.") args['updated_at'] = self._utc_iso() affected = await self.update(table, args, where) if affected: rec = await self.get(table, where) return rec.get("uid") if rec else None merged = {**(where or {}), **args} return await self.insert(table, merged) async def get( self, table: str, where: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: where_clause, where_params = self._build_where(where) sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" async for row in self._safe_query(table, sql, where_params, where or {}): return row return None async def find( self, table: str, where: Optional[Dict[str, Any]] = None, *, limit: int = 0, offset: int = 0, order_by: Optional[str] = None, ) -> List[Dict[str, Any]]: """Find records with optional ordering.""" where_clause, where_params = self._build_where(where) order_clause = f" ORDER BY {order_by}" if order_by else "" extra = (f" LIMIT {limit}" if limit else "") + ( f" OFFSET {offset}" if offset else "" ) sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" return [ row async for row in self._safe_query(table, sql, where_params, where or {}) ] async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: where_clause, where_params = self._build_where(where) sql = f"SELECT COUNT(*) FROM {table}{where_clause}" gen = self._safe_query(table, sql, where_params, where or {}) async for row in gen: return next(iter(row.values()), 0) return 0 async def exists(self, table: str, where: Dict[str, Any]) -> bool: return (await self.count(table, where)) > 0 async def kv_set( self, key: str, value: Any, *, table: str | None = None, ) -> None: tbl = table or self._KV_TABLE json_val = json.dumps(value, default=str) await self.upsert(tbl, {"value": json_val}, {"key": key}) async def kv_get( self, key: str, *, default: Any = None, table: str | None = None, ) -> Any: tbl = table or self._KV_TABLE row = await self.get(tbl, {"key": key}) if not row: return default try: return json.loads(row["value"]) except Exception: return default async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: """Execute raw SQL for complex queries like JOINs.""" async with aiosqlite.connect(self._file) as db: cursor = await db.execute(sql, params or ()) await db.commit() return cursor async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]: """Execute raw SQL query and return results as list of dicts.""" async with aiosqlite.connect(self._file) as db: db.row_factory = aiosqlite.Row async with db.execute(sql, params or ()) as cursor: return [dict(row) async for row in cursor] async def query_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Dict[str, Any]]: """Execute raw SQL query and return single result.""" results = await self.query_raw(sql + " LIMIT 1", params) return results[0] if results else None async def create_table(self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None): """Create table with custom schema and constraints.""" columns = [f"`{col}` {dtype}" for col, dtype in schema.items()] if constraints: columns.extend(constraints) columns_sql = ", ".join(columns) async with aiosqlite.connect(self._file) as db: await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") await db.commit() async def insert_unique(self, table: str, args: Dict[str, Any], unique_fields: List[str]) -> Union[str, None]: """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" try: return await self.insert(table, args) except aiosqlite.IntegrityError as e: if "UNIQUE" in str(e): return None raise async def transaction(self): """Context manager for transactions.""" return TransactionContext(self._file) async def aggregate(self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None) -> Any: """Perform aggregate functions like SUM, AVG, MAX, MIN.""" where_clause, where_params = self._build_where(where) sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" result = await self.query_one(sql, tuple(where_params)) return result['result'] if result else None class TransactionContext: """Context manager for database transactions.""" def __init__(self, db_file: str): self.db_file = db_file self.conn = None async def __aenter__(self): self.conn = await aiosqlite.connect(self.db_file) self.conn.row_factory = aiosqlite.Row await self.conn.execute("BEGIN") return self.conn async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is None: await self.conn.commit() else: await self.conn.rollback() await self.conn.close() # Test cases remain the same but with additional tests for new functionality class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.db_path = Path("temp_test.db") if self.db_path.exists(): self.db_path.unlink() self.connector = AsyncDataSet(str(self.db_path)) async def asyncTearDown(self): if self.db_path.exists(): self.db_path.unlink() async def test_insert_and_get(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertIsNotNone(rec) self.assertEqual(rec["name"], "John Doe") async def test_get_nonexistent(self): result = await self.connector.get("people", {"name": "Jane Doe"}) self.assertIsNone(result) async def test_update(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertEqual(rec["age"], 31) async def test_order_by(self): await self.connector.insert("people", {"name": "Alice", "age": 25}) await self.connector.insert("people", {"name": "Bob", "age": 30}) await self.connector.insert("people", {"name": "Charlie", "age": 20}) results = await self.connector.find("people", order_by="age ASC") self.assertEqual(results[0]["name"], "Charlie") self.assertEqual(results[-1]["name"], "Bob") async def test_raw_query(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) results = await self.connector.query_raw( "SELECT * FROM people WHERE age > ?", (26,) ) self.assertEqual(len(results), 1) self.assertEqual(results[0]["name"], "John") async def test_aggregate(self): await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "Jane", "age": 25}) await self.connector.insert("people", {"name": "Bob", "age": 35}) avg_age = await self.connector.aggregate("people", "AVG", "age") self.assertEqual(avg_age, 30) max_age = await self.connector.aggregate("people", "MAX", "age") self.assertEqual(max_age, 35) async def test_insert_with_auto_id(self): # Test auto-increment ID functionality id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) self.assertEqual(id2, id1 + 1) async def test_transaction(self): async with self.connector.transaction() as conn: await conn.execute("INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", ("test-uid", "John", 30, "2024-01-01", "2024-01-01")) # Transaction will be committed rec = await self.connector.get("people", {"name": "John"}) self.assertIsNotNone(rec) async def test_create_custom_table(self): schema = { "id": "INTEGER PRIMARY KEY AUTOINCREMENT", "username": "TEXT NOT NULL", "email": "TEXT NOT NULL", "score": "INTEGER DEFAULT 0" } constraints = ["UNIQUE(username)", "UNIQUE(email)"] await self.connector.create_table("users", schema, constraints) # Test that table was created with constraints result = await self.connector.insert_unique( "users", {"username": "john", "email": "john@example.com"}, ["username", "email"] ) self.assertIsNotNone(result) # Test duplicate insert result = await self.connector.insert_unique( "users", {"username": "john", "email": "different@example.com"}, ["username", "email"] ) self.assertIsNone(result) if __name__ == "__main__": unittest.main()