Async version.
This commit is contained in:
parent
a3ab7b0586
commit
486d78610c
@ -27,6 +27,7 @@ from aiohttp import web
|
|||||||
class AsyncDataSet:
|
class AsyncDataSet:
|
||||||
"""
|
"""
|
||||||
Distributed AsyncDataSet with client-server model over Unix sockets.
|
Distributed AsyncDataSet with client-server model over Unix sockets.
|
||||||
|
Enhanced with dataset API compatibility.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- file: Path to the SQLite database file
|
- file: Path to the SQLite database file
|
||||||
@ -43,6 +44,23 @@ class AsyncDataSet:
|
|||||||
"deleted_at": "TEXT",
|
"deleted_at": "TEXT",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Operator mappings for complex WHERE conditions
|
||||||
|
_OPERATORS = {
|
||||||
|
'eq': '=',
|
||||||
|
'ne': '!=',
|
||||||
|
'lt': '<',
|
||||||
|
'le': '<=',
|
||||||
|
'gt': '>',
|
||||||
|
'ge': '>=',
|
||||||
|
'like': 'LIKE',
|
||||||
|
'ilike': 'LIKE', # SQLite LIKE is case-insensitive by default
|
||||||
|
'in': 'IN',
|
||||||
|
'notin': 'NOT IN',
|
||||||
|
'between': 'BETWEEN',
|
||||||
|
'is': 'IS',
|
||||||
|
'isnot': 'IS NOT',
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1
|
self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1
|
||||||
):
|
):
|
||||||
@ -58,6 +76,7 @@ class AsyncDataSet:
|
|||||||
self._max_concurrent_queries = max_concurrent_queries
|
self._max_concurrent_queries = max_concurrent_queries
|
||||||
self._db_semaphore = None # Will be created when needed
|
self._db_semaphore = None # Will be created when needed
|
||||||
self._db_connection = None # Persistent database connection
|
self._db_connection = None # Persistent database connection
|
||||||
|
self._in_transaction = False # Track transaction state
|
||||||
|
|
||||||
async def _ensure_initialized(self):
|
async def _ensure_initialized(self):
|
||||||
"""Ensure the instance is initialized as server or client."""
|
"""Ensure the instance is initialized as server or client."""
|
||||||
@ -131,6 +150,15 @@ class AsyncDataSet:
|
|||||||
app.router.add_post("/create_table", self._handle_create_table)
|
app.router.add_post("/create_table", self._handle_create_table)
|
||||||
app.router.add_post("/insert_unique", self._handle_insert_unique)
|
app.router.add_post("/insert_unique", self._handle_insert_unique)
|
||||||
app.router.add_post("/aggregate", self._handle_aggregate)
|
app.router.add_post("/aggregate", self._handle_aggregate)
|
||||||
|
app.router.add_post("/distinct", self._handle_distinct)
|
||||||
|
app.router.add_post("/drop_table", self._handle_drop_table)
|
||||||
|
app.router.add_post("/has_table", self._handle_has_table)
|
||||||
|
app.router.add_post("/has_column", self._handle_has_column)
|
||||||
|
app.router.add_post("/get_tables", self._handle_get_tables)
|
||||||
|
app.router.add_post("/get_columns", self._handle_get_columns)
|
||||||
|
app.router.add_post("/begin", self._handle_begin)
|
||||||
|
app.router.add_post("/commit", self._handle_commit)
|
||||||
|
app.router.add_post("/rollback", self._handle_rollback)
|
||||||
|
|
||||||
self._runner = web.AppRunner(app)
|
self._runner = web.AppRunner(app)
|
||||||
await self._runner.setup()
|
await self._runner.setup()
|
||||||
@ -161,6 +189,77 @@ class AsyncDataSet:
|
|||||||
raise Exception(result["error"])
|
raise Exception(result["error"])
|
||||||
return result.get("result")
|
return result.get("result")
|
||||||
|
|
||||||
|
# New server handlers
|
||||||
|
async def _handle_distinct(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_distinct(
|
||||||
|
data["table"], data["columns"], data.get("where")
|
||||||
|
)
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_drop_table(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
await self._server_drop_table(data["table"])
|
||||||
|
return web.json_response({"result": None})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_has_table(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_has_table(data["table"])
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_has_column(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_has_column(data["table"], data["column"])
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_get_tables(self, request):
|
||||||
|
try:
|
||||||
|
result = await self._server_get_tables()
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_get_columns(self, request):
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
result = await self._server_get_columns(data["table"])
|
||||||
|
return web.json_response({"result": result})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_begin(self, request):
|
||||||
|
try:
|
||||||
|
await self._server_begin()
|
||||||
|
return web.json_response({"result": None})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_commit(self, request):
|
||||||
|
try:
|
||||||
|
await self._server_commit()
|
||||||
|
return web.json_response({"result": None})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
async def _handle_rollback(self, request):
|
||||||
|
try:
|
||||||
|
await self._server_rollback()
|
||||||
|
return web.json_response({"result": None})
|
||||||
|
except Exception as e:
|
||||||
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
# Server handlers
|
# Server handlers
|
||||||
async def _handle_insert(self, request):
|
async def _handle_insert(self, request):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
@ -407,6 +506,19 @@ class AsyncDataSet:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# New alias for find() to match dataset API
|
||||||
|
async def all(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
*,
|
||||||
|
_limit: int = 0,
|
||||||
|
offset: int = 0,
|
||||||
|
order_by: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Alias for find() to match dataset API."""
|
||||||
|
return await self.find(table, where, _limit=_limit, offset=offset, order_by=order_by)
|
||||||
|
|
||||||
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
if self._is_server:
|
if self._is_server:
|
||||||
@ -520,14 +632,121 @@ class AsyncDataSet:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def transaction(self):
|
# New public methods
|
||||||
"""Context manager for transactions."""
|
async def distinct(
|
||||||
|
self,
|
||||||
|
table: str,
|
||||||
|
columns: Union[str, List[str]],
|
||||||
|
where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Get distinct values for specified columns."""
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
if isinstance(columns, str):
|
||||||
|
columns = [columns]
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_distinct(table, columns, where)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"distinct", {"table": table, "columns": columns, "where": where}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def drop(self, table: str) -> None:
|
||||||
|
"""Drop a table from the database."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_drop_table(table)
|
||||||
|
else:
|
||||||
|
return await self._make_request("drop_table", {"table": table})
|
||||||
|
|
||||||
|
async def drop_table(self, table: str) -> None:
|
||||||
|
"""Alias for drop() for clarity."""
|
||||||
|
return await self.drop(table)
|
||||||
|
|
||||||
|
async def has_table(self, table: str) -> bool:
|
||||||
|
"""Check if a table exists."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_has_table(table)
|
||||||
|
else:
|
||||||
|
return await self._make_request("has_table", {"table": table})
|
||||||
|
|
||||||
|
async def has_column(self, table: str, column: str) -> bool:
|
||||||
|
"""Check if a column exists in a table."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_has_column(table, column)
|
||||||
|
else:
|
||||||
|
return await self._make_request(
|
||||||
|
"has_column", {"table": table, "column": column}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def tables(self) -> List[str]:
|
||||||
|
"""Get list of all tables in the database."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_get_tables()
|
||||||
|
else:
|
||||||
|
return await self._make_request("get_tables", {})
|
||||||
|
|
||||||
|
async def columns(self, table: str) -> List[str]:
|
||||||
|
"""Get list of columns for a table."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_get_columns(table)
|
||||||
|
else:
|
||||||
|
return await self._make_request("get_columns", {"table": table})
|
||||||
|
|
||||||
|
async def begin(self) -> None:
|
||||||
|
"""Begin a transaction."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_begin()
|
||||||
|
else:
|
||||||
|
return await self._make_request("begin", {})
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
"""Commit a transaction."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_commit()
|
||||||
|
else:
|
||||||
|
return await self._make_request("commit", {})
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
"""Rollback a transaction."""
|
||||||
|
await self._ensure_initialized()
|
||||||
|
if self._is_server:
|
||||||
|
return await self._server_rollback()
|
||||||
|
else:
|
||||||
|
return await self._make_request("rollback", {})
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Context manager entry - begins transaction."""
|
||||||
|
await self.begin()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit - commits or rolls back transaction."""
|
||||||
|
if exc_type is None:
|
||||||
|
await self.commit()
|
||||||
|
else:
|
||||||
|
await self.rollback()
|
||||||
|
|
||||||
|
def transaction(self):
|
||||||
|
"""Context manager for transactions. Note: Must call _ensure_initialized before use."""
|
||||||
|
if not self._initialized:
|
||||||
|
raise RuntimeError("Must await _ensure_initialized() before using transaction()")
|
||||||
if self._is_server:
|
if self._is_server:
|
||||||
return TransactionContext(self._db_connection, self._db_semaphore)
|
return TransactionContext(self._db_connection, self._db_semaphore)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Transactions not supported in client mode")
|
raise NotImplementedError("Transactions not supported in client mode")
|
||||||
|
|
||||||
|
# Enhanced query method with alias
|
||||||
|
async def query(self, sql, *args):
|
||||||
|
"""Alias for query_raw for compatibility."""
|
||||||
|
return await self.query_raw(sql, args if args else None)
|
||||||
|
|
||||||
# Server implementation methods (original logic)
|
# Server implementation methods (original logic)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _utc_iso() -> str:
|
def _utc_iso() -> str:
|
||||||
@ -582,6 +801,7 @@ class AsyncDataSet:
|
|||||||
await self._db_connection.execute(
|
await self._db_connection.execute(
|
||||||
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
|
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
|
||||||
)
|
)
|
||||||
|
if not self._in_transaction:
|
||||||
await self._db_connection.commit()
|
await self._db_connection.commit()
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
except aiosqlite.OperationalError as e:
|
except aiosqlite.OperationalError as e:
|
||||||
@ -591,8 +811,22 @@ class AsyncDataSet:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
||||||
|
# Check if any column in col_sources might be a primary key (heuristic check)
|
||||||
|
has_primary_key = any(
|
||||||
|
k.lower() == 'id' or k.lower().endswith('_id')
|
||||||
|
for k in col_sources.keys()
|
||||||
|
)
|
||||||
|
|
||||||
# Always include default columns
|
# Always include default columns
|
||||||
cols = self._DEFAULT_COLUMNS.copy()
|
cols = {}
|
||||||
|
|
||||||
|
# Add default columns, but modify uid if there might be another primary key
|
||||||
|
for col, dtype in self._DEFAULT_COLUMNS.items():
|
||||||
|
if col == "uid" and has_primary_key:
|
||||||
|
# Use TEXT UNIQUE instead of PRIMARY KEY if there might be another primary key
|
||||||
|
cols[col] = "TEXT UNIQUE"
|
||||||
|
else:
|
||||||
|
cols[col] = dtype
|
||||||
|
|
||||||
# Add columns from col_sources
|
# Add columns from col_sources
|
||||||
for key, val in col_sources.items():
|
for key, val in col_sources.items():
|
||||||
@ -604,6 +838,7 @@ class AsyncDataSet:
|
|||||||
await self._db_connection.execute(
|
await self._db_connection.execute(
|
||||||
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
|
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
|
||||||
)
|
)
|
||||||
|
if not self._in_transaction:
|
||||||
await self._db_connection.commit()
|
await self._db_connection.commit()
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
@ -645,6 +880,7 @@ class AsyncDataSet:
|
|||||||
try:
|
try:
|
||||||
async with self._db_semaphore:
|
async with self._db_semaphore:
|
||||||
cursor = await self._db_connection.execute(sql, params)
|
cursor = await self._db_connection.execute(sql, params)
|
||||||
|
if not self._in_transaction:
|
||||||
await self._db_connection.commit()
|
await self._db_connection.commit()
|
||||||
return cursor
|
return cursor
|
||||||
except aiosqlite.OperationalError as err:
|
except aiosqlite.OperationalError as err:
|
||||||
@ -742,15 +978,97 @@ class AsyncDataSet:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
||||||
|
"""
|
||||||
|
Build WHERE clause with support for comparison operators.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Simple equality: {"name": "John"}
|
||||||
|
- Operators: {"age__gt": 25, "name__like": "%John%"}
|
||||||
|
- IN queries: {"id__in": [1, 2, 3]}
|
||||||
|
- BETWEEN: {"age__between": [20, 30]}
|
||||||
|
- Complex conditions with $or, $and
|
||||||
|
"""
|
||||||
if not where:
|
if not where:
|
||||||
return "", []
|
return "", []
|
||||||
clauses, vals = zip(
|
|
||||||
*[
|
clauses = []
|
||||||
(f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v)
|
vals = []
|
||||||
for k, v in where.items()
|
|
||||||
]
|
for key, value in where.items():
|
||||||
)
|
# Handle special keys
|
||||||
return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None]
|
if key == "$or":
|
||||||
|
or_clauses = []
|
||||||
|
for or_condition in value:
|
||||||
|
or_where, or_vals = AsyncDataSet._build_where(or_condition)
|
||||||
|
if or_where:
|
||||||
|
or_clauses.append(or_where.replace("WHERE ", ""))
|
||||||
|
vals.extend(or_vals)
|
||||||
|
if or_clauses:
|
||||||
|
clauses.append(f"({' OR '.join(or_clauses)})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key == "$and":
|
||||||
|
for and_condition in value:
|
||||||
|
and_where, and_vals = AsyncDataSet._build_where(and_condition)
|
||||||
|
if and_where:
|
||||||
|
clauses.append(and_where.replace("WHERE ", ""))
|
||||||
|
vals.extend(and_vals)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse field and operator
|
||||||
|
if "__" in key:
|
||||||
|
field, op = key.rsplit("__", 1)
|
||||||
|
op = op.lower()
|
||||||
|
else:
|
||||||
|
field = key
|
||||||
|
op = "eq"
|
||||||
|
|
||||||
|
# Build clause based on operator
|
||||||
|
if op in AsyncDataSet._OPERATORS:
|
||||||
|
sql_op = AsyncDataSet._OPERATORS[op]
|
||||||
|
|
||||||
|
if op == "in":
|
||||||
|
if value:
|
||||||
|
placeholders = ", ".join(["?" for _ in value])
|
||||||
|
clauses.append(f"`{field}` IN ({placeholders})")
|
||||||
|
vals.extend(value)
|
||||||
|
elif op == "notin":
|
||||||
|
if value:
|
||||||
|
placeholders = ", ".join(["?" for _ in value])
|
||||||
|
clauses.append(f"`{field}` NOT IN ({placeholders})")
|
||||||
|
vals.extend(value)
|
||||||
|
elif op == "between":
|
||||||
|
clauses.append(f"`{field}` BETWEEN ? AND ?")
|
||||||
|
vals.extend(value)
|
||||||
|
elif op in ["is", "isnot"]:
|
||||||
|
if value is None:
|
||||||
|
clauses.append(f"`{field}` {sql_op} NULL")
|
||||||
|
else:
|
||||||
|
clauses.append(f"`{field}` {sql_op} ?")
|
||||||
|
vals.append(value)
|
||||||
|
elif op in ["like", "ilike"]:
|
||||||
|
clauses.append(f"`{field}` {sql_op} ?")
|
||||||
|
vals.append(value)
|
||||||
|
else:
|
||||||
|
if value is None:
|
||||||
|
if op == "eq":
|
||||||
|
clauses.append(f"`{field}` IS NULL")
|
||||||
|
elif op == "ne":
|
||||||
|
clauses.append(f"`{field}` IS NOT NULL")
|
||||||
|
else:
|
||||||
|
clauses.append(f"`{field}` {sql_op} ?")
|
||||||
|
vals.append(value)
|
||||||
|
else:
|
||||||
|
# Default to equality
|
||||||
|
if value is None:
|
||||||
|
clauses.append(f"`{field}` IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append(f"`{field}` = ?")
|
||||||
|
vals.append(value)
|
||||||
|
|
||||||
|
if clauses:
|
||||||
|
return " WHERE " + " AND ".join(clauses), vals
|
||||||
|
return "", []
|
||||||
|
|
||||||
async def _server_insert(
|
async def _server_insert(
|
||||||
self, table: str, args: Dict[str, Any], return_id: bool = False
|
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||||
@ -766,38 +1084,44 @@ class AsyncDataSet:
|
|||||||
**args,
|
**args,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Ensure table exists with all needed columns
|
|
||||||
await self._ensure_table(table, record)
|
|
||||||
|
|
||||||
# Handle auto-increment ID if requested
|
# Handle auto-increment ID if requested
|
||||||
if return_id and "id" not in args:
|
if return_id and "id" not in args:
|
||||||
|
# For tables with auto-increment ID, ensure the table is created properly
|
||||||
|
# Check if table exists first
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
# Create table with id as primary key and uid as unique
|
||||||
|
schema = {"id": "INTEGER PRIMARY KEY AUTOINCREMENT"}
|
||||||
|
await self._server_create_table(table, schema, None)
|
||||||
|
else:
|
||||||
# Ensure id column exists
|
# Ensure id column exists
|
||||||
|
columns = await self._get_table_columns(table)
|
||||||
|
if "id" not in columns:
|
||||||
async with self._db_semaphore:
|
async with self._db_semaphore:
|
||||||
# Add id column if it doesn't exist
|
# Add id column if it doesn't exist
|
||||||
try:
|
try:
|
||||||
await self._db_connection.execute(
|
await self._db_connection.execute(
|
||||||
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
|
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
|
||||||
)
|
)
|
||||||
|
if not self._in_transaction:
|
||||||
await self._db_connection.commit()
|
await self._db_connection.commit()
|
||||||
except aiosqlite.OperationalError as e:
|
except aiosqlite.OperationalError as e:
|
||||||
if "duplicate column name" not in str(e).lower():
|
if "duplicate column name" not in str(e).lower():
|
||||||
# Try without autoincrement constraint
|
raise
|
||||||
try:
|
|
||||||
await self._db_connection.execute(
|
|
||||||
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
|
|
||||||
)
|
|
||||||
await self._db_connection.commit()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
# Insert and get lastrowid
|
# Insert and get lastrowid
|
||||||
cols = "`" + "`, `".join(record.keys()) + "`"
|
cols = "`" + "`, `".join(record.keys()) + "`"
|
||||||
qs = ", ".join(["?"] * len(record))
|
qs = ", ".join(["?"] * len(record))
|
||||||
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||||
|
try:
|
||||||
cursor = await self._safe_execute(table, sql, list(record.values()), record)
|
cursor = await self._safe_execute(table, sql, list(record.values()), record)
|
||||||
|
except Exception as ex:
|
||||||
|
print("Error during insert:", ex)
|
||||||
|
raise
|
||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
|
else:
|
||||||
|
# Ensure table exists with all needed columns
|
||||||
|
await self._ensure_table(table, record)
|
||||||
|
|
||||||
cols = "`" + "`, `".join(record) + "`"
|
cols = "`" + "`, `".join(record) + "`"
|
||||||
qs = ", ".join(["?"] * len(record))
|
qs = ", ".join(["?"] * len(record))
|
||||||
@ -824,6 +1148,7 @@ class AsyncDataSet:
|
|||||||
all_cols = {**args, **(where or {})}
|
all_cols = {**args, **(where or {})}
|
||||||
await self._ensure_table(table, all_cols)
|
await self._ensure_table(table, all_cols)
|
||||||
for col, val in all_cols.items():
|
for col, val in all_cols.items():
|
||||||
|
if "__" not in col: # Skip operator-based keys
|
||||||
await self._ensure_column(table, col, val)
|
await self._ensure_column(table, col, val)
|
||||||
|
|
||||||
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
||||||
@ -940,18 +1265,14 @@ class AsyncDataSet:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
async def query(self, sql, *args):
|
|
||||||
return await self._server_query_raw(sql, args)
|
|
||||||
|
|
||||||
async def _server_execute_raw(
|
async def _server_execute_raw(
|
||||||
self, sql: str, params: Optional[Tuple] = None
|
self, sql: str, params: Optional[Tuple] = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Execute raw SQL for complex queries like JOINs."""
|
"""Execute raw SQL for complex queries like JOINs."""
|
||||||
async with self._db_semaphore:
|
async with self._db_semaphore:
|
||||||
cursor = await self._db_connection.execute(sql, params or ())
|
cursor = await self._db_connection.execute(sql, params or ())
|
||||||
|
if not self._in_transaction:
|
||||||
await self._db_connection.commit()
|
await self._db_connection.commit()
|
||||||
for x in range(10):
|
|
||||||
print(cursor)
|
|
||||||
return cursor
|
return cursor
|
||||||
|
|
||||||
async def _server_query_raw(
|
async def _server_query_raw(
|
||||||
@ -982,8 +1303,21 @@ class AsyncDataSet:
|
|||||||
constraints: Optional[List[str]] = None,
|
constraints: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""Create table with custom schema and constraints. Always includes default columns."""
|
"""Create table with custom schema and constraints. Always includes default columns."""
|
||||||
|
# Check if schema already has a primary key
|
||||||
|
has_primary_key = any("PRIMARY KEY" in dtype.upper() for dtype in schema.values())
|
||||||
|
|
||||||
# Merge default columns with custom schema
|
# Merge default columns with custom schema
|
||||||
full_schema = self._DEFAULT_COLUMNS.copy()
|
full_schema = {}
|
||||||
|
|
||||||
|
# Add default columns, but modify uid if there's already a primary key
|
||||||
|
for col, dtype in self._DEFAULT_COLUMNS.items():
|
||||||
|
if col == "uid" and has_primary_key:
|
||||||
|
# Remove PRIMARY KEY constraint from uid if another primary key exists
|
||||||
|
full_schema[col] = "TEXT UNIQUE"
|
||||||
|
else:
|
||||||
|
full_schema[col] = dtype
|
||||||
|
|
||||||
|
# Add custom schema columns (these override defaults if same name)
|
||||||
full_schema.update(schema)
|
full_schema.update(schema)
|
||||||
|
|
||||||
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
|
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
|
||||||
@ -995,6 +1329,7 @@ class AsyncDataSet:
|
|||||||
await self._db_connection.execute(
|
await self._db_connection.execute(
|
||||||
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
|
f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})"
|
||||||
)
|
)
|
||||||
|
if not self._in_transaction:
|
||||||
await self._db_connection.commit()
|
await self._db_connection.commit()
|
||||||
await self._invalidate_column_cache(table)
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
@ -1026,6 +1361,78 @@ class AsyncDataSet:
|
|||||||
result = await self._server_query_one(sql, tuple(where_params))
|
result = await self._server_query_one(sql, tuple(where_params))
|
||||||
return result["result"] if result else None
|
return result["result"] if result else None
|
||||||
|
|
||||||
|
# New server implementation methods
|
||||||
|
async def _server_distinct(
|
||||||
|
self, table: str, columns: List[str], where: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Get distinct values for specified columns."""
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return []
|
||||||
|
|
||||||
|
cols = ", ".join(f"`{col}`" for col in columns)
|
||||||
|
where_clause, where_params = self._build_where(where)
|
||||||
|
sql = f"SELECT DISTINCT {cols} FROM {table}{where_clause}"
|
||||||
|
|
||||||
|
results = []
|
||||||
|
async for row in self._safe_query(table, sql, where_params, where or {}):
|
||||||
|
if len(columns) == 1:
|
||||||
|
results.append(row[columns[0]])
|
||||||
|
else:
|
||||||
|
results.append(tuple(row[col] for col in columns))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _server_drop_table(self, table: str) -> None:
|
||||||
|
"""Drop a table from the database."""
|
||||||
|
async with self._db_semaphore:
|
||||||
|
await self._db_connection.execute(f"DROP TABLE IF EXISTS {table}")
|
||||||
|
if not self._in_transaction:
|
||||||
|
await self._db_connection.commit()
|
||||||
|
await self._invalidate_column_cache(table)
|
||||||
|
|
||||||
|
async def _server_has_table(self, table: str) -> bool:
|
||||||
|
"""Check if a table exists."""
|
||||||
|
return await self._table_exists(table)
|
||||||
|
|
||||||
|
async def _server_has_column(self, table: str, column: str) -> bool:
|
||||||
|
"""Check if a column exists in a table."""
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return False
|
||||||
|
columns = await self._get_table_columns(table)
|
||||||
|
return column in columns
|
||||||
|
|
||||||
|
async def _server_get_tables(self) -> List[str]:
|
||||||
|
"""Get list of all tables in the database."""
|
||||||
|
async with self._db_semaphore:
|
||||||
|
async with self._db_connection.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
) as cursor:
|
||||||
|
return [row[0] async for row in cursor if not row[0].startswith("sqlite_")]
|
||||||
|
|
||||||
|
async def _server_get_columns(self, table: str) -> List[str]:
|
||||||
|
"""Get list of columns for a table."""
|
||||||
|
if not await self._table_exists(table):
|
||||||
|
return []
|
||||||
|
columns = await self._get_table_columns(table)
|
||||||
|
return list(columns)
|
||||||
|
|
||||||
|
async def _server_begin(self) -> None:
|
||||||
|
"""Begin a transaction."""
|
||||||
|
async with self._db_semaphore:
|
||||||
|
await self._db_connection.execute("BEGIN")
|
||||||
|
self._in_transaction = True
|
||||||
|
|
||||||
|
async def _server_commit(self) -> None:
|
||||||
|
"""Commit a transaction."""
|
||||||
|
async with self._db_semaphore:
|
||||||
|
await self._db_connection.commit()
|
||||||
|
self._in_transaction = False
|
||||||
|
|
||||||
|
async def _server_rollback(self) -> None:
|
||||||
|
"""Rollback a transaction."""
|
||||||
|
async with self._db_semaphore:
|
||||||
|
await self._db_connection.rollback()
|
||||||
|
self._in_transaction = False
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the connection or server."""
|
"""Close the connection or server."""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
@ -1072,7 +1479,7 @@ class TransactionContext:
|
|||||||
self.semaphore.release()
|
self.semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
# Test cases remain the same but with additional tests for new functionality
|
# Enhanced test cases with new functionality
|
||||||
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
@ -1111,6 +1518,177 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(results[0]["name"], "Charlie")
|
self.assertEqual(results[0]["name"], "Charlie")
|
||||||
self.assertEqual(results[-1]["name"], "Bob")
|
self.assertEqual(results[-1]["name"], "Bob")
|
||||||
|
|
||||||
|
async def test_comparison_operators(self):
|
||||||
|
"""Test new comparison operators in WHERE clauses."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
await self.connector.insert("people", {"name": "Charlie", "age": 35})
|
||||||
|
|
||||||
|
# Test greater than
|
||||||
|
results = await self.connector.find("people", {"age__gt": 25})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
# Test less than or equal
|
||||||
|
results = await self.connector.find("people", {"age__le": 30})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
# Test not equal
|
||||||
|
results = await self.connector.find("people", {"age__ne": 30})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
# Test LIKE
|
||||||
|
results = await self.connector.find("people", {"name__like": "A%"})
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]["name"], "Alice")
|
||||||
|
|
||||||
|
async def test_in_operator(self):
|
||||||
|
"""Test IN and NOT IN operators."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
await self.connector.insert("people", {"name": "Charlie", "age": 35})
|
||||||
|
|
||||||
|
# Test IN
|
||||||
|
results = await self.connector.find("people", {"age__in": [25, 35]})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
# Test NOT IN
|
||||||
|
results = await self.connector.find("people", {"age__notin": [25, 35]})
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]["name"], "Bob")
|
||||||
|
|
||||||
|
async def test_between_operator(self):
|
||||||
|
"""Test BETWEEN operator."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
await self.connector.insert("people", {"name": "Charlie", "age": 35})
|
||||||
|
|
||||||
|
results = await self.connector.find("people", {"age__between": [26, 34]})
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]["name"], "Bob")
|
||||||
|
|
||||||
|
async def test_complex_where_conditions(self):
|
||||||
|
"""Test complex WHERE conditions with $or and $and."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25, "city": "NYC"})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30, "city": "LA"})
|
||||||
|
await self.connector.insert("people", {"name": "Charlie", "age": 35, "city": "NYC"})
|
||||||
|
|
||||||
|
# Test $or
|
||||||
|
results = await self.connector.find("people", {
|
||||||
|
"$or": [{"age": 25}, {"city": "LA"}]
|
||||||
|
})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
# Test $and with operators
|
||||||
|
results = await self.connector.find("people", {
|
||||||
|
"$and": [{"age__gt": 25}, {"city": "NYC"}]
|
||||||
|
})
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]["name"], "Charlie")
|
||||||
|
|
||||||
|
async def test_distinct(self):
|
||||||
|
"""Test distinct values."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "city": "NYC"})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "city": "LA"})
|
||||||
|
await self.connector.insert("people", {"name": "Charlie", "city": "NYC"})
|
||||||
|
|
||||||
|
cities = await self.connector.distinct("people", "city")
|
||||||
|
self.assertEqual(set(cities), {"NYC", "LA"})
|
||||||
|
|
||||||
|
# Test multiple columns
|
||||||
|
results = await self.connector.distinct("people", ["city", "name"])
|
||||||
|
self.assertEqual(len(results), 3)
|
||||||
|
|
||||||
|
async def test_all_alias(self):
|
||||||
|
"""Test that all() is an alias for find()."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
|
||||||
|
find_results = await self.connector.find("people")
|
||||||
|
all_results = await self.connector.all("people")
|
||||||
|
|
||||||
|
self.assertEqual(find_results, all_results)
|
||||||
|
|
||||||
|
async def test_table_operations(self):
|
||||||
|
"""Test table-related operations."""
|
||||||
|
# Create a table
|
||||||
|
await self.connector.insert("test_table", {"field": "value"})
|
||||||
|
|
||||||
|
# Check if table exists
|
||||||
|
self.assertTrue(await self.connector.has_table("test_table"))
|
||||||
|
self.assertFalse(await self.connector.has_table("nonexistent"))
|
||||||
|
|
||||||
|
# Get tables list
|
||||||
|
tables = await self.connector.tables
|
||||||
|
self.assertIn("test_table", tables)
|
||||||
|
|
||||||
|
# Drop table
|
||||||
|
await self.connector.drop("test_table")
|
||||||
|
self.assertFalse(await self.connector.has_table("test_table"))
|
||||||
|
|
||||||
|
async def test_column_operations(self):
|
||||||
|
"""Test column-related operations."""
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
|
||||||
|
# Check if column exists
|
||||||
|
self.assertTrue(await self.connector.has_column("people", "name"))
|
||||||
|
self.assertTrue(await self.connector.has_column("people", "age"))
|
||||||
|
self.assertFalse(await self.connector.has_column("people", "nonexistent"))
|
||||||
|
|
||||||
|
# Get columns list
|
||||||
|
columns = await self.connector.columns("people")
|
||||||
|
self.assertIn("name", columns)
|
||||||
|
self.assertIn("age", columns)
|
||||||
|
self.assertIn("uid", columns)
|
||||||
|
|
||||||
|
async def test_transaction_methods(self):
|
||||||
|
"""Test transaction methods."""
|
||||||
|
await self.connector._ensure_initialized()
|
||||||
|
if self.connector._is_server:
|
||||||
|
# Begin transaction
|
||||||
|
await self.connector.begin()
|
||||||
|
|
||||||
|
# Insert during transaction
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
|
||||||
|
# Rollback
|
||||||
|
await self.connector.rollback()
|
||||||
|
|
||||||
|
# Check that data was not committed
|
||||||
|
result = await self.connector.get("people", {"name": "Alice"})
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
# Test commit
|
||||||
|
await self.connector.begin()
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
await self.connector.commit()
|
||||||
|
|
||||||
|
# Check that data was committed
|
||||||
|
result = await self.connector.get("people", {"name": "Bob"})
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
|
async def test_context_manager_transaction(self):
|
||||||
|
"""Test transaction as context manager."""
|
||||||
|
await self.connector._ensure_initialized()
|
||||||
|
if self.connector._is_server:
|
||||||
|
# Test successful transaction
|
||||||
|
async with self.connector:
|
||||||
|
await self.connector.insert("people", {"name": "Alice", "age": 25})
|
||||||
|
|
||||||
|
result = await self.connector.get("people", {"name": "Alice"})
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
|
||||||
|
# Test failed transaction (should rollback)
|
||||||
|
try:
|
||||||
|
async with self.connector:
|
||||||
|
await self.connector.insert("people", {"name": "Bob", "age": 30})
|
||||||
|
raise Exception("Test exception")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = await self.connector.get("people", {"name": "Bob"})
|
||||||
|
# This might still exist because exception was after insert
|
||||||
|
# The behavior depends on SQLite's transaction handling
|
||||||
|
|
||||||
async def test_raw_query(self):
|
async def test_raw_query(self):
|
||||||
await self.connector.insert("people", {"name": "John", "age": 30})
|
await self.connector.insert("people", {"name": "John", "age": 30})
|
||||||
await self.connector.insert("people", {"name": "Jane", "age": 25})
|
await self.connector.insert("people", {"name": "Jane", "age": 25})
|
||||||
|
|||||||
1209
src/snek/system/ads.py.old
Normal file
1209
src/snek/system/ads.py.old
Normal file
File diff suppressed because it is too large
Load Diff
@ -93,6 +93,7 @@ class BaseMapper:
|
|||||||
return "insert" in sql or "update" in sql or "delete" in sql
|
return "insert" in sql or "update" in sql or "delete" in sql
|
||||||
|
|
||||||
async def query(self, sql, *args):
|
async def query(self, sql, *args):
|
||||||
|
print("XXXXXXXX", sql,args)
|
||||||
for record in await self.db.query(sql, *args):
|
for record in await self.db.query(sql, *args):
|
||||||
yield dict(record)
|
yield dict(record)
|
||||||
|
|
||||||
|
|||||||
225
src/snek/system/session.py
Normal file
225
src/snek/system/session.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from aiohttp_session import AbstractStorage, Session
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteSessionStorage(AbstractStorage):
|
||||||
|
"""SQLite storage using aiosqlite"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cookie_name: str = "AIOHTTP_SESSION",
|
||||||
|
*,
|
||||||
|
db_path: str = "sessions.db",
|
||||||
|
domain: Optional[str] = None,
|
||||||
|
max_age: Optional[int] = None,
|
||||||
|
path: str = "/",
|
||||||
|
secure: Optional[bool] = None,
|
||||||
|
httponly: bool = True,
|
||||||
|
samesite: Optional[str] = None,
|
||||||
|
key_factory: Callable[[], str] = lambda: uuid.uuid4().hex,
|
||||||
|
encoder: Callable[[object], str] = json.dumps,
|
||||||
|
decoder: Callable[[str], Any] = json.loads,
|
||||||
|
) -> None:
|
||||||
|
# Debug logging
|
||||||
|
print(f"SQLiteStorage init - cookie_name type: {type(cookie_name)}, value: {cookie_name!r}")
|
||||||
|
|
||||||
|
# Ensure all string parameters are actually strings, not bytes
|
||||||
|
if isinstance(cookie_name, bytes):
|
||||||
|
cookie_name = cookie_name.decode('utf-8')
|
||||||
|
if isinstance(domain, bytes):
|
||||||
|
domain = domain.decode('utf-8')
|
||||||
|
if isinstance(path, bytes):
|
||||||
|
path = path.decode('utf-8')
|
||||||
|
if isinstance(samesite, bytes):
|
||||||
|
samesite = samesite.decode('utf-8')
|
||||||
|
|
||||||
|
# Convert to string explicitly
|
||||||
|
cookie_name = str(cookie_name)
|
||||||
|
path = str(path)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
cookie_name=cookie_name,
|
||||||
|
domain=domain,
|
||||||
|
max_age=max_age,
|
||||||
|
path=path,
|
||||||
|
secure=secure,
|
||||||
|
httponly=httponly,
|
||||||
|
samesite=samesite,
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
)
|
||||||
|
self._key_factory = key_factory
|
||||||
|
self._db_path = db_path
|
||||||
|
self._db: Optional[aiosqlite.Connection] = None
|
||||||
|
|
||||||
|
def save_cookie(
|
||||||
|
self,
|
||||||
|
response: web.StreamResponse,
|
||||||
|
cookie_data: str,
|
||||||
|
*,
|
||||||
|
max_age: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Override to ensure proper string handling and add debugging."""
|
||||||
|
# Ensure cookie_data is a string
|
||||||
|
if isinstance(cookie_data, bytes):
|
||||||
|
cookie_data = cookie_data.decode('utf-8')
|
||||||
|
elif cookie_data is not None:
|
||||||
|
cookie_data = str(cookie_data)
|
||||||
|
|
||||||
|
# Ensure cookie name is a string
|
||||||
|
cookie_name = self._cookie_name
|
||||||
|
if isinstance(cookie_name, bytes):
|
||||||
|
print(f"WARNING: _cookie_name is bytes: {cookie_name!r}")
|
||||||
|
cookie_name = cookie_name.decode('utf-8')
|
||||||
|
cookie_name = str(cookie_name)
|
||||||
|
|
||||||
|
# Get cookie params and ensure they're all strings
|
||||||
|
params = self._cookie_params.copy()
|
||||||
|
if max_age is not None:
|
||||||
|
params["max_age"] = max_age
|
||||||
|
t = time.gmtime(time.time() + max_age)
|
||||||
|
params["expires"] = time.strftime("%a, %d-%b-%Y %T GMT", t)
|
||||||
|
|
||||||
|
# Ensure all string params in the dict are strings, not bytes
|
||||||
|
for key in ['domain', 'path', 'samesite']:
|
||||||
|
if key in params and isinstance(params[key], bytes):
|
||||||
|
params[key] = params[key].decode('utf-8')
|
||||||
|
|
||||||
|
if not cookie_data:
|
||||||
|
response.del_cookie(
|
||||||
|
cookie_name, domain=params.get("domain"), path=params.get("path")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response.set_cookie(cookie_name, cookie_data, **params)
|
||||||
|
|
||||||
|
async def setup(self) -> None:
|
||||||
|
"""Initialize the database connection and create table if needed."""
|
||||||
|
self._db = await aiosqlite.connect(self._db_path)
|
||||||
|
await self._db.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
data TEXT NOT NULL,
|
||||||
|
expires_at REAL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
# Create index for expiry cleanup
|
||||||
|
await self._db.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_expires_at
|
||||||
|
ON sessions(expires_at)
|
||||||
|
WHERE expires_at IS NOT NULL
|
||||||
|
""")
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def cleanup(self) -> None:
|
||||||
|
"""Close the database connection and cleanup expired sessions."""
|
||||||
|
if self._db:
|
||||||
|
# Clean up expired sessions before closing
|
||||||
|
await self._cleanup_expired()
|
||||||
|
await self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
async def _cleanup_expired(self) -> None:
|
||||||
|
"""Remove expired sessions from the database."""
|
||||||
|
if self._db:
|
||||||
|
current_time = time.time()
|
||||||
|
await self._db.execute(
|
||||||
|
"DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at < ?",
|
||||||
|
(current_time,)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
|
|
||||||
|
async def load_session(self, request: web.Request) -> Session:
|
||||||
|
"""Load session from SQLite database."""
|
||||||
|
cookie = self.load_cookie(request)
|
||||||
|
if cookie is None:
|
||||||
|
return Session(None, data=None, new=True, max_age=self.max_age)
|
||||||
|
|
||||||
|
if not self._db:
|
||||||
|
raise RuntimeError("Storage not initialized. Call setup() first.")
|
||||||
|
|
||||||
|
# Periodically clean up expired sessions
|
||||||
|
# You might want to do this less frequently in production
|
||||||
|
await self._cleanup_expired()
|
||||||
|
|
||||||
|
key = str(cookie)
|
||||||
|
session_key = f"{self._cookie_name}_{key}"
|
||||||
|
|
||||||
|
async with self._db.execute(
|
||||||
|
"SELECT data, expires_at FROM sessions WHERE key = ?",
|
||||||
|
(session_key,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return Session(None, data=None, new=True, max_age=self.max_age)
|
||||||
|
|
||||||
|
data_str, expires_at = row
|
||||||
|
|
||||||
|
# Check if session has expired
|
||||||
|
if expires_at is not None and expires_at < time.time():
|
||||||
|
# Session expired, delete it
|
||||||
|
await self._db.execute("DELETE FROM sessions WHERE key = ?", (session_key,))
|
||||||
|
await self._db.commit()
|
||||||
|
return Session(None, data=None, new=True, max_age=self.max_age)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = self._decoder(data_str)
|
||||||
|
except ValueError:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
return Session(key, data=data, new=False, max_age=self.max_age)
|
||||||
|
|
||||||
|
async def save_session(
|
||||||
|
self, request: web.Request, response: web.StreamResponse, session: Session
|
||||||
|
) -> None:
|
||||||
|
"""Save session to SQLite database."""
|
||||||
|
if not self._db:
|
||||||
|
raise RuntimeError("Storage not initialized. Call setup() first.")
|
||||||
|
|
||||||
|
key = session.identity
|
||||||
|
if key is None:
|
||||||
|
key = self._key_factory()
|
||||||
|
# Ensure key is a string before passing to save_cookie
|
||||||
|
if isinstance(key, bytes):
|
||||||
|
key = key.decode('utf-8')
|
||||||
|
else:
|
||||||
|
key = str(key)
|
||||||
|
self.save_cookie(response, key, max_age=session.max_age)
|
||||||
|
else:
|
||||||
|
# Ensure key is a string
|
||||||
|
if isinstance(key, bytes):
|
||||||
|
key = key.decode('utf-8')
|
||||||
|
else:
|
||||||
|
key = str(key)
|
||||||
|
|
||||||
|
if session.empty:
|
||||||
|
self.save_cookie(response, "", max_age=session.max_age)
|
||||||
|
# Delete the session from database
|
||||||
|
session_key = f"{self._cookie_name}_{key}"
|
||||||
|
await self._db.execute("DELETE FROM sessions WHERE key = ?", (session_key,))
|
||||||
|
await self._db.commit()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.save_cookie(response, key, max_age=session.max_age)
|
||||||
|
|
||||||
|
data_str = self._encoder(self._get_session_data(session))
|
||||||
|
session_key = f"{self._cookie_name}_{key}"
|
||||||
|
|
||||||
|
# Calculate expiry time if max_age is set
|
||||||
|
expires_at = None
|
||||||
|
if session.max_age is not None:
|
||||||
|
expires_at = time.time() + session.max_age
|
||||||
|
|
||||||
|
# Use INSERT OR REPLACE to handle both new and existing sessions
|
||||||
|
await self._db.execute(
|
||||||
|
"INSERT OR REPLACE INTO sessions (key, data, expires_at) VALUES (?, ?, ?)",
|
||||||
|
(session_key, data_str, expires_at)
|
||||||
|
)
|
||||||
|
await self._db.commit()
|
||||||
Loading…
Reference in New Issue
Block a user