|
import re
|
|
import json
|
|
from uuid import uuid4
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator
|
|
from pathlib import Path
|
|
import aiosqlite
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
import asyncio
|
|
|
|
|
|
class AsyncDataSet:
|
|
|
|
_KV_TABLE = "__kv_store"
|
|
|
|
def __init__(self, file: str):
|
|
self._file = file
|
|
|
|
@staticmethod
|
|
def _utc_iso() -> str:
|
|
return (
|
|
datetime.now(timezone.utc)
|
|
.replace(microsecond=0)
|
|
.isoformat()
|
|
.replace("+00:00", "Z")
|
|
)
|
|
|
|
@staticmethod
|
|
def _py_to_sqlite_type(value: Any) -> str:
|
|
if value is None:
|
|
return "TEXT"
|
|
if isinstance(value, bool):
|
|
return "INTEGER"
|
|
if isinstance(value, int):
|
|
return "INTEGER"
|
|
if isinstance(value, float):
|
|
return "REAL"
|
|
if isinstance(value, (bytes, bytearray, memoryview)):
|
|
return "BLOB"
|
|
return "TEXT"
|
|
|
|
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
|
col_type = self._py_to_sqlite_type(value)
|
|
async with aiosqlite.connect(self._file) as db:
|
|
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
|
await db.commit()
|
|
|
|
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
|
cols: Dict[str, str] = {
|
|
"uid": "TEXT PRIMARY KEY",
|
|
"created_at": "TEXT",
|
|
"updated_at": "TEXT",
|
|
"deleted_at": "TEXT",
|
|
}
|
|
for key, val in col_sources.items():
|
|
if key not in cols:
|
|
cols[key] = self._py_to_sqlite_type(val)
|
|
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
|
async with aiosqlite.connect(self._file) as db:
|
|
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
|
await db.commit()
|
|
|
|
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
|
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
|
|
|
@classmethod
|
|
def _missing_column_from_error(
|
|
cls, err: aiosqlite.OperationalError
|
|
) -> Optional[str]:
|
|
m = cls._RE_NO_COLUMN.search(str(err))
|
|
return m.group(1) if m else None
|
|
|
|
@classmethod
|
|
def _missing_table_from_error(
|
|
cls, err: aiosqlite.OperationalError
|
|
) -> Optional[str]:
|
|
m = cls._RE_NO_TABLE.search(str(err))
|
|
return m.group(1) if m else None
|
|
|
|
async def _safe_execute(
|
|
self,
|
|
table: str,
|
|
sql: str,
|
|
params: Iterable[Any],
|
|
col_sources: Dict[str, Any],
|
|
) -> aiosqlite.Cursor:
|
|
while True:
|
|
try:
|
|
async with aiosqlite.connect(self._file) as db:
|
|
cursor = await db.execute(sql, params)
|
|
await db.commit()
|
|
return cursor
|
|
except aiosqlite.OperationalError as err:
|
|
col = self._missing_column_from_error(err)
|
|
if col:
|
|
if col not in col_sources:
|
|
raise
|
|
await self._ensure_column(table, col, col_sources[col])
|
|
continue
|
|
tbl = self._missing_table_from_error(err)
|
|
if tbl:
|
|
await self._ensure_table(tbl, col_sources)
|
|
continue
|
|
raise
|
|
|
|
async def _safe_query(
|
|
self,
|
|
table: str,
|
|
sql: str,
|
|
params: Iterable[Any],
|
|
col_sources: Dict[str, Any],
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
while True:
|
|
try:
|
|
async with aiosqlite.connect(self._file) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
async with db.execute(sql, params) as cursor:
|
|
async for row in cursor:
|
|
yield dict(row)
|
|
return
|
|
except aiosqlite.OperationalError as err:
|
|
tbl = self._missing_table_from_error(err)
|
|
if tbl:
|
|
await self._ensure_table(tbl, col_sources)
|
|
continue
|
|
raise
|
|
|
|
@staticmethod
|
|
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
|
if not where:
|
|
return "", []
|
|
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
|
return " WHERE " + " AND ".join(clauses), list(vals)
|
|
|
|
async def insert(self, table: str, args: Dict[str, Any]) -> str:
|
|
uid = str(uuid4())
|
|
now = self._utc_iso()
|
|
record = {
|
|
"uid": uid,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"deleted_at": None,
|
|
**args,
|
|
}
|
|
cols = "`" + "`, `".join(record) + "`"
|
|
qs = ", ".join(["?"] * len(record))
|
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
|
await self._safe_execute(table, sql, list(record.values()), record)
|
|
return uid
|
|
|
|
async def update(
|
|
self,
|
|
table: str,
|
|
args: Dict[str, Any],
|
|
where: Optional[Dict[str, Any]] = None,
|
|
) -> int:
|
|
if not args:
|
|
return 0
|
|
args["updated_at"] = self._utc_iso()
|
|
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
|
params = list(args.values()) + where_params
|
|
cur = await self._safe_execute(table, sql, params, {**args, **(where or {})})
|
|
return cur.rowcount
|
|
|
|
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"DELETE FROM {table}{where_clause}"
|
|
cur = await self._safe_execute(table, sql, where_params, where or {})
|
|
return cur.rowcount
|
|
|
|
async def upsert(
|
|
self,
|
|
table: str,
|
|
args: Dict[str, Any],
|
|
where: Optional[Dict[str, Any]] = None,
|
|
) -> str | None:
|
|
if not args:
|
|
raise ValueError("Nothing to update. Empty dict given.")
|
|
args['updated_at'] = str(datetime.now())
|
|
affected = await self.update(table, args, where)
|
|
if affected:
|
|
rec = await self.get(table, where)
|
|
return rec.get("uid") if rec else None
|
|
merged = {**(where or {}), **args}
|
|
return await self.insert(table, merged)
|
|
|
|
async def get(
|
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
|
|
async for row in self._safe_query(table, sql, where_params, where or {}):
|
|
return row
|
|
return None
|
|
|
|
async def find(
|
|
self,
|
|
table: str,
|
|
where: Optional[Dict[str, Any]] = None,
|
|
*,
|
|
limit: int = 0,
|
|
offset: int = 0,
|
|
) -> List[Dict[str, Any]]:
|
|
where_clause, where_params = self._build_where(where)
|
|
extra = (f" LIMIT {limit}" if limit else "") + (
|
|
f" OFFSET {offset}" if offset else ""
|
|
)
|
|
sql = f"SELECT * FROM {table}{where_clause}{extra}"
|
|
return [
|
|
row async for row in self._safe_query(table, sql, where_params, where or {})
|
|
]
|
|
|
|
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
|
where_clause, where_params = self._build_where(where)
|
|
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
|
gen = self._safe_query(table, sql, where_params, where or {})
|
|
async for row in gen:
|
|
return next(iter(row.values()), 0)
|
|
return 0
|
|
|
|
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
|
return (await self.count(table, where)) > 0
|
|
|
|
async def kv_set(
|
|
self,
|
|
key: str,
|
|
value: Any,
|
|
*,
|
|
table: str | None = None,
|
|
) -> None:
|
|
tbl = table or self._KV_TABLE
|
|
json_val = json.dumps(value, default=str)
|
|
await self.upsert(tbl, {"value": json_val}, {"key": key})
|
|
|
|
async def kv_get(
|
|
self,
|
|
key: str,
|
|
*,
|
|
default: Any = None,
|
|
table: str | None = None,
|
|
) -> Any:
|
|
tbl = table or self._KV_TABLE
|
|
row = await self.get(tbl, {"key": key})
|
|
if not row:
|
|
return default
|
|
try:
|
|
return json.loads(row["value"])
|
|
except Exception:
|
|
return default
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator
|
|
|
|
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def asyncSetUp(self):
|
|
self.db_path = Path("temp_test.db")
|
|
if self.db_path.exists():
|
|
self.db_path.unlink()
|
|
self.connector = AsyncDataSet(str(self.db_path))
|
|
|
|
async def test_insert_and_get(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
|
self.assertIsNotNone(rec)
|
|
self.assertEqual(rec["name"], "John Doe")
|
|
|
|
async def test_get_nonexistent(self):
|
|
result = await self.connector.get("people", {"name": "Jane Doe"})
|
|
self.assertIsNone(result)
|
|
|
|
async def test_update(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
|
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
|
self.assertEqual(rec["age"], 31)
|
|
|
|
async def test_upsert_insert(self):
|
|
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["name"], "Alice")
|
|
self.assertEqual(rec["age"], 22)
|
|
self.alice_uid = uid
|
|
|
|
async def test_upsert_update(self):
|
|
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
|
await asyncio.sleep(1.1)
|
|
uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"})
|
|
self.assertEqual(uid_same, uid)
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["age"], 23)
|
|
|
|
async def test_count_and_exists(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
total = await self.connector.count("people")
|
|
self.assertEqual(total, 2)
|
|
self.assertTrue(await self.connector.exists("people", {"name": "Alice"}))
|
|
self.assertFalse(await self.connector.exists("people", {"name": "Nobody"}))
|
|
|
|
async def test_find_pagination(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
rows = await self.connector.find("people", limit=1, offset=1)
|
|
self.assertEqual(len(rows), 1)
|
|
self.assertIn(rows[0]["name"], {"John Doe", "Alice"})
|
|
|
|
async def test_kv_set_and_get(self):
|
|
await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"})
|
|
cfg = await self.connector.kv_get("config:threshold")
|
|
self.assertEqual(cfg, {"limit": 10, "mode": "fast"})
|
|
|
|
async def test_delete(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
await self.connector.delete("people", {"name": "John Doe"})
|
|
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
|
count = await self.connector.count("people")
|
|
self.assertEqual(count, 1)
|
|
|
|
async def test_inject(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "index": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "index": 23})
|
|
await self.connector.delete("people", {"name": "John Doe", "index": 30})
|
|
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
|
count = await self.connector.count("people")
|
|
self.assertEqual(count, 1)
|
|
|
|
async def test_insert_binary(self):
|
|
await self.connector.insert("binaries", {"data": b"1234"})
|
|
print(await self.connector.get("binaries", {"data": b"1234"}))
|
|
|
|
async def asyncSetUp(self):
|
|
self.db_path = Path("temp_test.db")
|
|
if self.db_path.exists():
|
|
self.db_path.unlink()
|
|
self.connector = AsyncDataSet(str(self.db_path))
|
|
|
|
async def asyncTearDown(self):
|
|
if self.db_path.exists():
|
|
self.db_path.unlink()
|
|
|
|
async def test_insert_and_get(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
|
self.assertIsNotNone(rec)
|
|
self.assertEqual(rec["name"], "John Doe")
|
|
self.assertEqual(rec["age"], 30)
|
|
self.assertIsNotNone(rec["uid"])
|
|
self.assertIsNotNone(rec["created_at"])
|
|
self.assertIsNotNone(rec["updated_at"])
|
|
self.assertIsNone(rec["deleted_at"])
|
|
|
|
async def test_get_nonexistent(self):
|
|
result = await self.connector.get("people", {"name": "Jane Doe"})
|
|
self.assertIsNone(result)
|
|
|
|
async def test_update(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await asyncio.sleep(1.1)
|
|
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
|
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
|
self.assertEqual(rec["age"], 31)
|
|
self.assertNotEqual(rec["created_at"], rec["updated_at"])
|
|
|
|
async def test_upsert_insert(self):
|
|
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["name"], "Alice")
|
|
self.assertEqual(rec["age"], 22)
|
|
self.assertIsNotNone(rec["uid"])
|
|
self.assertIsNotNone(rec["created_at"])
|
|
|
|
async def test_upsert_update(self):
|
|
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
|
await asyncio.sleep(1.1)
|
|
uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"})
|
|
self.assertEqual(uid_same, uid)
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["age"], 23)
|
|
self.assertNotEqual(rec["created_at"], rec["updated_at"])
|
|
|
|
async def test_count_and_exists(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
total = await self.connector.count("people")
|
|
self.assertEqual(total, 2)
|
|
self.assertTrue(await self.connector.exists("people", {"name": "Alice"}))
|
|
self.assertFalse(await self.connector.exists("people", {"name": "Nobody"}))
|
|
|
|
async def test_find_pagination(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
rows = await self.connector.find("people", limit=1, offset=1)
|
|
self.assertEqual(len(rows), 1)
|
|
self.assertIn(rows[0]["name"], {"John Doe", "Alice"})
|
|
|
|
async def test_kv_set_and_get(self):
|
|
await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"})
|
|
cfg = await self.connector.kv_get("config:threshold")
|
|
self.assertEqual(cfg, {"limit": 10, "mode": "fast"})
|
|
|
|
async def test_delete(self):
|
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
|
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
|
await self.connector.delete("people", {"name": "John Doe"})
|
|
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
|
count = await self.connector.count("people")
|
|
self.assertEqual(count, 1)
|
|
|
|
async def test_insert_binary(self):
|
|
await self.connector.insert("binaries", {"data": b"1234"})
|
|
rec = await self.connector.get("binaries", {"data": b"1234"})
|
|
self.assertIsNotNone(rec)
|
|
self.assertEqual(rec["data"], b"1234")
|
|
|
|
async def test_insert_multiple_types(self):
|
|
data = {
|
|
"name": "Test",
|
|
"age": 25,
|
|
"is_active": True,
|
|
"score": 98.5,
|
|
"data": b"test",
|
|
"metadata": None
|
|
}
|
|
await self.connector.insert("mixed", data)
|
|
rec = await self.connector.get("mixed", {"name": "Test"})
|
|
self.assertEqual(rec["name"], "Test")
|
|
self.assertEqual(rec["age"], 25)
|
|
self.assertEqual(rec["is_active"], 1)
|
|
self.assertEqual(rec["score"], 98.5)
|
|
self.assertEqual(rec["data"], b"test")
|
|
self.assertIsNone(rec["metadata"])
|
|
|
|
async def test_update_nonexistent(self):
|
|
result = await self.connector.update("people", {"age": 40}, {"name": "Nobody"})
|
|
self.assertEqual(result, 0)
|
|
|
|
async def test_delete_nonexistent(self):
|
|
result = await self.connector.delete("people", {"name": "Nobody"})
|
|
self.assertEqual(result, 0)
|
|
|
|
async def test_kv_get_default(self):
|
|
result = await self.connector.kv_get("nonexistent_key", default="default")
|
|
self.assertEqual(result, "default")
|
|
|
|
async def test_kv_set_custom_table(self):
|
|
await self.connector.kv_set("test_key", "test_value", table="custom_kv")
|
|
result = await self.connector.kv_get("test_key", table="custom_kv")
|
|
self.assertEqual(result, "test_value")
|
|
|
|
async def test_find_empty_table(self):
|
|
result = await self.connector.find("empty_table")
|
|
self.assertEqual(result, [])
|
|
|
|
async def test_count_empty_table(self):
|
|
result = await self.connector.count("empty_table")
|
|
self.assertEqual(result, 0)
|
|
|
|
async def test_exists_empty_table(self):
|
|
result = await self.connector.exists("empty_table", {"name": "Test"})
|
|
self.assertFalse(result)
|
|
|
|
async def test_insert_duplicate_column_handling(self):
|
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
|
await self.connector.insert("people", {"name": "Jane", "age": 25})
|
|
rec = await self.connector.get("people", {"name": "John"})
|
|
self.assertEqual(rec["age"], 30)
|
|
|
|
async def test_upsert_no_where(self):
|
|
uid = await self.connector.upsert("people", {"name": "Bob", "age": 40})
|
|
rec = await self.connector.get("people", {"uid": uid})
|
|
self.assertEqual(rec["name"], "Bob")
|
|
self.assertEqual(rec["age"], 40)
|
|
|
|
async def test_find_with_where(self):
|
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
|
await self.connector.insert("people", {"name": "Jane", "age": 25})
|
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
|
result = await self.connector.find("people", {"age": 30})
|
|
self.assertEqual(len(result), 2)
|
|
names = {row["name"] for row in result}
|
|
self.assertEqual(names, {"John", "Bob"})
|
|
|
|
async def test_update_multiple_columns(self):
|
|
await self.connector.insert("people", {"name": "John", "age": 30, "city": "NY"})
|
|
await self.connector.update("people", {"age": 31, "city": "LA"}, {"name": "John"})
|
|
rec = await self.connector.get("people", {"name": "John"})
|
|
self.assertEqual(rec["age"], 31)
|
|
self.assertEqual(rec["city"], "LA")
|
|
|
|
async def test_insert_invalid_column_name(self):
|
|
await self.connector.insert("people", {"invalid column": "value"})
|
|
|
|
async def test_kv_set_complex_object(self):
|
|
complex_data = {
|
|
"numbers": [1, 2, 3],
|
|
"nested": {"a": 1, "b": {"c": 2}},
|
|
"bool": True,
|
|
"null": None
|
|
}
|
|
await self.connector.kv_set("complex", complex_data)
|
|
result = await self.connector.kv_get("complex")
|
|
self.assertEqual(result, complex_data)
|
|
|
|
async def test_concurrent_operations(self):
|
|
async def insert_task(name, age):
|
|
await self.connector.insert("people", {"name": name, "age": age})
|
|
|
|
tasks = [
|
|
insert_task(f"Person_{i}", 20 + i) for i in range(5)
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
count = await self.connector.count("people")
|
|
self.assertEqual(count, 5)
|
|
|
|
async def test_column_type_preservation(self):
|
|
data = {"name": "Test", "number": 42, "float": 3.14, "bool": True,"bytes":b"data"}
|
|
await self.connector.insert("types", data)
|
|
rec = await self.connector.get("types", {"name": "Test"})
|
|
self.assertIsInstance(rec["number"], int)
|
|
self.assertIsInstance(rec["float"], float)
|
|
self.assertIsInstance(rec["bool"], int) # SQLite stores bool as int
|
|
self.assertIsInstance(rec["bytes"], bytes) # SQLite stores bool as int
|
|
self.assertEqual(rec["bool"], 1)
|
|
|
|
async def test_empty_update(self):
|
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
|
result = await self.connector.update("people", {}, {"name": "John"})
|
|
self.assertEqual(result, 0)
|
|
|
|
async def test_find_with_zero_limit(self):
|
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
|
result = await self.connector.find("people", limit=0)
|
|
self.assertEqual(len(result), 1)
|
|
|
|
async def test_upsert_empty_args(self):
|
|
with self.assertRaises(ValueError):
|
|
await self.connector.upsert("people", {}, {"name": "John"})
|
|
|
|
async def test_invalid_table_name(self):
|
|
with self.assertRaises(aiosqlite.OperationalError):
|
|
await self.connector.insert("invalid table name", {"name": "John"})
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|