feat: implement distributed ads with unix sockets
This commit is contained in:
parent
f59894c65a
commit
66a45eb5e6
@ -2,6 +2,14 @@
|
||||
|
||||
|
||||
|
||||
|
||||
## Version 1.5.0 - 2025-11-06
|
||||
|
||||
Agents can now use datasets stored across multiple locations. This allows for larger datasets and improved performance.
|
||||
|
||||
**Changes:** 2 files, 10 lines
|
||||
**Languages:** Markdown (8 lines), TOML (2 lines)
|
||||
|
||||
## Version 1.4.0 - 2025-11-06
|
||||
|
||||
Agents can now share data more efficiently using a new distributed dataset system. This improves performance and allows agents to work together on larger tasks.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from pr.core import Assistant
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = ["Assistant"]
|
||||
|
||||
@ -1,13 +1,18 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from pr import __version__
|
||||
from pr.core import Assistant
|
||||
|
||||
|
||||
def main():
|
||||
async def main_async():
|
||||
import tracemalloc
|
||||
|
||||
tracemalloc.start()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PR Assistant - Professional CLI AI assistant with autonomous execution",
|
||||
description="PR Assistant - Professional CLI AI assistant with visual effects, cost tracking, and autonomous execution",
|
||||
epilog="""
|
||||
Examples:
|
||||
pr "What is Python?" # Single query
|
||||
@ -18,6 +23,12 @@ Examples:
|
||||
pr --list-sessions # List all sessions
|
||||
pr --usage # Show token usage stats
|
||||
|
||||
Features:
|
||||
• Visual progress indicators during AI calls
|
||||
• Real-time cost tracking for each query
|
||||
• Sophisticated CLI with colors and effects
|
||||
• Tool execution with status updates
|
||||
|
||||
Commands in interactive mode:
|
||||
/auto [task] - Enter autonomous mode
|
||||
/reset - Clear message history
|
||||
@ -156,7 +167,11 @@ Commands in interactive mode:
|
||||
return
|
||||
|
||||
assistant = Assistant(args)
|
||||
assistant.run()
|
||||
await assistant.run()
|
||||
|
||||
|
||||
def main():
|
||||
return asyncio.run(main_async())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
969
pr/ads.py
969
pr/ads.py
@ -1,969 +0,0 @@
|
||||
import re
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
AsyncGenerator,
|
||||
Union,
|
||||
Tuple,
|
||||
Set,
|
||||
)
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
import socket
|
||||
import os
|
||||
import atexit
|
||||
|
||||
|
||||
class AsyncDataSet:
|
||||
"""
|
||||
Distributed AsyncDataSet with client-server model over Unix sockets.
|
||||
|
||||
Parameters:
|
||||
- file: Path to the SQLite database file
|
||||
- socket_path: Path to Unix socket for inter-process communication (default: "ads.sock")
|
||||
- max_concurrent_queries: Maximum concurrent database operations (default: 1)
|
||||
Set to 1 to prevent "database is locked" errors with SQLite
|
||||
"""
|
||||
|
||||
_KV_TABLE = "__kv_store"
|
||||
_DEFAULT_COLUMNS = {
|
||||
"uid": "TEXT PRIMARY KEY",
|
||||
"created_at": "TEXT",
|
||||
"updated_at": "TEXT",
|
||||
"deleted_at": "TEXT",
|
||||
}
|
||||
|
||||
def __init__(self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1):
|
||||
self._file = file
|
||||
self._socket_path = socket_path
|
||||
self._table_columns_cache: Dict[str, Set[str]] = {}
|
||||
self._is_server = None # None means not initialized yet
|
||||
self._server = None
|
||||
self._runner = None
|
||||
self._client_session = None
|
||||
self._initialized = False
|
||||
self._init_lock = None # Will be created when needed
|
||||
self._max_concurrent_queries = max_concurrent_queries
|
||||
self._db_semaphore = None # Will be created when needed
|
||||
self._db_connection = None # Persistent database connection
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensure the instance is initialized as server or client."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Create lock if needed (first time in async context)
|
||||
if self._init_lock is None:
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async with self._init_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
await self._initialize()
|
||||
|
||||
# Create semaphore for database operations (server only)
|
||||
if self._is_server:
|
||||
self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def _initialize(self):
|
||||
"""Initialize as server or client."""
|
||||
try:
|
||||
# Try to create Unix socket
|
||||
if os.path.exists(self._socket_path):
|
||||
# Check if socket is active
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.connect(self._socket_path)
|
||||
sock.close()
|
||||
# Socket is active, we're a client
|
||||
self._is_server = False
|
||||
await self._setup_client()
|
||||
except (ConnectionRefusedError, FileNotFoundError):
|
||||
# Socket file exists but not active, remove and become server
|
||||
os.unlink(self._socket_path)
|
||||
self._is_server = True
|
||||
await self._setup_server()
|
||||
else:
|
||||
# No socket file, become server
|
||||
self._is_server = True
|
||||
await self._setup_server()
|
||||
except Exception:
|
||||
# Fallback to client mode
|
||||
self._is_server = False
|
||||
await self._setup_client()
|
||||
|
||||
# Establish persistent database connection
|
||||
self._db_connection = await aiosqlite.connect(self._file)
|
||||
|
||||
async def _setup_server(self):
|
||||
"""Setup aiohttp server."""
|
||||
app = web.Application()
|
||||
|
||||
# Add routes for all methods
|
||||
app.router.add_post("/insert", self._handle_insert)
|
||||
app.router.add_post("/update", self._handle_update)
|
||||
app.router.add_post("/delete", self._handle_delete)
|
||||
app.router.add_post("/upsert", self._handle_upsert)
|
||||
app.router.add_post("/get", self._handle_get)
|
||||
app.router.add_post("/find", self._handle_find)
|
||||
app.router.add_post("/count", self._handle_count)
|
||||
app.router.add_post("/exists", self._handle_exists)
|
||||
app.router.add_post("/kv_set", self._handle_kv_set)
|
||||
app.router.add_post("/kv_get", self._handle_kv_get)
|
||||
app.router.add_post("/execute_raw", self._handle_execute_raw)
|
||||
app.router.add_post("/query_raw", self._handle_query_raw)
|
||||
app.router.add_post("/query_one", self._handle_query_one)
|
||||
app.router.add_post("/create_table", self._handle_create_table)
|
||||
app.router.add_post("/insert_unique", self._handle_insert_unique)
|
||||
app.router.add_post("/aggregate", self._handle_aggregate)
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
self._server = web.UnixSite(self._runner, self._socket_path)
|
||||
await self._server.start()
|
||||
|
||||
# Register cleanup
|
||||
def cleanup():
|
||||
if os.path.exists(self._socket_path):
|
||||
os.unlink(self._socket_path)
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
async def _setup_client(self):
|
||||
"""Setup aiohttp client."""
|
||||
connector = aiohttp.UnixConnector(path=self._socket_path)
|
||||
self._client_session = aiohttp.ClientSession(connector=connector)
|
||||
|
||||
async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any:
|
||||
"""Make HTTP request to server."""
|
||||
if not self._client_session:
|
||||
await self._setup_client()
|
||||
|
||||
url = f"http://localhost/{endpoint}"
|
||||
async with self._client_session.post(url, json=data) as resp:
|
||||
result = await resp.json()
|
||||
if result.get("error"):
|
||||
raise Exception(result["error"])
|
||||
return result.get("result")
|
||||
|
||||
# Server handlers
|
||||
async def _handle_insert(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_insert(
|
||||
data["table"], data["args"], data.get("return_id", False)
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_update(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_update(data["table"], data["args"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_delete(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_delete(data["table"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_upsert(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_upsert(data["table"], data["args"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_get(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_get(data["table"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_find(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_find(
|
||||
data["table"],
|
||||
data.get("where"),
|
||||
limit=data.get("limit", 0),
|
||||
offset=data.get("offset", 0),
|
||||
order_by=data.get("order_by"),
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_count(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_count(data["table"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_exists(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_exists(data["table"], data["where"])
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_kv_set(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
await self._server_kv_set(data["key"], data["value"], table=data.get("table"))
|
||||
return web.json_response({"result": None})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_kv_get(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_kv_get(
|
||||
data["key"], default=data.get("default"), table=data.get("table")
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_execute_raw(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_execute_raw(
|
||||
data["sql"],
|
||||
tuple(data.get("params", [])) if data.get("params") else None,
|
||||
)
|
||||
return web.json_response(
|
||||
{"result": result.rowcount if hasattr(result, "rowcount") else None}
|
||||
)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_query_raw(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_query_raw(
|
||||
data["sql"],
|
||||
tuple(data.get("params", [])) if data.get("params") else None,
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_query_one(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_query_one(
|
||||
data["sql"],
|
||||
tuple(data.get("params", [])) if data.get("params") else None,
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_create_table(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
await self._server_create_table(data["table"], data["schema"], data.get("constraints"))
|
||||
return web.json_response({"result": None})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_insert_unique(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_insert_unique(
|
||||
data["table"], data["args"], data["unique_fields"]
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_aggregate(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_aggregate(
|
||||
data["table"],
|
||||
data["function"],
|
||||
data.get("column", "*"),
|
||||
data.get("where"),
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
# Public methods that delegate to server or client
|
||||
async def insert(
|
||||
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||
) -> Union[str, int]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_insert(table, args, return_id)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"insert", {"table": table, "args": args, "return_id": return_id}
|
||||
)
|
||||
|
||||
async def update(
|
||||
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_update(table, args, where)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"update", {"table": table, "args": args, "where": where}
|
||||
)
|
||||
|
||||
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_delete(table, where)
|
||||
else:
|
||||
return await self._make_request("delete", {"table": table, "where": where})
|
||||
|
||||
async def get(
|
||||
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_get(table, where)
|
||||
else:
|
||||
return await self._make_request("get", {"table": table, "where": where})
|
||||
|
||||
async def find(
|
||||
self,
|
||||
table: str,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
limit: int = 0,
|
||||
offset: int = 0,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_find(
|
||||
table, where, limit=limit, offset=offset, order_by=order_by
|
||||
)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"find",
|
||||
{
|
||||
"table": table,
|
||||
"where": where,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"order_by": order_by,
|
||||
},
|
||||
)
|
||||
|
||||
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_count(table, where)
|
||||
else:
|
||||
return await self._make_request("count", {"table": table, "where": where})
|
||||
|
||||
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_exists(table, where)
|
||||
else:
|
||||
return await self._make_request("exists", {"table": table, "where": where})
|
||||
|
||||
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_query_raw(sql, params)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"query_raw", {"sql": sql, "params": list(params) if params else None}
|
||||
)
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Dict[str, str],
|
||||
constraints: Optional[List[str]] = None,
|
||||
):
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_create_table(table, schema, constraints)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"create_table",
|
||||
{"table": table, "schema": schema, "constraints": constraints},
|
||||
)
|
||||
|
||||
async def insert_unique(
|
||||
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||
) -> Union[str, None]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_insert_unique(table, args, unique_fields)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"insert_unique",
|
||||
{"table": table, "args": args, "unique_fields": unique_fields},
|
||||
)
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
table: str,
|
||||
function: str,
|
||||
column: str = "*",
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_aggregate(table, function, column, where)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"aggregate",
|
||||
{
|
||||
"table": table,
|
||||
"function": function,
|
||||
"column": column,
|
||||
"where": where,
|
||||
},
|
||||
)
|
||||
|
||||
async def transaction(self):
|
||||
"""Context manager for transactions."""
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return TransactionContext(self._db_connection, self._db_semaphore)
|
||||
else:
|
||||
raise NotImplementedError("Transactions not supported in client mode")
|
||||
|
||||
# Server implementation methods (original logic)
|
||||
@staticmethod
|
||||
def _utc_iso() -> str:
|
||||
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
||||
|
||||
@staticmethod
|
||||
def _py_to_sqlite_type(value: Any) -> str:
|
||||
if value is None:
|
||||
return "TEXT"
|
||||
if isinstance(value, bool):
|
||||
return "INTEGER"
|
||||
if isinstance(value, int):
|
||||
return "INTEGER"
|
||||
if isinstance(value, float):
|
||||
return "REAL"
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return "BLOB"
|
||||
return "TEXT"
|
||||
|
||||
async def _get_table_columns(self, table: str) -> Set[str]:
|
||||
"""Get actual columns that exist in the table."""
|
||||
if table in self._table_columns_cache:
|
||||
return self._table_columns_cache[table]
|
||||
|
||||
columns = set()
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
async with self._db_connection.execute(f"PRAGMA table_info({table})") as cursor:
|
||||
async for row in cursor:
|
||||
columns.add(row[1]) # Column name is at index 1
|
||||
self._table_columns_cache[table] = columns
|
||||
except:
|
||||
pass
|
||||
return columns
|
||||
|
||||
async def _invalidate_column_cache(self, table: str):
|
||||
"""Invalidate column cache for a table."""
|
||||
if table in self._table_columns_cache:
|
||||
del self._table_columns_cache[table]
|
||||
|
||||
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
||||
col_type = self._py_to_sqlite_type(value)
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
await self._db_connection.execute(
|
||||
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
|
||||
)
|
||||
await self._db_connection.commit()
|
||||
await self._invalidate_column_cache(table)
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
pass # Column already exists
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
||||
# Always include default columns
|
||||
cols = self._DEFAULT_COLUMNS.copy()
|
||||
|
||||
# Add columns from col_sources
|
||||
for key, val in col_sources.items():
|
||||
if key not in cols:
|
||||
cols[key] = self._py_to_sqlite_type(val)
|
||||
|
||||
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
||||
async with self._db_semaphore:
|
||||
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||
await self._db_connection.commit()
|
||||
await self._invalidate_column_cache(table)
|
||||
|
||||
async def _table_exists(self, table: str) -> bool:
|
||||
"""Check if a table exists."""
|
||||
async with self._db_semaphore:
|
||||
async with self._db_connection.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
||||
) as cursor:
|
||||
return await cursor.fetchone() is not None
|
||||
|
||||
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
||||
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
||||
|
||||
@classmethod
|
||||
def _missing_column_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
|
||||
m = cls._RE_NO_COLUMN.search(str(err))
|
||||
return m.group(1) if m else None
|
||||
|
||||
@classmethod
|
||||
def _missing_table_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
|
||||
m = cls._RE_NO_TABLE.search(str(err))
|
||||
return m.group(1) if m else None
|
||||
|
||||
async def _safe_execute(
|
||||
self,
|
||||
table: str,
|
||||
sql: str,
|
||||
params: Iterable[Any],
|
||||
col_sources: Dict[str, Any],
|
||||
max_retries: int = 10,
|
||||
) -> aiosqlite.Cursor:
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
cursor = await self._db_connection.execute(sql, params)
|
||||
await self._db_connection.commit()
|
||||
return cursor
|
||||
except aiosqlite.OperationalError as err:
|
||||
retries += 1
|
||||
err_str = str(err).lower()
|
||||
|
||||
# Handle missing column
|
||||
col = self._missing_column_from_error(err)
|
||||
if col:
|
||||
if col in col_sources:
|
||||
await self._ensure_column(table, col, col_sources[col])
|
||||
else:
|
||||
# Column not in sources, ensure it with NULL/TEXT type
|
||||
await self._ensure_column(table, col, None)
|
||||
continue
|
||||
|
||||
# Handle missing table
|
||||
tbl = self._missing_table_from_error(err)
|
||||
if tbl:
|
||||
await self._ensure_table(tbl, col_sources)
|
||||
continue
|
||||
|
||||
# Handle other column-related errors
|
||||
if "has no column named" in err_str:
|
||||
# Extract column name differently
|
||||
match = re.search(r"table \w+ has no column named (\w+)", err_str)
|
||||
if match:
|
||||
col_name = match.group(1)
|
||||
if col_name in col_sources:
|
||||
await self._ensure_column(table, col_name, col_sources[col_name])
|
||||
else:
|
||||
await self._ensure_column(table, col_name, None)
|
||||
continue
|
||||
|
||||
raise
|
||||
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||
|
||||
async def _safe_query(
|
||||
self,
|
||||
table: str,
|
||||
sql: str,
|
||||
params: Iterable[Any],
|
||||
col_sources: Dict[str, Any],
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
# Check if table exists first
|
||||
if not await self._table_exists(table):
|
||||
return
|
||||
|
||||
max_retries = 10
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
self._db_connection.row_factory = aiosqlite.Row
|
||||
async with self._db_connection.execute(sql, params) as cursor:
|
||||
# Fetch all rows while holding the semaphore
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
# Yield rows after releasing the semaphore
|
||||
for row in rows:
|
||||
yield dict(row)
|
||||
return
|
||||
except aiosqlite.OperationalError as err:
|
||||
retries += 1
|
||||
err_str = str(err).lower()
|
||||
|
||||
# Handle missing table
|
||||
tbl = self._missing_table_from_error(err)
|
||||
if tbl:
|
||||
# For queries, if table doesn't exist, just return empty
|
||||
return
|
||||
|
||||
# Handle missing column in WHERE clause or SELECT
|
||||
if "no such column" in err_str:
|
||||
# For queries with missing columns, return empty
|
||||
return
|
||||
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
||||
if not where:
|
||||
return "", []
|
||||
|
||||
clauses = []
|
||||
vals = []
|
||||
|
||||
op_map = {"$eq": "=", "$ne": "!=", "$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="}
|
||||
|
||||
for k, v in where.items():
|
||||
if isinstance(v, dict):
|
||||
# Handle operators like {"$lt": 10}
|
||||
op_key = next(iter(v))
|
||||
if op_key in op_map:
|
||||
operator = op_map[op_key]
|
||||
value = v[op_key]
|
||||
clauses.append(f"`{k}` {operator} ?")
|
||||
vals.append(value)
|
||||
else:
|
||||
# Fallback for unexpected dict values
|
||||
clauses.append(f"`{k}` = ?")
|
||||
vals.append(json.dumps(v))
|
||||
else:
|
||||
# Default to equality
|
||||
clauses.append(f"`{k}` = ?")
|
||||
vals.append(v)
|
||||
|
||||
return " WHERE " + " AND ".join(clauses), vals
|
||||
|
||||
async def _server_insert(
|
||||
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||
) -> Union[str, int]:
|
||||
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
|
||||
uid = str(uuid4())
|
||||
now = self._utc_iso()
|
||||
record = {
|
||||
"uid": uid,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
**args,
|
||||
}
|
||||
|
||||
# Ensure table exists with all needed columns
|
||||
await self._ensure_table(table, record)
|
||||
|
||||
# Handle auto-increment ID if requested
|
||||
if return_id and "id" not in args:
|
||||
# Ensure id column exists
|
||||
async with self._db_semaphore:
|
||||
# Add id column if it doesn't exist
|
||||
try:
|
||||
await self._db_connection.execute(
|
||||
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
)
|
||||
await self._db_connection.commit()
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "duplicate column name" not in str(e).lower():
|
||||
# Try without autoincrement constraint
|
||||
try:
|
||||
await self._db_connection.execute(
|
||||
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
|
||||
)
|
||||
await self._db_connection.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
await self._invalidate_column_cache(table)
|
||||
|
||||
# Insert and get lastrowid
|
||||
cols = "`" + "`, `".join(record.keys()) + "`"
|
||||
qs = ", ".join(["?"] * len(record))
|
||||
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||
cursor = await self._safe_execute(table, sql, list(record.values()), record)
|
||||
return cursor.lastrowid
|
||||
|
||||
cols = "`" + "`, `".join(record) + "`"
|
||||
qs = ", ".join(["?"] * len(record))
|
||||
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||
await self._safe_execute(table, sql, list(record.values()), record)
|
||||
return uid
|
||||
|
||||
async def _server_update(
|
||||
self,
|
||||
table: str,
|
||||
args: Dict[str, Any],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
if not args:
|
||||
return 0
|
||||
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return 0
|
||||
|
||||
args["updated_at"] = self._utc_iso()
|
||||
|
||||
# Ensure all columns exist
|
||||
all_cols = {**args, **(where or {})}
|
||||
await self._ensure_table(table, all_cols)
|
||||
for col, val in all_cols.items():
|
||||
await self._ensure_column(table, col, val)
|
||||
|
||||
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
||||
params = list(args.values()) + where_params
|
||||
cur = await self._safe_execute(table, sql, params, all_cols)
|
||||
return cur.rowcount
|
||||
|
||||
async def _server_delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return 0
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"DELETE FROM {table}{where_clause}"
|
||||
cur = await self._safe_execute(table, sql, where_params, where or {})
|
||||
return cur.rowcount
|
||||
|
||||
async def _server_upsert(
|
||||
self,
|
||||
table: str,
|
||||
args: Dict[str, Any],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> str | None:
|
||||
if not args:
|
||||
raise ValueError("Nothing to update. Empty dict given.")
|
||||
args["updated_at"] = self._utc_iso()
|
||||
affected = await self._server_update(table, args, where)
|
||||
if affected:
|
||||
rec = await self._server_get(table, where)
|
||||
return rec.get("uid") if rec else None
|
||||
merged = {**(where or {}), **args}
|
||||
return await self._server_insert(table, merged)
|
||||
|
||||
async def _server_get(
|
||||
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
|
||||
async for row in self._safe_query(table, sql, where_params, where or {}):
|
||||
return row
|
||||
return None
|
||||
|
||||
async def _server_find(
|
||||
self,
|
||||
table: str,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
limit: int = 0,
|
||||
offset: int = 0,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find records with optional ordering."""
|
||||
where_clause, where_params = self._build_where(where)
|
||||
order_clause = f" ORDER BY {order_by}" if order_by else ""
|
||||
extra = (f" LIMIT {limit}" if limit else "") + (f" OFFSET {offset}" if offset else "")
|
||||
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
|
||||
return [row async for row in self._safe_query(table, sql, where_params, where or {})]
|
||||
|
||||
async def _server_count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return 0
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
||||
gen = self._safe_query(table, sql, where_params, where or {})
|
||||
async for row in gen:
|
||||
return next(iter(row.values()), 0)
|
||||
return 0
|
||||
|
||||
async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||
return (await self._server_count(table, where)) > 0
|
||||
|
||||
async def _server_kv_set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
*,
|
||||
table: str | None = None,
|
||||
) -> None:
|
||||
tbl = table or self._KV_TABLE
|
||||
json_val = json.dumps(value, default=str)
|
||||
await self._server_upsert(tbl, {"value": json_val}, {"key": key})
|
||||
|
||||
async def _server_kv_get(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
default: Any = None,
|
||||
table: str | None = None,
|
||||
) -> Any:
|
||||
tbl = table or self._KV_TABLE
|
||||
row = await self._server_get(tbl, {"key": key})
|
||||
if not row:
|
||||
return default
|
||||
try:
|
||||
return json.loads(row["value"])
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
async def _server_execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
|
||||
"""Execute raw SQL for complex queries like JOINs."""
|
||||
async with self._db_semaphore:
|
||||
cursor = await self._db_connection.execute(sql, params or ())
|
||||
await self._db_connection.commit()
|
||||
return cursor
|
||||
|
||||
async def _server_query_raw(
|
||||
self, sql: str, params: Optional[Tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute raw SQL query and return results as list of dicts."""
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
self._db_connection.row_factory = aiosqlite.Row
|
||||
async with self._db_connection.execute(sql, params or ()) as cursor:
|
||||
return [dict(row) async for row in cursor]
|
||||
except aiosqlite.OperationalError:
|
||||
# Return empty list if query fails
|
||||
return []
|
||||
|
||||
async def _server_query_one(
|
||||
self, sql: str, params: Optional[Tuple] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Execute raw SQL query and return single result."""
|
||||
results = await self._server_query_raw(sql + " LIMIT 1", params)
|
||||
return results[0] if results else None
|
||||
|
||||
async def _server_create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Dict[str, str],
|
||||
constraints: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create table with custom schema and constraints. Always includes default columns."""
|
||||
# Merge default columns with custom schema
|
||||
full_schema = self._DEFAULT_COLUMNS.copy()
|
||||
full_schema.update(schema)
|
||||
|
||||
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
|
||||
if constraints:
|
||||
columns.extend(constraints)
|
||||
columns_sql = ", ".join(columns)
|
||||
|
||||
async with self._db_semaphore:
|
||||
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||
await self._db_connection.commit()
|
||||
await self._invalidate_column_cache(table)
|
||||
|
||||
async def _server_insert_unique(
|
||||
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||
) -> Union[str, None]:
|
||||
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
||||
try:
|
||||
return await self._server_insert(table, args)
|
||||
except aiosqlite.IntegrityError as e:
|
||||
if "UNIQUE" in str(e):
|
||||
return None
|
||||
raise
|
||||
|
||||
async def _server_aggregate(
|
||||
self,
|
||||
table: str,
|
||||
function: str,
|
||||
column: str = "*",
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return None
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
||||
result = await self._server_query_one(sql, tuple(where_params))
|
||||
return result["result"] if result else None
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection or server."""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
if self._client_session:
|
||||
await self._client_session.close()
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
if os.path.exists(self._socket_path) and self._is_server:
|
||||
os.unlink(self._socket_path)
|
||||
if self._db_connection:
|
||||
await self._db_connection.close()
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""Context manager for database transactions."""
|
||||
|
||||
def __init__(self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None):
|
||||
self.db_connection = db_connection
|
||||
self.semaphore = semaphore
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.semaphore:
|
||||
await self.semaphore.acquire()
|
||||
try:
|
||||
await self.db_connection.execute("BEGIN")
|
||||
return self.db_connection
|
||||
except:
|
||||
if self.semaphore:
|
||||
self.semaphore.release()
|
||||
raise
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
try:
|
||||
if exc_type is None:
|
||||
await self.db_connection.commit()
|
||||
else:
|
||||
await self.db_connection.rollback()
|
||||
finally:
|
||||
if self.semaphore:
|
||||
self.semaphore.release()
|
||||
|
||||
|
||||
# Test cases remain the same but with additional tests for new functionality
|
||||
@ -1,9 +1,13 @@
|
||||
import asyncio
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from pr.autonomous.detection import is_task_complete
|
||||
from pr.core.api import call_api
|
||||
from pr.core.context import truncate_tool_result
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.ui import Colors, display_tool_call
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
@ -35,10 +39,15 @@ def run_autonomous_mode(assistant, task):
|
||||
|
||||
logger.debug(f"Messages after context management: {len(assistant.messages)}")
|
||||
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an event loop, use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
|
||||
response = call_api(
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
@ -46,6 +55,22 @@ def run_autonomous_mode(assistant, task):
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose,
|
||||
),
|
||||
loop,
|
||||
)
|
||||
response = future.result()
|
||||
except RuntimeError:
|
||||
# No event loop running, use asyncio.run
|
||||
response = asyncio.run(
|
||||
call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose,
|
||||
)
|
||||
)
|
||||
|
||||
if "error" in response:
|
||||
@ -118,10 +143,16 @@ def process_response_autonomous(assistant, response):
|
||||
|
||||
for result in tool_results:
|
||||
assistant.messages.append(result)
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
|
||||
follow_up = call_api(
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an event loop, use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
@ -129,6 +160,22 @@ def process_response_autonomous(assistant, response):
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose,
|
||||
),
|
||||
loop,
|
||||
)
|
||||
follow_up = future.result()
|
||||
except RuntimeError:
|
||||
# No event loop running, use asyncio.run
|
||||
follow_up = asyncio.run(
|
||||
call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose,
|
||||
)
|
||||
)
|
||||
return process_response_autonomous(assistant, follow_up)
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS
|
||||
@ -40,7 +42,18 @@ def handle_command(assistant, command):
|
||||
if prompt_text.strip():
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, prompt_text)
|
||||
task = asyncio.create_task(process_message(assistant, prompt_text))
|
||||
assistant.background_tasks.add(task)
|
||||
task.add_done_callback(
|
||||
lambda t: (
|
||||
assistant.background_tasks.discard(t),
|
||||
(
|
||||
logging.error(f"Background task failed: {t.exception()}")
|
||||
if t.exception()
|
||||
else assistant.background_tasks.discard(t)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif cmd == "/auto":
|
||||
if len(command_parts) < 2:
|
||||
@ -201,7 +214,18 @@ def review_file(assistant, filename):
|
||||
message = f"Please review this file and provide feedback:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
task = asyncio.create_task(process_message(assistant, message))
|
||||
assistant.background_tasks.add(task)
|
||||
task.add_done_callback(
|
||||
lambda t: (
|
||||
assistant.background_tasks.discard(t),
|
||||
(
|
||||
logging.error(f"Background task failed: {t.exception()}")
|
||||
if t.exception()
|
||||
else assistant.background_tasks.discard(t)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
@ -212,7 +236,18 @@ def refactor_file(assistant, filename):
|
||||
message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
task = asyncio.create_task(process_message(assistant, message))
|
||||
assistant.background_tasks.add(task)
|
||||
task.add_done_callback(
|
||||
lambda t: (
|
||||
assistant.background_tasks.discard(t),
|
||||
(
|
||||
logging.error(f"Background task failed: {t.exception()}")
|
||||
if t.exception()
|
||||
else assistant.background_tasks.discard(t)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
@ -223,7 +258,18 @@ def obfuscate_file(assistant, filename):
|
||||
message = f"Please obfuscate this code:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
task = asyncio.create_task(process_message(assistant, message))
|
||||
assistant.background_tasks.add(task)
|
||||
task.add_done_callback(
|
||||
lambda t: (
|
||||
assistant.background_tasks.discard(t),
|
||||
(
|
||||
logging.error(f"Background task failed: {t.exception()}")
|
||||
if t.exception()
|
||||
else assistant.background_tasks.discard(t)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE
|
||||
from pr.core.context import auto_slim_messages
|
||||
from pr.core.http_client import http_client
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
|
||||
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
|
||||
async def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
|
||||
try:
|
||||
messages = auto_slim_messages(messages, verbose=verbose)
|
||||
|
||||
@ -42,16 +41,26 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
|
||||
data["tool_choice"] = "auto"
|
||||
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
|
||||
|
||||
request_json = json.dumps(data)
|
||||
request_json = data
|
||||
logger.debug(f"Request payload size: {len(request_json)} bytes")
|
||||
|
||||
req = urllib.request.Request(
|
||||
api_url, data=request_json.encode("utf-8"), headers=headers, method="POST"
|
||||
)
|
||||
|
||||
logger.debug("Sending HTTP request...")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
response_data = response.read().decode("utf-8")
|
||||
response = await http_client.post(api_url, headers=headers, json_data=request_json)
|
||||
|
||||
if response.get("error"):
|
||||
if "status" in response:
|
||||
logger.error(f"API HTTP Error: {response['status']} - {response.get('text', '')}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {
|
||||
"error": f"API Error: {response['status']}",
|
||||
"message": response.get("text", ""),
|
||||
}
|
||||
else:
|
||||
logger.error(f"API call failed: {response.get('exception', 'Unknown error')}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": response.get("exception", "Unknown error")}
|
||||
|
||||
response_data = response["text"]
|
||||
logger.debug(f"Response received: {len(response_data)} bytes")
|
||||
result = json.loads(response_data)
|
||||
|
||||
@ -73,32 +82,29 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
|
||||
usage = result["usage"]
|
||||
input_t = usage.get("prompt_tokens", 0)
|
||||
output_t = usage.get("completion_tokens", 0)
|
||||
cost = UsageTracker._calculate_cost(model, input_t, output_t)
|
||||
print(f"API call cost: €{cost:.4f}")
|
||||
UsageTracker._calculate_cost(model, input_t, output_t)
|
||||
|
||||
logger.debug("=== API CALL END ===")
|
||||
return result
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode("utf-8")
|
||||
logger.error(f"API HTTP Error: {e.code} - {error_body}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": f"API Error: {e.code}", "message": error_body}
|
||||
except Exception as e:
|
||||
logger.error(f"API call failed: {e}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def list_models(model_list_url, api_key):
|
||||
async def list_models(model_list_url, api_key):
|
||||
try:
|
||||
req = urllib.request.Request(model_list_url)
|
||||
headers = {}
|
||||
if api_key:
|
||||
req.add_header("Authorization", f"Bearer {api_key}")
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
with urllib.request.urlopen(req) as response:
|
||||
data = json.loads(response.read().decode("utf-8"))
|
||||
response = await http_client.get(model_list_url, headers=headers)
|
||||
|
||||
if response.get("error"):
|
||||
return {"error": response.get("text", "HTTP error")}
|
||||
|
||||
data = json.loads(response["text"])
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@ -20,7 +20,7 @@ from pr.config import (
|
||||
)
|
||||
from pr.core.api import call_api
|
||||
from pr.core.autonomous_interactions import (
|
||||
get_global_autonomous,
|
||||
start_global_autonomous,
|
||||
stop_global_autonomous,
|
||||
)
|
||||
from pr.core.background_monitor import (
|
||||
@ -29,6 +29,7 @@ from pr.core.background_monitor import (
|
||||
stop_global_monitor,
|
||||
)
|
||||
from pr.core.context import init_system_message, truncate_tool_result
|
||||
from pr.core.usage_tracker import UsageTracker
|
||||
from pr.tools import get_tools_definition
|
||||
from pr.tools.agents import (
|
||||
collaborate_agents,
|
||||
@ -71,7 +72,7 @@ from pr.tools.memory import (
|
||||
from pr.tools.patch import apply_patch, create_diff, display_file_diff
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.ui import Colors, render_markdown
|
||||
from pr.ui import Colors, Spinner, render_markdown
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -109,6 +110,8 @@ class Assistant:
|
||||
self.autonomous_mode = False
|
||||
self.autonomous_iterations = 0
|
||||
self.background_monitoring = False
|
||||
self.usage_tracker = UsageTracker()
|
||||
self.background_tasks = set()
|
||||
self.init_database()
|
||||
self.messages.append(init_system_message(args))
|
||||
|
||||
@ -125,8 +128,7 @@ class Assistant:
|
||||
# Initialize background monitoring components
|
||||
try:
|
||||
start_global_monitor()
|
||||
autonomous = get_global_autonomous()
|
||||
autonomous.start(llm_callback=self._handle_background_updates)
|
||||
start_global_autonomous(llm_callback=self._handle_background_updates)
|
||||
self.background_monitoring = True
|
||||
if self.debug:
|
||||
logger.debug("Background monitoring initialized")
|
||||
@ -236,7 +238,7 @@ class Assistant:
|
||||
if self.debug:
|
||||
print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}")
|
||||
|
||||
def execute_tool_calls(self, tool_calls):
|
||||
async def execute_tool_calls(self, tool_calls):
|
||||
results = []
|
||||
|
||||
logger.debug(f"Executing {len(tool_calls)} tool call(s)")
|
||||
@ -337,7 +339,7 @@ class Assistant:
|
||||
|
||||
return results
|
||||
|
||||
def process_response(self, response):
|
||||
async def process_response(self, response):
|
||||
if "error" in response:
|
||||
return f"Error: {response['error']}"
|
||||
|
||||
@ -348,15 +350,17 @@ class Assistant:
|
||||
self.messages.append(message)
|
||||
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
if self.verbose:
|
||||
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}")
|
||||
tool_count = len(message["tool_calls"])
|
||||
print(f"{Colors.BLUE}🔧 Executing {tool_count} tool call(s)...{Colors.RESET}")
|
||||
|
||||
tool_results = self.execute_tool_calls(message["tool_calls"])
|
||||
tool_results = await self.execute_tool_calls(message["tool_calls"])
|
||||
|
||||
print(f"{Colors.GREEN}✅ Tool execution completed.{Colors.RESET}")
|
||||
|
||||
for result in tool_results:
|
||||
self.messages.append(result)
|
||||
|
||||
follow_up = call_api(
|
||||
follow_up = await call_api(
|
||||
self.messages,
|
||||
self.model,
|
||||
self.api_url,
|
||||
@ -365,7 +369,7 @@ class Assistant:
|
||||
get_tools_definition(),
|
||||
verbose=self.verbose,
|
||||
)
|
||||
return self.process_response(follow_up)
|
||||
return await self.process_response(follow_up)
|
||||
|
||||
content = message.get("content", "")
|
||||
return render_markdown(content, self.syntax_highlighting)
|
||||
@ -439,12 +443,23 @@ class Assistant:
|
||||
readline.set_completer(completer)
|
||||
readline.parse_and_bind("tab: complete")
|
||||
|
||||
def run_repl(self):
|
||||
async def run_repl(self):
|
||||
self.setup_readline()
|
||||
signal.signal(signal.SIGINT, self.signal_handler)
|
||||
|
||||
print(f"{Colors.BOLD}r{Colors.RESET}")
|
||||
print(f"Type 'help' for commands or start chatting")
|
||||
print(
|
||||
f"{Colors.BOLD}{Colors.CYAN}╔══════════════════════════════════════════════╗{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f"{Colors.BOLD}{Colors.CYAN}║{Colors.RESET}{Colors.BOLD} PR Assistant v{__import__('pr').__version__} {Colors.RESET}{Colors.BOLD}{Colors.CYAN}║{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f"{Colors.BOLD}{Colors.CYAN}╚══════════════════════════════════════════════╝{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f"{Colors.GRAY}Type 'help' for commands, 'exit' to quit, or start chatting.{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.GRAY}AI calls will show costs and progress indicators.{Colors.RESET}\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -480,7 +495,7 @@ class Assistant:
|
||||
elif cmd_result is True:
|
||||
continue
|
||||
|
||||
process_message(self, user_input)
|
||||
await process_message(self, user_input)
|
||||
|
||||
except EOFError:
|
||||
break
|
||||
@ -490,7 +505,7 @@ class Assistant:
|
||||
print(f"{Colors.RED}Error: {e}{Colors.RESET}")
|
||||
logging.error(f"REPL error: {e}\n{traceback.format_exc()}")
|
||||
|
||||
def run_single(self):
|
||||
async def run_single(self):
|
||||
if self.args.message:
|
||||
message = self.args.message
|
||||
else:
|
||||
@ -498,7 +513,7 @@ class Assistant:
|
||||
|
||||
from pr.autonomous.mode import run_autonomous_mode
|
||||
|
||||
run_autonomous_mode(self, message)
|
||||
await run_autonomous_mode(self, message)
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, "enhanced") and self.enhanced:
|
||||
@ -525,22 +540,22 @@ class Assistant:
|
||||
if self.db_conn:
|
||||
self.db_conn.close()
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
try:
|
||||
print(
|
||||
f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}"
|
||||
)
|
||||
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
|
||||
print("DEBUG: calling run_repl")
|
||||
self.run_repl()
|
||||
await self.run_repl()
|
||||
else:
|
||||
print("DEBUG: calling run_single")
|
||||
self.run_single()
|
||||
await self.run_single()
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
|
||||
def process_message(assistant, message):
|
||||
async def process_message(assistant, message):
|
||||
from pr.core.knowledge_context import inject_knowledge_context
|
||||
|
||||
inject_knowledge_context(assistant, message)
|
||||
@ -550,10 +565,11 @@ def process_message(assistant, message):
|
||||
logger.debug(f"Processing user message: {message[:100]}...")
|
||||
logger.debug(f"Current message count: {len(assistant.messages)}")
|
||||
|
||||
if assistant.verbose:
|
||||
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}")
|
||||
# Start spinner for AI call
|
||||
spinner = Spinner("Querying AI...")
|
||||
await spinner.start()
|
||||
|
||||
response = call_api(
|
||||
response = await call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
@ -562,6 +578,19 @@ def process_message(assistant, message):
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose,
|
||||
)
|
||||
result = assistant.process_response(response)
|
||||
|
||||
await spinner.stop()
|
||||
|
||||
# Track usage and display cost
|
||||
if "usage" in response:
|
||||
usage = response["usage"]
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
assistant.usage_tracker.track_request(assistant.model, input_tokens, output_tokens)
|
||||
cost = UsageTracker._calculate_cost(assistant.model, input_tokens, output_tokens)
|
||||
total_cost = assistant.usage_tracker.session_usage["estimated_cost"]
|
||||
print(f"{Colors.YELLOW}💰 Cost: ${cost:.4f} | Total: ${total_cost:.4f}{Colors.RESET}")
|
||||
|
||||
result = await assistant.process_response(response)
|
||||
|
||||
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@ -128,7 +129,15 @@ class EnhancedAssistant:
|
||||
def _api_caller_for_agent(
|
||||
self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
|
||||
) -> Dict[str, Any]:
|
||||
return call_api(
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an event loop, use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
self.base.api_url,
|
||||
@ -136,6 +145,22 @@ class EnhancedAssistant:
|
||||
use_tools=False,
|
||||
tools_definition=[],
|
||||
verbose=self.base.verbose,
|
||||
),
|
||||
loop,
|
||||
)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
# No event loop running, use asyncio.run
|
||||
return asyncio.run(
|
||||
call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
self.base.api_url,
|
||||
self.base.api_key,
|
||||
use_tools=False,
|
||||
tools_definition=[],
|
||||
verbose=self.base.verbose,
|
||||
)
|
||||
)
|
||||
|
||||
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@ -145,7 +170,15 @@ class EnhancedAssistant:
|
||||
logger.debug("API cache hit")
|
||||
return cached_response
|
||||
|
||||
response = call_api(
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an event loop, use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
self.base.api_url,
|
||||
@ -153,6 +186,22 @@ class EnhancedAssistant:
|
||||
self.base.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=self.base.verbose,
|
||||
),
|
||||
loop,
|
||||
)
|
||||
response = future.result()
|
||||
except RuntimeError:
|
||||
# No event loop running, use asyncio.run
|
||||
response = asyncio.run(
|
||||
call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
self.base.api_url,
|
||||
self.base.api_key,
|
||||
self.base.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=self.base.verbose,
|
||||
)
|
||||
)
|
||||
|
||||
if self.api_cache and CACHE_ENABLED and "error" not in response:
|
||||
|
||||
136
pr/core/http_client.py
Normal file
136
pr/core/http_client.py
Normal file
@ -0,0 +1,136 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
|
||||
class AsyncHTTPClient:
|
||||
def __init__(self):
|
||||
self.session_headers = {}
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[bytes] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
timeout: float = 30.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an async HTTP request using urllib in a thread executor with retry logic."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Prepare headers
|
||||
request_headers = {**self.session_headers}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# Prepare data
|
||||
request_data = data
|
||||
if json_data is not None:
|
||||
request_data = json.dumps(json_data).encode("utf-8")
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
# Create request object
|
||||
req = urllib.request.Request(url, data=request_data, headers=request_headers, method=method)
|
||||
|
||||
attempt = 0
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
try:
|
||||
# Execute in thread pool
|
||||
response = await loop.run_in_executor(
|
||||
None, lambda: urllib.request.urlopen(req, timeout=timeout)
|
||||
)
|
||||
response_data = await loop.run_in_executor(None, response.read)
|
||||
response_text = response_data.decode("utf-8")
|
||||
|
||||
return {
|
||||
"status": response.status,
|
||||
"headers": dict(response.headers),
|
||||
"text": response_text,
|
||||
"json": lambda: json.loads(response_text) if response_text else None,
|
||||
}
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = await loop.run_in_executor(None, e.read)
|
||||
error_text = error_body.decode("utf-8")
|
||||
return {
|
||||
"status": e.code,
|
||||
"error": True,
|
||||
"text": error_text,
|
||||
"json": lambda: json.loads(error_text) if error_text else None,
|
||||
}
|
||||
except socket.timeout:
|
||||
# Handle socket timeouts specifically
|
||||
elapsed = time.time() - start_time
|
||||
elapsed_minutes = int(elapsed // 60)
|
||||
elapsed_seconds = elapsed % 60
|
||||
duration_str = (
|
||||
f"{elapsed_minutes}m {elapsed_seconds:.1f}s"
|
||||
if elapsed_minutes > 0
|
||||
else f"{elapsed_seconds:.1f}s"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Request timed out (attempt {attempt}, "
|
||||
f"duration: {duration_str}). Retrying in {attempt} second(s)..."
|
||||
)
|
||||
|
||||
# Exponential backoff starting at 1 second
|
||||
await asyncio.sleep(attempt)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# For other exceptions, check if they might be timeout-related
|
||||
if "timed out" in error_msg.lower() or "timeout" in error_msg.lower():
|
||||
elapsed = time.time() - start_time
|
||||
elapsed_minutes = int(elapsed // 60)
|
||||
elapsed_seconds = elapsed % 60
|
||||
duration_str = (
|
||||
f"{elapsed_minutes}m {elapsed_seconds:.1f}s"
|
||||
if elapsed_minutes > 0
|
||||
else f"{elapsed_seconds:.1f}s"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Request timed out (attempt {attempt}, "
|
||||
f"duration: {duration_str}). Retrying in {attempt} second(s)..."
|
||||
)
|
||||
|
||||
# Exponential backoff starting at 1 second
|
||||
await asyncio.sleep(attempt)
|
||||
else:
|
||||
# Non-timeout errors should not be retried
|
||||
return {"error": True, "exception": error_msg}
|
||||
|
||||
async def get(
|
||||
self, url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0
|
||||
) -> Dict[str, Any]:
|
||||
return await self.request("GET", url, headers=headers, timeout=timeout)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[bytes] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
timeout: float = 30.0,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.request(
|
||||
"POST", url, headers=headers, data=data, json_data=json_data, timeout=timeout
|
||||
)
|
||||
|
||||
def set_default_headers(self, headers: Dict[str, str]):
|
||||
self.session_headers.update(headers)
|
||||
|
||||
|
||||
# Global client instance
|
||||
http_client = AsyncHTTPClient()
|
||||
585
pr/server.py
585
pr/server.py
@ -1,585 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import pathlib
|
||||
import ast
|
||||
from types import SimpleNamespace
|
||||
import time
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
|
||||
# --- Optional external dependency used in your original code ---
|
||||
from pr.ads import AsyncDataSet # noqa: E402
|
||||
|
||||
import sys
|
||||
from http import cookies
|
||||
import urllib.parse
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import urllib.request
|
||||
import pathlib
|
||||
import pickle
|
||||
import zlib
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import zlib
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, base_dir: Path | str = "."):
|
||||
self.base_dir = Path(base_dir).resolve()
|
||||
self.base_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def is_cacheable(self, obj):
|
||||
return isinstance(obj, (int, str, bool, float))
|
||||
|
||||
def generate_key(self, *args, **kwargs):
|
||||
|
||||
return zlib.crc32(
|
||||
json.dumps(
|
||||
{
|
||||
"args": [arg for arg in args if self.is_cacheable(arg)],
|
||||
"kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)},
|
||||
},
|
||||
sort_keys=True,
|
||||
default=str,
|
||||
).encode()
|
||||
)
|
||||
|
||||
def set(self, key, value):
|
||||
key = self.generate_key(key)
|
||||
data = {"value": value, "timestamp": time.time()}
|
||||
serialized_data = pickle.dumps(data)
|
||||
with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f:
|
||||
f.write(serialized_data)
|
||||
|
||||
def get(self, key, default=None):
|
||||
key = self.generate_key(key)
|
||||
try:
|
||||
with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f:
|
||||
data = pickle.loads(f.read())
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
return default
|
||||
|
||||
def is_cacheable(self, obj):
|
||||
return isinstance(obj, (int, str, bool, float))
|
||||
|
||||
def cached(self, expiration=60):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()):
|
||||
# return await func(*args, **kwargs)
|
||||
key = self.generate_key(*args, **kwargs)
|
||||
print("Cache hit:", key)
|
||||
cached_data = self.get(key)
|
||||
if cached_data is not None:
|
||||
if expiration is None or (time.time() - cached_data["timestamp"]) < expiration:
|
||||
return cached_data["value"]
|
||||
result = await func(*args, **kwargs)
|
||||
self.set(key, result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
cache = Cache()
|
||||
|
||||
|
||||
class CGI:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.environ = os.environ
|
||||
if not self.gateway_interface == "CGI/1.1":
|
||||
return
|
||||
self.status = 200
|
||||
self.headers = {}
|
||||
self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi"))
|
||||
self.headers["Content-Length"] = "0"
|
||||
self.headers["Cache-Control"] = "no-cache"
|
||||
self.headers["Access-Control-Allow-Origin"] = "*"
|
||||
self.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
self.cookies = cookies.SimpleCookie()
|
||||
|
||||
self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"])
|
||||
self.file = io.BytesIO()
|
||||
|
||||
def validate(self, val, fn, error):
|
||||
if fn(val):
|
||||
return True
|
||||
self.status = 421
|
||||
self.write({"type": "error", "message": error})
|
||||
exit()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.get(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
result = self.query.get(key, [default])
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def get_bool(self, key):
|
||||
return str(self.get(key)).lower() in ["true", "yes", "1"]
|
||||
|
||||
@property
|
||||
def cache_key(self):
|
||||
return self.environ["CACHE_KEY"]
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
env = os.environ.copy()
|
||||
return env
|
||||
|
||||
@property
|
||||
def gateway_interface(self):
|
||||
return self.env.get("GATEWAY_INTERFACE", "")
|
||||
|
||||
@property
|
||||
def request_method(self):
|
||||
return self.env["REQUEST_METHOD"]
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
return self.env["QUERY_STRING"]
|
||||
|
||||
@property
|
||||
def script_name(self):
|
||||
return self.env["SCRIPT_NAME"]
|
||||
|
||||
@property
|
||||
def path_info(self):
|
||||
return self.env["PATH_INFO"]
|
||||
|
||||
@property
|
||||
def server_name(self):
|
||||
return self.env["SERVER_NAME"]
|
||||
|
||||
@property
|
||||
def server_port(self):
|
||||
return self.env["SERVER_PORT"]
|
||||
|
||||
@property
|
||||
def server_protocol(self):
|
||||
return self.env["SERVER_PROTOCOL"]
|
||||
|
||||
@property
|
||||
def remote_addr(self):
|
||||
return self.env["REMOTE_ADDR"]
|
||||
|
||||
def validate_get(self, key):
|
||||
self.validate(self.get(key), lambda x: bool(x), f"Missing {key}")
|
||||
|
||||
def print(self, data):
|
||||
self.write(data)
|
||||
|
||||
def write(self, data):
|
||||
if not isinstance(data, bytes) and not isinstance(data, str):
|
||||
data = json.dumps(data, default=str, indent=4)
|
||||
try:
|
||||
data = data.encode()
|
||||
except:
|
||||
pass
|
||||
self.file.write(data)
|
||||
self.headers["Content-Length"] = str(len(self.file.getvalue()))
|
||||
|
||||
@property
|
||||
def http_status(self):
|
||||
return f"HTTP/1.1 {self.status}\r\n".encode("utf-8")
|
||||
|
||||
@property
|
||||
def http_headers(self):
|
||||
headers = io.BytesIO()
|
||||
for header in self.headers:
|
||||
headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8"))
|
||||
headers.write(b"\r\n")
|
||||
return headers.getvalue()
|
||||
|
||||
@property
|
||||
def http_body(self):
|
||||
return self.file.getvalue()
|
||||
|
||||
@property
|
||||
def http_response(self):
|
||||
return self.http_status + self.http_headers + self.http_body
|
||||
|
||||
def flush(self, response=None):
|
||||
|
||||
if response:
|
||||
try:
|
||||
response = response.encode()
|
||||
except:
|
||||
pass
|
||||
sys.stdout.buffer.write(response)
|
||||
sys.stdout.buffer.flush()
|
||||
return
|
||||
sys.stdout.buffer.write(self.http_response)
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
def __del__(self):
|
||||
if self.http_body:
|
||||
self.flush()
|
||||
exit()
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Utilities
|
||||
# -------------------------------
|
||||
def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]:
|
||||
matches: List[Tuple[str, str]] = []
|
||||
for root, _, files in os.walk(directory):
|
||||
for file in files:
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
path = os.path.join(root, file)
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
source = fh.read()
|
||||
tree = ast.parse(source, filename=path)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.AsyncFunctionDef) and node.name == name:
|
||||
func_src = ast.get_source_segment(source, node)
|
||||
if func_src:
|
||||
matches.append((path, func_src))
|
||||
break
|
||||
except (SyntaxError, UnicodeDecodeError):
|
||||
continue
|
||||
return matches
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Server
|
||||
# -------------------------------
|
||||
class Static(SimpleNamespace):
|
||||
pass
|
||||
|
||||
|
||||
class View(web.View):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.db = self.request.app["db"]
|
||||
self.server = self.request.app["server"]
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return aiohttp.ClientSession()
|
||||
|
||||
|
||||
class RetoorServer:
|
||||
def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None:
|
||||
self.base_dir = Path(base_dir).resolve()
|
||||
self.port = port
|
||||
|
||||
self.static = Static()
|
||||
self.db = AsyncDataSet(".default.db")
|
||||
|
||||
self._logger = logging.getLogger("retoor.server")
|
||||
self._logger.setLevel(logging.INFO)
|
||||
if not self._logger.handlers:
|
||||
h = logging.StreamHandler(sys.stdout)
|
||||
fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
|
||||
h.setFormatter(fmt)
|
||||
self._logger.addHandler(h)
|
||||
|
||||
self._func_cache: Dict[str, Tuple[str, str]] = {}
|
||||
self._compiled_cache: Dict[str, object] = {}
|
||||
self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {}
|
||||
self._static_path_cache: Dict[str, Path] = {}
|
||||
self._base_dir_str = str(self.base_dir)
|
||||
|
||||
self.app = web.Application()
|
||||
|
||||
app = self.app
|
||||
|
||||
app["db"] = self.db
|
||||
for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir():
|
||||
if path.suffix == ".py":
|
||||
exec(open(path).read(), globals(), locals())
|
||||
|
||||
self.app["server"] = self
|
||||
|
||||
# from rpanel import create_app as create_rpanel_app
|
||||
# self.app.add_subapp("/api/{tail:.*}", create_rpanel_app())
|
||||
|
||||
self.app.router.add_route("*", "/{tail:.*}", self.handle_any)
|
||||
|
||||
# ---------------------------
|
||||
# Simple endpoints
|
||||
# ---------------------------
|
||||
async def handle_root(self, request: web.Request) -> web.Response:
|
||||
return web.Response(text="Static file and cgi server from retoor.")
|
||||
|
||||
async def handle_hello(self, request: web.Request) -> web.Response:
|
||||
return web.Response(text="Welcome to the custom HTTP server!")
|
||||
|
||||
# ---------------------------
|
||||
# Dynamic function dispatch
|
||||
# ---------------------------
|
||||
def _path_to_funcname(self, path: str) -> str:
|
||||
# /foo/bar -> foo_bar ; strip leading/trailing slashes safely
|
||||
return path.strip("/").replace("/", "_")
|
||||
|
||||
async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]:
|
||||
relpath = request.path.strip("/")
|
||||
relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace(
|
||||
"..", ""
|
||||
)
|
||||
if not relpath:
|
||||
return None
|
||||
|
||||
last_part = relpath.split("/")[-1]
|
||||
if "." in last_part:
|
||||
return None
|
||||
|
||||
funcname = self._path_to_funcname(request.path)
|
||||
|
||||
if funcname not in self._compiled_cache:
|
||||
entry = self._func_cache.get(funcname)
|
||||
if entry is None:
|
||||
matches = get_function_source(funcname, directory=str(self.base_dir))
|
||||
if not matches:
|
||||
return None
|
||||
entry = matches[0]
|
||||
self._func_cache[funcname] = entry
|
||||
|
||||
filepath, source = entry
|
||||
code_obj = compile(source, filepath, "exec")
|
||||
self._compiled_cache[funcname] = (code_obj, filepath)
|
||||
|
||||
code_obj, filepath = self._compiled_cache[funcname]
|
||||
ctx = {
|
||||
"static": self.static,
|
||||
"request": request,
|
||||
"relpath": relpath,
|
||||
"os": os,
|
||||
"sys": sys,
|
||||
"web": web,
|
||||
"asyncio": asyncio,
|
||||
"subprocess": subprocess,
|
||||
"urllib": urllib.parse,
|
||||
"app": request.app,
|
||||
"db": self.db,
|
||||
}
|
||||
exec(code_obj, ctx, ctx)
|
||||
coro = ctx.get(funcname)
|
||||
if not callable(coro):
|
||||
return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.")
|
||||
return await coro(request)
|
||||
|
||||
# ---------------------------
|
||||
# Static files
|
||||
# ---------------------------
|
||||
async def handle_static(self, request: web.Request) -> web.StreamResponse:
|
||||
path = request.path
|
||||
|
||||
if path in self._static_path_cache:
|
||||
cached_path = self._static_path_cache[path]
|
||||
if cached_path is None:
|
||||
return web.Response(status=404, text="Not found")
|
||||
try:
|
||||
return web.FileResponse(path=cached_path)
|
||||
except Exception as e:
|
||||
return web.Response(status=500, text=f"Failed to send file: {e!r}")
|
||||
|
||||
relpath = path.replace('..",', "").lstrip("/").rstrip("/")
|
||||
abspath = (self.base_dir / relpath).resolve()
|
||||
|
||||
if not str(abspath).startswith(self._base_dir_str):
|
||||
return web.Response(status=403, text="Forbidden")
|
||||
|
||||
if not abspath.exists():
|
||||
self._static_path_cache[path] = None
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
if abspath.is_dir():
|
||||
return web.Response(status=403, text="Directory listing forbidden")
|
||||
|
||||
self._static_path_cache[path] = abspath
|
||||
try:
|
||||
return web.FileResponse(path=abspath)
|
||||
except Exception as e:
|
||||
return web.Response(status=500, text=f"Failed to send file: {e!r}")
|
||||
|
||||
# ---------------------------
|
||||
# CGI
|
||||
# ---------------------------
|
||||
def _find_cgi_script(
|
||||
self, path_only: str
|
||||
) -> Tuple[Optional[Path], Optional[str], Optional[str]]:
|
||||
path_only = "pr/cgi/" + path_only
|
||||
if path_only in self._cgi_cache:
|
||||
return self._cgi_cache[path_only]
|
||||
|
||||
split_path = [p for p in path_only.split("/") if p]
|
||||
for i in range(len(split_path), -1, -1):
|
||||
candidate = "/" + "/".join(split_path[:i])
|
||||
candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve()
|
||||
if (
|
||||
str(candidate_fs).startswith(self._base_dir_str)
|
||||
and candidate_fs.parent.is_dir()
|
||||
and candidate_fs.parent.name == "cgi"
|
||||
and candidate_fs.suffix == ".bin"
|
||||
and candidate_fs.is_file()
|
||||
and os.access(candidate_fs, os.X_OK)
|
||||
):
|
||||
script_name = candidate if candidate.startswith("/") else "/" + candidate
|
||||
path_info = path_only[len(script_name) :]
|
||||
result = (candidate_fs, script_name, path_info)
|
||||
self._cgi_cache[path_only] = result
|
||||
return result
|
||||
|
||||
result = (None, None, None)
|
||||
self._cgi_cache[path_only] = result
|
||||
return result
|
||||
|
||||
async def _run_cgi(
|
||||
self, request: web.Request, script_path: Path, script_name: str, path_info: str
|
||||
) -> web.Response:
|
||||
start_time = time.time()
|
||||
method = request.method
|
||||
query = request.query_string
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CACHE_KEY"] = json.dumps(
|
||||
{"method": method, "query": query, "script_name": script_name, "path_info": path_info},
|
||||
default=str,
|
||||
)
|
||||
env["GATEWAY_INTERFACE"] = "CGI/1.1"
|
||||
env["REQUEST_METHOD"] = method
|
||||
env["QUERY_STRING"] = query
|
||||
env["SCRIPT_NAME"] = script_name
|
||||
env["PATH_INFO"] = path_info
|
||||
|
||||
host_parts = request.host.split(":", 1)
|
||||
env["SERVER_NAME"] = host_parts[0]
|
||||
env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80"
|
||||
env["SERVER_PROTOCOL"] = "HTTP/1.1"
|
||||
|
||||
peername = request.transport.get_extra_info("peername")
|
||||
env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0])
|
||||
|
||||
for hk, hv in request.headers.items():
|
||||
hk_lower = hk.lower()
|
||||
if hk_lower == "content-type":
|
||||
env["CONTENT_TYPE"] = hv
|
||||
elif hk_lower == "content-length":
|
||||
env["CONTENT_LENGTH"] = hv
|
||||
else:
|
||||
env["HTTP_" + hk.upper().replace("-", "_")] = hv
|
||||
|
||||
post_data = None
|
||||
if method in {"POST", "PUT", "PATCH"}:
|
||||
post_data = await request.read()
|
||||
env.setdefault("CONTENT_LENGTH", str(len(post_data)))
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(script_path),
|
||||
stdin=asyncio.subprocess.PIPE if post_data else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
stdout, stderr = await proc.communicate(input=post_data)
|
||||
|
||||
if proc.returncode != 0:
|
||||
msg = stderr.decode(errors="ignore")
|
||||
return web.Response(status=500, text=f"CGI script error:\n{msg}")
|
||||
|
||||
header_end = stdout.find(b"\r\n\r\n")
|
||||
if header_end != -1:
|
||||
offset = 4
|
||||
else:
|
||||
header_end = stdout.find(b"\n\n")
|
||||
offset = 2
|
||||
|
||||
if header_end != -1:
|
||||
body = stdout[header_end + offset :]
|
||||
headers_blob = stdout[:header_end].decode(errors="ignore")
|
||||
|
||||
status_code = 200
|
||||
headers: Dict[str, str] = {}
|
||||
for line in headers_blob.splitlines():
|
||||
line = line.strip()
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
|
||||
k, v = line.split(":", 1)
|
||||
k_stripped = k.strip()
|
||||
if k_stripped.lower() == "status":
|
||||
try:
|
||||
status_code = int(v.strip().split()[0])
|
||||
except Exception:
|
||||
status_code = 200
|
||||
else:
|
||||
headers[k_stripped] = v.strip()
|
||||
|
||||
end_time = time.time()
|
||||
headers["X-Powered-By"] = "retoor"
|
||||
headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime())
|
||||
headers["X-Time"] = str(end_time)
|
||||
headers["X-Duration"] = str(end_time - start_time)
|
||||
if "error" in str(body).lower():
|
||||
print(headers)
|
||||
print(body)
|
||||
return web.Response(body=body, headers=headers, status=status_code)
|
||||
else:
|
||||
return web.Response(body=stdout)
|
||||
|
||||
except Exception as e:
|
||||
return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}")
|
||||
|
||||
# ---------------------------
|
||||
# Main router
|
||||
# ---------------------------
|
||||
async def handle_any(self, request: web.Request) -> web.StreamResponse:
|
||||
path = request.path
|
||||
|
||||
if path == "/":
|
||||
return await self.handle_root(request)
|
||||
if path == "/hello":
|
||||
return await self.handle_hello(request)
|
||||
|
||||
script_path, script_name, path_info = self._find_cgi_script(path)
|
||||
if script_path:
|
||||
return await self._run_cgi(request, script_path, script_name or "", path_info or "")
|
||||
|
||||
dyn = await self._maybe_dispatch_dynamic(request)
|
||||
if dyn is not None:
|
||||
return dyn
|
||||
|
||||
return await self.handle_static(request)
|
||||
|
||||
# ---------------------------
|
||||
# Runner
|
||||
# ---------------------------
|
||||
def run(self) -> None:
|
||||
self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir)
|
||||
web.run_app(self.app, port=self.port)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
server = RetoorServer(base_dir=".", port=8118)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
else:
|
||||
cgi = CGI()
|
||||
@ -43,11 +43,6 @@ from pr.tools.memory import (
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.tools.context_modifier import (
|
||||
modify_context_add,
|
||||
modify_context_replace,
|
||||
modify_context_delete,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"add_knowledge_entry",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@ -16,7 +17,15 @@ def _create_api_wrapper():
|
||||
tools_definition = get_tools_definition() if use_tools else []
|
||||
|
||||
def api_wrapper(messages, temperature=None, max_tokens=None, **kwargs):
|
||||
return call_api(
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an event loop, use run_coroutine_threadsafe
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
call_api(
|
||||
messages=messages,
|
||||
model=model,
|
||||
api_url=api_url,
|
||||
@ -24,6 +33,22 @@ def _create_api_wrapper():
|
||||
use_tools=use_tools,
|
||||
tools_definition=tools_definition,
|
||||
verbose=False,
|
||||
),
|
||||
loop,
|
||||
)
|
||||
return future.result()
|
||||
except RuntimeError:
|
||||
# No event loop running, use asyncio.run
|
||||
return asyncio.run(
|
||||
call_api(
|
||||
messages=messages,
|
||||
model=model,
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
use_tools=use_tools,
|
||||
tools_definition=tools_definition,
|
||||
verbose=False,
|
||||
)
|
||||
)
|
||||
|
||||
return api_wrapper
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
CONTEXT_FILE = '/home/retoor/.local/share/rp/.rcontext.txt'
|
||||
CONTEXT_FILE = "/home/retoor/.local/share/rp/.rcontext.txt"
|
||||
|
||||
|
||||
def _read_context() -> str:
|
||||
if not os.path.exists(CONTEXT_FILE):
|
||||
raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.")
|
||||
with open(CONTEXT_FILE, 'r') as f:
|
||||
with open(CONTEXT_FILE, "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _write_context(content: str):
|
||||
with open(CONTEXT_FILE, 'w') as f:
|
||||
with open(CONTEXT_FILE, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def modify_context_add(new_content: str, position: Optional[str] = None) -> str:
|
||||
"""
|
||||
Add new content to the .rcontext.txt file.
|
||||
@ -25,13 +28,14 @@ def modify_context_add(new_content: str, position: Optional[str] = None) -> str:
|
||||
if position and position in current:
|
||||
# Insert before the position
|
||||
parts = current.split(position, 1)
|
||||
updated = parts[0] + new_content + '\n\n' + position + parts[1]
|
||||
updated = parts[0] + new_content + "\n\n" + position + parts[1]
|
||||
else:
|
||||
# Append at the end
|
||||
updated = current + '\n\n' + new_content
|
||||
updated = current + "\n\n" + new_content
|
||||
_write_context(updated)
|
||||
return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested."
|
||||
|
||||
|
||||
def modify_context_replace(old_content: str, new_content: str) -> str:
|
||||
"""
|
||||
Replace old content with new content in .rcontext.txt.
|
||||
@ -47,6 +51,7 @@ def modify_context_replace(old_content: str, new_content: str) -> str:
|
||||
_write_context(updated)
|
||||
return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects."
|
||||
|
||||
|
||||
def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str:
|
||||
"""
|
||||
Delete content from .rcontext.txt, but only if confirmed.
|
||||
@ -56,10 +61,12 @@ def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> st
|
||||
confirmed: Must be True to proceed with deletion.
|
||||
"""
|
||||
if not confirmed:
|
||||
raise PermissionError(f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently.")
|
||||
raise PermissionError(
|
||||
f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently."
|
||||
)
|
||||
current = _read_context()
|
||||
if content_to_delete not in current:
|
||||
raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...")
|
||||
updated = current.replace(content_to_delete, '', 1)
|
||||
updated = current.replace(content_to_delete, "", 1)
|
||||
_write_context(updated)
|
||||
return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines."
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from pr.ui.colors import Colors
|
||||
from pr.ui.colors import Colors, Spinner
|
||||
from pr.ui.display import display_tool_call, print_autonomous_header
|
||||
from pr.ui.rendering import highlight_code, render_markdown
|
||||
|
||||
__all__ = [
|
||||
"Colors",
|
||||
"Spinner",
|
||||
"highlight_code",
|
||||
"render_markdown",
|
||||
"display_tool_call",
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
@ -12,3 +15,30 @@ class Colors:
|
||||
BG_BLUE = "\033[44m"
|
||||
BG_GREEN = "\033[42m"
|
||||
BG_RED = "\033[41m"
|
||||
|
||||
|
||||
class Spinner:
|
||||
def __init__(self, message="Processing...", spinner_chars="|/-\\"):
|
||||
self.message = message
|
||||
self.spinner_chars = spinner_chars
|
||||
self.running = False
|
||||
self.task = None
|
||||
|
||||
async def start(self):
|
||||
self.running = True
|
||||
self.task = asyncio.create_task(self._spin())
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
if self.task:
|
||||
await self.task
|
||||
# Clear the line
|
||||
print("\r" + " " * (len(self.message) + 2) + "\r", end="", flush=True)
|
||||
|
||||
async def _spin(self):
|
||||
i = 0
|
||||
while self.running:
|
||||
char = self.spinner_chars[i % len(self.spinner_chars)]
|
||||
print(f"\r{Colors.CYAN}{char}{Colors.RESET} {self.message}", end="", flush=True)
|
||||
i += 1
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "rp"
|
||||
version = "1.4.0"
|
||||
version = "1.5.0"
|
||||
description = "R python edition. The ultimate autonomous AI CLI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
@ -14,7 +14,6 @@ authors = [
|
||||
{name = "retoor", email = "retoor@molodetz.nl"}
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp>=3.13.2",
|
||||
"pydantic>=2.12.3",
|
||||
"prompt_toolkit>=3.0.0",
|
||||
]
|
||||
@ -32,8 +31,6 @@ classifiers = [
|
||||
dev = [
|
||||
"pytest>=8.3.0",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-aiohttp>=1.1.0",
|
||||
"aiohttp>=3.13.2",
|
||||
"pytest-cov>=7.0.0",
|
||||
"black>=25.9.0",
|
||||
"flake8>=7.3.0",
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pr.ads import AsyncDataSet
|
||||
from pr.web.app import create_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -44,77 +41,3 @@ def sample_context_file(temp_dir):
|
||||
with open(context_path, "w") as f:
|
||||
f.write("Sample context content\n")
|
||||
return context_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(aiohttp_client, monkeypatch):
|
||||
"""Create a test client for the app."""
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp:
|
||||
temp_db_file = tmp.name
|
||||
|
||||
# Monkeypatch the db
|
||||
monkeypatch.setattr("pr.web.views.base.db", AsyncDataSet(temp_db_file, f"{temp_db_file}.sock"))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
static_dir = tmp_path / "static"
|
||||
templates_dir = tmp_path / "templates"
|
||||
repos_dir = tmp_path / "repos"
|
||||
|
||||
static_dir.mkdir()
|
||||
templates_dir.mkdir()
|
||||
repos_dir.mkdir()
|
||||
|
||||
# Monkeypatch the directories
|
||||
monkeypatch.setattr("pr.web.config.STATIC_DIR", static_dir)
|
||||
monkeypatch.setattr("pr.web.config.TEMPLATES_DIR", templates_dir)
|
||||
monkeypatch.setattr("pr.web.config.REPOS_DIR", repos_dir)
|
||||
monkeypatch.setattr("pr.web.config.REPOS_DIR", repos_dir)
|
||||
|
||||
# Create minimal templates
|
||||
(templates_dir / "index.html").write_text("<html>Index</html>")
|
||||
(templates_dir / "login.html").write_text("<html>Login</html>")
|
||||
(templates_dir / "register.html").write_text("<html>Register</html>")
|
||||
(templates_dir / "dashboard.html").write_text("<html>Dashboard</html>")
|
||||
(templates_dir / "repos.html").write_text("<html>Repos</html>")
|
||||
(templates_dir / "api_keys.html").write_text("<html>API Keys</html>")
|
||||
(templates_dir / "repo.html").write_text("<html>Repo</html>")
|
||||
(templates_dir / "file.html").write_text("<html>File</html>")
|
||||
(templates_dir / "edit_file.html").write_text("<html>Edit File</html>")
|
||||
(templates_dir / "deploy.html").write_text("<html>Deploy</html>")
|
||||
|
||||
app_instance = await create_app()
|
||||
client = await aiohttp_client(app_instance)
|
||||
|
||||
yield client
|
||||
|
||||
# Cleanup db
|
||||
os.unlink(temp_db_file)
|
||||
sock_file = f"{temp_db_file}.sock"
|
||||
if os.path.exists(sock_file):
|
||||
os.unlink(sock_file)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def authenticated_client(client):
|
||||
"""Create a client with an authenticated user."""
|
||||
# Register a user
|
||||
resp = await client.post(
|
||||
"/register",
|
||||
data={
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "password123",
|
||||
"confirm_password": "password123",
|
||||
},
|
||||
allow_redirects=False,
|
||||
)
|
||||
assert resp.status == 302 # Redirect to login
|
||||
|
||||
# Login
|
||||
resp = await client.post(
|
||||
"/login", data={"username": "testuser", "password": "password123"}, allow_redirects=False
|
||||
)
|
||||
assert resp.status == 302 # Redirect to dashboard
|
||||
|
||||
return client
|
||||
|
||||
Loading…
Reference in New Issue
Block a user