From 66a45eb5e60610a82a11ed1dabf77950f114f2e6 Mon Sep 17 00:00:00 2001 From: retoor Date: Thu, 6 Nov 2025 16:44:41 +0100 Subject: [PATCH] feat: implement distributed ads with unix sockets --- CHANGELOG.md | 8 + pr/__init__.py | 3 +- pr/__main__.py | 21 +- pr/ads.py | 969 ---------------------------------- pr/autonomous/mode.py | 91 +++- pr/commands/handlers.py | 54 +- pr/core/api.py | 86 +-- pr/core/assistant.py | 79 ++- pr/core/enhanced_assistant.py | 85 ++- pr/core/http_client.py | 136 +++++ pr/server.py | 585 -------------------- pr/tools/__init__.py | 5 - pr/tools/agents.py | 43 +- pr/tools/context_modifier.py | 21 +- pr/ui/__init__.py | 3 +- pr/ui/colors.py | 30 ++ pyproject.toml | 5 +- rp.py | 1 - tests/conftest.py | 77 --- 19 files changed, 531 insertions(+), 1771 deletions(-) delete mode 100644 pr/ads.py create mode 100644 pr/core/http_client.py delete mode 100644 pr/server.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2866228..4e6e38d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ + +## Version 1.5.0 - 2025-11-06 + +Agents can now use datasets stored across multiple locations. This allows for larger datasets and improved performance. + +**Changes:** 2 files, 10 lines +**Languages:** Markdown (8 lines), TOML (2 lines) + ## Version 1.4.0 - 2025-11-06 Agents can now share data more efficiently using a new distributed dataset system. This improves performance and allows agents to work together on larger tasks. diff --git a/pr/__init__.py b/pr/__init__.py index f41af4c..6283a51 100644 --- a/pr/__init__.py +++ b/pr/__init__.py @@ -1,4 +1,5 @@ +__version__ = "1.0.0" + from pr.core import Assistant -__version__ = "1.0.0" __all__ = ["Assistant"] diff --git a/pr/__main__.py b/pr/__main__.py index d2eb216..96f842c 100644 --- a/pr/__main__.py +++ b/pr/__main__.py @@ -1,13 +1,18 @@ import argparse +import asyncio import sys from pr import __version__ from pr.core import Assistant -def main(): +async def main_async(): + import tracemalloc + + tracemalloc.start() + parser = argparse.ArgumentParser( - description="PR Assistant - Professional CLI AI assistant with autonomous execution", + description="PR Assistant - Professional CLI AI assistant with visual effects, cost tracking, and autonomous execution", epilog=""" Examples: pr "What is Python?" # Single query @@ -18,6 +23,12 @@ Examples: pr --list-sessions # List all sessions pr --usage # Show token usage stats +Features: + • Visual progress indicators during AI calls + • Real-time cost tracking for each query + • Sophisticated CLI with colors and effects + • Tool execution with status updates + Commands in interactive mode: /auto [task] - Enter autonomous mode /reset - Clear message history @@ -156,7 +167,11 @@ Commands in interactive mode: return assistant = Assistant(args) - assistant.run() + await assistant.run() + + +def main(): + return asyncio.run(main_async()) if __name__ == "__main__": diff --git a/pr/ads.py b/pr/ads.py deleted file mode 100644 index a7941c2..0000000 --- a/pr/ads.py +++ /dev/null @@ -1,969 +0,0 @@ -import re -import json -from uuid import uuid4 -from datetime import datetime, timezone -from typing import ( - Any, - Dict, - Iterable, - List, - Optional, - AsyncGenerator, - Union, - Tuple, - Set, -) -import aiosqlite -import asyncio -import aiohttp -from aiohttp import web -import socket -import os -import atexit - - -class AsyncDataSet: - """ - Distributed AsyncDataSet with client-server model over Unix sockets. - - Parameters: - - file: Path to the SQLite database file - - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") - - max_concurrent_queries: Maximum concurrent database operations (default: 1) - Set to 1 to prevent "database is locked" errors with SQLite - """ - - _KV_TABLE = "__kv_store" - _DEFAULT_COLUMNS = { - "uid": "TEXT PRIMARY KEY", - "created_at": "TEXT", - "updated_at": "TEXT", - "deleted_at": "TEXT", - } - - def __init__(self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1): - self._file = file - self._socket_path = socket_path - self._table_columns_cache: Dict[str, Set[str]] = {} - self._is_server = None # None means not initialized yet - self._server = None - self._runner = None - self._client_session = None - self._initialized = False - self._init_lock = None # Will be created when needed - self._max_concurrent_queries = max_concurrent_queries - self._db_semaphore = None # Will be created when needed - self._db_connection = None # Persistent database connection - - async def _ensure_initialized(self): - """Ensure the instance is initialized as server or client.""" - if self._initialized: - return - - # Create lock if needed (first time in async context) - if self._init_lock is None: - self._init_lock = asyncio.Lock() - - async with self._init_lock: - if self._initialized: - return - - await self._initialize() - - # Create semaphore for database operations (server only) - if self._is_server: - self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) - - self._initialized = True - - async def _initialize(self): - """Initialize as server or client.""" - try: - # Try to create Unix socket - if os.path.exists(self._socket_path): - # Check if socket is active - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - sock.connect(self._socket_path) - sock.close() - # Socket is active, we're a client - self._is_server = False - await self._setup_client() - except (ConnectionRefusedError, FileNotFoundError): - # Socket file exists but not active, remove and become server - os.unlink(self._socket_path) - self._is_server = True - await self._setup_server() - else: - # No socket file, become server - self._is_server = True - await self._setup_server() - except Exception: - # Fallback to client mode - self._is_server = False - await self._setup_client() - - # Establish persistent database connection - self._db_connection = await aiosqlite.connect(self._file) - - async def _setup_server(self): - """Setup aiohttp server.""" - app = web.Application() - - # Add routes for all methods - app.router.add_post("/insert", self._handle_insert) - app.router.add_post("/update", self._handle_update) - app.router.add_post("/delete", self._handle_delete) - app.router.add_post("/upsert", self._handle_upsert) - app.router.add_post("/get", self._handle_get) - app.router.add_post("/find", self._handle_find) - app.router.add_post("/count", self._handle_count) - app.router.add_post("/exists", self._handle_exists) - app.router.add_post("/kv_set", self._handle_kv_set) - app.router.add_post("/kv_get", self._handle_kv_get) - app.router.add_post("/execute_raw", self._handle_execute_raw) - app.router.add_post("/query_raw", self._handle_query_raw) - app.router.add_post("/query_one", self._handle_query_one) - app.router.add_post("/create_table", self._handle_create_table) - app.router.add_post("/insert_unique", self._handle_insert_unique) - app.router.add_post("/aggregate", self._handle_aggregate) - - self._runner = web.AppRunner(app) - await self._runner.setup() - self._server = web.UnixSite(self._runner, self._socket_path) - await self._server.start() - - # Register cleanup - def cleanup(): - if os.path.exists(self._socket_path): - os.unlink(self._socket_path) - - atexit.register(cleanup) - - async def _setup_client(self): - """Setup aiohttp client.""" - connector = aiohttp.UnixConnector(path=self._socket_path) - self._client_session = aiohttp.ClientSession(connector=connector) - - async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: - """Make HTTP request to server.""" - if not self._client_session: - await self._setup_client() - - url = f"http://localhost/{endpoint}" - async with self._client_session.post(url, json=data) as resp: - result = await resp.json() - if result.get("error"): - raise Exception(result["error"]) - return result.get("result") - - # Server handlers - async def _handle_insert(self, request): - data = await request.json() - try: - result = await self._server_insert( - data["table"], data["args"], data.get("return_id", False) - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_update(self, request): - data = await request.json() - try: - result = await self._server_update(data["table"], data["args"], data.get("where")) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_delete(self, request): - data = await request.json() - try: - result = await self._server_delete(data["table"], data.get("where")) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_upsert(self, request): - data = await request.json() - try: - result = await self._server_upsert(data["table"], data["args"], data.get("where")) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_get(self, request): - data = await request.json() - try: - result = await self._server_get(data["table"], data.get("where")) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_find(self, request): - data = await request.json() - try: - result = await self._server_find( - data["table"], - data.get("where"), - limit=data.get("limit", 0), - offset=data.get("offset", 0), - order_by=data.get("order_by"), - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_count(self, request): - data = await request.json() - try: - result = await self._server_count(data["table"], data.get("where")) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_exists(self, request): - data = await request.json() - try: - result = await self._server_exists(data["table"], data["where"]) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_kv_set(self, request): - data = await request.json() - try: - await self._server_kv_set(data["key"], data["value"], table=data.get("table")) - return web.json_response({"result": None}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_kv_get(self, request): - data = await request.json() - try: - result = await self._server_kv_get( - data["key"], default=data.get("default"), table=data.get("table") - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_execute_raw(self, request): - data = await request.json() - try: - result = await self._server_execute_raw( - data["sql"], - tuple(data.get("params", [])) if data.get("params") else None, - ) - return web.json_response( - {"result": result.rowcount if hasattr(result, "rowcount") else None} - ) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_query_raw(self, request): - data = await request.json() - try: - result = await self._server_query_raw( - data["sql"], - tuple(data.get("params", [])) if data.get("params") else None, - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_query_one(self, request): - data = await request.json() - try: - result = await self._server_query_one( - data["sql"], - tuple(data.get("params", [])) if data.get("params") else None, - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_create_table(self, request): - data = await request.json() - try: - await self._server_create_table(data["table"], data["schema"], data.get("constraints")) - return web.json_response({"result": None}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_insert_unique(self, request): - data = await request.json() - try: - result = await self._server_insert_unique( - data["table"], data["args"], data["unique_fields"] - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - async def _handle_aggregate(self, request): - data = await request.json() - try: - result = await self._server_aggregate( - data["table"], - data["function"], - data.get("column", "*"), - data.get("where"), - ) - return web.json_response({"result": result}) - except Exception as e: - return web.json_response({"error": str(e)}, status=500) - - # Public methods that delegate to server or client - async def insert( - self, table: str, args: Dict[str, Any], return_id: bool = False - ) -> Union[str, int]: - await self._ensure_initialized() - if self._is_server: - return await self._server_insert(table, args, return_id) - else: - return await self._make_request( - "insert", {"table": table, "args": args, "return_id": return_id} - ) - - async def update( - self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None - ) -> int: - await self._ensure_initialized() - if self._is_server: - return await self._server_update(table, args, where) - else: - return await self._make_request( - "update", {"table": table, "args": args, "where": where} - ) - - async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: - await self._ensure_initialized() - if self._is_server: - return await self._server_delete(table, where) - else: - return await self._make_request("delete", {"table": table, "where": where}) - - async def get( - self, table: str, where: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: - await self._ensure_initialized() - if self._is_server: - return await self._server_get(table, where) - else: - return await self._make_request("get", {"table": table, "where": where}) - - async def find( - self, - table: str, - where: Optional[Dict[str, Any]] = None, - *, - limit: int = 0, - offset: int = 0, - order_by: Optional[str] = None, - ) -> List[Dict[str, Any]]: - await self._ensure_initialized() - if self._is_server: - return await self._server_find( - table, where, limit=limit, offset=offset, order_by=order_by - ) - else: - return await self._make_request( - "find", - { - "table": table, - "where": where, - "limit": limit, - "offset": offset, - "order_by": order_by, - }, - ) - - async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: - await self._ensure_initialized() - if self._is_server: - return await self._server_count(table, where) - else: - return await self._make_request("count", {"table": table, "where": where}) - - async def exists(self, table: str, where: Dict[str, Any]) -> bool: - await self._ensure_initialized() - if self._is_server: - return await self._server_exists(table, where) - else: - return await self._make_request("exists", {"table": table, "where": where}) - - async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]: - await self._ensure_initialized() - if self._is_server: - return await self._server_query_raw(sql, params) - else: - return await self._make_request( - "query_raw", {"sql": sql, "params": list(params) if params else None} - ) - - async def create_table( - self, - table: str, - schema: Dict[str, str], - constraints: Optional[List[str]] = None, - ): - await self._ensure_initialized() - if self._is_server: - return await self._server_create_table(table, schema, constraints) - else: - return await self._make_request( - "create_table", - {"table": table, "schema": schema, "constraints": constraints}, - ) - - async def insert_unique( - self, table: str, args: Dict[str, Any], unique_fields: List[str] - ) -> Union[str, None]: - await self._ensure_initialized() - if self._is_server: - return await self._server_insert_unique(table, args, unique_fields) - else: - return await self._make_request( - "insert_unique", - {"table": table, "args": args, "unique_fields": unique_fields}, - ) - - async def aggregate( - self, - table: str, - function: str, - column: str = "*", - where: Optional[Dict[str, Any]] = None, - ) -> Any: - await self._ensure_initialized() - if self._is_server: - return await self._server_aggregate(table, function, column, where) - else: - return await self._make_request( - "aggregate", - { - "table": table, - "function": function, - "column": column, - "where": where, - }, - ) - - async def transaction(self): - """Context manager for transactions.""" - await self._ensure_initialized() - if self._is_server: - return TransactionContext(self._db_connection, self._db_semaphore) - else: - raise NotImplementedError("Transactions not supported in client mode") - - # Server implementation methods (original logic) - @staticmethod - def _utc_iso() -> str: - return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") - - @staticmethod - def _py_to_sqlite_type(value: Any) -> str: - if value is None: - return "TEXT" - if isinstance(value, bool): - return "INTEGER" - if isinstance(value, int): - return "INTEGER" - if isinstance(value, float): - return "REAL" - if isinstance(value, (bytes, bytearray, memoryview)): - return "BLOB" - return "TEXT" - - async def _get_table_columns(self, table: str) -> Set[str]: - """Get actual columns that exist in the table.""" - if table in self._table_columns_cache: - return self._table_columns_cache[table] - - columns = set() - try: - async with self._db_semaphore: - async with self._db_connection.execute(f"PRAGMA table_info({table})") as cursor: - async for row in cursor: - columns.add(row[1]) # Column name is at index 1 - self._table_columns_cache[table] = columns - except: - pass - return columns - - async def _invalidate_column_cache(self, table: str): - """Invalidate column cache for a table.""" - if table in self._table_columns_cache: - del self._table_columns_cache[table] - - async def _ensure_column(self, table: str, name: str, value: Any) -> None: - col_type = self._py_to_sqlite_type(value) - try: - async with self._db_semaphore: - await self._db_connection.execute( - f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" - ) - await self._db_connection.commit() - await self._invalidate_column_cache(table) - except aiosqlite.OperationalError as e: - if "duplicate column name" in str(e).lower(): - pass # Column already exists - else: - raise - - async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: - # Always include default columns - cols = self._DEFAULT_COLUMNS.copy() - - # Add columns from col_sources - for key, val in col_sources.items(): - if key not in cols: - cols[key] = self._py_to_sqlite_type(val) - - columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) - async with self._db_semaphore: - await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") - await self._db_connection.commit() - await self._invalidate_column_cache(table) - - async def _table_exists(self, table: str) -> bool: - """Check if a table exists.""" - async with self._db_semaphore: - async with self._db_connection.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) - ) as cursor: - return await cursor.fetchone() is not None - - _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") - _RE_NO_TABLE = re.compile(r"no such table: (\w+)") - - @classmethod - def _missing_column_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]: - m = cls._RE_NO_COLUMN.search(str(err)) - return m.group(1) if m else None - - @classmethod - def _missing_table_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]: - m = cls._RE_NO_TABLE.search(str(err)) - return m.group(1) if m else None - - async def _safe_execute( - self, - table: str, - sql: str, - params: Iterable[Any], - col_sources: Dict[str, Any], - max_retries: int = 10, - ) -> aiosqlite.Cursor: - retries = 0 - while retries < max_retries: - try: - async with self._db_semaphore: - cursor = await self._db_connection.execute(sql, params) - await self._db_connection.commit() - return cursor - except aiosqlite.OperationalError as err: - retries += 1 - err_str = str(err).lower() - - # Handle missing column - col = self._missing_column_from_error(err) - if col: - if col in col_sources: - await self._ensure_column(table, col, col_sources[col]) - else: - # Column not in sources, ensure it with NULL/TEXT type - await self._ensure_column(table, col, None) - continue - - # Handle missing table - tbl = self._missing_table_from_error(err) - if tbl: - await self._ensure_table(tbl, col_sources) - continue - - # Handle other column-related errors - if "has no column named" in err_str: - # Extract column name differently - match = re.search(r"table \w+ has no column named (\w+)", err_str) - if match: - col_name = match.group(1) - if col_name in col_sources: - await self._ensure_column(table, col_name, col_sources[col_name]) - else: - await self._ensure_column(table, col_name, None) - continue - - raise - raise Exception(f"Max retries ({max_retries}) exceeded") - - async def _safe_query( - self, - table: str, - sql: str, - params: Iterable[Any], - col_sources: Dict[str, Any], - ) -> AsyncGenerator[Dict[str, Any], None]: - # Check if table exists first - if not await self._table_exists(table): - return - - max_retries = 10 - retries = 0 - - while retries < max_retries: - try: - async with self._db_semaphore: - self._db_connection.row_factory = aiosqlite.Row - async with self._db_connection.execute(sql, params) as cursor: - # Fetch all rows while holding the semaphore - rows = await cursor.fetchall() - - # Yield rows after releasing the semaphore - for row in rows: - yield dict(row) - return - except aiosqlite.OperationalError as err: - retries += 1 - err_str = str(err).lower() - - # Handle missing table - tbl = self._missing_table_from_error(err) - if tbl: - # For queries, if table doesn't exist, just return empty - return - - # Handle missing column in WHERE clause or SELECT - if "no such column" in err_str: - # For queries with missing columns, return empty - return - - raise - - @staticmethod - def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: - if not where: - return "", [] - - clauses = [] - vals = [] - - op_map = {"$eq": "=", "$ne": "!=", "$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="} - - for k, v in where.items(): - if isinstance(v, dict): - # Handle operators like {"$lt": 10} - op_key = next(iter(v)) - if op_key in op_map: - operator = op_map[op_key] - value = v[op_key] - clauses.append(f"`{k}` {operator} ?") - vals.append(value) - else: - # Fallback for unexpected dict values - clauses.append(f"`{k}` = ?") - vals.append(json.dumps(v)) - else: - # Default to equality - clauses.append(f"`{k}` = ?") - vals.append(v) - - return " WHERE " + " AND ".join(clauses), vals - - async def _server_insert( - self, table: str, args: Dict[str, Any], return_id: bool = False - ) -> Union[str, int]: - """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" - uid = str(uuid4()) - now = self._utc_iso() - record = { - "uid": uid, - "created_at": now, - "updated_at": now, - "deleted_at": None, - **args, - } - - # Ensure table exists with all needed columns - await self._ensure_table(table, record) - - # Handle auto-increment ID if requested - if return_id and "id" not in args: - # Ensure id column exists - async with self._db_semaphore: - # Add id column if it doesn't exist - try: - await self._db_connection.execute( - f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" - ) - await self._db_connection.commit() - except aiosqlite.OperationalError as e: - if "duplicate column name" not in str(e).lower(): - # Try without autoincrement constraint - try: - await self._db_connection.execute( - f"ALTER TABLE {table} ADD COLUMN id INTEGER" - ) - await self._db_connection.commit() - except: - pass - - await self._invalidate_column_cache(table) - - # Insert and get lastrowid - cols = "`" + "`, `".join(record.keys()) + "`" - qs = ", ".join(["?"] * len(record)) - sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" - cursor = await self._safe_execute(table, sql, list(record.values()), record) - return cursor.lastrowid - - cols = "`" + "`, `".join(record) + "`" - qs = ", ".join(["?"] * len(record)) - sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" - await self._safe_execute(table, sql, list(record.values()), record) - return uid - - async def _server_update( - self, - table: str, - args: Dict[str, Any], - where: Optional[Dict[str, Any]] = None, - ) -> int: - if not args: - return 0 - - # Check if table exists - if not await self._table_exists(table): - return 0 - - args["updated_at"] = self._utc_iso() - - # Ensure all columns exist - all_cols = {**args, **(where or {})} - await self._ensure_table(table, all_cols) - for col, val in all_cols.items(): - await self._ensure_column(table, col, val) - - set_clause = ", ".join(f"`{k}` = ?" for k in args) - where_clause, where_params = self._build_where(where) - sql = f"UPDATE {table} SET {set_clause}{where_clause}" - params = list(args.values()) + where_params - cur = await self._safe_execute(table, sql, params, all_cols) - return cur.rowcount - - async def _server_delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: - # Check if table exists - if not await self._table_exists(table): - return 0 - - where_clause, where_params = self._build_where(where) - sql = f"DELETE FROM {table}{where_clause}" - cur = await self._safe_execute(table, sql, where_params, where or {}) - return cur.rowcount - - async def _server_upsert( - self, - table: str, - args: Dict[str, Any], - where: Optional[Dict[str, Any]] = None, - ) -> str | None: - if not args: - raise ValueError("Nothing to update. Empty dict given.") - args["updated_at"] = self._utc_iso() - affected = await self._server_update(table, args, where) - if affected: - rec = await self._server_get(table, where) - return rec.get("uid") if rec else None - merged = {**(where or {}), **args} - return await self._server_insert(table, merged) - - async def _server_get( - self, table: str, where: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: - where_clause, where_params = self._build_where(where) - sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" - async for row in self._safe_query(table, sql, where_params, where or {}): - return row - return None - - async def _server_find( - self, - table: str, - where: Optional[Dict[str, Any]] = None, - *, - limit: int = 0, - offset: int = 0, - order_by: Optional[str] = None, - ) -> List[Dict[str, Any]]: - """Find records with optional ordering.""" - where_clause, where_params = self._build_where(where) - order_clause = f" ORDER BY {order_by}" if order_by else "" - extra = (f" LIMIT {limit}" if limit else "") + (f" OFFSET {offset}" if offset else "") - sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" - return [row async for row in self._safe_query(table, sql, where_params, where or {})] - - async def _server_count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: - # Check if table exists - if not await self._table_exists(table): - return 0 - - where_clause, where_params = self._build_where(where) - sql = f"SELECT COUNT(*) FROM {table}{where_clause}" - gen = self._safe_query(table, sql, where_params, where or {}) - async for row in gen: - return next(iter(row.values()), 0) - return 0 - - async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: - return (await self._server_count(table, where)) > 0 - - async def _server_kv_set( - self, - key: str, - value: Any, - *, - table: str | None = None, - ) -> None: - tbl = table or self._KV_TABLE - json_val = json.dumps(value, default=str) - await self._server_upsert(tbl, {"value": json_val}, {"key": key}) - - async def _server_kv_get( - self, - key: str, - *, - default: Any = None, - table: str | None = None, - ) -> Any: - tbl = table or self._KV_TABLE - row = await self._server_get(tbl, {"key": key}) - if not row: - return default - try: - return json.loads(row["value"]) - except Exception: - return default - - async def _server_execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: - """Execute raw SQL for complex queries like JOINs.""" - async with self._db_semaphore: - cursor = await self._db_connection.execute(sql, params or ()) - await self._db_connection.commit() - return cursor - - async def _server_query_raw( - self, sql: str, params: Optional[Tuple] = None - ) -> List[Dict[str, Any]]: - """Execute raw SQL query and return results as list of dicts.""" - try: - async with self._db_semaphore: - self._db_connection.row_factory = aiosqlite.Row - async with self._db_connection.execute(sql, params or ()) as cursor: - return [dict(row) async for row in cursor] - except aiosqlite.OperationalError: - # Return empty list if query fails - return [] - - async def _server_query_one( - self, sql: str, params: Optional[Tuple] = None - ) -> Optional[Dict[str, Any]]: - """Execute raw SQL query and return single result.""" - results = await self._server_query_raw(sql + " LIMIT 1", params) - return results[0] if results else None - - async def _server_create_table( - self, - table: str, - schema: Dict[str, str], - constraints: Optional[List[str]] = None, - ): - """Create table with custom schema and constraints. Always includes default columns.""" - # Merge default columns with custom schema - full_schema = self._DEFAULT_COLUMNS.copy() - full_schema.update(schema) - - columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] - if constraints: - columns.extend(constraints) - columns_sql = ", ".join(columns) - - async with self._db_semaphore: - await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") - await self._db_connection.commit() - await self._invalidate_column_cache(table) - - async def _server_insert_unique( - self, table: str, args: Dict[str, Any], unique_fields: List[str] - ) -> Union[str, None]: - """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" - try: - return await self._server_insert(table, args) - except aiosqlite.IntegrityError as e: - if "UNIQUE" in str(e): - return None - raise - - async def _server_aggregate( - self, - table: str, - function: str, - column: str = "*", - where: Optional[Dict[str, Any]] = None, - ) -> Any: - """Perform aggregate functions like SUM, AVG, MAX, MIN.""" - # Check if table exists - if not await self._table_exists(table): - return None - - where_clause, where_params = self._build_where(where) - sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" - result = await self._server_query_one(sql, tuple(where_params)) - return result["result"] if result else None - - async def close(self): - """Close the connection or server.""" - if not self._initialized: - return - - if self._client_session: - await self._client_session.close() - if self._runner: - await self._runner.cleanup() - if os.path.exists(self._socket_path) and self._is_server: - os.unlink(self._socket_path) - if self._db_connection: - await self._db_connection.close() - - -class TransactionContext: - """Context manager for database transactions.""" - - def __init__(self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None): - self.db_connection = db_connection - self.semaphore = semaphore - - async def __aenter__(self): - if self.semaphore: - await self.semaphore.acquire() - try: - await self.db_connection.execute("BEGIN") - return self.db_connection - except: - if self.semaphore: - self.semaphore.release() - raise - - async def __aexit__(self, exc_type, exc_val, exc_tb): - try: - if exc_type is None: - await self.db_connection.commit() - else: - await self.db_connection.rollback() - finally: - if self.semaphore: - self.semaphore.release() - - -# Test cases remain the same but with additional tests for new functionality diff --git a/pr/autonomous/mode.py b/pr/autonomous/mode.py index a02910a..c70dbb1 100644 --- a/pr/autonomous/mode.py +++ b/pr/autonomous/mode.py @@ -1,9 +1,13 @@ +import asyncio +import asyncio import json import logging import time from pr.autonomous.detection import is_task_complete +from pr.core.api import call_api from pr.core.context import truncate_tool_result +from pr.tools.base import get_tools_definition from pr.ui import Colors, display_tool_call logger = logging.getLogger("pr") @@ -35,18 +39,39 @@ def run_autonomous_mode(assistant, task): logger.debug(f"Messages after context management: {len(assistant.messages)}") - from pr.core.api import call_api - from pr.tools.base import get_tools_definition + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # If we're in an event loop, use run_coroutine_threadsafe + import concurrent.futures - response = call_api( - assistant.messages, - assistant.model, - assistant.api_url, - assistant.api_key, - assistant.use_tools, - get_tools_definition(), - verbose=assistant.verbose, - ) + with concurrent.futures.ThreadPoolExecutor() as executor: + future = asyncio.run_coroutine_threadsafe( + call_api( + assistant.messages, + assistant.model, + assistant.api_url, + assistant.api_key, + assistant.use_tools, + get_tools_definition(), + verbose=assistant.verbose, + ), + loop, + ) + response = future.result() + except RuntimeError: + # No event loop running, use asyncio.run + response = asyncio.run( + call_api( + assistant.messages, + assistant.model, + assistant.api_url, + assistant.api_key, + assistant.use_tools, + get_tools_definition(), + verbose=assistant.verbose, + ) + ) if "error" in response: logger.error(f"API error in autonomous mode: {response['error']}") @@ -118,18 +143,40 @@ def process_response_autonomous(assistant, response): for result in tool_results: assistant.messages.append(result) - from pr.core.api import call_api - from pr.tools.base import get_tools_definition - follow_up = call_api( - assistant.messages, - assistant.model, - assistant.api_url, - assistant.api_key, - assistant.use_tools, - get_tools_definition(), - verbose=assistant.verbose, - ) + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # If we're in an event loop, use run_coroutine_threadsafe + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = asyncio.run_coroutine_threadsafe( + call_api( + assistant.messages, + assistant.model, + assistant.api_url, + assistant.api_key, + assistant.use_tools, + get_tools_definition(), + verbose=assistant.verbose, + ), + loop, + ) + follow_up = future.result() + except RuntimeError: + # No event loop running, use asyncio.run + follow_up = asyncio.run( + call_api( + assistant.messages, + assistant.model, + assistant.api_url, + assistant.api_key, + assistant.use_tools, + get_tools_definition(), + verbose=assistant.verbose, + ) + ) return process_response_autonomous(assistant, follow_up) content = message.get("content", "") diff --git a/pr/commands/handlers.py b/pr/commands/handlers.py index f26ce4b..08674cd 100644 --- a/pr/commands/handlers.py +++ b/pr/commands/handlers.py @@ -1,4 +1,6 @@ +import asyncio import json +import logging import time from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS @@ -40,7 +42,18 @@ def handle_command(assistant, command): if prompt_text.strip(): from pr.core.assistant import process_message - process_message(assistant, prompt_text) + task = asyncio.create_task(process_message(assistant, prompt_text)) + assistant.background_tasks.add(task) + task.add_done_callback( + lambda t: ( + assistant.background_tasks.discard(t), + ( + logging.error(f"Background task failed: {t.exception()}") + if t.exception() + else assistant.background_tasks.discard(t) + ), + ) + ) elif cmd == "/auto": if len(command_parts) < 2: @@ -201,7 +214,18 @@ def review_file(assistant, filename): message = f"Please review this file and provide feedback:\n\n{result['content']}" from pr.core.assistant import process_message - process_message(assistant, message) + task = asyncio.create_task(process_message(assistant, message)) + assistant.background_tasks.add(task) + task.add_done_callback( + lambda t: ( + assistant.background_tasks.discard(t), + ( + logging.error(f"Background task failed: {t.exception()}") + if t.exception() + else assistant.background_tasks.discard(t) + ), + ) + ) else: print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") @@ -212,7 +236,18 @@ def refactor_file(assistant, filename): message = f"Please refactor this code to improve its quality:\n\n{result['content']}" from pr.core.assistant import process_message - process_message(assistant, message) + task = asyncio.create_task(process_message(assistant, message)) + assistant.background_tasks.add(task) + task.add_done_callback( + lambda t: ( + assistant.background_tasks.discard(t), + ( + logging.error(f"Background task failed: {t.exception()}") + if t.exception() + else assistant.background_tasks.discard(t) + ), + ) + ) else: print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") @@ -223,7 +258,18 @@ def obfuscate_file(assistant, filename): message = f"Please obfuscate this code:\n\n{result['content']}" from pr.core.assistant import process_message - process_message(assistant, message) + task = asyncio.create_task(process_message(assistant, message)) + assistant.background_tasks.add(task) + task.add_done_callback( + lambda t: ( + assistant.background_tasks.discard(t), + ( + logging.error(f"Background task failed: {t.exception()}") + if t.exception() + else assistant.background_tasks.discard(t) + ), + ) + ) else: print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") diff --git a/pr/core/api.py b/pr/core/api.py index 3e88120..1c236e5 100644 --- a/pr/core/api.py +++ b/pr/core/api.py @@ -1,15 +1,14 @@ import json import logging -import urllib.error -import urllib.request from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE from pr.core.context import auto_slim_messages +from pr.core.http_client import http_client logger = logging.getLogger("pr") -def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False): +async def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False): try: messages = auto_slim_messages(messages, verbose=verbose) @@ -42,63 +41,70 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver data["tool_choice"] = "auto" logger.debug(f"Tool calling enabled with {len(tools_definition)} tools") - request_json = json.dumps(data) + request_json = data logger.debug(f"Request payload size: {len(request_json)} bytes") - req = urllib.request.Request( - api_url, data=request_json.encode("utf-8"), headers=headers, method="POST" - ) - logger.debug("Sending HTTP request...") - with urllib.request.urlopen(req) as response: - response_data = response.read().decode("utf-8") - logger.debug(f"Response received: {len(response_data)} bytes") - result = json.loads(response_data) + response = await http_client.post(api_url, headers=headers, json_data=request_json) - if "usage" in result: - logger.debug(f"Token usage: {result['usage']}") - if "choices" in result and result["choices"]: - choice = result["choices"][0] - if "message" in choice: - msg = choice["message"] - logger.debug(f"Response role: {msg.get('role', 'N/A')}") - if "content" in msg and msg["content"]: - logger.debug(f"Response content length: {len(msg['content'])} chars") - if "tool_calls" in msg: - logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") + if response.get("error"): + if "status" in response: + logger.error(f"API HTTP Error: {response['status']} - {response.get('text', '')}") + logger.debug("=== API CALL FAILED ===") + return { + "error": f"API Error: {response['status']}", + "message": response.get("text", ""), + } + else: + logger.error(f"API call failed: {response.get('exception', 'Unknown error')}") + logger.debug("=== API CALL FAILED ===") + return {"error": response.get("exception", "Unknown error")} - if verbose and "usage" in result: - from pr.core.usage_tracker import UsageTracker + response_data = response["text"] + logger.debug(f"Response received: {len(response_data)} bytes") + result = json.loads(response_data) - usage = result["usage"] - input_t = usage.get("prompt_tokens", 0) - output_t = usage.get("completion_tokens", 0) - cost = UsageTracker._calculate_cost(model, input_t, output_t) - print(f"API call cost: €{cost:.4f}") + if "usage" in result: + logger.debug(f"Token usage: {result['usage']}") + if "choices" in result and result["choices"]: + choice = result["choices"][0] + if "message" in choice: + msg = choice["message"] + logger.debug(f"Response role: {msg.get('role', 'N/A')}") + if "content" in msg and msg["content"]: + logger.debug(f"Response content length: {len(msg['content'])} chars") + if "tool_calls" in msg: + logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") + + if verbose and "usage" in result: + from pr.core.usage_tracker import UsageTracker + + usage = result["usage"] + input_t = usage.get("prompt_tokens", 0) + output_t = usage.get("completion_tokens", 0) + UsageTracker._calculate_cost(model, input_t, output_t) logger.debug("=== API CALL END ===") return result - except urllib.error.HTTPError as e: - error_body = e.read().decode("utf-8") - logger.error(f"API HTTP Error: {e.code} - {error_body}") - logger.debug("=== API CALL FAILED ===") - return {"error": f"API Error: {e.code}", "message": error_body} except Exception as e: logger.error(f"API call failed: {e}") logger.debug("=== API CALL FAILED ===") return {"error": str(e)} -def list_models(model_list_url, api_key): +async def list_models(model_list_url, api_key): try: - req = urllib.request.Request(model_list_url) + headers = {} if api_key: - req.add_header("Authorization", f"Bearer {api_key}") + headers["Authorization"] = f"Bearer {api_key}" - with urllib.request.urlopen(req) as response: - data = json.loads(response.read().decode("utf-8")) + response = await http_client.get(model_list_url, headers=headers) + if response.get("error"): + return {"error": response.get("text", "HTTP error")} + + data = json.loads(response["text"]) return data.get("data", []) except Exception as e: return {"error": str(e)} diff --git a/pr/core/assistant.py b/pr/core/assistant.py index 4be6c61..825a892 100644 --- a/pr/core/assistant.py +++ b/pr/core/assistant.py @@ -20,7 +20,7 @@ from pr.config import ( ) from pr.core.api import call_api from pr.core.autonomous_interactions import ( - get_global_autonomous, + start_global_autonomous, stop_global_autonomous, ) from pr.core.background_monitor import ( @@ -29,6 +29,7 @@ from pr.core.background_monitor import ( stop_global_monitor, ) from pr.core.context import init_system_message, truncate_tool_result +from pr.core.usage_tracker import UsageTracker from pr.tools import get_tools_definition from pr.tools.agents import ( collaborate_agents, @@ -71,7 +72,7 @@ from pr.tools.memory import ( from pr.tools.patch import apply_patch, create_diff, display_file_diff from pr.tools.python_exec import python_exec from pr.tools.web import http_fetch, web_search, web_search_news -from pr.ui import Colors, render_markdown +from pr.ui import Colors, Spinner, render_markdown logger = logging.getLogger("pr") logger.setLevel(logging.DEBUG) @@ -109,6 +110,8 @@ class Assistant: self.autonomous_mode = False self.autonomous_iterations = 0 self.background_monitoring = False + self.usage_tracker = UsageTracker() + self.background_tasks = set() self.init_database() self.messages.append(init_system_message(args)) @@ -125,8 +128,7 @@ class Assistant: # Initialize background monitoring components try: start_global_monitor() - autonomous = get_global_autonomous() - autonomous.start(llm_callback=self._handle_background_updates) + start_global_autonomous(llm_callback=self._handle_background_updates) self.background_monitoring = True if self.debug: logger.debug("Background monitoring initialized") @@ -236,7 +238,7 @@ class Assistant: if self.debug: print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}") - def execute_tool_calls(self, tool_calls): + async def execute_tool_calls(self, tool_calls): results = [] logger.debug(f"Executing {len(tool_calls)} tool call(s)") @@ -337,7 +339,7 @@ class Assistant: return results - def process_response(self, response): + async def process_response(self, response): if "error" in response: return f"Error: {response['error']}" @@ -348,15 +350,17 @@ class Assistant: self.messages.append(message) if "tool_calls" in message and message["tool_calls"]: - if self.verbose: - print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}") + tool_count = len(message["tool_calls"]) + print(f"{Colors.BLUE}🔧 Executing {tool_count} tool call(s)...{Colors.RESET}") - tool_results = self.execute_tool_calls(message["tool_calls"]) + tool_results = await self.execute_tool_calls(message["tool_calls"]) + + print(f"{Colors.GREEN}✅ Tool execution completed.{Colors.RESET}") for result in tool_results: self.messages.append(result) - follow_up = call_api( + follow_up = await call_api( self.messages, self.model, self.api_url, @@ -365,7 +369,7 @@ class Assistant: get_tools_definition(), verbose=self.verbose, ) - return self.process_response(follow_up) + return await self.process_response(follow_up) content = message.get("content", "") return render_markdown(content, self.syntax_highlighting) @@ -439,12 +443,23 @@ class Assistant: readline.set_completer(completer) readline.parse_and_bind("tab: complete") - def run_repl(self): + async def run_repl(self): self.setup_readline() signal.signal(signal.SIGINT, self.signal_handler) - print(f"{Colors.BOLD}r{Colors.RESET}") - print(f"Type 'help' for commands or start chatting") + print( + f"{Colors.BOLD}{Colors.CYAN}╔══════════════════════════════════════════════╗{Colors.RESET}" + ) + print( + f"{Colors.BOLD}{Colors.CYAN}║{Colors.RESET}{Colors.BOLD} PR Assistant v{__import__('pr').__version__} {Colors.RESET}{Colors.BOLD}{Colors.CYAN}║{Colors.RESET}" + ) + print( + f"{Colors.BOLD}{Colors.CYAN}╚══════════════════════════════════════════════╝{Colors.RESET}" + ) + print( + f"{Colors.GRAY}Type 'help' for commands, 'exit' to quit, or start chatting.{Colors.RESET}" + ) + print(f"{Colors.GRAY}AI calls will show costs and progress indicators.{Colors.RESET}\n") while True: try: @@ -480,7 +495,7 @@ class Assistant: elif cmd_result is True: continue - process_message(self, user_input) + await process_message(self, user_input) except EOFError: break @@ -490,7 +505,7 @@ class Assistant: print(f"{Colors.RED}Error: {e}{Colors.RESET}") logging.error(f"REPL error: {e}\n{traceback.format_exc()}") - def run_single(self): + async def run_single(self): if self.args.message: message = self.args.message else: @@ -498,7 +513,7 @@ class Assistant: from pr.autonomous.mode import run_autonomous_mode - run_autonomous_mode(self, message) + await run_autonomous_mode(self, message) def cleanup(self): if hasattr(self, "enhanced") and self.enhanced: @@ -525,22 +540,22 @@ class Assistant: if self.db_conn: self.db_conn.close() - def run(self): + async def run(self): try: print( f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}" ) if self.args.interactive or (not self.args.message and sys.stdin.isatty()): print("DEBUG: calling run_repl") - self.run_repl() + await self.run_repl() else: print("DEBUG: calling run_single") - self.run_single() + await self.run_single() finally: self.cleanup() -def process_message(assistant, message): +async def process_message(assistant, message): from pr.core.knowledge_context import inject_knowledge_context inject_knowledge_context(assistant, message) @@ -550,10 +565,11 @@ def process_message(assistant, message): logger.debug(f"Processing user message: {message[:100]}...") logger.debug(f"Current message count: {len(assistant.messages)}") - if assistant.verbose: - print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}") + # Start spinner for AI call + spinner = Spinner("Querying AI...") + await spinner.start() - response = call_api( + response = await call_api( assistant.messages, assistant.model, assistant.api_url, @@ -562,6 +578,19 @@ def process_message(assistant, message): get_tools_definition(), verbose=assistant.verbose, ) - result = assistant.process_response(response) + + await spinner.stop() + + # Track usage and display cost + if "usage" in response: + usage = response["usage"] + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + assistant.usage_tracker.track_request(assistant.model, input_tokens, output_tokens) + cost = UsageTracker._calculate_cost(assistant.model, input_tokens, output_tokens) + total_cost = assistant.usage_tracker.session_usage["estimated_cost"] + print(f"{Colors.YELLOW}💰 Cost: ${cost:.4f} | Total: ${total_cost:.4f}{Colors.RESET}") + + result = await assistant.process_response(response) print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n") diff --git a/pr/core/enhanced_assistant.py b/pr/core/enhanced_assistant.py index e3fd1a4..0098ff7 100644 --- a/pr/core/enhanced_assistant.py +++ b/pr/core/enhanced_assistant.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import uuid @@ -128,15 +129,39 @@ class EnhancedAssistant: def _api_caller_for_agent( self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int ) -> Dict[str, Any]: - return call_api( - messages, - self.base.model, - self.base.api_url, - self.base.api_key, - use_tools=False, - tools_definition=[], - verbose=self.base.verbose, - ) + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # If we're in an event loop, use run_coroutine_threadsafe + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = asyncio.run_coroutine_threadsafe( + call_api( + messages, + self.base.model, + self.base.api_url, + self.base.api_key, + use_tools=False, + tools_definition=[], + verbose=self.base.verbose, + ), + loop, + ) + return future.result() + except RuntimeError: + # No event loop running, use asyncio.run + return asyncio.run( + call_api( + messages, + self.base.model, + self.base.api_url, + self.base.api_key, + use_tools=False, + tools_definition=[], + verbose=self.base.verbose, + ) + ) def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: if self.api_cache and CACHE_ENABLED: @@ -145,15 +170,39 @@ class EnhancedAssistant: logger.debug("API cache hit") return cached_response - response = call_api( - messages, - self.base.model, - self.base.api_url, - self.base.api_key, - self.base.use_tools, - get_tools_definition(), - verbose=self.base.verbose, - ) + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # If we're in an event loop, use run_coroutine_threadsafe + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = asyncio.run_coroutine_threadsafe( + call_api( + messages, + self.base.model, + self.base.api_url, + self.base.api_key, + self.base.use_tools, + get_tools_definition(), + verbose=self.base.verbose, + ), + loop, + ) + response = future.result() + except RuntimeError: + # No event loop running, use asyncio.run + response = asyncio.run( + call_api( + messages, + self.base.model, + self.base.api_url, + self.base.api_key, + self.base.use_tools, + get_tools_definition(), + verbose=self.base.verbose, + ) + ) if self.api_cache and CACHE_ENABLED and "error" not in response: token_count = response.get("usage", {}).get("total_tokens", 0) diff --git a/pr/core/http_client.py b/pr/core/http_client.py new file mode 100644 index 0000000..3af3db5 --- /dev/null +++ b/pr/core/http_client.py @@ -0,0 +1,136 @@ +import asyncio +import json +import logging +import socket +import time +import urllib.error +import urllib.request +from typing import Dict, Any, Optional + +logger = logging.getLogger("pr") + + +class AsyncHTTPClient: + def __init__(self): + self.session_headers = {} + + async def request( + self, + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + data: Optional[bytes] = None, + json_data: Optional[Dict[str, Any]] = None, + timeout: float = 30.0, + ) -> Dict[str, Any]: + """Make an async HTTP request using urllib in a thread executor with retry logic.""" + loop = asyncio.get_event_loop() + + # Prepare headers + request_headers = {**self.session_headers} + if headers: + request_headers.update(headers) + + # Prepare data + request_data = data + if json_data is not None: + request_data = json.dumps(json_data).encode("utf-8") + request_headers["Content-Type"] = "application/json" + + # Create request object + req = urllib.request.Request(url, data=request_data, headers=request_headers, method=method) + + attempt = 0 + start_time = time.time() + + while True: + attempt += 1 + try: + # Execute in thread pool + response = await loop.run_in_executor( + None, lambda: urllib.request.urlopen(req, timeout=timeout) + ) + response_data = await loop.run_in_executor(None, response.read) + response_text = response_data.decode("utf-8") + + return { + "status": response.status, + "headers": dict(response.headers), + "text": response_text, + "json": lambda: json.loads(response_text) if response_text else None, + } + + except urllib.error.HTTPError as e: + error_body = await loop.run_in_executor(None, e.read) + error_text = error_body.decode("utf-8") + return { + "status": e.code, + "error": True, + "text": error_text, + "json": lambda: json.loads(error_text) if error_text else None, + } + except socket.timeout: + # Handle socket timeouts specifically + elapsed = time.time() - start_time + elapsed_minutes = int(elapsed // 60) + elapsed_seconds = elapsed % 60 + duration_str = ( + f"{elapsed_minutes}m {elapsed_seconds:.1f}s" + if elapsed_minutes > 0 + else f"{elapsed_seconds:.1f}s" + ) + + logger.warning( + f"Request timed out (attempt {attempt}, " + f"duration: {duration_str}). Retrying in {attempt} second(s)..." + ) + + # Exponential backoff starting at 1 second + await asyncio.sleep(attempt) + except Exception as e: + error_msg = str(e) + # For other exceptions, check if they might be timeout-related + if "timed out" in error_msg.lower() or "timeout" in error_msg.lower(): + elapsed = time.time() - start_time + elapsed_minutes = int(elapsed // 60) + elapsed_seconds = elapsed % 60 + duration_str = ( + f"{elapsed_minutes}m {elapsed_seconds:.1f}s" + if elapsed_minutes > 0 + else f"{elapsed_seconds:.1f}s" + ) + + logger.warning( + f"Request timed out (attempt {attempt}, " + f"duration: {duration_str}). Retrying in {attempt} second(s)..." + ) + + # Exponential backoff starting at 1 second + await asyncio.sleep(attempt) + else: + # Non-timeout errors should not be retried + return {"error": True, "exception": error_msg} + + async def get( + self, url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0 + ) -> Dict[str, Any]: + return await self.request("GET", url, headers=headers, timeout=timeout) + + async def post( + self, + url: str, + headers: Optional[Dict[str, str]] = None, + data: Optional[bytes] = None, + json_data: Optional[Dict[str, Any]] = None, + timeout: float = 30.0, + ) -> Dict[str, Any]: + return await self.request( + "POST", url, headers=headers, data=data, json_data=json_data, timeout=timeout + ) + + def set_default_headers(self, headers: Dict[str, str]): + self.session_headers.update(headers) + + +# Global client instance +http_client = AsyncHTTPClient() diff --git a/pr/server.py b/pr/server.py deleted file mode 100644 index 10bc816..0000000 --- a/pr/server.py +++ /dev/null @@ -1,585 +0,0 @@ -#!/usr/bin/env python3 -import asyncio -import logging -import os -import sys -import subprocess -import urllib.parse -from pathlib import Path -from typing import Dict, List, Tuple, Optional -import pathlib -import ast -from types import SimpleNamespace -import time -from aiohttp import web -import aiohttp - -# --- Optional external dependency used in your original code --- -from pr.ads import AsyncDataSet # noqa: E402 - -import sys -from http import cookies -import urllib.parse -import os -import io -import json -import urllib.request -import pathlib -import pickle -import zlib - -import asyncio -import time -from functools import wraps -from pathlib import Path -import pickle -import zlib - - -class Cache: - def __init__(self, base_dir: Path | str = "."): - self.base_dir = Path(base_dir).resolve() - self.base_dir.mkdir(exist_ok=True, parents=True) - - def is_cacheable(self, obj): - return isinstance(obj, (int, str, bool, float)) - - def generate_key(self, *args, **kwargs): - - return zlib.crc32( - json.dumps( - { - "args": [arg for arg in args if self.is_cacheable(arg)], - "kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)}, - }, - sort_keys=True, - default=str, - ).encode() - ) - - def set(self, key, value): - key = self.generate_key(key) - data = {"value": value, "timestamp": time.time()} - serialized_data = pickle.dumps(data) - with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f: - f.write(serialized_data) - - def get(self, key, default=None): - key = self.generate_key(key) - try: - with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f: - data = pickle.loads(f.read()) - return data - except FileNotFoundError: - return default - - def is_cacheable(self, obj): - return isinstance(obj, (int, str, bool, float)) - - def cached(self, expiration=60): - def decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - # if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()): - # return await func(*args, **kwargs) - key = self.generate_key(*args, **kwargs) - print("Cache hit:", key) - cached_data = self.get(key) - if cached_data is not None: - if expiration is None or (time.time() - cached_data["timestamp"]) < expiration: - return cached_data["value"] - result = await func(*args, **kwargs) - self.set(key, result) - return result - - return wrapper - - return decorator - - -cache = Cache() - - -class CGI: - - def __init__(self): - - self.environ = os.environ - if not self.gateway_interface == "CGI/1.1": - return - self.status = 200 - self.headers = {} - self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi")) - self.headers["Content-Length"] = "0" - self.headers["Cache-Control"] = "no-cache" - self.headers["Access-Control-Allow-Origin"] = "*" - self.headers["Content-Type"] = "application/json; charset=utf-8" - self.cookies = cookies.SimpleCookie() - - self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"]) - self.file = io.BytesIO() - - def validate(self, val, fn, error): - if fn(val): - return True - self.status = 421 - self.write({"type": "error", "message": error}) - exit() - - def __getitem__(self, key): - return self.get(key) - - def get(self, key, default=None): - result = self.query.get(key, [default]) - if len(result) == 1: - return result[0] - return result - - def get_bool(self, key): - return str(self.get(key)).lower() in ["true", "yes", "1"] - - @property - def cache_key(self): - return self.environ["CACHE_KEY"] - - @property - def env(self): - env = os.environ.copy() - return env - - @property - def gateway_interface(self): - return self.env.get("GATEWAY_INTERFACE", "") - - @property - def request_method(self): - return self.env["REQUEST_METHOD"] - - @property - def query_string(self): - return self.env["QUERY_STRING"] - - @property - def script_name(self): - return self.env["SCRIPT_NAME"] - - @property - def path_info(self): - return self.env["PATH_INFO"] - - @property - def server_name(self): - return self.env["SERVER_NAME"] - - @property - def server_port(self): - return self.env["SERVER_PORT"] - - @property - def server_protocol(self): - return self.env["SERVER_PROTOCOL"] - - @property - def remote_addr(self): - return self.env["REMOTE_ADDR"] - - def validate_get(self, key): - self.validate(self.get(key), lambda x: bool(x), f"Missing {key}") - - def print(self, data): - self.write(data) - - def write(self, data): - if not isinstance(data, bytes) and not isinstance(data, str): - data = json.dumps(data, default=str, indent=4) - try: - data = data.encode() - except: - pass - self.file.write(data) - self.headers["Content-Length"] = str(len(self.file.getvalue())) - - @property - def http_status(self): - return f"HTTP/1.1 {self.status}\r\n".encode("utf-8") - - @property - def http_headers(self): - headers = io.BytesIO() - for header in self.headers: - headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8")) - headers.write(b"\r\n") - return headers.getvalue() - - @property - def http_body(self): - return self.file.getvalue() - - @property - def http_response(self): - return self.http_status + self.http_headers + self.http_body - - def flush(self, response=None): - - if response: - try: - response = response.encode() - except: - pass - sys.stdout.buffer.write(response) - sys.stdout.buffer.flush() - return - sys.stdout.buffer.write(self.http_response) - sys.stdout.buffer.flush() - - def __del__(self): - if self.http_body: - self.flush() - exit() - - -# ------------------------------- -# Utilities -# ------------------------------- -def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]: - matches: List[Tuple[str, str]] = [] - for root, _, files in os.walk(directory): - for file in files: - if not file.endswith(".py"): - continue - path = os.path.join(root, file) - try: - with open(path, "r", encoding="utf-8") as fh: - source = fh.read() - tree = ast.parse(source, filename=path) - for node in ast.walk(tree): - if isinstance(node, ast.AsyncFunctionDef) and node.name == name: - func_src = ast.get_source_segment(source, node) - if func_src: - matches.append((path, func_src)) - break - except (SyntaxError, UnicodeDecodeError): - continue - return matches - - -# ------------------------------- -# Server -# ------------------------------- -class Static(SimpleNamespace): - pass - - -class View(web.View): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.db = self.request.app["db"] - self.server = self.request.app["server"] - - @property - def client(self): - return aiohttp.ClientSession() - - -class RetoorServer: - def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None: - self.base_dir = Path(base_dir).resolve() - self.port = port - - self.static = Static() - self.db = AsyncDataSet(".default.db") - - self._logger = logging.getLogger("retoor.server") - self._logger.setLevel(logging.INFO) - if not self._logger.handlers: - h = logging.StreamHandler(sys.stdout) - fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s") - h.setFormatter(fmt) - self._logger.addHandler(h) - - self._func_cache: Dict[str, Tuple[str, str]] = {} - self._compiled_cache: Dict[str, object] = {} - self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {} - self._static_path_cache: Dict[str, Path] = {} - self._base_dir_str = str(self.base_dir) - - self.app = web.Application() - - app = self.app - - app["db"] = self.db - for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir(): - if path.suffix == ".py": - exec(open(path).read(), globals(), locals()) - - self.app["server"] = self - - # from rpanel import create_app as create_rpanel_app - # self.app.add_subapp("/api/{tail:.*}", create_rpanel_app()) - - self.app.router.add_route("*", "/{tail:.*}", self.handle_any) - - # --------------------------- - # Simple endpoints - # --------------------------- - async def handle_root(self, request: web.Request) -> web.Response: - return web.Response(text="Static file and cgi server from retoor.") - - async def handle_hello(self, request: web.Request) -> web.Response: - return web.Response(text="Welcome to the custom HTTP server!") - - # --------------------------- - # Dynamic function dispatch - # --------------------------- - def _path_to_funcname(self, path: str) -> str: - # /foo/bar -> foo_bar ; strip leading/trailing slashes safely - return path.strip("/").replace("/", "_") - - async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]: - relpath = request.path.strip("/") - relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace( - "..", "" - ) - if not relpath: - return None - - last_part = relpath.split("/")[-1] - if "." in last_part: - return None - - funcname = self._path_to_funcname(request.path) - - if funcname not in self._compiled_cache: - entry = self._func_cache.get(funcname) - if entry is None: - matches = get_function_source(funcname, directory=str(self.base_dir)) - if not matches: - return None - entry = matches[0] - self._func_cache[funcname] = entry - - filepath, source = entry - code_obj = compile(source, filepath, "exec") - self._compiled_cache[funcname] = (code_obj, filepath) - - code_obj, filepath = self._compiled_cache[funcname] - ctx = { - "static": self.static, - "request": request, - "relpath": relpath, - "os": os, - "sys": sys, - "web": web, - "asyncio": asyncio, - "subprocess": subprocess, - "urllib": urllib.parse, - "app": request.app, - "db": self.db, - } - exec(code_obj, ctx, ctx) - coro = ctx.get(funcname) - if not callable(coro): - return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.") - return await coro(request) - - # --------------------------- - # Static files - # --------------------------- - async def handle_static(self, request: web.Request) -> web.StreamResponse: - path = request.path - - if path in self._static_path_cache: - cached_path = self._static_path_cache[path] - if cached_path is None: - return web.Response(status=404, text="Not found") - try: - return web.FileResponse(path=cached_path) - except Exception as e: - return web.Response(status=500, text=f"Failed to send file: {e!r}") - - relpath = path.replace('..",', "").lstrip("/").rstrip("/") - abspath = (self.base_dir / relpath).resolve() - - if not str(abspath).startswith(self._base_dir_str): - return web.Response(status=403, text="Forbidden") - - if not abspath.exists(): - self._static_path_cache[path] = None - return web.Response(status=404, text="Not found") - - if abspath.is_dir(): - return web.Response(status=403, text="Directory listing forbidden") - - self._static_path_cache[path] = abspath - try: - return web.FileResponse(path=abspath) - except Exception as e: - return web.Response(status=500, text=f"Failed to send file: {e!r}") - - # --------------------------- - # CGI - # --------------------------- - def _find_cgi_script( - self, path_only: str - ) -> Tuple[Optional[Path], Optional[str], Optional[str]]: - path_only = "pr/cgi/" + path_only - if path_only in self._cgi_cache: - return self._cgi_cache[path_only] - - split_path = [p for p in path_only.split("/") if p] - for i in range(len(split_path), -1, -1): - candidate = "/" + "/".join(split_path[:i]) - candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve() - if ( - str(candidate_fs).startswith(self._base_dir_str) - and candidate_fs.parent.is_dir() - and candidate_fs.parent.name == "cgi" - and candidate_fs.suffix == ".bin" - and candidate_fs.is_file() - and os.access(candidate_fs, os.X_OK) - ): - script_name = candidate if candidate.startswith("/") else "/" + candidate - path_info = path_only[len(script_name) :] - result = (candidate_fs, script_name, path_info) - self._cgi_cache[path_only] = result - return result - - result = (None, None, None) - self._cgi_cache[path_only] = result - return result - - async def _run_cgi( - self, request: web.Request, script_path: Path, script_name: str, path_info: str - ) -> web.Response: - start_time = time.time() - method = request.method - query = request.query_string - - env = os.environ.copy() - env["CACHE_KEY"] = json.dumps( - {"method": method, "query": query, "script_name": script_name, "path_info": path_info}, - default=str, - ) - env["GATEWAY_INTERFACE"] = "CGI/1.1" - env["REQUEST_METHOD"] = method - env["QUERY_STRING"] = query - env["SCRIPT_NAME"] = script_name - env["PATH_INFO"] = path_info - - host_parts = request.host.split(":", 1) - env["SERVER_NAME"] = host_parts[0] - env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80" - env["SERVER_PROTOCOL"] = "HTTP/1.1" - - peername = request.transport.get_extra_info("peername") - env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0]) - - for hk, hv in request.headers.items(): - hk_lower = hk.lower() - if hk_lower == "content-type": - env["CONTENT_TYPE"] = hv - elif hk_lower == "content-length": - env["CONTENT_LENGTH"] = hv - else: - env["HTTP_" + hk.upper().replace("-", "_")] = hv - - post_data = None - if method in {"POST", "PUT", "PATCH"}: - post_data = await request.read() - env.setdefault("CONTENT_LENGTH", str(len(post_data))) - - try: - proc = await asyncio.create_subprocess_exec( - str(script_path), - stdin=asyncio.subprocess.PIPE if post_data else None, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=env, - ) - stdout, stderr = await proc.communicate(input=post_data) - - if proc.returncode != 0: - msg = stderr.decode(errors="ignore") - return web.Response(status=500, text=f"CGI script error:\n{msg}") - - header_end = stdout.find(b"\r\n\r\n") - if header_end != -1: - offset = 4 - else: - header_end = stdout.find(b"\n\n") - offset = 2 - - if header_end != -1: - body = stdout[header_end + offset :] - headers_blob = stdout[:header_end].decode(errors="ignore") - - status_code = 200 - headers: Dict[str, str] = {} - for line in headers_blob.splitlines(): - line = line.strip() - if not line or ":" not in line: - continue - - k, v = line.split(":", 1) - k_stripped = k.strip() - if k_stripped.lower() == "status": - try: - status_code = int(v.strip().split()[0]) - except Exception: - status_code = 200 - else: - headers[k_stripped] = v.strip() - - end_time = time.time() - headers["X-Powered-By"] = "retoor" - headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime()) - headers["X-Time"] = str(end_time) - headers["X-Duration"] = str(end_time - start_time) - if "error" in str(body).lower(): - print(headers) - print(body) - return web.Response(body=body, headers=headers, status=status_code) - else: - return web.Response(body=stdout) - - except Exception as e: - return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}") - - # --------------------------- - # Main router - # --------------------------- - async def handle_any(self, request: web.Request) -> web.StreamResponse: - path = request.path - - if path == "/": - return await self.handle_root(request) - if path == "/hello": - return await self.handle_hello(request) - - script_path, script_name, path_info = self._find_cgi_script(path) - if script_path: - return await self._run_cgi(request, script_path, script_name or "", path_info or "") - - dyn = await self._maybe_dispatch_dynamic(request) - if dyn is not None: - return dyn - - return await self.handle_static(request) - - # --------------------------- - # Runner - # --------------------------- - def run(self) -> None: - self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir) - web.run_app(self.app, port=self.port) - - -def main() -> None: - server = RetoorServer(base_dir=".", port=8118) - server.run() - - -if __name__ == "__main__": - main() -else: - cgi = CGI() diff --git a/pr/tools/__init__.py b/pr/tools/__init__.py index 929822b..f8a92a3 100644 --- a/pr/tools/__init__.py +++ b/pr/tools/__init__.py @@ -43,11 +43,6 @@ from pr.tools.memory import ( from pr.tools.patch import apply_patch, create_diff from pr.tools.python_exec import python_exec from pr.tools.web import http_fetch, web_search, web_search_news -from pr.tools.context_modifier import ( - modify_context_add, - modify_context_replace, - modify_context_delete, -) __all__ = [ "add_knowledge_entry", diff --git a/pr/tools/agents.py b/pr/tools/agents.py index 865ac44..8c47793 100644 --- a/pr/tools/agents.py +++ b/pr/tools/agents.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import Any, Dict, List @@ -16,15 +17,39 @@ def _create_api_wrapper(): tools_definition = get_tools_definition() if use_tools else [] def api_wrapper(messages, temperature=None, max_tokens=None, **kwargs): - return call_api( - messages=messages, - model=model, - api_url=api_url, - api_key=api_key, - use_tools=use_tools, - tools_definition=tools_definition, - verbose=False, - ) + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # If we're in an event loop, use run_coroutine_threadsafe + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = asyncio.run_coroutine_threadsafe( + call_api( + messages=messages, + model=model, + api_url=api_url, + api_key=api_key, + use_tools=use_tools, + tools_definition=tools_definition, + verbose=False, + ), + loop, + ) + return future.result() + except RuntimeError: + # No event loop running, use asyncio.run + return asyncio.run( + call_api( + messages=messages, + model=model, + api_url=api_url, + api_key=api_key, + use_tools=use_tools, + tools_definition=tools_definition, + verbose=False, + ) + ) return api_wrapper diff --git a/pr/tools/context_modifier.py b/pr/tools/context_modifier.py index 4d62900..2ee61de 100644 --- a/pr/tools/context_modifier.py +++ b/pr/tools/context_modifier.py @@ -1,18 +1,21 @@ import os from typing import Optional -CONTEXT_FILE = '/home/retoor/.local/share/rp/.rcontext.txt' +CONTEXT_FILE = "/home/retoor/.local/share/rp/.rcontext.txt" + def _read_context() -> str: if not os.path.exists(CONTEXT_FILE): raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.") - with open(CONTEXT_FILE, 'r') as f: + with open(CONTEXT_FILE, "r") as f: return f.read() + def _write_context(content: str): - with open(CONTEXT_FILE, 'w') as f: + with open(CONTEXT_FILE, "w") as f: f.write(content) + def modify_context_add(new_content: str, position: Optional[str] = None) -> str: """ Add new content to the .rcontext.txt file. @@ -25,13 +28,14 @@ def modify_context_add(new_content: str, position: Optional[str] = None) -> str: if position and position in current: # Insert before the position parts = current.split(position, 1) - updated = parts[0] + new_content + '\n\n' + position + parts[1] + updated = parts[0] + new_content + "\n\n" + position + parts[1] else: # Append at the end - updated = current + '\n\n' + new_content + updated = current + "\n\n" + new_content _write_context(updated) return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested." + def modify_context_replace(old_content: str, new_content: str) -> str: """ Replace old content with new content in .rcontext.txt. @@ -47,6 +51,7 @@ def modify_context_replace(old_content: str, new_content: str) -> str: _write_context(updated) return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects." + def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str: """ Delete content from .rcontext.txt, but only if confirmed. @@ -56,10 +61,12 @@ def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> st confirmed: Must be True to proceed with deletion. """ if not confirmed: - raise PermissionError(f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently.") + raise PermissionError( + f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently." + ) current = _read_context() if content_to_delete not in current: raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...") - updated = current.replace(content_to_delete, '', 1) + updated = current.replace(content_to_delete, "", 1) _write_context(updated) return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines." diff --git a/pr/ui/__init__.py b/pr/ui/__init__.py index f2520da..d2341ec 100644 --- a/pr/ui/__init__.py +++ b/pr/ui/__init__.py @@ -1,9 +1,10 @@ -from pr.ui.colors import Colors +from pr.ui.colors import Colors, Spinner from pr.ui.display import display_tool_call, print_autonomous_header from pr.ui.rendering import highlight_code, render_markdown __all__ = [ "Colors", + "Spinner", "highlight_code", "render_markdown", "display_tool_call", diff --git a/pr/ui/colors.py b/pr/ui/colors.py index 2c882c6..46ffd96 100644 --- a/pr/ui/colors.py +++ b/pr/ui/colors.py @@ -1,3 +1,6 @@ +import asyncio + + class Colors: RESET = "\033[0m" BOLD = "\033[1m" @@ -12,3 +15,30 @@ class Colors: BG_BLUE = "\033[44m" BG_GREEN = "\033[42m" BG_RED = "\033[41m" + + +class Spinner: + def __init__(self, message="Processing...", spinner_chars="|/-\\"): + self.message = message + self.spinner_chars = spinner_chars + self.running = False + self.task = None + + async def start(self): + self.running = True + self.task = asyncio.create_task(self._spin()) + + async def stop(self): + self.running = False + if self.task: + await self.task + # Clear the line + print("\r" + " " * (len(self.message) + 2) + "\r", end="", flush=True) + + async def _spin(self): + i = 0 + while self.running: + char = self.spinner_chars[i % len(self.spinner_chars)] + print(f"\r{Colors.CYAN}{char}{Colors.RESET} {self.message}", end="", flush=True) + i += 1 + await asyncio.sleep(0.1) diff --git a/pyproject.toml b/pyproject.toml index 21cd928..f1e0409 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rp" -version = "1.4.0" +version = "1.5.0" description = "R python edition. The ultimate autonomous AI CLI." readme = "README.md" requires-python = ">=3.12" @@ -14,7 +14,6 @@ authors = [ {name = "retoor", email = "retoor@molodetz.nl"} ] dependencies = [ - "aiohttp>=3.13.2", "pydantic>=2.12.3", "prompt_toolkit>=3.0.0", ] @@ -32,8 +31,6 @@ classifiers = [ dev = [ "pytest>=8.3.0", "pytest-asyncio>=1.2.0", - "pytest-aiohttp>=1.1.0", - "aiohttp>=3.13.2", "pytest-cov>=7.0.0", "black>=25.9.0", "flake8>=7.3.0", diff --git a/rp.py b/rp.py index 0d7b96e..4e91f29 100755 --- a/rp.py +++ b/rp.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # Trigger build - import sys import os diff --git a/tests/conftest.py b/tests/conftest.py index 9122e80..6227450 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,8 @@ import os import tempfile -from pathlib import Path from unittest.mock import MagicMock import pytest -from pr.ads import AsyncDataSet -from pr.web.app import create_app @pytest.fixture @@ -44,77 +41,3 @@ def sample_context_file(temp_dir): with open(context_path, "w") as f: f.write("Sample context content\n") return context_path - - -@pytest.fixture -async def client(aiohttp_client, monkeypatch): - """Create a test client for the app.""" - with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp: - temp_db_file = tmp.name - - # Monkeypatch the db - monkeypatch.setattr("pr.web.views.base.db", AsyncDataSet(temp_db_file, f"{temp_db_file}.sock")) - - with tempfile.TemporaryDirectory() as tmpdir: - tmp_path = Path(tmpdir) - static_dir = tmp_path / "static" - templates_dir = tmp_path / "templates" - repos_dir = tmp_path / "repos" - - static_dir.mkdir() - templates_dir.mkdir() - repos_dir.mkdir() - - # Monkeypatch the directories - monkeypatch.setattr("pr.web.config.STATIC_DIR", static_dir) - monkeypatch.setattr("pr.web.config.TEMPLATES_DIR", templates_dir) - monkeypatch.setattr("pr.web.config.REPOS_DIR", repos_dir) - monkeypatch.setattr("pr.web.config.REPOS_DIR", repos_dir) - - # Create minimal templates - (templates_dir / "index.html").write_text("Index") - (templates_dir / "login.html").write_text("Login") - (templates_dir / "register.html").write_text("Register") - (templates_dir / "dashboard.html").write_text("Dashboard") - (templates_dir / "repos.html").write_text("Repos") - (templates_dir / "api_keys.html").write_text("API Keys") - (templates_dir / "repo.html").write_text("Repo") - (templates_dir / "file.html").write_text("File") - (templates_dir / "edit_file.html").write_text("Edit File") - (templates_dir / "deploy.html").write_text("Deploy") - - app_instance = await create_app() - client = await aiohttp_client(app_instance) - - yield client - - # Cleanup db - os.unlink(temp_db_file) - sock_file = f"{temp_db_file}.sock" - if os.path.exists(sock_file): - os.unlink(sock_file) - - -@pytest.fixture -async def authenticated_client(client): - """Create a client with an authenticated user.""" - # Register a user - resp = await client.post( - "/register", - data={ - "username": "testuser", - "email": "test@example.com", - "password": "password123", - "confirm_password": "password123", - }, - allow_redirects=False, - ) - assert resp.status == 302 # Redirect to login - - # Login - resp = await client.post( - "/login", data={"username": "testuser", "password": "password123"}, allow_redirects=False - ) - assert resp.status == 302 # Redirect to dashboard - - return client