feat: implement distributed ads with unix sockets

This commit is contained in:
retoor 2025-11-06 16:44:41 +01:00
parent f59894c65a
commit 66a45eb5e6
19 changed files with 531 additions and 1771 deletions

View File

@ -2,6 +2,14 @@
## Version 1.5.0 - 2025-11-06
Agents can now use datasets stored across multiple locations. This allows for larger datasets and improved performance.
**Changes:** 2 files, 10 lines
**Languages:** Markdown (8 lines), TOML (2 lines)
## Version 1.4.0 - 2025-11-06 ## Version 1.4.0 - 2025-11-06
Agents can now share data more efficiently using a new distributed dataset system. This improves performance and allows agents to work together on larger tasks. Agents can now share data more efficiently using a new distributed dataset system. This improves performance and allows agents to work together on larger tasks.

View File

@ -1,4 +1,5 @@
__version__ = "1.0.0"
from pr.core import Assistant from pr.core import Assistant
__version__ = "1.0.0"
__all__ = ["Assistant"] __all__ = ["Assistant"]

View File

@ -1,13 +1,18 @@
import argparse import argparse
import asyncio
import sys import sys
from pr import __version__ from pr import __version__
from pr.core import Assistant from pr.core import Assistant
def main(): async def main_async():
import tracemalloc
tracemalloc.start()
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="PR Assistant - Professional CLI AI assistant with autonomous execution", description="PR Assistant - Professional CLI AI assistant with visual effects, cost tracking, and autonomous execution",
epilog=""" epilog="""
Examples: Examples:
pr "What is Python?" # Single query pr "What is Python?" # Single query
@ -18,6 +23,12 @@ Examples:
pr --list-sessions # List all sessions pr --list-sessions # List all sessions
pr --usage # Show token usage stats pr --usage # Show token usage stats
Features:
Visual progress indicators during AI calls
Real-time cost tracking for each query
Sophisticated CLI with colors and effects
Tool execution with status updates
Commands in interactive mode: Commands in interactive mode:
/auto [task] - Enter autonomous mode /auto [task] - Enter autonomous mode
/reset - Clear message history /reset - Clear message history
@ -156,7 +167,11 @@ Commands in interactive mode:
return return
assistant = Assistant(args) assistant = Assistant(args)
assistant.run() await assistant.run()
def main():
return asyncio.run(main_async())
if __name__ == "__main__": if __name__ == "__main__":

969
pr/ads.py
View File

