import re
import json
from uuid import uuid4
from datetime import datetime, timezone
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
AsyncGenerator,
Union,
Tuple,
Set,
)
from pathlib import Path
import aiosqlite
import unittest
from types import SimpleNamespace
import asyncio
class AsyncDataSet:
_KV_TABLE = "__kv_store"
_DEFAULT_COLUMNS = {
"uid": "TEXT PRIMARY KEY",
"created_at": "TEXT",
"updated_at": "TEXT",
"deleted_at": "TEXT",
}
def __init__(self, file: str):
self._file = file
self._table_columns_cache: Dict[str, Set[str]] = {}
@staticmethod
def _utc_iso() -> str:
return (
datetime.now(timezone.utc)
.replace(microsecond=0)
.isoformat()
.replace("+00:00", "Z")
)
@staticmethod
def _py_to_sqlite_type(value: Any) -> str:
if value is None:
return "TEXT"
if isinstance(value, bool):
return "INTEGER"
if isinstance(value, int):
return "INTEGER"
if isinstance(value, float):
return "REAL"
if isinstance(value, (bytes, bytearray, memoryview)):
return "BLOB"
return "TEXT"
async def _get_table_columns(self, table: str) -> Set[str]:
"""Get actual columns that exist in the table."""
if table in self._table_columns_cache:
return self._table_columns_cache[table]
columns = set()
try:
async with aiosqlite.connect(self._file) as db:
async with db.execute(f"PRAGMA table_info({table})") as cursor:
async for row in cursor:
columns.add(row[1]) # Column name is at index 1
self._table_columns_cache[table] = columns
except:
pass
return columns
async def _invalidate_column_cache(self, table: str):
"""Invalidate column cache for a table."""
if table in self._table_columns_cache:
del self._table_columns_cache[table]
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
col_type = self._py_to_sqlite_type(value)
try:
async with aiosqlite.connect(self._file) as db:
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
await db.commit()
await self._invalidate_column_cache(table)
except aiosqlite.OperationalError as e:
if "duplicate column name" in str(e).lower():
pass # Column already exists
else:
raise
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
# Always include default columns
cols = self._DEFAULT_COLUMNS.copy()
# Add columns from col_sources
for key, val in col_sources.items():
if key not in cols:
cols[key] = self._py_to_sqlite_type(val)
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
async with aiosqlite.connect(self._file) as db:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await db.commit()
await self._invalidate_column_cache(table)
async def _table_exists(self, table: str) -> bool:
"""Check if a table exists."""
async with aiosqlite.connect(self._file) as db:
async with db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
) as cursor:
return await cursor.fetchone() is not None
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
@classmethod
def _missing_column_from_error(
cls, err: aiosqlite.OperationalError
) -> Optional[str]:
m = cls._RE_NO_COLUMN.search(str(err))
return m.group(1) if m else None
@classmethod
def _missing_table_from_error(
cls, err: aiosqlite.OperationalError
) -> Optional[str]:
m = cls._RE_NO_TABLE.search(str(err))
return m.group(1) if m else None
async def _safe_execute(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
max_retries: int = 10,
) -> aiosqlite.Cursor:
retries = 0
while retries < max_retries:
try:
async with aiosqlite.connect(self._file) as db:
cursor = await db.execute(sql, params)
await db.commit()
return cursor
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing column
col = self._missing_column_from_error(err)
if col:
if col in col_sources:
await self._ensure_column(table, col, col_sources[col])
else:
# Column not in sources, ensure it with NULL/TEXT type
await self._ensure_column(table, col, None)
continue
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
await self._ensure_table(tbl, col_sources)
continue
# Handle other column-related errors
if "has no column named" in err_str:
# Extract column name differently
match = re.search(r"table \w+ has no column named (\w+)", err_str)
if match:
col_name = match.group(1)
if col_name in col_sources:
await self._ensure_column(
table, col_name, col_sources[col_name]
)
else:
await self._ensure_column(table, col_name, None)
continue
raise
raise Exception(f"Max retries ({max_retries}) exceeded")
async def _filter_existing_columns(
self, table: str, data: Dict[str, Any]
) -> Dict[str, Any]:
"""Filter data to only include columns that exist in the table."""
if not await self._table_exists(table):
return data
existing_columns = await self._get_table_columns(table)
if not existing_columns:
return data
return {k: v for k, v in data.items() if k in existing_columns}
async def _safe_query(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
) -> AsyncGenerator[Dict[str, Any], None]:
# Check if table exists first
if not await self._table_exists(table):
return
max_retries = 10
retries = 0
while retries < max_retries:
try:
async with aiosqlite.connect(self._file) as db:
db.row_factory = aiosqlite.Row
async with db.execute(sql, params) as cursor:
async for row in cursor:
yield dict(row)
return
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
# For queries, if table doesn't exist, just return empty
return
# Handle missing column in WHERE clause or SELECT
if "no such column" in err_str:
# For queries with missing columns, return empty
return
raise
@staticmethod
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where:
return "", []
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
return " WHERE " + " AND ".join(clauses), list(vals)
async def insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
uid = str(uuid4())
now = self._utc_iso()
record = {
"uid": uid,
"created_at": now,
"updated_at": now,
"deleted_at": None,
**args,
}
# Ensure table exists with all needed columns
await self._ensure_table(table, record)
# Handle auto-increment ID if requested
if return_id and "id" not in args:
# Ensure id column exists
async with aiosqlite.connect(self._file) as db:
# Add id column if it doesn't exist
try:
await db.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
)
await db.commit()
except aiosqlite.OperationalError as e:
if "duplicate column name" not in str(e).lower():
# Try without autoincrement constraint
try:
await db.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
)
await db.commit()
except:
pass
await self._invalidate_column_cache(table)
# Insert and get lastrowid
cols = "`" + "`, `".join(record.keys()) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
cursor = await self._safe_execute(table, sql, list(record.values()), record)
return cursor.lastrowid
cols = "`" + "`, `".join(record) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
await self._safe_execute(table, sql, list(record.values()), record)
return uid
async def update(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> int:
if not args:
return 0
# Check if table exists
if not await self._table_exists(table):
return 0
args["updated_at"] = self._utc_iso()
# Ensure all columns exist
all_cols = {**args, **(where or {})}
await self._ensure_table(table, all_cols)
for col, val in all_cols.items():
await self._ensure_column(table, col, val)
set_clause = ", ".join(f"`{k}` = ?" for k in args)
where_clause, where_params = self._build_where(where)
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
params = list(args.values()) + where_params
cur = await self._safe_execute(table, sql, params, all_cols)
return cur.rowcount
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"DELETE FROM {table}{where_clause}"
cur = await self._safe_execute(table, sql, where_params, where or {})
return cur.rowcount
async def upsert(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> str | None:
if not args:
raise ValueError("Nothing to update. Empty dict given.")
args["updated_at"] = self._utc_iso()
affected = await self.update(table, args, where)
if affected:
rec = await self.get(table, where)
return rec.get("uid") if rec else None
merged = {**(where or {}), **args}
return await self.insert(table, merged)
async def get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
where_clause, where_params = self._build_where(where)
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
async for row in self._safe_query(table, sql, where_params, where or {}):
return row
return None
async def find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Find records with optional ordering."""
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {limit}" if limit else "") + (
f" OFFSET {offset}" if offset else ""
)
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
return [
row async for row in self._safe_query(table, sql, where_params, where or {})
]
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
gen = self._safe_query(table, sql, where_params, where or {})
async for row in gen:
return next(iter(row.values()), 0)
return 0
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
return (await self.count(table, where)) > 0
async def kv_set(
self,
key: str,
value: Any,
*,
table: str | None = None,
) -> None:
tbl = table or self._KV_TABLE
json_val = json.dumps(value, default=str)
await self.upsert(tbl, {"value": json_val}, {"key": key})
async def kv_get(
self,
key: str,
*,
default: Any = None,
table: str | None = None,
) -> Any:
tbl = table or self._KV_TABLE
row = await self.get(tbl, {"key": key})
if not row:
return default
try:
return json.loads(row["value"])
except Exception:
return default
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
"""Execute raw SQL for complex queries like JOINs."""
async with aiosqlite.connect(self._file) as db:
cursor = await db.execute(sql, params or ())
await db.commit()
return cursor
async def query_raw(
self, sql: str, params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
"""Execute raw SQL query and return results as list of dicts."""
try:
async with aiosqlite.connect(self._file) as db:
db.row_factory = aiosqlite.Row
async with db.execute(sql, params or ()) as cursor:
return [dict(row) async for row in cursor]
except aiosqlite.OperationalError:
# Return empty list if query fails
return []
async def query_one(
self, sql: str, params: Optional[Tuple] = None
) -> Optional[Dict[str, Any]]:
"""Execute raw SQL query and return single result."""
results = await self.query_raw(sql + " LIMIT 1", params)
return results[0] if results else None
async def create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
"""Create table with custom schema and constraints. Always includes default columns."""
# Merge default columns with custom schema
full_schema = self._DEFAULT_COLUMNS.copy()
full_schema.update(schema)
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
if constraints:
columns.extend(constraints)
columns_sql = ", ".join(columns)
async with aiosqlite.connect(self._file) as db:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await db.commit()
await self._invalidate_column_cache(table)
async def insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
try:
return await self.insert(table, args)
except aiosqlite.IntegrityError as e:
if "UNIQUE" in str(e):
return None
raise
async def transaction(self):
"""Context manager for transactions."""
return TransactionContext(self._file)
async def aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
# Check if table exists
if not await self._table_exists(table):
return None
where_clause, where_params = self._build_where(where)
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
result = await self.query_one(sql, tuple(where_params))
return result["result"] if result else None
class TransactionContext:
"""Context manager for database transactions."""
def __init__(self, db_file: str):
self.db_file = db_file
self.conn = None
async def __aenter__(self):
self.conn = await aiosqlite.connect(self.db_file)
self.conn.row_factory = aiosqlite.Row
await self.conn.execute("BEGIN")
return self.conn
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.conn.commit()
else:
await self.conn.rollback()
await self.conn.close()
# Test cases remain the same but with additional tests for new functionality
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.db_path = Path("temp_test.db")
if self.db_path.exists():
self.db_path.unlink()
self.connector = AsyncDataSet(str(self.db_path))
async def asyncTearDown(self):
if self.db_path.exists():
self.db_path.unlink()
async def test_insert_and_get(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
rec = await self.connector.get("people", {"name": "John Doe"})
self.assertIsNotNone(rec)
self.assertEqual(rec["name"], "John Doe")
async def test_get_nonexistent(self):
result = await self.connector.get("people", {"name": "Jane Doe"})
self.assertIsNone(result)
async def test_update(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
rec = await self.connector.get("people", {"name": "John Doe"})
self.assertEqual(rec["age"], 31)
async def test_order_by(self):
await self.connector.insert("people", {"name": "Alice", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 30})
await self.connector.insert("people", {"name": "Charlie", "age": 20})
results = await self.connector.find("people", order_by="age ASC")
self.assertEqual(results[0]["name"], "Charlie")
self.assertEqual(results[-1]["name"], "Bob")
async def test_raw_query(self):
await self.connector.insert("people", {"name": "John", "age": 30})
await self.connector.insert("people", {"name": "Jane", "age": 25})
results = await self.connector.query_raw(
"SELECT * FROM people WHERE age > ?", (26,)
)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["name"], "John")
async def test_aggregate(self):
await self.connector.insert("people", {"name": "John", "age": 30})
await self.connector.insert("people", {"name": "Jane", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 35})
avg_age = await self.connector.aggregate("people", "AVG", "age")
self.assertEqual(avg_age, 30)
max_age = await self.connector.aggregate("people", "MAX", "age")
self.assertEqual(max_age, 35)
async def test_insert_with_auto_id(self):
# Test auto-increment ID functionality
id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True)
id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True)
self.assertEqual(id2, id1 + 1)
async def test_transaction(self):
async with self.connector.transaction() as conn:
await conn.execute(
"INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
("test-uid", "John", 30, "2024-01-01", "2024-01-01"),
)
# Transaction will be committed
rec = await self.connector.get("people", {"name": "John"})
self.assertIsNotNone(rec)
async def test_create_custom_table(self):
schema = {
"id": "INTEGER PRIMARY KEY AUTOINCREMENT",
"username": "TEXT NOT NULL",
"email": "TEXT NOT NULL",
"score": "INTEGER DEFAULT 0",
}
constraints = ["UNIQUE(username)", "UNIQUE(email)"]
await self.connector.create_table("users", schema, constraints)
# Test that table was created with constraints
result = await self.connector.insert_unique(
"users",
{"username": "john", "email": "john@example.com"},
["username", "email"],
)
self.assertIsNotNone(result)
# Test duplicate insert
result = await self.connector.insert_unique(
"users",
{"username": "john", "email": "different@example.com"},
["username", "email"],
)
self.assertIsNone(result)
async def test_missing_table_operations(self):
# Test operations on non-existent tables
self.assertEqual(await self.connector.count("nonexistent"), 0)
self.assertEqual(await self.connector.find("nonexistent"), [])
self.assertIsNone(await self.connector.get("nonexistent"))
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
self.assertEqual(await self.connector.delete("nonexistent"), 0)
self.assertEqual(
await self.connector.update("nonexistent", {"name": "test"}), 0
)
async def test_auto_column_creation(self):
# Insert with new columns that don't exist yet
await self.connector.insert(
"dynamic", {"col1": "value1", "col2": 42, "col3": 3.14}
)
# Add more columns in next insert
await self.connector.insert(
"dynamic", {"col1": "value2", "col4": True, "col5": None}
)
# All records should be retrievable
records = await self.connector.find("dynamic")
self.assertEqual(len(records), 2)
if __name__ == "__main__":
unittest.main()