feat: did some extensive memory implementations.

This commit is contained in:
retoor 2025-11-06 15:15:06 +01:00
parent 27bcc7409e
commit 31d272daa3
48 changed files with 6395 additions and 1062 deletions

View File

@ -1,5 +1,13 @@
# Changelog
## Version 1.3.0 - 2025-11-05
This release updates how the software finds configuration files and handles communication with agents. These changes improve reliability and allow for more flexible configuration options.
**Changes:** 32 files, 2964 lines
**Languages:** Other (706 lines), Python (2214 lines), TOML (44 lines)
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

View File

@ -72,7 +72,7 @@ serve:
implode: build
rpi rp.py -o rp
cp rp.py rp
chmod +x rp
if [ -d /home/retoor/bin ]; then cp rp /home/retoor/bin/rp; fi

969
pr/ads.py Normal file
View File

@ -0,0 +1,969 @@
import re
import json
from uuid import uuid4
from datetime import datetime, timezone
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
AsyncGenerator,
Union,
Tuple,
Set,
)
import aiosqlite
import asyncio
import aiohttp
from aiohttp import web
import socket
import os
import atexit
class AsyncDataSet:
"""
Distributed AsyncDataSet with client-server model over Unix sockets.
Parameters:
- file: Path to the SQLite database file
- socket_path: Path to Unix socket for inter-process communication (default: "ads.sock")
- max_concurrent_queries: Maximum concurrent database operations (default: 1)
Set to 1 to prevent "database is locked" errors with SQLite
"""
_KV_TABLE = "__kv_store"
_DEFAULT_COLUMNS = {
"uid": "TEXT PRIMARY KEY",
"created_at": "TEXT",
"updated_at": "TEXT",
"deleted_at": "TEXT",
}
def __init__(self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1):
self._file = file
self._socket_path = socket_path
self._table_columns_cache: Dict[str, Set[str]] = {}
self._is_server = None # None means not initialized yet
self._server = None
self._runner = None
self._client_session = None
self._initialized = False
self._init_lock = None # Will be created when needed
self._max_concurrent_queries = max_concurrent_queries
self._db_semaphore = None # Will be created when needed
self._db_connection = None # Persistent database connection
async def _ensure_initialized(self):
"""Ensure the instance is initialized as server or client."""
if self._initialized:
return
# Create lock if needed (first time in async context)
if self._init_lock is None:
self._init_lock = asyncio.Lock()
async with self._init_lock:
if self._initialized:
return
await self._initialize()
# Create semaphore for database operations (server only)
if self._is_server:
self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries)
self._initialized = True
async def _initialize(self):
"""Initialize as server or client."""
try:
# Try to create Unix socket
if os.path.exists(self._socket_path):
# Check if socket is active
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(self._socket_path)
sock.close()
# Socket is active, we're a client
self._is_server = False
await self._setup_client()
except (ConnectionRefusedError, FileNotFoundError):
# Socket file exists but not active, remove and become server
os.unlink(self._socket_path)
self._is_server = True
await self._setup_server()
else:
# No socket file, become server
self._is_server = True
await self._setup_server()
except Exception:
# Fallback to client mode
self._is_server = False
await self._setup_client()
# Establish persistent database connection
self._db_connection = await aiosqlite.connect(self._file)
async def _setup_server(self):
"""Setup aiohttp server."""
app = web.Application()
# Add routes for all methods
app.router.add_post("/insert", self._handle_insert)
app.router.add_post("/update", self._handle_update)
app.router.add_post("/delete", self._handle_delete)
app.router.add_post("/upsert", self._handle_upsert)
app.router.add_post("/get", self._handle_get)
app.router.add_post("/find", self._handle_find)
app.router.add_post("/count", self._handle_count)
app.router.add_post("/exists", self._handle_exists)
app.router.add_post("/kv_set", self._handle_kv_set)
app.router.add_post("/kv_get", self._handle_kv_get)
app.router.add_post("/execute_raw", self._handle_execute_raw)
app.router.add_post("/query_raw", self._handle_query_raw)
app.router.add_post("/query_one", self._handle_query_one)
app.router.add_post("/create_table", self._handle_create_table)
app.router.add_post("/insert_unique", self._handle_insert_unique)
app.router.add_post("/aggregate", self._handle_aggregate)
self._runner = web.AppRunner(app)
await self._runner.setup()
self._server = web.UnixSite(self._runner, self._socket_path)
await self._server.start()
# Register cleanup
def cleanup():
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
atexit.register(cleanup)
async def _setup_client(self):
"""Setup aiohttp client."""
connector = aiohttp.UnixConnector(path=self._socket_path)
self._client_session = aiohttp.ClientSession(connector=connector)
async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any:
"""Make HTTP request to server."""
if not self._client_session:
await self._setup_client()
url = f"http://localhost/{endpoint}"
async with self._client_session.post(url, json=data) as resp:
result = await resp.json()
if result.get("error"):
raise Exception(result["error"])
return result.get("result")
# Server handlers
async def _handle_insert(self, request):
data = await request.json()
try:
result = await self._server_insert(
data["table"], data["args"], data.get("return_id", False)
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_update(self, request):
data = await request.json()
try:
result = await self._server_update(data["table"], data["args"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_delete(self, request):
data = await request.json()
try:
result = await self._server_delete(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_upsert(self, request):
data = await request.json()
try:
result = await self._server_upsert(data["table"], data["args"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_get(self, request):
data = await request.json()
try:
result = await self._server_get(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_find(self, request):
data = await request.json()
try:
result = await self._server_find(
data["table"],
data.get("where"),
limit=data.get("limit", 0),
offset=data.get("offset", 0),
order_by=data.get("order_by"),
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_count(self, request):
data = await request.json()
try:
result = await self._server_count(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_exists(self, request):
data = await request.json()
try:
result = await self._server_exists(data["table"], data["where"])
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_kv_set(self, request):
data = await request.json()
try:
await self._server_kv_set(data["key"], data["value"], table=data.get("table"))
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_kv_get(self, request):
data = await request.json()
try:
result = await self._server_kv_get(
data["key"], default=data.get("default"), table=data.get("table")
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_execute_raw(self, request):
data = await request.json()
try:
result = await self._server_execute_raw(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response(
{"result": result.rowcount if hasattr(result, "rowcount") else None}
)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_query_raw(self, request):
data = await request.json()
try:
result = await self._server_query_raw(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_query_one(self, request):
data = await request.json()
try:
result = await self._server_query_one(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_create_table(self, request):
data = await request.json()
try:
await self._server_create_table(data["table"], data["schema"], data.get("constraints"))
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_insert_unique(self, request):
data = await request.json()
try:
result = await self._server_insert_unique(
data["table"], data["args"], data["unique_fields"]
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_aggregate(self, request):
data = await request.json()
try:
result = await self._server_aggregate(
data["table"],
data["function"],
data.get("column", "*"),
data.get("where"),
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
# Public methods that delegate to server or client
async def insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
await self._ensure_initialized()
if self._is_server:
return await self._server_insert(table, args, return_id)
else:
return await self._make_request(
"insert", {"table": table, "args": args, "return_id": return_id}
)
async def update(
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_update(table, args, where)
else:
return await self._make_request(
"update", {"table": table, "args": args, "where": where}
)
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_delete(table, where)
else:
return await self._make_request("delete", {"table": table, "where": where})
async def get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_get(table, where)
else:
return await self._make_request("get", {"table": table, "where": where})
async def find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_find(
table, where, limit=limit, offset=offset, order_by=order_by
)
else:
return await self._make_request(
"find",
{
"table": table,
"where": where,
"limit": limit,
"offset": offset,
"order_by": order_by,
},
)
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_count(table, where)
else:
return await self._make_request("count", {"table": table, "where": where})
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
await self._ensure_initialized()
if self._is_server:
return await self._server_exists(table, where)
else:
return await self._make_request("exists", {"table": table, "where": where})
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_query_raw(sql, params)
else:
return await self._make_request(
"query_raw", {"sql": sql, "params": list(params) if params else None}
)
async def create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
await self._ensure_initialized()
if self._is_server:
return await self._server_create_table(table, schema, constraints)
else:
return await self._make_request(
"create_table",
{"table": table, "schema": schema, "constraints": constraints},
)
async def insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
await self._ensure_initialized()
if self._is_server:
return await self._server_insert_unique(table, args, unique_fields)
else:
return await self._make_request(
"insert_unique",
{"table": table, "args": args, "unique_fields": unique_fields},
)
async def aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
await self._ensure_initialized()
if self._is_server:
return await self._server_aggregate(table, function, column, where)
else:
return await self._make_request(
"aggregate",
{
"table": table,
"function": function,
"column": column,
"where": where,
},
)
async def transaction(self):
"""Context manager for transactions."""
await self._ensure_initialized()
if self._is_server:
return TransactionContext(self._db_connection, self._db_semaphore)
else:
raise NotImplementedError("Transactions not supported in client mode")
# Server implementation methods (original logic)
@staticmethod
def _utc_iso() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
@staticmethod
def _py_to_sqlite_type(value: Any) -> str:
if value is None:
return "TEXT"
if isinstance(value, bool):
return "INTEGER"
if isinstance(value, int):
return "INTEGER"
if isinstance(value, float):
return "REAL"
if isinstance(value, (bytes, bytearray, memoryview)):
return "BLOB"
return "TEXT"
async def _get_table_columns(self, table: str) -> Set[str]:
"""Get actual columns that exist in the table."""
if table in self._table_columns_cache:
return self._table_columns_cache[table]
columns = set()
try:
async with self._db_semaphore:
async with self._db_connection.execute(f"PRAGMA table_info({table})") as cursor:
async for row in cursor:
columns.add(row[1]) # Column name is at index 1
self._table_columns_cache[table] = columns
except:
pass
return columns
async def _invalidate_column_cache(self, table: str):
"""Invalidate column cache for a table."""
if table in self._table_columns_cache:
del self._table_columns_cache[table]
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
col_type = self._py_to_sqlite_type(value)
try:
async with self._db_semaphore:
await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
)
await self._db_connection.commit()
await self._invalidate_column_cache(table)
except aiosqlite.OperationalError as e:
if "duplicate column name" in str(e).lower():
pass # Column already exists
else:
raise
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
# Always include default columns
cols = self._DEFAULT_COLUMNS.copy()
# Add columns from col_sources
for key, val in col_sources.items():
if key not in cols:
cols[key] = self._py_to_sqlite_type(val)
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
async with self._db_semaphore:
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await self._db_connection.commit()
await self._invalidate_column_cache(table)
async def _table_exists(self, table: str) -> bool:
"""Check if a table exists."""
async with self._db_semaphore:
async with self._db_connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
) as cursor:
return await cursor.fetchone() is not None
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
@classmethod
def _missing_column_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
m = cls._RE_NO_COLUMN.search(str(err))
return m.group(1) if m else None
@classmethod
def _missing_table_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
m = cls._RE_NO_TABLE.search(str(err))
return m.group(1) if m else None
async def _safe_execute(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
max_retries: int = 10,
) -> aiosqlite.Cursor:
retries = 0
while retries < max_retries:
try:
async with self._db_semaphore:
cursor = await self._db_connection.execute(sql, params)
await self._db_connection.commit()
return cursor
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing column
col = self._missing_column_from_error(err)
if col:
if col in col_sources:
await self._ensure_column(table, col, col_sources[col])
else:
# Column not in sources, ensure it with NULL/TEXT type
await self._ensure_column(table, col, None)
continue
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
await self._ensure_table(tbl, col_sources)
continue
# Handle other column-related errors
if "has no column named" in err_str:
# Extract column name differently
match = re.search(r"table \w+ has no column named (\w+)", err_str)
if match:
col_name = match.group(1)
if col_name in col_sources:
await self._ensure_column(table, col_name, col_sources[col_name])
else:
await self._ensure_column(table, col_name, None)
continue
raise
raise Exception(f"Max retries ({max_retries}) exceeded")
async def _safe_query(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
) -> AsyncGenerator[Dict[str, Any], None]:
# Check if table exists first
if not await self._table_exists(table):
return
max_retries = 10
retries = 0
while retries < max_retries:
try:
async with self._db_semaphore:
self._db_connection.row_factory = aiosqlite.Row
async with self._db_connection.execute(sql, params) as cursor:
# Fetch all rows while holding the semaphore
rows = await cursor.fetchall()
# Yield rows after releasing the semaphore
for row in rows:
yield dict(row)
return
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
# For queries, if table doesn't exist, just return empty
return
# Handle missing column in WHERE clause or SELECT
if "no such column" in err_str:
# For queries with missing columns, return empty
return
raise
@staticmethod
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where:
return "", []
clauses = []
vals = []
op_map = {"$eq": "=", "$ne": "!=", "$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="}
for k, v in where.items():
if isinstance(v, dict):
# Handle operators like {"$lt": 10}
op_key = next(iter(v))
if op_key in op_map:
operator = op_map[op_key]
value = v[op_key]
clauses.append(f"`{k}` {operator} ?")
vals.append(value)
else:
# Fallback for unexpected dict values
clauses.append(f"`{k}` = ?")
vals.append(json.dumps(v))
else:
# Default to equality
clauses.append(f"`{k}` = ?")
vals.append(v)
return " WHERE " + " AND ".join(clauses), vals
async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
uid = str(uuid4())
now = self._utc_iso()
record = {
"uid": uid,
"created_at": now,
"updated_at": now,
"deleted_at": None,
**args,
}
# Ensure table exists with all needed columns
await self._ensure_table(table, record)
# Handle auto-increment ID if requested
if return_id and "id" not in args:
# Ensure id column exists
async with self._db_semaphore:
# Add id column if it doesn't exist
try:
await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
)
await self._db_connection.commit()
except aiosqlite.OperationalError as e:
if "duplicate column name" not in str(e).lower():
# Try without autoincrement constraint
try:
await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
)
await self._db_connection.commit()
except:
pass
await self._invalidate_column_cache(table)
# Insert and get lastrowid
cols = "`" + "`, `".join(record.keys()) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
cursor = await self._safe_execute(table, sql, list(record.values()), record)
return cursor.lastrowid
cols = "`" + "`, `".join(record) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
await self._safe_execute(table, sql, list(record.values()), record)
return uid
async def _server_update(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> int:
if not args:
return 0
# Check if table exists
if not await self._table_exists(table):
return 0
args["updated_at"] = self._utc_iso()
# Ensure all columns exist
all_cols = {**args, **(where or {})}
await self._ensure_table(table, all_cols)
for col, val in all_cols.items():
await self._ensure_column(table, col, val)
set_clause = ", ".join(f"`{k}` = ?" for k in args)
where_clause, where_params = self._build_where(where)
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
params = list(args.values()) + where_params
cur = await self._safe_execute(table, sql, params, all_cols)
return cur.rowcount
async def _server_delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"DELETE FROM {table}{where_clause}"
cur = await self._safe_execute(table, sql, where_params, where or {})
return cur.rowcount
async def _server_upsert(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> str | None:
if not args:
raise ValueError("Nothing to update. Empty dict given.")
args["updated_at"] = self._utc_iso()
affected = await self._server_update(table, args, where)
if affected:
rec = await self._server_get(table, where)
return rec.get("uid") if rec else None
merged = {**(where or {}), **args}
return await self._server_insert(table, merged)
async def _server_get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
where_clause, where_params = self._build_where(where)
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
async for row in self._safe_query(table, sql, where_params, where or {}):
return row
return None
async def _server_find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Find records with optional ordering."""
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {limit}" if limit else "") + (f" OFFSET {offset}" if offset else "")
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
return [row async for row in self._safe_query(table, sql, where_params, where or {})]
async def _server_count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
gen = self._safe_query(table, sql, where_params, where or {})
async for row in gen:
return next(iter(row.values()), 0)
return 0
async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool:
return (await self._server_count(table, where)) > 0
async def _server_kv_set(
self,
key: str,
value: Any,
*,
table: str | None = None,
) -> None:
tbl = table or self._KV_TABLE
json_val = json.dumps(value, default=str)
await self._server_upsert(tbl, {"value": json_val}, {"key": key})
async def _server_kv_get(
self,
key: str,
*,
default: Any = None,
table: str | None = None,
) -> Any:
tbl = table or self._KV_TABLE
row = await self._server_get(tbl, {"key": key})
if not row:
return default
try:
return json.loads(row["value"])
except Exception:
return default
async def _server_execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
"""Execute raw SQL for complex queries like JOINs."""
async with self._db_semaphore:
cursor = await self._db_connection.execute(sql, params or ())
await self._db_connection.commit()
return cursor
async def _server_query_raw(
self, sql: str, params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
"""Execute raw SQL query and return results as list of dicts."""
try:
async with self._db_semaphore:
self._db_connection.row_factory = aiosqlite.Row
async with self._db_connection.execute(sql, params or ()) as cursor:
return [dict(row) async for row in cursor]
except aiosqlite.OperationalError:
# Return empty list if query fails
return []
async def _server_query_one(
self, sql: str, params: Optional[Tuple] = None
) -> Optional[Dict[str, Any]]:
"""Execute raw SQL query and return single result."""
results = await self._server_query_raw(sql + " LIMIT 1", params)
return results[0] if results else None
async def _server_create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
"""Create table with custom schema and constraints. Always includes default columns."""
# Merge default columns with custom schema
full_schema = self._DEFAULT_COLUMNS.copy()
full_schema.update(schema)
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
if constraints:
columns.extend(constraints)
columns_sql = ", ".join(columns)
async with self._db_semaphore:
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await self._db_connection.commit()
await self._invalidate_column_cache(table)
async def _server_insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
try:
return await self._server_insert(table, args)
except aiosqlite.IntegrityError as e:
if "UNIQUE" in str(e):
return None
raise
async def _server_aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
# Check if table exists
if not await self._table_exists(table):
return None
where_clause, where_params = self._build_where(where)
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
result = await self._server_query_one(sql, tuple(where_params))
return result["result"] if result else None
async def close(self):
"""Close the connection or server."""
if not self._initialized:
return
if self._client_session:
await self._client_session.close()
if self._runner:
await self._runner.cleanup()
if os.path.exists(self._socket_path) and self._is_server:
os.unlink(self._socket_path)
if self._db_connection:
await self._db_connection.close()
class TransactionContext:
"""Context manager for database transactions."""
def __init__(self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None):
self.db_connection = db_connection
self.semaphore = semaphore
async def __aenter__(self):
if self.semaphore:
await self.semaphore.acquire()
try:
await self.db_connection.execute("BEGIN")
return self.db_connection
except:
if self.semaphore:
self.semaphore.release()
raise
async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is None:
await self.db_connection.commit()
else:
await self.db_connection.rollback()
finally:
if self.semaphore:
self.semaphore.release()
# Test cases remain the same but with additional tests for new functionality

View File

@ -99,7 +99,7 @@ class AgentCommunicationBus:
self.conn.commit()
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
def receive_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
cursor = self.conn.cursor()
if unread_only:
cursor.execute(
@ -153,9 +153,6 @@ class AgentCommunicationBus:
def close(self):
self.conn.close()
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
return self.get_messages(agent_id, unread_only=True)
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
cursor = self.conn.cursor()
cursor.execute(

View File

@ -125,7 +125,7 @@ class AgentManager:
return message.message_id
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
return self.communication_bus.get_messages(agent_id, unread_only)
return self.communication_bus.receive_messages(agent_id, unread_only)
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
orchestrator = self.get_agent(orchestrator_id)

View File

@ -229,45 +229,3 @@ def get_agent_role(role_name: str) -> AgentRole:
def list_agent_roles() -> Dict[str, AgentRole]:
return AGENT_ROLES.copy()
def get_recommended_agent(task_description: str) -> str:
task_lower = task_description.lower()
code_keywords = [
"code",
"implement",
"function",
"class",
"bug",
"debug",
"refactor",
"optimize",
]
research_keywords = [
"search",
"find",
"research",
"information",
"analyze",
"investigate",
]
data_keywords = ["data", "database", "query", "statistics", "analyze", "process"]
planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"]
testing_keywords = ["test", "verify", "validate", "check", "quality"]
doc_keywords = ["document", "documentation", "explain", "guide", "manual"]
if any(keyword in task_lower for keyword in code_keywords):
return "coding"
elif any(keyword in task_lower for keyword in research_keywords):
return "research"
elif any(keyword in task_lower for keyword in data_keywords):
return "data_analysis"
elif any(keyword in task_lower for keyword in planning_keywords):
return "planning"
elif any(keyword in task_lower for keyword in testing_keywords):
return "testing"
elif any(keyword in task_lower for keyword in doc_keywords):
return "documentation"
else:
return "general"

View File

@ -122,18 +122,6 @@ class ToolCache:
conn.commit()
conn.close()
def invalidate_tool(self, tool_name: str):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_expired(self):
current_time = int(time.time())

View File

@ -1,6 +1,7 @@
import json
import time
from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS
from pr.autonomous import run_autonomous_mode
from pr.core.api import list_models
from pr.tools import read_file
@ -13,6 +14,12 @@ def handle_command(assistant, command):
command_parts = command.strip().split(maxsplit=1)
cmd = command_parts[0].lower()
if cmd in MULTIPLEXER_COMMANDS:
# Pass the assistant object and the remaining arguments to the multiplexer command
return MULTIPLEXER_COMMANDS[cmd](
assistant, command_parts[1:] if len(command_parts) > 1 else []
)
if cmd == "/edit":
rp_editor = RPEditor(command_parts[1] if len(command_parts) > 1 else None)
rp_editor.start()
@ -23,6 +30,18 @@ def handle_command(assistant, command):
if task:
run_autonomous_mode(assistant, task)
elif cmd == "/prompt":
rp_editor = RPEditor(command_parts[1] if len(command_parts) > 1 else None)
rp_editor.start()
rp_editor.thread.join()
prompt_text = str(rp_editor.get_text())
rp_editor.stop()
rp_editor = None
if prompt_text.strip():
from pr.core.assistant import process_message
process_message(assistant, prompt_text)
elif cmd == "/auto":
if len(command_parts) < 2:
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")

507
pr/commands/help_docs.py Normal file
View File

@ -0,0 +1,507 @@
from pr.ui import Colors
def get_workflow_help():
return f"""
{Colors.BOLD}WORKFLOWS - AUTOMATED TASK EXECUTION{Colors.RESET}
{Colors.BOLD}Overview:{Colors.RESET}
Workflows enable automated execution of multi-step tasks by chaining together
tool calls with defined execution logic, error handling, and conditional branching.
{Colors.BOLD}Execution Modes:{Colors.RESET}
{Colors.CYAN}sequential{Colors.RESET} - Steps execute one after another in order
{Colors.CYAN}parallel{Colors.RESET} - All steps execute simultaneously
{Colors.CYAN}conditional{Colors.RESET} - Steps execute based on conditions and paths
{Colors.BOLD}Workflow Components:{Colors.RESET}
{Colors.CYAN}Steps{Colors.RESET} - Individual tool executions with parameters
{Colors.CYAN}Variables{Colors.RESET} - Dynamic values available across workflow
{Colors.CYAN}Conditions{Colors.RESET} - Boolean expressions controlling step execution
{Colors.CYAN}Paths{Colors.RESET} - Success/failure routing between steps
{Colors.CYAN}Retries{Colors.RESET} - Automatic retry count for failed steps
{Colors.CYAN}Timeouts{Colors.RESET} - Maximum execution time per step (seconds)
{Colors.BOLD}Variable Substitution:{Colors.RESET}
Reference previous step results: ${{step.step_id}}
Reference workflow variables: ${{var.variable_name}}
Example: {{"path": "${{step.get_path}}/output.txt"}}
{Colors.BOLD}Workflow Commands:{Colors.RESET}
{Colors.CYAN}/workflows{Colors.RESET} - List all available workflows
{Colors.CYAN}/workflow <name>{Colors.RESET} - Execute a specific workflow
{Colors.BOLD}Workflow Structure:{Colors.RESET}
Each workflow consists of:
- Name and description
- Execution mode (sequential/parallel/conditional)
- List of steps with tool calls
- Optional variables for dynamic values
- Optional tags for categorization
{Colors.BOLD}Step Configuration:{Colors.RESET}
Each step includes:
- tool_name: Which tool to execute
- arguments: Parameters passed to the tool
- step_id: Unique identifier for the step
- condition: Optional boolean expression (conditional mode)
- on_success: List of step IDs to execute on success
- on_failure: List of step IDs to execute on failure
- retry_count: Number of automatic retries (default: 0)
- timeout_seconds: Maximum execution time (default: 300)
{Colors.BOLD}Conditional Execution:{Colors.RESET}
Conditions are Python expressions evaluated with access to:
- variables: All workflow variables
- results: All step results by step_id
Example condition: "results['check_file']['status'] == 'success'"
{Colors.BOLD}Success/Failure Paths:{Colors.RESET}
Steps can define:
- on_success: List of step IDs to execute if step succeeds
- on_failure: List of step IDs to execute if step fails
This enables error recovery and branching logic within workflows.
{Colors.BOLD}Workflow Storage:{Colors.RESET}
Workflows are stored in the SQLite database and persist across sessions.
Execution history and statistics are tracked for each workflow.
{Colors.BOLD}Example Workflow Scenarios:{Colors.RESET}
- Sequential: Data pipeline (fetch process save validate)
- Parallel: Multi-source aggregation (fetch A + fetch B + fetch C)
- Conditional: Error recovery (try main on failure fallback)
{Colors.BOLD}Best Practices:{Colors.RESET}
- Use descriptive step_id names (e.g., "fetch_user_data")
- Set appropriate timeouts for long-running operations
- Add retry_count for network or flaky operations
- Use conditional paths for error handling
- Keep workflows focused on a single logical task
- Use variables for values shared across steps
"""
def get_agent_help():
return f"""
{Colors.BOLD}AGENTS - SPECIALIZED AI ASSISTANTS{Colors.RESET}
{Colors.BOLD}Overview:{Colors.RESET}
Agents are specialized AI instances with specific roles, capabilities, and
system prompts. Each agent maintains its own conversation history and can
collaborate with other agents to complete complex tasks.
{Colors.BOLD}Available Agent Roles:{Colors.RESET}
{Colors.CYAN}coding{Colors.RESET} - Code writing, review, debugging, refactoring
Temperature: 0.3 (precise)
{Colors.CYAN}research{Colors.RESET} - Information gathering, analysis, documentation
Temperature: 0.5 (balanced)
{Colors.CYAN}data_analysis{Colors.RESET} - Data processing, statistical analysis, queries
Temperature: 0.3 (precise)
{Colors.CYAN}planning{Colors.RESET} - Task decomposition, workflow design, coordination
Temperature: 0.6 (creative)
{Colors.CYAN}testing{Colors.RESET} - Test design, quality assurance, validation
Temperature: 0.4 (methodical)
{Colors.CYAN}documentation{Colors.RESET} - Technical writing, API docs, user guides
Temperature: 0.6 (creative)
{Colors.CYAN}orchestrator{Colors.RESET} - Multi-agent coordination and task delegation
Temperature: 0.5 (balanced)
{Colors.CYAN}general{Colors.RESET} - General purpose assistance across domains
Temperature: 0.7 (versatile)
{Colors.BOLD}Agent Commands:{Colors.RESET}
{Colors.CYAN}/agent <role> <task>{Colors.RESET} - Create agent and assign task
{Colors.CYAN}/agents{Colors.RESET} - Show all active agents
{Colors.CYAN}/collaborate <task>{Colors.RESET} - Multi-agent collaboration
{Colors.BOLD}Single Agent Usage:{Colors.RESET}
Create a specialized agent and assign it a task:
/agent coding Implement a binary search function in Python
/agent research Find latest best practices for API security
/agent testing Create test cases for user authentication
Each agent:
- Maintains its own conversation history
- Has access to role-specific tools
- Uses specialized system prompts
- Tracks task completion count
- Integrates with knowledge base
{Colors.BOLD}Multi-Agent Collaboration:{Colors.RESET}
Use multiple agents to work together on complex tasks:
/collaborate Build a REST API with tests and documentation
This automatically:
1. Creates an orchestrator agent
2. Spawns coding, research, and planning agents
3. Orchestrator delegates subtasks to specialists
4. Agents communicate results back to orchestrator
5. Orchestrator integrates results and coordinates
{Colors.BOLD}Agent Capabilities by Role:{Colors.RESET}
{Colors.CYAN}Coding Agent:{Colors.RESET}
- Read, write, and modify files
- Execute Python code
- Run commands and build tools
- Index source directories
- Code review and refactoring
{Colors.CYAN}Research Agent:{Colors.RESET}
- Web search and HTTP fetching
- Read documentation and files
- Query databases for information
- Analyze and synthesize findings
{Colors.CYAN}Data Analysis Agent:{Colors.RESET}
- Database queries and operations
- Python execution for analysis
- File read/write for data
- Statistical processing
{Colors.CYAN}Planning Agent:{Colors.RESET}
- Task decomposition and organization
- Database storage for plans
- File operations for documentation
- Dependency identification
{Colors.CYAN}Testing Agent:{Colors.RESET}
- Test execution and validation
- Code reading and analysis
- Command execution for test runners
- Quality verification
{Colors.CYAN}Documentation Agent:{Colors.RESET}
- Technical writing and formatting
- Web research for best practices
- File operations for docs
- Documentation organization
{Colors.BOLD}Agent Session Management:{Colors.RESET}
- Each session has a unique ID
- Agents persist within a session
- View active agents with /agents
- Agents share knowledge base access
- Communication bus for agent messages
{Colors.BOLD}Knowledge Base Integration:{Colors.RESET}
When agents execute tasks:
- Relevant knowledge entries are automatically retrieved
- Top 3 matches are injected into agent context
- Knowledge access count is tracked
- Agents can update knowledge base
{Colors.BOLD}Agent Communication:{Colors.RESET}
Agents can send messages to each other:
- Request type: Ask another agent for help
- Response type: Answer another agent's request
- Notification type: Inform about events
- Coordination type: Synchronize work
{Colors.BOLD}Best Practices:{Colors.RESET}
- Use specialized agents for focused tasks
- Use collaboration for complex multi-domain tasks
- Coding agent: Precise, deterministic code generation
- Research agent: Thorough information gathering
- Planning agent: Breaking down complex tasks
- Testing agent: Comprehensive test coverage
- Documentation agent: Clear user-facing docs
{Colors.BOLD}Performance Characteristics:{Colors.RESET}
- Lower temperature: More deterministic (coding, data_analysis)
- Higher temperature: More creative (planning, documentation)
- Max tokens: 4096 for all agents
- Parallel execution: Multiple agents work simultaneously
- Message history: Full conversation context maintained
{Colors.BOLD}Example Usage Patterns:{Colors.RESET}
Single specialized agent:
/agent coding Refactor authentication module for better security
Multi-agent collaboration:
/collaborate Create a data visualization dashboard with tests
Agent inspection:
/agents # View all active agents and their stats
"""
def get_knowledge_help():
return f"""
{Colors.BOLD}KNOWLEDGE BASE - PERSISTENT INFORMATION STORAGE{Colors.RESET}
{Colors.BOLD}Overview:{Colors.RESET}
The knowledge base provides persistent storage for information, facts, and
context that can be accessed across sessions and by all agents.
{Colors.BOLD}Commands:{Colors.RESET}
{Colors.CYAN}/remember <content>{Colors.RESET} - Store information in knowledge base
{Colors.CYAN}/knowledge <query>{Colors.RESET} - Search knowledge base
{Colors.CYAN}/stats{Colors.RESET} - Show knowledge base statistics
{Colors.BOLD}Features:{Colors.RESET}
- Automatic categorization of content
- TF-IDF based semantic search
- Access frequency tracking
- Importance scoring
- Temporal relevance
- Category-based organization
{Colors.BOLD}Categories:{Colors.RESET}
Content is automatically categorized as:
- code: Code snippets and programming concepts
- api: API documentation and endpoints
- configuration: Settings and configurations
- documentation: General documentation
- fact: Factual information
- procedure: Step-by-step procedures
- general: Uncategorized information
{Colors.BOLD}Search Capabilities:{Colors.RESET}
- Keyword matching with TF-IDF scoring
- Returns top relevant entries
- Considers access frequency
- Factors in recency
- Searches across all categories
{Colors.BOLD}Integration:{Colors.RESET}
- Agents automatically query knowledge base
- Top matches injected into agent context
- Conversation history can be stored
- Facts extracted from interactions
- Persistent across sessions
"""
def get_cache_help():
return f"""
{Colors.BOLD}CACHING SYSTEM - PERFORMANCE OPTIMIZATION{Colors.RESET}
{Colors.BOLD}Overview:{Colors.RESET}
The caching system optimizes performance by storing API responses and tool
results, reducing redundant API calls and expensive operations.
{Colors.BOLD}Commands:{Colors.RESET}
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
{Colors.BOLD}Cache Types:{Colors.RESET}
{Colors.CYAN}API Cache:{Colors.RESET}
- Stores LLM API responses
- Key: Hash of messages + model + parameters
- Tracks token usage
- TTL: Configurable expiration
- Reduces API costs and latency
{Colors.CYAN}Tool Cache:{Colors.RESET}
- Stores tool execution results
- Key: Tool name + arguments hash
- Per-tool statistics
- Hit count tracking
- Reduces redundant operations
{Colors.BOLD}Cache Statistics:{Colors.RESET}
- Total entries (valid + expired)
- Valid entry count
- Expired entry count
- Total cached tokens (API cache)
- Cache hits per tool (Tool cache)
- Per-tool breakdown
{Colors.BOLD}Benefits:{Colors.RESET}
- Reduced API costs
- Faster response times
- Lower rate limit usage
- Consistent results for identical queries
- Tool execution optimization
"""
def get_background_help():
return f"""
{Colors.BOLD}BACKGROUND SESSIONS - CONCURRENT TASK EXECUTION{Colors.RESET}
{Colors.BOLD}Overview:{Colors.RESET}
Background sessions allow running long-running commands and processes while
continuing to interact with the assistant. Sessions are monitored and can
be managed through the assistant.
{Colors.BOLD}Commands:{Colors.RESET}
{Colors.CYAN}/bg start <command>{Colors.RESET} - Start command in background
{Colors.CYAN}/bg list{Colors.RESET} - List all background sessions
{Colors.CYAN}/bg status <name>{Colors.RESET} - Show session status
{Colors.CYAN}/bg output <name>{Colors.RESET} - View session output
{Colors.CYAN}/bg input <name> <text>{Colors.RESET} - Send input to session
{Colors.CYAN}/bg kill <name>{Colors.RESET} - Terminate session
{Colors.CYAN}/bg events{Colors.RESET} - Show recent events
{Colors.BOLD}Features:{Colors.RESET}
- Automatic session monitoring
- Output capture and buffering
- Interactive input support
- Event notification system
- Session lifecycle management
- Background event detection
{Colors.BOLD}Event Types:{Colors.RESET}
- session_started: New session created
- session_ended: Session terminated
- output_received: New output available
- possible_input_needed: May require user input
- high_output_volume: Large amount of output
- inactive_session: No activity for period
{Colors.BOLD}Status Indicators:{Colors.RESET}
The prompt shows background session count: You[2bg]>
- Number indicates active background sessions
- Updates automatically as sessions start/stop
{Colors.BOLD}Use Cases:{Colors.RESET}
- Long-running build processes
- Web servers and daemons
- File watching and monitoring
- Batch processing tasks
- Test suites execution
- Development servers
"""
def get_full_help():
return f"""
{Colors.BOLD}R - PROFESSIONAL AI ASSISTANT{Colors.RESET}
{Colors.BOLD}BASIC COMMANDS{Colors.RESET}
exit, quit, q - Exit the assistant
/help - Show this help message
/help workflows - Detailed workflow documentation
/help agents - Detailed agent documentation
/help knowledge - Knowledge base documentation
/help cache - Caching system documentation
/help background - Background sessions documentation
/reset - Clear message history
/dump - Show message history as JSON
/verbose - Toggle verbose mode
/models - List available models
/tools - List available tools
{Colors.BOLD}FILE OPERATIONS{Colors.RESET}
/review <file> - Review a file
/refactor <file> - Refactor code in a file
/obfuscate <file> - Obfuscate code in a file
{Colors.BOLD}AUTONOMOUS MODE{Colors.RESET}
{Colors.CYAN}/auto <task>{Colors.RESET} - Enter autonomous mode for task completion
Assistant works continuously until task is complete
Max 50 iterations with automatic context management
Press Ctrl+C twice to force exit
{Colors.BOLD}WORKFLOWS{Colors.RESET}
{Colors.CYAN}/workflows{Colors.RESET} - List all available workflows
{Colors.CYAN}/workflow <name>{Colors.RESET} - Execute a specific workflow
Workflows enable automated multi-step task execution with:
- Sequential, parallel, or conditional execution
- Variable substitution and step dependencies
- Error handling and retry logic
- Success/failure path routing
For detailed documentation: /help workflows
{Colors.BOLD}AGENTS{Colors.RESET}
{Colors.CYAN}/agent <role> <task>{Colors.RESET} - Create specialized agent
{Colors.CYAN}/agents{Colors.RESET} - Show active agents
{Colors.CYAN}/collaborate <task>{Colors.RESET} - Multi-agent collaboration
Available roles: coding, research, data_analysis, planning,
testing, documentation, orchestrator, general
Each agent has specialized capabilities and system prompts.
Agents can work independently or collaborate on complex tasks.
For detailed documentation: /help agents
{Colors.BOLD}KNOWLEDGE BASE{Colors.RESET}
{Colors.CYAN}/knowledge <query>{Colors.RESET} - Search knowledge base
{Colors.CYAN}/remember <content>{Colors.RESET} - Store information
Persistent storage for facts, procedures, and context.
Automatic categorization and TF-IDF search.
Integrated with agents for context injection.
For detailed documentation: /help knowledge
{Colors.BOLD}SESSION MANAGEMENT{Colors.RESET}
{Colors.CYAN}/history{Colors.RESET} - Show conversation history
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics
{Colors.BOLD}BACKGROUND SESSIONS{Colors.RESET}
{Colors.CYAN}/bg start <command>{Colors.RESET} - Start background session
{Colors.CYAN}/bg list{Colors.RESET} - List active sessions
{Colors.CYAN}/bg output <name>{Colors.RESET} - View session output
{Colors.CYAN}/bg kill <name>{Colors.RESET} - Terminate session
Run long-running processes while maintaining interactivity.
Automatic monitoring and event notifications.
For detailed documentation: /help background
{Colors.BOLD}CONTEXT FILES{Colors.RESET}
.rcontext.txt - Local project context (auto-loaded)
~/.rcontext.txt - Global context (auto-loaded)
-c, --context FILE - Additional context files (command line)
{Colors.BOLD}ENVIRONMENT VARIABLES{Colors.RESET}
OPENROUTER_API_KEY - API key for OpenRouter
AI_MODEL - Default model to use
API_URL - Custom API endpoint
USE_TOOLS - Enable/disable tools (default: 1)
{Colors.BOLD}COMMAND-LINE FLAGS{Colors.RESET}
-i, --interactive - Start in interactive mode
-v, --verbose - Enable verbose output
--debug - Enable debug logging
-m, --model MODEL - Specify AI model
--no-syntax - Disable syntax highlighting
{Colors.BOLD}DATA STORAGE{Colors.RESET}
~/.assistant_db.sqlite - SQLite database for persistence
~/.assistant_history - Command history
~/.assistant_error.log - Error logs
{Colors.BOLD}AVAILABLE TOOLS{Colors.RESET}
Tools are functions the AI can call to interact with the system:
- File operations: read, write, list, mkdir, search/replace
- Command execution: run_command, interactive sessions
- Web operations: http_fetch, web_search, web_search_news
- Database: db_set, db_get, db_query
- Python execution: python_exec with persistent globals
- Code editing: open_editor, insert/replace text, diff/patch
- Process management: tail, kill, interactive control
- Agents: create, execute tasks, collaborate
- Knowledge: add, search, categorize entries
{Colors.BOLD}GETTING HELP{Colors.RESET}
/help - This help message
/help workflows - Workflow system details
/help agents - Agent system details
/help knowledge - Knowledge base details
/help cache - Cache system details
/help background - Background sessions details
/tools - List all available tools
/models - List all available models
"""

View File

@ -1,8 +1,9 @@
import os
DEFAULT_MODEL = "x-ai/grok-code-fast-1"
DEFAULT_API_URL = "http://localhost:8118/ai/chat"
MODEL_LIST_URL = "http://localhost:8118/ai/models"
DEFAULT_API_URL = "https://static.molodetz.nl/rp.cgi/api/v1/chat/completions"
MODEL_LIST_URL = "https://static.molodetz.nl/rp.cgi/api/v1/models"
config_directory = os.path.expanduser("~/.local/share/rp")
os.makedirs(config_directory, exist_ok=True)

View File

@ -7,44 +7,45 @@ class AdvancedContextManager:
self.knowledge_store = knowledge_store
self.conversation_memory = conversation_memory
def adaptive_context_window(
self, messages: List[Dict[str, Any]], task_complexity: str = "medium"
) -> int:
complexity_thresholds = {
"simple": 10,
"medium": 20,
"complex": 35,
"very_complex": 50,
def adaptive_context_window(self, messages: List[Dict[str, Any]], complexity: str) -> int:
"""Calculate adaptive context window size based on message complexity."""
base_window = 10
complexity_multipliers = {
"simple": 1.0,
"medium": 2.0,
"complex": 3.5,
"very_complex": 5.0,
}
base_threshold = complexity_thresholds.get(task_complexity, 20)
message_complexity_score = self._analyze_message_complexity(messages)
if message_complexity_score > 0.7:
adjusted = int(base_threshold * 1.5)
elif message_complexity_score < 0.3:
adjusted = int(base_threshold * 0.7)
else:
adjusted = base_threshold
return max(base_threshold, adjusted)
multiplier = complexity_multipliers.get(complexity, 2.0)
return int(base_window * multiplier)
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
total_length = sum(len(msg.get("content", "")) for msg in messages)
avg_length = total_length / len(messages) if messages else 0
"""Analyze the complexity of messages and return a score between 0.0 and 1.0."""
if not messages:
return 0.0
unique_words = set()
for msg in messages:
content = msg.get("content", "")
words = re.findall(r"\b\w+\b", content.lower())
unique_words.update(words)
total_complexity = 0.0
for message in messages:
content = message.get("content", "")
if not content:
continue
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
# Calculate complexity based on various factors
word_count = len(content.split())
sentence_count = len(re.split(r"[.!?]+", content))
avg_word_length = sum(len(word) for word in content.split()) / max(word_count, 1)
# Simple complexity score based on length and richness
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
return complexity
# Complexity score based on length, vocabulary, and structure
length_score = min(1.0, word_count / 100) # Normalize to 0-1
structure_score = min(1.0, sentence_count / 10)
vocabulary_score = min(1.0, avg_word_length / 8)
message_complexity = (length_score + structure_score + vocabulary_score) / 3
total_complexity += message_complexity
return min(1.0, total_complexity / len(messages))
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
if not text.strip():
@ -82,3 +83,38 @@ class AdvancedContextManager:
return 0.0
return len(intersection) / len(union)
def create_enhanced_context(
self, messages: List[Dict[str, Any]], user_message: str, include_knowledge: bool = True
) -> tuple:
"""Create enhanced context with knowledge base integration."""
working_messages = messages.copy()
if include_knowledge and self.knowledge_store:
# Search knowledge base for relevant information
search_results = self.knowledge_store.search_entries(user_message, top_k=3)
if search_results:
knowledge_parts = []
for idx, entry in enumerate(search_results, 1):
content = entry.content
if len(content) > 2000:
content = content[:2000] + "..."
knowledge_parts.append(f"Match {idx} (Category: {entry.category}):\n{content}")
knowledge_message_content = (
"[KNOWLEDGE_BASE_CONTEXT]\nRelevant knowledge base entries:\n\n"
+ "\n\n".join(knowledge_parts)
)
knowledge_message = {"role": "user", "content": knowledge_message_content}
working_messages.append(knowledge_message)
context_info = f"Added {len(search_results)} knowledge base entries"
else:
context_info = "No relevant knowledge base entries found"
else:
context_info = "Knowledge base integration disabled"
return working_messages, context_info

View File

@ -67,6 +67,15 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
if "tool_calls" in msg:
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
if verbose and "usage" in result:
from pr.core.usage_tracker import UsageTracker
usage = result["usage"]
input_t = usage.get("prompt_tokens", 0)
output_t = usage.get("completion_tokens", 0)
cost = UsageTracker._calculate_cost(model, input_t, output_t)
print(f"API call cost: €{cost:.4f}")
logger.debug("=== API CALL END ===")
return result

View File

@ -29,34 +29,28 @@ from pr.core.background_monitor import (
stop_global_monitor,
)
from pr.core.context import init_system_message, truncate_tool_result
from pr.tools import (
apply_patch,
chdir,
create_diff,
db_get,
db_query,
db_set,
getpwd,
http_fetch,
index_source_directory,
kill_process,
list_directory,
mkdir,
python_exec,
read_file,
run_command,
search_replace,
tail_process,
web_search,
web_search_news,
write_file,
post_image,
from pr.tools import get_tools_definition
from pr.tools.agents import (
collaborate_agents,
create_agent,
execute_agent_task,
list_agents,
remove_agent,
)
from pr.tools.base import get_tools_definition
from pr.tools.command import kill_process, run_command, tail_process
from pr.tools.database import db_get, db_query, db_set
from pr.tools.filesystem import (
chdir,
clear_edit_tracker,
display_edit_summary,
display_edit_timeline,
getpwd,
index_source_directory,
list_directory,
mkdir,
read_file,
search_replace,
write_file,
)
from pr.tools.interactive_control import (
close_interactive_session,
@ -65,7 +59,18 @@ from pr.tools.interactive_control import (
send_input_to_session,
start_interactive_session,
)
from pr.tools.patch import display_file_diff
from pr.tools.memory import (
add_knowledge_entry,
delete_knowledge_entry,
get_knowledge_by_category,
get_knowledge_entry,
get_knowledge_statistics,
search_knowledge,
update_knowledge_importance,
)
from pr.tools.patch import apply_patch, create_diff, display_file_diff
from pr.tools.python_exec import python_exec
from pr.tools.web import http_fetch, web_search, web_search_news
from pr.ui import Colors, render_markdown
logger = logging.getLogger("pr")
@ -97,7 +102,7 @@ class Assistant:
"MODEL_LIST_URL", MODEL_LIST_URL
)
self.use_tools = os.environ.get("USE_TOOLS", "1") == "1"
self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1"
self.interrupt_count = 0
self.python_globals = {}
self.db_conn = None
@ -245,7 +250,6 @@ class Assistant:
logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
func_map = {
"post_image": lambda **kw: post_image(**kw),
"http_fetch": lambda **kw: http_fetch(**kw),
"run_command": lambda **kw: run_command(**kw),
"tail_process": lambda **kw: tail_process(**kw),

View File

@ -27,15 +27,6 @@ class BackgroundMonitor:
if self.monitor_thread:
self.monitor_thread.join(timeout=2)
def add_event_callback(self, callback):
"""Add a callback function to be called when events are detected."""
self.event_callbacks.append(callback)
def remove_event_callback(self, callback):
"""Remove an event callback."""
if callback in self.event_callbacks:
self.event_callbacks.remove(callback)
def get_pending_events(self):
"""Get all pending events from the queue."""
events = []
@ -224,30 +215,3 @@ def stop_global_monitor():
global _global_monitor
if _global_monitor:
_global_monitor.stop()
# Global monitor instance
_global_monitor = None
def start_global_monitor():
"""Start the global background monitor."""
global _global_monitor
if _global_monitor is None:
_global_monitor = BackgroundMonitor()
_global_monitor.start()
return _global_monitor
def stop_global_monitor():
"""Stop the global background monitor."""
global _global_monitor
if _global_monitor:
_global_monitor.stop()
_global_monitor = None
def get_global_monitor():
"""Get the global background monitor instance."""
global _global_monitor
return _global_monitor

View File

@ -1,7 +1,6 @@
import configparser
import os
from typing import Any, Dict
import uuid
from pr.core.logging import get_logger
logger = get_logger("config")
@ -11,29 +10,6 @@ CONFIG_FILE = os.path.join(CONFIG_DIRECTORY, ".prrc")
LOCAL_CONFIG_FILE = ".prrc"
def load_config() -> Dict[str, Any]:
os.makedirs(CONFIG_DIRECTORY, exist_ok=True)
config = {
"api": {},
"autonomous": {},
"ui": {},
"output": {},
"session": {},
"api_key": "rp-" + str(uuid.uuid4()),
}
global_config = _load_config_file(CONFIG_FILE)
local_config = _load_config_file(LOCAL_CONFIG_FILE)
for section in config.keys():
if section in global_config:
config[section].update(global_config[section])
if section in local_config:
config[section].update(local_config[section])
return config
def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
if not os.path.exists(filepath):
return {}

View File

@ -12,7 +12,6 @@ from pr.config import (
CONVERSATION_SUMMARY_THRESHOLD,
DB_PATH,
KNOWLEDGE_SEARCH_LIMIT,
MEMORY_AUTO_SUMMARIZE,
TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS,
)
@ -169,9 +168,9 @@ class EnhancedAssistant:
self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
)
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
# Automatically extract and store facts from every user message
facts = self.fact_extractor.extract_facts(user_message)
for fact in facts[:3]:
for fact in facts[:5]: # Store up to 5 facts per message
entry_id = str(uuid.uuid4())[:16]
import time
@ -182,7 +181,11 @@ class EnhancedAssistant:
entry_id=entry_id,
category=categories[0] if categories else "general",
content=fact["text"],
metadata={"type": fact["type"], "confidence": fact["confidence"]},
metadata={
"type": fact["type"],
"confidence": fact["confidence"],
"source": "user_message",
},
created_at=time.time(),
updated_at=time.time(),
)

View File

@ -0,0 +1,148 @@
import logging
logger = logging.getLogger("pr")
KNOWLEDGE_MESSAGE_MARKER = "[KNOWLEDGE_BASE_CONTEXT]"
def inject_knowledge_context(assistant, user_message):
if not hasattr(assistant, "enhanced") or not assistant.enhanced:
return
messages = assistant.messages
# Remove any existing knowledge context messages
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "user" and KNOWLEDGE_MESSAGE_MARKER in messages[i].get(
"content", ""
):
del messages[i]
logger.debug(f"Removed existing knowledge base message at index {i}")
break
try:
# Search knowledge base with enhanced FTS + semantic search
knowledge_results = assistant.enhanced.knowledge_store.search_entries(user_message, top_k=5)
# Search conversation history for related content
conversation_results = []
if hasattr(assistant.enhanced, "conversation_memory"):
history_results = assistant.enhanced.conversation_memory.search_conversations(
user_message, limit=3
)
for conv in history_results:
# Extract relevant messages from conversation
conv_messages = assistant.enhanced.conversation_memory.get_conversation_messages(
conv["conversation_id"]
)
for msg in conv_messages[-5:]: # Last 5 messages from each conversation
if msg["role"] == "user" and msg["content"] != user_message:
# Calculate relevance score
relevance = calculate_text_similarity(user_message, msg["content"])
if relevance > 0.3: # Only include relevant matches
conversation_results.append(
{
"content": msg["content"],
"score": relevance,
"source": f"Previous conversation: {conv['conversation_id'][:8]}",
}
)
# Combine and sort results by relevance score
all_results = []
# Add knowledge base results
for entry in knowledge_results:
score = entry.metadata.get("search_score", 0.5)
all_results.append(
{
"content": entry.content,
"score": score,
"source": f"Knowledge Base ({entry.category})",
"type": "knowledge",
}
)
# Add conversation results
for conv in conversation_results:
all_results.append(
{
"content": conv["content"],
"score": conv["score"],
"source": conv["source"],
"type": "conversation",
}
)
# Sort by score and take top 5
all_results.sort(key=lambda x: x["score"], reverse=True)
top_results = all_results[:5]
if not top_results:
logger.debug("No relevant knowledge or conversation matches found")
return
# Format context for LLM
knowledge_parts = []
for idx, result in enumerate(top_results, 1):
content = result["content"]
if len(content) > 1500: # Shorter limit for multiple results
content = content[:1500] + "..."
score_indicator = f"({result['score']:.2f})" if result["score"] < 1.0 else "(exact)"
knowledge_parts.append(
f"Match {idx} {score_indicator} - {result['source']}:\n{content}"
)
knowledge_message_content = (
f"{KNOWLEDGE_MESSAGE_MARKER}\nRelevant information from knowledge base and conversation history:\n\n"
+ "\n\n".join(knowledge_parts)
)
knowledge_message = {"role": "user", "content": knowledge_message_content}
messages.append(knowledge_message)
logger.debug(f"Injected enhanced context message with {len(top_results)} matches")
except Exception as e:
logger.error(f"Error injecting knowledge context: {e}")
def calculate_text_similarity(text1: str, text2: str) -> float:
"""Calculate similarity between two texts using word overlap and sequence matching."""
import re
# Normalize texts
text1_lower = text1.lower()
text2_lower = text2.lower()
# Exact substring match gets highest score
if text1_lower in text2_lower or text2_lower in text1_lower:
return 1.0
# Word-level similarity
words1 = set(re.findall(r"\b\w+\b", text1_lower))
words2 = set(re.findall(r"\b\w+\b", text2_lower))
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
word_similarity = len(intersection) / len(union)
# Bonus for consecutive word sequences (partial sentences)
consecutive_bonus = 0.0
words1_list = list(words1)
list(words2)
for i in range(len(words1_list) - 1):
for j in range(i + 2, min(i + 5, len(words1_list) + 1)):
phrase = " ".join(words1_list[i:j])
if phrase in text2_lower:
consecutive_bonus += 0.1 * (j - i)
total_similarity = min(1.0, word_similarity + consecutive_bonus)
return total_similarity

View File

@ -9,8 +9,10 @@ logger = get_logger("usage")
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
EXCHANGE_RATE = 1.0 # Keep in USD
MODEL_COSTS = {
"x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0},
"x-ai/grok-code-fast-1": {"input": 0.0002, "output": 0.0015}, # per 1000 tokens in USD
"gpt-4": {"input": 0.03, "output": 0.06},
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
@ -64,9 +66,10 @@ class UsageTracker:
self._save_to_history(model, input_tokens, output_tokens, cost)
logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}")
logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: {cost:.4f}")
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
@staticmethod
def _calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
if model not in MODEL_COSTS:
base_model = model.split("/")[0] if "/" in model else model
if base_model not in MODEL_COSTS:
@ -76,8 +79,8 @@ class UsageTracker:
else:
costs = MODEL_COSTS[model]
input_cost = (input_tokens / 1000) * costs["input"]
output_cost = (output_tokens / 1000) * costs["output"]
input_cost = (input_tokens / 1000) * costs["input"] * EXCHANGE_RATE
output_cost = (output_tokens / 1000) * costs["output"] * EXCHANGE_RATE
return input_cost + output_cost

View File

@ -370,8 +370,10 @@ class RPEditor:
self.cursor_x = 0
elif key == ord("u"):
self.undo()
elif key == ord("r") and self.prev_key == 18: # Ctrl-R
elif key == 18: # Ctrl-R
self.redo()
elif key == 19: # Ctrl-S
self._save_file()
self.prev_key = key
except Exception:

589
pr/implode.py Normal file
View File

@ -0,0 +1,589 @@
#!/usr/bin/env python3
"""
impLODE: A Python script to consolidate a multi-file Python project
into a single, runnable file.
It intelligently resolves local imports, hoists external dependencies to the top,
and preserves the core logic, using AST for safe transformations.
"""
import os
import sys
import ast
import argparse
import logging
import py_compile # Added
from typing import Set, Dict, Optional, TextIO
# Set up the main logger
logger = logging.getLogger("impLODE")
class ImportTransformer(ast.NodeTransformer):
"""
An AST transformer that visits Import and ImportFrom nodes.
On Pass 1 (Dry Run):
- Identifies all local vs. external imports.
- Recursively calls the main resolver for local modules.
- Stores external and __future__ imports in the Imploder instance.
On Pass 2 (Write Run):
- Recursively calls the main resolver for local modules.
- Removes all import statements (since they were hoisted in Pass 1).
"""
def __init__(
self,
imploder: "Imploder",
current_file_path: str,
f_out: Optional[TextIO],
is_dry_run: bool,
indent_level: int = 0,
):
self.imploder = imploder
self.current_file_path = current_file_path
self.current_dir = os.path.dirname(current_file_path)
self.f_out = f_out
self.is_dry_run = is_dry_run
self.indent = " " * indent_level
self.logger = logging.getLogger(self.__class__.__name__)
def _log_debug(self, msg: str):
"""Helper for indented debug logging."""
self.logger.debug(f"{self.indent} > {msg}")
def _find_local_module(self, module_name: str, level: int) -> Optional[str]:
"""
Tries to find the absolute path for a given module name and relative level.
Returns None if it's not a local module *and* cannot be found in site-packages.
"""
import importlib.util # Added
import importlib.machinery # Added
if not module_name:
# Handle cases like `from . import foo`
# The module_name is 'foo', handled by visit_ImportFrom's aliases.
# If it's just `from . import`, module_name is None.
# We'll resolve from the alias names.
# Determine base path for relative import
base_path = self.current_dir
if level > 0:
# Go up 'level' directories (minus one for current dir)
for _ in range(level - 1):
base_path = os.path.dirname(base_path)
return base_path
base_path = self.current_dir
if level > 0:
# Go up 'level' directories (minus one for current dir)
for _ in range(level - 1):
base_path = os.path.dirname(base_path)
else:
# Absolute import, start from the root
base_path = self.imploder.root_dir
module_parts = module_name.split(".")
module_path = os.path.join(base_path, *module_parts)
# Check for package `.../module/__init__.py`
package_init = os.path.join(module_path, "__init__.py")
if os.path.isfile(package_init):
self._log_debug(
f"Resolved '{module_name}' to local package: {os.path.relpath(package_init, self.imploder.root_dir)}"
)
return package_init
# Check for module `.../module.py`
module_py = module_path + ".py"
if os.path.isfile(module_py):
self._log_debug(
f"Resolved '{module_name}' to local module: {os.path.relpath(module_py, self.imploder.root_dir)}"
)
return module_py
# --- FALLBACK: Deep search from root ---
# This is less efficient but helps with non-standard project structures
# where the "root" for imports isn't the same as the Imploder's root.
if level == 0:
self._log_debug(
f"Module '{module_name}' not found at primary path. Starting deep fallback search from {self.imploder.root_dir}..."
)
target_path_py = os.path.join(*module_parts) + ".py"
target_path_init = os.path.join(*module_parts, "__init__.py")
for dirpath, dirnames, filenames in os.walk(self.imploder.root_dir, topdown=True):
# Prune common virtual envs and metadata dirs to speed up search
dirnames[:] = [
d
for d in dirnames
if not d.startswith(".")
and d not in ("venv", "env", ".venv", ".env", "__pycache__", "node_modules")
]
# Check for .../module/file.py
check_file_py = os.path.join(dirpath, target_path_py)
if os.path.isfile(check_file_py):
self._log_debug(
f"Fallback search found module: {os.path.relpath(check_file_py, self.imploder.root_dir)}"
)
return check_file_py
# Check for .../module/file/__init__.py
check_file_init = os.path.join(dirpath, target_path_init)
if os.path.isfile(check_file_init):
self._log_debug(
f"Fallback search found package: {os.path.relpath(check_file_init, self.imploder.root_dir)}"
)
return check_file_init
# --- Removed site-packages/external module inlining logic ---
# As requested, the script will now only resolve local modules.
# Not found locally
return None
def visit_Import(self, node: ast.Import) -> Optional[ast.AST]:
"""Handles `import foo` or `import foo.bar`."""
for alias in node.names:
module_path = self._find_local_module(alias.name, level=0)
if module_path:
# Local module
self._log_debug(f"Resolving local import: `import {alias.name}`")
self.imploder.resolve_file(
file_abs_path=module_path,
f_out=self.f_out,
is_dry_run=self.is_dry_run,
indent_level=self.imploder.current_indent_level,
)
else:
# External module
self._log_debug(f"Found external import: `import {alias.name}`")
if self.is_dry_run:
# On dry run, collect it for hoisting
key = f"import {alias.name}"
if key not in self.imploder.external_imports:
self.imploder.external_imports[key] = node
# On Pass 2 (write run), we remove this node by replacing it
# with a dummy function call to prevent IndentationErrors.
# On Pass 1 (dry run), we also replace it.
module_names = ", ".join([a.name for a in node.names])
new_call = ast.Call(
func=ast.Name(id="_implode_log_import", ctx=ast.Load()),
args=[
ast.Constant(value=module_names),
ast.Constant(value=0),
ast.Constant(value="import"),
],
keywords=[],
)
return ast.Expr(value=new_call)
def visit_ImportFrom(self, node: ast.ImportFrom) -> Optional[ast.AST]:
"""Handles `from foo import bar` or `from .foo import bar`."""
# --- NEW: Create the replacement node first ---
module_name_str = node.module or ""
import_type = "from-import"
if module_name_str == "__future__":
import_type = "future-import"
new_call = ast.Call(
func=ast.Name(id="_implode_log_import", ctx=ast.Load()),
args=[
ast.Constant(value=module_name_str),
ast.Constant(value=node.level),
ast.Constant(value=import_type),
],
keywords=[],
)
replacement_node = ast.Expr(value=new_call)
# Handle `from __future__ import ...`
if node.module == "__future__":
self._log_debug("Found __future__ import. Hoisting to top.")
if self.is_dry_run:
key = ast.unparse(node)
self.imploder.future_imports[key] = node
return replacement_node # Return replacement
module_path = self._find_local_module(node.module or "", node.level)
if module_path and os.path.isdir(module_path):
# This is a package import like `from . import foo`
# `module_path` is the directory of the package (e.g., self.current_dir)
self._log_debug(f"Resolving package import: `from {node.module or '.'} import ...`")
for alias in node.names:
# --- FIX ---
# Resolve `foo` relative to the package directory (`module_path`)
# Check for module `.../package_dir/foo.py`
package_module_py = os.path.join(module_path, alias.name + ".py")
# Check for sub-package `.../package_dir/foo/__init__.py`
package_module_init = os.path.join(module_path, alias.name, "__init__.py")
if os.path.isfile(package_module_py):
self._log_debug(
f"Found sub-module: {os.path.relpath(package_module_py, self.imploder.root_dir)}"
)
self.imploder.resolve_file(
file_abs_path=package_module_py,
f_out=self.f_out,
is_dry_run=self.is_dry_run,
indent_level=self.imploder.current_indent_level,
)
elif os.path.isfile(package_module_init):
self._log_debug(
f"Found sub-package: {os.path.relpath(package_module_init, self.imploder.root_dir)}"
)
self.imploder.resolve_file(
file_abs_path=package_module_init,
f_out=self.f_out,
is_dry_run=self.is_dry_run,
indent_level=self.imploder.current_indent_level,
)
else:
# Could be another sub-package, etc.
# This logic can be expanded, but for now, log a warning
self.logger.warning(
f"{self.indent} > Could not resolve sub-module '{alias.name}' in package '{module_path}'"
)
return replacement_node # Return replacement
# --- END FIX ---
if module_path:
# Local module (e.g., `from .foo import bar` or `from myapp.foo import bar`)
self._log_debug(f"Resolving local from-import: `from {node.module or '.'} ...`")
self.imploder.resolve_file(
file_abs_path=module_path,
f_out=self.f_out,
is_dry_run=self.is_dry_run,
indent_level=self.imploder.current_indent_level,
)
else:
# External module
self._log_debug(f"Found external from-import: `from {node.module or '.'} ...`")
if self.is_dry_run:
# On dry run, collect it for hoisting
key = ast.unparse(node)
if key not in self.imploder.external_imports:
self.imploder.external_imports[key] = node
# On Pass 2 (write run), we remove this node entirely by replacing it
# On Pass 1 (dry run), we also replace it
return replacement_node
class Imploder:
"""
Core class for handling the implosion process.
Manages state, file processing, and the two-pass analysis.
"""
def __init__(self, root_dir: str, enable_import_logging: bool = False):
self.root_dir = os.path.realpath(root_dir)
self.processed_files: Set[str] = set()
self.external_imports: Dict[str, ast.AST] = {}
self.future_imports: Dict[str, ast.AST] = {}
self.current_indent_level = 0
self.enable_import_logging = enable_import_logging # Added
logger.info(f"Initialized Imploder with root: {self.root_dir}")
def implode(self, main_file_abs_path: str, output_file_path: str):
"""
Runs the full two-pass implosion process.
"""
if not os.path.isfile(main_file_abs_path):
logger.critical(f"Main file not found: {main_file_abs_path}")
sys.exit(1)
# --- Pass 1: Analyze Dependencies (Dry Run) ---
logger.info(
f"--- PASS 1: Analyzing dependencies from {os.path.relpath(main_file_abs_path, self.root_dir)} ---"
)
self.processed_files.clear()
self.external_imports.clear()
self.future_imports.clear()
try:
self.resolve_file(main_file_abs_path, f_out=None, is_dry_run=True, indent_level=0)
except Exception as e:
logger.critical(f"Error during analysis pass: {e}", exc_info=True)
sys.exit(1)
logger.info(
f"--- Analysis complete. Found {len(self.future_imports)} __future__ imports and {len(self.external_imports)} external modules. ---"
)
# --- Pass 2: Write Imploded File ---
logger.info(f"--- PASS 2: Writing imploded file to {output_file_path} ---")
self.processed_files.clear()
try:
with open(output_file_path, "w", encoding="utf-8") as f_out:
# Write headers
f_out.write(f"#!/usr/bin/env python3\n")
f_out.write(f"# -*- coding: utf-8 -*-\n")
f_out.write(f"import logging\n")
f_out.write(f"\n# --- IMPLODED FILE: Generated by impLODE --- #\n")
f_out.write(
f"# --- Original main file: {os.path.relpath(main_file_abs_path, self.root_dir)} --- #\n"
)
# Write __future__ imports
if self.future_imports:
f_out.write("\n# --- Hoisted __future__ Imports --- #\n")
for node in self.future_imports.values():
f_out.write(f"{ast.unparse(node)}\n")
f_out.write("# --- End __future__ Imports --- #\n")
# --- NEW: Write Helper Function ---
enable_logging_str = "True" if self.enable_import_logging else "False"
f_out.write("\n# --- impLODE Helper Function --- #\n")
f_out.write(f"_IMPLODE_LOGGING_ENABLED_ = {enable_logging_str}\n")
f_out.write("def _implode_log_import(module_name, level, import_type):\n")
f_out.write(
' """Dummy function to replace imports and prevent IndentationErrors."""\n'
)
f_out.write(" if _IMPLODE_LOGGING_ENABLED_:\n")
f_out.write(
" print(f\"[impLODE Logger]: Skipped {import_type}: module='{module_name}', level={level}\")\n"
)
f_out.write(" pass\n")
f_out.write("# --- End Helper Function --- #\n")
# Write external imports
if self.external_imports:
f_out.write("\n# --- Hoisted External Imports --- #\n")
for node in self.external_imports.values():
# As requested, wrap each external import in a try/except block
f_out.write("try:\n")
f_out.write(f" {ast.unparse(node)}\n")
f_out.write("except ImportError:\n")
f_out.write(" pass\n")
f_out.write("# --- End External Imports --- #\n")
# Write the actual code
self.resolve_file(main_file_abs_path, f_out=f_out, is_dry_run=False, indent_level=0)
except IOError as e:
logger.critical(
f"Could not write to output file {output_file_path}: {e}", exc_info=True
)
sys.exit(1)
except Exception as e:
logger.critical(f"Error during write pass: {e}", exc_info=True)
sys.exit(1)
logger.info(f"--- Implosion complete! Output saved to {output_file_path} ---")
def resolve_file(
self, file_abs_path: str, f_out: Optional[TextIO], is_dry_run: bool, indent_level: int = 0
):
"""
Recursively resolves a single file.
- `is_dry_run=True`: Analyzes imports, populating `external_imports`.
- `is_dry_run=False`: Writes transformed code to `f_out`.
"""
self.current_indent_level = indent_level
indent = " " * indent_level
try:
file_real_path = os.path.realpath(file_abs_path)
# --- Reverted path logic as we only handle local files ---
rel_path = os.path.relpath(file_real_path, self.root_dir)
# --- End Revert ---
except ValueError:
# Happens if file is on a different drive than root_dir on Windows
logger.warning(
f"{indent}Cannot calculate relative path for {file_abs_path}. Using absolute."
)
rel_path = file_abs_path
# --- 1. Circular Dependency Check ---
if file_real_path in self.processed_files:
logger.debug(f"{indent}Skipping already processed file: {rel_path}")
return
logger.info(f"{indent}Processing: {rel_path}")
self.processed_files.add(file_real_path)
# --- 2. Read File ---
try:
with open(file_real_path, "r", encoding="utf-8") as f:
code = f.read()
except FileNotFoundError:
logger.error(f"{indent}File not found: {file_real_path}")
return
except UnicodeDecodeError:
logger.error(f"{indent}Could not decode file (not utf-8): {file_real_path}")
return
except Exception as e:
logger.error(f"{indent}Could not read file {file_real_path}: {e}")
return
# --- 2.5. Validate Syntax with py_compile ---
try:
# As requested, validate syntax using py_compile before AST parsing
py_compile.compile(file_real_path, doraise=True, quiet=1)
logger.debug(f"{indent}Syntax OK (py_compile): {rel_path}")
except py_compile.PyCompileError as e:
# This error includes file, line, and message
logger.error(
f"{indent}Syntax error (py_compile) in {e.file} on line {e.lineno}: {e.msg}"
)
return
except Exception as e:
# Catch other potential errors like permission issues
logger.error(f"{indent}Error during py_compile for {rel_path}: {e}")
return
# --- 3. Parse AST ---
try:
tree = ast.parse(code, filename=file_real_path)
except SyntaxError as e:
logger.error(f"{indent}Syntax error in {rel_path} on line {e.lineno}: {e.msg}")
return
except Exception as e:
logger.error(f"{indent}Could not parse AST for {rel_path}: {e}")
return
# --- 4. Transform AST ---
transformer = ImportTransformer(
imploder=self,
current_file_path=file_real_path,
f_out=f_out,
is_dry_run=is_dry_run,
indent_level=indent_level,
)
try:
new_tree = transformer.visit(tree)
except Exception as e:
logger.error(f"{indent}Error transforming AST for {rel_path}: {e}", exc_info=True)
return
# --- 5. Write Content (on Pass 2) ---
if not is_dry_run and f_out:
try:
ast.fix_missing_locations(new_tree)
f_out.write(f"\n\n# --- Content from {rel_path} --- #\n")
f_out.write(ast.unparse(new_tree))
f_out.write(f"\n# --- End of {rel_path} --- #\n")
logger.info(f"{indent}Successfully wrote content from: {rel_path}")
except Exception as e:
logger.error(
f"{indent}Could not unparse or write AST for {rel_path}: {e}", exc_info=True
)
self.current_indent_level = indent_level
def setup_logging(level: int):
"""Configures the root logger."""
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
# Configure the 'impLODE' logger
logger.setLevel(level)
logger.addHandler(handler)
# Configure the 'ImportTransformer' logger
transformer_logger = logging.getLogger("ImportTransformer")
transformer_logger.setLevel(level)
transformer_logger.addHandler(handler)
def main():
"""Main entry point for the script."""
parser = argparse.ArgumentParser(
description="impLODE: Consolidate a multi-file Python project into one file.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"main_file", type=str, help="The main entry point .py file of your project."
)
parser.add_argument(
"-o",
"--output",
type=str,
default="imploded.py",
help="Path for the combined output file. (default: imploded.py)",
)
parser.add_argument(
"-r",
"--root",
type=str,
default=".",
help="The root directory of the project for resolving absolute imports. (default: current directory)",
)
log_group = parser.add_mutually_exclusive_group()
log_group.add_argument(
"-v",
"--verbose",
action="store_const",
dest="log_level",
const=logging.DEBUG,
default=logging.INFO,
help="Enable verbose DEBUG logging.",
)
log_group.add_argument(
"-q",
"--quiet",
action="store_const",
dest="log_level",
const=logging.WARNING,
help="Suppress INFO logs, showing only WARNINGS and ERRORS.",
)
parser.add_argument(
"--enable-import-logging",
action="store_true",
help="Enable runtime logging for removed import statements in the final imploded script.",
)
args = parser.parse_args()
# --- Setup ---
setup_logging(args.log_level)
root_dir = os.path.abspath(args.root)
main_file_path = os.path.abspath(args.main_file)
output_file_path = os.path.abspath(args.output)
# --- Validation ---
if not os.path.isdir(root_dir):
logger.critical(f"Root directory not found: {root_dir}")
sys.exit(1)
if not os.path.isfile(main_file_path):
logger.critical(f"Main file not found: {main_file_path}")
sys.exit(1)
if not main_file_path.startswith(root_dir):
logger.warning(f"Main file {main_file_path} is outside the specified root {root_dir}.")
logger.warning("This may cause issues with absolute import resolution.")
if main_file_path == output_file_path:
logger.critical("Output file cannot be the same as the main file.")
sys.exit(1)
# --- Run ---
imploder = Imploder(
root_dir=root_dir, enable_import_logging=args.enable_import_logging # Added
)
imploder.implode(main_file_abs_path=main_file_path, output_file_path=output_file_path)
if __name__ == "__main__":
main()

View File

@ -56,10 +56,14 @@ class FactExtractor:
else:
if len(current_phrase) >= 2:
phrases.append(" ".join(current_phrase))
elif len(current_phrase) == 1:
phrases.append(current_phrase[0]) # Single capitalized words
current_phrase = []
if len(current_phrase) >= 2:
phrases.append(" ".join(current_phrase))
elif len(current_phrase) == 1:
phrases.append(current_phrase[0]) # Single capitalized words
return list(set(phrases))
@ -165,13 +169,14 @@ class FactExtractor:
return relationships
def extract_metadata(self, text: str) -> Dict[str, Any]:
word_count = len(text.split())
sentence_count = len(re.split(r"[.!?]", text))
word_count = len(text.split()) if text.strip() else 0
sentences = re.split(r"[.!?]", text.strip())
sentence_count = len([s for s in sentences if s.strip()]) if text.strip() else 0
urls = re.findall(r"https?://[^\s]+", text)
email_addresses = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text)
dates = re.findall(
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b|\b\d{4}\b", text
)
numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text)

View File

@ -1,8 +1,9 @@
import json
import sqlite3
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from .semantic_index import SemanticIndex
@ -35,11 +36,13 @@ class KnowledgeStore:
def __init__(self, db_path: str):
self.db_path = db_path
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.lock = threading.Lock()
self.semantic_index = SemanticIndex()
self._initialize_store()
self._load_index()
def _initialize_store(self):
with self.lock:
cursor = self.conn.cursor()
cursor.execute(
@ -76,6 +79,7 @@ class KnowledgeStore:
self.conn.commit()
def _load_index(self):
with self.lock:
cursor = self.conn.cursor()
cursor.execute("SELECT entry_id, content FROM knowledge_entries")
@ -83,6 +87,7 @@ class KnowledgeStore:
self.semantic_index.add_document(row[0], row[1])
def add_entry(self, entry: KnowledgeEntry):
with self.lock:
cursor = self.conn.cursor()
cursor.execute(
@ -108,6 +113,7 @@ class KnowledgeStore:
self.semantic_index.add_document(entry.entry_id, entry.content)
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
with self.lock:
cursor = self.conn.cursor()
cursor.execute(
@ -148,12 +154,34 @@ class KnowledgeStore:
def search_entries(
self, query: str, category: Optional[str] = None, top_k: int = 5
) -> List[KnowledgeEntry]:
search_results = self.semantic_index.search(query, top_k * 2)
# Combine semantic search with exact matching
semantic_results = self.semantic_index.search(query, top_k * 2)
# Add FTS (Full Text Search) with exact word/phrase matching
fts_results = self._fts_search(query, top_k * 2)
# Combine and deduplicate results with weighted scoring
combined_results = {}
# Add semantic results with weight 0.7
for entry_id, score in semantic_results:
combined_results[entry_id] = score * 0.7
# Add FTS results with weight 1.0 (higher priority for exact matches)
for entry_id, score in fts_results:
if entry_id in combined_results:
combined_results[entry_id] = max(combined_results[entry_id], score * 1.0)
else:
combined_results[entry_id] = score * 1.0
# Sort by combined score
sorted_results = sorted(combined_results.items(), key=lambda x: x[1], reverse=True)
with self.lock:
cursor = self.conn.cursor()
entries = []
for entry_id, score in search_results:
for entry_id, score in sorted_results[:top_k]:
if category:
cursor.execute(
"""
@ -185,14 +213,71 @@ class KnowledgeStore:
access_count=row[6],
importance_score=row[7],
)
# Add search score to metadata for context
entry.metadata["search_score"] = score
entries.append(entry)
if len(entries) >= top_k:
break
return entries
def _fts_search(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
"""Full Text Search with exact word and partial sentence matching."""
with self.lock:
cursor = self.conn.cursor()
# Prepare query for FTS
query_lower = query.lower()
query_words = query_lower.split()
# Search for exact phrase matches first
cursor.execute(
"""
SELECT entry_id, content
FROM knowledge_entries
WHERE LOWER(content) LIKE ?
""",
(f"%{query_lower}%",),
)
exact_matches = []
partial_matches = []
for row in cursor.fetchall():
entry_id, content = row
content_lower = content.lower()
# Exact phrase match gets highest score
if query_lower in content_lower:
exact_matches.append((entry_id, 1.0))
continue
# Count matching words
content_words = set(content_lower.split())
query_word_set = set(query_words)
matching_words = len(query_word_set & content_words)
if matching_words > 0:
# Score based on word overlap and position
word_overlap_score = matching_words / len(query_word_set)
# Bonus for consecutive word sequences
consecutive_bonus = 0.0
for i in range(len(query_words)):
for j in range(i + 1, min(i + 4, len(query_words) + 1)):
phrase = " ".join(query_words[i:j])
if phrase in content_lower:
consecutive_bonus += 0.2 * (j - i)
total_score = min(0.99, word_overlap_score + consecutive_bonus)
partial_matches.append((entry_id, total_score))
# Combine results, prioritizing exact matches
all_results = exact_matches + partial_matches
all_results.sort(key=lambda x: x[1], reverse=True)
return all_results[:top_k]
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
with self.lock:
cursor = self.conn.cursor()
cursor.execute(
@ -224,6 +309,7 @@ class KnowledgeStore:
return entries
def update_importance(self, entry_id: str, importance_score: float):
with self.lock:
cursor = self.conn.cursor()
cursor.execute(
@ -238,6 +324,7 @@ class KnowledgeStore:
self.conn.commit()
def delete_entry(self, entry_id: str) -> bool:
with self.lock:
cursor = self.conn.cursor()
cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,))
@ -251,6 +338,7 @@ class KnowledgeStore:
return deleted
def get_statistics(self) -> Dict[str, Any]:
with self.lock:
cursor = self.conn.cursor()
cursor.execute("SELECT COUNT(*) FROM knowledge_entries")

View File

@ -9,7 +9,7 @@ class SemanticIndex:
self.documents: Dict[str, str] = {}
self.vocabulary: Set[str] = set()
self.idf_scores: Dict[str, float] = {}
self.doc_vectors: Dict[str, Dict[str, float]] = {}
self.doc_tf_scores: Dict[str, Dict[str, float]] = {}
def _tokenize(self, text: str) -> List[str]:
text = text.lower()
@ -46,22 +46,26 @@ class SemanticIndex:
tokens = self._tokenize(text)
self.vocabulary.update(tokens)
self._compute_idf()
tf_scores = self._compute_tf(tokens)
self.doc_vectors[doc_id] = {
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) for token in tokens
}
self.doc_tf_scores[doc_id] = tf_scores
self._compute_idf()
def remove_document(self, doc_id: str):
if doc_id in self.documents:
del self.documents[doc_id]
if doc_id in self.doc_vectors:
del self.doc_vectors[doc_id]
if doc_id in self.doc_tf_scores:
del self.doc_tf_scores[doc_id]
self._compute_idf()
def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
if not query.strip():
return []
query_tokens = self._tokenize(query)
if not query_tokens:
return []
query_tf = self._compute_tf(query_tokens)
query_vector = {
@ -69,7 +73,10 @@ class SemanticIndex:
}
scores = []
for doc_id, doc_vector in self.doc_vectors.items():
for doc_id, doc_tf in self.doc_tf_scores.items():
doc_vector = {
token: doc_tf.get(token, 0) * self.idf_scores.get(token, 0) for token in doc_tf
}
similarity = self._cosine_similarity(query_vector, doc_vector)
scores.append((doc_id, similarity))

384
pr/multiplexer.py.bak Normal file
View File

@ -0,0 +1,384 @@
import queue
import subprocess
import sys
import threading
import time
from pr.tools.process_handlers import detect_process_type, get_handler_for_process
from pr.tools.prompt_detection import get_global_detector
from pr.ui import Colors
class TerminalMultiplexer:
def __init__(self, name, show_output=True):
self.name = name
self.show_output = show_output
self.stdout_buffer = []
self.stderr_buffer = []
self.stdout_queue = queue.Queue()
self.stderr_queue = queue.Queue()
self.active = True
self.lock = threading.Lock()
self.metadata = {
"start_time": time.time(),
"last_activity": time.time(),
"interaction_count": 0,
"process_type": "unknown",
"state": "active",
}
self.handler = None
self.prompt_detector = get_global_detector()
if self.show_output:
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
self.display_thread.start()
def _display_worker(self):
while self.active:
try:
line = self.stdout_queue.get(timeout=0.1)
if line:
if self.metadata.get("process_type") in ["vim", "ssh"]:
sys.stdout.write(line)
else:
sys.stdout.write(f"{Colors.GRAY}[{self.name}]{Colors.RESET} {line}\n")
sys.stdout.flush()
except queue.Empty:
pass
try:
line = self.stderr_queue.get(timeout=0.1)
if line:
if self.metadata.get("process_type") in ["vim", "ssh"]:
sys.stderr.write(line)
else:
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}\n")
sys.stderr.flush()
except queue.Empty:
pass
def write_stdout(self, data):
with self.lock:
self.stdout_buffer.append(data)
self.metadata["last_activity"] = time.time()
# Update handler state if available
if self.handler:
self.handler.update_state(data)
# Update prompt detector
self.prompt_detector.update_session_state(
self.name, data, self.metadata["process_type"]
)
if self.show_output:
self.stdout_queue.put(data)
def write_stderr(self, data):
with self.lock:
self.stderr_buffer.append(data)
self.metadata["last_activity"] = time.time()
# Update handler state if available
if self.handler:
self.handler.update_state(data)
# Update prompt detector
self.prompt_detector.update_session_state(
self.name, data, self.metadata["process_type"]
)
if self.show_output:
self.stderr_queue.put(data)
def get_stdout(self):
with self.lock:
return "".join(self.stdout_buffer)
def get_stderr(self):
with self.lock:
return "".join(self.stderr_buffer)
def get_all_output(self):
with self.lock:
return {
"stdout": "".join(self.stdout_buffer),
"stderr": "".join(self.stderr_buffer),
}
def get_metadata(self):
with self.lock:
return self.metadata.copy()
def update_metadata(self, key, value):
with self.lock:
self.metadata[key] = value
def set_process_type(self, process_type):
"""Set the process type and initialize appropriate handler."""
with self.lock:
self.metadata["process_type"] = process_type
self.handler = get_handler_for_process(process_type, self)
def send_input(self, input_data):
if hasattr(self, "process") and self.process.poll() is None:
try:
self.process.stdin.write(input_data + "\n")
self.process.stdin.flush()
with self.lock:
self.metadata["last_activity"] = time.time()
self.metadata["interaction_count"] += 1
except Exception as e:
self.write_stderr(f"Error sending input: {e}")
else:
# This will be implemented when we have a process attached
# For now, just update activity
with self.lock:
self.metadata["last_activity"] = time.time()
self.metadata["interaction_count"] += 1
def close(self):
self.active = False
if hasattr(self, "display_thread"):
self.display_thread.join(timeout=1)
_multiplexers = {}
_mux_counter = 0
_mux_lock = threading.Lock()
_background_monitor = None
_monitor_active = False
_monitor_interval = 0.2 # 200ms
def create_multiplexer(name=None, show_output=True):
global _mux_counter
with _mux_lock:
if name is None:
_mux_counter += 1
name = f"process-{_mux_counter}"
mux = TerminalMultiplexer(name, show_output)
_multiplexers[name] = mux
return name, mux
def get_multiplexer(name):
return _multiplexers.get(name)
def close_multiplexer(name):
mux = _multiplexers.get(name)
if mux:
mux.close()
del _multiplexers[name]
def get_all_multiplexer_states():
with _mux_lock:
states = {}
for name, mux in _multiplexers.items():
states[name] = {
"metadata": mux.get_metadata(),
"output_summary": {
"stdout_lines": len(mux.stdout_buffer),
"stderr_lines": len(mux.stderr_buffer),
},
}
return states
def cleanup_all_multiplexers():
for mux in list(_multiplexers.values()):
mux.close()
_multiplexers.clear()
# Background process management
_background_processes = {}
_process_lock = threading.Lock()
class BackgroundProcess:
def __init__(self, name, command):
self.name = name
self.command = command
self.process = None
self.multiplexer = None
self.status = "starting"
self.start_time = time.time()
self.end_time = None
def start(self):
"""Start the background process."""
try:
# Create multiplexer for this process
mux_name, mux = create_multiplexer(self.name, show_output=False)
self.multiplexer = mux
# Detect process type
process_type = detect_process_type(self.command)
mux.set_process_type(process_type)
# Start the subprocess
self.process = subprocess.Popen(
self.command,
shell=True,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True,
)
self.status = "running"
# Start output monitoring threads
threading.Thread(target=self._monitor_stdout, daemon=True).start()
threading.Thread(target=self._monitor_stderr, daemon=True).start()
return {"status": "success", "pid": self.process.pid}
except Exception as e:
self.status = "error"
return {"status": "error", "error": str(e)}
def _monitor_stdout(self):
"""Monitor stdout from the process."""
try:
for line in iter(self.process.stdout.readline, ""):
if line:
self.multiplexer.write_stdout(line.rstrip("\n\r"))
except Exception as e:
self.write_stderr(f"Error reading stdout: {e}")
finally:
self._check_completion()
def _monitor_stderr(self):
"""Monitor stderr from the process."""
try:
for line in iter(self.process.stderr.readline, ""):
if line:
self.multiplexer.write_stderr(line.rstrip("\n\r"))
except Exception as e:
self.write_stderr(f"Error reading stderr: {e}")
def _check_completion(self):
"""Check if process has completed."""
if self.process and self.process.poll() is not None:
self.status = "completed"
self.end_time = time.time()
def get_info(self):
"""Get process information."""
self._check_completion()
return {
"name": self.name,
"command": self.command,
"status": self.status,
"pid": self.process.pid if self.process else None,
"start_time": self.start_time,
"end_time": self.end_time,
"runtime": (
time.time() - self.start_time
if not self.end_time
else self.end_time - self.start_time
),
}
def get_output(self, lines=None):
"""Get process output."""
if not self.multiplexer:
return []
all_output = self.multiplexer.get_all_output()
stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else []
stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else []
combined = stdout_lines + stderr_lines
if lines:
combined = combined[-lines:]
return [line for line in combined if line.strip()]
def send_input(self, input_text):
"""Send input to the process."""
if self.process and self.status == "running":
try:
self.process.stdin.write(input_text + "\n")
self.process.stdin.flush()
return {"status": "success"}
except Exception as e:
return {"status": "error", "error": str(e)}
return {"status": "error", "error": "Process not running or no stdin"}
def kill(self):
"""Kill the process."""
if self.process and self.status == "running":
try:
self.process.terminate()
# Wait a bit for graceful termination
time.sleep(0.1)
if self.process.poll() is None:
self.process.kill()
self.status = "killed"
self.end_time = time.time()
return {"status": "success"}
except Exception as e:
return {"status": "error", "error": str(e)}
return {"status": "error", "error": "Process not running"}
def start_background_process(name, command):
"""Start a background process."""
with _process_lock:
if name in _background_processes:
return {"status": "error", "error": f"Process {name} already exists"}
process = BackgroundProcess(name, command)
result = process.start()
if result["status"] == "success":
_background_processes[name] = process
return result
def get_all_sessions():
"""Get all background process sessions."""
with _process_lock:
sessions = {}
for name, process in _background_processes.items():
sessions[name] = process.get_info()
return sessions
def get_session_info(name):
"""Get information about a specific session."""
with _process_lock:
process = _background_processes.get(name)
return process.get_info() if process else None
def get_session_output(name, lines=None):
"""Get output from a specific session."""
with _process_lock:
process = _background_processes.get(name)
return process.get_output(lines) if process else None
def send_input_to_session(name, input_text):
"""Send input to a background session."""
with _process_lock:
process = _background_processes.get(name)
return (
process.send_input(input_text)
if process
else {"status": "error", "error": "Session not found"}
)
def kill_session(name):
"""Kill a background session."""
with _process_lock:
process = _background_processes.get(name)
if process:
result = process.kill()
if result["status"] == "success":
del _background_processes[name]
return result
return {"status": "error", "error": "Session not found"}

585
pr/server.py Normal file
View File

@ -0,0 +1,585 @@
#!/usr/bin/env python3
import asyncio
import logging
import os
import sys
import subprocess
import urllib.parse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import pathlib
import ast
from types import SimpleNamespace
import time
from aiohttp import web
import aiohttp
# --- Optional external dependency used in your original code ---
from pr.ads import AsyncDataSet # noqa: E402
import sys
from http import cookies
import urllib.parse
import os
import io
import json
import urllib.request
import pathlib
import pickle
import zlib
import asyncio
import time
from functools import wraps
from pathlib import Path
import pickle
import zlib
class Cache:
def __init__(self, base_dir: Path | str = "."):
self.base_dir = Path(base_dir).resolve()
self.base_dir.mkdir(exist_ok=True, parents=True)
def is_cacheable(self, obj):
return isinstance(obj, (int, str, bool, float))
def generate_key(self, *args, **kwargs):
return zlib.crc32(
json.dumps(
{
"args": [arg for arg in args if self.is_cacheable(arg)],
"kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)},
},
sort_keys=True,
default=str,
).encode()
)
def set(self, key, value):
key = self.generate_key(key)
data = {"value": value, "timestamp": time.time()}
serialized_data = pickle.dumps(data)
with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f:
f.write(serialized_data)
def get(self, key, default=None):
key = self.generate_key(key)
try:
with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f:
data = pickle.loads(f.read())
return data
except FileNotFoundError:
return default
def is_cacheable(self, obj):
return isinstance(obj, (int, str, bool, float))
def cached(self, expiration=60):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()):
# return await func(*args, **kwargs)
key = self.generate_key(*args, **kwargs)
print("Cache hit:", key)
cached_data = self.get(key)
if cached_data is not None:
if expiration is None or (time.time() - cached_data["timestamp"]) < expiration:
return cached_data["value"]
result = await func(*args, **kwargs)
self.set(key, result)
return result
return wrapper
return decorator
cache = Cache()
class CGI:
def __init__(self):
self.environ = os.environ
if not self.gateway_interface == "CGI/1.1":
return
self.status = 200
self.headers = {}
self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi"))
self.headers["Content-Length"] = "0"
self.headers["Cache-Control"] = "no-cache"
self.headers["Access-Control-Allow-Origin"] = "*"
self.headers["Content-Type"] = "application/json; charset=utf-8"
self.cookies = cookies.SimpleCookie()
self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"])
self.file = io.BytesIO()
def validate(self, val, fn, error):
if fn(val):
return True
self.status = 421
self.write({"type": "error", "message": error})
exit()
def __getitem__(self, key):
return self.get(key)
def get(self, key, default=None):
result = self.query.get(key, [default])
if len(result) == 1:
return result[0]
return result
def get_bool(self, key):
return str(self.get(key)).lower() in ["true", "yes", "1"]
@property
def cache_key(self):
return self.environ["CACHE_KEY"]
@property
def env(self):
env = os.environ.copy()
return env
@property
def gateway_interface(self):
return self.env.get("GATEWAY_INTERFACE", "")
@property
def request_method(self):
return self.env["REQUEST_METHOD"]
@property
def query_string(self):
return self.env["QUERY_STRING"]
@property
def script_name(self):
return self.env["SCRIPT_NAME"]
@property
def path_info(self):
return self.env["PATH_INFO"]
@property
def server_name(self):
return self.env["SERVER_NAME"]
@property
def server_port(self):
return self.env["SERVER_PORT"]
@property
def server_protocol(self):
return self.env["SERVER_PROTOCOL"]
@property
def remote_addr(self):
return self.env["REMOTE_ADDR"]
def validate_get(self, key):
self.validate(self.get(key), lambda x: bool(x), f"Missing {key}")
def print(self, data):
self.write(data)
def write(self, data):
if not isinstance(data, bytes) and not isinstance(data, str):
data = json.dumps(data, default=str, indent=4)
try:
data = data.encode()
except:
pass
self.file.write(data)
self.headers["Content-Length"] = str(len(self.file.getvalue()))
@property
def http_status(self):
return f"HTTP/1.1 {self.status}\r\n".encode("utf-8")
@property
def http_headers(self):
headers = io.BytesIO()
for header in self.headers:
headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8"))
headers.write(b"\r\n")
return headers.getvalue()
@property
def http_body(self):
return self.file.getvalue()
@property
def http_response(self):
return self.http_status + self.http_headers + self.http_body
def flush(self, response=None):
if response:
try:
response = response.encode()
except:
pass
sys.stdout.buffer.write(response)
sys.stdout.buffer.flush()
return
sys.stdout.buffer.write(self.http_response)
sys.stdout.buffer.flush()
def __del__(self):
if self.http_body:
self.flush()
exit()
# -------------------------------
# Utilities
# -------------------------------
def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]:
matches: List[Tuple[str, str]] = []
for root, _, files in os.walk(directory):
for file in files:
if not file.endswith(".py"):
continue
path = os.path.join(root, file)
try:
with open(path, "r", encoding="utf-8") as fh:
source = fh.read()
tree = ast.parse(source, filename=path)
for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == name:
func_src = ast.get_source_segment(source, node)
if func_src:
matches.append((path, func_src))
break
except (SyntaxError, UnicodeDecodeError):
continue
return matches
# -------------------------------
# Server
# -------------------------------
class Static(SimpleNamespace):
pass
class View(web.View):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.db = self.request.app["db"]
self.server = self.request.app["server"]
@property
def client(self):
return aiohttp.ClientSession()
class RetoorServer:
def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None:
self.base_dir = Path(base_dir).resolve()
self.port = port
self.static = Static()
self.db = AsyncDataSet(".default.db")
self._logger = logging.getLogger("retoor.server")
self._logger.setLevel(logging.INFO)
if not self._logger.handlers:
h = logging.StreamHandler(sys.stdout)
fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
h.setFormatter(fmt)
self._logger.addHandler(h)
self._func_cache: Dict[str, Tuple[str, str]] = {}
self._compiled_cache: Dict[str, object] = {}
self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {}
self._static_path_cache: Dict[str, Path] = {}
self._base_dir_str = str(self.base_dir)
self.app = web.Application()
app = self.app
app["db"] = self.db
for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir():
if path.suffix == ".py":
exec(open(path).read(), globals(), locals())
self.app["server"] = self
# from rpanel import create_app as create_rpanel_app
# self.app.add_subapp("/api/{tail:.*}", create_rpanel_app())
self.app.router.add_route("*", "/{tail:.*}", self.handle_any)
# ---------------------------
# Simple endpoints
# ---------------------------
async def handle_root(self, request: web.Request) -> web.Response:
return web.Response(text="Static file and cgi server from retoor.")
async def handle_hello(self, request: web.Request) -> web.Response:
return web.Response(text="Welcome to the custom HTTP server!")
# ---------------------------
# Dynamic function dispatch
# ---------------------------
def _path_to_funcname(self, path: str) -> str:
# /foo/bar -> foo_bar ; strip leading/trailing slashes safely
return path.strip("/").replace("/", "_")
async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]:
relpath = request.path.strip("/")
relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace(
"..", ""
)
if not relpath:
return None
last_part = relpath.split("/")[-1]
if "." in last_part:
return None
funcname = self._path_to_funcname(request.path)
if funcname not in self._compiled_cache:
entry = self._func_cache.get(funcname)
if entry is None:
matches = get_function_source(funcname, directory=str(self.base_dir))
if not matches:
return None
entry = matches[0]
self._func_cache[funcname] = entry
filepath, source = entry
code_obj = compile(source, filepath, "exec")
self._compiled_cache[funcname] = (code_obj, filepath)
code_obj, filepath = self._compiled_cache[funcname]
ctx = {
"static": self.static,
"request": request,
"relpath": relpath,
"os": os,
"sys": sys,
"web": web,
"asyncio": asyncio,
"subprocess": subprocess,
"urllib": urllib.parse,
"app": request.app,
"db": self.db,
}
exec(code_obj, ctx, ctx)
coro = ctx.get(funcname)
if not callable(coro):
return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.")
return await coro(request)
# ---------------------------
# Static files
# ---------------------------
async def handle_static(self, request: web.Request) -> web.StreamResponse:
path = request.path
if path in self._static_path_cache:
cached_path = self._static_path_cache[path]
if cached_path is None:
return web.Response(status=404, text="Not found")
try:
return web.FileResponse(path=cached_path)
except Exception as e:
return web.Response(status=500, text=f"Failed to send file: {e!r}")
relpath = path.replace('..",', "").lstrip("/").rstrip("/")
abspath = (self.base_dir / relpath).resolve()
if not str(abspath).startswith(self._base_dir_str):
return web.Response(status=403, text="Forbidden")
if not abspath.exists():
self._static_path_cache[path] = None
return web.Response(status=404, text="Not found")
if abspath.is_dir():
return web.Response(status=403, text="Directory listing forbidden")
self._static_path_cache[path] = abspath
try:
return web.FileResponse(path=abspath)
except Exception as e:
return web.Response(status=500, text=f"Failed to send file: {e!r}")
# ---------------------------
# CGI
# ---------------------------
def _find_cgi_script(
self, path_only: str
) -> Tuple[Optional[Path], Optional[str], Optional[str]]:
path_only = "pr/cgi/" + path_only
if path_only in self._cgi_cache:
return self._cgi_cache[path_only]
split_path = [p for p in path_only.split("/") if p]
for i in range(len(split_path), -1, -1):
candidate = "/" + "/".join(split_path[:i])
candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve()
if (
str(candidate_fs).startswith(self._base_dir_str)
and candidate_fs.parent.is_dir()
and candidate_fs.parent.name == "cgi"
and candidate_fs.suffix == ".bin"
and candidate_fs.is_file()
and os.access(candidate_fs, os.X_OK)
):
script_name = candidate if candidate.startswith("/") else "/" + candidate
path_info = path_only[len(script_name) :]
result = (candidate_fs, script_name, path_info)
self._cgi_cache[path_only] = result
return result
result = (None, None, None)
self._cgi_cache[path_only] = result
return result
async def _run_cgi(
self, request: web.Request, script_path: Path, script_name: str, path_info: str
) -> web.Response:
start_time = time.time()
method = request.method
query = request.query_string
env = os.environ.copy()
env["CACHE_KEY"] = json.dumps(
{"method": method, "query": query, "script_name": script_name, "path_info": path_info},
default=str,
)
env["GATEWAY_INTERFACE"] = "CGI/1.1"
env["REQUEST_METHOD"] = method
env["QUERY_STRING"] = query
env["SCRIPT_NAME"] = script_name
env["PATH_INFO"] = path_info
host_parts = request.host.split(":", 1)
env["SERVER_NAME"] = host_parts[0]
env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80"
env["SERVER_PROTOCOL"] = "HTTP/1.1"
peername = request.transport.get_extra_info("peername")
env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0])
for hk, hv in request.headers.items():
hk_lower = hk.lower()
if hk_lower == "content-type":
env["CONTENT_TYPE"] = hv
elif hk_lower == "content-length":
env["CONTENT_LENGTH"] = hv
else:
env["HTTP_" + hk.upper().replace("-", "_")] = hv
post_data = None
if method in {"POST", "PUT", "PATCH"}:
post_data = await request.read()
env.setdefault("CONTENT_LENGTH", str(len(post_data)))
try:
proc = await asyncio.create_subprocess_exec(
str(script_path),
stdin=asyncio.subprocess.PIPE if post_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
stdout, stderr = await proc.communicate(input=post_data)
if proc.returncode != 0:
msg = stderr.decode(errors="ignore")
return web.Response(status=500, text=f"CGI script error:\n{msg}")
header_end = stdout.find(b"\r\n\r\n")
if header_end != -1:
offset = 4
else:
header_end = stdout.find(b"\n\n")
offset = 2
if header_end != -1:
body = stdout[header_end + offset :]
headers_blob = stdout[:header_end].decode(errors="ignore")
status_code = 200
headers: Dict[str, str] = {}
for line in headers_blob.splitlines():
line = line.strip()
if not line or ":" not in line:
continue
k, v = line.split(":", 1)
k_stripped = k.strip()
if k_stripped.lower() == "status":
try:
status_code = int(v.strip().split()[0])
except Exception:
status_code = 200
else:
headers[k_stripped] = v.strip()
end_time = time.time()
headers["X-Powered-By"] = "retoor"
headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime())
headers["X-Time"] = str(end_time)
headers["X-Duration"] = str(end_time - start_time)
if "error" in str(body).lower():
print(headers)
print(body)
return web.Response(body=body, headers=headers, status=status_code)
else:
return web.Response(body=stdout)
except Exception as e:
return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}")
# ---------------------------
# Main router
# ---------------------------
async def handle_any(self, request: web.Request) -> web.StreamResponse:
path = request.path
if path == "/":
return await self.handle_root(request)
if path == "/hello":
return await self.handle_hello(request)
script_path, script_name, path_info = self._find_cgi_script(path)
if script_path:
return await self._run_cgi(request, script_path, script_name or "", path_info or "")
dyn = await self._maybe_dispatch_dynamic(request)
if dyn is not None:
return dyn
return await self.handle_static(request)
# ---------------------------
# Runner
# ---------------------------
def run(self) -> None:
self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir)
web.run_app(self.app, port=self.port)
def main() -> None:
server = RetoorServer(base_dir=".", port=8118)
server.run()
if __name__ == "__main__":
main()
else:
cgi = CGI()

View File

@ -43,6 +43,11 @@ from pr.tools.memory import (
from pr.tools.patch import apply_patch, create_diff
from pr.tools.python_exec import python_exec
from pr.tools.web import http_fetch, web_search, web_search_news
from pr.tools.context_modifier import (
modify_context_add,
modify_context_replace,
modify_context_delete,
)
__all__ = [
"add_knowledge_entry",

View File

@ -1,596 +1,144 @@
import inspect
from typing import get_type_hints, get_origin, get_args
import pr.tools
def _type_to_json_schema(py_type):
"""Convert Python type to JSON Schema type."""
if py_type == str:
return {"type": "string"}
elif py_type == int:
return {"type": "integer"}
elif py_type == float:
return {"type": "number"}
elif py_type == bool:
return {"type": "boolean"}
elif get_origin(py_type) == list:
return {"type": "array", "items": _type_to_json_schema(get_args(py_type)[0])}
elif get_origin(py_type) == dict:
return {"type": "object"}
else:
# Default to string for unknown types
return {"type": "string"}
def _generate_tool_schema(func):
"""Generate JSON Schema for a tool function."""
sig = inspect.signature(func)
docstring = func.__doc__ or ""
# Extract description from docstring
description = docstring.strip().split("\n")[0] if docstring else ""
# Get type hints
type_hints = get_type_hints(func)
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name in ["db_conn", "python_globals"]: # Skip internal parameters
continue
param_type = type_hints.get(param_name, str)
schema = _type_to_json_schema(param_type)
# Add description from docstring if available
param_doc = ""
if docstring:
lines = docstring.split("\n")
in_args = False
for line in lines:
line = line.strip()
if line.startswith("Args:") or line.startswith("Arguments:"):
in_args = True
continue
elif in_args and line.startswith(param_name + ":"):
param_doc = line.split(":", 1)[1].strip()
break
elif in_args and line == "":
continue
elif in_args and not line.startswith(" "):
break
if param_doc:
schema["description"] = param_doc
# Set default if available
if param.default != inspect.Parameter.empty:
schema["default"] = param.default
properties[param_name] = schema
# Required if no default
if param.default == inspect.Parameter.empty:
required.append(param_name)
return {
"type": "function",
"function": {
"name": func.__name__,
"description": description,
"parameters": {
"type": "object",
"properties": properties,
"required": required,
},
},
}
def get_tools_definition():
return [
{
"type": "function",
"function": {
"name": "kill_process",
"description": "Terminate a background process by its PID. Use this to stop processes started with run_command that exceeded their timeout.",
"parameters": {
"type": "object",
"properties": {
"pid": {
"type": "integer",
"description": "The process ID returned by run_command when status is 'running'.",
}
},
"required": ["pid"],
},
},
},
{
"type": "function",
"function": {
"name": "tail_process",
"description": "Monitor and retrieve output from a background process by its PID. Use this to check on processes started with run_command that exceeded their timeout.",
"parameters": {
"type": "object",
"properties": {
"pid": {
"type": "integer",
"description": "The process ID returned by run_command when status is 'running'.",
},
"timeout": {
"type": "integer",
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
"default": 30,
},
},
"required": ["pid"],
},
},
},
{
"type": "function",
"function": {
"name": "http_fetch",
"description": "Fetch content from an HTTP URL",
"parameters": {
"type": "object",
"properties": {
"url": {"type": "string", "description": "The URL to fetch"},
"headers": {
"type": "object",
"description": "Optional HTTP headers",
},
},
"required": ["url"],
},
},
},
{
"type": "function",
"function": {
"name": "run_command",
"description": "Execute a shell command and capture output. Returns immediately after timeout with PID if still running. Use tail_process to monitor or kill_process to terminate long-running commands.",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute",
},
"timeout": {
"type": "integer",
"description": "Maximum seconds to wait for completion",
"default": 30,
},
},
"required": ["command"],
},
},
},
{
"type": "function",
"function": {
"name": "start_interactive_session",
"description": "Execute an interactive terminal command that requires user input or displays UI. The command runs in a dedicated session and returns a session name.",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The interactive command to execute (e.g., vim, nano, top)",
}
},
"required": ["command"],
},
},
},
{
"type": "function",
"function": {
"name": "send_input_to_session",
"description": "Send input to an interactive session.",
"parameters": {
"type": "object",
"properties": {
"session_name": {
"type": "string",
"description": "The name of the session",
},
"input_data": {
"type": "string",
"description": "The input to send to the session",
},
},
"required": ["session_name", "input_data"],
},
},
},
{
"type": "function",
"function": {
"name": "read_session_output",
"description": "Read output from an interactive session.",
"parameters": {
"type": "object",
"properties": {
"session_name": {
"type": "string",
"description": "The name of the session",
}
},
"required": ["session_name"],
},
},
},
{
"type": "function",
"function": {
"name": "close_interactive_session",
"description": "Close an interactive session.",
"parameters": {
"type": "object",
"properties": {
"session_name": {
"type": "string",
"description": "The name of the session",
}
},
"required": ["session_name"],
},
},
},
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read contents of a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
}
},
"required": ["filepath"],
},
},
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
},
"content": {
"type": "string",
"description": "Content to write",
},
},
"required": ["filepath", "content"],
},
},
},
{
"type": "function",
"function": {
"name": "list_directory",
"description": "List directory contents",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Directory path",
"default": ".",
},
"recursive": {
"type": "boolean",
"description": "List recursively",
"default": False,
},
},
},
},
},
{
"type": "function",
"function": {
"name": "mkdir",
"description": "Create a new directory",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path of the directory to create",
}
},
"required": ["path"],
},
},
},
{
"type": "function",
"function": {
"name": "chdir",
"description": "Change the current working directory",
"parameters": {
"type": "object",
"properties": {"path": {"type": "string", "description": "Path to change to"}},
"required": ["path"],
},
},
},
{
"type": "function",
"function": {
"name": "getpwd",
"description": "Get the current working directory",
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
"function": {
"name": "db_set",
"description": "Set a key-value pair in the database",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "The key"},
"value": {"type": "string", "description": "The value"},
},
"required": ["key", "value"],
},
},
},
{
"type": "function",
"function": {
"name": "db_get",
"description": "Get a value from the database",
"parameters": {
"type": "object",
"properties": {"key": {"type": "string", "description": "The key"}},
"required": ["key"],
},
},
},
{
"type": "function",
"function": {
"name": "db_query",
"description": "Execute a database query",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string", "description": "SQL query"}},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Perform a web search",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string", "description": "Search query"}},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "web_search_news",
"description": "Perform a web search for news",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for news",
}
},
"required": ["query"],
},
},
},
{
"type": "function",
"function": {
"name": "python_exec",
"description": "Execute Python code",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "Python code to execute",
}
},
"required": ["code"],
},
},
},
{
"type": "function",
"function": {
"name": "index_source_directory",
"description": "Index directory recursively and read all source files.",
"parameters": {
"type": "object",
"properties": {"path": {"type": "string", "description": "Path to index"}},
"required": ["path"],
},
},
},
{
"type": "function",
"function": {
"name": "search_replace",
"description": "Search and replace text in a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
},
"old_string": {
"type": "string",
"description": "String to replace",
},
"new_string": {
"type": "string",
"description": "Replacement string",
},
},
"required": ["filepath", "old_string", "new_string"],
},
},
},
{
"type": "function",
"function": {
"name": "apply_patch",
"description": "Apply a patch to a file, especially for source code",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file to patch",
},
"patch_content": {
"type": "string",
"description": "The patch content as a string",
},
},
"required": ["filepath", "patch_content"],
},
},
},
{
"type": "function",
"function": {
"name": "create_diff",
"description": "Create a unified diff between two files",
"parameters": {
"type": "object",
"properties": {
"file1": {
"type": "string",
"description": "Path to the first file",
},
"file2": {
"type": "string",
"description": "Path to the second file",
},
"fromfile": {
"type": "string",
"description": "Label for the first file",
"default": "file1",
},
"tofile": {
"type": "string",
"description": "Label for the second file",
"default": "file2",
},
},
"required": ["file1", "file2"],
},
},
},
{
"type": "function",
"function": {
"name": "open_editor",
"description": "Open the RPEditor for a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
}
},
"required": ["filepath"],
},
},
},
{
"type": "function",
"function": {
"name": "close_editor",
"description": "Close the RPEditor. Always close files when finished editing.",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
}
},
"required": ["filepath"],
},
},
},
{
"type": "function",
"function": {
"name": "editor_insert_text",
"description": "Insert text at cursor position in the editor",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
},
"text": {"type": "string", "description": "Text to insert"},
"line": {
"type": "integer",
"description": "Line number (optional)",
},
"col": {
"type": "integer",
"description": "Column number (optional)",
},
},
"required": ["filepath", "text"],
},
},
},
{
"type": "function",
"function": {
"name": "editor_replace_text",
"description": "Replace text in a range",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
},
"start_line": {"type": "integer", "description": "Start line"},
"start_col": {"type": "integer", "description": "Start column"},
"end_line": {"type": "integer", "description": "End line"},
"end_col": {"type": "integer", "description": "End column"},
"new_text": {"type": "string", "description": "New text"},
},
"required": [
"filepath",
"start_line",
"start_col",
"end_line",
"end_col",
"new_text",
],
},
},
},
{
"type": "function",
"function": {
"name": "editor_search",
"description": "Search for a pattern in the file",
"parameters": {
"type": "object",
"properties": {
"filepath": {
"type": "string",
"description": "Path to the file",
},
"pattern": {"type": "string", "description": "Regex pattern"},
"start_line": {
"type": "integer",
"description": "Start line",
"default": 0,
},
},
"required": ["filepath", "pattern"],
},
},
},
{
"type": "function",
"function": {
"name": "display_file_diff",
"description": "Display a visual colored diff between two files with syntax highlighting and statistics",
"parameters": {
"type": "object",
"properties": {
"filepath1": {
"type": "string",
"description": "Path to the original file",
},
"filepath2": {
"type": "string",
"description": "Path to the modified file",
},
"format_type": {
"type": "string",
"description": "Display format: 'unified' or 'side-by-side'",
"default": "unified",
},
},
"required": ["filepath1", "filepath2"],
},
},
},
{
"type": "function",
"function": {
"name": "display_edit_summary",
"description": "Display a summary of all edit operations performed during the session",
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
"function": {
"name": "display_edit_timeline",
"description": "Display a timeline of all edit operations with details",
"parameters": {
"type": "object",
"properties": {
"show_content": {
"type": "boolean",
"description": "Show content previews",
"default": False,
}
},
},
},
},
{
"type": "function",
"function": {
"name": "clear_edit_tracker",
"description": "Clear the edit tracker to start fresh",
"parameters": {"type": "object", "properties": {}},
},
},
]
"""Dynamically generate tool definitions from all tool functions."""
tools = []
# Get all functions from pr.tools modules
for name in dir(pr.tools):
if name.startswith("_"):
continue
obj = getattr(pr.tools, name)
if callable(obj) and hasattr(obj, "__module__") and obj.__module__.startswith("pr.tools."):
# Check if it's a tool function (has docstring and proper signature)
if obj.__doc__:
try:
schema = _generate_tool_schema(obj)
tools.append(schema)
except Exception as e:
print(f"Warning: Could not generate schema for {name}: {e}")
continue
return tools
def get_func_map(db_conn=None, python_globals=None):
"""Dynamically generate function map for tool execution."""
func_map = {}
# Include all functions from __all__ in pr.tools
for name in getattr(pr.tools, "__all__", []):
if name.startswith("_"):
continue
obj = getattr(pr.tools, name, None)
if callable(obj) and hasattr(obj, "__module__") and obj.__module__.startswith("pr.tools."):
sig = inspect.signature(obj)
params = list(sig.parameters.keys())
# Create wrapper based on parameters
if "db_conn" in params and "python_globals" in params:
func_map[name] = (
lambda func=obj, db_conn=db_conn, python_globals=python_globals, **kw: func(
**kw, db_conn=db_conn, python_globals=python_globals
)
)
elif "db_conn" in params:
func_map[name] = lambda func=obj, db_conn=db_conn, **kw: func(**kw, db_conn=db_conn)
elif "python_globals" in params:
func_map[name] = lambda func=obj, python_globals=python_globals, **kw: func(
**kw, python_globals=python_globals
)
else:
func_map[name] = lambda func=obj, **kw: func(**kw)
return func_map

