This commit is contained in:
retoor 2025-08-07 21:13:56 +02:00
parent 7334c59aa7
commit 49517fc71d
3 changed files with 1529 additions and 140 deletions

751
ads.py
View File

@ -2,15 +2,39 @@ import re
import json import json
from uuid import uuid4 from uuid import uuid4
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple, Set from typing import (
Any,
Dict,
Iterable,
List,
Optional,
AsyncGenerator,
Union,
Tuple,
Set,
)
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import asyncio import asyncio
import aiohttp
from aiohttp import web
import socket
import os
import atexit
class AsyncDataSet: class AsyncDataSet:
"""
Distributed AsyncDataSet with client-server model over Unix sockets.
Parameters:
- file: Path to the SQLite database file
- socket_path: Path to Unix socket for inter-process communication (default: "ads.sock")
- max_concurrent_queries: Maximum concurrent database operations (default: 1)
Set to 1 to prevent "database is locked" errors with SQLite
"""
_KV_TABLE = "__kv_store" _KV_TABLE = "__kv_store"
_DEFAULT_COLUMNS = { _DEFAULT_COLUMNS = {
@ -20,10 +44,487 @@ class AsyncDataSet:
"deleted_at": "TEXT", "deleted_at": "TEXT",
} }
def __init__(self, file: str): def __init__(
self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1
):
self._file = file self._file = file
self._socket_path = socket_path
self._table_columns_cache: Dict[str, Set[str]] = {} self._table_columns_cache: Dict[str, Set[str]] = {}
self._is_server = None # None means not initialized yet
self._server = None
self._runner = None
self._client_session = None
self._initialized = False
self._init_lock = None # Will be created when needed
self._max_concurrent_queries = max_concurrent_queries
self._db_semaphore = None # Will be created when needed
self._db_connection = None # Persistent database connection
async def _ensure_initialized(self):
"""Ensure the instance is initialized as server or client."""
if self._initialized:
return
# Create lock if needed (first time in async context)
if self._init_lock is None:
self._init_lock = asyncio.Lock()
async with self._init_lock:
if self._initialized:
return
await self._initialize()
# Create semaphore for database operations (server only)
if self._is_server:
self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries)
self._initialized = True
async def _initialize(self):
"""Initialize as server or client."""
try:
# Try to create Unix socket
if os.path.exists(self._socket_path):
# Check if socket is active
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(self._socket_path)
sock.close()
# Socket is active, we're a client
self._is_server = False
await self._setup_client()
except (ConnectionRefusedError, FileNotFoundError):
# Socket file exists but not active, remove and become server
os.unlink(self._socket_path)
self._is_server = True
await self._setup_server()
else:
# No socket file, become server
self._is_server = True
await self._setup_server()
except Exception as e:
# Fallback to client mode
self._is_server = False
await self._setup_client()
# Establish persistent database connection
self._db_connection = await aiosqlite.connect(self._file)
async def _setup_server(self):
"""Setup aiohttp server."""
app = web.Application()
# Add routes for all methods
app.router.add_post("/insert", self._handle_insert)
app.router.add_post("/update", self._handle_update)
app.router.add_post("/delete", self._handle_delete)
app.router.add_post("/upsert", self._handle_upsert)
app.router.add_post("/get", self._handle_get)
app.router.add_post("/find", self._handle_find)
app.router.add_post("/count", self._handle_count)
app.router.add_post("/exists", self._handle_exists)
app.router.add_post("/kv_set", self._handle_kv_set)
app.router.add_post("/kv_get", self._handle_kv_get)
app.router.add_post("/execute_raw", self._handle_execute_raw)
app.router.add_post("/query_raw", self._handle_query_raw)
app.router.add_post("/query_one", self._handle_query_one)
app.router.add_post("/create_table", self._handle_create_table)
app.router.add_post("/insert_unique", self._handle_insert_unique)
app.router.add_post("/aggregate", self._handle_aggregate)
self._runner = web.AppRunner(app)
await self._runner.setup()
self._server = web.UnixSite(self._runner, self._socket_path)
await self._server.start()
# Register cleanup
def cleanup():
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
atexit.register(cleanup)
async def _setup_client(self):
"""Setup aiohttp client."""
connector = aiohttp.UnixConnector(path=self._socket_path)
self._client_session = aiohttp.ClientSession(connector=connector)
async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any:
"""Make HTTP request to server."""
if not self._client_session:
await self._setup_client()
url = f"http://localhost/{endpoint}"
async with self._client_session.post(url, json=data) as resp:
result = await resp.json()
if result.get("error"):
raise Exception(result["error"])
return result.get("result")
# Server handlers
async def _handle_insert(self, request):
data = await request.json()
try:
result = await self._server_insert(
data["table"], data["args"], data.get("return_id", False)
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_update(self, request):
data = await request.json()
try:
result = await self._server_update(
data["table"], data["args"], data.get("where")
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_delete(self, request):
data = await request.json()
try:
result = await self._server_delete(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_upsert(self, request):
data = await request.json()
try:
result = await self._server_upsert(
data["table"], data["args"], data.get("where")
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_get(self, request):
data = await request.json()
try:
result = await self._server_get(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_find(self, request):
data = await request.json()
try:
result = await self._server_find(
data["table"],
data.get("where"),
limit=data.get("limit", 0),
offset=data.get("offset", 0),
order_by=data.get("order_by"),
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_count(self, request):
data = await request.json()
try:
result = await self._server_count(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_exists(self, request):
data = await request.json()
try:
result = await self._server_exists(data["table"], data["where"])
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_kv_set(self, request):
data = await request.json()
try:
await self._server_kv_set(
data["key"], data["value"], table=data.get("table")
)
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_kv_get(self, request):
data = await request.json()
try:
result = await self._server_kv_get(
data["key"], default=data.get("default"), table=data.get("table")
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_execute_raw(self, request):
data = await request.json()
try:
result = await self._server_execute_raw(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response(
{"result": result.rowcount if hasattr(result, "rowcount") else None}
)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_query_raw(self, request):
data = await request.json()
try:
result = await self._server_query_raw(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_query_one(self, request):
data = await request.json()
try:
result = await self._server_query_one(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_create_table(self, request):
data = await request.json()
try:
await self._server_create_table(
data["table"], data["schema"], data.get("constraints")
)
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_insert_unique(self, request):
data = await request.json()
try:
result = await self._server_insert_unique(
data["table"], data["args"], data["unique_fields"]
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_aggregate(self, request):
data = await request.json()
try:
result = await self._server_aggregate(
data["table"],
data["function"],
data.get("column", "*"),
data.get("where"),
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
# Public methods that delegate to server or client
async def insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
await self._ensure_initialized()
if self._is_server:
return await self._server_insert(table, args, return_id)
else:
return await self._make_request(
"insert", {"table": table, "args": args, "return_id": return_id}
)
async def update(
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_update(table, args, where)
else:
return await self._make_request(
"update", {"table": table, "args": args, "where": where}
)
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_delete(table, where)
else:
return await self._make_request("delete", {"table": table, "where": where})
async def upsert(
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
) -> str | None:
await self._ensure_initialized()
if self._is_server:
return await self._server_upsert(table, args, where)
else:
return await self._make_request(
"upsert", {"table": table, "args": args, "where": where}
)
async def get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_get(table, where)
else:
return await self._make_request("get", {"table": table, "where": where})
async def find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_find(
table, where, limit=limit, offset=offset, order_by=order_by
)
else:
return await self._make_request(
"find",
{
"table": table,
"where": where,
"limit": limit,
"offset": offset,
"order_by": order_by,
},
)
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_count(table, where)
else:
return await self._make_request("count", {"table": table, "where": where})
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
await self._ensure_initialized()
if self._is_server:
return await self._server_exists(table, where)
else:
return await self._make_request("exists", {"table": table, "where": where})
async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None:
await self._ensure_initialized()
if self._is_server:
return await self._server_kv_set(key, value, table=table)
else:
return await self._make_request(
"kv_set", {"key": key, "value": value, "table": table}
)
async def kv_get(
self, key: str, *, default: Any = None, table: str | None = None
) -> Any:
await self._ensure_initialized()
if self._is_server:
return await self._server_kv_get(key, default=default, table=table)
else:
return await self._make_request(
"kv_get", {"key": key, "default": default, "table": table}
)
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
await self._ensure_initialized()
if self._is_server:
return await self._server_execute_raw(sql, params)
else:
return await self._make_request(
"execute_raw", {"sql": sql, "params": list(params) if params else None}
)
async def query_raw(
self, sql: str, params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_query_raw(sql, params)
else:
return await self._make_request(
"query_raw", {"sql": sql, "params": list(params) if params else None}
)
async def query_one(
self, sql: str, params: Optional[Tuple] = None
) -> Optional[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_query_one(sql, params)
else:
return await self._make_request(
"query_one", {"sql": sql, "params": list(params) if params else None}
)
async def create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
await self._ensure_initialized()
if self._is_server:
return await self._server_create_table(table, schema, constraints)
else:
return await self._make_request(
"create_table",
{"table": table, "schema": schema, "constraints": constraints},
)
async def insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
await self._ensure_initialized()
if self._is_server:
return await self._server_insert_unique(table, args, unique_fields)
else:
return await self._make_request(
"insert_unique",
{"table": table, "args": args, "unique_fields": unique_fields},
)
async def aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
await self._ensure_initialized()
if self._is_server:
return await self._server_aggregate(table, function, column, where)
else:
return await self._make_request(
"aggregate",
{
"table": table,
"function": function,
"column": column,
"where": where,
},
)
async def transaction(self):
"""Context manager for transactions."""
await self._ensure_initialized()
if self._is_server:
return TransactionContext(self._db_connection, self._db_semaphore)
else:
raise NotImplementedError("Transactions not supported in client mode")
# Server implementation methods (original logic)
@staticmethod @staticmethod
def _utc_iso() -> str: def _utc_iso() -> str:
return ( return (
@ -54,8 +555,10 @@ class AsyncDataSet:
columns = set() columns = set()
try: try:
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
async with db.execute(f"PRAGMA table_info({table})") as cursor: async with self._db_connection.execute(
f"PRAGMA table_info({table})"
) as cursor:
async for row in cursor: async for row in cursor:
columns.add(row[1]) # Column name is at index 1 columns.add(row[1]) # Column name is at index 1
self._table_columns_cache[table] = columns self._table_columns_cache[table] = columns
@ -71,9 +574,11 @@ class AsyncDataSet:
async def _ensure_column(self, table: str, name: str, value: Any) -> None: async def _ensure_column(self, table: str, name: str, value: Any) -> None:
col_type = self._py_to_sqlite_type(value) col_type = self._py_to_sqlite_type(value)
try: try:
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}") await self._db_connection.execute(
await db.commit() f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
)
await self._db_connection.commit()
await self._invalidate_column_cache(table) await self._invalidate_column_cache(table)
except aiosqlite.OperationalError as e: except aiosqlite.OperationalError as e:
if "duplicate column name" in str(e).lower(): if "duplicate column name" in str(e).lower():
@ -91,15 +596,17 @@ class AsyncDataSet:
cols[key] = self._py_to_sqlite_type(val) cols[key] = self._py_to_sqlite_type(val)
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") await self._db_connection.execute(
await db.commit() f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
)
await self._db_connection.commit()
await self._invalidate_column_cache(table) await self._invalidate_column_cache(table)
async def _table_exists(self, table: str) -> bool: async def _table_exists(self, table: str) -> bool:
"""Check if a table exists.""" """Check if a table exists."""
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
async with db.execute( async with self._db_connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
) as cursor: ) as cursor:
return await cursor.fetchone() is not None return await cursor.fetchone() is not None
@ -127,14 +634,14 @@ class AsyncDataSet:
sql: str, sql: str,
params: Iterable[Any], params: Iterable[Any],
col_sources: Dict[str, Any], col_sources: Dict[str, Any],
max_retries: int = 10 max_retries: int = 10,
) -> aiosqlite.Cursor: ) -> aiosqlite.Cursor:
retries = 0 retries = 0
while retries < max_retries: while retries < max_retries:
try: try:
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
cursor = await db.execute(sql, params) cursor = await self._db_connection.execute(sql, params)
await db.commit() await self._db_connection.commit()
return cursor return cursor
except aiosqlite.OperationalError as err: except aiosqlite.OperationalError as err:
retries += 1 retries += 1
@ -163,7 +670,9 @@ class AsyncDataSet:
if match: if match:
col_name = match.group(1) col_name = match.group(1)
if col_name in col_sources: if col_name in col_sources:
await self._ensure_column(table, col_name, col_sources[col_name]) await self._ensure_column(
table, col_name, col_sources[col_name]
)
else: else:
await self._ensure_column(table, col_name, None) await self._ensure_column(table, col_name, None)
continue continue
@ -171,7 +680,9 @@ class AsyncDataSet:
raise raise
raise Exception(f"Max retries ({max_retries}) exceeded") raise Exception(f"Max retries ({max_retries}) exceeded")
async def _filter_existing_columns(self, table: str, data: Dict[str, Any]) -> Dict[str, Any]: async def _filter_existing_columns(
self, table: str, data: Dict[str, Any]
) -> Dict[str, Any]:
"""Filter data to only include columns that exist in the table.""" """Filter data to only include columns that exist in the table."""
if not await self._table_exists(table): if not await self._table_exists(table):
return data return data
@ -198,10 +709,14 @@ class AsyncDataSet:
while retries < max_retries: while retries < max_retries:
try: try:
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
db.row_factory = aiosqlite.Row self._db_connection.row_factory = aiosqlite.Row
async with db.execute(sql, params) as cursor: async with self._db_connection.execute(sql, params) as cursor:
async for row in cursor: # Fetch all rows while holding the semaphore
rows = await cursor.fetchall()
# Yield rows after releasing the semaphore
for row in rows:
yield dict(row) yield dict(row)
return return
except aiosqlite.OperationalError as err: except aiosqlite.OperationalError as err:
@ -228,7 +743,9 @@ class AsyncDataSet:
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
return " WHERE " + " AND ".join(clauses), list(vals) return " WHERE " + " AND ".join(clauses), list(vals)
async def insert(self, table: str, args: Dict[str, Any], return_id: bool = False) -> Union[str, int]: async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
uid = str(uuid4()) uid = str(uuid4())
now = self._utc_iso() now = self._utc_iso()
@ -244,19 +761,23 @@ class AsyncDataSet:
await self._ensure_table(table, record) await self._ensure_table(table, record)
# Handle auto-increment ID if requested # Handle auto-increment ID if requested
if return_id and 'id' not in args: if return_id and "id" not in args:
# Ensure id column exists # Ensure id column exists
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
# Add id column if it doesn't exist # Add id column if it doesn't exist
try: try:
await db.execute(f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT") await self._db_connection.execute(
await db.commit() f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
)
await self._db_connection.commit()
except aiosqlite.OperationalError as e: except aiosqlite.OperationalError as e:
if "duplicate column name" not in str(e).lower(): if "duplicate column name" not in str(e).lower():
# Try without autoincrement constraint # Try without autoincrement constraint
try: try:
await db.execute(f"ALTER TABLE {table} ADD COLUMN id INTEGER") await self._db_connection.execute(
await db.commit() f"ALTER TABLE {table} ADD COLUMN id INTEGER"
)
await self._db_connection.commit()
except: except:
pass pass
@ -275,7 +796,7 @@ class AsyncDataSet:
await self._safe_execute(table, sql, list(record.values()), record) await self._safe_execute(table, sql, list(record.values()), record)
return uid return uid
async def update( async def _server_update(
self, self,
table: str, table: str,
args: Dict[str, Any], args: Dict[str, Any],
@ -303,7 +824,9 @@ class AsyncDataSet:
cur = await self._safe_execute(table, sql, params, all_cols) cur = await self._safe_execute(table, sql, params, all_cols)
return cur.rowcount return cur.rowcount
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: async def _server_delete(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> int:
# Check if table exists # Check if table exists
if not await self._table_exists(table): if not await self._table_exists(table):
return 0 return 0
@ -313,7 +836,7 @@ class AsyncDataSet:
cur = await self._safe_execute(table, sql, where_params, where or {}) cur = await self._safe_execute(table, sql, where_params, where or {})
return cur.rowcount return cur.rowcount
async def upsert( async def _server_upsert(
self, self,
table: str, table: str,
args: Dict[str, Any], args: Dict[str, Any],
@ -321,15 +844,15 @@ class AsyncDataSet:
) -> str | None: ) -> str | None:
if not args: if not args:
raise ValueError("Nothing to update. Empty dict given.") raise ValueError("Nothing to update. Empty dict given.")
args['updated_at'] = self._utc_iso() args["updated_at"] = self._utc_iso()
affected = await self.update(table, args, where) affected = await self._server_update(table, args, where)
if affected: if affected:
rec = await self.get(table, where) rec = await self._server_get(table, where)
return rec.get("uid") if rec else None return rec.get("uid") if rec else None
merged = {**(where or {}), **args} merged = {**(where or {}), **args}
return await self.insert(table, merged) return await self._server_insert(table, merged)
async def get( async def _server_get(
self, table: str, where: Optional[Dict[str, Any]] = None self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
where_clause, where_params = self._build_where(where) where_clause, where_params = self._build_where(where)
@ -338,7 +861,7 @@ class AsyncDataSet:
return row return row
return None return None
async def find( async def _server_find(
self, self,
table: str, table: str,
where: Optional[Dict[str, Any]] = None, where: Optional[Dict[str, Any]] = None,
@ -358,7 +881,9 @@ class AsyncDataSet:
row async for row in self._safe_query(table, sql, where_params, where or {}) row async for row in self._safe_query(table, sql, where_params, where or {})
] ]
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: async def _server_count(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> int:
# Check if table exists # Check if table exists
if not await self._table_exists(table): if not await self._table_exists(table):
return 0 return 0
@ -370,10 +895,10 @@ class AsyncDataSet:
return next(iter(row.values()), 0) return next(iter(row.values()), 0)
return 0 return 0
async def exists(self, table: str, where: Dict[str, Any]) -> bool: async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool:
return (await self.count(table, where)) > 0 return (await self._server_count(table, where)) > 0
async def kv_set( async def _server_kv_set(
self, self,
key: str, key: str,
value: Any, value: Any,
@ -382,9 +907,9 @@ class AsyncDataSet:
) -> None: ) -> None:
tbl = table or self._KV_TABLE tbl = table or self._KV_TABLE
json_val = json.dumps(value, default=str) json_val = json.dumps(value, default=str)
await self.upsert(tbl, {"value": json_val}, {"key": key}) await self._server_upsert(tbl, {"value": json_val}, {"key": key})
async def kv_get( async def _server_kv_get(
self, self,
key: str, key: str,
*, *,
@ -392,7 +917,7 @@ class AsyncDataSet:
table: str | None = None, table: str | None = None,
) -> Any: ) -> Any:
tbl = table or self._KV_TABLE tbl = table or self._KV_TABLE
row = await self.get(tbl, {"key": key}) row = await self._server_get(tbl, {"key": key})
if not row: if not row:
return default return default
try: try:
@ -400,30 +925,41 @@ class AsyncDataSet:
except Exception: except Exception:
return default return default
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: async def _server_execute_raw(
self, sql: str, params: Optional[Tuple] = None
) -> Any:
"""Execute raw SQL for complex queries like JOINs.""" """Execute raw SQL for complex queries like JOINs."""
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
cursor = await db.execute(sql, params or ()) cursor = await self._db_connection.execute(sql, params or ())
await db.commit() await self._db_connection.commit()
return cursor return cursor
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]: async def _server_query_raw(
self, sql: str, params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
"""Execute raw SQL query and return results as list of dicts.""" """Execute raw SQL query and return results as list of dicts."""
try: try:
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
db.row_factory = aiosqlite.Row self._db_connection.row_factory = aiosqlite.Row
async with db.execute(sql, params or ()) as cursor: async with self._db_connection.execute(sql, params or ()) as cursor:
return [dict(row) async for row in cursor] return [dict(row) async for row in cursor]
except aiosqlite.OperationalError: except aiosqlite.OperationalError:
# Return empty list if query fails # Return empty list if query fails
return [] return []
async def query_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Dict[str, Any]]: async def _server_query_one(
self, sql: str, params: Optional[Tuple] = None
) -> Optional[Dict[str, Any]]:
"""Execute raw SQL query and return single result.""" """Execute raw SQL query and return single result."""
results = await self.query_raw(sql + " LIMIT 1", params) results = await self._server_query_raw(sql + " LIMIT 1", params)
return results[0] if results else None return results[0] if results else None
async def create_table(self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None): async def _server_create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
"""Create table with custom schema and constraints. Always includes default columns.""" """Create table with custom schema and constraints. Always includes default columns."""
# Merge default columns with custom schema # Merge default columns with custom schema
full_schema = self._DEFAULT_COLUMNS.copy() full_schema = self._DEFAULT_COLUMNS.copy()
@ -434,25 +970,31 @@ class AsyncDataSet:
columns.extend(constraints) columns.extend(constraints)
columns_sql = ", ".join(columns) columns_sql = ", ".join(columns)
async with aiosqlite.connect(self._file) as db: async with self._db_semaphore:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") await self._db_connection.execute(
await db.commit() f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
)
await self._db_connection.commit()
await self._invalidate_column_cache(table) await self._invalidate_column_cache(table)
async def insert_unique(self, table: str, args: Dict[str, Any], unique_fields: List[str]) -> Union[str, None]: async def _server_insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
"""Insert with unique constraint handling. Returns uid on success, None if duplicate.""" """Insert with unique constraint handling. Returns uid on success, None if duplicate."""
try: try:
return await self.insert(table, args) return await self._server_insert(table, args)
except aiosqlite.IntegrityError as e: except aiosqlite.IntegrityError as e:
if "UNIQUE" in str(e): if "UNIQUE" in str(e):
return None return None
raise raise
async def transaction(self): async def _server_aggregate(
"""Context manager for transactions.""" self,
return TransactionContext(self._file) table: str,
function: str,
async def aggregate(self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None) -> Any: column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
"""Perform aggregate functions like SUM, AVG, MAX, MIN.""" """Perform aggregate functions like SUM, AVG, MAX, MIN."""
# Check if table exists # Check if table exists
if not await self._table_exists(table): if not await self._table_exists(table):
@ -460,29 +1002,53 @@ class AsyncDataSet:
where_clause, where_params = self._build_where(where) where_clause, where_params = self._build_where(where)
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
result = await self.query_one(sql, tuple(where_params)) result = await self._server_query_one(sql, tuple(where_params))
return result['result'] if result else None return result["result"] if result else None
async def close(self):
"""Close the connection or server."""
if not self._initialized:
return
if self._client_session:
await self._client_session.close()
if self._runner:
await self._runner.cleanup()
if os.path.exists(self._socket_path) and self._is_server:
os.unlink(self._socket_path)
if self._db_connection:
await self._db_connection.close()
class TransactionContext: class TransactionContext:
"""Context manager for database transactions.""" """Context manager for database transactions."""
def __init__(self, db_file: str): def __init__(
self.db_file = db_file self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None
self.conn = None ):
self.db_connection = db_connection
self.semaphore = semaphore
async def __aenter__(self): async def __aenter__(self):
self.conn = await aiosqlite.connect(self.db_file) if self.semaphore:
self.conn.row_factory = aiosqlite.Row await self.semaphore.acquire()
await self.conn.execute("BEGIN") try:
return self.conn await self.db_connection.execute("BEGIN")
return self.db_connection
except:
if self.semaphore:
self.semaphore.release()
raise
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is None: if exc_type is None:
await self.conn.commit() await self.db_connection.commit()
else: else:
await self.conn.rollback() await self.db_connection.rollback()
await self.conn.close() finally:
if self.semaphore:
self.semaphore.release()
# Test cases remain the same but with additional tests for new functionality # Test cases remain the same but with additional tests for new functionality
@ -492,9 +1058,10 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
self.db_path = Path("temp_test.db") self.db_path = Path("temp_test.db")
if self.db_path.exists(): if self.db_path.exists():
self.db_path.unlink() self.db_path.unlink()
self.connector = AsyncDataSet(str(self.db_path)) self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1)
async def asyncTearDown(self): async def asyncTearDown(self):
await self.connector.close()
if self.db_path.exists(): if self.db_path.exists():
self.db_path.unlink() self.db_path.unlink()
@ -551,9 +1118,13 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
self.assertEqual(id2, id1 + 1) self.assertEqual(id2, id1 + 1)
async def test_transaction(self): async def test_transaction(self):
await self.connector._ensure_initialized()
if self.connector._is_server:
async with self.connector.transaction() as conn: async with self.connector.transaction() as conn:
await conn.execute("INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", await conn.execute(
("test-uid", "John", 30, "2024-01-01", "2024-01-01")) "INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
("test-uid", "John", 30, "2024-01-01", "2024-01-01"),
)
# Transaction will be committed # Transaction will be committed
rec = await self.connector.get("people", {"name": "John"}) rec = await self.connector.get("people", {"name": "John"})
@ -564,7 +1135,7 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
"id": "INTEGER PRIMARY KEY AUTOINCREMENT", "id": "INTEGER PRIMARY KEY AUTOINCREMENT",
"username": "TEXT NOT NULL", "username": "TEXT NOT NULL",
"email": "TEXT NOT NULL", "email": "TEXT NOT NULL",
"score": "INTEGER DEFAULT 0" "score": "INTEGER DEFAULT 0",
} }
constraints = ["UNIQUE(username)", "UNIQUE(email)"] constraints = ["UNIQUE(username)", "UNIQUE(email)"]
@ -574,7 +1145,7 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
result = await self.connector.insert_unique( result = await self.connector.insert_unique(
"users", "users",
{"username": "john", "email": "john@example.com"}, {"username": "john", "email": "john@example.com"},
["username", "email"] ["username", "email"],
) )
self.assertIsNotNone(result) self.assertIsNotNone(result)
@ -582,7 +1153,7 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
result = await self.connector.insert_unique( result = await self.connector.insert_unique(
"users", "users",
{"username": "john", "email": "different@example.com"}, {"username": "john", "email": "different@example.com"},
["username", "email"] ["username", "email"],
) )
self.assertIsNone(result) self.assertIsNone(result)
@ -593,14 +1164,20 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
self.assertIsNone(await self.connector.get("nonexistent")) self.assertIsNone(await self.connector.get("nonexistent"))
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1})) self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
self.assertEqual(await self.connector.delete("nonexistent"), 0) self.assertEqual(await self.connector.delete("nonexistent"), 0)
self.assertEqual(await self.connector.update("nonexistent", {"name": "test"}), 0) self.assertEqual(
await self.connector.update("nonexistent", {"name": "test"}), 0
)
async def test_auto_column_creation(self): async def test_auto_column_creation(self):
# Insert with new columns that don't exist yet # Insert with new columns that don't exist yet
await self.connector.insert("dynamic", {"col1": "value1", "col2": 42, "col3": 3.14}) await self.connector.insert(
"dynamic", {"col1": "value1", "col2": 42, "col3": 3.14}
)
# Add more columns in next insert # Add more columns in next insert
await self.connector.insert("dynamic", {"col1": "value2", "col4": True, "col5": None}) await self.connector.insert(
"dynamic", {"col1": "value2", "col4": True, "col5": None}
)
# All records should be retrievable # All records should be retrievable
records = await self.connector.find("dynamic") records = await self.connector.find("dynamic")

156
ads_test_server.py Normal file
View File

@ -0,0 +1,156 @@
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import os
import threading
from datetime import datetime
from ads import AsyncDataSet
import uvicorn
app = FastAPI()
# Initialize AsyncDataSet
db = AsyncDataSet(":memory:", socket_path="test_workers.sock")
@app.get("/write")
async def write_data():
"""Write process and thread ID to database."""
process_id = os.getpid()
thread_id = threading.get_ident()
timestamp = datetime.now().isoformat()
# Insert data with process and thread info
await db.insert(
"requests",
{
"process_id": process_id,
"thread_id": thread_id,
"timestamp": timestamp,
"endpoint": "/write",
},
)
return {
"status": "success",
"process_id": process_id,
"thread_id": thread_id,
"timestamp": timestamp,
}
@app.get("/data", response_class=HTMLResponse)
async def get_all_data():
"""Return all request data as HTML table."""
# Get all records ordered by timestamp
records = await db.find("requests", order_by="timestamp DESC")
# Build HTML table
html = """
<!DOCTYPE html>
<html>
<head>
<title>Worker Test Results</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; }
table { border-collapse: collapse; width: 100%; }
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
th { background-color: #4CAF50; color: white; }
tr:nth-child(even) { background-color: #f2f2f2; }
h1 { color: #333; }
.stats { margin: 20px 0; padding: 10px; background-color: #e7f3ff; border-radius: 5px; }
</style>
</head>
<body>
<h1>AsyncDataSet Multi-Worker Test Results</h1>
"""
# Calculate statistics
if records:
unique_processes = len(set(r["process_id"] for r in records))
unique_threads = len(set((r["process_id"], r["thread_id"]) for r in records))
total_requests = len(records)
html += f"""
<div class="stats">
<strong>Statistics:</strong><br>
Total Requests: {total_requests}<br>
Unique Processes: {unique_processes}<br>
Unique Process-Thread Combinations: {unique_threads}<br>
</div>
"""
html += """
<table>
<tr>
<th>ID</th>
<th>Process ID</th>
<th>Thread ID</th>
<th>Timestamp</th>
<th>Endpoint</th>
</tr>
"""
for record in records:
html += f"""
<tr>
<td>{record.get('uid', 'N/A')[:8]}...</td>
<td>{record.get('process_id', 'N/A')}</td>
<td>{record.get('thread_id', 'N/A')}</td>
<td>{record.get('timestamp', 'N/A')}</td>
<td>{record.get('endpoint', 'N/A')}</td>
</tr>
"""
html += """
</table>
<p style="margin-top: 20px;">
<a href="/write" target="_blank">Make a write request</a> |
<a href="/data">Refresh</a>
</p>
</body>
</html>
"""
return html
@app.get("/")
async def root():
"""Root endpoint showing process and thread info."""
process_id = os.getpid()
thread_id = threading.get_ident()
# Log this request too
await db.insert(
"requests",
{
"process_id": process_id,
"thread_id": thread_id,
"timestamp": datetime.now().isoformat(),
"endpoint": "/",
},
)
return {
"message": "FastAPI Multi-Worker Test Server",
"current_process_id": process_id,
"current_thread_id": thread_id,
"endpoints": {
"/": "This endpoint",
"/write": "Write process/thread data (POST)",
"/data": "View all data as HTML table (GET)",
},
}
if __name__ == "__main__":
# Run with multiple workers
# You can also run this from command line:
# uvicorn test_server:app --workers 4 --host 0.0.0.0 --port 8000
uvicorn.run(
"ads_test_server:app",
host="0.0.0.0",
port=8000,
workers=50, # Run with 4 worker processes
reload=False,
)

656
adsc.py Normal file
View File

@ -0,0 +1,656 @@
import re
import json
from uuid import uuid4
from datetime import datetime, timezone
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
AsyncGenerator,
Union,
Tuple,
Set,
)
from pathlib import Path
import aiosqlite
import unittest
from types import SimpleNamespace
import asyncio
class AsyncDataSet:
_KV_TABLE = "__kv_store"
_DEFAULT_COLUMNS = {
"uid": "TEXT PRIMARY KEY",
"created_at": "TEXT",
"updated_at": "TEXT",
"deleted_at": "TEXT",
}
def __init__(self, file: str):
self._file = file
self._table_columns_cache: Dict[str, Set[str]] = {}
@staticmethod
def _utc_iso() -> str:
return (
datetime.now(timezone.utc)
.replace(microsecond=0)
.isoformat()
.replace("+00:00", "Z")
)
@staticmethod
def _py_to_sqlite_type(value: Any) -> str:
if value is None:
return "TEXT"
if isinstance(value, bool):
return "INTEGER"
if isinstance(value, int):
return "INTEGER"
if isinstance(value, float):
return "REAL"
if isinstance(value, (bytes, bytearray, memoryview)):
return "BLOB"
return "TEXT"
async def _get_table_columns(self, table: str) -> Set[str]:
"""Get actual columns that exist in the table."""
if table in self._table_columns_cache:
return self._table_columns_cache[table]
columns = set()
try:
async with aiosqlite.connect(self._file) as db:
async with db.execute(f"PRAGMA table_info({table})") as cursor:
async for row in cursor:
columns.add(row[1]) # Column name is at index 1
self._table_columns_cache[table] = columns
except:
pass
return columns
async def _invalidate_column_cache(self, table: str):
"""Invalidate column cache for a table."""
if table in self._table_columns_cache:
del self._table_columns_cache[table]
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
col_type = self._py_to_sqlite_type(value)
try:
async with aiosqlite.connect(self._file) as db:
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
await db.commit()
await self._invalidate_column_cache(table)
except aiosqlite.OperationalError as e:
if "duplicate column name" in str(e).lower():
pass # Column already exists
else:
raise
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
# Always include default columns
cols = self._DEFAULT_COLUMNS.copy()
# Add columns from col_sources
for key, val in col_sources.items():
if key not in cols:
cols[key] = self._py_to_sqlite_type(val)
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
async with aiosqlite.connect(self._file) as db:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await db.commit()
await self._invalidate_column_cache(table)
async def _table_exists(self, table: str) -> bool:
"""Check if a table exists."""
async with aiosqlite.connect(self._file) as db:
async with db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
) as cursor:
return await cursor.fetchone() is not None
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
@classmethod
def _missing_column_from_error(
cls, err: aiosqlite.OperationalError
) -> Optional[str]:
m = cls._RE_NO_COLUMN.search(str(err))
return m.group(1) if m else None
@classmethod
def _missing_table_from_error(
cls, err: aiosqlite.OperationalError
) -> Optional[str]:
m = cls._RE_NO_TABLE.search(str(err))
return m.group(1) if m else None
async def _safe_execute(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
max_retries: int = 10,
) -> aiosqlite.Cursor:
retries = 0
while retries < max_retries:
try:
async with aiosqlite.connect(self._file) as db:
cursor = await db.execute(sql, params)
await db.commit()
return cursor
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing column
col = self._missing_column_from_error(err)
if col:
if col in col_sources:
await self._ensure_column(table, col, col_sources[col])
else:
# Column not in sources, ensure it with NULL/TEXT type
await self._ensure_column(table, col, None)
continue
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
await self._ensure_table(tbl, col_sources)
continue
# Handle other column-related errors
if "has no column named" in err_str:
# Extract column name differently
match = re.search(r"table \w+ has no column named (\w+)", err_str)
if match:
col_name = match.group(1)
if col_name in col_sources:
await self._ensure_column(
table, col_name, col_sources[col_name]
)
else:
await self._ensure_column(table, col_name, None)
continue
raise
raise Exception(f"Max retries ({max_retries}) exceeded")
async def _filter_existing_columns(
self, table: str, data: Dict[str, Any]
) -> Dict[str, Any]:
"""Filter data to only include columns that exist in the table."""
if not await self._table_exists(table):
return data
existing_columns = await self._get_table_columns(table)
if not existing_columns:
return data
return {k: v for k, v in data.items() if k in existing_columns}
async def _safe_query(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
) -> AsyncGenerator[Dict[str, Any], None]:
# Check if table exists first
if not await self._table_exists(table):
return
max_retries = 10
retries = 0
while retries < max_retries:
try:
async with aiosqlite.connect(self._file) as db:
db.row_factory = aiosqlite.Row
async with db.execute(sql, params) as cursor:
async for row in cursor:
yield dict(row)
return
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
# For queries, if table doesn't exist, just return empty
return
# Handle missing column in WHERE clause or SELECT
if "no such column" in err_str:
# For queries with missing columns, return empty
return
raise
@staticmethod
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where:
return "", []
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
return " WHERE " + " AND ".join(clauses), list(vals)
async def insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
uid = str(uuid4())
now = self._utc_iso()
record = {
"uid": uid,
"created_at": now,
"updated_at": now,
"deleted_at": None,
**args,
}
# Ensure table exists with all needed columns
await self._ensure_table(table, record)
# Handle auto-increment ID if requested
if return_id and "id" not in args:
# Ensure id column exists
async with aiosqlite.connect(self._file) as db:
# Add id column if it doesn't exist
try:
await db.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
)
await db.commit()
except aiosqlite.OperationalError as e:
if "duplicate column name" not in str(e).lower():
# Try without autoincrement constraint
try:
await db.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
)
await db.commit()
except:
pass
await self._invalidate_column_cache(table)
# Insert and get lastrowid
cols = "`" + "`, `".join(record.keys()) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
cursor = await self._safe_execute(table, sql, list(record.values()), record)
return cursor.lastrowid
cols = "`" + "`, `".join(record) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
await self._safe_execute(table, sql, list(record.values()), record)
return uid
async def update(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> int:
if not args:
return 0
# Check if table exists
if not await self._table_exists(table):
return 0
args["updated_at"] = self._utc_iso()
# Ensure all columns exist
all_cols = {**args, **(where or {})}
await self._ensure_table(table, all_cols)
for col, val in all_cols.items():
await self._ensure_column(table, col, val)
set_clause = ", ".join(f"`{k}` = ?" for k in args)
where_clause, where_params = self._build_where(where)
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
params = list(args.values()) + where_params
cur = await self._safe_execute(table, sql, params, all_cols)
return cur.rowcount
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"DELETE FROM {table}{where_clause}"
cur = await self._safe_execute(table, sql, where_params, where or {})
return cur.rowcount
async def upsert(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> str | None:
if not args:
raise ValueError("Nothing to update. Empty dict given.")
args["updated_at"] = self._utc_iso()
affected = await self.update(table, args, where)
if affected:
rec = await self.get(table, where)
return rec.get("uid") if rec else None
merged = {**(where or {}), **args}
return await self.insert(table, merged)
async def get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
where_clause, where_params = self._build_where(where)
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
async for row in self._safe_query(table, sql, where_params, where or {}):
return row
return None
async def find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Find records with optional ordering."""
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {limit}" if limit else "") + (
f" OFFSET {offset}" if offset else ""
)
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
return [
row async for row in self._safe_query(table, sql, where_params, where or {})
]
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
gen = self._safe_query(table, sql, where_params, where or {})
async for row in gen:
return next(iter(row.values()), 0)
return 0
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
return (await self.count(table, where)) > 0
async def kv_set(
self,
key: str,
value: Any,
*,
table: str | None = None,
) -> None:
tbl = table or self._KV_TABLE
json_val = json.dumps(value, default=str)
await self.upsert(tbl, {"value": json_val}, {"key": key})
async def kv_get(
self,
key: str,
*,
default: Any = None,
table: str | None = None,
) -> Any:
tbl = table or self._KV_TABLE
row = await self.get(tbl, {"key": key})
if not row:
return default
try:
return json.loads(row["value"])
except Exception:
return default
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
"""Execute raw SQL for complex queries like JOINs."""
async with aiosqlite.connect(self._file) as db:
cursor = await db.execute(sql, params or ())
await db.commit()
return cursor
async def query_raw(
self, sql: str, params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
"""Execute raw SQL query and return results as list of dicts."""
try:
async with aiosqlite.connect(self._file) as db:
db.row_factory = aiosqlite.Row
async with db.execute(sql, params or ()) as cursor:
return [dict(row) async for row in cursor]
except aiosqlite.OperationalError:
# Return empty list if query fails
return []
async def query_one(
self, sql: str, params: Optional[Tuple] = None
) -> Optional[Dict[str, Any]]:
"""Execute raw SQL query and return single result."""
results = await self.query_raw(sql + " LIMIT 1", params)
return results[0] if results else None
async def create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
"""Create table with custom schema and constraints. Always includes default columns."""
# Merge default columns with custom schema
full_schema = self._DEFAULT_COLUMNS.copy()
full_schema.update(schema)
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
if constraints:
columns.extend(constraints)
columns_sql = ", ".join(columns)
async with aiosqlite.connect(self._file) as db:
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await db.commit()
await self._invalidate_column_cache(table)
async def insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
try:
return await self.insert(table, args)
except aiosqlite.IntegrityError as e:
if "UNIQUE" in str(e):
return None
raise
async def transaction(self):
"""Context manager for transactions."""
return TransactionContext(self._file)
async def aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
# Check if table exists
if not await self._table_exists(table):
return None
where_clause, where_params = self._build_where(where)
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
result = await self.query_one(sql, tuple(where_params))
return result["result"] if result else None
class TransactionContext:
"""Context manager for database transactions."""
def __init__(self, db_file: str):
self.db_file = db_file
self.conn = None
async def __aenter__(self):
self.conn = await aiosqlite.connect(self.db_file)
self.conn.row_factory = aiosqlite.Row
await self.conn.execute("BEGIN")
return self.conn
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
await self.conn.commit()
else:
await self.conn.rollback()
await self.conn.close()
# Test cases remain the same but with additional tests for new functionality
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.db_path = Path("temp_test.db")
if self.db_path.exists():
self.db_path.unlink()
self.connector = AsyncDataSet(str(self.db_path))
async def asyncTearDown(self):
if self.db_path.exists():
self.db_path.unlink()
async def test_insert_and_get(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
rec = await self.connector.get("people", {"name": "John Doe"})
self.assertIsNotNone(rec)
self.assertEqual(rec["name"], "John Doe")
async def test_get_nonexistent(self):
result = await self.connector.get("people", {"name": "Jane Doe"})
self.assertIsNone(result)
async def test_update(self):
await self.connector.insert("people", {"name": "John Doe", "age": 30})
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
rec = await self.connector.get("people", {"name": "John Doe"})
self.assertEqual(rec["age"], 31)
async def test_order_by(self):
await self.connector.insert("people", {"name": "Alice", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 30})
await self.connector.insert("people", {"name": "Charlie", "age": 20})
results = await self.connector.find("people", order_by="age ASC")
self.assertEqual(results[0]["name"], "Charlie")
self.assertEqual(results[-1]["name"], "Bob")
async def test_raw_query(self):
await self.connector.insert("people", {"name": "John", "age": 30})
await self.connector.insert("people", {"name": "Jane", "age": 25})
results = await self.connector.query_raw(
"SELECT * FROM people WHERE age > ?", (26,)
)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["name"], "John")
async def test_aggregate(self):
await self.connector.insert("people", {"name": "John", "age": 30})
await self.connector.insert("people", {"name": "Jane", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 35})
avg_age = await self.connector.aggregate("people", "AVG", "age")
self.assertEqual(avg_age, 30)
max_age = await self.connector.aggregate("people", "MAX", "age")
self.assertEqual(max_age, 35)
async def test_insert_with_auto_id(self):
# Test auto-increment ID functionality
id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True)
id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True)
self.assertEqual(id2, id1 + 1)
async def test_transaction(self):
async with self.connector.transaction() as conn:
await conn.execute(
"INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
("test-uid", "John", 30, "2024-01-01", "2024-01-01"),
)
# Transaction will be committed
rec = await self.connector.get("people", {"name": "John"})
self.assertIsNotNone(rec)
async def test_create_custom_table(self):
schema = {
"id": "INTEGER PRIMARY KEY AUTOINCREMENT",
"username": "TEXT NOT NULL",
"email": "TEXT NOT NULL",
"score": "INTEGER DEFAULT 0",
}
constraints = ["UNIQUE(username)", "UNIQUE(email)"]
await self.connector.create_table("users", schema, constraints)
# Test that table was created with constraints
result = await self.connector.insert_unique(
"users",
{"username": "john", "email": "john@example.com"},
["username", "email"],
)
self.assertIsNotNone(result)
# Test duplicate insert
result = await self.connector.insert_unique(
"users",
{"username": "john", "email": "different@example.com"},
["username", "email"],
)
self.assertIsNone(result)
async def test_missing_table_operations(self):
# Test operations on non-existent tables
self.assertEqual(await self.connector.count("nonexistent"), 0)
self.assertEqual(await self.connector.find("nonexistent"), [])
self.assertIsNone(await self.connector.get("nonexistent"))
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
self.assertEqual(await self.connector.delete("nonexistent"), 0)
self.assertEqual(
await self.connector.update("nonexistent", {"name": "test"}), 0
)
async def test_auto_column_creation(self):
# Insert with new columns that don't exist yet
await self.connector.insert(
"dynamic", {"col1": "value1", "col2": 42, "col3": 3.14}
)
# Add more columns in next insert
await self.connector.insert(
"dynamic", {"col1": "value2", "col4": True, "col5": None}
)
# All records should be retrievable
records = await self.connector.find("dynamic")
self.assertEqual(len(records), 2)
if __name__ == "__main__":
unittest.main()