@ -1,969 +0,0 @@
import re
import json
from uuid import uuid4
from datetime import datetime, timezone
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
AsyncGenerator,
Union,
Tuple,
Set,
)
import aiosqlite
import asyncio
import aiohttp
from aiohttp import web
import socket
import os
import atexit
class AsyncDataSet:
"""
Distributed AsyncDataSet with client-server model over Unix sockets.
Parameters:
- file: Path to the SQLite database file
- socket_path: Path to Unix socket for inter-process communication (default: "ads.sock")
- max_concurrent_queries: Maximum concurrent database operations (default: 1)
Set to 1 to prevent "database is locked" errors with SQLite
"""
_KV_TABLE = "__kv_store"
_DEFAULT_COLUMNS = {
"uid": "TEXT PRIMARY KEY",
"created_at": "TEXT",
"updated_at": "TEXT",
"deleted_at": "TEXT",
}
def __init__(self, file: str, socket_path: str = "ads.sock", max_concurrent_queries: int = 1):
self._file = file
self._socket_path = socket_path
self._table_columns_cache: Dict[str, Set[str]] = {}
self._is_server = None # None means not initialized yet
self._server = None
self._runner = None
self._client_session = None
self._initialized = False
self._init_lock = None # Will be created when needed
self._max_concurrent_queries = max_concurrent_queries
self._db_semaphore = None # Will be created when needed
self._db_connection = None # Persistent database connection
async def _ensure_initialized(self):
"""Ensure the instance is initialized as server or client."""
if self._initialized:
return
# Create lock if needed (first time in async context)
if self._init_lock is None:
self._init_lock = asyncio.Lock()
async with self._init_lock:
if self._initialized:
return
await self._initialize()
# Create semaphore for database operations (server only)
if self._is_server:
self._db_semaphore = asyncio.Semaphore(self._max_concurrent_queries)
self._initialized = True
async def _initialize(self):
"""Initialize as server or client."""
try:
# Try to create Unix socket
if os.path.exists(self._socket_path):
# Check if socket is active
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(self._socket_path)
sock.close()
# Socket is active, we're a client
self._is_server = False
await self._setup_client()
except (ConnectionRefusedError, FileNotFoundError):
# Socket file exists but not active, remove and become server
os.unlink(self._socket_path)
self._is_server = True
await self._setup_server()
else:
# No socket file, become server
self._is_server = True
await self._setup_server()
except Exception:
# Fallback to client mode
self._is_server = False
await self._setup_client()
# Establish persistent database connection
self._db_connection = await aiosqlite.connect(self._file)
async def _setup_server(self):
"""Setup aiohttp server."""
app = web.Application()
# Add routes for all methods
app.router.add_post("/insert", self._handle_insert)
app.router.add_post("/update", self._handle_update)
app.router.add_post("/delete", self._handle_delete)
app.router.add_post("/upsert", self._handle_upsert)
app.router.add_post("/get", self._handle_get)
app.router.add_post("/find", self._handle_find)
app.router.add_post("/count", self._handle_count)
app.router.add_post("/exists", self._handle_exists)
app.router.add_post("/kv_set", self._handle_kv_set)
app.router.add_post("/kv_get", self._handle_kv_get)
app.router.add_post("/execute_raw", self._handle_execute_raw)
app.router.add_post("/query_raw", self._handle_query_raw)
app.router.add_post("/query_one", self._handle_query_one)
app.router.add_post("/create_table", self._handle_create_table)
app.router.add_post("/insert_unique", self._handle_insert_unique)
app.router.add_post("/aggregate", self._handle_aggregate)
self._runner = web.AppRunner(app)
await self._runner.setup()
self._server = web.UnixSite(self._runner, self._socket_path)
await self._server.start()
# Register cleanup
def cleanup():
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
atexit.register(cleanup)
async def _setup_client(self):
"""Setup aiohttp client."""
connector = aiohttp.UnixConnector(path=self._socket_path)
self._client_session = aiohttp.ClientSession(connector=connector)
async def _make_request(self, endpoint: str, data: Dict[str, Any]) -> Any:
"""Make HTTP request to server."""
if not self._client_session:
await self._setup_client()
url = f"http://localhost/{endpoint}"
async with self._client_session.post(url, json=data) as resp:
result = await resp.json()
if result.get("error"):
raise Exception(result["error"])
return result.get("result")
# Server handlers
async def _handle_insert(self, request):
data = await request.json()
try:
result = await self._server_insert(
data["table"], data["args"], data.get("return_id", False)
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_update(self, request):
data = await request.json()
try:
result = await self._server_update(data["table"], data["args"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_delete(self, request):
data = await request.json()
try:
result = await self._server_delete(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_upsert(self, request):
data = await request.json()
try:
result = await self._server_upsert(data["table"], data["args"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_get(self, request):
data = await request.json()
try:
result = await self._server_get(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_find(self, request):
data = await request.json()
try:
result = await self._server_find(
data["table"],
data.get("where"),
limit=data.get("limit", 0),
offset=data.get("offset", 0),
order_by=data.get("order_by"),
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_count(self, request):
data = await request.json()
try:
result = await self._server_count(data["table"], data.get("where"))
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_exists(self, request):
data = await request.json()
try:
result = await self._server_exists(data["table"], data["where"])
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_kv_set(self, request):
data = await request.json()
try:
await self._server_kv_set(data["key"], data["value"], table=data.get("table"))
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_kv_get(self, request):
data = await request.json()
try:
result = await self._server_kv_get(
data["key"], default=data.get("default"), table=data.get("table")
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_execute_raw(self, request):
data = await request.json()
try:
result = await self._server_execute_raw(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response(
{"result": result.rowcount if hasattr(result, "rowcount") else None}
)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_query_raw(self, request):
data = await request.json()
try:
result = await self._server_query_raw(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_query_one(self, request):
data = await request.json()
try:
result = await self._server_query_one(
data["sql"],
tuple(data.get("params", [])) if data.get("params") else None,
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_create_table(self, request):
data = await request.json()
try:
await self._server_create_table(data["table"], data["schema"], data.get("constraints"))
return web.json_response({"result": None})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_insert_unique(self, request):
data = await request.json()
try:
result = await self._server_insert_unique(
data["table"], data["args"], data["unique_fields"]
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
async def _handle_aggregate(self, request):
data = await request.json()
try:
result = await self._server_aggregate(
data["table"],
data["function"],
data.get("column", "*"),
data.get("where"),
)
return web.json_response({"result": result})
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
# Public methods that delegate to server or client
async def insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
await self._ensure_initialized()
if self._is_server:
return await self._server_insert(table, args, return_id)
else:
return await self._make_request(
"insert", {"table": table, "args": args, "return_id": return_id}
)
async def update(
self, table: str, args: Dict[str, Any], where: Optional[Dict[str, Any]] = None
) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_update(table, args, where)
else:
return await self._make_request(
"update", {"table": table, "args": args, "where": where}
)
async def delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_delete(table, where)
else:
return await self._make_request("delete", {"table": table, "where": where})
async def get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_get(table, where)
else:
return await self._make_request("get", {"table": table, "where": where})
async def find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_find(
table, where, limit=limit, offset=offset, order_by=order_by
)
else:
return await self._make_request(
"find",
{
"table": table,
"where": where,
"limit": limit,
"offset": offset,
"order_by": order_by,
},
)
async def count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
await self._ensure_initialized()
if self._is_server:
return await self._server_count(table, where)
else:
return await self._make_request("count", {"table": table, "where": where})
async def exists(self, table: str, where: Dict[str, Any]) -> bool:
await self._ensure_initialized()
if self._is_server:
return await self._server_exists(table, where)
else:
return await self._make_request("exists", {"table": table, "where": where})
async def query_raw(self, sql: str, params: Optional[Tuple] = None) -> List[Dict[str, Any]]:
await self._ensure_initialized()
if self._is_server:
return await self._server_query_raw(sql, params)
else:
return await self._make_request(
"query_raw", {"sql": sql, "params": list(params) if params else None}
)
async def create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
await self._ensure_initialized()
if self._is_server:
return await self._server_create_table(table, schema, constraints)
else:
return await self._make_request(
"create_table",
{"table": table, "schema": schema, "constraints": constraints},
)
async def insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
await self._ensure_initialized()
if self._is_server:
return await self._server_insert_unique(table, args, unique_fields)
else:
return await self._make_request(
"insert_unique",
{"table": table, "args": args, "unique_fields": unique_fields},
)
async def aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
await self._ensure_initialized()
if self._is_server:
return await self._server_aggregate(table, function, column, where)
else:
return await self._make_request(
"aggregate",
{
"table": table,
"function": function,
"column": column,
"where": where,
},
)
async def transaction(self):
"""Context manager for transactions."""
await self._ensure_initialized()
if self._is_server:
return TransactionContext(self._db_connection, self._db_semaphore)
else:
raise NotImplementedError("Transactions not supported in client mode")
# Server implementation methods (original logic)
@staticmethod
def _utc_iso() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
@staticmethod
def _py_to_sqlite_type(value: Any) -> str:
if value is None:
return "TEXT"
if isinstance(value, bool):
return "INTEGER"
if isinstance(value, int):
return "INTEGER"
if isinstance(value, float):
return "REAL"
if isinstance(value, (bytes, bytearray, memoryview)):
return "BLOB"
return "TEXT"
async def _get_table_columns(self, table: str) -> Set[str]:
"""Get actual columns that exist in the table."""
if table in self._table_columns_cache:
return self._table_columns_cache[table]
columns = set()
try:
async with self._db_semaphore:
async with self._db_connection.execute(f"PRAGMA table_info({table})") as cursor:
async for row in cursor:
columns.add(row[1]) # Column name is at index 1
self._table_columns_cache[table] = columns
except:
pass
return columns
async def _invalidate_column_cache(self, table: str):
"""Invalidate column cache for a table."""
if table in self._table_columns_cache:
del self._table_columns_cache[table]
async def _ensure_column(self, table: str, name: str, value: Any) -> None:
col_type = self._py_to_sqlite_type(value)
try:
async with self._db_semaphore:
await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN `{name}` {col_type}"
)
await self._db_connection.commit()
await self._invalidate_column_cache(table)
except aiosqlite.OperationalError as e:
if "duplicate column name" in str(e).lower():
pass # Column already exists
else:
raise
async def _ensure_table(self, table: str, col_sources: Dict[str, Any]) -> None:
# Always include default columns
cols = self._DEFAULT_COLUMNS.copy()
# Add columns from col_sources
for key, val in col_sources.items():
if key not in cols:
cols[key] = self._py_to_sqlite_type(val)
columns_sql = ", ".join(f"`{k}` {t}" for k, t in cols.items())
async with self._db_semaphore:
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await self._db_connection.commit()
await self._invalidate_column_cache(table)
async def _table_exists(self, table: str) -> bool:
"""Check if a table exists."""
async with self._db_semaphore:
async with self._db_connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,)
) as cursor:
return await cursor.fetchone() is not None
_RE_NO_COLUMN = re.compile(r"(?:no such column:|has no column named) (\w+)")
_RE_NO_TABLE = re.compile(r"no such table: (\w+)")
@classmethod
def _missing_column_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
m = cls._RE_NO_COLUMN.search(str(err))
return m.group(1) if m else None
@classmethod
def _missing_table_from_error(cls, err: aiosqlite.OperationalError) -> Optional[str]:
m = cls._RE_NO_TABLE.search(str(err))
return m.group(1) if m else None
async def _safe_execute(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
max_retries: int = 10,
) -> aiosqlite.Cursor:
retries = 0
while retries < max_retries:
try:
async with self._db_semaphore:
cursor = await self._db_connection.execute(sql, params)
await self._db_connection.commit()
return cursor
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing column
col = self._missing_column_from_error(err)
if col:
if col in col_sources:
await self._ensure_column(table, col, col_sources[col])
else:
# Column not in sources, ensure it with NULL/TEXT type
await self._ensure_column(table, col, None)
continue
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
await self._ensure_table(tbl, col_sources)
continue
# Handle other column-related errors
if "has no column named" in err_str:
# Extract column name differently
match = re.search(r"table \w+ has no column named (\w+)", err_str)
if match:
col_name = match.group(1)
if col_name in col_sources:
await self._ensure_column(table, col_name, col_sources[col_name])
else:
await self._ensure_column(table, col_name, None)
continue
raise
raise Exception(f"Max retries ({max_retries}) exceeded")
async def _safe_query(
self,
table: str,
sql: str,
params: Iterable[Any],
col_sources: Dict[str, Any],
) -> AsyncGenerator[Dict[str, Any], None]:
# Check if table exists first
if not await self._table_exists(table):
return
max_retries = 10
retries = 0
while retries < max_retries:
try:
async with self._db_semaphore:
self._db_connection.row_factory = aiosqlite.Row
async with self._db_connection.execute(sql, params) as cursor:
# Fetch all rows while holding the semaphore
rows = await cursor.fetchall()
# Yield rows after releasing the semaphore
for row in rows:
yield dict(row)
return
except aiosqlite.OperationalError as err:
retries += 1
err_str = str(err).lower()
# Handle missing table
tbl = self._missing_table_from_error(err)
if tbl:
# For queries, if table doesn't exist, just return empty
return
# Handle missing column in WHERE clause or SELECT
if "no such column" in err_str:
# For queries with missing columns, return empty
return
raise
@staticmethod
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where:
return "", []
clauses = []
vals = []
op_map = {"$eq": "=", "$ne": "!=", "$gt": ">", "$gte": ">=", "$lt": "<", "$lte": "<="}
for k, v in where.items():
if isinstance(v, dict):
# Handle operators like {"$lt": 10}
op_key = next(iter(v))
if op_key in op_map:
operator = op_map[op_key]
value = v[op_key]
clauses.append(f"`{k}` {operator} ?")
vals.append(value)
else:
# Fallback for unexpected dict values
clauses.append(f"`{k}` = ?")
vals.append(json.dumps(v))
else:
# Default to equality
clauses.append(f"`{k}` = ?")
vals.append(v)
return " WHERE " + " AND ".join(clauses), vals
async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
) -> Union[str, int]:
"""Insert a record. If return_id=True, returns auto-incremented ID instead of UUID."""
uid = str(uuid4())
now = self._utc_iso()
record = {
"uid": uid,
"created_at": now,
"updated_at": now,
"deleted_at": None,
**args,
}
# Ensure table exists with all needed columns
await self._ensure_table(table, record)
# Handle auto-increment ID if requested
if return_id and "id" not in args:
# Ensure id column exists
async with self._db_semaphore:
# Add id column if it doesn't exist
try:
await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER PRIMARY KEY AUTOINCREMENT"
)
await self._db_connection.commit()
except aiosqlite.OperationalError as e:
if "duplicate column name" not in str(e).lower():
# Try without autoincrement constraint
try:
await self._db_connection.execute(
f"ALTER TABLE {table} ADD COLUMN id INTEGER"
)
await self._db_connection.commit()
except:
pass
await self._invalidate_column_cache(table)
# Insert and get lastrowid
cols = "`" + "`, `".join(record.keys()) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
cursor = await self._safe_execute(table, sql, list(record.values()), record)
return cursor.lastrowid
cols = "`" + "`, `".join(record) + "`"
qs = ", ".join(["?"] * len(record))
sql = f"INSERT INTO {table} ({cols}) VALUES ({qs})"
await self._safe_execute(table, sql, list(record.values()), record)
return uid
async def _server_update(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> int:
if not args:
return 0
# Check if table exists
if not await self._table_exists(table):
return 0
args["updated_at"] = self._utc_iso()
# Ensure all columns exist
all_cols = {**args, **(where or {})}
await self._ensure_table(table, all_cols)
for col, val in all_cols.items():
await self._ensure_column(table, col, val)
set_clause = ", ".join(f"`{k}` = ?" for k in args)
where_clause, where_params = self._build_where(where)
sql = f"UPDATE {table} SET {set_clause}{where_clause}"
params = list(args.values()) + where_params
cur = await self._safe_execute(table, sql, params, all_cols)
return cur.rowcount
async def _server_delete(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"DELETE FROM {table}{where_clause}"
cur = await self._safe_execute(table, sql, where_params, where or {})
return cur.rowcount
async def _server_upsert(
self,
table: str,
args: Dict[str, Any],
where: Optional[Dict[str, Any]] = None,
) -> str | None:
if not args:
raise ValueError("Nothing to update. Empty dict given.")
args["updated_at"] = self._utc_iso()
affected = await self._server_update(table, args, where)
if affected:
rec = await self._server_get(table, where)
return rec.get("uid") if rec else None
merged = {**(where or {}), **args}
return await self._server_insert(table, merged)
async def _server_get(
self, table: str, where: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
where_clause, where_params = self._build_where(where)
sql = f"SELECT * FROM {table}{where_clause} LIMIT 1"
async for row in self._safe_query(table, sql, where_params, where or {}):
return row
return None
async def _server_find(
self,
table: str,
where: Optional[Dict[str, Any]] = None,
*,
limit: int = 0,
offset: int = 0,
order_by: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Find records with optional ordering."""
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {limit}" if limit else "") + (f" OFFSET {offset}" if offset else "")
sql = f"SELECT * FROM {table}{where_clause}{order_clause}{extra}"
return [row async for row in self._safe_query(table, sql, where_params, where or {})]
async def _server_count(self, table: str, where: Optional[Dict[str, Any]] = None) -> int:
# Check if table exists
if not await self._table_exists(table):
return 0
where_clause, where_params = self._build_where(where)
sql = f"SELECT COUNT(*) FROM {table}{where_clause}"
gen = self._safe_query(table, sql, where_params, where or {})
async for row in gen:
return next(iter(row.values()), 0)
return 0
async def _server_exists(self, table: str, where: Dict[str, Any]) -> bool:
return (await self._server_count(table, where)) > 0
async def _server_kv_set(
self,
key: str,
value: Any,
*,
table: str | None = None,
) -> None:
tbl = table or self._KV_TABLE
json_val = json.dumps(value, default=str)
await self._server_upsert(tbl, {"value": json_val}, {"key": key})
async def _server_kv_get(
self,
key: str,
*,
default: Any = None,
table: str | None = None,
) -> Any:
tbl = table or self._KV_TABLE
row = await self._server_get(tbl, {"key": key})
if not row:
return default
try:
return json.loads(row["value"])
except Exception:
return default
async def _server_execute_raw(self, sql: str, params: Optional[Tuple] = None) -> Any:
"""Execute raw SQL for complex queries like JOINs."""
async with self._db_semaphore:
cursor = await self._db_connection.execute(sql, params or ())
await self._db_connection.commit()
return cursor
async def _server_query_raw(
self, sql: str, params: Optional[Tuple] = None
) -> List[Dict[str, Any]]:
"""Execute raw SQL query and return results as list of dicts."""
try:
async with self._db_semaphore:
self._db_connection.row_factory = aiosqlite.Row
async with self._db_connection.execute(sql, params or ()) as cursor:
return [dict(row) async for row in cursor]
except aiosqlite.OperationalError:
# Return empty list if query fails
return []
async def _server_query_one(
self, sql: str, params: Optional[Tuple] = None
) -> Optional[Dict[str, Any]]:
"""Execute raw SQL query and return single result."""
results = await self._server_query_raw(sql + " LIMIT 1", params)
return results[0] if results else None
async def _server_create_table(
self,
table: str,
schema: Dict[str, str],
constraints: Optional[List[str]] = None,
):
"""Create table with custom schema and constraints. Always includes default columns."""
# Merge default columns with custom schema
full_schema = self._DEFAULT_COLUMNS.copy()
full_schema.update(schema)
columns = [f"`{col}` {dtype}" for col, dtype in full_schema.items()]
if constraints:
columns.extend(constraints)
columns_sql = ", ".join(columns)
async with self._db_semaphore:
await self._db_connection.execute(f"CREATE TABLE IF NOT EXISTS {table} ({columns_sql})")
await self._db_connection.commit()
await self._invalidate_column_cache(table)
async def _server_insert_unique(
self, table: str, args: Dict[str, Any], unique_fields: List[str]
) -> Union[str, None]:
"""Insert with unique constraint handling. Returns uid on success, None if duplicate."""
try:
return await self._server_insert(table, args)
except aiosqlite.IntegrityError as e:
if "UNIQUE" in str(e):
return None
raise
async def _server_aggregate(
self,
table: str,
function: str,
column: str = "*",
where: Optional[Dict[str, Any]] = None,
) -> Any:
"""Perform aggregate functions like SUM, AVG, MAX, MIN."""
# Check if table exists
if not await self._table_exists(table):
return None
where_clause, where_params = self._build_where(where)
sql = f"SELECT {function}({column}) as result FROM {table}{where_clause}"
result = await self._server_query_one(sql, tuple(where_params))
return result["result"] if result else None
async def close(self):
"""Close the connection or server."""
if not self._initialized:
return
if self._client_session:
await self._client_session.close()
if self._runner:
await self._runner.cleanup()
if os.path.exists(self._socket_path) and self._is_server:
os.unlink(self._socket_path)
if self._db_connection:
await self._db_connection.close()
class TransactionContext:
"""Context manager for database transactions."""
def __init__(self, db_connection: aiosqlite.Connection, semaphore: asyncio.Semaphore = None):
self.db_connection = db_connection
self.semaphore = semaphore
async def __aenter__(self):
if self.semaphore:
await self.semaphore.acquire()
try:
await self.db_connection.execute("BEGIN")
return self.db_connection
except:
if self.semaphore:
self.semaphore.release()
raise
async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is None:
await self.db_connection.commit()
else:
await self.db_connection.rollback()
finally:
if self.semaphore:
self.semaphore.release()
# Test cases remain the same but with additional tests for new functionality

View File

@ -1,9 +1,13 @@
import asyncio
import asyncio
import json import json
import logging import logging
import time import time
from pr.autonomous.detection import is_task_complete from pr.autonomous.detection import is_task_complete
from pr.core.api import call_api
from pr.core.context import truncate_tool_result from pr.core.context import truncate_tool_result
from pr.tools.base import get_tools_definition
from pr.ui import Colors, display_tool_call from pr.ui import Colors, display_tool_call
logger = logging.getLogger("pr") logger = logging.getLogger("pr")
@ -35,18 +39,39 @@ def run_autonomous_mode(assistant, task):
logger.debug(f"Messages after context management: {len(assistant.messages)}") logger.debug(f"Messages after context management: {len(assistant.messages)}")
from pr.core.api import call_api try:
from pr.tools.base import get_tools_definition # Try to get the current event loop
loop = asyncio.get_running_loop()
# If we're in an event loop, use run_coroutine_threadsafe
import concurrent.futures
response = call_api( with concurrent.futures.ThreadPoolExecutor() as executor:
assistant.messages, future = asyncio.run_coroutine_threadsafe(
assistant.model, call_api(
assistant.api_url, assistant.messages,
assistant.api_key, assistant.model,
assistant.use_tools, assistant.api_url,
get_tools_definition(), assistant.api_key,
verbose=assistant.verbose, assistant.use_tools,
) get_tools_definition(),
verbose=assistant.verbose,
),
loop,
)
response = future.result()
except RuntimeError:
# No event loop running, use asyncio.run
response = asyncio.run(
call_api(
assistant.messages,
assistant.model,
assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose,
)
)
if "error" in response: if "error" in response:
logger.error(f"API error in autonomous mode: {response['error']}") logger.error(f"API error in autonomous mode: {response['error']}")
@ -118,18 +143,40 @@ def process_response_autonomous(assistant, response):
for result in tool_results: for result in tool_results:
assistant.messages.append(result) assistant.messages.append(result)
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
follow_up = call_api( try:
assistant.messages, # Try to get the current event loop
assistant.model, loop = asyncio.get_running_loop()
assistant.api_url, # If we're in an event loop, use run_coroutine_threadsafe
assistant.api_key, import concurrent.futures
assistant.use_tools,
get_tools_definition(), with concurrent.futures.ThreadPoolExecutor() as executor:
verbose=assistant.verbose, future = asyncio.run_coroutine_threadsafe(
) call_api(
assistant.messages,
assistant.model,
assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose,
),
loop,
)
follow_up = future.result()
except RuntimeError:
# No event loop running, use asyncio.run
follow_up = asyncio.run(
call_api(
assistant.messages,
assistant.model,
assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose,
)
)
return process_response_autonomous(assistant, follow_up) return process_response_autonomous(assistant, follow_up)
content = message.get("content", "") content = message.get("content", "")

View File

@ -1,4 +1,6 @@
import asyncio
import json import json
import logging
import time import time
from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS from pr.commands.multiplexer_commands import MULTIPLEXER_COMMANDS
@ -40,7 +42,18 @@ def handle_command(assistant, command):
if prompt_text.strip(): if prompt_text.strip():
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, prompt_text) task = asyncio.create_task(process_message(assistant, prompt_text))
assistant.background_tasks.add(task)
task.add_done_callback(
lambda t: (
assistant.background_tasks.discard(t),
(
logging.error(f"Background task failed: {t.exception()}")
if t.exception()
else assistant.background_tasks.discard(t)
),
)
)
elif cmd == "/auto": elif cmd == "/auto":
if len(command_parts) < 2: if len(command_parts) < 2:
@ -201,7 +214,18 @@ def review_file(assistant, filename):
message = f"Please review this file and provide feedback:\n\n{result['content']}" message = f"Please review this file and provide feedback:\n\n{result['content']}"
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) task = asyncio.create_task(process_message(assistant, message))
assistant.background_tasks.add(task)
task.add_done_callback(
lambda t: (
assistant.background_tasks.discard(t),
(
logging.error(f"Background task failed: {t.exception()}")
if t.exception()
else assistant.background_tasks.discard(t)
),
)
)
else: else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
@ -212,7 +236,18 @@ def refactor_file(assistant, filename):
message = f"Please refactor this code to improve its quality:\n\n{result['content']}" message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) task = asyncio.create_task(process_message(assistant, message))
assistant.background_tasks.add(task)
task.add_done_callback(
lambda t: (
assistant.background_tasks.discard(t),
(
logging.error(f"Background task failed: {t.exception()}")
if t.exception()
else assistant.background_tasks.discard(t)
),
)
)
else: else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
@ -223,7 +258,18 @@ def obfuscate_file(assistant, filename):
message = f"Please obfuscate this code:\n\n{result['content']}" message = f"Please obfuscate this code:\n\n{result['content']}"
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) task = asyncio.create_task(process_message(assistant, message))
assistant.background_tasks.add(task)
task.add_done_callback(
lambda t: (
assistant.background_tasks.discard(t),
(
logging.error(f"Background task failed: {t.exception()}")
if t.exception()
else assistant.background_tasks.discard(t)
),
)
)
else: else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")

View File

@ -1,15 +1,14 @@
import json import json
import logging import logging
import urllib.error
import urllib.request
from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE
from pr.core.context import auto_slim_messages from pr.core.context import auto_slim_messages
from pr.core.http_client import http_client
logger = logging.getLogger("pr") logger = logging.getLogger("pr")
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False): async def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
try: try:
messages = auto_slim_messages(messages, verbose=verbose) messages = auto_slim_messages(messages, verbose=verbose)
@ -42,63 +41,70 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
data["tool_choice"] = "auto" data["tool_choice"] = "auto"
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools") logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
request_json = json.dumps(data) request_json = data
logger.debug(f"Request payload size: {len(request_json)} bytes") logger.debug(f"Request payload size: {len(request_json)} bytes")
req = urllib.request.Request(
api_url, data=request_json.encode("utf-8"), headers=headers, method="POST"
)
logger.debug("Sending HTTP request...") logger.debug("Sending HTTP request...")
with urllib.request.urlopen(req) as response: response = await http_client.post(api_url, headers=headers, json_data=request_json)
response_data = response.read().decode("utf-8")
logger.debug(f"Response received: {len(response_data)} bytes")
result = json.loads(response_data)
if "usage" in result: if response.get("error"):
logger.debug(f"Token usage: {result['usage']}") if "status" in response:
if "choices" in result and result["choices"]: logger.error(f"API HTTP Error: {response['status']} - {response.get('text', '')}")
choice = result["choices"][0] logger.debug("=== API CALL FAILED ===")
if "message" in choice: return {
msg = choice["message"] "error": f"API Error: {response['status']}",
logger.debug(f"Response role: {msg.get('role', 'N/A')}") "message": response.get("text", ""),
if "content" in msg and msg["content"]: }
logger.debug(f"Response content length: {len(msg['content'])} chars") else:
if "tool_calls" in msg: logger.error(f"API call failed: {response.get('exception', 'Unknown error')}")
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") logger.debug("=== API CALL FAILED ===")
return {"error": response.get("exception", "Unknown error")}
if verbose and "usage" in result: response_data = response["text"]
from pr.core.usage_tracker import UsageTracker logger.debug(f"Response received: {len(response_data)} bytes")
result = json.loads(response_data)
usage = result["usage"] if "usage" in result:
input_t = usage.get("prompt_tokens", 0) logger.debug(f"Token usage: {result['usage']}")
output_t = usage.get("completion_tokens", 0) if "choices" in result and result["choices"]:
cost = UsageTracker._calculate_cost(model, input_t, output_t) choice = result["choices"][0]
print(f"API call cost: €{cost:.4f}") if "message" in choice:
msg = choice["message"]
logger.debug(f"Response role: {msg.get('role', 'N/A')}")
if "content" in msg and msg["content"]:
logger.debug(f"Response content length: {len(msg['content'])} chars")
if "tool_calls" in msg:
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
if verbose and "usage" in result:
from pr.core.usage_tracker import UsageTracker
usage = result["usage"]
input_t = usage.get("prompt_tokens", 0)
output_t = usage.get("completion_tokens", 0)
UsageTracker._calculate_cost(model, input_t, output_t)
logger.debug("=== API CALL END ===") logger.debug("=== API CALL END ===")
return result return result
except urllib.error.HTTPError as e:
error_body = e.read().decode("utf-8")
logger.error(f"API HTTP Error: {e.code} - {error_body}")
logger.debug("=== API CALL FAILED ===")
return {"error": f"API Error: {e.code}", "message": error_body}
except Exception as e: except Exception as e:
logger.error(f"API call failed: {e}") logger.error(f"API call failed: {e}")
logger.debug("=== API CALL FAILED ===") logger.debug("=== API CALL FAILED ===")
return {"error": str(e)} return {"error": str(e)}
def list_models(model_list_url, api_key): async def list_models(model_list_url, api_key):
try: try:
req = urllib.request.Request(model_list_url) headers = {}
if api_key: if api_key:
req.add_header("Authorization", f"Bearer {api_key}") headers["Authorization"] = f"Bearer {api_key}"
with urllib.request.urlopen(req) as response: response = await http_client.get(model_list_url, headers=headers)
data = json.loads(response.read().decode("utf-8"))
if response.get("error"):
return {"error": response.get("text", "HTTP error")}
data = json.loads(response["text"])
return data.get("data", []) return data.get("data", [])
except Exception as e: except Exception as e:
return {"error": str(e)} return {"error": str(e)}

View File

@ -20,7 +20,7 @@ from pr.config import (
) )
from pr.core.api import call_api from pr.core.api import call_api
from pr.core.autonomous_interactions import ( from pr.core.autonomous_interactions import (
get_global_autonomous, start_global_autonomous,
stop_global_autonomous, stop_global_autonomous,
) )
from pr.core.background_monitor import ( from pr.core.background_monitor import (
@ -29,6 +29,7 @@ from pr.core.background_monitor import (
stop_global_monitor, stop_global_monitor,
) )
from pr.core.context import init_system_message, truncate_tool_result from pr.core.context import init_system_message, truncate_tool_result
from pr.core.usage_tracker import UsageTracker
from pr.tools import get_tools_definition from pr.tools import get_tools_definition
from pr.tools.agents import ( from pr.tools.agents import (
collaborate_agents, collaborate_agents,
@ -71,7 +72,7 @@ from pr.tools.memory import (
from pr.tools.patch import apply_patch, create_diff, display_file_diff from pr.tools.patch import apply_patch, create_diff, display_file_diff
from pr.tools.python_exec import python_exec from pr.tools.python_exec import python_exec
from pr.tools.web import http_fetch, web_search, web_search_news from pr.tools.web import http_fetch, web_search, web_search_news
from pr.ui import Colors, render_markdown from pr.ui import Colors, Spinner, render_markdown
logger = logging.getLogger("pr") logger = logging.getLogger("pr")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -109,6 +110,8 @@ class Assistant:
self.autonomous_mode = False self.autonomous_mode = False
self.autonomous_iterations = 0 self.autonomous_iterations = 0
self.background_monitoring = False self.background_monitoring = False
self.usage_tracker = UsageTracker()
self.background_tasks = set()
self.init_database() self.init_database()
self.messages.append(init_system_message(args)) self.messages.append(init_system_message(args))
@ -125,8 +128,7 @@ class Assistant:
# Initialize background monitoring components # Initialize background monitoring components
try: try:
start_global_monitor() start_global_monitor()
autonomous = get_global_autonomous() start_global_autonomous(llm_callback=self._handle_background_updates)
autonomous.start(llm_callback=self._handle_background_updates)
self.background_monitoring = True self.background_monitoring = True
if self.debug: if self.debug:
logger.debug("Background monitoring initialized") logger.debug("Background monitoring initialized")
@ -236,7 +238,7 @@ class Assistant:
if self.debug: if self.debug:
print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}") print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}")
def execute_tool_calls(self, tool_calls): async def execute_tool_calls(self, tool_calls):
results = [] results = []
logger.debug(f"Executing {len(tool_calls)} tool call(s)") logger.debug(f"Executing {len(tool_calls)} tool call(s)")
@ -337,7 +339,7 @@ class Assistant:
return results return results
def process_response(self, response): async def process_response(self, response):
if "error" in response: if "error" in response:
return f"Error: {response['error']}" return f"Error: {response['error']}"
@ -348,15 +350,17 @@ class Assistant:
self.messages.append(message) self.messages.append(message)
if "tool_calls" in message and message["tool_calls"]: if "tool_calls" in message and message["tool_calls"]:
if self.verbose: tool_count = len(message["tool_calls"])
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}") print(f"{Colors.BLUE}🔧 Executing {tool_count} tool call(s)...{Colors.RESET}")
tool_results = self.execute_tool_calls(message["tool_calls"]) tool_results = await self.execute_tool_calls(message["tool_calls"])
print(f"{Colors.GREEN}✅ Tool execution completed.{Colors.RESET}")
for result in tool_results: for result in tool_results:
self.messages.append(result) self.messages.append(result)
follow_up = call_api( follow_up = await call_api(
self.messages, self.messages,
self.model, self.model,
self.api_url, self.api_url,
@ -365,7 +369,7 @@ class Assistant:
get_tools_definition(), get_tools_definition(),
verbose=self.verbose, verbose=self.verbose,
) )
return self.process_response(follow_up) return await self.process_response(follow_up)
content = message.get("content", "") content = message.get("content", "")
return render_markdown(content, self.syntax_highlighting) return render_markdown(content, self.syntax_highlighting)
@ -439,12 +443,23 @@ class Assistant:
readline.set_completer(completer) readline.set_completer(completer)
readline.parse_and_bind("tab: complete") readline.parse_and_bind("tab: complete")
def run_repl(self): async def run_repl(self):
self.setup_readline() self.setup_readline()
signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGINT, self.signal_handler)
print(f"{Colors.BOLD}r{Colors.RESET}") print(
print(f"Type 'help' for commands or start chatting") f"{Colors.BOLD}{Colors.CYAN}╔══════════════════════════════════════════════╗{Colors.RESET}"
)
print(
f"{Colors.BOLD}{Colors.CYAN}{Colors.RESET}{Colors.BOLD} PR Assistant v{__import__('pr').__version__} {Colors.RESET}{Colors.BOLD}{Colors.CYAN}{Colors.RESET}"
)
print(
f"{Colors.BOLD}{Colors.CYAN}╚══════════════════════════════════════════════╝{Colors.RESET}"
)
print(
f"{Colors.GRAY}Type 'help' for commands, 'exit' to quit, or start chatting.{Colors.RESET}"
)
print(f"{Colors.GRAY}AI calls will show costs and progress indicators.{Colors.RESET}\n")
while True: while True:
try: try:
@ -480,7 +495,7 @@ class Assistant:
elif cmd_result is True: elif cmd_result is True:
continue continue
process_message(self, user_input) await process_message(self, user_input)
except EOFError: except EOFError:
break break
@ -490,7 +505,7 @@ class Assistant:
print(f"{Colors.RED}Error: {e}{Colors.RESET}") print(f"{Colors.RED}Error: {e}{Colors.RESET}")
logging.error(f"REPL error: {e}\n{traceback.format_exc()}") logging.error(f"REPL error: {e}\n{traceback.format_exc()}")
def run_single(self): async def run_single(self):
if self.args.message: if self.args.message:
message = self.args.message message = self.args.message
else: else:
@ -498,7 +513,7 @@ class Assistant:
from pr.autonomous.mode import run_autonomous_mode from pr.autonomous.mode import run_autonomous_mode
run_autonomous_mode(self, message) await run_autonomous_mode(self, message)
def cleanup(self): def cleanup(self):
if hasattr(self, "enhanced") and self.enhanced: if hasattr(self, "enhanced") and self.enhanced:
@ -525,22 +540,22 @@ class Assistant:
if self.db_conn: if self.db_conn:
self.db_conn.close() self.db_conn.close()
def run(self): async def run(self):
try: try:
print( print(
f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}" f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}"
) )
if self.args.interactive or (not self.args.message and sys.stdin.isatty()): if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
print("DEBUG: calling run_repl") print("DEBUG: calling run_repl")
self.run_repl() await self.run_repl()
else: else:
print("DEBUG: calling run_single") print("DEBUG: calling run_single")
self.run_single() await self.run_single()
finally: finally:
self.cleanup() self.cleanup()
def process_message(assistant, message): async def process_message(assistant, message):
from pr.core.knowledge_context import inject_knowledge_context from pr.core.knowledge_context import inject_knowledge_context
inject_knowledge_context(assistant, message) inject_knowledge_context(assistant, message)
@ -550,10 +565,11 @@ def process_message(assistant, message):
logger.debug(f"Processing user message: {message[:100]}...") logger.debug(f"Processing user message: {message[:100]}...")
logger.debug(f"Current message count: {len(assistant.messages)}") logger.debug(f"Current message count: {len(assistant.messages)}")
if assistant.verbose: # Start spinner for AI call
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}") spinner = Spinner("Querying AI...")
await spinner.start()
response = call_api( response = await call_api(
assistant.messages, assistant.messages,
assistant.model, assistant.model,
assistant.api_url, assistant.api_url,
@ -562,6 +578,19 @@ def process_message(assistant, message):
get_tools_definition(), get_tools_definition(),
verbose=assistant.verbose, verbose=assistant.verbose,
) )
result = assistant.process_response(response)
await spinner.stop()
# Track usage and display cost
if "usage" in response:
usage = response["usage"]
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
assistant.usage_tracker.track_request(assistant.model, input_tokens, output_tokens)
cost = UsageTracker._calculate_cost(assistant.model, input_tokens, output_tokens)
total_cost = assistant.usage_tracker.session_usage["estimated_cost"]
print(f"{Colors.YELLOW}💰 Cost: ${cost:.4f} | Total: ${total_cost:.4f}{Colors.RESET}")
result = await assistant.process_response(response)
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n") print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")

View File

@ -1,3 +1,4 @@
import asyncio
import json import json
import logging import logging
import uuid import uuid
@ -128,15 +129,39 @@ class EnhancedAssistant:
def _api_caller_for_agent( def _api_caller_for_agent(
self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return call_api( try:
messages, # Try to get the current event loop
self.base.model, loop = asyncio.get_running_loop()
self.base.api_url, # If we're in an event loop, use run_coroutine_threadsafe
self.base.api_key, import concurrent.futures
use_tools=False,
tools_definition=[], with concurrent.futures.ThreadPoolExecutor() as executor:
verbose=self.base.verbose, future = asyncio.run_coroutine_threadsafe(
) call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
use_tools=False,
tools_definition=[],
verbose=self.base.verbose,
),
loop,
)
return future.result()
except RuntimeError:
# No event loop running, use asyncio.run
return asyncio.run(
call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
use_tools=False,
tools_definition=[],
verbose=self.base.verbose,
)
)
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.api_cache and CACHE_ENABLED: if self.api_cache and CACHE_ENABLED:
@ -145,15 +170,39 @@ class EnhancedAssistant:
logger.debug("API cache hit") logger.debug("API cache hit")
return cached_response return cached_response
response = call_api( try:
messages, # Try to get the current event loop
self.base.model, loop = asyncio.get_running_loop()
self.base.api_url, # If we're in an event loop, use run_coroutine_threadsafe
self.base.api_key, import concurrent.futures
self.base.use_tools,
get_tools_definition(), with concurrent.futures.ThreadPoolExecutor() as executor:
verbose=self.base.verbose, future = asyncio.run_coroutine_threadsafe(
) call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose,
),
loop,
)
response = future.result()
except RuntimeError:
# No event loop running, use asyncio.run
response = asyncio.run(
call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose,
)
)
if self.api_cache and CACHE_ENABLED and "error" not in response: if self.api_cache and CACHE_ENABLED and "error" not in response:
token_count = response.get("usage", {}).get("total_tokens", 0) token_count = response.get("usage", {}).get("total_tokens", 0)

136
pr/core/http_client.py Normal file
View File

@ -0,0 +1,136 @@
import asyncio
import json
import logging
import socket
import time
import urllib.error
import urllib.request
from typing import Dict, Any, Optional
logger = logging.getLogger("pr")
class AsyncHTTPClient:
def __init__(self):
self.session_headers = {}
async def request(
self,
method: str,
url: str,
headers: Optional[Dict[str, str]] = None,
data: Optional[bytes] = None,
json_data: Optional[Dict[str, Any]] = None,
timeout: float = 30.0,
) -> Dict[str, Any]:
"""Make an async HTTP request using urllib in a thread executor with retry logic."""
loop = asyncio.get_event_loop()
# Prepare headers
request_headers = {**self.session_headers}
if headers:
request_headers.update(headers)
# Prepare data
request_data = data
if json_data is not None:
request_data = json.dumps(json_data).encode("utf-8")
request_headers["Content-Type"] = "application/json"
# Create request object
req = urllib.request.Request(url, data=request_data, headers=request_headers, method=method)
attempt = 0
start_time = time.time()
while True:
attempt += 1
try:
# Execute in thread pool
response = await loop.run_in_executor(
None, lambda: urllib.request.urlopen(req, timeout=timeout)
)
response_data = await loop.run_in_executor(None, response.read)
response_text = response_data.decode("utf-8")
return {
"status": response.status,
"headers": dict(response.headers),
"text": response_text,
"json": lambda: json.loads(response_text) if response_text else None,
}
except urllib.error.HTTPError as e:
error_body = await loop.run_in_executor(None, e.read)
error_text = error_body.decode("utf-8")
return {
"status": e.code,
"error": True,
"text": error_text,
"json": lambda: json.loads(error_text) if error_text else None,
}
except socket.timeout:
# Handle socket timeouts specifically
elapsed = time.time() - start_time
elapsed_minutes = int(elapsed // 60)
elapsed_seconds = elapsed % 60
duration_str = (
f"{elapsed_minutes}m {elapsed_seconds:.1f}s"
if elapsed_minutes > 0
else f"{elapsed_seconds:.1f}s"
)
logger.warning(
f"Request timed out (attempt {attempt}, "
f"duration: {duration_str}). Retrying in {attempt} second(s)..."
)
# Exponential backoff starting at 1 second
await asyncio.sleep(attempt)
except Exception as e:
error_msg = str(e)
# For other exceptions, check if they might be timeout-related
if "timed out" in error_msg.lower() or "timeout" in error_msg.lower():
elapsed = time.time() - start_time
elapsed_minutes = int(elapsed // 60)
elapsed_seconds = elapsed % 60
duration_str = (
f"{elapsed_minutes}m {elapsed_seconds:.1f}s"
if elapsed_minutes > 0
else f"{elapsed_seconds:.1f}s"
)
logger.warning(
f"Request timed out (attempt {attempt}, "
f"duration: {duration_str}). Retrying in {attempt} second(s)..."
)
# Exponential backoff starting at 1 second
await asyncio.sleep(attempt)
else:
# Non-timeout errors should not be retried
return {"error": True, "exception": error_msg}
async def get(
self, url: str, headers: Optional[Dict[str, str]] = None, timeout: float = 30.0
) -> Dict[str, Any]:
return await self.request("GET", url, headers=headers, timeout=timeout)
async def post(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
data: Optional[bytes] = None,
json_data: Optional[Dict[str, Any]] = None,
timeout: float = 30.0,
) -> Dict[str, Any]:
return await self.request(
"POST", url, headers=headers, data=data, json_data=json_data, timeout=timeout
)
def set_default_headers(self, headers: Dict[str, str]):
self.session_headers.update(headers)
# Global client instance
http_client = AsyncHTTPClient()

View File

@ -1,585 +0,0 @@
#!/usr/bin/env python3
import asyncio
import logging
import os
import sys
import subprocess
import urllib.parse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import pathlib
import ast
from types import SimpleNamespace
import time
from aiohttp import web
import aiohttp
# --- Optional external dependency used in your original code ---
from pr.ads import AsyncDataSet # noqa: E402
import sys
from http import cookies
import urllib.parse
import os
import io
import json
import urllib.request
import pathlib
import pickle
import zlib
import asyncio
import time
from functools import wraps
from pathlib import Path
import pickle
import zlib
class Cache:
def __init__(self, base_dir: Path | str = "."):
self.base_dir = Path(base_dir).resolve()
self.base_dir.mkdir(exist_ok=True, parents=True)
def is_cacheable(self, obj):
return isinstance(obj, (int, str, bool, float))
def generate_key(self, *args, **kwargs):
return zlib.crc32(
json.dumps(
{
"args": [arg for arg in args if self.is_cacheable(arg)],
"kwargs": {k: v for k, v in kwargs.items() if self.is_cacheable(v)},
},
sort_keys=True,
default=str,
).encode()
)
def set(self, key, value):
key = self.generate_key(key)
data = {"value": value, "timestamp": time.time()}
serialized_data = pickle.dumps(data)
with open(self.base_dir.joinpath(f"{key}.cache"), "wb") as f:
f.write(serialized_data)
def get(self, key, default=None):
key = self.generate_key(key)
try:
with open(self.base_dir.joinpath(f"{key}.cache"), "rb") as f:
data = pickle.loads(f.read())
return data
except FileNotFoundError:
return default
def is_cacheable(self, obj):
return isinstance(obj, (int, str, bool, float))
def cached(self, expiration=60):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# if not all(self.is_cacheable(arg) for arg in args) or not all(self.is_cacheable(v) for v in kwargs.values()):
# return await func(*args, **kwargs)
key = self.generate_key(*args, **kwargs)
print("Cache hit:", key)
cached_data = self.get(key)
if cached_data is not None:
if expiration is None or (time.time() - cached_data["timestamp"]) < expiration:
return cached_data["value"]
result = await func(*args, **kwargs)
self.set(key, result)
return result
return wrapper
return decorator
cache = Cache()
class CGI:
def __init__(self):
self.environ = os.environ
if not self.gateway_interface == "CGI/1.1":
return
self.status = 200
self.headers = {}
self.cache = Cache(pathlib.Path(__file__).parent.joinpath("cache/cgi"))
self.headers["Content-Length"] = "0"
self.headers["Cache-Control"] = "no-cache"
self.headers["Access-Control-Allow-Origin"] = "*"
self.headers["Content-Type"] = "application/json; charset=utf-8"
self.cookies = cookies.SimpleCookie()
self.query = urllib.parse.parse_qs(os.environ["QUERY_STRING"])
self.file = io.BytesIO()
def validate(self, val, fn, error):
if fn(val):
return True
self.status = 421
self.write({"type": "error", "message": error})
exit()
def __getitem__(self, key):
return self.get(key)
def get(self, key, default=None):
result = self.query.get(key, [default])
if len(result) == 1:
return result[0]
return result
def get_bool(self, key):
return str(self.get(key)).lower() in ["true", "yes", "1"]
@property
def cache_key(self):
return self.environ["CACHE_KEY"]
@property
def env(self):
env = os.environ.copy()
return env
@property
def gateway_interface(self):
return self.env.get("GATEWAY_INTERFACE", "")
@property
def request_method(self):
return self.env["REQUEST_METHOD"]
@property
def query_string(self):
return self.env["QUERY_STRING"]
@property
def script_name(self):
return self.env["SCRIPT_NAME"]
@property
def path_info(self):
return self.env["PATH_INFO"]
@property
def server_name(self):
return self.env["SERVER_NAME"]
@property
def server_port(self):
return self.env["SERVER_PORT"]
@property
def server_protocol(self):
return self.env["SERVER_PROTOCOL"]
@property
def remote_addr(self):
return self.env["REMOTE_ADDR"]
def validate_get(self, key):
self.validate(self.get(key), lambda x: bool(x), f"Missing {key}")
def print(self, data):
self.write(data)
def write(self, data):
if not isinstance(data, bytes) and not isinstance(data, str):
data = json.dumps(data, default=str, indent=4)
try:
data = data.encode()
except:
pass
self.file.write(data)
self.headers["Content-Length"] = str(len(self.file.getvalue()))
@property
def http_status(self):
return f"HTTP/1.1 {self.status}\r\n".encode("utf-8")
@property
def http_headers(self):
headers = io.BytesIO()
for header in self.headers:
headers.write(f"{header}: {self.headers[header]}\r\n".encode("utf-8"))
headers.write(b"\r\n")
return headers.getvalue()
@property
def http_body(self):
return self.file.getvalue()
@property
def http_response(self):
return self.http_status + self.http_headers + self.http_body
def flush(self, response=None):
if response:
try:
response = response.encode()
except:
pass
sys.stdout.buffer.write(response)
sys.stdout.buffer.flush()
return
sys.stdout.buffer.write(self.http_response)
sys.stdout.buffer.flush()
def __del__(self):
if self.http_body:
self.flush()
exit()
# -------------------------------
# Utilities
# -------------------------------
def get_function_source(name: str, directory: str = ".") -> List[Tuple[str, str]]:
matches: List[Tuple[str, str]] = []
for root, _, files in os.walk(directory):
for file in files:
if not file.endswith(".py"):
continue
path = os.path.join(root, file)
try:
with open(path, "r", encoding="utf-8") as fh:
source = fh.read()
tree = ast.parse(source, filename=path)
for node in ast.walk(tree):
if isinstance(node, ast.AsyncFunctionDef) and node.name == name:
func_src = ast.get_source_segment(source, node)
if func_src:
matches.append((path, func_src))
break
except (SyntaxError, UnicodeDecodeError):
continue
return matches
# -------------------------------
# Server
# -------------------------------
class Static(SimpleNamespace):
pass
class View(web.View):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.db = self.request.app["db"]
self.server = self.request.app["server"]
@property
def client(self):
return aiohttp.ClientSession()
class RetoorServer:
def __init__(self, base_dir: Path | str = ".", port: int = 8118) -> None:
self.base_dir = Path(base_dir).resolve()
self.port = port
self.static = Static()
self.db = AsyncDataSet(".default.db")
self._logger = logging.getLogger("retoor.server")
self._logger.setLevel(logging.INFO)
if not self._logger.handlers:
h = logging.StreamHandler(sys.stdout)
fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s")
h.setFormatter(fmt)
self._logger.addHandler(h)
self._func_cache: Dict[str, Tuple[str, str]] = {}
self._compiled_cache: Dict[str, object] = {}
self._cgi_cache: Dict[str, Tuple[Optional[Path], Optional[str], Optional[str]]] = {}
self._static_path_cache: Dict[str, Path] = {}
self._base_dir_str = str(self.base_dir)
self.app = web.Application()
app = self.app
app["db"] = self.db
for path in pathlib.Path(__file__).parent.joinpath("web").joinpath("views").iterdir():
if path.suffix == ".py":
exec(open(path).read(), globals(), locals())
self.app["server"] = self
# from rpanel import create_app as create_rpanel_app
# self.app.add_subapp("/api/{tail:.*}", create_rpanel_app())
self.app.router.add_route("*", "/{tail:.*}", self.handle_any)
# ---------------------------
# Simple endpoints
# ---------------------------
async def handle_root(self, request: web.Request) -> web.Response:
return web.Response(text="Static file and cgi server from retoor.")
async def handle_hello(self, request: web.Request) -> web.Response:
return web.Response(text="Welcome to the custom HTTP server!")
# ---------------------------
# Dynamic function dispatch
# ---------------------------
def _path_to_funcname(self, path: str) -> str:
# /foo/bar -> foo_bar ; strip leading/trailing slashes safely
return path.strip("/").replace("/", "_")
async def _maybe_dispatch_dynamic(self, request: web.Request) -> Optional[web.StreamResponse]:
relpath = request.path.strip("/")
relpath = str(pathlib.Path(__file__).joinpath("cgi").joinpath(relpath).resolve()).replace(
"..", ""
)
if not relpath:
return None
last_part = relpath.split("/")[-1]
if "." in last_part:
return None
funcname = self._path_to_funcname(request.path)
if funcname not in self._compiled_cache:
entry = self._func_cache.get(funcname)
if entry is None:
matches = get_function_source(funcname, directory=str(self.base_dir))
if not matches:
return None
entry = matches[0]
self._func_cache[funcname] = entry
filepath, source = entry
code_obj = compile(source, filepath, "exec")
self._compiled_cache[funcname] = (code_obj, filepath)
code_obj, filepath = self._compiled_cache[funcname]
ctx = {
"static": self.static,
"request": request,
"relpath": relpath,
"os": os,
"sys": sys,
"web": web,
"asyncio": asyncio,
"subprocess": subprocess,
"urllib": urllib.parse,
"app": request.app,
"db": self.db,
}
exec(code_obj, ctx, ctx)
coro = ctx.get(funcname)
if not callable(coro):
return web.Response(status=500, text=f"Dynamic handler '{funcname}' is not callable.")
return await coro(request)
# ---------------------------
# Static files
# ---------------------------
async def handle_static(self, request: web.Request) -> web.StreamResponse:
path = request.path
if path in self._static_path_cache:
cached_path = self._static_path_cache[path]
if cached_path is None:
return web.Response(status=404, text="Not found")
try:
return web.FileResponse(path=cached_path)
except Exception as e:
return web.Response(status=500, text=f"Failed to send file: {e!r}")
relpath = path.replace('..",', "").lstrip("/").rstrip("/")
abspath = (self.base_dir / relpath).resolve()
if not str(abspath).startswith(self._base_dir_str):
return web.Response(status=403, text="Forbidden")
if not abspath.exists():
self._static_path_cache[path] = None
return web.Response(status=404, text="Not found")
if abspath.is_dir():
return web.Response(status=403, text="Directory listing forbidden")
self._static_path_cache[path] = abspath
try:
return web.FileResponse(path=abspath)
except Exception as e:
return web.Response(status=500, text=f"Failed to send file: {e!r}")
# ---------------------------
# CGI
# ---------------------------
def _find_cgi_script(
self, path_only: str
) -> Tuple[Optional[Path], Optional[str], Optional[str]]:
path_only = "pr/cgi/" + path_only
if path_only in self._cgi_cache:
return self._cgi_cache[path_only]
split_path = [p for p in path_only.split("/") if p]
for i in range(len(split_path), -1, -1):
candidate = "/" + "/".join(split_path[:i])
candidate_fs = (self.base_dir / candidate.lstrip("/")).resolve()
if (
str(candidate_fs).startswith(self._base_dir_str)
and candidate_fs.parent.is_dir()
and candidate_fs.parent.name == "cgi"
and candidate_fs.suffix == ".bin"
and candidate_fs.is_file()
and os.access(candidate_fs, os.X_OK)
):
script_name = candidate if candidate.startswith("/") else "/" + candidate
path_info = path_only[len(script_name) :]
result = (candidate_fs, script_name, path_info)
self._cgi_cache[path_only] = result
return result
result = (None, None, None)
self._cgi_cache[path_only] = result
return result
async def _run_cgi(
self, request: web.Request, script_path: Path, script_name: str, path_info: str
) -> web.Response:
start_time = time.time()
method = request.method
query = request.query_string
env = os.environ.copy()
env["CACHE_KEY"] = json.dumps(
{"method": method, "query": query, "script_name": script_name, "path_info": path_info},
default=str,
)
env["GATEWAY_INTERFACE"] = "CGI/1.1"
env["REQUEST_METHOD"] = method
env["QUERY_STRING"] = query
env["SCRIPT_NAME"] = script_name
env["PATH_INFO"] = path_info
host_parts = request.host.split(":", 1)
env["SERVER_NAME"] = host_parts[0]
env["SERVER_PORT"] = host_parts[1] if len(host_parts) > 1 else "80"
env["SERVER_PROTOCOL"] = "HTTP/1.1"
peername = request.transport.get_extra_info("peername")
env["REMOTE_ADDR"] = request.headers.get("HOST_X_FORWARDED_FOR", peername[0])
for hk, hv in request.headers.items():
hk_lower = hk.lower()
if hk_lower == "content-type":
env["CONTENT_TYPE"] = hv
elif hk_lower == "content-length":
env["CONTENT_LENGTH"] = hv
else:
env["HTTP_" + hk.upper().replace("-", "_")] = hv
post_data = None
if method in {"POST", "PUT", "PATCH"}:
post_data = await request.read()
env.setdefault("CONTENT_LENGTH", str(len(post_data)))
try:
proc = await asyncio.create_subprocess_exec(
str(script_path),
stdin=asyncio.subprocess.PIPE if post_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
stdout, stderr = await proc.communicate(input=post_data)
if proc.returncode != 0:
msg = stderr.decode(errors="ignore")
return web.Response(status=500, text=f"CGI script error:\n{msg}")
header_end = stdout.find(b"\r\n\r\n")
if header_end != -1:
offset = 4
else:
header_end = stdout.find(b"\n\n")
offset = 2
if header_end != -1:
body = stdout[header_end + offset :]
headers_blob = stdout[:header_end].decode(errors="ignore")
status_code = 200
headers: Dict[str, str] = {}
for line in headers_blob.splitlines():
line = line.strip()
if not line or ":" not in line:
continue
k, v = line.split(":", 1)
k_stripped = k.strip()
if k_stripped.lower() == "status":
try:
status_code = int(v.strip().split()[0])
except Exception:
status_code = 200
else:
headers[k_stripped] = v.strip()
end_time = time.time()
headers["X-Powered-By"] = "retoor"
headers["X-Date"] = time.strftime("%a, %d %b %Y %H:%M:%S %Z", time.localtime())
headers["X-Time"] = str(end_time)
headers["X-Duration"] = str(end_time - start_time)
if "error" in str(body).lower():
print(headers)
print(body)
return web.Response(body=body, headers=headers, status=status_code)
else:
return web.Response(body=stdout)
except Exception as e:
return web.Response(status=500, text=f"Failed to execute CGI script.\n{e!r}")
# ---------------------------
# Main router
# ---------------------------
async def handle_any(self, request: web.Request) -> web.StreamResponse:
path = request.path
if path == "/":
return await self.handle_root(request)
if path == "/hello":
return await self.handle_hello(request)
script_path, script_name, path_info = self._find_cgi_script(path)
if script_path:
return await self._run_cgi(request, script_path, script_name or "", path_info or "")
dyn = await self._maybe_dispatch_dynamic(request)
if dyn is not None:
return dyn
return await self.handle_static(request)
# ---------------------------
# Runner
# ---------------------------
def run(self) -> None:
self._logger.info("Serving at port %d (base_dir=%s)", self.port, self.base_dir)
web.run_app(self.app, port=self.port)
def main() -> None:
server = RetoorServer(base_dir=".", port=8118)
server.run()
if __name__ == "__main__":
main()
else:
cgi = CGI()

View File

@ -43,11 +43,6 @@ from pr.tools.memory import (
from pr.tools.patch import apply_patch, create_diff from pr.tools.patch import apply_patch, create_diff
from pr.tools.python_exec import python_exec from pr.tools.python_exec import python_exec
from pr.tools.web import http_fetch, web_search, web_search_news from pr.tools.web import http_fetch, web_search, web_search_news
from pr.tools.context_modifier import (
modify_context_add,
modify_context_replace,
modify_context_delete,
)
__all__ = [ __all__ = [
"add_knowledge_entry", "add_knowledge_entry",

View File

@ -1,3 +1,4 @@
import asyncio
import os import os
from typing import Any, Dict, List from typing import Any, Dict, List
@ -16,15 +17,39 @@ def _create_api_wrapper():
tools_definition = get_tools_definition() if use_tools else [] tools_definition = get_tools_definition() if use_tools else []
def api_wrapper(messages, temperature=None, max_tokens=None, **kwargs): def api_wrapper(messages, temperature=None, max_tokens=None, **kwargs):
return call_api( try:
messages=messages, # Try to get the current event loop
model=model, loop = asyncio.get_running_loop()
api_url=api_url, # If we're in an event loop, use run_coroutine_threadsafe
api_key=api_key, import concurrent.futures
use_tools=use_tools,
tools_definition=tools_definition, with concurrent.futures.ThreadPoolExecutor() as executor:
verbose=False, future = asyncio.run_coroutine_threadsafe(
) call_api(
messages=messages,
model=model,
api_url=api_url,
api_key=api_key,
use_tools=use_tools,
tools_definition=tools_definition,
verbose=False,
),
loop,
)
return future.result()
except RuntimeError:
# No event loop running, use asyncio.run
return asyncio.run(
call_api(
messages=messages,
model=model,
api_url=api_url,
api_key=api_key,
use_tools=use_tools,
tools_definition=tools_definition,
verbose=False,
)
)
return api_wrapper return api_wrapper

View File

@ -1,18 +1,21 @@
import os import os
from typing import Optional from typing import Optional
CONTEXT_FILE = '/home/retoor/.local/share/rp/.rcontext.txt' CONTEXT_FILE = "/home/retoor/.local/share/rp/.rcontext.txt"
def _read_context() -> str: def _read_context() -> str:
if not os.path.exists(CONTEXT_FILE): if not os.path.exists(CONTEXT_FILE):
raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.") raise FileNotFoundError(f"Context file {CONTEXT_FILE} not found.")
with open(CONTEXT_FILE, 'r') as f: with open(CONTEXT_FILE, "r") as f:
return f.read() return f.read()
def _write_context(content: str): def _write_context(content: str):
with open(CONTEXT_FILE, 'w') as f: with open(CONTEXT_FILE, "w") as f:
f.write(content) f.write(content)
def modify_context_add(new_content: str, position: Optional[str] = None) -> str: def modify_context_add(new_content: str, position: Optional[str] = None) -> str:
""" """
Add new content to the .rcontext.txt file. Add new content to the .rcontext.txt file.
@ -25,13 +28,14 @@ def modify_context_add(new_content: str, position: Optional[str] = None) -> str:
if position and position in current: if position and position in current:
# Insert before the position # Insert before the position
parts = current.split(position, 1) parts = current.split(position, 1)
updated = parts[0] + new_content + '\n\n' + position + parts[1] updated = parts[0] + new_content + "\n\n" + position + parts[1]
else: else:
# Append at the end # Append at the end
updated = current + '\n\n' + new_content updated = current + "\n\n" + new_content
_write_context(updated) _write_context(updated)
return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested." return f"Added: {new_content[:100]}... (full addition applied). Consequences: Enhances functionality as requested."
def modify_context_replace(old_content: str, new_content: str) -> str: def modify_context_replace(old_content: str, new_content: str) -> str:
""" """
Replace old content with new content in .rcontext.txt. Replace old content with new content in .rcontext.txt.
@ -47,6 +51,7 @@ def modify_context_replace(old_content: str, new_content: str) -> str:
_write_context(updated) _write_context(updated)
return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects." return f"Replaced: '{old_content[:50]}...' with '{new_content[:50]}...'. Consequences: Changes behavior as specified; verify for unintended effects."
def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str: def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> str:
""" """
Delete content from .rcontext.txt, but only if confirmed. Delete content from .rcontext.txt, but only if confirmed.
@ -56,10 +61,12 @@ def modify_context_delete(content_to_delete: str, confirmed: bool = False) -> st
confirmed: Must be True to proceed with deletion. confirmed: Must be True to proceed with deletion.
""" """
if not confirmed: if not confirmed:
raise PermissionError(f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently.") raise PermissionError(
f"Deletion not confirmed. To delete '{content_to_delete[:50]}...', you must explicitly confirm. Are you sure? This may affect system behavior permanently."
)
current = _read_context() current = _read_context()
if content_to_delete not in current: if content_to_delete not in current:
raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...") raise ValueError(f"Content to delete not found: {content_to_delete[:50]}...")
updated = current.replace(content_to_delete, '', 1) updated = current.replace(content_to_delete, "", 1)
_write_context(updated) _write_context(updated)
return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines." return f"Deleted: '{content_to_delete[:50]}...'. Consequences: Removed specified content; system may lose referenced rules or guidelines."

View File

@ -1,9 +1,10 @@
from pr.ui.colors import Colors from pr.ui.colors import Colors, Spinner
from pr.ui.display import display_tool_call, print_autonomous_header from pr.ui.display import display_tool_call, print_autonomous_header
from pr.ui.rendering import highlight_code, render_markdown from pr.ui.rendering import highlight_code, render_markdown
__all__ = [ __all__ = [
"Colors", "Colors",
"Spinner",
"highlight_code", "highlight_code",
"render_markdown", "render_markdown",
"display_tool_call", "display_tool_call",

View File

@ -1,3 +1,6 @@
import asyncio
class Colors: class Colors:
RESET = "\033[0m" RESET = "\033[0m"
BOLD = "\033[1m" BOLD = "\033[1m"
@ -12,3 +15,30 @@ class Colors:
BG_BLUE = "\033[44m" BG_BLUE = "\033[44m"
BG_GREEN = "\033[42m" BG_GREEN = "\033[42m"
BG_RED = "\033[41m" BG_RED = "\033[41m"
class Spinner:
def __init__(self, message="Processing...", spinner_chars="|/-\\"):
self.message = message
self.spinner_chars = spinner_chars
self.running = False
self.task = None
async def start(self):
self.running = True
self.task = asyncio.create_task(self._spin())
async def stop(self):
self.running = False
if self.task:
await self.task
# Clear the line
print("\r" + " " * (len(self.message) + 2) + "\r", end="", flush=True)
async def _spin(self):
i = 0
while self.running:
char = self.spinner_chars[i % len(self.spinner_chars)]
print(f"\r{Colors.CYAN}{char}{Colors.RESET} {self.message}", end="", flush=True)
i += 1
await asyncio.sleep(0.1)

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "rp" name = "rp"
version = "1.4.0" version = "1.5.0"
description = "R python edition. The ultimate autonomous AI CLI." description = "R python edition. The ultimate autonomous AI CLI."
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
@ -14,7 +14,6 @@ authors = [
{name = "retoor", email = "retoor@molodetz.nl"} {name = "retoor", email = "retoor@molodetz.nl"}
] ]
dependencies = [ dependencies = [
"aiohttp>=3.13.2",
"pydantic>=2.12.3", "pydantic>=2.12.3",
"prompt_toolkit>=3.0.0", "prompt_toolkit>=3.0.0",
] ]
@ -32,8 +31,6 @@ classifiers = [
dev = [ dev = [
"pytest>=8.3.0", "pytest>=8.3.0",
"pytest-asyncio>=1.2.0", "pytest-asyncio>=1.2.0",
"pytest-aiohttp>=1.1.0",
"aiohttp>=3.13.2",
"pytest-cov>=7.0.0", "pytest-cov>=7.0.0",
"black>=25.9.0", "black>=25.9.0",
"flake8>=7.3.0", "flake8>=7.3.0",

1
rp.py
View File

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Trigger build # Trigger build
import sys import sys
import os import os

View File

@ -1,11 +1,8 @@
import os import os
import tempfile import tempfile
from pathlib import Path
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from pr.ads import AsyncDataSet
from pr.web.app import create_app
@pytest.fixture @pytest.fixture
@ -44,77 +41,3 @@ def sample_context_file(temp_dir):
with open(context_path, "w") as f: with open(context_path, "w") as f:
f.write("Sample context content\n") f.write("Sample context content\n")
return context_path return context_path
@pytest.fixture
async def client(aiohttp_client, monkeypatch):
"""Create a test client for the app."""
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp:
temp_db_file = tmp.name
# Monkeypatch the db
monkeypatch.setattr("pr.web.views.base.db", AsyncDataSet(temp_db_file, f"{temp_db_file}.sock"))
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
static_dir = tmp_path / "static"
templates_dir = tmp_path / "templates"
repos_dir = tmp_path / "repos"
static_dir.mkdir()
templates_dir.mkdir()
repos_dir.mkdir()
# Monkeypatch the directories
monkeypatch.setattr("pr.web.config.STATIC_DIR", static_dir)
monkeypatch.setattr("pr.web.config.TEMPLATES_DIR", templates_dir)
monkeypatch.setattr("pr.web.config.REPOS_DIR", repos_dir)
monkeypatch.setattr("pr.web.config.REPOS_DIR", repos_dir)
# Create minimal templates
(templates_dir / "index.html").write_text("<html>Index</html>")
(templates_dir / "login.html").write_text("<html>Login</html>")
(templates_dir / "register.html").write_text("<html>Register</html>")
(templates_dir / "dashboard.html").write_text("<html>Dashboard</html>")
(templates_dir / "repos.html").write_text("<html>Repos</html>")
(templates_dir / "api_keys.html").write_text("<html>API Keys</html>")
(templates_dir / "repo.html").write_text("<html>Repo</html>")
(templates_dir / "file.html").write_text("<html>File</html>")
(templates_dir / "edit_file.html").write_text("<html>Edit File</html>")
(templates_dir / "deploy.html").write_text("<html>Deploy</html>")
app_instance = await create_app()
client = await aiohttp_client(app_instance)
yield client
# Cleanup db
os.unlink(temp_db_file)
sock_file = f"{temp_db_file}.sock"
if os.path.exists(sock_file):
os.unlink(sock_file)
@pytest.fixture
async def authenticated_client(client):
"""Create a client with an authenticated user."""
# Register a user
resp = await client.post(
"/register",
data={
"username": "testuser",
"email": "test@example.com",
"password": "password123",
"confirm_password": "password123",
},
allow_redirects=False,
)
assert resp.status == 302 # Redirect to login
# Login
resp = await client.post(
"/login", data={"username": "testuser", "password": "password123"}, allow_redirects=False
)
assert resp.status == 302 # Redirect to dashboard
return client