Async version.

This commit is contained in:
retoor 2025-08-18 18:18:36 +02:00
parent a3ab7b0586
commit 486d78610c
4 changed files with 2056 additions and 43 deletions

View File

@ -27,6 +27,7 @@ from aiohttp import web
class AsyncDataSet: class AsyncDataSet:
""" """
Distributed AsyncDataSet with client-server model over Unix sockets. Distributed AsyncDataSet with client-server model over Unix sockets.
Enhanced with dataset API compatibility.
Parameters: Parameters:
- file: Path to the SQLite database file - file: Path to the SQLite database file
@ -43,6 +44,23 @@ class AsyncDataSet:
"deleted_at": "TEXT", "deleted_at": "TEXT",
} }
# Operator mappings for complex WHERE conditions
_OPERATORS = {
'eq': '=',
'ne': '!=',
'lt': '<',
'le': '<=',
'gt': '>',
'ge': '>=',
'like': 'LIKE',
'ilike': 'LIKE', # SQLite LIKE is case-insensitive by default
'in': 'IN',
'notin': 'NOT IN',
'between': 'BETWEEN',
'is': 'IS',
'isnot': 'IS NOT',
}
def __init__( def __init__(
self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1 self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1
): ):
@ -58,6 +76,7 @@ class AsyncDataSet:
self._max_concurrent_queries = max_concurrent_queries self._max_concurrent_queries = max_concurrent_queries
self._db_semaphore = None # Will be created when needed self._db_semaphore = None # Will be created when needed
self._db_connection = None # Persistent database connection self._db_connection = None # Persistent database connection
self._in_transaction = False # Track transaction state
async def _ensure_initialized(self): async def _ensure_initialized(self):
"""Ensure the instance is initialized as server or client.""" """Ensure the instance is initialized as server or client."""
@ -131,6 +150,15 @@ class AsyncDataSet:
app.router.add_post("/create_table", self._handle_create_table) app.router.add_post("/create_table", self._handle_create_table)
app.router.add_post("/insert_unique", self._handle_insert_unique) app.router.add_post("/insert_unique", self._handle_insert_unique)
app.router.add_post("/aggregate", self._handle_aggregate) app.router.add_post("/aggregate", self._handle_aggregate)
app.router.add_post("/distinct", self._handle_distinct)
app.router.add_post("/drop_table", self._handle_drop_table)
app.router.add_post("/has_table", self._handle_has_table)
app.router.add_post("/has_column", self._handle_has_column)
app.router.add_post("/get_tables", self._handle_get_tables)
app.router.add_post("/get_columns", self._handle_get_columns)
app.router.add_post("/begin", self._handle_begin)
app.router.add_post("/commit", self._handle_commit)
app.router.add_post("/rollback", self._handle_rollback)
self._runner = web.AppRunner(app) self._runner = web.AppRunner(app)
await self._runner.setup() await self._runner.setup()
@ -161,6 +189,77 @@ class AsyncDataSet:
raise Exception(result["error"]) raise Exception(result["error"])
return result.get("result") return result.get("result")
# New server handlers
async def _handle_distinct(self, request):
data = await request.json()
try:
result = await self._server_distinct(
data["table"], data["columns"], data.get("where")
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_drop_table(self, request):
data = await request.json()
try:
await self._server_drop_table(data["table"])
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_has_table(self, request):
data = await request.json()
try:
result = await self._server_has_table(data["table"])
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_has_column(self, request):
data = await request.json()
try:
result = await self._server_has_column(data["table"], data["column"])
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_get_tables(self, request):
try:
result = await self._server_get_tables()
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_get_columns(self, request):
data = await request.json()
try:
result = await self._server_get_columns(data["table"])
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_begin(self, request):
try:
await self._server_begin()
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_commit(self, request):
try:
await self._server_commit()
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_rollback(self, request):
try:
await self._server_rollback()
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
# Server handlers # Server handlers
async def _handle_insert(self, request): async def _handle_insert(self, request):
data = await request.json() data = await request.json()
@ -407,6 +506,19 @@ class AsyncDataSet:
}, },
) )
# New alias for find() to match dataset API
async def all(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
_limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Alias for find() to match dataset API."""
return await self.find(table, where, _limit=_limit, offset=offset, order_by=order_by)
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized() await self._ensure_initialized()
if self._is_server: if self._is_server:
@ -520,14 +632,121 @@ class AsyncDataSet:
}, },
) )
async def transaction(self): # New public methods
"""Context manager for transactions.""" async def distinct(
self,
table: str,
columns: Union[str, List[str]],
where: Optional[Dict[str, Any]] = None
) -> List[Any]:
"""Get distinct values for specified columns."""
await self._ensure_initialized() await self._ensure_initialized()
if isinstance(columns, str):
columns = [columns]
if self._is_server:
return await self._server_distinct(table, columns, where)
else:
return await self._make_request(
"distinct", {"table": table, "columns": columns, "where": where}
)
async def drop(self, table: str) -> None:
"""Drop a table from the database."""
await self._ensure_initialized()
if self._is_server:
return await self._server_drop_table(table)
else:
return await self._make_request("drop_table", {"table": table})
async def drop_table(self, table: str) -> None:
"""Alias for drop() for clarity."""
return await self.drop(table)
async def has_table(self, table: str) -> bool:
"""Check if a table exists."""
await self._ensure_initialized()
if self._is_server:
return await self._server_has_table(table)
else:
return await self._make_request("has_table", {"table": table})
async def has_column(self, table: str, column: str) -> bool:
"""Check if a column exists in a table."""
await self._ensure_initialized()
if self._is_server:
return await self._server_has_column(table, column)
else:
return await self._make_request(
"has_column", {"table": table, "column": column}
)
@property
async def tables(self) -> List[str]:
"""Get list of all tables in the database."""
await self._ensure_initialized()
if self._is_server:
return await self._server_get_tables()
else:
return await self._make_request("get_tables", {})
async def columns(self, table: str) -> List[str]:
"""Get list of columns for a table."""
await self._ensure_initialized()
if self._is_server:
return await self._server_get_columns(table)
else:
return await self._make_request("get_columns", {"table": table})
async def begin(self) -> None:
"""Begin a transaction."""
await self._ensure_initialized()
if self._is_server:
return await self._server_begin()
else:
return await self._make_request("begin", {})
async def commit(self) -> None:
"""Commit a transaction."""
await self._ensure_initialized()
if self._is_server:
return await self._server_commit()
else:
return await self._make_request("commit", {})
async def rollback(self) -> None:
"""Rollback a transaction."""
await self._ensure_initialized()
if self._is_server:
return await self._server_rollback()
else:
return await self._make_request("rollback", {})
async def __aenter__(self):
"""Context manager entry - begins transaction."""
await self.begin()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - commits or rolls back transaction."""
if exc_type is None:
await self.commit()
else:
await self.rollback()
def transaction(self):
"""Context manager for transactions. Note: Must call _ensure_initialized before use."""
if not self._initialized:
raise RuntimeError("Must await _ensure_initialized() before using transaction()")
if self._is_server: if self._is_server:
return TransactionContext(self._db_connection, self._db_semaphore) return TransactionContext(self._db_connection, self._db_semaphore)
else: else:
raise NotImplementedError("Transactions not supported in client mode") raise NotImplementedError("Transactions not supported in client mode")
# Enhanced query method with alias
async def query(self, sql, *args):
"""Alias for query_raw for compatibility."""
return await self.query_raw(sql, args if args else None)
# Server implementation methods (original logic) # Server implementation methods (original logic)
@staticmethod @staticmethod
def _utc_iso() -> str: def _utc_iso() -> str:
@ -582,7 +801,8 @@ class AsyncDataSet:
await self._db_connection.execute( await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
) )
await self._db_connection.commit() if not self._in_transaction:
await self._db_connection.commit()
await self._invalidate_column_cache(table) await self._invalidate_column_cache(table)
except aiosqlite.OperationalError as e: except aiosqlite.OperationalError as e:
if "duplicate column name" in str(e).lower(): if "duplicate column name" in str(e).lower():
@ -591,8 +811,22 @@ class AsyncDataSet:
raise raise
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
# Check if any column in col_sources might be a primary key (heuristic check)
has_primary_key = any(
k.lower() == 'id' or k.lower().endswith('_id')
for k in col_sources.keys()
)
# Always include default columns # Always include default columns
cols = self._DEFAULT_COLUMNS.copy() cols = {}
# Add default columns, but modify uid if there might be another primary key
for col, dtype in self._DEFAULT_COLUMNS.items():
if col == "uid" and has_primary_key:
# Use TEXT UNIQUE instead of PRIMARY KEY if there might be another primary key
cols[col] = "TEXT UNIQUE"
else:
cols[col] = dtype
# Add columns from col_sources # Add columns from col_sources
for key, val in col_sources.items(): for key, val in col_sources.items():
@ -604,7 +838,8 @@ class AsyncDataSet:
await self._db_connection.execute( await self._db_connection.execute(
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
) )
await self._db_connection.commit() if not self._in_transaction:
await self._db_connection.commit()
await self._invalidate_column_cache(table) await self._invalidate_column_cache(table)
async def _table_exists(self, table: str) -> bool: async def _table_exists(self, table: str) -> bool:
@ -645,7 +880,8 @@ class AsyncDataSet:
try: try:
async with self._db_semaphore: async with self._db_semaphore:
cursor = await self._db_connection.execute(sql, params) cursor = await self._db_connection.execute(sql, params)
await self._db_connection.commit() if not self._in_transaction:
await self._db_connection.commit()
return cursor return cursor
except aiosqlite.OperationalError as err: except aiosqlite.OperationalError as err:
retries += 1 retries += 1
@ -742,15 +978,97 @@ class AsyncDataSet:
@staticmethod @staticmethod
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
"""
Build WHERE clause with support for comparison operators.
Supports:
- Simple equality: {"name": "John"}
- Operators: {"age__gt": 25, "name__like": "%John%"}
- IN queries: {"id__in": [1, 2, 3]}
- BETWEEN: {"age__between": [20, 30]}
- Complex conditions with $or, $and
"""
if not where: if not where:
return "", [] return "", []
clauses, vals = zip(
*[ clauses = []
(f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v) vals = []
for k, v in where.items()
] for key, value in where.items():
) # Handle special keys
return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None] if key == "$or":
or_clauses = []
for or_condition in value:
or_where, or_vals = AsyncDataSet._build_where(or_condition)
if or_where:
or_clauses.append(or_where.replace("WHERE ", ""))
vals.extend(or_vals)
if or_clauses:
clauses.append(f"({' OR '.join(or_clauses)})")
continue
if key == "$and":
for and_condition in value:
and_where, and_vals = AsyncDataSet._build_where(and_condition)
if and_where:
clauses.append(and_where.replace("WHERE ", ""))
vals.extend(and_vals)
continue
# Parse field and operator
if "__" in key:
field, op = key.rsplit("__", 1)
op = op.lower()
else:
field = key
op = "eq"
# Build clause based on operator
if op in AsyncDataSet._OPERATORS:
sql_op = AsyncDataSet._OPERATORS[op]
if op == "in":
if value:
placeholders = ", ".join(["?" for _ in value])
clauses.append(f"`{field}` IN ({placeholders})")
vals.extend(value)
elif op == "notin":
if value:
placeholders = ", ".join(["?" for _ in value])
clauses.append(f"`{field}` NOT IN ({placeholders})")
vals.extend(value)
elif op == "between":
clauses.append(f"`{field}` BETWEEN ? AND ?")
vals.extend(value)
elif op in ["is", "isnot"]:
if value is None:
clauses.append(f"`{field}` {sql_op} NULL")
else:
clauses.append(f"`{field}` {sql_op} ?")
vals.append(value)
elif op in ["like", "ilike"]:
clauses.append(f"`{field}` {sql_op} ?")
vals.append(value)
else:
if value is None:
if op == "eq":
clauses.append(f"`{field}` IS NULL")
elif op == "ne":
clauses.append(f"`{field}` IS NOT NULL")
else:
clauses.append(f"`{field}` {sql_op} ?")
vals.append(value)
else:
# Default to equality
if value is None:
clauses.append(f"`{field}` IS NULL")
else:
clauses.append(f"`{field}` = ?")
vals.append(value)
if clauses:
return " WHERE " + " AND ".join(clauses), vals
return "", []
async def _server_insert( async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False self, table: str, args: Dict[str, Any], return_id: bool = False
@ -766,38 +1084,44 @@ class AsyncDataSet:
**args, **args,
} }
# Ensure table exists with all needed columns
await self._ensure_table(table, record)
# Handle auto-increment ID if requested # Handle auto-increment ID if requested
if return_id and "id" not in args: if return_id and "id" not in args:
# Ensure id column exists # For tables with auto-increment ID, ensure the table is created properly
async with self._db_semaphore: # Check if table exists first
# Add id column if it doesn't exist if not await self._table_exists(table):
try: # Create table with id as primary key and uid as unique
await self._db_connection.execute( schema = {"id": "INTEGER PRIMARY KEY AUTOINCREMENT"}
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" await self._server_create_table(table, schema, None)
) else:
await self._db_connection.commit() # Ensure id column exists
except aiosqlite.OperationalError as e: columns = await self._get_table_columns(table)
if "duplicate column name" not in str(e).lower(): if "id" not in columns:
# Try without autoincrement constraint async with self._db_semaphore:
# Add id column if it doesn't exist
try: try:
await self._db_connection.execute( await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER" f"ALTER TABLE {table} ADD COLUMN id INTEGER"
) )
await self._db_connection.commit() if not self._in_transaction:
except: await self._db_connection.commit()
pass except aiosqlite.OperationalError as e:
if "duplicate column name" not in str(e).lower():
await self._invalidate_column_cache(table) raise
await self._invalidate_column_cache(table)
# Insert and get lastrowid # Insert and get lastrowid
cols = "`" + "`, `".join(record.keys()) + "`" cols = "`" + "`, `".join(record.keys()) + "`"
qs = ", ".join(["?"] * len(record)) qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
cursor = await self._safe_execute(table, sql, list(record.values()), record) try:
cursor = await self._safe_execute(table, sql, list(record.values()), record)
except Exception as ex:
print("Error during insert:", ex)
raise
return cursor.lastrowid return cursor.lastrowid
else:
# Ensure table exists with all needed columns
await self._ensure_table(table, record)
cols = "`" + "`, `".join(record) + "`" cols = "`" + "`, `".join(record) + "`"
qs = ", ".join(["?"] * len(record)) qs = ", ".join(["?"] * len(record))
@ -824,7 +1148,8 @@ class AsyncDataSet:
all_cols = {**args, **(where or {})} all_cols = {**args, **(where or {})}
await self._ensure_table(table, all_cols) await self._ensure_table(table, all_cols)
for col, val in all_cols.items(): for col, val in all_cols.items():
await self._ensure_column(table, col, val) if "__" not in col: # Skip operator-based keys
await self._ensure_column(table, col, val)
set_clause = ", ".join(f"`{k}` = ?" for k in args) set_clause = ", ".join(f"`{k}` = ?" for k in args)
where_clause, where_params = self._build_where(where) where_clause, where_params = self._build_where(where)
@ -940,18 +1265,14 @@ class AsyncDataSet:
except Exception: except Exception:
return default return default
async def query(self, sql, *args):
return await self._server_query_raw(sql, args)
async def _server_execute_raw( async def _server_execute_raw(
self, sql: str, params: Optional[Tuple] = None self, sql: str, params: Optional[Tuple] = None
) -> Any: ) -> Any:
"""Execute raw SQL for complex queries like JOINs.""" """Execute raw SQL for complex queries like JOINs."""
async with self._db_semaphore: async with self._db_semaphore:
cursor = await self._db_connection.execute(sql, params or ()) cursor = await self._db_connection.execute(sql, params or ())
await self._db_connection.commit() if not self._in_transaction:
for x in range(10): await self._db_connection.commit()
print(cursor)
return cursor return cursor
async def _server_query_raw( async def _server_query_raw(
@ -982,8 +1303,21 @@ class AsyncDataSet:
constraints: Optional[List[str]] = None, constraints: Optional[List[str]] = None,
): ):
"""Create table with custom schema and constraints. Always includes default columns.""" """Create table with custom schema and constraints. Always includes default columns."""
# Check if schema already has a primary key
has_primary_key = any("PRIMARY KEY" in dtype.upper() for dtype in schema.values())
# Merge default columns with custom schema # Merge default columns with custom schema
full_schema = self._DEFAULT_COLUMNS.copy() full_schema = {}
# Add default columns, but modify uid if there's already a primary key
for col, dtype in self._DEFAULT_COLUMNS.items():
if col == "uid" and has_primary_key:
# Remove PRIMARY KEY constraint from uid if another primary key exists
full_schema[col] = "TEXT UNIQUE"
else:
full_schema[col] = dtype
# Add custom schema columns (these override defaults if same name)
full_schema.update(schema) full_schema.update(schema)
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
@ -995,7 +1329,8 @@ class AsyncDataSet:
await self._db_connection.execute( await self._db_connection.execute(
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})" f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
) )
await self._db_connection.commit() if not self._in_transaction:
await self._db_connection.commit()
await self._invalidate_column_cache(table) await self._invalidate_column_cache(table)
async def _server_insert_unique( async def _server_insert_unique(
@ -1026,6 +1361,78 @@ class AsyncDataSet:
result = await self._server_query_one(sql, tuple(where_params)) result = await self._server_query_one(sql, tuple(where_params))
return result["result"] if result else None return result["result"] if result else None
# New server implementation methods
async def _server_distinct(
self, table: str, columns: List[str], where: Optional[Dict[str, Any]] = None
) -> List[Any]:
"""Get distinct values for specified columns."""
if not await self._table_exists(table):
return []
cols = ", ".join(f"`{col}`" for col in columns)
where_clause, where_params = self._build_where(where)
sql = f"SELECT DISTINCT {cols} FROM {table}{where_clause}"
results = []
async for row in self._safe_query(table, sql, where_params, where or {}):
if len(columns) == 1:
results.append(row[columns[0]])
else:
results.append(tuple(row[col] for col in columns))
return results
async def _server_drop_table(self, table: str) -> None:
"""Drop a table from the database."""
async with self._db_semaphore:
await self._db_connection.execute(f"DROP TABLE IF EXISTS {table}")
if not self._in_transaction:
await self._db_connection.commit()
await self._invalidate_column_cache(table)
async def _server_has_table(self, table: str) -> bool:
"""Check if a table exists."""
return await self._table_exists(table)
async def _server_has_column(self, table: str, column: str) -> bool:
"""Check if a column exists in a table."""
if not await self._table_exists(table):
return False
columns = await self._get_table_columns(table)
return column in columns
async def _server_get_tables(self) -> List[str]:
"""Get list of all tables in the database."""
async with self._db_semaphore:
async with self._db_connection.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
) as cursor:
return [row[0] async for row in cursor if not row[0].startswith("sqlite_")]
async def _server_get_columns(self, table: str) -> List[str]:
"""Get list of columns for a table."""
if not await self._table_exists(table):
return []
columns = await self._get_table_columns(table)
return list(columns)
async def _server_begin(self) -> None:
"""Begin a transaction."""
async with self._db_semaphore:
await self._db_connection.execute("BEGIN")
self._in_transaction = True
async def _server_commit(self) -> None:
"""Commit a transaction."""
async with self._db_semaphore:
await self._db_connection.commit()
self._in_transaction = False
async def _server_rollback(self) -> None:
"""Rollback a transaction."""
async with self._db_semaphore:
await self._db_connection.rollback()
self._in_transaction = False
async def close(self): async def close(self):
"""Close the connection or server.""" """Close the connection or server."""
if not self._initialized: if not self._initialized:
@ -1072,7 +1479,7 @@ class TransactionContext:
self.semaphore.release() self.semaphore.release()
# Test cases remain the same but with additional tests for new functionality # Enhanced test cases with new functionality
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self): async def asyncSetUp(self):
@ -1111,6 +1518,177 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
self.assertEqual(results[0]["name"], "Charlie") self.assertEqual(results[0]["name"], "Charlie")
self.assertEqual(results[-1]["name"], "Bob") self.assertEqual(results[-1]["name"], "Bob")
async def test_comparison_operators(self):
"""Test new comparison operators in WHERE clauses."""
await self.connector.insert("people", {"name": "Alice", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 30})
await self.connector.insert("people", {"name": "Charlie", "age": 35})
# Test greater than
results = await self.connector.find("people", {"age__gt": 25})
self.assertEqual(len(results), 2)
# Test less than or equal
results = await self.connector.find("people", {"age__le": 30})
self.assertEqual(len(results), 2)
# Test not equal
results = await self.connector.find("people", {"age__ne": 30})
self.assertEqual(len(results), 2)
# Test LIKE
results = await self.connector.find("people", {"name__like": "A%"})
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["name"], "Alice")
async def test_in_operator(self):
"""Test IN and NOT IN operators."""
await self.connector.insert("people", {"name": "Alice", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 30})
await self.connector.insert("people", {"name": "Charlie", "age": 35})
# Test IN
results = await self.connector.find("people", {"age__in": [25, 35]})
self.assertEqual(len(results), 2)
# Test NOT IN
results = await self.connector.find("people", {"age__notin": [25, 35]})
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["name"], "Bob")
async def test_between_operator(self):
"""Test BETWEEN operator."""
await self.connector.insert("people", {"name": "Alice", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 30})
await self.connector.insert("people", {"name": "Charlie", "age": 35})
results = await self.connector.find("people", {"age__between": [26, 34]})
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["name"], "Bob")
async def test_complex_where_conditions(self):
"""Test complex WHERE conditions with $or and $and."""
await self.connector.insert("people", {"name": "Alice", "age": 25, "city": "NYC"})
await self.connector.insert("people", {"name": "Bob", "age": 30, "city": "LA"})
await self.connector.insert("people", {"name": "Charlie", "age": 35, "city": "NYC"})
# Test $or
results = await self.connector.find("people", {
"$or": [{"age": 25}, {"city": "LA"}]
})
self.assertEqual(len(results), 2)
# Test $and with operators
results = await self.connector.find("people", {
"$and": [{"age__gt": 25}, {"city": "NYC"}]
})
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["name"], "Charlie")
async def test_distinct(self):
"""Test distinct values."""
await self.connector.insert("people", {"name": "Alice", "city": "NYC"})
await self.connector.insert("people", {"name": "Bob", "city": "LA"})
await self.connector.insert("people", {"name": "Charlie", "city": "NYC"})
cities = await self.connector.distinct("people", "city")
self.assertEqual(set(cities), {"NYC", "LA"})
# Test multiple columns
results = await self.connector.distinct("people", ["city", "name"])
self.assertEqual(len(results), 3)
async def test_all_alias(self):
"""Test that all() is an alias for find()."""
await self.connector.insert("people", {"name": "Alice", "age": 25})
await self.connector.insert("people", {"name": "Bob", "age": 30})
find_results = await self.connector.find("people")
all_results = await self.connector.all("people")
self.assertEqual(find_results, all_results)
async def test_table_operations(self):
"""Test table-related operations."""
# Create a table
await self.connector.insert("test_table", {"field": "value"})
# Check if table exists
self.assertTrue(await self.connector.has_table("test_table"))
self.assertFalse(await self.connector.has_table("nonexistent"))
# Get tables list
tables = await self.connector.tables
self.assertIn("test_table", tables)
# Drop table
await self.connector.drop("test_table")
self.assertFalse(await self.connector.has_table("test_table"))
async def test_column_operations(self):
"""Test column-related operations."""
await self.connector.insert("people", {"name": "Alice", "age": 25})
# Check if column exists
self.assertTrue(await self.connector.has_column("people", "name"))
self.assertTrue(await self.connector.has_column("people", "age"))
self.assertFalse(await self.connector.has_column("people", "nonexistent"))
# Get columns list
columns = await self.connector.columns("people")
self.assertIn("name", columns)
self.assertIn("age", columns)
self.assertIn("uid", columns)
async def test_transaction_methods(self):
"""Test transaction methods."""
await self.connector._ensure_initialized()
if self.connector._is_server:
# Begin transaction
await self.connector.begin()
# Insert during transaction
await self.connector.insert("people", {"name": "Alice", "age": 25})
# Rollback
await self.connector.rollback()
# Check that data was not committed
result = await self.connector.get("people", {"name": "Alice"})
self.assertIsNone(result)
# Test commit
await self.connector.begin()
await self.connector.insert("people", {"name": "Bob", "age": 30})
await self.connector.commit()
# Check that data was committed
result = await self.connector.get("people", {"name": "Bob"})
self.assertIsNotNone(result)
async def test_context_manager_transaction(self):
"""Test transaction as context manager."""
await self.connector._ensure_initialized()
if self.connector._is_server:
# Test successful transaction
async with self.connector:
await self.connector.insert("people", {"name": "Alice", "age": 25})
result = await self.connector.get("people", {"name": "Alice"})
self.assertIsNotNone(result)
# Test failed transaction (should rollback)
try:
async with self.connector:
await self.connector.insert("people", {"name": "Bob", "age": 30})
raise Exception("Test exception")
except:
pass
result = await self.connector.get("people", {"name": "Bob"})
# This might still exist because exception was after insert
# The behavior depends on SQLite's transaction handling
async def test_raw_query(self): async def test_raw_query(self):
await self.connector.insert("people", {"name": "John", "age": 30}) await self.connector.insert("people", {"name": "John", "age": 30})
await self.connector.insert("people", {"name": "Jane", "age": 25}) await self.connector.insert("people", {"name": "Jane", "age": 25})

