import re import json from uuid import uuid4 from datetime import datetime, timezone from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator from pathlib import Path import aiosqlite import unittest from types import SimpleNamespace import asyncio class AsyncDataSet: _KV_TABLE = "__kv_store" def __init__(self, file: str): self._file = file @staticmethod def _utc_iso() -> str: return ( datetime.now(timezone.utc) .replace(microsecond=0) .isoformat() .replace("+00:00", "Z") ) @staticmethod def _py_to_sqlite_type(value: Any) -> str: if value is None: return "TEXT" if isinstance(value, bool): return "INTEGER" if isinstance(value, int): return "INTEGER" if isinstance(value, float): return "REAL" if isinstance(value, (bytes, bytearray, memoryview)): return "BLOB" return "TEXT" async def _ensure_column(self, table: str, name: str, value: Any) -> None: col_type = self._py_to_sqlite_type(value) async with aiosqlite.connect(self._file) as db: await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}") await db.commit() async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: cols: Dict[str, str] = { "uid": "TEXT PRIMARY KEY", "created_at": "TEXT", "updated_at": "TEXT", "deleted_at": "TEXT", } for key, val in col_sources.items(): if key not in cols: cols[key] = self._py_to_sqlite_type(val) columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) async with aiosqlite.connect(self._file) as db: await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") await db.commit() _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") _RE_NO_TABLE = re.compile(r"no such table: (\w+)") @classmethod def _missing_column_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_COLUMN.search(str(err)) return m.group(1) if m else None @classmethod def _missing_table_from_error( cls, err: aiosqlite.OperationalError ) -> Optional[str]: m = cls._RE_NO_TABLE.search(str(err)) return m.group(1) if m else None async def _safe_execute( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], ) -> aiosqlite.Cursor: while True: try: async with aiosqlite.connect(self._file) as db: cursor = await db.execute(sql, params) await db.commit() return cursor except aiosqlite.OperationalError as err: col = self._missing_column_from_error(err) if col: if col not in col_sources: raise await self._ensure_column(table, col, col_sources[col]) continue tbl = self._missing_table_from_error(err) if tbl: await self._ensure_table(tbl, col_sources) continue raise async def _safe_query( self, table: str, sql: str, params: Iterable[Any], col_sources: Dict[str, Any], ) -> AsyncGenerator[Dict[str, Any], None]: while True: try: async with aiosqlite.connect(self._file) as db: db.row_factory = aiosqlite.Row async with db.execute(sql, params) as cursor: async for row in cursor: yield dict(row) return except aiosqlite.OperationalError as err: tbl = self._missing_table_from_error(err) if tbl: await self._ensure_table(tbl, col_sources) continue raise @staticmethod def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: if not where: return "", [] clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) return " WHERE " + " AND ".join(clauses), list(vals) async def insert(self, table: str, args: Dict[str, Any]) -> str: uid = str(uuid4()) now = self._utc_iso() record = { "uid": uid, "created_at": now, "updated_at": now, "deleted_at": None, **args, } cols = "`" + "`, `".join(record) + "`" qs = ", ".join(["?"] * len(record)) sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" await self._safe_execute(table, sql, list(record.values()), record) return uid async def update( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> int: if not args: return 0 args["updated_at"] = self._utc_iso() set_clause = ", ".join(f"`{k}` = ?" for k in args) where_clause, where_params = self._build_where(where) sql = f"UPDATE {table} SET {set_clause}{where_clause}" params = list(args.values()) + where_params cur = await self._safe_execute(table, sql, params, {**args, **(where or {})}) return cur.rowcount async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: where_clause, where_params = self._build_where(where) sql = f"DELETE FROM {table}{where_clause}" cur = await self._safe_execute(table, sql, where_params, where or {}) return cur.rowcount async def upsert( self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None, ) -> str | None: affected = await self.update(table, args, where) if affected: rec = await self.get(table, where) return rec.get("uid") if rec else None merged = {**(where or {}), **args} return await self.insert(table, merged) async def get( self, table: str, where: Optional[Dict[str, Any]] = None ) -> Optional[Dict[str, Any]]: where_clause, where_params = self._build_where(where) sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" async for row in self._safe_query(table, sql, where_params, where or {}): return row return None async def find( self, table: str, where: Optional[Dict[str, Any]] = None, *, limit: int = 0, offset: int = 0, ) -> List[Dict[str, Any]]: where_clause, where_params = self._build_where(where) extra = (f" LIMIT {limit}" if limit else "") + ( f" OFFSET {offset}" if offset else "" ) sql = f"SELECT * FROM {table}{where_clause}{extra}" return [ row async for row in self._safe_query(table, sql, where_params, where or {}) ] async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: where_clause, where_params = self._build_where(where) sql = f"SELECT COUNT(*) FROM {table}{where_clause}" gen = self._safe_query(table, sql, where_params, where or {}) async for row in gen: return next(iter(row.values()), 0) return 0 async def exists(self, table: str, where: Dict[str, Any]) -> bool: return (await self.count(table, where)) > 0 async def kv_set( self, key: str, value: Any, *, table: str | None = None, ) -> None: tbl = table or self._KV_TABLE json_val = json.dumps(value, default=str) await self.upsert(tbl, {"value": json_val}, {"key": key}) async def kv_get( self, key: str, *, default: Any = None, table: str | None = None, ) -> Any: tbl = table or self._KV_TABLE row = await self.get(tbl, {"key": key}) if not row: return default try: return json.loads(row["value"]) except Exception: return default class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.db_path = Path("temp_test.db") if self.db_path.exists(): self.db_path.unlink() self.connector = AsyncDataSet(str(self.db_path)) async def test_insert_and_get(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertIsNotNone(rec) self.assertEqual(rec["name"], "John Doe") async def test_get_nonexistent(self): result = await self.connector.get("people", {"name": "Jane Doe"}) self.assertIsNone(result) async def test_update(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) rec = await self.connector.get("people", {"name": "John Doe"}) self.assertEqual(rec["age"], 31) async def test_upsert_insert(self): uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"}) rec = await self.connector.get("people", {"uid": uid}) self.assertEqual(rec["name"], "Alice") self.assertEqual(rec["age"], 22) self.alice_uid = uid async def test_upsert_update(self): uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"}) uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"}) self.assertEqual(uid_same, uid) rec = await self.connector.get("people", {"uid": uid}) self.assertEqual(rec["age"], 23) async def test_count_and_exists(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.insert("people", {"name": "Alice", "age": 23}) total = await self.connector.count("people") self.assertEqual(total, 2) self.assertTrue(await self.connector.exists("people", {"name": "Alice"})) self.assertFalse(await self.connector.exists("people", {"name": "Nobody"})) async def test_find_pagination(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.insert("people", {"name": "Alice", "age": 23}) rows = await self.connector.find("people", limit=1, offset=1) self.assertEqual(len(rows), 1) self.assertIn(rows[0]["name"], {"John Doe", "Alice"}) async def test_kv_set_and_get(self): await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"}) cfg = await self.connector.kv_get("config:threshold") self.assertEqual(cfg, {"limit": 10, "mode": "fast"}) async def test_delete(self): await self.connector.insert("people", {"name": "John Doe", "age": 30}) await self.connector.insert("people", {"name": "Alice", "age": 23}) await self.connector.delete("people", {"name": "John Doe"}) self.assertIsNone(await self.connector.get("people", {"name": "John Doe"})) count = await self.connector.count("people") self.assertEqual(count, 1) async def test_inject(self): await self.connector.insert("people", {"name": "John Doe", "index": 30}) await self.connector.insert("people", {"name": "Alice", "index": 23}) await self.connector.delete("people", {"name": "John Doe", "index": 30}) self.assertIsNone(await self.connector.get("people", {"name": "John Doe"})) count = await self.connector.count("people") self.assertEqual(count, 1) async def test_insert_binary(self): await self.connector.insert("binaries", {"data": b"1234"}) print(await self.connector.get("binaries", {"data": b"1234"})) if __name__ == "__main__": unittest.main()