From 1a29ee49180595510cf2838af033a4fc91155686 Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 4 Nov 2025 08:09:12 +0100 Subject: [PATCH] Update. --- CONTRIBUTING.md | 5 +- pr/__init__.py | 4 +- pr/__main__.py | 116 +++++--- pr/agents/__init__.py | 15 +- pr/agents/agent_communication.py | 149 ++++++---- pr/agents/agent_manager.py | 129 +++++---- pr/agents/agent_roles.py | 279 +++++++++++------- pr/autonomous/__init__.py | 4 +- pr/autonomous/detection.py | 31 +- pr/autonomous/mode.py | 168 ++++++----- pr/cache/__init__.py | 2 +- pr/cache/api_cache.py | 86 ++++-- pr/cache/tool_cache.py | 98 ++++--- pr/commands/__init__.py | 2 +- pr/commands/handlers.py | 260 ++++++++++------- pr/commands/multiplexer_commands.py | 112 +++++--- pr/config.py | 97 +++++-- pr/core/__init__.py | 10 +- pr/core/advanced_context.py | 58 ++-- pr/core/api.py | 72 ++--- pr/core/assistant.py | 423 ++++++++++++++++++---------- pr/core/autonomous_interactions.py | 79 ++++-- pr/core/background_monitor.py | 158 ++++++----- pr/core/config_loader.py | 21 +- pr/core/context.py | 163 ++++++++--- pr/core/enhanced_assistant.py | 194 +++++++------ pr/core/logging.py | 19 +- pr/core/session.py | 65 +++-- pr/core/usage_tracker.py | 125 ++++---- pr/core/validation.py | 10 +- pr/editor.py | 385 +++++++++++++------------ pr/editor2.py | 247 ++++++++-------- pr/input_handler.py | 36 +-- pr/memory/__init__.py | 13 +- pr/memory/conversation_memory.py | 256 +++++++++++------ pr/memory/fact_extractor.py | 241 +++++++++++----- pr/memory/knowledge_store.py | 169 ++++++----- pr/memory/semantic_index.py | 13 +- pr/multiplexer.py | 171 ++++++----- pr/plugins/loader.py | 23 +- pr/tools/__init__.py | 101 +++++-- pr/tools/agents.py | 40 ++- pr/tools/base.py | 417 +++++++++++++++++---------- pr/tools/command.py | 38 ++- pr/tools/database.py | 16 +- pr/tools/editor.py | 69 +++-- pr/tools/filesystem.py | 198 +++++++++---- pr/tools/interactive_control.py | 61 ++-- pr/tools/memory.py | 64 +++-- pr/tools/patch.py | 75 +++-- pr/tools/process_handlers.py | 264 ++++++++--------- pr/tools/prompt_detection.py | 320 +++++++++++---------- pr/tools/python_exec.py | 3 +- pr/tools/web.py | 15 +- pr/ui/__init__.py | 10 +- pr/ui/colors.py | 26 +- pr/ui/diff_display.py | 166 +++++++---- pr/ui/display.py | 9 +- pr/ui/edit_feedback.py | 98 ++++--- pr/ui/output.py | 32 +-- pr/ui/progress.py | 22 +- pr/ui/rendering.py | 89 +++--- pr/workflows/__init__.py | 10 +- pr/workflows/workflow_definition.py | 67 ++--- pr/workflows/workflow_engine.py | 134 +++++---- pr/workflows/workflow_storage.py | 155 ++++++---- rp.py | 3 +- tests/conftest.py | 24 +- tests/test_advanced_context.py | 52 ++-- tests/test_agents.py | 190 +++++++++---- tests/test_api.py | 56 ++-- tests/test_assistant.py | 104 ++++--- tests/test_config.py | 23 +- tests/test_config_loader.py | 71 +++-- tests/test_context.py | 21 +- tests/test_enhanced_assistant.py | 74 ++--- tests/test_logging.py | 19 +- tests/test_main.py | 102 ++++--- tests/test_session.py | 69 +++-- tests/test_tools.py | 101 ++++--- tests/test_usage_tracker.py | 104 ++++--- tests/test_validation.py | 41 ++- 82 files changed, 4965 insertions(+), 3096 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 20f9253..a790eac 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -159,6 +159,7 @@ def tool_function(args): """Implementation""" pass + def register_tools(): """Return list of tool definitions""" return [...] @@ -177,8 +178,8 @@ def register_tools(): ```python def test_read_file_with_valid_path_returns_content(temp_dir): # Arrange - filepath = os.path.join(temp_dir, 'test.txt') - expected_content = 'Hello, World!' + filepath = os.path.join(temp_dir, "test.txt") + expected_content = "Hello, World!" write_file(filepath, expected_content) # Act diff --git a/pr/__init__.py b/pr/__init__.py index 6e4f56d..f41af4c 100644 --- a/pr/__init__.py +++ b/pr/__init__.py @@ -1,4 +1,4 @@ from pr.core import Assistant -__version__ = '1.0.0' -__all__ = ['Assistant'] +__version__ = "1.0.0" +__all__ = ["Assistant"] diff --git a/pr/__main__.py b/pr/__main__.py index c64d131..8ddbd73 100644 --- a/pr/__main__.py +++ b/pr/__main__.py @@ -1,12 +1,14 @@ import argparse import sys -from pr.core import Assistant + from pr import __version__ +from pr.core import Assistant + def main(): parser = argparse.ArgumentParser( - description='PR Assistant - Professional CLI AI assistant with autonomous execution', - epilog=''' + description="PR Assistant - Professional CLI AI assistant with autonomous execution", + epilog=""" Examples: pr "What is Python?" # Single query pr -i # Interactive mode @@ -25,43 +27,79 @@ Commands in interactive mode: /usage - Show usage statistics /save - Save current session exit, quit, q - Exit the program - ''', - formatter_class=argparse.RawDescriptionHelpFormatter + """, + formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument('message', nargs='?', help='Message to send to assistant') - parser.add_argument('--version', action='version', version=f'PR Assistant {__version__}') - parser.add_argument('-m', '--model', help='AI model to use') - parser.add_argument('-u', '--api-url', help='API endpoint URL') - parser.add_argument('--model-list-url', help='Model list endpoint URL') - parser.add_argument('-i', '--interactive', action='store_true', help='Interactive mode') - parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output') - parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging') - parser.add_argument('--no-syntax', action='store_true', help='Disable syntax highlighting') - parser.add_argument('--include-env', action='store_true', help='Include environment variables in context') - parser.add_argument('-c', '--context', action='append', help='Additional context files') - parser.add_argument('--api-mode', action='store_true', help='API mode for specialized interaction') + parser.add_argument("message", nargs="?", help="Message to send to assistant") + parser.add_argument( + "--version", action="version", version=f"PR Assistant {__version__}" + ) + parser.add_argument("-m", "--model", help="AI model to use") + parser.add_argument("-u", "--api-url", help="API endpoint URL") + parser.add_argument("--model-list-url", help="Model list endpoint URL") + parser.add_argument( + "-i", "--interactive", action="store_true", help="Interactive mode" + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--debug", action="store_true", help="Enable debug mode with detailed logging" + ) + parser.add_argument( + "--no-syntax", action="store_true", help="Disable syntax highlighting" + ) + parser.add_argument( + "--include-env", + action="store_true", + help="Include environment variables in context", + ) + parser.add_argument( + "-c", "--context", action="append", help="Additional context files" + ) + parser.add_argument( + "--api-mode", action="store_true", help="API mode for specialized interaction" + ) - parser.add_argument('--output', choices=['text', 'json', 'structured'], - default='text', help='Output format') - parser.add_argument('--quiet', action='store_true', help='Minimal output') + parser.add_argument( + "--output", + choices=["text", "json", "structured"], + default="text", + help="Output format", + ) + parser.add_argument("--quiet", action="store_true", help="Minimal output") - parser.add_argument('--save-session', metavar='NAME', help='Save session with given name') - parser.add_argument('--load-session', metavar='NAME', help='Load session with given name') - parser.add_argument('--list-sessions', action='store_true', help='List all saved sessions') - parser.add_argument('--delete-session', metavar='NAME', help='Delete a saved session') - parser.add_argument('--export-session', nargs=2, metavar=('NAME', 'FILE'), - help='Export session to file') + parser.add_argument( + "--save-session", metavar="NAME", help="Save session with given name" + ) + parser.add_argument( + "--load-session", metavar="NAME", help="Load session with given name" + ) + parser.add_argument( + "--list-sessions", action="store_true", help="List all saved sessions" + ) + parser.add_argument( + "--delete-session", metavar="NAME", help="Delete a saved session" + ) + parser.add_argument( + "--export-session", + nargs=2, + metavar=("NAME", "FILE"), + help="Export session to file", + ) - parser.add_argument('--usage', action='store_true', help='Show token usage statistics') - parser.add_argument('--create-config', action='store_true', - help='Create default configuration file') - parser.add_argument('--plugins', action='store_true', help='List loaded plugins') + parser.add_argument( + "--usage", action="store_true", help="Show token usage statistics" + ) + parser.add_argument( + "--create-config", action="store_true", help="Create default configuration file" + ) + parser.add_argument("--plugins", action="store_true", help="List loaded plugins") args = parser.parse_args() if args.create_config: from pr.core.config_loader import create_default_config + if create_default_config(): print("Configuration file created at ~/.prrc") else: @@ -70,6 +108,7 @@ Commands in interactive mode: if args.list_sessions: from pr.core.session import SessionManager + sm = SessionManager() sessions = sm.list_sessions() if not sessions: @@ -85,6 +124,7 @@ Commands in interactive mode: if args.delete_session: from pr.core.session import SessionManager + sm = SessionManager() if sm.delete_session(args.delete_session): print(f"Session '{args.delete_session}' deleted") @@ -94,13 +134,14 @@ Commands in interactive mode: if args.export_session: from pr.core.session import SessionManager + sm = SessionManager() name, output_file = args.export_session - format_type = 'json' - if output_file.endswith('.md'): - format_type = 'markdown' - elif output_file.endswith('.txt'): - format_type = 'txt' + format_type = "json" + if output_file.endswith(".md"): + format_type = "markdown" + elif output_file.endswith(".txt"): + format_type = "txt" if sm.export_session(name, output_file, format_type): print(f"Session exported to {output_file}") @@ -110,6 +151,7 @@ Commands in interactive mode: if args.usage: from pr.core.usage_tracker import UsageTracker + usage = UsageTracker.get_total_usage() print(f"\nTotal Usage Statistics:") print(f" Requests: {usage['total_requests']}") @@ -119,6 +161,7 @@ Commands in interactive mode: if args.plugins: from pr.plugins.loader import PluginLoader + loader = PluginLoader() loader.load_plugins() plugins = loader.list_loaded_plugins() @@ -133,5 +176,6 @@ Commands in interactive mode: assistant = Assistant(args) assistant.run() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/pr/agents/__init__.py b/pr/agents/__init__.py index 838b865..6f227dc 100644 --- a/pr/agents/__init__.py +++ b/pr/agents/__init__.py @@ -1,6 +1,13 @@ +from .agent_communication import AgentCommunicationBus, AgentMessage +from .agent_manager import AgentInstance, AgentManager from .agent_roles import AgentRole, get_agent_role, list_agent_roles -from .agent_manager import AgentManager, AgentInstance -from .agent_communication import AgentMessage, AgentCommunicationBus -__all__ = ['AgentRole', 'get_agent_role', 'list_agent_roles', 'AgentManager', 'AgentInstance', - 'AgentMessage', 'AgentCommunicationBus'] \ No newline at end of file +__all__ = [ + "AgentRole", + "get_agent_role", + "list_agent_roles", + "AgentManager", + "AgentInstance", + "AgentMessage", + "AgentCommunicationBus", +] diff --git a/pr/agents/agent_communication.py b/pr/agents/agent_communication.py index 23c9b1d..82ac954 100644 --- a/pr/agents/agent_communication.py +++ b/pr/agents/agent_communication.py @@ -1,14 +1,16 @@ -import sqlite3 import json -from typing import List, Optional +import sqlite3 from dataclasses import dataclass from enum import Enum +from typing import List, Optional + class MessageType(Enum): REQUEST = "request" RESPONSE = "response" NOTIFICATION = "notification" + @dataclass class AgentMessage: message_id: str @@ -21,27 +23,28 @@ class AgentMessage: def to_dict(self) -> dict: return { - 'message_id': self.message_id, - 'from_agent': self.from_agent, - 'to_agent': self.to_agent, - 'message_type': self.message_type.value, - 'content': self.content, - 'metadata': self.metadata, - 'timestamp': self.timestamp + "message_id": self.message_id, + "from_agent": self.from_agent, + "to_agent": self.to_agent, + "message_type": self.message_type.value, + "content": self.content, + "metadata": self.metadata, + "timestamp": self.timestamp, } @classmethod - def from_dict(cls, data: dict) -> 'AgentMessage': + def from_dict(cls, data: dict) -> "AgentMessage": return cls( - message_id=data['message_id'], - from_agent=data['from_agent'], - to_agent=data['to_agent'], - message_type=MessageType(data['message_type']), - content=data['content'], - metadata=data['metadata'], - timestamp=data['timestamp'] + message_id=data["message_id"], + from_agent=data["from_agent"], + to_agent=data["to_agent"], + message_type=MessageType(data["message_type"]), + content=data["content"], + metadata=data["metadata"], + timestamp=data["timestamp"], ) + class AgentCommunicationBus: def __init__(self, db_path: str): self.db_path = db_path @@ -50,7 +53,8 @@ class AgentCommunicationBus: def _create_tables(self): cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS agent_messages ( message_id TEXT PRIMARY KEY, from_agent TEXT, @@ -62,70 +66,88 @@ class AgentCommunicationBus: session_id TEXT, read INTEGER DEFAULT 0 ) - ''') + """ + ) self.conn.commit() def send_message(self, message: AgentMessage, session_id: Optional[str] = None): cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT INTO agent_messages (message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - message.message_id, - message.from_agent, - message.to_agent, - message.message_type.value, - message.content, - json.dumps(message.metadata), - message.timestamp, - session_id - )) + """, + ( + message.message_id, + message.from_agent, + message.to_agent, + message.message_type.value, + message.content, + json.dumps(message.metadata), + message.timestamp, + session_id, + ), + ) self.conn.commit() - def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: + def get_messages( + self, agent_id: str, unread_only: bool = True + ) -> List[AgentMessage]: cursor = self.conn.cursor() if unread_only: - cursor.execute(''' + cursor.execute( + """ SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp FROM agent_messages WHERE to_agent = ? AND read = 0 ORDER BY timestamp ASC - ''', (agent_id,)) + """, + (agent_id,), + ) else: - cursor.execute(''' + cursor.execute( + """ SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp FROM agent_messages WHERE to_agent = ? ORDER BY timestamp ASC - ''', (agent_id,)) + """, + (agent_id,), + ) messages = [] for row in cursor.fetchall(): - messages.append(AgentMessage( - message_id=row[0], - from_agent=row[1], - to_agent=row[2], - message_type=MessageType(row[3]), - content=row[4], - metadata=json.loads(row[5]) if row[5] else {}, - timestamp=row[6] - )) + messages.append( + AgentMessage( + message_id=row[0], + from_agent=row[1], + to_agent=row[2], + message_type=MessageType(row[3]), + content=row[4], + metadata=json.loads(row[5]) if row[5] else {}, + timestamp=row[6], + ) + ) return messages def mark_as_read(self, message_id: str): cursor = self.conn.cursor() - cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,)) + cursor.execute( + "UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,) + ) self.conn.commit() def clear_messages(self, session_id: Optional[str] = None): cursor = self.conn.cursor() if session_id: - cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,)) + cursor.execute( + "DELETE FROM agent_messages WHERE session_id = ?", (session_id,) + ) else: - cursor.execute('DELETE FROM agent_messages') + cursor.execute("DELETE FROM agent_messages") self.conn.commit() def close(self): @@ -134,24 +156,31 @@ class AgentCommunicationBus: def receive_messages(self, agent_id: str) -> List[AgentMessage]: return self.get_messages(agent_id, unread_only=True) - def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]: + def get_conversation_history( + self, agent_a: str, agent_b: str + ) -> List[AgentMessage]: cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp FROM agent_messages WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?) ORDER BY timestamp ASC - ''', (agent_a, agent_b, agent_b, agent_a)) + """, + (agent_a, agent_b, agent_b, agent_a), + ) messages = [] for row in cursor.fetchall(): - messages.append(AgentMessage( - message_id=row[0], - from_agent=row[1], - to_agent=row[2], - message_type=MessageType(row[3]), - content=row[4], - metadata=json.loads(row[5]) if row[5] else {}, - timestamp=row[6] - )) - return messages \ No newline at end of file + messages.append( + AgentMessage( + message_id=row[0], + from_agent=row[1], + to_agent=row[2], + message_type=MessageType(row[3]), + content=row[4], + metadata=json.loads(row[5]) if row[5] else {}, + timestamp=row[6], + ) + ) + return messages diff --git a/pr/agents/agent_manager.py b/pr/agents/agent_manager.py index f6e9027..cc58b4d 100644 --- a/pr/agents/agent_manager.py +++ b/pr/agents/agent_manager.py @@ -1,11 +1,13 @@ -import time import json +import time import uuid -from typing import Dict, List, Any, Optional, Callable from dataclasses import dataclass, field -from .agent_roles import AgentRole, get_agent_role -from .agent_communication import AgentMessage, AgentCommunicationBus, MessageType +from typing import Any, Callable, Dict, List, Optional + from ..memory.knowledge_store import KnowledgeStore +from .agent_communication import AgentCommunicationBus, AgentMessage, MessageType +from .agent_roles import AgentRole, get_agent_role + @dataclass class AgentInstance: @@ -17,21 +19,20 @@ class AgentInstance: task_count: int = 0 def add_message(self, role: str, content: str): - self.message_history.append({ - 'role': role, - 'content': content, - 'timestamp': time.time() - }) + self.message_history.append( + {"role": role, "content": content, "timestamp": time.time()} + ) def get_system_message(self) -> Dict[str, str]: - return {'role': 'system', 'content': self.role.system_prompt} + return {"role": "system", "content": self.role.system_prompt} def get_messages_for_api(self) -> List[Dict[str, str]]: return [self.get_system_message()] + [ - {'role': msg['role'], 'content': msg['content']} + {"role": msg["role"], "content": msg["content"]} for msg in self.message_history ] + class AgentManager: def __init__(self, db_path: str, api_caller: Callable): self.db_path = db_path @@ -46,32 +47,31 @@ class AgentManager: agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}" role = get_agent_role(role_name) - agent = AgentInstance( - agent_id=agent_id, - role=role - ) + agent = AgentInstance(agent_id=agent_id, role=role) self.active_agents[agent_id] = agent return agent_id def get_agent(self, agent_id: str) -> Optional[AgentInstance]: return self.active_agents.get(agent_id) - + def remove_agent(self, agent_id: str) -> bool: if agent_id in self.active_agents: del self.active_agents[agent_id] return True return False - def execute_agent_task(self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def execute_agent_task( + self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: agent = self.get_agent(agent_id) if not agent: - return {'error': f'Agent {agent_id} not found'} + return {"error": f"Agent {agent_id} not found"} if context: agent.context.update(context) - agent.add_message('user', task) + agent.add_message("user", task) knowledge_matches = self.knowledge_store.search_entries(task, top_k=3) agent.task_count += 1 @@ -81,35 +81,40 @@ class AgentManager: for i, entry in enumerate(knowledge_matches, 1): shortened_content = entry.content[:2000] knowledge_content += f"{i}. {shortened_content}\\n\\n" - messages.insert(-1, {'role': 'user', 'content': knowledge_content}) + messages.insert(-1, {"role": "user", "content": knowledge_content}) try: response = self.api_caller( messages=messages, temperature=agent.role.temperature, - max_tokens=agent.role.max_tokens + max_tokens=agent.role.max_tokens, ) - if response and 'choices' in response: - assistant_message = response['choices'][0]['message']['content'] - agent.add_message('assistant', assistant_message) + if response and "choices" in response: + assistant_message = response["choices"][0]["message"]["content"] + agent.add_message("assistant", assistant_message) return { - 'success': True, - 'agent_id': agent_id, - 'response': assistant_message, - 'role': agent.role.name, - 'task_count': agent.task_count + "success": True, + "agent_id": agent_id, + "response": assistant_message, + "role": agent.role.name, + "task_count": agent.task_count, } else: - return {'error': 'Invalid API response', 'agent_id': agent_id} + return {"error": "Invalid API response", "agent_id": agent_id} except Exception as e: - return {'error': str(e), 'agent_id': agent_id} + return {"error": str(e), "agent_id": agent_id} - def send_agent_message(self, from_agent_id: str, to_agent_id: str, - content: str, message_type: MessageType = MessageType.REQUEST, - metadata: Optional[Dict[str, Any]] = None): + def send_agent_message( + self, + from_agent_id: str, + to_agent_id: str, + content: str, + message_type: MessageType = MessageType.REQUEST, + metadata: Optional[Dict[str, Any]] = None, + ): message = AgentMessage( from_agent=from_agent_id, to_agent=to_agent_id, @@ -117,57 +122,57 @@ class AgentManager: content=content, metadata=metadata or {}, timestamp=time.time(), - message_id=str(uuid.uuid4())[:16] + message_id=str(uuid.uuid4())[:16], ) self.communication_bus.send_message(message, self.session_id) return message.message_id - def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: + def get_agent_messages( + self, agent_id: str, unread_only: bool = True + ) -> List[AgentMessage]: return self.communication_bus.get_messages(agent_id, unread_only) - def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]): + def collaborate_agents( + self, orchestrator_id: str, task: str, agent_roles: List[str] + ): orchestrator = self.get_agent(orchestrator_id) if not orchestrator: - orchestrator_id = self.create_agent('orchestrator') + orchestrator_id = self.create_agent("orchestrator") orchestrator = self.get_agent(orchestrator_id) worker_agents = [] for role in agent_roles: agent_id = self.create_agent(role) - worker_agents.append({ - 'agent_id': agent_id, - 'role': role - }) + worker_agents.append({"agent_id": agent_id, "role": role}) - orchestration_prompt = f'''Task: {task} + orchestration_prompt = f"""Task: {task} Available specialized agents: {chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])} -Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.''' +Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.""" - orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt) + orchestrator_result = self.execute_agent_task( + orchestrator_id, orchestration_prompt + ) - results = { - 'orchestrator': orchestrator_result, - 'agents': [] - } + results = {"orchestrator": orchestrator_result, "agents": []} for agent_info in worker_agents: - agent_id = agent_info['agent_id'] + agent_id = agent_info["agent_id"] messages = self.get_agent_messages(agent_id) for msg in messages: subtask = msg.content result = self.execute_agent_task(agent_id, subtask) - results['agents'].append(result) + results["agents"].append(result) self.send_agent_message( from_agent_id=agent_id, to_agent_id=orchestrator_id, - content=result.get('response', ''), - message_type=MessageType.RESPONSE + content=result.get("response", ""), + message_type=MessageType.RESPONSE, ) self.communication_bus.mark_as_read(msg.message_id) @@ -175,21 +180,21 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei def get_session_summary(self) -> str: summary = { - 'session_id': self.session_id, - 'active_agents': len(self.active_agents), - 'agents': [ + "session_id": self.session_id, + "active_agents": len(self.active_agents), + "agents": [ { - 'agent_id': agent_id, - 'role': agent.role.name, - 'task_count': agent.task_count, - 'message_count': len(agent.message_history) + "agent_id": agent_id, + "role": agent.role.name, + "task_count": agent.task_count, + "message_count": len(agent.message_history), } for agent_id, agent in self.active_agents.items() - ] + ], } return json.dumps(summary) def clear_session(self): self.active_agents.clear() self.communication_bus.clear_messages(session_id=self.session_id) - self.session_id = str(uuid.uuid4())[:16] \ No newline at end of file + self.session_id = str(uuid.uuid4())[:16] diff --git a/pr/agents/agent_roles.py b/pr/agents/agent_roles.py index 6d6c7bc..1bdddde 100644 --- a/pr/agents/agent_roles.py +++ b/pr/agents/agent_roles.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import List, Dict, Any, Set +from typing import Dict, List, Set + @dataclass class AgentRole: @@ -11,182 +12,262 @@ class AgentRole: temperature: float = 0.7 max_tokens: int = 4096 + AGENT_ROLES = { - 'coding': AgentRole( - name='coding', - description='Specialized in writing, reviewing, and debugging code', - system_prompt='''You are a coding specialist AI assistant. Your primary responsibilities: + "coding": AgentRole( + name="coding", + description="Specialized in writing, reviewing, and debugging code", + system_prompt="""You are a coding specialist AI assistant. Your primary responsibilities: - Write clean, efficient, well-structured code - Review code for bugs, security issues, and best practices - Refactor and optimize existing code - Implement features based on specifications - Follow language-specific conventions and patterns -Focus on code quality, maintainability, and performance.''', +Focus on code quality, maintainability, and performance.""", allowed_tools={ - 'read_file', 'write_file', 'list_directory', 'create_directory', - 'change_directory', 'get_current_directory', 'python_exec', - 'run_command', 'index_directory' + "read_file", + "write_file", + "list_directory", + "create_directory", + "change_directory", + "get_current_directory", + "python_exec", + "run_command", + "index_directory", }, - specialization_areas=['code_writing', 'code_review', 'debugging', 'refactoring'], - temperature=0.3 + specialization_areas=[ + "code_writing", + "code_review", + "debugging", + "refactoring", + ], + temperature=0.3, ), - - 'research': AgentRole( - name='research', - description='Specialized in information gathering and analysis', - system_prompt='''You are a research specialist AI assistant. Your primary responsibilities: + "research": AgentRole( + name="research", + description="Specialized in information gathering and analysis", + system_prompt="""You are a research specialist AI assistant. Your primary responsibilities: - Search for and gather relevant information - Analyze data and documentation - Synthesize findings into clear summaries - Verify facts and cross-reference sources - Identify trends and patterns in information -Focus on accuracy, thoroughness, and clear communication of findings.''', +Focus on accuracy, thoroughness, and clear communication of findings.""", allowed_tools={ - 'read_file', 'list_directory', 'index_directory', - 'http_fetch', 'web_search', 'web_search_news', - 'db_query', 'db_get' + "read_file", + "list_directory", + "index_directory", + "http_fetch", + "web_search", + "web_search_news", + "db_query", + "db_get", }, - specialization_areas=['information_gathering', 'analysis', 'documentation', 'fact_checking'], - temperature=0.5 + specialization_areas=[ + "information_gathering", + "analysis", + "documentation", + "fact_checking", + ], + temperature=0.5, ), - - 'data_analysis': AgentRole( - name='data_analysis', - description='Specialized in data processing and analysis', - system_prompt='''You are a data analysis specialist AI assistant. Your primary responsibilities: + "data_analysis": AgentRole( + name="data_analysis", + description="Specialized in data processing and analysis", + system_prompt="""You are a data analysis specialist AI assistant. Your primary responsibilities: - Process and analyze structured and unstructured data - Perform statistical analysis and pattern recognition - Query databases and extract insights - Create data summaries and reports - Identify anomalies and trends -Focus on accuracy, data integrity, and actionable insights.''', +Focus on accuracy, data integrity, and actionable insights.""", allowed_tools={ - 'db_query', 'db_get', 'db_set', 'read_file', 'write_file', - 'python_exec', 'run_command', 'list_directory' + "db_query", + "db_get", + "db_set", + "read_file", + "write_file", + "python_exec", + "run_command", + "list_directory", }, - specialization_areas=['data_processing', 'statistical_analysis', 'database_operations'], - temperature=0.3 + specialization_areas=[ + "data_processing", + "statistical_analysis", + "database_operations", + ], + temperature=0.3, ), - - 'planning': AgentRole( - name='planning', - description='Specialized in task planning and coordination', - system_prompt='''You are a planning specialist AI assistant. Your primary responsibilities: + "planning": AgentRole( + name="planning", + description="Specialized in task planning and coordination", + system_prompt="""You are a planning specialist AI assistant. Your primary responsibilities: - Break down complex tasks into manageable steps - Create execution plans and workflows - Identify dependencies and prerequisites - Estimate effort and resource requirements - Coordinate between different components -Focus on logical organization, completeness, and feasibility.''', +Focus on logical organization, completeness, and feasibility.""", allowed_tools={ - 'read_file', 'write_file', 'list_directory', 'index_directory', - 'db_set', 'db_get' + "read_file", + "write_file", + "list_directory", + "index_directory", + "db_set", + "db_get", }, - specialization_areas=['task_decomposition', 'workflow_design', 'coordination'], - temperature=0.6 + specialization_areas=["task_decomposition", "workflow_design", "coordination"], + temperature=0.6, ), - - 'testing': AgentRole( - name='testing', - description='Specialized in testing and quality assurance', - system_prompt='''You are a testing specialist AI assistant. Your primary responsibilities: + "testing": AgentRole( + name="testing", + description="Specialized in testing and quality assurance", + system_prompt="""You are a testing specialist AI assistant. Your primary responsibilities: - Design and execute test cases - Identify edge cases and potential failures - Verify functionality and correctness - Test error handling and edge conditions - Ensure code meets quality standards -Focus on thoroughness, coverage, and issue identification.''', +Focus on thoroughness, coverage, and issue identification.""", allowed_tools={ - 'read_file', 'write_file', 'python_exec', 'run_command', - 'list_directory', 'db_query' + "read_file", + "write_file", + "python_exec", + "run_command", + "list_directory", + "db_query", }, - specialization_areas=['test_design', 'quality_assurance', 'validation'], - temperature=0.4 + specialization_areas=["test_design", "quality_assurance", "validation"], + temperature=0.4, ), - - 'documentation': AgentRole( - name='documentation', - description='Specialized in creating and maintaining documentation', - system_prompt='''You are a documentation specialist AI assistant. Your primary responsibilities: + "documentation": AgentRole( + name="documentation", + description="Specialized in creating and maintaining documentation", + system_prompt="""You are a documentation specialist AI assistant. Your primary responsibilities: - Write clear, comprehensive documentation - Create API references and user guides - Document code with comments and docstrings - Organize and structure information logically - Ensure documentation is up-to-date and accurate -Focus on clarity, completeness, and user-friendliness.''', +Focus on clarity, completeness, and user-friendliness.""", allowed_tools={ - 'read_file', 'write_file', 'list_directory', 'index_directory', - 'http_fetch', 'web_search' + "read_file", + "write_file", + "list_directory", + "index_directory", + "http_fetch", + "web_search", }, - specialization_areas=['technical_writing', 'documentation_organization', 'user_guides'], - temperature=0.6 + specialization_areas=[ + "technical_writing", + "documentation_organization", + "user_guides", + ], + temperature=0.6, ), - - 'orchestrator': AgentRole( - name='orchestrator', - description='Coordinates multiple agents and manages overall execution', - system_prompt='''You are an orchestrator AI assistant. Your primary responsibilities: + "orchestrator": AgentRole( + name="orchestrator", + description="Coordinates multiple agents and manages overall execution", + system_prompt="""You are an orchestrator AI assistant. Your primary responsibilities: - Coordinate multiple specialized agents - Delegate tasks to appropriate agents - Integrate results from different agents - Manage overall workflow execution - Ensure task completion and quality -Focus on effective delegation, integration, and overall success.''', +Focus on effective delegation, integration, and overall success.""", allowed_tools={ - 'read_file', 'write_file', 'list_directory', 'db_set', 'db_get', 'db_query' + "read_file", + "write_file", + "list_directory", + "db_set", + "db_get", + "db_query", }, - specialization_areas=['agent_coordination', 'task_delegation', 'result_integration'], - temperature=0.5 + specialization_areas=[ + "agent_coordination", + "task_delegation", + "result_integration", + ], + temperature=0.5, ), - - 'general': AgentRole( - name='general', - description='General purpose agent for miscellaneous tasks', - system_prompt='''You are a general purpose AI assistant. Your responsibilities: + "general": AgentRole( + name="general", + description="General purpose agent for miscellaneous tasks", + system_prompt="""You are a general purpose AI assistant. Your responsibilities: - Handle diverse tasks across multiple domains - Provide balanced assistance for various needs - Adapt to different types of requests - Collaborate with specialized agents when needed -Focus on versatility, helpfulness, and task completion.''', +Focus on versatility, helpfulness, and task completion.""", allowed_tools={ - 'read_file', 'write_file', 'list_directory', 'create_directory', - 'change_directory', 'get_current_directory', 'python_exec', - 'run_command', 'run_command_interactive', 'http_fetch', - 'web_search', 'web_search_news', 'db_set', 'db_get', 'db_query', - 'index_directory' + "read_file", + "write_file", + "list_directory", + "create_directory", + "change_directory", + "get_current_directory", + "python_exec", + "run_command", + "run_command_interactive", + "http_fetch", + "web_search", + "web_search_news", + "db_set", + "db_get", + "db_query", + "index_directory", }, - specialization_areas=['general_assistance'], - temperature=0.7 - ) + specialization_areas=["general_assistance"], + temperature=0.7, + ), } + def get_agent_role(role_name: str) -> AgentRole: - return AGENT_ROLES.get(role_name, AGENT_ROLES['general']) + return AGENT_ROLES.get(role_name, AGENT_ROLES["general"]) + def list_agent_roles() -> Dict[str, AgentRole]: return AGENT_ROLES.copy() + def get_recommended_agent(task_description: str) -> str: task_lower = task_description.lower() - code_keywords = ['code', 'implement', 'function', 'class', 'bug', 'debug', 'refactor', 'optimize'] - research_keywords = ['search', 'find', 'research', 'information', 'analyze', 'investigate'] - data_keywords = ['data', 'database', 'query', 'statistics', 'analyze', 'process'] - planning_keywords = ['plan', 'organize', 'workflow', 'steps', 'coordinate'] - testing_keywords = ['test', 'verify', 'validate', 'check', 'quality'] - doc_keywords = ['document', 'documentation', 'explain', 'guide', 'manual'] + code_keywords = [ + "code", + "implement", + "function", + "class", + "bug", + "debug", + "refactor", + "optimize", + ] + research_keywords = [ + "search", + "find", + "research", + "information", + "analyze", + "investigate", + ] + data_keywords = ["data", "database", "query", "statistics", "analyze", "process"] + planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"] + testing_keywords = ["test", "verify", "validate", "check", "quality"] + doc_keywords = ["document", "documentation", "explain", "guide", "manual"] if any(keyword in task_lower for keyword in code_keywords): - return 'coding' + return "coding" elif any(keyword in task_lower for keyword in research_keywords): - return 'research' + return "research" elif any(keyword in task_lower for keyword in data_keywords): - return 'data_analysis' + return "data_analysis" elif any(keyword in task_lower for keyword in planning_keywords): - return 'planning' + return "planning" elif any(keyword in task_lower for keyword in testing_keywords): - return 'testing' + return "testing" elif any(keyword in task_lower for keyword in doc_keywords): - return 'documentation' + return "documentation" else: - return 'general' + return "general" diff --git a/pr/autonomous/__init__.py b/pr/autonomous/__init__.py index 5841085..4638c44 100644 --- a/pr/autonomous/__init__.py +++ b/pr/autonomous/__init__.py @@ -1,4 +1,4 @@ from pr.autonomous.detection import is_task_complete -from pr.autonomous.mode import run_autonomous_mode, process_response_autonomous +from pr.autonomous.mode import process_response_autonomous, run_autonomous_mode -__all__ = ['is_task_complete', 'run_autonomous_mode', 'process_response_autonomous'] +__all__ = ["is_task_complete", "run_autonomous_mode", "process_response_autonomous"] diff --git a/pr/autonomous/detection.py b/pr/autonomous/detection.py index 0bcf165..6269fdf 100644 --- a/pr/autonomous/detection.py +++ b/pr/autonomous/detection.py @@ -1,28 +1,39 @@ from pr.config import MAX_AUTONOMOUS_ITERATIONS from pr.ui import Colors + def is_task_complete(response, iteration): - if 'error' in response: + if "error" in response: return True - if 'choices' not in response or not response['choices']: + if "choices" not in response or not response["choices"]: return True - message = response['choices'][0]['message'] - content = message.get('content', '').lower() + message = response["choices"][0]["message"] + content = message.get("content", "").lower() completion_keywords = [ - 'task complete', 'task is complete', 'finished', 'done', - 'successfully completed', 'task accomplished', 'all done', - 'implementation complete', 'setup complete', 'installation complete' + "task complete", + "task is complete", + "finished", + "done", + "successfully completed", + "task accomplished", + "all done", + "implementation complete", + "setup complete", + "installation complete", ] error_keywords = [ - 'cannot proceed', 'unable to continue', 'fatal error', - 'cannot complete', 'impossible to' + "cannot proceed", + "unable to continue", + "fatal error", + "cannot complete", + "impossible to", ] - has_tool_calls = 'tool_calls' in message and message['tool_calls'] + has_tool_calls = "tool_calls" in message and message["tool_calls"] mentions_completion = any(keyword in content for keyword in completion_keywords) mentions_error = any(keyword in content for keyword in error_keywords) diff --git a/pr/autonomous/mode.py b/pr/autonomous/mode.py index 5019d75..6200f9b 100644 --- a/pr/autonomous/mode.py +++ b/pr/autonomous/mode.py @@ -1,11 +1,13 @@ -import time import json import logging -from pr.ui import Colors, display_tool_call, print_autonomous_header +import time + from pr.autonomous.detection import is_task_complete from pr.core.context import truncate_tool_result +from pr.ui import Colors, display_tool_call + +logger = logging.getLogger("pr") -logger = logging.getLogger('pr') def run_autonomous_mode(assistant, task): assistant.autonomous_mode = True @@ -14,25 +16,32 @@ def run_autonomous_mode(assistant, task): logger.debug(f"=== AUTONOMOUS MODE START ===") logger.debug(f"Task: {task}") - assistant.messages.append({ - "role": "user", - "content": f"{task}" - }) + assistant.messages.append({"role": "user", "content": f"{task}"}) try: while True: assistant.autonomous_iterations += 1 - logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---") - logger.debug(f"Messages before context management: {len(assistant.messages)}") + logger.debug( + f"--- Autonomous iteration {assistant.autonomous_iterations} ---" + ) + logger.debug( + f"Messages before context management: {len(assistant.messages)}" + ) from pr.core.context import manage_context_window - assistant.messages = manage_context_window(assistant.messages, assistant.verbose) - logger.debug(f"Messages after context management: {len(assistant.messages)}") + assistant.messages = manage_context_window( + assistant.messages, assistant.verbose + ) + + logger.debug( + f"Messages after context management: {len(assistant.messages)}" + ) from pr.core.api import call_api from pr.tools.base import get_tools_definition + response = call_api( assistant.messages, assistant.model, @@ -40,10 +49,10 @@ def run_autonomous_mode(assistant, task): assistant.api_key, assistant.use_tools, get_tools_definition(), - verbose=assistant.verbose + verbose=assistant.verbose, ) - if 'error' in response: + if "error" in response: logger.error(f"API error in autonomous mode: {response['error']}") print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}") break @@ -74,22 +83,23 @@ def run_autonomous_mode(assistant, task): assistant.autonomous_mode = False logger.debug("=== AUTONOMOUS MODE END ===") + def process_response_autonomous(assistant, response): - if 'error' in response: + if "error" in response: return f"Error: {response['error']}" - if 'choices' not in response or not response['choices']: + if "choices" not in response or not response["choices"]: return "No response from API" - message = response['choices'][0]['message'] + message = response["choices"][0]["message"] assistant.messages.append(message) - if 'tool_calls' in message and message['tool_calls']: + if "tool_calls" in message and message["tool_calls"]: tool_results = [] - for tool_call in message['tool_calls']: - func_name = tool_call['function']['name'] - arguments = json.loads(tool_call['function']['arguments']) + for tool_call in message["tool_calls"]: + func_name = tool_call["function"]["name"] + arguments = json.loads(tool_call["function"]["arguments"]) result = execute_single_tool(assistant, func_name, arguments) result = truncate_tool_result(result) @@ -97,16 +107,19 @@ def process_response_autonomous(assistant, response): status = "success" if result.get("status") == "success" else "error" display_tool_call(func_name, arguments, status, result) - tool_results.append({ - "tool_call_id": tool_call['id'], - "role": "tool", - "content": json.dumps(result) - }) + tool_results.append( + { + "tool_call_id": tool_call["id"], + "role": "tool", + "content": json.dumps(result), + } + ) for result in tool_results: assistant.messages.append(result) from pr.core.api import call_api from pr.tools.base import get_tools_definition + follow_up = call_api( assistant.messages, assistant.model, @@ -114,59 +127,88 @@ def process_response_autonomous(assistant, response): assistant.api_key, assistant.use_tools, get_tools_definition(), - verbose=assistant.verbose + verbose=assistant.verbose, ) return process_response_autonomous(assistant, follow_up) - content = message.get('content', '') + content = message.get("content", "") from pr.ui import render_markdown + return render_markdown(content, assistant.syntax_highlighting) + def execute_single_tool(assistant, func_name, arguments): logger.debug(f"Executing tool in autonomous mode: {func_name}") logger.debug(f"Tool arguments: {arguments}") from pr.tools import ( - http_fetch, run_command, run_command_interactive, read_file, write_file, - list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query, - web_search, web_search_news, python_exec, index_source_directory, - search_replace, open_editor, editor_insert_text, editor_replace_text, - editor_search, close_editor, create_diff, apply_patch, tail_process, kill_process + apply_patch, + chdir, + close_editor, + create_diff, + db_get, + db_query, + db_set, + editor_insert_text, + editor_replace_text, + editor_search, + getpwd, + http_fetch, + index_source_directory, + kill_process, + list_directory, + mkdir, + open_editor, + python_exec, + read_file, + run_command, + run_command_interactive, + search_replace, + tail_process, + web_search, + web_search_news, + write_file, + ) + from pr.tools.filesystem import ( + clear_edit_tracker, + display_edit_summary, + display_edit_timeline, ) from pr.tools.patch import display_file_diff - from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker func_map = { - 'http_fetch': lambda **kw: http_fetch(**kw), - 'run_command': lambda **kw: run_command(**kw), - 'tail_process': lambda **kw: tail_process(**kw), - 'kill_process': lambda **kw: kill_process(**kw), - 'run_command_interactive': lambda **kw: run_command_interactive(**kw), - 'read_file': lambda **kw: read_file(**kw), - 'write_file': lambda **kw: write_file(**kw, db_conn=assistant.db_conn), - 'list_directory': lambda **kw: list_directory(**kw), - 'mkdir': lambda **kw: mkdir(**kw), - 'chdir': lambda **kw: chdir(**kw), - 'getpwd': lambda **kw: getpwd(**kw), - 'db_set': lambda **kw: db_set(**kw, db_conn=assistant.db_conn), - 'db_get': lambda **kw: db_get(**kw, db_conn=assistant.db_conn), - 'db_query': lambda **kw: db_query(**kw, db_conn=assistant.db_conn), - 'web_search': lambda **kw: web_search(**kw), - 'web_search_news': lambda **kw: web_search_news(**kw), - 'python_exec': lambda **kw: python_exec(**kw, python_globals=assistant.python_globals), - 'index_source_directory': lambda **kw: index_source_directory(**kw), - 'search_replace': lambda **kw: search_replace(**kw), - 'open_editor': lambda **kw: open_editor(**kw), - 'editor_insert_text': lambda **kw: editor_insert_text(**kw), - 'editor_replace_text': lambda **kw: editor_replace_text(**kw), - 'editor_search': lambda **kw: editor_search(**kw), - 'close_editor': lambda **kw: close_editor(**kw), - 'create_diff': lambda **kw: create_diff(**kw), - 'apply_patch': lambda **kw: apply_patch(**kw), - 'display_file_diff': lambda **kw: display_file_diff(**kw), - 'display_edit_summary': lambda **kw: display_edit_summary(), - 'display_edit_timeline': lambda **kw: display_edit_timeline(**kw), - 'clear_edit_tracker': lambda **kw: clear_edit_tracker(), + "http_fetch": lambda **kw: http_fetch(**kw), + "run_command": lambda **kw: run_command(**kw), + "tail_process": lambda **kw: tail_process(**kw), + "kill_process": lambda **kw: kill_process(**kw), + "run_command_interactive": lambda **kw: run_command_interactive(**kw), + "read_file": lambda **kw: read_file(**kw), + "write_file": lambda **kw: write_file(**kw, db_conn=assistant.db_conn), + "list_directory": lambda **kw: list_directory(**kw), + "mkdir": lambda **kw: mkdir(**kw), + "chdir": lambda **kw: chdir(**kw), + "getpwd": lambda **kw: getpwd(**kw), + "db_set": lambda **kw: db_set(**kw, db_conn=assistant.db_conn), + "db_get": lambda **kw: db_get(**kw, db_conn=assistant.db_conn), + "db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn), + "web_search": lambda **kw: web_search(**kw), + "web_search_news": lambda **kw: web_search_news(**kw), + "python_exec": lambda **kw: python_exec( + **kw, python_globals=assistant.python_globals + ), + "index_source_directory": lambda **kw: index_source_directory(**kw), + "search_replace": lambda **kw: search_replace(**kw), + "open_editor": lambda **kw: open_editor(**kw), + "editor_insert_text": lambda **kw: editor_insert_text(**kw), + "editor_replace_text": lambda **kw: editor_replace_text(**kw), + "editor_search": lambda **kw: editor_search(**kw), + "close_editor": lambda **kw: close_editor(**kw), + "create_diff": lambda **kw: create_diff(**kw), + "apply_patch": lambda **kw: apply_patch(**kw), + "display_file_diff": lambda **kw: display_file_diff(**kw), + "display_edit_summary": lambda **kw: display_edit_summary(), + "display_edit_timeline": lambda **kw: display_edit_timeline(**kw), + "clear_edit_tracker": lambda **kw: clear_edit_tracker(), } if func_name in func_map: diff --git a/pr/cache/__init__.py b/pr/cache/__init__.py index 237d9b6..cd3c486 100644 --- a/pr/cache/__init__.py +++ b/pr/cache/__init__.py @@ -1,4 +1,4 @@ from .api_cache import APICache from .tool_cache import ToolCache -__all__ = ['APICache', 'ToolCache'] +__all__ = ["APICache", "ToolCache"] diff --git a/pr/cache/api_cache.py b/pr/cache/api_cache.py index 52a66cd..f224f28 100644 --- a/pr/cache/api_cache.py +++ b/pr/cache/api_cache.py @@ -2,7 +2,8 @@ import hashlib import json import sqlite3 import time -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional + class APICache: def __init__(self, db_path: str, ttl_seconds: int = 3600): @@ -13,7 +14,8 @@ class APICache: def _initialize_cache(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS api_cache ( cache_key TEXT PRIMARY KEY, response_data TEXT NOT NULL, @@ -22,34 +24,44 @@ class APICache: model TEXT, token_count INTEGER ) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at) - ''') + """ + ) conn.commit() conn.close() - def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str: + def _generate_cache_key( + self, model: str, messages: list, temperature: float, max_tokens: int + ) -> str: cache_data = { - 'model': model, - 'messages': messages, - 'temperature': temperature, - 'max_tokens': max_tokens + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, } serialized = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(serialized.encode()).hexdigest() - def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]: + def get( + self, model: str, messages: list, temperature: float, max_tokens: int + ) -> Optional[Dict[str, Any]]: cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() current_time = int(time.time()) - cursor.execute(''' + cursor.execute( + """ SELECT response_data FROM api_cache WHERE cache_key = ? AND expires_at > ? - ''', (cache_key, current_time)) + """, + (cache_key, current_time), + ) row = cursor.fetchone() conn.close() @@ -58,8 +70,15 @@ class APICache: return json.loads(row[0]) return None - def set(self, model: str, messages: list, temperature: float, max_tokens: int, - response: Dict[str, Any], token_count: int = 0): + def set( + self, + model: str, + messages: list, + temperature: float, + max_tokens: int, + response: Dict[str, Any], + token_count: int = 0, + ): cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) current_time = int(time.time()) @@ -68,11 +87,21 @@ class APICache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT OR REPLACE INTO api_cache (cache_key, response_data, created_at, expires_at, model, token_count) VALUES (?, ?, ?, ?, ?, ?) - ''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count)) + """, + ( + cache_key, + json.dumps(response), + current_time, + expires_at, + model, + token_count, + ), + ) conn.commit() conn.close() @@ -83,7 +112,7 @@ class APICache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,)) + cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,)) deleted_count = cursor.rowcount conn.commit() @@ -95,7 +124,7 @@ class APICache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM api_cache') + cursor.execute("DELETE FROM api_cache") deleted_count = cursor.rowcount conn.commit() @@ -107,21 +136,26 @@ class APICache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('SELECT COUNT(*) FROM api_cache') + cursor.execute("SELECT COUNT(*) FROM api_cache") total_entries = cursor.fetchone()[0] current_time = int(time.time()) - cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,)) + cursor.execute( + "SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,) + ) valid_entries = cursor.fetchone()[0] - cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,)) + cursor.execute( + "SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?", + (current_time,), + ) total_tokens = cursor.fetchone()[0] or 0 conn.close() return { - 'total_entries': total_entries, - 'valid_entries': valid_entries, - 'expired_entries': total_entries - valid_entries, - 'total_cached_tokens': total_tokens + "total_entries": total_entries, + "valid_entries": valid_entries, + "expired_entries": total_entries - valid_entries, + "total_cached_tokens": total_tokens, } diff --git a/pr/cache/tool_cache.py b/pr/cache/tool_cache.py index d4f6fa3..da90225 100644 --- a/pr/cache/tool_cache.py +++ b/pr/cache/tool_cache.py @@ -2,16 +2,17 @@ import hashlib import json import sqlite3 import time -from typing import Optional, Any, Set +from typing import Any, Optional, Set + class ToolCache: DETERMINISTIC_TOOLS: Set[str] = { - 'read_file', - 'list_directory', - 'get_current_directory', - 'db_get', - 'db_query', - 'index_directory' + "read_file", + "list_directory", + "get_current_directory", + "db_get", + "db_query", + "index_directory", } def __init__(self, db_path: str, ttl_seconds: int = 300): @@ -22,7 +23,8 @@ class ToolCache: def _initialize_cache(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS tool_cache ( cache_key TEXT PRIMARY KEY, tool_name TEXT NOT NULL, @@ -31,21 +33,23 @@ class ToolCache: expires_at INTEGER NOT NULL, hit_count INTEGER DEFAULT 0 ) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name) - ''') + """ + ) conn.commit() conn.close() def _generate_cache_key(self, tool_name: str, arguments: dict) -> str: - cache_data = { - 'tool': tool_name, - 'args': arguments - } + cache_data = {"tool": tool_name, "args": arguments} serialized = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(serialized.encode()).hexdigest() @@ -62,18 +66,24 @@ class ToolCache: cursor = conn.cursor() current_time = int(time.time()) - cursor.execute(''' + cursor.execute( + """ SELECT result_data, hit_count FROM tool_cache WHERE cache_key = ? AND expires_at > ? - ''', (cache_key, current_time)) + """, + (cache_key, current_time), + ) row = cursor.fetchone() if row: - cursor.execute(''' + cursor.execute( + """ UPDATE tool_cache SET hit_count = hit_count + 1 WHERE cache_key = ? - ''', (cache_key,)) + """, + (cache_key,), + ) conn.commit() conn.close() return json.loads(row[0]) @@ -93,11 +103,14 @@ class ToolCache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT OR REPLACE INTO tool_cache (cache_key, tool_name, result_data, created_at, expires_at, hit_count) VALUES (?, ?, ?, ?, ?, 0) - ''', (cache_key, tool_name, json.dumps(result), current_time, expires_at)) + """, + (cache_key, tool_name, json.dumps(result), current_time, expires_at), + ) conn.commit() conn.close() @@ -106,7 +119,7 @@ class ToolCache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,)) + cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,)) deleted_count = cursor.rowcount conn.commit() @@ -120,7 +133,7 @@ class ToolCache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,)) + cursor.execute("DELETE FROM tool_cache WHERE expires_at <= ?", (current_time,)) deleted_count = cursor.rowcount conn.commit() @@ -132,7 +145,7 @@ class ToolCache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM tool_cache') + cursor.execute("DELETE FROM tool_cache") deleted_count = cursor.rowcount conn.commit() @@ -144,36 +157,41 @@ class ToolCache: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('SELECT COUNT(*) FROM tool_cache') + cursor.execute("SELECT COUNT(*) FROM tool_cache") total_entries = cursor.fetchone()[0] current_time = int(time.time()) - cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,)) + cursor.execute( + "SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,) + ) valid_entries = cursor.fetchone()[0] - cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,)) + cursor.execute( + "SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?", + (current_time,), + ) total_hits = cursor.fetchone()[0] or 0 - cursor.execute(''' + cursor.execute( + """ SELECT tool_name, COUNT(*), SUM(hit_count) FROM tool_cache WHERE expires_at > ? GROUP BY tool_name - ''', (current_time,)) + """, + (current_time,), + ) tool_stats = {} for row in cursor.fetchall(): - tool_stats[row[0]] = { - 'cached_entries': row[1], - 'total_hits': row[2] or 0 - } + tool_stats[row[0]] = {"cached_entries": row[1], "total_hits": row[2] or 0} conn.close() return { - 'total_entries': total_entries, - 'valid_entries': valid_entries, - 'expired_entries': total_entries - valid_entries, - 'total_cache_hits': total_hits, - 'by_tool': tool_stats + "total_entries": total_entries, + "valid_entries": valid_entries, + "expired_entries": total_entries - valid_entries, + "total_cache_hits": total_hits, + "by_tool": tool_stats, } diff --git a/pr/commands/__init__.py b/pr/commands/__init__.py index a79f563..a46d212 100644 --- a/pr/commands/__init__.py +++ b/pr/commands/__init__.py @@ -1,3 +1,3 @@ from pr.commands.handlers import handle_command -__all__ = ['handle_command'] +__all__ = ["handle_command"] diff --git a/pr/commands/handlers.py b/pr/commands/handlers.py index cc11816..c532e23 100644 --- a/pr/commands/handlers.py +++ b/pr/commands/handlers.py @@ -1,30 +1,35 @@ import json import time -from pr.ui import Colors + +from pr.autonomous import run_autonomous_mode +from pr.core.api import list_models from pr.tools import read_file from pr.tools.base import get_tools_definition -from pr.core.api import list_models -from pr.autonomous import run_autonomous_mode +from pr.ui import Colors + def handle_command(assistant, command): command_parts = command.strip().split(maxsplit=1) cmd = command_parts[0].lower() - if cmd == '/auto': + if cmd == "/auto": if len(command_parts) < 2: print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}") - print(f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}") + print( + f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}" + ) return True task = command_parts[1] run_autonomous_mode(assistant, task) return True - if cmd in ['exit', 'quit', 'q']: + if cmd in ["exit", "quit", "q"]: return False - elif cmd == 'help': - print(f""" + elif cmd == "help": + print( + f""" {Colors.BOLD}Available Commands:{Colors.RESET} {Colors.BOLD}Basic:{Colors.RESET} @@ -54,18 +59,21 @@ def handle_command(assistant, command): {Colors.CYAN}/cache{Colors.RESET} - Show cache statistics {Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches {Colors.CYAN}/stats{Colors.RESET} - Show system statistics - """) + """ + ) - elif cmd == '/reset': + elif cmd == "/reset": assistant.messages = assistant.messages[:1] print(f"{Colors.GREEN}Message history cleared{Colors.RESET}") - elif cmd == '/dump': + elif cmd == "/dump": print(json.dumps(assistant.messages, indent=2)) - elif cmd == '/verbose': + elif cmd == "/verbose": assistant.verbose = not assistant.verbose - print(f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}") + print( + f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}" + ) elif cmd.startswith("/model"): if len(command_parts) < 2: @@ -74,77 +82,81 @@ def handle_command(assistant, command): assistant.model = command_parts[1] print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}") - elif cmd == '/models': + elif cmd == "/models": models = list_models(assistant.model_list_url, assistant.api_key) - if isinstance(models, dict) and 'error' in models: + if isinstance(models, dict) and "error" in models: print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}") else: print(f"{Colors.BOLD}Available Models:{Colors.RESET}") for model in models: print(f" • {Colors.CYAN}{model['id']}{Colors.RESET}") - elif cmd == '/tools': + elif cmd == "/tools": print(f"{Colors.BOLD}Available Tools:{Colors.RESET}") for tool in get_tools_definition(): - func = tool['function'] - print(f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}") + func = tool["function"] + print( + f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}" + ) - elif cmd == '/review' and len(command_parts) > 1: + elif cmd == "/review" and len(command_parts) > 1: filename = command_parts[1] review_file(assistant, filename) - elif cmd == '/refactor' and len(command_parts) > 1: + elif cmd == "/refactor" and len(command_parts) > 1: filename = command_parts[1] refactor_file(assistant, filename) - elif cmd == '/obfuscate' and len(command_parts) > 1: + elif cmd == "/obfuscate" and len(command_parts) > 1: filename = command_parts[1] obfuscate_file(assistant, filename) - elif cmd == '/workflows': + elif cmd == "/workflows": show_workflows(assistant) - elif cmd == '/workflow' and len(command_parts) > 1: + elif cmd == "/workflow" and len(command_parts) > 1: workflow_name = command_parts[1] execute_workflow_command(assistant, workflow_name) - elif cmd == '/agent' and len(command_parts) > 1: + elif cmd == "/agent" and len(command_parts) > 1: args = command_parts[1].split(maxsplit=1) if len(args) < 2: print(f"{Colors.RED}Usage: /agent {Colors.RESET}") - print(f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}") + print( + f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}" + ) else: role, task = args[0], args[1] execute_agent_task(assistant, role, task) - elif cmd == '/agents': + elif cmd == "/agents": show_agents(assistant) - elif cmd == '/collaborate' and len(command_parts) > 1: + elif cmd == "/collaborate" and len(command_parts) > 1: task = command_parts[1] collaborate_agents_command(assistant, task) - elif cmd == '/knowledge' and len(command_parts) > 1: + elif cmd == "/knowledge" and len(command_parts) > 1: query = command_parts[1] search_knowledge(assistant, query) - elif cmd == '/remember' and len(command_parts) > 1: + elif cmd == "/remember" and len(command_parts) > 1: content = command_parts[1] store_knowledge(assistant, content) - elif cmd == '/history': + elif cmd == "/history": show_conversation_history(assistant) - elif cmd == '/cache': - if len(command_parts) > 1 and command_parts[1].lower() == 'clear': + elif cmd == "/cache": + if len(command_parts) > 1 and command_parts[1].lower() == "clear": clear_caches(assistant) else: show_cache_stats(assistant) - elif cmd == '/stats': + elif cmd == "/stats": show_system_stats(assistant) - elif cmd.startswith('/bg'): + elif cmd.startswith("/bg"): handle_background_command(assistant, command) else: @@ -152,35 +164,46 @@ def handle_command(assistant, command): return True + def review_file(assistant, filename): result = read_file(filename) - if result['status'] == 'success': - message = f"Please review this file and provide feedback:\n\n{result['content']}" + if result["status"] == "success": + message = ( + f"Please review this file and provide feedback:\n\n{result['content']}" + ) from pr.core.assistant import process_message + process_message(assistant, message) else: print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") + def refactor_file(assistant, filename): result = read_file(filename) - if result['status'] == 'success': - message = f"Please refactor this code to improve its quality:\n\n{result['content']}" + if result["status"] == "success": + message = ( + f"Please refactor this code to improve its quality:\n\n{result['content']}" + ) from pr.core.assistant import process_message + process_message(assistant, message) else: print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") + def obfuscate_file(assistant, filename): result = read_file(filename) - if result['status'] == 'success': + if result["status"] == "success": message = f"Please obfuscate this code:\n\n{result['content']}" from pr.core.assistant import process_message + process_message(assistant, message) else: print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") + def show_workflows(assistant): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -194,23 +217,25 @@ def show_workflows(assistant): print(f" • {Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}") print(f" Executions: {wf['execution_count']}") + def execute_workflow_command(assistant, workflow_name): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}") result = assistant.enhanced.execute_workflow(workflow_name) - if 'error' in result: + if "error" in result: print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}") else: print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}") print(f"Execution ID: {result['execution_id']}") print(f"Results: {json.dumps(result['results'], indent=2)}") + def execute_agent_task(assistant, role, task): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -221,14 +246,15 @@ def execute_agent_task(assistant, role, task): print(f"{Colors.YELLOW}Executing task...{Colors.RESET}") result = assistant.enhanced.agent_task(agent_id, task) - if 'error' in result: + if "error" in result: print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}") else: print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}") - print(result['response']) + print(result["response"]) + def show_agents(assistant): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -236,37 +262,39 @@ def show_agents(assistant): print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}") print(f"Active agents: {summary['active_agents']}") - if summary['agents']: - for agent in summary['agents']: + if summary["agents"]: + for agent in summary["agents"]: print(f"\n • {Colors.CYAN}{agent['agent_id']}{Colors.RESET}") print(f" Role: {agent['role']}") print(f" Tasks completed: {agent['task_count']}") print(f" Messages: {agent['message_count']}") + def collaborate_agents_command(assistant, task): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}") - roles = ['coding', 'research', 'planning'] + roles = ["coding", "research", "planning"] result = assistant.enhanced.collaborate_agents(task, roles) print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}") print(f"\nOrchestrator response:") - if 'orchestrator' in result and 'response' in result['orchestrator']: - print(result['orchestrator']['response']) + if "orchestrator" in result and "response" in result["orchestrator"]: + print(result["orchestrator"]["response"]) - if result.get('agents'): + if result.get("agents"): print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}") - for agent_result in result['agents']: - if 'role' in agent_result: + for agent_result in result["agents"]: + if "role" in agent_result: print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}") - print(agent_result.get('response', 'No response')) + print(agent_result.get("response", "No response")) + def search_knowledge(assistant, query): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -282,13 +310,15 @@ def search_knowledge(assistant, query): print(f" {entry.content[:200]}...") print(f" Accessed: {entry.access_count} times") + def store_knowledge(assistant, content): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return - import uuid import time + import uuid + from pr.memory import KnowledgeEntry categories = assistant.enhanced.fact_extractor.categorize_content(content) @@ -296,11 +326,11 @@ def store_knowledge(assistant, content): entry = KnowledgeEntry( entry_id=entry_id, - category=categories[0] if categories else 'general', + category=categories[0] if categories else "general", content=content, - metadata={'manual_entry': True}, + metadata={"manual_entry": True}, created_at=time.time(), - updated_at=time.time() + updated_at=time.time(), ) assistant.enhanced.knowledge_store.add_entry(entry) @@ -308,8 +338,9 @@ def store_knowledge(assistant, content): print(f"Entry ID: {entry_id}") print(f"Category: {entry.category}") + def show_conversation_history(assistant): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -322,17 +353,21 @@ def show_conversation_history(assistant): print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}") for conv in history: import datetime - started = datetime.datetime.fromtimestamp(conv['started_at']).strftime('%Y-%m-%d %H:%M') + + started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime( + "%Y-%m-%d %H:%M" + ) print(f"\n • {Colors.CYAN}{conv['conversation_id']}{Colors.RESET}") print(f" Started: {started}") print(f" Messages: {conv['message_count']}") - if conv.get('summary'): + if conv.get("summary"): print(f" Summary: {conv['summary'][:100]}...") - if conv.get('topics'): + if conv.get("topics"): print(f" Topics: {', '.join(conv['topics'])}") + def show_cache_stats(assistant): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -340,36 +375,40 @@ def show_cache_stats(assistant): print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}") - if 'api_cache' in stats: - api_stats = stats['api_cache'] + if "api_cache" in stats: + api_stats = stats["api_cache"] print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}") print(f" Total entries: {api_stats['total_entries']}") print(f" Valid entries: {api_stats['valid_entries']}") print(f" Expired entries: {api_stats['expired_entries']}") print(f" Cached tokens: {api_stats['total_cached_tokens']}") - if 'tool_cache' in stats: - tool_stats = stats['tool_cache'] + if "tool_cache" in stats: + tool_stats = stats["tool_cache"] print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}") print(f" Total entries: {tool_stats['total_entries']}") print(f" Valid entries: {tool_stats['valid_entries']}") print(f" Total cache hits: {tool_stats['total_cache_hits']}") - if tool_stats.get('by_tool'): + if tool_stats.get("by_tool"): print(f"\n Per-tool statistics:") - for tool_name, tool_stat in tool_stats['by_tool'].items(): - print(f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits") + for tool_name, tool_stat in tool_stats["by_tool"].items(): + print( + f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits" + ) + def clear_caches(assistant): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return assistant.enhanced.clear_caches() print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}") + def show_system_stats(assistant): - if not hasattr(assistant, 'enhanced'): + if not hasattr(assistant, "enhanced"): print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") return @@ -388,68 +427,81 @@ def show_system_stats(assistant): print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}") print(f" Count: {agent_summary['active_agents']}") - if 'api_cache' in cache_stats: + if "api_cache" in cache_stats: print(f"\n{Colors.CYAN}Caching:{Colors.RESET}") print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}") - if 'tool_cache' in cache_stats: + if "tool_cache" in cache_stats: print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}") + def handle_background_command(assistant, command): """Handle background multiplexer commands.""" parts = command.strip().split(maxsplit=2) if len(parts) < 2: print(f"{Colors.RED}Usage: /bg [args]{Colors.RESET}") - print(f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}") + print( + f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}" + ) return subcmd = parts[1].lower() try: - if subcmd == 'start' and len(parts) >= 3: + if subcmd == "start" and len(parts) >= 3: session_name = f"bg_{len(parts[2].split())}_{int(time.time())}" start_background_session(assistant, session_name, parts[2]) - elif subcmd == 'list': + elif subcmd == "list": list_background_sessions(assistant) - elif subcmd == 'status' and len(parts) >= 3: + elif subcmd == "status" and len(parts) >= 3: show_session_status(assistant, parts[2]) - elif subcmd == 'output' and len(parts) >= 3: + elif subcmd == "output" and len(parts) >= 3: show_session_output(assistant, parts[2]) - elif subcmd == 'input' and len(parts) >= 4: + elif subcmd == "input" and len(parts) >= 4: send_session_input(assistant, parts[2], parts[3]) - elif subcmd == 'kill' and len(parts) >= 3: + elif subcmd == "kill" and len(parts) >= 3: kill_background_session(assistant, parts[2]) - elif subcmd == 'events': + elif subcmd == "events": show_background_events(assistant) else: print(f"{Colors.RED}Unknown background command: {subcmd}{Colors.RESET}") - print(f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}") + print( + f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}" + ) except Exception as e: print(f"{Colors.RED}Error executing background command: {e}{Colors.RESET}") + def start_background_session(assistant, session_name, command): """Start a command in background.""" try: from pr.multiplexer import start_background_process + result = start_background_process(session_name, command) - if result['status'] == 'success': - print(f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}") + if result["status"] == "success": + print( + f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}" + ) else: - print(f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}") + print( + f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}" + ) except Exception as e: print(f"{Colors.RED}Error starting background session: {e}{Colors.RESET}") + def list_background_sessions(assistant): """List all background sessions.""" try: - from pr.ui.display import display_multiplexer_status from pr.multiplexer import get_all_sessions + from pr.ui.display import display_multiplexer_status sessions = get_all_sessions() display_multiplexer_status(sessions) except Exception as e: print(f"{Colors.RED}Error listing background sessions: {e}{Colors.RESET}") + def show_session_status(assistant, session_name): """Show status of a specific session.""" try: @@ -461,15 +513,17 @@ def show_session_status(assistant, session_name): print(f" Status: {info.get('status', 'unknown')}") print(f" PID: {info.get('pid', 'N/A')}") print(f" Command: {info.get('command', 'N/A')}") - if 'start_time' in info: + if "start_time" in info: import time - elapsed = time.time() - info['start_time'] + + elapsed = time.time() - info["start_time"] print(f" Running for: {elapsed:.1f}s") else: print(f"{Colors.YELLOW}Session '{session_name}' not found{Colors.RESET}") except Exception as e: print(f"{Colors.RED}Error getting session status: {e}{Colors.RESET}") + def show_session_output(assistant, session_name): """Show output of a specific session.""" try: @@ -482,36 +536,45 @@ def show_session_output(assistant, session_name): for line in output: print(line) else: - print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}") + print( + f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}" + ) except Exception as e: print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}") + def send_session_input(assistant, session_name, input_text): """Send input to a background session.""" try: from pr.multiplexer import send_input_to_session result = send_input_to_session(session_name, input_text) - if result['status'] == 'success': + if result["status"] == "success": print(f"{Colors.GREEN}Input sent to session '{session_name}'{Colors.RESET}") else: - print(f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}") + print( + f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}" + ) except Exception as e: print(f"{Colors.RED}Error sending input: {e}{Colors.RESET}") + def kill_background_session(assistant, session_name): """Kill a background session.""" try: from pr.multiplexer import kill_session result = kill_session(session_name) - if result['status'] == 'success': + if result["status"] == "success": print(f"{Colors.GREEN}Session '{session_name}' terminated{Colors.RESET}") else: - print(f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}") + print( + f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}" + ) except Exception as e: print(f"{Colors.RED}Error killing session: {e}{Colors.RESET}") + def show_background_events(assistant): """Show recent background events.""" try: @@ -526,6 +589,7 @@ def show_background_events(assistant): for event in events[-10:]: # Show last 10 events from pr.ui.display import display_background_event + display_background_event(event) else: print(f"{Colors.GRAY}No recent background events{Colors.RESET}") diff --git a/pr/commands/multiplexer_commands.py b/pr/commands/multiplexer_commands.py index 10357c1..27d5333 100644 --- a/pr/commands/multiplexer_commands.py +++ b/pr/commands/multiplexer_commands.py @@ -1,11 +1,15 @@ -from pr.tools.interactive_control import ( - list_active_sessions, get_session_status, read_session_output, - send_input_to_session, close_interactive_session -) from pr.multiplexer import get_multiplexer +from pr.tools.interactive_control import ( + close_interactive_session, + get_session_status, + list_active_sessions, + read_session_output, + send_input_to_session, +) from pr.tools.prompt_detection import get_global_detector from pr.ui import Colors + def show_sessions(args=None): """Show all active multiplexer sessions.""" sessions = list_active_sessions() @@ -18,24 +22,29 @@ def show_sessions(args=None): print("-" * 80) for session_name, session_data in sessions.items(): - metadata = session_data['metadata'] - output_summary = session_data['output_summary'] + metadata = session_data["metadata"] + output_summary = session_data["output_summary"] status = get_session_status(session_name) - is_active = status.get('is_active', False) if status else False + is_active = status.get("is_active", False) if status else False status_color = Colors.GREEN if is_active else Colors.RED - print(f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}") + print( + f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}" + ) - if status and 'pid' in status: + if status and "pid" in status: print(f" PID: {status['pid']}") print(f" Age: {metadata.get('start_time', 0):.1f}s") - print(f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines") + print( + f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines" + ) print(f" Interactions: {metadata.get('interaction_count', 0)}") print(f" State: {metadata.get('state', 'unknown')}") print() + def attach_session(args): """Attach to a session (show its output and allow interaction).""" if not args or len(args) < 1: @@ -56,20 +65,23 @@ def attach_session(args): # Show recent output try: output = read_session_output(session_name, lines=20) - if output['stdout']: + if output["stdout"]: print(f"{Colors.GRAY}Recent stdout:{Colors.RESET}") - for line in output['stdout'].split('\n'): + for line in output["stdout"].split("\n"): if line.strip(): print(f" {line}") - if output['stderr']: + if output["stderr"]: print(f"{Colors.YELLOW}Recent stderr:{Colors.RESET}") - for line in output['stderr'].split('\n'): + for line in output["stderr"].split("\n"): if line.strip(): print(f" {line}") except Exception as e: print(f"{Colors.RED}Error reading output: {e}{Colors.RESET}") - print(f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}") + print( + f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}" + ) + def detach_session(args): """Detach from a session (stop showing its output but keep it running).""" @@ -87,7 +99,10 @@ def detach_session(args): # In this implementation, detaching just means we stop displaying output # The session continues to run in the background mux.show_output = False - print(f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}") + print( + f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}" + ) + def kill_session(args): """Kill a session forcefully.""" @@ -101,7 +116,10 @@ def kill_session(args): close_interactive_session(session_name) print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}") except Exception as e: - print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}") + print( + f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}" + ) + def send_command(args): """Send a command to a session.""" @@ -110,13 +128,18 @@ def send_command(args): return session_name = args[0] - command = ' '.join(args[1:]) + command = " ".join(args[1:]) try: send_input_to_session(session_name, command) - print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}") + print( + f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}" + ) except Exception as e: - print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}") + print( + f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}" + ) + def show_session_log(args): """Show the full log/output of a session.""" @@ -131,19 +154,20 @@ def show_session_log(args): print(f"{Colors.BOLD}Full log for session: {session_name}{Colors.RESET}") print("=" * 80) - if output['stdout']: + if output["stdout"]: print(f"{Colors.GRAY}STDOUT:{Colors.RESET}") - print(output['stdout']) + print(output["stdout"]) print() - if output['stderr']: + if output["stderr"]: print(f"{Colors.YELLOW}STDERR:{Colors.RESET}") - print(output['stderr']) + print(output["stderr"]) print() except Exception as e: print(f"{Colors.RED}Error reading log for '{session_name}': {e}{Colors.RESET}") + def show_session_status(args): """Show detailed status of a session.""" if not args or len(args) < 1: @@ -160,11 +184,11 @@ def show_session_status(args): print(f"{Colors.BOLD}Status for session: {session_name}{Colors.RESET}") print("-" * 50) - metadata = status.get('metadata', {}) + metadata = status.get("metadata", {}) print(f"Process type: {metadata.get('process_type', 'unknown')}") print(f"Active: {status.get('is_active', False)}") - if 'pid' in status: + if "pid" in status: print(f"PID: {status['pid']}") print(f"Start time: {metadata.get('start_time', 0):.1f}") @@ -172,8 +196,10 @@ def show_session_status(args): print(f"Interaction count: {metadata.get('interaction_count', 0)}") print(f"State: {metadata.get('state', 'unknown')}") - output_summary = status.get('output_summary', {}) - print(f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr") + output_summary = status.get("output_summary", {}) + print( + f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr" + ) # Show prompt detection info detector = get_global_detector() @@ -182,6 +208,7 @@ def show_session_status(args): print(f"Current state: {session_info['current_state']}") print(f"Is waiting for input: {session_info['is_waiting']}") + def list_waiting_sessions(args=None): """List sessions that appear to be waiting for input.""" sessions = list_active_sessions() @@ -193,14 +220,16 @@ def list_waiting_sessions(args=None): waiting_sessions.append(session_name) if not waiting_sessions: - print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}") + print( + f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}" + ) return print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}") for session_name in waiting_sessions: status = get_session_status(session_name) if status: - process_type = status.get('metadata', {}).get('process_type', 'unknown') + process_type = status.get("metadata", {}).get("process_type", "unknown") print(f" {Colors.CYAN}{session_name}{Colors.RESET} ({process_type})") # Show suggestions @@ -208,17 +237,20 @@ def list_waiting_sessions(args=None): if session_info: suggestions = detector.get_response_suggestions({}, process_type) if suggestions: - print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3 + print( + f" Suggested inputs: {', '.join(suggestions[:3])}" + ) # Show first 3 print() + # Command registry for the multiplexer commands MULTIPLEXER_COMMANDS = { - 'show_sessions': show_sessions, - 'attach_session': attach_session, - 'detach_session': detach_session, - 'kill_session': kill_session, - 'send_command': send_command, - 'show_session_log': show_session_log, - 'show_session_status': show_session_status, - 'list_waiting_sessions': list_waiting_sessions, -} \ No newline at end of file + "show_sessions": show_sessions, + "attach_session": attach_session, + "detach_session": detach_session, + "kill_session": kill_session, + "send_command": send_command, + "show_session_log": show_session_log, + "show_session_status": show_session_status, + "list_waiting_sessions": list_waiting_sessions, +} diff --git a/pr/config.py b/pr/config.py index 5fd3096..3e950b9 100644 --- a/pr/config.py +++ b/pr/config.py @@ -27,15 +27,78 @@ CONTENT_TRIM_LENGTH = 30000 MAX_TOOL_RESULT_LENGTH = 30000 LANGUAGE_KEYWORDS = { - 'python': ['def', 'class', 'import', 'from', 'if', 'else', 'elif', 'for', 'while', - 'return', 'try', 'except', 'finally', 'with', 'as', 'lambda', 'yield', - 'None', 'True', 'False', 'and', 'or', 'not', 'in', 'is'], - 'javascript': ['function', 'var', 'let', 'const', 'if', 'else', 'for', 'while', - 'return', 'try', 'catch', 'finally', 'class', 'extends', 'new', - 'this', 'null', 'undefined', 'true', 'false'], - 'java': ['public', 'private', 'protected', 'class', 'interface', 'extends', - 'implements', 'static', 'final', 'void', 'int', 'String', 'boolean', - 'if', 'else', 'for', 'while', 'return', 'try', 'catch', 'finally'], + "python": [ + "def", + "class", + "import", + "from", + "if", + "else", + "elif", + "for", + "while", + "return", + "try", + "except", + "finally", + "with", + "as", + "lambda", + "yield", + "None", + "True", + "False", + "and", + "or", + "not", + "in", + "is", + ], + "javascript": [ + "function", + "var", + "let", + "const", + "if", + "else", + "for", + "while", + "return", + "try", + "catch", + "finally", + "class", + "extends", + "new", + "this", + "null", + "undefined", + "true", + "false", + ], + "java": [ + "public", + "private", + "protected", + "class", + "interface", + "extends", + "implements", + "static", + "final", + "void", + "int", + "String", + "boolean", + "if", + "else", + "for", + "while", + "return", + "try", + "catch", + "finally", + ], } CACHE_ENABLED = True @@ -70,18 +133,18 @@ MAX_CONCURRENT_SESSIONS = 10 # Process-specific timeouts (seconds) PROCESS_TIMEOUTS = { - 'default': 300, # 5 minutes - 'apt': 600, # 10 minutes - 'ssh': 60, # 1 minute - 'vim': 3600, # 1 hour - 'git': 300, # 5 minutes - 'npm': 600, # 10 minutes - 'pip': 300, # 5 minutes + "default": 300, # 5 minutes + "apt": 600, # 10 minutes + "ssh": 60, # 1 minute + "vim": 3600, # 1 hour + "git": 300, # 5 minutes + "npm": 600, # 10 minutes + "pip": 300, # 5 minutes } # Activity thresholds for LLM notification HIGH_OUTPUT_THRESHOLD = 50 # lines -INACTIVE_THRESHOLD = 300 # seconds +INACTIVE_THRESHOLD = 300 # seconds SESSION_NOTIFY_INTERVAL = 60 # seconds # Autonomous behavior flags diff --git a/pr/core/__init__.py b/pr/core/__init__.py index 010d62b..f30cf90 100644 --- a/pr/core/__init__.py +++ b/pr/core/__init__.py @@ -1,5 +1,11 @@ -from pr.core.assistant import Assistant from pr.core.api import call_api, list_models +from pr.core.assistant import Assistant from pr.core.context import init_system_message, manage_context_window -__all__ = ['Assistant', 'call_api', 'list_models', 'init_system_message', 'manage_context_window'] +__all__ = [ + "Assistant", + "call_api", + "list_models", + "init_system_message", + "manage_context_window", +] diff --git a/pr/core/advanced_context.py b/pr/core/advanced_context.py index e5c25a4..1d10154 100644 --- a/pr/core/advanced_context.py +++ b/pr/core/advanced_context.py @@ -1,20 +1,20 @@ import re -import math -from typing import List, Dict, Any -from collections import Counter +from typing import Any, Dict, List + class AdvancedContextManager: def __init__(self, knowledge_store=None, conversation_memory=None): self.knowledge_store = knowledge_store self.conversation_memory = conversation_memory - def adaptive_context_window(self, messages: List[Dict[str, Any]], - task_complexity: str = 'medium') -> int: + def adaptive_context_window( + self, messages: List[Dict[str, Any]], task_complexity: str = "medium" + ) -> int: complexity_thresholds = { - 'simple': 10, - 'medium': 20, - 'complex': 35, - 'very_complex': 50 + "simple": 10, + "medium": 20, + "complex": 35, + "very_complex": 50, } base_threshold = complexity_thresholds.get(task_complexity, 20) @@ -31,17 +31,19 @@ class AdvancedContextManager: return max(base_threshold, adjusted) def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float: - total_length = sum(len(msg.get('content', '')) for msg in messages) + total_length = sum(len(msg.get("content", "")) for msg in messages) avg_length = total_length / len(messages) if messages else 0 - + unique_words = set() for msg in messages: - content = msg.get('content', '') - words = re.findall(r'\b\w+\b', content.lower()) + content = msg.get("content", "") + words = re.findall(r"\b\w+\b", content.lower()) unique_words.update(words) - - vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0 - + + vocabulary_richness = ( + len(unique_words) / total_length if total_length > 0 else 0 + ) + # Simple complexity score based on length and richness complexity = min(1.0, (avg_length / 100) + vocabulary_richness) return complexity @@ -49,10 +51,10 @@ class AdvancedContextManager: def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]: if not text.strip(): return [] - sentences = re.split(r'(?<=[.!?])\s+', text) + sentences = re.split(r"(?<=[.!?])\s+", text) if not sentences: return [] - + # Simple scoring based on length and position scored_sentences = [] for i, sentence in enumerate(sentences): @@ -60,25 +62,25 @@ class AdvancedContextManager: position_score = 1.0 if i == 0 else 0.8 if i < len(sentences) / 2 else 0.6 score = (length_score + position_score) / 2 scored_sentences.append((sentence, score)) - + scored_sentences.sort(key=lambda x: x[1], reverse=True) return [s[0] for s in scored_sentences[:top_k]] def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str: - all_content = ' '.join([msg.get('content', '') for msg in messages]) + all_content = " ".join([msg.get("content", "") for msg in messages]) key_sentences = self.extract_key_sentences(all_content, top_k=3) - summary = ' '.join(key_sentences) + summary = " ".join(key_sentences) return summary if summary else "No content to summarize." def score_message_relevance(self, message: Dict[str, Any], context: str) -> float: - content = message.get('content', '') - content_words = set(re.findall(r'\b\w+\b', content.lower())) - context_words = set(re.findall(r'\b\w+\b', context.lower())) - + content = message.get("content", "") + content_words = set(re.findall(r"\b\w+\b", content.lower())) + context_words = set(re.findall(r"\b\w+\b", context.lower())) + intersection = content_words & context_words union = content_words | context_words - + if not union: return 0.0 - - return len(intersection) / len(union) \ No newline at end of file + + return len(intersection) / len(union) diff --git a/pr/core/api.py b/pr/core/api.py index fb92e52..b52efa4 100644 --- a/pr/core/api.py +++ b/pr/core/api.py @@ -1,13 +1,17 @@ import json -import urllib.request -import urllib.error import logging -from pr.config import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS +import urllib.error +import urllib.request + +from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE from pr.core.context import auto_slim_messages -logger = logging.getLogger('pr') +logger = logging.getLogger("pr") -def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False): + +def call_api( + messages, model, api_url, api_key, use_tools, tools_definition, verbose=False +): try: messages = auto_slim_messages(messages, verbose=verbose) @@ -17,62 +21,63 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver logger.debug(f"Use tools: {use_tools}") logger.debug(f"Message count: {len(messages)}") headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } if api_key: - headers['Authorization'] = f'Bearer {api_key}' + headers["Authorization"] = f"Bearer {api_key}" data = { - 'model': model, - 'messages': messages, - 'temperature': DEFAULT_TEMPERATURE, - 'max_tokens': DEFAULT_MAX_TOKENS + "model": model, + "messages": messages, + "temperature": DEFAULT_TEMPERATURE, + "max_tokens": DEFAULT_MAX_TOKENS, } if "gpt-5" in model: - del data['temperature'] - del data['max_tokens'] + del data["temperature"] + del data["max_tokens"] logger.debug("GPT-5 detected: removed temperature and max_tokens") if use_tools: - data['tools'] = tools_definition - data['tool_choice'] = 'auto' + data["tools"] = tools_definition + data["tool_choice"] = "auto" logger.debug(f"Tool calling enabled with {len(tools_definition)} tools") request_json = json.dumps(data) logger.debug(f"Request payload size: {len(request_json)} bytes") req = urllib.request.Request( - api_url, - data=request_json.encode('utf-8'), - headers=headers, - method='POST' + api_url, data=request_json.encode("utf-8"), headers=headers, method="POST" ) logger.debug("Sending HTTP request...") with urllib.request.urlopen(req) as response: - response_data = response.read().decode('utf-8') + response_data = response.read().decode("utf-8") logger.debug(f"Response received: {len(response_data)} bytes") result = json.loads(response_data) - if 'usage' in result: + if "usage" in result: logger.debug(f"Token usage: {result['usage']}") - if 'choices' in result and result['choices']: - choice = result['choices'][0] - if 'message' in choice: - msg = choice['message'] + if "choices" in result and result["choices"]: + choice = result["choices"][0] + if "message" in choice: + msg = choice["message"] logger.debug(f"Response role: {msg.get('role', 'N/A')}") - if 'content' in msg and msg['content']: - logger.debug(f"Response content length: {len(msg['content'])} chars") - if 'tool_calls' in msg: - logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") + if "content" in msg and msg["content"]: + logger.debug( + f"Response content length: {len(msg['content'])} chars" + ) + if "tool_calls" in msg: + logger.debug( + f"Response contains {len(msg['tool_calls'])} tool call(s)" + ) logger.debug("=== API CALL END ===") return result except urllib.error.HTTPError as e: - error_body = e.read().decode('utf-8') + error_body = e.read().decode("utf-8") logger.error(f"API HTTP Error: {e.code} - {error_body}") logger.debug("=== API CALL FAILED ===") return {"error": f"API Error: {e.code}", "message": error_body} @@ -81,15 +86,16 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver logger.debug("=== API CALL FAILED ===") return {"error": str(e)} + def list_models(model_list_url, api_key): try: req = urllib.request.Request(model_list_url) if api_key: - req.add_header('Authorization', f'Bearer {api_key}') + req.add_header("Authorization", f"Bearer {api_key}") with urllib.request.urlopen(req) as response: - data = json.loads(response.read().decode('utf-8')) + data = json.loads(response.read().decode("utf-8")) - return data.get('data', []) + return data.get("data", []) except Exception as e: return {"error": str(e)} diff --git a/pr/core/assistant.py b/pr/core/assistant.py index 1e58f69..92257d6 100644 --- a/pr/core/assistant.py +++ b/pr/core/assistant.py @@ -1,63 +1,111 @@ -import os -import sys -import json -import sqlite3 -import signal -import logging -import traceback -import readline import glob as glob_module +import json +import logging +import os +import readline +import signal +import sqlite3 +import sys +import traceback from concurrent.futures import ThreadPoolExecutor -from pr.config import DB_PATH, LOG_FILE, DEFAULT_MODEL, DEFAULT_API_URL, MODEL_LIST_URL, HISTORY_FILE -from pr.ui import Colors, render_markdown -from pr.core.context import init_system_message, truncate_tool_result + +from pr.commands import handle_command +from pr.config import ( + DB_PATH, + DEFAULT_API_URL, + DEFAULT_MODEL, + HISTORY_FILE, + LOG_FILE, + MODEL_LIST_URL, +) from pr.core.api import call_api +from pr.core.autonomous_interactions import ( + get_global_autonomous, + stop_global_autonomous, +) +from pr.core.background_monitor import ( + get_global_monitor, + start_global_monitor, + stop_global_monitor, +) +from pr.core.context import init_system_message, truncate_tool_result from pr.tools import ( - http_fetch, run_command, run_command_interactive, read_file, write_file, - list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query, - web_search, web_search_news, python_exec, index_source_directory, - open_editor, editor_insert_text, editor_replace_text, editor_search, - search_replace,close_editor,create_diff,apply_patch, - tail_process, kill_process + apply_patch, + chdir, + close_editor, + create_diff, + db_get, + db_query, + db_set, + editor_insert_text, + editor_replace_text, + editor_search, + getpwd, + http_fetch, + index_source_directory, + kill_process, + list_directory, + mkdir, + open_editor, + python_exec, + read_file, + run_command, + search_replace, + tail_process, + web_search, + web_search_news, + write_file, +) +from pr.tools.base import get_tools_definition +from pr.tools.filesystem import ( + clear_edit_tracker, + display_edit_summary, + display_edit_timeline, ) from pr.tools.interactive_control import ( - start_interactive_session, send_input_to_session, read_session_output, - list_active_sessions, close_interactive_session + close_interactive_session, + list_active_sessions, + read_session_output, + send_input_to_session, + start_interactive_session, ) from pr.tools.patch import display_file_diff -from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker -from pr.tools.base import get_tools_definition -from pr.commands import handle_command -from pr.core.background_monitor import start_global_monitor, stop_global_monitor, get_global_monitor -from pr.core.autonomous_interactions import start_global_autonomous, stop_global_autonomous, get_global_autonomous +from pr.ui import Colors, render_markdown -logger = logging.getLogger('pr') +logger = logging.getLogger("pr") logger.setLevel(logging.DEBUG) file_handler = logging.FileHandler(LOG_FILE) -file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) +file_handler.setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +) logger.addHandler(file_handler) + class Assistant: def __init__(self, args): self.args = args self.messages = [] self.verbose = args.verbose - self.debug = getattr(args, 'debug', False) + self.debug = getattr(args, "debug", False) self.syntax_highlighting = not args.no_syntax if self.debug: console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) - console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) + console_handler.setFormatter( + logging.Formatter("%(levelname)s: %(message)s") + ) logger.addHandler(console_handler) logger.debug("Debug mode enabled") - self.api_key = os.environ.get('OPENROUTER_API_KEY', '') - self.model = args.model or os.environ.get('AI_MODEL', DEFAULT_MODEL) - self.api_url = args.api_url or os.environ.get('API_URL', DEFAULT_API_URL) - self.model_list_url = args.model_list_url or os.environ.get('MODEL_LIST_URL', MODEL_LIST_URL) - self.use_tools = os.environ.get('USE_TOOLS', '1') == '1' - self.strict_mode = os.environ.get('STRICT_MODE', '0') == '1' + self.api_key = os.environ.get("OPENROUTER_API_KEY", "") + self.model = args.model or os.environ.get("AI_MODEL", DEFAULT_MODEL) + self.api_url = args.api_url or os.environ.get("API_URL", DEFAULT_API_URL) + self.model_list_url = args.model_list_url or os.environ.get( + "MODEL_LIST_URL", MODEL_LIST_URL + ) + self.use_tools = os.environ.get("USE_TOOLS", "1") == "1" + self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1" self.interrupt_count = 0 self.python_globals = {} self.db_conn = None @@ -69,6 +117,7 @@ class Assistant: try: from pr.core.enhanced_assistant import EnhancedAssistant + self.enhanced = EnhancedAssistant(self) if self.debug: logger.debug("Enhanced assistant features initialized") @@ -94,13 +143,17 @@ class Assistant: self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False) cursor = self.db_conn.cursor() - cursor.execute('''CREATE TABLE IF NOT EXISTS kv_store - (key TEXT PRIMARY KEY, value TEXT, timestamp REAL)''') + cursor.execute( + """CREATE TABLE IF NOT EXISTS kv_store + (key TEXT PRIMARY KEY, value TEXT, timestamp REAL)""" + ) - cursor.execute('''CREATE TABLE IF NOT EXISTS file_versions + cursor.execute( + """CREATE TABLE IF NOT EXISTS file_versions (id INTEGER PRIMARY KEY AUTOINCREMENT, filepath TEXT, content TEXT, hash TEXT, - timestamp REAL, version INTEGER)''') + timestamp REAL, version INTEGER)""" + ) self.db_conn.commit() logger.debug("Database initialized successfully") @@ -110,7 +163,7 @@ class Assistant: def _handle_background_updates(self, updates): """Handle background session updates by injecting them into the conversation.""" - if not updates or not updates.get('sessions'): + if not updates or not updates.get("sessions"): return # Format the update as a system message @@ -118,10 +171,12 @@ class Assistant: # Inject into current conversation if we're in an active session if self.messages and len(self.messages) > 0: - self.messages.append({ - "role": "system", - "content": f"Background session updates: {update_message}" - }) + self.messages.append( + { + "role": "system", + "content": f"Background session updates: {update_message}", + } + ) if self.verbose: print(f"{Colors.CYAN}Background update: {update_message}{Colors.RESET}") @@ -130,8 +185,8 @@ class Assistant: """Format background updates for LLM consumption.""" session_summaries = [] - for session_name, session_info in updates.get('sessions', {}).items(): - summary = session_info.get('summary', f'Session {session_name}') + for session_name, session_info in updates.get("sessions", {}).items(): + summary = session_info.get("summary", f"Session {session_name}") session_summaries.append(f"{session_name}: {summary}") if session_summaries: @@ -151,30 +206,44 @@ class Assistant: if events: print(f"\n{Colors.CYAN}Background Events:{Colors.RESET}") for event in events: - event_type = event.get('type', 'unknown') - session_name = event.get('session_name', 'unknown') + event_type = event.get("type", "unknown") + session_name = event.get("session_name", "unknown") - if event_type == 'session_started': - print(f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started") - elif event_type == 'session_ended': - print(f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended") - elif event_type == 'output_received': - lines = len(event.get('new_output', {}).get('stdout', [])) - print(f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output") - elif event_type == 'possible_input_needed': - print(f" {Colors.RED}❓{Colors.RESET} Session '{session_name}' may need input") - elif event_type == 'high_output_volume': - total = event.get('total_lines', 0) - print(f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)") - elif event_type == 'inactive_session': - inactive_time = event.get('inactive_seconds', 0) - print(f" {Colors.GRAY}⏰{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s") + if event_type == "session_started": + print( + f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started" + ) + elif event_type == "session_ended": + print( + f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended" + ) + elif event_type == "output_received": + lines = len(event.get("new_output", {}).get("stdout", [])) + print( + f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output" + ) + elif event_type == "possible_input_needed": + print( + f" {Colors.RED}❓{Colors.RESET} Session '{session_name}' may need input" + ) + elif event_type == "high_output_volume": + total = event.get("total_lines", 0) + print( + f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)" + ) + elif event_type == "inactive_session": + inactive_time = event.get("inactive_seconds", 0) + print( + f" {Colors.GRAY}⏰{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s" + ) print() # Add blank line after events except Exception as e: if self.debug: - print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}") + print( + f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}" + ) def execute_tool_calls(self, tool_calls): results = [] @@ -185,114 +254,147 @@ class Assistant: futures = [] for tool_call in tool_calls: - func_name = tool_call['function']['name'] - arguments = json.loads(tool_call['function']['arguments']) + func_name = tool_call["function"]["name"] + arguments = json.loads(tool_call["function"]["arguments"]) logger.debug(f"Tool call: {func_name} with arguments: {arguments}") func_map = { - 'http_fetch': lambda **kw: http_fetch(**kw), - 'run_command': lambda **kw: run_command(**kw), - 'tail_process': lambda **kw: tail_process(**kw), - 'kill_process': lambda **kw: kill_process(**kw), - 'start_interactive_session': lambda **kw: start_interactive_session(**kw), - 'send_input_to_session': lambda **kw: send_input_to_session(**kw), - 'read_session_output': lambda **kw: read_session_output(**kw), - 'close_interactive_session': lambda **kw: close_interactive_session(**kw), - 'read_file': lambda **kw: read_file(**kw, db_conn=self.db_conn), - 'write_file': lambda **kw: write_file(**kw, db_conn=self.db_conn), - 'list_directory': lambda **kw: list_directory(**kw), - 'mkdir': lambda **kw: mkdir(**kw), - 'chdir': lambda **kw: chdir(**kw), - 'getpwd': lambda **kw: getpwd(**kw), - 'db_set': lambda **kw: db_set(**kw, db_conn=self.db_conn), - 'db_get': lambda **kw: db_get(**kw, db_conn=self.db_conn), - 'db_query': lambda **kw: db_query(**kw, db_conn=self.db_conn), - 'web_search': lambda **kw: web_search(**kw), - 'web_search_news': lambda **kw: web_search_news(**kw), - 'python_exec': lambda **kw: python_exec(**kw, python_globals=self.python_globals), - 'index_source_directory': lambda **kw: index_source_directory(**kw), - 'search_replace': lambda **kw: search_replace(**kw, db_conn=self.db_conn), - 'open_editor': lambda **kw: open_editor(**kw), - 'editor_insert_text': lambda **kw: editor_insert_text(**kw, db_conn=self.db_conn), - 'editor_replace_text': lambda **kw: editor_replace_text(**kw, db_conn=self.db_conn), - 'editor_search': lambda **kw: editor_search(**kw), - 'close_editor': lambda **kw: close_editor(**kw), - 'create_diff': lambda **kw: create_diff(**kw), - 'apply_patch': lambda **kw: apply_patch(**kw, db_conn=self.db_conn), - 'display_file_diff': lambda **kw: display_file_diff(**kw), - 'display_edit_summary': lambda **kw: display_edit_summary(), - 'display_edit_timeline': lambda **kw: display_edit_timeline(**kw), - 'clear_edit_tracker': lambda **kw: clear_edit_tracker(), - 'start_interactive_session': lambda **kw: start_interactive_session(**kw), - 'send_input_to_session': lambda **kw: send_input_to_session(**kw), - 'read_session_output': lambda **kw: read_session_output(**kw), - 'list_active_sessions': lambda **kw: list_active_sessions(**kw), - 'close_interactive_session': lambda **kw: close_interactive_session(**kw), - 'create_agent': lambda **kw: create_agent(**kw), - 'list_agents': lambda **kw: list_agents(**kw), - 'execute_agent_task': lambda **kw: execute_agent_task(**kw), - 'remove_agent': lambda **kw: remove_agent(**kw), - 'collaborate_agents': lambda **kw: collaborate_agents(**kw), - 'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw), - 'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw), - 'search_knowledge': lambda **kw: search_knowledge(**kw), - 'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw), - 'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw), - 'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw), - 'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw), + "http_fetch": lambda **kw: http_fetch(**kw), + "run_command": lambda **kw: run_command(**kw), + "tail_process": lambda **kw: tail_process(**kw), + "kill_process": lambda **kw: kill_process(**kw), + "start_interactive_session": lambda **kw: start_interactive_session( + **kw + ), + "send_input_to_session": lambda **kw: send_input_to_session(**kw), + "read_session_output": lambda **kw: read_session_output(**kw), + "close_interactive_session": lambda **kw: close_interactive_session( + **kw + ), + "read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn), + "write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn), + "list_directory": lambda **kw: list_directory(**kw), + "mkdir": lambda **kw: mkdir(**kw), + "chdir": lambda **kw: chdir(**kw), + "getpwd": lambda **kw: getpwd(**kw), + "db_set": lambda **kw: db_set(**kw, db_conn=self.db_conn), + "db_get": lambda **kw: db_get(**kw, db_conn=self.db_conn), + "db_query": lambda **kw: db_query(**kw, db_conn=self.db_conn), + "web_search": lambda **kw: web_search(**kw), + "web_search_news": lambda **kw: web_search_news(**kw), + "python_exec": lambda **kw: python_exec( + **kw, python_globals=self.python_globals + ), + "index_source_directory": lambda **kw: index_source_directory(**kw), + "search_replace": lambda **kw: search_replace( + **kw, db_conn=self.db_conn + ), + "open_editor": lambda **kw: open_editor(**kw), + "editor_insert_text": lambda **kw: editor_insert_text( + **kw, db_conn=self.db_conn + ), + "editor_replace_text": lambda **kw: editor_replace_text( + **kw, db_conn=self.db_conn + ), + "editor_search": lambda **kw: editor_search(**kw), + "close_editor": lambda **kw: close_editor(**kw), + "create_diff": lambda **kw: create_diff(**kw), + "apply_patch": lambda **kw: apply_patch(**kw, db_conn=self.db_conn), + "display_file_diff": lambda **kw: display_file_diff(**kw), + "display_edit_summary": lambda **kw: display_edit_summary(), + "display_edit_timeline": lambda **kw: display_edit_timeline(**kw), + "clear_edit_tracker": lambda **kw: clear_edit_tracker(), + "start_interactive_session": lambda **kw: start_interactive_session( + **kw + ), + "send_input_to_session": lambda **kw: send_input_to_session(**kw), + "read_session_output": lambda **kw: read_session_output(**kw), + "list_active_sessions": lambda **kw: list_active_sessions(**kw), + "close_interactive_session": lambda **kw: close_interactive_session( + **kw + ), + "create_agent": lambda **kw: create_agent(**kw), + "list_agents": lambda **kw: list_agents(**kw), + "execute_agent_task": lambda **kw: execute_agent_task(**kw), + "remove_agent": lambda **kw: remove_agent(**kw), + "collaborate_agents": lambda **kw: collaborate_agents(**kw), + "add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw), + "get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw), + "search_knowledge": lambda **kw: search_knowledge(**kw), + "get_knowledge_by_category": lambda **kw: get_knowledge_by_category( + **kw + ), + "update_knowledge_importance": lambda **kw: update_knowledge_importance( + **kw + ), + "delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw), + "get_knowledge_statistics": lambda **kw: get_knowledge_statistics( + **kw + ), } if func_name in func_map: future = executor.submit(func_map[func_name], **arguments) - futures.append((tool_call['id'], future)) + futures.append((tool_call["id"], future)) for tool_id, future in futures: try: result = future.result(timeout=30) result = truncate_tool_result(result) logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...") - results.append({ - "tool_call_id": tool_id, - "role": "tool", - "content": json.dumps(result) - }) + results.append( + { + "tool_call_id": tool_id, + "role": "tool", + "content": json.dumps(result), + } + ) except Exception as e: logger.debug(f"Tool error for {tool_id}: {str(e)}") error_msg = str(e)[:200] if len(str(e)) > 200 else str(e) - results.append({ - "tool_call_id": tool_id, - "role": "tool", - "content": json.dumps({"status": "error", "error": error_msg}) - }) + results.append( + { + "tool_call_id": tool_id, + "role": "tool", + "content": json.dumps( + {"status": "error", "error": error_msg} + ), + } + ) return results def process_response(self, response): - if 'error' in response: + if "error" in response: return f"Error: {response['error']}" - if 'choices' not in response or not response['choices']: + if "choices" not in response or not response["choices"]: return "No response from API" - message = response['choices'][0]['message'] + message = response["choices"][0]["message"] self.messages.append(message) - if 'tool_calls' in message and message['tool_calls']: + if "tool_calls" in message and message["tool_calls"]: if self.verbose: print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}") - tool_results = self.execute_tool_calls(message['tool_calls']) + tool_results = self.execute_tool_calls(message["tool_calls"]) for result in tool_results: self.messages.append(result) follow_up = call_api( - self.messages, self.model, self.api_url, self.api_key, - self.use_tools, get_tools_definition(), verbose=self.verbose + self.messages, + self.model, + self.api_url, + self.api_key, + self.use_tools, + get_tools_definition(), + verbose=self.verbose, ) return self.process_response(follow_up) - content = message.get('content', '') + content = message.get("content", "") return render_markdown(content, self.syntax_highlighting) def signal_handler(self, signum, frame): @@ -303,7 +405,9 @@ class Assistant: self.autonomous_mode = False sys.exit(0) else: - print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}") + print( + f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}" + ) return self.interrupt_count += 1 @@ -323,21 +427,34 @@ class Assistant: readline.set_history_length(1000) import atexit + atexit.register(readline.write_history_file, HISTORY_FILE) - commands = ['exit', 'quit', 'help', 'reset', 'dump', 'verbose', - 'models', 'tools', 'review', 'refactor', 'obfuscate', '/auto'] + commands = [ + "exit", + "quit", + "help", + "reset", + "dump", + "verbose", + "models", + "tools", + "review", + "refactor", + "obfuscate", + "/auto", + ] def completer(text, state): options = [cmd for cmd in commands if cmd.startswith(text)] - glob_pattern = os.path.expanduser(text) + '*' + glob_pattern = os.path.expanduser(text) + "*" path_options = glob_module.glob(glob_pattern) path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options] combined_options = sorted(list(set(options + path_options))) - #combined_options.extend(self.commands) + # combined_options.extend(self.commands) if state < len(combined_options): return combined_options[state] @@ -345,10 +462,10 @@ class Assistant: return None delims = readline.get_completer_delims() - readline.set_completer_delims(delims.replace('/', '')) + readline.set_completer_delims(delims.replace("/", "")) readline.set_completer(completer) - readline.parse_and_bind('tab: complete') + readline.parse_and_bind("tab: complete") def run_repl(self): self.setup_readline() @@ -368,8 +485,11 @@ class Assistant: if self.background_monitoring: try: from pr.multiplexer import get_all_sessions + sessions = get_all_sessions() - active_count = sum(1 for s in sessions.values() if s.get('status') == 'running') + active_count = sum( + 1 for s in sessions.values() if s.get("status") == "running" + ) if active_count > 0: prompt += f"[{active_count}bg]" except: @@ -405,10 +525,11 @@ class Assistant: message = sys.stdin.read() from pr.autonomous.mode import run_autonomous_mode + run_autonomous_mode(self, message) def cleanup(self): - if hasattr(self, 'enhanced') and self.enhanced: + if hasattr(self, "enhanced") and self.enhanced: try: self.enhanced.cleanup() except Exception as e: @@ -424,6 +545,7 @@ class Assistant: try: from pr.multiplexer import cleanup_all_multiplexers + cleanup_all_multiplexers() except Exception as e: logger.error(f"Error cleaning up multiplexers: {e}") @@ -433,7 +555,9 @@ class Assistant: def run(self): try: - print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}") + print( + f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}" + ) if self.args.interactive or (not self.args.message and sys.stdin.isatty()): print("DEBUG: calling run_repl") self.run_repl() @@ -443,6 +567,7 @@ class Assistant: finally: self.cleanup() + def process_message(assistant, message): assistant.messages.append({"role": "user", "content": message}) @@ -453,9 +578,13 @@ def process_message(assistant, message): print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}") response = call_api( - assistant.messages, assistant.model, assistant.api_url, - assistant.api_key, assistant.use_tools, get_tools_definition(), - verbose=assistant.verbose + assistant.messages, + assistant.model, + assistant.api_url, + assistant.api_key, + assistant.use_tools, + get_tools_definition(), + verbose=assistant.verbose, ) result = assistant.process_response(response) diff --git a/pr/core/autonomous_interactions.py b/pr/core/autonomous_interactions.py index ff2309d..067a432 100644 --- a/pr/core/autonomous_interactions.py +++ b/pr/core/autonomous_interactions.py @@ -1,7 +1,12 @@ -import time import threading -from pr.core.background_monitor import get_global_monitor -from pr.tools.interactive_control import list_active_sessions, get_session_status, read_session_output +import time + +from pr.tools.interactive_control import ( + get_session_status, + list_active_sessions, + read_session_output, +) + class AutonomousInteractions: def __init__(self, interaction_interval=10.0): @@ -16,7 +21,9 @@ class AutonomousInteractions: self.llm_callback = llm_callback if self.interaction_thread is None: self.active = True - self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True) + self.interaction_thread = threading.Thread( + target=self._interaction_loop, daemon=True + ) self.interaction_thread.start() def stop(self): @@ -48,7 +55,9 @@ class AutonomousInteractions: if not sessions: return # No active sessions - sessions_needing_attention = self._identify_sessions_needing_attention(sessions) + sessions_needing_attention = self._identify_sessions_needing_attention( + sessions + ) if sessions_needing_attention and self.llm_callback: # Format session updates for LLM @@ -63,26 +72,30 @@ class AutonomousInteractions: needing_attention = [] for session_name, session_data in sessions.items(): - metadata = session_data['metadata'] - output_summary = session_data['output_summary'] + metadata = session_data["metadata"] + output_summary = session_data["output_summary"] # Criteria for needing attention: # 1. Recent output activity - time_since_activity = time.time() - metadata.get('last_activity', 0) + time_since_activity = time.time() - metadata.get("last_activity", 0) if time_since_activity < 30: # Activity in last 30 seconds needing_attention.append(session_name) continue # 2. High output volume (potential completion or error) - total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines'] + total_lines = ( + output_summary["stdout_lines"] + output_summary["stderr_lines"] + ) if total_lines > 50: # Arbitrary threshold needing_attention.append(session_name) continue # 3. Long-running sessions that might need intervention - session_age = time.time() - metadata.get('start_time', 0) - if session_age > 300 and time_since_activity > 60: # 5+ minutes old, inactive for 1+ minute + session_age = time.time() - metadata.get("start_time", 0) + if ( + session_age > 300 and time_since_activity > 60 + ): # 5+ minutes old, inactive for 1+ minute needing_attention.append(session_name) continue @@ -95,18 +108,18 @@ class AutonomousInteractions: def _session_looks_stuck(self, session_name, session_data): """Determine if a session appears to be stuck waiting for input.""" - metadata = session_data['metadata'] + metadata = session_data["metadata"] # Check if process is still running status = get_session_status(session_name) - if not status or not status.get('is_active', False): + if not status or not status.get("is_active", False): return False - time_since_activity = time.time() - metadata.get('last_activity', 0) - interaction_count = metadata.get('interaction_count', 0) + time_since_activity = time.time() - metadata.get("last_activity", 0) + interaction_count = metadata.get("interaction_count", 0) # If running for a while but no interactions, might be waiting - session_age = time.time() - metadata.get('start_time', 0) + session_age = time.time() - metadata.get("start_time", 0) if session_age > 60 and interaction_count == 0 and time_since_activity > 30: return True @@ -119,9 +132,9 @@ class AutonomousInteractions: def _format_session_updates(self, session_names): """Format session information for LLM consumption.""" updates = { - 'type': 'background_session_updates', - 'timestamp': time.time(), - 'sessions': {} + "type": "background_session_updates", + "timestamp": time.time(), + "sessions": {}, } for session_name in session_names: @@ -131,12 +144,12 @@ class AutonomousInteractions: try: recent_output = read_session_output(session_name, lines=20) except: - recent_output = {'stdout': '', 'stderr': ''} + recent_output = {"stdout": "", "stderr": ""} - updates['sessions'][session_name] = { - 'status': status, - 'recent_output': recent_output, - 'summary': self._create_session_summary(status, recent_output) + updates["sessions"][session_name] = { + "status": status, + "recent_output": recent_output, + "summary": self._create_session_summary(status, recent_output), } return updates @@ -145,34 +158,39 @@ class AutonomousInteractions: """Create a human-readable summary of session status.""" summary_parts = [] - process_type = status.get('metadata', {}).get('process_type', 'unknown') + process_type = status.get("metadata", {}).get("process_type", "unknown") summary_parts.append(f"Type: {process_type}") - is_active = status.get('is_active', False) + is_active = status.get("is_active", False) summary_parts.append(f"Status: {'Active' if is_active else 'Inactive'}") - if is_active and 'pid' in status: + if is_active and "pid" in status: summary_parts.append(f"PID: {status['pid']}") - age = time.time() - status.get('metadata', {}).get('start_time', 0) + age = time.time() - status.get("metadata", {}).get("start_time", 0) summary_parts.append(f"Age: {age:.1f}s") - output_lines = len(recent_output.get('stdout', '').split('\n')) + len(recent_output.get('stderr', '').split('\n')) + output_lines = len(recent_output.get("stdout", "").split("\n")) + len( + recent_output.get("stderr", "").split("\n") + ) summary_parts.append(f"Recent output: {output_lines} lines") - interaction_count = status.get('metadata', {}).get('interaction_count', 0) + interaction_count = status.get("metadata", {}).get("interaction_count", 0) summary_parts.append(f"Interactions: {interaction_count}") return " | ".join(summary_parts) + # Global autonomous interactions instance _global_autonomous = None + def get_global_autonomous(): """Get the global autonomous interactions instance.""" global _global_autonomous return _global_autonomous + def start_global_autonomous(llm_callback=None): """Start global autonomous interactions.""" global _global_autonomous @@ -181,6 +199,7 @@ def start_global_autonomous(llm_callback=None): _global_autonomous.start(llm_callback) return _global_autonomous + def stop_global_autonomous(): """Stop global autonomous interactions.""" global _global_autonomous diff --git a/pr/core/background_monitor.py b/pr/core/background_monitor.py index 12fe812..f898b99 100644 --- a/pr/core/background_monitor.py +++ b/pr/core/background_monitor.py @@ -1,8 +1,9 @@ +import queue import threading import time -import queue + from pr.multiplexer import get_all_multiplexer_states, get_multiplexer -from pr.tools.interactive_control import get_session_status + class BackgroundMonitor: def __init__(self, check_interval=5.0): @@ -17,7 +18,9 @@ class BackgroundMonitor: """Start the background monitoring thread.""" if self.monitor_thread is None: self.active = True - self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) + self.monitor_thread = threading.Thread( + target=self._monitor_loop, daemon=True + ) self.monitor_thread.start() def stop(self): @@ -78,19 +81,18 @@ class BackgroundMonitor: # Check for new sessions for session_name in new_states: if session_name not in old_states: - events.append({ - 'type': 'session_started', - 'session_name': session_name, - 'metadata': new_states[session_name]['metadata'] - }) + events.append( + { + "type": "session_started", + "session_name": session_name, + "metadata": new_states[session_name]["metadata"], + } + ) # Check for ended sessions for session_name in old_states: if session_name not in new_states: - events.append({ - 'type': 'session_ended', - 'session_name': session_name - }) + events.append({"type": "session_ended", "session_name": session_name}) # Check for activity in existing sessions for session_name, new_state in new_states.items(): @@ -98,92 +100,112 @@ class BackgroundMonitor: old_state = old_states[session_name] # Check for output changes - old_stdout_lines = old_state['output_summary']['stdout_lines'] - new_stdout_lines = new_state['output_summary']['stdout_lines'] - old_stderr_lines = old_state['output_summary']['stderr_lines'] - new_stderr_lines = new_state['output_summary']['stderr_lines'] + old_stdout_lines = old_state["output_summary"]["stdout_lines"] + new_stdout_lines = new_state["output_summary"]["stdout_lines"] + old_stderr_lines = old_state["output_summary"]["stderr_lines"] + new_stderr_lines = new_state["output_summary"]["stderr_lines"] - if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines: + if ( + new_stdout_lines > old_stdout_lines + or new_stderr_lines > old_stderr_lines + ): # Get the new output mux = get_multiplexer(session_name) if mux: all_output = mux.get_all_output() new_output = { - 'stdout': all_output['stdout'].split('\n')[old_stdout_lines:], - 'stderr': all_output['stderr'].split('\n')[old_stderr_lines:] + "stdout": all_output["stdout"].split("\n")[ + old_stdout_lines: + ], + "stderr": all_output["stderr"].split("\n")[ + old_stderr_lines: + ], } - events.append({ - 'type': 'output_received', - 'session_name': session_name, - 'new_output': new_output, - 'total_lines': { - 'stdout': new_stdout_lines, - 'stderr': new_stderr_lines + events.append( + { + "type": "output_received", + "session_name": session_name, + "new_output": new_output, + "total_lines": { + "stdout": new_stdout_lines, + "stderr": new_stderr_lines, + }, } - }) + ) # Check for state changes - old_metadata = old_state['metadata'] - new_metadata = new_state['metadata'] + old_metadata = old_state["metadata"] + new_metadata = new_state["metadata"] - if old_metadata.get('state') != new_metadata.get('state'): - events.append({ - 'type': 'state_changed', - 'session_name': session_name, - 'old_state': old_metadata.get('state'), - 'new_state': new_metadata.get('state') - }) + if old_metadata.get("state") != new_metadata.get("state"): + events.append( + { + "type": "state_changed", + "session_name": session_name, + "old_state": old_metadata.get("state"), + "new_state": new_metadata.get("state"), + } + ) # Check for process type identification - if (old_metadata.get('process_type') == 'unknown' and - new_metadata.get('process_type') != 'unknown'): - events.append({ - 'type': 'process_identified', - 'session_name': session_name, - 'process_type': new_metadata.get('process_type') - }) + if ( + old_metadata.get("process_type") == "unknown" + and new_metadata.get("process_type") != "unknown" + ): + events.append( + { + "type": "process_identified", + "session_name": session_name, + "process_type": new_metadata.get("process_type"), + } + ) # Check for sessions needing attention (based on heuristics) for session_name, state in new_states.items(): - metadata = state['metadata'] - output_summary = state['output_summary'] + metadata = state["metadata"] + output_summary = state["output_summary"] # Heuristic: High output volume might indicate completion or error - total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines'] + total_lines = ( + output_summary["stdout_lines"] + output_summary["stderr_lines"] + ) if total_lines > 100: # Arbitrary threshold - events.append({ - 'type': 'high_output_volume', - 'session_name': session_name, - 'total_lines': total_lines - }) + events.append( + { + "type": "high_output_volume", + "session_name": session_name, + "total_lines": total_lines, + } + ) # Heuristic: Long-running session without recent activity - time_since_activity = time.time() - metadata.get('last_activity', 0) + time_since_activity = time.time() - metadata.get("last_activity", 0) if time_since_activity > 300: # 5 minutes - events.append({ - 'type': 'inactive_session', - 'session_name': session_name, - 'inactive_seconds': time_since_activity - }) + events.append( + { + "type": "inactive_session", + "session_name": session_name, + "inactive_seconds": time_since_activity, + } + ) # Heuristic: Sessions that might be waiting for input # This would be enhanced with prompt detection in later phases if self._might_be_waiting_for_input(session_name, state): - events.append({ - 'type': 'possible_input_needed', - 'session_name': session_name - }) + events.append( + {"type": "possible_input_needed", "session_name": session_name} + ) return events def _might_be_waiting_for_input(self, session_name, state): """Heuristic to detect if a session might be waiting for input.""" - metadata = state['metadata'] - process_type = metadata.get('process_type', 'unknown') + metadata = state["metadata"] + metadata.get("process_type", "unknown") # Simple heuristics based on process type and recent activity - time_since_activity = time.time() - metadata.get('last_activity', 0) + time_since_activity = time.time() - metadata.get("last_activity", 0) # If it's been more than 10 seconds since last activity, might be waiting if time_since_activity > 10: @@ -191,9 +213,11 @@ class BackgroundMonitor: return False + # Global monitor instance _global_monitor = None + def get_global_monitor(): """Get the global background monitor instance.""" global _global_monitor @@ -201,20 +225,24 @@ def get_global_monitor(): _global_monitor = BackgroundMonitor() return _global_monitor + def start_global_monitor(): """Start the global background monitor.""" monitor = get_global_monitor() monitor.start() + def stop_global_monitor(): """Stop the global background monitor.""" global _global_monitor if _global_monitor: _global_monitor.stop() + # Global monitor instance _global_monitor = None + def start_global_monitor(): """Start the global background monitor.""" global _global_monitor @@ -223,6 +251,7 @@ def start_global_monitor(): _global_monitor.start() return _global_monitor + def stop_global_monitor(): """Stop the global background monitor.""" global _global_monitor @@ -230,6 +259,7 @@ def stop_global_monitor(): _global_monitor.stop() _global_monitor = None + def get_global_monitor(): """Get the global background monitor instance.""" global _global_monitor diff --git a/pr/core/config_loader.py b/pr/core/config_loader.py index c1d812d..790ba56 100644 --- a/pr/core/config_loader.py +++ b/pr/core/config_loader.py @@ -1,22 +1,17 @@ -import os import configparser -from typing import Dict, Any +import os +from typing import Any, Dict + from pr.core.logging import get_logger -logger = get_logger('config') +logger = get_logger("config") CONFIG_FILE = os.path.expanduser("~/.prrc") LOCAL_CONFIG_FILE = ".prrc" def load_config() -> Dict[str, Any]: - config = { - 'api': {}, - 'autonomous': {}, - 'ui': {}, - 'output': {}, - 'session': {} - } + config = {"api": {}, "autonomous": {}, "ui": {}, "output": {}, "session": {}} global_config = _load_config_file(CONFIG_FILE) local_config = _load_config_file(LOCAL_CONFIG_FILE) @@ -55,9 +50,9 @@ def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]: def _parse_value(value: str) -> Any: value = value.strip() - if value.lower() == 'true': + if value.lower() == "true": return True - if value.lower() == 'false': + if value.lower() == "false": return False if value.isdigit(): @@ -99,7 +94,7 @@ max_history = 1000 """ try: - with open(filepath, 'w') as f: + with open(filepath, "w") as f: f.write(default_config) logger.info(f"Created default configuration at {filepath}") return True diff --git a/pr/core/context.py b/pr/core/context.py index 9765339..4cbc04f 100644 --- a/pr/core/context.py +++ b/pr/core/context.py @@ -1,11 +1,21 @@ -import os import json import logging -from pr.config import (CONTEXT_FILE, GLOBAL_CONTEXT_FILE, CONTEXT_COMPRESSION_THRESHOLD, - RECENT_MESSAGES_TO_KEEP, MAX_TOKENS_LIMIT, CHARS_PER_TOKEN, - EMERGENCY_MESSAGES_TO_KEEP, CONTENT_TRIM_LENGTH, MAX_TOOL_RESULT_LENGTH) +import os + +from pr.config import ( + CHARS_PER_TOKEN, + CONTENT_TRIM_LENGTH, + CONTEXT_COMPRESSION_THRESHOLD, + CONTEXT_FILE, + EMERGENCY_MESSAGES_TO_KEEP, + GLOBAL_CONTEXT_FILE, + MAX_TOKENS_LIMIT, + MAX_TOOL_RESULT_LENGTH, + RECENT_MESSAGES_TO_KEEP, +) from pr.ui import Colors + def truncate_tool_result(result, max_length=None): if max_length is None: max_length = MAX_TOOL_RESULT_LENGTH @@ -17,24 +27,36 @@ def truncate_tool_result(result, max_length=None): if "output" in result_copy and isinstance(result_copy["output"], str): if len(result_copy["output"]) > max_length: - result_copy["output"] = result_copy["output"][:max_length] + f"\n... [truncated {len(result_copy['output']) - max_length} chars]" + result_copy["output"] = ( + result_copy["output"][:max_length] + + f"\n... [truncated {len(result_copy['output']) - max_length} chars]" + ) if "content" in result_copy and isinstance(result_copy["content"], str): if len(result_copy["content"]) > max_length: - result_copy["content"] = result_copy["content"][:max_length] + f"\n... [truncated {len(result_copy['content']) - max_length} chars]" + result_copy["content"] = ( + result_copy["content"][:max_length] + + f"\n... [truncated {len(result_copy['content']) - max_length} chars]" + ) if "data" in result_copy and isinstance(result_copy["data"], str): if len(result_copy["data"]) > max_length: - result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]" + result_copy["data"] = ( + result_copy["data"][:max_length] + f"\n... [truncated]" + ) if "error" in result_copy and isinstance(result_copy["error"], str): if len(result_copy["error"]) > max_length // 2: - result_copy["error"] = result_copy["error"][:max_length // 2] + "... [truncated]" + result_copy["error"] = ( + result_copy["error"][: max_length // 2] + "... [truncated]" + ) return result_copy + def init_system_message(args): - context_parts = ["""You are a professional AI assistant with access to advanced tools. + context_parts = [ + """You are a professional AI assistant with access to advanced tools. File Operations: - Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications @@ -51,14 +73,15 @@ Process Management: Shell Commands: - Be a shell ninja using native OS tools - Prefer standard Unix utilities over complex scripts -- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""] - #context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."] +- Use run_command_interactive for commands requiring user input (vim, nano, etc.)""" + ] + # context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."] max_context_size = 10000 if args.include_env: env_context = "Environment Variables:\n" for key, value in os.environ.items(): - if not key.startswith('_'): + if not key.startswith("_"): env_context += f"{key}={value}\n" if len(env_context) > max_context_size: env_context = env_context[:max_context_size] + "\n... [truncated]" @@ -67,7 +90,7 @@ Shell Commands: for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]: if os.path.exists(context_file): try: - with open(context_file, 'r') as f: + with open(context_file) as f: content = f.read() if len(content) > max_context_size: content = content[:max_context_size] + "\n... [truncated]" @@ -78,7 +101,7 @@ Shell Commands: if args.context: for ctx_file in args.context: try: - with open(ctx_file, 'r') as f: + with open(ctx_file) as f: content = f.read() if len(content) > max_context_size: content = content[:max_context_size] + "\n... [truncated]" @@ -88,22 +111,29 @@ Shell Commands: system_message = "\n\n".join(context_parts) if len(system_message) > max_context_size * 3: - system_message = system_message[:max_context_size * 3] + "\n... [system message truncated]" + system_message = ( + system_message[: max_context_size * 3] + "\n... [system message truncated]" + ) return {"role": "system", "content": system_message} + def should_compress_context(messages): return len(messages) > CONTEXT_COMPRESSION_THRESHOLD + def compress_context(messages): return manage_context_window(messages, verbose=False) + def manage_context_window(messages, verbose): if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD: return messages if verbose: - print(f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}") + print( + f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}" + ) system_message = messages[0] recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:] @@ -113,18 +143,21 @@ def manage_context_window(messages, verbose): summary = summarize_messages(middle_messages) summary_message = { "role": "system", - "content": f"[Previous conversation summary: {summary}]" + "content": f"[Previous conversation summary: {summary}]", } new_messages = [system_message, summary_message] + recent_messages if verbose: - print(f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}") + print( + f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}" + ) return new_messages return messages + def summarize_messages(messages): summary_parts = [] @@ -142,6 +175,7 @@ def summarize_messages(messages): return " | ".join(summary_parts[:10]) + def estimate_tokens(messages): total_chars = 0 @@ -155,6 +189,7 @@ def estimate_tokens(messages): return int(estimated_tokens * overhead_multiplier) + def trim_message_content(message, max_length): trimmed_msg = message.copy() @@ -162,14 +197,22 @@ def trim_message_content(message, max_length): content = trimmed_msg["content"] if isinstance(content, str) and len(content) > max_length: - trimmed_msg["content"] = content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]" + trimmed_msg["content"] = ( + content[:max_length] + + f"\n... [trimmed {len(content) - max_length} chars]" + ) elif isinstance(content, list): trimmed_content = [] for item in content: if isinstance(item, dict): trimmed_item = item.copy() - if "text" in trimmed_item and len(trimmed_item["text"]) > max_length: - trimmed_item["text"] = trimmed_item["text"][:max_length] + f"\n... [trimmed]" + if ( + "text" in trimmed_item + and len(trimmed_item["text"]) > max_length + ): + trimmed_item["text"] = ( + trimmed_item["text"][:max_length] + f"\n... [trimmed]" + ) trimmed_content.append(trimmed_item) else: trimmed_content.append(item) @@ -179,35 +222,61 @@ def trim_message_content(message, max_length): if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str): content = trimmed_msg["content"] if len(content) > MAX_TOOL_RESULT_LENGTH: - trimmed_msg["content"] = content[:MAX_TOOL_RESULT_LENGTH] + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]" + trimmed_msg["content"] = ( + content[:MAX_TOOL_RESULT_LENGTH] + + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]" + ) try: parsed = json.loads(content) if isinstance(parsed, dict): - if "output" in parsed and isinstance(parsed["output"], str) and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2: - parsed["output"] = parsed["output"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]" - if "content" in parsed and isinstance(parsed["content"], str) and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2: - parsed["content"] = parsed["content"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]" + if ( + "output" in parsed + and isinstance(parsed["output"], str) + and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2 + ): + parsed["output"] = ( + parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2] + + f"\n... [truncated]" + ) + if ( + "content" in parsed + and isinstance(parsed["content"], str) + and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2 + ): + parsed["content"] = ( + parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2] + + f"\n... [truncated]" + ) trimmed_msg["content"] = json.dumps(parsed) except: pass return trimmed_msg + def intelligently_trim_messages(messages, target_tokens, keep_recent=3): if estimate_tokens(messages) <= target_tokens: return messages - system_msg = messages[0] if messages and messages[0].get("role") == "system" else None + system_msg = ( + messages[0] if messages and messages[0].get("role") == "system" else None + ) start_idx = 1 if system_msg else 0 - recent_messages = messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:] - middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else [] + recent_messages = ( + messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:] + ) + middle_messages = ( + messages[start_idx:-keep_recent] if len(messages) > keep_recent else [] + ) trimmed_middle = [] for msg in middle_messages: if msg.get("role") == "tool": - trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)) + trimmed_middle.append( + trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2) + ) elif msg.get("role") in ["user", "assistant"]: trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH)) else: @@ -233,6 +302,7 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3): return ([system_msg] if system_msg else []) + messages[-1:] + def auto_slim_messages(messages, verbose=False): estimated_tokens = estimate_tokens(messages) @@ -240,29 +310,46 @@ def auto_slim_messages(messages, verbose=False): return messages if verbose: - print(f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}") - print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}") + print( + f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}" + ) + print( + f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}" + ) - result = intelligently_trim_messages(messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP) + result = intelligently_trim_messages( + messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP + ) final_tokens = estimate_tokens(result) if final_tokens > MAX_TOKENS_LIMIT: if verbose: - print(f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}") + print( + f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}" + ) result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose) final_tokens = estimate_tokens(result) if verbose: removed_count = len(messages) - len(result) - print(f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}") - print(f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}") + print( + f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}" + ) + print( + f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}" + ) if removed_count > 0: - print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}") + print( + f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}" + ) return result + def emergency_reduce_messages(messages, target_tokens, verbose=False): - system_msg = messages[0] if messages and messages[0].get("role") == "system" else None + system_msg = ( + messages[0] if messages and messages[0].get("role") == "system" else None + ) start_idx = 1 if system_msg else 0 keep_count = 2 diff --git a/pr/core/enhanced_assistant.py b/pr/core/enhanced_assistant.py index 412befb..c36e732 100644 --- a/pr/core/enhanced_assistant.py +++ b/pr/core/enhanced_assistant.py @@ -1,22 +1,29 @@ -import logging import json +import logging import uuid -from typing import Optional, Dict, Any, List -from pr.config import ( - DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL, - WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS, - KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED, - MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD -) -from pr.cache import APICache, ToolCache -from pr.workflows import WorkflowEngine, WorkflowStorage +from typing import Any, Dict, List, Optional + from pr.agents import AgentManager -from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor +from pr.cache import APICache, ToolCache +from pr.config import ( + ADVANCED_CONTEXT_ENABLED, + API_CACHE_TTL, + CACHE_ENABLED, + CONVERSATION_SUMMARY_THRESHOLD, + DB_PATH, + KNOWLEDGE_SEARCH_LIMIT, + MEMORY_AUTO_SUMMARIZE, + TOOL_CACHE_TTL, + WORKFLOW_EXECUTOR_MAX_WORKERS, +) from pr.core.advanced_context import AdvancedContextManager from pr.core.api import call_api +from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore from pr.tools.base import get_tools_definition +from pr.workflows import WorkflowEngine, WorkflowStorage + +logger = logging.getLogger("pr") -logger = logging.getLogger('pr') class EnhancedAssistant: def __init__(self, base_assistant): @@ -32,7 +39,7 @@ class EnhancedAssistant: self.workflow_storage = WorkflowStorage(DB_PATH) self.workflow_engine = WorkflowEngine( tool_executor=self._execute_tool_for_workflow, - max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS + max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS, ) self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent) @@ -44,20 +51,21 @@ class EnhancedAssistant: if ADVANCED_CONTEXT_ENABLED: self.context_manager = AdvancedContextManager( knowledge_store=self.knowledge_store, - conversation_memory=self.conversation_memory + conversation_memory=self.conversation_memory, ) else: self.context_manager = None self.current_conversation_id = str(uuid.uuid4())[:16] self.conversation_memory.create_conversation( - self.current_conversation_id, - session_id=str(uuid.uuid4())[:16] + self.current_conversation_id, session_id=str(uuid.uuid4())[:16] ) logger.info("Enhanced Assistant initialized with all features") - def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + def _execute_tool_for_workflow( + self, tool_name: str, arguments: Dict[str, Any] + ) -> Any: if self.tool_cache: cached_result = self.tool_cache.get(tool_name, arguments) if cached_result is not None: @@ -65,41 +73,66 @@ class EnhancedAssistant: return cached_result func_map = { - 'read_file': lambda **kw: self.base.execute_tool_calls([{ - 'id': 'temp', - 'function': {'name': 'read_file', 'arguments': json.dumps(kw)} - }])[0], - 'write_file': lambda **kw: self.base.execute_tool_calls([{ - 'id': 'temp', - 'function': {'name': 'write_file', 'arguments': json.dumps(kw)} - }])[0], - 'list_directory': lambda **kw: self.base.execute_tool_calls([{ - 'id': 'temp', - 'function': {'name': 'list_directory', 'arguments': json.dumps(kw)} - }])[0], - 'run_command': lambda **kw: self.base.execute_tool_calls([{ - 'id': 'temp', - 'function': {'name': 'run_command', 'arguments': json.dumps(kw)} - }])[0], + "read_file": lambda **kw: self.base.execute_tool_calls( + [ + { + "id": "temp", + "function": {"name": "read_file", "arguments": json.dumps(kw)}, + } + ] + )[0], + "write_file": lambda **kw: self.base.execute_tool_calls( + [ + { + "id": "temp", + "function": {"name": "write_file", "arguments": json.dumps(kw)}, + } + ] + )[0], + "list_directory": lambda **kw: self.base.execute_tool_calls( + [ + { + "id": "temp", + "function": { + "name": "list_directory", + "arguments": json.dumps(kw), + }, + } + ] + )[0], + "run_command": lambda **kw: self.base.execute_tool_calls( + [ + { + "id": "temp", + "function": { + "name": "run_command", + "arguments": json.dumps(kw), + }, + } + ] + )[0], } if tool_name in func_map: result = func_map[tool_name](**arguments) if self.tool_cache: - content = result.get('content', '') + content = result.get("content", "") try: - parsed_content = json.loads(content) if isinstance(content, str) else content + parsed_content = ( + json.loads(content) if isinstance(content, str) else content + ) self.tool_cache.set(tool_name, arguments, parsed_content) except Exception: pass return result - return {'error': f'Unknown tool: {tool_name}'} + return {"error": f"Unknown tool: {tool_name}"} - def _api_caller_for_agent(self, messages: List[Dict[str, Any]], - temperature: float, max_tokens: int) -> Dict[str, Any]: + def _api_caller_for_agent( + self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int + ) -> Dict[str, Any]: return call_api( messages, self.base.model, @@ -109,15 +142,12 @@ class EnhancedAssistant: tools=None, temperature=temperature, max_tokens=max_tokens, - verbose=self.base.verbose + verbose=self.base.verbose, ) def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: if self.api_cache and CACHE_ENABLED: - cached_response = self.api_cache.get( - self.base.model, messages, - 0.7, 4096 - ) + cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096) if cached_response: logger.debug("API cache hit") return cached_response @@ -129,15 +159,13 @@ class EnhancedAssistant: self.base.api_key, self.base.use_tools, get_tools_definition(), - verbose=self.base.verbose + verbose=self.base.verbose, ) - if self.api_cache and CACHE_ENABLED and 'error' not in response: - token_count = response.get('usage', {}).get('total_tokens', 0) + if self.api_cache and CACHE_ENABLED and "error" not in response: + token_count = response.get("usage", {}).get("total_tokens", 0) self.api_cache.set( - self.base.model, messages, - 0.7, 4096, - response, token_count + self.base.model, messages, 0.7, 4096, response, token_count ) return response @@ -146,35 +174,33 @@ class EnhancedAssistant: self.base.messages.append({"role": "user", "content": user_message}) self.conversation_memory.add_message( - self.current_conversation_id, - str(uuid.uuid4())[:16], - 'user', - user_message + self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message ) if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0: facts = self.fact_extractor.extract_facts(user_message) for fact in facts[:3]: entry_id = str(uuid.uuid4())[:16] - from pr.memory import KnowledgeEntry import time - categories = self.fact_extractor.categorize_content(fact['text']) + from pr.memory import KnowledgeEntry + + categories = self.fact_extractor.categorize_content(fact["text"]) entry = KnowledgeEntry( entry_id=entry_id, - category=categories[0] if categories else 'general', - content=fact['text'], - metadata={'type': fact['type'], 'confidence': fact['confidence']}, + category=categories[0] if categories else "general", + content=fact["text"], + metadata={"type": fact["type"], "confidence": fact["confidence"]}, created_at=time.time(), - updated_at=time.time() + updated_at=time.time(), ) self.knowledge_store.add_entry(entry) if self.context_manager and ADVANCED_CONTEXT_ENABLED: - enhanced_messages, context_info = self.context_manager.create_enhanced_context( - self.base.messages, - user_message, - include_knowledge=True + enhanced_messages, context_info = ( + self.context_manager.create_enhanced_context( + self.base.messages, user_message, include_knowledge=True + ) ) if self.base.verbose: @@ -189,38 +215,40 @@ class EnhancedAssistant: result = self.base.process_response(response) if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD: - summary = self.context_manager.advanced_summarize_messages( - self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:] - ) if self.context_manager else "Conversation in progress" + summary = ( + self.context_manager.advanced_summarize_messages( + self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:] + ) + if self.context_manager + else "Conversation in progress" + ) topics = self.fact_extractor.categorize_content(summary) self.conversation_memory.update_conversation_summary( - self.current_conversation_id, - summary, - topics + self.current_conversation_id, summary, topics ) return result - def execute_workflow(self, workflow_name: str, - initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def execute_workflow( + self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: workflow = self.workflow_storage.load_workflow_by_name(workflow_name) if not workflow: - return {'error': f'Workflow "{workflow_name}" not found'} + return {"error": f'Workflow "{workflow_name}" not found'} context = self.workflow_engine.execute_workflow(workflow, initial_variables) execution_id = self.workflow_storage.save_execution( - self.workflow_storage.load_workflow_by_name(workflow_name).name, - context + self.workflow_storage.load_workflow_by_name(workflow_name).name, context ) return { - 'success': True, - 'execution_id': execution_id, - 'results': context.step_results, - 'execution_log': context.execution_log + "success": True, + "execution_id": execution_id, + "results": context.step_results, + "execution_log": context.execution_log, } def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str: @@ -230,20 +258,22 @@ class EnhancedAssistant: return self.agent_manager.execute_agent_task(agent_id, task) def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]: - orchestrator_id = self.agent_manager.create_agent('orchestrator') + orchestrator_id = self.agent_manager.create_agent("orchestrator") return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles) - def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]: + def search_knowledge( + self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT + ) -> List[Any]: return self.knowledge_store.search_entries(query, top_k=limit) def get_cache_statistics(self) -> Dict[str, Any]: stats = {} if self.api_cache: - stats['api_cache'] = self.api_cache.get_statistics() + stats["api_cache"] = self.api_cache.get_statistics() if self.tool_cache: - stats['tool_cache'] = self.tool_cache.get_statistics() + stats["tool_cache"] = self.tool_cache.get_statistics() return stats diff --git a/pr/core/logging.py b/pr/core/logging.py index 016efb4..4df2153 100644 --- a/pr/core/logging.py +++ b/pr/core/logging.py @@ -1,6 +1,7 @@ import logging import os from logging.handlers import RotatingFileHandler + from pr.config import LOG_FILE @@ -9,21 +10,19 @@ def setup_logging(verbose=False): if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) - logger = logging.getLogger('pr') + logger = logging.getLogger("pr") logger.setLevel(logging.DEBUG if verbose else logging.INFO) if logger.handlers: logger.handlers.clear() file_handler = RotatingFileHandler( - LOG_FILE, - maxBytes=10 * 1024 * 1024, - backupCount=5 + LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5 ) file_handler.setLevel(logging.DEBUG) file_formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' + "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) file_handler.setFormatter(file_formatter) logger.addHandler(file_handler) @@ -31,9 +30,7 @@ def setup_logging(verbose=False): if verbose: console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) - console_formatter = logging.Formatter( - '%(levelname)s: %(message)s' - ) + console_formatter = logging.Formatter("%(levelname)s: %(message)s") console_handler.setFormatter(console_formatter) logger.addHandler(console_handler) @@ -42,5 +39,5 @@ def setup_logging(verbose=False): def get_logger(name=None): if name: - return logging.getLogger(f'pr.{name}') - return logging.getLogger('pr') + return logging.getLogger(f"pr.{name}") + return logging.getLogger("pr") diff --git a/pr/core/session.py b/pr/core/session.py index 8172440..367400a 100644 --- a/pr/core/session.py +++ b/pr/core/session.py @@ -2,9 +2,10 @@ import json import os from datetime import datetime from typing import Dict, List, Optional + from pr.core.logging import get_logger -logger = get_logger('session') +logger = get_logger("session") SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions") @@ -14,18 +15,20 @@ class SessionManager: def __init__(self): os.makedirs(SESSIONS_DIR, exist_ok=True) - def save_session(self, name: str, messages: List[Dict], metadata: Optional[Dict] = None) -> bool: + def save_session( + self, name: str, messages: List[Dict], metadata: Optional[Dict] = None + ) -> bool: try: session_file = os.path.join(SESSIONS_DIR, f"{name}.json") session_data = { - 'name': name, - 'created_at': datetime.now().isoformat(), - 'messages': messages, - 'metadata': metadata or {} + "name": name, + "created_at": datetime.now().isoformat(), + "messages": messages, + "metadata": metadata or {}, } - with open(session_file, 'w') as f: + with open(session_file, "w") as f: json.dump(session_data, f, indent=2) logger.info(f"Session saved: {name}") @@ -43,7 +46,7 @@ class SessionManager: logger.warning(f"Session not found: {name}") return None - with open(session_file, 'r') as f: + with open(session_file) as f: session_data = json.load(f) logger.info(f"Session loaded: {name}") @@ -58,22 +61,24 @@ class SessionManager: try: for filename in os.listdir(SESSIONS_DIR): - if filename.endswith('.json'): + if filename.endswith(".json"): filepath = os.path.join(SESSIONS_DIR, filename) try: - with open(filepath, 'r') as f: + with open(filepath) as f: data = json.load(f) - sessions.append({ - 'name': data.get('name', filename[:-5]), - 'created_at': data.get('created_at', 'unknown'), - 'message_count': len(data.get('messages', [])), - 'metadata': data.get('metadata', {}) - }) + sessions.append( + { + "name": data.get("name", filename[:-5]), + "created_at": data.get("created_at", "unknown"), + "message_count": len(data.get("messages", [])), + "metadata": data.get("metadata", {}), + } + ) except Exception as e: logger.warning(f"Error reading session file {filename}: {e}") - sessions.sort(key=lambda x: x['created_at'], reverse=True) + sessions.sort(key=lambda x: x["created_at"], reverse=True) except Exception as e: logger.error(f"Error listing sessions: {e}") @@ -96,39 +101,39 @@ class SessionManager: logger.error(f"Error deleting session {name}: {e}") return False - def export_session(self, name: str, output_path: str, format: str = 'json') -> bool: + def export_session(self, name: str, output_path: str, format: str = "json") -> bool: session_data = self.load_session(name) if not session_data: return False try: - if format == 'json': - with open(output_path, 'w') as f: + if format == "json": + with open(output_path, "w") as f: json.dump(session_data, f, indent=2) - elif format == 'markdown': - with open(output_path, 'w') as f: + elif format == "markdown": + with open(output_path, "w") as f: f.write(f"# Session: {name}\n\n") f.write(f"Created: {session_data['created_at']}\n\n") f.write("---\n\n") - for msg in session_data['messages']: - role = msg.get('role', 'unknown') - content = msg.get('content', '') + for msg in session_data["messages"]: + role = msg.get("role", "unknown") + content = msg.get("content", "") f.write(f"## {role.capitalize()}\n\n") f.write(f"{content}\n\n") f.write("---\n\n") - elif format == 'txt': - with open(output_path, 'w') as f: + elif format == "txt": + with open(output_path, "w") as f: f.write(f"Session: {name}\n") f.write(f"Created: {session_data['created_at']}\n") f.write("=" * 80 + "\n\n") - for msg in session_data['messages']: - role = msg.get('role', 'unknown') - content = msg.get('content', '') + for msg in session_data["messages"]: + role = msg.get("role", "unknown") + content = msg.get("content", "") f.write(f"[{role.upper()}]\n") f.write(f"{content}\n") diff --git a/pr/core/usage_tracker.py b/pr/core/usage_tracker.py index 5159d6d..d47eaae 100644 --- a/pr/core/usage_tracker.py +++ b/pr/core/usage_tracker.py @@ -2,20 +2,21 @@ import json import os from datetime import datetime from typing import Dict, Optional + from pr.core.logging import get_logger -logger = get_logger('usage') +logger = get_logger("usage") USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json") MODEL_COSTS = { - 'x-ai/grok-code-fast-1': {'input': 0.0, 'output': 0.0}, - 'gpt-4': {'input': 0.03, 'output': 0.06}, - 'gpt-4-turbo': {'input': 0.01, 'output': 0.03}, - 'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015}, - 'claude-3-opus': {'input': 0.015, 'output': 0.075}, - 'claude-3-sonnet': {'input': 0.003, 'output': 0.015}, - 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}, + "x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0}, + "gpt-4": {"input": 0.03, "output": 0.06}, + "gpt-4-turbo": {"input": 0.01, "output": 0.03}, + "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, + "claude-3-opus": {"input": 0.015, "output": 0.075}, + "claude-3-sonnet": {"input": 0.003, "output": 0.015}, + "claude-3-haiku": {"input": 0.00025, "output": 0.00125}, } @@ -23,12 +24,12 @@ class UsageTracker: def __init__(self): self.session_usage = { - 'requests': 0, - 'total_tokens': 0, - 'input_tokens': 0, - 'output_tokens': 0, - 'estimated_cost': 0.0, - 'models_used': {} + "requests": 0, + "total_tokens": 0, + "input_tokens": 0, + "output_tokens": 0, + "estimated_cost": 0.0, + "models_used": {}, } def track_request( @@ -36,30 +37,30 @@ class UsageTracker: model: str, input_tokens: int, output_tokens: int, - total_tokens: Optional[int] = None + total_tokens: Optional[int] = None, ): if total_tokens is None: total_tokens = input_tokens + output_tokens - self.session_usage['requests'] += 1 - self.session_usage['total_tokens'] += total_tokens - self.session_usage['input_tokens'] += input_tokens - self.session_usage['output_tokens'] += output_tokens + self.session_usage["requests"] += 1 + self.session_usage["total_tokens"] += total_tokens + self.session_usage["input_tokens"] += input_tokens + self.session_usage["output_tokens"] += output_tokens - if model not in self.session_usage['models_used']: - self.session_usage['models_used'][model] = { - 'requests': 0, - 'tokens': 0, - 'cost': 0.0 + if model not in self.session_usage["models_used"]: + self.session_usage["models_used"][model] = { + "requests": 0, + "tokens": 0, + "cost": 0.0, } - model_usage = self.session_usage['models_used'][model] - model_usage['requests'] += 1 - model_usage['tokens'] += total_tokens + model_usage = self.session_usage["models_used"][model] + model_usage["requests"] += 1 + model_usage["tokens"] += total_tokens cost = self._calculate_cost(model, input_tokens, output_tokens) - model_usage['cost'] += cost - self.session_usage['estimated_cost'] += cost + model_usage["cost"] += cost + self.session_usage["estimated_cost"] += cost self._save_to_history(model, input_tokens, output_tokens, cost) @@ -67,9 +68,11 @@ class UsageTracker: f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}" ) - def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float: + def _calculate_cost( + self, model: str, input_tokens: int, output_tokens: int + ) -> float: if model not in MODEL_COSTS: - base_model = model.split('/')[0] if '/' in model else model + base_model = model.split("/")[0] if "/" in model else model if base_model not in MODEL_COSTS: logger.warning(f"Unknown model for cost calculation: {model}") return 0.0 @@ -77,31 +80,35 @@ class UsageTracker: else: costs = MODEL_COSTS[model] - input_cost = (input_tokens / 1000) * costs['input'] - output_cost = (output_tokens / 1000) * costs['output'] + input_cost = (input_tokens / 1000) * costs["input"] + output_cost = (output_tokens / 1000) * costs["output"] return input_cost + output_cost - def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float): + def _save_to_history( + self, model: str, input_tokens: int, output_tokens: int, cost: float + ): try: history = [] if os.path.exists(USAGE_DB_FILE): - with open(USAGE_DB_FILE, 'r') as f: + with open(USAGE_DB_FILE) as f: history = json.load(f) - history.append({ - 'timestamp': datetime.now().isoformat(), - 'model': model, - 'input_tokens': input_tokens, - 'output_tokens': output_tokens, - 'total_tokens': input_tokens + output_tokens, - 'cost': cost - }) + history.append( + { + "timestamp": datetime.now().isoformat(), + "model": model, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "cost": cost, + } + ) if len(history) > 10000: history = history[-10000:] - with open(USAGE_DB_FILE, 'w') as f: + with open(USAGE_DB_FILE, "w") as f: json.dump(history, f, indent=2) except Exception as e: @@ -121,42 +128,34 @@ class UsageTracker: f"Estimated Cost: ${usage['estimated_cost']:.4f}", ] - if usage['models_used']: + if usage["models_used"]: lines.append("\nModels Used:") - for model, stats in usage['models_used'].items(): + for model, stats in usage["models_used"].items(): lines.append( f" {model}: {stats['requests']} requests, " f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}" ) - return '\n'.join(lines) + return "\n".join(lines) @staticmethod def get_total_usage() -> Dict: if not os.path.exists(USAGE_DB_FILE): - return { - 'total_requests': 0, - 'total_tokens': 0, - 'total_cost': 0.0 - } + return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0} try: - with open(USAGE_DB_FILE, 'r') as f: + with open(USAGE_DB_FILE) as f: history = json.load(f) - total_tokens = sum(entry['total_tokens'] for entry in history) - total_cost = sum(entry['cost'] for entry in history) + total_tokens = sum(entry["total_tokens"] for entry in history) + total_cost = sum(entry["cost"] for entry in history) return { - 'total_requests': len(history), - 'total_tokens': total_tokens, - 'total_cost': total_cost + "total_requests": len(history), + "total_tokens": total_tokens, + "total_cost": total_cost, } except Exception as e: logger.error(f"Error loading usage history: {e}") - return { - 'total_requests': 0, - 'total_tokens': 0, - 'total_cost': 0.0 - } + return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0} diff --git a/pr/core/validation.py b/pr/core/validation.py index 6bca76e..02c5301 100644 --- a/pr/core/validation.py +++ b/pr/core/validation.py @@ -1,5 +1,5 @@ import os -from typing import Optional + from pr.core.exceptions import ValidationError @@ -16,7 +16,9 @@ def validate_file_path(path: str, must_exist: bool = False) -> str: return os.path.abspath(path) -def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str: +def validate_directory_path( + path: str, must_exist: bool = False, create: bool = False +) -> str: if not path: raise ValidationError("Directory path cannot be empty") @@ -48,7 +50,7 @@ def validate_api_url(url: str) -> str: if not url: raise ValidationError("API URL cannot be empty") - if not url.startswith(('http://', 'https://')): + if not url.startswith(("http://", "https://")): raise ValidationError("API URL must start with http:// or https://") return url @@ -58,7 +60,7 @@ def validate_session_name(name: str) -> str: if not name: raise ValidationError("Session name cannot be empty") - invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|'] + invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] for char in invalid_chars: if char in name: raise ValidationError(f"Session name contains invalid character: {char}") diff --git a/pr/editor.py b/pr/editor.py index 598db74..5fbb45d 100644 --- a/pr/editor.py +++ b/pr/editor.py @@ -1,23 +1,22 @@ #!/usr/bin/env python3 +import atexit import curses -import threading -import sys import os -import re -import socket import pickle import queue -import time -import atexit +import re import signal -import traceback -from contextlib import contextmanager +import socket +import sys +import threading +import time + class RPEditor: def __init__(self, filename=None, auto_save=False, timeout=30): """ Initialize RPEditor with enhanced robustness features. - + Args: filename: File to edit auto_save: Enable auto-save on exit @@ -27,7 +26,7 @@ class RPEditor: self.lines = [""] self.cursor_y = 0 self.cursor_x = 0 - self.mode = 'normal' + self.mode = "normal" self.command = "" self.stdscr = None self.running = False @@ -47,7 +46,7 @@ class RPEditor: self._cleanup_registered = False self._original_terminal_state = None self._exception_occurred = False - + # Create socket pair with error handling try: self.client_sock, self.server_sock = socket.socketpair() @@ -56,10 +55,10 @@ class RPEditor: except Exception as e: self._cleanup() raise RuntimeError(f"Failed to create socket pair: {e}") - + # Register cleanup handlers self._register_cleanup() - + if filename: self.load_file() @@ -81,14 +80,14 @@ class RPEditor: try: # Stop the editor self.running = False - + # Save if auto-save is enabled if self.auto_save and self.filename and not self._exception_occurred: try: self._save_file() except: pass - + # Clean up curses if self.stdscr: try: @@ -103,13 +102,13 @@ class RPEditor: curses.endwin() except: pass - + # Clear screen after curses cleanup try: - os.system('clear' if os.name != 'nt' else 'cls') + os.system("clear" if os.name != "nt" else "cls") except: pass - + # Close sockets for sock in [self.client_sock, self.server_sock]: if sock: @@ -117,12 +116,12 @@ class RPEditor: sock.close() except: pass - + # Wait for threads to finish for thread in [self.thread, self.socket_thread]: if thread and thread.is_alive(): thread.join(timeout=1) - + except: pass @@ -130,12 +129,12 @@ class RPEditor: """Load file with enhanced error handling.""" try: if os.path.exists(self.filename): - with open(self.filename, 'r', encoding='utf-8', errors='replace') as f: + with open(self.filename, encoding="utf-8", errors="replace") as f: content = f.read() self.lines = content.splitlines() if content else [""] else: self.lines = [""] - except Exception as e: + except Exception: self.lines = [""] # Don't raise, just use empty content @@ -144,24 +143,24 @@ class RPEditor: with self.lock: if not self.filename: return False - + try: # Create backup if file exists if os.path.exists(self.filename): backup_name = f"{self.filename}.bak" try: - with open(self.filename, 'r', encoding='utf-8') as f: + with open(self.filename, encoding="utf-8") as f: backup_content = f.read() - with open(backup_name, 'w', encoding='utf-8') as f: + with open(backup_name, "w", encoding="utf-8") as f: f.write(backup_content) except: pass # Backup failed, but continue with save - + # Save the file - with open(self.filename, 'w', encoding='utf-8') as f: - f.write('\n'.join(self.lines)) + with open(self.filename, "w", encoding="utf-8") as f: + f.write("\n".join(self.lines)) return True - except Exception as e: + except Exception: return False def save_file(self): @@ -169,7 +168,7 @@ class RPEditor: if not self.running: return self._save_file() try: - self.client_sock.send(pickle.dumps({'command': 'save_file'})) + self.client_sock.send(pickle.dumps({"command": "save_file"})) except: return self._save_file() # Fallback to direct save @@ -177,10 +176,12 @@ class RPEditor: """Start the editor with enhanced error handling.""" if self.running: return False - + try: self.running = True - self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True) + self.socket_thread = threading.Thread( + target=self.socket_listener, daemon=True + ) self.socket_thread.start() self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start() @@ -194,10 +195,10 @@ class RPEditor: """Stop the editor with proper cleanup.""" try: if self.client_sock: - self.client_sock.send(pickle.dumps({'command': 'stop'})) + self.client_sock.send(pickle.dumps({"command": "stop"})) except: pass - + self.running = False time.sleep(0.1) # Give threads time to finish self._cleanup() @@ -206,20 +207,20 @@ class RPEditor: """Run the main editor loop with exception handling.""" try: curses.wrapper(self.main_loop) - except Exception as e: + except Exception: self._exception_occurred = True self._cleanup() def main_loop(self, stdscr): """Main editor loop with enhanced error recovery.""" self.stdscr = stdscr - + try: # Configure curses curses.curs_set(1) self.stdscr.keypad(True) self.stdscr.timeout(100) # Non-blocking with timeout - + while self.running: try: # Process queued commands @@ -230,11 +231,11 @@ class RPEditor: self.execute_command(command) except queue.Empty: break - + # Draw screen with self.lock: self.draw() - + # Handle input try: key = self.stdscr.getch() @@ -243,12 +244,12 @@ class RPEditor: self.handle_key(key) except curses.error: pass # Ignore curses errors - - except Exception as e: + + except Exception: # Log error but continue running pass - - except Exception as e: + + except Exception: self._exception_occurred = True finally: self._cleanup() @@ -258,28 +259,28 @@ class RPEditor: try: self.stdscr.clear() height, width = self.stdscr.getmaxyx() - + # Draw lines for i, line in enumerate(self.lines): if i >= height - 1: break try: # Handle long lines and special characters - display_line = line[:width-1] if len(line) >= width else line + display_line = line[: width - 1] if len(line) >= width else line self.stdscr.addstr(i, 0, display_line) except curses.error: pass # Skip lines that can't be displayed - + # Draw status line status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}" - if self.mode == 'command': - status = self.command[:width-1] - + if self.mode == "command": + status = self.command[: width - 1] + try: - self.stdscr.addstr(height - 1, 0, status[:width-1]) + self.stdscr.addstr(height - 1, 0, status[: width - 1]) except curses.error: pass - + # Position cursor cursor_x = min(self.cursor_x, width - 1) cursor_y = min(self.cursor_y, height - 2) @@ -287,7 +288,7 @@ class RPEditor: self.stdscr.move(cursor_y, cursor_x) except curses.error: pass - + self.stdscr.refresh() except Exception: pass # Continue even if draw fails @@ -295,11 +296,11 @@ class RPEditor: def handle_key(self, key): """Handle keyboard input with error recovery.""" try: - if self.mode == 'normal': + if self.mode == "normal": self.handle_normal(key) - elif self.mode == 'insert': + elif self.mode == "insert": self.handle_insert(key) - elif self.mode == 'command': + elif self.mode == "command": self.handle_command(key) except Exception: pass # Continue on error @@ -307,73 +308,73 @@ class RPEditor: def handle_normal(self, key): """Handle normal mode keys.""" try: - if key == ord('h') or key == curses.KEY_LEFT: + if key == ord("h") or key == curses.KEY_LEFT: self.move_cursor(0, -1) - elif key == ord('j') or key == curses.KEY_DOWN: + elif key == ord("j") or key == curses.KEY_DOWN: self.move_cursor(1, 0) - elif key == ord('k') or key == curses.KEY_UP: + elif key == ord("k") or key == curses.KEY_UP: self.move_cursor(-1, 0) - elif key == ord('l') or key == curses.KEY_RIGHT: + elif key == ord("l") or key == curses.KEY_RIGHT: self.move_cursor(0, 1) - elif key == ord('i'): - self.mode = 'insert' - elif key == ord(':'): - self.mode = 'command' + elif key == ord("i"): + self.mode = "insert" + elif key == ord(":"): + self.mode = "command" self.command = ":" - elif key == ord('x'): + elif key == ord("x"): self._delete_char() - elif key == ord('a'): + elif key == ord("a"): self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y])) - self.mode = 'insert' - elif key == ord('A'): + self.mode = "insert" + elif key == ord("A"): self.cursor_x = len(self.lines[self.cursor_y]) - self.mode = 'insert' - elif key == ord('o'): + self.mode = "insert" + elif key == ord("o"): self._insert_line(self.cursor_y + 1, "") self.cursor_y += 1 self.cursor_x = 0 - self.mode = 'insert' - elif key == ord('O'): + self.mode = "insert" + elif key == ord("O"): self._insert_line(self.cursor_y, "") self.cursor_x = 0 - self.mode = 'insert' - elif key == ord('d') and self.prev_key == ord('d'): + self.mode = "insert" + elif key == ord("d") and self.prev_key == ord("d"): if self.cursor_y < len(self.lines): self.clipboard = self.lines[self.cursor_y] self._delete_line(self.cursor_y) if self.cursor_y >= len(self.lines): self.cursor_y = max(0, len(self.lines) - 1) self.cursor_x = 0 - elif key == ord('y') and self.prev_key == ord('y'): + elif key == ord("y") and self.prev_key == ord("y"): if self.cursor_y < len(self.lines): self.clipboard = self.lines[self.cursor_y] - elif key == ord('p'): + elif key == ord("p"): self._insert_line(self.cursor_y + 1, self.clipboard) self.cursor_y += 1 self.cursor_x = 0 - elif key == ord('P'): + elif key == ord("P"): self._insert_line(self.cursor_y, self.clipboard) self.cursor_x = 0 - elif key == ord('w'): + elif key == ord("w"): self._move_word_forward() - elif key == ord('b'): + elif key == ord("b"): self._move_word_backward() - elif key == ord('0'): + elif key == ord("0"): self.cursor_x = 0 - elif key == ord('$'): + elif key == ord("$"): self.cursor_x = len(self.lines[self.cursor_y]) - elif key == ord('g'): - if self.prev_key == ord('g'): + elif key == ord("g"): + if self.prev_key == ord("g"): self.cursor_y = 0 self.cursor_x = 0 - elif key == ord('G'): + elif key == ord("G"): self.cursor_y = max(0, len(self.lines) - 1) self.cursor_x = 0 - elif key == ord('u'): + elif key == ord("u"): self.undo() - elif key == ord('r') and self.prev_key == 18: # Ctrl-R + elif key == ord("r") and self.prev_key == 18: # Ctrl-R self.redo() - + self.prev_key = key except Exception: pass @@ -410,7 +411,7 @@ class RPEditor: """Handle insert mode keys.""" try: if key == 27: # ESC - self.mode = 'normal' + self.mode = "normal" if self.cursor_x > 0: self.cursor_x -= 1 elif key == 10 or key == 13: # Enter @@ -438,10 +439,10 @@ class RPEditor: elif cmd.startswith("w "): self.filename = cmd[2:].strip() self._save_file() - self.mode = 'normal' + self.mode = "normal" self.command = "" elif key == 27: # ESC - self.mode = 'normal' + self.mode = "normal" self.command = "" elif key == curses.KEY_BACKSPACE or key == 127 or key == 8: if len(self.command) > 1: @@ -449,17 +450,17 @@ class RPEditor: elif 32 <= key <= 126: self.command += chr(key) except Exception: - self.mode = 'normal' + self.mode = "normal" self.command = "" def move_cursor(self, dy, dx): """Move cursor with bounds checking.""" if not self.lines: self.lines = [""] - + new_y = self.cursor_y + dy new_x = self.cursor_x + dx - + # Ensure valid Y position if 0 <= new_y < len(self.lines): self.cursor_y = new_y @@ -477,9 +478,9 @@ class RPEditor: """Save current state for undo.""" with self.lock: state = { - 'lines': [line for line in self.lines], - 'cursor_y': self.cursor_y, - 'cursor_x': self.cursor_x + "lines": list(self.lines), + "cursor_y": self.cursor_y, + "cursor_x": self.cursor_x, } self.undo_stack.append(state) if len(self.undo_stack) > self.max_undo: @@ -491,69 +492,79 @@ class RPEditor: with self.lock: if self.undo_stack: current_state = { - 'lines': [line for line in self.lines], - 'cursor_y': self.cursor_y, - 'cursor_x': self.cursor_x + "lines": list(self.lines), + "cursor_y": self.cursor_y, + "cursor_x": self.cursor_x, } self.redo_stack.append(current_state) state = self.undo_stack.pop() - self.lines = state['lines'] - self.cursor_y = min(state['cursor_y'], len(self.lines) - 1) - self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0) + self.lines = state["lines"] + self.cursor_y = min(state["cursor_y"], len(self.lines) - 1) + self.cursor_x = min( + state["cursor_x"], + len(self.lines[self.cursor_y]) if self.lines else 0, + ) def redo(self): """Redo last undone change.""" with self.lock: if self.redo_stack: current_state = { - 'lines': [line for line in self.lines], - 'cursor_y': self.cursor_y, - 'cursor_x': self.cursor_x + "lines": list(self.lines), + "cursor_y": self.cursor_y, + "cursor_x": self.cursor_x, } self.undo_stack.append(current_state) state = self.redo_stack.pop() - self.lines = state['lines'] - self.cursor_y = min(state['cursor_y'], len(self.lines) - 1) - self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0) + self.lines = state["lines"] + self.cursor_y = min(state["cursor_y"], len(self.lines) - 1) + self.cursor_x = min( + state["cursor_x"], + len(self.lines[self.cursor_y]) if self.lines else 0, + ) def _insert_text(self, text): """Insert text at cursor position.""" if not text: return - + self.save_state() - lines = text.split('\n') - + lines = text.split("\n") + if len(lines) == 1: # Single line insert if self.cursor_y >= len(self.lines): self.lines.append("") self.cursor_y = len(self.lines) - 1 - + line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = line[:self.cursor_x] + text + line[self.cursor_x:] + self.lines[self.cursor_y] = ( + line[: self.cursor_x] + text + line[self.cursor_x :] + ) self.cursor_x += len(text) else: # Multi-line insert if self.cursor_y >= len(self.lines): self.lines.append("") self.cursor_y = len(self.lines) - 1 - - first = self.lines[self.cursor_y][:self.cursor_x] + lines[0] - last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:] - + + first = self.lines[self.cursor_y][: self.cursor_x] + lines[0] + last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :] + self.lines[self.cursor_y] = first for i in range(1, len(lines) - 1): self.lines.insert(self.cursor_y + i, lines[i]) self.lines.insert(self.cursor_y + len(lines) - 1, last) - + self.cursor_y += len(lines) - 1 self.cursor_x = len(lines[-1]) def insert_text(self, text): """Thread-safe text insertion.""" try: - self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text})) + self.client_sock.send( + pickle.dumps({"command": "insert_text", "text": text}) + ) except: with self.lock: self._insert_text(text) @@ -561,14 +572,18 @@ class RPEditor: def _delete_char(self): """Delete character at cursor.""" self.save_state() - if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]): + if self.cursor_y < len(self.lines) and self.cursor_x < len( + self.lines[self.cursor_y] + ): line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = line[:self.cursor_x] + line[self.cursor_x+1:] + self.lines[self.cursor_y] = ( + line[: self.cursor_x] + line[self.cursor_x + 1 :] + ) def delete_char(self): """Thread-safe character deletion.""" try: - self.client_sock.send(pickle.dumps({'command': 'delete_char'})) + self.client_sock.send(pickle.dumps({"command": "delete_char"})) except: with self.lock: self._delete_char() @@ -578,9 +593,9 @@ class RPEditor: if self.cursor_y >= len(self.lines): self.lines.append("") self.cursor_y = len(self.lines) - 1 - + line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = line[:self.cursor_x] + char + line[self.cursor_x:] + self.lines[self.cursor_y] = line[: self.cursor_x] + char + line[self.cursor_x :] self.cursor_x += 1 def _split_line(self): @@ -588,10 +603,10 @@ class RPEditor: if self.cursor_y >= len(self.lines): self.lines.append("") self.cursor_y = len(self.lines) - 1 - + line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = line[:self.cursor_x] - self.lines.insert(self.cursor_y + 1, line[self.cursor_x:]) + self.lines[self.cursor_y] = line[: self.cursor_x] + self.lines.insert(self.cursor_y + 1, line[self.cursor_x :]) self.cursor_y += 1 self.cursor_x = 0 @@ -599,7 +614,9 @@ class RPEditor: """Handle backspace key.""" if self.cursor_x > 0: line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = line[:self.cursor_x-1] + line[self.cursor_x:] + self.lines[self.cursor_y] = ( + line[: self.cursor_x - 1] + line[self.cursor_x :] + ) self.cursor_x -= 1 elif self.cursor_y > 0: prev_len = len(self.lines[self.cursor_y - 1]) @@ -637,7 +654,7 @@ class RPEditor: self._set_text(text) return try: - self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text})) + self.client_sock.send(pickle.dumps({"command": "set_text", "text": text})) except: with self.lock: self._set_text(text) @@ -651,7 +668,9 @@ class RPEditor: def goto_line(self, line_num): """Thread-safe goto line.""" try: - self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num})) + self.client_sock.send( + pickle.dumps({"command": "goto_line", "line_num": line_num}) + ) except: with self.lock: self._goto_line(line_num) @@ -659,17 +678,17 @@ class RPEditor: def get_text(self): """Get entire text content.""" try: - self.client_sock.send(pickle.dumps({'command': 'get_text'})) + self.client_sock.send(pickle.dumps({"command": "get_text"})) data = self.client_sock.recv(65536) return pickle.loads(data) except: with self.lock: - return '\n'.join(self.lines) + return "\n".join(self.lines) def get_cursor(self): """Get cursor position.""" try: - self.client_sock.send(pickle.dumps({'command': 'get_cursor'})) + self.client_sock.send(pickle.dumps({"command": "get_cursor"})) data = self.client_sock.recv(4096) return pickle.loads(data) except: @@ -679,16 +698,16 @@ class RPEditor: def get_file_info(self): """Get file information.""" try: - self.client_sock.send(pickle.dumps({'command': 'get_file_info'})) + self.client_sock.send(pickle.dumps({"command": "get_file_info"})) data = self.client_sock.recv(4096) return pickle.loads(data) except: with self.lock: return { - 'filename': self.filename, - 'lines': len(self.lines), - 'cursor': (self.cursor_y, self.cursor_x), - 'mode': self.mode + "filename": self.filename, + "lines": len(self.lines), + "cursor": (self.cursor_y, self.cursor_x), + "mode": self.mode, } def socket_listener(self): @@ -713,39 +732,39 @@ class RPEditor: def execute_command(self, command): """Execute command with error handling.""" try: - cmd = command.get('command') - - if cmd == 'insert_text': - self._insert_text(command.get('text', '')) - elif cmd == 'delete_char': + cmd = command.get("command") + + if cmd == "insert_text": + self._insert_text(command.get("text", "")) + elif cmd == "delete_char": self._delete_char() - elif cmd == 'save_file': + elif cmd == "save_file": self._save_file() - elif cmd == 'set_text': - self._set_text(command.get('text', '')) - elif cmd == 'goto_line': - self._goto_line(command.get('line_num', 1)) - elif cmd == 'get_text': - result = '\n'.join(self.lines) + elif cmd == "set_text": + self._set_text(command.get("text", "")) + elif cmd == "goto_line": + self._goto_line(command.get("line_num", 1)) + elif cmd == "get_text": + result = "\n".join(self.lines) self.server_sock.send(pickle.dumps(result)) - elif cmd == 'get_cursor': + elif cmd == "get_cursor": result = (self.cursor_y, self.cursor_x) self.server_sock.send(pickle.dumps(result)) - elif cmd == 'get_file_info': + elif cmd == "get_file_info": result = { - 'filename': self.filename, - 'lines': len(self.lines), - 'cursor': (self.cursor_y, self.cursor_x), - 'mode': self.mode + "filename": self.filename, + "lines": len(self.lines), + "cursor": (self.cursor_y, self.cursor_x), + "mode": self.mode, } self.server_sock.send(pickle.dumps(result)) - elif cmd == 'stop': + elif cmd == "stop": self.running = False except Exception: pass # Additional public methods for backwards compatibility - + def move_cursor_to(self, y, x): """Move cursor to specific position.""" with self.lock: @@ -788,11 +807,11 @@ class RPEditor: """Replace text in range.""" with self.lock: self.save_state() - + # Validate bounds start_line = max(0, min(start_line, len(self.lines) - 1)) end_line = max(0, min(end_line, len(self.lines) - 1)) - + if start_line == end_line: line = self.lines[start_line] start_col = max(0, min(start_col, len(line))) @@ -801,9 +820,9 @@ class RPEditor: else: first_part = self.lines[start_line][:start_col] last_part = self.lines[end_line][end_col:] - new_lines = new_text.split('\n') + new_lines = new_text.split("\n") self.lines[start_line] = first_part + new_lines[0] - del self.lines[start_line + 1:end_line + 1] + del self.lines[start_line + 1 : end_line + 1] for i, new_line in enumerate(new_lines[1:], 1): self.lines.insert(start_line + i, new_line) if len(new_lines) > 1: @@ -842,24 +861,24 @@ class RPEditor: with self.lock: if not self.selection_start or not self.selection_end: return "" - + sl, sc = self.selection_start el, ec = self.selection_end - + # Validate bounds if sl < 0 or sl >= len(self.lines) or el < 0 or el >= len(self.lines): return "" - + if sl == el: return self.lines[sl][sc:ec] - + result = [self.lines[sl][sc:]] for i in range(sl + 1, el): if i < len(self.lines): result.append(self.lines[i]) if el < len(self.lines): result.append(self.lines[el][:ec]) - return '\n'.join(result) + return "\n".join(result) def delete_selection(self): """Delete selected text.""" @@ -880,7 +899,7 @@ class RPEditor: self.save_state() search_lines = search_block.splitlines() replace_lines = replace_block.splitlines() - + for i in range(len(self.lines) - len(search_lines) + 1): match = True for j, search_line in enumerate(search_lines): @@ -890,12 +909,12 @@ class RPEditor: if self.lines[i + j].strip() != search_line.strip(): match = False break - + if match: # Preserve indentation indent = len(self.lines[i]) - len(self.lines[i].lstrip()) - indented_replace = [' ' * indent + line for line in replace_lines] - self.lines[i:i+len(search_lines)] = indented_replace + indented_replace = [" " * indent + line for line in replace_lines] + self.lines[i : i + len(search_lines)] = indented_replace return True return False @@ -904,21 +923,21 @@ class RPEditor: with self.lock: self.save_state() try: - lines = diff_text.split('\n') + lines = diff_text.split("\n") start_line = 0 - + for line in lines: - if line.startswith('@@'): - match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line) + if line.startswith("@@"): + match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line) if match: start_line = int(match.group(1)) - 1 - elif line.startswith('-'): + elif line.startswith("-"): if start_line < len(self.lines): del self.lines[start_line] - elif line.startswith('+'): + elif line.startswith("+"): self.lines.insert(start_line, line[1:]) start_line += 1 - elif line and not line.startswith('\\'): + elif line and not line.startswith("\\"): start_line += 1 except Exception: pass @@ -972,18 +991,18 @@ def main(): editor = None try: filename = sys.argv[1] if len(sys.argv) > 1 else None - + # Parse additional arguments - auto_save = '--auto-save' in sys.argv - + auto_save = "--auto-save" in sys.argv + # Create and start editor editor = RPEditor(filename, auto_save=auto_save) editor.start() - + # Wait for editor to finish if editor.thread: editor.thread.join() - + except KeyboardInterrupt: pass except Exception as e: @@ -992,7 +1011,7 @@ def main(): if editor: editor.stop() # Ensure screen is cleared - os.system('clear' if os.name != 'nt' else 'cls') + os.system("clear" if os.name != "nt" else "cls") if __name__ == "__main__": diff --git a/pr/editor2.py b/pr/editor2.py index 9bba6d3..4ad4506 100644 --- a/pr/editor2.py +++ b/pr/editor2.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import curses -import threading -import sys -import os -import re -import socket import pickle import queue +import re +import socket +import sys +import threading + class RPEditor: def __init__(self, filename=None): @@ -14,7 +14,7 @@ class RPEditor: self.lines = [""] self.cursor_y = 0 self.cursor_x = 0 - self.mode = 'normal' + self.mode = "normal" self.command = "" self.stdscr = None self.running = False @@ -35,7 +35,7 @@ class RPEditor: def load_file(self): try: - with open(self.filename, 'r') as f: + with open(self.filename) as f: self.lines = f.read().splitlines() if not self.lines: self.lines = [""] @@ -45,11 +45,11 @@ class RPEditor: def _save_file(self): with self.lock: if self.filename: - with open(self.filename, 'w') as f: - f.write('\n'.join(self.lines)) + with open(self.filename, "w") as f: + f.write("\n".join(self.lines)) def save_file(self): - self.client_sock.send(pickle.dumps({'command': 'save_file'})) + self.client_sock.send(pickle.dumps({"command": "save_file"})) def start(self): self.running = True @@ -59,7 +59,7 @@ class RPEditor: self.thread.start() def stop(self): - self.client_sock.send(pickle.dumps({'command': 'stop'})) + self.client_sock.send(pickle.dumps({"command": "stop"})) self.running = False if self.stdscr: curses.endwin() @@ -99,66 +99,66 @@ class RPEditor: self.stdscr.addstr(i, 0, line[:width]) status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}" self.stdscr.addstr(height - 1, 0, status[:width]) - if self.mode == 'command': + if self.mode == "command": self.stdscr.addstr(height - 1, 0, self.command[:width]) self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1)) self.stdscr.refresh() def handle_key(self, key): - if self.mode == 'normal': + if self.mode == "normal": self.handle_normal(key) - elif self.mode == 'insert': + elif self.mode == "insert": self.handle_insert(key) - elif self.mode == 'command': + elif self.mode == "command": self.handle_command(key) def handle_normal(self, key): - if key == ord('h') or key == curses.KEY_LEFT: + if key == ord("h") or key == curses.KEY_LEFT: self.move_cursor(0, -1) - elif key == ord('j') or key == curses.KEY_DOWN: + elif key == ord("j") or key == curses.KEY_DOWN: self.move_cursor(1, 0) - elif key == ord('k') or key == curses.KEY_UP: + elif key == ord("k") or key == curses.KEY_UP: self.move_cursor(-1, 0) - elif key == ord('l') or key == curses.KEY_RIGHT: + elif key == ord("l") or key == curses.KEY_RIGHT: self.move_cursor(0, 1) - elif key == ord('i'): - self.mode = 'insert' - elif key == ord(':'): - self.mode = 'command' + elif key == ord("i"): + self.mode = "insert" + elif key == ord(":"): + self.mode = "command" self.command = ":" - elif key == ord('x'): + elif key == ord("x"): self._delete_char() - elif key == ord('a'): + elif key == ord("a"): self.cursor_x += 1 - self.mode = 'insert' - elif key == ord('A'): + self.mode = "insert" + elif key == ord("A"): self.cursor_x = len(self.lines[self.cursor_y]) - self.mode = 'insert' - elif key == ord('o'): + self.mode = "insert" + elif key == ord("o"): self._insert_line(self.cursor_y + 1, "") self.cursor_y += 1 self.cursor_x = 0 - self.mode = 'insert' - elif key == ord('O'): + self.mode = "insert" + elif key == ord("O"): self._insert_line(self.cursor_y, "") self.cursor_x = 0 - self.mode = 'insert' - elif key == ord('d') and self.prev_key == ord('d'): + self.mode = "insert" + elif key == ord("d") and self.prev_key == ord("d"): self.clipboard = self.lines[self.cursor_y] self._delete_line(self.cursor_y) if self.cursor_y >= len(self.lines): self.cursor_y = len(self.lines) - 1 self.cursor_x = 0 - elif key == ord('y') and self.prev_key == ord('y'): + elif key == ord("y") and self.prev_key == ord("y"): self.clipboard = self.lines[self.cursor_y] - elif key == ord('p'): + elif key == ord("p"): self._insert_line(self.cursor_y + 1, self.clipboard) self.cursor_y += 1 self.cursor_x = 0 - elif key == ord('P'): + elif key == ord("P"): self._insert_line(self.cursor_y, self.clipboard) self.cursor_x = 0 - elif key == ord('w'): + elif key == ord("w"): line = self.lines[self.cursor_y] i = self.cursor_x while i < len(line) and not line[i].isalnum(): @@ -166,7 +166,7 @@ class RPEditor: while i < len(line) and line[i].isalnum(): i += 1 self.cursor_x = i - elif key == ord('b'): + elif key == ord("b"): line = self.lines[self.cursor_y] i = self.cursor_x - 1 while i >= 0 and not line[i].isalnum(): @@ -174,26 +174,26 @@ class RPEditor: while i >= 0 and line[i].isalnum(): i -= 1 self.cursor_x = i + 1 - elif key == ord('0'): + elif key == ord("0"): self.cursor_x = 0 - elif key == ord('$'): + elif key == ord("$"): self.cursor_x = len(self.lines[self.cursor_y]) - elif key == ord('g'): - if self.prev_key == ord('g'): + elif key == ord("g"): + if self.prev_key == ord("g"): self.cursor_y = 0 self.cursor_x = 0 - elif key == ord('G'): + elif key == ord("G"): self.cursor_y = len(self.lines) - 1 self.cursor_x = 0 - elif key == ord('u'): + elif key == ord("u"): self.undo() - elif key == ord('r') and self.prev_key == 18: + elif key == ord("r") and self.prev_key == 18: self.redo() self.prev_key = key def handle_insert(self, key): if key == 27: - self.mode = 'normal' + self.mode = "normal" if self.cursor_x > 0: self.cursor_x -= 1 elif key == 10: @@ -207,11 +207,13 @@ class RPEditor: def handle_command(self, key): if key == 10: cmd = self.command[1:] - if cmd == "q" or cmd == 'q!': + if cmd == "q" or cmd == "q!": self.running = False - elif cmd == "w": + elif cmd == "w": self._save_file() - elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!": + elif ( + cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!" + ): self._save_file() self.running = False elif cmd.startswith("w "): @@ -220,10 +222,10 @@ class RPEditor: elif cmd == "wq": self._save_file() self.running = False - self.mode = 'normal' + self.mode = "normal" self.command = "" elif key == 27: - self.mode = 'normal' + self.mode = "normal" self.command = "" elif key == curses.KEY_BACKSPACE or key == 127: if len(self.command) > 1: @@ -241,9 +243,9 @@ class RPEditor: def save_state(self): with self.lock: state = { - 'lines': [line for line in self.lines], - 'cursor_y': self.cursor_y, - 'cursor_x': self.cursor_x + "lines": list(self.lines), + "cursor_y": self.cursor_y, + "cursor_x": self.cursor_x, } self.undo_stack.append(state) if len(self.undo_stack) > self.max_undo: @@ -254,71 +256,85 @@ class RPEditor: with self.lock: if self.undo_stack: current_state = { - 'lines': [line for line in self.lines], - 'cursor_y': self.cursor_y, - 'cursor_x': self.cursor_x + "lines": list(self.lines), + "cursor_y": self.cursor_y, + "cursor_x": self.cursor_x, } self.redo_stack.append(current_state) state = self.undo_stack.pop() - self.lines = state['lines'] - self.cursor_y = state['cursor_y'] - self.cursor_x = state['cursor_x'] + self.lines = state["lines"] + self.cursor_y = state["cursor_y"] + self.cursor_x = state["cursor_x"] def redo(self): with self.lock: if self.redo_stack: current_state = { - 'lines': [line for line in self.lines], - 'cursor_y': self.cursor_y, - 'cursor_x': self.cursor_x + "lines": list(self.lines), + "cursor_y": self.cursor_y, + "cursor_x": self.cursor_x, } self.undo_stack.append(current_state) state = self.redo_stack.pop() - self.lines = state['lines'] - self.cursor_y = state['cursor_y'] - self.cursor_x = state['cursor_x'] + self.lines = state["lines"] + self.cursor_y = state["cursor_y"] + self.cursor_x = state["cursor_x"] def _insert_text(self, text): self.save_state() - lines = text.split('\n') + lines = text.split("\n") if len(lines) == 1: - self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + text + self.lines[self.cursor_y][self.cursor_x:] + self.lines[self.cursor_y] = ( + self.lines[self.cursor_y][: self.cursor_x] + + text + + self.lines[self.cursor_y][self.cursor_x :] + ) self.cursor_x += len(text) else: - first = self.lines[self.cursor_y][:self.cursor_x] + lines[0] - last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:] + first = self.lines[self.cursor_y][: self.cursor_x] + lines[0] + last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :] self.lines[self.cursor_y] = first - for i in range(1, len(lines)-1): + for i in range(1, len(lines) - 1): self.lines.insert(self.cursor_y + i, lines[i]) self.lines.insert(self.cursor_y + len(lines) - 1, last) self.cursor_y += len(lines) - 1 self.cursor_x = len(lines[-1]) def insert_text(self, text): - self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text})) + self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text})) def _delete_char(self): self.save_state() if self.cursor_x < len(self.lines[self.cursor_y]): - self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + self.lines[self.cursor_y][self.cursor_x+1:] + self.lines[self.cursor_y] = ( + self.lines[self.cursor_y][: self.cursor_x] + + self.lines[self.cursor_y][self.cursor_x + 1 :] + ) def delete_char(self): - self.client_sock.send(pickle.dumps({'command': 'delete_char'})) + self.client_sock.send(pickle.dumps({"command": "delete_char"})) def _insert_char(self, char): - self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + char + self.lines[self.cursor_y][self.cursor_x:] + self.lines[self.cursor_y] = ( + self.lines[self.cursor_y][: self.cursor_x] + + char + + self.lines[self.cursor_y][self.cursor_x :] + ) self.cursor_x += 1 def _split_line(self): line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = line[:self.cursor_x] - self.lines.insert(self.cursor_y + 1, line[self.cursor_x:]) + self.lines[self.cursor_y] = line[: self.cursor_x] + self.lines.insert(self.cursor_y + 1, line[self.cursor_x :]) self.cursor_y += 1 self.cursor_x = 0 def _backspace(self): if self.cursor_x > 0: - self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x-1] + self.lines[self.cursor_y][self.cursor_x:] + self.lines[self.cursor_y] = ( + self.lines[self.cursor_y][: self.cursor_x - 1] + + self.lines[self.cursor_y][self.cursor_x :] + ) self.cursor_x -= 1 elif self.cursor_y > 0: prev_len = len(self.lines[self.cursor_y - 1]) @@ -347,7 +363,7 @@ class RPEditor: self.cursor_x = 0 def set_text(self, text): - self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text})) + self.client_sock.send(pickle.dumps({"command": "set_text", "text": text})) def _goto_line(self, line_num): line_num = max(0, min(line_num, len(self.lines) - 1)) @@ -355,24 +371,26 @@ class RPEditor: self.cursor_x = 0 def goto_line(self, line_num): - self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num})) + self.client_sock.send( + pickle.dumps({"command": "goto_line", "line_num": line_num}) + ) def get_text(self): - self.client_sock.send(pickle.dumps({'command': 'get_text'})) + self.client_sock.send(pickle.dumps({"command": "get_text"})) try: return pickle.loads(self.client_sock.recv(4096)) except: - return '' + return "" def get_cursor(self): - self.client_sock.send(pickle.dumps({'command': 'get_cursor'})) + self.client_sock.send(pickle.dumps({"command": "get_cursor"})) try: return pickle.loads(self.client_sock.recv(4096)) except: return (0, 0) def get_file_info(self): - self.client_sock.send(pickle.dumps({'command': 'get_file_info'})) + self.client_sock.send(pickle.dumps({"command": "get_file_info"})) try: return pickle.loads(self.client_sock.recv(4096)) except: @@ -390,46 +408,46 @@ class RPEditor: break def execute_command(self, command): - cmd = command.get('command') - if cmd == 'insert_text': - self._insert_text(command['text']) - elif cmd == 'delete_char': + cmd = command.get("command") + if cmd == "insert_text": + self._insert_text(command["text"]) + elif cmd == "delete_char": self._delete_char() - elif cmd == 'save_file': + elif cmd == "save_file": self._save_file() - elif cmd == 'set_text': - self._set_text(command['text']) - elif cmd == 'goto_line': - self._goto_line(command['line_num']) - elif cmd == 'get_text': - result = '\n'.join(self.lines) + elif cmd == "set_text": + self._set_text(command["text"]) + elif cmd == "goto_line": + self._goto_line(command["line_num"]) + elif cmd == "get_text": + result = "\n".join(self.lines) try: self.server_sock.send(pickle.dumps(result)) except: pass - elif cmd == 'get_cursor': + elif cmd == "get_cursor": result = (self.cursor_y, self.cursor_x) try: self.server_sock.send(pickle.dumps(result)) except: pass - elif cmd == 'get_file_info': + elif cmd == "get_file_info": result = { - 'filename': self.filename, - 'lines': len(self.lines), - 'cursor': (self.cursor_y, self.cursor_x), - 'mode': self.mode + "filename": self.filename, + "lines": len(self.lines), + "cursor": (self.cursor_y, self.cursor_x), + "mode": self.mode, } try: self.server_sock.send(pickle.dumps(result)) except: pass - elif cmd == 'stop': + elif cmd == "stop": self.running = False def move_cursor_to(self, y, x): with self.lock: - self.cursor_y = max(0, min(y, len(self.lines)-1)) + self.cursor_y = max(0, min(y, len(self.lines) - 1)) self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y]))) def get_line(self, line_num): @@ -469,9 +487,9 @@ class RPEditor: else: first_part = self.lines[start_line][:start_col] last_part = self.lines[end_line][end_col:] - new_lines = new_text.split('\n') + new_lines = new_text.split("\n") self.lines[start_line] = first_part + new_lines[0] - del self.lines[start_line + 1:end_line + 1] + del self.lines[start_line + 1 : end_line + 1] for i, new_line in enumerate(new_lines[1:], 1): self.lines.insert(start_line + i, new_line) if len(new_lines) > 1: @@ -511,7 +529,7 @@ class RPEditor: for i in range(sl + 1, el): result.append(self.lines[i]) result.append(self.lines[el][:ec]) - return '\n'.join(result) + return "\n".join(result) def delete_selection(self): with self.lock: @@ -537,24 +555,24 @@ class RPEditor: break if match: indent = len(self.lines[i]) - len(self.lines[i].lstrip()) - indented_replace = [' ' * indent + line for line in replace_lines] - self.lines[i:i+len(search_lines)] = indented_replace + indented_replace = [" " * indent + line for line in replace_lines] + self.lines[i : i + len(search_lines)] = indented_replace return True return False def apply_diff(self, diff_text): with self.lock: self.save_state() - lines = diff_text.split('\n') + lines = diff_text.split("\n") for line in lines: - if line.startswith('@@'): - match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line) + if line.startswith("@@"): + match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line) if match: start_line = int(match.group(1)) - 1 - elif line.startswith('-'): + elif line.startswith("-"): if start_line < len(self.lines): del self.lines[start_line] - elif line.startswith('+'): + elif line.startswith("+"): self.lines.insert(start_line, line[1:]) start_line += 1 @@ -574,6 +592,7 @@ class RPEditor: if self.thread: self.thread.join() + def main(): filename = sys.argv[1] if len(sys.argv) > 1 else None editor = RPEditor(filename) @@ -583,5 +602,3 @@ def main(): if __name__ == "__main__": main() - - diff --git a/pr/input_handler.py b/pr/input_handler.py index 07452f5..96f2b67 100644 --- a/pr/input_handler.py +++ b/pr/input_handler.py @@ -3,14 +3,13 @@ Advanced input handler for PR Assistant with editor mode, file inclusion, and image support. """ -import os -import re import base64 import mimetypes +import re import readline -import glob from pathlib import Path from typing import Optional + # from pr.ui.colors import Colors # Avoid import issues @@ -29,7 +28,7 @@ class AdvancedInputHandler: return None readline.set_completer(completer) - readline.parse_and_bind('tab: complete') + readline.parse_and_bind("tab: complete") except: pass # Readline not available @@ -60,7 +59,7 @@ class AdvancedInputHandler: return "" # Check for special commands - if user_input.lower() == '/editor': + if user_input.lower() == "/editor": self.toggle_editor_mode() return self.get_input(prompt) # Recurse to get new input @@ -74,23 +73,25 @@ class AdvancedInputHandler: def _get_editor_input(self, prompt: str) -> Optional[str]: """Get multi-line input for editor mode.""" try: - print("Editor mode: Enter your message. Type 'END' on a new line to finish.") + print( + "Editor mode: Enter your message. Type 'END' on a new line to finish." + ) print("Type '/simple' to switch back to simple mode.") lines = [] while True: try: line = input() - if line.strip().lower() == 'end': + if line.strip().lower() == "end": break - elif line.strip().lower() == '/simple': + elif line.strip().lower() == "/simple": self.toggle_editor_mode() return self.get_input(prompt) # Switch back and get input lines.append(line) except EOFError: break - content = '\n'.join(lines).strip() + content = "\n".join(lines).strip() if not content: return "" @@ -114,12 +115,13 @@ class AdvancedInputHandler: def _process_file_inclusions(self, text: str) -> str: """Replace @[filename] with file contents.""" + def replace_file(match): filename = match.group(1).strip() try: path = Path(filename).expanduser().resolve() if path.exists() and path.is_file(): - with open(path, 'r', encoding='utf-8', errors='replace') as f: + with open(path, encoding="utf-8", errors="replace") as f: content = f.read() return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n" else: @@ -128,7 +130,7 @@ class AdvancedInputHandler: return f"[Error reading file {filename}: {e}]" # Replace @[filename] patterns - pattern = r'@\[([^\]]+)\]' + pattern = r"@\[([^\]]+)\]" return re.sub(pattern, replace_file, text) def _process_image_inclusions(self, text: str) -> str: @@ -143,20 +145,22 @@ class AdvancedInputHandler: path = Path(word.strip()).expanduser().resolve() if path.exists() and path.is_file(): mime_type, _ = mimetypes.guess_type(str(path)) - if mime_type and mime_type.startswith('image/'): + if mime_type and mime_type.startswith("image/"): # Encode image - with open(path, 'rb') as f: - image_data = base64.b64encode(f.read()).decode('utf-8') + with open(path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") # Replace with data URL - processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n") + processed_parts.append( + f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n" + ) continue except: pass processed_parts.append(word) - return ' '.join(processed_parts) + return " ".join(processed_parts) # Global instance diff --git a/pr/memory/__init__.py b/pr/memory/__init__.py index 8489932..ee5c13f 100644 --- a/pr/memory/__init__.py +++ b/pr/memory/__init__.py @@ -1,7 +1,12 @@ -from .knowledge_store import KnowledgeStore, KnowledgeEntry -from .semantic_index import SemanticIndex from .conversation_memory import ConversationMemory from .fact_extractor import FactExtractor +from .knowledge_store import KnowledgeEntry, KnowledgeStore +from .semantic_index import SemanticIndex -__all__ = ['KnowledgeStore', 'KnowledgeEntry', 'SemanticIndex', - 'ConversationMemory', 'FactExtractor'] +__all__ = [ + "KnowledgeStore", + "KnowledgeEntry", + "SemanticIndex", + "ConversationMemory", + "FactExtractor", +] diff --git a/pr/memory/conversation_memory.py b/pr/memory/conversation_memory.py index cb4b701..58e66fa 100644 --- a/pr/memory/conversation_memory.py +++ b/pr/memory/conversation_memory.py @@ -1,7 +1,8 @@ import json import sqlite3 import time -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional + class ConversationMemory: def __init__(self, db_path: str): @@ -12,7 +13,8 @@ class ConversationMemory: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS conversation_history ( conversation_id TEXT PRIMARY KEY, session_id TEXT, @@ -23,9 +25,11 @@ class ConversationMemory: topics TEXT, metadata TEXT ) - ''') + """ + ) - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS conversation_messages ( message_id TEXT PRIMARY KEY, conversation_id TEXT NOT NULL, @@ -36,117 +40,163 @@ class ConversationMemory: metadata TEXT, FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id) ) - ''') + """ + ) - cursor.execute(''' + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp) - ''') + """ + ) conn.commit() conn.close() - def create_conversation(self, conversation_id: str, session_id: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None): + def create_conversation( + self, + conversation_id: str, + session_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT INTO conversation_history (conversation_id, session_id, started_at, metadata) VALUES (?, ?, ?, ?) - ''', ( - conversation_id, - session_id, - time.time(), - json.dumps(metadata) if metadata else None - )) + """, + ( + conversation_id, + session_id, + time.time(), + json.dumps(metadata) if metadata else None, + ), + ) conn.commit() conn.close() - def add_message(self, conversation_id: str, message_id: str, role: str, - content: str, tool_calls: Optional[List[Dict[str, Any]]] = None, - metadata: Optional[Dict[str, Any]] = None): + def add_message( + self, + conversation_id: str, + message_id: str, + role: str, + content: str, + tool_calls: Optional[List[Dict[str, Any]]] = None, + metadata: Optional[Dict[str, Any]] = None, + ): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT INTO conversation_messages (message_id, conversation_id, role, content, timestamp, tool_calls, metadata) VALUES (?, ?, ?, ?, ?, ?, ?) - ''', ( - message_id, - conversation_id, - role, - content, - time.time(), - json.dumps(tool_calls) if tool_calls else None, - json.dumps(metadata) if metadata else None - )) + """, + ( + message_id, + conversation_id, + role, + content, + time.time(), + json.dumps(tool_calls) if tool_calls else None, + json.dumps(metadata) if metadata else None, + ), + ) - cursor.execute(''' + cursor.execute( + """ UPDATE conversation_history SET message_count = message_count + 1 WHERE conversation_id = ? - ''', (conversation_id,)) + """, + (conversation_id,), + ) conn.commit() conn.close() - def get_conversation_messages(self, conversation_id: str, - limit: Optional[int] = None) -> List[Dict[str, Any]]: + def get_conversation_messages( + self, conversation_id: str, limit: Optional[int] = None + ) -> List[Dict[str, Any]]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() if limit: - cursor.execute(''' + cursor.execute( + """ SELECT message_id, role, content, timestamp, tool_calls, metadata FROM conversation_messages WHERE conversation_id = ? ORDER BY timestamp DESC LIMIT ? - ''', (conversation_id, limit)) + """, + (conversation_id, limit), + ) else: - cursor.execute(''' + cursor.execute( + """ SELECT message_id, role, content, timestamp, tool_calls, metadata FROM conversation_messages WHERE conversation_id = ? ORDER BY timestamp ASC - ''', (conversation_id,)) + """, + (conversation_id,), + ) messages = [] for row in cursor.fetchall(): - messages.append({ - 'message_id': row[0], - 'role': row[1], - 'content': row[2], - 'timestamp': row[3], - 'tool_calls': json.loads(row[4]) if row[4] else None, - 'metadata': json.loads(row[5]) if row[5] else None - }) + messages.append( + { + "message_id": row[0], + "role": row[1], + "content": row[2], + "timestamp": row[3], + "tool_calls": json.loads(row[4]) if row[4] else None, + "metadata": json.loads(row[5]) if row[5] else None, + } + ) conn.close() return messages - def update_conversation_summary(self, conversation_id: str, summary: str, - topics: Optional[List[str]] = None): + def update_conversation_summary( + self, conversation_id: str, summary: str, topics: Optional[List[str]] = None + ): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ UPDATE conversation_history SET summary = ?, topics = ?, ended_at = ? WHERE conversation_id = ? - ''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id)) + """, + ( + summary, + json.dumps(topics) if topics else None, + time.time(), + conversation_id, + ), + ) conn.commit() conn.close() @@ -155,7 +205,8 @@ class ConversationMemory: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ SELECT DISTINCT h.conversation_id, h.session_id, h.started_at, h.message_count, h.summary, h.topics FROM conversation_history h @@ -163,56 +214,69 @@ class ConversationMemory: WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ? ORDER BY h.started_at DESC LIMIT ? - ''', (f'%{query}%', f'%{query}%', f'%{query}%', limit)) + """, + (f"%{query}%", f"%{query}%", f"%{query}%", limit), + ) conversations = [] for row in cursor.fetchall(): - conversations.append({ - 'conversation_id': row[0], - 'session_id': row[1], - 'started_at': row[2], - 'message_count': row[3], - 'summary': row[4], - 'topics': json.loads(row[5]) if row[5] else [] - }) + conversations.append( + { + "conversation_id": row[0], + "session_id": row[1], + "started_at": row[2], + "message_count": row[3], + "summary": row[4], + "topics": json.loads(row[5]) if row[5] else [], + } + ) conn.close() return conversations - def get_recent_conversations(self, limit: int = 10, - session_id: Optional[str] = None) -> List[Dict[str, Any]]: + def get_recent_conversations( + self, limit: int = 10, session_id: Optional[str] = None + ) -> List[Dict[str, Any]]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() if session_id: - cursor.execute(''' + cursor.execute( + """ SELECT conversation_id, session_id, started_at, ended_at, message_count, summary, topics FROM conversation_history WHERE session_id = ? ORDER BY started_at DESC LIMIT ? - ''', (session_id, limit)) + """, + (session_id, limit), + ) else: - cursor.execute(''' + cursor.execute( + """ SELECT conversation_id, session_id, started_at, ended_at, message_count, summary, topics FROM conversation_history ORDER BY started_at DESC LIMIT ? - ''', (limit,)) + """, + (limit,), + ) conversations = [] for row in cursor.fetchall(): - conversations.append({ - 'conversation_id': row[0], - 'session_id': row[1], - 'started_at': row[2], - 'ended_at': row[3], - 'message_count': row[4], - 'summary': row[5], - 'topics': json.loads(row[6]) if row[6] else [] - }) + conversations.append( + { + "conversation_id": row[0], + "session_id": row[1], + "started_at": row[2], + "ended_at": row[3], + "message_count": row[4], + "summary": row[5], + "topics": json.loads(row[6]) if row[6] else [], + } + ) conn.close() return conversations @@ -221,10 +285,14 @@ class ConversationMemory: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?', - (conversation_id,)) - cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?', - (conversation_id,)) + cursor.execute( + "DELETE FROM conversation_messages WHERE conversation_id = ?", + (conversation_id,), + ) + cursor.execute( + "DELETE FROM conversation_history WHERE conversation_id = ?", + (conversation_id,), + ) deleted = cursor.rowcount > 0 conn.commit() @@ -236,24 +304,26 @@ class ConversationMemory: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('SELECT COUNT(*) FROM conversation_history') + cursor.execute("SELECT COUNT(*) FROM conversation_history") total_conversations = cursor.fetchone()[0] - cursor.execute('SELECT COUNT(*) FROM conversation_messages') + cursor.execute("SELECT COUNT(*) FROM conversation_messages") total_messages = cursor.fetchone()[0] - cursor.execute('SELECT SUM(message_count) FROM conversation_history') - total_message_count = cursor.fetchone()[0] or 0 + cursor.execute("SELECT SUM(message_count) FROM conversation_history") + cursor.fetchone()[0] or 0 - cursor.execute(''' + cursor.execute( + """ SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0 - ''') + """ + ) avg_messages = cursor.fetchone()[0] or 0 conn.close() return { - 'total_conversations': total_conversations, - 'total_messages': total_messages, - 'average_messages_per_conversation': round(avg_messages, 2) + "total_conversations": total_conversations, + "total_messages": total_messages, + "average_messages_per_conversation": round(avg_messages, 2), } diff --git a/pr/memory/fact_extractor.py b/pr/memory/fact_extractor.py index 72a8299..7fec5eb 100644 --- a/pr/memory/fact_extractor.py +++ b/pr/memory/fact_extractor.py @@ -1,16 +1,16 @@ import re -import json -from typing import List, Dict, Any, Set from collections import defaultdict +from typing import Any, Dict, List + class FactExtractor: def __init__(self): self.fact_patterns = [ - (r'([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)', 'definition'), - (r'([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})', 'temporal'), - (r'([A-Z][a-z]+) (invented|created|developed) ([^.]+)', 'attribution'), - (r'([^.]+) (costs?|worth) (\$[\d,]+)', 'numeric'), - (r'([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)', 'location'), + (r"([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)", "definition"), + (r"([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})", "temporal"), + (r"([A-Z][a-z]+) (invented|created|developed) ([^.]+)", "attribution"), + (r"([^.]+) (costs?|worth) (\$[\d,]+)", "numeric"), + (r"([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)", "location"), ] def extract_facts(self, text: str) -> List[Dict[str, Any]]: @@ -19,27 +19,31 @@ class FactExtractor: for pattern, fact_type in self.fact_patterns: matches = re.finditer(pattern, text) for match in matches: - facts.append({ - 'type': fact_type, - 'text': match.group(0), - 'components': match.groups(), - 'confidence': 0.7 - }) + facts.append( + { + "type": fact_type, + "text": match.group(0), + "components": match.groups(), + "confidence": 0.7, + } + ) noun_phrases = self._extract_noun_phrases(text) for phrase in noun_phrases: if len(phrase.split()) >= 2: - facts.append({ - 'type': 'entity', - 'text': phrase, - 'components': [phrase], - 'confidence': 0.5 - }) + facts.append( + { + "type": "entity", + "text": phrase, + "components": [phrase], + "confidence": 0.5, + } + ) return facts def _extract_noun_phrases(self, text: str) -> List[str]: - sentences = re.split(r'[.!?]', text) + sentences = re.split(r"[.!?]", text) phrases = [] for sentence in sentences: @@ -51,25 +55,73 @@ class FactExtractor: current_phrase.append(word) else: if len(current_phrase) >= 2: - phrases.append(' '.join(current_phrase)) + phrases.append(" ".join(current_phrase)) current_phrase = [] if len(current_phrase) >= 2: - phrases.append(' '.join(current_phrase)) + phrases.append(" ".join(current_phrase)) return list(set(phrases)) def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]: - words = re.findall(r'\b[a-z]{4,}\b', text.lower()) + words = re.findall(r"\b[a-z]{4,}\b", text.lower()) stopwords = { - 'this', 'that', 'these', 'those', 'what', 'which', 'where', 'when', - 'with', 'from', 'have', 'been', 'were', 'will', 'would', 'could', - 'should', 'about', 'their', 'there', 'other', 'than', 'then', 'them', - 'some', 'more', 'very', 'such', 'into', 'through', 'during', 'before', - 'after', 'above', 'below', 'between', 'under', 'again', 'further', - 'once', 'here', 'both', 'each', 'doing', 'only', 'over', 'same', - 'being', 'does', 'just', 'also', 'make', 'made', 'know', 'like' + "this", + "that", + "these", + "those", + "what", + "which", + "where", + "when", + "with", + "from", + "have", + "been", + "were", + "will", + "would", + "could", + "should", + "about", + "their", + "there", + "other", + "than", + "then", + "them", + "some", + "more", + "very", + "such", + "into", + "through", + "during", + "before", + "after", + "above", + "below", + "between", + "under", + "again", + "further", + "once", + "here", + "both", + "each", + "doing", + "only", + "over", + "same", + "being", + "does", + "just", + "also", + "make", + "made", + "know", + "like", } filtered_words = [w for w in words if w not in stopwords] @@ -85,57 +137,120 @@ class FactExtractor: relationships = [] relationship_patterns = [ - (r'([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)', 'employment'), - (r'([A-Z][a-z]+) (owns|has|possesses) ([^.]+)', 'ownership'), - (r'([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)', 'location'), - (r'([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)', 'usage'), + ( + r"([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)", + "employment", + ), + (r"([A-Z][a-z]+) (owns|has|possesses) ([^.]+)", "ownership"), + ( + r"([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)", + "location", + ), + (r"([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)", "usage"), ] for pattern, rel_type in relationship_patterns: matches = re.finditer(pattern, text) for match in matches: - relationships.append({ - 'type': rel_type, - 'subject': match.group(1), - 'predicate': match.group(2), - 'object': match.group(3), - 'confidence': 0.6 - }) + relationships.append( + { + "type": rel_type, + "subject": match.group(1), + "predicate": match.group(2), + "object": match.group(3), + "confidence": 0.6, + } + ) return relationships def extract_metadata(self, text: str) -> Dict[str, Any]: word_count = len(text.split()) - sentence_count = len(re.split(r'[.!?]', text)) + sentence_count = len(re.split(r"[.!?]", text)) - urls = re.findall(r'https?://[^\s]+', text) - email_addresses = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text) - dates = re.findall(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b', text) - numbers = re.findall(r'\b\d+(?:,\d{3})*(?:\.\d+)?\b', text) + urls = re.findall(r"https?://[^\s]+", text) + email_addresses = re.findall( + r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text + ) + dates = re.findall( + r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text + ) + numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text) return { - 'word_count': word_count, - 'sentence_count': sentence_count, - 'avg_words_per_sentence': round(word_count / max(sentence_count, 1), 2), - 'urls': urls, - 'email_addresses': email_addresses, - 'dates': dates, - 'numeric_values': numbers, - 'has_code': bool(re.search(r'```|def |class |import |function ', text)), - 'has_questions': bool(re.search(r'\?', text)) + "word_count": word_count, + "sentence_count": sentence_count, + "avg_words_per_sentence": round(word_count / max(sentence_count, 1), 2), + "urls": urls, + "email_addresses": email_addresses, + "dates": dates, + "numeric_values": numbers, + "has_code": bool(re.search(r"```|def |class |import |function ", text)), + "has_questions": bool(re.search(r"\?", text)), } def categorize_content(self, text: str) -> List[str]: categories = [] category_keywords = { - 'programming': ['code', 'function', 'class', 'variable', 'programming', 'software', 'debug'], - 'data': ['data', 'database', 'query', 'table', 'record', 'statistics', 'analysis'], - 'documentation': ['documentation', 'guide', 'tutorial', 'manual', 'readme', 'explain'], - 'configuration': ['config', 'settings', 'configuration', 'setup', 'install', 'deployment'], - 'testing': ['test', 'testing', 'validate', 'verification', 'quality', 'assertion'], - 'research': ['research', 'study', 'analysis', 'investigation', 'findings', 'results'], - 'planning': ['plan', 'planning', 'schedule', 'roadmap', 'milestone', 'timeline'], + "programming": [ + "code", + "function", + "class", + "variable", + "programming", + "software", + "debug", + ], + "data": [ + "data", + "database", + "query", + "table", + "record", + "statistics", + "analysis", + ], + "documentation": [ + "documentation", + "guide", + "tutorial", + "manual", + "readme", + "explain", + ], + "configuration": [ + "config", + "settings", + "configuration", + "setup", + "install", + "deployment", + ], + "testing": [ + "test", + "testing", + "validate", + "verification", + "quality", + "assertion", + ], + "research": [ + "research", + "study", + "analysis", + "investigation", + "findings", + "results", + ], + "planning": [ + "plan", + "planning", + "schedule", + "roadmap", + "milestone", + "timeline", + ], } text_lower = text.lower() @@ -143,4 +258,4 @@ class FactExtractor: if any(keyword in text_lower for keyword in keywords): categories.append(category) - return categories if categories else ['general'] + return categories if categories else ["general"] diff --git a/pr/memory/knowledge_store.py b/pr/memory/knowledge_store.py index 9625a70..22191d5 100644 --- a/pr/memory/knowledge_store.py +++ b/pr/memory/knowledge_store.py @@ -1,10 +1,12 @@ import json import sqlite3 import time -from typing import List, Dict, Any, Optional from dataclasses import dataclass +from typing import Any, Dict, List, Optional + from .semantic_index import SemanticIndex + @dataclass class KnowledgeEntry: entry_id: str @@ -18,16 +20,17 @@ class KnowledgeEntry: def to_dict(self) -> Dict[str, Any]: return { - 'entry_id': self.entry_id, - 'category': self.category, - 'content': self.content, - 'metadata': self.metadata, - 'created_at': self.created_at, - 'updated_at': self.updated_at, - 'access_count': self.access_count, - 'importance_score': self.importance_score + "entry_id": self.entry_id, + "category": self.category, + "content": self.content, + "metadata": self.metadata, + "created_at": self.created_at, + "updated_at": self.updated_at, + "access_count": self.access_count, + "importance_score": self.importance_score, } + class KnowledgeStore: def __init__(self, db_path: str): self.db_path = db_path @@ -39,7 +42,8 @@ class KnowledgeStore: def _initialize_store(self): cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS knowledge_entries ( entry_id TEXT PRIMARY KEY, category TEXT NOT NULL, @@ -50,44 +54,54 @@ class KnowledgeStore: access_count INTEGER DEFAULT 0, importance_score REAL DEFAULT 1.0 ) - ''') + """ + ) - cursor.execute(''' + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC) - ''') + """ + ) self.conn.commit() def _load_index(self): cursor = self.conn.cursor() - cursor.execute('SELECT entry_id, content FROM knowledge_entries') + cursor.execute("SELECT entry_id, content FROM knowledge_entries") for row in cursor.fetchall(): self.semantic_index.add_document(row[0], row[1]) def add_entry(self, entry: KnowledgeEntry): cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ INSERT OR REPLACE INTO knowledge_entries (entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - entry.entry_id, - entry.category, - entry.content, - json.dumps(entry.metadata), - entry.created_at, - entry.updated_at, - entry.access_count, - entry.importance_score - )) + """, + ( + entry.entry_id, + entry.category, + entry.content, + json.dumps(entry.metadata), + entry.created_at, + entry.updated_at, + entry.access_count, + entry.importance_score, + ), + ) self.conn.commit() @@ -96,20 +110,26 @@ class KnowledgeStore: def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]: cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score FROM knowledge_entries WHERE entry_id = ? - ''', (entry_id,)) + """, + (entry_id,), + ) row = cursor.fetchone() if row: - cursor.execute(''' + cursor.execute( + """ UPDATE knowledge_entries SET access_count = access_count + 1 WHERE entry_id = ? - ''', (entry_id,)) + """, + (entry_id,), + ) self.conn.commit() return KnowledgeEntry( @@ -120,13 +140,14 @@ class KnowledgeStore: created_at=row[4], updated_at=row[5], access_count=row[6] + 1, - importance_score=row[7] + importance_score=row[7], ) return None - def search_entries(self, query: str, category: Optional[str] = None, - top_k: int = 5) -> List[KnowledgeEntry]: + def search_entries( + self, query: str, category: Optional[str] = None, top_k: int = 5 + ) -> List[KnowledgeEntry]: search_results = self.semantic_index.search(query, top_k * 2) cursor = self.conn.cursor() @@ -134,17 +155,23 @@ class KnowledgeStore: entries = [] for entry_id, score in search_results: if category: - cursor.execute(''' + cursor.execute( + """ SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score FROM knowledge_entries WHERE entry_id = ? AND category = ? - ''', (entry_id, category)) + """, + (entry_id, category), + ) else: - cursor.execute(''' + cursor.execute( + """ SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score FROM knowledge_entries WHERE entry_id = ? - ''', (entry_id,)) + """, + (entry_id,), + ) row = cursor.fetchone() if row: @@ -156,7 +183,7 @@ class KnowledgeStore: created_at=row[4], updated_at=row[5], access_count=row[6], - importance_score=row[7] + importance_score=row[7], ) entries.append(entry) @@ -168,44 +195,52 @@ class KnowledgeStore: def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]: cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score FROM knowledge_entries WHERE category = ? ORDER BY importance_score DESC, created_at DESC LIMIT ? - ''', (category, limit)) + """, + (category, limit), + ) entries = [] for row in cursor.fetchall(): - entries.append(KnowledgeEntry( - entry_id=row[0], - category=row[1], - content=row[2], - metadata=json.loads(row[3]) if row[3] else {}, - created_at=row[4], - updated_at=row[5], - access_count=row[6], - importance_score=row[7] - )) + entries.append( + KnowledgeEntry( + entry_id=row[0], + category=row[1], + content=row[2], + metadata=json.loads(row[3]) if row[3] else {}, + created_at=row[4], + updated_at=row[5], + access_count=row[6], + importance_score=row[7], + ) + ) return entries def update_importance(self, entry_id: str, importance_score: float): cursor = self.conn.cursor() - cursor.execute(''' + cursor.execute( + """ UPDATE knowledge_entries SET importance_score = ?, updated_at = ? WHERE entry_id = ? - ''', (importance_score, time.time(), entry_id)) + """, + (importance_score, time.time(), entry_id), + ) self.conn.commit() def delete_entry(self, entry_id: str) -> bool: cursor = self.conn.cursor() - cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,)) + cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,)) deleted = cursor.rowcount > 0 self.conn.commit() @@ -218,27 +253,29 @@ class KnowledgeStore: def get_statistics(self) -> Dict[str, Any]: cursor = self.conn.cursor() - cursor.execute('SELECT COUNT(*) FROM knowledge_entries') + cursor.execute("SELECT COUNT(*) FROM knowledge_entries") total_entries = cursor.fetchone()[0] - cursor.execute('SELECT COUNT(DISTINCT category) FROM knowledge_entries') + cursor.execute("SELECT COUNT(DISTINCT category) FROM knowledge_entries") total_categories = cursor.fetchone()[0] - cursor.execute(''' + cursor.execute( + """ SELECT category, COUNT(*) as count FROM knowledge_entries GROUP BY category ORDER BY count DESC - ''') + """ + ) category_counts = {row[0]: row[1] for row in cursor.fetchall()} - cursor.execute('SELECT SUM(access_count) FROM knowledge_entries') + cursor.execute("SELECT SUM(access_count) FROM knowledge_entries") total_accesses = cursor.fetchone()[0] or 0 return { - 'total_entries': total_entries, - 'total_categories': total_categories, - 'category_distribution': category_counts, - 'total_accesses': total_accesses, - 'vocabulary_size': len(self.semantic_index.vocabulary) + "total_entries": total_entries, + "total_categories": total_categories, + "category_distribution": category_counts, + "total_accesses": total_accesses, + "vocabulary_size": len(self.semantic_index.vocabulary), } diff --git a/pr/memory/semantic_index.py b/pr/memory/semantic_index.py index 93d4d97..db9534c 100644 --- a/pr/memory/semantic_index.py +++ b/pr/memory/semantic_index.py @@ -1,7 +1,8 @@ import math import re from collections import Counter, defaultdict -from typing import List, Dict, Tuple, Set +from typing import Dict, List, Set, Tuple + class SemanticIndex: def __init__(self): @@ -12,7 +13,7 @@ class SemanticIndex: def _tokenize(self, text: str) -> List[str]: text = text.lower() - text = re.sub(r'[^a-z0-9\s]', ' ', text) + text = re.sub(r"[^a-z0-9\s]", " ", text) tokens = text.split() return tokens @@ -78,8 +79,12 @@ class SemanticIndex: scores.sort(key=lambda x: x[1], reverse=True) return scores[:top_k] - def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float: - dot_product = sum(vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)) + def _cosine_similarity( + self, vec1: Dict[str, float], vec2: Dict[str, float] + ) -> float: + dot_product = sum( + vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2) + ) norm1 = math.sqrt(sum(val**2 for val in vec1.values())) norm2 = math.sqrt(sum(val**2 for val in vec2.values())) if norm1 == 0 or norm2 == 0: diff --git a/pr/multiplexer.py b/pr/multiplexer.py index 7f50afa..e924631 100644 --- a/pr/multiplexer.py +++ b/pr/multiplexer.py @@ -1,14 +1,13 @@ -import threading import queue -import time -import sys import subprocess -import signal -import os -from pr.ui import Colors -from collections import defaultdict -from pr.tools.process_handlers import get_handler_for_process, detect_process_type +import sys +import threading +import time + +from pr.tools.process_handlers import detect_process_type, get_handler_for_process from pr.tools.prompt_detection import get_global_detector +from pr.ui import Colors + class TerminalMultiplexer: def __init__(self, name, show_output=True): @@ -21,17 +20,19 @@ class TerminalMultiplexer: self.active = True self.lock = threading.Lock() self.metadata = { - 'start_time': time.time(), - 'last_activity': time.time(), - 'interaction_count': 0, - 'process_type': 'unknown', - 'state': 'active' + "start_time": time.time(), + "last_activity": time.time(), + "interaction_count": 0, + "process_type": "unknown", + "state": "active", } self.handler = None self.prompt_detector = get_global_detector() if self.show_output: - self.display_thread = threading.Thread(target=self._display_worker, daemon=True) + self.display_thread = threading.Thread( + target=self._display_worker, daemon=True + ) self.display_thread.start() def _display_worker(self): @@ -47,7 +48,9 @@ class TerminalMultiplexer: try: line = self.stderr_queue.get(timeout=0.1) if line: - sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}") + sys.stderr.write( + f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}" + ) sys.stderr.flush() except queue.Empty: pass @@ -55,40 +58,44 @@ class TerminalMultiplexer: def write_stdout(self, data): with self.lock: self.stdout_buffer.append(data) - self.metadata['last_activity'] = time.time() + self.metadata["last_activity"] = time.time() # Update handler state if available if self.handler: self.handler.update_state(data) # Update prompt detector - self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type']) + self.prompt_detector.update_session_state( + self.name, data, self.metadata["process_type"] + ) if self.show_output: self.stdout_queue.put(data) def write_stderr(self, data): with self.lock: self.stderr_buffer.append(data) - self.metadata['last_activity'] = time.time() + self.metadata["last_activity"] = time.time() # Update handler state if available if self.handler: self.handler.update_state(data) # Update prompt detector - self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type']) + self.prompt_detector.update_session_state( + self.name, data, self.metadata["process_type"] + ) if self.show_output: self.stderr_queue.put(data) def get_stdout(self): with self.lock: - return ''.join(self.stdout_buffer) + return "".join(self.stdout_buffer) def get_stderr(self): with self.lock: - return ''.join(self.stderr_buffer) + return "".join(self.stderr_buffer) def get_all_output(self): with self.lock: return { - 'stdout': ''.join(self.stdout_buffer), - 'stderr': ''.join(self.stderr_buffer) + "stdout": "".join(self.stdout_buffer), + "stderr": "".join(self.stderr_buffer), } def get_metadata(self): @@ -102,31 +109,32 @@ class TerminalMultiplexer: def set_process_type(self, process_type): """Set the process type and initialize appropriate handler.""" with self.lock: - self.metadata['process_type'] = process_type + self.metadata["process_type"] = process_type self.handler = get_handler_for_process(process_type, self) def send_input(self, input_data): - if hasattr(self, 'process') and self.process.poll() is None: + if hasattr(self, "process") and self.process.poll() is None: try: - self.process.stdin.write(input_data + '\n') + self.process.stdin.write(input_data + "\n") self.process.stdin.flush() with self.lock: - self.metadata['last_activity'] = time.time() - self.metadata['interaction_count'] += 1 + self.metadata["last_activity"] = time.time() + self.metadata["interaction_count"] += 1 except Exception as e: self.write_stderr(f"Error sending input: {e}") else: # This will be implemented when we have a process attached # For now, just update activity with self.lock: - self.metadata['last_activity'] = time.time() - self.metadata['interaction_count'] += 1 + self.metadata["last_activity"] = time.time() + self.metadata["interaction_count"] += 1 def close(self): self.active = False - if hasattr(self, 'display_thread'): + if hasattr(self, "display_thread"): self.display_thread.join(timeout=1) + _multiplexers = {} _mux_counter = 0 _mux_lock = threading.Lock() @@ -134,6 +142,7 @@ _background_monitor = None _monitor_active = False _monitor_interval = 0.2 # 200ms + def create_multiplexer(name=None, show_output=True): global _mux_counter with _mux_lock: @@ -144,44 +153,50 @@ def create_multiplexer(name=None, show_output=True): _multiplexers[name] = mux return name, mux + def get_multiplexer(name): return _multiplexers.get(name) + def close_multiplexer(name): mux = _multiplexers.get(name) if mux: mux.close() del _multiplexers[name] + def get_all_multiplexer_states(): with _mux_lock: states = {} for name, mux in _multiplexers.items(): states[name] = { - 'metadata': mux.get_metadata(), - 'output_summary': { - 'stdout_lines': len(mux.stdout_buffer), - 'stderr_lines': len(mux.stderr_buffer) - } + "metadata": mux.get_metadata(), + "output_summary": { + "stdout_lines": len(mux.stdout_buffer), + "stderr_lines": len(mux.stderr_buffer), + }, } return states + def cleanup_all_multiplexers(): for mux in list(_multiplexers.values()): mux.close() _multiplexers.clear() + # Background process management _background_processes = {} _process_lock = threading.Lock() + class BackgroundProcess: def __init__(self, name, command): self.name = name self.command = command self.process = None self.multiplexer = None - self.status = 'starting' + self.status = "starting" self.start_time = time.time() self.end_time = None @@ -205,27 +220,27 @@ class BackgroundProcess: stderr=subprocess.PIPE, text=True, bufsize=1, - universal_newlines=True + universal_newlines=True, ) - self.status = 'running' + self.status = "running" # Start output monitoring threads threading.Thread(target=self._monitor_stdout, daemon=True).start() threading.Thread(target=self._monitor_stderr, daemon=True).start() - return {'status': 'success', 'pid': self.process.pid} + return {"status": "success", "pid": self.process.pid} except Exception as e: - self.status = 'error' - return {'status': 'error', 'error': str(e)} + self.status = "error" + return {"status": "error", "error": str(e)} def _monitor_stdout(self): """Monitor stdout from the process.""" try: - for line in iter(self.process.stdout.readline, ''): + for line in iter(self.process.stdout.readline, ""): if line: - self.multiplexer.write_stdout(line.rstrip('\n\r')) + self.multiplexer.write_stdout(line.rstrip("\n\r")) except Exception as e: self.write_stderr(f"Error reading stdout: {e}") finally: @@ -234,29 +249,33 @@ class BackgroundProcess: def _monitor_stderr(self): """Monitor stderr from the process.""" try: - for line in iter(self.process.stderr.readline, ''): + for line in iter(self.process.stderr.readline, ""): if line: - self.multiplexer.write_stderr(line.rstrip('\n\r')) + self.multiplexer.write_stderr(line.rstrip("\n\r")) except Exception as e: self.write_stderr(f"Error reading stderr: {e}") def _check_completion(self): """Check if process has completed.""" if self.process and self.process.poll() is not None: - self.status = 'completed' + self.status = "completed" self.end_time = time.time() def get_info(self): """Get process information.""" self._check_completion() return { - 'name': self.name, - 'command': self.command, - 'status': self.status, - 'pid': self.process.pid if self.process else None, - 'start_time': self.start_time, - 'end_time': self.end_time, - 'runtime': time.time() - self.start_time if not self.end_time else self.end_time - self.start_time + "name": self.name, + "command": self.command, + "status": self.status, + "pid": self.process.pid if self.process else None, + "start_time": self.start_time, + "end_time": self.end_time, + "runtime": ( + time.time() - self.start_time + if not self.end_time + else self.end_time - self.start_time + ), } def get_output(self, lines=None): @@ -265,8 +284,8 @@ class BackgroundProcess: return [] all_output = self.multiplexer.get_all_output() - stdout_lines = all_output['stdout'].split('\n') if all_output['stdout'] else [] - stderr_lines = all_output['stderr'].split('\n') if all_output['stderr'] else [] + stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else [] + stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else [] combined = stdout_lines + stderr_lines if lines: @@ -276,45 +295,47 @@ class BackgroundProcess: def send_input(self, input_text): """Send input to the process.""" - if self.process and self.status == 'running': + if self.process and self.status == "running": try: - self.process.stdin.write(input_text + '\n') + self.process.stdin.write(input_text + "\n") self.process.stdin.flush() - return {'status': 'success'} + return {"status": "success"} except Exception as e: - return {'status': 'error', 'error': str(e)} - return {'status': 'error', 'error': 'Process not running or no stdin'} + return {"status": "error", "error": str(e)} + return {"status": "error", "error": "Process not running or no stdin"} def kill(self): """Kill the process.""" - if self.process and self.status == 'running': + if self.process and self.status == "running": try: self.process.terminate() # Wait a bit for graceful termination time.sleep(0.1) if self.process.poll() is None: self.process.kill() - self.status = 'killed' + self.status = "killed" self.end_time = time.time() - return {'status': 'success'} + return {"status": "success"} except Exception as e: - return {'status': 'error', 'error': str(e)} - return {'status': 'error', 'error': 'Process not running'} + return {"status": "error", "error": str(e)} + return {"status": "error", "error": "Process not running"} + def start_background_process(name, command): """Start a background process.""" with _process_lock: if name in _background_processes: - return {'status': 'error', 'error': f'Process {name} already exists'} + return {"status": "error", "error": f"Process {name} already exists"} process = BackgroundProcess(name, command) result = process.start() - if result['status'] == 'success': + if result["status"] == "success": _background_processes[name] = process return result + def get_all_sessions(): """Get all background process sessions.""" with _process_lock: @@ -323,23 +344,31 @@ def get_all_sessions(): sessions[name] = process.get_info() return sessions + def get_session_info(name): """Get information about a specific session.""" with _process_lock: process = _background_processes.get(name) return process.get_info() if process else None + def get_session_output(name, lines=None): """Get output from a specific session.""" with _process_lock: process = _background_processes.get(name) return process.get_output(lines) if process else None + def send_input_to_session(name, input_text): """Send input to a background session.""" with _process_lock: process = _background_processes.get(name) - return process.send_input(input_text) if process else {'status': 'error', 'error': 'Session not found'} + return ( + process.send_input(input_text) + if process + else {"status": "error", "error": "Session not found"} + ) + def kill_session(name): """Kill a background session.""" @@ -347,7 +376,7 @@ def kill_session(name): process = _background_processes.get(name) if process: result = process.kill() - if result['status'] == 'success': + if result["status"] == "success": del _background_processes[name] return result - return {'status': 'error', 'error': 'Session not found'} + return {"status": "error", "error": "Session not found"} diff --git a/pr/plugins/loader.py b/pr/plugins/loader.py index 282bfd8..4c5babf 100644 --- a/pr/plugins/loader.py +++ b/pr/plugins/loader.py @@ -1,10 +1,11 @@ +import importlib.util import os import sys -import importlib.util -from typing import List, Dict, Callable, Any +from typing import Callable, Dict, List + from pr.core.logging import get_logger -logger = get_logger('plugins') +logger = get_logger("plugins") PLUGINS_DIR = os.path.expanduser("~/.pr/plugins") @@ -21,7 +22,7 @@ class PluginLoader: logger.info("No plugins directory found") return [] - plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith('.py')] + plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith(".py")] for plugin_file in plugin_files: try: @@ -44,16 +45,20 @@ class PluginLoader: sys.modules[plugin_name] = module spec.loader.exec_module(module) - if hasattr(module, 'register_tools'): + if hasattr(module, "register_tools"): tools = module.register_tools() if isinstance(tools, list): self.plugin_tools.extend(tools) self.loaded_plugins[plugin_name] = module logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)") else: - logger.warning(f"Plugin {plugin_name} register_tools() did not return a list") + logger.warning( + f"Plugin {plugin_name} register_tools() did not return a list" + ) else: - logger.warning(f"Plugin {plugin_name} does not have register_tools() function") + logger.warning( + f"Plugin {plugin_name} does not have register_tools() function" + ) def get_plugin_function(self, tool_name: str) -> Callable: for plugin_name, module in self.loaded_plugins.items(): @@ -67,7 +72,7 @@ class PluginLoader: def create_example_plugin(): - example_plugin = os.path.join(PLUGINS_DIR, 'example_plugin.py') + example_plugin = os.path.join(PLUGINS_DIR, "example_plugin.py") if os.path.exists(example_plugin): return @@ -121,7 +126,7 @@ def register_tools(): try: os.makedirs(PLUGINS_DIR, exist_ok=True) - with open(example_plugin, 'w') as f: + with open(example_plugin, "w") as f: f.write(example_code) logger.info(f"Created example plugin at {example_plugin}") except Exception as e: diff --git a/pr/tools/__init__.py b/pr/tools/__init__.py index 13b11de..4182945 100644 --- a/pr/tools/__init__.py +++ b/pr/tools/__init__.py @@ -1,25 +1,86 @@ -from pr.tools.base import get_tools_definition -from pr.tools.filesystem import ( - read_file, write_file, list_directory, mkdir, chdir, getpwd, index_source_directory, search_replace +from pr.tools.agents import ( + collaborate_agents, + create_agent, + execute_agent_task, + list_agents, + remove_agent, +) +from pr.tools.base import get_tools_definition +from pr.tools.command import ( + kill_process, + run_command, + run_command_interactive, + tail_process, +) +from pr.tools.database import db_get, db_query, db_set +from pr.tools.editor import ( + close_editor, + editor_insert_text, + editor_replace_text, + editor_search, + open_editor, +) +from pr.tools.filesystem import ( + chdir, + getpwd, + index_source_directory, + list_directory, + mkdir, + read_file, + search_replace, + write_file, +) +from pr.tools.memory import ( + add_knowledge_entry, + delete_knowledge_entry, + get_knowledge_by_category, + get_knowledge_entry, + get_knowledge_statistics, + search_knowledge, + update_knowledge_importance, ) -from pr.tools.command import run_command, run_command_interactive, tail_process, kill_process -from pr.tools.editor import open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor -from pr.tools.database import db_set, db_get, db_query -from pr.tools.web import http_fetch, web_search, web_search_news -from pr.tools.python_exec import python_exec from pr.tools.patch import apply_patch, create_diff -from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents -from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics +from pr.tools.python_exec import python_exec +from pr.tools.web import http_fetch, web_search, web_search_news __all__ = [ - 'get_tools_definition', - 'read_file', 'write_file', 'list_directory', 'mkdir', 'chdir', 'getpwd', 'index_source_directory', 'search_replace', - 'open_editor', 'editor_insert_text', 'editor_replace_text', 'editor_search','close_editor', - 'run_command', 'run_command_interactive', - 'db_set', 'db_get', 'db_query', - 'http_fetch', 'web_search', 'web_search_news', - 'python_exec','tail_process', 'kill_process', - 'apply_patch', 'create_diff', - 'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents', - 'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics' + "get_tools_definition", + "read_file", + "write_file", + "list_directory", + "mkdir", + "chdir", + "getpwd", + "index_source_directory", + "search_replace", + "open_editor", + "editor_insert_text", + "editor_replace_text", + "editor_search", + "close_editor", + "run_command", + "run_command_interactive", + "db_set", + "db_get", + "db_query", + "http_fetch", + "web_search", + "web_search_news", + "python_exec", + "tail_process", + "kill_process", + "apply_patch", + "create_diff", + "create_agent", + "list_agents", + "execute_agent_task", + "remove_agent", + "collaborate_agents", + "add_knowledge_entry", + "get_knowledge_entry", + "search_knowledge", + "get_knowledge_by_category", + "update_knowledge_importance", + "delete_knowledge_entry", + "get_knowledge_statistics", ] diff --git a/pr/tools/agents.py b/pr/tools/agents.py index 770c40e..1bc02c5 100644 --- a/pr/tools/agents.py +++ b/pr/tools/agents.py @@ -1,13 +1,15 @@ import os -from typing import Dict, Any, List +from typing import Any, Dict, List + from pr.agents.agent_manager import AgentManager from pr.core.api import call_api + def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]: """Create a new agent with the specified role.""" try: # Get db_path from environment or default - db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite') + db_path = os.environ.get("ASSISTANT_DB_PATH", "~/.assistant_db.sqlite") db_path = os.path.expanduser(db_path) manager = AgentManager(db_path, call_api) @@ -16,47 +18,57 @@ def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]: except Exception as e: return {"status": "error", "error": str(e)} + def list_agents() -> Dict[str, Any]: """List all active agents.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") manager = AgentManager(db_path, call_api) agents = [] for agent_id, agent in manager.active_agents.items(): - agents.append({ - "agent_id": agent_id, - "role": agent.role.name, - "task_count": agent.task_count, - "message_count": len(agent.message_history) - }) + agents.append( + { + "agent_id": agent_id, + "role": agent.role.name, + "task_count": agent.task_count, + "message_count": len(agent.message_history), + } + ) return {"status": "success", "agents": agents} except Exception as e: return {"status": "error", "error": str(e)} -def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]: + +def execute_agent_task( + agent_id: str, task: str, context: Dict[str, Any] = None +) -> Dict[str, Any]: """Execute a task with the specified agent.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") manager = AgentManager(db_path, call_api) result = manager.execute_agent_task(agent_id, task, context) return result except Exception as e: return {"status": "error", "error": str(e)} + def remove_agent(agent_id: str) -> Dict[str, Any]: """Remove an agent.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") manager = AgentManager(db_path, call_api) success = manager.remove_agent(agent_id) return {"status": "success" if success else "not_found", "agent_id": agent_id} except Exception as e: return {"status": "error", "error": str(e)} -def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]: + +def collaborate_agents( + orchestrator_id: str, task: str, agent_roles: List[str] +) -> Dict[str, Any]: """Collaborate multiple agents on a task.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") manager = AgentManager(db_path, call_api) result = manager.collaborate_agents(orchestrator_id, task, agent_roles) return result diff --git a/pr/tools/base.py b/pr/tools/base.py index 2c50fab..73f613b 100644 --- a/pr/tools/base.py +++ b/pr/tools/base.py @@ -10,12 +10,12 @@ def get_tools_definition(): "properties": { "pid": { "type": "integer", - "description": "The process ID returned by run_command when status is 'running'." + "description": "The process ID returned by run_command when status is 'running'.", } }, - "required": ["pid"] - } - } + "required": ["pid"], + }, + }, }, { "type": "function", @@ -27,17 +27,17 @@ def get_tools_definition(): "properties": { "pid": { "type": "integer", - "description": "The process ID returned by run_command when status is 'running'." + "description": "The process ID returned by run_command when status is 'running'.", }, "timeout": { "type": "integer", "description": "Maximum seconds to wait for process completion. Returns partial output if still running.", - "default": 30 - } + "default": 30, + }, }, - "required": ["pid"] - } - } + "required": ["pid"], + }, + }, }, { "type": "function", @@ -48,11 +48,14 @@ def get_tools_definition(): "type": "object", "properties": { "url": {"type": "string", "description": "The URL to fetch"}, - "headers": {"type": "object", "description": "Optional HTTP headers"} + "headers": { + "type": "object", + "description": "Optional HTTP headers", + }, }, - "required": ["url"] - } - } + "required": ["url"], + }, + }, }, { "type": "function", @@ -62,12 +65,19 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "command": {"type": "string", "description": "The shell command to execute"}, - "timeout": {"type": "integer", "description": "Maximum seconds to wait for completion", "default": 30} + "command": { + "type": "string", + "description": "The shell command to execute", + }, + "timeout": { + "type": "integer", + "description": "Maximum seconds to wait for completion", + "default": 30, + }, }, - "required": ["command"] - } - } + "required": ["command"], + }, + }, }, { "type": "function", @@ -77,11 +87,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "command": {"type": "string", "description": "The interactive command to execute (e.g., vim, nano, top)"} + "command": { + "type": "string", + "description": "The interactive command to execute (e.g., vim, nano, top)", + } }, - "required": ["command"] - } - } + "required": ["command"], + }, + }, }, { "type": "function", @@ -91,12 +104,18 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "session_name": {"type": "string", "description": "The name of the session"}, - "input_data": {"type": "string", "description": "The input to send to the session"} + "session_name": { + "type": "string", + "description": "The name of the session", + }, + "input_data": { + "type": "string", + "description": "The input to send to the session", + }, }, - "required": ["session_name", "input_data"] - } - } + "required": ["session_name", "input_data"], + }, + }, }, { "type": "function", @@ -106,11 +125,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "session_name": {"type": "string", "description": "The name of the session"} + "session_name": { + "type": "string", + "description": "The name of the session", + } }, - "required": ["session_name"] - } - } + "required": ["session_name"], + }, + }, }, { "type": "function", @@ -120,11 +142,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "session_name": {"type": "string", "description": "The name of the session"} + "session_name": { + "type": "string", + "description": "The name of the session", + } }, - "required": ["session_name"] - } - } + "required": ["session_name"], + }, + }, }, { "type": "function", @@ -134,11 +159,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"} + "filepath": { + "type": "string", + "description": "Path to the file", + } }, - "required": ["filepath"] - } - } + "required": ["filepath"], + }, + }, }, { "type": "function", @@ -148,12 +176,18 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"}, - "content": {"type": "string", "description": "Content to write"} + "filepath": { + "type": "string", + "description": "Path to the file", + }, + "content": { + "type": "string", + "description": "Content to write", + }, }, - "required": ["filepath", "content"] - } - } + "required": ["filepath", "content"], + }, + }, }, { "type": "function", @@ -163,11 +197,19 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "path": {"type": "string", "description": "Directory path", "default": "."}, - "recursive": {"type": "boolean", "description": "List recursively", "default": False} - } - } - } + "path": { + "type": "string", + "description": "Directory path", + "default": ".", + }, + "recursive": { + "type": "boolean", + "description": "List recursively", + "default": False, + }, + }, + }, + }, }, { "type": "function", @@ -177,11 +219,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "path": {"type": "string", "description": "Path of the directory to create"} + "path": { + "type": "string", + "description": "Path of the directory to create", + } }, - "required": ["path"] - } - } + "required": ["path"], + }, + }, }, { "type": "function", @@ -193,17 +238,17 @@ def get_tools_definition(): "properties": { "path": {"type": "string", "description": "Path to change to"} }, - "required": ["path"] - } - } + "required": ["path"], + }, + }, }, { "type": "function", "function": { "name": "getpwd", "description": "Get the current working directory", - "parameters": {"type": "object", "properties": {}} - } + "parameters": {"type": "object", "properties": {}}, + }, }, { "type": "function", @@ -214,11 +259,11 @@ def get_tools_definition(): "type": "object", "properties": { "key": {"type": "string", "description": "The key"}, - "value": {"type": "string", "description": "The value"} + "value": {"type": "string", "description": "The value"}, }, - "required": ["key", "value"] - } - } + "required": ["key", "value"], + }, + }, }, { "type": "function", @@ -227,12 +272,10 @@ def get_tools_definition(): "description": "Get a value from the database", "parameters": { "type": "object", - "properties": { - "key": {"type": "string", "description": "The key"} - }, - "required": ["key"] - } - } + "properties": {"key": {"type": "string", "description": "The key"}}, + "required": ["key"], + }, + }, }, { "type": "function", @@ -244,9 +287,9 @@ def get_tools_definition(): "properties": { "query": {"type": "string", "description": "SQL query"} }, - "required": ["query"] - } - } + "required": ["query"], + }, + }, }, { "type": "function", @@ -258,9 +301,9 @@ def get_tools_definition(): "properties": { "query": {"type": "string", "description": "Search query"} }, - "required": ["query"] - } - } + "required": ["query"], + }, + }, }, { "type": "function", @@ -270,11 +313,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "query": {"type": "string", "description": "Search query for news"} + "query": { + "type": "string", + "description": "Search query for news", + } }, - "required": ["query"] - } - } + "required": ["query"], + }, + }, }, { "type": "function", @@ -284,11 +330,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "code": {"type": "string", "description": "Python code to execute"} + "code": { + "type": "string", + "description": "Python code to execute", + } }, - "required": ["code"] - } - } + "required": ["code"], + }, + }, }, { "type": "function", @@ -300,9 +349,9 @@ def get_tools_definition(): "properties": { "path": {"type": "string", "description": "Path to index"} }, - "required": ["path"] - } - } + "required": ["path"], + }, + }, }, { "type": "function", @@ -312,13 +361,22 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"}, - "old_string": {"type": "string", "description": "String to replace"}, - "new_string": {"type": "string", "description": "Replacement string"} + "filepath": { + "type": "string", + "description": "Path to the file", + }, + "old_string": { + "type": "string", + "description": "String to replace", + }, + "new_string": { + "type": "string", + "description": "Replacement string", + }, }, - "required": ["filepath", "old_string", "new_string"] - } - } + "required": ["filepath", "old_string", "new_string"], + }, + }, }, { "type": "function", @@ -328,12 +386,18 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file to patch"}, - "patch_content": {"type": "string", "description": "The patch content as a string"} + "filepath": { + "type": "string", + "description": "Path to the file to patch", + }, + "patch_content": { + "type": "string", + "description": "The patch content as a string", + }, }, - "required": ["filepath", "patch_content"] - } - } + "required": ["filepath", "patch_content"], + }, + }, }, { "type": "function", @@ -343,14 +407,28 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "file1": {"type": "string", "description": "Path to the first file"}, - "file2": {"type": "string", "description": "Path to the second file"}, - "fromfile": {"type": "string", "description": "Label for the first file", "default": "file1"}, - "tofile": {"type": "string", "description": "Label for the second file", "default": "file2"} + "file1": { + "type": "string", + "description": "Path to the first file", + }, + "file2": { + "type": "string", + "description": "Path to the second file", + }, + "fromfile": { + "type": "string", + "description": "Label for the first file", + "default": "file1", + }, + "tofile": { + "type": "string", + "description": "Label for the second file", + "default": "file2", + }, }, - "required": ["file1", "file2"] - } - } + "required": ["file1", "file2"], + }, + }, }, { "type": "function", @@ -360,11 +438,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"} + "filepath": { + "type": "string", + "description": "Path to the file", + } }, - "required": ["filepath"] - } - } + "required": ["filepath"], + }, + }, }, { "type": "function", @@ -374,11 +455,14 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"} + "filepath": { + "type": "string", + "description": "Path to the file", + } }, - "required": ["filepath"] - } - } + "required": ["filepath"], + }, + }, }, { "type": "function", @@ -388,14 +472,23 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"}, + "filepath": { + "type": "string", + "description": "Path to the file", + }, "text": {"type": "string", "description": "Text to insert"}, - "line": {"type": "integer", "description": "Line number (optional)"}, - "col": {"type": "integer", "description": "Column number (optional)"} + "line": { + "type": "integer", + "description": "Line number (optional)", + }, + "col": { + "type": "integer", + "description": "Column number (optional)", + }, }, - "required": ["filepath", "text"] - } - } + "required": ["filepath", "text"], + }, + }, }, { "type": "function", @@ -405,16 +498,26 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"}, + "filepath": { + "type": "string", + "description": "Path to the file", + }, "start_line": {"type": "integer", "description": "Start line"}, "start_col": {"type": "integer", "description": "Start column"}, "end_line": {"type": "integer", "description": "End line"}, "end_col": {"type": "integer", "description": "End column"}, - "new_text": {"type": "string", "description": "New text"} + "new_text": {"type": "string", "description": "New text"}, }, - "required": ["filepath", "start_line", "start_col", "end_line", "end_col", "new_text"] - } - } + "required": [ + "filepath", + "start_line", + "start_col", + "end_line", + "end_col", + "new_text", + ], + }, + }, }, { "type": "function", @@ -424,13 +527,20 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath": {"type": "string", "description": "Path to the file"}, + "filepath": { + "type": "string", + "description": "Path to the file", + }, "pattern": {"type": "string", "description": "Regex pattern"}, - "start_line": {"type": "integer", "description": "Start line", "default": 0} + "start_line": { + "type": "integer", + "description": "Start line", + "default": 0, + }, }, - "required": ["filepath", "pattern"] - } - } + "required": ["filepath", "pattern"], + }, + }, }, { "type": "function", @@ -440,24 +550,31 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "filepath1": {"type": "string", "description": "Path to the original file"}, - "filepath2": {"type": "string", "description": "Path to the modified file"}, - "format_type": {"type": "string", "description": "Display format: 'unified' or 'side-by-side'", "default": "unified"} + "filepath1": { + "type": "string", + "description": "Path to the original file", + }, + "filepath2": { + "type": "string", + "description": "Path to the modified file", + }, + "format_type": { + "type": "string", + "description": "Display format: 'unified' or 'side-by-side'", + "default": "unified", + }, }, - "required": ["filepath1", "filepath2"] - } - } + "required": ["filepath1", "filepath2"], + }, + }, }, { "type": "function", "function": { "name": "display_edit_summary", "description": "Display a summary of all edit operations performed during the session", - "parameters": { - "type": "object", - "properties": {} - } - } + "parameters": {"type": "object", "properties": {}}, + }, }, { "type": "function", @@ -467,21 +584,21 @@ def get_tools_definition(): "parameters": { "type": "object", "properties": { - "show_content": {"type": "boolean", "description": "Show content previews", "default": False} - } - } - } + "show_content": { + "type": "boolean", + "description": "Show content previews", + "default": False, + } + }, + }, + }, }, { "type": "function", "function": { "name": "clear_edit_tracker", "description": "Clear the edit tracker to start fresh", - "parameters": { - "type": "object", - "properties": {} - } - } - } + "parameters": {"type": "object", "properties": {}}, + }, + }, ] - diff --git a/pr/tools/command.py b/pr/tools/command.py index 237e551..0e0def1 100644 --- a/pr/tools/command.py +++ b/pr/tools/command.py @@ -1,21 +1,23 @@ import os +import select import subprocess import time -import select -from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer -from pr.tools.interactive_control import start_interactive_session -from pr.config import MAX_CONCURRENT_SESSIONS + +from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer _processes = {} -def _register_process(pid:int, process): + +def _register_process(pid: int, process): _processes[pid] = process return _processes -def _get_process(pid:int): + +def _get_process(pid: int): return _processes.get(pid) -def kill_process(pid:int): + +def kill_process(pid: int): try: process = _get_process(pid) if process: @@ -67,7 +69,7 @@ def tail_process(pid: int, timeout: int = 30): "status": "success", "stdout": stdout_content, "stderr": stderr_content, - "returncode": process.returncode + "returncode": process.returncode, } if time.time() - start_time > timeout_duration: @@ -76,10 +78,12 @@ def tail_process(pid: int, timeout: int = 30): "message": "Process is still running. Call tail_process again to continue monitoring.", "stdout_so_far": stdout_content, "stderr_so_far": stderr_content, - "pid": pid + "pid": pid, } - ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) + ready, _, _ = select.select( + [process.stdout, process.stderr], [], [], 0.1 + ) for pipe in ready: if pipe == process.stdout: line = process.stdout.readline() @@ -100,7 +104,13 @@ def tail_process(pid: int, timeout: int = 30): def run_command(command, timeout=30, monitored=False): mux_name = None try: - process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) _register_process(process.pid, process) mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True) @@ -129,7 +139,7 @@ def run_command(command, timeout=30, monitored=False): "status": "success", "stdout": stdout_content, "stderr": stderr_content, - "returncode": process.returncode + "returncode": process.returncode, } if time.time() - start_time > timeout_duration: @@ -139,7 +149,7 @@ def run_command(command, timeout=30, monitored=False): "stdout_so_far": stdout_content, "stderr_so_far": stderr_content, "pid": process.pid, - "mux_name": mux_name + "mux_name": mux_name, } ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) @@ -158,6 +168,8 @@ def run_command(command, timeout=30, monitored=False): if mux_name: close_multiplexer(mux_name) return {"status": "error", "error": str(e)} + + def run_command_interactive(command): try: return_code = os.system(command) diff --git a/pr/tools/database.py b/pr/tools/database.py index 0a7c092..6de4d57 100644 --- a/pr/tools/database.py +++ b/pr/tools/database.py @@ -1,18 +1,23 @@ import time + def db_set(key, value, db_conn): if not db_conn: return {"status": "error", "error": "Database not initialized"} try: cursor = db_conn.cursor() - cursor.execute("""INSERT OR REPLACE INTO kv_store (key, value, timestamp) - VALUES (?, ?, ?)""", (key, value, time.time())) + cursor.execute( + """INSERT OR REPLACE INTO kv_store (key, value, timestamp) + VALUES (?, ?, ?)""", + (key, value, time.time()), + ) db_conn.commit() return {"status": "success", "message": f"Set {key}"} except Exception as e: return {"status": "error", "error": str(e)} + def db_get(key, db_conn): if not db_conn: return {"status": "error", "error": "Database not initialized"} @@ -28,6 +33,7 @@ def db_get(key, db_conn): except Exception as e: return {"status": "error", "error": str(e)} + def db_query(query, db_conn): if not db_conn: return {"status": "error", "error": "Database not initialized"} @@ -36,9 +42,11 @@ def db_query(query, db_conn): cursor = db_conn.cursor() cursor.execute(query) - if query.strip().upper().startswith('SELECT'): + if query.strip().upper().startswith("SELECT"): results = cursor.fetchall() - columns = [desc[0] for desc in cursor.description] if cursor.description else [] + columns = ( + [desc[0] for desc in cursor.description] if cursor.description else [] + ) return {"status": "success", "columns": columns, "rows": results} else: db_conn.commit() diff --git a/pr/tools/editor.py b/pr/tools/editor.py index d7387f5..dad2112 100644 --- a/pr/tools/editor.py +++ b/pr/tools/editor.py @@ -1,18 +1,21 @@ -from pr.editor import RPEditor -from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer -from ..ui.diff_display import display_diff, get_diff_stats -from ..ui.edit_feedback import track_edit, tracker -from ..tools.patch import display_content_diff import os import os.path +from pr.editor import RPEditor +from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer + +from ..tools.patch import display_content_diff +from ..ui.edit_feedback import track_edit, tracker + _editors = {} + def get_editor(filepath): if filepath not in _editors: _editors[filepath] = RPEditor(filepath) return _editors[filepath] + def close_editor(filepath): try: path = os.path.expanduser(filepath) @@ -29,6 +32,7 @@ def close_editor(filepath): except Exception as e: return {"status": "error", "error": str(e)} + def open_editor(filepath): try: path = os.path.expanduser(filepath) @@ -39,21 +43,28 @@ def open_editor(filepath): mux_name, mux = create_multiplexer(mux_name, show_output=True) mux.write_stdout(f"Opened editor for: {path}\n") - return {"status": "success", "message": f"Editor opened for {path}", "mux_name": mux_name} + return { + "status": "success", + "message": f"Editor opened for {path}", + "mux_name": mux_name, + } except Exception as e: return {"status": "error", "error": str(e)} + def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): try: path = os.path.expanduser(filepath) old_content = "" if os.path.exists(path): - with open(path, 'r') as f: + with open(path) as f: old_content = f.read() - position = (line if line is not None else 0) * 1000 + (col if col is not None else 0) - operation = track_edit('INSERT', filepath, start_pos=position, content=text) + position = (line if line is not None else 0) * 1000 + ( + col if col is not None else 0 + ) + operation = track_edit("INSERT", filepath, start_pos=position, content=text) tracker.mark_in_progress(operation) editor = get_editor(path) @@ -65,12 +76,16 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): mux_name = f"editor-{path}" mux = get_multiplexer(mux_name) if mux: - location = f" at line {line}, col {col}" if line is not None and col is not None else "" + location = ( + f" at line {line}, col {col}" + if line is not None and col is not None + else "" + ) preview = text[:50] + "..." if len(text) > 50 else text mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n") if show_diff and old_content: - with open(path, 'r') as f: + with open(path) as f: new_content = f.read() diff_result = display_content_diff(old_content, new_content, filepath) if diff_result["status"] == "success": @@ -81,23 +96,32 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): close_editor(filepath) return result except Exception as e: - if 'operation' in locals(): + if "operation" in locals(): tracker.mark_failed(operation) return {"status": "error", "error": str(e)} -def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True): + +def editor_replace_text( + filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True +): try: path = os.path.expanduser(filepath) old_content = "" if os.path.exists(path): - with open(path, 'r') as f: + with open(path) as f: old_content = f.read() start_pos = start_line * 1000 + start_col end_pos = end_line * 1000 + end_col - operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos, - content=new_text, old_content=old_content) + operation = track_edit( + "REPLACE", + filepath, + start_pos=start_pos, + end_pos=end_pos, + content=new_text, + old_content=old_content, + ) tracker.mark_in_progress(operation) editor = get_editor(path) @@ -108,10 +132,12 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_ mux = get_multiplexer(mux_name) if mux: preview = new_text[:50] + "..." if len(new_text) > 50 else new_text - mux.write_stdout(f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n") + mux.write_stdout( + f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n" + ) if show_diff and old_content: - with open(path, 'r') as f: + with open(path) as f: new_content = f.read() diff_result = display_content_diff(old_content, new_content, filepath) if diff_result["status"] == "success": @@ -122,10 +148,11 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_ close_editor(filepath) return result except Exception as e: - if 'operation' in locals(): + if "operation" in locals(): tracker.mark_failed(operation) return {"status": "error", "error": str(e)} + def editor_search(filepath, pattern, start_line=0): try: path = os.path.expanduser(filepath) @@ -135,7 +162,9 @@ def editor_search(filepath, pattern, start_line=0): mux_name = f"editor-{path}" mux = get_multiplexer(mux_name) if mux: - mux.write_stdout(f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n") + mux.write_stdout( + f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n" + ) result = {"status": "success", "results": results} close_editor(filepath) diff --git a/pr/tools/filesystem.py b/pr/tools/filesystem.py index 868955b..a51f58c 100644 --- a/pr/tools/filesystem.py +++ b/pr/tools/filesystem.py @@ -1,31 +1,36 @@ -import os import hashlib +import os import time -from typing import Dict + from pr.editor import RPEditor -from ..ui.diff_display import display_diff, get_diff_stats -from ..ui.edit_feedback import track_edit, tracker + from ..tools.patch import display_content_diff +from ..ui.diff_display import get_diff_stats +from ..ui.edit_feedback import track_edit, tracker _id = 0 + def get_uid(): global _id _id += 3 return _id + def read_file(filepath, db_conn=None): try: path = os.path.expanduser(filepath) - with open(path, 'r') as f: + with open(path) as f: content = f.read() if db_conn: from pr.tools.database import db_set + db_set("read:" + path, "true", db_conn) return {"status": "success", "content": content} except Exception as e: return {"status": "error", "error": str(e)} + def write_file(filepath, content, db_conn=None, show_diff=True): try: path = os.path.expanduser(filepath) @@ -34,15 +39,24 @@ def write_file(filepath, content, db_conn=None, show_diff=True): if not is_new_file and db_conn: from pr.tools.database import db_get + read_status = db_get("read:" + path, db_conn) - if read_status.get("status") != "success" or read_status.get("value") != "true": - return {"status": "error", "error": "File must be read before writing. Please read the file first."} + if ( + read_status.get("status") != "success" + or read_status.get("value") != "true" + ): + return { + "status": "error", + "error": "File must be read before writing. Please read the file first.", + } if not is_new_file: - with open(path, 'r') as f: + with open(path) as f: old_content = f.read() - operation = track_edit('WRITE', filepath, content=content, old_content=old_content) + operation = track_edit( + "WRITE", filepath, content=content, old_content=old_content + ) tracker.mark_in_progress(operation) if show_diff and not is_new_file: @@ -59,13 +73,18 @@ def write_file(filepath, content, db_conn=None, show_diff=True): cursor = db_conn.cursor() file_hash = hashlib.md5(old_content.encode()).hexdigest() - cursor.execute("SELECT MAX(version) FROM file_versions WHERE filepath = ?", (filepath,)) + cursor.execute( + "SELECT MAX(version) FROM file_versions WHERE filepath = ?", + (filepath,), + ) result = cursor.fetchone() version = (result[0] + 1) if result[0] else 1 - cursor.execute("""INSERT INTO file_versions (filepath, content, hash, timestamp, version) + cursor.execute( + """INSERT INTO file_versions (filepath, content, hash, timestamp, version) VALUES (?, ?, ?, ?, ?)""", - (filepath, old_content, file_hash, time.time(), version)) + (filepath, old_content, file_hash, time.time(), version), + ) db_conn.commit() except Exception: pass @@ -79,10 +98,11 @@ def write_file(filepath, content, db_conn=None, show_diff=True): return {"status": "success", "message": message} except Exception as e: - if 'operation' in locals(): + if "operation" in locals(): tracker.mark_failed(operation) return {"status": "error", "error": str(e)} + def list_directory(path=".", recursive=False): try: path = os.path.expanduser(path) @@ -91,21 +111,36 @@ def list_directory(path=".", recursive=False): for root, dirs, files in os.walk(path): for name in files: item_path = os.path.join(root, name) - items.append({"path": item_path, "type": "file", "size": os.path.getsize(item_path)}) + items.append( + { + "path": item_path, + "type": "file", + "size": os.path.getsize(item_path), + } + ) for name in dirs: - items.append({"path": os.path.join(root, name), "type": "directory"}) + items.append( + {"path": os.path.join(root, name), "type": "directory"} + ) else: for item in os.listdir(path): item_path = os.path.join(path, item) - items.append({ - "name": item, - "type": "directory" if os.path.isdir(item_path) else "file", - "size": os.path.getsize(item_path) if os.path.isfile(item_path) else None - }) + items.append( + { + "name": item, + "type": "directory" if os.path.isdir(item_path) else "file", + "size": ( + os.path.getsize(item_path) + if os.path.isfile(item_path) + else None + ), + } + ) return {"status": "success", "items": items} except Exception as e: return {"status": "error", "error": str(e)} + def mkdir(path): try: os.makedirs(os.path.expanduser(path), exist_ok=True) @@ -113,6 +148,7 @@ def mkdir(path): except Exception as e: return {"status": "error", "error": str(e)} + def chdir(path): try: os.chdir(os.path.expanduser(path)) @@ -120,16 +156,32 @@ def chdir(path): except Exception as e: return {"status": "error", "error": str(e)} + def getpwd(): try: return {"status": "success", "path": os.getcwd()} except Exception as e: return {"status": "error", "error": str(e)} + def index_source_directory(path): extensions = [ - ".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp", - ".html", ".css", ".json", ".xml", ".md", ".sh", ".rb", ".go" + ".py", + ".js", + ".ts", + ".java", + ".cpp", + ".c", + ".h", + ".hpp", + ".html", + ".css", + ".json", + ".xml", + ".md", + ".sh", + ".rb", + ".go", ] source_files = [] try: @@ -138,18 +190,16 @@ def index_source_directory(path): if any(file.endswith(ext) for ext in extensions): filepath = os.path.join(root, file) try: - with open(filepath, 'r', encoding='utf-8') as f: + with open(filepath, encoding="utf-8") as f: content = f.read() - source_files.append({ - "path": filepath, - "content": content - }) + source_files.append({"path": filepath, "content": content}) except Exception: continue return {"status": "success", "indexed_files": source_files} except Exception as e: return {"status": "error", "error": str(e)} + def search_replace(filepath, old_string, new_string, db_conn=None): try: path = os.path.expanduser(filepath) @@ -157,25 +207,38 @@ def search_replace(filepath, old_string, new_string, db_conn=None): return {"status": "error", "error": "File does not exist"} if db_conn: from pr.tools.database import db_get + read_status = db_get("read:" + path, db_conn) - if read_status.get("status") != "success" or read_status.get("value") != "true": - return {"status": "error", "error": "File must be read before writing. Please read the file first."} - with open(path, 'r') as f: + if ( + read_status.get("status") != "success" + or read_status.get("value") != "true" + ): + return { + "status": "error", + "error": "File must be read before writing. Please read the file first.", + } + with open(path) as f: content = f.read() content = content.replace(old_string, new_string) - with open(path, 'w') as f: + with open(path, "w") as f: f.write(content) - return {"status": "success", "message": f"Replaced '{old_string}' with '{new_string}' in {path}"} + return { + "status": "success", + "message": f"Replaced '{old_string}' with '{new_string}' in {path}", + } except Exception as e: return {"status": "error", "error": str(e)} + _editors = {} + def get_editor(filepath): if filepath not in _editors: _editors[filepath] = RPEditor(filepath) return _editors[filepath] + def close_editor(filepath): try: path = os.path.expanduser(filepath) @@ -185,6 +248,7 @@ def close_editor(filepath): except Exception as e: return {"status": "error", "error": str(e)} + def open_editor(filepath): try: path = os.path.expanduser(filepath) @@ -194,22 +258,34 @@ def open_editor(filepath): except Exception as e: return {"status": "error", "error": str(e)} -def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None): + +def editor_insert_text( + filepath, text, line=None, col=None, show_diff=True, db_conn=None +): try: path = os.path.expanduser(filepath) if db_conn: from pr.tools.database import db_get + read_status = db_get("read:" + path, db_conn) - if read_status.get("status") != "success" or read_status.get("value") != "true": - return {"status": "error", "error": "File must be read before writing. Please read the file first."} + if ( + read_status.get("status") != "success" + or read_status.get("value") != "true" + ): + return { + "status": "error", + "error": "File must be read before writing. Please read the file first.", + } old_content = "" if os.path.exists(path): - with open(path, 'r') as f: + with open(path) as f: old_content = f.read() - position = (line if line is not None else 0) * 1000 + (col if col is not None else 0) - operation = track_edit('INSERT', filepath, start_pos=position, content=text) + position = (line if line is not None else 0) * 1000 + ( + col if col is not None else 0 + ) + operation = track_edit("INSERT", filepath, start_pos=position, content=text) tracker.mark_in_progress(operation) editor = get_editor(path) @@ -219,7 +295,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c editor.save_file() if show_diff and old_content: - with open(path, 'r') as f: + with open(path) as f: new_content = f.read() diff_result = display_content_diff(old_content, new_content, filepath) if diff_result["status"] == "success": @@ -228,28 +304,51 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c tracker.mark_completed(operation) return {"status": "success", "message": f"Inserted text in {path}"} except Exception as e: - if 'operation' in locals(): + if "operation" in locals(): tracker.mark_failed(operation) return {"status": "error", "error": str(e)} -def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True, db_conn=None): + +def editor_replace_text( + filepath, + start_line, + start_col, + end_line, + end_col, + new_text, + show_diff=True, + db_conn=None, +): try: path = os.path.expanduser(filepath) if db_conn: from pr.tools.database import db_get + read_status = db_get("read:" + path, db_conn) - if read_status.get("status") != "success" or read_status.get("value") != "true": - return {"status": "error", "error": "File must be read before writing. Please read the file first."} + if ( + read_status.get("status") != "success" + or read_status.get("value") != "true" + ): + return { + "status": "error", + "error": "File must be read before writing. Please read the file first.", + } old_content = "" if os.path.exists(path): - with open(path, 'r') as f: + with open(path) as f: old_content = f.read() start_pos = start_line * 1000 + start_col end_pos = end_line * 1000 + end_col - operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos, - content=new_text, old_content=old_content) + operation = track_edit( + "REPLACE", + filepath, + start_pos=start_pos, + end_pos=end_pos, + content=new_text, + old_content=old_content, + ) tracker.mark_in_progress(operation) editor = get_editor(path) @@ -257,7 +356,7 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_ editor.save_file() if show_diff and old_content: - with open(path, 'r') as f: + with open(path) as f: new_content = f.read() diff_result = display_content_diff(old_content, new_content, filepath) if diff_result["status"] == "success": @@ -266,22 +365,25 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_ tracker.mark_completed(operation) return {"status": "success", "message": f"Replaced text in {path}"} except Exception as e: - if 'operation' in locals(): + if "operation" in locals(): tracker.mark_failed(operation) return {"status": "error", "error": str(e)} def display_edit_summary(): from ..ui.edit_feedback import display_edit_summary + return display_edit_summary() def display_edit_timeline(show_content=False): from ..ui.edit_feedback import display_edit_timeline + return display_edit_timeline(show_content) def clear_edit_tracker(): from ..ui.edit_feedback import clear_tracker + clear_tracker() return {"status": "success", "message": "Edit tracker cleared"} diff --git a/pr/tools/interactive_control.py b/pr/tools/interactive_control.py index f11bebf..153ebc4 100644 --- a/pr/tools/interactive_control.py +++ b/pr/tools/interactive_control.py @@ -1,9 +1,15 @@ import subprocess import threading -import time -from pr.multiplexer import create_multiplexer, get_multiplexer, close_multiplexer, get_all_multiplexer_states -def start_interactive_session(command, session_name=None, process_type='generic'): +from pr.multiplexer import ( + close_multiplexer, + create_multiplexer, + get_all_multiplexer_states, + get_multiplexer, +) + + +def start_interactive_session(command, session_name=None, process_type="generic"): """ Start an interactive session in a dedicated multiplexer. @@ -16,7 +22,7 @@ def start_interactive_session(command, session_name=None, process_type='generic' session_name: The name of the created session """ name, mux = create_multiplexer(session_name) - mux.update_metadata('process_type', process_type) + mux.update_metadata("process_type", process_type) # Start the process if isinstance(command, str): @@ -29,19 +35,23 @@ def start_interactive_session(command, session_name=None, process_type='generic' stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - bufsize=1 + bufsize=1, ) mux.process = process - mux.update_metadata('pid', process.pid) + mux.update_metadata("pid", process.pid) # Set process type and handler detected_type = detect_process_type(command) mux.set_process_type(detected_type) # Start output readers - stdout_thread = threading.Thread(target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True) - stderr_thread = threading.Thread(target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True) + stdout_thread = threading.Thread( + target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True + ) + stderr_thread = threading.Thread( + target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True + ) stdout_thread.start() stderr_thread.start() @@ -54,15 +64,17 @@ def start_interactive_session(command, session_name=None, process_type='generic' close_multiplexer(name) raise e + def _read_output(stream, write_func): """Read from a stream and write to multiplexer buffer.""" try: - for line in iter(stream.readline, ''): + for line in iter(stream.readline, ""): if line: - write_func(line.rstrip('\n')) + write_func(line.rstrip("\n")) except Exception as e: print(f"Error reading output: {e}") + def send_input_to_session(session_name, input_data): """ Send input to an interactive session. @@ -75,15 +87,16 @@ def send_input_to_session(session_name, input_data): if not mux: raise ValueError(f"Session {session_name} not found") - if not hasattr(mux, 'process') or mux.process.poll() is not None: + if not hasattr(mux, "process") or mux.process.poll() is not None: raise ValueError(f"Session {session_name} is not active") try: - mux.process.stdin.write(input_data + '\n') + mux.process.stdin.write(input_data + "\n") mux.process.stdin.flush() except Exception as e: raise RuntimeError(f"Failed to send input to session {session_name}: {e}") + def read_session_output(session_name, lines=None): """ Read output from a session. @@ -102,14 +115,12 @@ def read_session_output(session_name, lines=None): output = mux.get_all_output() if lines is not None: # Return last N lines - stdout_lines = output['stdout'].split('\n')[-lines:] if output['stdout'] else [] - stderr_lines = output['stderr'].split('\n')[-lines:] if output['stderr'] else [] - output = { - 'stdout': '\n'.join(stdout_lines), - 'stderr': '\n'.join(stderr_lines) - } + stdout_lines = output["stdout"].split("\n")[-lines:] if output["stdout"] else [] + stderr_lines = output["stderr"].split("\n")[-lines:] if output["stderr"] else [] + output = {"stdout": "\n".join(stdout_lines), "stderr": "\n".join(stderr_lines)} return output + def list_active_sessions(): """ List all active interactive sessions. @@ -119,6 +130,7 @@ def list_active_sessions(): """ return get_all_multiplexer_states() + def get_session_status(session_name): """ Get detailed status of a session. @@ -134,15 +146,16 @@ def get_session_status(session_name): return None status = mux.get_metadata() - status['is_active'] = hasattr(mux, 'process') and mux.process.poll() is None - if status['is_active']: - status['pid'] = mux.process.pid - status['output_summary'] = { - 'stdout_lines': len(mux.stdout_buffer), - 'stderr_lines': len(mux.stderr_buffer) + status["is_active"] = hasattr(mux, "process") and mux.process.poll() is None + if status["is_active"]: + status["pid"] = mux.process.pid + status["output_summary"] = { + "stdout_lines": len(mux.stdout_buffer), + "stderr_lines": len(mux.stderr_buffer), } return status + def close_interactive_session(session_name): """ Close an interactive session. diff --git a/pr/tools/memory.py b/pr/tools/memory.py index 12b48b4..d169df1 100644 --- a/pr/tools/memory.py +++ b/pr/tools/memory.py @@ -1,38 +1,43 @@ import os -from typing import Dict, Any, List -from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry import time import uuid +from typing import Any, Dict -def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]: +from pr.memory.knowledge_store import KnowledgeEntry, KnowledgeStore + + +def add_knowledge_entry( + category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None +) -> Dict[str, Any]: """Add a new entry to the knowledge base.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + if entry_id is None: entry_id = str(uuid.uuid4())[:16] - + entry = KnowledgeEntry( entry_id=entry_id, category=category, content=content, metadata=metadata or {}, created_at=time.time(), - updated_at=time.time() + updated_at=time.time(), ) - + store.add_entry(entry) return {"status": "success", "entry_id": entry_id} except Exception as e: return {"status": "error", "error": str(e)} + def get_knowledge_entry(entry_id: str) -> Dict[str, Any]: """Retrieve a knowledge entry by ID.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + entry = store.get_entry(entry_id) if entry: return {"status": "success", "entry": entry.to_dict()} @@ -41,58 +46,71 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]: except Exception as e: return {"status": "error", "error": str(e)} -def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]: + +def search_knowledge( + query: str, category: str = None, top_k: int = 5 +) -> Dict[str, Any]: """Search the knowledge base semantically.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + entries = store.search_entries(query, category, top_k) results = [entry.to_dict() for entry in entries] return {"status": "success", "results": results} except Exception as e: return {"status": "error", "error": str(e)} + def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]: """Get knowledge entries by category.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + entries = store.get_by_category(category, limit) results = [entry.to_dict() for entry in entries] return {"status": "success", "entries": results} except Exception as e: return {"status": "error", "error": str(e)} -def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]: + +def update_knowledge_importance( + entry_id: str, importance_score: float +) -> Dict[str, Any]: """Update the importance score of a knowledge entry.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + store.update_importance(entry_id, importance_score) - return {"status": "success", "entry_id": entry_id, "importance_score": importance_score} + return { + "status": "success", + "entry_id": entry_id, + "importance_score": importance_score, + } except Exception as e: return {"status": "error", "error": str(e)} + def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]: """Delete a knowledge entry.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + success = store.delete_entry(entry_id) return {"status": "success" if success else "not_found", "entry_id": entry_id} except Exception as e: return {"status": "error", "error": str(e)} + def get_knowledge_statistics() -> Dict[str, Any]: """Get statistics about the knowledge base.""" try: - db_path = os.path.expanduser('~/.assistant_db.sqlite') + db_path = os.path.expanduser("~/.assistant_db.sqlite") store = KnowledgeStore(db_path) - + stats = store.get_statistics() return {"status": "success", "statistics": stats} except Exception as e: diff --git a/pr/tools/patch.py b/pr/tools/patch.py index 7e11b69..ab40ecc 100644 --- a/pr/tools/patch.py +++ b/pr/tools/patch.py @@ -1,23 +1,37 @@ -import os -import tempfile -import subprocess import difflib -from ..ui.diff_display import display_diff, get_diff_stats, DiffDisplay +import os +import subprocess +import tempfile + +from ..ui.diff_display import display_diff, get_diff_stats + def apply_patch(filepath, patch_content, db_conn=None): try: path = os.path.expanduser(filepath) if db_conn: from pr.tools.database import db_get + read_status = db_get("read:" + path, db_conn) - if read_status.get("status") != "success" or read_status.get("value") != "true": - return {"status": "error", "error": "File must be read before writing. Please read the file first."} + if ( + read_status.get("status") != "success" + or read_status.get("value") != "true" + ): + return { + "status": "error", + "error": "File must be read before writing. Please read the file first.", + } # Write patch to temp file - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.patch') as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".patch") as f: f.write(patch_content) patch_file = f.name # Run patch command - result = subprocess.run(['patch', path, patch_file], capture_output=True, text=True, cwd=os.path.dirname(path)) + result = subprocess.run( + ["patch", path, patch_file], + capture_output=True, + text=True, + cwd=os.path.dirname(path), + ) os.unlink(patch_file) if result.returncode == 0: return {"status": "success", "output": result.stdout.strip()} @@ -26,11 +40,14 @@ def apply_patch(filepath, patch_content, db_conn=None): except Exception as e: return {"status": "error", "error": str(e)} -def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, format_type='unified'): + +def create_diff( + file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified" +): try: path1 = os.path.expanduser(file1) path2 = os.path.expanduser(file2) - with open(path1, 'r') as f1, open(path2, 'r') as f2: + with open(path1) as f1, open(path2) as f2: content1 = f1.read() content2 = f2.read() @@ -39,53 +56,51 @@ def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, fo stats = get_diff_stats(content1, content2) lines1 = content1.splitlines(keepends=True) lines2 = content2.splitlines(keepends=True) - plain_diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)) + plain_diff = list( + difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile) + ) return { "status": "success", - "diff": ''.join(plain_diff), + "diff": "".join(plain_diff), "visual_diff": visual_diff, - "stats": stats + "stats": stats, } else: lines1 = content1.splitlines(keepends=True) lines2 = content2.splitlines(keepends=True) - diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)) - return {"status": "success", "diff": ''.join(diff)} + diff = list( + difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile) + ) + return {"status": "success", "diff": "".join(diff)} except Exception as e: return {"status": "error", "error": str(e)} -def display_file_diff(filepath1, filepath2, format_type='unified', context_lines=3): +def display_file_diff(filepath1, filepath2, format_type="unified", context_lines=3): try: path1 = os.path.expanduser(filepath1) path2 = os.path.expanduser(filepath2) - with open(path1, 'r') as f1: + with open(path1) as f1: old_content = f1.read() - with open(path2, 'r') as f2: + with open(path2) as f2: new_content = f2.read() visual_diff = display_diff(old_content, new_content, filepath1, format_type) stats = get_diff_stats(old_content, new_content) - return { - "status": "success", - "visual_diff": visual_diff, - "stats": stats - } + return {"status": "success", "visual_diff": visual_diff, "stats": stats} except Exception as e: return {"status": "error", "error": str(e)} -def display_content_diff(old_content, new_content, filename='file', format_type='unified'): +def display_content_diff( + old_content, new_content, filename="file", format_type="unified" +): try: visual_diff = display_diff(old_content, new_content, filename, format_type) stats = get_diff_stats(old_content, new_content) - return { - "status": "success", - "visual_diff": visual_diff, - "stats": stats - } + return {"status": "success", "visual_diff": visual_diff, "stats": stats} except Exception as e: - return {"status": "error", "error": str(e)} \ No newline at end of file + return {"status": "error", "error": str(e)} diff --git a/pr/tools/process_handlers.py b/pr/tools/process_handlers.py index 639b12a..fa7d278 100644 --- a/pr/tools/process_handlers.py +++ b/pr/tools/process_handlers.py @@ -1,14 +1,13 @@ -import re -import time from abc import ABC, abstractmethod + class ProcessHandler(ABC): """Base class for process-specific handlers.""" def __init__(self, multiplexer): self.multiplexer = multiplexer self.state_machine = {} - self.current_state = 'initial' + self.current_state = "initial" self.prompt_patterns = [] self.response_suggestions = {} @@ -27,7 +26,8 @@ class ProcessHandler(ABC): def is_waiting_for_input(self): """Check if process appears to be waiting for input.""" - return self.current_state in ['waiting_confirmation', 'waiting_input'] + return self.current_state in ["waiting_confirmation", "waiting_input"] + class AptHandler(ProcessHandler): """Handler for apt package manager interactions.""" @@ -35,230 +35,238 @@ class AptHandler(ProcessHandler): def __init__(self, multiplexer): super().__init__(multiplexer) self.state_machine = { - 'initial': ['running_command'], - 'running_command': ['waiting_confirmation', 'completed'], - 'waiting_confirmation': ['confirmed', 'cancelled'], - 'confirmed': ['installing', 'completed'], - 'installing': ['completed', 'error'], - 'completed': [], - 'error': [], - 'cancelled': [] + "initial": ["running_command"], + "running_command": ["waiting_confirmation", "completed"], + "waiting_confirmation": ["confirmed", "cancelled"], + "confirmed": ["installing", "completed"], + "installing": ["completed", "error"], + "completed": [], + "error": [], + "cancelled": [], } self.prompt_patterns = [ - (r'Do you want to continue\?', 'confirmation'), - (r'After this operation.*installed\.', 'size_info'), - (r'Need to get.*B of archives\.', 'download_info'), - (r'Unpacking.*Configuring', 'configuring'), - (r'Setting up', 'setting_up'), - (r'E:\s', 'error') + (r"Do you want to continue\?", "confirmation"), + (r"After this operation.*installed\.", "size_info"), + (r"Need to get.*B of archives\.", "download_info"), + (r"Unpacking.*Configuring", "configuring"), + (r"Setting up", "setting_up"), + (r"E:\s", "error"), ] def get_process_type(self): - return 'apt' + return "apt" def update_state(self, output): """Update state based on apt output patterns.""" output_lower = output.lower() # Check for completion - if 'processing triggers' in output_lower or 'done' in output_lower: - self.current_state = 'completed' + if "processing triggers" in output_lower or "done" in output_lower: + self.current_state = "completed" # Check for confirmation prompts - elif 'do you want to continue' in output_lower: - self.current_state = 'waiting_confirmation' + elif "do you want to continue" in output_lower: + self.current_state = "waiting_confirmation" # Check for installation progress - elif 'setting up' in output_lower or 'unpacking' in output_lower: - self.current_state = 'installing' + elif "setting up" in output_lower or "unpacking" in output_lower: + self.current_state = "installing" # Check for errors - elif 'e:' in output_lower or 'error' in output_lower: - self.current_state = 'error' + elif "e:" in output_lower or "error" in output_lower: + self.current_state = "error" def get_prompt_suggestions(self): """Return suggested responses for apt prompts.""" suggestions = super().get_prompt_suggestions() - if self.current_state == 'waiting_confirmation': - suggestions.extend(['y', 'yes', 'n', 'no']) + if self.current_state == "waiting_confirmation": + suggestions.extend(["y", "yes", "n", "no"]) return suggestions + class VimHandler(ProcessHandler): """Handler for vim editor interactions.""" def __init__(self, multiplexer): super().__init__(multiplexer) self.state_machine = { - 'initial': ['normal_mode', 'insert_mode'], - 'normal_mode': ['insert_mode', 'command_mode', 'visual_mode'], - 'insert_mode': ['normal_mode'], - 'command_mode': ['normal_mode'], - 'visual_mode': ['normal_mode'], - 'exiting': [] + "initial": ["normal_mode", "insert_mode"], + "normal_mode": ["insert_mode", "command_mode", "visual_mode"], + "insert_mode": ["normal_mode"], + "command_mode": ["normal_mode"], + "visual_mode": ["normal_mode"], + "exiting": [], } self.prompt_patterns = [ - (r'-- INSERT --', 'insert_mode'), - (r'-- VISUAL --', 'visual_mode'), - (r':', 'command_mode'), - (r'Press ENTER', 'waiting_enter'), - (r'Saved', 'saved') + (r"-- INSERT --", "insert_mode"), + (r"-- VISUAL --", "visual_mode"), + (r":", "command_mode"), + (r"Press ENTER", "waiting_enter"), + (r"Saved", "saved"), ] self.mode_indicators = { - 'insert': '-- INSERT --', - 'visual': '-- VISUAL --', - 'command': ':' + "insert": "-- INSERT --", + "visual": "-- VISUAL --", + "command": ":", } def get_process_type(self): - return 'vim' + return "vim" def update_state(self, output): """Update state based on vim mode indicators.""" - if '-- INSERT --' in output: - self.current_state = 'insert_mode' - elif '-- VISUAL --' in output: - self.current_state = 'visual_mode' - elif output.strip().endswith(':'): - self.current_state = 'command_mode' - elif 'Press ENTER' in output: - self.current_state = 'waiting_enter' + if "-- INSERT --" in output: + self.current_state = "insert_mode" + elif "-- VISUAL --" in output: + self.current_state = "visual_mode" + elif output.strip().endswith(":"): + self.current_state = "command_mode" + elif "Press ENTER" in output: + self.current_state = "waiting_enter" else: # Default to normal mode if no specific indicators - self.current_state = 'normal_mode' + self.current_state = "normal_mode" def get_prompt_suggestions(self): """Return suggested commands for vim modes.""" suggestions = super().get_prompt_suggestions() - if self.current_state == 'command_mode': - suggestions.extend(['w', 'q', 'wq', 'q!', 'w!']) - elif self.current_state == 'normal_mode': - suggestions.extend(['i', 'a', 'o', 'dd', ':w', ':q']) - elif self.current_state == 'waiting_enter': - suggestions.extend(['\n']) + if self.current_state == "command_mode": + suggestions.extend(["w", "q", "wq", "q!", "w!"]) + elif self.current_state == "normal_mode": + suggestions.extend(["i", "a", "o", "dd", ":w", ":q"]) + elif self.current_state == "waiting_enter": + suggestions.extend(["\n"]) return suggestions + class SSHHandler(ProcessHandler): """Handler for SSH connection interactions.""" def __init__(self, multiplexer): super().__init__(multiplexer) self.state_machine = { - 'initial': ['connecting'], - 'connecting': ['auth_prompt', 'connected', 'failed'], - 'auth_prompt': ['connected', 'failed'], - 'connected': ['shell', 'disconnected'], - 'shell': ['disconnected'], - 'failed': [], - 'disconnected': [] + "initial": ["connecting"], + "connecting": ["auth_prompt", "connected", "failed"], + "auth_prompt": ["connected", "failed"], + "connected": ["shell", "disconnected"], + "shell": ["disconnected"], + "failed": [], + "disconnected": [], } self.prompt_patterns = [ - (r'password:', 'password_prompt'), - (r'yes/no', 'host_key_prompt'), - (r'Permission denied', 'auth_failed'), - (r'Welcome to', 'connected'), - (r'\$', 'shell_prompt'), - (r'\#', 'root_shell_prompt'), - (r'Connection closed', 'disconnected') + (r"password:", "password_prompt"), + (r"yes/no", "host_key_prompt"), + (r"Permission denied", "auth_failed"), + (r"Welcome to", "connected"), + (r"\$", "shell_prompt"), + (r"\#", "root_shell_prompt"), + (r"Connection closed", "disconnected"), ] def get_process_type(self): - return 'ssh' + return "ssh" def update_state(self, output): """Update state based on SSH connection output.""" output_lower = output.lower() - if 'permission denied' in output_lower: - self.current_state = 'failed' - elif 'password:' in output_lower: - self.current_state = 'auth_prompt' - elif 'yes/no' in output_lower: - self.current_state = 'auth_prompt' - elif 'welcome to' in output_lower or 'last login' in output_lower: - self.current_state = 'connected' - elif output.strip().endswith('$') or output.strip().endswith('#'): - self.current_state = 'shell' - elif 'connection closed' in output_lower: - self.current_state = 'disconnected' + if "permission denied" in output_lower: + self.current_state = "failed" + elif "password:" in output_lower: + self.current_state = "auth_prompt" + elif "yes/no" in output_lower: + self.current_state = "auth_prompt" + elif "welcome to" in output_lower or "last login" in output_lower: + self.current_state = "connected" + elif output.strip().endswith("$") or output.strip().endswith("#"): + self.current_state = "shell" + elif "connection closed" in output_lower: + self.current_state = "disconnected" def get_prompt_suggestions(self): """Return suggested responses for SSH prompts.""" suggestions = super().get_prompt_suggestions() - if self.current_state == 'auth_prompt': - if 'password:' in self.multiplexer.get_all_output()['stdout']: - suggestions.extend(['']) # Placeholder for actual password - elif 'yes/no' in self.multiplexer.get_all_output()['stdout']: - suggestions.extend(['yes', 'no']) + if self.current_state == "auth_prompt": + if "password:" in self.multiplexer.get_all_output()["stdout"]: + suggestions.extend([""]) # Placeholder for actual password + elif "yes/no" in self.multiplexer.get_all_output()["stdout"]: + suggestions.extend(["yes", "no"]) return suggestions + class GenericProcessHandler(ProcessHandler): """Fallback handler for unknown process types.""" def __init__(self, multiplexer): super().__init__(multiplexer) self.state_machine = { - 'initial': ['running'], - 'running': ['waiting_input', 'completed'], - 'waiting_input': ['running'], - 'completed': [] + "initial": ["running"], + "running": ["waiting_input", "completed"], + "waiting_input": ["running"], + "completed": [], } self.prompt_patterns = [ - (r'\?\s*$', 'waiting_input'), # Lines ending with ? - (r'>\s*$', 'waiting_input'), # Lines ending with > - (r':\s*$', 'waiting_input'), # Lines ending with : - (r'done', 'completed'), - (r'finished', 'completed'), - (r'exit code', 'completed') + (r"\?\s*$", "waiting_input"), # Lines ending with ? + (r">\s*$", "waiting_input"), # Lines ending with > + (r":\s*$", "waiting_input"), # Lines ending with : + (r"done", "completed"), + (r"finished", "completed"), + (r"exit code", "completed"), ] def get_process_type(self): - return 'generic' + return "generic" def update_state(self, output): """Basic state detection for generic processes.""" output_lower = output.lower() - if any(pattern in output_lower for pattern in ['done', 'finished', 'complete']): - self.current_state = 'completed' - elif any(output.strip().endswith(char) for char in ['?', '>', ':']): - self.current_state = 'waiting_input' + if any(pattern in output_lower for pattern in ["done", "finished", "complete"]): + self.current_state = "completed" + elif any(output.strip().endswith(char) for char in ["?", ">", ":"]): + self.current_state = "waiting_input" else: - self.current_state = 'running' + self.current_state = "running" + # Handler registry _handler_classes = { - 'apt': AptHandler, - 'vim': VimHandler, - 'ssh': SSHHandler, - 'generic': GenericProcessHandler + "apt": AptHandler, + "vim": VimHandler, + "ssh": SSHHandler, + "generic": GenericProcessHandler, } + def get_handler_for_process(process_type, multiplexer): """Get appropriate handler for a process type.""" handler_class = _handler_classes.get(process_type, GenericProcessHandler) return handler_class(multiplexer) + def detect_process_type(command): """Detect process type from command.""" - command_str = ' '.join(command) if isinstance(command, list) else command + command_str = " ".join(command) if isinstance(command, list) else command command_lower = command_str.lower() - if 'apt' in command_lower or 'apt-get' in command_lower: - return 'apt' - elif 'vim' in command_lower or 'vi ' in command_lower: - return 'vim' - elif 'ssh' in command_lower: - return 'ssh' + if "apt" in command_lower or "apt-get" in command_lower: + return "apt" + elif "vim" in command_lower or "vi " in command_lower: + return "vim" + elif "ssh" in command_lower: + return "ssh" else: - return 'generic' - return 'ssh' + return "generic" + return "ssh" + + def detect_process_type(command): """Detect process type from command.""" - command_str = ' '.join(command) if isinstance(command, list) else command + command_str = " ".join(command) if isinstance(command, list) else command command_lower = command_str.lower() - if 'apt' in command_lower or 'apt-get' in command_lower: - return 'apt' - elif 'vim' in command_lower or 'vi ' in command_lower: - return 'vim' - elif 'ssh' in command_lower: - return 'ssh' + if "apt" in command_lower or "apt-get" in command_lower: + return "apt" + elif "vim" in command_lower or "vi " in command_lower: + return "vim" + elif "ssh" in command_lower: + return "ssh" else: - return 'generic' + return "generic" diff --git a/pr/tools/prompt_detection.py b/pr/tools/prompt_detection.py index 7ad9077..022366e 100644 --- a/pr/tools/prompt_detection.py +++ b/pr/tools/prompt_detection.py @@ -1,6 +1,6 @@ import re import time -from collections import defaultdict + class PromptDetector: """Detects various process prompts and manages interaction state.""" @@ -10,101 +10,119 @@ class PromptDetector: self.state_machines = self._load_state_machines() self.session_states = {} self.timeout_configs = { - 'default': 30, # 30 seconds default timeout - 'apt': 300, # 5 minutes for apt operations - 'ssh': 60, # 1 minute for SSH connections - 'vim': 3600 # 1 hour for vim sessions + "default": 30, # 30 seconds default timeout + "apt": 300, # 5 minutes for apt operations + "ssh": 60, # 1 minute for SSH connections + "vim": 3600, # 1 hour for vim sessions } def _load_prompt_patterns(self): """Load regex patterns for detecting various prompts.""" return { - 'bash_prompt': [ - re.compile(r'[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$'), - re.compile(r'\$\s*$'), - re.compile(r'#\s*$'), - re.compile(r'>\s*$') # Continuation prompt + "bash_prompt": [ + re.compile(r"[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$"), + re.compile(r"\$\s*$"), + re.compile(r"#\s*$"), + re.compile(r">\s*$"), # Continuation prompt ], - 'confirmation': [ - re.compile(r'[Yy]/[Nn]', re.IGNORECASE), - re.compile(r'[Yy]es/[Nn]o', re.IGNORECASE), - re.compile(r'continue\?', re.IGNORECASE), - re.compile(r'proceed\?', re.IGNORECASE) + "confirmation": [ + re.compile(r"[Yy]/[Nn]", re.IGNORECASE), + re.compile(r"[Yy]es/[Nn]o", re.IGNORECASE), + re.compile(r"continue\?", re.IGNORECASE), + re.compile(r"proceed\?", re.IGNORECASE), ], - 'password': [ - re.compile(r'password:', re.IGNORECASE), - re.compile(r'passphrase:', re.IGNORECASE), - re.compile(r'enter password', re.IGNORECASE) + "password": [ + re.compile(r"password:", re.IGNORECASE), + re.compile(r"passphrase:", re.IGNORECASE), + re.compile(r"enter password", re.IGNORECASE), ], - 'sudo_password': [ - re.compile(r'\[sudo\].*password', re.IGNORECASE) + "sudo_password": [re.compile(r"\[sudo\].*password", re.IGNORECASE)], + "apt": [ + re.compile(r"Do you want to continue\?", re.IGNORECASE), + re.compile(r"After this operation", re.IGNORECASE), + re.compile(r"Need to get", re.IGNORECASE), ], - 'apt': [ - re.compile(r'Do you want to continue\?', re.IGNORECASE), - re.compile(r'After this operation', re.IGNORECASE), - re.compile(r'Need to get', re.IGNORECASE) + "vim": [ + re.compile(r"-- INSERT --"), + re.compile(r"-- VISUAL --"), + re.compile(r":"), + re.compile(r"Press ENTER", re.IGNORECASE), ], - 'vim': [ - re.compile(r'-- INSERT --'), - re.compile(r'-- VISUAL --'), - re.compile(r':'), - re.compile(r'Press ENTER', re.IGNORECASE) + "ssh": [ + re.compile(r"yes/no", re.IGNORECASE), + re.compile(r"password:", re.IGNORECASE), + re.compile(r"Permission denied", re.IGNORECASE), ], - 'ssh': [ - re.compile(r'yes/no', re.IGNORECASE), - re.compile(r'password:', re.IGNORECASE), - re.compile(r'Permission denied', re.IGNORECASE) + "git": [ + re.compile(r"Username:", re.IGNORECASE), + re.compile(r"Email:", re.IGNORECASE), ], - 'git': [ - re.compile(r'Username:', re.IGNORECASE), - re.compile(r'Email:', re.IGNORECASE) + "error": [ + re.compile(r"error:", re.IGNORECASE), + re.compile(r"failed", re.IGNORECASE), + re.compile(r"exception", re.IGNORECASE), ], - 'error': [ - re.compile(r'error:', re.IGNORECASE), - re.compile(r'failed', re.IGNORECASE), - re.compile(r'exception', re.IGNORECASE) - ] } def _load_state_machines(self): """Load state machines for different process types.""" return { - 'apt': { - 'states': ['initial', 'running', 'confirming', 'installing', 'completed', 'error'], - 'transitions': { - 'initial': ['running'], - 'running': ['confirming', 'installing', 'completed', 'error'], - 'confirming': ['installing', 'cancelled'], - 'installing': ['completed', 'error'], - 'completed': [], - 'error': [], - 'cancelled': [] - } + "apt": { + "states": [ + "initial", + "running", + "confirming", + "installing", + "completed", + "error", + ], + "transitions": { + "initial": ["running"], + "running": ["confirming", "installing", "completed", "error"], + "confirming": ["installing", "cancelled"], + "installing": ["completed", "error"], + "completed": [], + "error": [], + "cancelled": [], + }, }, - 'ssh': { - 'states': ['initial', 'connecting', 'authenticating', 'connected', 'error'], - 'transitions': { - 'initial': ['connecting'], - 'connecting': ['authenticating', 'connected', 'error'], - 'authenticating': ['connected', 'error'], - 'connected': ['error'], - 'error': [] - } + "ssh": { + "states": [ + "initial", + "connecting", + "authenticating", + "connected", + "error", + ], + "transitions": { + "initial": ["connecting"], + "connecting": ["authenticating", "connected", "error"], + "authenticating": ["connected", "error"], + "connected": ["error"], + "error": [], + }, + }, + "vim": { + "states": [ + "initial", + "normal", + "insert", + "visual", + "command", + "exiting", + ], + "transitions": { + "initial": ["normal", "insert"], + "normal": ["insert", "visual", "command", "exiting"], + "insert": ["normal"], + "visual": ["normal"], + "command": ["normal", "exiting"], + "exiting": [], + }, }, - 'vim': { - 'states': ['initial', 'normal', 'insert', 'visual', 'command', 'exiting'], - 'transitions': { - 'initial': ['normal', 'insert'], - 'normal': ['insert', 'visual', 'command', 'exiting'], - 'insert': ['normal'], - 'visual': ['normal'], - 'command': ['normal', 'exiting'], - 'exiting': [] - } - } } - def detect_prompt(self, output, process_type='generic'): + def detect_prompt(self, output, process_type="generic"): """Detect what type of prompt is present in the output.""" detections = {} @@ -125,93 +143,97 @@ class PromptDetector: return detections - def get_response_suggestions(self, prompt_detections, process_type='generic'): + def get_response_suggestions(self, prompt_detections, process_type="generic"): """Get suggested responses based on detected prompts.""" suggestions = [] for category, patterns in prompt_detections.items(): - if category == 'confirmation': - suggestions.extend(['y', 'yes', 'n', 'no']) - elif category == 'password': - suggestions.append('') - elif category == 'sudo_password': - suggestions.append('') - elif category == 'apt': - if any('continue' in p for p in patterns): - suggestions.extend(['y', 'yes']) - elif category == 'vim': - if any(':' in p for p in patterns): - suggestions.extend(['w', 'q', 'wq', 'q!']) - elif any('ENTER' in p for p in patterns): - suggestions.append('\n') - elif category == 'ssh': - if any('yes/no' in p for p in patterns): - suggestions.extend(['yes', 'no']) - elif any('password' in p for p in patterns): - suggestions.append('') - elif category == 'bash_prompt': - suggestions.extend(['help', 'ls', 'pwd', 'exit']) + if category == "confirmation": + suggestions.extend(["y", "yes", "n", "no"]) + elif category == "password": + suggestions.append("") + elif category == "sudo_password": + suggestions.append("") + elif category == "apt": + if any("continue" in p for p in patterns): + suggestions.extend(["y", "yes"]) + elif category == "vim": + if any(":" in p for p in patterns): + suggestions.extend(["w", "q", "wq", "q!"]) + elif any("ENTER" in p for p in patterns): + suggestions.append("\n") + elif category == "ssh": + if any("yes/no" in p for p in patterns): + suggestions.extend(["yes", "no"]) + elif any("password" in p for p in patterns): + suggestions.append("") + elif category == "bash_prompt": + suggestions.extend(["help", "ls", "pwd", "exit"]) return list(set(suggestions)) # Remove duplicates - def update_session_state(self, session_name, output, process_type='generic'): + def update_session_state(self, session_name, output, process_type="generic"): """Update the state machine for a session based on output.""" if session_name not in self.session_states: self.session_states[session_name] = { - 'current_state': 'initial', - 'process_type': process_type, - 'last_activity': time.time(), - 'transitions': [] + "current_state": "initial", + "process_type": process_type, + "last_activity": time.time(), + "transitions": [], } session_state = self.session_states[session_name] - old_state = session_state['current_state'] + old_state = session_state["current_state"] # Detect prompts and determine new state detections = self.detect_prompt(output, process_type) - new_state = self._determine_state_from_detections(detections, process_type, old_state) + new_state = self._determine_state_from_detections( + detections, process_type, old_state + ) if new_state != old_state: - session_state['transitions'].append({ - 'from': old_state, - 'to': new_state, - 'timestamp': time.time(), - 'trigger': detections - }) - session_state['current_state'] = new_state + session_state["transitions"].append( + { + "from": old_state, + "to": new_state, + "timestamp": time.time(), + "trigger": detections, + } + ) + session_state["current_state"] = new_state - session_state['last_activity'] = time.time() + session_state["last_activity"] = time.time() return new_state def _determine_state_from_detections(self, detections, process_type, current_state): """Determine new state based on prompt detections.""" if process_type in self.state_machines: - state_machine = self.state_machines[process_type] + self.state_machines[process_type] # State transition logic based on detections - if 'confirmation' in detections and current_state in ['running', 'initial']: - return 'confirming' - elif 'password' in detections or 'sudo_password' in detections: - return 'authenticating' - elif 'error' in detections: - return 'error' - elif 'bash_prompt' in detections and current_state != 'initial': - return 'connected' if process_type == 'ssh' else 'completed' - elif 'vim' in detections: - if any('-- INSERT --' in p for p in detections.get('vim', [])): - return 'insert' - elif any('-- VISUAL --' in p for p in detections.get('vim', [])): - return 'visual' - elif any(':' in p for p in detections.get('vim', [])): - return 'command' + if "confirmation" in detections and current_state in ["running", "initial"]: + return "confirming" + elif "password" in detections or "sudo_password" in detections: + return "authenticating" + elif "error" in detections: + return "error" + elif "bash_prompt" in detections and current_state != "initial": + return "connected" if process_type == "ssh" else "completed" + elif "vim" in detections: + if any("-- INSERT --" in p for p in detections.get("vim", [])): + return "insert" + elif any("-- VISUAL --" in p for p in detections.get("vim", [])): + return "visual" + elif any(":" in p for p in detections.get("vim", [])): + return "command" # Default state progression - if current_state == 'initial': - return 'running' - elif current_state == 'running' and detections: - return 'waiting_input' - elif current_state == 'waiting_input' and not detections: - return 'running' + if current_state == "initial": + return "running" + elif current_state == "running" and detections: + return "waiting_input" + elif current_state == "waiting_input" and not detections: + return "running" return current_state @@ -220,15 +242,15 @@ class PromptDetector: if session_name not in self.session_states: return False - state = self.session_states[session_name]['current_state'] - process_type = self.session_states[session_name]['process_type'] + state = self.session_states[session_name]["current_state"] + process_type = self.session_states[session_name]["process_type"] # States that typically indicate waiting for input waiting_states = { - 'generic': ['waiting_input'], - 'apt': ['confirming'], - 'ssh': ['authenticating'], - 'vim': ['command', 'insert', 'visual'] + "generic": ["waiting_input"], + "apt": ["confirming"], + "ssh": ["authenticating"], + "vim": ["command", "insert", "visual"], } return state in waiting_states.get(process_type, []) @@ -236,10 +258,10 @@ class PromptDetector: def get_session_timeout(self, session_name): """Get the timeout for a session based on its process type.""" if session_name not in self.session_states: - return self.timeout_configs['default'] + return self.timeout_configs["default"] - process_type = self.session_states[session_name]['process_type'] - return self.timeout_configs.get(process_type, self.timeout_configs['default']) + process_type = self.session_states[session_name]["process_type"] + return self.timeout_configs.get(process_type, self.timeout_configs["default"]) def check_for_timeouts(self): """Check all sessions for timeouts and return timed out sessions.""" @@ -248,7 +270,7 @@ class PromptDetector: for session_name, state in self.session_states.items(): timeout = self.get_session_timeout(session_name) - if current_time - state['last_activity'] > timeout: + if current_time - state["last_activity"] > timeout: timed_out.append(session_name) return timed_out @@ -260,16 +282,18 @@ class PromptDetector: state = self.session_states[session_name] return { - 'current_state': state['current_state'], - 'process_type': state['process_type'], - 'last_activity': state['last_activity'], - 'transitions': state['transitions'][-5:], # Last 5 transitions - 'is_waiting': self.is_waiting_for_input(session_name) + "current_state": state["current_state"], + "process_type": state["process_type"], + "last_activity": state["last_activity"], + "transitions": state["transitions"][-5:], # Last 5 transitions + "is_waiting": self.is_waiting_for_input(session_name), } + # Global detector instance _detector = None + def get_global_detector(): """Get the global prompt detector instance.""" global _detector diff --git a/pr/tools/python_exec.py b/pr/tools/python_exec.py index bc4bd91..fad609e 100644 --- a/pr/tools/python_exec.py +++ b/pr/tools/python_exec.py @@ -1,6 +1,7 @@ +import contextlib import traceback from io import StringIO -import contextlib + def python_exec(code, python_globals): try: diff --git a/pr/tools/web.py b/pr/tools/web.py index f7a214a..d3d23a7 100644 --- a/pr/tools/web.py +++ b/pr/tools/web.py @@ -1,7 +1,8 @@ -import urllib.request -import urllib.parse -import urllib.error import json +import urllib.error +import urllib.parse +import urllib.request + def http_fetch(url, headers=None): try: @@ -11,26 +12,28 @@ def http_fetch(url, headers=None): req.add_header(key, value) with urllib.request.urlopen(req) as response: - content = response.read().decode('utf-8') + content = response.read().decode("utf-8") return {"status": "success", "content": content[:10000]} except Exception as e: return {"status": "error", "error": str(e)} + def _perform_search(base_url, query, params=None): try: full_url = f"https://static.molodetz.nl/search.cgi?query={query}" with urllib.request.urlopen(full_url) as response: - content = response.read().decode('utf-8') + content = response.read().decode("utf-8") return {"status": "success", "content": json.loads(content)} except Exception as e: return {"status": "error", "error": str(e)} + def web_search(query): base_url = "https://search.molodetz.nl/search" return _perform_search(base_url, query) + def web_search_news(query): base_url = "https://search.molodetz.nl/search" return _perform_search(base_url, query) - diff --git a/pr/ui/__init__.py b/pr/ui/__init__.py index 054f8c2..f2520da 100644 --- a/pr/ui/__init__.py +++ b/pr/ui/__init__.py @@ -1,5 +1,11 @@ from pr.ui.colors import Colors -from pr.ui.rendering import highlight_code, render_markdown from pr.ui.display import display_tool_call, print_autonomous_header +from pr.ui.rendering import highlight_code, render_markdown -__all__ = ['Colors', 'highlight_code', 'render_markdown', 'display_tool_call', 'print_autonomous_header'] +__all__ = [ + "Colors", + "highlight_code", + "render_markdown", + "display_tool_call", + "print_autonomous_header", +] diff --git a/pr/ui/colors.py b/pr/ui/colors.py index 9527570..2c882c6 100644 --- a/pr/ui/colors.py +++ b/pr/ui/colors.py @@ -1,14 +1,14 @@ class Colors: - RESET = '\033[0m' - BOLD = '\033[1m' - RED = '\033[91m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - BLUE = '\033[94m' - MAGENTA = '\033[95m' - CYAN = '\033[96m' - GRAY = '\033[90m' - WHITE = '\033[97m' - BG_BLUE = '\033[44m' - BG_GREEN = '\033[42m' - BG_RED = '\033[41m' + RESET = "\033[0m" + BOLD = "\033[1m" + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + MAGENTA = "\033[95m" + CYAN = "\033[96m" + GRAY = "\033[90m" + WHITE = "\033[97m" + BG_BLUE = "\033[44m" + BG_GREEN = "\033[42m" + BG_RED = "\033[41m" diff --git a/pr/ui/diff_display.py b/pr/ui/diff_display.py index 6562b8f..2e83072 100644 --- a/pr/ui/diff_display.py +++ b/pr/ui/diff_display.py @@ -1,5 +1,6 @@ import difflib -from typing import List, Tuple, Dict, Optional +from typing import Dict, List, Optional, Tuple + from .colors import Colors @@ -19,8 +20,13 @@ class DiffStats: class DiffLine: - def __init__(self, line_type: str, content: str, old_line_num: Optional[int] = None, - new_line_num: Optional[int] = None): + def __init__( + self, + line_type: str, + content: str, + old_line_num: Optional[int] = None, + new_line_num: Optional[int] = None, + ): self.line_type = line_type self.content = content self.old_line_num = old_line_num @@ -28,27 +34,27 @@ class DiffLine: def format(self, show_line_nums: bool = True) -> str: color = { - 'add': Colors.GREEN, - 'delete': Colors.RED, - 'context': Colors.GRAY, - 'header': Colors.CYAN, - 'stats': Colors.BLUE + "add": Colors.GREEN, + "delete": Colors.RED, + "context": Colors.GRAY, + "header": Colors.CYAN, + "stats": Colors.BLUE, }.get(self.line_type, Colors.RESET) prefix = { - 'add': '+ ', - 'delete': '- ', - 'context': ' ', - 'header': '', - 'stats': '' - }.get(self.line_type, ' ') + "add": "+ ", + "delete": "- ", + "context": " ", + "header": "", + "stats": "", + }.get(self.line_type, " ") - if show_line_nums and self.line_type in ('add', 'delete', 'context'): - old_num = str(self.old_line_num) if self.old_line_num else ' ' - new_num = str(self.new_line_num) if self.new_line_num else ' ' + if show_line_nums and self.line_type in ("add", "delete", "context"): + old_num = str(self.old_line_num) if self.old_line_num else " " + new_num = str(self.new_line_num) if self.new_line_num else " " line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} " else: - line_num_str = '' + line_num_str = "" return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}" @@ -57,8 +63,9 @@ class DiffDisplay: def __init__(self, context_lines: int = 3): self.context_lines = context_lines - def create_diff(self, old_content: str, new_content: str, - filename: str = "file") -> Tuple[List[DiffLine], DiffStats]: + def create_diff( + self, old_content: str, new_content: str, filename: str = "file" + ) -> Tuple[List[DiffLine], DiffStats]: old_lines = old_content.splitlines(keepends=True) new_lines = new_content.splitlines(keepends=True) @@ -67,31 +74,38 @@ class DiffDisplay: stats.files_changed = 1 diff = difflib.unified_diff( - old_lines, new_lines, + old_lines, + new_lines, fromfile=f"a/{filename}", tofile=f"b/{filename}", - n=self.context_lines + n=self.context_lines, ) old_line_num = 0 new_line_num = 0 for line in diff: - if line.startswith('---') or line.startswith('+++'): - diff_lines.append(DiffLine('header', line.rstrip())) - elif line.startswith('@@'): - diff_lines.append(DiffLine('header', line.rstrip())) + if line.startswith("---") or line.startswith("+++"): + diff_lines.append(DiffLine("header", line.rstrip())) + elif line.startswith("@@"): + diff_lines.append(DiffLine("header", line.rstrip())) old_line_num, new_line_num = self._parse_hunk_header(line) - elif line.startswith('+'): + elif line.startswith("+"): stats.insertions += 1 - diff_lines.append(DiffLine('add', line[1:].rstrip(), None, new_line_num)) + diff_lines.append( + DiffLine("add", line[1:].rstrip(), None, new_line_num) + ) new_line_num += 1 - elif line.startswith('-'): + elif line.startswith("-"): stats.deletions += 1 - diff_lines.append(DiffLine('delete', line[1:].rstrip(), old_line_num, None)) + diff_lines.append( + DiffLine("delete", line[1:].rstrip(), old_line_num, None) + ) old_line_num += 1 - elif line.startswith(' '): - diff_lines.append(DiffLine('context', line[1:].rstrip(), old_line_num, new_line_num)) + elif line.startswith(" "): + diff_lines.append( + DiffLine("context", line[1:].rstrip(), old_line_num, new_line_num) + ) old_line_num += 1 new_line_num += 1 @@ -101,15 +115,20 @@ class DiffDisplay: def _parse_hunk_header(self, header: str) -> Tuple[int, int]: try: - parts = header.split('@@')[1].strip().split() - old_start = int(parts[0].split(',')[0].replace('-', '')) - new_start = int(parts[1].split(',')[0].replace('+', '')) + parts = header.split("@@")[1].strip().split() + old_start = int(parts[0].split(",")[0].replace("-", "")) + new_start = int(parts[1].split(",")[0].replace("+", "")) return old_start, new_start except (IndexError, ValueError): return 0, 0 - def render_diff(self, diff_lines: List[DiffLine], stats: DiffStats, - show_line_nums: bool = True, show_stats: bool = True) -> str: + def render_diff( + self, + diff_lines: List[DiffLine], + stats: DiffStats, + show_line_nums: bool = True, + show_stats: bool = True, + ) -> str: output = [] if show_stats: @@ -124,10 +143,15 @@ class DiffDisplay: if show_stats: output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") - return '\n'.join(output) + return "\n".join(output) - def display_file_diff(self, old_content: str, new_content: str, - filename: str = "file", show_line_nums: bool = True) -> str: + def display_file_diff( + self, + old_content: str, + new_content: str, + filename: str = "file", + show_line_nums: bool = True, + ) -> str: diff_lines, stats = self.create_diff(old_content, new_content, filename) if not diff_lines: @@ -135,8 +159,13 @@ class DiffDisplay: return self.render_diff(diff_lines, stats, show_line_nums) - def display_side_by_side(self, old_content: str, new_content: str, - filename: str = "file", width: int = 80) -> str: + def display_side_by_side( + self, + old_content: str, + new_content: str, + filename: str = "file", + width: int = 80, + ) -> str: old_lines = old_content.splitlines() new_lines = new_content.splitlines() @@ -144,40 +173,57 @@ class DiffDisplay: output = [] output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}") - output.append(f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}") + output.append( + f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}" + ) output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n") half_width = (width - 5) // 2 for tag, i1, i2, j1, j2 in matcher.get_opcodes(): - if tag == 'equal': - for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])): + if tag == "equal": + for i, (old_line, new_line) in enumerate( + zip(old_lines[i1:i2], new_lines[j1:j2]) + ): old_display = old_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width) - output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}") - elif tag == 'replace': + output.append( + f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}" + ) + elif tag == "replace": max_lines = max(i2 - i1, j2 - j1) for i in range(max_lines): old_line = old_lines[i1 + i] if i1 + i < i2 else "" new_line = new_lines[j1 + i] if j1 + i < j2 else "" old_display = old_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width) - output.append(f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}") - elif tag == 'delete': + output.append( + f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}" + ) + elif tag == "delete": for old_line in old_lines[i1:i2]: old_display = old_line[:half_width].ljust(half_width) - output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}") - elif tag == 'insert': + output.append( + f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}" + ) + elif tag == "insert": for new_line in new_lines[j1:j2]: new_display = new_line[:half_width].ljust(half_width) - output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}") + output.append( + f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}" + ) output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n") - return '\n'.join(output) + return "\n".join(output) -def display_diff(old_content: str, new_content: str, filename: str = "file", - format_type: str = "unified", context_lines: int = 3) -> str: +def display_diff( + old_content: str, + new_content: str, + filename: str = "file", + format_type: str = "unified", + context_lines: int = 3, +) -> str: displayer = DiffDisplay(context_lines) if format_type == "side-by-side": @@ -191,9 +237,9 @@ def get_diff_stats(old_content: str, new_content: str) -> Dict[str, int]: _, stats = displayer.create_diff(old_content, new_content) return { - 'insertions': stats.insertions, - 'deletions': stats.deletions, - 'modifications': stats.modifications, - 'total_changes': stats.total_changes, - 'files_changed': stats.files_changed + "insertions": stats.insertions, + "deletions": stats.deletions, + "modifications": stats.modifications, + "total_changes": stats.total_changes, + "files_changed": stats.files_changed, } diff --git a/pr/ui/display.py b/pr/ui/display.py index c0a6b2e..a3aa80c 100644 --- a/pr/ui/display.py +++ b/pr/ui/display.py @@ -1,8 +1,6 @@ -import json -import time -from typing import Dict, Any from pr.ui.colors import Colors + def display_tool_call(tool_name, arguments, status="running", result=None): if status == "running": return @@ -15,8 +13,11 @@ def display_tool_call(tool_name, arguments, status="running", result=None): print(f"{Colors.GRAY}{line}{Colors.RESET}") + def print_autonomous_header(task): print(f"{Colors.BOLD}Task:{Colors.RESET} {task}") - print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}") + print( + f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}" + ) print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n") print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n") diff --git a/pr/ui/edit_feedback.py b/pr/ui/edit_feedback.py index e9b6895..d1de847 100644 --- a/pr/ui/edit_feedback.py +++ b/pr/ui/edit_feedback.py @@ -1,12 +1,20 @@ -from typing import List, Dict, Optional from datetime import datetime +from typing import Dict, List, Optional + from .colors import Colors from .progress import ProgressBar class EditOperation: - def __init__(self, op_type: str, filepath: str, start_pos: int = 0, - end_pos: int = 0, content: str = "", old_content: str = ""): + def __init__( + self, + op_type: str, + filepath: str, + start_pos: int = 0, + end_pos: int = 0, + content: str = "", + old_content: str = "", + ): self.op_type = op_type self.filepath = filepath self.start_pos = start_pos @@ -18,40 +26,46 @@ class EditOperation: def format_operation(self) -> str: op_colors = { - 'INSERT': Colors.GREEN, - 'REPLACE': Colors.YELLOW, - 'DELETE': Colors.RED, - 'WRITE': Colors.BLUE + "INSERT": Colors.GREEN, + "REPLACE": Colors.YELLOW, + "DELETE": Colors.RED, + "WRITE": Colors.BLUE, } color = op_colors.get(self.op_type, Colors.RESET) status_icon = { - 'pending': '○', - 'in_progress': '◐', - 'completed': '●', - 'failed': '✗' - }.get(self.status, '○') + "pending": "○", + "in_progress": "◐", + "completed": "●", + "failed": "✗", + }.get(self.status, "○") return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}" def format_details(self, show_content: bool = True) -> str: output = [self.format_operation()] - if self.op_type in ('INSERT', 'REPLACE'): - output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}") + if self.op_type in ("INSERT", "REPLACE"): + output.append( + f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}" + ) if show_content: if self.old_content: - lines = self.old_content.split('\n') - preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '') + lines = self.old_content.split("\n") + preview = lines[0][:60] + ( + "..." if len(lines[0]) > 60 or len(lines) > 1 else "" + ) output.append(f" {Colors.RED}- {preview}{Colors.RESET}") if self.content: - lines = self.content.split('\n') - preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '') + lines = self.content.split("\n") + preview = lines[0][:60] + ( + "..." if len(lines[0]) > 60 or len(lines) > 1 else "" + ) output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}") - return '\n'.join(output) + return "\n".join(output) class EditTracker: @@ -76,11 +90,13 @@ class EditTracker: def get_stats(self) -> Dict[str, int]: stats = { - 'total': len(self.operations), - 'completed': sum(1 for op in self.operations if op.status == 'completed'), - 'pending': sum(1 for op in self.operations if op.status == 'pending'), - 'in_progress': sum(1 for op in self.operations if op.status == 'in_progress'), - 'failed': sum(1 for op in self.operations if op.status == 'failed') + "total": len(self.operations), + "completed": sum(1 for op in self.operations if op.status == "completed"), + "pending": sum(1 for op in self.operations if op.status == "pending"), + "in_progress": sum( + 1 for op in self.operations if op.status == "in_progress" + ), + "failed": sum(1 for op in self.operations if op.status == "failed"), } return stats @@ -88,7 +104,7 @@ class EditTracker: if not self.operations: return 0.0 stats = self.get_stats() - return (stats['completed'] / stats['total']) * 100 + return (stats["completed"] / stats["total"]) * 100 def display_progress(self) -> str: if not self.operations: @@ -96,26 +112,30 @@ class EditTracker: output = [] output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}") - output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}") + output.append( + f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}" + ) output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") stats = self.get_stats() - completion = self.get_completion_percentage() + self.get_completion_percentage() - progress_bar = ProgressBar(total=stats['total'], width=40) - progress_bar.current = stats['completed'] + progress_bar = ProgressBar(total=stats["total"], width=40) + progress_bar.current = stats["completed"] bar_display = progress_bar._get_bar_display() output.append(f"Progress: {bar_display}") - output.append(f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, " - f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n") + output.append( + f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, " + f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n" + ) output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}") for i, op in enumerate(self.operations[-5:], 1): output.append(f"{i}. {op.format_operation()}") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") - return '\n'.join(output) + return "\n".join(output) def display_timeline(self, show_content: bool = False) -> str: if not self.operations: @@ -134,18 +154,20 @@ class EditTracker: stats = self.get_stats() output.append(f"{Colors.BOLD}Summary:{Colors.RESET}") - output.append(f"{Colors.BLUE}Total operations: {stats['total']}, " - f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}") + output.append( + f"{Colors.BLUE}Total operations: {stats['total']}, " + f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}" + ) output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") - return '\n'.join(output) + return "\n".join(output) def display_summary(self) -> str: if not self.operations: return f"{Colors.GRAY}No edits to summarize{Colors.RESET}" stats = self.get_stats() - files_modified = len(set(op.filepath for op in self.operations)) + files_modified = len({op.filepath for op in self.operations}) output = [] output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}") @@ -156,7 +178,7 @@ class EditTracker: output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}") output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}") - if stats['failed'] > 0: + if stats["failed"] > 0: output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}") output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}") @@ -168,7 +190,7 @@ class EditTracker: output.append(f" {op_type}: {count}") output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n") - return '\n'.join(output) + return "\n".join(output) def clear(self): self.operations.clear() diff --git a/pr/ui/output.py b/pr/ui/output.py index 8ba894a..7106398 100644 --- a/pr/ui/output.py +++ b/pr/ui/output.py @@ -1,31 +1,31 @@ import json import sys -from typing import Any, Dict, List from datetime import datetime +from typing import Any class OutputFormatter: - def __init__(self, format_type: str = 'text', quiet: bool = False): + def __init__(self, format_type: str = "text", quiet: bool = False): self.format_type = format_type self.quiet = quiet - def output(self, data: Any, message_type: str = 'response'): - if self.quiet and message_type not in ['error', 'result']: + def output(self, data: Any, message_type: str = "response"): + if self.quiet and message_type not in ["error", "result"]: return - if self.format_type == 'json': + if self.format_type == "json": self._output_json(data, message_type) - elif self.format_type == 'structured': + elif self.format_type == "structured": self._output_structured(data, message_type) else: self._output_text(data, message_type) def _output_json(self, data: Any, message_type: str): output = { - 'type': message_type, - 'timestamp': datetime.now().isoformat(), - 'data': data + "type": message_type, + "timestamp": datetime.now().isoformat(), + "data": data, } print(json.dumps(output, indent=2)) @@ -46,24 +46,24 @@ class OutputFormatter: print(data) def error(self, message: str): - if self.format_type == 'json': - self._output_json({'error': message}, 'error') + if self.format_type == "json": + self._output_json({"error": message}, "error") else: print(f"Error: {message}", file=sys.stderr) def success(self, message: str): if not self.quiet: - if self.format_type == 'json': - self._output_json({'success': message}, 'success') + if self.format_type == "json": + self._output_json({"success": message}, "success") else: print(message) def info(self, message: str): if not self.quiet: - if self.format_type == 'json': - self._output_json({'info': message}, 'info') + if self.format_type == "json": + self._output_json({"info": message}, "info") else: print(message) def result(self, data: Any): - self.output(data, 'result') + self.output(data, "result") diff --git a/pr/ui/progress.py b/pr/ui/progress.py index 99902ab..5b3df23 100644 --- a/pr/ui/progress.py +++ b/pr/ui/progress.py @@ -1,6 +1,6 @@ import sys -import time import threading +import time class ProgressIndicator: @@ -30,15 +30,15 @@ class ProgressIndicator: self.running = False if self.thread: self.thread.join(timeout=1.0) - sys.stdout.write('\r' + ' ' * (len(self.message) + 10) + '\r') + sys.stdout.write("\r" + " " * (len(self.message) + 10) + "\r") sys.stdout.flush() def _animate(self): - spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] + spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] idx = 0 while self.running: - sys.stdout.write(f'\r{spinner[idx]} {self.message}...') + sys.stdout.write(f"\r{spinner[idx]} {self.message}...") sys.stdout.flush() idx = (idx + 1) % len(spinner) time.sleep(0.1) @@ -62,14 +62,20 @@ class ProgressBar: else: percent = int((self.current / self.total) * 100) - filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width - bar = '█' * filled + '░' * (self.width - filled) + filled = ( + int((self.current / self.total) * self.width) + if self.total > 0 + else self.width + ) + bar = "█" * filled + "░" * (self.width - filled) - sys.stdout.write(f'\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})') + sys.stdout.write( + f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})" + ) sys.stdout.flush() if self.current >= self.total: - sys.stdout.write('\n') + sys.stdout.write("\n") def finish(self): self.current = self.total diff --git a/pr/ui/rendering.py b/pr/ui/rendering.py index 61e8ef3..72f0613 100644 --- a/pr/ui/rendering.py +++ b/pr/ui/rendering.py @@ -1,90 +1,103 @@ import re -from pr.ui.colors import Colors + from pr.config import LANGUAGE_KEYWORDS +from pr.ui.colors import Colors + def highlight_code(code, language=None, syntax_highlighting=True): if not syntax_highlighting: return code if not language: - if 'def ' in code or 'import ' in code: - language = 'python' - elif 'function ' in code or 'const ' in code: - language = 'javascript' - elif 'public ' in code or 'class ' in code: - language = 'java' + if "def " in code or "import " in code: + language = "python" + elif "function " in code or "const " in code: + language = "javascript" + elif "public " in code or "class " in code: + language = "java" if language and language in LANGUAGE_KEYWORDS: keywords = LANGUAGE_KEYWORDS[language] for keyword in keywords: - pattern = r'\b' + re.escape(keyword) + r'\b' + pattern = r"\b" + re.escape(keyword) + r"\b" code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code) code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code) code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code) - code = re.sub(r'#(.*)$', f'{Colors.GRAY}#\\1{Colors.RESET}', code, flags=re.MULTILINE) - code = re.sub(r'//(.*)$', f'{Colors.GRAY}//\\1{Colors.RESET}', code, flags=re.MULTILINE) + code = re.sub( + r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE + ) + code = re.sub( + r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE + ) return code + def render_markdown(text, syntax_highlighting=True): if not syntax_highlighting: return text code_blocks = [] + def extract_code_block(match): - lang = match.group(1) or '' + lang = match.group(1) or "" code = match.group(2) - highlighted_code = highlight_code(code.strip('\n'), lang, syntax_highlighting) + highlighted_code = highlight_code(code.strip("\n"), lang, syntax_highlighting) placeholder = f"%%CODEBLOCK{len(code_blocks)}%%" - full_block = f'{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}' + full_block = f"{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}" code_blocks.append(full_block) return placeholder - text = re.sub(r'```(\w*)\n(.*?)\n?```', extract_code_block, text, flags=re.DOTALL) + text = re.sub(r"```(\w*)\n(.*?)\n?```", extract_code_block, text, flags=re.DOTALL) inline_codes = [] + def extract_inline_code(match): code = match.group(1) placeholder = f"%%INLINECODE{len(inline_codes)}%%" - inline_codes.append(f'{Colors.YELLOW}{code}{Colors.RESET}') + inline_codes.append(f"{Colors.YELLOW}{code}{Colors.RESET}") return placeholder - text = re.sub(r'`([^`]+)`', extract_inline_code, text) + text = re.sub(r"`([^`]+)`", extract_inline_code, text) - lines = text.split('\n') + lines = text.split("\n") processed_lines = [] for line in lines: - if line.startswith('### '): - line = f'{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}' - elif line.startswith('## '): - line = f'{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}' - elif line.startswith('# '): - line = f'{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}' - elif line.startswith('> '): - line = f'{Colors.CYAN}> {line[2:]}{Colors.RESET}' - elif re.match(r'^\s*[\*\-\+]\s', line): - match = re.match(r'^(\s*)([\*\-\+])(\s+.*)', line) + if line.startswith("### "): + line = f"{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}" + elif line.startswith("## "): + line = f"{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}" + elif line.startswith("# "): + line = f"{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}" + elif line.startswith("> "): + line = f"{Colors.CYAN}> {line[2:]}{Colors.RESET}" + elif re.match(r"^\s*[\*\-\+]\s", line): + match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line) if match: line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" - elif re.match(r'^\s*\d+\.\s', line): - match = re.match(r'^(\s*)(\d+\.)(\s+.*)', line) + elif re.match(r"^\s*\d+\.\s", line): + match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line) if match: line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" processed_lines.append(line) - text = '\n'.join(processed_lines) + text = "\n".join(processed_lines) - text = re.sub(r'\[(.*?)\]\((.*?)\)', f'{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}', text) - text = re.sub(r'~~(.*?)~~', f'{Colors.GRAY}\\1{Colors.RESET}', text) - text = re.sub(r'\*\*(.*?)\*\*', f'{Colors.BOLD}\\1{Colors.RESET}', text) - text = re.sub(r'__(.*?)__', f'{Colors.BOLD}\\1{Colors.RESET}', text) - text = re.sub(r'\*(.*?)\*', f'{Colors.CYAN}\\1{Colors.RESET}', text) - text = re.sub(r'_(.*?)_', f'{Colors.CYAN}\\1{Colors.RESET}', text) + text = re.sub( + r"\[(.*?)\]\((.*?)\)", + f"{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}", + text, + ) + text = re.sub(r"~~(.*?)~~", f"{Colors.GRAY}\\1{Colors.RESET}", text) + text = re.sub(r"\*\*(.*?)\*\*", f"{Colors.BOLD}\\1{Colors.RESET}", text) + text = re.sub(r"__(.*?)__", f"{Colors.BOLD}\\1{Colors.RESET}", text) + text = re.sub(r"\*(.*?)\*", f"{Colors.CYAN}\\1{Colors.RESET}", text) + text = re.sub(r"_(.*?)_", f"{Colors.CYAN}\\1{Colors.RESET}", text) for i, code in enumerate(inline_codes): - text = text.replace(f'%%INLINECODE{i}%%', code) + text = text.replace(f"%%INLINECODE{i}%%", code) for i, block in enumerate(code_blocks): - text = text.replace(f'%%CODEBLOCK{i}%%', block) + text = text.replace(f"%%CODEBLOCK{i}%%", block) return text diff --git a/pr/workflows/__init__.py b/pr/workflows/__init__.py index df6e64f..cd8fa84 100644 --- a/pr/workflows/__init__.py +++ b/pr/workflows/__init__.py @@ -1,5 +1,11 @@ -from .workflow_definition import Workflow, WorkflowStep, ExecutionMode +from .workflow_definition import ExecutionMode, Workflow, WorkflowStep from .workflow_engine import WorkflowEngine from .workflow_storage import WorkflowStorage -__all__ = ['Workflow', 'WorkflowStep', 'ExecutionMode', 'WorkflowEngine', 'WorkflowStorage'] +__all__ = [ + "Workflow", + "WorkflowStep", + "ExecutionMode", + "WorkflowEngine", + "WorkflowStorage", +] diff --git a/pr/workflows/workflow_definition.py b/pr/workflows/workflow_definition.py index 0770bab..9b34200 100644 --- a/pr/workflows/workflow_definition.py +++ b/pr/workflows/workflow_definition.py @@ -1,12 +1,14 @@ -from enum import Enum -from typing import List, Dict, Any, Optional from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + class ExecutionMode(Enum): SEQUENTIAL = "sequential" PARALLEL = "parallel" CONDITIONAL = "conditional" + @dataclass class WorkflowStep: tool_name: str @@ -20,29 +22,30 @@ class WorkflowStep: def to_dict(self) -> Dict[str, Any]: return { - 'tool_name': self.tool_name, - 'arguments': self.arguments, - 'step_id': self.step_id, - 'condition': self.condition, - 'on_success': self.on_success, - 'on_failure': self.on_failure, - 'retry_count': self.retry_count, - 'timeout_seconds': self.timeout_seconds + "tool_name": self.tool_name, + "arguments": self.arguments, + "step_id": self.step_id, + "condition": self.condition, + "on_success": self.on_success, + "on_failure": self.on_failure, + "retry_count": self.retry_count, + "timeout_seconds": self.timeout_seconds, } @staticmethod - def from_dict(data: Dict[str, Any]) -> 'WorkflowStep': + def from_dict(data: Dict[str, Any]) -> "WorkflowStep": return WorkflowStep( - tool_name=data['tool_name'], - arguments=data['arguments'], - step_id=data['step_id'], - condition=data.get('condition'), - on_success=data.get('on_success'), - on_failure=data.get('on_failure'), - retry_count=data.get('retry_count', 0), - timeout_seconds=data.get('timeout_seconds', 300) + tool_name=data["tool_name"], + arguments=data["arguments"], + step_id=data["step_id"], + condition=data.get("condition"), + on_success=data.get("on_success"), + on_failure=data.get("on_failure"), + retry_count=data.get("retry_count", 0), + timeout_seconds=data.get("timeout_seconds", 300), ) + @dataclass class Workflow: name: str @@ -54,23 +57,23 @@ class Workflow: def to_dict(self) -> Dict[str, Any]: return { - 'name': self.name, - 'description': self.description, - 'steps': [step.to_dict() for step in self.steps], - 'execution_mode': self.execution_mode.value, - 'variables': self.variables, - 'tags': self.tags + "name": self.name, + "description": self.description, + "steps": [step.to_dict() for step in self.steps], + "execution_mode": self.execution_mode.value, + "variables": self.variables, + "tags": self.tags, } @staticmethod - def from_dict(data: Dict[str, Any]) -> 'Workflow': + def from_dict(data: Dict[str, Any]) -> "Workflow": return Workflow( - name=data['name'], - description=data['description'], - steps=[WorkflowStep.from_dict(step) for step in data['steps']], - execution_mode=ExecutionMode(data.get('execution_mode', 'sequential')), - variables=data.get('variables', {}), - tags=data.get('tags', []) + name=data["name"], + description=data["description"], + steps=[WorkflowStep.from_dict(step) for step in data["steps"]], + execution_mode=ExecutionMode(data.get("execution_mode", "sequential")), + variables=data.get("variables", {}), + tags=data.get("tags", []), ) def add_step(self, step: WorkflowStep): diff --git a/pr/workflows/workflow_engine.py b/pr/workflows/workflow_engine.py index 891ef44..66251b7 100644 --- a/pr/workflows/workflow_engine.py +++ b/pr/workflows/workflow_engine.py @@ -1,8 +1,10 @@ -import time import re -from typing import Dict, Any, List, Callable, Optional +import time from concurrent.futures import ThreadPoolExecutor, as_completed -from .workflow_definition import Workflow, WorkflowStep, ExecutionMode +from typing import Any, Callable, Dict, List, Optional + +from .workflow_definition import ExecutionMode, Workflow, WorkflowStep + class WorkflowExecutionContext: def __init__(self): @@ -23,57 +25,66 @@ class WorkflowExecutionContext: return self.step_results.get(step_id) def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]): - self.execution_log.append({ - 'timestamp': time.time(), - 'event_type': event_type, - 'step_id': step_id, - 'details': details - }) + self.execution_log.append( + { + "timestamp": time.time(), + "event_type": event_type, + "step_id": step_id, + "details": details, + } + ) + class WorkflowEngine: def __init__(self, tool_executor: Callable, max_workers: int = 5): self.tool_executor = tool_executor self.max_workers = max_workers - def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool: + def _evaluate_condition( + self, condition: str, context: WorkflowExecutionContext + ) -> bool: if not condition: return True try: safe_locals = { - 'variables': context.variables, - 'results': context.step_results + "variables": context.variables, + "results": context.step_results, } return eval(condition, {"__builtins__": {}}, safe_locals) except Exception: return False - def _substitute_variables(self, arguments: Dict[str, Any], context: WorkflowExecutionContext) -> Dict[str, Any]: + def _substitute_variables( + self, arguments: Dict[str, Any], context: WorkflowExecutionContext + ) -> Dict[str, Any]: substituted = {} for key, value in arguments.items(): if isinstance(value, str): - pattern = r'\$\{([^}]+)\}' + pattern = r"\$\{([^}]+)\}" matches = re.findall(pattern, value) for match in matches: - if match.startswith('step.'): - step_id = match.split('.', 1)[1] + if match.startswith("step."): + step_id = match.split(".", 1)[1] replacement = context.get_step_result(step_id) if replacement is not None: - value = value.replace(f'${{{match}}}', str(replacement)) - elif match.startswith('var.'): - var_name = match.split('.', 1)[1] + value = value.replace(f"${{{match}}}", str(replacement)) + elif match.startswith("var."): + var_name = match.split(".", 1)[1] replacement = context.get_variable(var_name) if replacement is not None: - value = value.replace(f'${{{match}}}', str(replacement)) + value = value.replace(f"${{{match}}}", str(replacement)) substituted[key] = value else: substituted[key] = value return substituted - def _execute_step(self, step: WorkflowStep, context: WorkflowExecutionContext) -> Dict[str, Any]: + def _execute_step( + self, step: WorkflowStep, context: WorkflowExecutionContext + ) -> Dict[str, Any]: if not self._evaluate_condition(step.condition, context): - context.log_event('skipped', step.step_id, {'reason': 'condition_not_met'}) - return {'status': 'skipped', 'step_id': step.step_id} + context.log_event("skipped", step.step_id, {"reason": "condition_not_met"}) + return {"status": "skipped", "step_id": step.step_id} arguments = self._substitute_variables(step.arguments, context) @@ -83,26 +94,34 @@ class WorkflowEngine: while retry_attempts <= step.retry_count: try: - context.log_event('executing', step.step_id, { - 'tool': step.tool_name, - 'arguments': arguments, - 'attempt': retry_attempts + 1 - }) + context.log_event( + "executing", + step.step_id, + { + "tool": step.tool_name, + "arguments": arguments, + "attempt": retry_attempts + 1, + }, + ) result = self.tool_executor(step.tool_name, arguments) execution_time = time.time() - start_time context.set_step_result(step.step_id, result) - context.log_event('completed', step.step_id, { - 'execution_time': execution_time, - 'result_size': len(str(result)) if result else 0 - }) + context.log_event( + "completed", + step.step_id, + { + "execution_time": execution_time, + "result_size": len(str(result)) if result else 0, + }, + ) return { - 'status': 'success', - 'step_id': step.step_id, - 'result': result, - 'execution_time': execution_time + "status": "success", + "step_id": step.step_id, + "result": result, + "execution_time": execution_time, } except Exception as e: @@ -111,25 +130,26 @@ class WorkflowEngine: if retry_attempts <= step.retry_count: time.sleep(1 * retry_attempts) - context.log_event('failed', step.step_id, {'error': last_error}) + context.log_event("failed", step.step_id, {"error": last_error}) return { - 'status': 'failed', - 'step_id': step.step_id, - 'error': last_error, - 'execution_time': time.time() - start_time + "status": "failed", + "step_id": step.step_id, + "error": last_error, + "execution_time": time.time() - start_time, } - def _get_next_steps(self, completed_step: WorkflowStep, result: Dict[str, Any], - workflow: Workflow) -> List[WorkflowStep]: + def _get_next_steps( + self, completed_step: WorkflowStep, result: Dict[str, Any], workflow: Workflow + ) -> List[WorkflowStep]: next_steps = [] - if result['status'] == 'success' and completed_step.on_success: + if result["status"] == "success" and completed_step.on_success: for step_id in completed_step.on_success: step = workflow.get_step(step_id) if step: next_steps.append(step) - elif result['status'] == 'failed' and completed_step.on_failure: + elif result["status"] == "failed" and completed_step.on_failure: for step_id in completed_step.on_failure: step = workflow.get_step(step_id) if step: @@ -142,7 +162,9 @@ class WorkflowEngine: return next_steps - def execute_workflow(self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None) -> WorkflowExecutionContext: + def execute_workflow( + self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None + ) -> WorkflowExecutionContext: context = WorkflowExecutionContext() if initial_variables: @@ -151,7 +173,7 @@ class WorkflowEngine: if workflow.variables: context.variables.update(workflow.variables) - context.log_event('workflow_started', 'workflow', {'name': workflow.name}) + context.log_event("workflow_started", "workflow", {"name": workflow.name}) if workflow.execution_mode == ExecutionMode.PARALLEL: with ThreadPoolExecutor(max_workers=self.max_workers) as executor: @@ -164,9 +186,11 @@ class WorkflowEngine: step = futures[future] try: result = future.result() - context.log_event('step_completed', step.step_id, result) + context.log_event("step_completed", step.step_id, result) except Exception as e: - context.log_event('step_failed', step.step_id, {'error': str(e)}) + context.log_event( + "step_failed", step.step_id, {"error": str(e)} + ) else: pending_steps = workflow.get_initial_steps() @@ -184,9 +208,13 @@ class WorkflowEngine: next_steps = self._get_next_steps(step, result, workflow) pending_steps.extend(next_steps) - context.log_event('workflow_completed', 'workflow', { - 'total_steps': len(context.step_results), - 'executed_steps': list(context.step_results.keys()) - }) + context.log_event( + "workflow_completed", + "workflow", + { + "total_steps": len(context.step_results), + "executed_steps": list(context.step_results.keys()), + }, + ) return context diff --git a/pr/workflows/workflow_storage.py b/pr/workflows/workflow_storage.py index da050d5..258ce4f 100644 --- a/pr/workflows/workflow_storage.py +++ b/pr/workflows/workflow_storage.py @@ -2,8 +2,10 @@ import json import sqlite3 import time from typing import List, Optional + from .workflow_definition import Workflow + class WorkflowStorage: def __init__(self, db_path: str): self.db_path = db_path @@ -13,7 +15,8 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS workflows ( workflow_id TEXT PRIMARY KEY, name TEXT NOT NULL, @@ -25,9 +28,11 @@ class WorkflowStorage: last_execution_at INTEGER, tags TEXT ) - ''') + """ + ) - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS workflow_executions ( execution_id TEXT PRIMARY KEY, workflow_id TEXT NOT NULL, @@ -39,17 +44,24 @@ class WorkflowStorage: step_results TEXT, FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id) ) - ''') + """ + ) - cursor.execute(''' + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id) - ''') - cursor.execute(''' + """ + ) + cursor.execute( + """ CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at) - ''') + """ + ) conn.commit() conn.close() @@ -66,12 +78,22 @@ class WorkflowStorage: current_time = int(time.time()) tags_json = json.dumps(workflow.tags) - cursor.execute(''' + cursor.execute( + """ INSERT OR REPLACE INTO workflows (workflow_id, name, description, workflow_data, created_at, updated_at, tags) VALUES (?, ?, ?, ?, ?, ?, ?) - ''', (workflow_id, workflow.name, workflow.description, workflow_data, - current_time, current_time, tags_json)) + """, + ( + workflow_id, + workflow.name, + workflow.description, + workflow_data, + current_time, + current_time, + tags_json, + ), + ) conn.commit() conn.close() @@ -82,7 +104,9 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('SELECT workflow_data FROM workflows WHERE workflow_id = ?', (workflow_id,)) + cursor.execute( + "SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,) + ) row = cursor.fetchone() conn.close() @@ -95,7 +119,7 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('SELECT workflow_data FROM workflows WHERE name = ?', (name,)) + cursor.execute("SELECT workflow_data FROM workflows WHERE name = ?", (name,)) row = cursor.fetchone() conn.close() @@ -109,29 +133,36 @@ class WorkflowStorage: cursor = conn.cursor() if tag: - cursor.execute(''' + cursor.execute( + """ SELECT workflow_id, name, description, execution_count, last_execution_at, tags FROM workflows WHERE tags LIKE ? ORDER BY name - ''', (f'%"{tag}"%',)) + """, + (f'%"{tag}"%',), + ) else: - cursor.execute(''' + cursor.execute( + """ SELECT workflow_id, name, description, execution_count, last_execution_at, tags FROM workflows ORDER BY name - ''') + """ + ) workflows = [] for row in cursor.fetchall(): - workflows.append({ - 'workflow_id': row[0], - 'name': row[1], - 'description': row[2], - 'execution_count': row[3], - 'last_execution_at': row[4], - 'tags': json.loads(row[5]) if row[5] else [] - }) + workflows.append( + { + "workflow_id": row[0], + "name": row[1], + "description": row[2], + "execution_count": row[3], + "last_execution_at": row[4], + "tags": json.loads(row[5]) if row[5] else [], + } + ) conn.close() return workflows @@ -140,18 +171,21 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute('DELETE FROM workflows WHERE workflow_id = ?', (workflow_id,)) + cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,)) deleted = cursor.rowcount > 0 - cursor.execute('DELETE FROM workflow_executions WHERE workflow_id = ?', (workflow_id,)) + cursor.execute( + "DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,) + ) conn.commit() conn.close() return deleted - def save_execution(self, workflow_id: str, execution_context: 'WorkflowExecutionContext') -> str: - import hashlib + def save_execution( + self, workflow_id: str, execution_context: "WorkflowExecutionContext" + ) -> str: import uuid execution_id = str(uuid.uuid4())[:16] @@ -159,30 +193,40 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - started_at = int(execution_context.execution_log[0]['timestamp']) if execution_context.execution_log else int(time.time()) + started_at = ( + int(execution_context.execution_log[0]["timestamp"]) + if execution_context.execution_log + else int(time.time()) + ) completed_at = int(time.time()) - cursor.execute(''' + cursor.execute( + """ INSERT INTO workflow_executions (execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - execution_id, - workflow_id, - started_at, - completed_at, - 'completed', - json.dumps(execution_context.execution_log), - json.dumps(execution_context.variables), - json.dumps(execution_context.step_results) - )) + """, + ( + execution_id, + workflow_id, + started_at, + completed_at, + "completed", + json.dumps(execution_context.execution_log), + json.dumps(execution_context.variables), + json.dumps(execution_context.step_results), + ), + ) - cursor.execute(''' + cursor.execute( + """ UPDATE workflows SET execution_count = execution_count + 1, last_execution_at = ? WHERE workflow_id = ? - ''', (completed_at, workflow_id)) + """, + (completed_at, workflow_id), + ) conn.commit() conn.close() @@ -193,22 +237,27 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute(''' + cursor.execute( + """ SELECT execution_id, started_at, completed_at, status FROM workflow_executions WHERE workflow_id = ? ORDER BY started_at DESC LIMIT ? - ''', (workflow_id, limit)) + """, + (workflow_id, limit), + ) executions = [] for row in cursor.fetchall(): - executions.append({ - 'execution_id': row[0], - 'started_at': row[1], - 'completed_at': row[2], - 'status': row[3] - }) + executions.append( + { + "execution_id": row[0], + "started_at": row[1], + "completed_at": row[2], + "status": row[3], + } + ) conn.close() return executions diff --git a/rp.py b/rp.py index 8700db8..68021e6 100755 --- a/rp.py +++ b/rp.py @@ -2,8 +2,7 @@ # Trigger build -import sys from pr.__main__ import main -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/conftest.py b/tests/conftest.py index 5011af9..6227450 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,9 @@ -import pytest import os import tempfile from unittest.mock import MagicMock +import pytest + @pytest.fixture def temp_dir(): @@ -13,19 +14,8 @@ def temp_dir(): @pytest.fixture def mock_api_response(): return { - 'choices': [ - { - 'message': { - 'role': 'assistant', - 'content': 'Test response' - } - } - ], - 'usage': { - 'prompt_tokens': 10, - 'completion_tokens': 5, - 'total_tokens': 15 - } + "choices": [{"message": {"role": "assistant", "content": "Test response"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, } @@ -47,7 +37,7 @@ def mock_args(): @pytest.fixture def sample_context_file(temp_dir): - context_path = os.path.join(temp_dir, '.rcontext.txt') - with open(context_path, 'w') as f: - f.write('Sample context content\n') + context_path = os.path.join(temp_dir, ".rcontext.txt") + with open(context_path, "w") as f: + f.write("Sample context content\n") return context_path diff --git a/tests/test_advanced_context.py b/tests/test_advanced_context.py index 218fec1..7a15bea 100644 --- a/tests/test_advanced_context.py +++ b/tests/test_advanced_context.py @@ -1,39 +1,53 @@ -import pytest from pr.core.advanced_context import AdvancedContextManager + def test_adaptive_context_window_simple(): mgr = AdvancedContextManager() - messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] - window = mgr.adaptive_context_window(messages, 'simple') + messages = [ + {"content": "short"}, + {"content": "this is a longer message with more words"}, + ] + window = mgr.adaptive_context_window(messages, "simple") assert isinstance(window, int) assert window >= 10 + def test_adaptive_context_window_medium(): mgr = AdvancedContextManager() - messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] - window = mgr.adaptive_context_window(messages, 'medium') + messages = [ + {"content": "short"}, + {"content": "this is a longer message with more words"}, + ] + window = mgr.adaptive_context_window(messages, "medium") assert isinstance(window, int) assert window >= 20 + def test_adaptive_context_window_complex(): mgr = AdvancedContextManager() - messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] - window = mgr.adaptive_context_window(messages, 'complex') + messages = [ + {"content": "short"}, + {"content": "this is a longer message with more words"}, + ] + window = mgr.adaptive_context_window(messages, "complex") assert isinstance(window, int) assert window >= 35 + def test_analyze_message_complexity(): mgr = AdvancedContextManager() - messages = [{'content': 'hello world'}, {'content': 'hello again'}] + messages = [{"content": "hello world"}, {"content": "hello again"}] score = mgr._analyze_message_complexity(messages) assert 0 <= score <= 1 + def test_analyze_message_complexity_empty(): mgr = AdvancedContextManager() messages = [] score = mgr._analyze_message_complexity(messages) assert score == 0 + def test_extract_key_sentences(): mgr = AdvancedContextManager() text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words." @@ -41,41 +55,47 @@ def test_extract_key_sentences(): assert len(sentences) <= 2 assert all(isinstance(s, str) for s in sentences) + def test_extract_key_sentences_empty(): mgr = AdvancedContextManager() text = "" sentences = mgr.extract_key_sentences(text, 5) assert sentences == [] + def test_advanced_summarize_messages(): mgr = AdvancedContextManager() - messages = [{'content': 'Hello'}, {'content': 'How are you?'}] + messages = [{"content": "Hello"}, {"content": "How are you?"}] summary = mgr.advanced_summarize_messages(messages) assert isinstance(summary, str) + def test_advanced_summarize_messages_empty(): mgr = AdvancedContextManager() messages = [] summary = mgr.advanced_summarize_messages(messages) assert summary == "No content to summarize." + def test_score_message_relevance(): mgr = AdvancedContextManager() - message = {'content': 'hello world'} - context = 'world hello' + message = {"content": "hello world"} + context = "world hello" score = mgr.score_message_relevance(message, context) assert 0 <= score <= 1 + def test_score_message_relevance_no_overlap(): mgr = AdvancedContextManager() - message = {'content': 'hello'} - context = 'world' + message = {"content": "hello"} + context = "world" score = mgr.score_message_relevance(message, context) assert score == 0 + def test_score_message_relevance_empty(): mgr = AdvancedContextManager() - message = {'content': ''} - context = '' + message = {"content": ""} + context = "" score = mgr.score_message_relevance(message, context) - assert score == 0 \ No newline at end of file + assert score == 0 diff --git a/tests/test_agents.py b/tests/test_agents.py index 59d47cb..3fa63db 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -1,127 +1,213 @@ -import pytest -import time +from pr.agents.agent_communication import ( + AgentCommunicationBus, + AgentMessage, + MessageType, +) +from pr.agents.agent_manager import AgentInstance, AgentManager from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles -from pr.agents.agent_manager import AgentManager, AgentInstance -from pr.agents.agent_communication import AgentCommunicationBus, AgentMessage, MessageType + def test_get_agent_role(): - role = get_agent_role('coding') + role = get_agent_role("coding") assert isinstance(role, AgentRole) - assert role.name == 'coding' + assert role.name == "coding" + def test_list_agent_roles(): roles = list_agent_roles() assert isinstance(roles, dict) assert len(roles) > 0 - assert 'coding' in roles + assert "coding" in roles + def test_agent_role(): - role = AgentRole(name='test', description='test', system_prompt='test', allowed_tools=set(), specialization_areas=[]) - assert role.name == 'test' + role = AgentRole( + name="test", + description="test", + system_prompt="test", + allowed_tools=set(), + specialization_areas=[], + ) + assert role.name == "test" + def test_agent_instance(): - role = get_agent_role('coding') - instance = AgentInstance(agent_id='test', role=role) - assert instance.agent_id == 'test' + role = get_agent_role("coding") + instance = AgentInstance(agent_id="test", role=role) + assert instance.agent_id == "test" assert instance.role == role + def test_agent_manager_init(): - mgr = AgentManager(':memory:', None) + mgr = AgentManager(":memory:", None) assert mgr is not None + def test_agent_manager_create_agent(): - mgr = AgentManager(':memory:', None) - agent = mgr.create_agent('coding', 'test_agent') + mgr = AgentManager(":memory:", None) + agent = mgr.create_agent("coding", "test_agent") assert agent is not None + def test_agent_manager_get_agent(): - mgr = AgentManager(':memory:', None) - mgr.create_agent('coding', 'test_agent') - agent = mgr.get_agent('test_agent') + mgr = AgentManager(":memory:", None) + mgr.create_agent("coding", "test_agent") + agent = mgr.get_agent("test_agent") assert isinstance(agent, AgentInstance) + def test_agent_manager_remove_agent(): - mgr = AgentManager(':memory:', None) - mgr.create_agent('coding', 'test_agent') - mgr.remove_agent('test_agent') - agent = mgr.get_agent('test_agent') + mgr = AgentManager(":memory:", None) + mgr.create_agent("coding", "test_agent") + mgr.remove_agent("test_agent") + agent = mgr.get_agent("test_agent") assert agent is None + def test_agent_manager_send_agent_message(): - mgr = AgentManager(':memory:', None) - mgr.create_agent('coding', 'a') - mgr.create_agent('coding', 'b') - mgr.send_agent_message('a', 'b', 'test') + mgr = AgentManager(":memory:", None) + mgr.create_agent("coding", "a") + mgr.create_agent("coding", "b") + mgr.send_agent_message("a", "b", "test") assert True + def test_agent_manager_get_agent_messages(): - mgr = AgentManager(':memory:', None) - mgr.create_agent('coding', 'test') - messages = mgr.get_agent_messages('test') + mgr = AgentManager(":memory:", None) + mgr.create_agent("coding", "test") + messages = mgr.get_agent_messages("test") assert isinstance(messages, list) + def test_agent_manager_get_session_summary(): - mgr = AgentManager(':memory:', None) + mgr = AgentManager(":memory:", None) summary = mgr.get_session_summary() assert isinstance(summary, str) + def test_agent_manager_collaborate_agents(): - mgr = AgentManager(':memory:', None) - result = mgr.collaborate_agents('orchestrator', 'task', ['coding', 'research']) + mgr = AgentManager(":memory:", None) + result = mgr.collaborate_agents("orchestrator", "task", ["coding", "research"]) assert result is not None + def test_agent_manager_execute_agent_task(): - mgr = AgentManager(':memory:', None) - mgr.create_agent('coding', 'test') - result = mgr.execute_agent_task('test', 'task') + mgr = AgentManager(":memory:", None) + mgr.create_agent("coding", "test") + result = mgr.execute_agent_task("test", "task") assert result is not None + def test_agent_manager_clear_session(): - mgr = AgentManager(':memory:', None) + mgr = AgentManager(":memory:", None) mgr.clear_session() assert True + def test_agent_message(): - msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') - assert msg.from_agent == 'a' + msg = AgentMessage( + from_agent="a", + to_agent="b", + message_type=MessageType.REQUEST, + content="test", + metadata={}, + timestamp=1.0, + message_id="id", + ) + assert msg.from_agent == "a" + def test_agent_message_to_dict(): - msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') + msg = AgentMessage( + from_agent="a", + to_agent="b", + message_type=MessageType.REQUEST, + content="test", + metadata={}, + timestamp=1.0, + message_id="id", + ) d = msg.to_dict() assert isinstance(d, dict) + def test_agent_message_from_dict(): - d = {'from_agent': 'a', 'to_agent': 'b', 'message_type': 'request', 'content': 'test', 'metadata': {}, 'timestamp': 1.0, 'message_id': 'id'} + d = { + "from_agent": "a", + "to_agent": "b", + "message_type": "request", + "content": "test", + "metadata": {}, + "timestamp": 1.0, + "message_id": "id", + } msg = AgentMessage.from_dict(d) assert isinstance(msg, AgentMessage) + def test_agent_communication_bus_init(): - bus = AgentCommunicationBus(':memory:') + bus = AgentCommunicationBus(":memory:") assert bus is not None + def test_agent_communication_bus_send_message(): - bus = AgentCommunicationBus(':memory:') - msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') + bus = AgentCommunicationBus(":memory:") + msg = AgentMessage( + from_agent="a", + to_agent="b", + message_type=MessageType.REQUEST, + content="test", + metadata={}, + timestamp=1.0, + message_id="id", + ) bus.send_message(msg) assert True + def test_agent_communication_bus_receive_messages(): - bus = AgentCommunicationBus(':memory:') - msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') + bus = AgentCommunicationBus(":memory:") + msg = AgentMessage( + from_agent="a", + to_agent="b", + message_type=MessageType.REQUEST, + content="test", + metadata={}, + timestamp=1.0, + message_id="id", + ) bus.send_message(msg) - messages = bus.receive_messages('b') + messages = bus.receive_messages("b") assert len(messages) == 1 + def test_agent_communication_bus_get_conversation_history(): - bus = AgentCommunicationBus(':memory:') - msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') + bus = AgentCommunicationBus(":memory:") + msg = AgentMessage( + from_agent="a", + to_agent="b", + message_type=MessageType.REQUEST, + content="test", + metadata={}, + timestamp=1.0, + message_id="id", + ) bus.send_message(msg) - history = bus.get_conversation_history('a', 'b') + history = bus.get_conversation_history("a", "b") assert len(history) == 1 + def test_agent_communication_bus_mark_as_read(): - bus = AgentCommunicationBus(':memory:') - msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') + bus = AgentCommunicationBus(":memory:") + msg = AgentMessage( + from_agent="a", + to_agent="b", + message_type=MessageType.REQUEST, + content="test", + metadata={}, + timestamp=1.0, + message_id="id", + ) bus.send_message(msg) bus.mark_as_read(msg.message_id) - assert True \ No newline at end of file + assert True diff --git a/tests/test_api.py b/tests/test_api.py index 3f01176..f67c491 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,63 +1,65 @@ import unittest -from unittest.mock import patch, MagicMock -import json import urllib.error +from unittest.mock import MagicMock, patch + from pr.core.api import call_api, list_models class TestApi(unittest.TestCase): - @patch('pr.core.api.urllib.request.urlopen') - @patch('pr.core.api.auto_slim_messages') + @patch("pr.core.api.urllib.request.urlopen") + @patch("pr.core.api.auto_slim_messages") def test_call_api_success(self, mock_slim, mock_urlopen): - mock_slim.return_value = [{'role': 'user', 'content': 'test'}] + mock_slim.return_value = [{"role": "user", "content": "test"}] mock_response = MagicMock() mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}' mock_urlopen.return_value.__enter__.return_value = mock_response - result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}]) + result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}]) - self.assertIn('choices', result) + self.assertIn("choices", result) mock_urlopen.assert_called_once() - @patch('urllib.request.urlopen') - @patch('pr.core.api.auto_slim_messages') + @patch("urllib.request.urlopen") + @patch("pr.core.api.auto_slim_messages") def test_call_api_http_error(self, mock_slim, mock_urlopen): - mock_slim.return_value = [{'role': 'user', 'content': 'test'}] - mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock()) + mock_slim.return_value = [{"role": "user", "content": "test"}] + mock_urlopen.side_effect = urllib.error.HTTPError( + "http://url", 500, "error", None, MagicMock() + ) - result = call_api([], 'model', 'http://url', 'key', False, []) + result = call_api([], "model", "http://url", "key", False, []) - self.assertIn('error', result) + self.assertIn("error", result) - @patch('urllib.request.urlopen') - @patch('pr.core.api.auto_slim_messages') + @patch("urllib.request.urlopen") + @patch("pr.core.api.auto_slim_messages") def test_call_api_general_error(self, mock_slim, mock_urlopen): - mock_slim.return_value = [{'role': 'user', 'content': 'test'}] - mock_urlopen.side_effect = Exception('test error') + mock_slim.return_value = [{"role": "user", "content": "test"}] + mock_urlopen.side_effect = Exception("test error") - result = call_api([], 'model', 'http://url', 'key', False, []) + result = call_api([], "model", "http://url", "key", False, []) - self.assertIn('error', result) + self.assertIn("error", result) - @patch('urllib.request.urlopen') + @patch("urllib.request.urlopen") def test_list_models_success(self, mock_urlopen): mock_response = MagicMock() mock_response.read.return_value = b'{"data": [{"id": "model1"}]}' mock_urlopen.return_value.__enter__.return_value = mock_response - result = list_models('http://url', 'key') + result = list_models("http://url", "key") - self.assertEqual(result, [{'id': 'model1'}]) + self.assertEqual(result, [{"id": "model1"}]) - @patch('urllib.request.urlopen') + @patch("urllib.request.urlopen") def test_list_models_error(self, mock_urlopen): - mock_urlopen.side_effect = Exception('error') + mock_urlopen.side_effect = Exception("error") - result = list_models('http://url', 'key') + result = list_models("http://url", "key") - self.assertIn('error', result) + self.assertIn("error", result) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_assistant.py b/tests/test_assistant.py index b6f6037..b3395e9 100644 --- a/tests/test_assistant.py +++ b/tests/test_assistant.py @@ -1,7 +1,6 @@ import unittest -from unittest.mock import patch, MagicMock -import tempfile -import os +from unittest.mock import MagicMock, patch + from pr.core.assistant import Assistant, process_message @@ -12,83 +11,106 @@ class TestAssistant(unittest.TestCase): self.args.verbose = False self.args.debug = False self.args.no_syntax = False - self.args.model = 'test-model' - self.args.api_url = 'test-url' - self.args.model_list_url = 'test-list-url' + self.args.model = "test-model" + self.args.api_url = "test-url" + self.args.model_list_url = "test-list-url" - @patch('sqlite3.connect') - @patch('os.environ.get') - @patch('pr.core.context.init_system_message') - @patch('pr.core.enhanced_assistant.EnhancedAssistant') + @patch("sqlite3.connect") + @patch("os.environ.get") + @patch("pr.core.context.init_system_message") + @patch("pr.core.enhanced_assistant.EnhancedAssistant") def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite): - mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default) + mock_env.side_effect = lambda key, default: { + "OPENROUTER_API_KEY": "key", + "AI_MODEL": "model", + "API_URL": "url", + "MODEL_LIST_URL": "list", + "USE_TOOLS": "1", + "STRICT_MODE": "0", + }.get(key, default) mock_conn = MagicMock() mock_sqlite.return_value = mock_conn - mock_init_sys.return_value = {'role': 'system', 'content': 'sys'} + mock_init_sys.return_value = {"role": "system", "content": "sys"} assistant = Assistant(self.args) - self.assertEqual(assistant.api_key, 'key') - self.assertEqual(assistant.model, 'test-model') + self.assertEqual(assistant.api_key, "key") + self.assertEqual(assistant.model, "test-model") mock_sqlite.assert_called_once() - @patch('pr.core.assistant.call_api') - @patch('pr.core.assistant.render_markdown') + @patch("pr.core.assistant.call_api") + @patch("pr.core.assistant.render_markdown") def test_process_response_no_tools(self, mock_render, mock_call): assistant = MagicMock() assistant.messages = MagicMock() assistant.verbose = False assistant.syntax_highlighting = True - mock_render.return_value = 'rendered' + mock_render.return_value = "rendered" - response = {'choices': [{'message': {'content': 'content'}}]} + response = {"choices": [{"message": {"content": "content"}}]} result = Assistant.process_response(assistant, response) - self.assertEqual(result, 'rendered') - assistant.messages.append.assert_called_with({'content': 'content'}) + self.assertEqual(result, "rendered") + assistant.messages.append.assert_called_with({"content": "content"}) - @patch('pr.core.assistant.call_api') - @patch('pr.core.assistant.render_markdown') - @patch('pr.core.assistant.get_tools_definition') + @patch("pr.core.assistant.call_api") + @patch("pr.core.assistant.render_markdown") + @patch("pr.core.assistant.get_tools_definition") def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call): assistant = MagicMock() assistant.messages = MagicMock() assistant.verbose = False assistant.syntax_highlighting = True assistant.use_tools = True - assistant.model = 'model' - assistant.api_url = 'url' - assistant.api_key = 'key' + assistant.model = "model" + assistant.api_url = "url" + assistant.api_key = "key" mock_tools_def.return_value = [] - mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]} + mock_call.return_value = {"choices": [{"message": {"content": "follow"}}]} - response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]} + response = { + "choices": [ + { + "message": { + "tool_calls": [ + {"id": "1", "function": {"name": "test", "arguments": "{}"}} + ] + } + } + ] + } - with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]): - result = Assistant.process_response(assistant, response) + with patch.object( + assistant, + "execute_tool_calls", + return_value=[{"role": "tool", "content": "result"}], + ): + Assistant.process_response(assistant, response) mock_call.assert_called() - @patch('pr.core.assistant.call_api') - @patch('pr.core.assistant.get_tools_definition') + @patch("pr.core.assistant.call_api") + @patch("pr.core.assistant.get_tools_definition") def test_process_message(self, mock_tools, mock_call): assistant = MagicMock() assistant.messages = MagicMock() assistant.verbose = False assistant.use_tools = True - assistant.model = 'model' - assistant.api_url = 'url' - assistant.api_key = 'key' + assistant.model = "model" + assistant.api_url = "url" + assistant.api_key = "key" mock_tools.return_value = [] - mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]} + mock_call.return_value = {"choices": [{"message": {"content": "response"}}]} - with patch('pr.core.assistant.render_markdown', return_value='rendered'): - with patch('builtins.print'): - process_message(assistant, 'test message') + with patch("pr.core.assistant.render_markdown", return_value="rendered"): + with patch("builtins.print"): + process_message(assistant, "test message") - assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'}) + assistant.messages.append.assert_called_with( + {"role": "user", "content": "test message"} + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_config.py b/tests/test_config.py index 407864b..1cadc20 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,31 +1,30 @@ -import pytest from pr import config class TestConfig: def test_default_model_exists(self): - assert hasattr(config, 'DEFAULT_MODEL') + assert hasattr(config, "DEFAULT_MODEL") assert isinstance(config.DEFAULT_MODEL, str) assert len(config.DEFAULT_MODEL) > 0 def test_api_url_exists(self): - assert hasattr(config, 'DEFAULT_API_URL') - assert config.DEFAULT_API_URL.startswith('http') + assert hasattr(config, "DEFAULT_API_URL") + assert config.DEFAULT_API_URL.startswith("http") def test_file_paths_exist(self): - assert hasattr(config, 'DB_PATH') - assert hasattr(config, 'LOG_FILE') - assert hasattr(config, 'HISTORY_FILE') + assert hasattr(config, "DB_PATH") + assert hasattr(config, "LOG_FILE") + assert hasattr(config, "HISTORY_FILE") def test_autonomous_config(self): - assert hasattr(config, 'MAX_AUTONOMOUS_ITERATIONS') + assert hasattr(config, "MAX_AUTONOMOUS_ITERATIONS") assert config.MAX_AUTONOMOUS_ITERATIONS > 0 - assert hasattr(config, 'CONTEXT_COMPRESSION_THRESHOLD') + assert hasattr(config, "CONTEXT_COMPRESSION_THRESHOLD") assert config.CONTEXT_COMPRESSION_THRESHOLD > 0 def test_language_keywords(self): - assert hasattr(config, 'LANGUAGE_KEYWORDS') - assert 'python' in config.LANGUAGE_KEYWORDS - assert isinstance(config.LANGUAGE_KEYWORDS['python'], list) + assert hasattr(config, "LANGUAGE_KEYWORDS") + assert "python" in config.LANGUAGE_KEYWORDS + assert isinstance(config.LANGUAGE_KEYWORDS["python"], list) diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 96e4171..4e62549 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -1,56 +1,71 @@ -import pytest -from unittest.mock import patch, mock_open -import os -from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config +from unittest.mock import mock_open, patch + +from pr.core.config_loader import ( + _load_config_file, + _parse_value, + create_default_config, + load_config, +) + def test_parse_value_string(): - assert _parse_value('hello') == 'hello' + assert _parse_value("hello") == "hello" + def test_parse_value_int(): - assert _parse_value('123') == 123 + assert _parse_value("123") == 123 + def test_parse_value_float(): - assert _parse_value('1.23') == 1.23 + assert _parse_value("1.23") == 1.23 + def test_parse_value_bool_true(): - assert _parse_value('true') == True + assert _parse_value("true") == True + def test_parse_value_bool_false(): - assert _parse_value('false') == False + assert _parse_value("false") == False + def test_parse_value_bool_upper(): - assert _parse_value('TRUE') == True + assert _parse_value("TRUE") == True -@patch('os.path.exists', return_value=False) + +@patch("os.path.exists", return_value=False) def test_load_config_file_not_exists(mock_exists): - config = _load_config_file('test.ini') + config = _load_config_file("test.ini") assert config == {} -@patch('os.path.exists', return_value=True) -@patch('configparser.ConfigParser') + +@patch("os.path.exists", return_value=True) +@patch("configparser.ConfigParser") def test_load_config_file_exists(mock_parser_class, mock_exists): mock_parser = mock_parser_class.return_value - mock_parser.sections.return_value = ['api'] - mock_parser.items.return_value = [('key', 'value')] - config = _load_config_file('test.ini') - assert 'api' in config - assert config['api']['key'] == 'value' + mock_parser.sections.return_value = ["api"] + mock_parser.items.return_value = [("key", "value")] + config = _load_config_file("test.ini") + assert "api" in config + assert config["api"]["key"] == "value" -@patch('pr.core.config_loader._load_config_file') + +@patch("pr.core.config_loader._load_config_file") def test_load_config(mock_load): - mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}] + mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}] config = load_config() - assert config['api']['key'] == 'local' + assert config["api"]["key"] == "local" -@patch('builtins.open', new_callable=mock_open) + +@patch("builtins.open", new_callable=mock_open) def test_create_default_config(mock_file): - result = create_default_config('test.ini') + result = create_default_config("test.ini") assert result == True - mock_file.assert_called_once_with('test.ini', 'w') + mock_file.assert_called_once_with("test.ini", "w") handle = mock_file() handle.write.assert_called_once() -@patch('builtins.open', side_effect=Exception('error')) + +@patch("builtins.open", side_effect=Exception("error")) def test_create_default_config_error(mock_file): - result = create_default_config('test.ini') - assert result == False \ No newline at end of file + result = create_default_config("test.ini") + assert result == False diff --git a/tests/test_context.py b/tests/test_context.py index bb39686..f7ec596 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,30 +1,29 @@ -import pytest -from pr.core.context import should_compress_context, compress_context from pr.config import RECENT_MESSAGES_TO_KEEP +from pr.core.context import compress_context, should_compress_context class TestContextManagement: def test_should_compress_context_below_threshold(self): - messages = [{'role': 'user', 'content': 'test'}] * 10 + messages = [{"role": "user", "content": "test"}] * 10 assert should_compress_context(messages) is False def test_should_compress_context_above_threshold(self): - messages = [{'role': 'user', 'content': 'test'}] * 35 + messages = [{"role": "user", "content": "test"}] * 35 assert should_compress_context(messages) is True def test_compress_context_preserves_system_message(self): messages = [ - {'role': 'system', 'content': 'System prompt'}, - {'role': 'user', 'content': 'Hello'}, - {'role': 'assistant', 'content': 'Hi'}, + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, ] * 40 # Ensure compression compressed = compress_context(messages) - assert compressed[0]['role'] == 'system' - assert 'System prompt' in compressed[0]['content'] + assert compressed[0]["role"] == "system" + assert "System prompt" in compressed[0]["content"] def test_compress_context_keeps_recent_messages(self): - messages = [{'role': 'user', 'content': f'msg{i}'} for i in range(40)] + messages = [{"role": "user", "content": f"msg{i}"} for i in range(40)] compressed = compress_context(messages) # Should keep recent messages recent = compressed[-RECENT_MESSAGES_TO_KEEP:] @@ -32,4 +31,4 @@ class TestContextManagement: # Check that the messages are the most recent ones for i, msg in enumerate(recent): expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i - assert msg['content'] == f'msg{expected_index}' \ No newline at end of file + assert msg["content"] == f"msg{expected_index}" diff --git a/tests/test_enhanced_assistant.py b/tests/test_enhanced_assistant.py index d700f7e..29f1d5d 100644 --- a/tests/test_enhanced_assistant.py +++ b/tests/test_enhanced_assistant.py @@ -1,89 +1,97 @@ -import pytest from unittest.mock import MagicMock + from pr.core.enhanced_assistant import EnhancedAssistant + def test_enhanced_assistant_init(): mock_base = MagicMock() assistant = EnhancedAssistant(mock_base) assert assistant.base == mock_base assert assistant.current_conversation_id is not None + def test_enhanced_call_api_with_cache(): mock_base = MagicMock() - mock_base.model = 'test-model' - mock_base.api_url = 'http://test' - mock_base.api_key = 'key' + mock_base.model = "test-model" + mock_base.api_url = "http://test" + mock_base.api_key = "key" mock_base.use_tools = False mock_base.verbose = False - + assistant = EnhancedAssistant(mock_base) assistant.api_cache = MagicMock() - assistant.api_cache.get.return_value = {'cached': True} - - result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}]) - assert result == {'cached': True} + assistant.api_cache.get.return_value = {"cached": True} + + result = assistant.enhanced_call_api([{"role": "user", "content": "test"}]) + assert result == {"cached": True} assistant.api_cache.get.assert_called_once() + def test_enhanced_call_api_without_cache(): mock_base = MagicMock() - mock_base.model = 'test-model' - mock_base.api_url = 'http://test' - mock_base.api_key = 'key' + mock_base.model = "test-model" + mock_base.api_url = "http://test" + mock_base.api_key = "key" mock_base.use_tools = False mock_base.verbose = False - + assistant = EnhancedAssistant(mock_base) assistant.api_cache = None - + # It will try to call API and fail with network error, but that's expected - result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}]) - assert 'error' in result + result = assistant.enhanced_call_api([{"role": "user", "content": "test"}]) + assert "error" in result + def test_execute_workflow_not_found(): mock_base = MagicMock() assistant = EnhancedAssistant(mock_base) assistant.workflow_storage = MagicMock() assistant.workflow_storage.load_workflow_by_name.return_value = None - - result = assistant.execute_workflow('nonexistent') - assert 'error' in result + + result = assistant.execute_workflow("nonexistent") + assert "error" in result + def test_create_agent(): mock_base = MagicMock() assistant = EnhancedAssistant(mock_base) assistant.agent_manager = MagicMock() - assistant.agent_manager.create_agent.return_value = 'agent_id' - - result = assistant.create_agent('role') - assert result == 'agent_id' + assistant.agent_manager.create_agent.return_value = "agent_id" + + result = assistant.create_agent("role") + assert result == "agent_id" + def test_search_knowledge(): mock_base = MagicMock() assistant = EnhancedAssistant(mock_base) assistant.knowledge_store = MagicMock() - assistant.knowledge_store.search_entries.return_value = [{'result': True}] - - result = assistant.search_knowledge('query') - assert result == [{'result': True}] + assistant.knowledge_store.search_entries.return_value = [{"result": True}] + + result = assistant.search_knowledge("query") + assert result == [{"result": True}] + def test_get_cache_statistics(): mock_base = MagicMock() assistant = EnhancedAssistant(mock_base) assistant.api_cache = MagicMock() - assistant.api_cache.get_statistics.return_value = {'hits': 10} + assistant.api_cache.get_statistics.return_value = {"hits": 10} assistant.tool_cache = MagicMock() - assistant.tool_cache.get_statistics.return_value = {'misses': 5} - + assistant.tool_cache.get_statistics.return_value = {"misses": 5} + stats = assistant.get_cache_statistics() - assert 'api_cache' in stats - assert 'tool_cache' in stats + assert "api_cache" in stats + assert "tool_cache" in stats + def test_clear_caches(): mock_base = MagicMock() assistant = EnhancedAssistant(mock_base) assistant.api_cache = MagicMock() assistant.tool_cache = MagicMock() - + assistant.clear_caches() assistant.api_cache.clear_all.assert_called_once() assistant.tool_cache.clear_all.assert_called_once() diff --git a/tests/test_logging.py b/tests/test_logging.py index 9f19267..5336aea 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,24 +1,25 @@ -import pytest -import tempfile -import os -from pr.core.logging import setup_logging, get_logger +from pr.core.logging import get_logger, setup_logging + def test_setup_logging_basic(): logger = setup_logging(verbose=False) - assert logger.name == 'pr' + assert logger.name == "pr" assert logger.level == 20 # INFO + def test_setup_logging_verbose(): logger = setup_logging(verbose=True) - assert logger.name == 'pr' + assert logger.name == "pr" assert logger.level == 10 # DEBUG # Should have console handler assert len(logger.handlers) >= 2 + def test_get_logger_default(): logger = get_logger() - assert logger.name == 'pr' + assert logger.name == "pr" + def test_get_logger_named(): - logger = get_logger('test') - assert logger.name == 'pr.test' + logger = get_logger("test") + assert logger.name == "pr.test" diff --git a/tests/test_main.py b/tests/test_main.py index 2a483b5..7340157 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,118 +1,134 @@ -import pytest from unittest.mock import patch -import sys + +import pytest + from pr.__main__ import main + def test_main_version(capsys): - with patch('sys.argv', ['pr', '--version']): + with patch("sys.argv", ["pr", "--version"]): with pytest.raises(SystemExit): main() captured = capsys.readouterr() - assert 'PR Assistant' in captured.out + assert "PR Assistant" in captured.out + def test_main_create_config_success(capsys): - with patch('pr.core.config_loader.create_default_config', return_value=True): - with patch('sys.argv', ['pr', '--create-config']): + with patch("pr.core.config_loader.create_default_config", return_value=True): + with patch("sys.argv", ["pr", "--create-config"]): main() captured = capsys.readouterr() - assert 'Configuration file created' in captured.out + assert "Configuration file created" in captured.out + def test_main_create_config_fail(capsys): - with patch('pr.core.config_loader.create_default_config', return_value=False): - with patch('sys.argv', ['pr', '--create-config']): + with patch("pr.core.config_loader.create_default_config", return_value=False): + with patch("sys.argv", ["pr", "--create-config"]): main() captured = capsys.readouterr() - assert 'Error creating configuration file' in captured.err + assert "Error creating configuration file" in captured.err + def test_main_list_sessions_no_sessions(capsys): - with patch('pr.core.session.SessionManager') as mock_sm: + with patch("pr.core.session.SessionManager") as mock_sm: mock_instance = mock_sm.return_value mock_instance.list_sessions.return_value = [] - with patch('sys.argv', ['pr', '--list-sessions']): + with patch("sys.argv", ["pr", "--list-sessions"]): main() captured = capsys.readouterr() - assert 'No saved sessions found' in captured.out + assert "No saved sessions found" in captured.out + def test_main_list_sessions_with_sessions(capsys): - sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}] - with patch('pr.core.session.SessionManager') as mock_sm: + sessions = [{"name": "test", "created_at": "2023-01-01", "message_count": 5}] + with patch("pr.core.session.SessionManager") as mock_sm: mock_instance = mock_sm.return_value mock_instance.list_sessions.return_value = sessions - with patch('sys.argv', ['pr', '--list-sessions']): + with patch("sys.argv", ["pr", "--list-sessions"]): main() captured = capsys.readouterr() - assert 'Found 1 saved sessions' in captured.out - assert 'test' in captured.out + assert "Found 1 saved sessions" in captured.out + assert "test" in captured.out + def test_main_delete_session_success(capsys): - with patch('pr.core.session.SessionManager') as mock_sm: + with patch("pr.core.session.SessionManager") as mock_sm: mock_instance = mock_sm.return_value mock_instance.delete_session.return_value = True - with patch('sys.argv', ['pr', '--delete-session', 'test']): + with patch("sys.argv", ["pr", "--delete-session", "test"]): main() captured = capsys.readouterr() assert "Session 'test' deleted" in captured.out + def test_main_delete_session_fail(capsys): - with patch('pr.core.session.SessionManager') as mock_sm: + with patch("pr.core.session.SessionManager") as mock_sm: mock_instance = mock_sm.return_value mock_instance.delete_session.return_value = False - with patch('sys.argv', ['pr', '--delete-session', 'test']): + with patch("sys.argv", ["pr", "--delete-session", "test"]): main() captured = capsys.readouterr() assert "Error deleting session 'test'" in captured.err + def test_main_export_session_json(capsys): - with patch('pr.core.session.SessionManager') as mock_sm: + with patch("pr.core.session.SessionManager") as mock_sm: mock_instance = mock_sm.return_value mock_instance.export_session.return_value = True - with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']): + with patch("sys.argv", ["pr", "--export-session", "test", "output.json"]): main() captured = capsys.readouterr() - assert 'Session exported to output.json' in captured.out + assert "Session exported to output.json" in captured.out + def test_main_export_session_md(capsys): - with patch('pr.core.session.SessionManager') as mock_sm: + with patch("pr.core.session.SessionManager") as mock_sm: mock_instance = mock_sm.return_value mock_instance.export_session.return_value = True - with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']): + with patch("sys.argv", ["pr", "--export-session", "test", "output.md"]): main() captured = capsys.readouterr() - assert 'Session exported to output.md' in captured.out + assert "Session exported to output.md" in captured.out + def test_main_usage(capsys): - usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01} - with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage): - with patch('sys.argv', ['pr', '--usage']): + usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01} + with patch( + "pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage + ): + with patch("sys.argv", ["pr", "--usage"]): main() captured = capsys.readouterr() - assert 'Total Usage Statistics' in captured.out - assert 'Requests: 10' in captured.out + assert "Total Usage Statistics" in captured.out + assert "Requests: 10" in captured.out + def test_main_plugins_no_plugins(capsys): - with patch('pr.plugins.loader.PluginLoader') as mock_loader: + with patch("pr.plugins.loader.PluginLoader") as mock_loader: mock_instance = mock_loader.return_value mock_instance.load_plugins.return_value = None mock_instance.list_loaded_plugins.return_value = [] - with patch('sys.argv', ['pr', '--plugins']): + with patch("sys.argv", ["pr", "--plugins"]): main() captured = capsys.readouterr() - assert 'No plugins loaded' in captured.out + assert "No plugins loaded" in captured.out + def test_main_plugins_with_plugins(capsys): - with patch('pr.plugins.loader.PluginLoader') as mock_loader: + with patch("pr.plugins.loader.PluginLoader") as mock_loader: mock_instance = mock_loader.return_value mock_instance.load_plugins.return_value = None - mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2'] - with patch('sys.argv', ['pr', '--plugins']): + mock_instance.list_loaded_plugins.return_value = ["plugin1", "plugin2"] + with patch("sys.argv", ["pr", "--plugins"]): main() captured = capsys.readouterr() - assert 'Loaded 2 plugins' in captured.out + assert "Loaded 2 plugins" in captured.out + def test_main_run_assistant(): - with patch('pr.__main__.Assistant') as mock_assistant: + with patch("pr.__main__.Assistant") as mock_assistant: mock_instance = mock_assistant.return_value - with patch('sys.argv', ['pr', 'test message']): + with patch("sys.argv", ["pr", "test message"]): main() mock_assistant.assert_called_once() - mock_instance.run.assert_called_once() \ No newline at end of file + mock_instance.run.assert_called_once() diff --git a/tests/test_session.py b/tests/test_session.py index 833d692..a7b3838 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,116 +1,131 @@ -import pytest -import tempfile -import os import json +import os + +import pytest + from pr.core.session import SessionManager + @pytest.fixture def temp_sessions_dir(tmp_path, monkeypatch): from pr.core import session + original_dir = session.SESSIONS_DIR - monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path)) + monkeypatch.setattr(session, "SESSIONS_DIR", str(tmp_path)) # Clean any existing files import shutil + if os.path.exists(str(tmp_path)): shutil.rmtree(str(tmp_path)) os.makedirs(str(tmp_path), exist_ok=True) yield tmp_path - monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir) + monkeypatch.setattr(session, "SESSIONS_DIR", original_dir) + def test_session_manager_init(temp_sessions_dir): - manager = SessionManager() + SessionManager() assert os.path.exists(temp_sessions_dir) + def test_save_and_load_session(temp_sessions_dir): manager = SessionManager() name = "test_session" messages = [{"role": "user", "content": "Hello"}] metadata = {"test": True} - + assert manager.save_session(name, messages, metadata) - + loaded = manager.load_session(name) assert loaded is not None - assert loaded['name'] == name - assert loaded['messages'] == messages - assert loaded['metadata'] == metadata + assert loaded["name"] == name + assert loaded["messages"] == messages + assert loaded["metadata"] == metadata + def test_load_nonexistent_session(temp_sessions_dir): manager = SessionManager() loaded = manager.load_session("nonexistent") assert loaded is None + def test_list_sessions(temp_sessions_dir): manager = SessionManager() # Save a session manager.save_session("session1", [{"role": "user", "content": "Hi"}]) manager.save_session("session2", [{"role": "user", "content": "Hello"}]) - + sessions = manager.list_sessions() assert len(sessions) == 2 - assert sessions[0]['name'] == "session2" # sorted by created_at desc + assert sessions[0]["name"] == "session2" # sorted by created_at desc + def test_delete_session(temp_sessions_dir): manager = SessionManager() name = "to_delete" manager.save_session(name, [{"role": "user", "content": "Test"}]) - + assert manager.delete_session(name) assert manager.load_session(name) is None + def test_delete_nonexistent_session(temp_sessions_dir): manager = SessionManager() assert not manager.delete_session("nonexistent") + def test_export_session_json(temp_sessions_dir, tmp_path): manager = SessionManager() name = "export_test" messages = [{"role": "user", "content": "Export me"}] manager.save_session(name, messages) - + output_path = tmp_path / "exported.json" - assert manager.export_session(name, str(output_path), 'json') + assert manager.export_session(name, str(output_path), "json") assert output_path.exists() - + with open(output_path) as f: data = json.load(f) - assert data['name'] == name + assert data["name"] == name + def test_export_session_markdown(temp_sessions_dir, tmp_path): manager = SessionManager() name = "export_md" messages = [{"role": "user", "content": "Markdown export"}] manager.save_session(name, messages) - + output_path = tmp_path / "exported.md" - assert manager.export_session(name, str(output_path), 'markdown') + assert manager.export_session(name, str(output_path), "markdown") assert output_path.exists() - + content = output_path.read_text() assert "# Session: export_md" in content + def test_export_session_txt(temp_sessions_dir, tmp_path): manager = SessionManager() name = "export_txt" messages = [{"role": "user", "content": "Text export"}] manager.save_session(name, messages) - + output_path = tmp_path / "exported.txt" - assert manager.export_session(name, str(output_path), 'txt') + assert manager.export_session(name, str(output_path), "txt") assert output_path.exists() - + content = output_path.read_text() assert "Session: export_txt" in content + def test_export_nonexistent_session(temp_sessions_dir, tmp_path): manager = SessionManager() output_path = tmp_path / "nonexistent.json" - assert not manager.export_session("nonexistent", str(output_path), 'json') + assert not manager.export_session("nonexistent", str(output_path), "json") + def test_export_unsupported_format(temp_sessions_dir, tmp_path): manager = SessionManager() name = "test" manager.save_session(name, [{"role": "user", "content": "Test"}]) - + output_path = tmp_path / "test.unsupported" - assert not manager.export_session(name, str(output_path), 'unsupported') + assert not manager.export_session(name, str(output_path), "unsupported") diff --git a/tests/test_tools.py b/tests/test_tools.py index 6a33c01..5854c4b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,69 +1,68 @@ -import pytest import os -import tempfile -from pr.tools.filesystem import read_file, write_file, list_directory, search_replace -from pr.tools.patch import apply_patch, create_diff + from pr.tools.base import get_tools_definition +from pr.tools.filesystem import list_directory, read_file, search_replace, write_file +from pr.tools.patch import apply_patch, create_diff class TestFilesystemTools: def test_write_and_read_file(self, temp_dir): - filepath = os.path.join(temp_dir, 'test.txt') - content = 'Hello, World!' + filepath = os.path.join(temp_dir, "test.txt") + content = "Hello, World!" write_result = write_file(filepath, content) - assert write_result['status'] == 'success' + assert write_result["status"] == "success" read_result = read_file(filepath) - assert read_result['status'] == 'success' - assert content in read_result['content'] + assert read_result["status"] == "success" + assert content in read_result["content"] def test_read_nonexistent_file(self): - result = read_file('/nonexistent/path/file.txt') - assert result['status'] == 'error' + result = read_file("/nonexistent/path/file.txt") + assert result["status"] == "error" def test_list_directory(self, temp_dir): - test_file = os.path.join(temp_dir, 'testfile.txt') - with open(test_file, 'w') as f: - f.write('test') + test_file = os.path.join(temp_dir, "testfile.txt") + with open(test_file, "w") as f: + f.write("test") result = list_directory(temp_dir) - assert result['status'] == 'success' - assert any(item['name'] == 'testfile.txt' for item in result['items']) + assert result["status"] == "success" + assert any(item["name"] == "testfile.txt" for item in result["items"]) def test_search_replace(self, temp_dir): - filepath = os.path.join(temp_dir, 'test.txt') - content = 'Hello, World!' - with open(filepath, 'w') as f: + filepath = os.path.join(temp_dir, "test.txt") + content = "Hello, World!" + with open(filepath, "w") as f: f.write(content) - result = search_replace(filepath, 'World', 'Universe') - assert result['status'] == 'success' + result = search_replace(filepath, "World", "Universe") + assert result["status"] == "success" read_result = read_file(filepath) - assert 'Hello, Universe!' in read_result['content'] + assert "Hello, Universe!" in read_result["content"] class TestPatchTools: def test_create_diff(self, temp_dir): - file1 = os.path.join(temp_dir, 'file1.txt') - file2 = os.path.join(temp_dir, 'file2.txt') - with open(file1, 'w') as f: - f.write('line1\nline2\nline3\n') - with open(file2, 'w') as f: - f.write('line1\nline2 modified\nline3\n') + file1 = os.path.join(temp_dir, "file1.txt") + file2 = os.path.join(temp_dir, "file2.txt") + with open(file1, "w") as f: + f.write("line1\nline2\nline3\n") + with open(file2, "w") as f: + f.write("line1\nline2 modified\nline3\n") result = create_diff(file1, file2) - assert result['status'] == 'success' - assert 'line2' in result['diff'] - assert 'line2 modified' in result['diff'] + assert result["status"] == "success" + assert "line2" in result["diff"] + assert "line2 modified" in result["diff"] def test_apply_patch(self, temp_dir): - filepath = os.path.join(temp_dir, 'file.txt') - with open(filepath, 'w') as f: - f.write('line1\nline2\nline3\n') + filepath = os.path.join(temp_dir, "file.txt") + with open(filepath, "w") as f: + f.write("line1\nline2\nline3\n") # Create a simple patch patch_content = """--- a/file.txt @@ -75,10 +74,10 @@ class TestPatchTools: line3 """ result = apply_patch(filepath, patch_content) - assert result['status'] == 'success' + assert result["status"] == "success" read_result = read_file(filepath) - assert 'line2 modified' in read_result['content'] + assert "line2 modified" in read_result["content"] class TestToolDefinitions: @@ -92,27 +91,27 @@ class TestToolDefinitions: tools = get_tools_definition() for tool in tools: - assert 'type' in tool - assert tool['type'] == 'function' - assert 'function' in tool + assert "type" in tool + assert tool["type"] == "function" + assert "function" in tool - func = tool['function'] - assert 'name' in func - assert 'description' in func - assert 'parameters' in func + func = tool["function"] + assert "name" in func + assert "description" in func + assert "parameters" in func def test_filesystem_tools_present(self): tools = get_tools_definition() - tool_names = [t['function']['name'] for t in tools] + tool_names = [t["function"]["name"] for t in tools] - assert 'read_file' in tool_names - assert 'write_file' in tool_names - assert 'list_directory' in tool_names - assert 'search_replace' in tool_names + assert "read_file" in tool_names + assert "write_file" in tool_names + assert "list_directory" in tool_names + assert "search_replace" in tool_names def test_patch_tools_present(self): tools = get_tools_definition() - tool_names = [t['function']['name'] for t in tools] + tool_names = [t["function"]["name"] for t in tools] - assert 'apply_patch' in tool_names - assert 'create_diff' in tool_names \ No newline at end of file + assert "apply_patch" in tool_names + assert "create_diff" in tool_names diff --git a/tests/test_usage_tracker.py b/tests/test_usage_tracker.py index 5f6591d..3b7116c 100644 --- a/tests/test_usage_tracker.py +++ b/tests/test_usage_tracker.py @@ -1,86 +1,110 @@ -import pytest -import tempfile -import os import json +import os + +import pytest + from pr.core.usage_tracker import UsageTracker + @pytest.fixture def temp_usage_file(tmp_path, monkeypatch): from pr.core import usage_tracker + original_file = usage_tracker.USAGE_DB_FILE temp_file = str(tmp_path / "usage.json") - monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file) + monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", temp_file) yield temp_file if os.path.exists(temp_file): os.remove(temp_file) - monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file) + monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", original_file) + def test_usage_tracker_init(): tracker = UsageTracker() summary = tracker.get_session_summary() - assert summary['requests'] == 0 - assert summary['total_tokens'] == 0 - assert summary['estimated_cost'] == 0.0 + assert summary["requests"] == 0 + assert summary["total_tokens"] == 0 + assert summary["estimated_cost"] == 0.0 + def test_track_request_known_model(): tracker = UsageTracker() - tracker.track_request('gpt-3.5-turbo', 100, 50) - + tracker.track_request("gpt-3.5-turbo", 100, 50) + summary = tracker.get_session_summary() - assert summary['requests'] == 1 - assert summary['input_tokens'] == 100 - assert summary['output_tokens'] == 50 - assert summary['total_tokens'] == 150 - assert 'gpt-3.5-turbo' in summary['models_used'] + assert summary["requests"] == 1 + assert summary["input_tokens"] == 100 + assert summary["output_tokens"] == 50 + assert summary["total_tokens"] == 150 + assert "gpt-3.5-turbo" in summary["models_used"] # Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125 - assert abs(summary['estimated_cost'] - 0.000125) < 1e-6 + assert abs(summary["estimated_cost"] - 0.000125) < 1e-6 + def test_track_request_unknown_model(): tracker = UsageTracker() - tracker.track_request('unknown-model', 100, 50) - + tracker.track_request("unknown-model", 100, 50) + summary = tracker.get_session_summary() - assert summary['requests'] == 1 - assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0 + assert summary["requests"] == 1 + assert summary["estimated_cost"] == 0.0 # Unknown model, cost 0 + def test_track_request_multiple(): tracker = UsageTracker() - tracker.track_request('gpt-3.5-turbo', 100, 50) - tracker.track_request('gpt-4', 200, 100) - + tracker.track_request("gpt-3.5-turbo", 100, 50) + tracker.track_request("gpt-4", 200, 100) + summary = tracker.get_session_summary() - assert summary['requests'] == 2 - assert summary['input_tokens'] == 300 - assert summary['output_tokens'] == 150 - assert summary['total_tokens'] == 450 - assert len(summary['models_used']) == 2 + assert summary["requests"] == 2 + assert summary["input_tokens"] == 300 + assert summary["output_tokens"] == 150 + assert summary["total_tokens"] == 450 + assert len(summary["models_used"]) == 2 + def test_get_formatted_summary(): tracker = UsageTracker() - tracker.track_request('gpt-3.5-turbo', 100, 50) - + tracker.track_request("gpt-3.5-turbo", 100, 50) + formatted = tracker.get_formatted_summary() assert "Total Requests: 1" in formatted assert "Total Tokens: 150" in formatted assert "Estimated Cost: $0.0001" in formatted assert "gpt-3.5-turbo" in formatted + def test_get_total_usage_no_file(temp_usage_file): total = UsageTracker.get_total_usage() - assert total['total_requests'] == 0 - assert total['total_tokens'] == 0 - assert total['total_cost'] == 0.0 + assert total["total_requests"] == 0 + assert total["total_tokens"] == 0 + assert total["total_cost"] == 0.0 + def test_get_total_usage_with_data(temp_usage_file): # Manually create history file history = [ - {'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125}, - {'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008} + { + "timestamp": "2023-01-01", + "model": "gpt-3.5-turbo", + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + "cost": 0.000125, + }, + { + "timestamp": "2023-01-02", + "model": "gpt-4", + "input_tokens": 200, + "output_tokens": 100, + "total_tokens": 300, + "cost": 0.008, + }, ] - with open(temp_usage_file, 'w') as f: + with open(temp_usage_file, "w") as f: json.dump(history, f) - + total = UsageTracker.get_total_usage() - assert total['total_requests'] == 2 - assert total['total_tokens'] == 450 - assert abs(total['total_cost'] - 0.008125) < 1e-6 + assert total["total_requests"] == 2 + assert total["total_tokens"] == 450 + assert abs(total["total_cost"] - 0.008125) < 1e-6 diff --git a/tests/test_validation.py b/tests/test_validation.py index 7e1a2b4..d5d12b1 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,49 +1,59 @@ -import pytest -import tempfile import os +import tempfile + +import pytest + +from pr.core.exceptions import ValidationError from pr.core.validation import ( - validate_file_path, - validate_directory_path, - validate_model_name, validate_api_url, + validate_directory_path, + validate_file_path, + validate_max_tokens, + validate_model_name, validate_session_name, validate_temperature, - validate_max_tokens, ) -from pr.core.exceptions import ValidationError + def test_validate_file_path_empty(): with pytest.raises(ValidationError, match="File path cannot be empty"): validate_file_path("") + def test_validate_file_path_not_exist(): with pytest.raises(ValidationError, match="File does not exist"): validate_file_path("/nonexistent/file.txt", must_exist=True) + def test_validate_file_path_is_dir(): with tempfile.TemporaryDirectory() as tmpdir: with pytest.raises(ValidationError, match="Path is a directory"): validate_file_path(tmpdir, must_exist=True) + def test_validate_file_path_valid(): with tempfile.NamedTemporaryFile() as tmpfile: result = validate_file_path(tmpfile.name, must_exist=True) assert os.path.isabs(result) assert result == os.path.abspath(tmpfile.name) + def test_validate_directory_path_empty(): with pytest.raises(ValidationError, match="Directory path cannot be empty"): validate_directory_path("") + def test_validate_directory_path_not_exist(): with pytest.raises(ValidationError, match="Directory does not exist"): validate_directory_path("/nonexistent/dir", must_exist=True) + def test_validate_directory_path_not_dir(): with tempfile.NamedTemporaryFile() as tmpfile: with pytest.raises(ValidationError, match="Path is not a directory"): validate_directory_path(tmpfile.name, must_exist=True) + def test_validate_directory_path_create(): with tempfile.TemporaryDirectory() as tmpdir: new_dir = os.path.join(tmpdir, "new_dir") @@ -51,72 +61,89 @@ def test_validate_directory_path_create(): assert os.path.isdir(new_dir) assert result == os.path.abspath(new_dir) + def test_validate_directory_path_valid(): with tempfile.TemporaryDirectory() as tmpdir: result = validate_directory_path(tmpdir, must_exist=True) assert result == os.path.abspath(tmpdir) + def test_validate_model_name_empty(): with pytest.raises(ValidationError, match="Model name cannot be empty"): validate_model_name("") + def test_validate_model_name_too_short(): with pytest.raises(ValidationError, match="Model name too short"): validate_model_name("a") + def test_validate_model_name_valid(): result = validate_model_name("gpt-3.5-turbo") assert result == "gpt-3.5-turbo" + def test_validate_api_url_empty(): with pytest.raises(ValidationError, match="API URL cannot be empty"): validate_api_url("") + def test_validate_api_url_invalid(): with pytest.raises(ValidationError, match="API URL must start with"): validate_api_url("invalid-url") + def test_validate_api_url_valid(): result = validate_api_url("https://api.example.com") assert result == "https://api.example.com" + def test_validate_session_name_empty(): with pytest.raises(ValidationError, match="Session name cannot be empty"): validate_session_name("") + def test_validate_session_name_invalid_char(): with pytest.raises(ValidationError, match="contains invalid character"): validate_session_name("test/session") + def test_validate_session_name_too_long(): long_name = "a" * 256 with pytest.raises(ValidationError, match="Session name too long"): validate_session_name(long_name) + def test_validate_session_name_valid(): result = validate_session_name("valid_session_123") assert result == "valid_session_123" + def test_validate_temperature_too_low(): with pytest.raises(ValidationError, match="Temperature must be between"): validate_temperature(-0.1) + def test_validate_temperature_too_high(): with pytest.raises(ValidationError, match="Temperature must be between"): validate_temperature(2.1) + def test_validate_temperature_valid(): result = validate_temperature(0.7) assert result == 0.7 + def test_validate_max_tokens_too_low(): with pytest.raises(ValidationError, match="Max tokens must be at least 1"): validate_max_tokens(0) + def test_validate_max_tokens_too_high(): with pytest.raises(ValidationError, match="Max tokens too high"): validate_max_tokens(100001) + def test_validate_max_tokens_valid(): result = validate_max_tokens(1000) assert result == 1000