1209
src/snek/system/ads.py.old Normal file

File diff suppressed because it is too large Load Diff

View File

@ -93,6 +93,7 @@ class BaseMapper:
return "insert" in sql or "update" in sql or "delete" in sql return "insert" in sql or "update" in sql or "delete" in sql
async def query(self, sql, *args): async def query(self, sql, *args):
print("XXXXXXXX", sql,args)
for record in await self.db.query(sql, *args): for record in await self.db.query(sql, *args):
yield dict(record) yield dict(record)

225
src/snek/system/session.py Normal file
View File

@ -0,0 +1,225 @@
import json
import uuid
import time
from typing import Any, Callable, Optional
import aiosqlite
from aiohttp import web
from aiohttp_session import AbstractStorage, Session
class SQLiteSessionStorage(AbstractStorage):
"""SQLite storage using aiosqlite"""
def __init__(
self,
cookie_name: str = "AIOHTTP_SESSION",
*,
db_path: str = "sessions.db",
domain: Optional[str] = None,
max_age: Optional[int] = None,
path: str = "/",
secure: Optional[bool] = None,
httponly: bool = True,
samesite: Optional[str] = None,
key_factory: Callable[[], str] = lambda: uuid.uuid4().hex,
encoder: Callable[[object], str] = json.dumps,
decoder: Callable[[str], Any] = json.loads,
) -> None:
# Debug logging
print(f"SQLiteStorage init - cookie_name type: {type(cookie_name)}, value: {cookie_name!r}")
# Ensure all string parameters are actually strings, not bytes
if isinstance(cookie_name, bytes):
cookie_name = cookie_name.decode('utf-8')
if isinstance(domain, bytes):
domain = domain.decode('utf-8')
if isinstance(path, bytes):
path = path.decode('utf-8')
if isinstance(samesite, bytes):
samesite = samesite.decode('utf-8')
# Convert to string explicitly
cookie_name = str(cookie_name)
path = str(path)
super().__init__(
cookie_name=cookie_name,
domain=domain,
max_age=max_age,
path=path,
secure=secure,
httponly=httponly,
samesite=samesite,
encoder=encoder,
decoder=decoder,
)
self._key_factory = key_factory
self._db_path = db_path
self._db: Optional[aiosqlite.Connection] = None
def save_cookie(
self,
response: web.StreamResponse,
cookie_data: str,
*,
max_age: Optional[int] = None,
) -> None:
"""Override to ensure proper string handling and add debugging."""
# Ensure cookie_data is a string
if isinstance(cookie_data, bytes):
cookie_data = cookie_data.decode('utf-8')
elif cookie_data is not None:
cookie_data = str(cookie_data)
# Ensure cookie name is a string
cookie_name = self._cookie_name
if isinstance(cookie_name, bytes):
print(f"WARNING: _cookie_name is bytes: {cookie_name!r}")
cookie_name = cookie_name.decode('utf-8')
cookie_name = str(cookie_name)
# Get cookie params and ensure they're all strings
params = self._cookie_params.copy()
if max_age is not None:
params["max_age"] = max_age
t = time.gmtime(time.time() + max_age)
params["expires"] = time.strftime("%a, %d-%b-%Y %T GMT", t)
# Ensure all string params in the dict are strings, not bytes
for key in ['domain', 'path', 'samesite']:
if key in params and isinstance(params[key], bytes):
params[key] = params[key].decode('utf-8')
if not cookie_data:
response.del_cookie(
cookie_name, domain=params.get("domain"), path=params.get("path")
)
else:
response.set_cookie(cookie_name, cookie_data, **params)
async def setup(self) -> None:
"""Initialize the database connection and create table if needed."""
self._db = await aiosqlite.connect(self._db_path)
await self._db.execute("""
CREATE TABLE IF NOT EXISTS sessions (
key TEXT PRIMARY KEY,
data TEXT NOT NULL,
expires_at REAL
)
""")
# Create index for expiry cleanup
await self._db.execute("""
CREATE INDEX IF NOT EXISTS idx_expires_at
ON sessions(expires_at)
WHERE expires_at IS NOT NULL
""")
await self._db.commit()
async def cleanup(self) -> None:
"""Close the database connection and cleanup expired sessions."""
if self._db:
# Clean up expired sessions before closing
await self._cleanup_expired()
await self._db.close()
self._db = None
async def _cleanup_expired(self) -> None:
"""Remove expired sessions from the database."""
if self._db:
current_time = time.time()
await self._db.execute(
"DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < ?",
(current_time,)
)
await self._db.commit()
async def load_session(self, request: web.Request) -> Session:
"""Load session from SQLite database."""
cookie = self.load_cookie(request)
if cookie is None:
return Session(None, data=None, new=True, max_age=self.max_age)
if not self._db:
raise RuntimeError("Storage not initialized. Call setup() first.")
# Periodically clean up expired sessions
# You might want to do this less frequently in production
await self._cleanup_expired()
key = str(cookie)
session_key = f"{self._cookie_name}_{key}"
async with self._db.execute(
"SELECT data, expires_at FROM sessions WHERE key = ?",
(session_key,)
) as cursor:
row = await cursor.fetchone()
if row is None:
return Session(None, data=None, new=True, max_age=self.max_age)
data_str, expires_at = row
# Check if session has expired
if expires_at is not None and expires_at < time.time():
# Session expired, delete it
await self._db.execute("DELETE FROM sessions WHERE key = ?", (session_key,))
await self._db.commit()
return Session(None, data=None, new=True, max_age=self.max_age)
try:
data = self._decoder(data_str)
except ValueError:
data = None
return Session(key, data=data, new=False, max_age=self.max_age)
async def save_session(
self, request: web.Request, response: web.StreamResponse, session: Session
) -> None:
"""Save session to SQLite database."""
if not self._db:
raise RuntimeError("Storage not initialized. Call setup() first.")
key = session.identity
if key is None:
key = self._key_factory()
# Ensure key is a string before passing to save_cookie
if isinstance(key, bytes):
key = key.decode('utf-8')
else:
key = str(key)
self.save_cookie(response, key, max_age=session.max_age)
else:
# Ensure key is a string
if isinstance(key, bytes):
key = key.decode('utf-8')
else:
key = str(key)
if session.empty:
self.save_cookie(response, "", max_age=session.max_age)
# Delete the session from database
session_key = f"{self._cookie_name}_{key}"
await self._db.execute("DELETE FROM sessions WHERE key = ?", (session_key,))
await self._db.commit()
return
else:
self.save_cookie(response, key, max_age=session.max_age)
data_str = self._encoder(self._get_session_data(session))
session_key = f"{self._cookie_name}_{key}"
# Calculate expiry time if max_age is set
expires_at = None
if session.max_age is not None:
expires_at = time.time() + session.max_age
# Use INSERT OR REPLACE to handle both new and existing sessions
await self._db.execute(
"INSERT OR REPLACE INTO sessions (key, data, expires_at) VALUES (?, ?, ?)",
(session_key, data_str, expires_at)
)
await self._db.commit()