diff --git a/ads.py b/ads.py new file mode 100644 index 0000000..2c4a2ec --- /dev/null +++ b/ads.py @@ -0,0 +1,335 @@ +import re +import json +from uuid import uuid4 +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator +from pathlib import Path +import aiosqlite +import unittest +from types import SimpleNamespace +import asyncio + + +class AsyncDataSetConnector: + + _KV_TABLE = "__kv_store" + + def __init__(self, file: str): + self._file = file + + @staticmethod + def _utc_iso() -> str: + return ( + datetime.now(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + @staticmethod + def _py_to_sqlite_type(value: Any) -> str: + if value is None: + return "TEXT" + if isinstance(value, bool): + return "INTEGER" + if isinstance(value, int): + return "INTEGER" + if isinstance(value, float): + return "REAL" + if isinstance(value, (bytes, bytearray, memoryview)): + return "BLOB" + return "TEXT" + + async def _ensure_column(self, table: str, name: str, value: Any) -> None: + col_type = self._py_to_sqlite_type(value) + async with aiosqlite.connect(self._file) as db: + await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}") + await db.commit() + + async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + cols: Dict[str, str] = { + "uid": "TEXT PRIMARY KEY", + "created_at": "TEXT", + "updated_at": "TEXT", + "deleted_at": "TEXT", + } + for key, val in col_sources.items(): + if key not in cols: + cols[key] = self._py_to_sqlite_type(val) + columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) + async with aiosqlite.connect(self._file) as db: + await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") + await db.commit() + + _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") + _RE_NO_TABLE = re.compile(r"no such table: (\w+)") + + @classmethod + def _missing_column_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_COLUMN.search(str(err)) + return m.group(1) if m else None + + @classmethod + def _missing_table_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_TABLE.search(str(err)) + return m.group(1) if m else None + + async def _safe_execute( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> aiosqlite.Cursor: + while True: + try: + async with aiosqlite.connect(self._file) as db: + cursor = await db.execute(sql, params) + await db.commit() + return cursor + except aiosqlite.OperationalError as err: + col = self._missing_column_from_error(err) + if col: + if col not in col_sources: + raise + await self._ensure_column(table, col, col_sources[col]) + continue + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + raise + + async def _safe_query( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> AsyncGenerator[Dict[str, Any], None]: + while True: + try: + async with aiosqlite.connect(self._file) as db: + db.row_factory = aiosqlite.Row + async with db.execute(sql, params) as cursor: + async for row in cursor: + yield dict(row) + return + except aiosqlite.OperationalError as err: + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + raise + + @staticmethod + def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + if not where: + return "", [] + clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) + return " WHERE " + " AND ".join(clauses), list(vals) + + async def insert(self, table: str, args: Dict[str, Any]) -> str: + uid = str(uuid4()) + now = self._utc_iso() + record = { + "uid": uid, + "created_at": now, + "updated_at": now, + "deleted_at": None, + **args, + } + cols = "`" + "`, `".join(record) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + await self._safe_execute(table, sql, list(record.values()), record) + return uid + + async def update( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> int: + if not args: + return 0 + args["updated_at"] = self._utc_iso() + set_clause = ", ".join(f"`{k}` = ?" for k in args) + where_clause, where_params = self._build_where(where) + sql = f"UPDATE {table} SET {set_clause}{where_clause}" + params = list(args.values()) + where_params + cur = await self._safe_execute(table, sql, params, {**args, **(where or {})}) + return cur.rowcount + + async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + where_clause, where_params = self._build_where(where) + sql = f"DELETE FROM {table}{where_clause}" + cur = await self._safe_execute(table, sql, where_params, where or {}) + return cur.rowcount + + async def upsert( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> str | None: + affected = await self.update(table, args, where) + if affected: + rec = await self.get(table, where) + return rec.get("uid") if rec else None + merged = {**(where or {}), **args} + return await self.insert(table, merged) + + async def get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" + async for row in self._safe_query(table, sql, where_params, where or {}): + return row + return None + + async def find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + limit: int = 0, + offset: int = 0, + ) -> List[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + extra = (f" LIMIT {limit}" if limit else "") + ( + f" OFFSET {offset}" if offset else "" + ) + sql = f"SELECT * FROM {table}{where_clause}{extra}" + return [ + row async for row in self._safe_query(table, sql, where_params, where or {}) + ] + + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + where_clause, where_params = self._build_where(where) + sql = f"SELECT COUNT(*) FROM {table}{where_clause}" + gen = self._safe_query(table, sql, where_params, where or {}) + async for row in gen: + return next(iter(row.values()), 0) + return 0 + + async def exists(self, table: str, where: Dict[str, Any]) -> bool: + return (await self.count(table, where)) > 0 + + async def kv_set( + self, + key: str, + value: Any, + *, + table: str | None = None, + ) -> None: + tbl = table or self._KV_TABLE + json_val = json.dumps(value, default=str) + await self.upsert(tbl, {"value": json_val}, {"key": key}) + + async def kv_get( + self, + key: str, + *, + default: Any = None, + table: str | None = None, + ) -> Any: + tbl = table or self._KV_TABLE + row = await self.get(tbl, {"key": key}) + if not row: + return default + try: + return json.loads(row["value"]) + except Exception: + return default + + + + + +class TestAsyncDataSetConnector(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.db_path = Path("temp_test.db") + if self.db_path.exists(): + self.db_path.unlink() + self.connector = AsyncDataSetConnector(str(self.db_path)) + + async def test_insert_and_get(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertIsNotNone(rec) + self.assertEqual(rec["name"], "John Doe") + + async def test_get_nonexistent(self): + result = await self.connector.get("people", {"name": "Jane Doe"}) + self.assertIsNone(result) + + async def test_update(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertEqual(rec["age"], 31) + + async def test_upsert_insert(self): + uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"}) + rec = await self.connector.get("people", {"uid": uid}) + self.assertEqual(rec["name"], "Alice") + self.assertEqual(rec["age"], 22) + self.alice_uid = uid + + async def test_upsert_update(self): + uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"}) + uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"}) + self.assertEqual(uid_same, uid) + rec = await self.connector.get("people", {"uid": uid}) + self.assertEqual(rec["age"], 23) + + async def test_count_and_exists(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.insert("people", {"name": "Alice", "age": 23}) + total = await self.connector.count("people") + self.assertEqual(total, 2) + self.assertTrue(await self.connector.exists("people", {"name": "Alice"})) + self.assertFalse(await self.connector.exists("people", {"name": "Nobody"})) + + async def test_find_pagination(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.insert("people", {"name": "Alice", "age": 23}) + rows = await self.connector.find("people", limit=1, offset=1) + self.assertEqual(len(rows), 1) + self.assertIn(rows[0]["name"], {"John Doe", "Alice"}) + + async def test_kv_set_and_get(self): + await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"}) + cfg = await self.connector.kv_get("config:threshold") + self.assertEqual(cfg, {"limit": 10, "mode": "fast"}) + + async def test_delete(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.insert("people", {"name": "Alice", "age": 23}) + await self.connector.delete("people", {"name": "John Doe"}) + self.assertIsNone(await self.connector.get("people", {"name": "John Doe"})) + count = await self.connector.count("people") + self.assertEqual(count, 1) + + async def test_inject(self): + await self.connector.insert("people", {"name": "John Doe", "index": 30}) + await self.connector.insert("people", {"name": "Alice", "index": 23}) + await self.connector.delete("people", {"name": "John Doe","index": 30}) + self.assertIsNone(await self.connector.get("people", {"name": "John Doe"})) + count = await self.connector.count("people") + self.assertEqual(count, 1) + + + async def test_insert_binary(self): + await self.connector.insert("binaries",{"data":b"1234"}) + print(await self.connector.get("binaries",{"data":b"1234"})) + +if __name__ == "__main__": + unittest.main()