feat: did some extensive memory implementations.
This commit is contained in:
parent
27bcc7409e
commit
31d272daa3
@ -1,5 +1,13 @@
|
||||
# Changelog
|
||||
|
||||
|
||||
## Version 1.3.0 - 2025-11-05
|
||||
|
||||
This release updates how the software finds configuration files and handles communication with agents. These changes improve reliability and allow for more flexible configuration options.
|
||||
|
||||
**Changes:** 32 files, 2964 lines
|
||||
**Languages:** Other (706 lines), Python (2214 lines), TOML (44 lines)
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
|
||||
2
Makefile
2
Makefile
@ -72,7 +72,7 @@ serve:
|
||||
|
||||
|
||||
implode: build
|
||||
rpi rp.py -o rp
|
||||
cp rp.py rp
|
||||
chmod +x rp
|
||||
if [ -d /home/retoor/bin ]; then cp rp /home/retoor/bin/rp; fi
|
||||
|
||||
|
||||
969
pr/ads.py
Normal file
969
pr/ads.py
Normal file
@ -0,0 +1,969 @@
|
||||
import re
|
||||
import json
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
AsyncGenerator,
|
||||
Union,
|
||||
Tuple,
|
||||
Set,
|
||||
)
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
import socket
|
||||
import os
|
||||
import atexit
|
||||
|
||||
|
||||
class AsyncDataSet:
|
||||
"""
|
||||
Distributed AsyncDataSet with client-server model over Unix sockets.
|
||||
|
||||
Parameters:
|
||||
- file: Path to the SQLite database file
|
||||
- socket_path: Path to Unix socket for inter-process communication (default: "ads.sock")
|
||||
- max_concurrent_queries: Maximum concurrent database operations (default: 1)
|
||||
Set to 1 to prevent "database is locked" errors with SQLite
|
||||
"""
|
||||
|
||||
_KV_TABLE = "__kv_store"
|
||||
_DEFAULT_COLUMNS = {
|
||||
"uid": "TEXT PRIMARY KEY",
|
||||
"created_at": "TEXT",
|
||||
"updated_at": "TEXT",
|
||||
"deleted_at": "TEXT",
|
||||
}
|
||||
|
||||
def __init__(self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1):
|
||||
self._file = file
|
||||
self._socket_path = socket_path
|
||||
self._table_columns_cache: Dict[str, Set[str]] = {}
|
||||
self._is_server = None # None means not initialized yet
|
||||
self._server = None
|
||||
self._runner = None
|
||||
self._client_session = None
|
||||
self._initialized = False
|
||||
self._init_lock = None # Will be created when needed
|
||||
self._max_concurrent_queries = max_concurrent_queries
|
||||
self._db_semaphore = None # Will be created when needed
|
||||
self._db_connection = None # Persistent database connection
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
"""Ensure the instance is initialized as server or client."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Create lock if needed (first time in async context)
|
||||
if self._init_lock is None:
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async with self._init_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
await self._initialize()
|
||||
|
||||
# Create semaphore for database operations (server only)
|
||||
if self._is_server:
|
||||
self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def _initialize(self):
|
||||
"""Initialize as server or client."""
|
||||
try:
|
||||
# Try to create Unix socket
|
||||
if os.path.exists(self._socket_path):
|
||||
# Check if socket is active
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.connect(self._socket_path)
|
||||
sock.close()
|
||||
# Socket is active, we're a client
|
||||
self._is_server = False
|
||||
await self._setup_client()
|
||||
except (ConnectionRefusedError, FileNotFoundError):
|
||||
# Socket file exists but not active, remove and become server
|
||||
os.unlink(self._socket_path)
|
||||
self._is_server = True
|
||||
await self._setup_server()
|
||||
else:
|
||||
# No socket file, become server
|
||||
self._is_server = True
|
||||
await self._setup_server()
|
||||
except Exception:
|
||||
# Fallback to client mode
|
||||
self._is_server = False
|
||||
await self._setup_client()
|
||||
|
||||
# Establish persistent database connection
|
||||
self._db_connection = await aiosqlite.connect(self._file)
|
||||
|
||||
async def _setup_server(self):
|
||||
"""Setup aiohttp server."""
|
||||
app = web.Application()
|
||||
|
||||
# Add routes for all methods
|
||||
app.router.add_post("/insert", self._handle_insert)
|
||||
app.router.add_post("/update", self._handle_update)
|
||||
app.router.add_post("/delete", self._handle_delete)
|
||||
app.router.add_post("/upsert", self._handle_upsert)
|
||||
app.router.add_post("/get", self._handle_get)
|
||||
app.router.add_post("/find", self._handle_find)
|
||||
app.router.add_post("/count", self._handle_count)
|
||||
app.router.add_post("/exists", self._handle_exists)
|
||||
app.router.add_post("/kv_set", self._handle_kv_set)
|
||||
app.router.add_post("/kv_get", self._handle_kv_get)
|
||||
app.router.add_post("/execute_raw", self._handle_execute_raw)
|
||||
app.router.add_post("/query_raw", self._handle_query_raw)
|
||||
app.router.add_post("/query_one", self._handle_query_one)
|
||||
app.router.add_post("/create_table", self._handle_create_table)
|
||||
app.router.add_post("/insert_unique", self._handle_insert_unique)
|
||||
app.router.add_post("/aggregate", self._handle_aggregate)
|
||||
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
self._server = web.UnixSite(self._runner, self._socket_path)
|
||||
await self._server.start()
|
||||
|
||||
# Register cleanup
|
||||
def cleanup():
|
||||
if os.path.exists(self._socket_path):
|
||||
os.unlink(self._socket_path)
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
async def _setup_client(self):
|
||||
"""Setup aiohttp client."""
|
||||
connector = aiohttp.UnixConnector(path=self._socket_path)
|
||||
self._client_session = aiohttp.ClientSession(connector=connector)
|
||||
|
||||
async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any:
|
||||
"""Make HTTP request to server."""
|
||||
if not self._client_session:
|
||||
await self._setup_client()
|
||||
|
||||
url = f"http://localhost/{endpoint}"
|
||||
async with self._client_session.post(url, json=data) as resp:
|
||||
result = await resp.json()
|
||||
if result.get("error"):
|
||||
raise Exception(result["error"])
|
||||
return result.get("result")
|
||||
|
||||
# Server handlers
|
||||
async def _handle_insert(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_insert(
|
||||
data["table"], data["args"], data.get("return_id", False)
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_update(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_update(data["table"], data["args"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_delete(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_delete(data["table"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_upsert(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_upsert(data["table"], data["args"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_get(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_get(data["table"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_find(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_find(
|
||||
data["table"],
|
||||
data.get("where"),
|
||||
limit=data.get("limit", 0),
|
||||
offset=data.get("offset", 0),
|
||||
order_by=data.get("order_by"),
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_count(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_count(data["table"], data.get("where"))
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_exists(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_exists(data["table"], data["where"])
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_kv_set(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
await self._server_kv_set(data["key"], data["value"], table=data.get("table"))
|
||||
return web.json_response({"result": None})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_kv_get(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_kv_get(
|
||||
data["key"], default=data.get("default"), table=data.get("table")
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_execute_raw(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_execute_raw(
|
||||
data["sql"],
|
||||
tuple(data.get("params", [])) if data.get("params") else None,
|
||||
)
|
||||
return web.json_response(
|
||||
{"result": result.rowcount if hasattr(result, "rowcount") else None}
|
||||
)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_query_raw(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_query_raw(
|
||||
data["sql"],
|
||||
tuple(data.get("params", [])) if data.get("params") else None,
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_query_one(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_query_one(
|
||||
data["sql"],
|
||||
tuple(data.get("params", [])) if data.get("params") else None,
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_create_table(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
await self._server_create_table(data["table"], data["schema"], data.get("constraints"))
|
||||
return web.json_response({"result": None})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_insert_unique(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_insert_unique(
|
||||
data["table"], data["args"], data["unique_fields"]
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_aggregate(self, request):
|
||||
data = await request.json()
|
||||
try:
|
||||
result = await self._server_aggregate(
|
||||
data["table"],
|
||||
data["function"],
|
||||
data.get("column", "*"),
|
||||
data.get("where"),
|
||||
)
|
||||
return web.json_response({"result": result})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
# Public methods that delegate to server or client
|
||||
async def insert(
|
||||
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||
) -> Union[str, int]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_insert(table, args, return_id)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"insert", {"table": table, "args": args, "return_id": return_id}
|
||||
)
|
||||
|
||||
async def update(
|
||||
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
|
||||
) -> int:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_update(table, args, where)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"update", {"table": table, "args": args, "where": where}
|
||||
)
|
||||
|
||||
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_delete(table, where)
|
||||
else:
|
||||
return await self._make_request("delete", {"table": table, "where": where})
|
||||
|
||||
async def get(
|
||||
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_get(table, where)
|
||||
else:
|
||||
return await self._make_request("get", {"table": table, "where": where})
|
||||
|
||||
async def find(
|
||||
self,
|
||||
table: str,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
limit: int = 0,
|
||||
offset: int = 0,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_find(
|
||||
table, where, limit=limit, offset=offset, order_by=order_by
|
||||
)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"find",
|
||||
{
|
||||
"table": table,
|
||||
"where": where,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"order_by": order_by,
|
||||
},
|
||||
)
|
||||
|
||||
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_count(table, where)
|
||||
else:
|
||||
return await self._make_request("count", {"table": table, "where": where})
|
||||
|
||||
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_exists(table, where)
|
||||
else:
|
||||
return await self._make_request("exists", {"table": table, "where": where})
|
||||
|
||||
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_query_raw(sql, params)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"query_raw", {"sql": sql, "params": list(params) if params else None}
|
||||
)
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Dict[str, str],
|
||||
constraints: Optional[List[str]] = None,
|
||||
):
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_create_table(table, schema, constraints)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"create_table",
|
||||
{"table": table, "schema": schema, "constraints": constraints},
|
||||
)
|
||||
|
||||
async def insert_unique(
|
||||
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||
) -> Union[str, None]:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_insert_unique(table, args, unique_fields)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"insert_unique",
|
||||
{"table": table, "args": args, "unique_fields": unique_fields},
|
||||
)
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
table: str,
|
||||
function: str,
|
||||
column: str = "*",
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return await self._server_aggregate(table, function, column, where)
|
||||
else:
|
||||
return await self._make_request(
|
||||
"aggregate",
|
||||
{
|
||||
"table": table,
|
||||
"function": function,
|
||||
"column": column,
|
||||
"where": where,
|
||||
},
|
||||
)
|
||||
|
||||
async def transaction(self):
|
||||
"""Context manager for transactions."""
|
||||
await self._ensure_initialized()
|
||||
if self._is_server:
|
||||
return TransactionContext(self._db_connection, self._db_semaphore)
|
||||
else:
|
||||
raise NotImplementedError("Transactions not supported in client mode")
|
||||
|
||||
# Server implementation methods (original logic)
|
||||
@staticmethod
|
||||
def _utc_iso() -> str:
|
||||
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
||||
|
||||
@staticmethod
|
||||
def _py_to_sqlite_type(value: Any) -> str:
|
||||
if value is None:
|
||||
return "TEXT"
|
||||
if isinstance(value, bool):
|
||||
return "INTEGER"
|
||||
if isinstance(value, int):
|
||||
return "INTEGER"
|
||||
if isinstance(value, float):
|
||||
return "REAL"
|
||||
if isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return "BLOB"
|
||||
return "TEXT"
|
||||
|
||||
async def _get_table_columns(self, table: str) -> Set[str]:
|
||||
"""Get actual columns that exist in the table."""
|
||||
if table in self._table_columns_cache:
|
||||
return self._table_columns_cache[table]
|
||||
|
||||
columns = set()
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
async with self._db_connection.execute(f"PRAGMA table_info({table})") as cursor:
|
||||
async for row in cursor:
|
||||
columns.add(row[1]) # Column name is at index 1
|
||||
self._table_columns_cache[table] = columns
|
||||
except:
|
||||
pass
|
||||
return columns
|
||||
|
||||
async def _invalidate_column_cache(self, table: str):
|
||||
"""Invalidate column cache for a table."""
|
||||
if table in self._table_columns_cache:
|
||||
del self._table_columns_cache[table]
|
||||
|
||||
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
|
||||
col_type = self._py_to_sqlite_type(value)
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
await self._db_connection.execute(
|
||||
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
|
||||
)
|
||||
await self._db_connection.commit()
|
||||
await self._invalidate_column_cache(table)
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "duplicate column name" in str(e).lower():
|
||||
pass # Column already exists
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
|
||||
# Always include default columns
|
||||
cols = self._DEFAULT_COLUMNS.copy()
|
||||
|
||||
# Add columns from col_sources
|
||||
for key, val in col_sources.items():
|
||||
if key not in cols:
|
||||
cols[key] = self._py_to_sqlite_type(val)
|
||||
|
||||
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
|
||||
async with self._db_semaphore:
|
||||
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||
await self._db_connection.commit()
|
||||
await self._invalidate_column_cache(table)
|
||||
|
||||
async def _table_exists(self, table: str) -> bool:
|
||||
"""Check if a table exists."""
|
||||
async with self._db_semaphore:
|
||||
async with self._db_connection.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
|
||||
) as cursor:
|
||||
return await cursor.fetchone() is not None
|
||||
|
||||
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
|
||||
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
|
||||
|
||||
@classmethod
|
||||
def _missing_column_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
|
||||
m = cls._RE_NO_COLUMN.search(str(err))
|
||||
return m.group(1) if m else None
|
||||
|
||||
@classmethod
|
||||
def _missing_table_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
|
||||
m = cls._RE_NO_TABLE.search(str(err))
|
||||
return m.group(1) if m else None
|
||||
|
||||
async def _safe_execute(
|
||||
self,
|
||||
table: str,
|
||||
sql: str,
|
||||
params: Iterable[Any],
|
||||
col_sources: Dict[str, Any],
|
||||
max_retries: int = 10,
|
||||
) -> aiosqlite.Cursor:
|
||||
retries = 0
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
cursor = await self._db_connection.execute(sql, params)
|
||||
await self._db_connection.commit()
|
||||
return cursor
|
||||
except aiosqlite.OperationalError as err:
|
||||
retries += 1
|
||||
err_str = str(err).lower()
|
||||
|
||||
# Handle missing column
|
||||
col = self._missing_column_from_error(err)
|
||||
if col:
|
||||
if col in col_sources:
|
||||
await self._ensure_column(table, col, col_sources[col])
|
||||
else:
|
||||
# Column not in sources, ensure it with NULL/TEXT type
|
||||
await self._ensure_column(table, col, None)
|
||||
continue
|
||||
|
||||
# Handle missing table
|
||||
tbl = self._missing_table_from_error(err)
|
||||
if tbl:
|
||||
await self._ensure_table(tbl, col_sources)
|
||||
continue
|
||||
|
||||
# Handle other column-related errors
|
||||
if "has no column named" in err_str:
|
||||
# Extract column name differently
|
||||
match = re.search(r"table \w+ has no column named (\w+)", err_str)
|
||||
if match:
|
||||
col_name = match.group(1)
|
||||
if col_name in col_sources:
|
||||
await self._ensure_column(table, col_name, col_sources[col_name])
|
||||
else:
|
||||
await self._ensure_column(table, col_name, None)
|
||||
continue
|
||||
|
||||
raise
|
||||
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||
|
||||
async def _safe_query(
|
||||
self,
|
||||
table: str,
|
||||
sql: str,
|
||||
params: Iterable[Any],
|
||||
col_sources: Dict[str, Any],
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
# Check if table exists first
|
||||
if not await self._table_exists(table):
|
||||
return
|
||||
|
||||
max_retries = 10
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
self._db_connection.row_factory = aiosqlite.Row
|
||||
async with self._db_connection.execute(sql, params) as cursor:
|
||||
# Fetch all rows while holding the semaphore
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
# Yield rows after releasing the semaphore
|
||||
for row in rows:
|
||||
yield dict(row)
|
||||
return
|
||||
except aiosqlite.OperationalError as err:
|
||||
retries += 1
|
||||
err_str = str(err).lower()
|
||||
|
||||
# Handle missing table
|
||||
tbl = self._missing_table_from_error(err)
|
||||
if tbl:
|
||||
# For queries, if table doesn't exist, just return empty
|
||||
return
|
||||
|
||||
# Handle missing column in WHERE clause or SELECT
|
||||
if "no such column" in err_str:
|
||||
# For queries with missing columns, return empty
|
||||
return
|
||||
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
||||
if not where:
|
||||
return "", []
|
||||
|
||||
clauses = []
|
||||
vals = []
|
||||
|
||||
op_map = {"$eq": "=", "$ne": "!=", "$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="}
|
||||
|
||||
for k, v in where.items():
|
||||
if isinstance(v, dict):
|
||||
# Handle operators like {"$lt": 10}
|
||||
op_key = next(iter(v))
|
||||
if op_key in op_map:
|
||||
operator = op_map[op_key]
|
||||
value = v[op_key]
|
||||
clauses.append(f"`{k}` {operator} ?")
|
||||
vals.append(value)
|
||||
else:
|
||||
# Fallback for unexpected dict values
|
||||
clauses.append(f"`{k}` = ?")
|
||||
vals.append(json.dumps(v))
|
||||
else:
|
||||
# Default to equality
|
||||
clauses.append(f"`{k}` = ?")
|
||||
vals.append(v)
|
||||
|
||||
return " WHERE " + " AND ".join(clauses), vals
|
||||
|
||||
async def _server_insert(
|
||||
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||
) -> Union[str, int]:
|
||||
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
|
||||
uid = str(uuid4())
|
||||
now = self._utc_iso()
|
||||
record = {
|
||||
"uid": uid,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"deleted_at": None,
|
||||
**args,
|
||||
}
|
||||
|
||||
# Ensure table exists with all needed columns
|
||||
await self._ensure_table(table, record)
|
||||
|
||||
# Handle auto-increment ID if requested
|
||||
if return_id and "id" not in args:
|
||||
# Ensure id column exists
|
||||
async with self._db_semaphore:
|
||||
# Add id column if it doesn't exist
|
||||
try:
|
||||
await self._db_connection.execute(
|
||||
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||
)
|
||||
await self._db_connection.commit()
|
||||
except aiosqlite.OperationalError as e:
|
||||
if "duplicate column name" not in str(e).lower():
|
||||
# Try without autoincrement constraint
|
||||
try:
|
||||
await self._db_connection.execute(
|
||||
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
|
||||
)
|
||||
await self._db_connection.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
await self._invalidate_column_cache(table)
|
||||
|
||||
# Insert and get lastrowid
|
||||
cols = "`" + "`, `".join(record.keys()) + "`"
|
||||
qs = ", ".join(["?"] * len(record))
|
||||
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||
cursor = await self._safe_execute(table, sql, list(record.values()), record)
|
||||
return cursor.lastrowid
|
||||
|
||||
cols = "`" + "`, `".join(record) + "`"
|
||||
qs = ", ".join(["?"] * len(record))
|
||||
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
|
||||
await self._safe_execute(table, sql, list(record.values()), record)
|
||||
return uid
|
||||
|
||||
async def _server_update(
|
||||
self,
|
||||
table: str,
|
||||
args: Dict[str, Any],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> int:
|
||||
if not args:
|
||||
return 0
|
||||
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return 0
|
||||
|
||||
args["updated_at"] = self._utc_iso()
|
||||
|
||||
# Ensure all columns exist
|
||||
all_cols = {**args, **(where or {})}
|
||||
await self._ensure_table(table, all_cols)
|
||||
for col, val in all_cols.items():
|
||||
await self._ensure_column(table, col, val)
|
||||
|
||||
set_clause = ", ".join(f"`{k}` = ?" for k in args)
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
|
||||
params = list(args.values()) + where_params
|
||||
cur = await self._safe_execute(table, sql, params, all_cols)
|
||||
return cur.rowcount
|
||||
|
||||
async def _server_delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return 0
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"DELETE FROM {table}{where_clause}"
|
||||
cur = await self._safe_execute(table, sql, where_params, where or {})
|
||||
return cur.rowcount
|
||||
|
||||
async def _server_upsert(
|
||||
self,
|
||||
table: str,
|
||||
args: Dict[str, Any],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> str | None:
|
||||
if not args:
|
||||
raise ValueError("Nothing to update. Empty dict given.")
|
||||
args["updated_at"] = self._utc_iso()
|
||||
affected = await self._server_update(table, args, where)
|
||||
if affected:
|
||||
rec = await self._server_get(table, where)
|
||||
return rec.get("uid") if rec else None
|
||||
merged = {**(where or {}), **args}
|
||||
return await self._server_insert(table, merged)
|
||||
|
||||
async def _server_get(
|
||||
self, table: str, where: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
|
||||
async for row in self._safe_query(table, sql, where_params, where or {}):
|
||||
return row
|
||||
return None
|
||||
|
||||
async def _server_find(
|
||||
self,
|
||||
table: str,
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
limit: int = 0,
|
||||
offset: int = 0,
|
||||
order_by: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find records with optional ordering."""
|
||||
where_clause, where_params = self._build_where(where)
|
||||
order_clause = f" ORDER BY {order_by}" if order_by else ""
|
||||
extra = (f" LIMIT {limit}" if limit else "") + (f" OFFSET {offset}" if offset else "")
|
||||
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
|
||||
return [row async for row in self._safe_query(table, sql, where_params, where or {})]
|
||||
|
||||
async def _server_count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return 0
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
|
||||
gen = self._safe_query(table, sql, where_params, where or {})
|
||||
async for row in gen:
|
||||
return next(iter(row.values()), 0)
|
||||
return 0
|
||||
|
||||
async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool:
|
||||
return (await self._server_count(table, where)) > 0
|
||||
|
||||
async def _server_kv_set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
*,
|
||||
table: str | None = None,
|
||||
) -> None:
|
||||
tbl = table or self._KV_TABLE
|
||||
json_val = json.dumps(value, default=str)
|
||||
await self._server_upsert(tbl, {"value": json_val}, {"key": key})
|
||||
|
||||
async def _server_kv_get(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
default: Any = None,
|
||||
table: str | None = None,
|
||||
) -> Any:
|
||||
tbl = table or self._KV_TABLE
|
||||
row = await self._server_get(tbl, {"key": key})
|
||||
if not row:
|
||||
return default
|
||||
try:
|
||||
return json.loads(row["value"])
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
async def _server_execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
|
||||
"""Execute raw SQL for complex queries like JOINs."""
|
||||
async with self._db_semaphore:
|
||||
cursor = await self._db_connection.execute(sql, params or ())
|
||||
await self._db_connection.commit()
|
||||
return cursor
|
||||
|
||||
async def _server_query_raw(
|
||||
self, sql: str, params: Optional[Tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute raw SQL query and return results as list of dicts."""
|
||||
try:
|
||||
async with self._db_semaphore:
|
||||
self._db_connection.row_factory = aiosqlite.Row
|
||||
async with self._db_connection.execute(sql, params or ()) as cursor:
|
||||
return [dict(row) async for row in cursor]
|
||||
except aiosqlite.OperationalError:
|
||||
# Return empty list if query fails
|
||||
return []
|
||||
|
||||
async def _server_query_one(
|
||||
self, sql: str, params: Optional[Tuple] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Execute raw SQL query and return single result."""
|
||||
results = await self._server_query_raw(sql + " LIMIT 1", params)
|
||||
return results[0] if results else None
|
||||
|
||||
async def _server_create_table(
|
||||
self,
|
||||
table: str,
|
||||
schema: Dict[str, str],
|
||||
constraints: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create table with custom schema and constraints. Always includes default columns."""
|
||||
# Merge default columns with custom schema
|
||||
full_schema = self._DEFAULT_COLUMNS.copy()
|
||||
full_schema.update(schema)
|
||||
|
||||
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
|
||||
if constraints:
|
||||
columns.extend(constraints)
|
||||
columns_sql = ", ".join(columns)
|
||||
|
||||
async with self._db_semaphore:
|
||||
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
|
||||
await self._db_connection.commit()
|
||||
await self._invalidate_column_cache(table)
|
||||
|
||||
async def _server_insert_unique(
|
||||
self, table: str, args: Dict[str, Any], unique_fields: List[str]
|
||||
) -> Union[str, None]:
|
||||
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
|
||||
try:
|
||||
return await self._server_insert(table, args)
|
||||
except aiosqlite.IntegrityError as e:
|
||||
if "UNIQUE" in str(e):
|
||||
return None
|
||||
raise
|
||||
|
||||
async def _server_aggregate(
|
||||
self,
|
||||
table: str,
|
||||
function: str,
|
||||
column: str = "*",
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
|
||||
# Check if table exists
|
||||
if not await self._table_exists(table):
|
||||
return None
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
|
||||
result = await self._server_query_one(sql, tuple(where_params))
|
||||
return result["result"] if result else None
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection or server."""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
if self._client_session:
|
||||
await self._client_session.close()
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
if os.path.exists(self._socket_path) and self._is_server:
|
||||
os.unlink(self._socket_path)
|
||||
if self._db_connection:
|
||||
await self._db_connection.close()
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""Context manager for database transactions."""
|
||||
|
||||
def __init__(self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None):
|
||||
self.db_connection = db_connection
|
||||
self.semaphore = semaphore
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.semaphore:
|
||||
await self.semaphore.acquire()
|
||||
try:
|
||||
await self.db_connection.execute("BEGIN")
|
||||
return self.db_connection
|
||||
except:
|
||||
if self.semaphore:
|
||||
self.semaphore.release()
|
||||
raise
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
try:
|
||||
if exc_type is None:
|
||||
await self.db_connection.commit()
|
||||
else:
|
||||
await self.db_connection.rollback()
|
||||
finally:
|
||||
if self.semaphore:
|
||||
self.semaphore.release()
|
||||
|
||||
|
||||
# Test cases remain the same but with additional tests for new functionality
|
||||
@ -99,7 +99,7 @@ class AgentCommunicationBus:
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
def receive_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
if unread_only:
|
||||
cursor.execute(
|
||||
@ -153,9 +153,6 @@ class AgentCommunicationBus:
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
|
||||
return self.get_messages(agent_id, unread_only=True)
|
||||
|
||||
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
|
||||
@ -125,7 +125,7 @@ class AgentManager:
|
||||
return message.message_id
|
||||
|
||||
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
return self.communication_bus.get_messages(agent_id, unread_only)
|
||||
return self.communication_bus.receive_messages(agent_id, unread_only)
|
||||
|
||||
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
|
||||
orchestrator = self.get_agent(orchestrator_id)
|
||||
|
||||
@ -229,45 +229,3 @@ def get_agent_role(role_name: str) -> AgentRole:
|
||||
|
||||
def list_agent_roles() -> Dict[str, AgentRole]:
|
||||
return AGENT_ROLES.copy()
|
||||
|
||||
|
||||
def get_recommended_agent(task_description: str) -> str:
|
||||
task_lower = task_description.lower()
|
||||
|
||||
code_keywords = [
|
||||
"code",
|
||||
"implement",
|
||||
"function",
|
||||
"class",
|
||||
"bug",
|
||||
"debug",
|
||||
"refactor",
|
||||
"optimize",
|
||||
]
|
||||
research_keywords = [
|
||||
"search",
|
||||
"find",
|
||||
"research",
|
||||
"information",
|
||||
"analyze",
|
||||
"investigate",
|
||||
]
|
||||
data_keywords = ["data", "database", "query", "statistics", "analyze", "process"]
|
||||
planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"]
|
||||
testing_keywords = ["test", "verify", "validate", "check", "quality"]
|
||||
doc_keywords = ["document", "documentation", "explain", "guide", "manual"]
|
||||
|
||||
if any(keyword in task_lower for keyword in code_keywords):
|
||||
return "coding"
|
||||
elif any(keyword in task_lower for keyword in research_keywords):
|
||||
return "research"
|
||||
elif any(keyword in task_lower for keyword in data_keywords):
|
||||
return "data_analysis"
|
||||
elif any(keyword in task_lower for keyword in planning_keywords):
|
||||
return "planning"
|
||||
elif any(keyword in task_lower for keyword in testing_keywords):
|
||||
return "testing"
|
||||
elif any(keyword in task_lower for keyword in doc_keywords):
|
||||
return "documentation"
|
||||
else:
|
||||
return "general"
|
||||
|
||||
12
pr/cache/tool_cache.py
vendored
12
pr/cache/tool_cache.py
vendored
@ -122,18 +122,6 @@ class ToolCache:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def invalidate_tool(self, tool_name: str):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted_count
|
||||
|
||||
def clear_expired(self):
|
||||
current_time = int(time.time())
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS
|
||||
from pr.autonomous import run_autonomous_mode
|
||||
from pr.core.api import list_models
|
||||
from pr.tools import read_file
|
||||
@ -13,6 +14,12 @@ def handle_command(assistant, command):
|
||||
command_parts = command.strip().split(maxsplit=1)
|
||||
cmd = command_parts[0].lower()
|
||||
|
||||
if cmd in MULTIPLEXER_COMMANDS:
|
||||
# Pass the assistant object and the remaining arguments to the multiplexer command
|
||||
return MULTIPLEXER_COMMANDS[cmd](
|
||||
assistant, command_parts[1:] if len(command_parts) > 1 else []
|
||||
)
|
||||
|
||||
if cmd == "/edit":
|
||||
rp_editor = RPEditor(command_parts[1] if len(command_parts) > 1 else None)
|
||||
rp_editor.start()
|
||||
@ -23,6 +30,18 @@ def handle_command(assistant, command):
|
||||
if task:
|
||||
run_autonomous_mode(assistant, task)
|
||||
|
||||
elif cmd == "/prompt":
|
||||
rp_editor = RPEditor(command_parts[1] if len(command_parts) > 1 else None)
|
||||
rp_editor.start()
|
||||
rp_editor.thread.join()
|
||||
prompt_text = str(rp_editor.get_text())
|
||||
rp_editor.stop()
|
||||
rp_editor = None
|
||||
if prompt_text.strip():
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, prompt_text)
|
||||
|
||||
elif cmd == "/auto":
|
||||
if len(command_parts) < 2:
|
||||
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")
|
||||
|
||||
507
pr/commands/help_docs.py
Normal file
507
pr/commands/help_docs.py
Normal file
@ -0,0 +1,507 @@
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
def get_workflow_help():
|
||||
return f"""
|
||||
{Colors.BOLD}WORKFLOWS - AUTOMATED TASK EXECUTION{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Overview:{Colors.RESET}
|
||||
Workflows enable automated execution of multi-step tasks by chaining together
|
||||
tool calls with defined execution logic, error handling, and conditional branching.
|
||||
|
||||
{Colors.BOLD}Execution Modes:{Colors.RESET}
|
||||
{Colors.CYAN}sequential{Colors.RESET} - Steps execute one after another in order
|
||||
{Colors.CYAN}parallel{Colors.RESET} - All steps execute simultaneously
|
||||
{Colors.CYAN}conditional{Colors.RESET} - Steps execute based on conditions and paths
|
||||
|
||||
{Colors.BOLD}Workflow Components:{Colors.RESET}
|
||||
{Colors.CYAN}Steps{Colors.RESET} - Individual tool executions with parameters
|
||||
{Colors.CYAN}Variables{Colors.RESET} - Dynamic values available across workflow
|
||||
{Colors.CYAN}Conditions{Colors.RESET} - Boolean expressions controlling step execution
|
||||
{Colors.CYAN}Paths{Colors.RESET} - Success/failure routing between steps
|
||||
{Colors.CYAN}Retries{Colors.RESET} - Automatic retry count for failed steps
|
||||
{Colors.CYAN}Timeouts{Colors.RESET} - Maximum execution time per step (seconds)
|
||||
|
||||
{Colors.BOLD}Variable Substitution:{Colors.RESET}
|
||||
Reference previous step results: ${{step.step_id}}
|
||||
Reference workflow variables: ${{var.variable_name}}
|
||||
|
||||
Example: {{"path": "${{step.get_path}}/output.txt"}}
|
||||
|
||||
{Colors.BOLD}Workflow Commands:{Colors.RESET}
|
||||
{Colors.CYAN}/workflows{Colors.RESET} - List all available workflows
|
||||
{Colors.CYAN}/workflow <name>{Colors.RESET} - Execute a specific workflow
|
||||
|
||||
{Colors.BOLD}Workflow Structure:{Colors.RESET}
|
||||
Each workflow consists of:
|
||||
- Name and description
|
||||
- Execution mode (sequential/parallel/conditional)
|
||||
- List of steps with tool calls
|
||||
- Optional variables for dynamic values
|
||||
- Optional tags for categorization
|
||||
|
||||
{Colors.BOLD}Step Configuration:{Colors.RESET}
|
||||
Each step includes:
|
||||
- tool_name: Which tool to execute
|
||||
- arguments: Parameters passed to the tool
|
||||
- step_id: Unique identifier for the step
|
||||
- condition: Optional boolean expression (conditional mode)
|
||||
- on_success: List of step IDs to execute on success
|
||||
- on_failure: List of step IDs to execute on failure
|
||||
- retry_count: Number of automatic retries (default: 0)
|
||||
- timeout_seconds: Maximum execution time (default: 300)
|
||||
|
||||
{Colors.BOLD}Conditional Execution:{Colors.RESET}
|
||||
Conditions are Python expressions evaluated with access to:
|
||||
- variables: All workflow variables
|
||||
- results: All step results by step_id
|
||||
|
||||
Example condition: "results['check_file']['status'] == 'success'"
|
||||
|
||||
{Colors.BOLD}Success/Failure Paths:{Colors.RESET}
|
||||
Steps can define:
|
||||
- on_success: List of step IDs to execute if step succeeds
|
||||
- on_failure: List of step IDs to execute if step fails
|
||||
|
||||
This enables error recovery and branching logic within workflows.
|
||||
|
||||
{Colors.BOLD}Workflow Storage:{Colors.RESET}
|
||||
Workflows are stored in the SQLite database and persist across sessions.
|
||||
Execution history and statistics are tracked for each workflow.
|
||||
|
||||
{Colors.BOLD}Example Workflow Scenarios:{Colors.RESET}
|
||||
- Sequential: Data pipeline (fetch → process → save → validate)
|
||||
- Parallel: Multi-source aggregation (fetch A + fetch B + fetch C)
|
||||
- Conditional: Error recovery (try main → on failure → fallback)
|
||||
|
||||
{Colors.BOLD}Best Practices:{Colors.RESET}
|
||||
- Use descriptive step_id names (e.g., "fetch_user_data")
|
||||
- Set appropriate timeouts for long-running operations
|
||||
- Add retry_count for network or flaky operations
|
||||
- Use conditional paths for error handling
|
||||
- Keep workflows focused on a single logical task
|
||||
- Use variables for values shared across steps
|
||||
"""
|
||||
|
||||
|
||||
def get_agent_help():
|
||||
return f"""
|
||||
{Colors.BOLD}AGENTS - SPECIALIZED AI ASSISTANTS{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Overview:{Colors.RESET}
|
||||
Agents are specialized AI instances with specific roles, capabilities, and
|
||||
system prompts. Each agent maintains its own conversation history and can
|
||||
collaborate with other agents to complete complex tasks.
|
||||
|
||||
{Colors.BOLD}Available Agent Roles:{Colors.RESET}
|
||||
{Colors.CYAN}coding{Colors.RESET} - Code writing, review, debugging, refactoring
|
||||
Temperature: 0.3 (precise)
|
||||
|
||||
{Colors.CYAN}research{Colors.RESET} - Information gathering, analysis, documentation
|
||||
Temperature: 0.5 (balanced)
|
||||
|
||||
{Colors.CYAN}data_analysis{Colors.RESET} - Data processing, statistical analysis, queries
|
||||
Temperature: 0.3 (precise)
|
||||
|
||||
{Colors.CYAN}planning{Colors.RESET} - Task decomposition, workflow design, coordination
|
||||
Temperature: 0.6 (creative)
|
||||
|
||||
{Colors.CYAN}testing{Colors.RESET} - Test design, quality assurance, validation
|
||||
Temperature: 0.4 (methodical)
|
||||
|
||||
{Colors.CYAN}documentation{Colors.RESET} - Technical writing, API docs, user guides
|
||||
Temperature: 0.6 (creative)
|
||||
|
||||
{Colors.CYAN}orchestrator{Colors.RESET} - Multi-agent coordination and task delegation
|
||||
Temperature: 0.5 (balanced)
|
||||
|
||||
{Colors.CYAN}general{Colors.RESET} - General purpose assistance across domains
|
||||
Temperature: 0.7 (versatile)
|
||||
|
||||
{Colors.BOLD}Agent Commands:{Colors.RESET}
|
||||
{Colors.CYAN}/agent <role> <task>{Colors.RESET} - Create agent and assign task
|
||||
{Colors.CYAN}/agents{Colors.RESET} - Show all active agents
|
||||
{Colors.CYAN}/collaborate <task>{Colors.RESET} - Multi-agent collaboration
|
||||
|
||||
{Colors.BOLD}Single Agent Usage:{Colors.RESET}
|
||||
Create a specialized agent and assign it a task:
|
||||
|
||||
/agent coding Implement a binary search function in Python
|
||||
/agent research Find latest best practices for API security
|
||||
/agent testing Create test cases for user authentication
|
||||
|
||||
Each agent:
|
||||
- Maintains its own conversation history
|
||||
- Has access to role-specific tools
|
||||
- Uses specialized system prompts
|
||||
- Tracks task completion count
|
||||
- Integrates with knowledge base
|
||||
|
||||
{Colors.BOLD}Multi-Agent Collaboration:{Colors.RESET}
|
||||
Use multiple agents to work together on complex tasks:
|
||||
|
||||
/collaborate Build a REST API with tests and documentation
|
||||
|
||||
This automatically:
|
||||
1. Creates an orchestrator agent
|
||||
2. Spawns coding, research, and planning agents
|
||||
3. Orchestrator delegates subtasks to specialists
|
||||
4. Agents communicate results back to orchestrator
|
||||
5. Orchestrator integrates results and coordinates
|
||||
|
||||
{Colors.BOLD}Agent Capabilities by Role:{Colors.RESET}
|
||||
|
||||
{Colors.CYAN}Coding Agent:{Colors.RESET}
|
||||
- Read, write, and modify files
|
||||
- Execute Python code
|
||||
- Run commands and build tools
|
||||
- Index source directories
|
||||
- Code review and refactoring
|
||||
|
||||
{Colors.CYAN}Research Agent:{Colors.RESET}
|
||||
- Web search and HTTP fetching
|
||||
- Read documentation and files
|
||||
- Query databases for information
|
||||
- Analyze and synthesize findings
|
||||
|
||||
{Colors.CYAN}Data Analysis Agent:{Colors.RESET}
|
||||
- Database queries and operations
|
||||
- Python execution for analysis
|
||||
- File read/write for data
|
||||
- Statistical processing
|
||||
|
||||
{Colors.CYAN}Planning Agent:{Colors.RESET}
|
||||
- Task decomposition and organization
|
||||
- Database storage for plans
|
||||
- File operations for documentation
|
||||
- Dependency identification
|
||||
|
||||
{Colors.CYAN}Testing Agent:{Colors.RESET}
|
||||
- Test execution and validation
|
||||
- Code reading and analysis
|
||||
- Command execution for test runners
|
||||
- Quality verification
|
||||
|
||||
{Colors.CYAN}Documentation Agent:{Colors.RESET}
|
||||
- Technical writing and formatting
|
||||
- Web research for best practices
|
||||
- File operations for docs
|
||||
- Documentation organization
|
||||
|
||||
{Colors.BOLD}Agent Session Management:{Colors.RESET}
|
||||
- Each session has a unique ID
|
||||
- Agents persist within a session
|
||||
- View active agents with /agents
|
||||
- Agents share knowledge base access
|
||||
- Communication bus for agent messages
|
||||
|
||||
{Colors.BOLD}Knowledge Base Integration:{Colors.RESET}
|
||||
When agents execute tasks:
|
||||
- Relevant knowledge entries are automatically retrieved
|
||||
- Top 3 matches are injected into agent context
|
||||
- Knowledge access count is tracked
|
||||
- Agents can update knowledge base
|
||||
|
||||
{Colors.BOLD}Agent Communication:{Colors.RESET}
|
||||
Agents can send messages to each other:
|
||||
- Request type: Ask another agent for help
|
||||
- Response type: Answer another agent's request
|
||||
- Notification type: Inform about events
|
||||
- Coordination type: Synchronize work
|
||||
|
||||
{Colors.BOLD}Best Practices:{Colors.RESET}
|
||||
- Use specialized agents for focused tasks
|
||||
- Use collaboration for complex multi-domain tasks
|
||||
- Coding agent: Precise, deterministic code generation
|
||||
- Research agent: Thorough information gathering
|
||||
- Planning agent: Breaking down complex tasks
|
||||
- Testing agent: Comprehensive test coverage
|
||||
- Documentation agent: Clear user-facing docs
|
||||
|
||||
{Colors.BOLD}Performance Characteristics:{Colors.RESET}
|
||||
- Lower temperature: More deterministic (coding, data_analysis)
|
||||
- Higher temperature: More creative (planning, documentation)
|
||||
- Max tokens: 4096 for all agents
|
||||
- Parallel execution: Multiple agents work simultaneously
|
||||
- Message history: Full conversation context maintained
|
||||
|
||||
{Colors.BOLD}Example Usage Patterns:{Colors.RESET}
|
||||
|
||||
Single specialized agent:
|
||||
/agent coding Refactor authentication module for better security
|
||||
|
||||
Multi-agent collaboration:
|
||||
/collaborate Create a data visualization dashboard with tests
|
||||
|
||||
Agent inspection:
|
||||
/agents # View all active agents and their stats
|
||||
"""
|
||||
|
||||
|
||||
def get_knowledge_help():
|
||||
return f"""
|
||||
{Colors.BOLD}KNOWLEDGE BASE - PERSISTENT INFORMATION STORAGE{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Overview:{Colors.RESET}
|
||||
The knowledge base provides persistent storage for information, facts, and
|
||||
context that can be accessed across sessions and by all agents.
|
||||
|
||||
{Colors.BOLD}Commands:{Colors.RESET}
|
||||
{Colors.CYAN}/remember <content>{Colors.RESET} - Store information in knowledge base
|
||||
{Colors.CYAN}/knowledge <query>{Colors.RESET} - Search knowledge base
|
||||
{Colors.CYAN}/stats{Colors.RESET} - Show knowledge base statistics
|
||||
|
||||
{Colors.BOLD}Features:{Colors.RESET}
|
||||
- Automatic categorization of content
|
||||
- TF-IDF based semantic search
|
||||
- Access frequency tracking
|
||||
- Importance scoring
|
||||
- Temporal relevance
|
||||
- Category-based organization
|
||||
|
||||
{Colors.BOLD}Categories:{Colors.RESET}
|
||||
Content is automatically categorized as:
|
||||
- code: Code snippets and programming concepts
|
||||
- api: API documentation and endpoints
|
||||
- configuration: Settings and configurations
|
||||
- documentation: General documentation
|
||||
- fact: Factual information
|
||||
- procedure: Step-by-step procedures
|
||||
- general: Uncategorized information
|
||||
|
||||
{Colors.BOLD}Search Capabilities:{Colors.RESET}
|
||||
- Keyword matching with TF-IDF scoring
|
||||
- Returns top relevant entries
|
||||
- Considers access frequency
|
||||
- Factors in recency
|
||||
- Searches across all categories
|
||||
|
||||
{Colors.BOLD}Integration:{Colors.RESET}
|
||||
- Agents automatically query knowledge base
|
||||
- Top matches injected into agent context
|
||||
- Conversation history can be stored
|
||||
- Facts extracted from interactions
|
||||
- Persistent across sessions
|
||||
"""
|
||||
|
||||
|
||||
def get_cache_help():
|
||||
return f"""
|
||||
{Colors.BOLD}CACHING SYSTEM - PERFORMANCE OPTIMIZATION{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Overview:{Colors.RESET}
|
||||
The caching system optimizes performance by storing API responses and tool
|
||||
results, reducing redundant API calls and expensive operations.
|
||||
|
||||
{Colors.BOLD}Commands:{Colors.RESET}
|
||||
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
|
||||
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
|
||||
|
||||
{Colors.BOLD}Cache Types:{Colors.RESET}
|
||||
|
||||
{Colors.CYAN}API Cache:{Colors.RESET}
|
||||
- Stores LLM API responses
|
||||
- Key: Hash of messages + model + parameters
|
||||
- Tracks token usage
|
||||
- TTL: Configurable expiration
|
||||
- Reduces API costs and latency
|
||||
|
||||
{Colors.CYAN}Tool Cache:{Colors.RESET}
|
||||
- Stores tool execution results
|
||||
- Key: Tool name + arguments hash
|
||||
- Per-tool statistics
|
||||
- Hit count tracking
|
||||
- Reduces redundant operations
|
||||
|
||||
{Colors.BOLD}Cache Statistics:{Colors.RESET}
|
||||
- Total entries (valid + expired)
|
||||
- Valid entry count
|
||||
- Expired entry count
|
||||
- Total cached tokens (API cache)
|
||||
- Cache hits per tool (Tool cache)
|
||||
- Per-tool breakdown
|
||||
|
||||
{Colors.BOLD}Benefits:{Colors.RESET}
|
||||
- Reduced API costs
|
||||
- Faster response times
|
||||
- Lower rate limit usage
|
||||
- Consistent results for identical queries
|
||||
- Tool execution optimization
|
||||
"""
|
||||
|
||||
|
||||
def get_background_help():
|
||||
return f"""
|
||||
{Colors.BOLD}BACKGROUND SESSIONS - CONCURRENT TASK EXECUTION{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Overview:{Colors.RESET}
|
||||
Background sessions allow running long-running commands and processes while
|
||||
continuing to interact with the assistant. Sessions are monitored and can
|
||||
be managed through the assistant.
|
||||
|
||||
{Colors.BOLD}Commands:{Colors.RESET}
|
||||
{Colors.CYAN}/bg start <command>{Colors.RESET} - Start command in background
|
||||
{Colors.CYAN}/bg list{Colors.RESET} - List all background sessions
|
||||
{Colors.CYAN}/bg status <name>{Colors.RESET} - Show session status
|
||||
{Colors.CYAN}/bg output <name>{Colors.RESET} - View session output
|
||||
{Colors.CYAN}/bg input <name> <text>{Colors.RESET} - Send input to session
|
||||
{Colors.CYAN}/bg kill <name>{Colors.RESET} - Terminate session
|
||||
{Colors.CYAN}/bg events{Colors.RESET} - Show recent events
|
||||
|
||||
{Colors.BOLD}Features:{Colors.RESET}
|
||||
- Automatic session monitoring
|
||||
- Output capture and buffering
|
||||
- Interactive input support
|
||||
- Event notification system
|
||||
- Session lifecycle management
|
||||
- Background event detection
|
||||
|
||||
{Colors.BOLD}Event Types:{Colors.RESET}
|
||||
- session_started: New session created
|
||||
- session_ended: Session terminated
|
||||
- output_received: New output available
|
||||
- possible_input_needed: May require user input
|
||||
- high_output_volume: Large amount of output
|
||||
- inactive_session: No activity for period
|
||||
|
||||
{Colors.BOLD}Status Indicators:{Colors.RESET}
|
||||
The prompt shows background session count: You[2bg]>
|
||||
- Number indicates active background sessions
|
||||
- Updates automatically as sessions start/stop
|
||||
|
||||
{Colors.BOLD}Use Cases:{Colors.RESET}
|
||||
- Long-running build processes
|
||||
- Web servers and daemons
|
||||
- File watching and monitoring
|
||||
- Batch processing tasks
|
||||
- Test suites execution
|
||||
- Development servers
|
||||
"""
|
||||
|
||||
|
||||
def get_full_help():
|
||||
return f"""
|
||||
{Colors.BOLD}R - PROFESSIONAL AI ASSISTANT{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}BASIC COMMANDS{Colors.RESET}
|
||||
exit, quit, q - Exit the assistant
|
||||
/help - Show this help message
|
||||
/help workflows - Detailed workflow documentation
|
||||
/help agents - Detailed agent documentation
|
||||
/help knowledge - Knowledge base documentation
|
||||
/help cache - Caching system documentation
|
||||
/help background - Background sessions documentation
|
||||
/reset - Clear message history
|
||||
/dump - Show message history as JSON
|
||||
/verbose - Toggle verbose mode
|
||||
/models - List available models
|
||||
/tools - List available tools
|
||||
|
||||
{Colors.BOLD}FILE OPERATIONS{Colors.RESET}
|
||||
/review <file> - Review a file
|
||||
/refactor <file> - Refactor code in a file
|
||||
/obfuscate <file> - Obfuscate code in a file
|
||||
|
||||
{Colors.BOLD}AUTONOMOUS MODE{Colors.RESET}
|
||||
{Colors.CYAN}/auto <task>{Colors.RESET} - Enter autonomous mode for task completion
|
||||
Assistant works continuously until task is complete
|
||||
Max 50 iterations with automatic context management
|
||||
Press Ctrl+C twice to force exit
|
||||
|
||||
{Colors.BOLD}WORKFLOWS{Colors.RESET}
|
||||
{Colors.CYAN}/workflows{Colors.RESET} - List all available workflows
|
||||
{Colors.CYAN}/workflow <name>{Colors.RESET} - Execute a specific workflow
|
||||
|
||||
Workflows enable automated multi-step task execution with:
|
||||
- Sequential, parallel, or conditional execution
|
||||
- Variable substitution and step dependencies
|
||||
- Error handling and retry logic
|
||||
- Success/failure path routing
|
||||
|
||||
For detailed documentation: /help workflows
|
||||
|
||||
{Colors.BOLD}AGENTS{Colors.RESET}
|
||||
{Colors.CYAN}/agent <role> <task>{Colors.RESET} - Create specialized agent
|
||||
{Colors.CYAN}/agents{Colors.RESET} - Show active agents
|
||||
{Colors.CYAN}/collaborate <task>{Colors.RESET} - Multi-agent collaboration
|
||||
|
||||
Available roles: coding, research, data_analysis, planning,
|
||||
testing, documentation, orchestrator, general
|
||||
|
||||
Each agent has specialized capabilities and system prompts.
|
||||
Agents can work independently or collaborate on complex tasks.
|
||||
|
||||
For detailed documentation: /help agents
|
||||
|
||||
{Colors.BOLD}KNOWLEDGE BASE{Colors.RESET}
|
||||
{Colors.CYAN}/knowledge <query>{Colors.RESET} - Search knowledge base
|
||||
{Colors.CYAN}/remember <content>{Colors.RESET} - Store information
|
||||
|
||||
Persistent storage for facts, procedures, and context.
|
||||
Automatic categorization and TF-IDF search.
|
||||
Integrated with agents for context injection.
|
||||
|
||||
For detailed documentation: /help knowledge
|
||||
|
||||
{Colors.BOLD}SESSION MANAGEMENT{Colors.RESET}
|
||||
{Colors.CYAN}/history{Colors.RESET} - Show conversation history
|
||||
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
|
||||
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
|
||||
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics
|
||||
|
||||
{Colors.BOLD}BACKGROUND SESSIONS{Colors.RESET}
|
||||
{Colors.CYAN}/bg start <command>{Colors.RESET} - Start background session
|
||||
{Colors.CYAN}/bg list{Colors.RESET} - List active sessions
|
||||
{Colors.CYAN}/bg output <name>{Colors.RESET} - View session output
|
||||
{Colors.CYAN}/bg kill <name>{Colors.RESET} - Terminate session
|
||||
|
||||
Run long-running processes while maintaining interactivity.
|
||||
Automatic monitoring and event notifications.
|
||||
|
||||
For detailed documentation: /help background
|
||||
|
||||
{Colors.BOLD}CONTEXT FILES{Colors.RESET}
|
||||
.rcontext.txt - Local project context (auto-loaded)
|
||||
~/.rcontext.txt - Global context (auto-loaded)
|
||||
-c, --context FILE - Additional context files (command line)
|
||||
|
||||
{Colors.BOLD}ENVIRONMENT VARIABLES{Colors.RESET}
|
||||
OPENROUTER_API_KEY - API key for OpenRouter
|
||||
AI_MODEL - Default model to use
|
||||
API_URL - Custom API endpoint
|
||||
USE_TOOLS - Enable/disable tools (default: 1)
|
||||
|
||||
{Colors.BOLD}COMMAND-LINE FLAGS{Colors.RESET}
|
||||
-i, --interactive - Start in interactive mode
|
||||
-v, --verbose - Enable verbose output
|
||||
--debug - Enable debug logging
|
||||
-m, --model MODEL - Specify AI model
|
||||
--no-syntax - Disable syntax highlighting
|
||||
|
||||
{Colors.BOLD}DATA STORAGE{Colors.RESET}
|
||||
~/.assistant_db.sqlite - SQLite database for persistence
|
||||
~/.assistant_history - Command history
|
||||
~/.assistant_error.log - Error logs
|
||||
|
||||
{Colors.BOLD}AVAILABLE TOOLS{Colors.RESET}
|
||||
Tools are functions the AI can call to interact with the system:
|
||||
- File operations: read, write, list, mkdir, search/replace
|
||||
- Command execution: run_command, interactive sessions
|
||||
- Web operations: http_fetch, web_search, web_search_news
|
||||
- Database: db_set, db_get, db_query
|
||||
- Python execution: python_exec with persistent globals
|
||||
- Code editing: open_editor, insert/replace text, diff/patch
|
||||
- Process management: tail, kill, interactive control
|
||||
- Agents: create, execute tasks, collaborate
|
||||
- Knowledge: add, search, categorize entries
|
||||
|
||||
{Colors.BOLD}GETTING HELP{Colors.RESET}
|
||||
/help - This help message
|
||||
/help workflows - Workflow system details
|
||||
/help agents - Agent system details
|
||||
/help knowledge - Knowledge base details
|
||||
/help cache - Cache system details
|
||||
/help background - Background sessions details
|
||||
/tools - List all available tools
|
||||
/models - List all available models
|
||||
"""
|
||||
@ -1,8 +1,9 @@
|
||||
import os
|
||||
|
||||
DEFAULT_MODEL = "x-ai/grok-code-fast-1"
|
||||
DEFAULT_API_URL = "http://localhost:8118/ai/chat"
|
||||
MODEL_LIST_URL = "http://localhost:8118/ai/models"
|
||||
DEFAULT_API_URL = "https://static.molodetz.nl/rp.cgi/api/v1/chat/completions"
|
||||
MODEL_LIST_URL = "https://static.molodetz.nl/rp.cgi/api/v1/models"
|
||||
|
||||
|
||||
config_directory = os.path.expanduser("~/.local/share/rp")
|
||||
os.makedirs(config_directory, exist_ok=True)
|
||||
|
||||
@ -7,44 +7,45 @@ class AdvancedContextManager:
|
||||
self.knowledge_store = knowledge_store
|
||||
self.conversation_memory = conversation_memory
|
||||
|
||||
def adaptive_context_window(
|
||||
self, messages: List[Dict[str, Any]], task_complexity: str = "medium"
|
||||
) -> int:
|
||||
complexity_thresholds = {
|
||||
"simple": 10,
|
||||
"medium": 20,
|
||||
"complex": 35,
|
||||
"very_complex": 50,
|
||||
def adaptive_context_window(self, messages: List[Dict[str, Any]], complexity: str) -> int:
|
||||
"""Calculate adaptive context window size based on message complexity."""
|
||||
base_window = 10
|
||||
|
||||
complexity_multipliers = {
|
||||
"simple": 1.0,
|
||||
"medium": 2.0,
|
||||
"complex": 3.5,
|
||||
"very_complex": 5.0,
|
||||
}
|
||||
|
||||
base_threshold = complexity_thresholds.get(task_complexity, 20)
|
||||
|
||||
message_complexity_score = self._analyze_message_complexity(messages)
|
||||
|
||||
if message_complexity_score > 0.7:
|
||||
adjusted = int(base_threshold * 1.5)
|
||||
elif message_complexity_score < 0.3:
|
||||
adjusted = int(base_threshold * 0.7)
|
||||
else:
|
||||
adjusted = base_threshold
|
||||
|
||||
return max(base_threshold, adjusted)
|
||||
multiplier = complexity_multipliers.get(complexity, 2.0)
|
||||
return int(base_window * multiplier)
|
||||
|
||||
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
|
||||
total_length = sum(len(msg.get("content", "")) for msg in messages)
|
||||
avg_length = total_length / len(messages) if messages else 0
|
||||
"""Analyze the complexity of messages and return a score between 0.0 and 1.0."""
|
||||
if not messages:
|
||||
return 0.0
|
||||
|
||||
unique_words = set()
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
words = re.findall(r"\b\w+\b", content.lower())
|
||||
unique_words.update(words)
|
||||
total_complexity = 0.0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
|
||||
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
|
||||
# Calculate complexity based on various factors
|
||||
word_count = len(content.split())
|
||||
sentence_count = len(re.split(r"[.!?]+", content))
|
||||
avg_word_length = sum(len(word) for word in content.split()) / max(word_count, 1)
|
||||
|
||||
# Simple complexity score based on length and richness
|
||||
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
|
||||
return complexity
|
||||
# Complexity score based on length, vocabulary, and structure
|
||||
length_score = min(1.0, word_count / 100) # Normalize to 0-1
|
||||
structure_score = min(1.0, sentence_count / 10)
|
||||
vocabulary_score = min(1.0, avg_word_length / 8)
|
||||
|
||||
message_complexity = (length_score + structure_score + vocabulary_score) / 3
|
||||
total_complexity += message_complexity
|
||||
|
||||
return min(1.0, total_complexity / len(messages))
|
||||
|
||||
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
|
||||
if not text.strip():
|
||||
@ -82,3 +83,38 @@ class AdvancedContextManager:
|
||||
return 0.0
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
def create_enhanced_context(
|
||||
self, messages: List[Dict[str, Any]], user_message: str, include_knowledge: bool = True
|
||||
) -> tuple:
|
||||
"""Create enhanced context with knowledge base integration."""
|
||||
working_messages = messages.copy()
|
||||
|
||||
if include_knowledge and self.knowledge_store:
|
||||
# Search knowledge base for relevant information
|
||||
search_results = self.knowledge_store.search_entries(user_message, top_k=3)
|
||||
|
||||
if search_results:
|
||||
knowledge_parts = []
|
||||
for idx, entry in enumerate(search_results, 1):
|
||||
content = entry.content
|
||||
if len(content) > 2000:
|
||||
content = content[:2000] + "..."
|
||||
|
||||
knowledge_parts.append(f"Match {idx} (Category: {entry.category}):\n{content}")
|
||||
|
||||
knowledge_message_content = (
|
||||
"[KNOWLEDGE_BASE_CONTEXT]\nRelevant knowledge base entries:\n\n"
|
||||
+ "\n\n".join(knowledge_parts)
|
||||
)
|
||||
|
||||
knowledge_message = {"role": "user", "content": knowledge_message_content}
|
||||
working_messages.append(knowledge_message)
|
||||
|
||||
context_info = f"Added {len(search_results)} knowledge base entries"
|
||||
else:
|
||||
context_info = "No relevant knowledge base entries found"
|
||||
else:
|
||||
context_info = "Knowledge base integration disabled"
|
||||
|
||||
return working_messages, context_info
|
||||
|
||||
@ -67,6 +67,15 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
|
||||
if "tool_calls" in msg:
|
||||
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
|
||||
|
||||
if verbose and "usage" in result:
|
||||
from pr.core.usage_tracker import UsageTracker
|
||||
|
||||
usage = result["usage"]
|
||||
input_t = usage.get("prompt_tokens", 0)
|
||||
output_t = usage.get("completion_tokens", 0)
|
||||
cost = UsageTracker._calculate_cost(model, input_t, output_t)
|
||||
print(f"API call cost: €{cost:.4f}")
|
||||
|
||||
logger.debug("=== API CALL END ===")
|
||||
return result
|
||||
|
||||
|
||||
@ -29,34 +29,28 @@ from pr.core.background_monitor import (
|
||||
stop_global_monitor,
|
||||
)
|
||||
from pr.core.context import init_system_message, truncate_tool_result
|
||||
from pr.tools import (
|
||||
apply_patch,
|
||||
chdir,
|
||||
create_diff,
|
||||
db_get,
|
||||
db_query,
|
||||
db_set,
|
||||
getpwd,
|
||||
http_fetch,
|
||||
index_source_directory,
|
||||
kill_process,
|
||||
list_directory,
|
||||
mkdir,
|
||||
python_exec,
|
||||
read_file,
|
||||
run_command,
|
||||
search_replace,
|
||||
tail_process,
|
||||
web_search,
|
||||
web_search_news,
|
||||
write_file,
|
||||
post_image,
|
||||
from pr.tools import get_tools_definition
|
||||
from pr.tools.agents import (
|
||||
collaborate_agents,
|
||||
create_agent,
|
||||
execute_agent_task,
|
||||
list_agents,
|
||||
remove_agent,
|
||||
)
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.tools.command import kill_process, run_command, tail_process
|
||||
from pr.tools.database import db_get, db_query, db_set
|
||||
from pr.tools.filesystem import (
|
||||
chdir,
|
||||
clear_edit_tracker,
|
||||
display_edit_summary,
|
||||
display_edit_timeline,
|
||||
getpwd,
|
||||
index_source_directory,
|
||||
list_directory,
|
||||
mkdir,
|
||||
read_file,
|
||||
search_replace,
|
||||
write_file,
|
||||
)
|
||||
from pr.tools.interactive_control import (
|
||||
close_interactive_session,
|
||||
@ -65,7 +59,18 @@ from pr.tools.interactive_control import (
|
||||
send_input_to_session,
|
||||
start_interactive_session,
|
||||
)
|
||||
from pr.tools.patch import display_file_diff
|
||||
from pr.tools.memory import (
|
||||
add_knowledge_entry,
|
||||
delete_knowledge_entry,
|
||||
get_knowledge_by_category,
|
||||
get_knowledge_entry,
|
||||
get_knowledge_statistics,
|
||||
search_knowledge,
|
||||
update_knowledge_importance,
|
||||
)
|
||||
from pr.tools.patch import apply_patch, create_diff, display_file_diff
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.ui import Colors, render_markdown
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
@ -97,7 +102,7 @@ class Assistant:
|
||||
"MODEL_LIST_URL", MODEL_LIST_URL
|
||||
)
|
||||
self.use_tools = os.environ.get("USE_TOOLS", "1") == "1"
|
||||
self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1"
|
||||
|
||||
self.interrupt_count = 0
|
||||
self.python_globals = {}
|
||||
self.db_conn = None
|
||||
@ -245,7 +250,6 @@ class Assistant:
|
||||
logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
|
||||
|
||||
func_map = {
|
||||
"post_image": lambda **kw: post_image(**kw),
|
||||
"http_fetch": lambda **kw: http_fetch(**kw),
|
||||
"run_command": lambda **kw: run_command(**kw),
|
||||
"tail_process": lambda **kw: tail_process(**kw),
|
||||
|
||||
@ -27,15 +27,6 @@ class BackgroundMonitor:
|
||||
if self.monitor_thread:
|
||||
self.monitor_thread.join(timeout=2)
|
||||
|
||||
def add_event_callback(self, callback):
|
||||
"""Add a callback function to be called when events are detected."""
|
||||
self.event_callbacks.append(callback)
|
||||
|
||||
def remove_event_callback(self, callback):
|
||||
"""Remove an event callback."""
|
||||
if callback in self.event_callbacks:
|
||||
self.event_callbacks.remove(callback)
|
||||
|
||||
def get_pending_events(self):
|
||||
"""Get all pending events from the queue."""
|
||||
events = []
|
||||
@ -224,30 +215,3 @@ def stop_global_monitor():
|
||||
global _global_monitor
|
||||
if _global_monitor:
|
||||
_global_monitor.stop()
|
||||
|
||||
|
||||
# Global monitor instance
|
||||
_global_monitor = None
|
||||
|
||||
|
||||
def start_global_monitor():
|
||||
"""Start the global background monitor."""
|
||||
global _global_monitor
|
||||
if _global_monitor is None:
|
||||
_global_monitor = BackgroundMonitor()
|
||||
_global_monitor.start()
|
||||
return _global_monitor
|
||||
|
||||
|
||||
def stop_global_monitor():
|
||||
"""Stop the global background monitor."""
|
||||
global _global_monitor
|
||||
if _global_monitor:
|
||||
_global_monitor.stop()
|
||||
_global_monitor = None
|
||||
|
||||
|
||||
def get_global_monitor():
|
||||
"""Get the global background monitor instance."""
|
||||
global _global_monitor
|
||||
return _global_monitor
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import configparser
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
import uuid
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger("config")
|
||||
@ -11,29 +10,6 @@ CONFIG_FILE = os.path.join(CONFIG_DIRECTORY, ".prrc")
|
||||
LOCAL_CONFIG_FILE = ".prrc"
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
os.makedirs(CONFIG_DIRECTORY, exist_ok=True)
|
||||
config = {
|
||||
"api": {},
|
||||
"autonomous": {},
|
||||
"ui": {},
|
||||
"output": {},
|
||||
"session": {},
|
||||
"api_key": "rp-" + str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
global_config = _load_config_file(CONFIG_FILE)
|
||||
local_config = _load_config_file(LOCAL_CONFIG_FILE)
|
||||
|
||||
for section in config.keys():
|
||||
if section in global_config:
|
||||
config[section].update(global_config[section])
|
||||
if section in local_config:
|
||||
config[section].update(local_config[section])
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
|
||||
if not os.path.exists(filepath):
|
||||
return {}
|
||||
|
||||
@ -12,7 +12,6 @@ from pr.config import (
|
||||
CONVERSATION_SUMMARY_THRESHOLD,
|
||||
DB_PATH,
|
||||
KNOWLEDGE_SEARCH_LIMIT,
|
||||
MEMORY_AUTO_SUMMARIZE,
|
||||
TOOL_CACHE_TTL,
|
||||
WORKFLOW_EXECUTOR_MAX_WORKERS,
|
||||
)
|
||||
@ -169,9 +168,9 @@ class EnhancedAssistant:
|
||||
self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
|
||||
)
|
||||
|
||||
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
|
||||
# Automatically extract and store facts from every user message
|
||||
facts = self.fact_extractor.extract_facts(user_message)
|
||||
for fact in facts[:3]:
|
||||
for fact in facts[:5]: # Store up to 5 facts per message
|
||||
entry_id = str(uuid.uuid4())[:16]
|
||||
import time
|
||||
|
||||
@ -182,7 +181,11 @@ class EnhancedAssistant:
|
||||
entry_id=entry_id,
|
||||
category=categories[0] if categories else "general",
|
||||
content=fact["text"],
|
||||
metadata={"type": fact["type"], "confidence": fact["confidence"]},
|
||||
metadata={
|
||||
"type": fact["type"],
|
||||
"confidence": fact["confidence"],
|
||||
"source": "user_message",
|
||||
},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time(),
|
||||
)
|
||||
|
||||
148
pr/core/knowledge_context.py
Normal file
148
pr/core/knowledge_context.py
Normal file
@ -0,0 +1,148 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
KNOWLEDGE_MESSAGE_MARKER = "[KNOWLEDGE_BASE_CONTEXT]"
|
||||
|
||||
|
||||
def inject_knowledge_context(assistant, user_message):
|
||||
if not hasattr(assistant, "enhanced") or not assistant.enhanced:
|
||||
return
|
||||
|
||||
messages = assistant.messages
|
||||
|
||||
# Remove any existing knowledge context messages
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].get("role") == "user" and KNOWLEDGE_MESSAGE_MARKER in messages[i].get(
|
||||
"content", ""
|
||||
):
|
||||
del messages[i]
|
||||
logger.debug(f"Removed existing knowledge base message at index {i}")
|
||||
break
|
||||
|
||||
try:
|
||||
# Search knowledge base with enhanced FTS + semantic search
|
||||
knowledge_results = assistant.enhanced.knowledge_store.search_entries(user_message, top_k=5)
|
||||
|
||||
# Search conversation history for related content
|
||||
conversation_results = []
|
||||
if hasattr(assistant.enhanced, "conversation_memory"):
|
||||
history_results = assistant.enhanced.conversation_memory.search_conversations(
|
||||
user_message, limit=3
|
||||
)
|
||||
for conv in history_results:
|
||||
# Extract relevant messages from conversation
|
||||
conv_messages = assistant.enhanced.conversation_memory.get_conversation_messages(
|
||||
conv["conversation_id"]
|
||||
)
|
||||
for msg in conv_messages[-5:]: # Last 5 messages from each conversation
|
||||
if msg["role"] == "user" and msg["content"] != user_message:
|
||||
# Calculate relevance score
|
||||
relevance = calculate_text_similarity(user_message, msg["content"])
|
||||
if relevance > 0.3: # Only include relevant matches
|
||||
conversation_results.append(
|
||||
{
|
||||
"content": msg["content"],
|
||||
"score": relevance,
|
||||
"source": f"Previous conversation: {conv['conversation_id'][:8]}",
|
||||
}
|
||||
)
|
||||
|
||||
# Combine and sort results by relevance score
|
||||
all_results = []
|
||||
|
||||
# Add knowledge base results
|
||||
for entry in knowledge_results:
|
||||
score = entry.metadata.get("search_score", 0.5)
|
||||
all_results.append(
|
||||
{
|
||||
"content": entry.content,
|
||||
"score": score,
|
||||
"source": f"Knowledge Base ({entry.category})",
|
||||
"type": "knowledge",
|
||||
}
|
||||
)
|
||||
|
||||
# Add conversation results
|
||||
for conv in conversation_results:
|
||||
all_results.append(
|
||||
{
|
||||
"content": conv["content"],
|
||||
"score": conv["score"],
|
||||
"source": conv["source"],
|
||||
"type": "conversation",
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by score and take top 5
|
||||
all_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_results = all_results[:5]
|
||||
|
||||
if not top_results:
|
||||
logger.debug("No relevant knowledge or conversation matches found")
|
||||
return
|
||||
|
||||
# Format context for LLM
|
||||
knowledge_parts = []
|
||||
for idx, result in enumerate(top_results, 1):
|
||||
content = result["content"]
|
||||
if len(content) > 1500: # Shorter limit for multiple results
|
||||
content = content[:1500] + "..."
|
||||
|
||||
score_indicator = f"({result['score']:.2f})" if result["score"] < 1.0 else "(exact)"
|
||||
knowledge_parts.append(
|
||||
f"Match {idx} {score_indicator} - {result['source']}:\n{content}"
|
||||
)
|
||||
|
||||
knowledge_message_content = (
|
||||
f"{KNOWLEDGE_MESSAGE_MARKER}\nRelevant information from knowledge base and conversation history:\n\n"
|
||||
+ "\n\n".join(knowledge_parts)
|
||||
)
|
||||
|
||||
knowledge_message = {"role": "user", "content": knowledge_message_content}
|
||||
|
||||
messages.append(knowledge_message)
|
||||
logger.debug(f"Injected enhanced context message with {len(top_results)} matches")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error injecting knowledge context: {e}")
|
||||
|
||||
|
||||
def calculate_text_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate similarity between two texts using word overlap and sequence matching."""
|
||||
import re
|
||||
|
||||
# Normalize texts
|
||||
text1_lower = text1.lower()
|
||||
text2_lower = text2.lower()
|
||||
|
||||
# Exact substring match gets highest score
|
||||
if text1_lower in text2_lower or text2_lower in text1_lower:
|
||||
return 1.0
|
||||
|
||||
# Word-level similarity
|
||||
words1 = set(re.findall(r"\b\w+\b", text1_lower))
|
||||
words2 = set(re.findall(r"\b\w+\b", text2_lower))
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1 & words2
|
||||
union = words1 | words2
|
||||
|
||||
word_similarity = len(intersection) / len(union)
|
||||
|
||||
# Bonus for consecutive word sequences (partial sentences)
|
||||
consecutive_bonus = 0.0
|
||||
words1_list = list(words1)
|
||||
list(words2)
|
||||
|
||||
for i in range(len(words1_list) - 1):
|
||||
for j in range(i + 2, min(i + 5, len(words1_list) + 1)):
|
||||
phrase = " ".join(words1_list[i:j])
|
||||
if phrase in text2_lower:
|
||||
consecutive_bonus += 0.1 * (j - i)
|
||||
|
||||
total_similarity = min(1.0, word_similarity + consecutive_bonus)
|
||||
|
||||
return total_similarity
|
||||
@ -9,8 +9,10 @@ logger = get_logger("usage")
|
||||
|
||||
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
|
||||
|
||||
EXCHANGE_RATE = 1.0 # Keep in USD
|
||||
|
||||
MODEL_COSTS = {
|
||||
"x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0},
|
||||
"x-ai/grok-code-fast-1": {"input": 0.0002, "output": 0.0015}, # per 1000 tokens in USD
|
||||
"gpt-4": {"input": 0.03, "output": 0.06},
|
||||
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
||||
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
|
||||
@ -64,9 +66,10 @@ class UsageTracker:
|
||||
|
||||
self._save_to_history(model, input_tokens, output_tokens, cost)
|
||||
|
||||
logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}")
|
||||
logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: €{cost:.4f}")
|
||||
|
||||
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
@staticmethod
|
||||
def _calculate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
if model not in MODEL_COSTS:
|
||||
base_model = model.split("/")[0] if "/" in model else model
|
||||
if base_model not in MODEL_COSTS:
|
||||
@ -76,8 +79,8 @@ class UsageTracker:
|
||||
else:
|
||||
costs = MODEL_COSTS[model]
|
||||
|
||||
input_cost = (input_tokens / 1000) * costs["input"]
|
||||
output_cost = (output_tokens / 1000) * costs["output"]
|
||||
input_cost = (input_tokens / 1000) * costs["input"] * EXCHANGE_RATE
|
||||
output_cost = (output_tokens / 1000) * costs["output"] * EXCHANGE_RATE
|
||||
|
||||
return input_cost + output_cost
|
||||
|
||||
|
||||
@ -370,8 +370,10 @@ class RPEditor:
|
||||
self.cursor_x = 0
|
||||
elif key == ord("u"):
|
||||
self.undo()
|
||||
elif key == ord("r") and self.prev_key == 18: # Ctrl-R
|
||||
elif key == 18: # Ctrl-R
|
||||
self.redo()
|
||||
elif key == 19: # Ctrl-S
|
||||
self._save_file()
|
||||
|
||||
self.prev_key = key
|
||||
except Exception:
|
||||
|
||||
589
pr/implode.py
Normal file
589
pr/implode.py
Normal file
@ -0,0 +1,589 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
impLODE: A Python script to consolidate a multi-file Python project
|
||||
into a single, runnable file.
|
||||
|
||||
It intelligently resolves local imports, hoists external dependencies to the top,
|
||||
and preserves the core logic, using AST for safe transformations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import ast
|
||||
import argparse
|
||||
import logging
|
||||
import py_compile # Added
|
||||
from typing import Set, Dict, Optional, TextIO
|
||||
|
||||
# Set up the main logger
|
||||
logger = logging.getLogger("impLODE")
|
||||
|
||||
|
||||
class ImportTransformer(ast.NodeTransformer):
|
||||
"""
|
||||
An AST transformer that visits Import and ImportFrom nodes.
|
||||
|
||||
On Pass 1 (Dry Run):
|
||||
- Identifies all local vs. external imports.
|
||||
- Recursively calls the main resolver for local modules.
|
||||
- Stores external and __future__ imports in the Imploder instance.
|
||||
|
||||
On Pass 2 (Write Run):
|
||||
- Recursively calls the main resolver for local modules.
|
||||
- Removes all import statements (since they were hoisted in Pass 1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
imploder: "Imploder",
|
||||
current_file_path: str,
|
||||
f_out: Optional[TextIO],
|
||||
is_dry_run: bool,
|
||||
indent_level: int = 0,
|
||||
):
|
||||
self.imploder = imploder
|
||||
self.current_file_path = current_file_path
|
||||
self.current_dir = os.path.dirname(current_file_path)
|
||||
self.f_out = f_out
|
||||
self.is_dry_run = is_dry_run
|
||||
self.indent = " " * indent_level
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def _log_debug(self, msg: str):
|
||||
"""Helper for indented debug logging."""
|
||||
self.logger.debug(f"{self.indent} > {msg}")
|
||||
|
||||
def _find_local_module(self, module_name: str, level: int) -> Optional[str]:
|
||||
"""
|
||||
Tries to find the absolute path for a given module name and relative level.
|
||||
Returns None if it's not a local module *and* cannot be found in site-packages.
|
||||
"""
|
||||
import importlib.util # Added
|
||||
import importlib.machinery # Added
|
||||
|
||||
if not module_name:
|
||||
# Handle cases like `from . import foo`
|
||||
# The module_name is 'foo', handled by visit_ImportFrom's aliases.
|
||||
# If it's just `from . import`, module_name is None.
|
||||
# We'll resolve from the alias names.
|
||||
|
||||
# Determine base path for relative import
|
||||
base_path = self.current_dir
|
||||
if level > 0:
|
||||
# Go up 'level' directories (minus one for current dir)
|
||||
for _ in range(level - 1):
|
||||
base_path = os.path.dirname(base_path)
|
||||
return base_path
|
||||
|
||||
base_path = self.current_dir
|
||||
if level > 0:
|
||||
# Go up 'level' directories (minus one for current dir)
|
||||
for _ in range(level - 1):
|
||||
base_path = os.path.dirname(base_path)
|
||||
else:
|
||||
# Absolute import, start from the root
|
||||
base_path = self.imploder.root_dir
|
||||
|
||||
module_parts = module_name.split(".")
|
||||
module_path = os.path.join(base_path, *module_parts)
|
||||
|
||||
# Check for package `.../module/__init__.py`
|
||||
package_init = os.path.join(module_path, "__init__.py")
|
||||
if os.path.isfile(package_init):
|
||||
self._log_debug(
|
||||
f"Resolved '{module_name}' to local package: {os.path.relpath(package_init, self.imploder.root_dir)}"
|
||||
)
|
||||
return package_init
|
||||
|
||||
# Check for module `.../module.py`
|
||||
module_py = module_path + ".py"
|
||||
if os.path.isfile(module_py):
|
||||
self._log_debug(
|
||||
f"Resolved '{module_name}' to local module: {os.path.relpath(module_py, self.imploder.root_dir)}"
|
||||
)
|
||||
return module_py
|
||||
|
||||
# --- FALLBACK: Deep search from root ---
|
||||
# This is less efficient but helps with non-standard project structures
|
||||
# where the "root" for imports isn't the same as the Imploder's root.
|
||||
if level == 0:
|
||||
self._log_debug(
|
||||
f"Module '{module_name}' not found at primary path. Starting deep fallback search from {self.imploder.root_dir}..."
|
||||
)
|
||||
|
||||
target_path_py = os.path.join(*module_parts) + ".py"
|
||||
target_path_init = os.path.join(*module_parts, "__init__.py")
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(self.imploder.root_dir, topdown=True):
|
||||
# Prune common virtual envs and metadata dirs to speed up search
|
||||
dirnames[:] = [
|
||||
d
|
||||
for d in dirnames
|
||||
if not d.startswith(".")
|
||||
and d not in ("venv", "env", ".venv", ".env", "__pycache__", "node_modules")
|
||||
]
|
||||
|
||||
# Check for .../module/file.py
|
||||
check_file_py = os.path.join(dirpath, target_path_py)
|
||||
if os.path.isfile(check_file_py):
|
||||
self._log_debug(
|
||||
f"Fallback search found module: {os.path.relpath(check_file_py, self.imploder.root_dir)}"
|
||||
)
|
||||
return check_file_py
|
||||
|
||||
# Check for .../module/file/__init__.py
|
||||
check_file_init = os.path.join(dirpath, target_path_init)
|
||||
if os.path.isfile(check_file_init):
|
||||
self._log_debug(
|
||||
f"Fallback search found package: {os.path.relpath(check_file_init, self.imploder.root_dir)}"
|
||||
)
|
||||
return check_file_init
|
||||
|
||||
# --- Removed site-packages/external module inlining logic ---
|
||||
# As requested, the script will now only resolve local modules.
|
||||
|
||||
# Not found locally
|
||||
return None
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> Optional[ast.AST]:
|
||||
"""Handles `import foo` or `import foo.bar`."""
|
||||
for alias in node.names:
|
||||
module_path = self._find_local_module(alias.name, level=0)
|
||||
|
||||
if module_path:
|
||||
# Local module
|
||||
self._log_debug(f"Resolving local import: `import {alias.name}`")
|
||||
self.imploder.resolve_file(
|
||||
file_abs_path=module_path,
|
||||
f_out=self.f_out,
|
||||
is_dry_run=self.is_dry_run,
|
||||
indent_level=self.imploder.current_indent_level,
|
||||
)
|
||||
else:
|
||||
# External module
|
||||
self._log_debug(f"Found external import: `import {alias.name}`")
|
||||
if self.is_dry_run:
|
||||
# On dry run, collect it for hoisting
|
||||
key = f"import {alias.name}"
|
||||
if key not in self.imploder.external_imports:
|
||||
self.imploder.external_imports[key] = node
|
||||
|
||||
# On Pass 2 (write run), we remove this node by replacing it
|
||||
# with a dummy function call to prevent IndentationErrors.
|
||||
# On Pass 1 (dry run), we also replace it.
|
||||
module_names = ", ".join([a.name for a in node.names])
|
||||
new_call = ast.Call(
|
||||
func=ast.Name(id="_implode_log_import", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Constant(value=module_names),
|
||||
ast.Constant(value=0),
|
||||
ast.Constant(value="import"),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
return ast.Expr(value=new_call)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> Optional[ast.AST]:
|
||||
"""Handles `from foo import bar` or `from .foo import bar`."""
|
||||
|
||||
# --- NEW: Create the replacement node first ---
|
||||
module_name_str = node.module or ""
|
||||
import_type = "from-import"
|
||||
if module_name_str == "__future__":
|
||||
import_type = "future-import"
|
||||
|
||||
new_call = ast.Call(
|
||||
func=ast.Name(id="_implode_log_import", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Constant(value=module_name_str),
|
||||
ast.Constant(value=node.level),
|
||||
ast.Constant(value=import_type),
|
||||
],
|
||||
keywords=[],
|
||||
)
|
||||
replacement_node = ast.Expr(value=new_call)
|
||||
|
||||
# Handle `from __future__ import ...`
|
||||
if node.module == "__future__":
|
||||
self._log_debug("Found __future__ import. Hoisting to top.")
|
||||
if self.is_dry_run:
|
||||
key = ast.unparse(node)
|
||||
self.imploder.future_imports[key] = node
|
||||
return replacement_node # Return replacement
|
||||
|
||||
module_path = self._find_local_module(node.module or "", node.level)
|
||||
|
||||
if module_path and os.path.isdir(module_path):
|
||||
# This is a package import like `from . import foo`
|
||||
# `module_path` is the directory of the package (e.g., self.current_dir)
|
||||
self._log_debug(f"Resolving package import: `from {node.module or '.'} import ...`")
|
||||
for alias in node.names:
|
||||
# --- FIX ---
|
||||
# Resolve `foo` relative to the package directory (`module_path`)
|
||||
|
||||
# Check for module `.../package_dir/foo.py`
|
||||
package_module_py = os.path.join(module_path, alias.name + ".py")
|
||||
|
||||
# Check for sub-package `.../package_dir/foo/__init__.py`
|
||||
package_module_init = os.path.join(module_path, alias.name, "__init__.py")
|
||||
|
||||
if os.path.isfile(package_module_py):
|
||||
self._log_debug(
|
||||
f"Found sub-module: {os.path.relpath(package_module_py, self.imploder.root_dir)}"
|
||||
)
|
||||
self.imploder.resolve_file(
|
||||
file_abs_path=package_module_py,
|
||||
f_out=self.f_out,
|
||||
is_dry_run=self.is_dry_run,
|
||||
indent_level=self.imploder.current_indent_level,
|
||||
)
|
||||
elif os.path.isfile(package_module_init):
|
||||
self._log_debug(
|
||||
f"Found sub-package: {os.path.relpath(package_module_init, self.imploder.root_dir)}"
|
||||
)
|
||||
self.imploder.resolve_file(
|
||||
file_abs_path=package_module_init,
|
||||
f_out=self.f_out,
|
||||
is_dry_run=self.is_dry_run,
|
||||
indent_level=self.imploder.current_indent_level,
|
||||
)
|
||||
else:
|
||||
# Could be another sub-package, etc.
|
||||
# This logic can be expanded, but for now, log a warning
|
||||
self.logger.warning(
|
||||
f"{self.indent} > Could not resolve sub-module '{alias.name}' in package '{module_path}'"
|
||||
)
|
||||
return replacement_node # Return replacement
|
||||
# --- END FIX ---
|
||||
|
||||
if module_path:
|
||||
# Local module (e.g., `from .foo import bar` or `from myapp.foo import bar`)
|
||||
self._log_debug(f"Resolving local from-import: `from {node.module or '.'} ...`")
|
||||
self.imploder.resolve_file(
|
||||
file_abs_path=module_path,
|
||||
f_out=self.f_out,
|
||||
is_dry_run=self.is_dry_run,
|
||||
indent_level=self.imploder.current_indent_level,
|
||||
)
|
||||
else:
|
||||
# External module
|
||||
self._log_debug(f"Found external from-import: `from {node.module or '.'} ...`")
|
||||
if self.is_dry_run:
|
||||
# On dry run, collect it for hoisting
|
||||
key = ast.unparse(node)
|
||||
if key not in self.imploder.external_imports:
|
||||
self.imploder.external_imports[key] = node
|
||||
|
||||
# On Pass 2 (write run), we remove this node entirely by replacing it
|
||||
# On Pass 1 (dry run), we also replace it
|
||||
return replacement_node
|
||||
|
||||
|
||||
class Imploder:
|
||||
"""
|
||||
Core class for handling the implosion process.
|
||||
Manages state, file processing, and the two-pass analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, root_dir: str, enable_import_logging: bool = False):
|
||||
self.root_dir = os.path.realpath(root_dir)
|
||||
self.processed_files: Set[str] = set()
|
||||
self.external_imports: Dict[str, ast.AST] = {}
|
||||
self.future_imports: Dict[str, ast.AST] = {}
|
||||
self.current_indent_level = 0
|
||||
self.enable_import_logging = enable_import_logging # Added
|
||||
logger.info(f"Initialized Imploder with root: {self.root_dir}")
|
||||
|
||||
def implode(self, main_file_abs_path: str, output_file_path: str):
|
||||
"""
|
||||
Runs the full two-pass implosion process.
|
||||
"""
|
||||
if not os.path.isfile(main_file_abs_path):
|
||||
logger.critical(f"Main file not found: {main_file_abs_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# --- Pass 1: Analyze Dependencies (Dry Run) ---
|
||||
logger.info(
|
||||
f"--- PASS 1: Analyzing dependencies from {os.path.relpath(main_file_abs_path, self.root_dir)} ---"
|
||||
)
|
||||
self.processed_files.clear()
|
||||
self.external_imports.clear()
|
||||
self.future_imports.clear()
|
||||
|
||||
try:
|
||||
self.resolve_file(main_file_abs_path, f_out=None, is_dry_run=True, indent_level=0)
|
||||
except Exception as e:
|
||||
logger.critical(f"Error during analysis pass: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(
|
||||
f"--- Analysis complete. Found {len(self.future_imports)} __future__ imports and {len(self.external_imports)} external modules. ---"
|
||||
)
|
||||
|
||||
# --- Pass 2: Write Imploded File ---
|
||||
logger.info(f"--- PASS 2: Writing imploded file to {output_file_path} ---")
|
||||
self.processed_files.clear()
|
||||
|
||||
try:
|
||||
with open(output_file_path, "w", encoding="utf-8") as f_out:
|
||||
# Write headers
|
||||
f_out.write(f"#!/usr/bin/env python3\n")
|
||||
|
||||
f_out.write(f"# -*- coding: utf-8 -*-\n")
|
||||
f_out.write(f"import logging\n")
|
||||
f_out.write(f"\n# --- IMPLODED FILE: Generated by impLODE --- #\n")
|
||||
f_out.write(
|
||||
f"# --- Original main file: {os.path.relpath(main_file_abs_path, self.root_dir)} --- #\n"
|
||||
)
|
||||
|
||||
# Write __future__ imports
|
||||
if self.future_imports:
|
||||
f_out.write("\n# --- Hoisted __future__ Imports --- #\n")
|
||||
for node in self.future_imports.values():
|
||||
f_out.write(f"{ast.unparse(node)}\n")
|
||||
f_out.write("# --- End __future__ Imports --- #\n")
|
||||
|
||||
# --- NEW: Write Helper Function ---
|
||||
enable_logging_str = "True" if self.enable_import_logging else "False"
|
||||
f_out.write("\n# --- impLODE Helper Function --- #\n")
|
||||
f_out.write(f"_IMPLODE_LOGGING_ENABLED_ = {enable_logging_str}\n")
|
||||
f_out.write("def _implode_log_import(module_name, level, import_type):\n")
|
||||
f_out.write(
|
||||
' """Dummy function to replace imports and prevent IndentationErrors."""\n'
|
||||
)
|
||||
f_out.write(" if _IMPLODE_LOGGING_ENABLED_:\n")
|
||||
f_out.write(
|
||||
" print(f\"[impLODE Logger]: Skipped {import_type}: module='{module_name}', level={level}\")\n"
|
||||
)
|
||||
f_out.write(" pass\n")
|
||||
f_out.write("# --- End Helper Function --- #\n")
|
||||
|
||||
# Write external imports
|
||||
if self.external_imports:
|
||||
f_out.write("\n# --- Hoisted External Imports --- #\n")
|
||||
for node in self.external_imports.values():
|
||||
# As requested, wrap each external import in a try/except block
|
||||
f_out.write("try:\n")
|
||||
f_out.write(f" {ast.unparse(node)}\n")
|
||||
f_out.write("except ImportError:\n")
|
||||
f_out.write(" pass\n")
|
||||
f_out.write("# --- End External Imports --- #\n")
|
||||
|
||||
# Write the actual code
|
||||
self.resolve_file(main_file_abs_path, f_out=f_out, is_dry_run=False, indent_level=0)
|
||||
|
||||
except IOError as e:
|
||||
logger.critical(
|
||||
f"Could not write to output file {output_file_path}: {e}", exc_info=True
|
||||
)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.critical(f"Error during write pass: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"--- Implosion complete! Output saved to {output_file_path} ---")
|
||||
|
||||
def resolve_file(
|
||||
self, file_abs_path: str, f_out: Optional[TextIO], is_dry_run: bool, indent_level: int = 0
|
||||
):
|
||||
"""
|
||||
Recursively resolves a single file.
|
||||
- `is_dry_run=True`: Analyzes imports, populating `external_imports`.
|
||||
- `is_dry_run=False`: Writes transformed code to `f_out`.
|
||||
"""
|
||||
self.current_indent_level = indent_level
|
||||
indent = " " * indent_level
|
||||
|
||||
try:
|
||||
file_real_path = os.path.realpath(file_abs_path)
|
||||
# --- Reverted path logic as we only handle local files ---
|
||||
rel_path = os.path.relpath(file_real_path, self.root_dir)
|
||||
# --- End Revert ---
|
||||
except ValueError:
|
||||
# Happens if file is on a different drive than root_dir on Windows
|
||||
logger.warning(
|
||||
f"{indent}Cannot calculate relative path for {file_abs_path}. Using absolute."
|
||||
)
|
||||
rel_path = file_abs_path
|
||||
|
||||
# --- 1. Circular Dependency Check ---
|
||||
if file_real_path in self.processed_files:
|
||||
logger.debug(f"{indent}Skipping already processed file: {rel_path}")
|
||||
return
|
||||
|
||||
logger.info(f"{indent}Processing: {rel_path}")
|
||||
self.processed_files.add(file_real_path)
|
||||
|
||||
# --- 2. Read File ---
|
||||
try:
|
||||
with open(file_real_path, "r", encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"{indent}File not found: {file_real_path}")
|
||||
return
|
||||
except UnicodeDecodeError:
|
||||
logger.error(f"{indent}Could not decode file (not utf-8): {file_real_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"{indent}Could not read file {file_real_path}: {e}")
|
||||
return
|
||||
|
||||
# --- 2.5. Validate Syntax with py_compile ---
|
||||
try:
|
||||
# As requested, validate syntax using py_compile before AST parsing
|
||||
py_compile.compile(file_real_path, doraise=True, quiet=1)
|
||||
logger.debug(f"{indent}Syntax OK (py_compile): {rel_path}")
|
||||
except py_compile.PyCompileError as e:
|
||||
# This error includes file, line, and message
|
||||
logger.error(
|
||||
f"{indent}Syntax error (py_compile) in {e.file} on line {e.lineno}: {e.msg}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
# Catch other potential errors like permission issues
|
||||
logger.error(f"{indent}Error during py_compile for {rel_path}: {e}")
|
||||
return
|
||||
|
||||
# --- 3. Parse AST ---
|
||||
try:
|
||||
tree = ast.parse(code, filename=file_real_path)
|
||||
except SyntaxError as e:
|
||||
logger.error(f"{indent}Syntax error in {rel_path} on line {e.lineno}: {e.msg}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"{indent}Could not parse AST for {rel_path}: {e}")
|
||||
return
|
||||
|
||||
# --- 4. Transform AST ---
|
||||
transformer = ImportTransformer(
|
||||
imploder=self,
|
||||
current_file_path=file_real_path,
|
||||
f_out=f_out,
|
||||
is_dry_run=is_dry_run,
|
||||
indent_level=indent_level,
|
||||
)
|
||||
|
||||
try:
|
||||
new_tree = transformer.visit(tree)
|
||||
except Exception as e:
|
||||
logger.error(f"{indent}Error transforming AST for {rel_path}: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
# --- 5. Write Content (on Pass 2) ---
|
||||
if not is_dry_run and f_out:
|
||||
try:
|
||||
ast.fix_missing_locations(new_tree)
|
||||
f_out.write(f"\n\n# --- Content from {rel_path} --- #\n")
|
||||
f_out.write(ast.unparse(new_tree))
|
||||
f_out.write(f"\n# --- End of {rel_path} --- #\n")
|
||||
logger.info(f"{indent}Successfully wrote content from: {rel_path}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{indent}Could not unparse or write AST for {rel_path}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
self.current_indent_level = indent_level
|
||||
|
||||
|
||||
def setup_logging(level: int):
|
||||
"""Configures the root logger."""
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure the 'impLODE' logger
|
||||
logger.setLevel(level)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Configure the 'ImportTransformer' logger
|
||||
transformer_logger = logging.getLogger("ImportTransformer")
|
||||
transformer_logger.setLevel(level)
|
||||
transformer_logger.addHandler(handler)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the script."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="impLODE: Consolidate a multi-file Python project into one file.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"main_file", type=str, help="The main entry point .py file of your project."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
type=str,
|
||||
default="imploded.py",
|
||||
help="Path for the combined output file. (default: imploded.py)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--root",
|
||||
type=str,
|
||||
default=".",
|
||||
help="The root directory of the project for resolving absolute imports. (default: current directory)",
|
||||
)
|
||||
|
||||
log_group = parser.add_mutually_exclusive_group()
|
||||
log_group.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_const",
|
||||
dest="log_level",
|
||||
const=logging.DEBUG,
|
||||
default=logging.INFO,
|
||||
help="Enable verbose DEBUG logging.",
|
||||
)
|
||||
log_group.add_argument(
|
||||
"-q",
|
||||
"--quiet",
|
||||
action="store_const",
|
||||
dest="log_level",
|
||||
const=logging.WARNING,
|
||||
help="Suppress INFO logs, showing only WARNINGS and ERRORS.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-import-logging",
|
||||
action="store_true",
|
||||
help="Enable runtime logging for removed import statements in the final imploded script.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Setup ---
|
||||
setup_logging(args.log_level)
|
||||
|
||||
root_dir = os.path.abspath(args.root)
|
||||
main_file_path = os.path.abspath(args.main_file)
|
||||
output_file_path = os.path.abspath(args.output)
|
||||
|
||||
# --- Validation ---
|
||||
if not os.path.isdir(root_dir):
|
||||
logger.critical(f"Root directory not found: {root_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.isfile(main_file_path):
|
||||
logger.critical(f"Main file not found: {main_file_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not main_file_path.startswith(root_dir):
|
||||
logger.warning(f"Main file {main_file_path} is outside the specified root {root_dir}.")
|
||||
logger.warning("This may cause issues with absolute import resolution.")
|
||||
|
||||
if main_file_path == output_file_path:
|
||||
logger.critical("Output file cannot be the same as the main file.")
|
||||
sys.exit(1)
|
||||
|
||||
# --- Run ---
|
||||
imploder = Imploder(
|
||||
root_dir=root_dir, enable_import_logging=args.enable_import_logging # Added
|
||||
)
|
||||
imploder.implode(main_file_abs_path=main_file_path, output_file_path=output_file_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -56,10 +56,14 @@ class FactExtractor:
|
||||
else:
|
||||
if len(current_phrase) >= 2:
|
||||
phrases.append(" ".join(current_phrase))
|
||||
elif len(current_phrase) == 1:
|
||||
phrases.append(current_phrase[0]) # Single capitalized words
|
||||
current_phrase = []
|
||||
|
||||
if len(current_phrase) >= 2:
|
||||
phrases.append(" ".join(current_phrase))
|
||||
elif len(current_phrase) == 1:
|
||||
phrases.append(current_phrase[0]) # Single capitalized words
|
||||
|
||||
return list(set(phrases))
|
||||
|
||||
@ -165,13 +169,14 @@ class FactExtractor:
|
||||
return relationships
|
||||
|
||||
def extract_metadata(self, text: str) -> Dict[str, Any]:
|
||||
word_count = len(text.split())
|
||||
sentence_count = len(re.split(r"[.!?]", text))
|
||||
word_count = len(text.split()) if text.strip() else 0
|
||||
sentences = re.split(r"[.!?]", text.strip())
|
||||
sentence_count = len([s for s in sentences if s.strip()]) if text.strip() else 0
|
||||
|
||||
urls = re.findall(r"https?://[^\s]+", text)
|
||||
email_addresses = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text)
|
||||
dates = re.findall(
|
||||
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
|
||||
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b|\b\d{4}\b", text
|
||||
)
|
||||
numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text)
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .semantic_index import SemanticIndex
|
||||
|
||||
@ -35,11 +36,13 @@ class KnowledgeStore:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.lock = threading.Lock()
|
||||
self.semantic_index = SemanticIndex()
|
||||
self._initialize_store()
|
||||
self._load_index()
|
||||
|
||||
def _initialize_store(self):
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
@ -76,6 +79,7 @@ class KnowledgeStore:
|
||||
self.conn.commit()
|
||||
|
||||
def _load_index(self):
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("SELECT entry_id, content FROM knowledge_entries")
|
||||
@ -83,6 +87,7 @@ class KnowledgeStore:
|
||||
self.semantic_index.add_document(row[0], row[1])
|
||||
|
||||
def add_entry(self, entry: KnowledgeEntry):
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
@ -108,6 +113,7 @@ class KnowledgeStore:
|
||||
self.semantic_index.add_document(entry.entry_id, entry.content)
|
||||
|
||||
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
@ -148,12 +154,34 @@ class KnowledgeStore:
|
||||
def search_entries(
|
||||
self, query: str, category: Optional[str] = None, top_k: int = 5
|
||||
) -> List[KnowledgeEntry]:
|
||||
search_results = self.semantic_index.search(query, top_k * 2)
|
||||
# Combine semantic search with exact matching
|
||||
semantic_results = self.semantic_index.search(query, top_k * 2)
|
||||
|
||||
# Add FTS (Full Text Search) with exact word/phrase matching
|
||||
fts_results = self._fts_search(query, top_k * 2)
|
||||
|
||||
# Combine and deduplicate results with weighted scoring
|
||||
combined_results = {}
|
||||
|
||||
# Add semantic results with weight 0.7
|
||||
for entry_id, score in semantic_results:
|
||||
combined_results[entry_id] = score * 0.7
|
||||
|
||||
# Add FTS results with weight 1.0 (higher priority for exact matches)
|
||||
for entry_id, score in fts_results:
|
||||
if entry_id in combined_results:
|
||||
combined_results[entry_id] = max(combined_results[entry_id], score * 1.0)
|
||||
else:
|
||||
combined_results[entry_id] = score * 1.0
|
||||
|
||||
# Sort by combined score
|
||||
sorted_results = sorted(combined_results.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
entries = []
|
||||
for entry_id, score in search_results:
|
||||
for entry_id, score in sorted_results[:top_k]:
|
||||
if category:
|
||||
cursor.execute(
|
||||
"""
|
||||
@ -185,14 +213,71 @@ class KnowledgeStore:
|
||||
access_count=row[6],
|
||||
importance_score=row[7],
|
||||
)
|
||||
# Add search score to metadata for context
|
||||
entry.metadata["search_score"] = score
|
||||
entries.append(entry)
|
||||
|
||||
if len(entries) >= top_k:
|
||||
break
|
||||
|
||||
return entries
|
||||
|
||||
def _fts_search(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
||||
"""Full Text Search with exact word and partial sentence matching."""
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
# Prepare query for FTS
|
||||
query_lower = query.lower()
|
||||
query_words = query_lower.split()
|
||||
|
||||
# Search for exact phrase matches first
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT entry_id, content
|
||||
FROM knowledge_entries
|
||||
WHERE LOWER(content) LIKE ?
|
||||
""",
|
||||
(f"%{query_lower}%",),
|
||||
)
|
||||
|
||||
exact_matches = []
|
||||
partial_matches = []
|
||||
|
||||
for row in cursor.fetchall():
|
||||
entry_id, content = row
|
||||
content_lower = content.lower()
|
||||
|
||||
# Exact phrase match gets highest score
|
||||
if query_lower in content_lower:
|
||||
exact_matches.append((entry_id, 1.0))
|
||||
continue
|
||||
|
||||
# Count matching words
|
||||
content_words = set(content_lower.split())
|
||||
query_word_set = set(query_words)
|
||||
matching_words = len(query_word_set & content_words)
|
||||
|
||||
if matching_words > 0:
|
||||
# Score based on word overlap and position
|
||||
word_overlap_score = matching_words / len(query_word_set)
|
||||
|
||||
# Bonus for consecutive word sequences
|
||||
consecutive_bonus = 0.0
|
||||
for i in range(len(query_words)):
|
||||
for j in range(i + 1, min(i + 4, len(query_words) + 1)):
|
||||
phrase = " ".join(query_words[i:j])
|
||||
if phrase in content_lower:
|
||||
consecutive_bonus += 0.2 * (j - i)
|
||||
|
||||
total_score = min(0.99, word_overlap_score + consecutive_bonus)
|
||||
partial_matches.append((entry_id, total_score))
|
||||
|
||||
# Combine results, prioritizing exact matches
|
||||
all_results = exact_matches + partial_matches
|
||||
all_results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return all_results[:top_k]
|
||||
|
||||
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
@ -224,6 +309,7 @@ class KnowledgeStore:
|
||||
return entries
|
||||
|
||||
def update_importance(self, entry_id: str, importance_score: float):
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
@ -238,6 +324,7 @@ class KnowledgeStore:
|
||||
self.conn.commit()
|
||||
|
||||
def delete_entry(self, entry_id: str) -> bool:
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,))
|
||||
@ -251,6 +338,7 @@ class KnowledgeStore:
|
||||
return deleted
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
with self.lock:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM knowledge_entries")
|
||||
|
||||
@ -9,7 +9,7 @@ class SemanticIndex:
|
||||
self.documents: Dict[str, str] = {}
|
||||
self.vocabulary: Set[str] = set()
|
||||
self.idf_scores: Dict[str, float] = {}
|
||||
self.doc_vectors: Dict[str, Dict[str, float]] = {}
|
||||
self.doc_tf_scores: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
text = text.lower()
|
||||
@ -46,22 +46,26 @@ class SemanticIndex:
|
||||
tokens = self._tokenize(text)
|
||||
self.vocabulary.update(tokens)
|
||||
|
||||
self._compute_idf()
|
||||
|
||||
tf_scores = self._compute_tf(tokens)
|
||||
self.doc_vectors[doc_id] = {
|
||||
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) for token in tokens
|
||||
}
|
||||
self.doc_tf_scores[doc_id] = tf_scores
|
||||
|
||||
self._compute_idf()
|
||||
|
||||
def remove_document(self, doc_id: str):
|
||||
if doc_id in self.documents:
|
||||
del self.documents[doc_id]
|
||||
if doc_id in self.doc_vectors:
|
||||
del self.doc_vectors[doc_id]
|
||||
if doc_id in self.doc_tf_scores:
|
||||
del self.doc_tf_scores[doc_id]
|
||||
self._compute_idf()
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
if not query.strip():
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
query_tf = self._compute_tf(query_tokens)
|
||||
|
||||
query_vector = {
|
||||
@ -69,7 +73,10 @@ class SemanticIndex:
|
||||
}
|
||||
|
||||
scores = []
|
||||
for doc_id, doc_vector in self.doc_vectors.items():
|
||||
for doc_id, doc_tf in self.doc_tf_scores.items():
|
||||
doc_vector = {
|
||||
token: doc_tf.get(token, 0) * self.idf_scores.get(token, 0) for token in doc_tf
|
||||
}
|
||||
similarity = self._cosine_similarity(query_vector, doc_vector)
|
||||
scores.append((doc_id, similarity))
|
||||
|
||||
|
||||
384
pr/multiplexer.py.bak
Normal file
384
pr/multiplexer.py.bak
Normal file
@ -0,0 +1,384 @@
|
||||
import queue
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from pr.tools.process_handlers import detect_process_type, get_handler_for_process
|
||||
from pr.tools.prompt_detection import get_global_detector
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
class TerminalMultiplexer:
|
||||
def __init__(self, name, show_output=True):
|
||||
self.name = name
|
||||
self.show_output = show_output
|
||||
self.stdout_buffer = []
|
||||
self.stderr_buffer = []
|
||||
self.stdout_queue = queue.Queue()
|
||||
self.stderr_queue = queue.Queue()
|
||||
self.active = True
|
||||
self.lock = threading.Lock()
|
||||
self.metadata = {
|
||||
"start_time": time.time(),
|
||||
"last_activity": time.time(),
|
||||
"interaction_count": 0,
|
||||
"process_type": "unknown",
|
||||
"state": "active",
|
||||
}
|
||||
self.handler = None
|
||||
self.prompt_detector = get_global_detector()
|
||||
|
||||
if self.show_output:
|
||||
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
|
||||
self.display_thread.start()
|
||||
|
||||
def _display_worker(self):
|
||||
while self.active:
|
||||
try:
|
||||
line = self.stdout_queue.get(timeout=0.1)
|
||||
if line:
|
||||
if self.metadata.get("process_type") in ["vim", "ssh"]:
|
||||
sys.stdout.write(line)
|
||||
else:
|
||||
sys.stdout.write(f"{Colors.GRAY}[{self.name}]{Colors.RESET} {line}\n")
|
||||
sys.stdout.flush()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
try:
|
||||
line = self.stderr_queue.get(timeout=0.1)
|
||||
if line:
|
||||
if self.metadata.get("process_type") in ["vim", "ssh"]:
|
||||
sys.stderr.write(line)
|
||||
else:
|
||||
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}\n")
|
||||
sys.stderr.flush()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
def write_stdout(self, data):
|
||||
with self.lock:
|
||||
self.stdout_buffer.append(data)
|
||||
self.metadata["last_activity"] = time.time()
|
||||
# Update handler state if available
|
||||
if self.handler:
|
||||
self.handler.update_state(data)
|
||||
# Update prompt detector
|
||||
self.prompt_detector.update_session_state(
|
||||
self.name, data, self.metadata["process_type"]
|
||||
)
|
||||
if self.show_output:
|
||||
self.stdout_queue.put(data)
|
||||
|
||||
def write_stderr(self, data):
|
||||
with self.lock:
|
||||
self.stderr_buffer.append(data)
|
||||
self.metadata["last_activity"] = time.time()
|
||||
# Update handler state if available
|
||||
if self.handler:
|
||||
self.handler.update_state(data)
|
||||
# Update prompt detector
|
||||
self.prompt_detector.update_session_state(
|
||||
self.name, data, self.metadata["process_type"]
|
||||
)
|
||||
if self.show_output:
|
||||
self.stderr_queue.put(data)
|
||||
|
||||
def get_stdout(self):
|
||||
with self.lock:
|
||||
return "".join(self.stdout_buffer)
|
||||
|
||||
def get_stderr(self):
|
||||
with self.lock:
|
||||
return "".join(self.stderr_buffer)
|
||||
|
||||
def get_all_output(self):
|
||||
with self.lock:
|
||||
return {
|
||||
"stdout": "".join(self.stdout_buffer),
|
||||
"stderr": "".join(self.stderr_buffer),
|
||||
}
|
||||
|
||||
def get_metadata(self):
|
||||
with self.lock:
|
||||
return self.metadata.copy()
|
||||
|
||||
def update_metadata(self, key, value):
|
||||
with self.lock:
|
||||
self.metadata[key] = value
|
||||
|
||||
def set_process_type(self, process_type):
|
||||
"""Set the process type and initialize appropriate handler."""
|
||||
with self.lock:
|
||||
self.metadata["process_type"] = process_type
|
||||
self.handler = get_handler_for_process(process_type, self)
|
||||
|
||||
def send_input(self, input_data):
|
||||
if hasattr(self, "process") and self.process.poll() is None:
|
||||
try:
|
||||
self.process.stdin.write(input_data + "\n")
|
||||
self.process.stdin.flush()
|
||||
with self.lock:
|
||||
self.metadata["last_activity"] = time.time()
|
||||
self.metadata["interaction_count"] += 1
|
||||
except Exception as e:
|
||||
self.write_stderr(f"Error sending input: {e}")
|
||||
else:
|
||||
# This will be implemented when we have a process attached
|
||||
# For now, just update activity
|
||||
with self.lock:
|
||||
self.metadata["last_activity"] = time.time()
|
||||
self.metadata["interaction_count"] += 1
|
||||
|
||||
def close(self):
|
||||
self.active = False
|
||||
if hasattr(self, "display_thread"):
|
||||
self.display_thread.join(timeout=1)
|
||||
|
||||
|
||||
_multiplexers = {}
|
||||
_mux_counter = 0
|
||||
_mux_lock = threading.Lock()
|
||||
_background_monitor = None
|
||||
_monitor_active = False
|
||||
_monitor_interval = 0.2 # 200ms
|
||||
|
||||
|
||||
def create_multiplexer(name=None, show_output=True):
|
||||
global _mux_counter
|
||||
with _mux_lock:
|
||||
if name is None:
|
||||
_mux_counter += 1
|
||||
name = f"process-{_mux_counter}"
|
||||
mux = TerminalMultiplexer(name, show_output)
|
||||
_multiplexers[name] = mux
|
||||
return name, mux
|
||||
|
||||
|
||||
def get_multiplexer(name):
|
||||
return _multiplexers.get(name)
|
||||
|
||||
|
||||
def close_multiplexer(name):
|
||||
mux = _multiplexers.get(name)
|
||||
if mux:
|
||||
mux.close()
|
||||
del _multiplexers[name]
|
||||
|
||||
|
||||
def get_all_multiplexer_states():
|
||||
with _mux_lock:
|
||||
states = {}
|
||||
for name, mux in _multiplexers.items():
|
||||
states[name] = {
|
||||
"metadata": mux.get_metadata(),
|
||||
"output_summary": {
|
||||
"stdout_lines": len(mux.stdout_buffer),
|
||||
"stderr_lines": len(mux.stderr_buffer),
|
||||
},
|
||||
}
|
||||
return states
|
||||
|
||||
|
||||
def cleanup_all_multiplexers():
|
||||
for mux in list(_multiplexers.values()):
|
||||
mux.close()
|
||||
_multiplexers.clear()
|
||||
|
||||
|
||||
# Background process management
|
||||
_background_processes = {}
|
||||
_process_lock = threading.Lock()
|
||||
|
||||
|
||||
class BackgroundProcess:
|
||||
def __init__(self, name, command):
|
||||
self.name = name
|
||||
self.command = command
|
||||
self.process = None
|
||||
self.multiplexer = None
|
||||
self.status = "starting"
|
||||
self.start_time = time.time()
|
||||
self.end_time = None
|
||||
|
||||
def start(self):
|
||||
"""Start the background process."""
|
||||
try:
|
||||
# Create multiplexer for this process
|
||||
mux_name, mux = create_multiplexer(self.name, show_output=False)
|
||||
self.multiplexer = mux
|
||||
|
||||
# Detect process type
|
||||
process_type = detect_process_type(self.command)
|
||||
mux.set_process_type(process_type)
|
||||
|
||||
# Start the subprocess
|
||||
self.process = subprocess.Popen(
|
||||
self.command,
|
||||
shell=True,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
self.status = "running"
|
||||
|
||||
# Start output monitoring threads
|
||||
threading.Thread(target=self._monitor_stdout, daemon=True).start()
|
||||
threading.Thread(target=self._monitor_stderr, daemon=True).start()
|
||||
|
||||
return {"status": "success", "pid": self.process.pid}
|
||||
|
||||
except Exception as e:
|
||||
self.status = "error"
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def _monitor_stdout(self):
|
||||
"""Monitor stdout from the process."""
|
||||
try:
|
||||
for line in iter(self.process.stdout.readline, ""):
|
||||
if line:
|
||||
self.multiplexer.write_stdout(line.rstrip("\n\r"))
|
||||
except Exception as e:
|
||||
self.write_stderr(f"Error reading stdout: {e}")
|
||||
finally:
|
||||
self._check_completion()
|
||||
|
||||
def _monitor_stderr(self):
|
||||
"""Monitor stderr from the process."""
|
||||
try:
|
||||
for line in iter(self.process.stderr.readline, ""):
|
||||
if line:
|
||||
self.multiplexer.write_stderr(line.rstrip("\n\r"))
|
||||
except Exception as e:
|
||||
self.write_stderr(f"Error reading stderr: {e}")
|
||||
|
||||
def _check_completion(self):
|
||||
"""Check if process has completed."""
|
||||
if self.process and self.process.poll() is not None:
|
||||
self.status = "completed"
|
||||
self.end_time = time.time()
|
||||
|
||||
def get_info(self):
|
||||
"""Get process information."""
|
||||
self._check_completion()
|
||||
return {
|
||||
"name": self.name,
|
||||
"command": self.command,
|
||||
"status": self.status,
|
||||
"pid": self.process.pid if self.process else None,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"runtime": (
|
||||
time.time() - self.start_time
|
||||
if not self.end_time
|
||||
else self.end_time - self.start_time
|
||||
),
|
||||
}
|
||||
|
||||
def get_output(self, lines=None):
|
||||
"""Get process output."""
|
||||
if not self.multiplexer:
|
||||
return []
|
||||
|
||||
all_output = self.multiplexer.get_all_output()
|
||||
stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else []
|
||||
stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else []
|
||||
|
||||
combined = stdout_lines + stderr_lines
|
||||
if lines:
|
||||
combined = combined[-lines:]
|
||||
|
||||
return [line for line in combined if line.strip()]
|
||||
|
||||
def send_input(self, input_text):
|
||||
"""Send input to the process."""
|
||||
if self.process and self.status == "running":
|
||||
try:
|
||||
self.process.stdin.write(input_text + "\n")
|
||||
self.process.stdin.flush()
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
return {"status": "error", "error": "Process not running or no stdin"}
|
||||
|
||||
def kill(self):
|
||||
"""Kill the process."""
|
||||
if self.process and self.status == "running":
|
||||
try:
|
||||
self.process.terminate()
|
||||
# Wait a bit for graceful termination
|
||||
time.sleep(0.1)
|
||||
if self.process.poll() is None:
|
||||
self.process.kill()
|
||||
self.status = "killed"
|
||||
self.end_time = time.time()
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
return {"status": "error", "error": "Process not running"}
|
||||
|
||||
|
||||
def start_background_process(name, command):
|
||||
"""Start a background process."""
|
||||
with _process_lock:
|
||||
if name in _background_processes:
|
||||
return {"status": "error", "error": f"Process {name} already exists"}
|
||||
|
||||
process = BackgroundProcess(name, command)
|
||||
result = process.start()
|
||||
|
||||
if result["status"] == "success":
|
||||
_background_processes[name] = process
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all_sessions():
|
||||
"""Get all background process sessions."""
|
||||
with _process_lock:
|
||||
sessions = {}
|
||||
for name, process in _background_processes.items():
|
||||
sessions[name] = process.get_info()
|
||||
return sessions
|
||||
|
||||
|
||||
def get_session_info(name):
|
||||
"""Get information about a specific session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
return process.get_info() if process else None
|
||||
|
||||
|
||||
def get_session_output(name, lines=None):
|
||||
"""Get output from a specific session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
return process.get_output(lines) if process else None
|
||||
|
||||
|
||||
def send_input_to_session(name, input_text):
|
||||
"""Send input to a background session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
return (
|
||||
process.send_input(input_text)
|
||||
if process
|
||||
else {"status": "error", "error": "Session not found"}
|
||||
)
|
||||
|
||||
|
||||
def kill_session(name):
|
||||
"""Kill a background session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
if process:
|
||||
result = process.kill()
|
||||
if result["status"] == "success":
|
||||
del _background_processes[name]
|
||||
return result
|
||||
return {"status": "error", "error": "Session not found"}
|
||||
585
pr/server.py
Normal file
585
pr/server.py
Normal file
@ -0,0 +1,585 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import pathlib
|
||||
import ast
|
||||
from types import SimpleNamespace
|
||||
import time
|
||||
from aiohttp import web
|
||||
import aiohttp
|
||||
|
||||
# --- Optional external dependency used in your original code ---
|
||||
from pr.ads import AsyncDataSet # noqa: E402
|
||||
|
||||
import sys
|
||||
from http import cookies
|
||||
import urllib.parse
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import urllib.request
|
||||
import pathlib
|
||||
import pickle
|
||||
import zlib
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import zlib
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, base_dir: Path | str = "."):
|
||||
self.base_dir = Path(base_dir).resolve()
|
||||
self.base_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def is_cacheable(self, obj):
|
||||
return isinstance(obj, (int, str, bool, float))
|
||||
|
||||
def generate_key(self, *args, **kwargs):
|
||||
|
||||
return zlib.crc32(
|
||||
json.dumps(
|
||||
{
|
||||
"args": [arg for arg in args if self.is_cacheable(arg)],
|
||||
"kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)},
|
||||
},
|
||||
sort_keys=True,
|
||||
default=str,
|
||||
).encode()
|
||||
)
|
||||
|
||||
def set(self, key, value):
|
||||
key = self.generate_key(key)
|
||||
data = {"value": value, "timestamp": time.time()}
|
||||
serialized_data = pickle.dumps(data)
|
||||
with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f:
|
||||
f.write(serialized_data)
|
||||
|
||||
def get(self, key, default=None):
|
||||
key = self.generate_key(key)
|
||||
try:
|
||||
with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f:
|
||||
data = pickle.loads(f.read())
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
return default
|
||||
|
||||
def is_cacheable(self, obj):
|
||||
return isinstance(obj, (int, str, bool, float))
|
||||
|
||||
def cached(self, expiration=60):
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()):
|
||||
# return await func(*args, **kwargs)
|
||||
key = self.generate_key(*args, **kwargs)
|
||||
print("Cache hit:", key)
|
||||
cached_data = self.get(key)
|
||||
if cached_data is not None:
|
||||
if expiration is None or (time.time() - cached_data["timestamp"]) < expiration:
|
||||
return cached_data["value"]
|
||||
result = await func(*args, **kwargs)
|
||||
self.set(key, result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
cache = Cache()
|
||||
|
||||
|
||||
class CGI:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.environ = os.environ
|
||||
if not self.gateway_interface == "CGI/1.1":
|
||||
return
|
||||
self.status = 200
|
||||
self.headers = {}
|
||||
self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi"))
|
||||
self.headers["Content-Length"] = "0"
|
||||
self.headers["Cache-Control"] = "no-cache"
|
||||
self.headers["Access-Control-Allow-Origin"] = "*"
|
||||
self.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
self.cookies = cookies.SimpleCookie()
|
||||
|
||||
self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"])
|
||||
self.file = io.BytesIO()
|
||||
|
||||
def validate(self, val, fn, error):
|
||||
if fn(val):
|
||||
return True
|
||||
self.status = 421
|
||||
self.write({"type": "error", "message": error})
|
||||
exit()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.get(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
result = self.query.get(key, [default])
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def get_bool(self, key):
|
||||
return str(self.get(key)).lower() in ["true", "yes", "1"]
|
||||
|
||||
@property
|
||||
def cache_key(self):
|
||||
return self.environ["CACHE_KEY"]
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
env = os.environ.copy()
|
||||
return env
|
||||
|
||||
@property
|
||||
def gateway_interface(self):
|
||||
return self.env.get("GATEWAY_INTERFACE", "")
|
||||
|
||||
@property
|
||||
def request_method(self):
|
||||
return self.env["REQUEST_METHOD"]
|
||||
|
||||
@property
|
||||
def query_string(self):
|
||||
return self.env["QUERY_STRING"]
|
||||
|
||||
@property
|
||||
def script_name(self):
|
||||
return self.env["SCRIPT_NAME"]
|
||||
|
||||
@property
|
||||
def path_info(self):
|
||||
return self.env["PATH_INFO"]
|
||||
|
||||
@property
|
||||
def server_name(self):
|
||||
return self.env["SERVER_NAME"]
|
||||
|
||||
@property
|
||||
def server_port(self):
|
||||
return self.env["SERVER_PORT"]
|
||||
|
||||
@property
|
||||
def server_protocol(self):
|
||||
return self.env["SERVER_PROTOCOL"]
|
||||
|
||||
@property
|
||||
def remote_addr(self):
|
||||
return self.env["REMOTE_ADDR"]
|
||||
|
||||
def validate_get(self, key):
|
||||
self.validate(self.get(key), lambda x: bool(x), f"Missing {key}")
|
||||
|
||||
def print(self, data):
|
||||
self.write(data)
|
||||
|
||||
def write(self, data):
|
||||
if not isinstance(data, bytes) and not isinstance(data, str):
|
||||
data = json.dumps(data, default=str, indent=4)
|
||||
try:
|
||||
data = data.encode()
|
||||
except:
|
||||
pass
|
||||
self.file.write(data)
|
||||
self.headers["Content-Length"] = str(len(self.file.getvalue()))
|
||||
|
||||
@property
|
||||
def http_status(self):
|
||||
return f"HTTP/1.1 {self.status}\r\n".encode("utf-8")
|
||||
|
||||
@property
|
||||
def http_headers(self):
|
||||
headers = io.BytesIO()
|
||||
for header in self.headers:
|
||||
headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8"))
|
||||
headers.write(b"\r\n")
|
||||
return headers.getvalue()
|
||||
|
||||
@property
|
||||
def http_body(self):
|
||||
return self.file.getvalue()
|
||||
|
||||
@property
|
||||
def http_response(self):
|
||||
return self.http_status + self.http_headers + self.http_body
|
||||
|
||||
def flush(self, response=None):
|
||||
|
||||
if response:
|
||||
try:
|
||||
response = response.encode()
|
||||
except:
|
||||
pass
|
||||
sys.stdout.buffer.write(response)
|
||||
sys.stdout.buffer.flush()
|
||||
return
|
||||
sys.stdout.buffer.write(self.http_response)
|
||||
sys.stdout.buffer.flush()
|
||||
|
||||
def __del__(self):
|
||||
if self.http_body:
|
||||
self.flush()
|
||||
exit()
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Utilities
|
||||
# -------------------------------
|
||||
def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]:
|
||||
matches: List[Tuple[str, str]] = []
|
||||
for root, _, files in os.walk(directory):
|
||||
for file in files:
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
path = os.path.join(root, file)
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
source = fh.read()
|
||||
tree = ast.parse(source, filename=path)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.AsyncFunctionDef) and node.name == name:
|
||||
func_src = ast.get_source_segment(source, node)
|
||||
if func_src:
|
||||
matches.append((path, func_src))
|
||||
break
|
||||
except (SyntaxError, UnicodeDecodeError):
|
||||
continue
|
||||
return matches
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Server
|
||||
# -------------------------------
|
||||
class Static(SimpleNamespace):
|
||||
pass
|
||||
|
||||
|
||||
class View(web.View):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.db = self.request.app["db"]
|
||||
self.server = self.request.app["server"]
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
return aiohttp.ClientSession()
|
||||
|
||||
|
||||
class RetoorServer:
|
||||
def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None:
|
||||
self.base_dir = Path(base_dir).resolve()
|
||||
self.port = port
|
||||
|
||||
self.static = Static()
|
||||
self.db = AsyncDataSet(".default.db")
|
||||
|
||||
self._logger = logging.getLogger("retoor.server")
|
||||
self._logger.setLevel(logging.INFO)
|
||||
if not self._logger.handlers:
|
||||
h = logging.StreamHandler(sys.stdout)
|
||||
fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
|
||||
h.setFormatter(fmt)
|
||||
self._logger.addHandler(h)
|
||||
|
||||
self._func_cache: Dict[str, Tuple[str, str]] = {}
|
||||
self._compiled_cache: Dict[str, object] = {}
|
||||
self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {}
|
||||
self._static_path_cache: Dict[str, Path] = {}
|
||||
self._base_dir_str = str(self.base_dir)
|
||||
|
||||
self.app = web.Application()
|
||||
|
||||
app = self.app
|
||||
|
||||
app["db"] = self.db
|
||||
for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir():
|
||||
if path.suffix == ".py":
|
||||
exec(open(path).read(), globals(), locals())
|
||||
|
||||
self.app["server"] = self
|
||||
|
||||
# from rpanel import create_app as create_rpanel_app
|
||||
# self.app.add_subapp("/api/{tail:.*}", create_rpanel_app())
|
||||
|
||||
self.app.router.add_route("*", "/{tail:.*}", self.handle_any)
|
||||
|
||||
# ---------------------------
|
||||
# Simple endpoints
|
||||
# ---------------------------
|
||||
async def handle_root(self, request: web.Request) -> web.Response:
|
||||
return web.Response(text="Static file and cgi server from retoor.")
|
||||
|
||||
async def handle_hello(self, request: web.Request) -> web.Response:
|
||||
return web.Response(text="Welcome to the custom HTTP server!")
|
||||
|
||||
# ---------------------------
|
||||
# Dynamic function dispatch
|
||||
# ---------------------------
|
||||
def _path_to_funcname(self, path: str) -> str:
|
||||
# /foo/bar -> foo_bar ; strip leading/trailing slashes safely
|
||||
return path.strip("/").replace("/", "_")
|
||||
|
||||
async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]:
|
||||
relpath = request.path.strip("/")
|
||||
relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace(
|
||||
"..", ""
|
||||
)
|
||||
if not relpath:
|
||||
return None
|
||||
|
||||
last_part = relpath.split("/")[-1]
|
||||
if "." in last_part:
|
||||
return None
|
||||
|
||||
funcname = self._path_to_funcname(request.path)
|
||||
|
||||
if funcname not in self._compiled_cache:
|
||||
entry = self._func_cache.get(funcname)
|
||||
if entry is None:
|
||||
matches = get_function_source(funcname, directory=str(self.base_dir))
|
||||
if not matches:
|
||||
return None
|
||||
entry = matches[0]
|
||||
self._func_cache[funcname] = entry
|
||||
|
||||
filepath, source = entry
|
||||
code_obj = compile(source, filepath, "exec")
|
||||
self._compiled_cache[funcname] = (code_obj, filepath)
|
||||
|
||||
code_obj, filepath = self._compiled_cache[funcname]
|
||||
ctx = {
|
||||
"static": self.static,
|
||||
"request": request,
|
||||
"relpath": relpath,
|
||||
"os": os,
|
||||
"sys": sys,
|
||||
"web": web,
|
||||
"asyncio": asyncio,
|
||||
"subprocess": subprocess,
|
||||
"urllib": urllib.parse,
|
||||
"app": request.app,
|
||||
"db": self.db,
|
||||
}
|
||||
exec(code_obj, ctx, ctx)
|
||||
coro = ctx.get(funcname)
|
||||
if not callable(coro):
|
||||
return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.")
|
||||
return await coro(request)
|
||||
|
||||
# ---------------------------
|
||||
# Static files
|
||||
# ---------------------------
|
||||
async def handle_static(self, request: web.Request) -> web.StreamResponse:
|
||||
path = request.path
|
||||
|
||||
if path in self._static_path_cache:
|
||||
cached_path = self._static_path_cache[path]
|
||||
if cached_path is None:
|
||||
return web.Response(status=404, text="Not found")
|
||||
try:
|
||||
return web.FileResponse(path=cached_path)
|
||||
except Exception as e:
|
||||
return web.Response(status=500, text=f"Failed to send file: {e!r}")
|
||||
|
||||
relpath = path.replace('..",', "").lstrip("/").rstrip("/")
|
||||
abspath = (self.base_dir / relpath).resolve()
|
||||
|
||||
if not str(abspath).startswith(self._base_dir_str):
|
||||
return web.Response(status=403, text="Forbidden")
|
||||
|
||||
if not abspath.exists():
|
||||
self._static_path_cache[path] = None
|
||||
return web.Response(status=404, text="Not found")
|
||||
|
||||
if abspath.is_dir():
|
||||
return web.Response(status=403, text="Directory listing forbidden")
|
||||
|
||||
self._static_path_cache[path] = abspath
|
||||
try:
|
||||
return web.FileResponse(path=abspath)
|
||||
except Exception as e:
|
||||
return web.Response(status=500, text=f"Failed to send file: {e!r}")
|
||||
|
||||
# ---------------------------
|
||||
# CGI
|
||||
# ---------------------------
|
||||
def _find_cgi_script(
|
||||
self, path_only: str
|
||||
) -> Tuple[Optional[Path], Optional[str], Optional[str]]:
|
||||
path_only = "pr/cgi/" + path_only
|
||||
if path_only in self._cgi_cache:
|
||||
return self._cgi_cache[path_only]
|
||||
|
||||
split_path = [p for p in path_only.split("/") if p]
|
||||
for i in range(len(split_path), -1, -1):
|
||||
candidate = "/" + "/".join(split_path[:i])
|
||||
candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve()
|
||||
if (
|
||||
str(candidate_fs).startswith(self._base_dir_str)
|
||||
and candidate_fs.parent.is_dir()
|
||||
and candidate_fs.parent.name == "cgi"
|
||||
and candidate_fs.suffix == ".bin"
|
||||
and candidate_fs.is_file()
|
||||
and os.access(candidate_fs, os.X_OK)
|
||||
):
|
||||
script_name = candidate if candidate.startswith("/") else "/" + candidate
|
||||
path_info = path_only[len(script_name) :]
|
||||
result = (candidate_fs, script_name, path_info)
|
||||
self._cgi_cache[path_only] = result
|
||||
return result
|
||||
|
||||
result = (None, None, None)
|
||||
self._cgi_cache[path_only] = result
|
||||
return result
|
||||
|
||||
async def _run_cgi(
|
||||
self, request: web.Request, script_path: Path, script_name: str, path_info: str
|
||||
) -> web.Response:
|
||||
start_time = time.time()
|
||||
method = request.method
|
||||
query = request.query_string
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CACHE_KEY"] = json.dumps(
|
||||
{"method": method, "query": query, "script_name": script_name, "path_info": path_info},
|
||||
default=str,
|
||||
)
|
||||
env["GATEWAY_INTERFACE"] = "CGI/1.1"
|
||||
env["REQUEST_METHOD"] = method
|
||||
env["QUERY_STRING"] = query
|
||||
env["SCRIPT_NAME"] = script_name
|
||||
env["PATH_INFO"] = path_info
|
||||
|
||||
host_parts = request.host.split(":", 1)
|
||||
env["SERVER_NAME"] = host_parts[0]
|
||||
env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80"
|
||||
env["SERVER_PROTOCOL"] = "HTTP/1.1"
|
||||
|
||||
peername = request.transport.get_extra_info("peername")
|
||||
env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0])
|
||||
|
||||
for hk, hv in request.headers.items():
|
||||
hk_lower = hk.lower()
|
||||
if hk_lower == "content-type":
|
||||
env["CONTENT_TYPE"] = hv
|
||||
elif hk_lower == "content-length":
|
||||
env["CONTENT_LENGTH"] = hv
|
||||
else:
|
||||
env["HTTP_" + hk.upper().replace("-", "_")] = hv
|
||||
|
||||
post_data = None
|
||||
if method in {"POST", "PUT", "PATCH"}:
|
||||
post_data = await request.read()
|
||||
env.setdefault("CONTENT_LENGTH", str(len(post_data)))
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
str(script_path),
|
||||
stdin=asyncio.subprocess.PIPE if post_data else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
stdout, stderr = await proc.communicate(input=post_data)
|
||||
|
||||
if proc.returncode != 0:
|
||||
msg = stderr.decode(errors="ignore")
|
||||
return web.Response(status=500, text=f"CGI script error:\n{msg}")
|
||||
|
||||
header_end = stdout.find(b"\r\n\r\n")
|
||||
if header_end != -1:
|
||||
offset = 4
|
||||
else:
|
||||
header_end = stdout.find(b"\n\n")
|
||||
offset = 2
|
||||
|
||||
if header_end != -1:
|
||||
body = stdout[header_end + offset :]
|
||||
headers_blob = stdout[:header_end].decode(errors="ignore")
|
||||
|
||||
status_code = 200
|
||||
headers: Dict[str, str] = {}
|
||||
for line in headers_blob.splitlines():
|
||||
line = line.strip()
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
|
||||
k, v = line.split(":", 1)
|
||||
k_stripped = k.strip()
|
||||
if k_stripped.lower() == "status":
|
||||
try:
|
||||
status_code = int(v.strip().split()[0])
|
||||
except Exception:
|
||||
status_code = 200
|
||||
else:
|
||||
headers[k_stripped] = v.strip()
|
||||
|
||||
end_time = time.time()
|
||||
headers["X-Powered-By"] = "retoor"
|
||||
headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime())
|
||||
headers["X-Time"] = str(end_time)
|
||||
headers["X-Duration"] = str(end_time - start_time)
|
||||
if "error" in str(body).lower():
|
||||
print(headers)
|
||||
print(body)
|
||||
return web.Response(body=body, headers=headers, status=status_code)
|
||||
else:
|
||||
return web.Response(body=stdout)
|
||||
|
||||
except Exception as e:
|
||||
return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}")
|
||||
|
||||
# ---------------------------
|
||||
# Main router
|
||||
# ---------------------------
|
||||
async def handle_any(self, request: web.Request) -> web.StreamResponse:
|
||||
path = request.path
|
||||
|
||||
if path == "/":
|
||||
return await self.handle_root(request)
|
||||
if path == "/hello":
|
||||
return await self.handle_hello(request)
|
||||
|
||||
script_path, script_name, path_info = self._find_cgi_script(path)
|
||||
if script_path:
|
||||
return await self._run_cgi(request, script_path, script_name or "", path_info or "")
|
||||
|
||||
dyn = await self._maybe_dispatch_dynamic(request)
|
||||
if dyn is not None:
|
||||
return dyn
|
||||
|
||||
return await self.handle_static(request)
|
||||
|
||||
# ---------------------------
|
||||
# Runner
|
||||
# ---------------------------
|
||||
def run(self) -> None:
|
||||
self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir)
|
||||
web.run_app(self.app, port=self.port)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
server = RetoorServer(base_dir=".", port=8118)
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
else:
|
||||
cgi = CGI()
|
||||
@ -43,6 +43,11 @@ from pr.tools.memory import (
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.tools.context_modifier import (
|
||||
modify_context_add,
|
||||
modify_context_replace,
|
||||
modify_context_delete,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"add_knowledge_entry",
|
||||
|
||||
738
pr/tools/base.py
738
pr/tools/base.py
@ -1,596 +1,144 @@
|
||||
import inspect
|
||||
from typing import get_type_hints, get_origin, get_args
|
||||
import pr.tools
|
||||
|
||||
|
||||
def _type_to_json_schema(py_type):
|
||||
"""Convert Python type to JSON Schema type."""
|
||||
if py_type == str:
|
||||
return {"type": "string"}
|
||||
elif py_type == int:
|
||||
return {"type": "integer"}
|
||||
elif py_type == float:
|
||||
return {"type": "number"}
|
||||
elif py_type == bool:
|
||||
return {"type": "boolean"}
|
||||
elif get_origin(py_type) == list:
|
||||
return {"type": "array", "items": _type_to_json_schema(get_args(py_type)[0])}
|
||||
elif get_origin(py_type) == dict:
|
||||
return {"type": "object"}
|
||||
else:
|
||||
# Default to string for unknown types
|
||||
return {"type": "string"}
|
||||
|
||||
|
||||
def _generate_tool_schema(func):
|
||||
"""Generate JSON Schema for a tool function."""
|
||||
sig = inspect.signature(func)
|
||||
docstring = func.__doc__ or ""
|
||||
|
||||
# Extract description from docstring
|
||||
description = docstring.strip().split("\n")[0] if docstring else ""
|
||||
|
||||
# Get type hints
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name in ["db_conn", "python_globals"]: # Skip internal parameters
|
||||
continue
|
||||
|
||||
param_type = type_hints.get(param_name, str)
|
||||
schema = _type_to_json_schema(param_type)
|
||||
|
||||
# Add description from docstring if available
|
||||
param_doc = ""
|
||||
if docstring:
|
||||
lines = docstring.split("\n")
|
||||
in_args = False
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("Args:") or line.startswith("Arguments:"):
|
||||
in_args = True
|
||||
continue
|
||||
elif in_args and line.startswith(param_name + ":"):
|
||||
param_doc = line.split(":", 1)[1].strip()
|
||||
break
|
||||
elif in_args and line == "":
|
||||
continue
|
||||
elif in_args and not line.startswith(" "):
|
||||
break
|
||||
|
||||
if param_doc:
|
||||
schema["description"] = param_doc
|
||||
|
||||
# Set default if available
|
||||
if param.default != inspect.Parameter.empty:
|
||||
schema["default"] = param.default
|
||||
|
||||
properties[param_name] = schema
|
||||
|
||||
# Required if no default
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func.__name__,
|
||||
"description": description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_tools_definition():
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "kill_process",
|
||||
"description": "Terminate a background process by its PID. Use this to stop processes started with run_command that exceeded their timeout.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pid": {
|
||||
"type": "integer",
|
||||
"description": "The process ID returned by run_command when status is 'running'.",
|
||||
}
|
||||
},
|
||||
"required": ["pid"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tail_process",
|
||||
"description": "Monitor and retrieve output from a background process by its PID. Use this to check on processes started with run_command that exceeded their timeout.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pid": {
|
||||
"type": "integer",
|
||||
"description": "The process ID returned by run_command when status is 'running'.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["pid"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "http_fetch",
|
||||
"description": "Fetch content from an HTTP URL",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "The URL to fetch"},
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Optional HTTP headers",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_command",
|
||||
"description": "Execute a shell command and capture output. Returns immediately after timeout with PID if still running. Use tail_process to monitor or kill_process to terminate long-running commands.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum seconds to wait for completion",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "start_interactive_session",
|
||||
"description": "Execute an interactive terminal command that requires user input or displays UI. The command runs in a dedicated session and returns a session name.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The interactive command to execute (e.g., vim, nano, top)",
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "send_input_to_session",
|
||||
"description": "Send input to an interactive session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the session",
|
||||
},
|
||||
"input_data": {
|
||||
"type": "string",
|
||||
"description": "The input to send to the session",
|
||||
},
|
||||
},
|
||||
"required": ["session_name", "input_data"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_session_output",
|
||||
"description": "Read output from an interactive session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the session",
|
||||
}
|
||||
},
|
||||
"required": ["session_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "close_interactive_session",
|
||||
"description": "Close an interactive session.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the session",
|
||||
}
|
||||
},
|
||||
"required": ["session_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read contents of a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
}
|
||||
},
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "content"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_directory",
|
||||
"description": "List directory contents",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path",
|
||||
"default": ".",
|
||||
},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "List recursively",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mkdir",
|
||||
"description": "Create a new directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path of the directory to create",
|
||||
}
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "chdir",
|
||||
"description": "Change the current working directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "Path to change to"}},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "getpwd",
|
||||
"description": "Get the current working directory",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "db_set",
|
||||
"description": "Set a key-value pair in the database",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "The key"},
|
||||
"value": {"type": "string", "description": "The value"},
|
||||
},
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "db_get",
|
||||
"description": "Get a value from the database",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"key": {"type": "string", "description": "The key"}},
|
||||
"required": ["key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "db_query",
|
||||
"description": "Execute a database query",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "SQL query"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Perform a web search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string", "description": "Search query"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search_news",
|
||||
"description": "Perform a web search for news",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for news",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python_exec",
|
||||
"description": "Execute Python code",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute",
|
||||
}
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "index_source_directory",
|
||||
"description": "Index directory recursively and read all source files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "Path to index"}},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_replace",
|
||||
"description": "Search and replace text in a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "String to replace",
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement string",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "old_string", "new_string"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "apply_patch",
|
||||
"description": "Apply a patch to a file, especially for source code",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to patch",
|
||||
},
|
||||
"patch_content": {
|
||||
"type": "string",
|
||||
"description": "The patch content as a string",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "patch_content"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_diff",
|
||||
"description": "Create a unified diff between two files",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file1": {
|
||||
"type": "string",
|
||||
"description": "Path to the first file",
|
||||
},
|
||||
"file2": {
|
||||
"type": "string",
|
||||
"description": "Path to the second file",
|
||||
},
|
||||
"fromfile": {
|
||||
"type": "string",
|
||||
"description": "Label for the first file",
|
||||
"default": "file1",
|
||||
},
|
||||
"tofile": {
|
||||
"type": "string",
|
||||
"description": "Label for the second file",
|
||||
"default": "file2",
|
||||
},
|
||||
},
|
||||
"required": ["file1", "file2"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "open_editor",
|
||||
"description": "Open the RPEditor for a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
}
|
||||
},
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "close_editor",
|
||||
"description": "Close the RPEditor. Always close files when finished editing.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
}
|
||||
},
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "editor_insert_text",
|
||||
"description": "Insert text at cursor position in the editor",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"text": {"type": "string", "description": "Text to insert"},
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "Line number (optional)",
|
||||
},
|
||||
"col": {
|
||||
"type": "integer",
|
||||
"description": "Column number (optional)",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "text"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "editor_replace_text",
|
||||
"description": "Replace text in a range",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"start_line": {"type": "integer", "description": "Start line"},
|
||||
"start_col": {"type": "integer", "description": "Start column"},
|
||||
"end_line": {"type": "integer", "description": "End line"},
|
||||
"end_col": {"type": "integer", "description": "End column"},
|
||||
"new_text": {"type": "string", "description": "New text"},
|
||||
},
|
||||
"required": [
|
||||
"filepath",
|
||||
"start_line",
|
||||
"start_col",
|
||||
"end_line",
|
||||
"end_col",
|
||||
"new_text",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "editor_search",
|
||||
"description": "Search for a pattern in the file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"pattern": {"type": "string", "description": "Regex pattern"},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Start line",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "pattern"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_file_diff",
|
||||
"description": "Display a visual colored diff between two files with syntax highlighting and statistics",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath1": {
|
||||
"type": "string",
|
||||
"description": "Path to the original file",
|
||||
},
|
||||
"filepath2": {
|
||||
"type": "string",
|
||||
"description": "Path to the modified file",
|
||||
},
|
||||
"format_type": {
|
||||
"type": "string",
|
||||
"description": "Display format: 'unified' or 'side-by-side'",
|
||||
"default": "unified",
|
||||
},
|
||||
},
|
||||
"required": ["filepath1", "filepath2"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_edit_summary",
|
||||
"description": "Display a summary of all edit operations performed during the session",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_edit_timeline",
|
||||
"description": "Display a timeline of all edit operations with details",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"show_content": {
|
||||
"type": "boolean",
|
||||
"description": "Show content previews",
|
||||
"default": False,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "clear_edit_tracker",
|
||||
"description": "Clear the edit tracker to start fresh",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
]
|
||||
"""Dynamically generate tool definitions from all tool functions."""
|
||||
tools = []
|
||||
|
||||
# Get all functions from pr.tools modules
|
||||
for name in dir(pr.tools):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
obj = getattr(pr.tools, name)
|
||||
if callable(obj) and hasattr(obj, "__module__") and obj.__module__.startswith("pr.tools."):
|
||||
# Check if it's a tool function (has docstring and proper signature)
|
||||
if obj.__doc__:
|
||||
try:
|
||||
schema = _generate_tool_schema(obj)
|
||||
tools.append(schema)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not generate schema for {name}: {e}")
|
||||
continue
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def get_func_map(db_conn=None, python_globals=None):
|
||||
"""Dynamically generate function map for tool execution."""
|
||||
func_map = {}
|
||||
|
||||
# Include all functions from __all__ in pr.tools
|
||||
for name in getattr(pr.tools, "__all__", []):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
obj = getattr(pr.tools, name, None)
|
||||
if callable(obj) and hasattr(obj, "__module__") and obj.__module__.startswith("pr.tools."):
|
||||
sig = inspect.signature(obj)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
# Create wrapper based on parameters
|
||||
if "db_conn" in params and "python_globals" in params:
|
||||
func_map[name] = (
|
||||
lambda func=obj, db_conn=db_conn, python_globals=python_globals, **kw: func(
|
||||
**kw, db_conn=db_conn, python_globals=python_globals
|
||||
)
|
||||
)
|
||||
elif "db_conn" in params:
|
||||
func_map[name] = lambda func=obj, db_conn=db_conn, **kw: func(**kw, db_conn=db_conn)
|
||||
elif "python_globals" in params:
|
||||
func_map[name] = lambda func=obj, python_globals=python_globals, **kw: func(
|
||||
**kw, python_globals=python_globals
|
||||
)
|
||||
else:
|
||||
func_map[name] = lambda func=obj, **kw: func(**kw)
|
||||
|
||||
return func_map
|
||||
|
||||
@ -2,7 +2,6 @@ import select
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
|
||||
|
||||
_processes = {}
|
||||
|
||||
@ -99,6 +98,19 @@ def tail_process(pid: int, timeout: int = 30):
|
||||
|
||||
|
||||
def run_command(command, timeout=30, monitored=False, cwd=None):
|
||||
"""Execute a shell command and return the output.
|
||||
|
||||
Args:
|
||||
command: The shell command to execute.
|
||||
timeout: Maximum time in seconds to wait for completion.
|
||||
monitored: Whether to monitor the process (unused).
|
||||
cwd: Working directory for the command.
|
||||
|
||||
Returns:
|
||||
Dict with status, stdout, stderr, returncode, and optionally pid if still running.
|
||||
"""
|
||||
from pr.multiplexer import close_multiplexer, create_multiplexer
|
||||
|
||||
mux_name = None
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
|
||||
176
pr/tools/command.py.bak
Normal file
176
pr/tools/command.py.bak
Normal file
@ -0,0 +1,176 @@
|
||||
print(f"Executing command: {command}") print(f"Killing process: {pid}")import os
|
||||
import select
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
|
||||
|
||||
_processes = {}
|
||||
|
||||
|
||||
def _register_process(pid: int, process):
|
||||
_processes[pid] = process
|
||||
return _processes
|
||||
|
||||
|
||||
def _get_process(pid: int):
|
||||
return _processes.get(pid)
|
||||
|
||||
|
||||
def kill_process(pid: int):
|
||||
try:
|
||||
process = _get_process(pid)
|
||||
if process:
|
||||
process.kill()
|
||||
_processes.pop(pid)
|
||||
|
||||
mux_name = f"cmd-{pid}"
|
||||
if get_multiplexer(mux_name):
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {"status": "success", "message": f"Process {pid} has been killed"}
|
||||
else:
|
||||
return {"status": "error", "error": f"Process {pid} not found"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def tail_process(pid: int, timeout: int = 30):
|
||||
process = _get_process(pid)
|
||||
if process:
|
||||
mux_name = f"cmd-{pid}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
|
||||
if not mux:
|
||||
mux_name, mux = create_multiplexer(mux_name, show_output=True)
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
timeout_duration = timeout
|
||||
stdout_content = ""
|
||||
stderr_content = ""
|
||||
|
||||
while True:
|
||||
if process.poll() is not None:
|
||||
remaining_stdout, remaining_stderr = process.communicate()
|
||||
if remaining_stdout:
|
||||
mux.write_stdout(remaining_stdout)
|
||||
stdout_content += remaining_stdout
|
||||
if remaining_stderr:
|
||||
mux.write_stderr(remaining_stderr)
|
||||
stderr_content += remaining_stderr
|
||||
|
||||
if pid in _processes:
|
||||
_processes.pop(pid)
|
||||
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"stdout": stdout_content,
|
||||
"stderr": stderr_content,
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
if time.time() - start_time > timeout_duration:
|
||||
return {
|
||||
"status": "running",
|
||||
"message": "Process is still running. Call tail_process again to continue monitoring.",
|
||||
"stdout_so_far": stdout_content,
|
||||
"stderr_so_far": stderr_content,
|
||||
"pid": pid,
|
||||
}
|
||||
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
for pipe in ready:
|
||||
if pipe == process.stdout:
|
||||
line = process.stdout.readline()
|
||||
if line:
|
||||
mux.write_stdout(line)
|
||||
stdout_content += line
|
||||
elif pipe == process.stderr:
|
||||
line = process.stderr.readline()
|
||||
if line:
|
||||
mux.write_stderr(line)
|
||||
stderr_content += line
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
else:
|
||||
return {"status": "error", "error": f"Process {pid} not found"}
|
||||
|
||||
|
||||
def run_command(command, timeout=30, monitored=False):
|
||||
mux_name = None
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
_register_process(process.pid, process)
|
||||
|
||||
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
|
||||
|
||||
start_time = time.time()
|
||||
timeout_duration = timeout
|
||||
stdout_content = ""
|
||||
stderr_content = ""
|
||||
|
||||
while True:
|
||||
if process.poll() is not None:
|
||||
remaining_stdout, remaining_stderr = process.communicate()
|
||||
if remaining_stdout:
|
||||
mux.write_stdout(remaining_stdout)
|
||||
stdout_content += remaining_stdout
|
||||
if remaining_stderr:
|
||||
mux.write_stderr(remaining_stderr)
|
||||
stderr_content += remaining_stderr
|
||||
|
||||
if process.pid in _processes:
|
||||
_processes.pop(process.pid)
|
||||
|
||||
close_multiplexer(mux_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"stdout": stdout_content,
|
||||
"stderr": stderr_content,
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
if time.time() - start_time > timeout_duration:
|
||||
return {
|
||||
"status": "running",
|
||||
"message": f"Process still running after {timeout}s timeout. Use tail_process({process.pid}) to monitor or kill_process({process.pid}) to terminate.",
|
||||
"stdout_so_far": stdout_content,
|
||||
"stderr_so_far": stderr_content,
|
||||
"pid": process.pid,
|
||||
"mux_name": mux_name,
|
||||
}
|
||||
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
for pipe in ready:
|
||||
if pipe == process.stdout:
|
||||
line = process.stdout.readline()
|
||||
if line:
|
||||
mux.write_stdout(line)
|
||||
stdout_content += line
|
||||
elif pipe == process.stderr:
|
||||
line = process.stderr.readline()
|
||||
if line:
|
||||
mux.write_stderr(line)
|
||||
stderr_content += line
|
||||
except Exception as e:
|
||||
if mux_name:
|
||||
close_multiplexer(mux_name)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def run_command_interactive(command):
|
||||
try:
|
||||
return_code = os.system(command)
|
||||
return {"status": "success", "returncode": return_code}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
65
pr/tools/context_modifier.py
Normal file
65
pr/tools/context_modifier.py
Normal file
@ -0,0 +1,65 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
CONTEXT_FILE = '/home/retoor/.local/share/rp/.rcontext.txt'
|
||||
|
||||
def _read_context() -> str:
|
||||
if not os.path.exists(CONTEXT_FILE):
|
||||
raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.")
|
||||
with open(CONTEXT_FILE, 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def _write_context(content: str):
|
||||
with open(CONTEXT_FILE, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
def modify_context_add(new_content: str, position: Optional[str] = None) -> str:
|
||||
"""
|
||||
Add new content to the .rcontext.txt file.
|
||||
|
||||
Args:
|
||||
new_content: The content to add.
|
||||
position: Optional marker to insert before (e.g., '***').
|
||||
"""
|
||||
current = _read_context()
|
||||
if position and position in current:
|
||||
# Insert before the position
|
||||
parts = current.split(position, 1)
|
||||
updated = parts[0] + new_content + '\n\n' + position + parts[1]
|
||||
else:
|
||||
# Append at the end
|
||||
updated = current + '\n\n' + new_content
|
||||
_write_context(updated)
|
||||
return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested."
|
||||
|
||||
def modify_context_replace(old_content: str, new_content: str) -> str:
|
||||
"""
|
||||
Replace old content with new content in .rcontext.txt.
|
||||
|
||||
Args:
|
||||
old_content: The content to replace.
|
||||
new_content: The replacement content.
|
||||
"""
|
||||
current = _read_context()
|
||||
if old_content not in current:
|
||||
raise ValueError(f"Old content not found: {old_content[:50]}...")
|
||||
updated = current.replace(old_content, new_content, 1) # Replace first occurrence
|
||||
_write_context(updated)
|
||||
return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects."
|
||||
|
||||
def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str:
|
||||
"""
|
||||
Delete content from .rcontext.txt, but only if confirmed.
|
||||
|
||||
Args:
|
||||
content_to_delete: The content to delete.
|
||||
confirmed: Must be True to proceed with deletion.
|
||||
"""
|
||||
if not confirmed:
|
||||
raise PermissionError(f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently.")
|
||||
current = _read_context()
|
||||
if content_to_delete not in current:
|
||||
raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...")
|
||||
updated = current.replace(content_to_delete, '', 1)
|
||||
_write_context(updated)
|
||||
return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines."
|
||||
@ -2,6 +2,16 @@ import time
|
||||
|
||||
|
||||
def db_set(key, value, db_conn):
|
||||
"""Set a key-value pair in the database.
|
||||
|
||||
Args:
|
||||
key: The key to set.
|
||||
value: The value to store.
|
||||
db_conn: Database connection.
|
||||
|
||||
Returns:
|
||||
Dict with status and message.
|
||||
"""
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
@ -19,6 +29,15 @@ def db_set(key, value, db_conn):
|
||||
|
||||
|
||||
def db_get(key, db_conn):
|
||||
"""Get a value from the database.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve.
|
||||
db_conn: Database connection.
|
||||
|
||||
Returns:
|
||||
Dict with status and value.
|
||||
"""
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
@ -35,6 +54,15 @@ def db_get(key, db_conn):
|
||||
|
||||
|
||||
def db_query(query, db_conn):
|
||||
"""Execute a database query.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute.
|
||||
db_conn: Database connection.
|
||||
|
||||
Returns:
|
||||
Dict with status and query results.
|
||||
"""
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
|
||||
838
pr/tools/debugging.py
Normal file
838
pr/tools/debugging.py
Normal file
@ -0,0 +1,838 @@
|
||||
import sys
|
||||
import os
|
||||
import ast
|
||||
import inspect
|
||||
import time
|
||||
import threading
|
||||
import gc
|
||||
import weakref
|
||||
import linecache
|
||||
import re
|
||||
import json
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
def __init__(self):
|
||||
self.allocations = defaultdict(list)
|
||||
self.references = weakref.WeakValueDictionary()
|
||||
self.peak_memory = 0
|
||||
self.current_memory = 0
|
||||
|
||||
def track_object(self, obj, location):
|
||||
try:
|
||||
obj_id = id(obj)
|
||||
obj_size = sys.getsizeof(obj)
|
||||
self.allocations[location].append(
|
||||
{
|
||||
"id": obj_id,
|
||||
"type": type(obj).__name__,
|
||||
"size": obj_size,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
self.current_memory += obj_size
|
||||
if self.current_memory > self.peak_memory:
|
||||
self.peak_memory = self.current_memory
|
||||
except:
|
||||
pass
|
||||
|
||||
def analyze_leaks(self):
|
||||
gc.collect()
|
||||
leaks = []
|
||||
for obj in gc.get_objects():
|
||||
if sys.getrefcount(obj) > 10:
|
||||
try:
|
||||
leaks.append(
|
||||
{
|
||||
"type": type(obj).__name__,
|
||||
"refcount": sys.getrefcount(obj),
|
||||
"size": sys.getsizeof(obj),
|
||||
}
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return sorted(leaks, key=lambda x: x["refcount"], reverse=True)[:20]
|
||||
|
||||
def get_report(self):
|
||||
return {
|
||||
"peak_memory": self.peak_memory,
|
||||
"current_memory": self.current_memory,
|
||||
"allocation_count": sum(len(v) for v in self.allocations.values()),
|
||||
"leaks": self.analyze_leaks(),
|
||||
}
|
||||
|
||||
|
||||
class PerformanceProfiler:
|
||||
def __init__(self):
|
||||
self.function_times = defaultdict(lambda: {"calls": 0, "total_time": 0.0, "self_time": 0.0})
|
||||
self.call_stack = []
|
||||
self.start_times = {}
|
||||
|
||||
def enter_function(self, frame):
|
||||
func_name = self._get_function_name(frame)
|
||||
self.call_stack.append(func_name)
|
||||
self.start_times[id(frame)] = time.perf_counter()
|
||||
|
||||
def exit_function(self, frame):
|
||||
func_name = self._get_function_name(frame)
|
||||
frame_id = id(frame)
|
||||
|
||||
if frame_id in self.start_times:
|
||||
elapsed = time.perf_counter() - self.start_times[frame_id]
|
||||
self.function_times[func_name]["calls"] += 1
|
||||
self.function_times[func_name]["total_time"] += elapsed
|
||||
self.function_times[func_name]["self_time"] += elapsed
|
||||
del self.start_times[frame_id]
|
||||
|
||||
if self.call_stack:
|
||||
self.call_stack.pop()
|
||||
|
||||
def _get_function_name(self, frame):
|
||||
return f"{frame.f_code.co_filename}:{frame.f_code.co_name}:{frame.f_lineno}"
|
||||
|
||||
def get_hotspots(self, limit=20):
|
||||
sorted_funcs = sorted(
|
||||
self.function_times.items(), key=lambda x: x[1]["total_time"], reverse=True
|
||||
)
|
||||
return sorted_funcs[:limit]
|
||||
|
||||
|
||||
class StaticAnalyzer(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.issues = []
|
||||
self.complexity = 0
|
||||
self.unused_vars = set()
|
||||
self.undefined_vars = set()
|
||||
self.defined_vars = set()
|
||||
self.used_vars = set()
|
||||
self.functions = {}
|
||||
self.classes = {}
|
||||
self.imports = []
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.functions[node.name] = {
|
||||
"lineno": node.lineno,
|
||||
"args": [arg.arg for arg in node.args.args],
|
||||
"decorators": [
|
||||
d.id if isinstance(d, ast.Name) else "complex" for d in node.decorator_list
|
||||
],
|
||||
"complexity": self._calculate_complexity(node),
|
||||
}
|
||||
|
||||
if len(node.body) == 0:
|
||||
self.issues.append(f"Line {node.lineno}: Empty function '{node.name}'")
|
||||
|
||||
if len(node.args.args) > 7:
|
||||
self.issues.append(
|
||||
f"Line {node.lineno}: Function '{node.name}' has too many parameters ({len(node.args.args)})"
|
||||
)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
self.classes[node.name] = {
|
||||
"lineno": node.lineno,
|
||||
"bases": [b.id if isinstance(b, ast.Name) else "complex" for b in node.bases],
|
||||
"methods": [],
|
||||
}
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Import(self, node):
|
||||
for alias in node.names:
|
||||
self.imports.append(alias.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
if node.module:
|
||||
self.imports.append(node.module)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Name(self, node):
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
self.defined_vars.add(node.id)
|
||||
elif isinstance(node.ctx, ast.Load):
|
||||
self.used_vars.add(node.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_If(self, node):
|
||||
self.complexity += 1
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node):
|
||||
self.complexity += 1
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_While(self, node):
|
||||
self.complexity += 1
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ExceptHandler(self, node):
|
||||
self.complexity += 1
|
||||
if node.type is None:
|
||||
self.issues.append(f"Line {node.lineno}: Bare except clause (catches all exceptions)")
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
if isinstance(node.op, ast.Add):
|
||||
if isinstance(node.left, ast.Str) or isinstance(node.right, ast.Str):
|
||||
self.issues.append(
|
||||
f"Line {node.lineno}: String concatenation with '+' (use f-strings or join)"
|
||||
)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _calculate_complexity(self, node):
|
||||
complexity = 1
|
||||
for child in ast.walk(node):
|
||||
if isinstance(child, (ast.If, ast.For, ast.While, ast.ExceptHandler)):
|
||||
complexity += 1
|
||||
return complexity
|
||||
|
||||
def finalize(self):
|
||||
self.unused_vars = self.defined_vars - self.used_vars
|
||||
self.undefined_vars = self.used_vars - self.defined_vars - set(dir(__builtins__))
|
||||
|
||||
for var in self.unused_vars:
|
||||
self.issues.append(f"Unused variable: '{var}'")
|
||||
|
||||
def analyze_code(self, source_code):
|
||||
try:
|
||||
tree = ast.parse(source_code)
|
||||
self.visit(tree)
|
||||
self.finalize()
|
||||
return {
|
||||
"issues": self.issues,
|
||||
"complexity": self.complexity,
|
||||
"functions": self.functions,
|
||||
"classes": self.classes,
|
||||
"imports": self.imports,
|
||||
"unused_vars": list(self.unused_vars),
|
||||
}
|
||||
except SyntaxError as e:
|
||||
return {"error": f"Syntax error at line {e.lineno}: {e.msg}"}
|
||||
|
||||
|
||||
class DynamicTracer:
|
||||
def __init__(self):
|
||||
self.execution_trace = []
|
||||
self.exception_trace = []
|
||||
self.variable_changes = defaultdict(list)
|
||||
self.line_coverage = set()
|
||||
self.function_calls = defaultdict(int)
|
||||
self.max_trace_length = 10000
|
||||
self.memory_tracker = MemoryTracker()
|
||||
self.profiler = PerformanceProfiler()
|
||||
self.trace_active = False
|
||||
|
||||
def trace_calls(self, frame, event, arg):
|
||||
if not self.trace_active:
|
||||
return
|
||||
|
||||
if len(self.execution_trace) >= self.max_trace_length:
|
||||
return self.trace_calls
|
||||
|
||||
co = frame.f_code
|
||||
func_name = co.co_name
|
||||
filename = co.co_filename
|
||||
line_no = frame.f_lineno
|
||||
|
||||
if "site-packages" in filename or filename.startswith("<"):
|
||||
return self.trace_calls
|
||||
|
||||
trace_entry = {
|
||||
"event": event,
|
||||
"function": func_name,
|
||||
"filename": filename,
|
||||
"lineno": line_no,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
if event == "call":
|
||||
self.function_calls[f"{filename}:{func_name}"] += 1
|
||||
self.profiler.enter_function(frame)
|
||||
trace_entry["locals"] = {
|
||||
k: repr(v)[:100] for k, v in frame.f_locals.items() if not k.startswith("__")
|
||||
}
|
||||
|
||||
elif event == "return":
|
||||
self.profiler.exit_function(frame)
|
||||
trace_entry["return_value"] = repr(arg)[:100] if arg else None
|
||||
|
||||
elif event == "line":
|
||||
self.line_coverage.add((filename, line_no))
|
||||
line_code = linecache.getline(filename, line_no).strip()
|
||||
trace_entry["code"] = line_code
|
||||
|
||||
for var, value in frame.f_locals.items():
|
||||
if not var.startswith("__"):
|
||||
self.variable_changes[var].append(
|
||||
{"line": line_no, "value": repr(value)[:100], "timestamp": time.time()}
|
||||
)
|
||||
self.memory_tracker.track_object(value, f"{filename}:{line_no}")
|
||||
|
||||
elif event == "exception":
|
||||
exc_type, exc_value, exc_tb = arg
|
||||
self.exception_trace.append(
|
||||
{
|
||||
"type": exc_type.__name__,
|
||||
"message": str(exc_value),
|
||||
"filename": filename,
|
||||
"function": func_name,
|
||||
"lineno": line_no,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
trace_entry["exception"] = {"type": exc_type.__name__, "message": str(exc_value)}
|
||||
|
||||
self.execution_trace.append(trace_entry)
|
||||
return self.trace_calls
|
||||
|
||||
def start_tracing(self):
|
||||
self.trace_active = True
|
||||
sys.settrace(self.trace_calls)
|
||||
threading.settrace(self.trace_calls)
|
||||
|
||||
def stop_tracing(self):
|
||||
self.trace_active = False
|
||||
sys.settrace(None)
|
||||
threading.settrace(None)
|
||||
|
||||
def get_trace_report(self):
|
||||
return {
|
||||
"execution_trace": self.execution_trace[-100:],
|
||||
"exception_trace": self.exception_trace,
|
||||
"line_coverage": list(self.line_coverage),
|
||||
"function_calls": dict(self.function_calls),
|
||||
"variable_changes": {k: v[-10:] for k, v in self.variable_changes.items()},
|
||||
"hotspots": self.profiler.get_hotspots(20),
|
||||
"memory_report": self.memory_tracker.get_report(),
|
||||
}
|
||||
|
||||
|
||||
class GitBisectAutomator:
|
||||
def __init__(self, repo_path="."):
|
||||
self.repo_path = repo_path
|
||||
|
||||
def is_git_repo(self):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--git-dir"],
|
||||
cwd=self.repo_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_commit_history(self, limit=50):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "log", f"--max-count={limit}", "--oneline"],
|
||||
cwd=self.repo_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
commits = []
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
parts = line.split(" ", 1)
|
||||
if len(parts) == 2:
|
||||
commits.append({"hash": parts[0], "message": parts[1]})
|
||||
return commits
|
||||
except:
|
||||
pass
|
||||
return []
|
||||
|
||||
def blame_file(self, filepath):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "blame", "-L", "1,50", filepath],
|
||||
cwd=self.repo_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class LogAnalyzer:
|
||||
def __init__(self):
|
||||
self.log_patterns = {
|
||||
"error": re.compile(r"error|exception|fail|critical", re.IGNORECASE),
|
||||
"warning": re.compile(r"warn|caution", re.IGNORECASE),
|
||||
"debug": re.compile(r"debug|trace", re.IGNORECASE),
|
||||
"timestamp": re.compile(r"\\d{4}-\\d{2}-\\d{2}[\\s_T]\\d{2}:\\d{2}:\\d{2}"),
|
||||
}
|
||||
self.anomalies = []
|
||||
|
||||
def analyze_logs(self, log_content):
|
||||
lines = log_content.split("\n")
|
||||
errors = []
|
||||
warnings = []
|
||||
timestamps = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if self.log_patterns["error"].search(line):
|
||||
errors.append({"line": i + 1, "content": line})
|
||||
elif self.log_patterns["warning"].search(line):
|
||||
warnings.append({"line": i + 1, "content": line})
|
||||
|
||||
ts_match = self.log_patterns["timestamp"].search(line)
|
||||
if ts_match:
|
||||
timestamps.append(ts_match.group())
|
||||
|
||||
return {
|
||||
"total_lines": len(lines),
|
||||
"errors": errors[:50],
|
||||
"warnings": warnings[:50],
|
||||
"error_count": len(errors),
|
||||
"warning_count": len(warnings),
|
||||
"timestamps": timestamps[:20],
|
||||
}
|
||||
|
||||
|
||||
class ExceptionAnalyzer:
|
||||
def __init__(self):
|
||||
self.exceptions = []
|
||||
self.exception_counts = defaultdict(int)
|
||||
|
||||
def capture_exception(self, exc_type, exc_value, exc_traceback):
|
||||
exc_info = {
|
||||
"type": exc_type.__name__,
|
||||
"message": str(exc_value),
|
||||
"timestamp": time.time(),
|
||||
"traceback": [],
|
||||
}
|
||||
|
||||
tb = exc_traceback
|
||||
while tb is not None:
|
||||
frame = tb.tb_frame
|
||||
exc_info["traceback"].append(
|
||||
{
|
||||
"filename": frame.f_code.co_filename,
|
||||
"function": frame.f_code.co_name,
|
||||
"lineno": tb.tb_lineno,
|
||||
"locals": {
|
||||
k: repr(v)[:100]
|
||||
for k, v in frame.f_locals.items()
|
||||
if not k.startswith("__")
|
||||
},
|
||||
}
|
||||
)
|
||||
tb = tb.tb_next
|
||||
|
||||
self.exceptions.append(exc_info)
|
||||
self.exception_counts[exc_type.__name__] += 1
|
||||
return exc_info
|
||||
|
||||
def get_report(self):
|
||||
return {
|
||||
"total_exceptions": len(self.exceptions),
|
||||
"exception_types": dict(self.exception_counts),
|
||||
"recent_exceptions": self.exceptions[-10:],
|
||||
}
|
||||
|
||||
|
||||
class TestGenerator:
|
||||
def __init__(self):
|
||||
self.test_cases = []
|
||||
|
||||
def generate_tests_for_function(self, func_name, func_signature):
|
||||
test_template = f"""def test_{func_name}_basic():
|
||||
result = {func_name}()
|
||||
assert result is not None
|
||||
|
||||
def test_{func_name}_edge_cases():
|
||||
pass
|
||||
|
||||
def test_{func_name}_exceptions():
|
||||
pass
|
||||
"""
|
||||
return test_template
|
||||
|
||||
def analyze_function_for_tests(self, func_obj):
|
||||
sig = inspect.signature(func_obj)
|
||||
test_inputs = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
param_type = param.annotation
|
||||
if param_type == int:
|
||||
test_inputs.append([0, 1, -1, 100])
|
||||
elif param_type == str:
|
||||
test_inputs.append(["", "test", "a" * 100])
|
||||
elif param_type == list:
|
||||
test_inputs.append([[], [1], [1, 2, 3]])
|
||||
else:
|
||||
test_inputs.append([None])
|
||||
else:
|
||||
test_inputs.append([None, 0, "", []])
|
||||
|
||||
return test_inputs
|
||||
|
||||
|
||||
class CodeFlowVisualizer:
|
||||
def __init__(self):
|
||||
self.flow_graph = defaultdict(list)
|
||||
self.call_hierarchy = defaultdict(set)
|
||||
|
||||
def build_flow_from_trace(self, execution_trace):
|
||||
for i in range(len(execution_trace) - 1):
|
||||
current = execution_trace[i]
|
||||
next_step = execution_trace[i + 1]
|
||||
|
||||
current_node = f"{current['function']}:{current['lineno']}"
|
||||
next_node = f"{next_step['function']}:{next_step['lineno']}"
|
||||
|
||||
self.flow_graph[current_node].append(next_node)
|
||||
|
||||
if current["event"] == "call":
|
||||
caller = current["function"]
|
||||
callee = next_step["function"]
|
||||
self.call_hierarchy[caller].add(callee)
|
||||
|
||||
def generate_text_visualization(self):
|
||||
output = []
|
||||
output.append("Call Hierarchy:")
|
||||
for caller, callees in sorted(self.call_hierarchy.items()):
|
||||
output.append(f" {caller}")
|
||||
for callee in sorted(callees):
|
||||
output.append(f" -> {callee}")
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
class AutomatedDebugger:
|
||||
def __init__(self):
|
||||
self.static_analyzer = StaticAnalyzer()
|
||||
self.dynamic_tracer = DynamicTracer()
|
||||
self.exception_analyzer = ExceptionAnalyzer()
|
||||
self.log_analyzer = LogAnalyzer()
|
||||
self.git_automator = GitBisectAutomator()
|
||||
self.test_generator = TestGenerator()
|
||||
self.flow_visualizer = CodeFlowVisualizer()
|
||||
self.original_excepthook = sys.excepthook
|
||||
|
||||
def analyze_source_file(self, filepath):
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
source_code = f.read()
|
||||
return self.static_analyzer.analyze_code(source_code)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def run_with_tracing(self, target_func, *args, **kwargs):
|
||||
self.dynamic_tracer.start_tracing()
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
try:
|
||||
result = target_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exception = self.exception_analyzer.capture_exception(type(e), e, e.__traceback__)
|
||||
finally:
|
||||
self.dynamic_tracer.stop_tracing()
|
||||
|
||||
self.flow_visualizer.build_flow_from_trace(self.dynamic_tracer.execution_trace)
|
||||
|
||||
return {
|
||||
"result": result,
|
||||
"exception": exception,
|
||||
"trace": self.dynamic_tracer.get_trace_report(),
|
||||
"flow": self.flow_visualizer.generate_text_visualization(),
|
||||
}
|
||||
|
||||
def analyze_logs(self, log_file_or_content):
|
||||
if os.path.isfile(log_file_or_content):
|
||||
with open(log_file_or_content, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
else:
|
||||
content = log_file_or_content
|
||||
|
||||
return self.log_analyzer.analyze_logs(content)
|
||||
|
||||
def generate_debug_report(self, output_file="debug_report.json"):
|
||||
report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"static_analysis": self.static_analyzer.issues,
|
||||
"dynamic_trace": self.dynamic_tracer.get_trace_report(),
|
||||
"exceptions": self.exception_analyzer.get_report(),
|
||||
"git_info": (
|
||||
self.git_automator.get_commit_history(10)
|
||||
if self.git_automator.is_git_repo()
|
||||
else None
|
||||
),
|
||||
"flow_visualization": self.flow_visualizer.generate_text_visualization(),
|
||||
}
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(report, f, indent=2, default=str)
|
||||
|
||||
return report
|
||||
|
||||
def auto_debug_function(self, func, test_inputs=None):
|
||||
results = []
|
||||
|
||||
if test_inputs is None:
|
||||
test_inputs = self.test_generator.analyze_function_for_tests(func)
|
||||
|
||||
for input_set in test_inputs:
|
||||
try:
|
||||
if isinstance(input_set, list):
|
||||
result = self.run_with_tracing(func, *input_set)
|
||||
else:
|
||||
result = self.run_with_tracing(func, input_set)
|
||||
results.append(
|
||||
{
|
||||
"input": input_set,
|
||||
"success": result["exception"] is None,
|
||||
"output": result["result"],
|
||||
"trace_summary": {
|
||||
"function_calls": len(result["trace"]["function_calls"]),
|
||||
"exceptions": len(result["trace"]["exception_trace"]),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
results.append({"input": input_set, "success": False, "error": str(e)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global instances for tools
|
||||
_memory_tracker = MemoryTracker()
|
||||
_performance_profiler = PerformanceProfiler()
|
||||
_static_analyzer = StaticAnalyzer()
|
||||
_dynamic_tracer = DynamicTracer()
|
||||
_git_automator = GitBisectAutomator()
|
||||
_log_analyzer = LogAnalyzer()
|
||||
_exception_analyzer = ExceptionAnalyzer()
|
||||
_test_generator = TestGenerator()
|
||||
_code_flow_visualizer = CodeFlowVisualizer()
|
||||
_automated_debugger = AutomatedDebugger()
|
||||
|
||||
|
||||
# Tool functions
|
||||
def track_memory_allocation(location: str = "manual") -> dict:
|
||||
"""Track current memory allocation at a specific location."""
|
||||
try:
|
||||
_memory_tracker.track_object({}, location)
|
||||
return {
|
||||
"status": "success",
|
||||
"current_memory": _memory_tracker.current_memory,
|
||||
"peak_memory": _memory_tracker.peak_memory,
|
||||
"location": location,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def analyze_memory_leaks() -> dict:
|
||||
"""Analyze potential memory leaks in the current process."""
|
||||
try:
|
||||
leaks = _memory_tracker.analyze_leaks()
|
||||
return {"status": "success", "leaks_found": len(leaks), "top_leaks": leaks[:10]}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def get_memory_report() -> dict:
|
||||
"""Get a comprehensive memory usage report."""
|
||||
try:
|
||||
return {"status": "success", "report": _memory_tracker.get_report()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def start_performance_profiling() -> dict:
|
||||
"""Start performance profiling."""
|
||||
try:
|
||||
PerformanceProfiler()
|
||||
return {"status": "success", "message": "Performance profiling started"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def stop_performance_profiling() -> dict:
|
||||
"""Stop performance profiling and get hotspots."""
|
||||
try:
|
||||
hotspots = _performance_profiler.get_hotspots(20)
|
||||
return {"status": "success", "hotspots": hotspots}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def analyze_source_code(source_code: str) -> dict:
|
||||
"""Perform static analysis on Python source code."""
|
||||
try:
|
||||
analyzer = StaticAnalyzer()
|
||||
result = analyzer.analyze_code(source_code)
|
||||
return {"status": "success", "analysis": result}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def analyze_source_file(filepath: str) -> dict:
|
||||
"""Analyze a Python source file statically."""
|
||||
try:
|
||||
result = _automated_debugger.analyze_source_file(filepath)
|
||||
return {"status": "success", "filepath": filepath, "analysis": result}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def start_dynamic_tracing() -> dict:
|
||||
"""Start dynamic execution tracing."""
|
||||
try:
|
||||
_dynamic_tracer.start_tracing()
|
||||
return {"status": "success", "message": "Dynamic tracing started"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def stop_dynamic_tracing() -> dict:
|
||||
"""Stop dynamic tracing and get trace report."""
|
||||
try:
|
||||
_dynamic_tracer.stop_tracing()
|
||||
report = _dynamic_tracer.get_trace_report()
|
||||
return {"status": "success", "trace_report": report}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def get_git_commit_history(limit: int = 50) -> dict:
|
||||
"""Get recent git commit history."""
|
||||
try:
|
||||
commits = _git_automator.get_commit_history(limit)
|
||||
return {
|
||||
"status": "success",
|
||||
"commits": commits,
|
||||
"is_git_repo": _git_automator.is_git_repo(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def blame_file(filepath: str) -> dict:
|
||||
"""Get git blame information for a file."""
|
||||
try:
|
||||
blame_output = _git_automator.blame_file(filepath)
|
||||
return {"status": "success", "filepath": filepath, "blame": blame_output}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def analyze_log_content(log_content: str) -> dict:
|
||||
"""Analyze log content for errors, warnings, and patterns."""
|
||||
try:
|
||||
analysis = _log_analyzer.analyze_logs(log_content)
|
||||
return {"status": "success", "analysis": analysis}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def analyze_log_file(filepath: str) -> dict:
|
||||
"""Analyze a log file."""
|
||||
try:
|
||||
analysis = _automated_debugger.analyze_logs(filepath)
|
||||
return {"status": "success", "filepath": filepath, "analysis": analysis}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def get_exception_report() -> dict:
|
||||
"""Get a report of captured exceptions."""
|
||||
try:
|
||||
report = _exception_analyzer.get_report()
|
||||
return {"status": "success", "exception_report": report}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def generate_tests_for_function(func_name: str, func_signature: str = "") -> dict:
|
||||
"""Generate test templates for a function."""
|
||||
try:
|
||||
test_code = _test_generator.generate_tests_for_function(func_name, func_signature)
|
||||
return {"status": "success", "func_name": func_name, "test_code": test_code}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def visualize_code_flow_from_trace(execution_trace) -> dict:
|
||||
"""Visualize code flow from execution trace."""
|
||||
try:
|
||||
visualizer = CodeFlowVisualizer()
|
||||
visualizer.build_flow_from_trace(execution_trace)
|
||||
visualization = visualizer.generate_text_visualization()
|
||||
return {"status": "success", "flow_visualization": visualization}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def run_function_with_debugging(func_code: str, *args, **kwargs) -> dict:
|
||||
"""Execute a function with full debugging."""
|
||||
try:
|
||||
# Compile and execute the function
|
||||
local_vars = {}
|
||||
exec(func_code, globals(), local_vars)
|
||||
|
||||
# Find the function (assuming it's the last defined function)
|
||||
func = None
|
||||
for name, obj in local_vars.items():
|
||||
if callable(obj) and not name.startswith("_"):
|
||||
func = obj
|
||||
break
|
||||
|
||||
if func is None:
|
||||
return {"status": "error", "error": "No function found in code"}
|
||||
|
||||
result = _automated_debugger.run_with_tracing(func, *args, **kwargs)
|
||||
return {"status": "success", "debug_result": result}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def generate_comprehensive_debug_report(output_file: str = "debug_report.json") -> dict:
|
||||
"""Generate a comprehensive debug report."""
|
||||
try:
|
||||
report = _automated_debugger.generate_debug_report(output_file)
|
||||
return {
|
||||
"status": "success",
|
||||
"output_file": output_file,
|
||||
"report_summary": {
|
||||
"static_issues": len(report.get("static_analysis", [])),
|
||||
"exceptions": report.get("exceptions", {}).get("total_exceptions", 0),
|
||||
"function_calls": len(report.get("dynamic_trace", {}).get("function_calls", {})),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def auto_debug_function(func_code: str, test_inputs: list = None) -> dict:
|
||||
"""Automatically debug a function with test inputs."""
|
||||
try:
|
||||
local_vars = {}
|
||||
exec(func_code, globals(), local_vars)
|
||||
|
||||
func = None
|
||||
for name, obj in local_vars.items():
|
||||
if callable(obj) and not name.startswith("_"):
|
||||
func = obj
|
||||
break
|
||||
|
||||
if func is None:
|
||||
return {"status": "error", "error": "No function found in code"}
|
||||
|
||||
if test_inputs is None:
|
||||
test_inputs = _test_generator.analyze_function_for_tests(func)
|
||||
|
||||
results = _automated_debugger.auto_debug_function(func, test_inputs)
|
||||
return {"status": "success", "debug_results": results}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import os.path
|
||||
|
||||
from pr.editor import RPEditor
|
||||
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
|
||||
|
||||
|
||||
from ..tools.patch import display_content_diff
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
@ -17,6 +17,8 @@ def get_editor(filepath):
|
||||
|
||||
|
||||
def close_editor(filepath):
|
||||
from pr.multiplexer import close_multiplexer, get_multiplexer
|
||||
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = get_editor(path)
|
||||
@ -34,6 +36,8 @@ def close_editor(filepath):
|
||||
|
||||
|
||||
def open_editor(filepath):
|
||||
from pr.multiplexer import create_multiplexer
|
||||
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = RPEditor(path)
|
||||
@ -53,6 +57,8 @@ def open_editor(filepath):
|
||||
|
||||
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
from pr.multiplexer import get_multiplexer
|
||||
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
|
||||
@ -98,6 +104,8 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
def editor_replace_text(
|
||||
filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True
|
||||
):
|
||||
from pr.multiplexer import get_multiplexer
|
||||
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
|
||||
@ -148,6 +156,8 @@ def editor_replace_text(
|
||||
|
||||
|
||||
def editor_search(filepath, pattern, start_line=0):
|
||||
from pr.multiplexer import get_multiplexer
|
||||
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
editor = RPEditor(path)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Any
|
||||
|
||||
from pr.editor import RPEditor
|
||||
|
||||
@ -17,7 +18,17 @@ def get_uid():
|
||||
return _id
|
||||
|
||||
|
||||
def read_file(filepath, db_conn=None):
|
||||
def read_file(filepath: str, db_conn: Optional[Any] = None) -> dict:
|
||||
"""
|
||||
Read the contents of a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the file to read
|
||||
db_conn: Optional database connection for tracking
|
||||
|
||||
Returns:
|
||||
dict: Status and content or error
|
||||
"""
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
with open(path) as f:
|
||||
@ -31,7 +42,22 @@ def read_file(filepath, db_conn=None):
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
def write_file(
|
||||
filepath: str, content: str, db_conn: Optional[Any] = None, show_diff: bool = True
|
||||
) -> dict:
|
||||
"""
|
||||
Write content to a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the file to write
|
||||
content: Content to write
|
||||
db_conn: Optional database connection for tracking
|
||||
show_diff: Whether to show diff of changes
|
||||
|
||||
Returns:
|
||||
dict: Status and message or error
|
||||
"""
|
||||
operation = None
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
old_content = ""
|
||||
@ -93,12 +119,13 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
|
||||
return {"status": "success", "message": message}
|
||||
except Exception as e:
|
||||
if "operation" in locals():
|
||||
if operation is not None:
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def list_directory(path=".", recursive=False):
|
||||
"""List files and directories in the specified path."""
|
||||
try:
|
||||
path = os.path.expanduser(path)
|
||||
items = []
|
||||
@ -139,6 +166,7 @@ def mkdir(path):
|
||||
|
||||
|
||||
def chdir(path):
|
||||
"""Change the current working directory."""
|
||||
try:
|
||||
os.chdir(os.path.expanduser(path))
|
||||
return {"status": "success", "new_path": os.getcwd()}
|
||||
@ -147,13 +175,23 @@ def chdir(path):
|
||||
|
||||
|
||||
def getpwd():
|
||||
"""Get the current working directory."""
|
||||
try:
|
||||
return {"status": "success", "path": os.getcwd()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def index_source_directory(path):
|
||||
def index_source_directory(path: str) -> dict:
|
||||
"""
|
||||
Index directory recursively and read all source files.
|
||||
|
||||
Args:
|
||||
path: Path to index
|
||||
|
||||
Returns:
|
||||
dict: Status and indexed files or error
|
||||
"""
|
||||
extensions = [
|
||||
".py",
|
||||
".js",
|
||||
@ -189,7 +227,21 @@ def index_source_directory(path):
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def search_replace(filepath, old_string, new_string, db_conn=None):
|
||||
def search_replace(
|
||||
filepath: str, old_string: str, new_string: str, db_conn: Optional[Any] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Search and replace text in a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the file
|
||||
old_string: String to replace
|
||||
new_string: Replacement string
|
||||
db_conn: Optional database connection for tracking
|
||||
|
||||
Returns:
|
||||
dict: Status and message or error
|
||||
"""
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if not os.path.exists(path):
|
||||
@ -246,6 +298,7 @@ def open_editor(filepath):
|
||||
|
||||
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
|
||||
operation = None
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
@ -283,7 +336,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
|
||||
tracker.mark_completed(operation)
|
||||
return {"status": "success", "message": f"Inserted text in {path}"}
|
||||
except Exception as e:
|
||||
if "operation" in locals():
|
||||
if operation is not None:
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
@ -299,6 +352,7 @@ def editor_replace_text(
|
||||
db_conn=None,
|
||||
):
|
||||
try:
|
||||
operation = None
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
@ -341,7 +395,7 @@ def editor_replace_text(
|
||||
tracker.mark_completed(operation)
|
||||
return {"status": "success", "message": f"Replaced text in {path}"}
|
||||
except Exception as e:
|
||||
if "operation" in locals():
|
||||
if operation is not None:
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
import subprocess
|
||||
import threading
|
||||
import importlib
|
||||
|
||||
from pr.multiplexer import (
|
||||
close_multiplexer,
|
||||
create_multiplexer,
|
||||
get_all_multiplexer_states,
|
||||
get_multiplexer,
|
||||
)
|
||||
|
||||
def _get_multiplexer_functions():
|
||||
"""Lazy import multiplexer functions to avoid circular imports."""
|
||||
multiplexer = importlib.import_module("pr.multiplexer")
|
||||
return {
|
||||
"create_multiplexer": multiplexer.create_multiplexer,
|
||||
"get_multiplexer": multiplexer.get_multiplexer,
|
||||
"close_multiplexer": multiplexer.close_multiplexer,
|
||||
"get_all_multiplexer_states": multiplexer.get_all_multiplexer_states,
|
||||
}
|
||||
|
||||
|
||||
def start_interactive_session(command, session_name=None, process_type="generic", cwd=None):
|
||||
@ -22,7 +27,8 @@ def start_interactive_session(command, session_name=None, process_type="generic"
|
||||
Returns:
|
||||
session_name: The name of the created session
|
||||
"""
|
||||
name, mux = create_multiplexer(session_name)
|
||||
funcs = _get_multiplexer_functions()
|
||||
name, mux = funcs["create_multiplexer"](session_name)
|
||||
mux.update_metadata("process_type", process_type)
|
||||
|
||||
# Start the process
|
||||
@ -65,7 +71,7 @@ def start_interactive_session(command, session_name=None, process_type="generic"
|
||||
|
||||
return name
|
||||
except Exception as e:
|
||||
close_multiplexer(name)
|
||||
funcs["close_multiplexer"](name)
|
||||
raise e
|
||||
|
||||
|
||||
@ -97,7 +103,8 @@ def send_input_to_session(session_name, input_data):
|
||||
session_name: Name of the session
|
||||
input_data: Input string to send
|
||||
"""
|
||||
mux = get_multiplexer(session_name)
|
||||
funcs = _get_multiplexer_functions()
|
||||
mux = funcs["get_multiplexer"](session_name)
|
||||
if not mux:
|
||||
raise ValueError(f"Session {session_name} not found")
|
||||
|
||||
@ -112,6 +119,7 @@ def send_input_to_session(session_name, input_data):
|
||||
|
||||
|
||||
def read_session_output(session_name, lines=None):
|
||||
funcs = _get_multiplexer_functions()
|
||||
"""
|
||||
Read output from a session.
|
||||
|
||||
@ -122,7 +130,7 @@ def read_session_output(session_name, lines=None):
|
||||
Returns:
|
||||
dict: {'stdout': str, 'stderr': str}
|
||||
"""
|
||||
mux = get_multiplexer(session_name)
|
||||
mux = funcs["get_multiplexer"](session_name)
|
||||
if not mux:
|
||||
raise ValueError(f"Session {session_name} not found")
|
||||
|
||||
@ -142,7 +150,8 @@ def list_active_sessions():
|
||||
Returns:
|
||||
dict: Session states
|
||||
"""
|
||||
return get_all_multiplexer_states()
|
||||
funcs = _get_multiplexer_functions()
|
||||
return funcs["get_all_multiplexer_states"]()
|
||||
|
||||
|
||||
def get_session_status(session_name):
|
||||
@ -155,7 +164,8 @@ def get_session_status(session_name):
|
||||
Returns:
|
||||
dict: Session metadata and status
|
||||
"""
|
||||
mux = get_multiplexer(session_name)
|
||||
funcs = _get_multiplexer_functions()
|
||||
mux = funcs["get_multiplexer"](session_name)
|
||||
if not mux:
|
||||
return None
|
||||
|
||||
@ -175,10 +185,11 @@ def close_interactive_session(session_name):
|
||||
Close an interactive session.
|
||||
"""
|
||||
try:
|
||||
mux = get_multiplexer(session_name)
|
||||
funcs = _get_multiplexer_functions()
|
||||
mux = funcs["get_multiplexer"](session_name)
|
||||
if mux:
|
||||
mux.process.kill()
|
||||
close_multiplexer(session_name)
|
||||
funcs["close_multiplexer"](session_name)
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
@ -7,6 +7,16 @@ from ..ui.diff_display import display_diff, get_diff_stats
|
||||
|
||||
|
||||
def apply_patch(filepath, patch_content, db_conn=None):
|
||||
"""Apply a patch to a file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the file to patch.
|
||||
patch_content: The patch content as a string.
|
||||
db_conn: Database connection (optional).
|
||||
|
||||
Returns:
|
||||
Dict with status and output.
|
||||
"""
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
@ -41,6 +51,19 @@ def apply_patch(filepath, patch_content, db_conn=None):
|
||||
def create_diff(
|
||||
file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified"
|
||||
):
|
||||
"""Create a unified diff between two files.
|
||||
|
||||
Args:
|
||||
file1: Path to the first file.
|
||||
file2: Path to the second file.
|
||||
fromfile: Label for the first file.
|
||||
tofile: Label for the second file.
|
||||
visual: Whether to include visual diff.
|
||||
format_type: Diff format type.
|
||||
|
||||
Returns:
|
||||
Dict with status and diff content.
|
||||
"""
|
||||
try:
|
||||
path1 = os.path.expanduser(file1)
|
||||
path2 = os.path.expanduser(file2)
|
||||
|
||||
@ -5,6 +5,16 @@ from io import StringIO
|
||||
|
||||
|
||||
def python_exec(code, python_globals, cwd=None):
|
||||
"""Execute Python code and capture the output.
|
||||
|
||||
Args:
|
||||
code: The Python code to execute.
|
||||
python_globals: Dictionary of global variables for execution.
|
||||
cwd: Working directory for execution.
|
||||
|
||||
Returns:
|
||||
Dict with status and output, or error information.
|
||||
"""
|
||||
try:
|
||||
original_cwd = None
|
||||
if cwd:
|
||||
|
||||
19
pr/tools/vision.py
Normal file
19
pr/tools/vision.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pr.vision import post_image as vision_post_image
|
||||
import functools
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def post_image(path: str, prompt: str = None):
|
||||
"""Post an image for analysis.
|
||||
|
||||
Args:
|
||||
path: Path to the image file.
|
||||
prompt: Optional prompt for analysis.
|
||||
|
||||
Returns:
|
||||
Analysis result.
|
||||
"""
|
||||
try:
|
||||
return vision_post_image(path=path, prompt=prompt)
|
||||
except Exception:
|
||||
raise
|
||||
@ -5,6 +5,15 @@ import urllib.request
|
||||
|
||||
|
||||
def http_fetch(url, headers=None):
|
||||
"""Fetch content from an HTTP URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch.
|
||||
headers: Optional HTTP headers.
|
||||
|
||||
Returns:
|
||||
Dict with status and content.
|
||||
"""
|
||||
try:
|
||||
req = urllib.request.Request(url)
|
||||
if headers:
|
||||
@ -30,10 +39,26 @@ def _perform_search(base_url, query, params=None):
|
||||
|
||||
|
||||
def web_search(query):
|
||||
"""Perform a web search.
|
||||
|
||||
Args:
|
||||
query: Search query.
|
||||
|
||||
Returns:
|
||||
Dict with status and search results.
|
||||
"""
|
||||
base_url = "https://search.molodetz.nl/search"
|
||||
return _perform_search(base_url, query)
|
||||
|
||||
|
||||
def web_search_news(query):
|
||||
"""Perform a web search for news.
|
||||
|
||||
Args:
|
||||
query: Search query for news.
|
||||
|
||||
Returns:
|
||||
Dict with status and news search results.
|
||||
"""
|
||||
base_url = "https://search.molodetz.nl/search"
|
||||
return _perform_search(base_url, query)
|
||||
|
||||
50
pr/vision.py
Executable file
50
pr/vision.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import http.client
|
||||
import argparse
|
||||
|
||||
import base64
|
||||
import json
|
||||
import http.client
|
||||
import pathlib
|
||||
|
||||
DEFAULT_URL = "https://static.molodetz.nl/rp.vision.cgi"
|
||||
|
||||
|
||||
def post_image(image_path: str, prompt: str = "", url: str = DEFAULT_URL):
|
||||
image_path = str(pathlib.Path(image_path).resolve().absolute())
|
||||
if not url:
|
||||
url = DEFAULT_URL
|
||||
url_parts = url.split("/")
|
||||
host = url_parts[2]
|
||||
path = "/" + "/".join(url_parts[3:])
|
||||
|
||||
with open(image_path, "rb") as file:
|
||||
image_data = file.read()
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
payload = {"data": base64_data, "path": image_path, "prompt": prompt}
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Content-Length": str(len(body)),
|
||||
"User-Agent": "Python http.client",
|
||||
}
|
||||
|
||||
conn = http.client.HTTPSConnection(host)
|
||||
conn.request("POST", path, body, headers)
|
||||
resp = conn.getresponse()
|
||||
data = resp.read()
|
||||
print("Status:", resp.status, resp.reason)
|
||||
print(data.decode())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("image_path")
|
||||
parser.add_argument("--prompt", default="")
|
||||
parser.add_argument("--url", default=DEFAULT_URL)
|
||||
args = parser.parse_args()
|
||||
post_image(args.url, args.image_path, args.prompt)
|
||||
@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "rp"
|
||||
version = "1.2.0"
|
||||
version = "1.3.0"
|
||||
description = "R python edition. The ultimate autonomous AI CLI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13.3"
|
||||
requires-python = ">=3.12"
|
||||
license = {text = "MIT"}
|
||||
keywords = ["ai", "assistant", "cli", "automation", "openrouter", "autonomous"]
|
||||
authors = [
|
||||
@ -16,6 +16,7 @@ authors = [
|
||||
dependencies = [
|
||||
"aiohttp>=3.13.2",
|
||||
"pydantic>=2.12.3",
|
||||
"prompt_toolkit>=3.0.0",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
|
||||
6
rp.py
6
rp.py
@ -2,6 +2,12 @@
|
||||
|
||||
# Trigger build
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add current directory to path to ensure imports work
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from pr.__main__ import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
from unittest.mock import mock_open, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
from pr.core.config_loader import (
|
||||
_load_config_file,
|
||||
_parse_value,
|
||||
create_default_config,
|
||||
load_config,
|
||||
)
|
||||
|
||||
|
||||
@ -36,36 +34,3 @@ def test_parse_value_bool_upper():
|
||||
def test_load_config_file_not_exists(mock_exists):
|
||||
config = _load_config_file("test.ini")
|
||||
assert config == {}
|
||||
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("configparser.ConfigParser")
|
||||
def test_load_config_file_exists(mock_parser_class, mock_exists):
|
||||
mock_parser = mock_parser_class.return_value
|
||||
mock_parser.sections.return_value = ["api"]
|
||||
mock_parser.items.return_value = [("key", "value")]
|
||||
config = _load_config_file("test.ini")
|
||||
assert "api" in config
|
||||
assert config["api"]["key"] == "value"
|
||||
|
||||
|
||||
@patch("pr.core.config_loader._load_config_file")
|
||||
def test_load_config(mock_load):
|
||||
mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}]
|
||||
config = load_config()
|
||||
assert config["api"]["key"] == "local"
|
||||
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
def test_create_default_config(mock_file):
|
||||
result = create_default_config("test.ini")
|
||||
assert result == True
|
||||
mock_file.assert_called_once_with("test.ini", "w")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called_once()
|
||||
|
||||
|
||||
@patch("builtins.open", side_effect=Exception("error"))
|
||||
def test_create_default_config_error(mock_file):
|
||||
result = create_default_config("test.ini")
|
||||
assert result == False
|
||||
|
||||
372
tests/test_conversation_memory.py
Normal file
372
tests/test_conversation_memory.py
Normal file
@ -0,0 +1,372 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
|
||||
from pr.memory.conversation_memory import ConversationMemory
|
||||
|
||||
|
||||
class TestConversationMemory:
|
||||
def setup_method(self):
|
||||
"""Set up test database for each test."""
|
||||
self.db_fd, self.db_path = tempfile.mkstemp()
|
||||
self.memory = ConversationMemory(self.db_path)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test database after each test."""
|
||||
self.memory = None
|
||||
os.close(self.db_fd)
|
||||
os.unlink(self.db_path)
|
||||
|
||||
def test_init(self):
|
||||
"""Test ConversationMemory initialization."""
|
||||
assert self.memory.db_path == self.db_path
|
||||
# Verify tables were created
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
assert "conversation_history" in tables
|
||||
assert "conversation_messages" in tables
|
||||
conn.close()
|
||||
|
||||
def test_create_conversation(self):
|
||||
"""Test creating a new conversation."""
|
||||
conversation_id = "test_conv_123"
|
||||
session_id = "test_session_456"
|
||||
|
||||
self.memory.create_conversation(conversation_id, session_id)
|
||||
|
||||
# Verify conversation was created
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] == conversation_id
|
||||
assert row[1] == session_id
|
||||
conn.close()
|
||||
|
||||
def test_create_conversation_without_session(self):
|
||||
"""Test creating a conversation without session ID."""
|
||||
conversation_id = "test_conv_no_session"
|
||||
|
||||
self.memory.create_conversation(conversation_id)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] == conversation_id
|
||||
assert row[1] is None
|
||||
conn.close()
|
||||
|
||||
def test_create_conversation_with_metadata(self):
|
||||
"""Test creating a conversation with metadata."""
|
||||
conversation_id = "test_conv_metadata"
|
||||
metadata = {"topic": "test", "priority": "high"}
|
||||
|
||||
self.memory.create_conversation(conversation_id, metadata=metadata)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT conversation_id, metadata FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] == conversation_id
|
||||
assert row[1] is not None
|
||||
conn.close()
|
||||
|
||||
def test_add_message(self):
|
||||
"""Test adding a message to a conversation."""
|
||||
conversation_id = "test_conv_msg"
|
||||
message_id = "test_msg_123"
|
||||
role = "user"
|
||||
content = "Hello, world!"
|
||||
|
||||
# Create conversation first
|
||||
self.memory.create_conversation(conversation_id)
|
||||
|
||||
# Add message
|
||||
self.memory.add_message(conversation_id, message_id, role, content)
|
||||
|
||||
# Verify message was added
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT message_id, conversation_id, role, content FROM conversation_messages WHERE message_id = ?",
|
||||
(message_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] == message_id
|
||||
assert row[1] == conversation_id
|
||||
assert row[2] == role
|
||||
assert row[3] == content
|
||||
|
||||
# Verify message count was updated
|
||||
cursor.execute(
|
||||
"SELECT message_count FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
count_row = cursor.fetchone()
|
||||
assert count_row[0] == 1
|
||||
conn.close()
|
||||
|
||||
def test_add_message_with_tool_calls(self):
|
||||
"""Test adding a message with tool calls."""
|
||||
conversation_id = "test_conv_tools"
|
||||
message_id = "test_msg_tools"
|
||||
role = "assistant"
|
||||
content = "I'll help you with that."
|
||||
tool_calls = [{"function": "test_func", "args": {"param": "value"}}]
|
||||
|
||||
self.memory.create_conversation(conversation_id)
|
||||
self.memory.add_message(conversation_id, message_id, role, content, tool_calls=tool_calls)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT tool_calls FROM conversation_messages WHERE message_id = ?", (message_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] is not None
|
||||
conn.close()
|
||||
|
||||
def test_add_message_with_metadata(self):
|
||||
"""Test adding a message with metadata."""
|
||||
conversation_id = "test_conv_meta"
|
||||
message_id = "test_msg_meta"
|
||||
role = "user"
|
||||
content = "Test message"
|
||||
metadata = {"tokens": 5, "model": "gpt-4"}
|
||||
|
||||
self.memory.create_conversation(conversation_id)
|
||||
self.memory.add_message(conversation_id, message_id, role, content, metadata=metadata)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT metadata FROM conversation_messages WHERE message_id = ?", (message_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] is not None
|
||||
conn.close()
|
||||
|
||||
def test_get_conversation_messages(self):
|
||||
"""Test retrieving conversation messages."""
|
||||
conversation_id = "test_conv_get"
|
||||
self.memory.create_conversation(conversation_id)
|
||||
|
||||
# Add multiple messages
|
||||
messages = [
|
||||
("msg1", "user", "Hello"),
|
||||
("msg2", "assistant", "Hi there"),
|
||||
("msg3", "user", "How are you?"),
|
||||
]
|
||||
|
||||
for msg_id, role, content in messages:
|
||||
self.memory.add_message(conversation_id, msg_id, role, content)
|
||||
|
||||
# Retrieve all messages
|
||||
retrieved = self.memory.get_conversation_messages(conversation_id)
|
||||
assert len(retrieved) == 3
|
||||
assert retrieved[0]["message_id"] == "msg1"
|
||||
assert retrieved[1]["message_id"] == "msg2"
|
||||
assert retrieved[2]["message_id"] == "msg3"
|
||||
|
||||
def test_get_conversation_messages_limited(self):
|
||||
"""Test retrieving limited number of conversation messages."""
|
||||
conversation_id = "test_conv_limit"
|
||||
self.memory.create_conversation(conversation_id)
|
||||
|
||||
# Add multiple messages
|
||||
for i in range(5):
|
||||
self.memory.add_message(conversation_id, f"msg{i}", "user", f"Message {i}")
|
||||
|
||||
# Retrieve limited messages
|
||||
retrieved = self.memory.get_conversation_messages(conversation_id, limit=3)
|
||||
assert len(retrieved) == 3
|
||||
# Should return most recent messages first due to DESC order
|
||||
assert retrieved[0]["message_id"] == "msg4"
|
||||
assert retrieved[1]["message_id"] == "msg3"
|
||||
assert retrieved[2]["message_id"] == "msg2"
|
||||
|
||||
def test_update_conversation_summary(self):
|
||||
"""Test updating conversation summary."""
|
||||
conversation_id = "test_conv_summary"
|
||||
self.memory.create_conversation(conversation_id)
|
||||
|
||||
summary = "This is a test conversation summary"
|
||||
topics = ["testing", "memory", "conversation"]
|
||||
|
||||
self.memory.update_conversation_summary(conversation_id, summary, topics)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT summary, topics, ended_at FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] == summary
|
||||
assert row[1] is not None # topics should be stored
|
||||
assert row[2] is not None # ended_at should be set
|
||||
conn.close()
|
||||
|
||||
def test_search_conversations(self):
|
||||
"""Test searching conversations by content."""
|
||||
# Create conversations with different content
|
||||
conv1 = "conv_search_1"
|
||||
conv2 = "conv_search_2"
|
||||
conv3 = "conv_search_3"
|
||||
|
||||
self.memory.create_conversation(conv1)
|
||||
self.memory.create_conversation(conv2)
|
||||
self.memory.create_conversation(conv3)
|
||||
|
||||
# Add messages with searchable content
|
||||
self.memory.add_message(conv1, "msg1", "user", "Python programming tutorial")
|
||||
self.memory.add_message(conv2, "msg2", "user", "JavaScript development guide")
|
||||
self.memory.add_message(conv3, "msg3", "user", "Database design principles")
|
||||
|
||||
# Search for "programming"
|
||||
results = self.memory.search_conversations("programming")
|
||||
assert len(results) == 1
|
||||
assert results[0]["conversation_id"] == conv1
|
||||
|
||||
# Search for "development"
|
||||
results = self.memory.search_conversations("development")
|
||||
assert len(results) == 1
|
||||
assert results[0]["conversation_id"] == conv2
|
||||
|
||||
def test_get_recent_conversations(self):
|
||||
"""Test getting recent conversations."""
|
||||
# Create conversations at different times
|
||||
conv1 = "conv_recent_1"
|
||||
conv2 = "conv_recent_2"
|
||||
conv3 = "conv_recent_3"
|
||||
|
||||
self.memory.create_conversation(conv1)
|
||||
time.sleep(0.01) # Small delay to ensure different timestamps
|
||||
self.memory.create_conversation(conv2)
|
||||
time.sleep(0.01)
|
||||
self.memory.create_conversation(conv3)
|
||||
|
||||
# Get recent conversations
|
||||
recent = self.memory.get_recent_conversations(limit=2)
|
||||
assert len(recent) == 2
|
||||
# Should be ordered by started_at DESC
|
||||
assert recent[0]["conversation_id"] == conv3
|
||||
assert recent[1]["conversation_id"] == conv2
|
||||
|
||||
def test_get_recent_conversations_by_session(self):
|
||||
"""Test getting recent conversations for a specific session."""
|
||||
session1 = "session_1"
|
||||
session2 = "session_2"
|
||||
|
||||
conv1 = "conv_session_1"
|
||||
conv2 = "conv_session_2"
|
||||
conv3 = "conv_session_3"
|
||||
|
||||
self.memory.create_conversation(conv1, session1)
|
||||
self.memory.create_conversation(conv2, session2)
|
||||
self.memory.create_conversation(conv3, session1)
|
||||
|
||||
# Get conversations for session1
|
||||
session_convs = self.memory.get_recent_conversations(session_id=session1)
|
||||
assert len(session_convs) == 2
|
||||
conversation_ids = [c["conversation_id"] for c in session_convs]
|
||||
assert conv1 in conversation_ids
|
||||
assert conv3 in conversation_ids
|
||||
assert conv2 not in conversation_ids
|
||||
|
||||
def test_delete_conversation(self):
|
||||
"""Test deleting a conversation."""
|
||||
conversation_id = "conv_delete"
|
||||
self.memory.create_conversation(conversation_id)
|
||||
|
||||
# Add some messages
|
||||
self.memory.add_message(conversation_id, "msg1", "user", "Test message")
|
||||
self.memory.add_message(conversation_id, "msg2", "assistant", "Response")
|
||||
|
||||
# Delete conversation
|
||||
result = self.memory.delete_conversation(conversation_id)
|
||||
assert result is True
|
||||
|
||||
# Verify conversation and messages are gone
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
assert cursor.fetchone()[0] == 0
|
||||
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM conversation_messages WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
assert cursor.fetchone()[0] == 0
|
||||
conn.close()
|
||||
|
||||
def test_delete_nonexistent_conversation(self):
|
||||
"""Test deleting a non-existent conversation."""
|
||||
result = self.memory.delete_conversation("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_get_statistics(self):
|
||||
"""Test getting memory statistics."""
|
||||
# Create some conversations and messages
|
||||
for i in range(3):
|
||||
conv_id = f"conv_stats_{i}"
|
||||
self.memory.create_conversation(conv_id)
|
||||
for j in range(2):
|
||||
self.memory.add_message(conv_id, f"msg_{i}_{j}", "user", f"Message {j}")
|
||||
|
||||
stats = self.memory.get_statistics()
|
||||
assert stats["total_conversations"] == 3
|
||||
assert stats["total_messages"] == 6
|
||||
assert stats["average_messages_per_conversation"] == 2.0
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Test that the memory can handle concurrent access."""
|
||||
import threading
|
||||
import queue
|
||||
|
||||
results = queue.Queue()
|
||||
|
||||
def worker(worker_id):
|
||||
try:
|
||||
conv_id = f"conv_thread_{worker_id}"
|
||||
self.memory.create_conversation(conv_id)
|
||||
self.memory.add_message(conv_id, f"msg_{worker_id}", "user", f"Worker {worker_id}")
|
||||
results.put(True)
|
||||
except Exception as e:
|
||||
results.put(e)
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for i in range(5):
|
||||
t = threading.Thread(target=worker, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check results
|
||||
for _ in range(5):
|
||||
result = results.get()
|
||||
assert result is True
|
||||
|
||||
# Verify all conversations were created
|
||||
recent = self.memory.get_recent_conversations(limit=10)
|
||||
assert len(recent) >= 5
|
||||
242
tests/test_fact_extractor.py
Normal file
242
tests/test_fact_extractor.py
Normal file
@ -0,0 +1,242 @@
|
||||
from pr.memory.fact_extractor import FactExtractor
|
||||
|
||||
|
||||
class TestFactExtractor:
|
||||
def setup_method(self):
|
||||
"""Set up test fixture."""
|
||||
self.extractor = FactExtractor()
|
||||
|
||||
def test_init(self):
|
||||
"""Test FactExtractor initialization."""
|
||||
assert self.extractor.fact_patterns is not None
|
||||
assert len(self.extractor.fact_patterns) > 0
|
||||
|
||||
def test_extract_facts_definition(self):
|
||||
"""Test extracting definition facts."""
|
||||
text = "John Smith is a software engineer. Python is a programming language."
|
||||
facts = self.extractor.extract_facts(text)
|
||||
|
||||
assert len(facts) >= 2
|
||||
# Check for definition pattern matches
|
||||
definition_facts = [f for f in facts if f["type"] == "definition"]
|
||||
assert len(definition_facts) >= 1
|
||||
|
||||
def test_extract_facts_temporal(self):
|
||||
"""Test extracting temporal facts."""
|
||||
text = "John was born in 1990. The company was founded in 2010."
|
||||
facts = self.extractor.extract_facts(text)
|
||||
|
||||
temporal_facts = [f for f in facts if f["type"] == "temporal"]
|
||||
assert len(temporal_facts) >= 1
|
||||
|
||||
def test_extract_facts_attribution(self):
|
||||
"""Test extracting attribution facts."""
|
||||
text = "John invented the widget. Mary developed the software."
|
||||
facts = self.extractor.extract_facts(text)
|
||||
|
||||
attribution_facts = [f for f in facts if f["type"] == "attribution"]
|
||||
assert len(attribution_facts) >= 1
|
||||
|
||||
def test_extract_facts_numeric(self):
|
||||
"""Test extracting numeric facts."""
|
||||
text = "The car costs $25,000. The house is worth $500,000."
|
||||
facts = self.extractor.extract_facts(text)
|
||||
|
||||
numeric_facts = [f for f in facts if f["type"] == "numeric"]
|
||||
assert len(numeric_facts) >= 1
|
||||
|
||||
def test_extract_facts_location(self):
|
||||
"""Test extracting location facts."""
|
||||
text = "John lives in San Francisco. The office is located in New York."
|
||||
facts = self.extractor.extract_facts(text)
|
||||
|
||||
location_facts = [f for f in facts if f["type"] == "location"]
|
||||
assert len(location_facts) >= 1
|
||||
|
||||
def test_extract_facts_entity(self):
|
||||
"""Test extracting entity facts from noun phrases."""
|
||||
text = "John Smith works at Google Inc. He uses Python programming."
|
||||
facts = self.extractor.extract_facts(text)
|
||||
|
||||
entity_facts = [f for f in facts if f["type"] == "entity"]
|
||||
assert len(entity_facts) >= 1
|
||||
|
||||
def test_extract_noun_phrases(self):
|
||||
"""Test noun phrase extraction."""
|
||||
text = "John Smith is a software engineer at Google. He works on Python Projects."
|
||||
phrases = self.extractor._extract_noun_phrases(text)
|
||||
|
||||
assert "John Smith" in phrases
|
||||
assert "Google" in phrases
|
||||
assert "Python Projects" in phrases
|
||||
|
||||
def test_extract_noun_phrases_capitalized(self):
|
||||
"""Test that only capitalized noun phrases are extracted."""
|
||||
text = "the quick brown fox jumps over the lazy dog"
|
||||
phrases = self.extractor._extract_noun_phrases(text)
|
||||
|
||||
# Should be empty since no capitalized words
|
||||
assert len(phrases) == 0
|
||||
|
||||
def test_extract_key_terms(self):
|
||||
"""Test key term extraction."""
|
||||
text = "Python is a programming language used for software development and data analysis."
|
||||
terms = self.extractor.extract_key_terms(text, top_k=5)
|
||||
|
||||
assert len(terms) <= 5
|
||||
# Should contain programming, language, software, development, data, analysis
|
||||
term_words = [term[0] for term in terms]
|
||||
assert "programming" in term_words
|
||||
assert "language" in term_words
|
||||
assert "software" in term_words
|
||||
|
||||
def test_extract_key_terms_stopwords_filtered(self):
|
||||
"""Test that stopwords are filtered from key terms."""
|
||||
text = "This is a test of the system that should work properly."
|
||||
terms = self.extractor.extract_key_terms(text)
|
||||
|
||||
term_words = [term[0] for term in terms]
|
||||
# Stopwords should not appear
|
||||
assert "this" not in term_words
|
||||
assert "is" not in term_words
|
||||
assert "a" not in term_words
|
||||
assert "of" not in term_words
|
||||
assert "the" not in term_words
|
||||
|
||||
def test_extract_relationships_employment(self):
|
||||
"""Test extracting employment relationships."""
|
||||
text = "John works for Google. Mary is employed by Microsoft."
|
||||
relationships = self.extractor.extract_relationships(text)
|
||||
|
||||
employment_rels = [r for r in relationships if r["type"] == "employment"]
|
||||
assert len(employment_rels) >= 1
|
||||
|
||||
def test_extract_relationships_ownership(self):
|
||||
"""Test extracting ownership relationships."""
|
||||
text = "John owns a car. Mary has a house."
|
||||
relationships = self.extractor.extract_relationships(text)
|
||||
|
||||
ownership_rels = [r for r in relationships if r["type"] == "ownership"]
|
||||
assert len(ownership_rels) >= 1
|
||||
|
||||
def test_extract_relationships_location(self):
|
||||
"""Test extracting location relationships."""
|
||||
text = "John located in New York. The factory belongs to Google."
|
||||
relationships = self.extractor.extract_relationships(text)
|
||||
|
||||
location_rels = [r for r in relationships if r["type"] == "location"]
|
||||
assert len(location_rels) >= 1
|
||||
|
||||
def test_extract_relationships_usage(self):
|
||||
"""Test extracting usage relationships."""
|
||||
text = "John uses Python. The company implements agile methodology."
|
||||
relationships = self.extractor.extract_relationships(text)
|
||||
|
||||
usage_rels = [r for r in relationships if r["type"] == "usage"]
|
||||
assert len(usage_rels) >= 1
|
||||
|
||||
def test_extract_metadata(self):
|
||||
"""Test metadata extraction."""
|
||||
text = "This is a test document. It contains some information about Python programming. You can visit https://python.org for more details. Contact john@example.com for questions. The project started in 2020 and costs $10,000."
|
||||
metadata = self.extractor.extract_metadata(text)
|
||||
|
||||
assert metadata["word_count"] > 0
|
||||
assert metadata["sentence_count"] > 0
|
||||
assert metadata["avg_words_per_sentence"] > 0
|
||||
assert len(metadata["urls"]) > 0
|
||||
assert len(metadata["email_addresses"]) > 0
|
||||
assert len(metadata["dates"]) > 0
|
||||
assert len(metadata["numeric_values"]) > 0
|
||||
assert metadata["has_code"] is False # No code in this text
|
||||
|
||||
def test_extract_metadata_with_code(self):
|
||||
"""Test metadata extraction with code content."""
|
||||
text = "Here is a function: def hello(): print('Hello, world!')"
|
||||
metadata = self.extractor.extract_metadata(text)
|
||||
|
||||
assert metadata["has_code"] is True
|
||||
|
||||
def test_extract_metadata_with_questions(self):
|
||||
"""Test metadata extraction with questions."""
|
||||
text = "What is Python? How does it work? Why use it?"
|
||||
metadata = self.extractor.extract_metadata(text)
|
||||
|
||||
assert metadata["has_questions"] is True
|
||||
|
||||
def test_categorize_content_programming(self):
|
||||
"""Test content categorization for programming."""
|
||||
text = "Python is a programming language used for code development and debugging."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "programming" in categories
|
||||
|
||||
def test_categorize_content_data(self):
|
||||
"""Test content categorization for data."""
|
||||
text = "The database contains records and tables with statistical analysis."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "data" in categories
|
||||
|
||||
def test_categorize_content_documentation(self):
|
||||
"""Test content categorization for documentation."""
|
||||
text = "This guide explains how to use the tutorial and manual."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "documentation" in categories
|
||||
|
||||
def test_categorize_content_configuration(self):
|
||||
"""Test content categorization for configuration."""
|
||||
text = "Configure the settings and setup the deployment environment."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "configuration" in categories
|
||||
|
||||
def test_categorize_content_testing(self):
|
||||
"""Test content categorization for testing."""
|
||||
text = "Run the tests to validate the functionality and verify quality."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "testing" in categories
|
||||
|
||||
def test_categorize_content_research(self):
|
||||
"""Test content categorization for research."""
|
||||
text = "The study investigates findings and results from the analysis."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "research" in categories
|
||||
|
||||
def test_categorize_content_planning(self):
|
||||
"""Test content categorization for planning."""
|
||||
text = "Plan the project schedule with milestones and timeline."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "planning" in categories
|
||||
|
||||
def test_categorize_content_general(self):
|
||||
"""Test content categorization defaults to general."""
|
||||
text = "This is some random text without specific keywords."
|
||||
categories = self.extractor.categorize_content(text)
|
||||
|
||||
assert "general" in categories
|
||||
|
||||
def test_extract_facts_empty_text(self):
|
||||
"""Test fact extraction with empty text."""
|
||||
facts = self.extractor.extract_facts("")
|
||||
assert len(facts) == 0
|
||||
|
||||
def test_extract_key_terms_empty_text(self):
|
||||
"""Test key term extraction with empty text."""
|
||||
terms = self.extractor.extract_key_terms("")
|
||||
assert len(terms) == 0
|
||||
|
||||
def test_extract_relationships_empty_text(self):
|
||||
"""Test relationship extraction with empty text."""
|
||||
relationships = self.extractor.extract_relationships("")
|
||||
assert len(relationships) == 0
|
||||
|
||||
def test_extract_metadata_empty_text(self):
|
||||
"""Test metadata extraction with empty text."""
|
||||
metadata = self.extractor.extract_metadata("")
|
||||
assert metadata["word_count"] == 0
|
||||
assert metadata["sentence_count"] == 0
|
||||
assert metadata["avg_words_per_sentence"] == 0.0
|
||||
344
tests/test_knowledge_store.py
Normal file
344
tests/test_knowledge_store.py
Normal file
@ -0,0 +1,344 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
|
||||
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
|
||||
|
||||
|
||||
class TestKnowledgeStore:
|
||||
def setup_method(self):
|
||||
"""Set up test database for each test."""
|
||||
self.db_fd, self.db_path = tempfile.mkstemp()
|
||||
self.store = KnowledgeStore(self.db_path)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test database after each test."""
|
||||
self.store = None
|
||||
os.close(self.db_fd)
|
||||
os.unlink(self.db_path)
|
||||
|
||||
def test_init(self):
|
||||
"""Test KnowledgeStore initialization."""
|
||||
assert self.store.db_path == self.db_path
|
||||
# Verify tables were created
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
assert "knowledge_entries" in tables
|
||||
conn.close()
|
||||
|
||||
def test_add_entry(self):
|
||||
"""Test adding a knowledge entry."""
|
||||
entry = KnowledgeEntry(
|
||||
entry_id="test_1",
|
||||
category="test",
|
||||
content="This is a test entry",
|
||||
metadata={"source": "test"},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time(),
|
||||
)
|
||||
|
||||
self.store.add_entry(entry)
|
||||
|
||||
# Verify entry was added
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT entry_id, category, content FROM knowledge_entries WHERE entry_id = ?",
|
||||
("test_1",),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row[0] == "test_1"
|
||||
assert row[1] == "test"
|
||||
assert row[2] == "This is a test entry"
|
||||
conn.close()
|
||||
|
||||
def test_get_entry(self):
|
||||
"""Test retrieving a knowledge entry."""
|
||||
entry = KnowledgeEntry(
|
||||
entry_id="test_get",
|
||||
category="test",
|
||||
content="Content to retrieve",
|
||||
metadata={},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time(),
|
||||
)
|
||||
|
||||
self.store.add_entry(entry)
|
||||
retrieved = self.store.get_entry("test_get")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.entry_id == "test_get"
|
||||
assert retrieved.content == "Content to retrieve"
|
||||
assert retrieved.access_count == 1 # Should be incremented
|
||||
|
||||
def test_get_entry_not_found(self):
|
||||
"""Test retrieving a non-existent entry."""
|
||||
retrieved = self.store.get_entry("nonexistent")
|
||||
assert retrieved is None
|
||||
|
||||
def test_search_entries_semantic(self):
|
||||
"""Test semantic search."""
|
||||
entries = [
|
||||
KnowledgeEntry(
|
||||
"entry1", "personal", "John is a software engineer", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"entry2", "personal", "Mary works as a designer", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"entry3", "tech", "Python is a programming language", {}, time.time(), time.time()
|
||||
),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
results = self.store.search_entries("software engineer", top_k=2)
|
||||
assert len(results) >= 1
|
||||
# Should find the most relevant entry
|
||||
found_ids = [r.entry_id for r in results]
|
||||
assert "entry1" in found_ids
|
||||
|
||||
def test_search_entries_fts_exact(self):
|
||||
"""Test full-text search with exact matches."""
|
||||
entries = [
|
||||
KnowledgeEntry(
|
||||
"exact1", "test", "Python programming language", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"exact2", "test", "Java programming language", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"exact3", "test", "Database design principles", {}, time.time(), time.time()
|
||||
),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
results = self.store.search_entries("programming language", top_k=3)
|
||||
assert len(results) >= 2
|
||||
found_ids = [r.entry_id for r in results]
|
||||
assert "exact1" in found_ids
|
||||
assert "exact2" in found_ids
|
||||
|
||||
def test_search_entries_by_category(self):
|
||||
"""Test searching entries by category."""
|
||||
entries = [
|
||||
KnowledgeEntry("cat1", "personal", "John's info", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("cat2", "tech", "Python info", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("cat3", "personal", "Jane's info", {}, time.time(), time.time()),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
results = self.store.search_entries("info", category="personal", top_k=5)
|
||||
assert len(results) == 2
|
||||
found_ids = [r.entry_id for r in results]
|
||||
assert "cat1" in found_ids
|
||||
assert "cat3" in found_ids
|
||||
assert "cat2" not in found_ids
|
||||
|
||||
def test_get_by_category(self):
|
||||
"""Test getting entries by category."""
|
||||
entries = [
|
||||
KnowledgeEntry("get1", "personal", "Entry 1", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("get2", "tech", "Entry 2", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("get3", "personal", "Entry 3", {}, time.time() + 1, time.time() + 1),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
personal_entries = self.store.get_by_category("personal")
|
||||
assert len(personal_entries) == 2
|
||||
# Should be ordered by importance_score DESC, created_at DESC
|
||||
assert personal_entries[0].entry_id == "get3" # More recent
|
||||
assert personal_entries[1].entry_id == "get1"
|
||||
|
||||
def test_update_importance(self):
|
||||
"""Test updating entry importance."""
|
||||
entry = KnowledgeEntry(
|
||||
"importance_test", "test", "Test content", {}, time.time(), time.time()
|
||||
)
|
||||
self.store.add_entry(entry)
|
||||
|
||||
self.store.update_importance("importance_test", 0.8)
|
||||
|
||||
retrieved = self.store.get_entry("importance_test")
|
||||
assert retrieved.importance_score == 0.8
|
||||
|
||||
def test_delete_entry(self):
|
||||
"""Test deleting an entry."""
|
||||
entry = KnowledgeEntry("delete_test", "test", "To be deleted", {}, time.time(), time.time())
|
||||
self.store.add_entry(entry)
|
||||
|
||||
result = self.store.delete_entry("delete_test")
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
retrieved = self.store.get_entry("delete_test")
|
||||
assert retrieved is None
|
||||
|
||||
def test_delete_entry_not_found(self):
|
||||
"""Test deleting a non-existent entry."""
|
||||
result = self.store.delete_entry("nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_get_statistics(self):
|
||||
"""Test getting store statistics."""
|
||||
entries = [
|
||||
KnowledgeEntry("stat1", "personal", "Personal info", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("stat2", "tech", "Tech info", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("stat3", "personal", "More personal info", {}, time.time(), time.time()),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
stats = self.store.get_statistics()
|
||||
assert stats["total_entries"] == 3
|
||||
assert stats["total_categories"] == 2
|
||||
assert stats["category_distribution"]["personal"] == 2
|
||||
assert stats["category_distribution"]["tech"] == 1
|
||||
assert stats["vocabulary_size"] > 0
|
||||
|
||||
def test_fts_search_exact_phrase(self):
|
||||
"""Test FTS exact phrase matching."""
|
||||
entries = [
|
||||
KnowledgeEntry(
|
||||
"fts1", "test", "Python is great for programming", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"fts2", "test", "Java is also good for programming", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"fts3", "test", "Database management is important", {}, time.time(), time.time()
|
||||
),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
fts_results = self.store._fts_search("programming")
|
||||
assert len(fts_results) == 2
|
||||
entry_ids = [entry_id for entry_id, score in fts_results]
|
||||
assert "fts1" in entry_ids
|
||||
assert "fts2" in entry_ids
|
||||
|
||||
def test_fts_search_partial_match(self):
|
||||
"""Test FTS partial word matching."""
|
||||
entries = [
|
||||
KnowledgeEntry("partial1", "test", "Python programming", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("partial2", "test", "Java programming", {}, time.time(), time.time()),
|
||||
KnowledgeEntry("partial3", "test", "Database design", {}, time.time(), time.time()),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
fts_results = self.store._fts_search("program")
|
||||
assert len(fts_results) >= 2
|
||||
|
||||
def test_combined_search_scoring(self):
|
||||
"""Test that combined semantic + FTS search produces proper scoring."""
|
||||
entries = [
|
||||
KnowledgeEntry(
|
||||
"combined1", "test", "Python programming language", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"combined2", "test", "Java programming language", {}, time.time(), time.time()
|
||||
),
|
||||
KnowledgeEntry(
|
||||
"combined3", "test", "Database management system", {}, time.time(), time.time()
|
||||
),
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
self.store.add_entry(entry)
|
||||
|
||||
results = self.store.search_entries("programming language")
|
||||
assert len(results) >= 2
|
||||
|
||||
# Check that results have search scores (at least one should have a positive score)
|
||||
has_positive_score = False
|
||||
for result in results:
|
||||
assert "search_score" in result.metadata
|
||||
if result.metadata["search_score"] > 0:
|
||||
has_positive_score = True
|
||||
assert has_positive_score
|
||||
|
||||
def test_empty_search(self):
|
||||
"""Test searching with no matches."""
|
||||
results = self.store.search_entries("nonexistent topic")
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_empty_query(self):
|
||||
"""Test searching with empty query."""
|
||||
results = self.store.search_entries("")
|
||||
assert len(results) == 0
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Test that the store can handle concurrent access."""
|
||||
import threading
|
||||
import queue
|
||||
|
||||
results = queue.Queue()
|
||||
|
||||
def worker(worker_id):
|
||||
try:
|
||||
entry = KnowledgeEntry(
|
||||
f"thread_{worker_id}",
|
||||
"test",
|
||||
f"Content from worker {worker_id}",
|
||||
{},
|
||||
time.time(),
|
||||
time.time(),
|
||||
)
|
||||
self.store.add_entry(entry)
|
||||
retrieved = self.store.get_entry(f"thread_{worker_id}")
|
||||
assert retrieved is not None
|
||||
results.put(True)
|
||||
except Exception as e:
|
||||
results.put(e)
|
||||
|
||||
# Start multiple threads
|
||||
threads = []
|
||||
for i in range(5):
|
||||
t = threading.Thread(target=worker, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Check results
|
||||
for _ in range(5):
|
||||
result = results.get()
|
||||
assert result is True
|
||||
|
||||
def test_entry_to_dict(self):
|
||||
"""Test KnowledgeEntry to_dict method."""
|
||||
entry = KnowledgeEntry(
|
||||
entry_id="dict_test",
|
||||
category="test",
|
||||
content="Test content",
|
||||
metadata={"key": "value"},
|
||||
created_at=1234567890.0,
|
||||
updated_at=1234567891.0,
|
||||
access_count=5,
|
||||
importance_score=0.8,
|
||||
)
|
||||
|
||||
entry_dict = entry.to_dict()
|
||||
assert entry_dict["entry_id"] == "dict_test"
|
||||
assert entry_dict["category"] == "test"
|
||||
assert entry_dict["content"] == "Test content"
|
||||
assert entry_dict["metadata"]["key"] == "value"
|
||||
assert entry_dict["access_count"] == 5
|
||||
assert entry_dict["importance_score"] == 0.8
|
||||
280
tests/test_semantic_index.py
Normal file
280
tests/test_semantic_index.py
Normal file
@ -0,0 +1,280 @@
|
||||
import math
|
||||
import pytest
|
||||
from pr.memory.semantic_index import SemanticIndex
|
||||
|
||||
|
||||
class TestSemanticIndex:
|
||||
def test_init(self):
|
||||
"""Test SemanticIndex initialization."""
|
||||
index = SemanticIndex()
|
||||
assert index.documents == {}
|
||||
assert index.vocabulary == set()
|
||||
assert index.idf_scores == {}
|
||||
assert index.doc_tf_scores == {}
|
||||
|
||||
def test_tokenize_basic(self):
|
||||
"""Test basic tokenization functionality."""
|
||||
index = SemanticIndex()
|
||||
tokens = index._tokenize("Hello, world! This is a test.")
|
||||
expected = ["hello", "world", "this", "is", "a", "test"]
|
||||
assert tokens == expected
|
||||
|
||||
def test_tokenize_special_characters(self):
|
||||
"""Test tokenization with special characters."""
|
||||
index = SemanticIndex()
|
||||
tokens = index._tokenize("Hello@world.com test-case_123")
|
||||
expected = ["hello", "world", "com", "test", "case", "123"]
|
||||
assert tokens == expected
|
||||
|
||||
def test_tokenize_empty_string(self):
|
||||
"""Test tokenization of empty string."""
|
||||
index = SemanticIndex()
|
||||
tokens = index._tokenize("")
|
||||
assert tokens == []
|
||||
|
||||
def test_tokenize_only_special_chars(self):
|
||||
"""Test tokenization with only special characters."""
|
||||
index = SemanticIndex()
|
||||
tokens = index._tokenize("!@#$%^&*()")
|
||||
assert tokens == []
|
||||
|
||||
def test_compute_tf_basic(self):
|
||||
"""Test TF computation for basic case."""
|
||||
index = SemanticIndex()
|
||||
tokens = ["hello", "world", "hello", "test"]
|
||||
tf_scores = index._compute_tf(tokens)
|
||||
expected = {"hello": 2 / 4, "world": 1 / 4, "test": 1 / 4} # 0.5 # 0.25 # 0.25
|
||||
assert tf_scores == expected
|
||||
|
||||
def test_compute_tf_empty(self):
|
||||
"""Test TF computation for empty tokens."""
|
||||
index = SemanticIndex()
|
||||
tf_scores = index._compute_tf([])
|
||||
assert tf_scores == {}
|
||||
|
||||
def test_compute_tf_single_token(self):
|
||||
"""Test TF computation for single token."""
|
||||
index = SemanticIndex()
|
||||
tokens = ["hello"]
|
||||
tf_scores = index._compute_tf(tokens)
|
||||
assert tf_scores == {"hello": 1.0}
|
||||
|
||||
def test_compute_idf_single_document(self):
|
||||
"""Test IDF computation with single document."""
|
||||
index = SemanticIndex()
|
||||
index.documents = {"doc1": "hello world"}
|
||||
index._compute_idf()
|
||||
assert index.idf_scores == {"hello": 1.0, "world": 1.0}
|
||||
|
||||
def test_compute_idf_multiple_documents(self):
|
||||
"""Test IDF computation with multiple documents."""
|
||||
index = SemanticIndex()
|
||||
index.documents = {"doc1": "hello world", "doc2": "hello test", "doc3": "world test"}
|
||||
index._compute_idf()
|
||||
expected = {
|
||||
"hello": math.log(3 / 2), # appears in 2/3 docs
|
||||
"world": math.log(3 / 2), # appears in 2/3 docs
|
||||
"test": math.log(3 / 2), # appears in 2/3 docs
|
||||
}
|
||||
assert index.idf_scores == expected
|
||||
|
||||
def test_compute_idf_empty_documents(self):
|
||||
"""Test IDF computation with no documents."""
|
||||
index = SemanticIndex()
|
||||
index._compute_idf()
|
||||
assert index.idf_scores == {}
|
||||
|
||||
def test_add_document_basic(self):
|
||||
"""Test adding a basic document."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
|
||||
assert "doc1" in index.documents
|
||||
assert index.documents["doc1"] == "hello world"
|
||||
assert "hello" in index.vocabulary
|
||||
assert "world" in index.vocabulary
|
||||
assert "doc1" in index.doc_tf_scores
|
||||
|
||||
def test_add_document_updates_vocabulary(self):
|
||||
"""Test that adding documents updates vocabulary."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
assert index.vocabulary == {"hello", "world"}
|
||||
|
||||
index.add_document("doc2", "hello test")
|
||||
assert index.vocabulary == {"hello", "world", "test"}
|
||||
|
||||
def test_add_document_updates_idf(self):
|
||||
"""Test that adding documents updates IDF scores."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
assert index.idf_scores == {"hello": 1.0, "world": 1.0}
|
||||
|
||||
index.add_document("doc2", "hello test")
|
||||
expected_idf = {
|
||||
"hello": math.log(2 / 2), # appears in both docs
|
||||
"world": math.log(2 / 1), # appears in 1/2 docs
|
||||
"test": math.log(2 / 1), # appears in 1/2 docs
|
||||
}
|
||||
assert index.idf_scores == expected_idf
|
||||
|
||||
def test_add_document_tf_computation(self):
|
||||
"""Test TF score computation when adding document."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world hello")
|
||||
|
||||
# TF: hello=2/3, world=1/3
|
||||
expected_tf = {"hello": 2 / 3, "world": 1 / 3}
|
||||
assert index.doc_tf_scores["doc1"] == expected_tf
|
||||
|
||||
def test_remove_document_existing(self):
|
||||
"""Test removing an existing document."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
index.add_document("doc2", "hello test")
|
||||
|
||||
initial_vocab = index.vocabulary.copy()
|
||||
initial_idf = index.idf_scores.copy()
|
||||
|
||||
index.remove_document("doc1")
|
||||
|
||||
assert "doc1" not in index.documents
|
||||
assert "doc1" not in index.doc_tf_scores
|
||||
# Vocabulary should still contain all words
|
||||
assert index.vocabulary == initial_vocab
|
||||
# IDF should be recomputed
|
||||
assert index.idf_scores != initial_idf
|
||||
assert index.idf_scores == {"hello": 1.0, "test": 1.0}
|
||||
|
||||
def test_remove_document_nonexistent(self):
|
||||
"""Test removing a non-existent document."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
|
||||
initial_state = {
|
||||
"documents": index.documents.copy(),
|
||||
"vocabulary": index.vocabulary.copy(),
|
||||
"idf_scores": index.idf_scores.copy(),
|
||||
"doc_tf_scores": index.doc_tf_scores.copy(),
|
||||
}
|
||||
|
||||
index.remove_document("nonexistent")
|
||||
|
||||
assert index.documents == initial_state["documents"]
|
||||
assert index.vocabulary == initial_state["vocabulary"]
|
||||
assert index.idf_scores == initial_state["idf_scores"]
|
||||
assert index.doc_tf_scores == initial_state["doc_tf_scores"]
|
||||
|
||||
def test_search_basic(self):
|
||||
"""Test basic search functionality."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
index.add_document("doc2", "hello test")
|
||||
index.add_document("doc3", "world test")
|
||||
|
||||
results = index.search("hello", top_k=5)
|
||||
assert len(results) == 3 # All documents are returned with similarity scores
|
||||
|
||||
# Results should be sorted by similarity (descending)
|
||||
scores = {doc_id: score for doc_id, score in results}
|
||||
assert scores["doc1"] > 0 # doc1 contains "hello"
|
||||
assert scores["doc2"] > 0 # doc2 contains "hello"
|
||||
assert scores["doc3"] == 0 # doc3 does not contain "hello"
|
||||
|
||||
def test_search_empty_query(self):
|
||||
"""Test search with empty query."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
|
||||
results = index.search("", top_k=5)
|
||||
assert results == []
|
||||
|
||||
def test_search_no_documents(self):
|
||||
"""Test search when no documents exist."""
|
||||
index = SemanticIndex()
|
||||
results = index.search("hello", top_k=5)
|
||||
assert results == []
|
||||
|
||||
def test_search_top_k_limit(self):
|
||||
"""Test search respects top_k parameter."""
|
||||
index = SemanticIndex()
|
||||
for i in range(10):
|
||||
index.add_document(f"doc{i}", f"hello world content")
|
||||
|
||||
results = index.search("content", top_k=3)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_cosine_similarity_identical_vectors(self):
|
||||
"""Test cosine similarity with identical vectors."""
|
||||
index = SemanticIndex()
|
||||
vec1 = {"hello": 1.0, "world": 0.5}
|
||||
vec2 = {"hello": 1.0, "world": 0.5}
|
||||
similarity = index._cosine_similarity(vec1, vec2)
|
||||
assert similarity == pytest.approx(1.0)
|
||||
|
||||
def test_cosine_similarity_orthogonal_vectors(self):
|
||||
"""Test cosine similarity with orthogonal vectors."""
|
||||
index = SemanticIndex()
|
||||
vec1 = {"hello": 1.0}
|
||||
vec2 = {"world": 1.0}
|
||||
similarity = index._cosine_similarity(vec1, vec2)
|
||||
assert similarity == 0.0
|
||||
|
||||
def test_cosine_similarity_zero_vector(self):
|
||||
"""Test cosine similarity with zero vector."""
|
||||
index = SemanticIndex()
|
||||
vec1 = {"hello": 1.0}
|
||||
vec2 = {}
|
||||
similarity = index._cosine_similarity(vec1, vec2)
|
||||
assert similarity == 0.0
|
||||
|
||||
def test_cosine_similarity_empty_vectors(self):
|
||||
"""Test cosine similarity with empty vectors."""
|
||||
index = SemanticIndex()
|
||||
similarity = index._cosine_similarity({}, {})
|
||||
assert similarity == 0.0
|
||||
|
||||
def test_search_relevance_ordering(self):
|
||||
"""Test that search results are ordered by relevance."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello hello hello") # High TF for "hello"
|
||||
index.add_document("doc2", "hello world") # Medium relevance
|
||||
index.add_document("doc3", "world test") # No "hello"
|
||||
|
||||
results = index.search("hello", top_k=5)
|
||||
assert len(results) == 3 # All documents are returned
|
||||
|
||||
# doc1 should have higher score than doc2, and doc3 should have 0
|
||||
scores = {doc_id: score for doc_id, score in results}
|
||||
assert scores["doc1"] > scores["doc2"] > scores["doc3"]
|
||||
assert scores["doc3"] == 0
|
||||
|
||||
def test_vocabulary_persistence(self):
|
||||
"""Test that vocabulary persists even after document removal."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
index.add_document("doc2", "test case")
|
||||
|
||||
assert index.vocabulary == {"hello", "world", "test", "case"}
|
||||
|
||||
index.remove_document("doc1")
|
||||
# Vocabulary should still contain all words
|
||||
assert index.vocabulary == {"hello", "world", "test", "case"}
|
||||
|
||||
def test_idf_recomputation_after_removal(self):
|
||||
"""Test IDF recomputation after document removal."""
|
||||
index = SemanticIndex()
|
||||
index.add_document("doc1", "hello world")
|
||||
index.add_document("doc2", "hello test")
|
||||
index.add_document("doc3", "world test")
|
||||
|
||||
# Remove doc3, leaving doc1 and doc2
|
||||
index.remove_document("doc3")
|
||||
|
||||
# "hello" appears in both remaining docs, "world" and "test" in one each
|
||||
expected_idf = {
|
||||
"hello": math.log(2 / 2), # 2 docs, appears in 2
|
||||
"world": math.log(2 / 1), # 2 docs, appears in 1
|
||||
"test": math.log(2 / 1), # 2 docs, appears in 1
|
||||
}
|
||||
assert index.idf_scores == expected_idf
|
||||
@ -168,7 +168,6 @@ class TestToolDefinitions:
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
|
||||
assert "run_command" in tool_names
|
||||
assert "start_interactive_session" in tool_names
|
||||
|
||||
def test_python_exec_present(self):
|
||||
tools = get_tools_definition()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user