Updates.
This commit is contained in:
parent
7334c59aa7
commit
49517fc71d
751
ads.py
751
ads.py
@ -2,15 +2,39 @@ import re
|
|||||||
import json
|
import json
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple, Set
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
AsyncGenerator,
|
||||||
|
Union,
|
||||||
|
Tuple,
|
||||||
|
Set,
|
||||||
|
)
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
import socket
|
||||||
|
import os
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
|
||||||
class AsyncDataSet:
|
class AsyncDataSet:
|
||||||
|
"""
|
||||||
|
Distributed AsyncDataSet with client-server model over Unix sockets.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- file: Path to the SQLite database file
|
||||||
|
- socket_path: Path to Unix socket for inter-process communication (default: "ads.sock")
|
||||||
|
- max_concurrent_queries: Maximum concurrent database operations (default: 1)
|
||||||
|
Set to 1 to prevent "database is locked" errors with SQLite
|
||||||
|
"""
|
||||||
|
|
||||||
_KV_TABLE = "__kv_store"
|
_KV_TABLE = "__kv_store"
|
||||||
_DEFAULT_COLUMNS = {
|
_DEFAULT_COLUMNS = {
|
||||||
@ -20,10 +44,487 @@ class AsyncDataSet:
|
|||||||
"deleted_at": "TEXT",
|
"deleted_at": "TEXT",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, file: str):
|
def __init__(
|
||||||
|
self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1
|
||||||
|
):
|
||||||
self._file = file
|
self._file = file
|
||||||
|
self._socket_path = socket_path
|
||||||
self._table_columns_cache: Dict[str, Set[str]] = {}
|
self._table_columns_cache: Dict[str, Set[str]] = {}
|
||||||
|
self._is_server = None # None means not initialized yet
|
||||||
|
self._server = None
|
||||||
|
self._runner = None
|
||||||
|
self._client_session = None
|
||||||
|
self._initialized = False
|
||||||
|
self._init_lock = None # Will be created when needed
|
||||||
|
self._max_concurrent_queries = max_concurrent_queries
|
||||||
|
self._db_semaphore = None # Will be created when needed
|
||||||
|
self._db_connection = None # Persistent database connection
|
||||||
|
|
||||||
|
async def _ensure_initialized(self):
|
||||||
|
"""Ensure the instance is initialized as server or client."""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create lock if needed (first time in async context)
|
||||||
|
if self._init_lock is None:
|
||||||
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async with self._init_lock:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._initialize()
|
||||||
|
|
||||||
|
# Create semaphore for database operations (server only)
|
||||||
|
if self._is_server:
|
||||||
|
self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
async def _initialize(self):
|
||||||
|
"""Initialize as server or client."""
|
||||||
|
try:
|
||||||
|
# Try to create Unix socket
|
||||||
|
if os.path.exists(self._socket_path):
|
||||||
|
# Check if socket is active
|
||||||
|
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
|
try:
|
||||||
|
sock.connect(self._socket_path)
|
||||||
|
sock.close()
|
||||||
|
# Socket is active, we're a client
|
||||||
|
self._is_server = False
|
||||||
|
await self._setup_client()
|
||||||
|
except (ConnectionRefusedError, FileNotFoundError):
|
||||||
|
# Socket file exists but not active, remove and become server
|
||||||
|
os.unlink(self._socket_path)
|
||||||
|
self._is_server = True
|
||||||
|
await self._setup_server()
|
||||||
|
else:
|
||||||
|
# No socket file, become server
|
||||||
|
self._is_server = True
|
||||||
|
await self._setup_server()
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback to client mode
|
||||||
|
self._is_server = False
|
||||||
|
await self._setup_client()
|
||||||
|
|
||||||
|
# Establish persistent database connection
|
||||||
|
self._db_connection = await aiosqlite.connect(self._file)
|
||||||
|
|
||||||
|
async def _setup_server(self):
|
||||||
|
"""Setup aiohttp server."""
|
||||||
|
app = web.Application()
|
||||||
|
|
||||||
|
# Add routes for all methods
|
||||||
|
app.router.add_post("/insert", self._handle_insert)
|
||||||
|
app.router.add_post("/update", self._handle_update)
|
||||||
|
app.router.add_post("/delete", self._handle_delete)
|
||||||
|
app.router.add_post("/upsert", self._handle_upsert)
|
||||||
|
app.router.add_post("/get", self._handle_get)
|
||||||
|
app.router.add_post("/find", self._handle_find)
|
||||||
|
app.router.add_post("/count", self._handle_count)
|
||||||
|
app.router.add_post("/exists", self._handle_exists)
|
||||||
|
app.router.add_post("/kv_set", self._handle_kv_set)
|
||||||
|
app.router.add_post("/kv_get", self._handle_kv_get)
|
||||||
|
app.router.add_post("/execute_raw", self._handle_execute_raw)
|
||||||
|
app.router.add_post("/query_raw", self._handle_query_raw)
|
||||||
|
app.router.add_post("/query_one", self._handle_query_one)
|
||||||
|
app.router.add_post("/create_table", self._handle_create_table)
|
||||||
|
app.router.add_post("/insert_unique", self._handle_insert_unique)
|
||||||
|
app.router.add_post("/aggregate", self._handle_aggregate)
|
||||||
|
|
||||||
|
self._runner = web.AppRunner(app)
|
||||||
|
await self._runner.setup()
|
||||||
|
self._server = web.UnixSite(self._runner, self._socket_path)
|
||||||
|
await self._server.start()
|
||||||
|
|
||||||
|
# Register cleanup
|
||||||
|
def cleanup():
|
||||||
|
if os.path.exists(self._socket_path):
|
||||||
|
os.unlink(self._socket_path)
|
||||||
|
|
||||||
|
atexit.register(cleanup)
|
||||||
|
|
||||||
|
async def _setup_client(self):
|
||||||
|
"""Setup aiohttp client."""
|
||||||
|
connector = aiohttp.UnixConnector(path=self._socket_path)
|
||||||
|
self._client_session = aiohttp.ClientSession(connector=connector)
|
||||||
|
|
||||||
|
async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any:
|
||||||
|
"""Make HTTP request to server."""
|
||||||
|
if not self._client_session:
|
||||||
|
await self._setup_client()
|
||||||
|
|
||||||
|
url = f"http://localhost/{endpoint}"
|
||||||
|
async with self._client_session.post(url, json=data) as resp:
|
||||||
|
result = await resp.json()
|
||||||
|
if result.get("error"):
|
||||||
|
raise Exception(result["error"])
|
||||||
|
return result.get("result")
|
||||||
|
|
||||||
|
# Server handlers
|
||||||
|
async def _handle_insert(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_insert(
|
||||||
|
data["table"], data["args"], data.get("return_id", False)
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_update(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_update(
|
||||||
|
data["table"], data["args"], data.get("where")
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_delete(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_delete(data["table"], data.get("where"))
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_upsert(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_upsert(
|
||||||
|
data["table"], data["args"], data.get("where")
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_get(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_get(data["table"], data.get("where"))
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_find(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_find(
|
||||||
|
data["table"],
|
||||||
|
data.get("where"),
|
||||||
|
limit=data.get("limit", 0),
|
||||||
|
offset=data.get("offset", 0),
|
||||||
|
order_by=data.get("order_by"),
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_count(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_count(data["table"], data.get("where"))
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_exists(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_exists(data["table"], data["where"])
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_kv_set(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
await self._server_kv_set(
|
||||||
|
data["key"], data["value"], table=data.get("table")
|
||||||
|
)
|
||||||
|
return web.json_response({"result": None})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_kv_get(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_kv_get(
|
||||||
|
data["key"], default=data.get("default"), table=data.get("table")
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_execute_raw(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_execute_raw(
|
||||||
|
data["sql"],
|
||||||
|
tuple(data.get("params", [])) if data.get("params") else None,
|
||||||
|
)
|
||||||
|
return web.json_response(
|
||||||
|
{"result": result.rowcount if hasattr(result, "rowcount") else None}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_query_raw(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_query_raw(
|
||||||
|
data["sql"],
|
||||||
|
tuple(data.get("params", [])) if data.get("params") else None,
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_query_one(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_query_one(
|
||||||
|
data["sql"],
|
||||||
|
tuple(data.get("params", [])) if data.get("params") else None,
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_create_table(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
await self._server_create_table(
|
||||||
|
data["table"], data["schema"], data.get("constraints")
|
||||||
|
)
|
||||||
|
return web.json_response({"result": None})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_insert_unique(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_insert_unique(
|
||||||
|
data["table"], data["args"], data["unique_fields"]
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_aggregate(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_aggregate(
|
||||||
|
data["table"],
|
||||||
|
data["function"],
|
||||||
|
data.get("column", "*"),
|
||||||
|
data.get("where"),
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
# Public methods that delegate to server or client
|
||||||
|
async def insert(
|
||||||
|
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||||
|
) -> Union[str, int]:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_insert(table, args, return_id)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"insert", {"table": table, "args": args, "return_id": return_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> int:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_update(table, args, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"update", {"table": table, "args": args, "where": where}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_delete(table, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request("delete", {"table": table, "where": where})
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> str | None:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_upsert(table, args, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"upsert", {"table": table, "args": args, "where": where}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_get(table, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request("get", {"table": table, "where": where})
|
||||||
|
|
||||||
|
async def find(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
*,
|
||||||
|
limit: int = 0,
|
||||||
|
offset: int = 0,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_find(
|
||||||
|
table, where, limit=limit, offset=offset, order_by=order_by
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"find",
|
||||||
|
{
|
||||||
|
"table": table,
|
||||||
|
"where": where,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"order_by": order_by,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_count(table, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request("count", {"table": table, "where": where})
|
||||||
|
|
||||||
|
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_exists(table, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request("exists", {"table": table, "where": where})
|
||||||
|
|
||||||
|
async def kv_set(self, key: str, value: Any, *, table: str | None = None) -> None:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_kv_set(key, value, table=table)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"kv_set", {"key": key, "value": value, "table": table}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def kv_get(
|
||||||
|
self, key: str, *, default: Any = None, table: str | None = None
|
||||||
|
) -> Any:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_kv_get(key, default=default, table=table)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"kv_get", {"key": key, "default": default, "table": table}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_execute_raw(sql, params)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"execute_raw", {"sql": sql, "params": list(params) if params else None}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query_raw(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_query_raw(sql, params)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"query_raw", {"sql": sql, "params": list(params) if params else None}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query_one(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_query_one(sql, params)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"query_one", {"sql": sql, "params": list(params) if params else None}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_table(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
schema: Dict[str, str],
|
||||||
|
constraints: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_create_table(table, schema, constraints)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"create_table",
|
||||||
|
{"table": table, "schema": schema, "constraints": constraints},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_unique(
|
||||||
|
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||||
|
) -> Union[str, None]:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_insert_unique(table, args, unique_fields)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"insert_unique",
|
||||||
|
{"table": table, "args": args, "unique_fields": unique_fields},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
function: str,
|
||||||
|
column: str = "*",
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Any:
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_aggregate(table, function, column, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"aggregate",
|
||||||
|
{
|
||||||
|
"table": table,
|
||||||
|
"function": function,
|
||||||
|
"column": column,
|
||||||
|
"where": where,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def transaction(self):
|
||||||
|
"""Context manager for transactions."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return TransactionContext(self._db_connection, self._db_semaphore)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Transactions not supported in client mode")
|
||||||
|
|
||||||
|
# Server implementation methods (original logic)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _utc_iso() -> str:
|
def _utc_iso() -> str:
|
||||||
return (
|
return (
|
||||||
@ -54,8 +555,10 @@ class AsyncDataSet:
|
|||||||
|
|
||||||
columns = set()
|
columns = set()
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
async with db.execute(f"PRAGMA table_info({table})") as cursor:
|
async with self._db_connection.execute(
|
||||||
|
f"PRAGMA table_info({table})"
|
||||||
|
) as cursor:
|
||||||
async for row in cursor:
|
async for row in cursor:
|
||||||
columns.add(row[1]) # Column name is at index 1
|
columns.add(row[1]) # Column name is at index 1
|
||||||
self._table_columns_cache[table] = columns
|
self._table_columns_cache[table] = columns
|
||||||
@ -71,9 +574,11 @@ class AsyncDataSet:
|
|||||||
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
||||||
col_type = self._py_to_sqlite_type(value)
|
col_type = self._py_to_sqlite_type(value)
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
await self._db_connection.execute(
|
||||||
await db.commit()
|
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
|
||||||
|
)
|
||||||
|
await self._db_connection.commit()
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
except aiosqlite.OperationalError as e:
|
except aiosqlite.OperationalError as e:
|
||||||
if "duplicate column name" in str(e).lower():
|
if "duplicate column name" in str(e).lower():
|
||||||
@ -91,15 +596,17 @@ class AsyncDataSet:
|
|||||||
cols[key] = self._py_to_sqlite_type(val)
|
cols[key] = self._py_to_sqlite_type(val)
|
||||||
|
|
||||||
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
await self._db_connection.execute(
|
||||||
await db.commit()
|
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
|
||||||
|
)
|
||||||
|
await self._db_connection.commit()
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
async def _table_exists(self, table: str) -> bool:
|
async def _table_exists(self, table: str) -> bool:
|
||||||
"""Check if a table exists."""
|
"""Check if a table exists."""
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
async with db.execute(
|
async with self._db_connection.execute(
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
||||||
) as cursor:
|
) as cursor:
|
||||||
return await cursor.fetchone() is not None
|
return await cursor.fetchone() is not None
|
||||||
@ -127,14 +634,14 @@ class AsyncDataSet:
|
|||||||
sql: str,
|
sql: str,
|
||||||
params: Iterable[Any],
|
params: Iterable[Any],
|
||||||
col_sources: Dict[str, Any],
|
col_sources: Dict[str, Any],
|
||||||
max_retries: int = 10
|
max_retries: int = 10,
|
||||||
) -> aiosqlite.Cursor:
|
) -> aiosqlite.Cursor:
|
||||||
retries = 0
|
retries = 0
|
||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
cursor = await db.execute(sql, params)
|
cursor = await self._db_connection.execute(sql, params)
|
||||||
await db.commit()
|
await self._db_connection.commit()
|
||||||
return cursor
|
return cursor
|
||||||
except aiosqlite.OperationalError as err:
|
except aiosqlite.OperationalError as err:
|
||||||
retries += 1
|
retries += 1
|
||||||
@ -163,7 +670,9 @@ class AsyncDataSet:
|
|||||||
if match:
|
if match:
|
||||||
col_name = match.group(1)
|
col_name = match.group(1)
|
||||||
if col_name in col_sources:
|
if col_name in col_sources:
|
||||||
await self._ensure_column(table, col_name, col_sources[col_name])
|
await self._ensure_column(
|
||||||
|
table, col_name, col_sources[col_name]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await self._ensure_column(table, col_name, None)
|
await self._ensure_column(table, col_name, None)
|
||||||
continue
|
continue
|
||||||
@ -171,7 +680,9 @@ class AsyncDataSet:
|
|||||||
raise
|
raise
|
||||||
raise Exception(f"Max retries ({max_retries}) exceeded")
|
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||||
|
|
||||||
async def _filter_existing_columns(self, table: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
async def _filter_existing_columns(
|
||||||
|
self, table: str, data: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Filter data to only include columns that exist in the table."""
|
"""Filter data to only include columns that exist in the table."""
|
||||||
if not await self._table_exists(table):
|
if not await self._table_exists(table):
|
||||||
return data
|
return data
|
||||||
@ -198,10 +709,14 @@ class AsyncDataSet:
|
|||||||
|
|
||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
db.row_factory = aiosqlite.Row
|
self._db_connection.row_factory = aiosqlite.Row
|
||||||
async with db.execute(sql, params) as cursor:
|
async with self._db_connection.execute(sql, params) as cursor:
|
||||||
async for row in cursor:
|
# Fetch all rows while holding the semaphore
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
# Yield rows after releasing the semaphore
|
||||||
|
for row in rows:
|
||||||
yield dict(row)
|
yield dict(row)
|
||||||
return
|
return
|
||||||
except aiosqlite.OperationalError as err:
|
except aiosqlite.OperationalError as err:
|
||||||
@ -228,7 +743,9 @@ class AsyncDataSet:
|
|||||||
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
||||||
return " WHERE " + " AND ".join(clauses), list(vals)
|
return " WHERE " + " AND ".join(clauses), list(vals)
|
||||||
|
|
||||||
async def insert(self, table: str, args: Dict[str, Any], return_id: bool = False) -> Union[str, int]:
|
async def _server_insert(
|
||||||
|
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||||
|
) -> Union[str, int]:
|
||||||
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
|
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
|
||||||
uid = str(uuid4())
|
uid = str(uuid4())
|
||||||
now = self._utc_iso()
|
now = self._utc_iso()
|
||||||
@ -244,19 +761,23 @@ class AsyncDataSet:
|
|||||||
await self._ensure_table(table, record)
|
await self._ensure_table(table, record)
|
||||||
|
|
||||||
# Handle auto-increment ID if requested
|
# Handle auto-increment ID if requested
|
||||||
if return_id and 'id' not in args:
|
if return_id and "id" not in args:
|
||||||
# Ensure id column exists
|
# Ensure id column exists
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
# Add id column if it doesn't exist
|
# Add id column if it doesn't exist
|
||||||
try:
|
try:
|
||||||
await db.execute(f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT")
|
await self._db_connection.execute(
|
||||||
await db.commit()
|
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||||
|
)
|
||||||
|
await self._db_connection.commit()
|
||||||
except aiosqlite.OperationalError as e:
|
except aiosqlite.OperationalError as e:
|
||||||
if "duplicate column name" not in str(e).lower():
|
if "duplicate column name" not in str(e).lower():
|
||||||
# Try without autoincrement constraint
|
# Try without autoincrement constraint
|
||||||
try:
|
try:
|
||||||
await db.execute(f"ALTER TABLE {table} ADD COLUMN id INTEGER")
|
await self._db_connection.execute(
|
||||||
await db.commit()
|
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
|
||||||
|
)
|
||||||
|
await self._db_connection.commit()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -275,7 +796,7 @@ class AsyncDataSet:
|
|||||||
await self._safe_execute(table, sql, list(record.values()), record)
|
await self._safe_execute(table, sql, list(record.values()), record)
|
||||||
return uid
|
return uid
|
||||||
|
|
||||||
async def update(
|
async def _server_update(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
args: Dict[str, Any],
|
args: Dict[str, Any],
|
||||||
@ -303,7 +824,9 @@ class AsyncDataSet:
|
|||||||
cur = await self._safe_execute(table, sql, params, all_cols)
|
cur = await self._safe_execute(table, sql, params, all_cols)
|
||||||
return cur.rowcount
|
return cur.rowcount
|
||||||
|
|
||||||
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
async def _server_delete(
|
||||||
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> int:
|
||||||
# Check if table exists
|
# Check if table exists
|
||||||
if not await self._table_exists(table):
|
if not await self._table_exists(table):
|
||||||
return 0
|
return 0
|
||||||
@ -313,7 +836,7 @@ class AsyncDataSet:
|
|||||||
cur = await self._safe_execute(table, sql, where_params, where or {})
|
cur = await self._safe_execute(table, sql, where_params, where or {})
|
||||||
return cur.rowcount
|
return cur.rowcount
|
||||||
|
|
||||||
async def upsert(
|
async def _server_upsert(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
args: Dict[str, Any],
|
args: Dict[str, Any],
|
||||||
@ -321,15 +844,15 @@ class AsyncDataSet:
|
|||||||
) -> str | None:
|
) -> str | None:
|
||||||
if not args:
|
if not args:
|
||||||
raise ValueError("Nothing to update. Empty dict given.")
|
raise ValueError("Nothing to update. Empty dict given.")
|
||||||
args['updated_at'] = self._utc_iso()
|
args["updated_at"] = self._utc_iso()
|
||||||
affected = await self.update(table, args, where)
|
affected = await self._server_update(table, args, where)
|
||||||
if affected:
|
if affected:
|
||||||
rec = await self.get(table, where)
|
rec = await self._server_get(table, where)
|
||||||
return rec.get("uid") if rec else None
|
return rec.get("uid") if rec else None
|
||||||
merged = {**(where or {}), **args}
|
merged = {**(where or {}), **args}
|
||||||
return await self.insert(table, merged)
|
return await self._server_insert(table, merged)
|
||||||
|
|
||||||
async def get(
|
async def _server_get(
|
||||||
self, table: str, where: Optional[Dict[str, Any]] = None
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
where_clause, where_params = self._build_where(where)
|
where_clause, where_params = self._build_where(where)
|
||||||
@ -338,7 +861,7 @@ class AsyncDataSet:
|
|||||||
return row
|
return row
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def find(
|
async def _server_find(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
where: Optional[Dict[str, Any]] = None,
|
where: Optional[Dict[str, Any]] = None,
|
||||||
@ -358,7 +881,9 @@ class AsyncDataSet:
|
|||||||
row async for row in self._safe_query(table, sql, where_params, where or {})
|
row async for row in self._safe_query(table, sql, where_params, where or {})
|
||||||
]
|
]
|
||||||
|
|
||||||
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
async def _server_count(
|
||||||
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> int:
|
||||||
# Check if table exists
|
# Check if table exists
|
||||||
if not await self._table_exists(table):
|
if not await self._table_exists(table):
|
||||||
return 0
|
return 0
|
||||||
@ -370,10 +895,10 @@ class AsyncDataSet:
|
|||||||
return next(iter(row.values()), 0)
|
return next(iter(row.values()), 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||||
return (await self.count(table, where)) > 0
|
return (await self._server_count(table, where)) > 0
|
||||||
|
|
||||||
async def kv_set(
|
async def _server_kv_set(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
value: Any,
|
value: Any,
|
||||||
@ -382,9 +907,9 @@ class AsyncDataSet:
|
|||||||
) -> None:
|
) -> None:
|
||||||
tbl = table or self._KV_TABLE
|
tbl = table or self._KV_TABLE
|
||||||
json_val = json.dumps(value, default=str)
|
json_val = json.dumps(value, default=str)
|
||||||
await self.upsert(tbl, {"value": json_val}, {"key": key})
|
await self._server_upsert(tbl, {"value": json_val}, {"key": key})
|
||||||
|
|
||||||
async def kv_get(
|
async def _server_kv_get(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
*,
|
*,
|
||||||
@ -392,7 +917,7 @@ class AsyncDataSet:
|
|||||||
table: str | None = None,
|
table: str | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
tbl = table or self._KV_TABLE
|
tbl = table or self._KV_TABLE
|
||||||
row = await self.get(tbl, {"key": key})
|
row = await self._server_get(tbl, {"key": key})
|
||||||
if not row:
|
if not row:
|
||||||
return default
|
return default
|
||||||
try:
|
try:
|
||||||
@ -400,30 +925,41 @@ class AsyncDataSet:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
|
async def _server_execute_raw(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> Any:
|
||||||
"""Execute raw SQL for complex queries like JOINs."""
|
"""Execute raw SQL for complex queries like JOINs."""
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
cursor = await db.execute(sql, params or ())
|
cursor = await self._db_connection.execute(sql, params or ())
|
||||||
await db.commit()
|
await self._db_connection.commit()
|
||||||
return cursor
|
return cursor
|
||||||
|
|
||||||
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
|
async def _server_query_raw(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Execute raw SQL query and return results as list of dicts."""
|
"""Execute raw SQL query and return results as list of dicts."""
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
db.row_factory = aiosqlite.Row
|
self._db_connection.row_factory = aiosqlite.Row
|
||||||
async with db.execute(sql, params or ()) as cursor:
|
async with self._db_connection.execute(sql, params or ()) as cursor:
|
||||||
return [dict(row) async for row in cursor]
|
return [dict(row) async for row in cursor]
|
||||||
except aiosqlite.OperationalError:
|
except aiosqlite.OperationalError:
|
||||||
# Return empty list if query fails
|
# Return empty list if query fails
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def query_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Dict[str, Any]]:
|
async def _server_query_one(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Execute raw SQL query and return single result."""
|
"""Execute raw SQL query and return single result."""
|
||||||
results = await self.query_raw(sql + " LIMIT 1", params)
|
results = await self._server_query_raw(sql + " LIMIT 1", params)
|
||||||
return results[0] if results else None
|
return results[0] if results else None
|
||||||
|
|
||||||
async def create_table(self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None):
|
async def _server_create_table(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
schema: Dict[str, str],
|
||||||
|
constraints: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""Create table with custom schema and constraints. Always includes default columns."""
|
"""Create table with custom schema and constraints. Always includes default columns."""
|
||||||
# Merge default columns with custom schema
|
# Merge default columns with custom schema
|
||||||
full_schema = self._DEFAULT_COLUMNS.copy()
|
full_schema = self._DEFAULT_COLUMNS.copy()
|
||||||
@ -434,25 +970,31 @@ class AsyncDataSet:
|
|||||||
columns.extend(constraints)
|
columns.extend(constraints)
|
||||||
columns_sql = ", ".join(columns)
|
columns_sql = ", ".join(columns)
|
||||||
|
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with self._db_semaphore:
|
||||||
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
await self._db_connection.execute(
|
||||||
await db.commit()
|
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
|
||||||
|
)
|
||||||
|
await self._db_connection.commit()
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
async def insert_unique(self, table: str, args: Dict[str, Any], unique_fields: List[str]) -> Union[str, None]:
|
async def _server_insert_unique(
|
||||||
|
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||||
|
) -> Union[str, None]:
|
||||||
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
||||||
try:
|
try:
|
||||||
return await self.insert(table, args)
|
return await self._server_insert(table, args)
|
||||||
except aiosqlite.IntegrityError as e:
|
except aiosqlite.IntegrityError as e:
|
||||||
if "UNIQUE" in str(e):
|
if "UNIQUE" in str(e):
|
||||||
return None
|
return None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def transaction(self):
|
async def _server_aggregate(
|
||||||
"""Context manager for transactions."""
|
self,
|
||||||
return TransactionContext(self._file)
|
table: str,
|
||||||
|
function: str,
|
||||||
async def aggregate(self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None) -> Any:
|
column: str = "*",
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Any:
|
||||||
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
||||||
# Check if table exists
|
# Check if table exists
|
||||||
if not await self._table_exists(table):
|
if not await self._table_exists(table):
|
||||||
@ -460,29 +1002,53 @@ class AsyncDataSet:
|
|||||||
|
|
||||||
where_clause, where_params = self._build_where(where)
|
where_clause, where_params = self._build_where(where)
|
||||||
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
||||||
result = await self.query_one(sql, tuple(where_params))
|
result = await self._server_query_one(sql, tuple(where_params))
|
||||||
return result['result'] if result else None
|
return result["result"] if result else None
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the connection or server."""
|
||||||
|
if not self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._client_session:
|
||||||
|
await self._client_session.close()
|
||||||
|
if self._runner:
|
||||||
|
await self._runner.cleanup()
|
||||||
|
if os.path.exists(self._socket_path) and self._is_server:
|
||||||
|
os.unlink(self._socket_path)
|
||||||
|
if self._db_connection:
|
||||||
|
await self._db_connection.close()
|
||||||
|
|
||||||
|
|
||||||
class TransactionContext:
|
class TransactionContext:
|
||||||
"""Context manager for database transactions."""
|
"""Context manager for database transactions."""
|
||||||
|
|
||||||
def __init__(self, db_file: str):
|
def __init__(
|
||||||
self.db_file = db_file
|
self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None
|
||||||
self.conn = None
|
):
|
||||||
|
self.db_connection = db_connection
|
||||||
|
self.semaphore = semaphore
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.conn = await aiosqlite.connect(self.db_file)
|
if self.semaphore:
|
||||||
self.conn.row_factory = aiosqlite.Row
|
await self.semaphore.acquire()
|
||||||
await self.conn.execute("BEGIN")
|
try:
|
||||||
return self.conn
|
await self.db_connection.execute("BEGIN")
|
||||||
|
return self.db_connection
|
||||||
|
except:
|
||||||
|
if self.semaphore:
|
||||||
|
self.semaphore.release()
|
||||||
|
raise
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
try:
|
||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
await self.conn.commit()
|
await self.db_connection.commit()
|
||||||
else:
|
else:
|
||||||
await self.conn.rollback()
|
await self.db_connection.rollback()
|
||||||
await self.conn.close()
|
finally:
|
||||||
|
if self.semaphore:
|
||||||
|
self.semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
# Test cases remain the same but with additional tests for new functionality
|
# Test cases remain the same but with additional tests for new functionality
|
||||||
@ -492,9 +1058,10 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.db_path = Path("temp_test.db")
|
self.db_path = Path("temp_test.db")
|
||||||
if self.db_path.exists():
|
if self.db_path.exists():
|
||||||
self.db_path.unlink()
|
self.db_path.unlink()
|
||||||
self.connector = AsyncDataSet(str(self.db_path))
|
self.connector = AsyncDataSet(str(self.db_path), max_concurrent_queries=1)
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
|
await self.connector.close()
|
||||||
if self.db_path.exists():
|
if self.db_path.exists():
|
||||||
self.db_path.unlink()
|
self.db_path.unlink()
|
||||||
|
|
||||||
@ -551,9 +1118,13 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(id2, id1 + 1)
|
self.assertEqual(id2, id1 + 1)
|
||||||
|
|
||||||
async def test_transaction(self):
|
async def test_transaction(self):
|
||||||
|
await self.connector._ensure_initialized()
|
||||||
|
if self.connector._is_server:
|
||||||
async with self.connector.transaction() as conn:
|
async with self.connector.transaction() as conn:
|
||||||
await conn.execute("INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
await conn.execute(
|
||||||
("test-uid", "John", 30, "2024-01-01", "2024-01-01"))
|
"INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
("test-uid", "John", 30, "2024-01-01", "2024-01-01"),
|
||||||
|
)
|
||||||
# Transaction will be committed
|
# Transaction will be committed
|
||||||
|
|
||||||
rec = await self.connector.get("people", {"name": "John"})
|
rec = await self.connector.get("people", {"name": "John"})
|
||||||
@ -564,7 +1135,7 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
"id": "INTEGER PRIMARY KEY AUTOINCREMENT",
|
"id": "INTEGER PRIMARY KEY AUTOINCREMENT",
|
||||||
"username": "TEXT NOT NULL",
|
"username": "TEXT NOT NULL",
|
||||||
"email": "TEXT NOT NULL",
|
"email": "TEXT NOT NULL",
|
||||||
"score": "INTEGER DEFAULT 0"
|
"score": "INTEGER DEFAULT 0",
|
||||||
}
|
}
|
||||||
constraints = ["UNIQUE(username)", "UNIQUE(email)"]
|
constraints = ["UNIQUE(username)", "UNIQUE(email)"]
|
||||||
|
|
||||||
@ -574,7 +1145,7 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
result = await self.connector.insert_unique(
|
result = await self.connector.insert_unique(
|
||||||
"users",
|
"users",
|
||||||
{"username": "john", "email": "john@example.com"},
|
{"username": "john", "email": "john@example.com"},
|
||||||
["username", "email"]
|
["username", "email"],
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(result)
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
@ -582,7 +1153,7 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
result = await self.connector.insert_unique(
|
result = await self.connector.insert_unique(
|
||||||
"users",
|
"users",
|
||||||
{"username": "john", "email": "different@example.com"},
|
{"username": "john", "email": "different@example.com"},
|
||||||
["username", "email"]
|
["username", "email"],
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
@ -593,14 +1164,20 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertIsNone(await self.connector.get("nonexistent"))
|
self.assertIsNone(await self.connector.get("nonexistent"))
|
||||||
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
|
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
|
||||||
self.assertEqual(await self.connector.delete("nonexistent"), 0)
|
self.assertEqual(await self.connector.delete("nonexistent"), 0)
|
||||||
self.assertEqual(await self.connector.update("nonexistent", {"name": "test"}), 0)
|
self.assertEqual(
|
||||||
|
await self.connector.update("nonexistent", {"name": "test"}), 0
|
||||||
|
)
|
||||||
|
|
||||||
async def test_auto_column_creation(self):
|
async def test_auto_column_creation(self):
|
||||||
# Insert with new columns that don't exist yet
|
# Insert with new columns that don't exist yet
|
||||||
await self.connector.insert("dynamic", {"col1": "value1", "col2": 42, "col3": 3.14})
|
await self.connector.insert(
|
||||||
|
"dynamic", {"col1": "value1", "col2": 42, "col3": 3.14}
|
||||||
|
)
|
||||||
|
|
||||||
# Add more columns in next insert
|
# Add more columns in next insert
|
||||||
await self.connector.insert("dynamic", {"col1": "value2", "col4": True, "col5": None})
|
await self.connector.insert(
|
||||||
|
"dynamic", {"col1": "value2", "col4": True, "col5": None}
|
||||||
|
)
|
||||||
|
|
||||||
# All records should be retrievable
|
# All records should be retrievable
|
||||||
records = await self.connector.find("dynamic")
|
records = await self.connector.find("dynamic")
|
||||||
|
156
ads_test_server.py
Normal file
156
ads_test_server.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from ads import AsyncDataSet
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Initialize AsyncDataSet
|
||||||
|
db = AsyncDataSet(":memory:", socket_path="test_workers.sock")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/write")
|
||||||
|
async def write_data():
|
||||||
|
"""Write process and thread ID to database."""
|
||||||
|
process_id = os.getpid()
|
||||||
|
thread_id = threading.get_ident()
|
||||||
|
timestamp = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# Insert data with process and thread info
|
||||||
|
await db.insert(
|
||||||
|
"requests",
|
||||||
|
{
|
||||||
|
"process_id": process_id,
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"endpoint": "/write",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"process_id": process_id,
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/data", response_class=HTMLResponse)
|
||||||
|
async def get_all_data():
|
||||||
|
"""Return all request data as HTML table."""
|
||||||
|
# Get all records ordered by timestamp
|
||||||
|
records = await db.find("requests", order_by="timestamp DESC")
|
||||||
|
|
||||||
|
# Build HTML table
|
||||||
|
html = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Worker Test Results</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: Arial, sans-serif; margin: 20px; }
|
||||||
|
table { border-collapse: collapse; width: 100%; }
|
||||||
|
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
|
||||||
|
th { background-color: #4CAF50; color: white; }
|
||||||
|
tr:nth-child(even) { background-color: #f2f2f2; }
|
||||||
|
h1 { color: #333; }
|
||||||
|
.stats { margin: 20px 0; padding: 10px; background-color: #e7f3ff; border-radius: 5px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>AsyncDataSet Multi-Worker Test Results</h1>
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
if records:
|
||||||
|
unique_processes = len(set(r["process_id"] for r in records))
|
||||||
|
unique_threads = len(set((r["process_id"], r["thread_id"]) for r in records))
|
||||||
|
total_requests = len(records)
|
||||||
|
|
||||||
|
html += f"""
|
||||||
|
<div class="stats">
|
||||||
|
<strong>Statistics:</strong><br>
|
||||||
|
Total Requests: {total_requests}<br>
|
||||||
|
Unique Processes: {unique_processes}<br>
|
||||||
|
Unique Process-Thread Combinations: {unique_threads}<br>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
html += """
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<th>ID</th>
|
||||||
|
<th>Process ID</th>
|
||||||
|
<th>Thread ID</th>
|
||||||
|
<th>Timestamp</th>
|
||||||
|
<th>Endpoint</th>
|
||||||
|
</tr>
|
||||||
|
"""
|
||||||
|
|
||||||
|
for record in records:
|
||||||
|
html += f"""
|
||||||
|
<tr>
|
||||||
|
<td>{record.get('uid', 'N/A')[:8]}...</td>
|
||||||
|
<td>{record.get('process_id', 'N/A')}</td>
|
||||||
|
<td>{record.get('thread_id', 'N/A')}</td>
|
||||||
|
<td>{record.get('timestamp', 'N/A')}</td>
|
||||||
|
<td>{record.get('endpoint', 'N/A')}</td>
|
||||||
|
</tr>
|
||||||
|
"""
|
||||||
|
|
||||||
|
html += """
|
||||||
|
</table>
|
||||||
|
<p style="margin-top: 20px;">
|
||||||
|
<a href="/write" target="_blank">Make a write request</a> |
|
||||||
|
<a href="/data">Refresh</a>
|
||||||
|
</p>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return html
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint showing process and thread info."""
|
||||||
|
process_id = os.getpid()
|
||||||
|
thread_id = threading.get_ident()
|
||||||
|
|
||||||
|
# Log this request too
|
||||||
|
await db.insert(
|
||||||
|
"requests",
|
||||||
|
{
|
||||||
|
"process_id": process_id,
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"endpoint": "/",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "FastAPI Multi-Worker Test Server",
|
||||||
|
"current_process_id": process_id,
|
||||||
|
"current_thread_id": thread_id,
|
||||||
|
"endpoints": {
|
||||||
|
"/": "This endpoint",
|
||||||
|
"/write": "Write process/thread data (POST)",
|
||||||
|
"/data": "View all data as HTML table (GET)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run with multiple workers
|
||||||
|
# You can also run this from command line:
|
||||||
|
# uvicorn test_server:app --workers 4 --host 0.0.0.0 --port 8000
|
||||||
|
uvicorn.run(
|
||||||
|
"ads_test_server:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8000,
|
||||||
|
workers=50, # Run with 4 worker processes
|
||||||
|
reload=False,
|
||||||
|
)
|
656
adsc.py
Normal file
656
adsc.py
Normal file
@ -0,0 +1,656 @@
|
|||||||
|
import re
|
||||||
|
import json
|
||||||
|
from uuid import uuid4
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
AsyncGenerator,
|
||||||
|
Union,
|
||||||
|
Tuple,
|
||||||
|
Set,
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
import aiosqlite
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDataSet:
|
||||||
|
|
||||||
|
_KV_TABLE = "__kv_store"
|
||||||
|
_DEFAULT_COLUMNS = {
|
||||||
|
"uid": "TEXT PRIMARY KEY",
|
||||||
|
"created_at": "TEXT",
|
||||||
|
"updated_at": "TEXT",
|
||||||
|
"deleted_at": "TEXT",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, file: str):
|
||||||
|
self._file = file
|
||||||
|
self._table_columns_cache: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _utc_iso() -> str:
|
||||||
|
return (
|
||||||
|
datetime.now(timezone.utc)
|
||||||
|
.replace(microsecond=0)
|
||||||
|
.isoformat()
|
||||||
|
.replace("+00:00", "Z")
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _py_to_sqlite_type(value: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return "TEXT"
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "INTEGER"
|
||||||
|
if isinstance(value, int):
|
||||||
|
return "INTEGER"
|
||||||
|
if isinstance(value, float):
|
||||||
|
return "REAL"
|
||||||
|
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||||
|
return "BLOB"
|
||||||
|
return "TEXT"
|
||||||
|
|
||||||
|
async def _get_table_columns(self, table: str) -> Set[str]:
|
||||||
|
"""Get actual columns that exist in the table."""
|
||||||
|
if table in self._table_columns_cache:
|
||||||
|
return self._table_columns_cache[table]
|
||||||
|
|
||||||
|
columns = set()
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
async with db.execute(f"PRAGMA table_info({table})") as cursor:
|
||||||
|
async for row in cursor:
|
||||||
|
columns.add(row[1]) # Column name is at index 1
|
||||||
|
self._table_columns_cache[table] = columns
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return columns
|
||||||
|
|
||||||
|
async def _invalidate_column_cache(self, table: str):
|
||||||
|
"""Invalidate column cache for a table."""
|
||||||
|
if table in self._table_columns_cache:
|
||||||
|
del self._table_columns_cache[table]
|
||||||
|
|
||||||
|
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
||||||
|
col_type = self._py_to_sqlite_type(value)
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
||||||
|
await db.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
except aiosqlite.OperationalError as e:
|
||||||
|
if "duplicate column name" in str(e).lower():
|
||||||
|
pass # Column already exists
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
||||||
|
# Always include default columns
|
||||||
|
cols = self._DEFAULT_COLUMNS.copy()
|
||||||
|
|
||||||
|
# Add columns from col_sources
|
||||||
|
for key, val in col_sources.items():
|
||||||
|
if key not in cols:
|
||||||
|
cols[key] = self._py_to_sqlite_type(val)
|
||||||
|
|
||||||
|
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||||
|
await db.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
|
async def _table_exists(self, table: str) -> bool:
|
||||||
|
"""Check if a table exists."""
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
||||||
|
) as cursor:
|
||||||
|
return await cursor.fetchone() is not None
|
||||||
|
|
||||||
|
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
||||||
|
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_column_from_error(
|
||||||
|
cls, err: aiosqlite.OperationalError
|
||||||
|
) -> Optional[str]:
|
||||||
|
m = cls._RE_NO_COLUMN.search(str(err))
|
||||||
|
return m.group(1) if m else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_table_from_error(
|
||||||
|
cls, err: aiosqlite.OperationalError
|
||||||
|
) -> Optional[str]:
|
||||||
|
m = cls._RE_NO_TABLE.search(str(err))
|
||||||
|
return m.group(1) if m else None
|
||||||
|
|
||||||
|
async def _safe_execute(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
sql: str,
|
||||||
|
params: Iterable[Any],
|
||||||
|
col_sources: Dict[str, Any],
|
||||||
|
max_retries: int = 10,
|
||||||
|
) -> aiosqlite.Cursor:
|
||||||
|
retries = 0
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
cursor = await db.execute(sql, params)
|
||||||
|
await db.commit()
|
||||||
|
return cursor
|
||||||
|
except aiosqlite.OperationalError as err:
|
||||||
|
retries += 1
|
||||||
|
err_str = str(err).lower()
|
||||||
|
|
||||||
|
# Handle missing column
|
||||||
|
col = self._missing_column_from_error(err)
|
||||||
|
if col:
|
||||||
|
if col in col_sources:
|
||||||
|
await self._ensure_column(table, col, col_sources[col])
|
||||||
|
else:
|
||||||
|
# Column not in sources, ensure it with NULL/TEXT type
|
||||||
|
await self._ensure_column(table, col, None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle missing table
|
||||||
|
tbl = self._missing_table_from_error(err)
|
||||||
|
if tbl:
|
||||||
|
await self._ensure_table(tbl, col_sources)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle other column-related errors
|
||||||
|
if "has no column named" in err_str:
|
||||||
|
# Extract column name differently
|
||||||
|
match = re.search(r"table \w+ has no column named (\w+)", err_str)
|
||||||
|
if match:
|
||||||
|
col_name = match.group(1)
|
||||||
|
if col_name in col_sources:
|
||||||
|
await self._ensure_column(
|
||||||
|
table, col_name, col_sources[col_name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._ensure_column(table, col_name, None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise
|
||||||
|
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||||
|
|
||||||
|
async def _filter_existing_columns(
|
||||||
|
self, table: str, data: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Filter data to only include columns that exist in the table."""
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return data
|
||||||
|
|
||||||
|
existing_columns = await self._get_table_columns(table)
|
||||||
|
if not existing_columns:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return {k: v for k, v in data.items() if k in existing_columns}
|
||||||
|
|
||||||
|
async def _safe_query(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
sql: str,
|
||||||
|
params: Iterable[Any],
|
||||||
|
col_sources: Dict[str, Any],
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
# Check if table exists first
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return
|
||||||
|
|
||||||
|
max_retries = 10
|
||||||
|
retries = 0
|
||||||
|
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute(sql, params) as cursor:
|
||||||
|
async for row in cursor:
|
||||||
|
yield dict(row)
|
||||||
|
return
|
||||||
|
except aiosqlite.OperationalError as err:
|
||||||
|
retries += 1
|
||||||
|
err_str = str(err).lower()
|
||||||
|
|
||||||
|
# Handle missing table
|
||||||
|
tbl = self._missing_table_from_error(err)
|
||||||
|
if tbl:
|
||||||
|
# For queries, if table doesn't exist, just return empty
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle missing column in WHERE clause or SELECT
|
||||||
|
if "no such column" in err_str:
|
||||||
|
# For queries with missing columns, return empty
|
||||||
|
return
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
||||||
|
if not where:
|
||||||
|
return "", []
|
||||||
|
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
||||||
|
return " WHERE " + " AND ".join(clauses), list(vals)
|
||||||
|
|
||||||
|
async def insert(
|
||||||
|
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||||
|
) -> Union[str, int]:
|
||||||
|
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
|
||||||
|
uid = str(uuid4())
|
||||||
|
now = self._utc_iso()
|
||||||
|
record = {
|
||||||
|
"uid": uid,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
"deleted_at": None,
|
||||||
|
**args,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure table exists with all needed columns
|
||||||
|
await self._ensure_table(table, record)
|
||||||
|
|
||||||
|
# Handle auto-increment ID if requested
|
||||||
|
if return_id and "id" not in args:
|
||||||
|
# Ensure id column exists
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
# Add id column if it doesn't exist
|
||||||
|
try:
|
||||||
|
await db.execute(
|
||||||
|
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
except aiosqlite.OperationalError as e:
|
||||||
|
if "duplicate column name" not in str(e).lower():
|
||||||
|
# Try without autoincrement constraint
|
||||||
|
try:
|
||||||
|
await db.execute(
|
||||||
|
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
|
# Insert and get lastrowid
|
||||||
|
cols = "`" + "`, `".join(record.keys()) + "`"
|
||||||
|
qs = ", ".join(["?"] * len(record))
|
||||||
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||||
|
cursor = await self._safe_execute(table, sql, list(record.values()), record)
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
cols = "`" + "`, `".join(record) + "`"
|
||||||
|
qs = ", ".join(["?"] * len(record))
|
||||||
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||||
|
await self._safe_execute(table, sql, list(record.values()), record)
|
||||||
|
return uid
|
||||||
|
|
||||||
|
async def update(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> int:
|
||||||
|
if not args:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
args["updated_at"] = self._utc_iso()
|
||||||
|
|
||||||
|
# Ensure all columns exist
|
||||||
|
all_cols = {**args, **(where or {})}
|
||||||
|
await self._ensure_table(table, all_cols)
|
||||||
|
for col, val in all_cols.items():
|
||||||
|
await self._ensure_column(table, col, val)
|
||||||
|
|
||||||
|
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
||||||
|
params = list(args.values()) + where_params
|
||||||
|
cur = await self._safe_execute(table, sql, params, all_cols)
|
||||||
|
return cur.rowcount
|
||||||
|
|
||||||
|
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
sql = f"DELETE FROM {table}{where_clause}"
|
||||||
|
cur = await self._safe_execute(table, sql, where_params, where or {})
|
||||||
|
return cur.rowcount
|
||||||
|
|
||||||
|
async def upsert(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str | None:
|
||||||
|
if not args:
|
||||||
|
raise ValueError("Nothing to update. Empty dict given.")
|
||||||
|
args["updated_at"] = self._utc_iso()
|
||||||
|
affected = await self.update(table, args, where)
|
||||||
|
if affected:
|
||||||
|
rec = await self.get(table, where)
|
||||||
|
return rec.get("uid") if rec else None
|
||||||
|
merged = {**(where or {}), **args}
|
||||||
|
return await self.insert(table, merged)
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
|
||||||
|
async for row in self._safe_query(table, sql, where_params, where or {}):
|
||||||
|
return row
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def find(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
*,
|
||||||
|
limit: int = 0,
|
||||||
|
offset: int = 0,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Find records with optional ordering."""
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
order_clause = f" ORDER BY {order_by}" if order_by else ""
|
||||||
|
extra = (f" LIMIT {limit}" if limit else "") + (
|
||||||
|
f" OFFSET {offset}" if offset else ""
|
||||||
|
)
|
||||||
|
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
|
||||||
|
return [
|
||||||
|
row async for row in self._safe_query(table, sql, where_params, where or {})
|
||||||
|
]
|
||||||
|
|
||||||
|
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
||||||
|
gen = self._safe_query(table, sql, where_params, where or {})
|
||||||
|
async for row in gen:
|
||||||
|
return next(iter(row.values()), 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||||
|
return (await self.count(table, where)) > 0
|
||||||
|
|
||||||
|
async def kv_set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
*,
|
||||||
|
table: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
tbl = table or self._KV_TABLE
|
||||||
|
json_val = json.dumps(value, default=str)
|
||||||
|
await self.upsert(tbl, {"value": json_val}, {"key": key})
|
||||||
|
|
||||||
|
async def kv_get(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
*,
|
||||||
|
default: Any = None,
|
||||||
|
table: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
tbl = table or self._KV_TABLE
|
||||||
|
row = await self.get(tbl, {"key": key})
|
||||||
|
if not row:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return json.loads(row["value"])
|
||||||
|
except Exception:
|
||||||
|
return default
|
||||||
|
|
||||||
|
async def execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
|
||||||
|
"""Execute raw SQL for complex queries like JOINs."""
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
cursor = await db.execute(sql, params or ())
|
||||||
|
await db.commit()
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
async def query_raw(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Execute raw SQL query and return results as list of dicts."""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute(sql, params or ()) as cursor:
|
||||||
|
return [dict(row) async for row in cursor]
|
||||||
|
except aiosqlite.OperationalError:
|
||||||
|
# Return empty list if query fails
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def query_one(
|
||||||
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Execute raw SQL query and return single result."""
|
||||||
|
results = await self.query_raw(sql + " LIMIT 1", params)
|
||||||
|
return results[0] if results else None
|
||||||
|
|
||||||
|
async def create_table(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
schema: Dict[str, str],
|
||||||
|
constraints: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Create table with custom schema and constraints. Always includes default columns."""
|
||||||
|
# Merge default columns with custom schema
|
||||||
|
full_schema = self._DEFAULT_COLUMNS.copy()
|
||||||
|
full_schema.update(schema)
|
||||||
|
|
||||||
|
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
|
||||||
|
if constraints:
|
||||||
|
columns.extend(constraints)
|
||||||
|
columns_sql = ", ".join(columns)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||||
|
await db.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
|
async def insert_unique(
|
||||||
|
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||||
|
) -> Union[str, None]:
|
||||||
|
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
||||||
|
try:
|
||||||
|
return await self.insert(table, args)
|
||||||
|
except aiosqlite.IntegrityError as e:
|
||||||
|
if "UNIQUE" in str(e):
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def transaction(self):
|
||||||
|
"""Context manager for transactions."""
|
||||||
|
return TransactionContext(self._file)
|
||||||
|
|
||||||
|
async def aggregate(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
function: str,
|
||||||
|
column: str = "*",
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return None
|
||||||
|
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
||||||
|
result = await self.query_one(sql, tuple(where_params))
|
||||||
|
return result["result"] if result else None
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionContext:
|
||||||
|
"""Context manager for database transactions."""
|
||||||
|
|
||||||
|
def __init__(self, db_file: str):
|
||||||
|
self.db_file = db_file
|
||||||
|
self.conn = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.conn = await aiosqlite.connect(self.db_file)
|
||||||
|
self.conn.row_factory = aiosqlite.Row
|
||||||
|
await self.conn.execute("BEGIN")
|
||||||
|
return self.conn
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if exc_type is None:
|
||||||
|
await self.conn.commit()
|
||||||
|
else:
|
||||||
|
await self.conn.rollback()
|
||||||
|
await self.conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases remain the same but with additional tests for new functionality
|
||||||
|
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.db_path = Path("temp_test.db")
|
||||||
|
if self.db_path.exists():
|
||||||
|
self.db_path.unlink()
|
||||||
|
self.connector = AsyncDataSet(str(self.db_path))
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
if self.db_path.exists():
|
||||||
|
self.db_path.unlink()
|
||||||
|
|
||||||
|
async def test_insert_and_get(self):
|
||||||
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||||
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
||||||
|
self.assertIsNotNone(rec)
|
||||||
|
self.assertEqual(rec["name"], "John Doe")
|
||||||
|
|
||||||
|
async def test_get_nonexistent(self):
|
||||||
|
result = await self.connector.get("people", {"name": "Jane Doe"})
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
async def test_update(self):
|
||||||
|
await self.connector.insert("people", {"name": "John Doe", "age": 30})
|
||||||
|
await self.connector.update("people", {"age": 31}, {"name": "John Doe"})
|
||||||
|
rec = await self.connector.get("people", {"name": "John Doe"})
|
||||||
|
self.assertEqual(rec["age"], 31)
|
||||||
|
|
||||||
|
async def test_order_by(self):
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
await self.connector.insert("people", {"name": "Charlie", "age": 20})
|
||||||
|
|
||||||
|
results = await self.connector.find("people", order_by="age ASC")
|
||||||
|
self.assertEqual(results[0]["name"], "Charlie")
|
||||||
|
self.assertEqual(results[-1]["name"], "Bob")
|
||||||
|
|
||||||
|
async def test_raw_query(self):
|
||||||
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
||||||
|
await self.connector.insert("people", {"name": "Jane", "age": 25})
|
||||||
|
|
||||||
|
results = await self.connector.query_raw(
|
||||||
|
"SELECT * FROM people WHERE age > ?", (26,)
|
||||||
|
)
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]["name"], "John")
|
||||||
|
|
||||||
|
async def test_aggregate(self):
|
||||||
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
||||||
|
await self.connector.insert("people", {"name": "Jane", "age": 25})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 35})
|
||||||
|
|
||||||
|
avg_age = await self.connector.aggregate("people", "AVG", "age")
|
||||||
|
self.assertEqual(avg_age, 30)
|
||||||
|
|
||||||
|
max_age = await self.connector.aggregate("people", "MAX", "age")
|
||||||
|
self.assertEqual(max_age, 35)
|
||||||
|
|
||||||
|
async def test_insert_with_auto_id(self):
|
||||||
|
# Test auto-increment ID functionality
|
||||||
|
id1 = await self.connector.insert("posts", {"title": "First"}, return_id=True)
|
||||||
|
id2 = await self.connector.insert("posts", {"title": "Second"}, return_id=True)
|
||||||
|
self.assertEqual(id2, id1 + 1)
|
||||||
|
|
||||||
|
async def test_transaction(self):
|
||||||
|
async with self.connector.transaction() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"INSERT INTO people (uid, name, age, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
("test-uid", "John", 30, "2024-01-01", "2024-01-01"),
|
||||||
|
)
|
||||||
|
# Transaction will be committed
|
||||||
|
|
||||||
|
rec = await self.connector.get("people", {"name": "John"})
|
||||||
|
self.assertIsNotNone(rec)
|
||||||
|
|
||||||
|
async def test_create_custom_table(self):
|
||||||
|
schema = {
|
||||||
|
"id": "INTEGER PRIMARY KEY AUTOINCREMENT",
|
||||||
|
"username": "TEXT NOT NULL",
|
||||||
|
"email": "TEXT NOT NULL",
|
||||||
|
"score": "INTEGER DEFAULT 0",
|
||||||
|
}
|
||||||
|
constraints = ["UNIQUE(username)", "UNIQUE(email)"]
|
||||||
|
|
||||||
|
await self.connector.create_table("users", schema, constraints)
|
||||||
|
|
||||||
|
# Test that table was created with constraints
|
||||||
|
result = await self.connector.insert_unique(
|
||||||
|
"users",
|
||||||
|
{"username": "john", "email": "john@example.com"},
|
||||||
|
["username", "email"],
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
|
# Test duplicate insert
|
||||||
|
result = await self.connector.insert_unique(
|
||||||
|
"users",
|
||||||
|
{"username": "john", "email": "different@example.com"},
|
||||||
|
["username", "email"],
|
||||||
|
)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
async def test_missing_table_operations(self):
|
||||||
|
# Test operations on non-existent tables
|
||||||
|
self.assertEqual(await self.connector.count("nonexistent"), 0)
|
||||||
|
self.assertEqual(await self.connector.find("nonexistent"), [])
|
||||||
|
self.assertIsNone(await self.connector.get("nonexistent"))
|
||||||
|
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
|
||||||
|
self.assertEqual(await self.connector.delete("nonexistent"), 0)
|
||||||
|
self.assertEqual(
|
||||||
|
await self.connector.update("nonexistent", {"name": "test"}), 0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_auto_column_creation(self):
|
||||||
|
# Insert with new columns that don't exist yet
|
||||||
|
await self.connector.insert(
|
||||||
|
"dynamic", {"col1": "value1", "col2": 42, "col3": 3.14}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add more columns in next insert
|
||||||
|
await self.connector.insert(
|
||||||
|
"dynamic", {"col1": "value2", "col4": True, "col5": None}
|
||||||
|
)
|
||||||
|
|
||||||
|
# All records should be retrievable
|
||||||
|
records = await self.connector.find("dynamic")
|
||||||
|
self.assertEqual(len(records), 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user