diff --git a/ads.py b/ads.py new file mode 100644 index 0000000..7dba74e --- /dev/null +++ b/ads.py @@ -0,0 +1,456 @@ +import re +import json +from uuid import uuid4 +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple +from pathlib import Path +import aiosqlite +import unittest +from types import SimpleNamespace +import asyncio + + +class AsyncDataSet: + + _KV_TABLE = "__kv_store" + + def __init__(self, file: str): + self._file = file + + @staticmethod + def _utc_iso() -> str: + return ( + datetime.now(timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + + @staticmethod + def _py_to_sqlite_type(value: Any) -> str: + if value is None: + return "TEXT" + if isinstance(value, bool): + return "INTEGER" + if isinstance(value, int): + return "INTEGER" + if isinstance(value, float): + return "REAL" + if isinstance(value, (bytes, bytearray, memoryview)): + return "BLOB" + return "TEXT" + + async def _ensure_column(self, table: str, name: str, value: Any) -> None: + col_type = self._py_to_sqlite_type(value) + async with aiosqlite.connect(self._file) as db: + await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}") + await db.commit() + + async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + cols: Dict[str, str] = { + "uid": "TEXT PRIMARY KEY", + "created_at": "TEXT", + "updated_at": "TEXT", + "deleted_at": "TEXT", + } + for key, val in col_sources.items(): + if key not in cols: + cols[key] = self._py_to_sqlite_type(val) + columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) + async with aiosqlite.connect(self._file) as db: + await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") + await db.commit() + + _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") + _RE_NO_TABLE = re.compile(r"no such table: (\w+)") + + @classmethod + def _missing_column_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_COLUMN.search(str(err)) + return m.group(1) if m else None + + @classmethod + def _missing_table_from_error( + cls, err: aiosqlite.OperationalError + ) -> Optional[str]: + m = cls._RE_NO_TABLE.search(str(err)) + return m.group(1) if m else None + + async def _safe_execute( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> aiosqlite.Cursor: + while True: + try: + async with aiosqlite.connect(self._file) as db: + cursor = await db.execute(sql, params) + await db.commit() + return cursor + except aiosqlite.OperationalError as err: + col = self._missing_column_from_error(err) + if col: + if col not in col_sources: + raise + await self._ensure_column(table, col, col_sources[col]) + continue + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + raise + + async def _safe_query( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> AsyncGenerator[Dict[str, Any], None]: + while True: + try: + async with aiosqlite.connect(self._file) as db: + db.row_factory = aiosqlite.Row + async with db.execute(sql, params) as cursor: + async for row in cursor: + yield dict(row) + return + except aiosqlite.OperationalError as err: + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + raise + + @staticmethod + def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + if not where: + return "", [] + clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) + return " WHERE " + " AND ".join(clauses), list(vals) + + async def insert(self, table: str, args: Dict[str, Any], return_id: bool = False) -> Union[str, int]: + """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" + uid = str(uuid4()) + now = self._utc_iso() + record = { + "uid": uid, + "created_at": now, + "updated_at": now, + "deleted_at": None, + **args, + } + + # Handle auto-increment ID if requested + if return_id and 'id' not in args: + async with aiosqlite.connect(self._file) as db: + # Create table with autoincrement ID if needed + await db.execute(f""" + CREATE TABLE IF NOT EXISTS {table} ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + uid TEXT UNIQUE, + created_at TEXT, + updated_at TEXT, + deleted_at TEXT + ) + """) + await db.commit() + + # Insert and get lastrowid + cols = "`" + "`, `".join(record.keys()) + "`" + qs = ", ".join(["?"] * len(record)) + cursor = await db.execute(f"INSERT INTO {table} ({cols}) VALUES ({qs})", list(record.values())) + await db.commit() + return cursor.lastrowid + + cols = "`" + "`, `".join(record) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + await self._safe_execute(table, sql, list(record.values()), record) + return uid + + async def update( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> int: + if not args: + return 0 + args["updated_at"] = self._utc_iso() + set_clause = ", ".join(f"`{k}` = ?" for k in args) + where_clause, where_params = self._build_where(where) + sql = f"UPDATE {table} SET {set_clause}{where_clause}" + params = list(args.values()) + where_params + cur = await self._safe_execute(table, sql, params, {**args, **(where or {})}) + return cur.rowcount + + async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + where_clause, where_params = self._build_where(where) + sql = f"DELETE FROM {table}{where_clause}" + cur = await self._safe_execute(table, sql, where_params, where or {}) + return cur.rowcount + + async def upsert( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> str | None: + if not args: + raise ValueError("Nothing to update. Empty dict given.") + args['updated_at'] = self._utc_iso() + affected = await self.update(table, args, where) + if affected: + rec = await self.get(table, where) + return rec.get("uid") if rec else None + merged = {**(where or {}), **args} + return await self.insert(table, merged) + + async def get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" + async for row in self._safe_query(table, sql, where_params, where or {}): + return row + return None + + async def find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Find records with optional ordering.""" + where_clause, where_params = self._build_where(where) + order_clause = f" ORDER BY {order_by}" if order_by else "" + extra = (f" LIMIT {limit}" if limit else "") + ( + f" OFFSET {offset}" if offset else "" + ) + sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" + return [ + row async for row in self._safe_query(table, sql, where_params, where or {}) + ] + + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + where_clause, where_params = self._build_where(where) + sql = f"SELECT COUNT(*) FROM {table}{where_clause}" + gen = self._safe_query(table, sql, where_params, where or {}) + async for row in gen: + return next(iter(row.values()), 0) + return 0 + + async def exists(self, table: str, where: Dict[str, Any]) -> bool: + return (await self.count(table, where)) > 0 + + async def kv_set( + self, + key: str, + value: Any, + *, + table: str | None = None, + ) -> None: + tbl = table or self._KV_TABLE + json_val = json.dumps(value, default=str) + await self.upsert(tbl, {"value": json_val}, {"key": key}) + + async def kv_get( + self, + key: str, + *, + default: Any = None, + table: str | None = None, + ) -> Any: + tbl = table or self._KV_TABLE + row = await self.get(tbl, {"key": key}) + if not row: + return default + try: + return json.loads(row["value"]) + except Exception: + return default + + async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: + """Execute raw SQL for complex queries like JOINs.""" + async with aiosqlite.connect(self._file) as db: + cursor = await db.execute(sql, params or ()) + await db.commit() + return cursor + + async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]: + """Execute raw SQL query and return results as list of dicts.""" + async with aiosqlite.connect(self._file) as db: + db.row_factory = aiosqlite.Row + async with db.execute(sql, params or ()) as cursor: + return [dict(row) async for row in cursor] + + async def query_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Dict[str, Any]]: + """Execute raw SQL query and return single result.""" + results = await self.query_raw(sql + " LIMIT 1", params) + return results[0] if results else None + + async def create_table(self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None): + """Create table with custom schema and constraints.""" + columns = [f"`{col}` {dtype}" for col, dtype in schema.items()] + if constraints: + columns.extend(constraints) + columns_sql = ", ".join(columns) + + async with aiosqlite.connect(self._file) as db: + await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") + await db.commit() + + async def insert_unique(self, table: str, args: Dict[str, Any], unique_fields: List[str]) -> Union[str, None]: + """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" + try: + return await self.insert(table, args) + except aiosqlite.IntegrityError as e: + if "UNIQUE" in str(e): + return None + raise + + async def transaction(self): + """Context manager for transactions.""" + return TransactionContext(self._file) + + async def aggregate(self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None) -> Any: + """Perform aggregate functions like SUM, AVG, MAX, MIN.""" + where_clause, where_params = self._build_where(where) + sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" + result = await self.query_one(sql, tuple(where_params)) + return result['result'] if result else None + + +class TransactionContext: + """Context manager for database transactions.""" + + def __init__(self, db_file: str): + self.db_file = db_file + self.conn = None + + async def __aenter__(self): + self.conn = await aiosqlite.connect(self.db_file) + self.conn.row_factory = aiosqlite.Row + await self.conn.execute("BEGIN") + return self.conn + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + await self.conn.commit() + else: + await self.conn.rollback() + await self.conn.close() + + +# Test cases remain the same but with additional tests for new functionality +class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.db_path = Path("temp_test.db") + if self.db_path.exists(): + self.db_path.unlink() + self.connector = AsyncDataSet(str(self.db_path)) + + async def asyncTearDown(self): + if self.db_path.exists(): + self.db_path.unlink() + + async def test_insert_and_get(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertIsNotNone(rec) + self.assertEqual(rec["name"], "John Doe") + + async def test_get_nonexistent(self): + result = await self.connector.get("people", {"name": "Jane Doe"}) + self.assertIsNone(result) + + async def test_update(self): + await self.connector.insert("people", {"name": "John Doe", "age": 30}) + await self.connector.update("people", {"age": 31}, {"name": "John Doe"}) + rec = await self.connector.get("people", {"name": "John Doe"}) + self.assertEqual(rec["age"], 31) + + async def test_order_by(self): + await self.connector.insert("people", {"name": "Alice", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 30}) + await self.connector.insert("people", {"name": "Charlie", "age": 20}) + + results = await self.connector.find("people", order_by="age ASC") + self.assertEqual(results[0]["name"], "Charlie") + self.assertEqual(results[-1]["name"], "Bob") + + async def test_raw_query(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + + results = await self.connector.query_raw( + "SELECT * FROM people WHERE age > ?", (26,) + ) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "John") + + async def test_aggregate(self): + await self.connector.insert("people", {"name": "John", "age": 30}) + await self.connector.insert("people", {"name": "Jane", "age": 25}) + await self.connector.insert("people", {"name": "Bob", "age": 35}) + + avg_age = await self.connector.aggregate("people", "AVG", "age") + self.assertEqual(avg_age, 30) + + max_age = await self.connector.aggregate("people", "MAX", "age") + self.assertEqual(max_age, 35) + + async def test_insert_with_auto_id(self): + # Test auto-increment ID functionality + id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True) + id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True) + self.assertEqual(id2, id1 + 1) + + async def test_transaction(self): + async with self.connector.transaction() as conn: + await conn.execute("INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", + ("test-uid", "John", 30, "2024-01-01", "2024-01-01")) + # Transaction will be committed + + rec = await self.connector.get("people", {"name": "John"}) + self.assertIsNotNone(rec) + + async def test_create_custom_table(self): + schema = { + "id": "INTEGER PRIMARY KEY AUTOINCREMENT", + "username": "TEXT NOT NULL", + "email": "TEXT NOT NULL", + "score": "INTEGER DEFAULT 0" + } + constraints = ["UNIQUE(username)", "UNIQUE(email)"] + + await self.connector.create_table("users", schema, constraints) + + # Test that table was created with constraints + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "john@example.com"}, + ["username", "email"] + ) + self.assertIsNotNone(result) + + # Test duplicate insert + result = await self.connector.insert_unique( + "users", + {"username": "john", "email": "different@example.com"}, + ["username", "email"] + ) + self.assertIsNone(result) + + +if __name__ == "__main__": + unittest.main()