View File

@ -2,7 +2,6 @@ import select
import subprocess
import time
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
_processes = {}
@ -99,6 +98,19 @@ def tail_process(pid: int, timeout: int = 30):
def run_command(command, timeout=30, monitored=False, cwd=None):
"""Execute a shell command and return the output.
Args:
command: The shell command to execute.
timeout: Maximum time in seconds to wait for completion.
monitored: Whether to monitor the process (unused).
cwd: Working directory for the command.
Returns:
Dict with status, stdout, stderr, returncode, and optionally pid if still running.
"""
from pr.multiplexer import close_multiplexer, create_multiplexer
mux_name = None
try:
process = subprocess.Popen(

176
pr/tools/command.py.bak Normal file
View File

@ -0,0 +1,176 @@
print(f"Executing command: {command}") print(f"Killing process: {pid}")import os
import select
import subprocess
import time
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
_processes = {}
def _register_process(pid: int, process):
_processes[pid] = process
return _processes
def _get_process(pid: int):
return _processes.get(pid)
def kill_process(pid: int):
try:
process = _get_process(pid)
if process:
process.kill()
_processes.pop(pid)
mux_name = f"cmd-{pid}"
if get_multiplexer(mux_name):
close_multiplexer(mux_name)
return {"status": "success", "message": f"Process {pid} has been killed"}
else:
return {"status": "error", "error": f"Process {pid} not found"}
except Exception as e:
return {"status": "error", "error": str(e)}
def tail_process(pid: int, timeout: int = 30):
process = _get_process(pid)
if process:
mux_name = f"cmd-{pid}"
mux = get_multiplexer(mux_name)
if not mux:
mux_name, mux = create_multiplexer(mux_name, show_output=True)
try:
start_time = time.time()
timeout_duration = timeout
stdout_content = ""
stderr_content = ""
while True:
if process.poll() is not None:
remaining_stdout, remaining_stderr = process.communicate()
if remaining_stdout:
mux.write_stdout(remaining_stdout)
stdout_content += remaining_stdout
if remaining_stderr:
mux.write_stderr(remaining_stderr)
stderr_content += remaining_stderr
if pid in _processes:
_processes.pop(pid)
close_multiplexer(mux_name)
return {
"status": "success",
"stdout": stdout_content,
"stderr": stderr_content,
"returncode": process.returncode,
}
if time.time() - start_time > timeout_duration:
return {
"status": "running",
"message": "Process is still running. Call tail_process again to continue monitoring.",
"stdout_so_far": stdout_content,
"stderr_so_far": stderr_content,
"pid": pid,
}
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
for pipe in ready:
if pipe == process.stdout:
line = process.stdout.readline()
if line:
mux.write_stdout(line)
stdout_content += line
elif pipe == process.stderr:
line = process.stderr.readline()
if line:
mux.write_stderr(line)
stderr_content += line
except Exception as e:
return {"status": "error", "error": str(e)}
else:
return {"status": "error", "error": f"Process {pid} not found"}
def run_command(command, timeout=30, monitored=False):
mux_name = None
try:
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
_register_process(process.pid, process)
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
start_time = time.time()
timeout_duration = timeout
stdout_content = ""
stderr_content = ""
while True:
if process.poll() is not None:
remaining_stdout, remaining_stderr = process.communicate()
if remaining_stdout:
mux.write_stdout(remaining_stdout)
stdout_content += remaining_stdout
if remaining_stderr:
mux.write_stderr(remaining_stderr)
stderr_content += remaining_stderr
if process.pid in _processes:
_processes.pop(process.pid)
close_multiplexer(mux_name)
return {
"status": "success",
"stdout": stdout_content,
"stderr": stderr_content,
"returncode": process.returncode,
}
if time.time() - start_time > timeout_duration:
return {
"status": "running",
"message": f"Process still running after {timeout}s timeout. Use tail_process({process.pid}) to monitor or kill_process({process.pid}) to terminate.",
"stdout_so_far": stdout_content,
"stderr_so_far": stderr_content,
"pid": process.pid,
"mux_name": mux_name,
}
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
for pipe in ready:
if pipe == process.stdout:
line = process.stdout.readline()
if line:
mux.write_stdout(line)
stdout_content += line
elif pipe == process.stderr:
line = process.stderr.readline()
if line:
mux.write_stderr(line)
stderr_content += line
except Exception as e:
if mux_name:
close_multiplexer(mux_name)
return {"status": "error", "error": str(e)}
def run_command_interactive(command):
try:
return_code = os.system(command)
return {"status": "success", "returncode": return_code}
except Exception as e:
return {"status": "error", "error": str(e)}

View File

@ -0,0 +1,65 @@
import os
from typing import Optional
CONTEXT_FILE = '/home/retoor/.local/share/rp/.rcontext.txt'
def _read_context() -> str:
if not os.path.exists(CONTEXT_FILE):
raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.")
with open(CONTEXT_FILE, 'r') as f:
return f.read()
def _write_context(content: str):
with open(CONTEXT_FILE, 'w') as f:
f.write(content)
def modify_context_add(new_content: str, position: Optional[str] = None) -> str:
"""
Add new content to the .rcontext.txt file.
Args:
new_content: The content to add.
position: Optional marker to insert before (e.g., '***').
"""
current = _read_context()
if position and position in current:
# Insert before the position
parts = current.split(position, 1)
updated = parts[0] + new_content + '\n\n' + position + parts[1]
else:
# Append at the end
updated = current + '\n\n' + new_content
_write_context(updated)
return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested."
def modify_context_replace(old_content: str, new_content: str) -> str:
"""
Replace old content with new content in .rcontext.txt.
Args:
old_content: The content to replace.
new_content: The replacement content.
"""
current = _read_context()
if old_content not in current:
raise ValueError(f"Old content not found: {old_content[:50]}...")
updated = current.replace(old_content, new_content, 1) # Replace first occurrence
_write_context(updated)
return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects."
def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str:
"""
Delete content from .rcontext.txt, but only if confirmed.
Args:
content_to_delete: The content to delete.
confirmed: Must be True to proceed with deletion.
"""
if not confirmed:
raise PermissionError(f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently.")
current = _read_context()
if content_to_delete not in current:
raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...")
updated = current.replace(content_to_delete, '', 1)
_write_context(updated)
return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines."

View File

@ -2,6 +2,16 @@ import time
def db_set(key, value, db_conn):
"""Set a key-value pair in the database.
Args:
key: The key to set.
value: The value to store.
db_conn: Database connection.
Returns:
Dict with status and message.
"""
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
@ -19,6 +29,15 @@ def db_set(key, value, db_conn):
def db_get(key, db_conn):
"""Get a value from the database.
Args:
key: The key to retrieve.
db_conn: Database connection.
Returns:
Dict with status and value.
"""
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
@ -35,6 +54,15 @@ def db_get(key, db_conn):
def db_query(query, db_conn):
"""Execute a database query.
Args:
query: SQL query to execute.
db_conn: Database connection.
Returns:
Dict with status and query results.
"""
if not db_conn:
return {"status": "error", "error": "Database not initialized"}

838
pr/tools/debugging.py Normal file
View File

@ -0,0 +1,838 @@
import sys
import os
import ast
import inspect
import time
import threading
import gc
import weakref
import linecache
import re
import json
import subprocess
from collections import defaultdict
from datetime import datetime
class MemoryTracker:
def __init__(self):
self.allocations = defaultdict(list)
self.references = weakref.WeakValueDictionary()
self.peak_memory = 0
self.current_memory = 0
def track_object(self, obj, location):
try:
obj_id = id(obj)
obj_size = sys.getsizeof(obj)
self.allocations[location].append(
{
"id": obj_id,
"type": type(obj).__name__,
"size": obj_size,
"timestamp": time.time(),
}
)
self.current_memory += obj_size
if self.current_memory > self.peak_memory:
self.peak_memory = self.current_memory
except:
pass
def analyze_leaks(self):
gc.collect()
leaks = []
for obj in gc.get_objects():
if sys.getrefcount(obj) > 10:
try:
leaks.append(
{
"type": type(obj).__name__,
"refcount": sys.getrefcount(obj),
"size": sys.getsizeof(obj),
}
)
except:
pass
return sorted(leaks, key=lambda x: x["refcount"], reverse=True)[:20]
def get_report(self):
return {
"peak_memory": self.peak_memory,
"current_memory": self.current_memory,
"allocation_count": sum(len(v) for v in self.allocations.values()),
"leaks": self.analyze_leaks(),
}
class PerformanceProfiler:
def __init__(self):
self.function_times = defaultdict(lambda: {"calls": 0, "total_time": 0.0, "self_time": 0.0})
self.call_stack = []
self.start_times = {}
def enter_function(self, frame):
func_name = self._get_function_name(frame)
self.call_stack.append(func_name)
self.start_times[id(frame)] = time.perf_counter()
def exit_function(self, frame):
func_name = self._get_function_name(frame)
frame_id = id(frame)
if frame_id in self.start_times:
elapsed = time.perf_counter() - self.start_times[frame_id]
self.function_times[func_name]["calls"] += 1
self.function_times[func_name]["total_time"] += elapsed
self.function_times[func_name]["self_time"] += elapsed
del self.start_times[frame_id]
if self.call_stack:
self.call_stack.pop()
def _get_function_name(self, frame):
return f"{frame.f_code.co_filename}:{frame.f_code.co_name}:{frame.f_lineno}"
def get_hotspots(self, limit=20):
sorted_funcs = sorted(
self.function_times.items(), key=lambda x: x[1]["total_time"], reverse=True
)
return sorted_funcs[:limit]
class StaticAnalyzer(ast.NodeVisitor):
def __init__(self):
self.issues = []
self.complexity = 0
self.unused_vars = set()
self.undefined_vars = set()
self.defined_vars = set()
self.used_vars = set()
self.functions = {}
self.classes = {}
self.imports = []
def visit_FunctionDef(self, node):
self.functions[node.name] = {
"lineno": node.lineno,
"args": [arg.arg for arg in node.args.args],
"decorators": [
d.id if isinstance(d, ast.Name) else "complex" for d in node.decorator_list
],
"complexity": self._calculate_complexity(node),
}
if len(node.body) == 0:
self.issues.append(f"Line {node.lineno}: Empty function '{node.name}'")
if len(node.args.args) > 7:
self.issues.append(
f"Line {node.lineno}: Function '{node.name}' has too many parameters ({len(node.args.args)})"
)
self.generic_visit(node)
def visit_ClassDef(self, node):
self.classes[node.name] = {
"lineno": node.lineno,
"bases": [b.id if isinstance(b, ast.Name) else "complex" for b in node.bases],
"methods": [],
}
self.generic_visit(node)
def visit_Import(self, node):
for alias in node.names:
self.imports.append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
if node.module:
self.imports.append(node.module)
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self.defined_vars.add(node.id)
elif isinstance(node.ctx, ast.Load):
self.used_vars.add(node.id)
self.generic_visit(node)
def visit_If(self, node):
self.complexity += 1
self.generic_visit(node)
def visit_For(self, node):
self.complexity += 1
self.generic_visit(node)
def visit_While(self, node):
self.complexity += 1
self.generic_visit(node)
def visit_ExceptHandler(self, node):
self.complexity += 1
if node.type is None:
self.issues.append(f"Line {node.lineno}: Bare except clause (catches all exceptions)")
self.generic_visit(node)
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
if isinstance(node.left, ast.Str) or isinstance(node.right, ast.Str):
self.issues.append(
f"Line {node.lineno}: String concatenation with '+' (use f-strings or join)"
)
self.generic_visit(node)
def _calculate_complexity(self, node):
complexity = 1
for child in ast.walk(node):
if isinstance(child, (ast.If, ast.For, ast.While, ast.ExceptHandler)):
complexity += 1
return complexity
def finalize(self):
self.unused_vars = self.defined_vars - self.used_vars
self.undefined_vars = self.used_vars - self.defined_vars - set(dir(__builtins__))
for var in self.unused_vars:
self.issues.append(f"Unused variable: '{var}'")
def analyze_code(self, source_code):
try:
tree = ast.parse(source_code)
self.visit(tree)
self.finalize()
return {
"issues": self.issues,
"complexity": self.complexity,
"functions": self.functions,
"classes": self.classes,
"imports": self.imports,
"unused_vars": list(self.unused_vars),
}
except SyntaxError as e:
return {"error": f"Syntax error at line {e.lineno}: {e.msg}"}
class DynamicTracer:
def __init__(self):
self.execution_trace = []
self.exception_trace = []
self.variable_changes = defaultdict(list)
self.line_coverage = set()
self.function_calls = defaultdict(int)
self.max_trace_length = 10000
self.memory_tracker = MemoryTracker()
self.profiler = PerformanceProfiler()
self.trace_active = False
def trace_calls(self, frame, event, arg):
if not self.trace_active:
return
if len(self.execution_trace) >= self.max_trace_length:
return self.trace_calls
co = frame.f_code
func_name = co.co_name
filename = co.co_filename
line_no = frame.f_lineno
if "site-packages" in filename or filename.startswith("<"):
return self.trace_calls
trace_entry = {
"event": event,
"function": func_name,
"filename": filename,
"lineno": line_no,
"timestamp": time.time(),
}
if event == "call":
self.function_calls[f"{filename}:{func_name}"] += 1
self.profiler.enter_function(frame)
trace_entry["locals"] = {
k: repr(v)[:100] for k, v in frame.f_locals.items() if not k.startswith("__")
}
elif event == "return":
self.profiler.exit_function(frame)
trace_entry["return_value"] = repr(arg)[:100] if arg else None
elif event == "line":
self.line_coverage.add((filename, line_no))
line_code = linecache.getline(filename, line_no).strip()
trace_entry["code"] = line_code
for var, value in frame.f_locals.items():
if not var.startswith("__"):
self.variable_changes[var].append(
{"line": line_no, "value": repr(value)[:100], "timestamp": time.time()}
)
self.memory_tracker.track_object(value, f"{filename}:{line_no}")
elif event == "exception":
exc_type, exc_value, exc_tb = arg
self.exception_trace.append(
{
"type": exc_type.__name__,
"message": str(exc_value),
"filename": filename,
"function": func_name,
"lineno": line_no,
"timestamp": time.time(),
}
)
trace_entry["exception"] = {"type": exc_type.__name__, "message": str(exc_value)}
self.execution_trace.append(trace_entry)
return self.trace_calls
def start_tracing(self):
self.trace_active = True
sys.settrace(self.trace_calls)
threading.settrace(self.trace_calls)
def stop_tracing(self):
self.trace_active = False
sys.settrace(None)
threading.settrace(None)
def get_trace_report(self):
return {
"execution_trace": self.execution_trace[-100:],
"exception_trace": self.exception_trace,
"line_coverage": list(self.line_coverage),
"function_calls": dict(self.function_calls),
"variable_changes": {k: v[-10:] for k, v in self.variable_changes.items()},
"hotspots": self.profiler.get_hotspots(20),
"memory_report": self.memory_tracker.get_report(),
}
class GitBisectAutomator:
def __init__(self, repo_path="."):
self.repo_path = repo_path
def is_git_repo(self):
try:
result = subprocess.run(
["git", "rev-parse", "--git-dir"],
cwd=self.repo_path,
capture_output=True,
text=True,
)
return result.returncode == 0
except:
return False
def get_commit_history(self, limit=50):
try:
result = subprocess.run(
["git", "log", f"--max-count={limit}", "--oneline"],
cwd=self.repo_path,
capture_output=True,
text=True,
)
if result.returncode == 0:
commits = []
for line in result.stdout.strip().split("\n"):
parts = line.split(" ", 1)
if len(parts) == 2:
commits.append({"hash": parts[0], "message": parts[1]})
return commits
except:
pass
return []
def blame_file(self, filepath):
try:
result = subprocess.run(
["git", "blame", "-L", "1,50", filepath],
cwd=self.repo_path,
capture_output=True,
text=True,
)
if result.returncode == 0:
return result.stdout
except:
pass
return None
class LogAnalyzer:
def __init__(self):
self.log_patterns = {
"error": re.compile(r"error|exception|fail|critical", re.IGNORECASE),
"warning": re.compile(r"warn|caution", re.IGNORECASE),
"debug": re.compile(r"debug|trace", re.IGNORECASE),
"timestamp": re.compile(r"\\d{4}-\\d{2}-\\d{2}[\\s_T]\\d{2}:\\d{2}:\\d{2}"),
}
self.anomalies = []
def analyze_logs(self, log_content):
lines = log_content.split("\n")
errors = []
warnings = []
timestamps = []
for i, line in enumerate(lines):
if self.log_patterns["error"].search(line):
errors.append({"line": i + 1, "content": line})
elif self.log_patterns["warning"].search(line):
warnings.append({"line": i + 1, "content": line})
ts_match = self.log_patterns["timestamp"].search(line)
if ts_match:
timestamps.append(ts_match.group())
return {
"total_lines": len(lines),
"errors": errors[:50],
"warnings": warnings[:50],
"error_count": len(errors),
"warning_count": len(warnings),
"timestamps": timestamps[:20],
}
class ExceptionAnalyzer:
def __init__(self):
self.exceptions = []
self.exception_counts = defaultdict(int)
def capture_exception(self, exc_type, exc_value, exc_traceback):
exc_info = {
"type": exc_type.__name__,
"message": str(exc_value),
"timestamp": time.time(),
"traceback": [],
}
tb = exc_traceback
while tb is not None:
frame = tb.tb_frame
exc_info["traceback"].append(
{
"filename": frame.f_code.co_filename,
"function": frame.f_code.co_name,
"lineno": tb.tb_lineno,
"locals": {
k: repr(v)[:100]
for k, v in frame.f_locals.items()
if not k.startswith("__")
},
}
)
tb = tb.tb_next
self.exceptions.append(exc_info)
self.exception_counts[exc_type.__name__] += 1
return exc_info
def get_report(self):
return {
"total_exceptions": len(self.exceptions),
"exception_types": dict(self.exception_counts),
"recent_exceptions": self.exceptions[-10:],
}
class TestGenerator:
def __init__(self):
self.test_cases = []
def generate_tests_for_function(self, func_name, func_signature):
test_template = f"""def test_{func_name}_basic():
result = {func_name}()
assert result is not None
def test_{func_name}_edge_cases():
pass
def test_{func_name}_exceptions():
pass
"""
return test_template
def analyze_function_for_tests(self, func_obj):
sig = inspect.signature(func_obj)
test_inputs = []
for param_name, param in sig.parameters.items():
if param.annotation != inspect.Parameter.empty:
param_type = param.annotation
if param_type == int:
test_inputs.append([0, 1, -1, 100])
elif param_type == str:
test_inputs.append(["", "test", "a" * 100])
elif param_type == list:
test_inputs.append([[], [1], [1, 2, 3]])
else:
test_inputs.append([None])
else:
test_inputs.append([None, 0, "", []])
return test_inputs
class CodeFlowVisualizer:
def __init__(self):
self.flow_graph = defaultdict(list)
self.call_hierarchy = defaultdict(set)
def build_flow_from_trace(self, execution_trace):
for i in range(len(execution_trace) - 1):
current = execution_trace[i]
next_step = execution_trace[i + 1]
current_node = f"{current['function']}:{current['lineno']}"
next_node = f"{next_step['function']}:{next_step['lineno']}"
self.flow_graph[current_node].append(next_node)
if current["event"] == "call":
caller = current["function"]
callee = next_step["function"]
self.call_hierarchy[caller].add(callee)
def generate_text_visualization(self):
output = []
output.append("Call Hierarchy:")
for caller, callees in sorted(self.call_hierarchy.items()):
output.append(f" {caller}")
for callee in sorted(callees):
output.append(f" -> {callee}")
return "\n".join(output)
class AutomatedDebugger:
def __init__(self):
self.static_analyzer = StaticAnalyzer()
self.dynamic_tracer = DynamicTracer()
self.exception_analyzer = ExceptionAnalyzer()
self.log_analyzer = LogAnalyzer()
self.git_automator = GitBisectAutomator()
self.test_generator = TestGenerator()
self.flow_visualizer = CodeFlowVisualizer()
self.original_excepthook = sys.excepthook
def analyze_source_file(self, filepath):
try:
with open(filepath, "r", encoding="utf-8") as f:
source_code = f.read()
return self.static_analyzer.analyze_code(source_code)
except Exception as e:
return {"error": str(e)}
def run_with_tracing(self, target_func, *args, **kwargs):
self.dynamic_tracer.start_tracing()
result = None
exception = None
try:
result = target_func(*args, **kwargs)
except Exception as e:
exception = self.exception_analyzer.capture_exception(type(e), e, e.__traceback__)
finally:
self.dynamic_tracer.stop_tracing()
self.flow_visualizer.build_flow_from_trace(self.dynamic_tracer.execution_trace)
return {
"result": result,
"exception": exception,
"trace": self.dynamic_tracer.get_trace_report(),
"flow": self.flow_visualizer.generate_text_visualization(),
}
def analyze_logs(self, log_file_or_content):
if os.path.isfile(log_file_or_content):
with open(log_file_or_content, "r", encoding="utf-8") as f:
content = f.read()
else:
content = log_file_or_content
return self.log_analyzer.analyze_logs(content)
def generate_debug_report(self, output_file="debug_report.json"):
report = {
"timestamp": datetime.now().isoformat(),
"static_analysis": self.static_analyzer.issues,
"dynamic_trace": self.dynamic_tracer.get_trace_report(),
"exceptions": self.exception_analyzer.get_report(),
"git_info": (
self.git_automator.get_commit_history(10)
if self.git_automator.is_git_repo()
else None
),
"flow_visualization": self.flow_visualizer.generate_text_visualization(),
}
with open(output_file, "w") as f:
json.dump(report, f, indent=2, default=str)
return report
def auto_debug_function(self, func, test_inputs=None):
results = []
if test_inputs is None:
test_inputs = self.test_generator.analyze_function_for_tests(func)
for input_set in test_inputs:
try:
if isinstance(input_set, list):
result = self.run_with_tracing(func, *input_set)
else:
result = self.run_with_tracing(func, input_set)
results.append(
{
"input": input_set,
"success": result["exception"] is None,
"output": result["result"],
"trace_summary": {
"function_calls": len(result["trace"]["function_calls"]),
"exceptions": len(result["trace"]["exception_trace"]),
},
}
)
except Exception as e:
results.append({"input": input_set, "success": False, "error": str(e)})
return results
# Global instances for tools
_memory_tracker = MemoryTracker()
_performance_profiler = PerformanceProfiler()
_static_analyzer = StaticAnalyzer()
_dynamic_tracer = DynamicTracer()
_git_automator = GitBisectAutomator()
_log_analyzer = LogAnalyzer()
_exception_analyzer = ExceptionAnalyzer()
_test_generator = TestGenerator()
_code_flow_visualizer = CodeFlowVisualizer()
_automated_debugger = AutomatedDebugger()
# Tool functions
def track_memory_allocation(location: str = "manual") -> dict:
"""Track current memory allocation at a specific location."""
try:
_memory_tracker.track_object({}, location)
return {
"status": "success",
"current_memory": _memory_tracker.current_memory,
"peak_memory": _memory_tracker.peak_memory,
"location": location,
}
except Exception as e:
return {"status": "error", "error": str(e)}
def analyze_memory_leaks() -> dict:
"""Analyze potential memory leaks in the current process."""
try:
leaks = _memory_tracker.analyze_leaks()
return {"status": "success", "leaks_found": len(leaks), "top_leaks": leaks[:10]}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_memory_report() -> dict:
"""Get a comprehensive memory usage report."""
try:
return {"status": "success", "report": _memory_tracker.get_report()}
except Exception as e:
return {"status": "error", "error": str(e)}
def start_performance_profiling() -> dict:
"""Start performance profiling."""
try:
PerformanceProfiler()
return {"status": "success", "message": "Performance profiling started"}
except Exception as e:
return {"status": "error", "error": str(e)}
def stop_performance_profiling() -> dict:
"""Stop performance profiling and get hotspots."""
try:
hotspots = _performance_profiler.get_hotspots(20)
return {"status": "success", "hotspots": hotspots}
except Exception as e:
return {"status": "error", "error": str(e)}
def analyze_source_code(source_code: str) -> dict:
"""Perform static analysis on Python source code."""
try:
analyzer = StaticAnalyzer()
result = analyzer.analyze_code(source_code)
return {"status": "success", "analysis": result}
except Exception as e:
return {"status": "error", "error": str(e)}
def analyze_source_file(filepath: str) -> dict:
"""Analyze a Python source file statically."""
try:
result = _automated_debugger.analyze_source_file(filepath)
return {"status": "success", "filepath": filepath, "analysis": result}
except Exception as e:
return {"status": "error", "error": str(e)}
def start_dynamic_tracing() -> dict:
"""Start dynamic execution tracing."""
try:
_dynamic_tracer.start_tracing()
return {"status": "success", "message": "Dynamic tracing started"}
except Exception as e:
return {"status": "error", "error": str(e)}
def stop_dynamic_tracing() -> dict:
"""Stop dynamic tracing and get trace report."""
try:
_dynamic_tracer.stop_tracing()
report = _dynamic_tracer.get_trace_report()
return {"status": "success", "trace_report": report}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_git_commit_history(limit: int = 50) -> dict:
"""Get recent git commit history."""
try:
commits = _git_automator.get_commit_history(limit)
return {
"status": "success",
"commits": commits,
"is_git_repo": _git_automator.is_git_repo(),
}
except Exception as e:
return {"status": "error", "error": str(e)}
def blame_file(filepath: str) -> dict:
"""Get git blame information for a file."""
try:
blame_output = _git_automator.blame_file(filepath)
return {"status": "success", "filepath": filepath, "blame": blame_output}
except Exception as e:
return {"status": "error", "error": str(e)}
def analyze_log_content(log_content: str) -> dict:
"""Analyze log content for errors, warnings, and patterns."""
try:
analysis = _log_analyzer.analyze_logs(log_content)
return {"status": "success", "analysis": analysis}
except Exception as e:
return {"status": "error", "error": str(e)}
def analyze_log_file(filepath: str) -> dict:
"""Analyze a log file."""
try:
analysis = _automated_debugger.analyze_logs(filepath)
return {"status": "success", "filepath": filepath, "analysis": analysis}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_exception_report() -> dict:
"""Get a report of captured exceptions."""
try:
report = _exception_analyzer.get_report()
return {"status": "success", "exception_report": report}
except Exception as e:
return {"status": "error", "error": str(e)}
def generate_tests_for_function(func_name: str, func_signature: str = "") -> dict:
"""Generate test templates for a function."""
try:
test_code = _test_generator.generate_tests_for_function(func_name, func_signature)
return {"status": "success", "func_name": func_name, "test_code": test_code}
except Exception as e:
return {"status": "error", "error": str(e)}
def visualize_code_flow_from_trace(execution_trace) -> dict:
"""Visualize code flow from execution trace."""
try:
visualizer = CodeFlowVisualizer()
visualizer.build_flow_from_trace(execution_trace)
visualization = visualizer.generate_text_visualization()
return {"status": "success", "flow_visualization": visualization}
except Exception as e:
return {"status": "error", "error": str(e)}
def run_function_with_debugging(func_code: str, *args, **kwargs) -> dict:
"""Execute a function with full debugging."""
try:
# Compile and execute the function
local_vars = {}
exec(func_code, globals(), local_vars)
# Find the function (assuming it's the last defined function)
func = None
for name, obj in local_vars.items():
if callable(obj) and not name.startswith("_"):
func = obj
break
if func is None:
return {"status": "error", "error": "No function found in code"}
result = _automated_debugger.run_with_tracing(func, *args, **kwargs)
return {"status": "success", "debug_result": result}
except Exception as e:
return {"status": "error", "error": str(e)}
def generate_comprehensive_debug_report(output_file: str = "debug_report.json") -> dict:
"""Generate a comprehensive debug report."""
try:
report = _automated_debugger.generate_debug_report(output_file)
return {
"status": "success",
"output_file": output_file,
"report_summary": {
"static_issues": len(report.get("static_analysis", [])),
"exceptions": report.get("exceptions", {}).get("total_exceptions", 0),
"function_calls": len(report.get("dynamic_trace", {}).get("function_calls", {})),
},
}
except Exception as e:
return {"status": "error", "error": str(e)}
def auto_debug_function(func_code: str, test_inputs: list = None) -> dict:
"""Automatically debug a function with test inputs."""
try:
local_vars = {}
exec(func_code, globals(), local_vars)
func = None
for name, obj in local_vars.items():
if callable(obj) and not name.startswith("_"):
func = obj
break
if func is None:
return {"status": "error", "error": "No function found in code"}
if test_inputs is None:
test_inputs = _test_generator.analyze_function_for_tests(func)
results = _automated_debugger.auto_debug_function(func, test_inputs)
return {"status": "success", "debug_results": results}
except Exception as e:
return {"status": "error", "error": str(e)}

View File

@ -2,7 +2,7 @@ import os
import os.path
from pr.editor import RPEditor
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
from ..tools.patch import display_content_diff
from ..ui.edit_feedback import track_edit, tracker
@ -17,6 +17,8 @@ def get_editor(filepath):
def close_editor(filepath):
from pr.multiplexer import close_multiplexer, get_multiplexer
try:
path = os.path.expanduser(filepath)
editor = get_editor(path)
@ -34,6 +36,8 @@ def close_editor(filepath):
def open_editor(filepath):
from pr.multiplexer import create_multiplexer
try:
path = os.path.expanduser(filepath)
editor = RPEditor(path)
@ -53,6 +57,8 @@ def open_editor(filepath):
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
from pr.multiplexer import get_multiplexer
try:
path = os.path.expanduser(filepath)
@ -98,6 +104,8 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
def editor_replace_text(
filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True
):
from pr.multiplexer import get_multiplexer
try:
path = os.path.expanduser(filepath)
@ -148,6 +156,8 @@ def editor_replace_text(
def editor_search(filepath, pattern, start_line=0):
from pr.multiplexer import get_multiplexer
try:
path = os.path.expanduser(filepath)
editor = RPEditor(path)

View File

@ -1,6 +1,7 @@
import hashlib
import os
import time
from typing import Optional, Any
from pr.editor import RPEditor
@ -17,7 +18,17 @@ def get_uid():
return _id
def read_file(filepath, db_conn=None):
def read_file(filepath: str, db_conn: Optional[Any] = None) -> dict:
"""
Read the contents of a file.
Args:
filepath: Path to the file to read
db_conn: Optional database connection for tracking
Returns:
dict: Status and content or error
"""
try:
path = os.path.expanduser(filepath)
with open(path) as f:
@ -31,7 +42,22 @@ def read_file(filepath, db_conn=None):
return {"status": "error", "error": str(e)}
def write_file(filepath, content, db_conn=None, show_diff=True):
def write_file(
filepath: str, content: str, db_conn: Optional[Any] = None, show_diff: bool = True
) -> dict:
"""
Write content to a file.
Args:
filepath: Path to the file to write
content: Content to write
db_conn: Optional database connection for tracking
show_diff: Whether to show diff of changes
Returns:
dict: Status and message or error
"""
operation = None
try:
path = os.path.expanduser(filepath)
old_content = ""
@ -93,12 +119,13 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
return {"status": "success", "message": message}
except Exception as e:
if "operation" in locals():
if operation is not None:
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def list_directory(path=".", recursive=False):
"""List files and directories in the specified path."""
try:
path = os.path.expanduser(path)
items = []
@ -139,6 +166,7 @@ def mkdir(path):
def chdir(path):
"""Change the current working directory."""
try:
os.chdir(os.path.expanduser(path))
return {"status": "success", "new_path": os.getcwd()}
@ -147,13 +175,23 @@ def chdir(path):
def getpwd():
"""Get the current working directory."""
try:
return {"status": "success", "path": os.getcwd()}
except Exception as e:
return {"status": "error", "error": str(e)}
def index_source_directory(path):
def index_source_directory(path: str) -> dict:
"""
Index directory recursively and read all source files.
Args:
path: Path to index
Returns:
dict: Status and indexed files or error
"""
extensions = [
".py",
".js",
@ -189,7 +227,21 @@ def index_source_directory(path):
return {"status": "error", "error": str(e)}
def search_replace(filepath, old_string, new_string, db_conn=None):
def search_replace(
filepath: str, old_string: str, new_string: str, db_conn: Optional[Any] = None
) -> dict:
"""
Search and replace text in a file.
Args:
filepath: Path to the file
old_string: String to replace
new_string: Replacement string
db_conn: Optional database connection for tracking
Returns:
dict: Status and message or error
"""
try:
path = os.path.expanduser(filepath)
if not os.path.exists(path):
@ -246,6 +298,7 @@ def open_editor(filepath):
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
operation = None
try:
path = os.path.expanduser(filepath)
if db_conn:
@ -283,7 +336,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
tracker.mark_completed(operation)
return {"status": "success", "message": f"Inserted text in {path}"}
except Exception as e:
if "operation" in locals():
if operation is not None:
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
@ -299,6 +352,7 @@ def editor_replace_text(
db_conn=None,
):
try:
operation = None
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
@ -341,7 +395,7 @@ def editor_replace_text(
tracker.mark_completed(operation)
return {"status": "success", "message": f"Replaced text in {path}"}
except Exception as e:
if "operation" in locals():
if operation is not None:
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}

View File

@ -1,12 +1,17 @@
import subprocess
import threading
import importlib
from pr.multiplexer import (
close_multiplexer,
create_multiplexer,
get_all_multiplexer_states,
get_multiplexer,
)
def _get_multiplexer_functions():
"""Lazy import multiplexer functions to avoid circular imports."""
multiplexer = importlib.import_module("pr.multiplexer")
return {
"create_multiplexer": multiplexer.create_multiplexer,
"get_multiplexer": multiplexer.get_multiplexer,
"close_multiplexer": multiplexer.close_multiplexer,
"get_all_multiplexer_states": multiplexer.get_all_multiplexer_states,
}
def start_interactive_session(command, session_name=None, process_type="generic", cwd=None):
@ -22,7 +27,8 @@ def start_interactive_session(command, session_name=None, process_type="generic"
Returns:
session_name: The name of the created session
"""
name, mux = create_multiplexer(session_name)
funcs = _get_multiplexer_functions()
name, mux = funcs["create_multiplexer"](session_name)
mux.update_metadata("process_type", process_type)
# Start the process
@ -65,7 +71,7 @@ def start_interactive_session(command, session_name=None, process_type="generic"
return name
except Exception as e:
close_multiplexer(name)
funcs["close_multiplexer"](name)
raise e
@ -97,7 +103,8 @@ def send_input_to_session(session_name, input_data):
session_name: Name of the session
input_data: Input string to send
"""
mux = get_multiplexer(session_name)
funcs = _get_multiplexer_functions()
mux = funcs["get_multiplexer"](session_name)
if not mux:
raise ValueError(f"Session {session_name} not found")
@ -112,6 +119,7 @@ def send_input_to_session(session_name, input_data):
def read_session_output(session_name, lines=None):
funcs = _get_multiplexer_functions()
"""
Read output from a session.
@ -122,7 +130,7 @@ def read_session_output(session_name, lines=None):
Returns:
dict: {'stdout': str, 'stderr': str}
"""
mux = get_multiplexer(session_name)
mux = funcs["get_multiplexer"](session_name)
if not mux:
raise ValueError(f"Session {session_name} not found")
@ -142,7 +150,8 @@ def list_active_sessions():
Returns:
dict: Session states
"""
return get_all_multiplexer_states()
funcs = _get_multiplexer_functions()
return funcs["get_all_multiplexer_states"]()
def get_session_status(session_name):
@ -155,7 +164,8 @@ def get_session_status(session_name):
Returns:
dict: Session metadata and status
"""
mux = get_multiplexer(session_name)
funcs = _get_multiplexer_functions()
mux = funcs["get_multiplexer"](session_name)
if not mux:
return None
@ -175,10 +185,11 @@ def close_interactive_session(session_name):
Close an interactive session.
"""
try:
mux = get_multiplexer(session_name)
funcs = _get_multiplexer_functions()
mux = funcs["get_multiplexer"](session_name)
if mux:
mux.process.kill()
close_multiplexer(session_name)
funcs["close_multiplexer"](session_name)
return {"status": "success"}
except Exception as e:
return {"status": "error", "error": str(e)}

View File

@ -7,6 +7,16 @@ from ..ui.diff_display import display_diff, get_diff_stats
def apply_patch(filepath, patch_content, db_conn=None):
"""Apply a patch to a file.
Args:
filepath: Path to the file to patch.
patch_content: The patch content as a string.
db_conn: Database connection (optional).
Returns:
Dict with status and output.
"""
try:
path = os.path.expanduser(filepath)
if db_conn:
@ -41,6 +51,19 @@ def apply_patch(filepath, patch_content, db_conn=None):
def create_diff(
file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified"
):
"""Create a unified diff between two files.
Args:
file1: Path to the first file.
file2: Path to the second file.
fromfile: Label for the first file.
tofile: Label for the second file.
visual: Whether to include visual diff.
format_type: Diff format type.
Returns:
Dict with status and diff content.
"""
try:
path1 = os.path.expanduser(file1)
path2 = os.path.expanduser(file2)

View File

@ -5,6 +5,16 @@ from io import StringIO
def python_exec(code, python_globals, cwd=None):
"""Execute Python code and capture the output.
Args:
code: The Python code to execute.
python_globals: Dictionary of global variables for execution.
cwd: Working directory for execution.
Returns:
Dict with status and output, or error information.
"""
try:
original_cwd = None
if cwd:

19
pr/tools/vision.py Normal file
View File

@ -0,0 +1,19 @@
from pr.vision import post_image as vision_post_image
import functools
@functools.lru_cache()
def post_image(path: str, prompt: str = None):
"""Post an image for analysis.
Args:
path: Path to the image file.
prompt: Optional prompt for analysis.
Returns:
Analysis result.
"""
try:
return vision_post_image(path=path, prompt=prompt)
except Exception:
raise

View File

@ -5,6 +5,15 @@ import urllib.request
def http_fetch(url, headers=None):
"""Fetch content from an HTTP URL.
Args:
url: The URL to fetch.
headers: Optional HTTP headers.
Returns:
Dict with status and content.
"""
try:
req = urllib.request.Request(url)
if headers:
@ -30,10 +39,26 @@ def _perform_search(base_url, query, params=None):
def web_search(query):
"""Perform a web search.
Args:
query: Search query.
Returns:
Dict with status and search results.
"""
base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query)
def web_search_news(query):
"""Perform a web search for news.
Args:
query: Search query for news.
Returns:
Dict with status and news search results.
"""
base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query)

50
pr/vision.py Executable file
View File

@ -0,0 +1,50 @@
#!/usr/bin/env python
import http.client
import argparse
import base64
import json
import http.client
import pathlib
DEFAULT_URL = "https://static.molodetz.nl/rp.vision.cgi"
def post_image(image_path: str, prompt: str = "", url: str = DEFAULT_URL):
image_path = str(pathlib.Path(image_path).resolve().absolute())
if not url:
url = DEFAULT_URL
url_parts = url.split("/")
host = url_parts[2]
path = "/" + "/".join(url_parts[3:])
with open(image_path, "rb") as file:
image_data = file.read()
base64_data = base64.b64encode(image_data).decode("utf-8")
payload = {"data": base64_data, "path": image_path, "prompt": prompt}
body = json.dumps(payload).encode("utf-8")
headers = {
"Content-Type": "application/json",
"Content-Length": str(len(body)),
"User-Agent": "Python http.client",
}
conn = http.client.HTTPSConnection(host)
conn.request("POST", path, body, headers)
resp = conn.getresponse()
data = resp.read()
print("Status:", resp.status, resp.reason)
print(data.decode())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("image_path")
parser.add_argument("--prompt", default="")
parser.add_argument("--url", default=DEFAULT_URL)
args = parser.parse_args()
post_image(args.url, args.image_path, args.prompt)

View File

@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
[project]
name = "rp"
version = "1.2.0"
version = "1.3.0"
description = "R python edition. The ultimate autonomous AI CLI."
readme = "README.md"
requires-python = ">=3.13.3"
requires-python = ">=3.12"
license = {text = "MIT"}
keywords = ["ai", "assistant", "cli", "automation", "openrouter", "autonomous"]
authors = [
@ -16,6 +16,7 @@ authors = [
dependencies = [
"aiohttp>=3.13.2",
"pydantic>=2.12.3",
"prompt_toolkit>=3.0.0",
]
classifiers = [
"Development Status :: 4 - Beta",

6
rp.py
View File

@ -2,6 +2,12 @@
# Trigger build
import sys
import os
# Add current directory to path to ensure imports work
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from pr.__main__ import main
if __name__ == "__main__":

View File

@ -1,10 +1,8 @@
from unittest.mock import mock_open, patch
from unittest.mock import patch
from pr.core.config_loader import (
_load_config_file,
_parse_value,
create_default_config,
load_config,
)
@ -36,36 +34,3 @@ def test_parse_value_bool_upper():
def test_load_config_file_not_exists(mock_exists):
config = _load_config_file("test.ini")
assert config == {}
@patch("os.path.exists", return_value=True)
@patch("configparser.ConfigParser")
def test_load_config_file_exists(mock_parser_class, mock_exists):
mock_parser = mock_parser_class.return_value
mock_parser.sections.return_value = ["api"]
mock_parser.items.return_value = [("key", "value")]
config = _load_config_file("test.ini")
assert "api" in config
assert config["api"]["key"] == "value"
@patch("pr.core.config_loader._load_config_file")
def test_load_config(mock_load):
mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}]
config = load_config()
assert config["api"]["key"] == "local"
@patch("builtins.open", new_callable=mock_open)
def test_create_default_config(mock_file):
result = create_default_config("test.ini")
assert result == True
mock_file.assert_called_once_with("test.ini", "w")
handle = mock_file()
handle.write.assert_called_once()
@patch("builtins.open", side_effect=Exception("error"))
def test_create_default_config_error(mock_file):
result = create_default_config("test.ini")
assert result == False

View File

@ -0,0 +1,372 @@
import sqlite3
import tempfile
import os
import time
from pr.memory.conversation_memory import ConversationMemory
class TestConversationMemory:
def setup_method(self):
"""Set up test database for each test."""
self.db_fd, self.db_path = tempfile.mkstemp()
self.memory = ConversationMemory(self.db_path)
def teardown_method(self):
"""Clean up test database after each test."""
self.memory = None
os.close(self.db_fd)
os.unlink(self.db_path)
def test_init(self):
"""Test ConversationMemory initialization."""
assert self.memory.db_path == self.db_path
# Verify tables were created
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
assert "conversation_history" in tables
assert "conversation_messages" in tables
conn.close()
def test_create_conversation(self):
"""Test creating a new conversation."""
conversation_id = "test_conv_123"
session_id = "test_session_456"
self.memory.create_conversation(conversation_id, session_id)
# Verify conversation was created
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == conversation_id
assert row[1] == session_id
conn.close()
def test_create_conversation_without_session(self):
"""Test creating a conversation without session ID."""
conversation_id = "test_conv_no_session"
self.memory.create_conversation(conversation_id)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == conversation_id
assert row[1] is None
conn.close()
def test_create_conversation_with_metadata(self):
"""Test creating a conversation with metadata."""
conversation_id = "test_conv_metadata"
metadata = {"topic": "test", "priority": "high"}
self.memory.create_conversation(conversation_id, metadata=metadata)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT conversation_id, metadata FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == conversation_id
assert row[1] is not None
conn.close()
def test_add_message(self):
"""Test adding a message to a conversation."""
conversation_id = "test_conv_msg"
message_id = "test_msg_123"
role = "user"
content = "Hello, world!"
# Create conversation first
self.memory.create_conversation(conversation_id)
# Add message
self.memory.add_message(conversation_id, message_id, role, content)
# Verify message was added
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT message_id, conversation_id, role, content FROM conversation_messages WHERE message_id = ?",
(message_id,),
)
row = cursor.fetchone()
assert row[0] == message_id
assert row[1] == conversation_id
assert row[2] == role
assert row[3] == content
# Verify message count was updated
cursor.execute(
"SELECT message_count FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
count_row = cursor.fetchone()
assert count_row[0] == 1
conn.close()
def test_add_message_with_tool_calls(self):
"""Test adding a message with tool calls."""
conversation_id = "test_conv_tools"
message_id = "test_msg_tools"
role = "assistant"
content = "I'll help you with that."
tool_calls = [{"function": "test_func", "args": {"param": "value"}}]
self.memory.create_conversation(conversation_id)
self.memory.add_message(conversation_id, message_id, role, content, tool_calls=tool_calls)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT tool_calls FROM conversation_messages WHERE message_id = ?", (message_id,)
)
row = cursor.fetchone()
assert row[0] is not None
conn.close()
def test_add_message_with_metadata(self):
"""Test adding a message with metadata."""
conversation_id = "test_conv_meta"
message_id = "test_msg_meta"
role = "user"
content = "Test message"
metadata = {"tokens": 5, "model": "gpt-4"}
self.memory.create_conversation(conversation_id)
self.memory.add_message(conversation_id, message_id, role, content, metadata=metadata)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT metadata FROM conversation_messages WHERE message_id = ?", (message_id,)
)
row = cursor.fetchone()
assert row[0] is not None
conn.close()
def test_get_conversation_messages(self):
"""Test retrieving conversation messages."""
conversation_id = "test_conv_get"
self.memory.create_conversation(conversation_id)
# Add multiple messages
messages = [
("msg1", "user", "Hello"),
("msg2", "assistant", "Hi there"),
("msg3", "user", "How are you?"),
]
for msg_id, role, content in messages:
self.memory.add_message(conversation_id, msg_id, role, content)
# Retrieve all messages
retrieved = self.memory.get_conversation_messages(conversation_id)
assert len(retrieved) == 3
assert retrieved[0]["message_id"] == "msg1"
assert retrieved[1]["message_id"] == "msg2"
assert retrieved[2]["message_id"] == "msg3"
def test_get_conversation_messages_limited(self):
"""Test retrieving limited number of conversation messages."""
conversation_id = "test_conv_limit"
self.memory.create_conversation(conversation_id)
# Add multiple messages
for i in range(5):
self.memory.add_message(conversation_id, f"msg{i}", "user", f"Message {i}")
# Retrieve limited messages
retrieved = self.memory.get_conversation_messages(conversation_id, limit=3)
assert len(retrieved) == 3
# Should return most recent messages first due to DESC order
assert retrieved[0]["message_id"] == "msg4"
assert retrieved[1]["message_id"] == "msg3"
assert retrieved[2]["message_id"] == "msg2"
def test_update_conversation_summary(self):
"""Test updating conversation summary."""
conversation_id = "test_conv_summary"
self.memory.create_conversation(conversation_id)
summary = "This is a test conversation summary"
topics = ["testing", "memory", "conversation"]
self.memory.update_conversation_summary(conversation_id, summary, topics)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT summary, topics, ended_at FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == summary
assert row[1] is not None # topics should be stored
assert row[2] is not None # ended_at should be set
conn.close()
def test_search_conversations(self):
"""Test searching conversations by content."""
# Create conversations with different content
conv1 = "conv_search_1"
conv2 = "conv_search_2"
conv3 = "conv_search_3"
self.memory.create_conversation(conv1)
self.memory.create_conversation(conv2)
self.memory.create_conversation(conv3)
# Add messages with searchable content
self.memory.add_message(conv1, "msg1", "user", "Python programming tutorial")
self.memory.add_message(conv2, "msg2", "user", "JavaScript development guide")
self.memory.add_message(conv3, "msg3", "user", "Database design principles")
# Search for "programming"
results = self.memory.search_conversations("programming")
assert len(results) == 1
assert results[0]["conversation_id"] == conv1
# Search for "development"
results = self.memory.search_conversations("development")
assert len(results) == 1
assert results[0]["conversation_id"] == conv2
def test_get_recent_conversations(self):
"""Test getting recent conversations."""
# Create conversations at different times
conv1 = "conv_recent_1"
conv2 = "conv_recent_2"
conv3 = "conv_recent_3"
self.memory.create_conversation(conv1)
time.sleep(0.01) # Small delay to ensure different timestamps
self.memory.create_conversation(conv2)
time.sleep(0.01)
self.memory.create_conversation(conv3)
# Get recent conversations
recent = self.memory.get_recent_conversations(limit=2)
assert len(recent) == 2
# Should be ordered by started_at DESC
assert recent[0]["conversation_id"] == conv3
assert recent[1]["conversation_id"] == conv2
def test_get_recent_conversations_by_session(self):
"""Test getting recent conversations for a specific session."""
session1 = "session_1"
session2 = "session_2"
conv1 = "conv_session_1"
conv2 = "conv_session_2"
conv3 = "conv_session_3"
self.memory.create_conversation(conv1, session1)
self.memory.create_conversation(conv2, session2)
self.memory.create_conversation(conv3, session1)
# Get conversations for session1
session_convs = self.memory.get_recent_conversations(session_id=session1)
assert len(session_convs) == 2
conversation_ids = [c["conversation_id"] for c in session_convs]
assert conv1 in conversation_ids
assert conv3 in conversation_ids
assert conv2 not in conversation_ids
def test_delete_conversation(self):
"""Test deleting a conversation."""
conversation_id = "conv_delete"
self.memory.create_conversation(conversation_id)
# Add some messages
self.memory.add_message(conversation_id, "msg1", "user", "Test message")
self.memory.add_message(conversation_id, "msg2", "assistant", "Response")
# Delete conversation
result = self.memory.delete_conversation(conversation_id)
assert result is True
# Verify conversation and messages are gone
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
assert cursor.fetchone()[0] == 0
cursor.execute(
"SELECT COUNT(*) FROM conversation_messages WHERE conversation_id = ?",
(conversation_id,),
)
assert cursor.fetchone()[0] == 0
conn.close()
def test_delete_nonexistent_conversation(self):
"""Test deleting a non-existent conversation."""
result = self.memory.delete_conversation("nonexistent")
assert result is False
def test_get_statistics(self):
"""Test getting memory statistics."""
# Create some conversations and messages
for i in range(3):
conv_id = f"conv_stats_{i}"
self.memory.create_conversation(conv_id)
for j in range(2):
self.memory.add_message(conv_id, f"msg_{i}_{j}", "user", f"Message {j}")
stats = self.memory.get_statistics()
assert stats["total_conversations"] == 3
assert stats["total_messages"] == 6
assert stats["average_messages_per_conversation"] == 2.0
def test_thread_safety(self):
"""Test that the memory can handle concurrent access."""
import threading
import queue
results = queue.Queue()
def worker(worker_id):
try:
conv_id = f"conv_thread_{worker_id}"
self.memory.create_conversation(conv_id)
self.memory.add_message(conv_id, f"msg_{worker_id}", "user", f"Worker {worker_id}")
results.put(True)
except Exception as e:
results.put(e)
# Start multiple threads
threads = []
for i in range(5):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check results
for _ in range(5):
result = results.get()
assert result is True
# Verify all conversations were created
recent = self.memory.get_recent_conversations(limit=10)
assert len(recent) >= 5

View File

@ -0,0 +1,242 @@
from pr.memory.fact_extractor import FactExtractor
class TestFactExtractor:
def setup_method(self):
"""Set up test fixture."""
self.extractor = FactExtractor()
def test_init(self):
"""Test FactExtractor initialization."""
assert self.extractor.fact_patterns is not None
assert len(self.extractor.fact_patterns) > 0
def test_extract_facts_definition(self):
"""Test extracting definition facts."""
text = "John Smith is a software engineer. Python is a programming language."
facts = self.extractor.extract_facts(text)
assert len(facts) >= 2
# Check for definition pattern matches
definition_facts = [f for f in facts if f["type"] == "definition"]
assert len(definition_facts) >= 1
def test_extract_facts_temporal(self):
"""Test extracting temporal facts."""
text = "John was born in 1990. The company was founded in 2010."
facts = self.extractor.extract_facts(text)
temporal_facts = [f for f in facts if f["type"] == "temporal"]
assert len(temporal_facts) >= 1
def test_extract_facts_attribution(self):
"""Test extracting attribution facts."""
text = "John invented the widget. Mary developed the software."
facts = self.extractor.extract_facts(text)
attribution_facts = [f for f in facts if f["type"] == "attribution"]
assert len(attribution_facts) >= 1
def test_extract_facts_numeric(self):
"""Test extracting numeric facts."""
text = "The car costs $25,000. The house is worth $500,000."
facts = self.extractor.extract_facts(text)
numeric_facts = [f for f in facts if f["type"] == "numeric"]
assert len(numeric_facts) >= 1
def test_extract_facts_location(self):
"""Test extracting location facts."""
text = "John lives in San Francisco. The office is located in New York."
facts = self.extractor.extract_facts(text)
location_facts = [f for f in facts if f["type"] == "location"]
assert len(location_facts) >= 1
def test_extract_facts_entity(self):
"""Test extracting entity facts from noun phrases."""
text = "John Smith works at Google Inc. He uses Python programming."
facts = self.extractor.extract_facts(text)
entity_facts = [f for f in facts if f["type"] == "entity"]
assert len(entity_facts) >= 1
def test_extract_noun_phrases(self):
"""Test noun phrase extraction."""
text = "John Smith is a software engineer at Google. He works on Python Projects."
phrases = self.extractor._extract_noun_phrases(text)
assert "John Smith" in phrases
assert "Google" in phrases
assert "Python Projects" in phrases
def test_extract_noun_phrases_capitalized(self):
"""Test that only capitalized noun phrases are extracted."""
text = "the quick brown fox jumps over the lazy dog"
phrases = self.extractor._extract_noun_phrases(text)
# Should be empty since no capitalized words
assert len(phrases) == 0
def test_extract_key_terms(self):
"""Test key term extraction."""
text = "Python is a programming language used for software development and data analysis."
terms = self.extractor.extract_key_terms(text, top_k=5)
assert len(terms) <= 5
# Should contain programming, language, software, development, data, analysis
term_words = [term[0] for term in terms]
assert "programming" in term_words
assert "language" in term_words
assert "software" in term_words
def test_extract_key_terms_stopwords_filtered(self):
"""Test that stopwords are filtered from key terms."""
text = "This is a test of the system that should work properly."
terms = self.extractor.extract_key_terms(text)
term_words = [term[0] for term in terms]
# Stopwords should not appear
assert "this" not in term_words
assert "is" not in term_words
assert "a" not in term_words
assert "of" not in term_words
assert "the" not in term_words
def test_extract_relationships_employment(self):
"""Test extracting employment relationships."""
text = "John works for Google. Mary is employed by Microsoft."
relationships = self.extractor.extract_relationships(text)
employment_rels = [r for r in relationships if r["type"] == "employment"]
assert len(employment_rels) >= 1
def test_extract_relationships_ownership(self):
"""Test extracting ownership relationships."""
text = "John owns a car. Mary has a house."
relationships = self.extractor.extract_relationships(text)
ownership_rels = [r for r in relationships if r["type"] == "ownership"]
assert len(ownership_rels) >= 1
def test_extract_relationships_location(self):
"""Test extracting location relationships."""
text = "John located in New York. The factory belongs to Google."
relationships = self.extractor.extract_relationships(text)
location_rels = [r for r in relationships if r["type"] == "location"]
assert len(location_rels) >= 1
def test_extract_relationships_usage(self):
"""Test extracting usage relationships."""
text = "John uses Python. The company implements agile methodology."
relationships = self.extractor.extract_relationships(text)
usage_rels = [r for r in relationships if r["type"] == "usage"]
assert len(usage_rels) >= 1
def test_extract_metadata(self):
"""Test metadata extraction."""
text = "This is a test document. It contains some information about Python programming. You can visit https://python.org for more details. Contact john@example.com for questions. The project started in 2020 and costs $10,000."
metadata = self.extractor.extract_metadata(text)
assert metadata["word_count"] > 0
assert metadata["sentence_count"] > 0
assert metadata["avg_words_per_sentence"] > 0
assert len(metadata["urls"]) > 0
assert len(metadata["email_addresses"]) > 0
assert len(metadata["dates"]) > 0
assert len(metadata["numeric_values"]) > 0
assert metadata["has_code"] is False # No code in this text
def test_extract_metadata_with_code(self):
"""Test metadata extraction with code content."""
text = "Here is a function: def hello(): print('Hello, world!')"
metadata = self.extractor.extract_metadata(text)
assert metadata["has_code"] is True
def test_extract_metadata_with_questions(self):
"""Test metadata extraction with questions."""
text = "What is Python? How does it work? Why use it?"
metadata = self.extractor.extract_metadata(text)
assert metadata["has_questions"] is True
def test_categorize_content_programming(self):
"""Test content categorization for programming."""
text = "Python is a programming language used for code development and debugging."
categories = self.extractor.categorize_content(text)
assert "programming" in categories
def test_categorize_content_data(self):
"""Test content categorization for data."""
text = "The database contains records and tables with statistical analysis."
categories = self.extractor.categorize_content(text)
assert "data" in categories
def test_categorize_content_documentation(self):
"""Test content categorization for documentation."""
text = "This guide explains how to use the tutorial and manual."
categories = self.extractor.categorize_content(text)
assert "documentation" in categories
def test_categorize_content_configuration(self):
"""Test content categorization for configuration."""
text = "Configure the settings and setup the deployment environment."
categories = self.extractor.categorize_content(text)
assert "configuration" in categories
def test_categorize_content_testing(self):
"""Test content categorization for testing."""
text = "Run the tests to validate the functionality and verify quality."
categories = self.extractor.categorize_content(text)
assert "testing" in categories
def test_categorize_content_research(self):
"""Test content categorization for research."""
text = "The study investigates findings and results from the analysis."
categories = self.extractor.categorize_content(text)
assert "research" in categories
def test_categorize_content_planning(self):
"""Test content categorization for planning."""
text = "Plan the project schedule with milestones and timeline."
categories = self.extractor.categorize_content(text)
assert "planning" in categories
def test_categorize_content_general(self):
"""Test content categorization defaults to general."""
text = "This is some random text without specific keywords."
categories = self.extractor.categorize_content(text)
assert "general" in categories
def test_extract_facts_empty_text(self):
"""Test fact extraction with empty text."""
facts = self.extractor.extract_facts("")
assert len(facts) == 0
def test_extract_key_terms_empty_text(self):
"""Test key term extraction with empty text."""
terms = self.extractor.extract_key_terms("")
assert len(terms) == 0
def test_extract_relationships_empty_text(self):
"""Test relationship extraction with empty text."""
relationships = self.extractor.extract_relationships("")
assert len(relationships) == 0
def test_extract_metadata_empty_text(self):
"""Test metadata extraction with empty text."""
metadata = self.extractor.extract_metadata("")
assert metadata["word_count"] == 0
assert metadata["sentence_count"] == 0
assert metadata["avg_words_per_sentence"] == 0.0

View File

@ -0,0 +1,344 @@
import sqlite3
import tempfile
import os
import time
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
class TestKnowledgeStore:
def setup_method(self):
"""Set up test database for each test."""
self.db_fd, self.db_path = tempfile.mkstemp()
self.store = KnowledgeStore(self.db_path)
def teardown_method(self):
"""Clean up test database after each test."""
self.store = None
os.close(self.db_fd)
os.unlink(self.db_path)
def test_init(self):
"""Test KnowledgeStore initialization."""
assert self.store.db_path == self.db_path
# Verify tables were created
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
assert "knowledge_entries" in tables
conn.close()
def test_add_entry(self):
"""Test adding a knowledge entry."""
entry = KnowledgeEntry(
entry_id="test_1",
category="test",
content="This is a test entry",
metadata={"source": "test"},
created_at=time.time(),
updated_at=time.time(),
)
self.store.add_entry(entry)
# Verify entry was added
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT entry_id, category, content FROM knowledge_entries WHERE entry_id = ?",
("test_1",),
)
row = cursor.fetchone()
assert row[0] == "test_1"
assert row[1] == "test"
assert row[2] == "This is a test entry"
conn.close()
def test_get_entry(self):
"""Test retrieving a knowledge entry."""
entry = KnowledgeEntry(
entry_id="test_get",
category="test",
content="Content to retrieve",
metadata={},
created_at=time.time(),
updated_at=time.time(),
)
self.store.add_entry(entry)
retrieved = self.store.get_entry("test_get")
assert retrieved is not None
assert retrieved.entry_id == "test_get"
assert retrieved.content == "Content to retrieve"
assert retrieved.access_count == 1 # Should be incremented
def test_get_entry_not_found(self):
"""Test retrieving a non-existent entry."""
retrieved = self.store.get_entry("nonexistent")
assert retrieved is None
def test_search_entries_semantic(self):
"""Test semantic search."""
entries = [
KnowledgeEntry(
"entry1", "personal", "John is a software engineer", {}, time.time(), time.time()
),
KnowledgeEntry(
"entry2", "personal", "Mary works as a designer", {}, time.time(), time.time()
),
KnowledgeEntry(
"entry3", "tech", "Python is a programming language", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("software engineer", top_k=2)
assert len(results) >= 1
# Should find the most relevant entry
found_ids = [r.entry_id for r in results]
assert "entry1" in found_ids
def test_search_entries_fts_exact(self):
"""Test full-text search with exact matches."""
entries = [
KnowledgeEntry(
"exact1", "test", "Python programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"exact2", "test", "Java programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"exact3", "test", "Database design principles", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("programming language", top_k=3)
assert len(results) >= 2
found_ids = [r.entry_id for r in results]
assert "exact1" in found_ids
assert "exact2" in found_ids
def test_search_entries_by_category(self):
"""Test searching entries by category."""
entries = [
KnowledgeEntry("cat1", "personal", "John's info", {}, time.time(), time.time()),
KnowledgeEntry("cat2", "tech", "Python info", {}, time.time(), time.time()),
KnowledgeEntry("cat3", "personal", "Jane's info", {}, time.time(), time.time()),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("info", category="personal", top_k=5)
assert len(results) == 2
found_ids = [r.entry_id for r in results]
assert "cat1" in found_ids
assert "cat3" in found_ids
assert "cat2" not in found_ids
def test_get_by_category(self):
"""Test getting entries by category."""
entries = [
KnowledgeEntry("get1", "personal", "Entry 1", {}, time.time(), time.time()),
KnowledgeEntry("get2", "tech", "Entry 2", {}, time.time(), time.time()),
KnowledgeEntry("get3", "personal", "Entry 3", {}, time.time() + 1, time.time() + 1),
]
for entry in entries:
self.store.add_entry(entry)
personal_entries = self.store.get_by_category("personal")
assert len(personal_entries) == 2
# Should be ordered by importance_score DESC, created_at DESC
assert personal_entries[0].entry_id == "get3" # More recent
assert personal_entries[1].entry_id == "get1"
def test_update_importance(self):
"""Test updating entry importance."""
entry = KnowledgeEntry(
"importance_test", "test", "Test content", {}, time.time(), time.time()
)
self.store.add_entry(entry)
self.store.update_importance("importance_test", 0.8)
retrieved = self.store.get_entry("importance_test")
assert retrieved.importance_score == 0.8
def test_delete_entry(self):
"""Test deleting an entry."""
entry = KnowledgeEntry("delete_test", "test", "To be deleted", {}, time.time(), time.time())
self.store.add_entry(entry)
result = self.store.delete_entry("delete_test")
assert result is True
# Verify it's gone
retrieved = self.store.get_entry("delete_test")
assert retrieved is None
def test_delete_entry_not_found(self):
"""Test deleting a non-existent entry."""
result = self.store.delete_entry("nonexistent")
assert result is False
def test_get_statistics(self):
"""Test getting store statistics."""
entries = [
KnowledgeEntry("stat1", "personal", "Personal info", {}, time.time(), time.time()),
KnowledgeEntry("stat2", "tech", "Tech info", {}, time.time(), time.time()),
KnowledgeEntry("stat3", "personal", "More personal info", {}, time.time(), time.time()),
]
for entry in entries:
self.store.add_entry(entry)
stats = self.store.get_statistics()
assert stats["total_entries"] == 3
assert stats["total_categories"] == 2
assert stats["category_distribution"]["personal"] == 2
assert stats["category_distribution"]["tech"] == 1
assert stats["vocabulary_size"] > 0
def test_fts_search_exact_phrase(self):
"""Test FTS exact phrase matching."""
entries = [
KnowledgeEntry(
"fts1", "test", "Python is great for programming", {}, time.time(), time.time()
),
KnowledgeEntry(
"fts2", "test", "Java is also good for programming", {}, time.time(), time.time()
),
KnowledgeEntry(
"fts3", "test", "Database management is important", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
fts_results = self.store._fts_search("programming")
assert len(fts_results) == 2
entry_ids = [entry_id for entry_id, score in fts_results]
assert "fts1" in entry_ids
assert "fts2" in entry_ids
def test_fts_search_partial_match(self):
"""Test FTS partial word matching."""
entries = [
KnowledgeEntry("partial1", "test", "Python programming", {}, time.time(), time.time()),
KnowledgeEntry("partial2", "test", "Java programming", {}, time.time(), time.time()),
KnowledgeEntry("partial3", "test", "Database design", {}, time.time(), time.time()),
]
for entry in entries:
self.store.add_entry(entry)
fts_results = self.store._fts_search("program")
assert len(fts_results) >= 2
def test_combined_search_scoring(self):
"""Test that combined semantic + FTS search produces proper scoring."""
entries = [
KnowledgeEntry(
"combined1", "test", "Python programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"combined2", "test", "Java programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"combined3", "test", "Database management system", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("programming language")
assert len(results) >= 2
# Check that results have search scores (at least one should have a positive score)
has_positive_score = False
for result in results:
assert "search_score" in result.metadata
if result.metadata["search_score"] > 0:
has_positive_score = True
assert has_positive_score
def test_empty_search(self):
"""Test searching with no matches."""
results = self.store.search_entries("nonexistent topic")
assert len(results) == 0
def test_search_empty_query(self):
"""Test searching with empty query."""
results = self.store.search_entries("")
assert len(results) == 0
def test_thread_safety(self):
"""Test that the store can handle concurrent access."""
import threading
import queue
results = queue.Queue()
def worker(worker_id):
try:
entry = KnowledgeEntry(
f"thread_{worker_id}",
"test",
f"Content from worker {worker_id}",
{},
time.time(),
time.time(),
)
self.store.add_entry(entry)
retrieved = self.store.get_entry(f"thread_{worker_id}")
assert retrieved is not None
results.put(True)
except Exception as e:
results.put(e)
# Start multiple threads
threads = []
for i in range(5):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check results
for _ in range(5):
result = results.get()
assert result is True
def test_entry_to_dict(self):
"""Test KnowledgeEntry to_dict method."""
entry = KnowledgeEntry(
entry_id="dict_test",
category="test",
content="Test content",
metadata={"key": "value"},
created_at=1234567890.0,
updated_at=1234567891.0,
access_count=5,
importance_score=0.8,
)
entry_dict = entry.to_dict()
assert entry_dict["entry_id"] == "dict_test"
assert entry_dict["category"] == "test"
assert entry_dict["content"] == "Test content"
assert entry_dict["metadata"]["key"] == "value"
assert entry_dict["access_count"] == 5
assert entry_dict["importance_score"] == 0.8

View File

@ -0,0 +1,280 @@
import math
import pytest
from pr.memory.semantic_index import SemanticIndex
class TestSemanticIndex:
def test_init(self):
"""Test SemanticIndex initialization."""
index = SemanticIndex()
assert index.documents == {}
assert index.vocabulary == set()
assert index.idf_scores == {}
assert index.doc_tf_scores == {}
def test_tokenize_basic(self):
"""Test basic tokenization functionality."""
index = SemanticIndex()
tokens = index._tokenize("Hello, world! This is a test.")
expected = ["hello", "world", "this", "is", "a", "test"]
assert tokens == expected
def test_tokenize_special_characters(self):
"""Test tokenization with special characters."""
index = SemanticIndex()
tokens = index._tokenize("Hello@world.com test-case_123")
expected = ["hello", "world", "com", "test", "case", "123"]
assert tokens == expected
def test_tokenize_empty_string(self):
"""Test tokenization of empty string."""
index = SemanticIndex()
tokens = index._tokenize("")
assert tokens == []
def test_tokenize_only_special_chars(self):
"""Test tokenization with only special characters."""
index = SemanticIndex()
tokens = index._tokenize("!@#$%^&*()")
assert tokens == []
def test_compute_tf_basic(self):
"""Test TF computation for basic case."""
index = SemanticIndex()
tokens = ["hello", "world", "hello", "test"]
tf_scores = index._compute_tf(tokens)
expected = {"hello": 2 / 4, "world": 1 / 4, "test": 1 / 4} # 0.5 # 0.25 # 0.25
assert tf_scores == expected
def test_compute_tf_empty(self):
"""Test TF computation for empty tokens."""
index = SemanticIndex()
tf_scores = index._compute_tf([])
assert tf_scores == {}
def test_compute_tf_single_token(self):
"""Test TF computation for single token."""
index = SemanticIndex()
tokens = ["hello"]
tf_scores = index._compute_tf(tokens)
assert tf_scores == {"hello": 1.0}
def test_compute_idf_single_document(self):
"""Test IDF computation with single document."""
index = SemanticIndex()
index.documents = {"doc1": "hello world"}
index._compute_idf()
assert index.idf_scores == {"hello": 1.0, "world": 1.0}
def test_compute_idf_multiple_documents(self):
"""Test IDF computation with multiple documents."""
index = SemanticIndex()
index.documents = {"doc1": "hello world", "doc2": "hello test", "doc3": "world test"}
index._compute_idf()
expected = {
"hello": math.log(3 / 2), # appears in 2/3 docs
"world": math.log(3 / 2), # appears in 2/3 docs
"test": math.log(3 / 2), # appears in 2/3 docs
}
assert index.idf_scores == expected
def test_compute_idf_empty_documents(self):
"""Test IDF computation with no documents."""
index = SemanticIndex()
index._compute_idf()
assert index.idf_scores == {}
def test_add_document_basic(self):
"""Test adding a basic document."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
assert "doc1" in index.documents
assert index.documents["doc1"] == "hello world"
assert "hello" in index.vocabulary
assert "world" in index.vocabulary
assert "doc1" in index.doc_tf_scores
def test_add_document_updates_vocabulary(self):
"""Test that adding documents updates vocabulary."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
assert index.vocabulary == {"hello", "world"}
index.add_document("doc2", "hello test")
assert index.vocabulary == {"hello", "world", "test"}
def test_add_document_updates_idf(self):
"""Test that adding documents updates IDF scores."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
assert index.idf_scores == {"hello": 1.0, "world": 1.0}
index.add_document("doc2", "hello test")
expected_idf = {
"hello": math.log(2 / 2), # appears in both docs
"world": math.log(2 / 1), # appears in 1/2 docs
"test": math.log(2 / 1), # appears in 1/2 docs
}
assert index.idf_scores == expected_idf
def test_add_document_tf_computation(self):
"""Test TF score computation when adding document."""
index = SemanticIndex()
index.add_document("doc1", "hello world hello")
# TF: hello=2/3, world=1/3
expected_tf = {"hello": 2 / 3, "world": 1 / 3}
assert index.doc_tf_scores["doc1"] == expected_tf
def test_remove_document_existing(self):
"""Test removing an existing document."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
index.add_document("doc2", "hello test")
initial_vocab = index.vocabulary.copy()
initial_idf = index.idf_scores.copy()
index.remove_document("doc1")
assert "doc1" not in index.documents
assert "doc1" not in index.doc_tf_scores
# Vocabulary should still contain all words
assert index.vocabulary == initial_vocab
# IDF should be recomputed
assert index.idf_scores != initial_idf
assert index.idf_scores == {"hello": 1.0, "test": 1.0}
def test_remove_document_nonexistent(self):
"""Test removing a non-existent document."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
initial_state = {
"documents": index.documents.copy(),
"vocabulary": index.vocabulary.copy(),
"idf_scores": index.idf_scores.copy(),
"doc_tf_scores": index.doc_tf_scores.copy(),
}
index.remove_document("nonexistent")
assert index.documents == initial_state["documents"]
assert index.vocabulary == initial_state["vocabulary"]
assert index.idf_scores == initial_state["idf_scores"]
assert index.doc_tf_scores == initial_state["doc_tf_scores"]
def test_search_basic(self):
"""Test basic search functionality."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
index.add_document("doc2", "hello test")
index.add_document("doc3", "world test")
results = index.search("hello", top_k=5)
assert len(results) == 3 # All documents are returned with similarity scores
# Results should be sorted by similarity (descending)
scores = {doc_id: score for doc_id, score in results}
assert scores["doc1"] > 0 # doc1 contains "hello"
assert scores["doc2"] > 0 # doc2 contains "hello"
assert scores["doc3"] == 0 # doc3 does not contain "hello"
def test_search_empty_query(self):
"""Test search with empty query."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
results = index.search("", top_k=5)
assert results == []
def test_search_no_documents(self):
"""Test search when no documents exist."""
index = SemanticIndex()
results = index.search("hello", top_k=5)
assert results == []
def test_search_top_k_limit(self):
"""Test search respects top_k parameter."""
index = SemanticIndex()
for i in range(10):
index.add_document(f"doc{i}", f"hello world content")
results = index.search("content", top_k=3)
assert len(results) == 3
def test_cosine_similarity_identical_vectors(self):
"""Test cosine similarity with identical vectors."""
index = SemanticIndex()
vec1 = {"hello": 1.0, "world": 0.5}
vec2 = {"hello": 1.0, "world": 0.5}
similarity = index._cosine_similarity(vec1, vec2)
assert similarity == pytest.approx(1.0)
def test_cosine_similarity_orthogonal_vectors(self):
"""Test cosine similarity with orthogonal vectors."""
index = SemanticIndex()
vec1 = {"hello": 1.0}
vec2 = {"world": 1.0}
similarity = index._cosine_similarity(vec1, vec2)
assert similarity == 0.0
def test_cosine_similarity_zero_vector(self):
"""Test cosine similarity with zero vector."""
index = SemanticIndex()
vec1 = {"hello": 1.0}
vec2 = {}
similarity = index._cosine_similarity(vec1, vec2)
assert similarity == 0.0
def test_cosine_similarity_empty_vectors(self):
"""Test cosine similarity with empty vectors."""
index = SemanticIndex()
similarity = index._cosine_similarity({}, {})
assert similarity == 0.0
def test_search_relevance_ordering(self):
"""Test that search results are ordered by relevance."""
index = SemanticIndex()
index.add_document("doc1", "hello hello hello") # High TF for "hello"
index.add_document("doc2", "hello world") # Medium relevance
index.add_document("doc3", "world test") # No "hello"
results = index.search("hello", top_k=5)
assert len(results) == 3 # All documents are returned
# doc1 should have higher score than doc2, and doc3 should have 0
scores = {doc_id: score for doc_id, score in results}
assert scores["doc1"] > scores["doc2"] > scores["doc3"]
assert scores["doc3"] == 0
def test_vocabulary_persistence(self):
"""Test that vocabulary persists even after document removal."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
index.add_document("doc2", "test case")
assert index.vocabulary == {"hello", "world", "test", "case"}
index.remove_document("doc1")
# Vocabulary should still contain all words
assert index.vocabulary == {"hello", "world", "test", "case"}
def test_idf_recomputation_after_removal(self):
"""Test IDF recomputation after document removal."""
index = SemanticIndex()
index.add_document("doc1", "hello world")
index.add_document("doc2", "hello test")
index.add_document("doc3", "world test")
# Remove doc3, leaving doc1 and doc2
index.remove_document("doc3")
# "hello" appears in both remaining docs, "world" and "test" in one each
expected_idf = {
"hello": math.log(2 / 2), # 2 docs, appears in 2
"world": math.log(2 / 1), # 2 docs, appears in 1
"test": math.log(2 / 1), # 2 docs, appears in 1
}
assert index.idf_scores == expected_idf

View File

@ -168,7 +168,6 @@ class TestToolDefinitions:
tool_names = [t["function"]["name"] for t in tools]
assert "run_command" in tool_names
assert "start_interactive_session" in tool_names
def test_python_exec_present(self):
tools = get_tools_definition()