import re
import json
from uuid import uuid4
from datetime import datetime, timezone
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator
from pathlib import Path
import aiosqlite
import unittest
from types import SimpleNamespace
import asyncio
class AsyncDataSet:
_KV_TABLE = "__kv_store"
def __init__(self, file: str):
self._file = file
@staticmethod
def _utc_iso() -> str:
return (
datetime.now(timezone.utc)
.replace(microsecond=0)
.isoformat()
.replace("+00:00", "Z")
)
@staticmethod
def _py_to_sqlite_type(value: Any) -> str:
if value is None:
return "TEXT"
if isinstance(value, bool):
return "INTEGER"
if isinstance(value, int):
return "INTEGER"
if isinstance(value, float):
return "REAL"
if isinstance(value, (bytes, bytearray, memoryview)):
return "BLOB"
return "TEXT"
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
col_type = self._py_to_sqlite_type(value)
async with aiosqlite.connect(self._file) as db:
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
await db.commit()
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
cols: Dict[str, str] = {
"uid": "TEXT PRIMARY KEY",
"created_at": "TEXT",
"updated_at": "TEXT",
"deleted_at": "TEXT",
}
for key, val in col_sources.items():
if key not in cols:
cols[key] = self._py_to_sqlite_type(val)
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
async with aiosqlite.connect(self._file) as db:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await db.commit()
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
@classmethod
def _missing_column_from_error(
cls, err: aiosqlite.OperationalError
) -> Optional[str]:
m = cls._RE_NO_COLUMN.search(str(err))
return m.group(1) if m else None
@classmethod
def _missing_table_from_error(
cls, err: aiosqlite.OperationalError
) -> Optional[str]:
m = cls._RE_NO_TABLE.search(str(err))
return m.group(1) if m else None
async def _safe_execute(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
) -> aiosqlite.Cursor:
while True:
try:
async with aiosqlite.connect(self._file) as db:
cursor = await db.execute(sql, params)
await db.commit()
return cursor
except aiosqlite.OperationalError as err:
col = self._missing_column_from_error(err)
if col:
if col not in col_sources:
raise
await self._ensure_column(table, col, col_sources[col])
continue
tbl = self._missing_table_from_error(err)
if tbl:
await self._ensure_table(tbl, col_sources)
continue
raise
async def _safe_query(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
) -> AsyncGenerator[Dict[str, Any], None]:
while True:
try:
async with aiosqlite.connect(self._file) as db:
db.row_factory = aiosqlite.Row
async with db.execute(sql, params) as cursor:
async for row in cursor:
yield dict(row)
return
except aiosqlite.OperationalError as err:
tbl = self._missing_table_from_error(err)
if tbl:
await self._ensure_table(tbl, col_sources)
continue
raise
@staticmethod
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where:
return "", []
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
return " WHERE " + " AND ".join(clauses), list(vals)
async def insert(self, table: str, args: Dict[str, Any]) -> str:
uid = str(uuid4())
now = self._utc_iso()
record = {
"uid": uid,
"created_at": now,
"updated_at": now,
"deleted_at": None,
**args,
}
cols = "`" + "`, `".join(record) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
await self._safe_execute(table, sql, list(record.values()), record)
return uid
async def update(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> int:
if not args:
return 0
args["updated_at"] = self._utc_iso()
set_clause = ", ".join(f"`{k}` = ?" for k in args)
where_clause, where_params = self._build_where(where)
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
params = list(args.values()) + where_params
cur = await self._safe_execute(table, sql, params, {**args, **(where or {})})
return cur.rowcount
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
where_clause, where_params = self._build_where(where)
sql = f"DELETE FROM {table}{where_clause}"
cur = await self._safe_execute(table, sql, where_params, where or {})
return cur.rowcount
async def upsert(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> str | None:
affected = await self.update(table, args, where)
if affected:
rec = await self.get(table, where)
return rec.get("uid") if rec else None
merged = {**(where or {}), **args}
return await self.insert(table, merged)
async def get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
where_clause, where_params = self._build_where(where)
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
async for row in self._safe_query(table, sql, where_params, where or {}):
return row
return None
async def find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
) -> List[Dict[str, Any]]:
where_clause, where_params = self._build_where(where)
extra = (f" LIMIT {limit}" if limit else "") + (
f" OFFSET {offset}" if offset else ""
)
sql = f"SELECT * FROM {table}{where_clause}{extra}"
return [
row async for row in self._safe_query(table, sql, where_params, where or {})
]
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
where_clause, where_params = self._build_where(where)
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
gen = self._safe_query(table, sql, where_params, where or {})
async for row in gen:
return next(iter(row.values()), 0)
return 0
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
return (await self.count(table, where)) > 0
async def kv_set(
self,
key: str,
value: Any,
*,
table: str | None = None,
) -> None:
tbl = table or self._KV_TABLE
json_val = json.dumps(value, default=str)
await self.upsert(tbl, {"value": json_val}, {"key": key})
async def kv_get(
self,
key: str,
*,
default: Any = None,
table: str | None = None,
) -> Any:
tbl = table or self._KV_TABLE
row = await self.get(tbl, {"key": key})
if not row:
return default
try:
return json.loads(row["value"])
except Exception:
return default
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.db_path = Path("temp_test.db")
if self.db_path.exists():
self.db_path.unlink()
self.connector = AsyncDataSet(str(self.db_path))
async def test_insert_and_get(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
rec = await self.connector.get("people", {"name": "John Doe"})
self.assertIsNotNone(rec)
self.assertEqual(rec["name"], "John Doe")
async def test_get_nonexistent(self):
result = await self.connector.get("people", {"name": "Jane Doe"})
self.assertIsNone(result)
async def test_update(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
rec = await self.connector.get("people", {"name": "John Doe"})
self.assertEqual(rec["age"], 31)
async def test_upsert_insert(self):
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
rec = await self.connector.get("people", {"uid": uid})
self.assertEqual(rec["name"], "Alice")
self.assertEqual(rec["age"], 22)
self.alice_uid = uid
async def test_upsert_update(self):
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"})
self.assertEqual(uid_same, uid)
rec = await self.connector.get("people", {"uid": uid})
self.assertEqual(rec["age"], 23)
async def test_count_and_exists(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
await self.connector.insert("people", {"name": "Alice", "age": 23})
total = await self.connector.count("people")
self.assertEqual(total, 2)
self.assertTrue(await self.connector.exists("people", {"name": "Alice"}))
self.assertFalse(await self.connector.exists("people", {"name": "Nobody"}))
async def test_find_pagination(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
await self.connector.insert("people", {"name": "Alice", "age": 23})
rows = await self.connector.find("people", limit=1, offset=1)
self.assertEqual(len(rows), 1)
self.assertIn(rows[0]["name"], {"John Doe", "Alice"})
async def test_kv_set_and_get(self):
await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"})
cfg = await self.connector.kv_get("config:threshold")
self.assertEqual(cfg, {"limit": 10, "mode": "fast"})
async def test_delete(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
await self.connector.insert("people", {"name": "Alice", "age": 23})
await self.connector.delete("people", {"name": "John Doe"})
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
count = await self.connector.count("people")
self.assertEqual(count, 1)
async def test_inject(self):
await self.connector.insert("people", {"name": "John Doe", "index": 30})
await self.connector.insert("people", {"name": "Alice", "index": 23})
await self.connector.delete("people", {"name": "John Doe", "index": 30})
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
count = await self.connector.count("people")
self.assertEqual(count, 1)
async def test_insert_binary(self):
await self.connector.insert("binaries", {"data": b"1234"})
print(await self.connector.get("binaries", {"data": b"1234"}))
if __name__ == "__main__":
unittest.main()