Working version, perfect.
This commit is contained in:
parent
88474ba4df
commit
efcc4a14b6
237
ads.py
237
ads.py
@ -2,7 +2,7 @@ import re
|
|||||||
import json
|
import json
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, AsyncGenerator, Union, Tuple, Set
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import unittest
|
import unittest
|
||||||
@ -13,9 +13,16 @@ import asyncio
|
|||||||
class AsyncDataSet:
|
class AsyncDataSet:
|
||||||
|
|
||||||
_KV_TABLE = "__kv_store"
|
_KV_TABLE = "__kv_store"
|
||||||
|
_DEFAULT_COLUMNS = {
|
||||||
|
"uid": "TEXT PRIMARY KEY",
|
||||||
|
"created_at": "TEXT",
|
||||||
|
"updated_at": "TEXT",
|
||||||
|
"deleted_at": "TEXT",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, file: str):
|
def __init__(self, file: str):
|
||||||
self._file = file
|
self._file = file
|
||||||
|
self._table_columns_cache: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _utc_iso() -> str:
|
def _utc_iso() -> str:
|
||||||
@ -40,26 +47,62 @@ class AsyncDataSet:
|
|||||||
return "BLOB"
|
return "BLOB"
|
||||||
return "TEXT"
|
return "TEXT"
|
||||||
|
|
||||||
|
async def _get_table_columns(self, table: str) -> Set[str]:
|
||||||
|
"""Get actual columns that exist in the table."""
|
||||||
|
if table in self._table_columns_cache:
|
||||||
|
return self._table_columns_cache[table]
|
||||||
|
|
||||||
|
columns = set()
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
async with db.execute(f"PRAGMA table_info({table})") as cursor:
|
||||||
|
async for row in cursor:
|
||||||
|
columns.add(row[1]) # Column name is at index 1
|
||||||
|
self._table_columns_cache[table] = columns
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return columns
|
||||||
|
|
||||||
|
async def _invalidate_column_cache(self, table: str):
|
||||||
|
"""Invalidate column cache for a table."""
|
||||||
|
if table in self._table_columns_cache:
|
||||||
|
del self._table_columns_cache[table]
|
||||||
|
|
||||||
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
||||||
col_type = self._py_to_sqlite_type(value)
|
col_type = self._py_to_sqlite_type(value)
|
||||||
async with aiosqlite.connect(self._file) as db:
|
try:
|
||||||
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
async with aiosqlite.connect(self._file) as db:
|
||||||
await db.commit()
|
await db.execute(f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}")
|
||||||
|
await db.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
except aiosqlite.OperationalError as e:
|
||||||
|
if "duplicate column name" in str(e).lower():
|
||||||
|
pass # Column already exists
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
||||||
cols: Dict[str, str] = {
|
# Always include default columns
|
||||||
"uid": "TEXT PRIMARY KEY",
|
cols = self._DEFAULT_COLUMNS.copy()
|
||||||
"created_at": "TEXT",
|
|
||||||
"updated_at": "TEXT",
|
# Add columns from col_sources
|
||||||
"deleted_at": "TEXT",
|
|
||||||
}
|
|
||||||
for key, val in col_sources.items():
|
for key, val in col_sources.items():
|
||||||
if key not in cols:
|
if key not in cols:
|
||||||
cols[key] = self._py_to_sqlite_type(val)
|
cols[key] = self._py_to_sqlite_type(val)
|
||||||
|
|
||||||
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with aiosqlite.connect(self._file) as db:
|
||||||
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
|
async def _table_exists(self, table: str) -> bool:
|
||||||
|
"""Check if a table exists."""
|
||||||
|
async with aiosqlite.connect(self._file) as db:
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
||||||
|
) as cursor:
|
||||||
|
return await cursor.fetchone() is not None
|
||||||
|
|
||||||
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
||||||
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
||||||
@ -84,25 +127,60 @@ class AsyncDataSet:
|
|||||||
sql: str,
|
sql: str,
|
||||||
params: Iterable[Any],
|
params: Iterable[Any],
|
||||||
col_sources: Dict[str, Any],
|
col_sources: Dict[str, Any],
|
||||||
|
max_retries: int = 10
|
||||||
) -> aiosqlite.Cursor:
|
) -> aiosqlite.Cursor:
|
||||||
while True:
|
retries = 0
|
||||||
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with aiosqlite.connect(self._file) as db:
|
||||||
cursor = await db.execute(sql, params)
|
cursor = await db.execute(sql, params)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return cursor
|
return cursor
|
||||||
except aiosqlite.OperationalError as err:
|
except aiosqlite.OperationalError as err:
|
||||||
|
retries += 1
|
||||||
|
err_str = str(err).lower()
|
||||||
|
|
||||||
|
# Handle missing column
|
||||||
col = self._missing_column_from_error(err)
|
col = self._missing_column_from_error(err)
|
||||||
if col:
|
if col:
|
||||||
if col not in col_sources:
|
if col in col_sources:
|
||||||
raise
|
await self._ensure_column(table, col, col_sources[col])
|
||||||
await self._ensure_column(table, col, col_sources[col])
|
else:
|
||||||
|
# Column not in sources, ensure it with NULL/TEXT type
|
||||||
|
await self._ensure_column(table, col, None)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Handle missing table
|
||||||
tbl = self._missing_table_from_error(err)
|
tbl = self._missing_table_from_error(err)
|
||||||
if tbl:
|
if tbl:
|
||||||
await self._ensure_table(tbl, col_sources)
|
await self._ensure_table(tbl, col_sources)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Handle other column-related errors
|
||||||
|
if "has no column named" in err_str:
|
||||||
|
# Extract column name differently
|
||||||
|
match = re.search(r"table \w+ has no column named (\w+)", err_str)
|
||||||
|
if match:
|
||||||
|
col_name = match.group(1)
|
||||||
|
if col_name in col_sources:
|
||||||
|
await self._ensure_column(table, col_name, col_sources[col_name])
|
||||||
|
else:
|
||||||
|
await self._ensure_column(table, col_name, None)
|
||||||
|
continue
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||||
|
|
||||||
|
async def _filter_existing_columns(self, table: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Filter data to only include columns that exist in the table."""
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return data
|
||||||
|
|
||||||
|
existing_columns = await self._get_table_columns(table)
|
||||||
|
if not existing_columns:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return {k: v for k, v in data.items() if k in existing_columns}
|
||||||
|
|
||||||
async def _safe_query(
|
async def _safe_query(
|
||||||
self,
|
self,
|
||||||
@ -111,7 +189,14 @@ class AsyncDataSet:
|
|||||||
params: Iterable[Any],
|
params: Iterable[Any],
|
||||||
col_sources: Dict[str, Any],
|
col_sources: Dict[str, Any],
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
while True:
|
# Check if table exists first
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return
|
||||||
|
|
||||||
|
max_retries = 10
|
||||||
|
retries = 0
|
||||||
|
|
||||||
|
while retries < max_retries:
|
||||||
try:
|
try:
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with aiosqlite.connect(self._file) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
@ -120,10 +205,20 @@ class AsyncDataSet:
|
|||||||
yield dict(row)
|
yield dict(row)
|
||||||
return
|
return
|
||||||
except aiosqlite.OperationalError as err:
|
except aiosqlite.OperationalError as err:
|
||||||
|
retries += 1
|
||||||
|
err_str = str(err).lower()
|
||||||
|
|
||||||
|
# Handle missing table
|
||||||
tbl = self._missing_table_from_error(err)
|
tbl = self._missing_table_from_error(err)
|
||||||
if tbl:
|
if tbl:
|
||||||
await self._ensure_table(tbl, col_sources)
|
# For queries, if table doesn't exist, just return empty
|
||||||
continue
|
return
|
||||||
|
|
||||||
|
# Handle missing column in WHERE clause or SELECT
|
||||||
|
if "no such column" in err_str:
|
||||||
|
# For queries with missing columns, return empty
|
||||||
|
return
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -145,27 +240,34 @@ class AsyncDataSet:
|
|||||||
**args,
|
**args,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Ensure table exists with all needed columns
|
||||||
|
await self._ensure_table(table, record)
|
||||||
|
|
||||||
# Handle auto-increment ID if requested
|
# Handle auto-increment ID if requested
|
||||||
if return_id and 'id' not in args:
|
if return_id and 'id' not in args:
|
||||||
|
# Ensure id column exists
|
||||||
async with aiosqlite.connect(self._file) as db:
|
async with aiosqlite.connect(self._file) as db:
|
||||||
# Create table with autoincrement ID if needed
|
# Add id column if it doesn't exist
|
||||||
await db.execute(f"""
|
try:
|
||||||
CREATE TABLE IF NOT EXISTS {table} (
|
await db.execute(f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT")
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
await db.commit()
|
||||||
uid TEXT UNIQUE,
|
except aiosqlite.OperationalError as e:
|
||||||
created_at TEXT,
|
if "duplicate column name" not in str(e).lower():
|
||||||
updated_at TEXT,
|
# Try without autoincrement constraint
|
||||||
deleted_at TEXT
|
try:
|
||||||
)
|
await db.execute(f"ALTER TABLE {table} ADD COLUMN id INTEGER")
|
||||||
""")
|
await db.commit()
|
||||||
await db.commit()
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
# Insert and get lastrowid
|
# Insert and get lastrowid
|
||||||
cols = "`" + "`, `".join(record.keys()) + "`"
|
cols = "`" + "`, `".join(record.keys()) + "`"
|
||||||
qs = ", ".join(["?"] * len(record))
|
qs = ", ".join(["?"] * len(record))
|
||||||
cursor = await db.execute(f"INSERT INTO {table} ({cols}) VALUES ({qs})", list(record.values()))
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||||
await db.commit()
|
cursor = await self._safe_execute(table, sql, list(record.values()), record)
|
||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
|
|
||||||
cols = "`" + "`, `".join(record) + "`"
|
cols = "`" + "`, `".join(record) + "`"
|
||||||
qs = ", ".join(["?"] * len(record))
|
qs = ", ".join(["?"] * len(record))
|
||||||
@ -181,15 +283,31 @@ class AsyncDataSet:
|
|||||||
) -> int:
|
) -> int:
|
||||||
if not args:
|
if not args:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return 0
|
||||||
|
|
||||||
args["updated_at"] = self._utc_iso()
|
args["updated_at"] = self._utc_iso()
|
||||||
|
|
||||||
|
# Ensure all columns exist
|
||||||
|
all_cols = {**args, **(where or {})}
|
||||||
|
await self._ensure_table(table, all_cols)
|
||||||
|
for col, val in all_cols.items():
|
||||||
|
await self._ensure_column(table, col, val)
|
||||||
|
|
||||||
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
||||||
where_clause, where_params = self._build_where(where)
|
where_clause, where_params = self._build_where(where)
|
||||||
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
||||||
params = list(args.values()) + where_params
|
params = list(args.values()) + where_params
|
||||||
cur = await self._safe_execute(table, sql, params, {**args, **(where or {})})
|
cur = await self._safe_execute(table, sql, params, all_cols)
|
||||||
return cur.rowcount
|
return cur.rowcount
|
||||||
|
|
||||||
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return 0
|
||||||
|
|
||||||
where_clause, where_params = self._build_where(where)
|
where_clause, where_params = self._build_where(where)
|
||||||
sql = f"DELETE FROM {table}{where_clause}"
|
sql = f"DELETE FROM {table}{where_clause}"
|
||||||
cur = await self._safe_execute(table, sql, where_params, where or {})
|
cur = await self._safe_execute(table, sql, where_params, where or {})
|
||||||
@ -241,6 +359,10 @@ class AsyncDataSet:
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return 0
|
||||||
|
|
||||||
where_clause, where_params = self._build_where(where)
|
where_clause, where_params = self._build_where(where)
|
||||||
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
||||||
gen = self._safe_query(table, sql, where_params, where or {})
|
gen = self._safe_query(table, sql, where_params, where or {})
|
||||||
@ -287,10 +409,14 @@ class AsyncDataSet:
|
|||||||
|
|
||||||
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
|
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
|
||||||
"""Execute raw SQL query and return results as list of dicts."""
|
"""Execute raw SQL query and return results as list of dicts."""
|
||||||
async with aiosqlite.connect(self._file) as db:
|
try:
|
||||||
db.row_factory = aiosqlite.Row
|
async with aiosqlite.connect(self._file) as db:
|
||||||
async with db.execute(sql, params or ()) as cursor:
|
db.row_factory = aiosqlite.Row
|
||||||
return [dict(row) async for row in cursor]
|
async with db.execute(sql, params or ()) as cursor:
|
||||||
|
return [dict(row) async for row in cursor]
|
||||||
|
except aiosqlite.OperationalError:
|
||||||
|
# Return empty list if query fails
|
||||||
|
return []
|
||||||
|
|
||||||
async def query_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Dict[str, Any]]:
|
async def query_one(self, sql: str, params: Optional[Tuple] = None) -> Optional[Dict[str, Any]]:
|
||||||
"""Execute raw SQL query and return single result."""
|
"""Execute raw SQL query and return single result."""
|
||||||
@ -298,8 +424,12 @@ class AsyncDataSet:
|
|||||||
return results[0] if results else None
|
return results[0] if results else None
|
||||||
|
|
||||||
async def create_table(self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None):
|
async def create_table(self, table: str, schema: Dict[str, str], constraints: Optional[List[str]] = None):
|
||||||
"""Create table with custom schema and constraints."""
|
"""Create table with custom schema and constraints. Always includes default columns."""
|
||||||
columns = [f"`{col}` {dtype}" for col, dtype in schema.items()]
|
# Merge default columns with custom schema
|
||||||
|
full_schema = self._DEFAULT_COLUMNS.copy()
|
||||||
|
full_schema.update(schema)
|
||||||
|
|
||||||
|
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
|
||||||
if constraints:
|
if constraints:
|
||||||
columns.extend(constraints)
|
columns.extend(constraints)
|
||||||
columns_sql = ", ".join(columns)
|
columns_sql = ", ".join(columns)
|
||||||
@ -307,6 +437,7 @@ class AsyncDataSet:
|
|||||||
async with aiosqlite.connect(self._file) as db:
|
async with aiosqlite.connect(self._file) as db:
|
||||||
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
await db.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
async def insert_unique(self, table: str, args: Dict[str, Any], unique_fields: List[str]) -> Union[str, None]:
|
async def insert_unique(self, table: str, args: Dict[str, Any], unique_fields: List[str]) -> Union[str, None]:
|
||||||
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
||||||
@ -323,6 +454,10 @@ class AsyncDataSet:
|
|||||||
|
|
||||||
async def aggregate(self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None) -> Any:
|
async def aggregate(self, table: str, function: str, column: str = "*", where: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
||||||
|
# Check if table exists
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return None
|
||||||
|
|
||||||
where_clause, where_params = self._build_where(where)
|
where_clause, where_params = self._build_where(where)
|
||||||
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
||||||
result = await self.query_one(sql, tuple(where_params))
|
result = await self.query_one(sql, tuple(where_params))
|
||||||
@ -450,6 +585,26 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
["username", "email"]
|
["username", "email"]
|
||||||
)
|
)
|
||||||
self.assertIsNone(result)
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
async def test_missing_table_operations(self):
|
||||||
|
# Test operations on non-existent tables
|
||||||
|
self.assertEqual(await self.connector.count("nonexistent"), 0)
|
||||||
|
self.assertEqual(await self.connector.find("nonexistent"), [])
|
||||||
|
self.assertIsNone(await self.connector.get("nonexistent"))
|
||||||
|
self.assertFalse(await self.connector.exists("nonexistent", {"id": 1}))
|
||||||
|
self.assertEqual(await self.connector.delete("nonexistent"), 0)
|
||||||
|
self.assertEqual(await self.connector.update("nonexistent", {"name": "test"}), 0)
|
||||||
|
|
||||||
|
async def test_auto_column_creation(self):
|
||||||
|
# Insert with new columns that don't exist yet
|
||||||
|
await self.connector.insert("dynamic", {"col1": "value1", "col2": 42, "col3": 3.14})
|
||||||
|
|
||||||
|
# Add more columns in next insert
|
||||||
|
await self.connector.insert("dynamic", {"col1": "value2", "col4": True, "col5": None})
|
||||||
|
|
||||||
|
# All records should be retrievable
|
||||||
|
records = await self.connector.find("dynamic")
|
||||||
|
self.assertEqual(len(records), 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user