|
import re
|
|
import json
|
|
from uuid import uuid4
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator
|
|
from pathlib import Path
|
|
import aiosqlite
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
import asyncio
|
|
|
|
|
|
class AsyncDataSet:
|
|
|
|
_KV_TABLE = "__kv_store"
|
|
|
|
def __init__(self, file: str):
|
|
self._file = file
|
|
|
|
@staticmethod
|
|
def _utc_iso() -> str:
|
|
return (
|
|
datetime.now(timezone.utc)
|
|
.replace(microsecond=0)
|
|
.isoformat()
|
|
.replace("+00:00", "Z")
|
|
)
|
|
|
|
@staticmethod
|
|
def _py_to_sqlite_type(value: Any) -> str:
|
|
if value is None:
|
|
return "TEXT"
|
|
if isinstance(value, bool):
|
|
return "INTEGER"
|
|
if isinstance(value, int):
|
|
return "INTEGER"
|
|
if isinstance(value, float):
|
|
return "REAL"
|
|
if isinstance(value, (bytes, bytearray, memoryview)):
|
|
return "BLOB"
|
|
return "TEXT"
|
|
|
|
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
|
col_type = self._py_to_sqlite_type(value)
|
|
async with aiosqlite.connect(self._file) as db:
|
|
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
|
await db.commit()
|
|
|
|
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
|
cols: Dict[str, str] = {
|
|
"uid": "TEXT PRIMARY KEY",
|
|
"created_at": "TEXT",
|
|
"updated_at": "TEXT",
|
|
"deleted_at": "TEXT",
|
|
}
|
|
for key, val in col_sources.items():
|
|
if key not in cols:
|
|
cols[key] = self._py_to_sqlite_type(val)
|
|
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
|
async with aiosqlite.connect(self._file) as db:
|
|
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
|
await db.commit()
|
|
|
|
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
|
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
|
|
|
@classmethod
|
|
def _missing_column_from_error(
|
|
cls, err: aiosqlite.OperationalError
|
|
) -> Optional[str]:
|
|
m = cls._RE_NO_COLUMN.search(str(err))
|
|
return m.group(1) if m else None
|
|
|
|
@classmethod
|
|
def _missing_table_from_error(
|
|
cls, err: aiosqlite.OperationalError
|
|
) -> Optional[str]:
|
|
m = cls._RE_NO_TABLE.search(str(err))
|
|
return m.group(1) if m else None
|
|
|
|
async def _safe_execute(
|
|
self,
|
|
table: str,
|
|
sql: str,
|
|
params: Iterable[Any],
|
|
col_sources: Dict[str, Any],
|
|
) -> aiosqlite.Cursor:
|
|
while True:
|
|
try:
|
|
async with aiosqlite.connect(self._file) as db:
|
|
cursor = await db.execute(sql, params)
|
|
await db.commit()
|
|
return cursor
|
|
except aiosqlite.OperationalError as err:
|
|
col = self._missing_column_from_error(err)
|
|
if col:
|
|
if col not in col_sources:
|
|
raise
|
|
await self._ensure_column(table, col, col_sources[col])
|
|
continue
|
|
tbl = self._missing_table_from_error(err)
|
|
if tbl:
|
|
await self._ensure_table(tbl, col_sources)
|
|
continue
|
|
raise
|
|
|
|
async def _safe_query(
|
|
self,
|
|
table: str,
|
|
sql: str,
|
|
params: Iterable[Any],
|
|
col_sources: Dict[str, Any],
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
while True:
|
|
try:
|
|
async with aiosqlite.connect(self._file) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
async with db.execute(sql, params) as cursor:
|
|
async for row in cursor:
|
|
yield dict(row)
|
|
return
|
|
except aiosqlite.OperationalError as err:
|
|
tbl = self._missing_table_from_error(err)
|
|
if tbl:
|
|
await self._ensure_table(tbl, col_sources)
|
|
continue
|
|
raise
|
|
|
|
@staticmethod
|
|
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
|
if not where:
|
|
return "", []
|
|
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
|
return " WHERE " + " AND ".join(clauses), list(vals)
|
|
|
|
async def insert(self, table: str, args: Dict[str, Any]) -> str:
|
|
uid = str(uuid4())
|
|
now = self._utc_iso()
|
|
record = {
|
|
"uid": uid,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"deleted_at": None,
|
|
**args,
|
|
}
|
|
cols = "`" + "`, `".join(record) + "`"
|
|
qs = ", ".join(["?"] * len(record))
|
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
|
await self._safe_execute(table, sql, list(record.values()), record)
|
|
return uid
|
|
|
|
async def update(
|
|
self,
|
|
table: str,
|
|
args: Dict[str, Any],
|
|
where: Optional[Dict[str, Any]] = None,
|
|
) -> int:
|
|
if not args:
|
|
return 0
|
|
args["updated_at"] = self._utc_iso()
|
|
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
|
params = list(args.values()) + where_params
|
|
cur = await self._safe_execute(table, sql, params, {**args, **(where or {})})
|
|
return cur.rowcount
|
|
|
|
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"DELETE FROM {table}{where_clause}"
|
|
cur = await self._safe_execute(table, sql, where_params, where or {})
|
|
return cur.rowcount
|
|
|
|
async def upsert(
|
|
self,
|
|
table: str,
|
|
args: Dict[str, Any],
|
|
where: Optional[Dict[str, Any]] = None,
|
|
) -> str | None:
|
|
affected = await self.update(table, args, where)
|
|
if affected:
|
|
rec = await self.get(table, where)
|
|
return rec.get("uid") if rec else None
|
|
merged = {**(where or {}), **args}
|
|
return await self.insert(table, merged)
|
|
|
|
async def get(
|
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
|
|
async for row in self._safe_query(table, sql, where_params, where or {}):
|
|
return row
|
|
return None
|
|
|
|
async def find(
|
|
self,
|
|
table: str,
|
|
where: Optional[Dict[str, Any]] = None,
|
|
*,
|
|
limit: int = 0,
|
|
offset: int = 0,
|
|
) -> List[Dict[str, Any]]:
|
|
where_clause, where_params = self._build_where(where)
|
|
extra = (f" LIMIT {limit}" if limit else "") + (
|
|
f" OFFSET {offset}" if offset else ""
|
|
)
|
|
sql = f"SELECT * FROM {table}{where_clause}{extra}"
|
|
return [
|
|
row async for row in self._safe_query(table, sql, where_params, where or {})
|
|
]
|
|
|
|
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
|
gen = self._safe_query(table, sql, where_params, where or {})
|
|
async for row in gen:
|
|
return next(iter(row.values()), 0)
|
|
return 0
|
|
|
|
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
|
return (await self.count(table, where)) > 0
|
|
|
|
async def kv_set(
|
|
self,
|
|
key: str,
|
|
value: Any,
|
|
*,
|
|
table: str | None = None,
|
|
) -> None:
|
|
tbl = table or self._KV_TABLE
|
|
json_val = json.dumps(value, default=str)
|
|
await self.upsert(tbl, {"value": json_val}, {"key": key})
|
|
|
|
async def kv_get(
|
|
self,
|
|
key: str,
|
|
*,
|
|
default: Any = None,
|
|
table: str | None = None,
|
|
) -> Any:
|
|
tbl = table or self._KV_TABLE
|
|
row = await self.get(tbl, {"key": key})
|
|
if not row:
|
|
return default
|
|
try:
|
|
return json.loads(row["value"])
|
|
except Exception:
|
|
return default
|
|
|
|
|
|
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def asyncSetUp(self):
|
|
self.db_path = Path("temp_test.db")
|
|
if self.db_path.exists():
|
|
self.db_path.unlink()
|
|
self.connector = AsyncDataSet(str(self.db_path))
|
|
|
|
async def test_insert_and_get(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
|
self.assertIsNotNone(rec)
|
|
self.assertEqual(rec["name"], "John Doe")
|
|
|
|
async def test_get_nonexistent(self):
|
|
result = await self.connector.get("people", {"name": "Jane Doe"})
|
|
self.assertIsNone(result)
|
|
|
|
async def test_update(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
|
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
|
self.assertEqual(rec["age"], 31)
|
|
|
|
async def test_upsert_insert(self):
|
|
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["name"], "Alice")
|
|
self.assertEqual(rec["age"], 22)
|
|
self.alice_uid = uid
|
|
|
|
async def test_upsert_update(self):
|
|
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
|
uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"})
|
|
self.assertEqual(uid_same, uid)
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["age"], 23)
|
|
|
|
async def test_count_and_exists(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
total = await self.connector.count("people")
|
|
self.assertEqual(total, 2)
|
|
self.assertTrue(await self.connector.exists("people", {"name": "Alice"}))
|
|
self.assertFalse(await self.connector.exists("people", {"name": "Nobody"}))
|
|
|
|
async def test_find_pagination(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
rows = await self.connector.find("people", limit=1, offset=1)
|
|
self.assertEqual(len(rows), 1)
|
|
self.assertIn(rows[0]["name"], {"John Doe", "Alice"})
|
|
|
|
async def test_kv_set_and_get(self):
|
|
await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"})
|
|
cfg = await self.connector.kv_get("config:threshold")
|
|
self.assertEqual(cfg, {"limit": 10, "mode": "fast"})
|
|
|
|
async def test_delete(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
await self.connector.delete("people", {"name": "John Doe"})
|
|
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
|
count = await self.connector.count("people")
|
|
self.assertEqual(count, 1)
|
|
|
|
async def test_inject(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "index": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "index": 23})
|
|
await self.connector.delete("people", {"name": "John Doe", "index": 30})
|
|
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
|
count = await self.connector.count("people")
|
|
self.assertEqual(count, 1)
|
|
|
|
async def test_insert_binary(self):
|
|
await self.connector.insert("binaries", {"data": b"1234"})
|
|
print(await self.connector.get("binaries", {"data": b"1234"}))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|