diff --git a/CHANGELOG.md b/CHANGELOG.md index 63258a2..faf5e8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog + +## Version 1.3.0 - 2025-11-05 + +This release updates how the software finds configuration files and handles communication with agents. These changes improve reliability and allow for more flexible configuration options. + +**Changes:** 32 files, 2964 lines +**Languages:** Other (706 lines), Python (2214 lines), TOML (44 lines) + All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), diff --git a/Makefile b/Makefile index 9f72562..b39e87e 100644 --- a/Makefile +++ b/Makefile @@ -72,7 +72,7 @@ serve: implode: build - rpi rp.py -o rp + cp rp.py rp chmod +x rp if [ -d /home/retoor/bin ]; then cp rp /home/retoor/bin/rp; fi diff --git a/pr/ads.py b/pr/ads.py new file mode 100644 index 0000000..a7941c2 --- /dev/null +++ b/pr/ads.py @@ -0,0 +1,969 @@ +import re +import json +from uuid import uuid4 +from datetime import datetime, timezone +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + AsyncGenerator, + Union, + Tuple, + Set, +) +import aiosqlite +import asyncio +import aiohttp +from aiohttp import web +import socket +import os +import atexit + + +class AsyncDataSet: + """ + Distributed AsyncDataSet with client-server model over Unix sockets. + + Parameters: + - file: Path to the SQLite database file + - socket_path: Path to Unix socket for inter-process communication (default: "ads.sock") + - max_concurrent_queries: Maximum concurrent database operations (default: 1) + Set to 1 to prevent "database is locked" errors with SQLite + """ + + _KV_TABLE = "__kv_store" + _DEFAULT_COLUMNS = { + "uid": "TEXT PRIMARY KEY", + "created_at": "TEXT", + "updated_at": "TEXT", + "deleted_at": "TEXT", + } + + def __init__(self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1): + self._file = file + self._socket_path = socket_path + self._table_columns_cache: Dict[str, Set[str]] = {} + self._is_server = None # None means not initialized yet + self._server = None + self._runner = None + self._client_session = None + self._initialized = False + self._init_lock = None # Will be created when needed + self._max_concurrent_queries = max_concurrent_queries + self._db_semaphore = None # Will be created when needed + self._db_connection = None # Persistent database connection + + async def _ensure_initialized(self): + """Ensure the instance is initialized as server or client.""" + if self._initialized: + return + + # Create lock if needed (first time in async context) + if self._init_lock is None: + self._init_lock = asyncio.Lock() + + async with self._init_lock: + if self._initialized: + return + + await self._initialize() + + # Create semaphore for database operations (server only) + if self._is_server: + self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries) + + self._initialized = True + + async def _initialize(self): + """Initialize as server or client.""" + try: + # Try to create Unix socket + if os.path.exists(self._socket_path): + # Check if socket is active + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(self._socket_path) + sock.close() + # Socket is active, we're a client + self._is_server = False + await self._setup_client() + except (ConnectionRefusedError, FileNotFoundError): + # Socket file exists but not active, remove and become server + os.unlink(self._socket_path) + self._is_server = True + await self._setup_server() + else: + # No socket file, become server + self._is_server = True + await self._setup_server() + except Exception: + # Fallback to client mode + self._is_server = False + await self._setup_client() + + # Establish persistent database connection + self._db_connection = await aiosqlite.connect(self._file) + + async def _setup_server(self): + """Setup aiohttp server.""" + app = web.Application() + + # Add routes for all methods + app.router.add_post("/insert", self._handle_insert) + app.router.add_post("/update", self._handle_update) + app.router.add_post("/delete", self._handle_delete) + app.router.add_post("/upsert", self._handle_upsert) + app.router.add_post("/get", self._handle_get) + app.router.add_post("/find", self._handle_find) + app.router.add_post("/count", self._handle_count) + app.router.add_post("/exists", self._handle_exists) + app.router.add_post("/kv_set", self._handle_kv_set) + app.router.add_post("/kv_get", self._handle_kv_get) + app.router.add_post("/execute_raw", self._handle_execute_raw) + app.router.add_post("/query_raw", self._handle_query_raw) + app.router.add_post("/query_one", self._handle_query_one) + app.router.add_post("/create_table", self._handle_create_table) + app.router.add_post("/insert_unique", self._handle_insert_unique) + app.router.add_post("/aggregate", self._handle_aggregate) + + self._runner = web.AppRunner(app) + await self._runner.setup() + self._server = web.UnixSite(self._runner, self._socket_path) + await self._server.start() + + # Register cleanup + def cleanup(): + if os.path.exists(self._socket_path): + os.unlink(self._socket_path) + + atexit.register(cleanup) + + async def _setup_client(self): + """Setup aiohttp client.""" + connector = aiohttp.UnixConnector(path=self._socket_path) + self._client_session = aiohttp.ClientSession(connector=connector) + + async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any: + """Make HTTP request to server.""" + if not self._client_session: + await self._setup_client() + + url = f"http://localhost/{endpoint}" + async with self._client_session.post(url, json=data) as resp: + result = await resp.json() + if result.get("error"): + raise Exception(result["error"]) + return result.get("result") + + # Server handlers + async def _handle_insert(self, request): + data = await request.json() + try: + result = await self._server_insert( + data["table"], data["args"], data.get("return_id", False) + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_update(self, request): + data = await request.json() + try: + result = await self._server_update(data["table"], data["args"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_delete(self, request): + data = await request.json() + try: + result = await self._server_delete(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_upsert(self, request): + data = await request.json() + try: + result = await self._server_upsert(data["table"], data["args"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_get(self, request): + data = await request.json() + try: + result = await self._server_get(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_find(self, request): + data = await request.json() + try: + result = await self._server_find( + data["table"], + data.get("where"), + limit=data.get("limit", 0), + offset=data.get("offset", 0), + order_by=data.get("order_by"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_count(self, request): + data = await request.json() + try: + result = await self._server_count(data["table"], data.get("where")) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_exists(self, request): + data = await request.json() + try: + result = await self._server_exists(data["table"], data["where"]) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_set(self, request): + data = await request.json() + try: + await self._server_kv_set(data["key"], data["value"], table=data.get("table")) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_kv_get(self, request): + data = await request.json() + try: + result = await self._server_kv_get( + data["key"], default=data.get("default"), table=data.get("table") + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_execute_raw(self, request): + data = await request.json() + try: + result = await self._server_execute_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response( + {"result": result.rowcount if hasattr(result, "rowcount") else None} + ) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_raw(self, request): + data = await request.json() + try: + result = await self._server_query_raw( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_query_one(self, request): + data = await request.json() + try: + result = await self._server_query_one( + data["sql"], + tuple(data.get("params", [])) if data.get("params") else None, + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_create_table(self, request): + data = await request.json() + try: + await self._server_create_table(data["table"], data["schema"], data.get("constraints")) + return web.json_response({"result": None}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_insert_unique(self, request): + data = await request.json() + try: + result = await self._server_insert_unique( + data["table"], data["args"], data["unique_fields"] + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + async def _handle_aggregate(self, request): + data = await request.json() + try: + result = await self._server_aggregate( + data["table"], + data["function"], + data.get("column", "*"), + data.get("where"), + ) + return web.json_response({"result": result}) + except Exception as e: + return web.json_response({"error": str(e)}, status=500) + + # Public methods that delegate to server or client + async def insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert(table, args, return_id) + else: + return await self._make_request( + "insert", {"table": table, "args": args, "return_id": return_id} + ) + + async def update( + self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None + ) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_update(table, args, where) + else: + return await self._make_request( + "update", {"table": table, "args": args, "where": where} + ) + + async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_delete(table, where) + else: + return await self._make_request("delete", {"table": table, "where": where}) + + async def get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_get(table, where) + else: + return await self._make_request("get", {"table": table, "where": where}) + + async def find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_find( + table, where, limit=limit, offset=offset, order_by=order_by + ) + else: + return await self._make_request( + "find", + { + "table": table, + "where": where, + "limit": limit, + "offset": offset, + "order_by": order_by, + }, + ) + + async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + await self._ensure_initialized() + if self._is_server: + return await self._server_count(table, where) + else: + return await self._make_request("count", {"table": table, "where": where}) + + async def exists(self, table: str, where: Dict[str, Any]) -> bool: + await self._ensure_initialized() + if self._is_server: + return await self._server_exists(table, where) + else: + return await self._make_request("exists", {"table": table, "where": where}) + + async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]: + await self._ensure_initialized() + if self._is_server: + return await self._server_query_raw(sql, params) + else: + return await self._make_request( + "query_raw", {"sql": sql, "params": list(params) if params else None} + ) + + async def create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + await self._ensure_initialized() + if self._is_server: + return await self._server_create_table(table, schema, constraints) + else: + return await self._make_request( + "create_table", + {"table": table, "schema": schema, "constraints": constraints}, + ) + + async def insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + await self._ensure_initialized() + if self._is_server: + return await self._server_insert_unique(table, args, unique_fields) + else: + return await self._make_request( + "insert_unique", + {"table": table, "args": args, "unique_fields": unique_fields}, + ) + + async def aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + await self._ensure_initialized() + if self._is_server: + return await self._server_aggregate(table, function, column, where) + else: + return await self._make_request( + "aggregate", + { + "table": table, + "function": function, + "column": column, + "where": where, + }, + ) + + async def transaction(self): + """Context manager for transactions.""" + await self._ensure_initialized() + if self._is_server: + return TransactionContext(self._db_connection, self._db_semaphore) + else: + raise NotImplementedError("Transactions not supported in client mode") + + # Server implementation methods (original logic) + @staticmethod + def _utc_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + @staticmethod + def _py_to_sqlite_type(value: Any) -> str: + if value is None: + return "TEXT" + if isinstance(value, bool): + return "INTEGER" + if isinstance(value, int): + return "INTEGER" + if isinstance(value, float): + return "REAL" + if isinstance(value, (bytes, bytearray, memoryview)): + return "BLOB" + return "TEXT" + + async def _get_table_columns(self, table: str) -> Set[str]: + """Get actual columns that exist in the table.""" + if table in self._table_columns_cache: + return self._table_columns_cache[table] + + columns = set() + try: + async with self._db_semaphore: + async with self._db_connection.execute(f"PRAGMA table_info({table})") as cursor: + async for row in cursor: + columns.add(row[1]) # Column name is at index 1 + self._table_columns_cache[table] = columns + except: + pass + return columns + + async def _invalidate_column_cache(self, table: str): + """Invalidate column cache for a table.""" + if table in self._table_columns_cache: + del self._table_columns_cache[table] + + async def _ensure_column(self, table: str, name: str, value: Any) -> None: + col_type = self._py_to_sqlite_type(value) + try: + async with self._db_semaphore: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}" + ) + await self._db_connection.commit() + await self._invalidate_column_cache(table) + except aiosqlite.OperationalError as e: + if "duplicate column name" in str(e).lower(): + pass # Column already exists + else: + raise + + async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None: + # Always include default columns + cols = self._DEFAULT_COLUMNS.copy() + + # Add columns from col_sources + for key, val in col_sources.items(): + if key not in cols: + cols[key] = self._py_to_sqlite_type(val) + + columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items()) + async with self._db_semaphore: + await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _table_exists(self, table: str) -> bool: + """Check if a table exists.""" + async with self._db_semaphore: + async with self._db_connection.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) + ) as cursor: + return await cursor.fetchone() is not None + + _RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)") + _RE_NO_TABLE = re.compile(r"no such table: (\w+)") + + @classmethod + def _missing_column_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]: + m = cls._RE_NO_COLUMN.search(str(err)) + return m.group(1) if m else None + + @classmethod + def _missing_table_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]: + m = cls._RE_NO_TABLE.search(str(err)) + return m.group(1) if m else None + + async def _safe_execute( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + max_retries: int = 10, + ) -> aiosqlite.Cursor: + retries = 0 + while retries < max_retries: + try: + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params) + await self._db_connection.commit() + return cursor + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing column + col = self._missing_column_from_error(err) + if col: + if col in col_sources: + await self._ensure_column(table, col, col_sources[col]) + else: + # Column not in sources, ensure it with NULL/TEXT type + await self._ensure_column(table, col, None) + continue + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + await self._ensure_table(tbl, col_sources) + continue + + # Handle other column-related errors + if "has no column named" in err_str: + # Extract column name differently + match = re.search(r"table \w+ has no column named (\w+)", err_str) + if match: + col_name = match.group(1) + if col_name in col_sources: + await self._ensure_column(table, col_name, col_sources[col_name]) + else: + await self._ensure_column(table, col_name, None) + continue + + raise + raise Exception(f"Max retries ({max_retries}) exceeded") + + async def _safe_query( + self, + table: str, + sql: str, + params: Iterable[Any], + col_sources: Dict[str, Any], + ) -> AsyncGenerator[Dict[str, Any], None]: + # Check if table exists first + if not await self._table_exists(table): + return + + max_retries = 10 + retries = 0 + + while retries < max_retries: + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params) as cursor: + # Fetch all rows while holding the semaphore + rows = await cursor.fetchall() + + # Yield rows after releasing the semaphore + for row in rows: + yield dict(row) + return + except aiosqlite.OperationalError as err: + retries += 1 + err_str = str(err).lower() + + # Handle missing table + tbl = self._missing_table_from_error(err) + if tbl: + # For queries, if table doesn't exist, just return empty + return + + # Handle missing column in WHERE clause or SELECT + if "no such column" in err_str: + # For queries with missing columns, return empty + return + + raise + + @staticmethod + def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: + if not where: + return "", [] + + clauses = [] + vals = [] + + op_map = {"$eq": "=", "$ne": "!=", "$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="} + + for k, v in where.items(): + if isinstance(v, dict): + # Handle operators like {"$lt": 10} + op_key = next(iter(v)) + if op_key in op_map: + operator = op_map[op_key] + value = v[op_key] + clauses.append(f"`{k}` {operator} ?") + vals.append(value) + else: + # Fallback for unexpected dict values + clauses.append(f"`{k}` = ?") + vals.append(json.dumps(v)) + else: + # Default to equality + clauses.append(f"`{k}` = ?") + vals.append(v) + + return " WHERE " + " AND ".join(clauses), vals + + async def _server_insert( + self, table: str, args: Dict[str, Any], return_id: bool = False + ) -> Union[str, int]: + """Insert a record. If return_id=True, returns auto-incremented ID instead of UUID.""" + uid = str(uuid4()) + now = self._utc_iso() + record = { + "uid": uid, + "created_at": now, + "updated_at": now, + "deleted_at": None, + **args, + } + + # Ensure table exists with all needed columns + await self._ensure_table(table, record) + + # Handle auto-increment ID if requested + if return_id and "id" not in args: + # Ensure id column exists + async with self._db_semaphore: + # Add id column if it doesn't exist + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT" + ) + await self._db_connection.commit() + except aiosqlite.OperationalError as e: + if "duplicate column name" not in str(e).lower(): + # Try without autoincrement constraint + try: + await self._db_connection.execute( + f"ALTER TABLE {table} ADD COLUMN id INTEGER" + ) + await self._db_connection.commit() + except: + pass + + await self._invalidate_column_cache(table) + + # Insert and get lastrowid + cols = "`" + "`, `".join(record.keys()) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + cursor = await self._safe_execute(table, sql, list(record.values()), record) + return cursor.lastrowid + + cols = "`" + "`, `".join(record) + "`" + qs = ", ".join(["?"] * len(record)) + sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})" + await self._safe_execute(table, sql, list(record.values()), record) + return uid + + async def _server_update( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> int: + if not args: + return 0 + + # Check if table exists + if not await self._table_exists(table): + return 0 + + args["updated_at"] = self._utc_iso() + + # Ensure all columns exist + all_cols = {**args, **(where or {})} + await self._ensure_table(table, all_cols) + for col, val in all_cols.items(): + await self._ensure_column(table, col, val) + + set_clause = ", ".join(f"`{k}` = ?" for k in args) + where_clause, where_params = self._build_where(where) + sql = f"UPDATE {table} SET {set_clause}{where_clause}" + params = list(args.values()) + where_params + cur = await self._safe_execute(table, sql, params, all_cols) + return cur.rowcount + + async def _server_delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"DELETE FROM {table}{where_clause}" + cur = await self._safe_execute(table, sql, where_params, where or {}) + return cur.rowcount + + async def _server_upsert( + self, + table: str, + args: Dict[str, Any], + where: Optional[Dict[str, Any]] = None, + ) -> str | None: + if not args: + raise ValueError("Nothing to update. Empty dict given.") + args["updated_at"] = self._utc_iso() + affected = await self._server_update(table, args, where) + if affected: + rec = await self._server_get(table, where) + return rec.get("uid") if rec else None + merged = {**(where or {}), **args} + return await self._server_insert(table, merged) + + async def _server_get( + self, table: str, where: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: + where_clause, where_params = self._build_where(where) + sql = f"SELECT * FROM {table}{where_clause} LIMIT 1" + async for row in self._safe_query(table, sql, where_params, where or {}): + return row + return None + + async def _server_find( + self, + table: str, + where: Optional[Dict[str, Any]] = None, + *, + limit: int = 0, + offset: int = 0, + order_by: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Find records with optional ordering.""" + where_clause, where_params = self._build_where(where) + order_clause = f" ORDER BY {order_by}" if order_by else "" + extra = (f" LIMIT {limit}" if limit else "") + (f" OFFSET {offset}" if offset else "") + sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}" + return [row async for row in self._safe_query(table, sql, where_params, where or {})] + + async def _server_count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int: + # Check if table exists + if not await self._table_exists(table): + return 0 + + where_clause, where_params = self._build_where(where) + sql = f"SELECT COUNT(*) FROM {table}{where_clause}" + gen = self._safe_query(table, sql, where_params, where or {}) + async for row in gen: + return next(iter(row.values()), 0) + return 0 + + async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool: + return (await self._server_count(table, where)) > 0 + + async def _server_kv_set( + self, + key: str, + value: Any, + *, + table: str | None = None, + ) -> None: + tbl = table or self._KV_TABLE + json_val = json.dumps(value, default=str) + await self._server_upsert(tbl, {"value": json_val}, {"key": key}) + + async def _server_kv_get( + self, + key: str, + *, + default: Any = None, + table: str | None = None, + ) -> Any: + tbl = table or self._KV_TABLE + row = await self._server_get(tbl, {"key": key}) + if not row: + return default + try: + return json.loads(row["value"]) + except Exception: + return default + + async def _server_execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any: + """Execute raw SQL for complex queries like JOINs.""" + async with self._db_semaphore: + cursor = await self._db_connection.execute(sql, params or ()) + await self._db_connection.commit() + return cursor + + async def _server_query_raw( + self, sql: str, params: Optional[Tuple] = None + ) -> List[Dict[str, Any]]: + """Execute raw SQL query and return results as list of dicts.""" + try: + async with self._db_semaphore: + self._db_connection.row_factory = aiosqlite.Row + async with self._db_connection.execute(sql, params or ()) as cursor: + return [dict(row) async for row in cursor] + except aiosqlite.OperationalError: + # Return empty list if query fails + return [] + + async def _server_query_one( + self, sql: str, params: Optional[Tuple] = None + ) -> Optional[Dict[str, Any]]: + """Execute raw SQL query and return single result.""" + results = await self._server_query_raw(sql + " LIMIT 1", params) + return results[0] if results else None + + async def _server_create_table( + self, + table: str, + schema: Dict[str, str], + constraints: Optional[List[str]] = None, + ): + """Create table with custom schema and constraints. Always includes default columns.""" + # Merge default columns with custom schema + full_schema = self._DEFAULT_COLUMNS.copy() + full_schema.update(schema) + + columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()] + if constraints: + columns.extend(constraints) + columns_sql = ", ".join(columns) + + async with self._db_semaphore: + await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})") + await self._db_connection.commit() + await self._invalidate_column_cache(table) + + async def _server_insert_unique( + self, table: str, args: Dict[str, Any], unique_fields: List[str] + ) -> Union[str, None]: + """Insert with unique constraint handling. Returns uid on success, None if duplicate.""" + try: + return await self._server_insert(table, args) + except aiosqlite.IntegrityError as e: + if "UNIQUE" in str(e): + return None + raise + + async def _server_aggregate( + self, + table: str, + function: str, + column: str = "*", + where: Optional[Dict[str, Any]] = None, + ) -> Any: + """Perform aggregate functions like SUM, AVG, MAX, MIN.""" + # Check if table exists + if not await self._table_exists(table): + return None + + where_clause, where_params = self._build_where(where) + sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}" + result = await self._server_query_one(sql, tuple(where_params)) + return result["result"] if result else None + + async def close(self): + """Close the connection or server.""" + if not self._initialized: + return + + if self._client_session: + await self._client_session.close() + if self._runner: + await self._runner.cleanup() + if os.path.exists(self._socket_path) and self._is_server: + os.unlink(self._socket_path) + if self._db_connection: + await self._db_connection.close() + + +class TransactionContext: + """Context manager for database transactions.""" + + def __init__(self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None): + self.db_connection = db_connection + self.semaphore = semaphore + + async def __aenter__(self): + if self.semaphore: + await self.semaphore.acquire() + try: + await self.db_connection.execute("BEGIN") + return self.db_connection + except: + if self.semaphore: + self.semaphore.release() + raise + + async def __aexit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type is None: + await self.db_connection.commit() + else: + await self.db_connection.rollback() + finally: + if self.semaphore: + self.semaphore.release() + + +# Test cases remain the same but with additional tests for new functionality diff --git a/pr/agents/agent_communication.py b/pr/agents/agent_communication.py index daa18d7..27569fa 100644 --- a/pr/agents/agent_communication.py +++ b/pr/agents/agent_communication.py @@ -99,7 +99,7 @@ class AgentCommunicationBus: self.conn.commit() - def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: + def receive_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: cursor = self.conn.cursor() if unread_only: cursor.execute( @@ -153,9 +153,6 @@ class AgentCommunicationBus: def close(self): self.conn.close() - def receive_messages(self, agent_id: str) -> List[AgentMessage]: - return self.get_messages(agent_id, unread_only=True) - def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]: cursor = self.conn.cursor() cursor.execute( diff --git a/pr/agents/agent_manager.py b/pr/agents/agent_manager.py index 441ece3..65c9a6d 100644 --- a/pr/agents/agent_manager.py +++ b/pr/agents/agent_manager.py @@ -125,7 +125,7 @@ class AgentManager: return message.message_id def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: - return self.communication_bus.get_messages(agent_id, unread_only) + return self.communication_bus.receive_messages(agent_id, unread_only) def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]): orchestrator = self.get_agent(orchestrator_id) diff --git a/pr/agents/agent_roles.py b/pr/agents/agent_roles.py index 1bdddde..851df2c 100644 --- a/pr/agents/agent_roles.py +++ b/pr/agents/agent_roles.py @@ -229,45 +229,3 @@ def get_agent_role(role_name: str) -> AgentRole: def list_agent_roles() -> Dict[str, AgentRole]: return AGENT_ROLES.copy() - - -def get_recommended_agent(task_description: str) -> str: - task_lower = task_description.lower() - - code_keywords = [ - "code", - "implement", - "function", - "class", - "bug", - "debug", - "refactor", - "optimize", - ] - research_keywords = [ - "search", - "find", - "research", - "information", - "analyze", - "investigate", - ] - data_keywords = ["data", "database", "query", "statistics", "analyze", "process"] - planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"] - testing_keywords = ["test", "verify", "validate", "check", "quality"] - doc_keywords = ["document", "documentation", "explain", "guide", "manual"] - - if any(keyword in task_lower for keyword in code_keywords): - return "coding" - elif any(keyword in task_lower for keyword in research_keywords): - return "research" - elif any(keyword in task_lower for keyword in data_keywords): - return "data_analysis" - elif any(keyword in task_lower for keyword in planning_keywords): - return "planning" - elif any(keyword in task_lower for keyword in testing_keywords): - return "testing" - elif any(keyword in task_lower for keyword in doc_keywords): - return "documentation" - else: - return "general" diff --git a/pr/cache/tool_cache.py b/pr/cache/tool_cache.py index b22b048..6dd37a8 100644 --- a/pr/cache/tool_cache.py +++ b/pr/cache/tool_cache.py @@ -122,18 +122,6 @@ class ToolCache: conn.commit() conn.close() - def invalidate_tool(self, tool_name: str): - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() - - cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,)) - deleted_count = cursor.rowcount - - conn.commit() - conn.close() - - return deleted_count - def clear_expired(self): current_time = int(time.time()) diff --git a/pr/commands/handlers.py b/pr/commands/handlers.py index 0fd9d55..f26ce4b 100644 --- a/pr/commands/handlers.py +++ b/pr/commands/handlers.py @@ -1,6 +1,7 @@ import json import time +from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS from pr.autonomous import run_autonomous_mode from pr.core.api import list_models from pr.tools import read_file @@ -13,6 +14,12 @@ def handle_command(assistant, command): command_parts = command.strip().split(maxsplit=1) cmd = command_parts[0].lower() + if cmd in MULTIPLEXER_COMMANDS: + # Pass the assistant object and the remaining arguments to the multiplexer command + return MULTIPLEXER_COMMANDS[cmd]( + assistant, command_parts[1:] if len(command_parts) > 1 else [] + ) + if cmd == "/edit": rp_editor = RPEditor(command_parts[1] if len(command_parts) > 1 else None) rp_editor.start() @@ -23,6 +30,18 @@ def handle_command(assistant, command): if task: run_autonomous_mode(assistant, task) + elif cmd == "/prompt": + rp_editor = RPEditor(command_parts[1] if len(command_parts) > 1 else None) + rp_editor.start() + rp_editor.thread.join() + prompt_text = str(rp_editor.get_text()) + rp_editor.stop() + rp_editor = None + if prompt_text.strip(): + from pr.core.assistant import process_message + + process_message(assistant, prompt_text) + elif cmd == "/auto": if len(command_parts) < 2: print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}") diff --git a/pr/commands/help_docs.py b/pr/commands/help_docs.py new file mode 100644 index 0000000..61f3f82 --- /dev/null +++ b/pr/commands/help_docs.py @@ -0,0 +1,507 @@ +from pr.ui import Colors + + +def get_workflow_help(): + return f""" +{Colors.BOLD}WORKFLOWS - AUTOMATED TASK EXECUTION{Colors.RESET} + +{Colors.BOLD}Overview:{Colors.RESET} +Workflows enable automated execution of multi-step tasks by chaining together +tool calls with defined execution logic, error handling, and conditional branching. + +{Colors.BOLD}Execution Modes:{Colors.RESET} + {Colors.CYAN}sequential{Colors.RESET} - Steps execute one after another in order + {Colors.CYAN}parallel{Colors.RESET} - All steps execute simultaneously + {Colors.CYAN}conditional{Colors.RESET} - Steps execute based on conditions and paths + +{Colors.BOLD}Workflow Components:{Colors.RESET} + {Colors.CYAN}Steps{Colors.RESET} - Individual tool executions with parameters + {Colors.CYAN}Variables{Colors.RESET} - Dynamic values available across workflow + {Colors.CYAN}Conditions{Colors.RESET} - Boolean expressions controlling step execution + {Colors.CYAN}Paths{Colors.RESET} - Success/failure routing between steps + {Colors.CYAN}Retries{Colors.RESET} - Automatic retry count for failed steps + {Colors.CYAN}Timeouts{Colors.RESET} - Maximum execution time per step (seconds) + +{Colors.BOLD}Variable Substitution:{Colors.RESET} + Reference previous step results: ${{step.step_id}} + Reference workflow variables: ${{var.variable_name}} + + Example: {{"path": "${{step.get_path}}/output.txt"}} + +{Colors.BOLD}Workflow Commands:{Colors.RESET} + {Colors.CYAN}/workflows{Colors.RESET} - List all available workflows + {Colors.CYAN}/workflow {Colors.RESET} - Execute a specific workflow + +{Colors.BOLD}Workflow Structure:{Colors.RESET} +Each workflow consists of: + - Name and description + - Execution mode (sequential/parallel/conditional) + - List of steps with tool calls + - Optional variables for dynamic values + - Optional tags for categorization + +{Colors.BOLD}Step Configuration:{Colors.RESET} +Each step includes: + - tool_name: Which tool to execute + - arguments: Parameters passed to the tool + - step_id: Unique identifier for the step + - condition: Optional boolean expression (conditional mode) + - on_success: List of step IDs to execute on success + - on_failure: List of step IDs to execute on failure + - retry_count: Number of automatic retries (default: 0) + - timeout_seconds: Maximum execution time (default: 300) + +{Colors.BOLD}Conditional Execution:{Colors.RESET} +Conditions are Python expressions evaluated with access to: + - variables: All workflow variables + - results: All step results by step_id + + Example condition: "results['check_file']['status'] == 'success'" + +{Colors.BOLD}Success/Failure Paths:{Colors.RESET} +Steps can define: + - on_success: List of step IDs to execute if step succeeds + - on_failure: List of step IDs to execute if step fails + +This enables error recovery and branching logic within workflows. + +{Colors.BOLD}Workflow Storage:{Colors.RESET} +Workflows are stored in the SQLite database and persist across sessions. +Execution history and statistics are tracked for each workflow. + +{Colors.BOLD}Example Workflow Scenarios:{Colors.RESET} + - Sequential: Data pipeline (fetch → process → save → validate) + - Parallel: Multi-source aggregation (fetch A + fetch B + fetch C) + - Conditional: Error recovery (try main → on failure → fallback) + +{Colors.BOLD}Best Practices:{Colors.RESET} + - Use descriptive step_id names (e.g., "fetch_user_data") + - Set appropriate timeouts for long-running operations + - Add retry_count for network or flaky operations + - Use conditional paths for error handling + - Keep workflows focused on a single logical task + - Use variables for values shared across steps +""" + + +def get_agent_help(): + return f""" +{Colors.BOLD}AGENTS - SPECIALIZED AI ASSISTANTS{Colors.RESET} + +{Colors.BOLD}Overview:{Colors.RESET} +Agents are specialized AI instances with specific roles, capabilities, and +system prompts. Each agent maintains its own conversation history and can +collaborate with other agents to complete complex tasks. + +{Colors.BOLD}Available Agent Roles:{Colors.RESET} + {Colors.CYAN}coding{Colors.RESET} - Code writing, review, debugging, refactoring + Temperature: 0.3 (precise) + + {Colors.CYAN}research{Colors.RESET} - Information gathering, analysis, documentation + Temperature: 0.5 (balanced) + + {Colors.CYAN}data_analysis{Colors.RESET} - Data processing, statistical analysis, queries + Temperature: 0.3 (precise) + + {Colors.CYAN}planning{Colors.RESET} - Task decomposition, workflow design, coordination + Temperature: 0.6 (creative) + + {Colors.CYAN}testing{Colors.RESET} - Test design, quality assurance, validation + Temperature: 0.4 (methodical) + + {Colors.CYAN}documentation{Colors.RESET} - Technical writing, API docs, user guides + Temperature: 0.6 (creative) + + {Colors.CYAN}orchestrator{Colors.RESET} - Multi-agent coordination and task delegation + Temperature: 0.5 (balanced) + + {Colors.CYAN}general{Colors.RESET} - General purpose assistance across domains + Temperature: 0.7 (versatile) + +{Colors.BOLD}Agent Commands:{Colors.RESET} + {Colors.CYAN}/agent {Colors.RESET} - Create agent and assign task + {Colors.CYAN}/agents{Colors.RESET} - Show all active agents + {Colors.CYAN}/collaborate {Colors.RESET} - Multi-agent collaboration + +{Colors.BOLD}Single Agent Usage:{Colors.RESET} +Create a specialized agent and assign it a task: + + /agent coding Implement a binary search function in Python + /agent research Find latest best practices for API security + /agent testing Create test cases for user authentication + +Each agent: + - Maintains its own conversation history + - Has access to role-specific tools + - Uses specialized system prompts + - Tracks task completion count + - Integrates with knowledge base + +{Colors.BOLD}Multi-Agent Collaboration:{Colors.RESET} +Use multiple agents to work together on complex tasks: + + /collaborate Build a REST API with tests and documentation + +This automatically: + 1. Creates an orchestrator agent + 2. Spawns coding, research, and planning agents + 3. Orchestrator delegates subtasks to specialists + 4. Agents communicate results back to orchestrator + 5. Orchestrator integrates results and coordinates + +{Colors.BOLD}Agent Capabilities by Role:{Colors.RESET} + +{Colors.CYAN}Coding Agent:{Colors.RESET} + - Read, write, and modify files + - Execute Python code + - Run commands and build tools + - Index source directories + - Code review and refactoring + +{Colors.CYAN}Research Agent:{Colors.RESET} + - Web search and HTTP fetching + - Read documentation and files + - Query databases for information + - Analyze and synthesize findings + +{Colors.CYAN}Data Analysis Agent:{Colors.RESET} + - Database queries and operations + - Python execution for analysis + - File read/write for data + - Statistical processing + +{Colors.CYAN}Planning Agent:{Colors.RESET} + - Task decomposition and organization + - Database storage for plans + - File operations for documentation + - Dependency identification + +{Colors.CYAN}Testing Agent:{Colors.RESET} + - Test execution and validation + - Code reading and analysis + - Command execution for test runners + - Quality verification + +{Colors.CYAN}Documentation Agent:{Colors.RESET} + - Technical writing and formatting + - Web research for best practices + - File operations for docs + - Documentation organization + +{Colors.BOLD}Agent Session Management:{Colors.RESET} + - Each session has a unique ID + - Agents persist within a session + - View active agents with /agents + - Agents share knowledge base access + - Communication bus for agent messages + +{Colors.BOLD}Knowledge Base Integration:{Colors.RESET} +When agents execute tasks: + - Relevant knowledge entries are automatically retrieved + - Top 3 matches are injected into agent context + - Knowledge access count is tracked + - Agents can update knowledge base + +{Colors.BOLD}Agent Communication:{Colors.RESET} +Agents can send messages to each other: + - Request type: Ask another agent for help + - Response type: Answer another agent's request + - Notification type: Inform about events + - Coordination type: Synchronize work + +{Colors.BOLD}Best Practices:{Colors.RESET} + - Use specialized agents for focused tasks + - Use collaboration for complex multi-domain tasks + - Coding agent: Precise, deterministic code generation + - Research agent: Thorough information gathering + - Planning agent: Breaking down complex tasks + - Testing agent: Comprehensive test coverage + - Documentation agent: Clear user-facing docs + +{Colors.BOLD}Performance Characteristics:{Colors.RESET} + - Lower temperature: More deterministic (coding, data_analysis) + - Higher temperature: More creative (planning, documentation) + - Max tokens: 4096 for all agents + - Parallel execution: Multiple agents work simultaneously + - Message history: Full conversation context maintained + +{Colors.BOLD}Example Usage Patterns:{Colors.RESET} + +Single specialized agent: + /agent coding Refactor authentication module for better security + +Multi-agent collaboration: + /collaborate Create a data visualization dashboard with tests + +Agent inspection: + /agents # View all active agents and their stats +""" + + +def get_knowledge_help(): + return f""" +{Colors.BOLD}KNOWLEDGE BASE - PERSISTENT INFORMATION STORAGE{Colors.RESET} + +{Colors.BOLD}Overview:{Colors.RESET} +The knowledge base provides persistent storage for information, facts, and +context that can be accessed across sessions and by all agents. + +{Colors.BOLD}Commands:{Colors.RESET} + {Colors.CYAN}/remember {Colors.RESET} - Store information in knowledge base + {Colors.CYAN}/knowledge {Colors.RESET} - Search knowledge base + {Colors.CYAN}/stats{Colors.RESET} - Show knowledge base statistics + +{Colors.BOLD}Features:{Colors.RESET} + - Automatic categorization of content + - TF-IDF based semantic search + - Access frequency tracking + - Importance scoring + - Temporal relevance + - Category-based organization + +{Colors.BOLD}Categories:{Colors.RESET} +Content is automatically categorized as: + - code: Code snippets and programming concepts + - api: API documentation and endpoints + - configuration: Settings and configurations + - documentation: General documentation + - fact: Factual information + - procedure: Step-by-step procedures + - general: Uncategorized information + +{Colors.BOLD}Search Capabilities:{Colors.RESET} + - Keyword matching with TF-IDF scoring + - Returns top relevant entries + - Considers access frequency + - Factors in recency + - Searches across all categories + +{Colors.BOLD}Integration:{Colors.RESET} + - Agents automatically query knowledge base + - Top matches injected into agent context + - Conversation history can be stored + - Facts extracted from interactions + - Persistent across sessions +""" + + +def get_cache_help(): + return f""" +{Colors.BOLD}CACHING SYSTEM - PERFORMANCE OPTIMIZATION{Colors.RESET} + +{Colors.BOLD}Overview:{Colors.RESET} +The caching system optimizes performance by storing API responses and tool +results, reducing redundant API calls and expensive operations. + +{Colors.BOLD}Commands:{Colors.RESET} + {Colors.CYAN}/cache{Colors.RESET} - Show cache statistics + {Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches + +{Colors.BOLD}Cache Types:{Colors.RESET} + +{Colors.CYAN}API Cache:{Colors.RESET} + - Stores LLM API responses + - Key: Hash of messages + model + parameters + - Tracks token usage + - TTL: Configurable expiration + - Reduces API costs and latency + +{Colors.CYAN}Tool Cache:{Colors.RESET} + - Stores tool execution results + - Key: Tool name + arguments hash + - Per-tool statistics + - Hit count tracking + - Reduces redundant operations + +{Colors.BOLD}Cache Statistics:{Colors.RESET} + - Total entries (valid + expired) + - Valid entry count + - Expired entry count + - Total cached tokens (API cache) + - Cache hits per tool (Tool cache) + - Per-tool breakdown + +{Colors.BOLD}Benefits:{Colors.RESET} + - Reduced API costs + - Faster response times + - Lower rate limit usage + - Consistent results for identical queries + - Tool execution optimization +""" + + +def get_background_help(): + return f""" +{Colors.BOLD}BACKGROUND SESSIONS - CONCURRENT TASK EXECUTION{Colors.RESET} + +{Colors.BOLD}Overview:{Colors.RESET} +Background sessions allow running long-running commands and processes while +continuing to interact with the assistant. Sessions are monitored and can +be managed through the assistant. + +{Colors.BOLD}Commands:{Colors.RESET} + {Colors.CYAN}/bg start {Colors.RESET} - Start command in background + {Colors.CYAN}/bg list{Colors.RESET} - List all background sessions + {Colors.CYAN}/bg status {Colors.RESET} - Show session status + {Colors.CYAN}/bg output {Colors.RESET} - View session output + {Colors.CYAN}/bg input {Colors.RESET} - Send input to session + {Colors.CYAN}/bg kill {Colors.RESET} - Terminate session + {Colors.CYAN}/bg events{Colors.RESET} - Show recent events + +{Colors.BOLD}Features:{Colors.RESET} + - Automatic session monitoring + - Output capture and buffering + - Interactive input support + - Event notification system + - Session lifecycle management + - Background event detection + +{Colors.BOLD}Event Types:{Colors.RESET} + - session_started: New session created + - session_ended: Session terminated + - output_received: New output available + - possible_input_needed: May require user input + - high_output_volume: Large amount of output + - inactive_session: No activity for period + +{Colors.BOLD}Status Indicators:{Colors.RESET} +The prompt shows background session count: You[2bg]> + - Number indicates active background sessions + - Updates automatically as sessions start/stop + +{Colors.BOLD}Use Cases:{Colors.RESET} + - Long-running build processes + - Web servers and daemons + - File watching and monitoring + - Batch processing tasks + - Test suites execution + - Development servers +""" + + +def get_full_help(): + return f""" +{Colors.BOLD}R - PROFESSIONAL AI ASSISTANT{Colors.RESET} + +{Colors.BOLD}BASIC COMMANDS{Colors.RESET} + exit, quit, q - Exit the assistant + /help - Show this help message + /help workflows - Detailed workflow documentation + /help agents - Detailed agent documentation + /help knowledge - Knowledge base documentation + /help cache - Caching system documentation + /help background - Background sessions documentation + /reset - Clear message history + /dump - Show message history as JSON + /verbose - Toggle verbose mode + /models - List available models + /tools - List available tools + +{Colors.BOLD}FILE OPERATIONS{Colors.RESET} + /review - Review a file + /refactor - Refactor code in a file + /obfuscate - Obfuscate code in a file + +{Colors.BOLD}AUTONOMOUS MODE{Colors.RESET} + {Colors.CYAN}/auto {Colors.RESET} - Enter autonomous mode for task completion + Assistant works continuously until task is complete + Max 50 iterations with automatic context management + Press Ctrl+C twice to force exit + +{Colors.BOLD}WORKFLOWS{Colors.RESET} + {Colors.CYAN}/workflows{Colors.RESET} - List all available workflows + {Colors.CYAN}/workflow {Colors.RESET} - Execute a specific workflow + + Workflows enable automated multi-step task execution with: + - Sequential, parallel, or conditional execution + - Variable substitution and step dependencies + - Error handling and retry logic + - Success/failure path routing + + For detailed documentation: /help workflows + +{Colors.BOLD}AGENTS{Colors.RESET} + {Colors.CYAN}/agent {Colors.RESET} - Create specialized agent + {Colors.CYAN}/agents{Colors.RESET} - Show active agents + {Colors.CYAN}/collaborate {Colors.RESET} - Multi-agent collaboration + + Available roles: coding, research, data_analysis, planning, + testing, documentation, orchestrator, general + + Each agent has specialized capabilities and system prompts. + Agents can work independently or collaborate on complex tasks. + + For detailed documentation: /help agents + +{Colors.BOLD}KNOWLEDGE BASE{Colors.RESET} + {Colors.CYAN}/knowledge {Colors.RESET} - Search knowledge base + {Colors.CYAN}/remember {Colors.RESET} - Store information + + Persistent storage for facts, procedures, and context. + Automatic categorization and TF-IDF search. + Integrated with agents for context injection. + + For detailed documentation: /help knowledge + +{Colors.BOLD}SESSION MANAGEMENT{Colors.RESET} + {Colors.CYAN}/history{Colors.RESET} - Show conversation history + {Colors.CYAN}/cache{Colors.RESET} - Show cache statistics + {Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches + {Colors.CYAN}/stats{Colors.RESET} - Show system statistics + +{Colors.BOLD}BACKGROUND SESSIONS{Colors.RESET} + {Colors.CYAN}/bg start {Colors.RESET} - Start background session + {Colors.CYAN}/bg list{Colors.RESET} - List active sessions + {Colors.CYAN}/bg output {Colors.RESET} - View session output + {Colors.CYAN}/bg kill {Colors.RESET} - Terminate session + + Run long-running processes while maintaining interactivity. + Automatic monitoring and event notifications. + + For detailed documentation: /help background + +{Colors.BOLD}CONTEXT FILES{Colors.RESET} + .rcontext.txt - Local project context (auto-loaded) + ~/.rcontext.txt - Global context (auto-loaded) + -c, --context FILE - Additional context files (command line) + +{Colors.BOLD}ENVIRONMENT VARIABLES{Colors.RESET} + OPENROUTER_API_KEY - API key for OpenRouter + AI_MODEL - Default model to use + API_URL - Custom API endpoint + USE_TOOLS - Enable/disable tools (default: 1) + +{Colors.BOLD}COMMAND-LINE FLAGS{Colors.RESET} + -i, --interactive - Start in interactive mode + -v, --verbose - Enable verbose output + --debug - Enable debug logging + -m, --model MODEL - Specify AI model + --no-syntax - Disable syntax highlighting + +{Colors.BOLD}DATA STORAGE{Colors.RESET} + ~/.assistant_db.sqlite - SQLite database for persistence + ~/.assistant_history - Command history + ~/.assistant_error.log - Error logs + +{Colors.BOLD}AVAILABLE TOOLS{Colors.RESET} +Tools are functions the AI can call to interact with the system: + - File operations: read, write, list, mkdir, search/replace + - Command execution: run_command, interactive sessions + - Web operations: http_fetch, web_search, web_search_news + - Database: db_set, db_get, db_query + - Python execution: python_exec with persistent globals + - Code editing: open_editor, insert/replace text, diff/patch + - Process management: tail, kill, interactive control + - Agents: create, execute tasks, collaborate + - Knowledge: add, search, categorize entries + +{Colors.BOLD}GETTING HELP{Colors.RESET} + /help - This help message + /help workflows - Workflow system details + /help agents - Agent system details + /help knowledge - Knowledge base details + /help cache - Cache system details + /help background - Background sessions details + /tools - List all available tools + /models - List all available models +""" diff --git a/pr/config.py b/pr/config.py index 5006ec0..b4ad98c 100644 --- a/pr/config.py +++ b/pr/config.py @@ -1,8 +1,9 @@ import os DEFAULT_MODEL = "x-ai/grok-code-fast-1" -DEFAULT_API_URL = "http://localhost:8118/ai/chat" -MODEL_LIST_URL = "http://localhost:8118/ai/models" +DEFAULT_API_URL = "https://static.molodetz.nl/rp.cgi/api/v1/chat/completions" +MODEL_LIST_URL = "https://static.molodetz.nl/rp.cgi/api/v1/models" + config_directory = os.path.expanduser("~/.local/share/rp") os.makedirs(config_directory, exist_ok=True) diff --git a/pr/core/advanced_context.py b/pr/core/advanced_context.py index cedba7e..8183559 100644 --- a/pr/core/advanced_context.py +++ b/pr/core/advanced_context.py @@ -7,44 +7,45 @@ class AdvancedContextManager: self.knowledge_store = knowledge_store self.conversation_memory = conversation_memory - def adaptive_context_window( - self, messages: List[Dict[str, Any]], task_complexity: str = "medium" - ) -> int: - complexity_thresholds = { - "simple": 10, - "medium": 20, - "complex": 35, - "very_complex": 50, + def adaptive_context_window(self, messages: List[Dict[str, Any]], complexity: str) -> int: + """Calculate adaptive context window size based on message complexity.""" + base_window = 10 + + complexity_multipliers = { + "simple": 1.0, + "medium": 2.0, + "complex": 3.5, + "very_complex": 5.0, } - base_threshold = complexity_thresholds.get(task_complexity, 20) - - message_complexity_score = self._analyze_message_complexity(messages) - - if message_complexity_score > 0.7: - adjusted = int(base_threshold * 1.5) - elif message_complexity_score < 0.3: - adjusted = int(base_threshold * 0.7) - else: - adjusted = base_threshold - - return max(base_threshold, adjusted) + multiplier = complexity_multipliers.get(complexity, 2.0) + return int(base_window * multiplier) def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float: - total_length = sum(len(msg.get("content", "")) for msg in messages) - avg_length = total_length / len(messages) if messages else 0 + """Analyze the complexity of messages and return a score between 0.0 and 1.0.""" + if not messages: + return 0.0 - unique_words = set() - for msg in messages: - content = msg.get("content", "") - words = re.findall(r"\b\w+\b", content.lower()) - unique_words.update(words) + total_complexity = 0.0 + for message in messages: + content = message.get("content", "") + if not content: + continue - vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0 + # Calculate complexity based on various factors + word_count = len(content.split()) + sentence_count = len(re.split(r"[.!?]+", content)) + avg_word_length = sum(len(word) for word in content.split()) / max(word_count, 1) - # Simple complexity score based on length and richness - complexity = min(1.0, (avg_length / 100) + vocabulary_richness) - return complexity + # Complexity score based on length, vocabulary, and structure + length_score = min(1.0, word_count / 100) # Normalize to 0-1 + structure_score = min(1.0, sentence_count / 10) + vocabulary_score = min(1.0, avg_word_length / 8) + + message_complexity = (length_score + structure_score + vocabulary_score) / 3 + total_complexity += message_complexity + + return min(1.0, total_complexity / len(messages)) def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]: if not text.strip(): @@ -82,3 +83,38 @@ class AdvancedContextManager: return 0.0 return len(intersection) / len(union) + + def create_enhanced_context( + self, messages: List[Dict[str, Any]], user_message: str, include_knowledge: bool = True + ) -> tuple: + """Create enhanced context with knowledge base integration.""" + working_messages = messages.copy() + + if include_knowledge and self.knowledge_store: + # Search knowledge base for relevant information + search_results = self.knowledge_store.search_entries(user_message, top_k=3) + + if search_results: + knowledge_parts = [] + for idx, entry in enumerate(search_results, 1): + content = entry.content + if len(content) > 2000: + content = content[:2000] + "..." + + knowledge_parts.append(f"Match {idx} (Category: {entry.category}):\n{content}") + + knowledge_message_content = ( + "[KNOWLEDGE_BASE_CONTEXT]\nRelevant knowledge base entries:\n\n" + + "\n\n".join(knowledge_parts) + ) + + knowledge_message = {"role": "user", "content": knowledge_message_content} + working_messages.append(knowledge_message) + + context_info = f"Added {len(search_results)} knowledge base entries" + else: + context_info = "No relevant knowledge base entries found" + else: + context_info = "Knowledge base integration disabled" + + return working_messages, context_info diff --git a/pr/core/api.py b/pr/core/api.py index f65affd..3e88120 100644 --- a/pr/core/api.py +++ b/pr/core/api.py @@ -67,6 +67,15 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver if "tool_calls" in msg: logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") + if verbose and "usage" in result: + from pr.core.usage_tracker import UsageTracker + + usage = result["usage"] + input_t = usage.get("prompt_tokens", 0) + output_t = usage.get("completion_tokens", 0) + cost = UsageTracker._calculate_cost(model, input_t, output_t) + print(f"API call cost: €{cost:.4f}") + logger.debug("=== API CALL END ===") return result diff --git a/pr/core/assistant.py b/pr/core/assistant.py index 10e7b1d..4be6c61 100644 --- a/pr/core/assistant.py +++ b/pr/core/assistant.py @@ -29,34 +29,28 @@ from pr.core.background_monitor import ( stop_global_monitor, ) from pr.core.context import init_system_message, truncate_tool_result -from pr.tools import ( - apply_patch, - chdir, - create_diff, - db_get, - db_query, - db_set, - getpwd, - http_fetch, - index_source_directory, - kill_process, - list_directory, - mkdir, - python_exec, - read_file, - run_command, - search_replace, - tail_process, - web_search, - web_search_news, - write_file, - post_image, +from pr.tools import get_tools_definition +from pr.tools.agents import ( + collaborate_agents, + create_agent, + execute_agent_task, + list_agents, + remove_agent, ) -from pr.tools.base import get_tools_definition +from pr.tools.command import kill_process, run_command, tail_process +from pr.tools.database import db_get, db_query, db_set from pr.tools.filesystem import ( + chdir, clear_edit_tracker, display_edit_summary, display_edit_timeline, + getpwd, + index_source_directory, + list_directory, + mkdir, + read_file, + search_replace, + write_file, ) from pr.tools.interactive_control import ( close_interactive_session, @@ -65,7 +59,18 @@ from pr.tools.interactive_control import ( send_input_to_session, start_interactive_session, ) -from pr.tools.patch import display_file_diff +from pr.tools.memory import ( + add_knowledge_entry, + delete_knowledge_entry, + get_knowledge_by_category, + get_knowledge_entry, + get_knowledge_statistics, + search_knowledge, + update_knowledge_importance, +) +from pr.tools.patch import apply_patch, create_diff, display_file_diff +from pr.tools.python_exec import python_exec +from pr.tools.web import http_fetch, web_search, web_search_news from pr.ui import Colors, render_markdown logger = logging.getLogger("pr") @@ -97,7 +102,7 @@ class Assistant: "MODEL_LIST_URL", MODEL_LIST_URL ) self.use_tools = os.environ.get("USE_TOOLS", "1") == "1" - self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1" + self.interrupt_count = 0 self.python_globals = {} self.db_conn = None @@ -245,7 +250,6 @@ class Assistant: logger.debug(f"Tool call: {func_name} with arguments: {arguments}") func_map = { - "post_image": lambda **kw: post_image(**kw), "http_fetch": lambda **kw: http_fetch(**kw), "run_command": lambda **kw: run_command(**kw), "tail_process": lambda **kw: tail_process(**kw), diff --git a/pr/core/background_monitor.py b/pr/core/background_monitor.py index b30de63..a965614 100644 --- a/pr/core/background_monitor.py +++ b/pr/core/background_monitor.py @@ -27,15 +27,6 @@ class BackgroundMonitor: if self.monitor_thread: self.monitor_thread.join(timeout=2) - def add_event_callback(self, callback): - """Add a callback function to be called when events are detected.""" - self.event_callbacks.append(callback) - - def remove_event_callback(self, callback): - """Remove an event callback.""" - if callback in self.event_callbacks: - self.event_callbacks.remove(callback) - def get_pending_events(self): """Get all pending events from the queue.""" events = [] @@ -224,30 +215,3 @@ def stop_global_monitor(): global _global_monitor if _global_monitor: _global_monitor.stop() - - -# Global monitor instance -_global_monitor = None - - -def start_global_monitor(): - """Start the global background monitor.""" - global _global_monitor - if _global_monitor is None: - _global_monitor = BackgroundMonitor() - _global_monitor.start() - return _global_monitor - - -def stop_global_monitor(): - """Stop the global background monitor.""" - global _global_monitor - if _global_monitor: - _global_monitor.stop() - _global_monitor = None - - -def get_global_monitor(): - """Get the global background monitor instance.""" - global _global_monitor - return _global_monitor diff --git a/pr/core/config_loader.py b/pr/core/config_loader.py index b47780c..a677a11 100644 --- a/pr/core/config_loader.py +++ b/pr/core/config_loader.py @@ -1,7 +1,6 @@ import configparser import os from typing import Any, Dict -import uuid from pr.core.logging import get_logger logger = get_logger("config") @@ -11,29 +10,6 @@ CONFIG_FILE = os.path.join(CONFIG_DIRECTORY, ".prrc") LOCAL_CONFIG_FILE = ".prrc" -def load_config() -> Dict[str, Any]: - os.makedirs(CONFIG_DIRECTORY, exist_ok=True) - config = { - "api": {}, - "autonomous": {}, - "ui": {}, - "output": {}, - "session": {}, - "api_key": "rp-" + str(uuid.uuid4()), - } - - global_config = _load_config_file(CONFIG_FILE) - local_config = _load_config_file(LOCAL_CONFIG_FILE) - - for section in config.keys(): - if section in global_config: - config[section].update(global_config[section]) - if section in local_config: - config[section].update(local_config[section]) - - return config - - def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]: if not os.path.exists(filepath): return {} diff --git a/pr/core/enhanced_assistant.py b/pr/core/enhanced_assistant.py index 2682fc6..e3fd1a4 100644 --- a/pr/core/enhanced_assistant.py +++ b/pr/core/enhanced_assistant.py @@ -12,7 +12,6 @@ from pr.config import ( CONVERSATION_SUMMARY_THRESHOLD, DB_PATH, KNOWLEDGE_SEARCH_LIMIT, - MEMORY_AUTO_SUMMARIZE, TOOL_CACHE_TTL, WORKFLOW_EXECUTOR_MAX_WORKERS, ) @@ -169,24 +168,28 @@ class EnhancedAssistant: self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message ) - if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0: - facts = self.fact_extractor.extract_facts(user_message) - for fact in facts[:3]: - entry_id = str(uuid.uuid4())[:16] - import time + # Automatically extract and store facts from every user message + facts = self.fact_extractor.extract_facts(user_message) + for fact in facts[:5]: # Store up to 5 facts per message + entry_id = str(uuid.uuid4())[:16] + import time - from pr.memory import KnowledgeEntry + from pr.memory import KnowledgeEntry - categories = self.fact_extractor.categorize_content(fact["text"]) - entry = KnowledgeEntry( - entry_id=entry_id, - category=categories[0] if categories else "general", - content=fact["text"], - metadata={"type": fact["type"], "confidence": fact["confidence"]}, - created_at=time.time(), - updated_at=time.time(), - ) - self.knowledge_store.add_entry(entry) + categories = self.fact_extractor.categorize_content(fact["text"]) + entry = KnowledgeEntry( + entry_id=entry_id, + category=categories[0] if categories else "general", + content=fact["text"], + metadata={ + "type": fact["type"], + "confidence": fact["confidence"], + "source": "user_message", + }, + created_at=time.time(), + updated_at=time.time(), + ) + self.knowledge_store.add_entry(entry) if self.context_manager and ADVANCED_CONTEXT_ENABLED: enhanced_messages, context_info = self.context_manager.create_enhanced_context( diff --git a/pr/core/knowledge_context.py b/pr/core/knowledge_context.py new file mode 100644 index 0000000..78977fc --- /dev/null +++ b/pr/core/knowledge_context.py @@ -0,0 +1,148 @@ +import logging + +logger = logging.getLogger("pr") + +KNOWLEDGE_MESSAGE_MARKER = "[KNOWLEDGE_BASE_CONTEXT]" + + +def inject_knowledge_context(assistant, user_message): + if not hasattr(assistant, "enhanced") or not assistant.enhanced: + return + + messages = assistant.messages + + # Remove any existing knowledge context messages + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user" and KNOWLEDGE_MESSAGE_MARKER in messages[i].get( + "content", "" + ): + del messages[i] + logger.debug(f"Removed existing knowledge base message at index {i}") + break + + try: + # Search knowledge base with enhanced FTS + semantic search + knowledge_results = assistant.enhanced.knowledge_store.search_entries(user_message, top_k=5) + + # Search conversation history for related content + conversation_results = [] + if hasattr(assistant.enhanced, "conversation_memory"): + history_results = assistant.enhanced.conversation_memory.search_conversations( + user_message, limit=3 + ) + for conv in history_results: + # Extract relevant messages from conversation + conv_messages = assistant.enhanced.conversation_memory.get_conversation_messages( + conv["conversation_id"] + ) + for msg in conv_messages[-5:]: # Last 5 messages from each conversation + if msg["role"] == "user" and msg["content"] != user_message: + # Calculate relevance score + relevance = calculate_text_similarity(user_message, msg["content"]) + if relevance > 0.3: # Only include relevant matches + conversation_results.append( + { + "content": msg["content"], + "score": relevance, + "source": f"Previous conversation: {conv['conversation_id'][:8]}", + } + ) + + # Combine and sort results by relevance score + all_results = [] + + # Add knowledge base results + for entry in knowledge_results: + score = entry.metadata.get("search_score", 0.5) + all_results.append( + { + "content": entry.content, + "score": score, + "source": f"Knowledge Base ({entry.category})", + "type": "knowledge", + } + ) + + # Add conversation results + for conv in conversation_results: + all_results.append( + { + "content": conv["content"], + "score": conv["score"], + "source": conv["source"], + "type": "conversation", + } + ) + + # Sort by score and take top 5 + all_results.sort(key=lambda x: x["score"], reverse=True) + top_results = all_results[:5] + + if not top_results: + logger.debug("No relevant knowledge or conversation matches found") + return + + # Format context for LLM + knowledge_parts = [] + for idx, result in enumerate(top_results, 1): + content = result["content"] + if len(content) > 1500: # Shorter limit for multiple results + content = content[:1500] + "..." + + score_indicator = f"({result['score']:.2f})" if result["score"] < 1.0 else "(exact)" + knowledge_parts.append( + f"Match {idx} {score_indicator} - {result['source']}:\n{content}" + ) + + knowledge_message_content = ( + f"{KNOWLEDGE_MESSAGE_MARKER}\nRelevant information from knowledge base and conversation history:\n\n" + + "\n\n".join(knowledge_parts) + ) + + knowledge_message = {"role": "user", "content": knowledge_message_content} + + messages.append(knowledge_message) + logger.debug(f"Injected enhanced context message with {len(top_results)} matches") + + except Exception as e: + logger.error(f"Error injecting knowledge context: {e}") + + +def calculate_text_similarity(text1: str, text2: str) -> float: + """Calculate similarity between two texts using word overlap and sequence matching.""" + import re + + # Normalize texts + text1_lower = text1.lower() + text2_lower = text2.lower() + + # Exact substring match gets highest score + if text1_lower in text2_lower or text2_lower in text1_lower: + return 1.0 + + # Word-level similarity + words1 = set(re.findall(r"\b\w+\b", text1_lower)) + words2 = set(re.findall(r"\b\w+\b", text2_lower)) + + if not words1 or not words2: + return 0.0 + + intersection = words1 & words2 + union = words1 | words2 + + word_similarity = len(intersection) / len(union) + + # Bonus for consecutive word sequences (partial sentences) + consecutive_bonus = 0.0 + words1_list = list(words1) + list(words2) + + for i in range(len(words1_list) - 1): + for j in range(i + 2, min(i + 5, len(words1_list) + 1)): + phrase = " ".join(words1_list[i:j]) + if phrase in text2_lower: + consecutive_bonus += 0.1 * (j - i) + + total_similarity = min(1.0, word_similarity + consecutive_bonus) + + return total_similarity diff --git a/pr/core/usage_tracker.py b/pr/core/usage_tracker.py index 7223263..c43fe44 100644 --- a/pr/core/usage_tracker.py +++ b/pr/core/usage_tracker.py @@ -9,8 +9,10 @@ logger = get_logger("usage") USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json") +EXCHANGE_RATE = 1.0 # Keep in USD + MODEL_COSTS = { - "x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0}, + "x-ai/grok-code-fast-1": {"input": 0.0002, "output": 0.0015}, # per 1000 tokens in USD "gpt-4": {"input": 0.03, "output": 0.06}, "gpt-4-turbo": {"input": 0.01, "output": 0.03}, "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, @@ -64,9 +66,10 @@ class UsageTracker: self._save_to_history(model, input_tokens, output_tokens, cost) - logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}") + logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: €{cost:.4f}") - def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float: + @staticmethod + def _calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float: if model not in MODEL_COSTS: base_model = model.split("/")[0] if "/" in model else model if base_model not in MODEL_COSTS: @@ -76,8 +79,8 @@ class UsageTracker: else: costs = MODEL_COSTS[model] - input_cost = (input_tokens / 1000) * costs["input"] - output_cost = (output_tokens / 1000) * costs["output"] + input_cost = (input_tokens / 1000) * costs["input"] * EXCHANGE_RATE + output_cost = (output_tokens / 1000) * costs["output"] * EXCHANGE_RATE return input_cost + output_cost diff --git a/pr/editor.py b/pr/editor.py index dfd0039..59511bc 100644 --- a/pr/editor.py +++ b/pr/editor.py @@ -370,8 +370,10 @@ class RPEditor: self.cursor_x = 0 elif key == ord("u"): self.undo() - elif key == ord("r") and self.prev_key == 18: # Ctrl-R + elif key == 18: # Ctrl-R self.redo() + elif key == 19: # Ctrl-S + self._save_file() self.prev_key = key except Exception: diff --git a/pr/implode.py b/pr/implode.py new file mode 100644 index 0000000..443b881 --- /dev/null +++ b/pr/implode.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python3 + +""" +impLODE: A Python script to consolidate a multi-file Python project +into a single, runnable file. + +It intelligently resolves local imports, hoists external dependencies to the top, +and preserves the core logic, using AST for safe transformations. +""" + +import os +import sys +import ast +import argparse +import logging +import py_compile # Added +from typing import Set, Dict, Optional, TextIO + +# Set up the main logger +logger = logging.getLogger("impLODE") + + +class ImportTransformer(ast.NodeTransformer): + """ + An AST transformer that visits Import and ImportFrom nodes. + + On Pass 1 (Dry Run): + - Identifies all local vs. external imports. + - Recursively calls the main resolver for local modules. + - Stores external and __future__ imports in the Imploder instance. + + On Pass 2 (Write Run): + - Recursively calls the main resolver for local modules. + - Removes all import statements (since they were hoisted in Pass 1). + """ + + def __init__( + self, + imploder: "Imploder", + current_file_path: str, + f_out: Optional[TextIO], + is_dry_run: bool, + indent_level: int = 0, + ): + self.imploder = imploder + self.current_file_path = current_file_path + self.current_dir = os.path.dirname(current_file_path) + self.f_out = f_out + self.is_dry_run = is_dry_run + self.indent = " " * indent_level + self.logger = logging.getLogger(self.__class__.__name__) + + def _log_debug(self, msg: str): + """Helper for indented debug logging.""" + self.logger.debug(f"{self.indent} > {msg}") + + def _find_local_module(self, module_name: str, level: int) -> Optional[str]: + """ + Tries to find the absolute path for a given module name and relative level. + Returns None if it's not a local module *and* cannot be found in site-packages. + """ + import importlib.util # Added + import importlib.machinery # Added + + if not module_name: + # Handle cases like `from . import foo` + # The module_name is 'foo', handled by visit_ImportFrom's aliases. + # If it's just `from . import`, module_name is None. + # We'll resolve from the alias names. + + # Determine base path for relative import + base_path = self.current_dir + if level > 0: + # Go up 'level' directories (minus one for current dir) + for _ in range(level - 1): + base_path = os.path.dirname(base_path) + return base_path + + base_path = self.current_dir + if level > 0: + # Go up 'level' directories (minus one for current dir) + for _ in range(level - 1): + base_path = os.path.dirname(base_path) + else: + # Absolute import, start from the root + base_path = self.imploder.root_dir + + module_parts = module_name.split(".") + module_path = os.path.join(base_path, *module_parts) + + # Check for package `.../module/__init__.py` + package_init = os.path.join(module_path, "__init__.py") + if os.path.isfile(package_init): + self._log_debug( + f"Resolved '{module_name}' to local package: {os.path.relpath(package_init, self.imploder.root_dir)}" + ) + return package_init + + # Check for module `.../module.py` + module_py = module_path + ".py" + if os.path.isfile(module_py): + self._log_debug( + f"Resolved '{module_name}' to local module: {os.path.relpath(module_py, self.imploder.root_dir)}" + ) + return module_py + + # --- FALLBACK: Deep search from root --- + # This is less efficient but helps with non-standard project structures + # where the "root" for imports isn't the same as the Imploder's root. + if level == 0: + self._log_debug( + f"Module '{module_name}' not found at primary path. Starting deep fallback search from {self.imploder.root_dir}..." + ) + + target_path_py = os.path.join(*module_parts) + ".py" + target_path_init = os.path.join(*module_parts, "__init__.py") + + for dirpath, dirnames, filenames in os.walk(self.imploder.root_dir, topdown=True): + # Prune common virtual envs and metadata dirs to speed up search + dirnames[:] = [ + d + for d in dirnames + if not d.startswith(".") + and d not in ("venv", "env", ".venv", ".env", "__pycache__", "node_modules") + ] + + # Check for .../module/file.py + check_file_py = os.path.join(dirpath, target_path_py) + if os.path.isfile(check_file_py): + self._log_debug( + f"Fallback search found module: {os.path.relpath(check_file_py, self.imploder.root_dir)}" + ) + return check_file_py + + # Check for .../module/file/__init__.py + check_file_init = os.path.join(dirpath, target_path_init) + if os.path.isfile(check_file_init): + self._log_debug( + f"Fallback search found package: {os.path.relpath(check_file_init, self.imploder.root_dir)}" + ) + return check_file_init + + # --- Removed site-packages/external module inlining logic --- + # As requested, the script will now only resolve local modules. + + # Not found locally + return None + + def visit_Import(self, node: ast.Import) -> Optional[ast.AST]: + """Handles `import foo` or `import foo.bar`.""" + for alias in node.names: + module_path = self._find_local_module(alias.name, level=0) + + if module_path: + # Local module + self._log_debug(f"Resolving local import: `import {alias.name}`") + self.imploder.resolve_file( + file_abs_path=module_path, + f_out=self.f_out, + is_dry_run=self.is_dry_run, + indent_level=self.imploder.current_indent_level, + ) + else: + # External module + self._log_debug(f"Found external import: `import {alias.name}`") + if self.is_dry_run: + # On dry run, collect it for hoisting + key = f"import {alias.name}" + if key not in self.imploder.external_imports: + self.imploder.external_imports[key] = node + + # On Pass 2 (write run), we remove this node by replacing it + # with a dummy function call to prevent IndentationErrors. + # On Pass 1 (dry run), we also replace it. + module_names = ", ".join([a.name for a in node.names]) + new_call = ast.Call( + func=ast.Name(id="_implode_log_import", ctx=ast.Load()), + args=[ + ast.Constant(value=module_names), + ast.Constant(value=0), + ast.Constant(value="import"), + ], + keywords=[], + ) + return ast.Expr(value=new_call) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> Optional[ast.AST]: + """Handles `from foo import bar` or `from .foo import bar`.""" + + # --- NEW: Create the replacement node first --- + module_name_str = node.module or "" + import_type = "from-import" + if module_name_str == "__future__": + import_type = "future-import" + + new_call = ast.Call( + func=ast.Name(id="_implode_log_import", ctx=ast.Load()), + args=[ + ast.Constant(value=module_name_str), + ast.Constant(value=node.level), + ast.Constant(value=import_type), + ], + keywords=[], + ) + replacement_node = ast.Expr(value=new_call) + + # Handle `from __future__ import ...` + if node.module == "__future__": + self._log_debug("Found __future__ import. Hoisting to top.") + if self.is_dry_run: + key = ast.unparse(node) + self.imploder.future_imports[key] = node + return replacement_node # Return replacement + + module_path = self._find_local_module(node.module or "", node.level) + + if module_path and os.path.isdir(module_path): + # This is a package import like `from . import foo` + # `module_path` is the directory of the package (e.g., self.current_dir) + self._log_debug(f"Resolving package import: `from {node.module or '.'} import ...`") + for alias in node.names: + # --- FIX --- + # Resolve `foo` relative to the package directory (`module_path`) + + # Check for module `.../package_dir/foo.py` + package_module_py = os.path.join(module_path, alias.name + ".py") + + # Check for sub-package `.../package_dir/foo/__init__.py` + package_module_init = os.path.join(module_path, alias.name, "__init__.py") + + if os.path.isfile(package_module_py): + self._log_debug( + f"Found sub-module: {os.path.relpath(package_module_py, self.imploder.root_dir)}" + ) + self.imploder.resolve_file( + file_abs_path=package_module_py, + f_out=self.f_out, + is_dry_run=self.is_dry_run, + indent_level=self.imploder.current_indent_level, + ) + elif os.path.isfile(package_module_init): + self._log_debug( + f"Found sub-package: {os.path.relpath(package_module_init, self.imploder.root_dir)}" + ) + self.imploder.resolve_file( + file_abs_path=package_module_init, + f_out=self.f_out, + is_dry_run=self.is_dry_run, + indent_level=self.imploder.current_indent_level, + ) + else: + # Could be another sub-package, etc. + # This logic can be expanded, but for now, log a warning + self.logger.warning( + f"{self.indent} > Could not resolve sub-module '{alias.name}' in package '{module_path}'" + ) + return replacement_node # Return replacement + # --- END FIX --- + + if module_path: + # Local module (e.g., `from .foo import bar` or `from myapp.foo import bar`) + self._log_debug(f"Resolving local from-import: `from {node.module or '.'} ...`") + self.imploder.resolve_file( + file_abs_path=module_path, + f_out=self.f_out, + is_dry_run=self.is_dry_run, + indent_level=self.imploder.current_indent_level, + ) + else: + # External module + self._log_debug(f"Found external from-import: `from {node.module or '.'} ...`") + if self.is_dry_run: + # On dry run, collect it for hoisting + key = ast.unparse(node) + if key not in self.imploder.external_imports: + self.imploder.external_imports[key] = node + + # On Pass 2 (write run), we remove this node entirely by replacing it + # On Pass 1 (dry run), we also replace it + return replacement_node + + +class Imploder: + """ + Core class for handling the implosion process. + Manages state, file processing, and the two-pass analysis. + """ + + def __init__(self, root_dir: str, enable_import_logging: bool = False): + self.root_dir = os.path.realpath(root_dir) + self.processed_files: Set[str] = set() + self.external_imports: Dict[str, ast.AST] = {} + self.future_imports: Dict[str, ast.AST] = {} + self.current_indent_level = 0 + self.enable_import_logging = enable_import_logging # Added + logger.info(f"Initialized Imploder with root: {self.root_dir}") + + def implode(self, main_file_abs_path: str, output_file_path: str): + """ + Runs the full two-pass implosion process. + """ + if not os.path.isfile(main_file_abs_path): + logger.critical(f"Main file not found: {main_file_abs_path}") + sys.exit(1) + + # --- Pass 1: Analyze Dependencies (Dry Run) --- + logger.info( + f"--- PASS 1: Analyzing dependencies from {os.path.relpath(main_file_abs_path, self.root_dir)} ---" + ) + self.processed_files.clear() + self.external_imports.clear() + self.future_imports.clear() + + try: + self.resolve_file(main_file_abs_path, f_out=None, is_dry_run=True, indent_level=0) + except Exception as e: + logger.critical(f"Error during analysis pass: {e}", exc_info=True) + sys.exit(1) + + logger.info( + f"--- Analysis complete. Found {len(self.future_imports)} __future__ imports and {len(self.external_imports)} external modules. ---" + ) + + # --- Pass 2: Write Imploded File --- + logger.info(f"--- PASS 2: Writing imploded file to {output_file_path} ---") + self.processed_files.clear() + + try: + with open(output_file_path, "w", encoding="utf-8") as f_out: + # Write headers + f_out.write(f"#!/usr/bin/env python3\n") + + f_out.write(f"# -*- coding: utf-8 -*-\n") + f_out.write(f"import logging\n") + f_out.write(f"\n# --- IMPLODED FILE: Generated by impLODE --- #\n") + f_out.write( + f"# --- Original main file: {os.path.relpath(main_file_abs_path, self.root_dir)} --- #\n" + ) + + # Write __future__ imports + if self.future_imports: + f_out.write("\n# --- Hoisted __future__ Imports --- #\n") + for node in self.future_imports.values(): + f_out.write(f"{ast.unparse(node)}\n") + f_out.write("# --- End __future__ Imports --- #\n") + + # --- NEW: Write Helper Function --- + enable_logging_str = "True" if self.enable_import_logging else "False" + f_out.write("\n# --- impLODE Helper Function --- #\n") + f_out.write(f"_IMPLODE_LOGGING_ENABLED_ = {enable_logging_str}\n") + f_out.write("def _implode_log_import(module_name, level, import_type):\n") + f_out.write( + ' """Dummy function to replace imports and prevent IndentationErrors."""\n' + ) + f_out.write(" if _IMPLODE_LOGGING_ENABLED_:\n") + f_out.write( + " print(f\"[impLODE Logger]: Skipped {import_type}: module='{module_name}', level={level}\")\n" + ) + f_out.write(" pass\n") + f_out.write("# --- End Helper Function --- #\n") + + # Write external imports + if self.external_imports: + f_out.write("\n# --- Hoisted External Imports --- #\n") + for node in self.external_imports.values(): + # As requested, wrap each external import in a try/except block + f_out.write("try:\n") + f_out.write(f" {ast.unparse(node)}\n") + f_out.write("except ImportError:\n") + f_out.write(" pass\n") + f_out.write("# --- End External Imports --- #\n") + + # Write the actual code + self.resolve_file(main_file_abs_path, f_out=f_out, is_dry_run=False, indent_level=0) + + except IOError as e: + logger.critical( + f"Could not write to output file {output_file_path}: {e}", exc_info=True + ) + sys.exit(1) + except Exception as e: + logger.critical(f"Error during write pass: {e}", exc_info=True) + sys.exit(1) + + logger.info(f"--- Implosion complete! Output saved to {output_file_path} ---") + + def resolve_file( + self, file_abs_path: str, f_out: Optional[TextIO], is_dry_run: bool, indent_level: int = 0 + ): + """ + Recursively resolves a single file. + - `is_dry_run=True`: Analyzes imports, populating `external_imports`. + - `is_dry_run=False`: Writes transformed code to `f_out`. + """ + self.current_indent_level = indent_level + indent = " " * indent_level + + try: + file_real_path = os.path.realpath(file_abs_path) + # --- Reverted path logic as we only handle local files --- + rel_path = os.path.relpath(file_real_path, self.root_dir) + # --- End Revert --- + except ValueError: + # Happens if file is on a different drive than root_dir on Windows + logger.warning( + f"{indent}Cannot calculate relative path for {file_abs_path}. Using absolute." + ) + rel_path = file_abs_path + + # --- 1. Circular Dependency Check --- + if file_real_path in self.processed_files: + logger.debug(f"{indent}Skipping already processed file: {rel_path}") + return + + logger.info(f"{indent}Processing: {rel_path}") + self.processed_files.add(file_real_path) + + # --- 2. Read File --- + try: + with open(file_real_path, "r", encoding="utf-8") as f: + code = f.read() + except FileNotFoundError: + logger.error(f"{indent}File not found: {file_real_path}") + return + except UnicodeDecodeError: + logger.error(f"{indent}Could not decode file (not utf-8): {file_real_path}") + return + except Exception as e: + logger.error(f"{indent}Could not read file {file_real_path}: {e}") + return + + # --- 2.5. Validate Syntax with py_compile --- + try: + # As requested, validate syntax using py_compile before AST parsing + py_compile.compile(file_real_path, doraise=True, quiet=1) + logger.debug(f"{indent}Syntax OK (py_compile): {rel_path}") + except py_compile.PyCompileError as e: + # This error includes file, line, and message + logger.error( + f"{indent}Syntax error (py_compile) in {e.file} on line {e.lineno}: {e.msg}" + ) + return + except Exception as e: + # Catch other potential errors like permission issues + logger.error(f"{indent}Error during py_compile for {rel_path}: {e}") + return + + # --- 3. Parse AST --- + try: + tree = ast.parse(code, filename=file_real_path) + except SyntaxError as e: + logger.error(f"{indent}Syntax error in {rel_path} on line {e.lineno}: {e.msg}") + return + except Exception as e: + logger.error(f"{indent}Could not parse AST for {rel_path}: {e}") + return + + # --- 4. Transform AST --- + transformer = ImportTransformer( + imploder=self, + current_file_path=file_real_path, + f_out=f_out, + is_dry_run=is_dry_run, + indent_level=indent_level, + ) + + try: + new_tree = transformer.visit(tree) + except Exception as e: + logger.error(f"{indent}Error transforming AST for {rel_path}: {e}", exc_info=True) + return + + # --- 5. Write Content (on Pass 2) --- + if not is_dry_run and f_out: + try: + ast.fix_missing_locations(new_tree) + f_out.write(f"\n\n# --- Content from {rel_path} --- #\n") + f_out.write(ast.unparse(new_tree)) + f_out.write(f"\n# --- End of {rel_path} --- #\n") + logger.info(f"{indent}Successfully wrote content from: {rel_path}") + except Exception as e: + logger.error( + f"{indent}Could not unparse or write AST for {rel_path}: {e}", exc_info=True + ) + + self.current_indent_level = indent_level + + +def setup_logging(level: int): + """Configures the root logger.""" + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + + # Configure the 'impLODE' logger + logger.setLevel(level) + logger.addHandler(handler) + + # Configure the 'ImportTransformer' logger + transformer_logger = logging.getLogger("ImportTransformer") + transformer_logger.setLevel(level) + transformer_logger.addHandler(handler) + + +def main(): + """Main entry point for the script.""" + parser = argparse.ArgumentParser( + description="impLODE: Consolidate a multi-file Python project into one file.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "main_file", type=str, help="The main entry point .py file of your project." + ) + parser.add_argument( + "-o", + "--output", + type=str, + default="imploded.py", + help="Path for the combined output file. (default: imploded.py)", + ) + parser.add_argument( + "-r", + "--root", + type=str, + default=".", + help="The root directory of the project for resolving absolute imports. (default: current directory)", + ) + + log_group = parser.add_mutually_exclusive_group() + log_group.add_argument( + "-v", + "--verbose", + action="store_const", + dest="log_level", + const=logging.DEBUG, + default=logging.INFO, + help="Enable verbose DEBUG logging.", + ) + log_group.add_argument( + "-q", + "--quiet", + action="store_const", + dest="log_level", + const=logging.WARNING, + help="Suppress INFO logs, showing only WARNINGS and ERRORS.", + ) + + parser.add_argument( + "--enable-import-logging", + action="store_true", + help="Enable runtime logging for removed import statements in the final imploded script.", + ) + + args = parser.parse_args() + + # --- Setup --- + setup_logging(args.log_level) + + root_dir = os.path.abspath(args.root) + main_file_path = os.path.abspath(args.main_file) + output_file_path = os.path.abspath(args.output) + + # --- Validation --- + if not os.path.isdir(root_dir): + logger.critical(f"Root directory not found: {root_dir}") + sys.exit(1) + + if not os.path.isfile(main_file_path): + logger.critical(f"Main file not found: {main_file_path}") + sys.exit(1) + + if not main_file_path.startswith(root_dir): + logger.warning(f"Main file {main_file_path} is outside the specified root {root_dir}.") + logger.warning("This may cause issues with absolute import resolution.") + + if main_file_path == output_file_path: + logger.critical("Output file cannot be the same as the main file.") + sys.exit(1) + + # --- Run --- + imploder = Imploder( + root_dir=root_dir, enable_import_logging=args.enable_import_logging # Added + ) + imploder.implode(main_file_abs_path=main_file_path, output_file_path=output_file_path) + + +if __name__ == "__main__": + main() diff --git a/pr/memory/fact_extractor.py b/pr/memory/fact_extractor.py index ff28d52..7011af7 100644 --- a/pr/memory/fact_extractor.py +++ b/pr/memory/fact_extractor.py @@ -56,10 +56,14 @@ class FactExtractor: else: if len(current_phrase) >= 2: phrases.append(" ".join(current_phrase)) + elif len(current_phrase) == 1: + phrases.append(current_phrase[0]) # Single capitalized words current_phrase = [] if len(current_phrase) >= 2: phrases.append(" ".join(current_phrase)) + elif len(current_phrase) == 1: + phrases.append(current_phrase[0]) # Single capitalized words return list(set(phrases)) @@ -165,13 +169,14 @@ class FactExtractor: return relationships def extract_metadata(self, text: str) -> Dict[str, Any]: - word_count = len(text.split()) - sentence_count = len(re.split(r"[.!?]", text)) + word_count = len(text.split()) if text.strip() else 0 + sentences = re.split(r"[.!?]", text.strip()) + sentence_count = len([s for s in sentences if s.strip()]) if text.strip() else 0 urls = re.findall(r"https?://[^\s]+", text) email_addresses = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text) dates = re.findall( - r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text + r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b|\b\d{4}\b", text ) numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text) diff --git a/pr/memory/knowledge_store.py b/pr/memory/knowledge_store.py index 22191d5..35f574c 100644 --- a/pr/memory/knowledge_store.py +++ b/pr/memory/knowledge_store.py @@ -1,8 +1,9 @@ import json import sqlite3 +import threading import time from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from .semantic_index import SemanticIndex @@ -35,247 +36,334 @@ class KnowledgeStore: def __init__(self, db_path: str): self.db_path = db_path self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self.lock = threading.Lock() self.semantic_index = SemanticIndex() self._initialize_store() self._load_index() def _initialize_store(self): - cursor = self.conn.cursor() + with self.lock: + cursor = self.conn.cursor() - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS knowledge_entries ( - entry_id TEXT PRIMARY KEY, - category TEXT NOT NULL, - content TEXT NOT NULL, - metadata TEXT, - created_at REAL NOT NULL, - updated_at REAL NOT NULL, - access_count INTEGER DEFAULT 0, - importance_score REAL DEFAULT 1.0 - ) - """ - ) - - cursor.execute( - """ - CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category) - """ - ) - cursor.execute( - """ - CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC) - """ - ) - cursor.execute( - """ - CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC) - """ - ) - - self.conn.commit() - - def _load_index(self): - cursor = self.conn.cursor() - - cursor.execute("SELECT entry_id, content FROM knowledge_entries") - for row in cursor.fetchall(): - self.semantic_index.add_document(row[0], row[1]) - - def add_entry(self, entry: KnowledgeEntry): - cursor = self.conn.cursor() - - cursor.execute( - """ - INSERT OR REPLACE INTO knowledge_entries - (entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - entry.entry_id, - entry.category, - entry.content, - json.dumps(entry.metadata), - entry.created_at, - entry.updated_at, - entry.access_count, - entry.importance_score, - ), - ) - - self.conn.commit() - - self.semantic_index.add_document(entry.entry_id, entry.content) - - def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]: - cursor = self.conn.cursor() - - cursor.execute( - """ - SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score - FROM knowledge_entries - WHERE entry_id = ? - """, - (entry_id,), - ) - - row = cursor.fetchone() - - if row: cursor.execute( """ - UPDATE knowledge_entries - SET access_count = access_count + 1 + CREATE TABLE IF NOT EXISTS knowledge_entries ( + entry_id TEXT PRIMARY KEY, + category TEXT NOT NULL, + content TEXT NOT NULL, + metadata TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + access_count INTEGER DEFAULT 0, + importance_score REAL DEFAULT 1.0 + ) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category) + """ + ) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC) + """ + ) + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC) + """ + ) + + self.conn.commit() + + def _load_index(self): + with self.lock: + cursor = self.conn.cursor() + + cursor.execute("SELECT entry_id, content FROM knowledge_entries") + for row in cursor.fetchall(): + self.semantic_index.add_document(row[0], row[1]) + + def add_entry(self, entry: KnowledgeEntry): + with self.lock: + cursor = self.conn.cursor() + + cursor.execute( + """ + INSERT OR REPLACE INTO knowledge_entries + (entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.entry_id, + entry.category, + entry.content, + json.dumps(entry.metadata), + entry.created_at, + entry.updated_at, + entry.access_count, + entry.importance_score, + ), + ) + + self.conn.commit() + + self.semantic_index.add_document(entry.entry_id, entry.content) + + def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]: + with self.lock: + cursor = self.conn.cursor() + + cursor.execute( + """ + SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score + FROM knowledge_entries WHERE entry_id = ? """, (entry_id,), ) - self.conn.commit() - return KnowledgeEntry( - entry_id=row[0], - category=row[1], - content=row[2], - metadata=json.loads(row[3]) if row[3] else {}, - created_at=row[4], - updated_at=row[5], - access_count=row[6] + 1, - importance_score=row[7], - ) + row = cursor.fetchone() - return None - - def search_entries( - self, query: str, category: Optional[str] = None, top_k: int = 5 - ) -> List[KnowledgeEntry]: - search_results = self.semantic_index.search(query, top_k * 2) - - cursor = self.conn.cursor() - - entries = [] - for entry_id, score in search_results: - if category: + if row: cursor.execute( """ - SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score - FROM knowledge_entries - WHERE entry_id = ? AND category = ? - """, - (entry_id, category), - ) - else: - cursor.execute( - """ - SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score - FROM knowledge_entries + UPDATE knowledge_entries + SET access_count = access_count + 1 WHERE entry_id = ? """, (entry_id,), ) + self.conn.commit() - row = cursor.fetchone() - if row: - entry = KnowledgeEntry( + return KnowledgeEntry( entry_id=row[0], category=row[1], content=row[2], metadata=json.loads(row[3]) if row[3] else {}, created_at=row[4], updated_at=row[5], - access_count=row[6], + access_count=row[6] + 1, importance_score=row[7], ) - entries.append(entry) - if len(entries) >= top_k: - break + return None - return entries + def search_entries( + self, query: str, category: Optional[str] = None, top_k: int = 5 + ) -> List[KnowledgeEntry]: + # Combine semantic search with exact matching + semantic_results = self.semantic_index.search(query, top_k * 2) - def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]: - cursor = self.conn.cursor() + # Add FTS (Full Text Search) with exact word/phrase matching + fts_results = self._fts_search(query, top_k * 2) - cursor.execute( - """ - SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score - FROM knowledge_entries - WHERE category = ? - ORDER BY importance_score DESC, created_at DESC - LIMIT ? - """, - (category, limit), - ) + # Combine and deduplicate results with weighted scoring + combined_results = {} - entries = [] - for row in cursor.fetchall(): - entries.append( - KnowledgeEntry( - entry_id=row[0], - category=row[1], - content=row[2], - metadata=json.loads(row[3]) if row[3] else {}, - created_at=row[4], - updated_at=row[5], - access_count=row[6], - importance_score=row[7], - ) + # Add semantic results with weight 0.7 + for entry_id, score in semantic_results: + combined_results[entry_id] = score * 0.7 + + # Add FTS results with weight 1.0 (higher priority for exact matches) + for entry_id, score in fts_results: + if entry_id in combined_results: + combined_results[entry_id] = max(combined_results[entry_id], score * 1.0) + else: + combined_results[entry_id] = score * 1.0 + + # Sort by combined score + sorted_results = sorted(combined_results.items(), key=lambda x: x[1], reverse=True) + + with self.lock: + cursor = self.conn.cursor() + + entries = [] + for entry_id, score in sorted_results[:top_k]: + if category: + cursor.execute( + """ + SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score + FROM knowledge_entries + WHERE entry_id = ? AND category = ? + """, + (entry_id, category), + ) + else: + cursor.execute( + """ + SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score + FROM knowledge_entries + WHERE entry_id = ? + """, + (entry_id,), + ) + + row = cursor.fetchone() + if row: + entry = KnowledgeEntry( + entry_id=row[0], + category=row[1], + content=row[2], + metadata=json.loads(row[3]) if row[3] else {}, + created_at=row[4], + updated_at=row[5], + access_count=row[6], + importance_score=row[7], + ) + # Add search score to metadata for context + entry.metadata["search_score"] = score + entries.append(entry) + + return entries + + def _fts_search(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]: + """Full Text Search with exact word and partial sentence matching.""" + with self.lock: + cursor = self.conn.cursor() + + # Prepare query for FTS + query_lower = query.lower() + query_words = query_lower.split() + + # Search for exact phrase matches first + cursor.execute( + """ + SELECT entry_id, content + FROM knowledge_entries + WHERE LOWER(content) LIKE ? + """, + (f"%{query_lower}%",), ) - return entries + exact_matches = [] + partial_matches = [] + + for row in cursor.fetchall(): + entry_id, content = row + content_lower = content.lower() + + # Exact phrase match gets highest score + if query_lower in content_lower: + exact_matches.append((entry_id, 1.0)) + continue + + # Count matching words + content_words = set(content_lower.split()) + query_word_set = set(query_words) + matching_words = len(query_word_set & content_words) + + if matching_words > 0: + # Score based on word overlap and position + word_overlap_score = matching_words / len(query_word_set) + + # Bonus for consecutive word sequences + consecutive_bonus = 0.0 + for i in range(len(query_words)): + for j in range(i + 1, min(i + 4, len(query_words) + 1)): + phrase = " ".join(query_words[i:j]) + if phrase in content_lower: + consecutive_bonus += 0.2 * (j - i) + + total_score = min(0.99, word_overlap_score + consecutive_bonus) + partial_matches.append((entry_id, total_score)) + + # Combine results, prioritizing exact matches + all_results = exact_matches + partial_matches + all_results.sort(key=lambda x: x[1], reverse=True) + + return all_results[:top_k] + + def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]: + with self.lock: + cursor = self.conn.cursor() + + cursor.execute( + """ + SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score + FROM knowledge_entries + WHERE category = ? + ORDER BY importance_score DESC, created_at DESC + LIMIT ? + """, + (category, limit), + ) + + entries = [] + for row in cursor.fetchall(): + entries.append( + KnowledgeEntry( + entry_id=row[0], + category=row[1], + content=row[2], + metadata=json.loads(row[3]) if row[3] else {}, + created_at=row[4], + updated_at=row[5], + access_count=row[6], + importance_score=row[7], + ) + ) + + return entries def update_importance(self, entry_id: str, importance_score: float): - cursor = self.conn.cursor() + with self.lock: + cursor = self.conn.cursor() - cursor.execute( - """ - UPDATE knowledge_entries - SET importance_score = ?, updated_at = ? - WHERE entry_id = ? - """, - (importance_score, time.time(), entry_id), - ) + cursor.execute( + """ + UPDATE knowledge_entries + SET importance_score = ?, updated_at = ? + WHERE entry_id = ? + """, + (importance_score, time.time(), entry_id), + ) - self.conn.commit() + self.conn.commit() def delete_entry(self, entry_id: str) -> bool: - cursor = self.conn.cursor() + with self.lock: + cursor = self.conn.cursor() - cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,)) - deleted = cursor.rowcount > 0 + cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,)) + deleted = cursor.rowcount > 0 - self.conn.commit() + self.conn.commit() - if deleted: - self.semantic_index.remove_document(entry_id) + if deleted: + self.semantic_index.remove_document(entry_id) - return deleted + return deleted def get_statistics(self) -> Dict[str, Any]: - cursor = self.conn.cursor() + with self.lock: + cursor = self.conn.cursor() - cursor.execute("SELECT COUNT(*) FROM knowledge_entries") - total_entries = cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM knowledge_entries") + total_entries = cursor.fetchone()[0] - cursor.execute("SELECT COUNT(DISTINCT category) FROM knowledge_entries") - total_categories = cursor.fetchone()[0] + cursor.execute("SELECT COUNT(DISTINCT category) FROM knowledge_entries") + total_categories = cursor.fetchone()[0] - cursor.execute( + cursor.execute( + """ + SELECT category, COUNT(*) as count + FROM knowledge_entries + GROUP BY category + ORDER BY count DESC """ - SELECT category, COUNT(*) as count - FROM knowledge_entries - GROUP BY category - ORDER BY count DESC - """ - ) - category_counts = {row[0]: row[1] for row in cursor.fetchall()} + ) + category_counts = {row[0]: row[1] for row in cursor.fetchall()} - cursor.execute("SELECT SUM(access_count) FROM knowledge_entries") - total_accesses = cursor.fetchone()[0] or 0 + cursor.execute("SELECT SUM(access_count) FROM knowledge_entries") + total_accesses = cursor.fetchone()[0] or 0 - return { - "total_entries": total_entries, - "total_categories": total_categories, - "category_distribution": category_counts, - "total_accesses": total_accesses, - "vocabulary_size": len(self.semantic_index.vocabulary), - } + return { + "total_entries": total_entries, + "total_categories": total_categories, + "category_distribution": category_counts, + "total_accesses": total_accesses, + "vocabulary_size": len(self.semantic_index.vocabulary), + } diff --git a/pr/memory/semantic_index.py b/pr/memory/semantic_index.py index bceedb5..e2015a8 100644 --- a/pr/memory/semantic_index.py +++ b/pr/memory/semantic_index.py @@ -9,7 +9,7 @@ class SemanticIndex: self.documents: Dict[str, str] = {} self.vocabulary: Set[str] = set() self.idf_scores: Dict[str, float] = {} - self.doc_vectors: Dict[str, Dict[str, float]] = {} + self.doc_tf_scores: Dict[str, Dict[str, float]] = {} def _tokenize(self, text: str) -> List[str]: text = text.lower() @@ -46,22 +46,26 @@ class SemanticIndex: tokens = self._tokenize(text) self.vocabulary.update(tokens) - self._compute_idf() - tf_scores = self._compute_tf(tokens) - self.doc_vectors[doc_id] = { - token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) for token in tokens - } + self.doc_tf_scores[doc_id] = tf_scores + + self._compute_idf() def remove_document(self, doc_id: str): if doc_id in self.documents: del self.documents[doc_id] - if doc_id in self.doc_vectors: - del self.doc_vectors[doc_id] + if doc_id in self.doc_tf_scores: + del self.doc_tf_scores[doc_id] self._compute_idf() def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: + if not query.strip(): + return [] + query_tokens = self._tokenize(query) + if not query_tokens: + return [] + query_tf = self._compute_tf(query_tokens) query_vector = { @@ -69,7 +73,10 @@ class SemanticIndex: } scores = [] - for doc_id, doc_vector in self.doc_vectors.items(): + for doc_id, doc_tf in self.doc_tf_scores.items(): + doc_vector = { + token: doc_tf.get(token, 0) * self.idf_scores.get(token, 0) for token in doc_tf + } similarity = self._cosine_similarity(query_vector, doc_vector) scores.append((doc_id, similarity)) diff --git a/pr/multiplexer.py.bak b/pr/multiplexer.py.bak new file mode 100644 index 0000000..c2f2a35 --- /dev/null +++ b/pr/multiplexer.py.bak @@ -0,0 +1,384 @@ +import queue +import subprocess +import sys +import threading +import time + +from pr.tools.process_handlers import detect_process_type, get_handler_for_process +from pr.tools.prompt_detection import get_global_detector +from pr.ui import Colors + + +class TerminalMultiplexer: + def __init__(self, name, show_output=True): + self.name = name + self.show_output = show_output + self.stdout_buffer = [] + self.stderr_buffer = [] + self.stdout_queue = queue.Queue() + self.stderr_queue = queue.Queue() + self.active = True + self.lock = threading.Lock() + self.metadata = { + "start_time": time.time(), + "last_activity": time.time(), + "interaction_count": 0, + "process_type": "unknown", + "state": "active", + } + self.handler = None + self.prompt_detector = get_global_detector() + + if self.show_output: + self.display_thread = threading.Thread(target=self._display_worker, daemon=True) + self.display_thread.start() + + def _display_worker(self): + while self.active: + try: + line = self.stdout_queue.get(timeout=0.1) + if line: + if self.metadata.get("process_type") in ["vim", "ssh"]: + sys.stdout.write(line) + else: + sys.stdout.write(f"{Colors.GRAY}[{self.name}]{Colors.RESET} {line}\n") + sys.stdout.flush() + except queue.Empty: + pass + + try: + line = self.stderr_queue.get(timeout=0.1) + if line: + if self.metadata.get("process_type") in ["vim", "ssh"]: + sys.stderr.write(line) + else: + sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}\n") + sys.stderr.flush() + except queue.Empty: + pass + + def write_stdout(self, data): + with self.lock: + self.stdout_buffer.append(data) + self.metadata["last_activity"] = time.time() + # Update handler state if available + if self.handler: + self.handler.update_state(data) + # Update prompt detector + self.prompt_detector.update_session_state( + self.name, data, self.metadata["process_type"] + ) + if self.show_output: + self.stdout_queue.put(data) + + def write_stderr(self, data): + with self.lock: + self.stderr_buffer.append(data) + self.metadata["last_activity"] = time.time() + # Update handler state if available + if self.handler: + self.handler.update_state(data) + # Update prompt detector + self.prompt_detector.update_session_state( + self.name, data, self.metadata["process_type"] + ) + if self.show_output: + self.stderr_queue.put(data) + + def get_stdout(self): + with self.lock: + return "".join(self.stdout_buffer) + + def get_stderr(self): + with self.lock: + return "".join(self.stderr_buffer) + + def get_all_output(self): + with self.lock: + return { + "stdout": "".join(self.stdout_buffer), + "stderr": "".join(self.stderr_buffer), + } + + def get_metadata(self): + with self.lock: + return self.metadata.copy() + + def update_metadata(self, key, value): + with self.lock: + self.metadata[key] = value + + def set_process_type(self, process_type): + """Set the process type and initialize appropriate handler.""" + with self.lock: + self.metadata["process_type"] = process_type + self.handler = get_handler_for_process(process_type, self) + + def send_input(self, input_data): + if hasattr(self, "process") and self.process.poll() is None: + try: + self.process.stdin.write(input_data + "\n") + self.process.stdin.flush() + with self.lock: + self.metadata["last_activity"] = time.time() + self.metadata["interaction_count"] += 1 + except Exception as e: + self.write_stderr(f"Error sending input: {e}") + else: + # This will be implemented when we have a process attached + # For now, just update activity + with self.lock: + self.metadata["last_activity"] = time.time() + self.metadata["interaction_count"] += 1 + + def close(self): + self.active = False + if hasattr(self, "display_thread"): + self.display_thread.join(timeout=1) + + +_multiplexers = {} +_mux_counter = 0 +_mux_lock = threading.Lock() +_background_monitor = None +_monitor_active = False +_monitor_interval = 0.2 # 200ms + + +def create_multiplexer(name=None, show_output=True): + global _mux_counter + with _mux_lock: + if name is None: + _mux_counter += 1 + name = f"process-{_mux_counter}" + mux = TerminalMultiplexer(name, show_output) + _multiplexers[name] = mux + return name, mux + + +def get_multiplexer(name): + return _multiplexers.get(name) + + +def close_multiplexer(name): + mux = _multiplexers.get(name) + if mux: + mux.close() + del _multiplexers[name] + + +def get_all_multiplexer_states(): + with _mux_lock: + states = {} + for name, mux in _multiplexers.items(): + states[name] = { + "metadata": mux.get_metadata(), + "output_summary": { + "stdout_lines": len(mux.stdout_buffer), + "stderr_lines": len(mux.stderr_buffer), + }, + } + return states + + +def cleanup_all_multiplexers(): + for mux in list(_multiplexers.values()): + mux.close() + _multiplexers.clear() + + +# Background process management +_background_processes = {} +_process_lock = threading.Lock() + + +class BackgroundProcess: + def __init__(self, name, command): + self.name = name + self.command = command + self.process = None + self.multiplexer = None + self.status = "starting" + self.start_time = time.time() + self.end_time = None + + def start(self): + """Start the background process.""" + try: + # Create multiplexer for this process + mux_name, mux = create_multiplexer(self.name, show_output=False) + self.multiplexer = mux + + # Detect process type + process_type = detect_process_type(self.command) + mux.set_process_type(process_type) + + # Start the subprocess + self.process = subprocess.Popen( + self.command, + shell=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + universal_newlines=True, + ) + + self.status = "running" + + # Start output monitoring threads + threading.Thread(target=self._monitor_stdout, daemon=True).start() + threading.Thread(target=self._monitor_stderr, daemon=True).start() + + return {"status": "success", "pid": self.process.pid} + + except Exception as e: + self.status = "error" + return {"status": "error", "error": str(e)} + + def _monitor_stdout(self): + """Monitor stdout from the process.""" + try: + for line in iter(self.process.stdout.readline, ""): + if line: + self.multiplexer.write_stdout(line.rstrip("\n\r")) + except Exception as e: + self.write_stderr(f"Error reading stdout: {e}") + finally: + self._check_completion() + + def _monitor_stderr(self): + """Monitor stderr from the process.""" + try: + for line in iter(self.process.stderr.readline, ""): + if line: + self.multiplexer.write_stderr(line.rstrip("\n\r")) + except Exception as e: + self.write_stderr(f"Error reading stderr: {e}") + + def _check_completion(self): + """Check if process has completed.""" + if self.process and self.process.poll() is not None: + self.status = "completed" + self.end_time = time.time() + + def get_info(self): + """Get process information.""" + self._check_completion() + return { + "name": self.name, + "command": self.command, + "status": self.status, + "pid": self.process.pid if self.process else None, + "start_time": self.start_time, + "end_time": self.end_time, + "runtime": ( + time.time() - self.start_time + if not self.end_time + else self.end_time - self.start_time + ), + } + + def get_output(self, lines=None): + """Get process output.""" + if not self.multiplexer: + return [] + + all_output = self.multiplexer.get_all_output() + stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else [] + stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else [] + + combined = stdout_lines + stderr_lines + if lines: + combined = combined[-lines:] + + return [line for line in combined if line.strip()] + + def send_input(self, input_text): + """Send input to the process.""" + if self.process and self.status == "running": + try: + self.process.stdin.write(input_text + "\n") + self.process.stdin.flush() + return {"status": "success"} + except Exception as e: + return {"status": "error", "error": str(e)} + return {"status": "error", "error": "Process not running or no stdin"} + + def kill(self): + """Kill the process.""" + if self.process and self.status == "running": + try: + self.process.terminate() + # Wait a bit for graceful termination + time.sleep(0.1) + if self.process.poll() is None: + self.process.kill() + self.status = "killed" + self.end_time = time.time() + return {"status": "success"} + except Exception as e: + return {"status": "error", "error": str(e)} + return {"status": "error", "error": "Process not running"} + + +def start_background_process(name, command): + """Start a background process.""" + with _process_lock: + if name in _background_processes: + return {"status": "error", "error": f"Process {name} already exists"} + + process = BackgroundProcess(name, command) + result = process.start() + + if result["status"] == "success": + _background_processes[name] = process + + return result + + +def get_all_sessions(): + """Get all background process sessions.""" + with _process_lock: + sessions = {} + for name, process in _background_processes.items(): + sessions[name] = process.get_info() + return sessions + + +def get_session_info(name): + """Get information about a specific session.""" + with _process_lock: + process = _background_processes.get(name) + return process.get_info() if process else None + + +def get_session_output(name, lines=None): + """Get output from a specific session.""" + with _process_lock: + process = _background_processes.get(name) + return process.get_output(lines) if process else None + + +def send_input_to_session(name, input_text): + """Send input to a background session.""" + with _process_lock: + process = _background_processes.get(name) + return ( + process.send_input(input_text) + if process + else {"status": "error", "error": "Session not found"} + ) + + +def kill_session(name): + """Kill a background session.""" + with _process_lock: + process = _background_processes.get(name) + if process: + result = process.kill() + if result["status"] == "success": + del _background_processes[name] + return result + return {"status": "error", "error": "Session not found"} diff --git a/pr/server.py b/pr/server.py new file mode 100644 index 0000000..10bc816 --- /dev/null +++ b/pr/server.py @@ -0,0 +1,585 @@ +#!/usr/bin/env python3 +import asyncio +import logging +import os +import sys +import subprocess +import urllib.parse +from pathlib import Path +from typing import Dict, List, Tuple, Optional +import pathlib +import ast +from types import SimpleNamespace +import time +from aiohttp import web +import aiohttp + +# --- Optional external dependency used in your original code --- +from pr.ads import AsyncDataSet # noqa: E402 + +import sys +from http import cookies +import urllib.parse +import os +import io +import json +import urllib.request +import pathlib +import pickle +import zlib + +import asyncio +import time +from functools import wraps +from pathlib import Path +import pickle +import zlib + + +class Cache: + def __init__(self, base_dir: Path | str = "."): + self.base_dir = Path(base_dir).resolve() + self.base_dir.mkdir(exist_ok=True, parents=True) + + def is_cacheable(self, obj): + return isinstance(obj, (int, str, bool, float)) + + def generate_key(self, *args, **kwargs): + + return zlib.crc32( + json.dumps( + { + "args": [arg for arg in args if self.is_cacheable(arg)], + "kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)}, + }, + sort_keys=True, + default=str, + ).encode() + ) + + def set(self, key, value): + key = self.generate_key(key) + data = {"value": value, "timestamp": time.time()} + serialized_data = pickle.dumps(data) + with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f: + f.write(serialized_data) + + def get(self, key, default=None): + key = self.generate_key(key) + try: + with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f: + data = pickle.loads(f.read()) + return data + except FileNotFoundError: + return default + + def is_cacheable(self, obj): + return isinstance(obj, (int, str, bool, float)) + + def cached(self, expiration=60): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + # if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()): + # return await func(*args, **kwargs) + key = self.generate_key(*args, **kwargs) + print("Cache hit:", key) + cached_data = self.get(key) + if cached_data is not None: + if expiration is None or (time.time() - cached_data["timestamp"]) < expiration: + return cached_data["value"] + result = await func(*args, **kwargs) + self.set(key, result) + return result + + return wrapper + + return decorator + + +cache = Cache() + + +class CGI: + + def __init__(self): + + self.environ = os.environ + if not self.gateway_interface == "CGI/1.1": + return + self.status = 200 + self.headers = {} + self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi")) + self.headers["Content-Length"] = "0" + self.headers["Cache-Control"] = "no-cache" + self.headers["Access-Control-Allow-Origin"] = "*" + self.headers["Content-Type"] = "application/json; charset=utf-8" + self.cookies = cookies.SimpleCookie() + + self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"]) + self.file = io.BytesIO() + + def validate(self, val, fn, error): + if fn(val): + return True + self.status = 421 + self.write({"type": "error", "message": error}) + exit() + + def __getitem__(self, key): + return self.get(key) + + def get(self, key, default=None): + result = self.query.get(key, [default]) + if len(result) == 1: + return result[0] + return result + + def get_bool(self, key): + return str(self.get(key)).lower() in ["true", "yes", "1"] + + @property + def cache_key(self): + return self.environ["CACHE_KEY"] + + @property + def env(self): + env = os.environ.copy() + return env + + @property + def gateway_interface(self): + return self.env.get("GATEWAY_INTERFACE", "") + + @property + def request_method(self): + return self.env["REQUEST_METHOD"] + + @property + def query_string(self): + return self.env["QUERY_STRING"] + + @property + def script_name(self): + return self.env["SCRIPT_NAME"] + + @property + def path_info(self): + return self.env["PATH_INFO"] + + @property + def server_name(self): + return self.env["SERVER_NAME"] + + @property + def server_port(self): + return self.env["SERVER_PORT"] + + @property + def server_protocol(self): + return self.env["SERVER_PROTOCOL"] + + @property + def remote_addr(self): + return self.env["REMOTE_ADDR"] + + def validate_get(self, key): + self.validate(self.get(key), lambda x: bool(x), f"Missing {key}") + + def print(self, data): + self.write(data) + + def write(self, data): + if not isinstance(data, bytes) and not isinstance(data, str): + data = json.dumps(data, default=str, indent=4) + try: + data = data.encode() + except: + pass + self.file.write(data) + self.headers["Content-Length"] = str(len(self.file.getvalue())) + + @property + def http_status(self): + return f"HTTP/1.1 {self.status}\r\n".encode("utf-8") + + @property + def http_headers(self): + headers = io.BytesIO() + for header in self.headers: + headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8")) + headers.write(b"\r\n") + return headers.getvalue() + + @property + def http_body(self): + return self.file.getvalue() + + @property + def http_response(self): + return self.http_status + self.http_headers + self.http_body + + def flush(self, response=None): + + if response: + try: + response = response.encode() + except: + pass + sys.stdout.buffer.write(response) + sys.stdout.buffer.flush() + return + sys.stdout.buffer.write(self.http_response) + sys.stdout.buffer.flush() + + def __del__(self): + if self.http_body: + self.flush() + exit() + + +# ------------------------------- +# Utilities +# ------------------------------- +def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]: + matches: List[Tuple[str, str]] = [] + for root, _, files in os.walk(directory): + for file in files: + if not file.endswith(".py"): + continue + path = os.path.join(root, file) + try: + with open(path, "r", encoding="utf-8") as fh: + source = fh.read() + tree = ast.parse(source, filename=path) + for node in ast.walk(tree): + if isinstance(node, ast.AsyncFunctionDef) and node.name == name: + func_src = ast.get_source_segment(source, node) + if func_src: + matches.append((path, func_src)) + break + except (SyntaxError, UnicodeDecodeError): + continue + return matches + + +# ------------------------------- +# Server +# ------------------------------- +class Static(SimpleNamespace): + pass + + +class View(web.View): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.db = self.request.app["db"] + self.server = self.request.app["server"] + + @property + def client(self): + return aiohttp.ClientSession() + + +class RetoorServer: + def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None: + self.base_dir = Path(base_dir).resolve() + self.port = port + + self.static = Static() + self.db = AsyncDataSet(".default.db") + + self._logger = logging.getLogger("retoor.server") + self._logger.setLevel(logging.INFO) + if not self._logger.handlers: + h = logging.StreamHandler(sys.stdout) + fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s") + h.setFormatter(fmt) + self._logger.addHandler(h) + + self._func_cache: Dict[str, Tuple[str, str]] = {} + self._compiled_cache: Dict[str, object] = {} + self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {} + self._static_path_cache: Dict[str, Path] = {} + self._base_dir_str = str(self.base_dir) + + self.app = web.Application() + + app = self.app + + app["db"] = self.db + for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir(): + if path.suffix == ".py": + exec(open(path).read(), globals(), locals()) + + self.app["server"] = self + + # from rpanel import create_app as create_rpanel_app + # self.app.add_subapp("/api/{tail:.*}", create_rpanel_app()) + + self.app.router.add_route("*", "/{tail:.*}", self.handle_any) + + # --------------------------- + # Simple endpoints + # --------------------------- + async def handle_root(self, request: web.Request) -> web.Response: + return web.Response(text="Static file and cgi server from retoor.") + + async def handle_hello(self, request: web.Request) -> web.Response: + return web.Response(text="Welcome to the custom HTTP server!") + + # --------------------------- + # Dynamic function dispatch + # --------------------------- + def _path_to_funcname(self, path: str) -> str: + # /foo/bar -> foo_bar ; strip leading/trailing slashes safely + return path.strip("/").replace("/", "_") + + async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]: + relpath = request.path.strip("/") + relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace( + "..", "" + ) + if not relpath: + return None + + last_part = relpath.split("/")[-1] + if "." in last_part: + return None + + funcname = self._path_to_funcname(request.path) + + if funcname not in self._compiled_cache: + entry = self._func_cache.get(funcname) + if entry is None: + matches = get_function_source(funcname, directory=str(self.base_dir)) + if not matches: + return None + entry = matches[0] + self._func_cache[funcname] = entry + + filepath, source = entry + code_obj = compile(source, filepath, "exec") + self._compiled_cache[funcname] = (code_obj, filepath) + + code_obj, filepath = self._compiled_cache[funcname] + ctx = { + "static": self.static, + "request": request, + "relpath": relpath, + "os": os, + "sys": sys, + "web": web, + "asyncio": asyncio, + "subprocess": subprocess, + "urllib": urllib.parse, + "app": request.app, + "db": self.db, + } + exec(code_obj, ctx, ctx) + coro = ctx.get(funcname) + if not callable(coro): + return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.") + return await coro(request) + + # --------------------------- + # Static files + # --------------------------- + async def handle_static(self, request: web.Request) -> web.StreamResponse: + path = request.path + + if path in self._static_path_cache: + cached_path = self._static_path_cache[path] + if cached_path is None: + return web.Response(status=404, text="Not found") + try: + return web.FileResponse(path=cached_path) + except Exception as e: + return web.Response(status=500, text=f"Failed to send file: {e!r}") + + relpath = path.replace('..",', "").lstrip("/").rstrip("/") + abspath = (self.base_dir / relpath).resolve() + + if not str(abspath).startswith(self._base_dir_str): + return web.Response(status=403, text="Forbidden") + + if not abspath.exists(): + self._static_path_cache[path] = None + return web.Response(status=404, text="Not found") + + if abspath.is_dir(): + return web.Response(status=403, text="Directory listing forbidden") + + self._static_path_cache[path] = abspath + try: + return web.FileResponse(path=abspath) + except Exception as e: + return web.Response(status=500, text=f"Failed to send file: {e!r}") + + # --------------------------- + # CGI + # --------------------------- + def _find_cgi_script( + self, path_only: str + ) -> Tuple[Optional[Path], Optional[str], Optional[str]]: + path_only = "pr/cgi/" + path_only + if path_only in self._cgi_cache: + return self._cgi_cache[path_only] + + split_path = [p for p in path_only.split("/") if p] + for i in range(len(split_path), -1, -1): + candidate = "/" + "/".join(split_path[:i]) + candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve() + if ( + str(candidate_fs).startswith(self._base_dir_str) + and candidate_fs.parent.is_dir() + and candidate_fs.parent.name == "cgi" + and candidate_fs.suffix == ".bin" + and candidate_fs.is_file() + and os.access(candidate_fs, os.X_OK) + ): + script_name = candidate if candidate.startswith("/") else "/" + candidate + path_info = path_only[len(script_name) :] + result = (candidate_fs, script_name, path_info) + self._cgi_cache[path_only] = result + return result + + result = (None, None, None) + self._cgi_cache[path_only] = result + return result + + async def _run_cgi( + self, request: web.Request, script_path: Path, script_name: str, path_info: str + ) -> web.Response: + start_time = time.time() + method = request.method + query = request.query_string + + env = os.environ.copy() + env["CACHE_KEY"] = json.dumps( + {"method": method, "query": query, "script_name": script_name, "path_info": path_info}, + default=str, + ) + env["GATEWAY_INTERFACE"] = "CGI/1.1" + env["REQUEST_METHOD"] = method + env["QUERY_STRING"] = query + env["SCRIPT_NAME"] = script_name + env["PATH_INFO"] = path_info + + host_parts = request.host.split(":", 1) + env["SERVER_NAME"] = host_parts[0] + env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80" + env["SERVER_PROTOCOL"] = "HTTP/1.1" + + peername = request.transport.get_extra_info("peername") + env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0]) + + for hk, hv in request.headers.items(): + hk_lower = hk.lower() + if hk_lower == "content-type": + env["CONTENT_TYPE"] = hv + elif hk_lower == "content-length": + env["CONTENT_LENGTH"] = hv + else: + env["HTTP_" + hk.upper().replace("-", "_")] = hv + + post_data = None + if method in {"POST", "PUT", "PATCH"}: + post_data = await request.read() + env.setdefault("CONTENT_LENGTH", str(len(post_data))) + + try: + proc = await asyncio.create_subprocess_exec( + str(script_path), + stdin=asyncio.subprocess.PIPE if post_data else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + stdout, stderr = await proc.communicate(input=post_data) + + if proc.returncode != 0: + msg = stderr.decode(errors="ignore") + return web.Response(status=500, text=f"CGI script error:\n{msg}") + + header_end = stdout.find(b"\r\n\r\n") + if header_end != -1: + offset = 4 + else: + header_end = stdout.find(b"\n\n") + offset = 2 + + if header_end != -1: + body = stdout[header_end + offset :] + headers_blob = stdout[:header_end].decode(errors="ignore") + + status_code = 200 + headers: Dict[str, str] = {} + for line in headers_blob.splitlines(): + line = line.strip() + if not line or ":" not in line: + continue + + k, v = line.split(":", 1) + k_stripped = k.strip() + if k_stripped.lower() == "status": + try: + status_code = int(v.strip().split()[0]) + except Exception: + status_code = 200 + else: + headers[k_stripped] = v.strip() + + end_time = time.time() + headers["X-Powered-By"] = "retoor" + headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime()) + headers["X-Time"] = str(end_time) + headers["X-Duration"] = str(end_time - start_time) + if "error" in str(body).lower(): + print(headers) + print(body) + return web.Response(body=body, headers=headers, status=status_code) + else: + return web.Response(body=stdout) + + except Exception as e: + return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}") + + # --------------------------- + # Main router + # --------------------------- + async def handle_any(self, request: web.Request) -> web.StreamResponse: + path = request.path + + if path == "/": + return await self.handle_root(request) + if path == "/hello": + return await self.handle_hello(request) + + script_path, script_name, path_info = self._find_cgi_script(path) + if script_path: + return await self._run_cgi(request, script_path, script_name or "", path_info or "") + + dyn = await self._maybe_dispatch_dynamic(request) + if dyn is not None: + return dyn + + return await self.handle_static(request) + + # --------------------------- + # Runner + # --------------------------- + def run(self) -> None: + self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir) + web.run_app(self.app, port=self.port) + + +def main() -> None: + server = RetoorServer(base_dir=".", port=8118) + server.run() + + +if __name__ == "__main__": + main() +else: + cgi = CGI() diff --git a/pr/tools/__init__.py b/pr/tools/__init__.py index f8a92a3..929822b 100644 --- a/pr/tools/__init__.py +++ b/pr/tools/__init__.py @@ -43,6 +43,11 @@ from pr.tools.memory import ( from pr.tools.patch import apply_patch, create_diff from pr.tools.python_exec import python_exec from pr.tools.web import http_fetch, web_search, web_search_news +from pr.tools.context_modifier import ( + modify_context_add, + modify_context_replace, + modify_context_delete, +) __all__ = [ "add_knowledge_entry", diff --git a/pr/tools/base.py b/pr/tools/base.py index 2d9d962..35fbe70 100644 --- a/pr/tools/base.py +++ b/pr/tools/base.py @@ -1,596 +1,144 @@ +import inspect +from typing import get_type_hints, get_origin, get_args +import pr.tools + + +def _type_to_json_schema(py_type): + """Convert Python type to JSON Schema type.""" + if py_type == str: + return {"type": "string"} + elif py_type == int: + return {"type": "integer"} + elif py_type == float: + return {"type": "number"} + elif py_type == bool: + return {"type": "boolean"} + elif get_origin(py_type) == list: + return {"type": "array", "items": _type_to_json_schema(get_args(py_type)[0])} + elif get_origin(py_type) == dict: + return {"type": "object"} + else: + # Default to string for unknown types + return {"type": "string"} + + +def _generate_tool_schema(func): + """Generate JSON Schema for a tool function.""" + sig = inspect.signature(func) + docstring = func.__doc__ or "" + + # Extract description from docstring + description = docstring.strip().split("\n")[0] if docstring else "" + + # Get type hints + type_hints = get_type_hints(func) + + properties = {} + required = [] + + for param_name, param in sig.parameters.items(): + if param_name in ["db_conn", "python_globals"]: # Skip internal parameters + continue + + param_type = type_hints.get(param_name, str) + schema = _type_to_json_schema(param_type) + + # Add description from docstring if available + param_doc = "" + if docstring: + lines = docstring.split("\n") + in_args = False + for line in lines: + line = line.strip() + if line.startswith("Args:") or line.startswith("Arguments:"): + in_args = True + continue + elif in_args and line.startswith(param_name + ":"): + param_doc = line.split(":", 1)[1].strip() + break + elif in_args and line == "": + continue + elif in_args and not line.startswith(" "): + break + + if param_doc: + schema["description"] = param_doc + + # Set default if available + if param.default != inspect.Parameter.empty: + schema["default"] = param.default + + properties[param_name] = schema + + # Required if no default + if param.default == inspect.Parameter.empty: + required.append(param_name) + + return { + "type": "function", + "function": { + "name": func.__name__, + "description": description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + + def get_tools_definition(): - return [ - { - "type": "function", - "function": { - "name": "kill_process", - "description": "Terminate a background process by its PID. Use this to stop processes started with run_command that exceeded their timeout.", - "parameters": { - "type": "object", - "properties": { - "pid": { - "type": "integer", - "description": "The process ID returned by run_command when status is 'running'.", - } - }, - "required": ["pid"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "tail_process", - "description": "Monitor and retrieve output from a background process by its PID. Use this to check on processes started with run_command that exceeded their timeout.", - "parameters": { - "type": "object", - "properties": { - "pid": { - "type": "integer", - "description": "The process ID returned by run_command when status is 'running'.", - }, - "timeout": { - "type": "integer", - "description": "Maximum seconds to wait for process completion. Returns partial output if still running.", - "default": 30, - }, - }, - "required": ["pid"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "http_fetch", - "description": "Fetch content from an HTTP URL", - "parameters": { - "type": "object", - "properties": { - "url": {"type": "string", "description": "The URL to fetch"}, - "headers": { - "type": "object", - "description": "Optional HTTP headers", - }, - }, - "required": ["url"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "run_command", - "description": "Execute a shell command and capture output. Returns immediately after timeout with PID if still running. Use tail_process to monitor or kill_process to terminate long-running commands.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The shell command to execute", - }, - "timeout": { - "type": "integer", - "description": "Maximum seconds to wait for completion", - "default": 30, - }, - }, - "required": ["command"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "start_interactive_session", - "description": "Execute an interactive terminal command that requires user input or displays UI. The command runs in a dedicated session and returns a session name.", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "The interactive command to execute (e.g., vim, nano, top)", - } - }, - "required": ["command"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "send_input_to_session", - "description": "Send input to an interactive session.", - "parameters": { - "type": "object", - "properties": { - "session_name": { - "type": "string", - "description": "The name of the session", - }, - "input_data": { - "type": "string", - "description": "The input to send to the session", - }, - }, - "required": ["session_name", "input_data"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "read_session_output", - "description": "Read output from an interactive session.", - "parameters": { - "type": "object", - "properties": { - "session_name": { - "type": "string", - "description": "The name of the session", - } - }, - "required": ["session_name"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "close_interactive_session", - "description": "Close an interactive session.", - "parameters": { - "type": "object", - "properties": { - "session_name": { - "type": "string", - "description": "The name of the session", - } - }, - "required": ["session_name"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "read_file", - "description": "Read contents of a file", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - } - }, - "required": ["filepath"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "write_file", - "description": "Write content to a file", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - }, - "content": { - "type": "string", - "description": "Content to write", - }, - }, - "required": ["filepath", "content"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "list_directory", - "description": "List directory contents", - "parameters": { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Directory path", - "default": ".", - }, - "recursive": { - "type": "boolean", - "description": "List recursively", - "default": False, - }, - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "mkdir", - "description": "Create a new directory", - "parameters": { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Path of the directory to create", - } - }, - "required": ["path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "chdir", - "description": "Change the current working directory", - "parameters": { - "type": "object", - "properties": {"path": {"type": "string", "description": "Path to change to"}}, - "required": ["path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "getpwd", - "description": "Get the current working directory", - "parameters": {"type": "object", "properties": {}}, - }, - }, - { - "type": "function", - "function": { - "name": "db_set", - "description": "Set a key-value pair in the database", - "parameters": { - "type": "object", - "properties": { - "key": {"type": "string", "description": "The key"}, - "value": {"type": "string", "description": "The value"}, - }, - "required": ["key", "value"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "db_get", - "description": "Get a value from the database", - "parameters": { - "type": "object", - "properties": {"key": {"type": "string", "description": "The key"}}, - "required": ["key"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "db_query", - "description": "Execute a database query", - "parameters": { - "type": "object", - "properties": {"query": {"type": "string", "description": "SQL query"}}, - "required": ["query"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "web_search", - "description": "Perform a web search", - "parameters": { - "type": "object", - "properties": {"query": {"type": "string", "description": "Search query"}}, - "required": ["query"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "web_search_news", - "description": "Perform a web search for news", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query for news", - } - }, - "required": ["query"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "python_exec", - "description": "Execute Python code", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute", - } - }, - "required": ["code"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "index_source_directory", - "description": "Index directory recursively and read all source files.", - "parameters": { - "type": "object", - "properties": {"path": {"type": "string", "description": "Path to index"}}, - "required": ["path"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "search_replace", - "description": "Search and replace text in a file", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - }, - "old_string": { - "type": "string", - "description": "String to replace", - }, - "new_string": { - "type": "string", - "description": "Replacement string", - }, - }, - "required": ["filepath", "old_string", "new_string"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "apply_patch", - "description": "Apply a patch to a file, especially for source code", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file to patch", - }, - "patch_content": { - "type": "string", - "description": "The patch content as a string", - }, - }, - "required": ["filepath", "patch_content"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "create_diff", - "description": "Create a unified diff between two files", - "parameters": { - "type": "object", - "properties": { - "file1": { - "type": "string", - "description": "Path to the first file", - }, - "file2": { - "type": "string", - "description": "Path to the second file", - }, - "fromfile": { - "type": "string", - "description": "Label for the first file", - "default": "file1", - }, - "tofile": { - "type": "string", - "description": "Label for the second file", - "default": "file2", - }, - }, - "required": ["file1", "file2"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "open_editor", - "description": "Open the RPEditor for a file", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - } - }, - "required": ["filepath"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "close_editor", - "description": "Close the RPEditor. Always close files when finished editing.", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - } - }, - "required": ["filepath"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "editor_insert_text", - "description": "Insert text at cursor position in the editor", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - }, - "text": {"type": "string", "description": "Text to insert"}, - "line": { - "type": "integer", - "description": "Line number (optional)", - }, - "col": { - "type": "integer", - "description": "Column number (optional)", - }, - }, - "required": ["filepath", "text"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "editor_replace_text", - "description": "Replace text in a range", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - }, - "start_line": {"type": "integer", "description": "Start line"}, - "start_col": {"type": "integer", "description": "Start column"}, - "end_line": {"type": "integer", "description": "End line"}, - "end_col": {"type": "integer", "description": "End column"}, - "new_text": {"type": "string", "description": "New text"}, - }, - "required": [ - "filepath", - "start_line", - "start_col", - "end_line", - "end_col", - "new_text", - ], - }, - }, - }, - { - "type": "function", - "function": { - "name": "editor_search", - "description": "Search for a pattern in the file", - "parameters": { - "type": "object", - "properties": { - "filepath": { - "type": "string", - "description": "Path to the file", - }, - "pattern": {"type": "string", "description": "Regex pattern"}, - "start_line": { - "type": "integer", - "description": "Start line", - "default": 0, - }, - }, - "required": ["filepath", "pattern"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "display_file_diff", - "description": "Display a visual colored diff between two files with syntax highlighting and statistics", - "parameters": { - "type": "object", - "properties": { - "filepath1": { - "type": "string", - "description": "Path to the original file", - }, - "filepath2": { - "type": "string", - "description": "Path to the modified file", - }, - "format_type": { - "type": "string", - "description": "Display format: 'unified' or 'side-by-side'", - "default": "unified", - }, - }, - "required": ["filepath1", "filepath2"], - }, - }, - }, - { - "type": "function", - "function": { - "name": "display_edit_summary", - "description": "Display a summary of all edit operations performed during the session", - "parameters": {"type": "object", "properties": {}}, - }, - }, - { - "type": "function", - "function": { - "name": "display_edit_timeline", - "description": "Display a timeline of all edit operations with details", - "parameters": { - "type": "object", - "properties": { - "show_content": { - "type": "boolean", - "description": "Show content previews", - "default": False, - } - }, - }, - }, - }, - { - "type": "function", - "function": { - "name": "clear_edit_tracker", - "description": "Clear the edit tracker to start fresh", - "parameters": {"type": "object", "properties": {}}, - }, - }, - ] + """Dynamically generate tool definitions from all tool functions.""" + tools = [] + + # Get all functions from pr.tools modules + for name in dir(pr.tools): + if name.startswith("_"): + continue + + obj = getattr(pr.tools, name) + if callable(obj) and hasattr(obj, "__module__") and obj.__module__.startswith("pr.tools."): + # Check if it's a tool function (has docstring and proper signature) + if obj.__doc__: + try: + schema = _generate_tool_schema(obj) + tools.append(schema) + except Exception as e: + print(f"Warning: Could not generate schema for {name}: {e}") + continue + + return tools + + +def get_func_map(db_conn=None, python_globals=None): + """Dynamically generate function map for tool execution.""" + func_map = {} + + # Include all functions from __all__ in pr.tools + for name in getattr(pr.tools, "__all__", []): + if name.startswith("_"): + continue + + obj = getattr(pr.tools, name, None) + if callable(obj) and hasattr(obj, "__module__") and obj.__module__.startswith("pr.tools."): + sig = inspect.signature(obj) + params = list(sig.parameters.keys()) + + # Create wrapper based on parameters + if "db_conn" in params and "python_globals" in params: + func_map[name] = ( + lambda func=obj, db_conn=db_conn, python_globals=python_globals, **kw: func( + **kw, db_conn=db_conn, python_globals=python_globals + ) + ) + elif "db_conn" in params: + func_map[name] = lambda func=obj, db_conn=db_conn, **kw: func(**kw, db_conn=db_conn) + elif "python_globals" in params: + func_map[name] = lambda func=obj, python_globals=python_globals, **kw: func( + **kw, python_globals=python_globals + ) + else: + func_map[name] = lambda func=obj, **kw: func(**kw) + + return func_map diff --git a/pr/tools/command.py b/pr/tools/command.py index e464a28..4e56d8c 100644 --- a/pr/tools/command.py +++ b/pr/tools/command.py @@ -2,7 +2,6 @@ import select import subprocess import time -from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer _processes = {} @@ -99,6 +98,19 @@ def tail_process(pid: int, timeout: int = 30): def run_command(command, timeout=30, monitored=False, cwd=None): + """Execute a shell command and return the output. + + Args: + command: The shell command to execute. + timeout: Maximum time in seconds to wait for completion. + monitored: Whether to monitor the process (unused). + cwd: Working directory for the command. + + Returns: + Dict with status, stdout, stderr, returncode, and optionally pid if still running. + """ + from pr.multiplexer import close_multiplexer, create_multiplexer + mux_name = None try: process = subprocess.Popen( diff --git a/pr/tools/command.py.bak b/pr/tools/command.py.bak new file mode 100644 index 0000000..838a4df --- /dev/null +++ b/pr/tools/command.py.bak @@ -0,0 +1,176 @@ + print(f"Executing command: {command}") print(f"Killing process: {pid}")import os +import select +import subprocess +import time + +from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer + +_processes = {} + + +def _register_process(pid: int, process): + _processes[pid] = process + return _processes + + +def _get_process(pid: int): + return _processes.get(pid) + + +def kill_process(pid: int): + try: + process = _get_process(pid) + if process: + process.kill() + _processes.pop(pid) + + mux_name = f"cmd-{pid}" + if get_multiplexer(mux_name): + close_multiplexer(mux_name) + + return {"status": "success", "message": f"Process {pid} has been killed"} + else: + return {"status": "error", "error": f"Process {pid} not found"} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def tail_process(pid: int, timeout: int = 30): + process = _get_process(pid) + if process: + mux_name = f"cmd-{pid}" + mux = get_multiplexer(mux_name) + + if not mux: + mux_name, mux = create_multiplexer(mux_name, show_output=True) + + try: + start_time = time.time() + timeout_duration = timeout + stdout_content = "" + stderr_content = "" + + while True: + if process.poll() is not None: + remaining_stdout, remaining_stderr = process.communicate() + if remaining_stdout: + mux.write_stdout(remaining_stdout) + stdout_content += remaining_stdout + if remaining_stderr: + mux.write_stderr(remaining_stderr) + stderr_content += remaining_stderr + + if pid in _processes: + _processes.pop(pid) + + close_multiplexer(mux_name) + + return { + "status": "success", + "stdout": stdout_content, + "stderr": stderr_content, + "returncode": process.returncode, + } + + if time.time() - start_time > timeout_duration: + return { + "status": "running", + "message": "Process is still running. Call tail_process again to continue monitoring.", + "stdout_so_far": stdout_content, + "stderr_so_far": stderr_content, + "pid": pid, + } + + ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) + for pipe in ready: + if pipe == process.stdout: + line = process.stdout.readline() + if line: + mux.write_stdout(line) + stdout_content += line + elif pipe == process.stderr: + line = process.stderr.readline() + if line: + mux.write_stderr(line) + stderr_content += line + except Exception as e: + return {"status": "error", "error": str(e)} + else: + return {"status": "error", "error": f"Process {pid} not found"} + + +def run_command(command, timeout=30, monitored=False): + mux_name = None + try: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + _register_process(process.pid, process) + + mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True) + + start_time = time.time() + timeout_duration = timeout + stdout_content = "" + stderr_content = "" + + while True: + if process.poll() is not None: + remaining_stdout, remaining_stderr = process.communicate() + if remaining_stdout: + mux.write_stdout(remaining_stdout) + stdout_content += remaining_stdout + if remaining_stderr: + mux.write_stderr(remaining_stderr) + stderr_content += remaining_stderr + + if process.pid in _processes: + _processes.pop(process.pid) + + close_multiplexer(mux_name) + + return { + "status": "success", + "stdout": stdout_content, + "stderr": stderr_content, + "returncode": process.returncode, + } + + if time.time() - start_time > timeout_duration: + return { + "status": "running", + "message": f"Process still running after {timeout}s timeout. Use tail_process({process.pid}) to monitor or kill_process({process.pid}) to terminate.", + "stdout_so_far": stdout_content, + "stderr_so_far": stderr_content, + "pid": process.pid, + "mux_name": mux_name, + } + + ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) + for pipe in ready: + if pipe == process.stdout: + line = process.stdout.readline() + if line: + mux.write_stdout(line) + stdout_content += line + elif pipe == process.stderr: + line = process.stderr.readline() + if line: + mux.write_stderr(line) + stderr_content += line + except Exception as e: + if mux_name: + close_multiplexer(mux_name) + return {"status": "error", "error": str(e)} + + +def run_command_interactive(command): + try: + return_code = os.system(command) + return {"status": "success", "returncode": return_code} + except Exception as e: + return {"status": "error", "error": str(e)} \ No newline at end of file diff --git a/pr/tools/context_modifier.py b/pr/tools/context_modifier.py new file mode 100644 index 0000000..4d62900 --- /dev/null +++ b/pr/tools/context_modifier.py @@ -0,0 +1,65 @@ +import os +from typing import Optional + +CONTEXT_FILE = '/home/retoor/.local/share/rp/.rcontext.txt' + +def _read_context() -> str: + if not os.path.exists(CONTEXT_FILE): + raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.") + with open(CONTEXT_FILE, 'r') as f: + return f.read() + +def _write_context(content: str): + with open(CONTEXT_FILE, 'w') as f: + f.write(content) + +def modify_context_add(new_content: str, position: Optional[str] = None) -> str: + """ + Add new content to the .rcontext.txt file. + + Args: + new_content: The content to add. + position: Optional marker to insert before (e.g., '***'). + """ + current = _read_context() + if position and position in current: + # Insert before the position + parts = current.split(position, 1) + updated = parts[0] + new_content + '\n\n' + position + parts[1] + else: + # Append at the end + updated = current + '\n\n' + new_content + _write_context(updated) + return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested." + +def modify_context_replace(old_content: str, new_content: str) -> str: + """ + Replace old content with new content in .rcontext.txt. + + Args: + old_content: The content to replace. + new_content: The replacement content. + """ + current = _read_context() + if old_content not in current: + raise ValueError(f"Old content not found: {old_content[:50]}...") + updated = current.replace(old_content, new_content, 1) # Replace first occurrence + _write_context(updated) + return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects." + +def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str: + """ + Delete content from .rcontext.txt, but only if confirmed. + + Args: + content_to_delete: The content to delete. + confirmed: Must be True to proceed with deletion. + """ + if not confirmed: + raise PermissionError(f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently.") + current = _read_context() + if content_to_delete not in current: + raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...") + updated = current.replace(content_to_delete, '', 1) + _write_context(updated) + return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines." diff --git a/pr/tools/database.py b/pr/tools/database.py index d1f6426..81dde83 100644 --- a/pr/tools/database.py +++ b/pr/tools/database.py @@ -2,6 +2,16 @@ import time def db_set(key, value, db_conn): + """Set a key-value pair in the database. + + Args: + key: The key to set. + value: The value to store. + db_conn: Database connection. + + Returns: + Dict with status and message. + """ if not db_conn: return {"status": "error", "error": "Database not initialized"} @@ -19,6 +29,15 @@ def db_set(key, value, db_conn): def db_get(key, db_conn): + """Get a value from the database. + + Args: + key: The key to retrieve. + db_conn: Database connection. + + Returns: + Dict with status and value. + """ if not db_conn: return {"status": "error", "error": "Database not initialized"} @@ -35,6 +54,15 @@ def db_get(key, db_conn): def db_query(query, db_conn): + """Execute a database query. + + Args: + query: SQL query to execute. + db_conn: Database connection. + + Returns: + Dict with status and query results. + """ if not db_conn: return {"status": "error", "error": "Database not initialized"} diff --git a/pr/tools/debugging.py b/pr/tools/debugging.py new file mode 100644 index 0000000..02d6686 --- /dev/null +++ b/pr/tools/debugging.py @@ -0,0 +1,838 @@ +import sys +import os +import ast +import inspect +import time +import threading +import gc +import weakref +import linecache +import re +import json +import subprocess +from collections import defaultdict +from datetime import datetime + + +class MemoryTracker: + def __init__(self): + self.allocations = defaultdict(list) + self.references = weakref.WeakValueDictionary() + self.peak_memory = 0 + self.current_memory = 0 + + def track_object(self, obj, location): + try: + obj_id = id(obj) + obj_size = sys.getsizeof(obj) + self.allocations[location].append( + { + "id": obj_id, + "type": type(obj).__name__, + "size": obj_size, + "timestamp": time.time(), + } + ) + self.current_memory += obj_size + if self.current_memory > self.peak_memory: + self.peak_memory = self.current_memory + except: + pass + + def analyze_leaks(self): + gc.collect() + leaks = [] + for obj in gc.get_objects(): + if sys.getrefcount(obj) > 10: + try: + leaks.append( + { + "type": type(obj).__name__, + "refcount": sys.getrefcount(obj), + "size": sys.getsizeof(obj), + } + ) + except: + pass + return sorted(leaks, key=lambda x: x["refcount"], reverse=True)[:20] + + def get_report(self): + return { + "peak_memory": self.peak_memory, + "current_memory": self.current_memory, + "allocation_count": sum(len(v) for v in self.allocations.values()), + "leaks": self.analyze_leaks(), + } + + +class PerformanceProfiler: + def __init__(self): + self.function_times = defaultdict(lambda: {"calls": 0, "total_time": 0.0, "self_time": 0.0}) + self.call_stack = [] + self.start_times = {} + + def enter_function(self, frame): + func_name = self._get_function_name(frame) + self.call_stack.append(func_name) + self.start_times[id(frame)] = time.perf_counter() + + def exit_function(self, frame): + func_name = self._get_function_name(frame) + frame_id = id(frame) + + if frame_id in self.start_times: + elapsed = time.perf_counter() - self.start_times[frame_id] + self.function_times[func_name]["calls"] += 1 + self.function_times[func_name]["total_time"] += elapsed + self.function_times[func_name]["self_time"] += elapsed + del self.start_times[frame_id] + + if self.call_stack: + self.call_stack.pop() + + def _get_function_name(self, frame): + return f"{frame.f_code.co_filename}:{frame.f_code.co_name}:{frame.f_lineno}" + + def get_hotspots(self, limit=20): + sorted_funcs = sorted( + self.function_times.items(), key=lambda x: x[1]["total_time"], reverse=True + ) + return sorted_funcs[:limit] + + +class StaticAnalyzer(ast.NodeVisitor): + def __init__(self): + self.issues = [] + self.complexity = 0 + self.unused_vars = set() + self.undefined_vars = set() + self.defined_vars = set() + self.used_vars = set() + self.functions = {} + self.classes = {} + self.imports = [] + + def visit_FunctionDef(self, node): + self.functions[node.name] = { + "lineno": node.lineno, + "args": [arg.arg for arg in node.args.args], + "decorators": [ + d.id if isinstance(d, ast.Name) else "complex" for d in node.decorator_list + ], + "complexity": self._calculate_complexity(node), + } + + if len(node.body) == 0: + self.issues.append(f"Line {node.lineno}: Empty function '{node.name}'") + + if len(node.args.args) > 7: + self.issues.append( + f"Line {node.lineno}: Function '{node.name}' has too many parameters ({len(node.args.args)})" + ) + + self.generic_visit(node) + + def visit_ClassDef(self, node): + self.classes[node.name] = { + "lineno": node.lineno, + "bases": [b.id if isinstance(b, ast.Name) else "complex" for b in node.bases], + "methods": [], + } + self.generic_visit(node) + + def visit_Import(self, node): + for alias in node.names: + self.imports.append(alias.name) + self.generic_visit(node) + + def visit_ImportFrom(self, node): + if node.module: + self.imports.append(node.module) + self.generic_visit(node) + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store): + self.defined_vars.add(node.id) + elif isinstance(node.ctx, ast.Load): + self.used_vars.add(node.id) + self.generic_visit(node) + + def visit_If(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_For(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_While(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_ExceptHandler(self, node): + self.complexity += 1 + if node.type is None: + self.issues.append(f"Line {node.lineno}: Bare except clause (catches all exceptions)") + self.generic_visit(node) + + def visit_BinOp(self, node): + if isinstance(node.op, ast.Add): + if isinstance(node.left, ast.Str) or isinstance(node.right, ast.Str): + self.issues.append( + f"Line {node.lineno}: String concatenation with '+' (use f-strings or join)" + ) + self.generic_visit(node) + + def _calculate_complexity(self, node): + complexity = 1 + for child in ast.walk(node): + if isinstance(child, (ast.If, ast.For, ast.While, ast.ExceptHandler)): + complexity += 1 + return complexity + + def finalize(self): + self.unused_vars = self.defined_vars - self.used_vars + self.undefined_vars = self.used_vars - self.defined_vars - set(dir(__builtins__)) + + for var in self.unused_vars: + self.issues.append(f"Unused variable: '{var}'") + + def analyze_code(self, source_code): + try: + tree = ast.parse(source_code) + self.visit(tree) + self.finalize() + return { + "issues": self.issues, + "complexity": self.complexity, + "functions": self.functions, + "classes": self.classes, + "imports": self.imports, + "unused_vars": list(self.unused_vars), + } + except SyntaxError as e: + return {"error": f"Syntax error at line {e.lineno}: {e.msg}"} + + +class DynamicTracer: + def __init__(self): + self.execution_trace = [] + self.exception_trace = [] + self.variable_changes = defaultdict(list) + self.line_coverage = set() + self.function_calls = defaultdict(int) + self.max_trace_length = 10000 + self.memory_tracker = MemoryTracker() + self.profiler = PerformanceProfiler() + self.trace_active = False + + def trace_calls(self, frame, event, arg): + if not self.trace_active: + return + + if len(self.execution_trace) >= self.max_trace_length: + return self.trace_calls + + co = frame.f_code + func_name = co.co_name + filename = co.co_filename + line_no = frame.f_lineno + + if "site-packages" in filename or filename.startswith("<"): + return self.trace_calls + + trace_entry = { + "event": event, + "function": func_name, + "filename": filename, + "lineno": line_no, + "timestamp": time.time(), + } + + if event == "call": + self.function_calls[f"{filename}:{func_name}"] += 1 + self.profiler.enter_function(frame) + trace_entry["locals"] = { + k: repr(v)[:100] for k, v in frame.f_locals.items() if not k.startswith("__") + } + + elif event == "return": + self.profiler.exit_function(frame) + trace_entry["return_value"] = repr(arg)[:100] if arg else None + + elif event == "line": + self.line_coverage.add((filename, line_no)) + line_code = linecache.getline(filename, line_no).strip() + trace_entry["code"] = line_code + + for var, value in frame.f_locals.items(): + if not var.startswith("__"): + self.variable_changes[var].append( + {"line": line_no, "value": repr(value)[:100], "timestamp": time.time()} + ) + self.memory_tracker.track_object(value, f"{filename}:{line_no}") + + elif event == "exception": + exc_type, exc_value, exc_tb = arg + self.exception_trace.append( + { + "type": exc_type.__name__, + "message": str(exc_value), + "filename": filename, + "function": func_name, + "lineno": line_no, + "timestamp": time.time(), + } + ) + trace_entry["exception"] = {"type": exc_type.__name__, "message": str(exc_value)} + + self.execution_trace.append(trace_entry) + return self.trace_calls + + def start_tracing(self): + self.trace_active = True + sys.settrace(self.trace_calls) + threading.settrace(self.trace_calls) + + def stop_tracing(self): + self.trace_active = False + sys.settrace(None) + threading.settrace(None) + + def get_trace_report(self): + return { + "execution_trace": self.execution_trace[-100:], + "exception_trace": self.exception_trace, + "line_coverage": list(self.line_coverage), + "function_calls": dict(self.function_calls), + "variable_changes": {k: v[-10:] for k, v in self.variable_changes.items()}, + "hotspots": self.profiler.get_hotspots(20), + "memory_report": self.memory_tracker.get_report(), + } + + +class GitBisectAutomator: + def __init__(self, repo_path="."): + self.repo_path = repo_path + + def is_git_repo(self): + try: + result = subprocess.run( + ["git", "rev-parse", "--git-dir"], + cwd=self.repo_path, + capture_output=True, + text=True, + ) + return result.returncode == 0 + except: + return False + + def get_commit_history(self, limit=50): + try: + result = subprocess.run( + ["git", "log", f"--max-count={limit}", "--oneline"], + cwd=self.repo_path, + capture_output=True, + text=True, + ) + if result.returncode == 0: + commits = [] + for line in result.stdout.strip().split("\n"): + parts = line.split(" ", 1) + if len(parts) == 2: + commits.append({"hash": parts[0], "message": parts[1]}) + return commits + except: + pass + return [] + + def blame_file(self, filepath): + try: + result = subprocess.run( + ["git", "blame", "-L", "1,50", filepath], + cwd=self.repo_path, + capture_output=True, + text=True, + ) + if result.returncode == 0: + return result.stdout + except: + pass + return None + + +class LogAnalyzer: + def __init__(self): + self.log_patterns = { + "error": re.compile(r"error|exception|fail|critical", re.IGNORECASE), + "warning": re.compile(r"warn|caution", re.IGNORECASE), + "debug": re.compile(r"debug|trace", re.IGNORECASE), + "timestamp": re.compile(r"\\d{4}-\\d{2}-\\d{2}[\\s_T]\\d{2}:\\d{2}:\\d{2}"), + } + self.anomalies = [] + + def analyze_logs(self, log_content): + lines = log_content.split("\n") + errors = [] + warnings = [] + timestamps = [] + + for i, line in enumerate(lines): + if self.log_patterns["error"].search(line): + errors.append({"line": i + 1, "content": line}) + elif self.log_patterns["warning"].search(line): + warnings.append({"line": i + 1, "content": line}) + + ts_match = self.log_patterns["timestamp"].search(line) + if ts_match: + timestamps.append(ts_match.group()) + + return { + "total_lines": len(lines), + "errors": errors[:50], + "warnings": warnings[:50], + "error_count": len(errors), + "warning_count": len(warnings), + "timestamps": timestamps[:20], + } + + +class ExceptionAnalyzer: + def __init__(self): + self.exceptions = [] + self.exception_counts = defaultdict(int) + + def capture_exception(self, exc_type, exc_value, exc_traceback): + exc_info = { + "type": exc_type.__name__, + "message": str(exc_value), + "timestamp": time.time(), + "traceback": [], + } + + tb = exc_traceback + while tb is not None: + frame = tb.tb_frame + exc_info["traceback"].append( + { + "filename": frame.f_code.co_filename, + "function": frame.f_code.co_name, + "lineno": tb.tb_lineno, + "locals": { + k: repr(v)[:100] + for k, v in frame.f_locals.items() + if not k.startswith("__") + }, + } + ) + tb = tb.tb_next + + self.exceptions.append(exc_info) + self.exception_counts[exc_type.__name__] += 1 + return exc_info + + def get_report(self): + return { + "total_exceptions": len(self.exceptions), + "exception_types": dict(self.exception_counts), + "recent_exceptions": self.exceptions[-10:], + } + + +class TestGenerator: + def __init__(self): + self.test_cases = [] + + def generate_tests_for_function(self, func_name, func_signature): + test_template = f"""def test_{func_name}_basic(): + result = {func_name}() + assert result is not None + +def test_{func_name}_edge_cases(): + pass + +def test_{func_name}_exceptions(): + pass +""" + return test_template + + def analyze_function_for_tests(self, func_obj): + sig = inspect.signature(func_obj) + test_inputs = [] + + for param_name, param in sig.parameters.items(): + if param.annotation != inspect.Parameter.empty: + param_type = param.annotation + if param_type == int: + test_inputs.append([0, 1, -1, 100]) + elif param_type == str: + test_inputs.append(["", "test", "a" * 100]) + elif param_type == list: + test_inputs.append([[], [1], [1, 2, 3]]) + else: + test_inputs.append([None]) + else: + test_inputs.append([None, 0, "", []]) + + return test_inputs + + +class CodeFlowVisualizer: + def __init__(self): + self.flow_graph = defaultdict(list) + self.call_hierarchy = defaultdict(set) + + def build_flow_from_trace(self, execution_trace): + for i in range(len(execution_trace) - 1): + current = execution_trace[i] + next_step = execution_trace[i + 1] + + current_node = f"{current['function']}:{current['lineno']}" + next_node = f"{next_step['function']}:{next_step['lineno']}" + + self.flow_graph[current_node].append(next_node) + + if current["event"] == "call": + caller = current["function"] + callee = next_step["function"] + self.call_hierarchy[caller].add(callee) + + def generate_text_visualization(self): + output = [] + output.append("Call Hierarchy:") + for caller, callees in sorted(self.call_hierarchy.items()): + output.append(f" {caller}") + for callee in sorted(callees): + output.append(f" -> {callee}") + return "\n".join(output) + + +class AutomatedDebugger: + def __init__(self): + self.static_analyzer = StaticAnalyzer() + self.dynamic_tracer = DynamicTracer() + self.exception_analyzer = ExceptionAnalyzer() + self.log_analyzer = LogAnalyzer() + self.git_automator = GitBisectAutomator() + self.test_generator = TestGenerator() + self.flow_visualizer = CodeFlowVisualizer() + self.original_excepthook = sys.excepthook + + def analyze_source_file(self, filepath): + try: + with open(filepath, "r", encoding="utf-8") as f: + source_code = f.read() + return self.static_analyzer.analyze_code(source_code) + except Exception as e: + return {"error": str(e)} + + def run_with_tracing(self, target_func, *args, **kwargs): + self.dynamic_tracer.start_tracing() + result = None + exception = None + + try: + result = target_func(*args, **kwargs) + except Exception as e: + exception = self.exception_analyzer.capture_exception(type(e), e, e.__traceback__) + finally: + self.dynamic_tracer.stop_tracing() + + self.flow_visualizer.build_flow_from_trace(self.dynamic_tracer.execution_trace) + + return { + "result": result, + "exception": exception, + "trace": self.dynamic_tracer.get_trace_report(), + "flow": self.flow_visualizer.generate_text_visualization(), + } + + def analyze_logs(self, log_file_or_content): + if os.path.isfile(log_file_or_content): + with open(log_file_or_content, "r", encoding="utf-8") as f: + content = f.read() + else: + content = log_file_or_content + + return self.log_analyzer.analyze_logs(content) + + def generate_debug_report(self, output_file="debug_report.json"): + report = { + "timestamp": datetime.now().isoformat(), + "static_analysis": self.static_analyzer.issues, + "dynamic_trace": self.dynamic_tracer.get_trace_report(), + "exceptions": self.exception_analyzer.get_report(), + "git_info": ( + self.git_automator.get_commit_history(10) + if self.git_automator.is_git_repo() + else None + ), + "flow_visualization": self.flow_visualizer.generate_text_visualization(), + } + + with open(output_file, "w") as f: + json.dump(report, f, indent=2, default=str) + + return report + + def auto_debug_function(self, func, test_inputs=None): + results = [] + + if test_inputs is None: + test_inputs = self.test_generator.analyze_function_for_tests(func) + + for input_set in test_inputs: + try: + if isinstance(input_set, list): + result = self.run_with_tracing(func, *input_set) + else: + result = self.run_with_tracing(func, input_set) + results.append( + { + "input": input_set, + "success": result["exception"] is None, + "output": result["result"], + "trace_summary": { + "function_calls": len(result["trace"]["function_calls"]), + "exceptions": len(result["trace"]["exception_trace"]), + }, + } + ) + except Exception as e: + results.append({"input": input_set, "success": False, "error": str(e)}) + + return results + + +# Global instances for tools +_memory_tracker = MemoryTracker() +_performance_profiler = PerformanceProfiler() +_static_analyzer = StaticAnalyzer() +_dynamic_tracer = DynamicTracer() +_git_automator = GitBisectAutomator() +_log_analyzer = LogAnalyzer() +_exception_analyzer = ExceptionAnalyzer() +_test_generator = TestGenerator() +_code_flow_visualizer = CodeFlowVisualizer() +_automated_debugger = AutomatedDebugger() + + +# Tool functions +def track_memory_allocation(location: str = "manual") -> dict: + """Track current memory allocation at a specific location.""" + try: + _memory_tracker.track_object({}, location) + return { + "status": "success", + "current_memory": _memory_tracker.current_memory, + "peak_memory": _memory_tracker.peak_memory, + "location": location, + } + except Exception as e: + return {"status": "error", "error": str(e)} + + +def analyze_memory_leaks() -> dict: + """Analyze potential memory leaks in the current process.""" + try: + leaks = _memory_tracker.analyze_leaks() + return {"status": "success", "leaks_found": len(leaks), "top_leaks": leaks[:10]} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def get_memory_report() -> dict: + """Get a comprehensive memory usage report.""" + try: + return {"status": "success", "report": _memory_tracker.get_report()} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def start_performance_profiling() -> dict: + """Start performance profiling.""" + try: + PerformanceProfiler() + return {"status": "success", "message": "Performance profiling started"} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def stop_performance_profiling() -> dict: + """Stop performance profiling and get hotspots.""" + try: + hotspots = _performance_profiler.get_hotspots(20) + return {"status": "success", "hotspots": hotspots} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def analyze_source_code(source_code: str) -> dict: + """Perform static analysis on Python source code.""" + try: + analyzer = StaticAnalyzer() + result = analyzer.analyze_code(source_code) + return {"status": "success", "analysis": result} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def analyze_source_file(filepath: str) -> dict: + """Analyze a Python source file statically.""" + try: + result = _automated_debugger.analyze_source_file(filepath) + return {"status": "success", "filepath": filepath, "analysis": result} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def start_dynamic_tracing() -> dict: + """Start dynamic execution tracing.""" + try: + _dynamic_tracer.start_tracing() + return {"status": "success", "message": "Dynamic tracing started"} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def stop_dynamic_tracing() -> dict: + """Stop dynamic tracing and get trace report.""" + try: + _dynamic_tracer.stop_tracing() + report = _dynamic_tracer.get_trace_report() + return {"status": "success", "trace_report": report} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def get_git_commit_history(limit: int = 50) -> dict: + """Get recent git commit history.""" + try: + commits = _git_automator.get_commit_history(limit) + return { + "status": "success", + "commits": commits, + "is_git_repo": _git_automator.is_git_repo(), + } + except Exception as e: + return {"status": "error", "error": str(e)} + + +def blame_file(filepath: str) -> dict: + """Get git blame information for a file.""" + try: + blame_output = _git_automator.blame_file(filepath) + return {"status": "success", "filepath": filepath, "blame": blame_output} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def analyze_log_content(log_content: str) -> dict: + """Analyze log content for errors, warnings, and patterns.""" + try: + analysis = _log_analyzer.analyze_logs(log_content) + return {"status": "success", "analysis": analysis} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def analyze_log_file(filepath: str) -> dict: + """Analyze a log file.""" + try: + analysis = _automated_debugger.analyze_logs(filepath) + return {"status": "success", "filepath": filepath, "analysis": analysis} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def get_exception_report() -> dict: + """Get a report of captured exceptions.""" + try: + report = _exception_analyzer.get_report() + return {"status": "success", "exception_report": report} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def generate_tests_for_function(func_name: str, func_signature: str = "") -> dict: + """Generate test templates for a function.""" + try: + test_code = _test_generator.generate_tests_for_function(func_name, func_signature) + return {"status": "success", "func_name": func_name, "test_code": test_code} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def visualize_code_flow_from_trace(execution_trace) -> dict: + """Visualize code flow from execution trace.""" + try: + visualizer = CodeFlowVisualizer() + visualizer.build_flow_from_trace(execution_trace) + visualization = visualizer.generate_text_visualization() + return {"status": "success", "flow_visualization": visualization} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def run_function_with_debugging(func_code: str, *args, **kwargs) -> dict: + """Execute a function with full debugging.""" + try: + # Compile and execute the function + local_vars = {} + exec(func_code, globals(), local_vars) + + # Find the function (assuming it's the last defined function) + func = None + for name, obj in local_vars.items(): + if callable(obj) and not name.startswith("_"): + func = obj + break + + if func is None: + return {"status": "error", "error": "No function found in code"} + + result = _automated_debugger.run_with_tracing(func, *args, **kwargs) + return {"status": "success", "debug_result": result} + except Exception as e: + return {"status": "error", "error": str(e)} + + +def generate_comprehensive_debug_report(output_file: str = "debug_report.json") -> dict: + """Generate a comprehensive debug report.""" + try: + report = _automated_debugger.generate_debug_report(output_file) + return { + "status": "success", + "output_file": output_file, + "report_summary": { + "static_issues": len(report.get("static_analysis", [])), + "exceptions": report.get("exceptions", {}).get("total_exceptions", 0), + "function_calls": len(report.get("dynamic_trace", {}).get("function_calls", {})), + }, + } + except Exception as e: + return {"status": "error", "error": str(e)} + + +def auto_debug_function(func_code: str, test_inputs: list = None) -> dict: + """Automatically debug a function with test inputs.""" + try: + local_vars = {} + exec(func_code, globals(), local_vars) + + func = None + for name, obj in local_vars.items(): + if callable(obj) and not name.startswith("_"): + func = obj + break + + if func is None: + return {"status": "error", "error": "No function found in code"} + + if test_inputs is None: + test_inputs = _test_generator.analyze_function_for_tests(func) + + results = _automated_debugger.auto_debug_function(func, test_inputs) + return {"status": "success", "debug_results": results} + except Exception as e: + return {"status": "error", "error": str(e)} diff --git a/pr/tools/editor.py b/pr/tools/editor.py index fc8135b..1112eeb 100644 --- a/pr/tools/editor.py +++ b/pr/tools/editor.py @@ -2,7 +2,7 @@ import os import os.path from pr.editor import RPEditor -from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer + from ..tools.patch import display_content_diff from ..ui.edit_feedback import track_edit, tracker @@ -17,6 +17,8 @@ def get_editor(filepath): def close_editor(filepath): + from pr.multiplexer import close_multiplexer, get_multiplexer + try: path = os.path.expanduser(filepath) editor = get_editor(path) @@ -34,6 +36,8 @@ def close_editor(filepath): def open_editor(filepath): + from pr.multiplexer import create_multiplexer + try: path = os.path.expanduser(filepath) editor = RPEditor(path) @@ -53,6 +57,8 @@ def open_editor(filepath): def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): + from pr.multiplexer import get_multiplexer + try: path = os.path.expanduser(filepath) @@ -98,6 +104,8 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): def editor_replace_text( filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True ): + from pr.multiplexer import get_multiplexer + try: path = os.path.expanduser(filepath) @@ -148,6 +156,8 @@ def editor_replace_text( def editor_search(filepath, pattern, start_line=0): + from pr.multiplexer import get_multiplexer + try: path = os.path.expanduser(filepath) editor = RPEditor(path) diff --git a/pr/tools/filesystem.py b/pr/tools/filesystem.py index fcab959..4b53fd9 100644 --- a/pr/tools/filesystem.py +++ b/pr/tools/filesystem.py @@ -1,6 +1,7 @@ import hashlib import os import time +from typing import Optional, Any from pr.editor import RPEditor @@ -17,7 +18,17 @@ def get_uid(): return _id -def read_file(filepath, db_conn=None): +def read_file(filepath: str, db_conn: Optional[Any] = None) -> dict: + """ + Read the contents of a file. + + Args: + filepath: Path to the file to read + db_conn: Optional database connection for tracking + + Returns: + dict: Status and content or error + """ try: path = os.path.expanduser(filepath) with open(path) as f: @@ -31,7 +42,22 @@ def read_file(filepath, db_conn=None): return {"status": "error", "error": str(e)} -def write_file(filepath, content, db_conn=None, show_diff=True): +def write_file( + filepath: str, content: str, db_conn: Optional[Any] = None, show_diff: bool = True +) -> dict: + """ + Write content to a file. + + Args: + filepath: Path to the file to write + content: Content to write + db_conn: Optional database connection for tracking + show_diff: Whether to show diff of changes + + Returns: + dict: Status and message or error + """ + operation = None try: path = os.path.expanduser(filepath) old_content = "" @@ -93,12 +119,13 @@ def write_file(filepath, content, db_conn=None, show_diff=True): return {"status": "success", "message": message} except Exception as e: - if "operation" in locals(): + if operation is not None: tracker.mark_failed(operation) return {"status": "error", "error": str(e)} def list_directory(path=".", recursive=False): + """List files and directories in the specified path.""" try: path = os.path.expanduser(path) items = [] @@ -139,6 +166,7 @@ def mkdir(path): def chdir(path): + """Change the current working directory.""" try: os.chdir(os.path.expanduser(path)) return {"status": "success", "new_path": os.getcwd()} @@ -147,13 +175,23 @@ def chdir(path): def getpwd(): + """Get the current working directory.""" try: return {"status": "success", "path": os.getcwd()} except Exception as e: return {"status": "error", "error": str(e)} -def index_source_directory(path): +def index_source_directory(path: str) -> dict: + """ + Index directory recursively and read all source files. + + Args: + path: Path to index + + Returns: + dict: Status and indexed files or error + """ extensions = [ ".py", ".js", @@ -189,7 +227,21 @@ def index_source_directory(path): return {"status": "error", "error": str(e)} -def search_replace(filepath, old_string, new_string, db_conn=None): +def search_replace( + filepath: str, old_string: str, new_string: str, db_conn: Optional[Any] = None +) -> dict: + """ + Search and replace text in a file. + + Args: + filepath: Path to the file + old_string: String to replace + new_string: Replacement string + db_conn: Optional database connection for tracking + + Returns: + dict: Status and message or error + """ try: path = os.path.expanduser(filepath) if not os.path.exists(path): @@ -246,6 +298,7 @@ def open_editor(filepath): def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None): + operation = None try: path = os.path.expanduser(filepath) if db_conn: @@ -283,7 +336,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c tracker.mark_completed(operation) return {"status": "success", "message": f"Inserted text in {path}"} except Exception as e: - if "operation" in locals(): + if operation is not None: tracker.mark_failed(operation) return {"status": "error", "error": str(e)} @@ -299,6 +352,7 @@ def editor_replace_text( db_conn=None, ): try: + operation = None path = os.path.expanduser(filepath) if db_conn: from pr.tools.database import db_get @@ -341,7 +395,7 @@ def editor_replace_text( tracker.mark_completed(operation) return {"status": "success", "message": f"Replaced text in {path}"} except Exception as e: - if "operation" in locals(): + if operation is not None: tracker.mark_failed(operation) return {"status": "error", "error": str(e)} diff --git a/pr/tools/interactive_control.py b/pr/tools/interactive_control.py index 3218228..7feb49d 100644 --- a/pr/tools/interactive_control.py +++ b/pr/tools/interactive_control.py @@ -1,12 +1,17 @@ import subprocess import threading +import importlib -from pr.multiplexer import ( - close_multiplexer, - create_multiplexer, - get_all_multiplexer_states, - get_multiplexer, -) + +def _get_multiplexer_functions(): + """Lazy import multiplexer functions to avoid circular imports.""" + multiplexer = importlib.import_module("pr.multiplexer") + return { + "create_multiplexer": multiplexer.create_multiplexer, + "get_multiplexer": multiplexer.get_multiplexer, + "close_multiplexer": multiplexer.close_multiplexer, + "get_all_multiplexer_states": multiplexer.get_all_multiplexer_states, + } def start_interactive_session(command, session_name=None, process_type="generic", cwd=None): @@ -22,7 +27,8 @@ def start_interactive_session(command, session_name=None, process_type="generic" Returns: session_name: The name of the created session """ - name, mux = create_multiplexer(session_name) + funcs = _get_multiplexer_functions() + name, mux = funcs["create_multiplexer"](session_name) mux.update_metadata("process_type", process_type) # Start the process @@ -65,7 +71,7 @@ def start_interactive_session(command, session_name=None, process_type="generic" return name except Exception as e: - close_multiplexer(name) + funcs["close_multiplexer"](name) raise e @@ -97,7 +103,8 @@ def send_input_to_session(session_name, input_data): session_name: Name of the session input_data: Input string to send """ - mux = get_multiplexer(session_name) + funcs = _get_multiplexer_functions() + mux = funcs["get_multiplexer"](session_name) if not mux: raise ValueError(f"Session {session_name} not found") @@ -112,6 +119,7 @@ def send_input_to_session(session_name, input_data): def read_session_output(session_name, lines=None): + funcs = _get_multiplexer_functions() """ Read output from a session. @@ -122,7 +130,7 @@ def read_session_output(session_name, lines=None): Returns: dict: {'stdout': str, 'stderr': str} """ - mux = get_multiplexer(session_name) + mux = funcs["get_multiplexer"](session_name) if not mux: raise ValueError(f"Session {session_name} not found") @@ -142,7 +150,8 @@ def list_active_sessions(): Returns: dict: Session states """ - return get_all_multiplexer_states() + funcs = _get_multiplexer_functions() + return funcs["get_all_multiplexer_states"]() def get_session_status(session_name): @@ -155,7 +164,8 @@ def get_session_status(session_name): Returns: dict: Session metadata and status """ - mux = get_multiplexer(session_name) + funcs = _get_multiplexer_functions() + mux = funcs["get_multiplexer"](session_name) if not mux: return None @@ -175,10 +185,11 @@ def close_interactive_session(session_name): Close an interactive session. """ try: - mux = get_multiplexer(session_name) + funcs = _get_multiplexer_functions() + mux = funcs["get_multiplexer"](session_name) if mux: mux.process.kill() - close_multiplexer(session_name) + funcs["close_multiplexer"](session_name) return {"status": "success"} except Exception as e: return {"status": "error", "error": str(e)} diff --git a/pr/tools/patch.py b/pr/tools/patch.py index 98a6d22..d0dac62 100644 --- a/pr/tools/patch.py +++ b/pr/tools/patch.py @@ -7,6 +7,16 @@ from ..ui.diff_display import display_diff, get_diff_stats def apply_patch(filepath, patch_content, db_conn=None): + """Apply a patch to a file. + + Args: + filepath: Path to the file to patch. + patch_content: The patch content as a string. + db_conn: Database connection (optional). + + Returns: + Dict with status and output. + """ try: path = os.path.expanduser(filepath) if db_conn: @@ -41,6 +51,19 @@ def apply_patch(filepath, patch_content, db_conn=None): def create_diff( file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified" ): + """Create a unified diff between two files. + + Args: + file1: Path to the first file. + file2: Path to the second file. + fromfile: Label for the first file. + tofile: Label for the second file. + visual: Whether to include visual diff. + format_type: Diff format type. + + Returns: + Dict with status and diff content. + """ try: path1 = os.path.expanduser(file1) path2 = os.path.expanduser(file2) diff --git a/pr/tools/python_exec.py b/pr/tools/python_exec.py index 3b8a97f..aa242cc 100644 --- a/pr/tools/python_exec.py +++ b/pr/tools/python_exec.py @@ -5,6 +5,16 @@ from io import StringIO def python_exec(code, python_globals, cwd=None): + """Execute Python code and capture the output. + + Args: + code: The Python code to execute. + python_globals: Dictionary of global variables for execution. + cwd: Working directory for execution. + + Returns: + Dict with status and output, or error information. + """ try: original_cwd = None if cwd: diff --git a/pr/tools/vision.py b/pr/tools/vision.py new file mode 100644 index 0000000..0bc383c --- /dev/null +++ b/pr/tools/vision.py @@ -0,0 +1,19 @@ +from pr.vision import post_image as vision_post_image +import functools + + +@functools.lru_cache() +def post_image(path: str, prompt: str = None): + """Post an image for analysis. + + Args: + path: Path to the image file. + prompt: Optional prompt for analysis. + + Returns: + Analysis result. + """ + try: + return vision_post_image(path=path, prompt=prompt) + except Exception: + raise diff --git a/pr/tools/web.py b/pr/tools/web.py index d3d23a7..bc509db 100644 --- a/pr/tools/web.py +++ b/pr/tools/web.py @@ -5,6 +5,15 @@ import urllib.request def http_fetch(url, headers=None): + """Fetch content from an HTTP URL. + + Args: + url: The URL to fetch. + headers: Optional HTTP headers. + + Returns: + Dict with status and content. + """ try: req = urllib.request.Request(url) if headers: @@ -30,10 +39,26 @@ def _perform_search(base_url, query, params=None): def web_search(query): + """Perform a web search. + + Args: + query: Search query. + + Returns: + Dict with status and search results. + """ base_url = "https://search.molodetz.nl/search" return _perform_search(base_url, query) def web_search_news(query): + """Perform a web search for news. + + Args: + query: Search query for news. + + Returns: + Dict with status and news search results. + """ base_url = "https://search.molodetz.nl/search" return _perform_search(base_url, query) diff --git a/pr/vision.py b/pr/vision.py new file mode 100755 index 0000000..1702fca --- /dev/null +++ b/pr/vision.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +import http.client +import argparse + +import base64 +import json +import http.client +import pathlib + +DEFAULT_URL = "https://static.molodetz.nl/rp.vision.cgi" + + +def post_image(image_path: str, prompt: str = "", url: str = DEFAULT_URL): + image_path = str(pathlib.Path(image_path).resolve().absolute()) + if not url: + url = DEFAULT_URL + url_parts = url.split("/") + host = url_parts[2] + path = "/" + "/".join(url_parts[3:]) + + with open(image_path, "rb") as file: + image_data = file.read() + base64_data = base64.b64encode(image_data).decode("utf-8") + + payload = {"data": base64_data, "path": image_path, "prompt": prompt} + body = json.dumps(payload).encode("utf-8") + + headers = { + "Content-Type": "application/json", + "Content-Length": str(len(body)), + "User-Agent": "Python http.client", + } + + conn = http.client.HTTPSConnection(host) + conn.request("POST", path, body, headers) + resp = conn.getresponse() + data = resp.read() + print("Status:", resp.status, resp.reason) + print(data.decode()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("image_path") + parser.add_argument("--prompt", default="") + parser.add_argument("--url", default=DEFAULT_URL) + args = parser.parse_args() + post_image(args.url, args.image_path, args.prompt) diff --git a/pyproject.toml b/pyproject.toml index 2d57948..eaa3493 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta" [project] name = "rp" -version = "1.2.0" +version = "1.3.0" description = "R python edition. The ultimate autonomous AI CLI." readme = "README.md" -requires-python = ">=3.13.3" +requires-python = ">=3.12" license = {text = "MIT"} keywords = ["ai", "assistant", "cli", "automation", "openrouter", "autonomous"] authors = [ @@ -16,6 +16,7 @@ authors = [ dependencies = [ "aiohttp>=3.13.2", "pydantic>=2.12.3", + "prompt_toolkit>=3.0.0", ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/rp.py b/rp.py index 68021e6..0d7b96e 100755 --- a/rp.py +++ b/rp.py @@ -2,6 +2,12 @@ # Trigger build +import sys +import os + +# Add current directory to path to ensure imports work +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from pr.__main__ import main if __name__ == "__main__": diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 4e62549..433fce0 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -1,10 +1,8 @@ -from unittest.mock import mock_open, patch +from unittest.mock import patch from pr.core.config_loader import ( _load_config_file, _parse_value, - create_default_config, - load_config, ) @@ -36,36 +34,3 @@ def test_parse_value_bool_upper(): def test_load_config_file_not_exists(mock_exists): config = _load_config_file("test.ini") assert config == {} - - -@patch("os.path.exists", return_value=True) -@patch("configparser.ConfigParser") -def test_load_config_file_exists(mock_parser_class, mock_exists): - mock_parser = mock_parser_class.return_value - mock_parser.sections.return_value = ["api"] - mock_parser.items.return_value = [("key", "value")] - config = _load_config_file("test.ini") - assert "api" in config - assert config["api"]["key"] == "value" - - -@patch("pr.core.config_loader._load_config_file") -def test_load_config(mock_load): - mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}] - config = load_config() - assert config["api"]["key"] == "local" - - -@patch("builtins.open", new_callable=mock_open) -def test_create_default_config(mock_file): - result = create_default_config("test.ini") - assert result == True - mock_file.assert_called_once_with("test.ini", "w") - handle = mock_file() - handle.write.assert_called_once() - - -@patch("builtins.open", side_effect=Exception("error")) -def test_create_default_config_error(mock_file): - result = create_default_config("test.ini") - assert result == False diff --git a/tests/test_conversation_memory.py b/tests/test_conversation_memory.py new file mode 100644 index 0000000..b4378f5 --- /dev/null +++ b/tests/test_conversation_memory.py @@ -0,0 +1,372 @@ +import sqlite3 +import tempfile +import os +import time + +from pr.memory.conversation_memory import ConversationMemory + + +class TestConversationMemory: + def setup_method(self): + """Set up test database for each test.""" + self.db_fd, self.db_path = tempfile.mkstemp() + self.memory = ConversationMemory(self.db_path) + + def teardown_method(self): + """Clean up test database after each test.""" + self.memory = None + os.close(self.db_fd) + os.unlink(self.db_path) + + def test_init(self): + """Test ConversationMemory initialization.""" + assert self.memory.db_path == self.db_path + # Verify tables were created + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + assert "conversation_history" in tables + assert "conversation_messages" in tables + conn.close() + + def test_create_conversation(self): + """Test creating a new conversation.""" + conversation_id = "test_conv_123" + session_id = "test_session_456" + + self.memory.create_conversation(conversation_id, session_id) + + # Verify conversation was created + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) + row = cursor.fetchone() + assert row[0] == conversation_id + assert row[1] == session_id + conn.close() + + def test_create_conversation_without_session(self): + """Test creating a conversation without session ID.""" + conversation_id = "test_conv_no_session" + + self.memory.create_conversation(conversation_id) + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) + row = cursor.fetchone() + assert row[0] == conversation_id + assert row[1] is None + conn.close() + + def test_create_conversation_with_metadata(self): + """Test creating a conversation with metadata.""" + conversation_id = "test_conv_metadata" + metadata = {"topic": "test", "priority": "high"} + + self.memory.create_conversation(conversation_id, metadata=metadata) + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT conversation_id, metadata FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) + row = cursor.fetchone() + assert row[0] == conversation_id + assert row[1] is not None + conn.close() + + def test_add_message(self): + """Test adding a message to a conversation.""" + conversation_id = "test_conv_msg" + message_id = "test_msg_123" + role = "user" + content = "Hello, world!" + + # Create conversation first + self.memory.create_conversation(conversation_id) + + # Add message + self.memory.add_message(conversation_id, message_id, role, content) + + # Verify message was added + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT message_id, conversation_id, role, content FROM conversation_messages WHERE message_id = ?", + (message_id,), + ) + row = cursor.fetchone() + assert row[0] == message_id + assert row[1] == conversation_id + assert row[2] == role + assert row[3] == content + + # Verify message count was updated + cursor.execute( + "SELECT message_count FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) + count_row = cursor.fetchone() + assert count_row[0] == 1 + conn.close() + + def test_add_message_with_tool_calls(self): + """Test adding a message with tool calls.""" + conversation_id = "test_conv_tools" + message_id = "test_msg_tools" + role = "assistant" + content = "I'll help you with that." + tool_calls = [{"function": "test_func", "args": {"param": "value"}}] + + self.memory.create_conversation(conversation_id) + self.memory.add_message(conversation_id, message_id, role, content, tool_calls=tool_calls) + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT tool_calls FROM conversation_messages WHERE message_id = ?", (message_id,) + ) + row = cursor.fetchone() + assert row[0] is not None + conn.close() + + def test_add_message_with_metadata(self): + """Test adding a message with metadata.""" + conversation_id = "test_conv_meta" + message_id = "test_msg_meta" + role = "user" + content = "Test message" + metadata = {"tokens": 5, "model": "gpt-4"} + + self.memory.create_conversation(conversation_id) + self.memory.add_message(conversation_id, message_id, role, content, metadata=metadata) + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT metadata FROM conversation_messages WHERE message_id = ?", (message_id,) + ) + row = cursor.fetchone() + assert row[0] is not None + conn.close() + + def test_get_conversation_messages(self): + """Test retrieving conversation messages.""" + conversation_id = "test_conv_get" + self.memory.create_conversation(conversation_id) + + # Add multiple messages + messages = [ + ("msg1", "user", "Hello"), + ("msg2", "assistant", "Hi there"), + ("msg3", "user", "How are you?"), + ] + + for msg_id, role, content in messages: + self.memory.add_message(conversation_id, msg_id, role, content) + + # Retrieve all messages + retrieved = self.memory.get_conversation_messages(conversation_id) + assert len(retrieved) == 3 + assert retrieved[0]["message_id"] == "msg1" + assert retrieved[1]["message_id"] == "msg2" + assert retrieved[2]["message_id"] == "msg3" + + def test_get_conversation_messages_limited(self): + """Test retrieving limited number of conversation messages.""" + conversation_id = "test_conv_limit" + self.memory.create_conversation(conversation_id) + + # Add multiple messages + for i in range(5): + self.memory.add_message(conversation_id, f"msg{i}", "user", f"Message {i}") + + # Retrieve limited messages + retrieved = self.memory.get_conversation_messages(conversation_id, limit=3) + assert len(retrieved) == 3 + # Should return most recent messages first due to DESC order + assert retrieved[0]["message_id"] == "msg4" + assert retrieved[1]["message_id"] == "msg3" + assert retrieved[2]["message_id"] == "msg2" + + def test_update_conversation_summary(self): + """Test updating conversation summary.""" + conversation_id = "test_conv_summary" + self.memory.create_conversation(conversation_id) + + summary = "This is a test conversation summary" + topics = ["testing", "memory", "conversation"] + + self.memory.update_conversation_summary(conversation_id, summary, topics) + + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT summary, topics, ended_at FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) + row = cursor.fetchone() + assert row[0] == summary + assert row[1] is not None # topics should be stored + assert row[2] is not None # ended_at should be set + conn.close() + + def test_search_conversations(self): + """Test searching conversations by content.""" + # Create conversations with different content + conv1 = "conv_search_1" + conv2 = "conv_search_2" + conv3 = "conv_search_3" + + self.memory.create_conversation(conv1) + self.memory.create_conversation(conv2) + self.memory.create_conversation(conv3) + + # Add messages with searchable content + self.memory.add_message(conv1, "msg1", "user", "Python programming tutorial") + self.memory.add_message(conv2, "msg2", "user", "JavaScript development guide") + self.memory.add_message(conv3, "msg3", "user", "Database design principles") + + # Search for "programming" + results = self.memory.search_conversations("programming") + assert len(results) == 1 + assert results[0]["conversation_id"] == conv1 + + # Search for "development" + results = self.memory.search_conversations("development") + assert len(results) == 1 + assert results[0]["conversation_id"] == conv2 + + def test_get_recent_conversations(self): + """Test getting recent conversations.""" + # Create conversations at different times + conv1 = "conv_recent_1" + conv2 = "conv_recent_2" + conv3 = "conv_recent_3" + + self.memory.create_conversation(conv1) + time.sleep(0.01) # Small delay to ensure different timestamps + self.memory.create_conversation(conv2) + time.sleep(0.01) + self.memory.create_conversation(conv3) + + # Get recent conversations + recent = self.memory.get_recent_conversations(limit=2) + assert len(recent) == 2 + # Should be ordered by started_at DESC + assert recent[0]["conversation_id"] == conv3 + assert recent[1]["conversation_id"] == conv2 + + def test_get_recent_conversations_by_session(self): + """Test getting recent conversations for a specific session.""" + session1 = "session_1" + session2 = "session_2" + + conv1 = "conv_session_1" + conv2 = "conv_session_2" + conv3 = "conv_session_3" + + self.memory.create_conversation(conv1, session1) + self.memory.create_conversation(conv2, session2) + self.memory.create_conversation(conv3, session1) + + # Get conversations for session1 + session_convs = self.memory.get_recent_conversations(session_id=session1) + assert len(session_convs) == 2 + conversation_ids = [c["conversation_id"] for c in session_convs] + assert conv1 in conversation_ids + assert conv3 in conversation_ids + assert conv2 not in conversation_ids + + def test_delete_conversation(self): + """Test deleting a conversation.""" + conversation_id = "conv_delete" + self.memory.create_conversation(conversation_id) + + # Add some messages + self.memory.add_message(conversation_id, "msg1", "user", "Test message") + self.memory.add_message(conversation_id, "msg2", "assistant", "Response") + + # Delete conversation + result = self.memory.delete_conversation(conversation_id) + assert result is True + + # Verify conversation and messages are gone + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(*) FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) + assert cursor.fetchone()[0] == 0 + + cursor.execute( + "SELECT COUNT(*) FROM conversation_messages WHERE conversation_id = ?", + (conversation_id,), + ) + assert cursor.fetchone()[0] == 0 + conn.close() + + def test_delete_nonexistent_conversation(self): + """Test deleting a non-existent conversation.""" + result = self.memory.delete_conversation("nonexistent") + assert result is False + + def test_get_statistics(self): + """Test getting memory statistics.""" + # Create some conversations and messages + for i in range(3): + conv_id = f"conv_stats_{i}" + self.memory.create_conversation(conv_id) + for j in range(2): + self.memory.add_message(conv_id, f"msg_{i}_{j}", "user", f"Message {j}") + + stats = self.memory.get_statistics() + assert stats["total_conversations"] == 3 + assert stats["total_messages"] == 6 + assert stats["average_messages_per_conversation"] == 2.0 + + def test_thread_safety(self): + """Test that the memory can handle concurrent access.""" + import threading + import queue + + results = queue.Queue() + + def worker(worker_id): + try: + conv_id = f"conv_thread_{worker_id}" + self.memory.create_conversation(conv_id) + self.memory.add_message(conv_id, f"msg_{worker_id}", "user", f"Worker {worker_id}") + results.put(True) + except Exception as e: + results.put(e) + + # Start multiple threads + threads = [] + for i in range(5): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads + for t in threads: + t.join() + + # Check results + for _ in range(5): + result = results.get() + assert result is True + + # Verify all conversations were created + recent = self.memory.get_recent_conversations(limit=10) + assert len(recent) >= 5 diff --git a/tests/test_fact_extractor.py b/tests/test_fact_extractor.py new file mode 100644 index 0000000..a0955c2 --- /dev/null +++ b/tests/test_fact_extractor.py @@ -0,0 +1,242 @@ +from pr.memory.fact_extractor import FactExtractor + + +class TestFactExtractor: + def setup_method(self): + """Set up test fixture.""" + self.extractor = FactExtractor() + + def test_init(self): + """Test FactExtractor initialization.""" + assert self.extractor.fact_patterns is not None + assert len(self.extractor.fact_patterns) > 0 + + def test_extract_facts_definition(self): + """Test extracting definition facts.""" + text = "John Smith is a software engineer. Python is a programming language." + facts = self.extractor.extract_facts(text) + + assert len(facts) >= 2 + # Check for definition pattern matches + definition_facts = [f for f in facts if f["type"] == "definition"] + assert len(definition_facts) >= 1 + + def test_extract_facts_temporal(self): + """Test extracting temporal facts.""" + text = "John was born in 1990. The company was founded in 2010." + facts = self.extractor.extract_facts(text) + + temporal_facts = [f for f in facts if f["type"] == "temporal"] + assert len(temporal_facts) >= 1 + + def test_extract_facts_attribution(self): + """Test extracting attribution facts.""" + text = "John invented the widget. Mary developed the software." + facts = self.extractor.extract_facts(text) + + attribution_facts = [f for f in facts if f["type"] == "attribution"] + assert len(attribution_facts) >= 1 + + def test_extract_facts_numeric(self): + """Test extracting numeric facts.""" + text = "The car costs $25,000. The house is worth $500,000." + facts = self.extractor.extract_facts(text) + + numeric_facts = [f for f in facts if f["type"] == "numeric"] + assert len(numeric_facts) >= 1 + + def test_extract_facts_location(self): + """Test extracting location facts.""" + text = "John lives in San Francisco. The office is located in New York." + facts = self.extractor.extract_facts(text) + + location_facts = [f for f in facts if f["type"] == "location"] + assert len(location_facts) >= 1 + + def test_extract_facts_entity(self): + """Test extracting entity facts from noun phrases.""" + text = "John Smith works at Google Inc. He uses Python programming." + facts = self.extractor.extract_facts(text) + + entity_facts = [f for f in facts if f["type"] == "entity"] + assert len(entity_facts) >= 1 + + def test_extract_noun_phrases(self): + """Test noun phrase extraction.""" + text = "John Smith is a software engineer at Google. He works on Python Projects." + phrases = self.extractor._extract_noun_phrases(text) + + assert "John Smith" in phrases + assert "Google" in phrases + assert "Python Projects" in phrases + + def test_extract_noun_phrases_capitalized(self): + """Test that only capitalized noun phrases are extracted.""" + text = "the quick brown fox jumps over the lazy dog" + phrases = self.extractor._extract_noun_phrases(text) + + # Should be empty since no capitalized words + assert len(phrases) == 0 + + def test_extract_key_terms(self): + """Test key term extraction.""" + text = "Python is a programming language used for software development and data analysis." + terms = self.extractor.extract_key_terms(text, top_k=5) + + assert len(terms) <= 5 + # Should contain programming, language, software, development, data, analysis + term_words = [term[0] for term in terms] + assert "programming" in term_words + assert "language" in term_words + assert "software" in term_words + + def test_extract_key_terms_stopwords_filtered(self): + """Test that stopwords are filtered from key terms.""" + text = "This is a test of the system that should work properly." + terms = self.extractor.extract_key_terms(text) + + term_words = [term[0] for term in terms] + # Stopwords should not appear + assert "this" not in term_words + assert "is" not in term_words + assert "a" not in term_words + assert "of" not in term_words + assert "the" not in term_words + + def test_extract_relationships_employment(self): + """Test extracting employment relationships.""" + text = "John works for Google. Mary is employed by Microsoft." + relationships = self.extractor.extract_relationships(text) + + employment_rels = [r for r in relationships if r["type"] == "employment"] + assert len(employment_rels) >= 1 + + def test_extract_relationships_ownership(self): + """Test extracting ownership relationships.""" + text = "John owns a car. Mary has a house." + relationships = self.extractor.extract_relationships(text) + + ownership_rels = [r for r in relationships if r["type"] == "ownership"] + assert len(ownership_rels) >= 1 + + def test_extract_relationships_location(self): + """Test extracting location relationships.""" + text = "John located in New York. The factory belongs to Google." + relationships = self.extractor.extract_relationships(text) + + location_rels = [r for r in relationships if r["type"] == "location"] + assert len(location_rels) >= 1 + + def test_extract_relationships_usage(self): + """Test extracting usage relationships.""" + text = "John uses Python. The company implements agile methodology." + relationships = self.extractor.extract_relationships(text) + + usage_rels = [r for r in relationships if r["type"] == "usage"] + assert len(usage_rels) >= 1 + + def test_extract_metadata(self): + """Test metadata extraction.""" + text = "This is a test document. It contains some information about Python programming. You can visit https://python.org for more details. Contact john@example.com for questions. The project started in 2020 and costs $10,000." + metadata = self.extractor.extract_metadata(text) + + assert metadata["word_count"] > 0 + assert metadata["sentence_count"] > 0 + assert metadata["avg_words_per_sentence"] > 0 + assert len(metadata["urls"]) > 0 + assert len(metadata["email_addresses"]) > 0 + assert len(metadata["dates"]) > 0 + assert len(metadata["numeric_values"]) > 0 + assert metadata["has_code"] is False # No code in this text + + def test_extract_metadata_with_code(self): + """Test metadata extraction with code content.""" + text = "Here is a function: def hello(): print('Hello, world!')" + metadata = self.extractor.extract_metadata(text) + + assert metadata["has_code"] is True + + def test_extract_metadata_with_questions(self): + """Test metadata extraction with questions.""" + text = "What is Python? How does it work? Why use it?" + metadata = self.extractor.extract_metadata(text) + + assert metadata["has_questions"] is True + + def test_categorize_content_programming(self): + """Test content categorization for programming.""" + text = "Python is a programming language used for code development and debugging." + categories = self.extractor.categorize_content(text) + + assert "programming" in categories + + def test_categorize_content_data(self): + """Test content categorization for data.""" + text = "The database contains records and tables with statistical analysis." + categories = self.extractor.categorize_content(text) + + assert "data" in categories + + def test_categorize_content_documentation(self): + """Test content categorization for documentation.""" + text = "This guide explains how to use the tutorial and manual." + categories = self.extractor.categorize_content(text) + + assert "documentation" in categories + + def test_categorize_content_configuration(self): + """Test content categorization for configuration.""" + text = "Configure the settings and setup the deployment environment." + categories = self.extractor.categorize_content(text) + + assert "configuration" in categories + + def test_categorize_content_testing(self): + """Test content categorization for testing.""" + text = "Run the tests to validate the functionality and verify quality." + categories = self.extractor.categorize_content(text) + + assert "testing" in categories + + def test_categorize_content_research(self): + """Test content categorization for research.""" + text = "The study investigates findings and results from the analysis." + categories = self.extractor.categorize_content(text) + + assert "research" in categories + + def test_categorize_content_planning(self): + """Test content categorization for planning.""" + text = "Plan the project schedule with milestones and timeline." + categories = self.extractor.categorize_content(text) + + assert "planning" in categories + + def test_categorize_content_general(self): + """Test content categorization defaults to general.""" + text = "This is some random text without specific keywords." + categories = self.extractor.categorize_content(text) + + assert "general" in categories + + def test_extract_facts_empty_text(self): + """Test fact extraction with empty text.""" + facts = self.extractor.extract_facts("") + assert len(facts) == 0 + + def test_extract_key_terms_empty_text(self): + """Test key term extraction with empty text.""" + terms = self.extractor.extract_key_terms("") + assert len(terms) == 0 + + def test_extract_relationships_empty_text(self): + """Test relationship extraction with empty text.""" + relationships = self.extractor.extract_relationships("") + assert len(relationships) == 0 + + def test_extract_metadata_empty_text(self): + """Test metadata extraction with empty text.""" + metadata = self.extractor.extract_metadata("") + assert metadata["word_count"] == 0 + assert metadata["sentence_count"] == 0 + assert metadata["avg_words_per_sentence"] == 0.0 diff --git a/tests/test_knowledge_store.py b/tests/test_knowledge_store.py new file mode 100644 index 0000000..baab6d8 --- /dev/null +++ b/tests/test_knowledge_store.py @@ -0,0 +1,344 @@ +import sqlite3 +import tempfile +import os +import time + +from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry + + +class TestKnowledgeStore: + def setup_method(self): + """Set up test database for each test.""" + self.db_fd, self.db_path = tempfile.mkstemp() + self.store = KnowledgeStore(self.db_path) + + def teardown_method(self): + """Clean up test database after each test.""" + self.store = None + os.close(self.db_fd) + os.unlink(self.db_path) + + def test_init(self): + """Test KnowledgeStore initialization.""" + assert self.store.db_path == self.db_path + # Verify tables were created + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + assert "knowledge_entries" in tables + conn.close() + + def test_add_entry(self): + """Test adding a knowledge entry.""" + entry = KnowledgeEntry( + entry_id="test_1", + category="test", + content="This is a test entry", + metadata={"source": "test"}, + created_at=time.time(), + updated_at=time.time(), + ) + + self.store.add_entry(entry) + + # Verify entry was added + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + cursor.execute( + "SELECT entry_id, category, content FROM knowledge_entries WHERE entry_id = ?", + ("test_1",), + ) + row = cursor.fetchone() + assert row[0] == "test_1" + assert row[1] == "test" + assert row[2] == "This is a test entry" + conn.close() + + def test_get_entry(self): + """Test retrieving a knowledge entry.""" + entry = KnowledgeEntry( + entry_id="test_get", + category="test", + content="Content to retrieve", + metadata={}, + created_at=time.time(), + updated_at=time.time(), + ) + + self.store.add_entry(entry) + retrieved = self.store.get_entry("test_get") + + assert retrieved is not None + assert retrieved.entry_id == "test_get" + assert retrieved.content == "Content to retrieve" + assert retrieved.access_count == 1 # Should be incremented + + def test_get_entry_not_found(self): + """Test retrieving a non-existent entry.""" + retrieved = self.store.get_entry("nonexistent") + assert retrieved is None + + def test_search_entries_semantic(self): + """Test semantic search.""" + entries = [ + KnowledgeEntry( + "entry1", "personal", "John is a software engineer", {}, time.time(), time.time() + ), + KnowledgeEntry( + "entry2", "personal", "Mary works as a designer", {}, time.time(), time.time() + ), + KnowledgeEntry( + "entry3", "tech", "Python is a programming language", {}, time.time(), time.time() + ), + ] + + for entry in entries: + self.store.add_entry(entry) + + results = self.store.search_entries("software engineer", top_k=2) + assert len(results) >= 1 + # Should find the most relevant entry + found_ids = [r.entry_id for r in results] + assert "entry1" in found_ids + + def test_search_entries_fts_exact(self): + """Test full-text search with exact matches.""" + entries = [ + KnowledgeEntry( + "exact1", "test", "Python programming language", {}, time.time(), time.time() + ), + KnowledgeEntry( + "exact2", "test", "Java programming language", {}, time.time(), time.time() + ), + KnowledgeEntry( + "exact3", "test", "Database design principles", {}, time.time(), time.time() + ), + ] + + for entry in entries: + self.store.add_entry(entry) + + results = self.store.search_entries("programming language", top_k=3) + assert len(results) >= 2 + found_ids = [r.entry_id for r in results] + assert "exact1" in found_ids + assert "exact2" in found_ids + + def test_search_entries_by_category(self): + """Test searching entries by category.""" + entries = [ + KnowledgeEntry("cat1", "personal", "John's info", {}, time.time(), time.time()), + KnowledgeEntry("cat2", "tech", "Python info", {}, time.time(), time.time()), + KnowledgeEntry("cat3", "personal", "Jane's info", {}, time.time(), time.time()), + ] + + for entry in entries: + self.store.add_entry(entry) + + results = self.store.search_entries("info", category="personal", top_k=5) + assert len(results) == 2 + found_ids = [r.entry_id for r in results] + assert "cat1" in found_ids + assert "cat3" in found_ids + assert "cat2" not in found_ids + + def test_get_by_category(self): + """Test getting entries by category.""" + entries = [ + KnowledgeEntry("get1", "personal", "Entry 1", {}, time.time(), time.time()), + KnowledgeEntry("get2", "tech", "Entry 2", {}, time.time(), time.time()), + KnowledgeEntry("get3", "personal", "Entry 3", {}, time.time() + 1, time.time() + 1), + ] + + for entry in entries: + self.store.add_entry(entry) + + personal_entries = self.store.get_by_category("personal") + assert len(personal_entries) == 2 + # Should be ordered by importance_score DESC, created_at DESC + assert personal_entries[0].entry_id == "get3" # More recent + assert personal_entries[1].entry_id == "get1" + + def test_update_importance(self): + """Test updating entry importance.""" + entry = KnowledgeEntry( + "importance_test", "test", "Test content", {}, time.time(), time.time() + ) + self.store.add_entry(entry) + + self.store.update_importance("importance_test", 0.8) + + retrieved = self.store.get_entry("importance_test") + assert retrieved.importance_score == 0.8 + + def test_delete_entry(self): + """Test deleting an entry.""" + entry = KnowledgeEntry("delete_test", "test", "To be deleted", {}, time.time(), time.time()) + self.store.add_entry(entry) + + result = self.store.delete_entry("delete_test") + assert result is True + + # Verify it's gone + retrieved = self.store.get_entry("delete_test") + assert retrieved is None + + def test_delete_entry_not_found(self): + """Test deleting a non-existent entry.""" + result = self.store.delete_entry("nonexistent") + assert result is False + + def test_get_statistics(self): + """Test getting store statistics.""" + entries = [ + KnowledgeEntry("stat1", "personal", "Personal info", {}, time.time(), time.time()), + KnowledgeEntry("stat2", "tech", "Tech info", {}, time.time(), time.time()), + KnowledgeEntry("stat3", "personal", "More personal info", {}, time.time(), time.time()), + ] + + for entry in entries: + self.store.add_entry(entry) + + stats = self.store.get_statistics() + assert stats["total_entries"] == 3 + assert stats["total_categories"] == 2 + assert stats["category_distribution"]["personal"] == 2 + assert stats["category_distribution"]["tech"] == 1 + assert stats["vocabulary_size"] > 0 + + def test_fts_search_exact_phrase(self): + """Test FTS exact phrase matching.""" + entries = [ + KnowledgeEntry( + "fts1", "test", "Python is great for programming", {}, time.time(), time.time() + ), + KnowledgeEntry( + "fts2", "test", "Java is also good for programming", {}, time.time(), time.time() + ), + KnowledgeEntry( + "fts3", "test", "Database management is important", {}, time.time(), time.time() + ), + ] + + for entry in entries: + self.store.add_entry(entry) + + fts_results = self.store._fts_search("programming") + assert len(fts_results) == 2 + entry_ids = [entry_id for entry_id, score in fts_results] + assert "fts1" in entry_ids + assert "fts2" in entry_ids + + def test_fts_search_partial_match(self): + """Test FTS partial word matching.""" + entries = [ + KnowledgeEntry("partial1", "test", "Python programming", {}, time.time(), time.time()), + KnowledgeEntry("partial2", "test", "Java programming", {}, time.time(), time.time()), + KnowledgeEntry("partial3", "test", "Database design", {}, time.time(), time.time()), + ] + + for entry in entries: + self.store.add_entry(entry) + + fts_results = self.store._fts_search("program") + assert len(fts_results) >= 2 + + def test_combined_search_scoring(self): + """Test that combined semantic + FTS search produces proper scoring.""" + entries = [ + KnowledgeEntry( + "combined1", "test", "Python programming language", {}, time.time(), time.time() + ), + KnowledgeEntry( + "combined2", "test", "Java programming language", {}, time.time(), time.time() + ), + KnowledgeEntry( + "combined3", "test", "Database management system", {}, time.time(), time.time() + ), + ] + + for entry in entries: + self.store.add_entry(entry) + + results = self.store.search_entries("programming language") + assert len(results) >= 2 + + # Check that results have search scores (at least one should have a positive score) + has_positive_score = False + for result in results: + assert "search_score" in result.metadata + if result.metadata["search_score"] > 0: + has_positive_score = True + assert has_positive_score + + def test_empty_search(self): + """Test searching with no matches.""" + results = self.store.search_entries("nonexistent topic") + assert len(results) == 0 + + def test_search_empty_query(self): + """Test searching with empty query.""" + results = self.store.search_entries("") + assert len(results) == 0 + + def test_thread_safety(self): + """Test that the store can handle concurrent access.""" + import threading + import queue + + results = queue.Queue() + + def worker(worker_id): + try: + entry = KnowledgeEntry( + f"thread_{worker_id}", + "test", + f"Content from worker {worker_id}", + {}, + time.time(), + time.time(), + ) + self.store.add_entry(entry) + retrieved = self.store.get_entry(f"thread_{worker_id}") + assert retrieved is not None + results.put(True) + except Exception as e: + results.put(e) + + # Start multiple threads + threads = [] + for i in range(5): + t = threading.Thread(target=worker, args=(i,)) + threads.append(t) + t.start() + + # Wait for all threads + for t in threads: + t.join() + + # Check results + for _ in range(5): + result = results.get() + assert result is True + + def test_entry_to_dict(self): + """Test KnowledgeEntry to_dict method.""" + entry = KnowledgeEntry( + entry_id="dict_test", + category="test", + content="Test content", + metadata={"key": "value"}, + created_at=1234567890.0, + updated_at=1234567891.0, + access_count=5, + importance_score=0.8, + ) + + entry_dict = entry.to_dict() + assert entry_dict["entry_id"] == "dict_test" + assert entry_dict["category"] == "test" + assert entry_dict["content"] == "Test content" + assert entry_dict["metadata"]["key"] == "value" + assert entry_dict["access_count"] == 5 + assert entry_dict["importance_score"] == 0.8 diff --git a/tests/test_semantic_index.py b/tests/test_semantic_index.py new file mode 100644 index 0000000..d50b893 --- /dev/null +++ b/tests/test_semantic_index.py @@ -0,0 +1,280 @@ +import math +import pytest +from pr.memory.semantic_index import SemanticIndex + + +class TestSemanticIndex: + def test_init(self): + """Test SemanticIndex initialization.""" + index = SemanticIndex() + assert index.documents == {} + assert index.vocabulary == set() + assert index.idf_scores == {} + assert index.doc_tf_scores == {} + + def test_tokenize_basic(self): + """Test basic tokenization functionality.""" + index = SemanticIndex() + tokens = index._tokenize("Hello, world! This is a test.") + expected = ["hello", "world", "this", "is", "a", "test"] + assert tokens == expected + + def test_tokenize_special_characters(self): + """Test tokenization with special characters.""" + index = SemanticIndex() + tokens = index._tokenize("Hello@world.com test-case_123") + expected = ["hello", "world", "com", "test", "case", "123"] + assert tokens == expected + + def test_tokenize_empty_string(self): + """Test tokenization of empty string.""" + index = SemanticIndex() + tokens = index._tokenize("") + assert tokens == [] + + def test_tokenize_only_special_chars(self): + """Test tokenization with only special characters.""" + index = SemanticIndex() + tokens = index._tokenize("!@#$%^&*()") + assert tokens == [] + + def test_compute_tf_basic(self): + """Test TF computation for basic case.""" + index = SemanticIndex() + tokens = ["hello", "world", "hello", "test"] + tf_scores = index._compute_tf(tokens) + expected = {"hello": 2 / 4, "world": 1 / 4, "test": 1 / 4} # 0.5 # 0.25 # 0.25 + assert tf_scores == expected + + def test_compute_tf_empty(self): + """Test TF computation for empty tokens.""" + index = SemanticIndex() + tf_scores = index._compute_tf([]) + assert tf_scores == {} + + def test_compute_tf_single_token(self): + """Test TF computation for single token.""" + index = SemanticIndex() + tokens = ["hello"] + tf_scores = index._compute_tf(tokens) + assert tf_scores == {"hello": 1.0} + + def test_compute_idf_single_document(self): + """Test IDF computation with single document.""" + index = SemanticIndex() + index.documents = {"doc1": "hello world"} + index._compute_idf() + assert index.idf_scores == {"hello": 1.0, "world": 1.0} + + def test_compute_idf_multiple_documents(self): + """Test IDF computation with multiple documents.""" + index = SemanticIndex() + index.documents = {"doc1": "hello world", "doc2": "hello test", "doc3": "world test"} + index._compute_idf() + expected = { + "hello": math.log(3 / 2), # appears in 2/3 docs + "world": math.log(3 / 2), # appears in 2/3 docs + "test": math.log(3 / 2), # appears in 2/3 docs + } + assert index.idf_scores == expected + + def test_compute_idf_empty_documents(self): + """Test IDF computation with no documents.""" + index = SemanticIndex() + index._compute_idf() + assert index.idf_scores == {} + + def test_add_document_basic(self): + """Test adding a basic document.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + + assert "doc1" in index.documents + assert index.documents["doc1"] == "hello world" + assert "hello" in index.vocabulary + assert "world" in index.vocabulary + assert "doc1" in index.doc_tf_scores + + def test_add_document_updates_vocabulary(self): + """Test that adding documents updates vocabulary.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + assert index.vocabulary == {"hello", "world"} + + index.add_document("doc2", "hello test") + assert index.vocabulary == {"hello", "world", "test"} + + def test_add_document_updates_idf(self): + """Test that adding documents updates IDF scores.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + assert index.idf_scores == {"hello": 1.0, "world": 1.0} + + index.add_document("doc2", "hello test") + expected_idf = { + "hello": math.log(2 / 2), # appears in both docs + "world": math.log(2 / 1), # appears in 1/2 docs + "test": math.log(2 / 1), # appears in 1/2 docs + } + assert index.idf_scores == expected_idf + + def test_add_document_tf_computation(self): + """Test TF score computation when adding document.""" + index = SemanticIndex() + index.add_document("doc1", "hello world hello") + + # TF: hello=2/3, world=1/3 + expected_tf = {"hello": 2 / 3, "world": 1 / 3} + assert index.doc_tf_scores["doc1"] == expected_tf + + def test_remove_document_existing(self): + """Test removing an existing document.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + index.add_document("doc2", "hello test") + + initial_vocab = index.vocabulary.copy() + initial_idf = index.idf_scores.copy() + + index.remove_document("doc1") + + assert "doc1" not in index.documents + assert "doc1" not in index.doc_tf_scores + # Vocabulary should still contain all words + assert index.vocabulary == initial_vocab + # IDF should be recomputed + assert index.idf_scores != initial_idf + assert index.idf_scores == {"hello": 1.0, "test": 1.0} + + def test_remove_document_nonexistent(self): + """Test removing a non-existent document.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + + initial_state = { + "documents": index.documents.copy(), + "vocabulary": index.vocabulary.copy(), + "idf_scores": index.idf_scores.copy(), + "doc_tf_scores": index.doc_tf_scores.copy(), + } + + index.remove_document("nonexistent") + + assert index.documents == initial_state["documents"] + assert index.vocabulary == initial_state["vocabulary"] + assert index.idf_scores == initial_state["idf_scores"] + assert index.doc_tf_scores == initial_state["doc_tf_scores"] + + def test_search_basic(self): + """Test basic search functionality.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + index.add_document("doc2", "hello test") + index.add_document("doc3", "world test") + + results = index.search("hello", top_k=5) + assert len(results) == 3 # All documents are returned with similarity scores + + # Results should be sorted by similarity (descending) + scores = {doc_id: score for doc_id, score in results} + assert scores["doc1"] > 0 # doc1 contains "hello" + assert scores["doc2"] > 0 # doc2 contains "hello" + assert scores["doc3"] == 0 # doc3 does not contain "hello" + + def test_search_empty_query(self): + """Test search with empty query.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + + results = index.search("", top_k=5) + assert results == [] + + def test_search_no_documents(self): + """Test search when no documents exist.""" + index = SemanticIndex() + results = index.search("hello", top_k=5) + assert results == [] + + def test_search_top_k_limit(self): + """Test search respects top_k parameter.""" + index = SemanticIndex() + for i in range(10): + index.add_document(f"doc{i}", f"hello world content") + + results = index.search("content", top_k=3) + assert len(results) == 3 + + def test_cosine_similarity_identical_vectors(self): + """Test cosine similarity with identical vectors.""" + index = SemanticIndex() + vec1 = {"hello": 1.0, "world": 0.5} + vec2 = {"hello": 1.0, "world": 0.5} + similarity = index._cosine_similarity(vec1, vec2) + assert similarity == pytest.approx(1.0) + + def test_cosine_similarity_orthogonal_vectors(self): + """Test cosine similarity with orthogonal vectors.""" + index = SemanticIndex() + vec1 = {"hello": 1.0} + vec2 = {"world": 1.0} + similarity = index._cosine_similarity(vec1, vec2) + assert similarity == 0.0 + + def test_cosine_similarity_zero_vector(self): + """Test cosine similarity with zero vector.""" + index = SemanticIndex() + vec1 = {"hello": 1.0} + vec2 = {} + similarity = index._cosine_similarity(vec1, vec2) + assert similarity == 0.0 + + def test_cosine_similarity_empty_vectors(self): + """Test cosine similarity with empty vectors.""" + index = SemanticIndex() + similarity = index._cosine_similarity({}, {}) + assert similarity == 0.0 + + def test_search_relevance_ordering(self): + """Test that search results are ordered by relevance.""" + index = SemanticIndex() + index.add_document("doc1", "hello hello hello") # High TF for "hello" + index.add_document("doc2", "hello world") # Medium relevance + index.add_document("doc3", "world test") # No "hello" + + results = index.search("hello", top_k=5) + assert len(results) == 3 # All documents are returned + + # doc1 should have higher score than doc2, and doc3 should have 0 + scores = {doc_id: score for doc_id, score in results} + assert scores["doc1"] > scores["doc2"] > scores["doc3"] + assert scores["doc3"] == 0 + + def test_vocabulary_persistence(self): + """Test that vocabulary persists even after document removal.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + index.add_document("doc2", "test case") + + assert index.vocabulary == {"hello", "world", "test", "case"} + + index.remove_document("doc1") + # Vocabulary should still contain all words + assert index.vocabulary == {"hello", "world", "test", "case"} + + def test_idf_recomputation_after_removal(self): + """Test IDF recomputation after document removal.""" + index = SemanticIndex() + index.add_document("doc1", "hello world") + index.add_document("doc2", "hello test") + index.add_document("doc3", "world test") + + # Remove doc3, leaving doc1 and doc2 + index.remove_document("doc3") + + # "hello" appears in both remaining docs, "world" and "test" in one each + expected_idf = { + "hello": math.log(2 / 2), # 2 docs, appears in 2 + "world": math.log(2 / 1), # 2 docs, appears in 1 + "test": math.log(2 / 1), # 2 docs, appears in 1 + } + assert index.idf_scores == expected_idf diff --git a/tests/test_tools.py b/tests/test_tools.py index e6e1790..1c16f41 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -168,7 +168,6 @@ class TestToolDefinitions: tool_names = [t["function"]["name"] for t in tools] assert "run_command" in tool_names - assert "start_interactive_session" in tool_names def test_python_exec_present(self): tools = get_tools_definition()