ADded ads.py.
This commit is contained in:
parent
7b4c1def45
commit
695efe6d2f
335
ads.py
Normal file
335
ads.py
Normal file
@ -0,0 +1,335 @@
|
||||
import re
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator
|
||||
from pathlib import Path
|
||||
import aiosqlite
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
import asyncio
|
||||
|
||||
|
||||
class AsyncDataSetConnector:
|
||||
|
||||
_KV_TABLE = "__kv_store"
|
||||
|
||||
def __init__(self, file: str):
|
||||
self._file = file
|
||||
|
||||
@staticmethod
|
||||
def _utc_iso() -> str:
|
||||
return (
|
||||
datetime.now(timezone.utc)
|
||||
.replace(microsecond=0)
|
||||
.isoformat()
|
||||
.replace("+00:00", "Z")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _py_to_sqlite_type(value: Any) -> str:
|
||||
if value is None:
|
||||
return "TEXT"
|
||||
if isinstance(value, bool):
|
||||
return "INTEGER"
|
||||
if isinstance(value, int):
|
||||
return "INTEGER"
|
||||
if isinstance(value, float):
|
||||
return "REAL"
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return "BLOB"
|
||||
return "TEXT"
|
||||
|
||||
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
||||
col_type = self._py_to_sqlite_type(value)
|
||||
async with aiosqlite.connect(self._file) as db:
|
||||
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
||||
await db.commit()
|
||||
|
||||
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
||||
cols: Dict[str, str] = {
|
||||
"uid": "TEXT PRIMARY KEY",
|
||||
"created_at": "TEXT",
|
||||
"updated_at": "TEXT",
|
||||
"deleted_at": "TEXT",
|
||||
}
|
||||
for key, val in col_sources.items():
|
||||
if key not in cols:
|
||||
cols[key] = self._py_to_sqlite_type(val)
|
||||
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
||||
async with aiosqlite.connect(self._file) as db:
|
||||
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||
await db.commit()
|
||||
|
||||
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
||||
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
||||
|
||||
@classmethod
|
||||
def _missing_column_from_error(
|
||||
cls, err: aiosqlite.OperationalError
|
||||
) -> Optional[str]:
|
||||
m = cls._RE_NO_COLUMN.search(str(err))
|
||||
return m.group(1) if m else None
|
||||
|
||||
@classmethod
|
||||
def _missing_table_from_error(
|
||||
cls, err: aiosqlite.OperationalError
|
||||
) -> Optional[str]:
|
||||
m = cls._RE_NO_TABLE.search(str(err))
|
||||
return m.group(1) if m else None
|
||||
|
||||
async def _safe_execute(
|
||||
self,
|
||||
table: str,
|
||||
sql: str,
|
||||
params: Iterable[Any],
|
||||
col_sources: Dict[str, Any],
|
||||
) -> aiosqlite.Cursor:
|
||||
while True:
|
||||
try:
|
||||
async with aiosqlite.connect(self._file) as db:
|
||||
cursor = await db.execute(sql, params)
|
||||
await db.commit()
|
||||
return cursor
|
||||
except aiosqlite.OperationalError as err:
|
||||
col = self._missing_column_from_error(err)
|
||||
if col:
|
||||
if col not in col_sources:
|
||||
raise
|
||||
await self._ensure_column(table, col, col_sources[col])
|
||||
continue
|
||||
tbl = self._missing_table_from_error(err)
|
||||
if tbl:
|
||||
await self._ensure_table(tbl, col_sources)
|
||||
continue
|
||||
raise
|
||||
|
||||
async def _safe_query(
|
||||
self,
|
||||
table: str,
|
||||
sql: str,
|
||||
params: Iterable[Any],
|
||||
col_sources: Dict[str, Any],
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
while True:
|
||||
try:
|
||||
async with aiosqlite.connect(self._file) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(sql, params) as cursor:
|
||||
async for row in cursor:
|
||||
yield dict(row)
|
||||
return
|
||||
except aiosqlite.OperationalError as err:
|
||||
tbl = self._missing_table_from_error(err)
|
||||
if tbl:
|
||||
await self._ensure_table(tbl, col_sources)
|
||||
continue
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
||||
if not where:
|
||||
return "", []
|
||||
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
||||
return " WHERE " + " AND ".join(clauses), list(vals)
|
||||
|
||||
async def insert(self, table: str, args: Dict[str, Any]) -> str:
|
||||
uid = str(uuid4())
|
||||
now = self._utc_iso()
|
||||
record = {
|
||||
"uid": uid,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
**args,
|
||||
}
|
||||
cols = "`" + "`, `".join(record) + "`"
|
||||
qs = ", ".join(["?"] * len(record))
|
||||
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||
await self._safe_execute(table, sql, list(record.values()), record)
|
||||
return uid
|
||||
|
||||
async def update(
|
||||
self,
|
||||
table: str,
|
||||
args: Dict[str, Any],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
if not args:
|
||||
return 0
|
||||
args["updated_at"] = self._utc_iso()
|
||||
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
||||
params = list(args.values()) + where_params
|
||||
cur = await self._safe_execute(table, sql, params, {**args, **(where or {})})
|
||||
return cur.rowcount
|
||||
|
||||
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"DELETE FROM {table}{where_clause}"
|
||||
cur = await self._safe_execute(table, sql, where_params, where or {})
|
||||
return cur.rowcount
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
args: Dict[str, Any],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> str | None:
|
||||
affected = await self.update(table, args, where)
|
||||
if affected:
|
||||
rec = await self.get(table, where)
|
||||
return rec.get("uid") if rec else None
|
||||
merged = {**(where or {}), **args}
|
||||
return await self.insert(table, merged)
|
||||
|
||||
async def get(
|
||||
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
|
||||
async for row in self._safe_query(table, sql, where_params, where or {}):
|
||||
return row
|
||||
return None
|
||||
|
||||
async def find(
|
||||
self,
|
||||
table: str,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
limit: int = 0,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
where_clause, where_params = self._build_where(where)
|
||||
extra = (f" LIMIT {limit}" if limit else "") + (
|
||||
f" OFFSET {offset}" if offset else ""
|
||||
)
|
||||
sql = f"SELECT * FROM {table}{where_clause}{extra}"
|
||||
return [
|
||||
row async for row in self._safe_query(table, sql, where_params, where or {})
|
||||
]
|
||||
|
||||
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
||||
gen = self._safe_query(table, sql, where_params, where or {})
|
||||
async for row in gen:
|
||||
return next(iter(row.values()), 0)
|
||||
return 0
|
||||
|
||||
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||
return (await self.count(table, where)) > 0
|
||||
|
||||
async def kv_set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
*,
|
||||
table: str | None = None,
|
||||
) -> None:
|
||||
tbl = table or self._KV_TABLE
|
||||
json_val = json.dumps(value, default=str)
|
||||
await self.upsert(tbl, {"value": json_val}, {"key": key})
|
||||
|
||||
async def kv_get(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
default: Any = None,
|
||||
table: str | None = None,
|
||||
) -> Any:
|
||||
tbl = table or self._KV_TABLE
|
||||
row = await self.get(tbl, {"key": key})
|
||||
if not row:
|
||||
return default
|
||||
try:
|
||||
return json.loads(row["value"])
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TestAsyncDataSetConnector(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def asyncSetUp(self):
|
||||
self.db_path = Path("temp_test.db")
|
||||
if self.db_path.exists():
|
||||
self.db_path.unlink()
|
||||
self.connector = AsyncDataSetConnector(str(self.db_path))
|
||||
|
||||
async def test_insert_and_get(self):
|
||||
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||
rec = await self.connector.get("people", {"name": "John Doe"})
|
||||
self.assertIsNotNone(rec)
|
||||
self.assertEqual(rec["name"], "John Doe")
|
||||
|
||||
async def test_get_nonexistent(self):
|
||||
result = await self.connector.get("people", {"name": "Jane Doe"})
|
||||
self.assertIsNone(result)
|
||||
|
||||
async def test_update(self):
|
||||
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
|
||||
rec = await self.connector.get("people", {"name": "John Doe"})
|
||||
self.assertEqual(rec["age"], 31)
|
||||
|
||||
async def test_upsert_insert(self):
|
||||
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
||||
rec = await self.connector.get("people", {"uid": uid})
|
||||
self.assertEqual(rec["name"], "Alice")
|
||||
self.assertEqual(rec["age"], 22)
|
||||
self.alice_uid = uid
|
||||
|
||||
async def test_upsert_update(self):
|
||||
uid = await self.connector.upsert("people", {"age": 22}, {"name": "Alice"})
|
||||
uid_same = await self.connector.upsert("people", {"age": 23}, {"name": "Alice"})
|
||||
self.assertEqual(uid_same, uid)
|
||||
rec = await self.connector.get("people", {"uid": uid})
|
||||
self.assertEqual(rec["age"], 23)
|
||||
|
||||
async def test_count_and_exists(self):
|
||||
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
||||
total = await self.connector.count("people")
|
||||
self.assertEqual(total, 2)
|
||||
self.assertTrue(await self.connector.exists("people", {"name": "Alice"}))
|
||||
self.assertFalse(await self.connector.exists("people", {"name": "Nobody"}))
|
||||
|
||||
async def test_find_pagination(self):
|
||||
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
||||
rows = await self.connector.find("people", limit=1, offset=1)
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertIn(rows[0]["name"], {"John Doe", "Alice"})
|
||||
|
||||
async def test_kv_set_and_get(self):
|
||||
await self.connector.kv_set("config:threshold", {"limit": 10, "mode": "fast"})
|
||||
cfg = await self.connector.kv_get("config:threshold")
|
||||
self.assertEqual(cfg, {"limit": 10, "mode": "fast"})
|
||||
|
||||
async def test_delete(self):
|
||||
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||
await self.connector.insert("people", {"name": "Alice", "age": 23})
|
||||
await self.connector.delete("people", {"name": "John Doe"})
|
||||
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
||||
count = await self.connector.count("people")
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
async def test_inject(self):
|
||||
await self.connector.insert("people", {"name": "John Doe", "index": 30})
|
||||
await self.connector.insert("people", {"name": "Alice", "index": 23})
|
||||
await self.connector.delete("people", {"name": "John Doe","index": 30})
|
||||
self.assertIsNone(await self.connector.get("people", {"name": "John Doe"}))
|
||||
count = await self.connector.count("people")
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
|
||||
async def test_insert_binary(self):
|
||||
await self.connector.insert("binaries",{"data":b"1234"})
|
||||
print(await self.connector.get("binaries",{"data":b"1234"}))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user