Update.
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (ubuntu-latest, 3.8) (push) Waiting to run
Tests / test (ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 39s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 55s
Tests / test (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / test (ubuntu-latest, 3.12) (push) Has been cancelled

This commit is contained in:
retoor 2025-11-04 08:09:12 +01:00
parent 5f04811dcc
commit 1a29ee4918
82 changed files with 4965 additions and 3096 deletions

View File

@ -159,6 +159,7 @@ def tool_function(args):
"""Implementation""" """Implementation"""
pass pass
def register_tools(): def register_tools():
"""Return list of tool definitions""" """Return list of tool definitions"""
return [...] return [...]
@ -177,8 +178,8 @@ def register_tools():
```python ```python
def test_read_file_with_valid_path_returns_content(temp_dir): def test_read_file_with_valid_path_returns_content(temp_dir):
# Arrange # Arrange
filepath = os.path.join(temp_dir, 'test.txt') filepath = os.path.join(temp_dir, "test.txt")
expected_content = 'Hello, World!' expected_content = "Hello, World!"
write_file(filepath, expected_content) write_file(filepath, expected_content)
# Act # Act

View File

@ -1,4 +1,4 @@
from pr.core import Assistant from pr.core import Assistant
__version__ = '1.0.0' __version__ = "1.0.0"
__all__ = ['Assistant'] __all__ = ["Assistant"]

View File

@ -1,12 +1,14 @@
import argparse import argparse
import sys import sys
from pr.core import Assistant
from pr import __version__ from pr import __version__
from pr.core import Assistant
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='PR Assistant - Professional CLI AI assistant with autonomous execution', description="PR Assistant - Professional CLI AI assistant with autonomous execution",
epilog=''' epilog="""
Examples: Examples:
pr "What is Python?" # Single query pr "What is Python?" # Single query
pr -i # Interactive mode pr -i # Interactive mode
@ -25,43 +27,79 @@ Commands in interactive mode:
/usage - Show usage statistics /usage - Show usage statistics
/save <name> - Save current session /save <name> - Save current session
exit, quit, q - Exit the program exit, quit, q - Exit the program
''', """,
formatter_class=argparse.RawDescriptionHelpFormatter formatter_class=argparse.RawDescriptionHelpFormatter,
) )
parser.add_argument('message', nargs='?', help='Message to send to assistant') parser.add_argument("message", nargs="?", help="Message to send to assistant")
parser.add_argument('--version', action='version', version=f'PR Assistant {__version__}') parser.add_argument(
parser.add_argument('-m', '--model', help='AI model to use') "--version", action="version", version=f"PR Assistant {__version__}"
parser.add_argument('-u', '--api-url', help='API endpoint URL') )
parser.add_argument('--model-list-url', help='Model list endpoint URL') parser.add_argument("-m", "--model", help="AI model to use")
parser.add_argument('-i', '--interactive', action='store_true', help='Interactive mode') parser.add_argument("-u", "--api-url", help="API endpoint URL")
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output') parser.add_argument("--model-list-url", help="Model list endpoint URL")
parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging') parser.add_argument(
parser.add_argument('--no-syntax', action='store_true', help='Disable syntax highlighting') "-i", "--interactive", action="store_true", help="Interactive mode"
parser.add_argument('--include-env', action='store_true', help='Include environment variables in context') )
parser.add_argument('-c', '--context', action='append', help='Additional context files') parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument('--api-mode', action='store_true', help='API mode for specialized interaction') parser.add_argument(
"--debug", action="store_true", help="Enable debug mode with detailed logging"
)
parser.add_argument(
"--no-syntax", action="store_true", help="Disable syntax highlighting"
)
parser.add_argument(
"--include-env",
action="store_true",
help="Include environment variables in context",
)
parser.add_argument(
"-c", "--context", action="append", help="Additional context files"
)
parser.add_argument(
"--api-mode", action="store_true", help="API mode for specialized interaction"
)
parser.add_argument('--output', choices=['text', 'json', 'structured'], parser.add_argument(
default='text', help='Output format') "--output",
parser.add_argument('--quiet', action='store_true', help='Minimal output') choices=["text", "json", "structured"],
default="text",
help="Output format",
)
parser.add_argument("--quiet", action="store_true", help="Minimal output")
parser.add_argument('--save-session', metavar='NAME', help='Save session with given name') parser.add_argument(
parser.add_argument('--load-session', metavar='NAME', help='Load session with given name') "--save-session", metavar="NAME", help="Save session with given name"
parser.add_argument('--list-sessions', action='store_true', help='List all saved sessions') )
parser.add_argument('--delete-session', metavar='NAME', help='Delete a saved session') parser.add_argument(
parser.add_argument('--export-session', nargs=2, metavar=('NAME', 'FILE'), "--load-session", metavar="NAME", help="Load session with given name"
help='Export session to file') )
parser.add_argument(
"--list-sessions", action="store_true", help="List all saved sessions"
)
parser.add_argument(
"--delete-session", metavar="NAME", help="Delete a saved session"
)
parser.add_argument(
"--export-session",
nargs=2,
metavar=("NAME", "FILE"),
help="Export session to file",
)
parser.add_argument('--usage', action='store_true', help='Show token usage statistics') parser.add_argument(
parser.add_argument('--create-config', action='store_true', "--usage", action="store_true", help="Show token usage statistics"
help='Create default configuration file') )
parser.add_argument('--plugins', action='store_true', help='List loaded plugins') parser.add_argument(
"--create-config", action="store_true", help="Create default configuration file"
)
parser.add_argument("--plugins", action="store_true", help="List loaded plugins")
args = parser.parse_args() args = parser.parse_args()
if args.create_config: if args.create_config:
from pr.core.config_loader import create_default_config from pr.core.config_loader import create_default_config
if create_default_config(): if create_default_config():
print("Configuration file created at ~/.prrc") print("Configuration file created at ~/.prrc")
else: else:
@ -70,6 +108,7 @@ Commands in interactive mode:
if args.list_sessions: if args.list_sessions:
from pr.core.session import SessionManager from pr.core.session import SessionManager
sm = SessionManager() sm = SessionManager()
sessions = sm.list_sessions() sessions = sm.list_sessions()
if not sessions: if not sessions:
@ -85,6 +124,7 @@ Commands in interactive mode:
if args.delete_session: if args.delete_session:
from pr.core.session import SessionManager from pr.core.session import SessionManager
sm = SessionManager() sm = SessionManager()
if sm.delete_session(args.delete_session): if sm.delete_session(args.delete_session):
print(f"Session '{args.delete_session}' deleted") print(f"Session '{args.delete_session}' deleted")
@ -94,13 +134,14 @@ Commands in interactive mode:
if args.export_session: if args.export_session:
from pr.core.session import SessionManager from pr.core.session import SessionManager
sm = SessionManager() sm = SessionManager()
name, output_file = args.export_session name, output_file = args.export_session
format_type = 'json' format_type = "json"
if output_file.endswith('.md'): if output_file.endswith(".md"):
format_type = 'markdown' format_type = "markdown"
elif output_file.endswith('.txt'): elif output_file.endswith(".txt"):
format_type = 'txt' format_type = "txt"
if sm.export_session(name, output_file, format_type): if sm.export_session(name, output_file, format_type):
print(f"Session exported to {output_file}") print(f"Session exported to {output_file}")
@ -110,6 +151,7 @@ Commands in interactive mode:
if args.usage: if args.usage:
from pr.core.usage_tracker import UsageTracker from pr.core.usage_tracker import UsageTracker
usage = UsageTracker.get_total_usage() usage = UsageTracker.get_total_usage()
print(f"\nTotal Usage Statistics:") print(f"\nTotal Usage Statistics:")
print(f" Requests: {usage['total_requests']}") print(f" Requests: {usage['total_requests']}")
@ -119,6 +161,7 @@ Commands in interactive mode:
if args.plugins: if args.plugins:
from pr.plugins.loader import PluginLoader from pr.plugins.loader import PluginLoader
loader = PluginLoader() loader = PluginLoader()
loader.load_plugins() loader.load_plugins()
plugins = loader.list_loaded_plugins() plugins = loader.list_loaded_plugins()
@ -133,5 +176,6 @@ Commands in interactive mode:
assistant = Assistant(args) assistant = Assistant(args)
assistant.run() assistant.run()
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -1,6 +1,13 @@
from .agent_communication import AgentCommunicationBus, AgentMessage
from .agent_manager import AgentInstance, AgentManager
from .agent_roles import AgentRole, get_agent_role, list_agent_roles from .agent_roles import AgentRole, get_agent_role, list_agent_roles
from .agent_manager import AgentManager, AgentInstance
from .agent_communication import AgentMessage, AgentCommunicationBus
__all__ = ['AgentRole', 'get_agent_role', 'list_agent_roles', 'AgentManager', 'AgentInstance', __all__ = [
'AgentMessage', 'AgentCommunicationBus'] "AgentRole",
"get_agent_role",
"list_agent_roles",
"AgentManager",
"AgentInstance",
"AgentMessage",
"AgentCommunicationBus",
]

View File

@ -1,14 +1,16 @@
import sqlite3
import json import json
from typing import List, Optional import sqlite3
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional
class MessageType(Enum): class MessageType(Enum):
REQUEST = "request" REQUEST = "request"
RESPONSE = "response" RESPONSE = "response"
NOTIFICATION = "notification" NOTIFICATION = "notification"
@dataclass @dataclass
class AgentMessage: class AgentMessage:
message_id: str message_id: str
@ -21,27 +23,28 @@ class AgentMessage:
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
'message_id': self.message_id, "message_id": self.message_id,
'from_agent': self.from_agent, "from_agent": self.from_agent,
'to_agent': self.to_agent, "to_agent": self.to_agent,
'message_type': self.message_type.value, "message_type": self.message_type.value,
'content': self.content, "content": self.content,
'metadata': self.metadata, "metadata": self.metadata,
'timestamp': self.timestamp "timestamp": self.timestamp,
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'AgentMessage': def from_dict(cls, data: dict) -> "AgentMessage":
return cls( return cls(
message_id=data['message_id'], message_id=data["message_id"],
from_agent=data['from_agent'], from_agent=data["from_agent"],
to_agent=data['to_agent'], to_agent=data["to_agent"],
message_type=MessageType(data['message_type']), message_type=MessageType(data["message_type"]),
content=data['content'], content=data["content"],
metadata=data['metadata'], metadata=data["metadata"],
timestamp=data['timestamp'] timestamp=data["timestamp"],
) )
class AgentCommunicationBus: class AgentCommunicationBus:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
@ -50,7 +53,8 @@ class AgentCommunicationBus:
def _create_tables(self): def _create_tables(self):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS agent_messages ( CREATE TABLE IF NOT EXISTS agent_messages (
message_id TEXT PRIMARY KEY, message_id TEXT PRIMARY KEY,
from_agent TEXT, from_agent TEXT,
@ -62,70 +66,88 @@ class AgentCommunicationBus:
session_id TEXT, session_id TEXT,
read INTEGER DEFAULT 0 read INTEGER DEFAULT 0
) )
''') """
)
self.conn.commit() self.conn.commit()
def send_message(self, message: AgentMessage, session_id: Optional[str] = None): def send_message(self, message: AgentMessage, session_id: Optional[str] = None):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT INTO agent_messages INSERT INTO agent_messages
(message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id) (message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', ( """,
message.message_id, (
message.from_agent, message.message_id,
message.to_agent, message.from_agent,
message.message_type.value, message.to_agent,
message.content, message.message_type.value,
json.dumps(message.metadata), message.content,
message.timestamp, json.dumps(message.metadata),
session_id message.timestamp,
)) session_id,
),
)
self.conn.commit() self.conn.commit()
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: def get_messages(
self, agent_id: str, unread_only: bool = True
) -> List[AgentMessage]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
if unread_only: if unread_only:
cursor.execute(''' cursor.execute(
"""
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages FROM agent_messages
WHERE to_agent = ? AND read = 0 WHERE to_agent = ? AND read = 0
ORDER BY timestamp ASC ORDER BY timestamp ASC
''', (agent_id,)) """,
(agent_id,),
)
else: else:
cursor.execute(''' cursor.execute(
"""
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages FROM agent_messages
WHERE to_agent = ? WHERE to_agent = ?
ORDER BY timestamp ASC ORDER BY timestamp ASC
''', (agent_id,)) """,
(agent_id,),
)
messages = [] messages = []
for row in cursor.fetchall(): for row in cursor.fetchall():
messages.append(AgentMessage( messages.append(
message_id=row[0], AgentMessage(
from_agent=row[1], message_id=row[0],
to_agent=row[2], from_agent=row[1],
message_type=MessageType(row[3]), to_agent=row[2],
content=row[4], message_type=MessageType(row[3]),
metadata=json.loads(row[5]) if row[5] else {}, content=row[4],
timestamp=row[6] metadata=json.loads(row[5]) if row[5] else {},
)) timestamp=row[6],
)
)
return messages return messages
def mark_as_read(self, message_id: str): def mark_as_read(self, message_id: str):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,)) cursor.execute(
"UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,)
)
self.conn.commit() self.conn.commit()
def clear_messages(self, session_id: Optional[str] = None): def clear_messages(self, session_id: Optional[str] = None):
cursor = self.conn.cursor() cursor = self.conn.cursor()
if session_id: if session_id:
cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,)) cursor.execute(
"DELETE FROM agent_messages WHERE session_id = ?", (session_id,)
)
else: else:
cursor.execute('DELETE FROM agent_messages') cursor.execute("DELETE FROM agent_messages")
self.conn.commit() self.conn.commit()
def close(self): def close(self):
@ -134,24 +156,31 @@ class AgentCommunicationBus:
def receive_messages(self, agent_id: str) -> List[AgentMessage]: def receive_messages(self, agent_id: str) -> List[AgentMessage]:
return self.get_messages(agent_id, unread_only=True) return self.get_messages(agent_id, unread_only=True)
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]: def get_conversation_history(
self, agent_a: str, agent_b: str
) -> List[AgentMessage]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages FROM agent_messages
WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?) WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?)
ORDER BY timestamp ASC ORDER BY timestamp ASC
''', (agent_a, agent_b, agent_b, agent_a)) """,
(agent_a, agent_b, agent_b, agent_a),
)
messages = [] messages = []
for row in cursor.fetchall(): for row in cursor.fetchall():
messages.append(AgentMessage( messages.append(
message_id=row[0], AgentMessage(
from_agent=row[1], message_id=row[0],
to_agent=row[2], from_agent=row[1],
message_type=MessageType(row[3]), to_agent=row[2],
content=row[4], message_type=MessageType(row[3]),
metadata=json.loads(row[5]) if row[5] else {}, content=row[4],
timestamp=row[6] metadata=json.loads(row[5]) if row[5] else {},
)) timestamp=row[6],
)
)
return messages return messages

View File

@ -1,11 +1,13 @@
import time
import json import json
import time
import uuid import uuid
from typing import Dict, List, Any, Optional, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .agent_roles import AgentRole, get_agent_role from typing import Any, Callable, Dict, List, Optional
from .agent_communication import AgentMessage, AgentCommunicationBus, MessageType
from ..memory.knowledge_store import KnowledgeStore from ..memory.knowledge_store import KnowledgeStore
from .agent_communication import AgentCommunicationBus, AgentMessage, MessageType
from .agent_roles import AgentRole, get_agent_role
@dataclass @dataclass
class AgentInstance: class AgentInstance:
@ -17,21 +19,20 @@ class AgentInstance:
task_count: int = 0 task_count: int = 0
def add_message(self, role: str, content: str): def add_message(self, role: str, content: str):
self.message_history.append({ self.message_history.append(
'role': role, {"role": role, "content": content, "timestamp": time.time()}
'content': content, )
'timestamp': time.time()
})
def get_system_message(self) -> Dict[str, str]: def get_system_message(self) -> Dict[str, str]:
return {'role': 'system', 'content': self.role.system_prompt} return {"role": "system", "content": self.role.system_prompt}
def get_messages_for_api(self) -> List[Dict[str, str]]: def get_messages_for_api(self) -> List[Dict[str, str]]:
return [self.get_system_message()] + [ return [self.get_system_message()] + [
{'role': msg['role'], 'content': msg['content']} {"role": msg["role"], "content": msg["content"]}
for msg in self.message_history for msg in self.message_history
] ]
class AgentManager: class AgentManager:
def __init__(self, db_path: str, api_caller: Callable): def __init__(self, db_path: str, api_caller: Callable):
self.db_path = db_path self.db_path = db_path
@ -46,10 +47,7 @@ class AgentManager:
agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}" agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}"
role = get_agent_role(role_name) role = get_agent_role(role_name)
agent = AgentInstance( agent = AgentInstance(agent_id=agent_id, role=role)
agent_id=agent_id,
role=role
)
self.active_agents[agent_id] = agent self.active_agents[agent_id] = agent
return agent_id return agent_id
@ -63,15 +61,17 @@ class AgentManager:
return True return True
return False return False
def execute_agent_task(self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: def execute_agent_task(
self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
agent = self.get_agent(agent_id) agent = self.get_agent(agent_id)
if not agent: if not agent:
return {'error': f'Agent {agent_id} not found'} return {"error": f"Agent {agent_id} not found"}
if context: if context:
agent.context.update(context) agent.context.update(context)
agent.add_message('user', task) agent.add_message("user", task)
knowledge_matches = self.knowledge_store.search_entries(task, top_k=3) knowledge_matches = self.knowledge_store.search_entries(task, top_k=3)
agent.task_count += 1 agent.task_count += 1
@ -81,35 +81,40 @@ class AgentManager:
for i, entry in enumerate(knowledge_matches, 1): for i, entry in enumerate(knowledge_matches, 1):
shortened_content = entry.content[:2000] shortened_content = entry.content[:2000]
knowledge_content += f"{i}. {shortened_content}\\n\\n" knowledge_content += f"{i}. {shortened_content}\\n\\n"
messages.insert(-1, {'role': 'user', 'content': knowledge_content}) messages.insert(-1, {"role": "user", "content": knowledge_content})
try: try:
response = self.api_caller( response = self.api_caller(
messages=messages, messages=messages,
temperature=agent.role.temperature, temperature=agent.role.temperature,
max_tokens=agent.role.max_tokens max_tokens=agent.role.max_tokens,
) )
if response and 'choices' in response: if response and "choices" in response:
assistant_message = response['choices'][0]['message']['content'] assistant_message = response["choices"][0]["message"]["content"]
agent.add_message('assistant', assistant_message) agent.add_message("assistant", assistant_message)
return { return {
'success': True, "success": True,
'agent_id': agent_id, "agent_id": agent_id,
'response': assistant_message, "response": assistant_message,
'role': agent.role.name, "role": agent.role.name,
'task_count': agent.task_count "task_count": agent.task_count,
} }
else: else:
return {'error': 'Invalid API response', 'agent_id': agent_id} return {"error": "Invalid API response", "agent_id": agent_id}
except Exception as e: except Exception as e:
return {'error': str(e), 'agent_id': agent_id} return {"error": str(e), "agent_id": agent_id}
def send_agent_message(self, from_agent_id: str, to_agent_id: str, def send_agent_message(
content: str, message_type: MessageType = MessageType.REQUEST, self,
metadata: Optional[Dict[str, Any]] = None): from_agent_id: str,
to_agent_id: str,
content: str,
message_type: MessageType = MessageType.REQUEST,
metadata: Optional[Dict[str, Any]] = None,
):
message = AgentMessage( message = AgentMessage(
from_agent=from_agent_id, from_agent=from_agent_id,
to_agent=to_agent_id, to_agent=to_agent_id,
@ -117,57 +122,57 @@ class AgentManager:
content=content, content=content,
metadata=metadata or {}, metadata=metadata or {},
timestamp=time.time(), timestamp=time.time(),
message_id=str(uuid.uuid4())[:16] message_id=str(uuid.uuid4())[:16],
) )
self.communication_bus.send_message(message, self.session_id) self.communication_bus.send_message(message, self.session_id)
return message.message_id return message.message_id
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: def get_agent_messages(
self, agent_id: str, unread_only: bool = True
) -> List[AgentMessage]:
return self.communication_bus.get_messages(agent_id, unread_only) return self.communication_bus.get_messages(agent_id, unread_only)
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]): def collaborate_agents(
self, orchestrator_id: str, task: str, agent_roles: List[str]
):
orchestrator = self.get_agent(orchestrator_id) orchestrator = self.get_agent(orchestrator_id)
if not orchestrator: if not orchestrator:
orchestrator_id = self.create_agent('orchestrator') orchestrator_id = self.create_agent("orchestrator")
orchestrator = self.get_agent(orchestrator_id) orchestrator = self.get_agent(orchestrator_id)
worker_agents = [] worker_agents = []
for role in agent_roles: for role in agent_roles:
agent_id = self.create_agent(role) agent_id = self.create_agent(role)
worker_agents.append({ worker_agents.append({"agent_id": agent_id, "role": role})
'agent_id': agent_id,
'role': role
})
orchestration_prompt = f'''Task: {task} orchestration_prompt = f"""Task: {task}
Available specialized agents: Available specialized agents:
{chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])} {chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])}
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.''' Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results."""
orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt) orchestrator_result = self.execute_agent_task(
orchestrator_id, orchestration_prompt
)
results = { results = {"orchestrator": orchestrator_result, "agents": []}
'orchestrator': orchestrator_result,
'agents': []
}
for agent_info in worker_agents: for agent_info in worker_agents:
agent_id = agent_info['agent_id'] agent_id = agent_info["agent_id"]
messages = self.get_agent_messages(agent_id) messages = self.get_agent_messages(agent_id)
for msg in messages: for msg in messages:
subtask = msg.content subtask = msg.content
result = self.execute_agent_task(agent_id, subtask) result = self.execute_agent_task(agent_id, subtask)
results['agents'].append(result) results["agents"].append(result)
self.send_agent_message( self.send_agent_message(
from_agent_id=agent_id, from_agent_id=agent_id,
to_agent_id=orchestrator_id, to_agent_id=orchestrator_id,
content=result.get('response', ''), content=result.get("response", ""),
message_type=MessageType.RESPONSE message_type=MessageType.RESPONSE,
) )
self.communication_bus.mark_as_read(msg.message_id) self.communication_bus.mark_as_read(msg.message_id)
@ -175,17 +180,17 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
def get_session_summary(self) -> str: def get_session_summary(self) -> str:
summary = { summary = {
'session_id': self.session_id, "session_id": self.session_id,
'active_agents': len(self.active_agents), "active_agents": len(self.active_agents),
'agents': [ "agents": [
{ {
'agent_id': agent_id, "agent_id": agent_id,
'role': agent.role.name, "role": agent.role.name,
'task_count': agent.task_count, "task_count": agent.task_count,
'message_count': len(agent.message_history) "message_count": len(agent.message_history),
} }
for agent_id, agent in self.active_agents.items() for agent_id, agent in self.active_agents.items()
] ],
} }
return json.dumps(summary) return json.dumps(summary)

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict, Any, Set from typing import Dict, List, Set
@dataclass @dataclass
class AgentRole: class AgentRole:
@ -11,182 +12,262 @@ class AgentRole:
temperature: float = 0.7 temperature: float = 0.7
max_tokens: int = 4096 max_tokens: int = 4096
AGENT_ROLES = { AGENT_ROLES = {
'coding': AgentRole( "coding": AgentRole(
name='coding', name="coding",
description='Specialized in writing, reviewing, and debugging code', description="Specialized in writing, reviewing, and debugging code",
system_prompt='''You are a coding specialist AI assistant. Your primary responsibilities: system_prompt="""You are a coding specialist AI assistant. Your primary responsibilities:
- Write clean, efficient, well-structured code - Write clean, efficient, well-structured code
- Review code for bugs, security issues, and best practices - Review code for bugs, security issues, and best practices
- Refactor and optimize existing code - Refactor and optimize existing code
- Implement features based on specifications - Implement features based on specifications
- Follow language-specific conventions and patterns - Follow language-specific conventions and patterns
Focus on code quality, maintainability, and performance.''', Focus on code quality, maintainability, and performance.""",
allowed_tools={ allowed_tools={
'read_file', 'write_file', 'list_directory', 'create_directory', "read_file",
'change_directory', 'get_current_directory', 'python_exec', "write_file",
'run_command', 'index_directory' "list_directory",
"create_directory",
"change_directory",
"get_current_directory",
"python_exec",
"run_command",
"index_directory",
}, },
specialization_areas=['code_writing', 'code_review', 'debugging', 'refactoring'], specialization_areas=[
temperature=0.3 "code_writing",
"code_review",
"debugging",
"refactoring",
],
temperature=0.3,
), ),
"research": AgentRole(
'research': AgentRole( name="research",
name='research', description="Specialized in information gathering and analysis",
description='Specialized in information gathering and analysis', system_prompt="""You are a research specialist AI assistant. Your primary responsibilities:
system_prompt='''You are a research specialist AI assistant. Your primary responsibilities:
- Search for and gather relevant information - Search for and gather relevant information
- Analyze data and documentation - Analyze data and documentation
- Synthesize findings into clear summaries - Synthesize findings into clear summaries
- Verify facts and cross-reference sources - Verify facts and cross-reference sources
- Identify trends and patterns in information - Identify trends and patterns in information
Focus on accuracy, thoroughness, and clear communication of findings.''', Focus on accuracy, thoroughness, and clear communication of findings.""",
allowed_tools={ allowed_tools={
'read_file', 'list_directory', 'index_directory', "read_file",
'http_fetch', 'web_search', 'web_search_news', "list_directory",
'db_query', 'db_get' "index_directory",
"http_fetch",
"web_search",
"web_search_news",
"db_query",
"db_get",
}, },
specialization_areas=['information_gathering', 'analysis', 'documentation', 'fact_checking'], specialization_areas=[
temperature=0.5 "information_gathering",
"analysis",
"documentation",
"fact_checking",
],
temperature=0.5,
), ),
"data_analysis": AgentRole(
'data_analysis': AgentRole( name="data_analysis",
name='data_analysis', description="Specialized in data processing and analysis",
description='Specialized in data processing and analysis', system_prompt="""You are a data analysis specialist AI assistant. Your primary responsibilities:
system_prompt='''You are a data analysis specialist AI assistant. Your primary responsibilities:
- Process and analyze structured and unstructured data - Process and analyze structured and unstructured data
- Perform statistical analysis and pattern recognition - Perform statistical analysis and pattern recognition
- Query databases and extract insights - Query databases and extract insights
- Create data summaries and reports - Create data summaries and reports
- Identify anomalies and trends - Identify anomalies and trends
Focus on accuracy, data integrity, and actionable insights.''', Focus on accuracy, data integrity, and actionable insights.""",
allowed_tools={ allowed_tools={
'db_query', 'db_get', 'db_set', 'read_file', 'write_file', "db_query",
'python_exec', 'run_command', 'list_directory' "db_get",
"db_set",
"read_file",
"write_file",
"python_exec",
"run_command",
"list_directory",
}, },
specialization_areas=['data_processing', 'statistical_analysis', 'database_operations'], specialization_areas=[
temperature=0.3 "data_processing",
"statistical_analysis",
"database_operations",
],
temperature=0.3,
), ),
"planning": AgentRole(
'planning': AgentRole( name="planning",
name='planning', description="Specialized in task planning and coordination",
description='Specialized in task planning and coordination', system_prompt="""You are a planning specialist AI assistant. Your primary responsibilities:
system_prompt='''You are a planning specialist AI assistant. Your primary responsibilities:
- Break down complex tasks into manageable steps - Break down complex tasks into manageable steps
- Create execution plans and workflows - Create execution plans and workflows
- Identify dependencies and prerequisites - Identify dependencies and prerequisites
- Estimate effort and resource requirements - Estimate effort and resource requirements
- Coordinate between different components - Coordinate between different components
Focus on logical organization, completeness, and feasibility.''', Focus on logical organization, completeness, and feasibility.""",
allowed_tools={ allowed_tools={
'read_file', 'write_file', 'list_directory', 'index_directory', "read_file",
'db_set', 'db_get' "write_file",
"list_directory",
"index_directory",
"db_set",
"db_get",
}, },
specialization_areas=['task_decomposition', 'workflow_design', 'coordination'], specialization_areas=["task_decomposition", "workflow_design", "coordination"],
temperature=0.6 temperature=0.6,
), ),
"testing": AgentRole(
'testing': AgentRole( name="testing",
name='testing', description="Specialized in testing and quality assurance",
description='Specialized in testing and quality assurance', system_prompt="""You are a testing specialist AI assistant. Your primary responsibilities:
system_prompt='''You are a testing specialist AI assistant. Your primary responsibilities:
- Design and execute test cases - Design and execute test cases
- Identify edge cases and potential failures - Identify edge cases and potential failures
- Verify functionality and correctness - Verify functionality and correctness
- Test error handling and edge conditions - Test error handling and edge conditions
- Ensure code meets quality standards - Ensure code meets quality standards
Focus on thoroughness, coverage, and issue identification.''', Focus on thoroughness, coverage, and issue identification.""",
allowed_tools={ allowed_tools={
'read_file', 'write_file', 'python_exec', 'run_command', "read_file",
'list_directory', 'db_query' "write_file",
"python_exec",
"run_command",
"list_directory",
"db_query",
}, },
specialization_areas=['test_design', 'quality_assurance', 'validation'], specialization_areas=["test_design", "quality_assurance", "validation"],
temperature=0.4 temperature=0.4,
), ),
"documentation": AgentRole(
'documentation': AgentRole( name="documentation",
name='documentation', description="Specialized in creating and maintaining documentation",
description='Specialized in creating and maintaining documentation', system_prompt="""You are a documentation specialist AI assistant. Your primary responsibilities:
system_prompt='''You are a documentation specialist AI assistant. Your primary responsibilities:
- Write clear, comprehensive documentation - Write clear, comprehensive documentation
- Create API references and user guides - Create API references and user guides
- Document code with comments and docstrings - Document code with comments and docstrings
- Organize and structure information logically - Organize and structure information logically
- Ensure documentation is up-to-date and accurate - Ensure documentation is up-to-date and accurate
Focus on clarity, completeness, and user-friendliness.''', Focus on clarity, completeness, and user-friendliness.""",
allowed_tools={ allowed_tools={
'read_file', 'write_file', 'list_directory', 'index_directory', "read_file",
'http_fetch', 'web_search' "write_file",
"list_directory",
"index_directory",
"http_fetch",
"web_search",
}, },
specialization_areas=['technical_writing', 'documentation_organization', 'user_guides'], specialization_areas=[
temperature=0.6 "technical_writing",
"documentation_organization",
"user_guides",
],
temperature=0.6,
), ),
"orchestrator": AgentRole(
'orchestrator': AgentRole( name="orchestrator",
name='orchestrator', description="Coordinates multiple agents and manages overall execution",
description='Coordinates multiple agents and manages overall execution', system_prompt="""You are an orchestrator AI assistant. Your primary responsibilities:
system_prompt='''You are an orchestrator AI assistant. Your primary responsibilities:
- Coordinate multiple specialized agents - Coordinate multiple specialized agents
- Delegate tasks to appropriate agents - Delegate tasks to appropriate agents
- Integrate results from different agents - Integrate results from different agents
- Manage overall workflow execution - Manage overall workflow execution
- Ensure task completion and quality - Ensure task completion and quality
Focus on effective delegation, integration, and overall success.''', Focus on effective delegation, integration, and overall success.""",
allowed_tools={ allowed_tools={
'read_file', 'write_file', 'list_directory', 'db_set', 'db_get', 'db_query' "read_file",
"write_file",
"list_directory",
"db_set",
"db_get",
"db_query",
}, },
specialization_areas=['agent_coordination', 'task_delegation', 'result_integration'], specialization_areas=[
temperature=0.5 "agent_coordination",
"task_delegation",
"result_integration",
],
temperature=0.5,
), ),
"general": AgentRole(
'general': AgentRole( name="general",
name='general', description="General purpose agent for miscellaneous tasks",
description='General purpose agent for miscellaneous tasks', system_prompt="""You are a general purpose AI assistant. Your responsibilities:
system_prompt='''You are a general purpose AI assistant. Your responsibilities:
- Handle diverse tasks across multiple domains - Handle diverse tasks across multiple domains
- Provide balanced assistance for various needs - Provide balanced assistance for various needs
- Adapt to different types of requests - Adapt to different types of requests
- Collaborate with specialized agents when needed - Collaborate with specialized agents when needed
Focus on versatility, helpfulness, and task completion.''', Focus on versatility, helpfulness, and task completion.""",
allowed_tools={ allowed_tools={
'read_file', 'write_file', 'list_directory', 'create_directory', "read_file",
'change_directory', 'get_current_directory', 'python_exec', "write_file",
'run_command', 'run_command_interactive', 'http_fetch', "list_directory",
'web_search', 'web_search_news', 'db_set', 'db_get', 'db_query', "create_directory",
'index_directory' "change_directory",
"get_current_directory",
"python_exec",
"run_command",
"run_command_interactive",
"http_fetch",
"web_search",
"web_search_news",
"db_set",
"db_get",
"db_query",
"index_directory",
}, },
specialization_areas=['general_assistance'], specialization_areas=["general_assistance"],
temperature=0.7 temperature=0.7,
) ),
} }
def get_agent_role(role_name: str) -> AgentRole: def get_agent_role(role_name: str) -> AgentRole:
return AGENT_ROLES.get(role_name, AGENT_ROLES['general']) return AGENT_ROLES.get(role_name, AGENT_ROLES["general"])
def list_agent_roles() -> Dict[str, AgentRole]: def list_agent_roles() -> Dict[str, AgentRole]:
return AGENT_ROLES.copy() return AGENT_ROLES.copy()
def get_recommended_agent(task_description: str) -> str: def get_recommended_agent(task_description: str) -> str:
task_lower = task_description.lower() task_lower = task_description.lower()
code_keywords = ['code', 'implement', 'function', 'class', 'bug', 'debug', 'refactor', 'optimize'] code_keywords = [
research_keywords = ['search', 'find', 'research', 'information', 'analyze', 'investigate'] "code",
data_keywords = ['data', 'database', 'query', 'statistics', 'analyze', 'process'] "implement",
planning_keywords = ['plan', 'organize', 'workflow', 'steps', 'coordinate'] "function",
testing_keywords = ['test', 'verify', 'validate', 'check', 'quality'] "class",
doc_keywords = ['document', 'documentation', 'explain', 'guide', 'manual'] "bug",
"debug",
"refactor",
"optimize",
]
research_keywords = [
"search",
"find",
"research",
"information",
"analyze",
"investigate",
]
data_keywords = ["data", "database", "query", "statistics", "analyze", "process"]
planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"]
testing_keywords = ["test", "verify", "validate", "check", "quality"]
doc_keywords = ["document", "documentation", "explain", "guide", "manual"]
if any(keyword in task_lower for keyword in code_keywords): if any(keyword in task_lower for keyword in code_keywords):
return 'coding' return "coding"
elif any(keyword in task_lower for keyword in research_keywords): elif any(keyword in task_lower for keyword in research_keywords):
return 'research' return "research"
elif any(keyword in task_lower for keyword in data_keywords): elif any(keyword in task_lower for keyword in data_keywords):
return 'data_analysis' return "data_analysis"
elif any(keyword in task_lower for keyword in planning_keywords): elif any(keyword in task_lower for keyword in planning_keywords):
return 'planning' return "planning"
elif any(keyword in task_lower for keyword in testing_keywords): elif any(keyword in task_lower for keyword in testing_keywords):
return 'testing' return "testing"
elif any(keyword in task_lower for keyword in doc_keywords): elif any(keyword in task_lower for keyword in doc_keywords):
return 'documentation' return "documentation"
else: else:
return 'general' return "general"

View File

@ -1,4 +1,4 @@
from pr.autonomous.detection import is_task_complete from pr.autonomous.detection import is_task_complete
from pr.autonomous.mode import run_autonomous_mode, process_response_autonomous from pr.autonomous.mode import process_response_autonomous, run_autonomous_mode
__all__ = ['is_task_complete', 'run_autonomous_mode', 'process_response_autonomous'] __all__ = ["is_task_complete", "run_autonomous_mode", "process_response_autonomous"]

View File

@ -1,28 +1,39 @@
from pr.config import MAX_AUTONOMOUS_ITERATIONS from pr.config import MAX_AUTONOMOUS_ITERATIONS
from pr.ui import Colors from pr.ui import Colors
def is_task_complete(response, iteration): def is_task_complete(response, iteration):
if 'error' in response: if "error" in response:
return True return True
if 'choices' not in response or not response['choices']: if "choices" not in response or not response["choices"]:
return True return True
message = response['choices'][0]['message'] message = response["choices"][0]["message"]
content = message.get('content', '').lower() content = message.get("content", "").lower()
completion_keywords = [ completion_keywords = [
'task complete', 'task is complete', 'finished', 'done', "task complete",
'successfully completed', 'task accomplished', 'all done', "task is complete",
'implementation complete', 'setup complete', 'installation complete' "finished",
"done",
"successfully completed",
"task accomplished",
"all done",
"implementation complete",
"setup complete",
"installation complete",
] ]
error_keywords = [ error_keywords = [
'cannot proceed', 'unable to continue', 'fatal error', "cannot proceed",
'cannot complete', 'impossible to' "unable to continue",
"fatal error",
"cannot complete",
"impossible to",
] ]
has_tool_calls = 'tool_calls' in message and message['tool_calls'] has_tool_calls = "tool_calls" in message and message["tool_calls"]
mentions_completion = any(keyword in content for keyword in completion_keywords) mentions_completion = any(keyword in content for keyword in completion_keywords)
mentions_error = any(keyword in content for keyword in error_keywords) mentions_error = any(keyword in content for keyword in error_keywords)

View File

@ -1,11 +1,13 @@
import time
import json import json
import logging import logging
from pr.ui import Colors, display_tool_call, print_autonomous_header import time
from pr.autonomous.detection import is_task_complete from pr.autonomous.detection import is_task_complete
from pr.core.context import truncate_tool_result from pr.core.context import truncate_tool_result
from pr.ui import Colors, display_tool_call
logger = logging.getLogger("pr")
logger = logging.getLogger('pr')
def run_autonomous_mode(assistant, task): def run_autonomous_mode(assistant, task):
assistant.autonomous_mode = True assistant.autonomous_mode = True
@ -14,25 +16,32 @@ def run_autonomous_mode(assistant, task):
logger.debug(f"=== AUTONOMOUS MODE START ===") logger.debug(f"=== AUTONOMOUS MODE START ===")
logger.debug(f"Task: {task}") logger.debug(f"Task: {task}")
assistant.messages.append({ assistant.messages.append({"role": "user", "content": f"{task}"})
"role": "user",
"content": f"{task}"
})
try: try:
while True: while True:
assistant.autonomous_iterations += 1 assistant.autonomous_iterations += 1
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---") logger.debug(
logger.debug(f"Messages before context management: {len(assistant.messages)}") f"--- Autonomous iteration {assistant.autonomous_iterations} ---"
)
logger.debug(
f"Messages before context management: {len(assistant.messages)}"
)
from pr.core.context import manage_context_window from pr.core.context import manage_context_window
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
logger.debug(f"Messages after context management: {len(assistant.messages)}") assistant.messages = manage_context_window(
assistant.messages, assistant.verbose
)
logger.debug(
f"Messages after context management: {len(assistant.messages)}"
)
from pr.core.api import call_api from pr.core.api import call_api
from pr.tools.base import get_tools_definition from pr.tools.base import get_tools_definition
response = call_api( response = call_api(
assistant.messages, assistant.messages,
assistant.model, assistant.model,
@ -40,10 +49,10 @@ def run_autonomous_mode(assistant, task):
assistant.api_key, assistant.api_key,
assistant.use_tools, assistant.use_tools,
get_tools_definition(), get_tools_definition(),
verbose=assistant.verbose verbose=assistant.verbose,
) )
if 'error' in response: if "error" in response:
logger.error(f"API error in autonomous mode: {response['error']}") logger.error(f"API error in autonomous mode: {response['error']}")
print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}") print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}")
break break
@ -74,22 +83,23 @@ def run_autonomous_mode(assistant, task):
assistant.autonomous_mode = False assistant.autonomous_mode = False
logger.debug("=== AUTONOMOUS MODE END ===") logger.debug("=== AUTONOMOUS MODE END ===")
def process_response_autonomous(assistant, response): def process_response_autonomous(assistant, response):
if 'error' in response: if "error" in response:
return f"Error: {response['error']}" return f"Error: {response['error']}"
if 'choices' not in response or not response['choices']: if "choices" not in response or not response["choices"]:
return "No response from API" return "No response from API"
message = response['choices'][0]['message'] message = response["choices"][0]["message"]
assistant.messages.append(message) assistant.messages.append(message)
if 'tool_calls' in message and message['tool_calls']: if "tool_calls" in message and message["tool_calls"]:
tool_results = [] tool_results = []
for tool_call in message['tool_calls']: for tool_call in message["tool_calls"]:
func_name = tool_call['function']['name'] func_name = tool_call["function"]["name"]
arguments = json.loads(tool_call['function']['arguments']) arguments = json.loads(tool_call["function"]["arguments"])
result = execute_single_tool(assistant, func_name, arguments) result = execute_single_tool(assistant, func_name, arguments)
result = truncate_tool_result(result) result = truncate_tool_result(result)
@ -97,16 +107,19 @@ def process_response_autonomous(assistant, response):
status = "success" if result.get("status") == "success" else "error" status = "success" if result.get("status") == "success" else "error"
display_tool_call(func_name, arguments, status, result) display_tool_call(func_name, arguments, status, result)
tool_results.append({ tool_results.append(
"tool_call_id": tool_call['id'], {
"role": "tool", "tool_call_id": tool_call["id"],
"content": json.dumps(result) "role": "tool",
}) "content": json.dumps(result),
}
)
for result in tool_results: for result in tool_results:
assistant.messages.append(result) assistant.messages.append(result)
from pr.core.api import call_api from pr.core.api import call_api
from pr.tools.base import get_tools_definition from pr.tools.base import get_tools_definition
follow_up = call_api( follow_up = call_api(
assistant.messages, assistant.messages,
assistant.model, assistant.model,
@ -114,59 +127,88 @@ def process_response_autonomous(assistant, response):
assistant.api_key, assistant.api_key,
assistant.use_tools, assistant.use_tools,
get_tools_definition(), get_tools_definition(),
verbose=assistant.verbose verbose=assistant.verbose,
) )
return process_response_autonomous(assistant, follow_up) return process_response_autonomous(assistant, follow_up)
content = message.get('content', '') content = message.get("content", "")
from pr.ui import render_markdown from pr.ui import render_markdown
return render_markdown(content, assistant.syntax_highlighting) return render_markdown(content, assistant.syntax_highlighting)
def execute_single_tool(assistant, func_name, arguments): def execute_single_tool(assistant, func_name, arguments):
logger.debug(f"Executing tool in autonomous mode: {func_name}") logger.debug(f"Executing tool in autonomous mode: {func_name}")
logger.debug(f"Tool arguments: {arguments}") logger.debug(f"Tool arguments: {arguments}")
from pr.tools import ( from pr.tools import (
http_fetch, run_command, run_command_interactive, read_file, write_file, apply_patch,
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query, chdir,
web_search, web_search_news, python_exec, index_source_directory, close_editor,
search_replace, open_editor, editor_insert_text, editor_replace_text, create_diff,
editor_search, close_editor, create_diff, apply_patch, tail_process, kill_process db_get,
db_query,
db_set,
editor_insert_text,
editor_replace_text,
editor_search,
getpwd,
http_fetch,
index_source_directory,
kill_process,
list_directory,
mkdir,
open_editor,
python_exec,
read_file,
run_command,
run_command_interactive,
search_replace,
tail_process,
web_search,
web_search_news,
write_file,
)
from pr.tools.filesystem import (
clear_edit_tracker,
display_edit_summary,
display_edit_timeline,
) )
from pr.tools.patch import display_file_diff from pr.tools.patch import display_file_diff
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
func_map = { func_map = {
'http_fetch': lambda **kw: http_fetch(**kw), "http_fetch": lambda **kw: http_fetch(**kw),
'run_command': lambda **kw: run_command(**kw), "run_command": lambda **kw: run_command(**kw),
'tail_process': lambda **kw: tail_process(**kw), "tail_process": lambda **kw: tail_process(**kw),
'kill_process': lambda **kw: kill_process(**kw), "kill_process": lambda **kw: kill_process(**kw),
'run_command_interactive': lambda **kw: run_command_interactive(**kw), "run_command_interactive": lambda **kw: run_command_interactive(**kw),
'read_file': lambda **kw: read_file(**kw), "read_file": lambda **kw: read_file(**kw),
'write_file': lambda **kw: write_file(**kw, db_conn=assistant.db_conn), "write_file": lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
'list_directory': lambda **kw: list_directory(**kw), "list_directory": lambda **kw: list_directory(**kw),
'mkdir': lambda **kw: mkdir(**kw), "mkdir": lambda **kw: mkdir(**kw),
'chdir': lambda **kw: chdir(**kw), "chdir": lambda **kw: chdir(**kw),
'getpwd': lambda **kw: getpwd(**kw), "getpwd": lambda **kw: getpwd(**kw),
'db_set': lambda **kw: db_set(**kw, db_conn=assistant.db_conn), "db_set": lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
'db_get': lambda **kw: db_get(**kw, db_conn=assistant.db_conn), "db_get": lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
'db_query': lambda **kw: db_query(**kw, db_conn=assistant.db_conn), "db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
'web_search': lambda **kw: web_search(**kw), "web_search": lambda **kw: web_search(**kw),
'web_search_news': lambda **kw: web_search_news(**kw), "web_search_news": lambda **kw: web_search_news(**kw),
'python_exec': lambda **kw: python_exec(**kw, python_globals=assistant.python_globals), "python_exec": lambda **kw: python_exec(
'index_source_directory': lambda **kw: index_source_directory(**kw), **kw, python_globals=assistant.python_globals
'search_replace': lambda **kw: search_replace(**kw), ),
'open_editor': lambda **kw: open_editor(**kw), "index_source_directory": lambda **kw: index_source_directory(**kw),
'editor_insert_text': lambda **kw: editor_insert_text(**kw), "search_replace": lambda **kw: search_replace(**kw),
'editor_replace_text': lambda **kw: editor_replace_text(**kw), "open_editor": lambda **kw: open_editor(**kw),
'editor_search': lambda **kw: editor_search(**kw), "editor_insert_text": lambda **kw: editor_insert_text(**kw),
'close_editor': lambda **kw: close_editor(**kw), "editor_replace_text": lambda **kw: editor_replace_text(**kw),
'create_diff': lambda **kw: create_diff(**kw), "editor_search": lambda **kw: editor_search(**kw),
'apply_patch': lambda **kw: apply_patch(**kw), "close_editor": lambda **kw: close_editor(**kw),
'display_file_diff': lambda **kw: display_file_diff(**kw), "create_diff": lambda **kw: create_diff(**kw),
'display_edit_summary': lambda **kw: display_edit_summary(), "apply_patch": lambda **kw: apply_patch(**kw),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw), "display_file_diff": lambda **kw: display_file_diff(**kw),
'clear_edit_tracker': lambda **kw: clear_edit_tracker(), "display_edit_summary": lambda **kw: display_edit_summary(),
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
"clear_edit_tracker": lambda **kw: clear_edit_tracker(),
} }
if func_name in func_map: if func_name in func_map:

View File

@ -1,4 +1,4 @@
from .api_cache import APICache from .api_cache import APICache
from .tool_cache import ToolCache from .tool_cache import ToolCache
__all__ = ['APICache', 'ToolCache'] __all__ = ["APICache", "ToolCache"]

86
pr/cache/api_cache.py vendored
View File

@ -2,7 +2,8 @@ import hashlib
import json import json
import sqlite3 import sqlite3
import time import time
from typing import Optional, Dict, Any from typing import Any, Dict, Optional
class APICache: class APICache:
def __init__(self, db_path: str, ttl_seconds: int = 3600): def __init__(self, db_path: str, ttl_seconds: int = 3600):
@ -13,7 +14,8 @@ class APICache:
def _initialize_cache(self): def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS api_cache ( CREATE TABLE IF NOT EXISTS api_cache (
cache_key TEXT PRIMARY KEY, cache_key TEXT PRIMARY KEY,
response_data TEXT NOT NULL, response_data TEXT NOT NULL,
@ -22,34 +24,44 @@ class APICache:
model TEXT, model TEXT,
token_count INTEGER token_count INTEGER
) )
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at) CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
''') """
)
conn.commit() conn.commit()
conn.close() conn.close()
def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str: def _generate_cache_key(
self, model: str, messages: list, temperature: float, max_tokens: int
) -> str:
cache_data = { cache_data = {
'model': model, "model": model,
'messages': messages, "messages": messages,
'temperature': temperature, "temperature": temperature,
'max_tokens': max_tokens "max_tokens": max_tokens,
} }
serialized = json.dumps(cache_data, sort_keys=True) serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest() return hashlib.sha256(serialized.encode()).hexdigest()
def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]: def get(
self, model: str, messages: list, temperature: float, max_tokens: int
) -> Optional[Dict[str, Any]]:
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
current_time = int(time.time()) current_time = int(time.time())
cursor.execute(''' cursor.execute(
"""
SELECT response_data FROM api_cache SELECT response_data FROM api_cache
WHERE cache_key = ? AND expires_at > ? WHERE cache_key = ? AND expires_at > ?
''', (cache_key, current_time)) """,
(cache_key, current_time),
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@ -58,8 +70,15 @@ class APICache:
return json.loads(row[0]) return json.loads(row[0])
return None return None
def set(self, model: str, messages: list, temperature: float, max_tokens: int, def set(
response: Dict[str, Any], token_count: int = 0): self,
model: str,
messages: list,
temperature: float,
max_tokens: int,
response: Dict[str, Any],
token_count: int = 0,
):
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
current_time = int(time.time()) current_time = int(time.time())
@ -68,11 +87,21 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT OR REPLACE INTO api_cache INSERT OR REPLACE INTO api_cache
(cache_key, response_data, created_at, expires_at, model, token_count) (cache_key, response_data, created_at, expires_at, model, token_count)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count)) """,
(
cache_key,
json.dumps(response),
current_time,
expires_at,
model,
token_count,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -83,7 +112,7 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,)) cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,))
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
conn.commit() conn.commit()
@ -95,7 +124,7 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM api_cache') cursor.execute("DELETE FROM api_cache")
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
conn.commit() conn.commit()
@ -107,21 +136,26 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM api_cache') cursor.execute("SELECT COUNT(*) FROM api_cache")
total_entries = cursor.fetchone()[0] total_entries = cursor.fetchone()[0]
current_time = int(time.time()) current_time = int(time.time())
cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,)) cursor.execute(
"SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)
)
valid_entries = cursor.fetchone()[0] valid_entries = cursor.fetchone()[0]
cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,)) cursor.execute(
"SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?",
(current_time,),
)
total_tokens = cursor.fetchone()[0] or 0 total_tokens = cursor.fetchone()[0] or 0
conn.close() conn.close()
return { return {
'total_entries': total_entries, "total_entries": total_entries,
'valid_entries': valid_entries, "valid_entries": valid_entries,
'expired_entries': total_entries - valid_entries, "expired_entries": total_entries - valid_entries,
'total_cached_tokens': total_tokens "total_cached_tokens": total_tokens,
} }

View File

@ -2,16 +2,17 @@ import hashlib
import json import json
import sqlite3 import sqlite3
import time import time
from typing import Optional, Any, Set from typing import Any, Optional, Set
class ToolCache: class ToolCache:
DETERMINISTIC_TOOLS: Set[str] = { DETERMINISTIC_TOOLS: Set[str] = {
'read_file', "read_file",
'list_directory', "list_directory",
'get_current_directory', "get_current_directory",
'db_get', "db_get",
'db_query', "db_query",
'index_directory' "index_directory",
} }
def __init__(self, db_path: str, ttl_seconds: int = 300): def __init__(self, db_path: str, ttl_seconds: int = 300):
@ -22,7 +23,8 @@ class ToolCache:
def _initialize_cache(self): def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tool_cache ( CREATE TABLE IF NOT EXISTS tool_cache (
cache_key TEXT PRIMARY KEY, cache_key TEXT PRIMARY KEY,
tool_name TEXT NOT NULL, tool_name TEXT NOT NULL,
@ -31,21 +33,23 @@ class ToolCache:
expires_at INTEGER NOT NULL, expires_at INTEGER NOT NULL,
hit_count INTEGER DEFAULT 0 hit_count INTEGER DEFAULT 0
) )
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at) CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name) CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
''') """
)
conn.commit() conn.commit()
conn.close() conn.close()
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str: def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
cache_data = { cache_data = {"tool": tool_name, "args": arguments}
'tool': tool_name,
'args': arguments
}
serialized = json.dumps(cache_data, sort_keys=True) serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest() return hashlib.sha256(serialized.encode()).hexdigest()
@ -62,18 +66,24 @@ class ToolCache:
cursor = conn.cursor() cursor = conn.cursor()
current_time = int(time.time()) current_time = int(time.time())
cursor.execute(''' cursor.execute(
"""
SELECT result_data, hit_count FROM tool_cache SELECT result_data, hit_count FROM tool_cache
WHERE cache_key = ? AND expires_at > ? WHERE cache_key = ? AND expires_at > ?
''', (cache_key, current_time)) """,
(cache_key, current_time),
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
cursor.execute(''' cursor.execute(
"""
UPDATE tool_cache SET hit_count = hit_count + 1 UPDATE tool_cache SET hit_count = hit_count + 1
WHERE cache_key = ? WHERE cache_key = ?
''', (cache_key,)) """,
(cache_key,),
)
conn.commit() conn.commit()
conn.close() conn.close()
return json.loads(row[0]) return json.loads(row[0])
@ -93,11 +103,14 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT OR REPLACE INTO tool_cache INSERT OR REPLACE INTO tool_cache
(cache_key, tool_name, result_data, created_at, expires_at, hit_count) (cache_key, tool_name, result_data, created_at, expires_at, hit_count)
VALUES (?, ?, ?, ?, ?, 0) VALUES (?, ?, ?, ?, ?, 0)
''', (cache_key, tool_name, json.dumps(result), current_time, expires_at)) """,
(cache_key, tool_name, json.dumps(result), current_time, expires_at),
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -106,7 +119,7 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,)) cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,))
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
conn.commit() conn.commit()
@ -120,7 +133,7 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,)) cursor.execute("DELETE FROM tool_cache WHERE expires_at <= ?", (current_time,))
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
conn.commit() conn.commit()
@ -132,7 +145,7 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache') cursor.execute("DELETE FROM tool_cache")
deleted_count = cursor.rowcount deleted_count = cursor.rowcount
conn.commit() conn.commit()
@ -144,36 +157,41 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM tool_cache') cursor.execute("SELECT COUNT(*) FROM tool_cache")
total_entries = cursor.fetchone()[0] total_entries = cursor.fetchone()[0]
current_time = int(time.time()) current_time = int(time.time())
cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,)) cursor.execute(
"SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)
)
valid_entries = cursor.fetchone()[0] valid_entries = cursor.fetchone()[0]
cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,)) cursor.execute(
"SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?",
(current_time,),
)
total_hits = cursor.fetchone()[0] or 0 total_hits = cursor.fetchone()[0] or 0
cursor.execute(''' cursor.execute(
"""
SELECT tool_name, COUNT(*), SUM(hit_count) SELECT tool_name, COUNT(*), SUM(hit_count)
FROM tool_cache FROM tool_cache
WHERE expires_at > ? WHERE expires_at > ?
GROUP BY tool_name GROUP BY tool_name
''', (current_time,)) """,
(current_time,),
)
tool_stats = {} tool_stats = {}
for row in cursor.fetchall(): for row in cursor.fetchall():
tool_stats[row[0]] = { tool_stats[row[0]] = {"cached_entries": row[1], "total_hits": row[2] or 0}
'cached_entries': row[1],
'total_hits': row[2] or 0
}
conn.close() conn.close()
return { return {
'total_entries': total_entries, "total_entries": total_entries,
'valid_entries': valid_entries, "valid_entries": valid_entries,
'expired_entries': total_entries - valid_entries, "expired_entries": total_entries - valid_entries,
'total_cache_hits': total_hits, "total_cache_hits": total_hits,
'by_tool': tool_stats "by_tool": tool_stats,
} }

View File

@ -1,3 +1,3 @@
from pr.commands.handlers import handle_command from pr.commands.handlers import handle_command
__all__ = ['handle_command'] __all__ = ["handle_command"]

View File

@ -1,30 +1,35 @@
import json import json
import time import time
from pr.ui import Colors
from pr.autonomous import run_autonomous_mode
from pr.core.api import list_models
from pr.tools import read_file from pr.tools import read_file
from pr.tools.base import get_tools_definition from pr.tools.base import get_tools_definition
from pr.core.api import list_models from pr.ui import Colors
from pr.autonomous import run_autonomous_mode
def handle_command(assistant, command): def handle_command(assistant, command):
command_parts = command.strip().split(maxsplit=1) command_parts = command.strip().split(maxsplit=1)
cmd = command_parts[0].lower() cmd = command_parts[0].lower()
if cmd == '/auto': if cmd == "/auto":
if len(command_parts) < 2: if len(command_parts) < 2:
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}") print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")
print(f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}") print(
f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}"
)
return True return True
task = command_parts[1] task = command_parts[1]
run_autonomous_mode(assistant, task) run_autonomous_mode(assistant, task)
return True return True
if cmd in ['exit', 'quit', 'q']: if cmd in ["exit", "quit", "q"]:
return False return False
elif cmd == 'help': elif cmd == "help":
print(f""" print(
f"""
{Colors.BOLD}Available Commands:{Colors.RESET} {Colors.BOLD}Available Commands:{Colors.RESET}
{Colors.BOLD}Basic:{Colors.RESET} {Colors.BOLD}Basic:{Colors.RESET}
@ -54,18 +59,21 @@ def handle_command(assistant, command):
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics {Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches {Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics {Colors.CYAN}/stats{Colors.RESET} - Show system statistics
""") """
)
elif cmd == '/reset': elif cmd == "/reset":
assistant.messages = assistant.messages[:1] assistant.messages = assistant.messages[:1]
print(f"{Colors.GREEN}Message history cleared{Colors.RESET}") print(f"{Colors.GREEN}Message history cleared{Colors.RESET}")
elif cmd == '/dump': elif cmd == "/dump":
print(json.dumps(assistant.messages, indent=2)) print(json.dumps(assistant.messages, indent=2))
elif cmd == '/verbose': elif cmd == "/verbose":
assistant.verbose = not assistant.verbose assistant.verbose = not assistant.verbose
print(f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}") print(
f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}"
)
elif cmd.startswith("/model"): elif cmd.startswith("/model"):
if len(command_parts) < 2: if len(command_parts) < 2:
@ -74,77 +82,81 @@ def handle_command(assistant, command):
assistant.model = command_parts[1] assistant.model = command_parts[1]
print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}") print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}")
elif cmd == '/models': elif cmd == "/models":
models = list_models(assistant.model_list_url, assistant.api_key) models = list_models(assistant.model_list_url, assistant.api_key)
if isinstance(models, dict) and 'error' in models: if isinstance(models, dict) and "error" in models:
print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}") print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}")
else: else:
print(f"{Colors.BOLD}Available Models:{Colors.RESET}") print(f"{Colors.BOLD}Available Models:{Colors.RESET}")
for model in models: for model in models:
print(f"{Colors.CYAN}{model['id']}{Colors.RESET}") print(f"{Colors.CYAN}{model['id']}{Colors.RESET}")
elif cmd == '/tools': elif cmd == "/tools":
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}") print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
for tool in get_tools_definition(): for tool in get_tools_definition():
func = tool['function'] func = tool["function"]
print(f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}") print(
f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}"
)
elif cmd == '/review' and len(command_parts) > 1: elif cmd == "/review" and len(command_parts) > 1:
filename = command_parts[1] filename = command_parts[1]
review_file(assistant, filename) review_file(assistant, filename)
elif cmd == '/refactor' and len(command_parts) > 1: elif cmd == "/refactor" and len(command_parts) > 1:
filename = command_parts[1] filename = command_parts[1]
refactor_file(assistant, filename) refactor_file(assistant, filename)
elif cmd == '/obfuscate' and len(command_parts) > 1: elif cmd == "/obfuscate" and len(command_parts) > 1:
filename = command_parts[1] filename = command_parts[1]
obfuscate_file(assistant, filename) obfuscate_file(assistant, filename)
elif cmd == '/workflows': elif cmd == "/workflows":
show_workflows(assistant) show_workflows(assistant)
elif cmd == '/workflow' and len(command_parts) > 1: elif cmd == "/workflow" and len(command_parts) > 1:
workflow_name = command_parts[1] workflow_name = command_parts[1]
execute_workflow_command(assistant, workflow_name) execute_workflow_command(assistant, workflow_name)
elif cmd == '/agent' and len(command_parts) > 1: elif cmd == "/agent" and len(command_parts) > 1:
args = command_parts[1].split(maxsplit=1) args = command_parts[1].split(maxsplit=1)
if len(args) < 2: if len(args) < 2:
print(f"{Colors.RED}Usage: /agent <role> <task>{Colors.RESET}") print(f"{Colors.RED}Usage: /agent <role> <task>{Colors.RESET}")
print(f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}") print(
f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}"
)
else: else:
role, task = args[0], args[1] role, task = args[0], args[1]
execute_agent_task(assistant, role, task) execute_agent_task(assistant, role, task)
elif cmd == '/agents': elif cmd == "/agents":
show_agents(assistant) show_agents(assistant)
elif cmd == '/collaborate' and len(command_parts) > 1: elif cmd == "/collaborate" and len(command_parts) > 1:
task = command_parts[1] task = command_parts[1]
collaborate_agents_command(assistant, task) collaborate_agents_command(assistant, task)
elif cmd == '/knowledge' and len(command_parts) > 1: elif cmd == "/knowledge" and len(command_parts) > 1:
query = command_parts[1] query = command_parts[1]
search_knowledge(assistant, query) search_knowledge(assistant, query)
elif cmd == '/remember' and len(command_parts) > 1: elif cmd == "/remember" and len(command_parts) > 1:
content = command_parts[1] content = command_parts[1]
store_knowledge(assistant, content) store_knowledge(assistant, content)
elif cmd == '/history': elif cmd == "/history":
show_conversation_history(assistant) show_conversation_history(assistant)
elif cmd == '/cache': elif cmd == "/cache":
if len(command_parts) > 1 and command_parts[1].lower() == 'clear': if len(command_parts) > 1 and command_parts[1].lower() == "clear":
clear_caches(assistant) clear_caches(assistant)
else: else:
show_cache_stats(assistant) show_cache_stats(assistant)
elif cmd == '/stats': elif cmd == "/stats":
show_system_stats(assistant) show_system_stats(assistant)
elif cmd.startswith('/bg'): elif cmd.startswith("/bg"):
handle_background_command(assistant, command) handle_background_command(assistant, command)
else: else:
@ -152,35 +164,46 @@ def handle_command(assistant, command):
return True return True
def review_file(assistant, filename): def review_file(assistant, filename):
result = read_file(filename) result = read_file(filename)
if result['status'] == 'success': if result["status"] == "success":
message = f"Please review this file and provide feedback:\n\n{result['content']}" message = (
f"Please review this file and provide feedback:\n\n{result['content']}"
)
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) process_message(assistant, message)
else: else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def refactor_file(assistant, filename): def refactor_file(assistant, filename):
result = read_file(filename) result = read_file(filename)
if result['status'] == 'success': if result["status"] == "success":
message = f"Please refactor this code to improve its quality:\n\n{result['content']}" message = (
f"Please refactor this code to improve its quality:\n\n{result['content']}"
)
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) process_message(assistant, message)
else: else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def obfuscate_file(assistant, filename): def obfuscate_file(assistant, filename):
result = read_file(filename) result = read_file(filename)
if result['status'] == 'success': if result["status"] == "success":
message = f"Please obfuscate this code:\n\n{result['content']}" message = f"Please obfuscate this code:\n\n{result['content']}"
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) process_message(assistant, message)
else: else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def show_workflows(assistant): def show_workflows(assistant):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -194,23 +217,25 @@ def show_workflows(assistant):
print(f"{Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}") print(f"{Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}")
print(f" Executions: {wf['execution_count']}") print(f" Executions: {wf['execution_count']}")
def execute_workflow_command(assistant, workflow_name): def execute_workflow_command(assistant, workflow_name):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}") print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}")
result = assistant.enhanced.execute_workflow(workflow_name) result = assistant.enhanced.execute_workflow(workflow_name)
if 'error' in result: if "error" in result:
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
else: else:
print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}") print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}")
print(f"Execution ID: {result['execution_id']}") print(f"Execution ID: {result['execution_id']}")
print(f"Results: {json.dumps(result['results'], indent=2)}") print(f"Results: {json.dumps(result['results'], indent=2)}")
def execute_agent_task(assistant, role, task): def execute_agent_task(assistant, role, task):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -221,14 +246,15 @@ def execute_agent_task(assistant, role, task):
print(f"{Colors.YELLOW}Executing task...{Colors.RESET}") print(f"{Colors.YELLOW}Executing task...{Colors.RESET}")
result = assistant.enhanced.agent_task(agent_id, task) result = assistant.enhanced.agent_task(agent_id, task)
if 'error' in result: if "error" in result:
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}") print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
else: else:
print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}") print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}")
print(result['response']) print(result["response"])
def show_agents(assistant): def show_agents(assistant):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -236,37 +262,39 @@ def show_agents(assistant):
print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}") print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}")
print(f"Active agents: {summary['active_agents']}") print(f"Active agents: {summary['active_agents']}")
if summary['agents']: if summary["agents"]:
for agent in summary['agents']: for agent in summary["agents"]:
print(f"\n{Colors.CYAN}{agent['agent_id']}{Colors.RESET}") print(f"\n{Colors.CYAN}{agent['agent_id']}{Colors.RESET}")
print(f" Role: {agent['role']}") print(f" Role: {agent['role']}")
print(f" Tasks completed: {agent['task_count']}") print(f" Tasks completed: {agent['task_count']}")
print(f" Messages: {agent['message_count']}") print(f" Messages: {agent['message_count']}")
def collaborate_agents_command(assistant, task): def collaborate_agents_command(assistant, task):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}") print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}")
roles = ['coding', 'research', 'planning'] roles = ["coding", "research", "planning"]
result = assistant.enhanced.collaborate_agents(task, roles) result = assistant.enhanced.collaborate_agents(task, roles)
print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}") print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}")
print(f"\nOrchestrator response:") print(f"\nOrchestrator response:")
if 'orchestrator' in result and 'response' in result['orchestrator']: if "orchestrator" in result and "response" in result["orchestrator"]:
print(result['orchestrator']['response']) print(result["orchestrator"]["response"])
if result.get('agents'): if result.get("agents"):
print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}") print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}")
for agent_result in result['agents']: for agent_result in result["agents"]:
if 'role' in agent_result: if "role" in agent_result:
print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}") print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}")
print(agent_result.get('response', 'No response')) print(agent_result.get("response", "No response"))
def search_knowledge(assistant, query): def search_knowledge(assistant, query):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -282,13 +310,15 @@ def search_knowledge(assistant, query):
print(f" {entry.content[:200]}...") print(f" {entry.content[:200]}...")
print(f" Accessed: {entry.access_count} times") print(f" Accessed: {entry.access_count} times")
def store_knowledge(assistant, content): def store_knowledge(assistant, content):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
import uuid
import time import time
import uuid
from pr.memory import KnowledgeEntry from pr.memory import KnowledgeEntry
categories = assistant.enhanced.fact_extractor.categorize_content(content) categories = assistant.enhanced.fact_extractor.categorize_content(content)
@ -296,11 +326,11 @@ def store_knowledge(assistant, content):
entry = KnowledgeEntry( entry = KnowledgeEntry(
entry_id=entry_id, entry_id=entry_id,
category=categories[0] if categories else 'general', category=categories[0] if categories else "general",
content=content, content=content,
metadata={'manual_entry': True}, metadata={"manual_entry": True},
created_at=time.time(), created_at=time.time(),
updated_at=time.time() updated_at=time.time(),
) )
assistant.enhanced.knowledge_store.add_entry(entry) assistant.enhanced.knowledge_store.add_entry(entry)
@ -308,8 +338,9 @@ def store_knowledge(assistant, content):
print(f"Entry ID: {entry_id}") print(f"Entry ID: {entry_id}")
print(f"Category: {entry.category}") print(f"Category: {entry.category}")
def show_conversation_history(assistant): def show_conversation_history(assistant):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -322,17 +353,21 @@ def show_conversation_history(assistant):
print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}") print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}")
for conv in history: for conv in history:
import datetime import datetime
started = datetime.datetime.fromtimestamp(conv['started_at']).strftime('%Y-%m-%d %H:%M')
started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime(
"%Y-%m-%d %H:%M"
)
print(f"\n{Colors.CYAN}{conv['conversation_id']}{Colors.RESET}") print(f"\n{Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
print(f" Started: {started}") print(f" Started: {started}")
print(f" Messages: {conv['message_count']}") print(f" Messages: {conv['message_count']}")
if conv.get('summary'): if conv.get("summary"):
print(f" Summary: {conv['summary'][:100]}...") print(f" Summary: {conv['summary'][:100]}...")
if conv.get('topics'): if conv.get("topics"):
print(f" Topics: {', '.join(conv['topics'])}") print(f" Topics: {', '.join(conv['topics'])}")
def show_cache_stats(assistant): def show_cache_stats(assistant):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -340,36 +375,40 @@ def show_cache_stats(assistant):
print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}") print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}")
if 'api_cache' in stats: if "api_cache" in stats:
api_stats = stats['api_cache'] api_stats = stats["api_cache"]
print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}") print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}")
print(f" Total entries: {api_stats['total_entries']}") print(f" Total entries: {api_stats['total_entries']}")
print(f" Valid entries: {api_stats['valid_entries']}") print(f" Valid entries: {api_stats['valid_entries']}")
print(f" Expired entries: {api_stats['expired_entries']}") print(f" Expired entries: {api_stats['expired_entries']}")
print(f" Cached tokens: {api_stats['total_cached_tokens']}") print(f" Cached tokens: {api_stats['total_cached_tokens']}")
if 'tool_cache' in stats: if "tool_cache" in stats:
tool_stats = stats['tool_cache'] tool_stats = stats["tool_cache"]
print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}") print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}")
print(f" Total entries: {tool_stats['total_entries']}") print(f" Total entries: {tool_stats['total_entries']}")
print(f" Valid entries: {tool_stats['valid_entries']}") print(f" Valid entries: {tool_stats['valid_entries']}")
print(f" Total cache hits: {tool_stats['total_cache_hits']}") print(f" Total cache hits: {tool_stats['total_cache_hits']}")
if tool_stats.get('by_tool'): if tool_stats.get("by_tool"):
print(f"\n Per-tool statistics:") print(f"\n Per-tool statistics:")
for tool_name, tool_stat in tool_stats['by_tool'].items(): for tool_name, tool_stat in tool_stats["by_tool"].items():
print(f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits") print(
f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits"
)
def clear_caches(assistant): def clear_caches(assistant):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
assistant.enhanced.clear_caches() assistant.enhanced.clear_caches()
print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}") print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}")
def show_system_stats(assistant): def show_system_stats(assistant):
if not hasattr(assistant, 'enhanced'): if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}") print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return return
@ -388,68 +427,81 @@ def show_system_stats(assistant):
print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}") print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}")
print(f" Count: {agent_summary['active_agents']}") print(f" Count: {agent_summary['active_agents']}")
if 'api_cache' in cache_stats: if "api_cache" in cache_stats:
print(f"\n{Colors.CYAN}Caching:{Colors.RESET}") print(f"\n{Colors.CYAN}Caching:{Colors.RESET}")
print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}") print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}")
if 'tool_cache' in cache_stats: if "tool_cache" in cache_stats:
print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}") print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}")
def handle_background_command(assistant, command): def handle_background_command(assistant, command):
"""Handle background multiplexer commands.""" """Handle background multiplexer commands."""
parts = command.strip().split(maxsplit=2) parts = command.strip().split(maxsplit=2)
if len(parts) < 2: if len(parts) < 2:
print(f"{Colors.RED}Usage: /bg <subcommand> [args]{Colors.RESET}") print(f"{Colors.RED}Usage: /bg <subcommand> [args]{Colors.RESET}")
print(f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}") print(
f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}"
)
return return
subcmd = parts[1].lower() subcmd = parts[1].lower()
try: try:
if subcmd == 'start' and len(parts) >= 3: if subcmd == "start" and len(parts) >= 3:
session_name = f"bg_{len(parts[2].split())}_{int(time.time())}" session_name = f"bg_{len(parts[2].split())}_{int(time.time())}"
start_background_session(assistant, session_name, parts[2]) start_background_session(assistant, session_name, parts[2])
elif subcmd == 'list': elif subcmd == "list":
list_background_sessions(assistant) list_background_sessions(assistant)
elif subcmd == 'status' and len(parts) >= 3: elif subcmd == "status" and len(parts) >= 3:
show_session_status(assistant, parts[2]) show_session_status(assistant, parts[2])
elif subcmd == 'output' and len(parts) >= 3: elif subcmd == "output" and len(parts) >= 3:
show_session_output(assistant, parts[2]) show_session_output(assistant, parts[2])
elif subcmd == 'input' and len(parts) >= 4: elif subcmd == "input" and len(parts) >= 4:
send_session_input(assistant, parts[2], parts[3]) send_session_input(assistant, parts[2], parts[3])
elif subcmd == 'kill' and len(parts) >= 3: elif subcmd == "kill" and len(parts) >= 3:
kill_background_session(assistant, parts[2]) kill_background_session(assistant, parts[2])
elif subcmd == 'events': elif subcmd == "events":
show_background_events(assistant) show_background_events(assistant)
else: else:
print(f"{Colors.RED}Unknown background command: {subcmd}{Colors.RESET}") print(f"{Colors.RED}Unknown background command: {subcmd}{Colors.RESET}")
print(f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}") print(
f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error executing background command: {e}{Colors.RESET}") print(f"{Colors.RED}Error executing background command: {e}{Colors.RESET}")
def start_background_session(assistant, session_name, command): def start_background_session(assistant, session_name, command):
"""Start a command in background.""" """Start a command in background."""
try: try:
from pr.multiplexer import start_background_process from pr.multiplexer import start_background_process
result = start_background_process(session_name, command) result = start_background_process(session_name, command)
if result['status'] == 'success': if result["status"] == "success":
print(f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}") print(
f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}"
)
else: else:
print(f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}") print(
f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error starting background session: {e}{Colors.RESET}") print(f"{Colors.RED}Error starting background session: {e}{Colors.RESET}")
def list_background_sessions(assistant): def list_background_sessions(assistant):
"""List all background sessions.""" """List all background sessions."""
try: try:
from pr.ui.display import display_multiplexer_status
from pr.multiplexer import get_all_sessions from pr.multiplexer import get_all_sessions
from pr.ui.display import display_multiplexer_status
sessions = get_all_sessions() sessions = get_all_sessions()
display_multiplexer_status(sessions) display_multiplexer_status(sessions)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error listing background sessions: {e}{Colors.RESET}") print(f"{Colors.RED}Error listing background sessions: {e}{Colors.RESET}")
def show_session_status(assistant, session_name): def show_session_status(assistant, session_name):
"""Show status of a specific session.""" """Show status of a specific session."""
try: try:
@ -461,15 +513,17 @@ def show_session_status(assistant, session_name):
print(f" Status: {info.get('status', 'unknown')}") print(f" Status: {info.get('status', 'unknown')}")
print(f" PID: {info.get('pid', 'N/A')}") print(f" PID: {info.get('pid', 'N/A')}")
print(f" Command: {info.get('command', 'N/A')}") print(f" Command: {info.get('command', 'N/A')}")
if 'start_time' in info: if "start_time" in info:
import time import time
elapsed = time.time() - info['start_time']
elapsed = time.time() - info["start_time"]
print(f" Running for: {elapsed:.1f}s") print(f" Running for: {elapsed:.1f}s")
else: else:
print(f"{Colors.YELLOW}Session '{session_name}' not found{Colors.RESET}") print(f"{Colors.YELLOW}Session '{session_name}' not found{Colors.RESET}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error getting session status: {e}{Colors.RESET}") print(f"{Colors.RED}Error getting session status: {e}{Colors.RESET}")
def show_session_output(assistant, session_name): def show_session_output(assistant, session_name):
"""Show output of a specific session.""" """Show output of a specific session."""
try: try:
@ -482,36 +536,45 @@ def show_session_output(assistant, session_name):
for line in output: for line in output:
print(line) print(line)
else: else:
print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}") print(
f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}") print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}")
def send_session_input(assistant, session_name, input_text): def send_session_input(assistant, session_name, input_text):
"""Send input to a background session.""" """Send input to a background session."""
try: try:
from pr.multiplexer import send_input_to_session from pr.multiplexer import send_input_to_session
result = send_input_to_session(session_name, input_text) result = send_input_to_session(session_name, input_text)
if result['status'] == 'success': if result["status"] == "success":
print(f"{Colors.GREEN}Input sent to session '{session_name}'{Colors.RESET}") print(f"{Colors.GREEN}Input sent to session '{session_name}'{Colors.RESET}")
else: else:
print(f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}") print(
f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error sending input: {e}{Colors.RESET}") print(f"{Colors.RED}Error sending input: {e}{Colors.RESET}")
def kill_background_session(assistant, session_name): def kill_background_session(assistant, session_name):
"""Kill a background session.""" """Kill a background session."""
try: try:
from pr.multiplexer import kill_session from pr.multiplexer import kill_session
result = kill_session(session_name) result = kill_session(session_name)
if result['status'] == 'success': if result["status"] == "success":
print(f"{Colors.GREEN}Session '{session_name}' terminated{Colors.RESET}") print(f"{Colors.GREEN}Session '{session_name}' terminated{Colors.RESET}")
else: else:
print(f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}") print(
f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error killing session: {e}{Colors.RESET}") print(f"{Colors.RED}Error killing session: {e}{Colors.RESET}")
def show_background_events(assistant): def show_background_events(assistant):
"""Show recent background events.""" """Show recent background events."""
try: try:
@ -526,6 +589,7 @@ def show_background_events(assistant):
for event in events[-10:]: # Show last 10 events for event in events[-10:]: # Show last 10 events
from pr.ui.display import display_background_event from pr.ui.display import display_background_event
display_background_event(event) display_background_event(event)
else: else:
print(f"{Colors.GRAY}No recent background events{Colors.RESET}") print(f"{Colors.GRAY}No recent background events{Colors.RESET}")

View File

@ -1,11 +1,15 @@
from pr.tools.interactive_control import (
list_active_sessions, get_session_status, read_session_output,
send_input_to_session, close_interactive_session
)
from pr.multiplexer import get_multiplexer from pr.multiplexer import get_multiplexer
from pr.tools.interactive_control import (
close_interactive_session,
get_session_status,
list_active_sessions,
read_session_output,
send_input_to_session,
)
from pr.tools.prompt_detection import get_global_detector from pr.tools.prompt_detection import get_global_detector
from pr.ui import Colors from pr.ui import Colors
def show_sessions(args=None): def show_sessions(args=None):
"""Show all active multiplexer sessions.""" """Show all active multiplexer sessions."""
sessions = list_active_sessions() sessions = list_active_sessions()
@ -18,24 +22,29 @@ def show_sessions(args=None):
print("-" * 80) print("-" * 80)
for session_name, session_data in sessions.items(): for session_name, session_data in sessions.items():
metadata = session_data['metadata'] metadata = session_data["metadata"]
output_summary = session_data['output_summary'] output_summary = session_data["output_summary"]
status = get_session_status(session_name) status = get_session_status(session_name)
is_active = status.get('is_active', False) if status else False is_active = status.get("is_active", False) if status else False
status_color = Colors.GREEN if is_active else Colors.RED status_color = Colors.GREEN if is_active else Colors.RED
print(f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}") print(
f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}"
)
if status and 'pid' in status: if status and "pid" in status:
print(f" PID: {status['pid']}") print(f" PID: {status['pid']}")
print(f" Age: {metadata.get('start_time', 0):.1f}s") print(f" Age: {metadata.get('start_time', 0):.1f}s")
print(f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines") print(
f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines"
)
print(f" Interactions: {metadata.get('interaction_count', 0)}") print(f" Interactions: {metadata.get('interaction_count', 0)}")
print(f" State: {metadata.get('state', 'unknown')}") print(f" State: {metadata.get('state', 'unknown')}")
print() print()
def attach_session(args): def attach_session(args):
"""Attach to a session (show its output and allow interaction).""" """Attach to a session (show its output and allow interaction)."""
if not args or len(args) < 1: if not args or len(args) < 1:
@ -56,20 +65,23 @@ def attach_session(args):
# Show recent output # Show recent output
try: try:
output = read_session_output(session_name, lines=20) output = read_session_output(session_name, lines=20)
if output['stdout']: if output["stdout"]:
print(f"{Colors.GRAY}Recent stdout:{Colors.RESET}") print(f"{Colors.GRAY}Recent stdout:{Colors.RESET}")
for line in output['stdout'].split('\n'): for line in output["stdout"].split("\n"):
if line.strip(): if line.strip():
print(f" {line}") print(f" {line}")
if output['stderr']: if output["stderr"]:
print(f"{Colors.YELLOW}Recent stderr:{Colors.RESET}") print(f"{Colors.YELLOW}Recent stderr:{Colors.RESET}")
for line in output['stderr'].split('\n'): for line in output["stderr"].split("\n"):
if line.strip(): if line.strip():
print(f" {line}") print(f" {line}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error reading output: {e}{Colors.RESET}") print(f"{Colors.RED}Error reading output: {e}{Colors.RESET}")
print(f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}") print(
f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}"
)
def detach_session(args): def detach_session(args):
"""Detach from a session (stop showing its output but keep it running).""" """Detach from a session (stop showing its output but keep it running)."""
@ -87,7 +99,10 @@ def detach_session(args):
# In this implementation, detaching just means we stop displaying output # In this implementation, detaching just means we stop displaying output
# The session continues to run in the background # The session continues to run in the background
mux.show_output = False mux.show_output = False
print(f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}") print(
f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}"
)
def kill_session(args): def kill_session(args):
"""Kill a session forcefully.""" """Kill a session forcefully."""
@ -101,7 +116,10 @@ def kill_session(args):
close_interactive_session(session_name) close_interactive_session(session_name)
print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}") print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}") print(
f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}"
)
def send_command(args): def send_command(args):
"""Send a command to a session.""" """Send a command to a session."""
@ -110,13 +128,18 @@ def send_command(args):
return return
session_name = args[0] session_name = args[0]
command = ' '.join(args[1:]) command = " ".join(args[1:])
try: try:
send_input_to_session(session_name, command) send_input_to_session(session_name, command)
print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}") print(
f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}") print(
f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}"
)
def show_session_log(args): def show_session_log(args):
"""Show the full log/output of a session.""" """Show the full log/output of a session."""
@ -131,19 +154,20 @@ def show_session_log(args):
print(f"{Colors.BOLD}Full log for session: {session_name}{Colors.RESET}") print(f"{Colors.BOLD}Full log for session: {session_name}{Colors.RESET}")
print("=" * 80) print("=" * 80)
if output['stdout']: if output["stdout"]:
print(f"{Colors.GRAY}STDOUT:{Colors.RESET}") print(f"{Colors.GRAY}STDOUT:{Colors.RESET}")
print(output['stdout']) print(output["stdout"])
print() print()
if output['stderr']: if output["stderr"]:
print(f"{Colors.YELLOW}STDERR:{Colors.RESET}") print(f"{Colors.YELLOW}STDERR:{Colors.RESET}")
print(output['stderr']) print(output["stderr"])
print() print()
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error reading log for '{session_name}': {e}{Colors.RESET}") print(f"{Colors.RED}Error reading log for '{session_name}': {e}{Colors.RESET}")
def show_session_status(args): def show_session_status(args):
"""Show detailed status of a session.""" """Show detailed status of a session."""
if not args or len(args) < 1: if not args or len(args) < 1:
@ -160,11 +184,11 @@ def show_session_status(args):
print(f"{Colors.BOLD}Status for session: {session_name}{Colors.RESET}") print(f"{Colors.BOLD}Status for session: {session_name}{Colors.RESET}")
print("-" * 50) print("-" * 50)
metadata = status.get('metadata', {}) metadata = status.get("metadata", {})
print(f"Process type: {metadata.get('process_type', 'unknown')}") print(f"Process type: {metadata.get('process_type', 'unknown')}")
print(f"Active: {status.get('is_active', False)}") print(f"Active: {status.get('is_active', False)}")
if 'pid' in status: if "pid" in status:
print(f"PID: {status['pid']}") print(f"PID: {status['pid']}")
print(f"Start time: {metadata.get('start_time', 0):.1f}") print(f"Start time: {metadata.get('start_time', 0):.1f}")
@ -172,8 +196,10 @@ def show_session_status(args):
print(f"Interaction count: {metadata.get('interaction_count', 0)}") print(f"Interaction count: {metadata.get('interaction_count', 0)}")
print(f"State: {metadata.get('state', 'unknown')}") print(f"State: {metadata.get('state', 'unknown')}")
output_summary = status.get('output_summary', {}) output_summary = status.get("output_summary", {})
print(f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr") print(
f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr"
)
# Show prompt detection info # Show prompt detection info
detector = get_global_detector() detector = get_global_detector()
@ -182,6 +208,7 @@ def show_session_status(args):
print(f"Current state: {session_info['current_state']}") print(f"Current state: {session_info['current_state']}")
print(f"Is waiting for input: {session_info['is_waiting']}") print(f"Is waiting for input: {session_info['is_waiting']}")
def list_waiting_sessions(args=None): def list_waiting_sessions(args=None):
"""List sessions that appear to be waiting for input.""" """List sessions that appear to be waiting for input."""
sessions = list_active_sessions() sessions = list_active_sessions()
@ -193,14 +220,16 @@ def list_waiting_sessions(args=None):
waiting_sessions.append(session_name) waiting_sessions.append(session_name)
if not waiting_sessions: if not waiting_sessions:
print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}") print(
f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}"
)
return return
print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}") print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}")
for session_name in waiting_sessions: for session_name in waiting_sessions:
status = get_session_status(session_name) status = get_session_status(session_name)
if status: if status:
process_type = status.get('metadata', {}).get('process_type', 'unknown') process_type = status.get("metadata", {}).get("process_type", "unknown")
print(f" {Colors.CYAN}{session_name}{Colors.RESET} ({process_type})") print(f" {Colors.CYAN}{session_name}{Colors.RESET} ({process_type})")
# Show suggestions # Show suggestions
@ -208,17 +237,20 @@ def list_waiting_sessions(args=None):
if session_info: if session_info:
suggestions = detector.get_response_suggestions({}, process_type) suggestions = detector.get_response_suggestions({}, process_type)
if suggestions: if suggestions:
print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3 print(
f" Suggested inputs: {', '.join(suggestions[:3])}"
) # Show first 3
print() print()
# Command registry for the multiplexer commands # Command registry for the multiplexer commands
MULTIPLEXER_COMMANDS = { MULTIPLEXER_COMMANDS = {
'show_sessions': show_sessions, "show_sessions": show_sessions,
'attach_session': attach_session, "attach_session": attach_session,
'detach_session': detach_session, "detach_session": detach_session,
'kill_session': kill_session, "kill_session": kill_session,
'send_command': send_command, "send_command": send_command,
'show_session_log': show_session_log, "show_session_log": show_session_log,
'show_session_status': show_session_status, "show_session_status": show_session_status,
'list_waiting_sessions': list_waiting_sessions, "list_waiting_sessions": list_waiting_sessions,
} }

View File

@ -27,15 +27,78 @@ CONTENT_TRIM_LENGTH = 30000
MAX_TOOL_RESULT_LENGTH = 30000 MAX_TOOL_RESULT_LENGTH = 30000
LANGUAGE_KEYWORDS = { LANGUAGE_KEYWORDS = {
'python': ['def', 'class', 'import', 'from', 'if', 'else', 'elif', 'for', 'while', "python": [
'return', 'try', 'except', 'finally', 'with', 'as', 'lambda', 'yield', "def",
'None', 'True', 'False', 'and', 'or', 'not', 'in', 'is'], "class",
'javascript': ['function', 'var', 'let', 'const', 'if', 'else', 'for', 'while', "import",
'return', 'try', 'catch', 'finally', 'class', 'extends', 'new', "from",
'this', 'null', 'undefined', 'true', 'false'], "if",
'java': ['public', 'private', 'protected', 'class', 'interface', 'extends', "else",
'implements', 'static', 'final', 'void', 'int', 'String', 'boolean', "elif",
'if', 'else', 'for', 'while', 'return', 'try', 'catch', 'finally'], "for",
"while",
"return",
"try",
"except",
"finally",
"with",
"as",
"lambda",
"yield",
"None",
"True",
"False",
"and",
"or",
"not",
"in",
"is",
],
"javascript": [
"function",
"var",
"let",
"const",
"if",
"else",
"for",
"while",
"return",
"try",
"catch",
"finally",
"class",
"extends",
"new",
"this",
"null",
"undefined",
"true",
"false",
],
"java": [
"public",
"private",
"protected",
"class",
"interface",
"extends",
"implements",
"static",
"final",
"void",
"int",
"String",
"boolean",
"if",
"else",
"for",
"while",
"return",
"try",
"catch",
"finally",
],
} }
CACHE_ENABLED = True CACHE_ENABLED = True
@ -70,18 +133,18 @@ MAX_CONCURRENT_SESSIONS = 10
# Process-specific timeouts (seconds) # Process-specific timeouts (seconds)
PROCESS_TIMEOUTS = { PROCESS_TIMEOUTS = {
'default': 300, # 5 minutes "default": 300, # 5 minutes
'apt': 600, # 10 minutes "apt": 600, # 10 minutes
'ssh': 60, # 1 minute "ssh": 60, # 1 minute
'vim': 3600, # 1 hour "vim": 3600, # 1 hour
'git': 300, # 5 minutes "git": 300, # 5 minutes
'npm': 600, # 10 minutes "npm": 600, # 10 minutes
'pip': 300, # 5 minutes "pip": 300, # 5 minutes
} }
# Activity thresholds for LLM notification # Activity thresholds for LLM notification
HIGH_OUTPUT_THRESHOLD = 50 # lines HIGH_OUTPUT_THRESHOLD = 50 # lines
INACTIVE_THRESHOLD = 300 # seconds INACTIVE_THRESHOLD = 300 # seconds
SESSION_NOTIFY_INTERVAL = 60 # seconds SESSION_NOTIFY_INTERVAL = 60 # seconds
# Autonomous behavior flags # Autonomous behavior flags

View File

@ -1,5 +1,11 @@
from pr.core.assistant import Assistant
from pr.core.api import call_api, list_models from pr.core.api import call_api, list_models
from pr.core.assistant import Assistant
from pr.core.context import init_system_message, manage_context_window from pr.core.context import init_system_message, manage_context_window
__all__ = ['Assistant', 'call_api', 'list_models', 'init_system_message', 'manage_context_window'] __all__ = [
"Assistant",
"call_api",
"list_models",
"init_system_message",
"manage_context_window",
]

View File

@ -1,20 +1,20 @@
import re import re
import math from typing import Any, Dict, List
from typing import List, Dict, Any
from collections import Counter
class AdvancedContextManager: class AdvancedContextManager:
def __init__(self, knowledge_store=None, conversation_memory=None): def __init__(self, knowledge_store=None, conversation_memory=None):
self.knowledge_store = knowledge_store self.knowledge_store = knowledge_store
self.conversation_memory = conversation_memory self.conversation_memory = conversation_memory
def adaptive_context_window(self, messages: List[Dict[str, Any]], def adaptive_context_window(
task_complexity: str = 'medium') -> int: self, messages: List[Dict[str, Any]], task_complexity: str = "medium"
) -> int:
complexity_thresholds = { complexity_thresholds = {
'simple': 10, "simple": 10,
'medium': 20, "medium": 20,
'complex': 35, "complex": 35,
'very_complex': 50 "very_complex": 50,
} }
base_threshold = complexity_thresholds.get(task_complexity, 20) base_threshold = complexity_thresholds.get(task_complexity, 20)
@ -31,16 +31,18 @@ class AdvancedContextManager:
return max(base_threshold, adjusted) return max(base_threshold, adjusted)
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float: def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
total_length = sum(len(msg.get('content', '')) for msg in messages) total_length = sum(len(msg.get("content", "")) for msg in messages)
avg_length = total_length / len(messages) if messages else 0 avg_length = total_length / len(messages) if messages else 0
unique_words = set() unique_words = set()
for msg in messages: for msg in messages:
content = msg.get('content', '') content = msg.get("content", "")
words = re.findall(r'\b\w+\b', content.lower()) words = re.findall(r"\b\w+\b", content.lower())
unique_words.update(words) unique_words.update(words)
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0 vocabulary_richness = (
len(unique_words) / total_length if total_length > 0 else 0
)
# Simple complexity score based on length and richness # Simple complexity score based on length and richness
complexity = min(1.0, (avg_length / 100) + vocabulary_richness) complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
@ -49,7 +51,7 @@ class AdvancedContextManager:
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]: def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
if not text.strip(): if not text.strip():
return [] return []
sentences = re.split(r'(?<=[.!?])\s+', text) sentences = re.split(r"(?<=[.!?])\s+", text)
if not sentences: if not sentences:
return [] return []
@ -65,15 +67,15 @@ class AdvancedContextManager:
return [s[0] for s in scored_sentences[:top_k]] return [s[0] for s in scored_sentences[:top_k]]
def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str: def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str:
all_content = ' '.join([msg.get('content', '') for msg in messages]) all_content = " ".join([msg.get("content", "") for msg in messages])
key_sentences = self.extract_key_sentences(all_content, top_k=3) key_sentences = self.extract_key_sentences(all_content, top_k=3)
summary = ' '.join(key_sentences) summary = " ".join(key_sentences)
return summary if summary else "No content to summarize." return summary if summary else "No content to summarize."
def score_message_relevance(self, message: Dict[str, Any], context: str) -> float: def score_message_relevance(self, message: Dict[str, Any], context: str) -> float:
content = message.get('content', '') content = message.get("content", "")
content_words = set(re.findall(r'\b\w+\b', content.lower())) content_words = set(re.findall(r"\b\w+\b", content.lower()))
context_words = set(re.findall(r'\b\w+\b', context.lower())) context_words = set(re.findall(r"\b\w+\b", context.lower()))
intersection = content_words & context_words intersection = content_words & context_words
union = content_words | context_words union = content_words | context_words

View File

@ -1,13 +1,17 @@
import json import json
import urllib.request
import urllib.error
import logging import logging
from pr.config import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS import urllib.error
import urllib.request
from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE
from pr.core.context import auto_slim_messages from pr.core.context import auto_slim_messages
logger = logging.getLogger('pr') logger = logging.getLogger("pr")
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
def call_api(
messages, model, api_url, api_key, use_tools, tools_definition, verbose=False
):
try: try:
messages = auto_slim_messages(messages, verbose=verbose) messages = auto_slim_messages(messages, verbose=verbose)
@ -17,62 +21,63 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
logger.debug(f"Use tools: {use_tools}") logger.debug(f"Use tools: {use_tools}")
logger.debug(f"Message count: {len(messages)}") logger.debug(f"Message count: {len(messages)}")
headers = { headers = {
'Content-Type': 'application/json', "Content-Type": "application/json",
} }
if api_key: if api_key:
headers['Authorization'] = f'Bearer {api_key}' headers["Authorization"] = f"Bearer {api_key}"
data = { data = {
'model': model, "model": model,
'messages': messages, "messages": messages,
'temperature': DEFAULT_TEMPERATURE, "temperature": DEFAULT_TEMPERATURE,
'max_tokens': DEFAULT_MAX_TOKENS "max_tokens": DEFAULT_MAX_TOKENS,
} }
if "gpt-5" in model: if "gpt-5" in model:
del data['temperature'] del data["temperature"]
del data['max_tokens'] del data["max_tokens"]
logger.debug("GPT-5 detected: removed temperature and max_tokens") logger.debug("GPT-5 detected: removed temperature and max_tokens")
if use_tools: if use_tools:
data['tools'] = tools_definition data["tools"] = tools_definition
data['tool_choice'] = 'auto' data["tool_choice"] = "auto"
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools") logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
request_json = json.dumps(data) request_json = json.dumps(data)
logger.debug(f"Request payload size: {len(request_json)} bytes") logger.debug(f"Request payload size: {len(request_json)} bytes")
req = urllib.request.Request( req = urllib.request.Request(
api_url, api_url, data=request_json.encode("utf-8"), headers=headers, method="POST"
data=request_json.encode('utf-8'),
headers=headers,
method='POST'
) )
logger.debug("Sending HTTP request...") logger.debug("Sending HTTP request...")
with urllib.request.urlopen(req) as response: with urllib.request.urlopen(req) as response:
response_data = response.read().decode('utf-8') response_data = response.read().decode("utf-8")
logger.debug(f"Response received: {len(response_data)} bytes") logger.debug(f"Response received: {len(response_data)} bytes")
result = json.loads(response_data) result = json.loads(response_data)
if 'usage' in result: if "usage" in result:
logger.debug(f"Token usage: {result['usage']}") logger.debug(f"Token usage: {result['usage']}")
if 'choices' in result and result['choices']: if "choices" in result and result["choices"]:
choice = result['choices'][0] choice = result["choices"][0]
if 'message' in choice: if "message" in choice:
msg = choice['message'] msg = choice["message"]
logger.debug(f"Response role: {msg.get('role', 'N/A')}") logger.debug(f"Response role: {msg.get('role', 'N/A')}")
if 'content' in msg and msg['content']: if "content" in msg and msg["content"]:
logger.debug(f"Response content length: {len(msg['content'])} chars") logger.debug(
if 'tool_calls' in msg: f"Response content length: {len(msg['content'])} chars"
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") )
if "tool_calls" in msg:
logger.debug(
f"Response contains {len(msg['tool_calls'])} tool call(s)"
)
logger.debug("=== API CALL END ===") logger.debug("=== API CALL END ===")
return result return result
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
error_body = e.read().decode('utf-8') error_body = e.read().decode("utf-8")
logger.error(f"API HTTP Error: {e.code} - {error_body}") logger.error(f"API HTTP Error: {e.code} - {error_body}")
logger.debug("=== API CALL FAILED ===") logger.debug("=== API CALL FAILED ===")
return {"error": f"API Error: {e.code}", "message": error_body} return {"error": f"API Error: {e.code}", "message": error_body}
@ -81,15 +86,16 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
logger.debug("=== API CALL FAILED ===") logger.debug("=== API CALL FAILED ===")
return {"error": str(e)} return {"error": str(e)}
def list_models(model_list_url, api_key): def list_models(model_list_url, api_key):
try: try:
req = urllib.request.Request(model_list_url) req = urllib.request.Request(model_list_url)
if api_key: if api_key:
req.add_header('Authorization', f'Bearer {api_key}') req.add_header("Authorization", f"Bearer {api_key}")
with urllib.request.urlopen(req) as response: with urllib.request.urlopen(req) as response:
data = json.loads(response.read().decode('utf-8')) data = json.loads(response.read().decode("utf-8"))
return data.get('data', []) return data.get("data", [])
except Exception as e: except Exception as e:
return {"error": str(e)} return {"error": str(e)}

View File

@ -1,63 +1,111 @@
import os
import sys
import json
import sqlite3
import signal
import logging
import traceback
import readline
import glob as glob_module import glob as glob_module
import json
import logging
import os
import readline
import signal
import sqlite3
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pr.config import DB_PATH, LOG_FILE, DEFAULT_MODEL, DEFAULT_API_URL, MODEL_LIST_URL, HISTORY_FILE
from pr.ui import Colors, render_markdown from pr.commands import handle_command
from pr.core.context import init_system_message, truncate_tool_result from pr.config import (
DB_PATH,
DEFAULT_API_URL,
DEFAULT_MODEL,
HISTORY_FILE,
LOG_FILE,
MODEL_LIST_URL,
)
from pr.core.api import call_api from pr.core.api import call_api
from pr.core.autonomous_interactions import (
get_global_autonomous,
stop_global_autonomous,
)
from pr.core.background_monitor import (
get_global_monitor,
start_global_monitor,
stop_global_monitor,
)
from pr.core.context import init_system_message, truncate_tool_result
from pr.tools import ( from pr.tools import (
http_fetch, run_command, run_command_interactive, read_file, write_file, apply_patch,
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query, chdir,
web_search, web_search_news, python_exec, index_source_directory, close_editor,
open_editor, editor_insert_text, editor_replace_text, editor_search, create_diff,
search_replace,close_editor,create_diff,apply_patch, db_get,
tail_process, kill_process db_query,
db_set,
editor_insert_text,
editor_replace_text,
editor_search,
getpwd,
http_fetch,
index_source_directory,
kill_process,
list_directory,
mkdir,
open_editor,
python_exec,
read_file,
run_command,
search_replace,
tail_process,
web_search,
web_search_news,
write_file,
)
from pr.tools.base import get_tools_definition
from pr.tools.filesystem import (
clear_edit_tracker,
display_edit_summary,
display_edit_timeline,
) )
from pr.tools.interactive_control import ( from pr.tools.interactive_control import (
start_interactive_session, send_input_to_session, read_session_output, close_interactive_session,
list_active_sessions, close_interactive_session list_active_sessions,
read_session_output,
send_input_to_session,
start_interactive_session,
) )
from pr.tools.patch import display_file_diff from pr.tools.patch import display_file_diff
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker from pr.ui import Colors, render_markdown
from pr.tools.base import get_tools_definition
from pr.commands import handle_command
from pr.core.background_monitor import start_global_monitor, stop_global_monitor, get_global_monitor
from pr.core.autonomous_interactions import start_global_autonomous, stop_global_autonomous, get_global_autonomous
logger = logging.getLogger('pr') logger = logging.getLogger("pr")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(LOG_FILE) file_handler = logging.FileHandler(LOG_FILE)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) file_handler.setFormatter(
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
)
logger.addHandler(file_handler) logger.addHandler(file_handler)
class Assistant: class Assistant:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self.messages = [] self.messages = []
self.verbose = args.verbose self.verbose = args.verbose
self.debug = getattr(args, 'debug', False) self.debug = getattr(args, "debug", False)
self.syntax_highlighting = not args.no_syntax self.syntax_highlighting = not args.no_syntax
if self.debug: if self.debug:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG) console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) console_handler.setFormatter(
logging.Formatter("%(levelname)s: %(message)s")
)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger.debug("Debug mode enabled") logger.debug("Debug mode enabled")
self.api_key = os.environ.get('OPENROUTER_API_KEY', '') self.api_key = os.environ.get("OPENROUTER_API_KEY", "")
self.model = args.model or os.environ.get('AI_MODEL', DEFAULT_MODEL) self.model = args.model or os.environ.get("AI_MODEL", DEFAULT_MODEL)
self.api_url = args.api_url or os.environ.get('API_URL', DEFAULT_API_URL) self.api_url = args.api_url or os.environ.get("API_URL", DEFAULT_API_URL)
self.model_list_url = args.model_list_url or os.environ.get('MODEL_LIST_URL', MODEL_LIST_URL) self.model_list_url = args.model_list_url or os.environ.get(
self.use_tools = os.environ.get('USE_TOOLS', '1') == '1' "MODEL_LIST_URL", MODEL_LIST_URL
self.strict_mode = os.environ.get('STRICT_MODE', '0') == '1' )
self.use_tools = os.environ.get("USE_TOOLS", "1") == "1"
self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1"
self.interrupt_count = 0 self.interrupt_count = 0
self.python_globals = {} self.python_globals = {}
self.db_conn = None self.db_conn = None
@ -69,6 +117,7 @@ class Assistant:
try: try:
from pr.core.enhanced_assistant import EnhancedAssistant from pr.core.enhanced_assistant import EnhancedAssistant
self.enhanced = EnhancedAssistant(self) self.enhanced = EnhancedAssistant(self)
if self.debug: if self.debug:
logger.debug("Enhanced assistant features initialized") logger.debug("Enhanced assistant features initialized")
@ -94,13 +143,17 @@ class Assistant:
self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False) self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
cursor = self.db_conn.cursor() cursor = self.db_conn.cursor()
cursor.execute('''CREATE TABLE IF NOT EXISTS kv_store cursor.execute(
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)''') """CREATE TABLE IF NOT EXISTS kv_store
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)"""
)
cursor.execute('''CREATE TABLE IF NOT EXISTS file_versions cursor.execute(
"""CREATE TABLE IF NOT EXISTS file_versions
(id INTEGER PRIMARY KEY AUTOINCREMENT, (id INTEGER PRIMARY KEY AUTOINCREMENT,
filepath TEXT, content TEXT, hash TEXT, filepath TEXT, content TEXT, hash TEXT,
timestamp REAL, version INTEGER)''') timestamp REAL, version INTEGER)"""
)
self.db_conn.commit() self.db_conn.commit()
logger.debug("Database initialized successfully") logger.debug("Database initialized successfully")
@ -110,7 +163,7 @@ class Assistant:
def _handle_background_updates(self, updates): def _handle_background_updates(self, updates):
"""Handle background session updates by injecting them into the conversation.""" """Handle background session updates by injecting them into the conversation."""
if not updates or not updates.get('sessions'): if not updates or not updates.get("sessions"):
return return
# Format the update as a system message # Format the update as a system message
@ -118,10 +171,12 @@ class Assistant:
# Inject into current conversation if we're in an active session # Inject into current conversation if we're in an active session
if self.messages and len(self.messages) > 0: if self.messages and len(self.messages) > 0:
self.messages.append({ self.messages.append(
"role": "system", {
"content": f"Background session updates: {update_message}" "role": "system",
}) "content": f"Background session updates: {update_message}",
}
)
if self.verbose: if self.verbose:
print(f"{Colors.CYAN}Background update: {update_message}{Colors.RESET}") print(f"{Colors.CYAN}Background update: {update_message}{Colors.RESET}")
@ -130,8 +185,8 @@ class Assistant:
"""Format background updates for LLM consumption.""" """Format background updates for LLM consumption."""
session_summaries = [] session_summaries = []
for session_name, session_info in updates.get('sessions', {}).items(): for session_name, session_info in updates.get("sessions", {}).items():
summary = session_info.get('summary', f'Session {session_name}') summary = session_info.get("summary", f"Session {session_name}")
session_summaries.append(f"{session_name}: {summary}") session_summaries.append(f"{session_name}: {summary}")
if session_summaries: if session_summaries:
@ -151,30 +206,44 @@ class Assistant:
if events: if events:
print(f"\n{Colors.CYAN}Background Events:{Colors.RESET}") print(f"\n{Colors.CYAN}Background Events:{Colors.RESET}")
for event in events: for event in events:
event_type = event.get('type', 'unknown') event_type = event.get("type", "unknown")
session_name = event.get('session_name', 'unknown') session_name = event.get("session_name", "unknown")
if event_type == 'session_started': if event_type == "session_started":
print(f" {Colors.GREEN}{Colors.RESET} Session '{session_name}' started") print(
elif event_type == 'session_ended': f" {Colors.GREEN}{Colors.RESET} Session '{session_name}' started"
print(f" {Colors.YELLOW}{Colors.RESET} Session '{session_name}' ended") )
elif event_type == 'output_received': elif event_type == "session_ended":
lines = len(event.get('new_output', {}).get('stdout', [])) print(
print(f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output") f" {Colors.YELLOW}{Colors.RESET} Session '{session_name}' ended"
elif event_type == 'possible_input_needed': )
print(f" {Colors.RED}{Colors.RESET} Session '{session_name}' may need input") elif event_type == "output_received":
elif event_type == 'high_output_volume': lines = len(event.get("new_output", {}).get("stdout", []))
total = event.get('total_lines', 0) print(
print(f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)") f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output"
elif event_type == 'inactive_session': )
inactive_time = event.get('inactive_seconds', 0) elif event_type == "possible_input_needed":
print(f" {Colors.GRAY}{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s") print(
f" {Colors.RED}{Colors.RESET} Session '{session_name}' may need input"
)
elif event_type == "high_output_volume":
total = event.get("total_lines", 0)
print(
f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)"
)
elif event_type == "inactive_session":
inactive_time = event.get("inactive_seconds", 0)
print(
f" {Colors.GRAY}{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s"
)
print() # Add blank line after events print() # Add blank line after events
except Exception as e: except Exception as e:
if self.debug: if self.debug:
print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}") print(
f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}"
)
def execute_tool_calls(self, tool_calls): def execute_tool_calls(self, tool_calls):
results = [] results = []
@ -185,114 +254,147 @@ class Assistant:
futures = [] futures = []
for tool_call in tool_calls: for tool_call in tool_calls:
func_name = tool_call['function']['name'] func_name = tool_call["function"]["name"]
arguments = json.loads(tool_call['function']['arguments']) arguments = json.loads(tool_call["function"]["arguments"])
logger.debug(f"Tool call: {func_name} with arguments: {arguments}") logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
func_map = { func_map = {
'http_fetch': lambda **kw: http_fetch(**kw), "http_fetch": lambda **kw: http_fetch(**kw),
'run_command': lambda **kw: run_command(**kw), "run_command": lambda **kw: run_command(**kw),
'tail_process': lambda **kw: tail_process(**kw), "tail_process": lambda **kw: tail_process(**kw),
'kill_process': lambda **kw: kill_process(**kw), "kill_process": lambda **kw: kill_process(**kw),
'start_interactive_session': lambda **kw: start_interactive_session(**kw), "start_interactive_session": lambda **kw: start_interactive_session(
'send_input_to_session': lambda **kw: send_input_to_session(**kw), **kw
'read_session_output': lambda **kw: read_session_output(**kw), ),
'close_interactive_session': lambda **kw: close_interactive_session(**kw), "send_input_to_session": lambda **kw: send_input_to_session(**kw),
'read_file': lambda **kw: read_file(**kw, db_conn=self.db_conn), "read_session_output": lambda **kw: read_session_output(**kw),
'write_file': lambda **kw: write_file(**kw, db_conn=self.db_conn), "close_interactive_session": lambda **kw: close_interactive_session(
'list_directory': lambda **kw: list_directory(**kw), **kw
'mkdir': lambda **kw: mkdir(**kw), ),
'chdir': lambda **kw: chdir(**kw), "read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn),
'getpwd': lambda **kw: getpwd(**kw), "write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn),
'db_set': lambda **kw: db_set(**kw, db_conn=self.db_conn), "list_directory": lambda **kw: list_directory(**kw),
'db_get': lambda **kw: db_get(**kw, db_conn=self.db_conn), "mkdir": lambda **kw: mkdir(**kw),
'db_query': lambda **kw: db_query(**kw, db_conn=self.db_conn), "chdir": lambda **kw: chdir(**kw),
'web_search': lambda **kw: web_search(**kw), "getpwd": lambda **kw: getpwd(**kw),
'web_search_news': lambda **kw: web_search_news(**kw), "db_set": lambda **kw: db_set(**kw, db_conn=self.db_conn),
'python_exec': lambda **kw: python_exec(**kw, python_globals=self.python_globals), "db_get": lambda **kw: db_get(**kw, db_conn=self.db_conn),
'index_source_directory': lambda **kw: index_source_directory(**kw), "db_query": lambda **kw: db_query(**kw, db_conn=self.db_conn),
'search_replace': lambda **kw: search_replace(**kw, db_conn=self.db_conn), "web_search": lambda **kw: web_search(**kw),
'open_editor': lambda **kw: open_editor(**kw), "web_search_news": lambda **kw: web_search_news(**kw),
'editor_insert_text': lambda **kw: editor_insert_text(**kw, db_conn=self.db_conn), "python_exec": lambda **kw: python_exec(
'editor_replace_text': lambda **kw: editor_replace_text(**kw, db_conn=self.db_conn), **kw, python_globals=self.python_globals
'editor_search': lambda **kw: editor_search(**kw), ),
'close_editor': lambda **kw: close_editor(**kw), "index_source_directory": lambda **kw: index_source_directory(**kw),
'create_diff': lambda **kw: create_diff(**kw), "search_replace": lambda **kw: search_replace(
'apply_patch': lambda **kw: apply_patch(**kw, db_conn=self.db_conn), **kw, db_conn=self.db_conn
'display_file_diff': lambda **kw: display_file_diff(**kw), ),
'display_edit_summary': lambda **kw: display_edit_summary(), "open_editor": lambda **kw: open_editor(**kw),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw), "editor_insert_text": lambda **kw: editor_insert_text(
'clear_edit_tracker': lambda **kw: clear_edit_tracker(), **kw, db_conn=self.db_conn
'start_interactive_session': lambda **kw: start_interactive_session(**kw), ),
'send_input_to_session': lambda **kw: send_input_to_session(**kw), "editor_replace_text": lambda **kw: editor_replace_text(
'read_session_output': lambda **kw: read_session_output(**kw), **kw, db_conn=self.db_conn
'list_active_sessions': lambda **kw: list_active_sessions(**kw), ),
'close_interactive_session': lambda **kw: close_interactive_session(**kw), "editor_search": lambda **kw: editor_search(**kw),
'create_agent': lambda **kw: create_agent(**kw), "close_editor": lambda **kw: close_editor(**kw),
'list_agents': lambda **kw: list_agents(**kw), "create_diff": lambda **kw: create_diff(**kw),
'execute_agent_task': lambda **kw: execute_agent_task(**kw), "apply_patch": lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
'remove_agent': lambda **kw: remove_agent(**kw), "display_file_diff": lambda **kw: display_file_diff(**kw),
'collaborate_agents': lambda **kw: collaborate_agents(**kw), "display_edit_summary": lambda **kw: display_edit_summary(),
'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw), "display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw), "clear_edit_tracker": lambda **kw: clear_edit_tracker(),
'search_knowledge': lambda **kw: search_knowledge(**kw), "start_interactive_session": lambda **kw: start_interactive_session(
'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw), **kw
'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw), ),
'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw), "send_input_to_session": lambda **kw: send_input_to_session(**kw),
'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw), "read_session_output": lambda **kw: read_session_output(**kw),
"list_active_sessions": lambda **kw: list_active_sessions(**kw),
"close_interactive_session": lambda **kw: close_interactive_session(
**kw
),
"create_agent": lambda **kw: create_agent(**kw),
"list_agents": lambda **kw: list_agents(**kw),
"execute_agent_task": lambda **kw: execute_agent_task(**kw),
"remove_agent": lambda **kw: remove_agent(**kw),
"collaborate_agents": lambda **kw: collaborate_agents(**kw),
"add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw),
"get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw),
"search_knowledge": lambda **kw: search_knowledge(**kw),
"get_knowledge_by_category": lambda **kw: get_knowledge_by_category(
**kw
),
"update_knowledge_importance": lambda **kw: update_knowledge_importance(
**kw
),
"delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw),
"get_knowledge_statistics": lambda **kw: get_knowledge_statistics(
**kw
),
} }
if func_name in func_map: if func_name in func_map:
future = executor.submit(func_map[func_name], **arguments) future = executor.submit(func_map[func_name], **arguments)
futures.append((tool_call['id'], future)) futures.append((tool_call["id"], future))
for tool_id, future in futures: for tool_id, future in futures:
try: try:
result = future.result(timeout=30) result = future.result(timeout=30)
result = truncate_tool_result(result) result = truncate_tool_result(result)
logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...") logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...")
results.append({ results.append(
"tool_call_id": tool_id, {
"role": "tool", "tool_call_id": tool_id,
"content": json.dumps(result) "role": "tool",
}) "content": json.dumps(result),
}
)
except Exception as e: except Exception as e:
logger.debug(f"Tool error for {tool_id}: {str(e)}") logger.debug(f"Tool error for {tool_id}: {str(e)}")
error_msg = str(e)[:200] if len(str(e)) > 200 else str(e) error_msg = str(e)[:200] if len(str(e)) > 200 else str(e)
results.append({ results.append(
"tool_call_id": tool_id, {
"role": "tool", "tool_call_id": tool_id,
"content": json.dumps({"status": "error", "error": error_msg}) "role": "tool",
}) "content": json.dumps(
{"status": "error", "error": error_msg}
),
}
)
return results return results
def process_response(self, response): def process_response(self, response):
if 'error' in response: if "error" in response:
return f"Error: {response['error']}" return f"Error: {response['error']}"
if 'choices' not in response or not response['choices']: if "choices" not in response or not response["choices"]:
return "No response from API" return "No response from API"
message = response['choices'][0]['message'] message = response["choices"][0]["message"]
self.messages.append(message) self.messages.append(message)
if 'tool_calls' in message and message['tool_calls']: if "tool_calls" in message and message["tool_calls"]:
if self.verbose: if self.verbose:
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}") print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}")
tool_results = self.execute_tool_calls(message['tool_calls']) tool_results = self.execute_tool_calls(message["tool_calls"])
for result in tool_results: for result in tool_results:
self.messages.append(result) self.messages.append(result)
follow_up = call_api( follow_up = call_api(
self.messages, self.model, self.api_url, self.api_key, self.messages,
self.use_tools, get_tools_definition(), verbose=self.verbose self.model,
self.api_url,
self.api_key,
self.use_tools,
get_tools_definition(),
verbose=self.verbose,
) )
return self.process_response(follow_up) return self.process_response(follow_up)
content = message.get('content', '') content = message.get("content", "")
return render_markdown(content, self.syntax_highlighting) return render_markdown(content, self.syntax_highlighting)
def signal_handler(self, signum, frame): def signal_handler(self, signum, frame):
@ -303,7 +405,9 @@ class Assistant:
self.autonomous_mode = False self.autonomous_mode = False
sys.exit(0) sys.exit(0)
else: else:
print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}") print(
f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}"
)
return return
self.interrupt_count += 1 self.interrupt_count += 1
@ -323,21 +427,34 @@ class Assistant:
readline.set_history_length(1000) readline.set_history_length(1000)
import atexit import atexit
atexit.register(readline.write_history_file, HISTORY_FILE) atexit.register(readline.write_history_file, HISTORY_FILE)
commands = ['exit', 'quit', 'help', 'reset', 'dump', 'verbose', commands = [
'models', 'tools', 'review', 'refactor', 'obfuscate', '/auto'] "exit",
"quit",
"help",
"reset",
"dump",
"verbose",
"models",
"tools",
"review",
"refactor",
"obfuscate",
"/auto",
]
def completer(text, state): def completer(text, state):
options = [cmd for cmd in commands if cmd.startswith(text)] options = [cmd for cmd in commands if cmd.startswith(text)]
glob_pattern = os.path.expanduser(text) + '*' glob_pattern = os.path.expanduser(text) + "*"
path_options = glob_module.glob(glob_pattern) path_options = glob_module.glob(glob_pattern)
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options] path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
combined_options = sorted(list(set(options + path_options))) combined_options = sorted(list(set(options + path_options)))
#combined_options.extend(self.commands) # combined_options.extend(self.commands)
if state < len(combined_options): if state < len(combined_options):
return combined_options[state] return combined_options[state]
@ -345,10 +462,10 @@ class Assistant:
return None return None
delims = readline.get_completer_delims() delims = readline.get_completer_delims()
readline.set_completer_delims(delims.replace('/', '')) readline.set_completer_delims(delims.replace("/", ""))
readline.set_completer(completer) readline.set_completer(completer)
readline.parse_and_bind('tab: complete') readline.parse_and_bind("tab: complete")
def run_repl(self): def run_repl(self):
self.setup_readline() self.setup_readline()
@ -368,8 +485,11 @@ class Assistant:
if self.background_monitoring: if self.background_monitoring:
try: try:
from pr.multiplexer import get_all_sessions from pr.multiplexer import get_all_sessions
sessions = get_all_sessions() sessions = get_all_sessions()
active_count = sum(1 for s in sessions.values() if s.get('status') == 'running') active_count = sum(
1 for s in sessions.values() if s.get("status") == "running"
)
if active_count > 0: if active_count > 0:
prompt += f"[{active_count}bg]" prompt += f"[{active_count}bg]"
except: except:
@ -405,10 +525,11 @@ class Assistant:
message = sys.stdin.read() message = sys.stdin.read()
from pr.autonomous.mode import run_autonomous_mode from pr.autonomous.mode import run_autonomous_mode
run_autonomous_mode(self, message) run_autonomous_mode(self, message)
def cleanup(self): def cleanup(self):
if hasattr(self, 'enhanced') and self.enhanced: if hasattr(self, "enhanced") and self.enhanced:
try: try:
self.enhanced.cleanup() self.enhanced.cleanup()
except Exception as e: except Exception as e:
@ -424,6 +545,7 @@ class Assistant:
try: try:
from pr.multiplexer import cleanup_all_multiplexers from pr.multiplexer import cleanup_all_multiplexers
cleanup_all_multiplexers() cleanup_all_multiplexers()
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up multiplexers: {e}") logger.error(f"Error cleaning up multiplexers: {e}")
@ -433,7 +555,9 @@ class Assistant:
def run(self): def run(self):
try: try:
print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}") print(
f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}"
)
if self.args.interactive or (not self.args.message and sys.stdin.isatty()): if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
print("DEBUG: calling run_repl") print("DEBUG: calling run_repl")
self.run_repl() self.run_repl()
@ -443,6 +567,7 @@ class Assistant:
finally: finally:
self.cleanup() self.cleanup()
def process_message(assistant, message): def process_message(assistant, message):
assistant.messages.append({"role": "user", "content": message}) assistant.messages.append({"role": "user", "content": message})
@ -453,9 +578,13 @@ def process_message(assistant, message):
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}") print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}")
response = call_api( response = call_api(
assistant.messages, assistant.model, assistant.api_url, assistant.messages,
assistant.api_key, assistant.use_tools, get_tools_definition(), assistant.model,
verbose=assistant.verbose assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose,
) )
result = assistant.process_response(response) result = assistant.process_response(response)

View File

@ -1,7 +1,12 @@
import time
import threading import threading
from pr.core.background_monitor import get_global_monitor import time
from pr.tools.interactive_control import list_active_sessions, get_session_status, read_session_output
from pr.tools.interactive_control import (
get_session_status,
list_active_sessions,
read_session_output,
)
class AutonomousInteractions: class AutonomousInteractions:
def __init__(self, interaction_interval=10.0): def __init__(self, interaction_interval=10.0):
@ -16,7 +21,9 @@ class AutonomousInteractions:
self.llm_callback = llm_callback self.llm_callback = llm_callback
if self.interaction_thread is None: if self.interaction_thread is None:
self.active = True self.active = True
self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True) self.interaction_thread = threading.Thread(
target=self._interaction_loop, daemon=True
)
self.interaction_thread.start() self.interaction_thread.start()
def stop(self): def stop(self):
@ -48,7 +55,9 @@ class AutonomousInteractions:
if not sessions: if not sessions:
return # No active sessions return # No active sessions
sessions_needing_attention = self._identify_sessions_needing_attention(sessions) sessions_needing_attention = self._identify_sessions_needing_attention(
sessions
)
if sessions_needing_attention and self.llm_callback: if sessions_needing_attention and self.llm_callback:
# Format session updates for LLM # Format session updates for LLM
@ -63,26 +72,30 @@ class AutonomousInteractions:
needing_attention = [] needing_attention = []
for session_name, session_data in sessions.items(): for session_name, session_data in sessions.items():
metadata = session_data['metadata'] metadata = session_data["metadata"]
output_summary = session_data['output_summary'] output_summary = session_data["output_summary"]
# Criteria for needing attention: # Criteria for needing attention:
# 1. Recent output activity # 1. Recent output activity
time_since_activity = time.time() - metadata.get('last_activity', 0) time_since_activity = time.time() - metadata.get("last_activity", 0)
if time_since_activity < 30: # Activity in last 30 seconds if time_since_activity < 30: # Activity in last 30 seconds
needing_attention.append(session_name) needing_attention.append(session_name)
continue continue
# 2. High output volume (potential completion or error) # 2. High output volume (potential completion or error)
total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines'] total_lines = (
output_summary["stdout_lines"] + output_summary["stderr_lines"]
)
if total_lines > 50: # Arbitrary threshold if total_lines > 50: # Arbitrary threshold
needing_attention.append(session_name) needing_attention.append(session_name)
continue continue
# 3. Long-running sessions that might need intervention # 3. Long-running sessions that might need intervention
session_age = time.time() - metadata.get('start_time', 0) session_age = time.time() - metadata.get("start_time", 0)
if session_age > 300 and time_since_activity > 60: # 5+ minutes old, inactive for 1+ minute if (
session_age > 300 and time_since_activity > 60
): # 5+ minutes old, inactive for 1+ minute
needing_attention.append(session_name) needing_attention.append(session_name)
continue continue
@ -95,18 +108,18 @@ class AutonomousInteractions:
def _session_looks_stuck(self, session_name, session_data): def _session_looks_stuck(self, session_name, session_data):
"""Determine if a session appears to be stuck waiting for input.""" """Determine if a session appears to be stuck waiting for input."""
metadata = session_data['metadata'] metadata = session_data["metadata"]
# Check if process is still running # Check if process is still running
status = get_session_status(session_name) status = get_session_status(session_name)
if not status or not status.get('is_active', False): if not status or not status.get("is_active", False):
return False return False
time_since_activity = time.time() - metadata.get('last_activity', 0) time_since_activity = time.time() - metadata.get("last_activity", 0)
interaction_count = metadata.get('interaction_count', 0) interaction_count = metadata.get("interaction_count", 0)
# If running for a while but no interactions, might be waiting # If running for a while but no interactions, might be waiting
session_age = time.time() - metadata.get('start_time', 0) session_age = time.time() - metadata.get("start_time", 0)
if session_age > 60 and interaction_count == 0 and time_since_activity > 30: if session_age > 60 and interaction_count == 0 and time_since_activity > 30:
return True return True
@ -119,9 +132,9 @@ class AutonomousInteractions:
def _format_session_updates(self, session_names): def _format_session_updates(self, session_names):
"""Format session information for LLM consumption.""" """Format session information for LLM consumption."""
updates = { updates = {
'type': 'background_session_updates', "type": "background_session_updates",
'timestamp': time.time(), "timestamp": time.time(),
'sessions': {} "sessions": {},
} }
for session_name in session_names: for session_name in session_names:
@ -131,12 +144,12 @@ class AutonomousInteractions:
try: try:
recent_output = read_session_output(session_name, lines=20) recent_output = read_session_output(session_name, lines=20)
except: except:
recent_output = {'stdout': '', 'stderr': ''} recent_output = {"stdout": "", "stderr": ""}
updates['sessions'][session_name] = { updates["sessions"][session_name] = {
'status': status, "status": status,
'recent_output': recent_output, "recent_output": recent_output,
'summary': self._create_session_summary(status, recent_output) "summary": self._create_session_summary(status, recent_output),
} }
return updates return updates
@ -145,34 +158,39 @@ class AutonomousInteractions:
"""Create a human-readable summary of session status.""" """Create a human-readable summary of session status."""
summary_parts = [] summary_parts = []
process_type = status.get('metadata', {}).get('process_type', 'unknown') process_type = status.get("metadata", {}).get("process_type", "unknown")
summary_parts.append(f"Type: {process_type}") summary_parts.append(f"Type: {process_type}")
is_active = status.get('is_active', False) is_active = status.get("is_active", False)
summary_parts.append(f"Status: {'Active' if is_active else 'Inactive'}") summary_parts.append(f"Status: {'Active' if is_active else 'Inactive'}")
if is_active and 'pid' in status: if is_active and "pid" in status:
summary_parts.append(f"PID: {status['pid']}") summary_parts.append(f"PID: {status['pid']}")
age = time.time() - status.get('metadata', {}).get('start_time', 0) age = time.time() - status.get("metadata", {}).get("start_time", 0)
summary_parts.append(f"Age: {age:.1f}s") summary_parts.append(f"Age: {age:.1f}s")
output_lines = len(recent_output.get('stdout', '').split('\n')) + len(recent_output.get('stderr', '').split('\n')) output_lines = len(recent_output.get("stdout", "").split("\n")) + len(
recent_output.get("stderr", "").split("\n")
)
summary_parts.append(f"Recent output: {output_lines} lines") summary_parts.append(f"Recent output: {output_lines} lines")
interaction_count = status.get('metadata', {}).get('interaction_count', 0) interaction_count = status.get("metadata", {}).get("interaction_count", 0)
summary_parts.append(f"Interactions: {interaction_count}") summary_parts.append(f"Interactions: {interaction_count}")
return " | ".join(summary_parts) return " | ".join(summary_parts)
# Global autonomous interactions instance # Global autonomous interactions instance
_global_autonomous = None _global_autonomous = None
def get_global_autonomous(): def get_global_autonomous():
"""Get the global autonomous interactions instance.""" """Get the global autonomous interactions instance."""
global _global_autonomous global _global_autonomous
return _global_autonomous return _global_autonomous
def start_global_autonomous(llm_callback=None): def start_global_autonomous(llm_callback=None):
"""Start global autonomous interactions.""" """Start global autonomous interactions."""
global _global_autonomous global _global_autonomous
@ -181,6 +199,7 @@ def start_global_autonomous(llm_callback=None):
_global_autonomous.start(llm_callback) _global_autonomous.start(llm_callback)
return _global_autonomous return _global_autonomous
def stop_global_autonomous(): def stop_global_autonomous():
"""Stop global autonomous interactions.""" """Stop global autonomous interactions."""
global _global_autonomous global _global_autonomous

View File

@ -1,8 +1,9 @@
import queue
import threading import threading
import time import time
import queue
from pr.multiplexer import get_all_multiplexer_states, get_multiplexer from pr.multiplexer import get_all_multiplexer_states, get_multiplexer
from pr.tools.interactive_control import get_session_status
class BackgroundMonitor: class BackgroundMonitor:
def __init__(self, check_interval=5.0): def __init__(self, check_interval=5.0):
@ -17,7 +18,9 @@ class BackgroundMonitor:
"""Start the background monitoring thread.""" """Start the background monitoring thread."""
if self.monitor_thread is None: if self.monitor_thread is None:
self.active = True self.active = True
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) self.monitor_thread = threading.Thread(
target=self._monitor_loop, daemon=True
)
self.monitor_thread.start() self.monitor_thread.start()
def stop(self): def stop(self):
@ -78,19 +81,18 @@ class BackgroundMonitor:
# Check for new sessions # Check for new sessions
for session_name in new_states: for session_name in new_states:
if session_name not in old_states: if session_name not in old_states:
events.append({ events.append(
'type': 'session_started', {
'session_name': session_name, "type": "session_started",
'metadata': new_states[session_name]['metadata'] "session_name": session_name,
}) "metadata": new_states[session_name]["metadata"],
}
)
# Check for ended sessions # Check for ended sessions
for session_name in old_states: for session_name in old_states:
if session_name not in new_states: if session_name not in new_states:
events.append({ events.append({"type": "session_ended", "session_name": session_name})
'type': 'session_ended',
'session_name': session_name
})
# Check for activity in existing sessions # Check for activity in existing sessions
for session_name, new_state in new_states.items(): for session_name, new_state in new_states.items():
@ -98,92 +100,112 @@ class BackgroundMonitor:
old_state = old_states[session_name] old_state = old_states[session_name]
# Check for output changes # Check for output changes
old_stdout_lines = old_state['output_summary']['stdout_lines'] old_stdout_lines = old_state["output_summary"]["stdout_lines"]
new_stdout_lines = new_state['output_summary']['stdout_lines'] new_stdout_lines = new_state["output_summary"]["stdout_lines"]
old_stderr_lines = old_state['output_summary']['stderr_lines'] old_stderr_lines = old_state["output_summary"]["stderr_lines"]
new_stderr_lines = new_state['output_summary']['stderr_lines'] new_stderr_lines = new_state["output_summary"]["stderr_lines"]
if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines: if (
new_stdout_lines > old_stdout_lines
or new_stderr_lines > old_stderr_lines
):
# Get the new output # Get the new output
mux = get_multiplexer(session_name) mux = get_multiplexer(session_name)
if mux: if mux:
all_output = mux.get_all_output() all_output = mux.get_all_output()
new_output = { new_output = {
'stdout': all_output['stdout'].split('\n')[old_stdout_lines:], "stdout": all_output["stdout"].split("\n")[
'stderr': all_output['stderr'].split('\n')[old_stderr_lines:] old_stdout_lines:
],
"stderr": all_output["stderr"].split("\n")[
old_stderr_lines:
],
} }
events.append({ events.append(
'type': 'output_received', {
'session_name': session_name, "type": "output_received",
'new_output': new_output, "session_name": session_name,
'total_lines': { "new_output": new_output,
'stdout': new_stdout_lines, "total_lines": {
'stderr': new_stderr_lines "stdout": new_stdout_lines,
"stderr": new_stderr_lines,
},
} }
}) )
# Check for state changes # Check for state changes
old_metadata = old_state['metadata'] old_metadata = old_state["metadata"]
new_metadata = new_state['metadata'] new_metadata = new_state["metadata"]
if old_metadata.get('state') != new_metadata.get('state'): if old_metadata.get("state") != new_metadata.get("state"):
events.append({ events.append(
'type': 'state_changed', {
'session_name': session_name, "type": "state_changed",
'old_state': old_metadata.get('state'), "session_name": session_name,
'new_state': new_metadata.get('state') "old_state": old_metadata.get("state"),
}) "new_state": new_metadata.get("state"),
}
)
# Check for process type identification # Check for process type identification
if (old_metadata.get('process_type') == 'unknown' and if (
new_metadata.get('process_type') != 'unknown'): old_metadata.get("process_type") == "unknown"
events.append({ and new_metadata.get("process_type") != "unknown"
'type': 'process_identified', ):
'session_name': session_name, events.append(
'process_type': new_metadata.get('process_type') {
}) "type": "process_identified",
"session_name": session_name,
"process_type": new_metadata.get("process_type"),
}
)
# Check for sessions needing attention (based on heuristics) # Check for sessions needing attention (based on heuristics)
for session_name, state in new_states.items(): for session_name, state in new_states.items():
metadata = state['metadata'] metadata = state["metadata"]
output_summary = state['output_summary'] output_summary = state["output_summary"]
# Heuristic: High output volume might indicate completion or error # Heuristic: High output volume might indicate completion or error
total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines'] total_lines = (
output_summary["stdout_lines"] + output_summary["stderr_lines"]
)
if total_lines > 100: # Arbitrary threshold if total_lines > 100: # Arbitrary threshold
events.append({ events.append(
'type': 'high_output_volume', {
'session_name': session_name, "type": "high_output_volume",
'total_lines': total_lines "session_name": session_name,
}) "total_lines": total_lines,
}
)
# Heuristic: Long-running session without recent activity # Heuristic: Long-running session without recent activity
time_since_activity = time.time() - metadata.get('last_activity', 0) time_since_activity = time.time() - metadata.get("last_activity", 0)
if time_since_activity > 300: # 5 minutes if time_since_activity > 300: # 5 minutes
events.append({ events.append(
'type': 'inactive_session', {
'session_name': session_name, "type": "inactive_session",
'inactive_seconds': time_since_activity "session_name": session_name,
}) "inactive_seconds": time_since_activity,
}
)
# Heuristic: Sessions that might be waiting for input # Heuristic: Sessions that might be waiting for input
# This would be enhanced with prompt detection in later phases # This would be enhanced with prompt detection in later phases
if self._might_be_waiting_for_input(session_name, state): if self._might_be_waiting_for_input(session_name, state):
events.append({ events.append(
'type': 'possible_input_needed', {"type": "possible_input_needed", "session_name": session_name}
'session_name': session_name )
})
return events return events
def _might_be_waiting_for_input(self, session_name, state): def _might_be_waiting_for_input(self, session_name, state):
"""Heuristic to detect if a session might be waiting for input.""" """Heuristic to detect if a session might be waiting for input."""
metadata = state['metadata'] metadata = state["metadata"]
process_type = metadata.get('process_type', 'unknown') metadata.get("process_type", "unknown")
# Simple heuristics based on process type and recent activity # Simple heuristics based on process type and recent activity
time_since_activity = time.time() - metadata.get('last_activity', 0) time_since_activity = time.time() - metadata.get("last_activity", 0)
# If it's been more than 10 seconds since last activity, might be waiting # If it's been more than 10 seconds since last activity, might be waiting
if time_since_activity > 10: if time_since_activity > 10:
@ -191,9 +213,11 @@ class BackgroundMonitor:
return False return False
# Global monitor instance # Global monitor instance
_global_monitor = None _global_monitor = None
def get_global_monitor(): def get_global_monitor():
"""Get the global background monitor instance.""" """Get the global background monitor instance."""
global _global_monitor global _global_monitor
@ -201,20 +225,24 @@ def get_global_monitor():
_global_monitor = BackgroundMonitor() _global_monitor = BackgroundMonitor()
return _global_monitor return _global_monitor
def start_global_monitor(): def start_global_monitor():
"""Start the global background monitor.""" """Start the global background monitor."""
monitor = get_global_monitor() monitor = get_global_monitor()
monitor.start() monitor.start()
def stop_global_monitor(): def stop_global_monitor():
"""Stop the global background monitor.""" """Stop the global background monitor."""
global _global_monitor global _global_monitor
if _global_monitor: if _global_monitor:
_global_monitor.stop() _global_monitor.stop()
# Global monitor instance # Global monitor instance
_global_monitor = None _global_monitor = None
def start_global_monitor(): def start_global_monitor():
"""Start the global background monitor.""" """Start the global background monitor."""
global _global_monitor global _global_monitor
@ -223,6 +251,7 @@ def start_global_monitor():
_global_monitor.start() _global_monitor.start()
return _global_monitor return _global_monitor
def stop_global_monitor(): def stop_global_monitor():
"""Stop the global background monitor.""" """Stop the global background monitor."""
global _global_monitor global _global_monitor
@ -230,6 +259,7 @@ def stop_global_monitor():
_global_monitor.stop() _global_monitor.stop()
_global_monitor = None _global_monitor = None
def get_global_monitor(): def get_global_monitor():
"""Get the global background monitor instance.""" """Get the global background monitor instance."""
global _global_monitor global _global_monitor

View File

@ -1,22 +1,17 @@
import os
import configparser import configparser
from typing import Dict, Any import os
from typing import Any, Dict
from pr.core.logging import get_logger from pr.core.logging import get_logger
logger = get_logger('config') logger = get_logger("config")
CONFIG_FILE = os.path.expanduser("~/.prrc") CONFIG_FILE = os.path.expanduser("~/.prrc")
LOCAL_CONFIG_FILE = ".prrc" LOCAL_CONFIG_FILE = ".prrc"
def load_config() -> Dict[str, Any]: def load_config() -> Dict[str, Any]:
config = { config = {"api": {}, "autonomous": {}, "ui": {}, "output": {}, "session": {}}
'api': {},
'autonomous': {},
'ui': {},
'output': {},
'session': {}
}
global_config = _load_config_file(CONFIG_FILE) global_config = _load_config_file(CONFIG_FILE)
local_config = _load_config_file(LOCAL_CONFIG_FILE) local_config = _load_config_file(LOCAL_CONFIG_FILE)
@ -55,9 +50,9 @@ def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
def _parse_value(value: str) -> Any: def _parse_value(value: str) -> Any:
value = value.strip() value = value.strip()
if value.lower() == 'true': if value.lower() == "true":
return True return True
if value.lower() == 'false': if value.lower() == "false":
return False return False
if value.isdigit(): if value.isdigit():
@ -99,7 +94,7 @@ max_history = 1000
""" """
try: try:
with open(filepath, 'w') as f: with open(filepath, "w") as f:
f.write(default_config) f.write(default_config)
logger.info(f"Created default configuration at {filepath}") logger.info(f"Created default configuration at {filepath}")
return True return True

View File

@ -1,11 +1,21 @@
import os
import json import json
import logging import logging
from pr.config import (CONTEXT_FILE, GLOBAL_CONTEXT_FILE, CONTEXT_COMPRESSION_THRESHOLD, import os
RECENT_MESSAGES_TO_KEEP, MAX_TOKENS_LIMIT, CHARS_PER_TOKEN,
EMERGENCY_MESSAGES_TO_KEEP, CONTENT_TRIM_LENGTH, MAX_TOOL_RESULT_LENGTH) from pr.config import (
CHARS_PER_TOKEN,
CONTENT_TRIM_LENGTH,
CONTEXT_COMPRESSION_THRESHOLD,
CONTEXT_FILE,
EMERGENCY_MESSAGES_TO_KEEP,
GLOBAL_CONTEXT_FILE,
MAX_TOKENS_LIMIT,
MAX_TOOL_RESULT_LENGTH,
RECENT_MESSAGES_TO_KEEP,
)
from pr.ui import Colors from pr.ui import Colors
def truncate_tool_result(result, max_length=None): def truncate_tool_result(result, max_length=None):
if max_length is None: if max_length is None:
max_length = MAX_TOOL_RESULT_LENGTH max_length = MAX_TOOL_RESULT_LENGTH
@ -17,24 +27,36 @@ def truncate_tool_result(result, max_length=None):
if "output" in result_copy and isinstance(result_copy["output"], str): if "output" in result_copy and isinstance(result_copy["output"], str):
if len(result_copy["output"]) > max_length: if len(result_copy["output"]) > max_length:
result_copy["output"] = result_copy["output"][:max_length] + f"\n... [truncated {len(result_copy['output']) - max_length} chars]" result_copy["output"] = (
result_copy["output"][:max_length]
+ f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
)
if "content" in result_copy and isinstance(result_copy["content"], str): if "content" in result_copy and isinstance(result_copy["content"], str):
if len(result_copy["content"]) > max_length: if len(result_copy["content"]) > max_length:
result_copy["content"] = result_copy["content"][:max_length] + f"\n... [truncated {len(result_copy['content']) - max_length} chars]" result_copy["content"] = (
result_copy["content"][:max_length]
+ f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
)
if "data" in result_copy and isinstance(result_copy["data"], str): if "data" in result_copy and isinstance(result_copy["data"], str):
if len(result_copy["data"]) > max_length: if len(result_copy["data"]) > max_length:
result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]" result_copy["data"] = (
result_copy["data"][:max_length] + f"\n... [truncated]"
)
if "error" in result_copy and isinstance(result_copy["error"], str): if "error" in result_copy and isinstance(result_copy["error"], str):
if len(result_copy["error"]) > max_length // 2: if len(result_copy["error"]) > max_length // 2:
result_copy["error"] = result_copy["error"][:max_length // 2] + "... [truncated]" result_copy["error"] = (
result_copy["error"][: max_length // 2] + "... [truncated]"
)
return result_copy return result_copy
def init_system_message(args): def init_system_message(args):
context_parts = ["""You are a professional AI assistant with access to advanced tools. context_parts = [
"""You are a professional AI assistant with access to advanced tools.
File Operations: File Operations:
- Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications - Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications
@ -51,14 +73,15 @@ Process Management:
Shell Commands: Shell Commands:
- Be a shell ninja using native OS tools - Be a shell ninja using native OS tools
- Prefer standard Unix utilities over complex scripts - Prefer standard Unix utilities over complex scripts
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""] - Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""
#context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."] ]
# context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
max_context_size = 10000 max_context_size = 10000
if args.include_env: if args.include_env:
env_context = "Environment Variables:\n" env_context = "Environment Variables:\n"
for key, value in os.environ.items(): for key, value in os.environ.items():
if not key.startswith('_'): if not key.startswith("_"):
env_context += f"{key}={value}\n" env_context += f"{key}={value}\n"
if len(env_context) > max_context_size: if len(env_context) > max_context_size:
env_context = env_context[:max_context_size] + "\n... [truncated]" env_context = env_context[:max_context_size] + "\n... [truncated]"
@ -67,7 +90,7 @@ Shell Commands:
for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]: for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]:
if os.path.exists(context_file): if os.path.exists(context_file):
try: try:
with open(context_file, 'r') as f: with open(context_file) as f:
content = f.read() content = f.read()
if len(content) > max_context_size: if len(content) > max_context_size:
content = content[:max_context_size] + "\n... [truncated]" content = content[:max_context_size] + "\n... [truncated]"
@ -78,7 +101,7 @@ Shell Commands:
if args.context: if args.context:
for ctx_file in args.context: for ctx_file in args.context:
try: try:
with open(ctx_file, 'r') as f: with open(ctx_file) as f:
content = f.read() content = f.read()
if len(content) > max_context_size: if len(content) > max_context_size:
content = content[:max_context_size] + "\n... [truncated]" content = content[:max_context_size] + "\n... [truncated]"
@ -88,22 +111,29 @@ Shell Commands:
system_message = "\n\n".join(context_parts) system_message = "\n\n".join(context_parts)
if len(system_message) > max_context_size * 3: if len(system_message) > max_context_size * 3:
system_message = system_message[:max_context_size * 3] + "\n... [system message truncated]" system_message = (
system_message[: max_context_size * 3] + "\n... [system message truncated]"
)
return {"role": "system", "content": system_message} return {"role": "system", "content": system_message}
def should_compress_context(messages): def should_compress_context(messages):
return len(messages) > CONTEXT_COMPRESSION_THRESHOLD return len(messages) > CONTEXT_COMPRESSION_THRESHOLD
def compress_context(messages): def compress_context(messages):
return manage_context_window(messages, verbose=False) return manage_context_window(messages, verbose=False)
def manage_context_window(messages, verbose): def manage_context_window(messages, verbose):
if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD: if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD:
return messages return messages
if verbose: if verbose:
print(f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}") print(
f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}"
)
system_message = messages[0] system_message = messages[0]
recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:] recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:]
@ -113,18 +143,21 @@ def manage_context_window(messages, verbose):
summary = summarize_messages(middle_messages) summary = summarize_messages(middle_messages)
summary_message = { summary_message = {
"role": "system", "role": "system",
"content": f"[Previous conversation summary: {summary}]" "content": f"[Previous conversation summary: {summary}]",
} }
new_messages = [system_message, summary_message] + recent_messages new_messages = [system_message, summary_message] + recent_messages
if verbose: if verbose:
print(f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}") print(
f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}"
)
return new_messages return new_messages
return messages return messages
def summarize_messages(messages): def summarize_messages(messages):
summary_parts = [] summary_parts = []
@ -142,6 +175,7 @@ def summarize_messages(messages):
return " | ".join(summary_parts[:10]) return " | ".join(summary_parts[:10])
def estimate_tokens(messages): def estimate_tokens(messages):
total_chars = 0 total_chars = 0
@ -155,6 +189,7 @@ def estimate_tokens(messages):
return int(estimated_tokens * overhead_multiplier) return int(estimated_tokens * overhead_multiplier)
def trim_message_content(message, max_length): def trim_message_content(message, max_length):
trimmed_msg = message.copy() trimmed_msg = message.copy()
@ -162,14 +197,22 @@ def trim_message_content(message, max_length):
content = trimmed_msg["content"] content = trimmed_msg["content"]
if isinstance(content, str) and len(content) > max_length: if isinstance(content, str) and len(content) > max_length:
trimmed_msg["content"] = content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]" trimmed_msg["content"] = (
content[:max_length]
+ f"\n... [trimmed {len(content) - max_length} chars]"
)
elif isinstance(content, list): elif isinstance(content, list):
trimmed_content = [] trimmed_content = []
for item in content: for item in content:
if isinstance(item, dict): if isinstance(item, dict):
trimmed_item = item.copy() trimmed_item = item.copy()
if "text" in trimmed_item and len(trimmed_item["text"]) > max_length: if (
trimmed_item["text"] = trimmed_item["text"][:max_length] + f"\n... [trimmed]" "text" in trimmed_item
and len(trimmed_item["text"]) > max_length
):
trimmed_item["text"] = (
trimmed_item["text"][:max_length] + f"\n... [trimmed]"
)
trimmed_content.append(trimmed_item) trimmed_content.append(trimmed_item)
else: else:
trimmed_content.append(item) trimmed_content.append(item)
@ -179,35 +222,61 @@ def trim_message_content(message, max_length):
if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str): if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str):
content = trimmed_msg["content"] content = trimmed_msg["content"]
if len(content) > MAX_TOOL_RESULT_LENGTH: if len(content) > MAX_TOOL_RESULT_LENGTH:
trimmed_msg["content"] = content[:MAX_TOOL_RESULT_LENGTH] + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]" trimmed_msg["content"] = (
content[:MAX_TOOL_RESULT_LENGTH]
+ f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
)
try: try:
parsed = json.loads(content) parsed = json.loads(content)
if isinstance(parsed, dict): if isinstance(parsed, dict):
if "output" in parsed and isinstance(parsed["output"], str) and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2: if (
parsed["output"] = parsed["output"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]" "output" in parsed
if "content" in parsed and isinstance(parsed["content"], str) and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2: and isinstance(parsed["output"], str)
parsed["content"] = parsed["content"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]" and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2
):
parsed["output"] = (
parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2]
+ f"\n... [truncated]"
)
if (
"content" in parsed
and isinstance(parsed["content"], str)
and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2
):
parsed["content"] = (
parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2]
+ f"\n... [truncated]"
)
trimmed_msg["content"] = json.dumps(parsed) trimmed_msg["content"] = json.dumps(parsed)
except: except:
pass pass
return trimmed_msg return trimmed_msg
def intelligently_trim_messages(messages, target_tokens, keep_recent=3): def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
if estimate_tokens(messages) <= target_tokens: if estimate_tokens(messages) <= target_tokens:
return messages return messages
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None system_msg = (
messages[0] if messages and messages[0].get("role") == "system" else None
)
start_idx = 1 if system_msg else 0 start_idx = 1 if system_msg else 0
recent_messages = messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:] recent_messages = (
middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else [] messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
)
middle_messages = (
messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
)
trimmed_middle = [] trimmed_middle = []
for msg in middle_messages: for msg in middle_messages:
if msg.get("role") == "tool": if msg.get("role") == "tool":
trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)) trimmed_middle.append(
trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)
)
elif msg.get("role") in ["user", "assistant"]: elif msg.get("role") in ["user", "assistant"]:
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH)) trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
else: else:
@ -233,6 +302,7 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
return ([system_msg] if system_msg else []) + messages[-1:] return ([system_msg] if system_msg else []) + messages[-1:]
def auto_slim_messages(messages, verbose=False): def auto_slim_messages(messages, verbose=False):
estimated_tokens = estimate_tokens(messages) estimated_tokens = estimate_tokens(messages)
@ -240,29 +310,46 @@ def auto_slim_messages(messages, verbose=False):
return messages return messages
if verbose: if verbose:
print(f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}") print(
print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}") f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}"
)
print(
f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}"
)
result = intelligently_trim_messages(messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP) result = intelligently_trim_messages(
messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP
)
final_tokens = estimate_tokens(result) final_tokens = estimate_tokens(result)
if final_tokens > MAX_TOKENS_LIMIT: if final_tokens > MAX_TOKENS_LIMIT:
if verbose: if verbose:
print(f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}") print(
f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}"
)
result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose) result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose)
final_tokens = estimate_tokens(result) final_tokens = estimate_tokens(result)
if verbose: if verbose:
removed_count = len(messages) - len(result) removed_count = len(messages) - len(result)
print(f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}") print(
print(f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}") f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}"
)
print(
f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}"
)
if removed_count > 0: if removed_count > 0:
print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}") print(
f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}"
)
return result return result
def emergency_reduce_messages(messages, target_tokens, verbose=False): def emergency_reduce_messages(messages, target_tokens, verbose=False):
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None system_msg = (
messages[0] if messages and messages[0].get("role") == "system" else None
)
start_idx = 1 if system_msg else 0 start_idx = 1 if system_msg else 0
keep_count = 2 keep_count = 2

View File

@ -1,22 +1,29 @@
import logging
import json import json
import logging
import uuid import uuid
from typing import Optional, Dict, Any, List from typing import Any, Dict, List, Optional
from pr.config import (
DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS,
KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED,
MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD
)
from pr.cache import APICache, ToolCache
from pr.workflows import WorkflowEngine, WorkflowStorage
from pr.agents import AgentManager from pr.agents import AgentManager
from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor from pr.cache import APICache, ToolCache
from pr.config import (
ADVANCED_CONTEXT_ENABLED,
API_CACHE_TTL,
CACHE_ENABLED,
CONVERSATION_SUMMARY_THRESHOLD,
DB_PATH,
KNOWLEDGE_SEARCH_LIMIT,
MEMORY_AUTO_SUMMARIZE,
TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS,
)
from pr.core.advanced_context import AdvancedContextManager from pr.core.advanced_context import AdvancedContextManager
from pr.core.api import call_api from pr.core.api import call_api
from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore
from pr.tools.base import get_tools_definition from pr.tools.base import get_tools_definition
from pr.workflows import WorkflowEngine, WorkflowStorage
logger = logging.getLogger("pr")
logger = logging.getLogger('pr')
class EnhancedAssistant: class EnhancedAssistant:
def __init__(self, base_assistant): def __init__(self, base_assistant):
@ -32,7 +39,7 @@ class EnhancedAssistant:
self.workflow_storage = WorkflowStorage(DB_PATH) self.workflow_storage = WorkflowStorage(DB_PATH)
self.workflow_engine = WorkflowEngine( self.workflow_engine = WorkflowEngine(
tool_executor=self._execute_tool_for_workflow, tool_executor=self._execute_tool_for_workflow,
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS,
) )
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent) self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
@ -44,20 +51,21 @@ class EnhancedAssistant:
if ADVANCED_CONTEXT_ENABLED: if ADVANCED_CONTEXT_ENABLED:
self.context_manager = AdvancedContextManager( self.context_manager = AdvancedContextManager(
knowledge_store=self.knowledge_store, knowledge_store=self.knowledge_store,
conversation_memory=self.conversation_memory conversation_memory=self.conversation_memory,
) )
else: else:
self.context_manager = None self.context_manager = None
self.current_conversation_id = str(uuid.uuid4())[:16] self.current_conversation_id = str(uuid.uuid4())[:16]
self.conversation_memory.create_conversation( self.conversation_memory.create_conversation(
self.current_conversation_id, self.current_conversation_id, session_id=str(uuid.uuid4())[:16]
session_id=str(uuid.uuid4())[:16]
) )
logger.info("Enhanced Assistant initialized with all features") logger.info("Enhanced Assistant initialized with all features")
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any: def _execute_tool_for_workflow(
self, tool_name: str, arguments: Dict[str, Any]
) -> Any:
if self.tool_cache: if self.tool_cache:
cached_result = self.tool_cache.get(tool_name, arguments) cached_result = self.tool_cache.get(tool_name, arguments)
if cached_result is not None: if cached_result is not None:
@ -65,41 +73,66 @@ class EnhancedAssistant:
return cached_result return cached_result
func_map = { func_map = {
'read_file': lambda **kw: self.base.execute_tool_calls([{ "read_file": lambda **kw: self.base.execute_tool_calls(
'id': 'temp', [
'function': {'name': 'read_file', 'arguments': json.dumps(kw)} {
}])[0], "id": "temp",
'write_file': lambda **kw: self.base.execute_tool_calls([{ "function": {"name": "read_file", "arguments": json.dumps(kw)},
'id': 'temp', }
'function': {'name': 'write_file', 'arguments': json.dumps(kw)} ]
}])[0], )[0],
'list_directory': lambda **kw: self.base.execute_tool_calls([{ "write_file": lambda **kw: self.base.execute_tool_calls(
'id': 'temp', [
'function': {'name': 'list_directory', 'arguments': json.dumps(kw)} {
}])[0], "id": "temp",
'run_command': lambda **kw: self.base.execute_tool_calls([{ "function": {"name": "write_file", "arguments": json.dumps(kw)},
'id': 'temp', }
'function': {'name': 'run_command', 'arguments': json.dumps(kw)} ]
}])[0], )[0],
"list_directory": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {
"name": "list_directory",
"arguments": json.dumps(kw),
},
}
]
)[0],
"run_command": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {
"name": "run_command",
"arguments": json.dumps(kw),
},
}
]
)[0],
} }
if tool_name in func_map: if tool_name in func_map:
result = func_map[tool_name](**arguments) result = func_map[tool_name](**arguments)
if self.tool_cache: if self.tool_cache:
content = result.get('content', '') content = result.get("content", "")
try: try:
parsed_content = json.loads(content) if isinstance(content, str) else content parsed_content = (
json.loads(content) if isinstance(content, str) else content
)
self.tool_cache.set(tool_name, arguments, parsed_content) self.tool_cache.set(tool_name, arguments, parsed_content)
except Exception: except Exception:
pass pass
return result return result
return {'error': f'Unknown tool: {tool_name}'} return {"error": f"Unknown tool: {tool_name}"}
def _api_caller_for_agent(self, messages: List[Dict[str, Any]], def _api_caller_for_agent(
temperature: float, max_tokens: int) -> Dict[str, Any]: self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
) -> Dict[str, Any]:
return call_api( return call_api(
messages, messages,
self.base.model, self.base.model,
@ -109,15 +142,12 @@ class EnhancedAssistant:
tools=None, tools=None,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
verbose=self.base.verbose verbose=self.base.verbose,
) )
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.api_cache and CACHE_ENABLED: if self.api_cache and CACHE_ENABLED:
cached_response = self.api_cache.get( cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096)
self.base.model, messages,
0.7, 4096
)
if cached_response: if cached_response:
logger.debug("API cache hit") logger.debug("API cache hit")
return cached_response return cached_response
@ -129,15 +159,13 @@ class EnhancedAssistant:
self.base.api_key, self.base.api_key,
self.base.use_tools, self.base.use_tools,
get_tools_definition(), get_tools_definition(),
verbose=self.base.verbose verbose=self.base.verbose,
) )
if self.api_cache and CACHE_ENABLED and 'error' not in response: if self.api_cache and CACHE_ENABLED and "error" not in response:
token_count = response.get('usage', {}).get('total_tokens', 0) token_count = response.get("usage", {}).get("total_tokens", 0)
self.api_cache.set( self.api_cache.set(
self.base.model, messages, self.base.model, messages, 0.7, 4096, response, token_count
0.7, 4096,
response, token_count
) )
return response return response
@ -146,35 +174,33 @@ class EnhancedAssistant:
self.base.messages.append({"role": "user", "content": user_message}) self.base.messages.append({"role": "user", "content": user_message})
self.conversation_memory.add_message( self.conversation_memory.add_message(
self.current_conversation_id, self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
str(uuid.uuid4())[:16],
'user',
user_message
) )
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0: if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
facts = self.fact_extractor.extract_facts(user_message) facts = self.fact_extractor.extract_facts(user_message)
for fact in facts[:3]: for fact in facts[:3]:
entry_id = str(uuid.uuid4())[:16] entry_id = str(uuid.uuid4())[:16]
from pr.memory import KnowledgeEntry
import time import time
categories = self.fact_extractor.categorize_content(fact['text']) from pr.memory import KnowledgeEntry
categories = self.fact_extractor.categorize_content(fact["text"])
entry = KnowledgeEntry( entry = KnowledgeEntry(
entry_id=entry_id, entry_id=entry_id,
category=categories[0] if categories else 'general', category=categories[0] if categories else "general",
content=fact['text'], content=fact["text"],
metadata={'type': fact['type'], 'confidence': fact['confidence']}, metadata={"type": fact["type"], "confidence": fact["confidence"]},
created_at=time.time(), created_at=time.time(),
updated_at=time.time() updated_at=time.time(),
) )
self.knowledge_store.add_entry(entry) self.knowledge_store.add_entry(entry)
if self.context_manager and ADVANCED_CONTEXT_ENABLED: if self.context_manager and ADVANCED_CONTEXT_ENABLED:
enhanced_messages, context_info = self.context_manager.create_enhanced_context( enhanced_messages, context_info = (
self.base.messages, self.context_manager.create_enhanced_context(
user_message, self.base.messages, user_message, include_knowledge=True
include_knowledge=True )
) )
if self.base.verbose: if self.base.verbose:
@ -189,38 +215,40 @@ class EnhancedAssistant:
result = self.base.process_response(response) result = self.base.process_response(response)
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD: if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
summary = self.context_manager.advanced_summarize_messages( summary = (
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:] self.context_manager.advanced_summarize_messages(
) if self.context_manager else "Conversation in progress" self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
)
if self.context_manager
else "Conversation in progress"
)
topics = self.fact_extractor.categorize_content(summary) topics = self.fact_extractor.categorize_content(summary)
self.conversation_memory.update_conversation_summary( self.conversation_memory.update_conversation_summary(
self.current_conversation_id, self.current_conversation_id, summary, topics
summary,
topics
) )
return result return result
def execute_workflow(self, workflow_name: str, def execute_workflow(
initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
workflow = self.workflow_storage.load_workflow_by_name(workflow_name) workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
if not workflow: if not workflow:
return {'error': f'Workflow "{workflow_name}" not found'} return {"error": f'Workflow "{workflow_name}" not found'}
context = self.workflow_engine.execute_workflow(workflow, initial_variables) context = self.workflow_engine.execute_workflow(workflow, initial_variables)
execution_id = self.workflow_storage.save_execution( execution_id = self.workflow_storage.save_execution(
self.workflow_storage.load_workflow_by_name(workflow_name).name, self.workflow_storage.load_workflow_by_name(workflow_name).name, context
context
) )
return { return {
'success': True, "success": True,
'execution_id': execution_id, "execution_id": execution_id,
'results': context.step_results, "results": context.step_results,
'execution_log': context.execution_log "execution_log": context.execution_log,
} }
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str: def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
@ -230,20 +258,22 @@ class EnhancedAssistant:
return self.agent_manager.execute_agent_task(agent_id, task) return self.agent_manager.execute_agent_task(agent_id, task)
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]: def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
orchestrator_id = self.agent_manager.create_agent('orchestrator') orchestrator_id = self.agent_manager.create_agent("orchestrator")
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles) return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]: def search_knowledge(
self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT
) -> List[Any]:
return self.knowledge_store.search_entries(query, top_k=limit) return self.knowledge_store.search_entries(query, top_k=limit)
def get_cache_statistics(self) -> Dict[str, Any]: def get_cache_statistics(self) -> Dict[str, Any]:
stats = {} stats = {}
if self.api_cache: if self.api_cache:
stats['api_cache'] = self.api_cache.get_statistics() stats["api_cache"] = self.api_cache.get_statistics()
if self.tool_cache: if self.tool_cache:
stats['tool_cache'] = self.tool_cache.get_statistics() stats["tool_cache"] = self.tool_cache.get_statistics()
return stats return stats

View File

@ -1,6 +1,7 @@
import logging import logging
import os import os
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from pr.config import LOG_FILE from pr.config import LOG_FILE
@ -9,21 +10,19 @@ def setup_logging(verbose=False):
if log_dir and not os.path.exists(log_dir): if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger('pr') logger = logging.getLogger("pr")
logger.setLevel(logging.DEBUG if verbose else logging.INFO) logger.setLevel(logging.DEBUG if verbose else logging.INFO)
if logger.handlers: if logger.handlers:
logger.handlers.clear() logger.handlers.clear()
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
LOG_FILE, LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5
maxBytes=10 * 1024 * 1024,
backupCount=5
) )
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter( file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt='%Y-%m-%d %H:%M:%S' datefmt="%Y-%m-%d %H:%M:%S",
) )
file_handler.setFormatter(file_formatter) file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler) logger.addHandler(file_handler)
@ -31,9 +30,7 @@ def setup_logging(verbose=False):
if verbose: if verbose:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter( console_formatter = logging.Formatter("%(levelname)s: %(message)s")
'%(levelname)s: %(message)s'
)
console_handler.setFormatter(console_formatter) console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler) logger.addHandler(console_handler)
@ -42,5 +39,5 @@ def setup_logging(verbose=False):
def get_logger(name=None): def get_logger(name=None):
if name: if name:
return logging.getLogger(f'pr.{name}') return logging.getLogger(f"pr.{name}")
return logging.getLogger('pr') return logging.getLogger("pr")

View File

@ -2,9 +2,10 @@ import json
import os import os
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pr.core.logging import get_logger from pr.core.logging import get_logger
logger = get_logger('session') logger = get_logger("session")
SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions") SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions")
@ -14,18 +15,20 @@ class SessionManager:
def __init__(self): def __init__(self):
os.makedirs(SESSIONS_DIR, exist_ok=True) os.makedirs(SESSIONS_DIR, exist_ok=True)
def save_session(self, name: str, messages: List[Dict], metadata: Optional[Dict] = None) -> bool: def save_session(
self, name: str, messages: List[Dict], metadata: Optional[Dict] = None
) -> bool:
try: try:
session_file = os.path.join(SESSIONS_DIR, f"{name}.json") session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
session_data = { session_data = {
'name': name, "name": name,
'created_at': datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
'messages': messages, "messages": messages,
'metadata': metadata or {} "metadata": metadata or {},
} }
with open(session_file, 'w') as f: with open(session_file, "w") as f:
json.dump(session_data, f, indent=2) json.dump(session_data, f, indent=2)
logger.info(f"Session saved: {name}") logger.info(f"Session saved: {name}")
@ -43,7 +46,7 @@ class SessionManager:
logger.warning(f"Session not found: {name}") logger.warning(f"Session not found: {name}")
return None return None
with open(session_file, 'r') as f: with open(session_file) as f:
session_data = json.load(f) session_data = json.load(f)
logger.info(f"Session loaded: {name}") logger.info(f"Session loaded: {name}")
@ -58,22 +61,24 @@ class SessionManager:
try: try:
for filename in os.listdir(SESSIONS_DIR): for filename in os.listdir(SESSIONS_DIR):
if filename.endswith('.json'): if filename.endswith(".json"):
filepath = os.path.join(SESSIONS_DIR, filename) filepath = os.path.join(SESSIONS_DIR, filename)
try: try:
with open(filepath, 'r') as f: with open(filepath) as f:
data = json.load(f) data = json.load(f)
sessions.append({ sessions.append(
'name': data.get('name', filename[:-5]), {
'created_at': data.get('created_at', 'unknown'), "name": data.get("name", filename[:-5]),
'message_count': len(data.get('messages', [])), "created_at": data.get("created_at", "unknown"),
'metadata': data.get('metadata', {}) "message_count": len(data.get("messages", [])),
}) "metadata": data.get("metadata", {}),
}
)
except Exception as e: except Exception as e:
logger.warning(f"Error reading session file {filename}: {e}") logger.warning(f"Error reading session file {filename}: {e}")
sessions.sort(key=lambda x: x['created_at'], reverse=True) sessions.sort(key=lambda x: x["created_at"], reverse=True)
except Exception as e: except Exception as e:
logger.error(f"Error listing sessions: {e}") logger.error(f"Error listing sessions: {e}")
@ -96,39 +101,39 @@ class SessionManager:
logger.error(f"Error deleting session {name}: {e}") logger.error(f"Error deleting session {name}: {e}")
return False return False
def export_session(self, name: str, output_path: str, format: str = 'json') -> bool: def export_session(self, name: str, output_path: str, format: str = "json") -> bool:
session_data = self.load_session(name) session_data = self.load_session(name)
if not session_data: if not session_data:
return False return False
try: try:
if format == 'json': if format == "json":
with open(output_path, 'w') as f: with open(output_path, "w") as f:
json.dump(session_data, f, indent=2) json.dump(session_data, f, indent=2)
elif format == 'markdown': elif format == "markdown":
with open(output_path, 'w') as f: with open(output_path, "w") as f:
f.write(f"# Session: {name}\n\n") f.write(f"# Session: {name}\n\n")
f.write(f"Created: {session_data['created_at']}\n\n") f.write(f"Created: {session_data['created_at']}\n\n")
f.write("---\n\n") f.write("---\n\n")
for msg in session_data['messages']: for msg in session_data["messages"]:
role = msg.get('role', 'unknown') role = msg.get("role", "unknown")
content = msg.get('content', '') content = msg.get("content", "")
f.write(f"## {role.capitalize()}\n\n") f.write(f"## {role.capitalize()}\n\n")
f.write(f"{content}\n\n") f.write(f"{content}\n\n")
f.write("---\n\n") f.write("---\n\n")
elif format == 'txt': elif format == "txt":
with open(output_path, 'w') as f: with open(output_path, "w") as f:
f.write(f"Session: {name}\n") f.write(f"Session: {name}\n")
f.write(f"Created: {session_data['created_at']}\n") f.write(f"Created: {session_data['created_at']}\n")
f.write("=" * 80 + "\n\n") f.write("=" * 80 + "\n\n")
for msg in session_data['messages']: for msg in session_data["messages"]:
role = msg.get('role', 'unknown') role = msg.get("role", "unknown")
content = msg.get('content', '') content = msg.get("content", "")
f.write(f"[{role.upper()}]\n") f.write(f"[{role.upper()}]\n")
f.write(f"{content}\n") f.write(f"{content}\n")

View File

@ -2,20 +2,21 @@ import json
import os import os
from datetime import datetime from datetime import datetime
from typing import Dict, Optional from typing import Dict, Optional
from pr.core.logging import get_logger from pr.core.logging import get_logger
logger = get_logger('usage') logger = get_logger("usage")
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json") USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
MODEL_COSTS = { MODEL_COSTS = {
'x-ai/grok-code-fast-1': {'input': 0.0, 'output': 0.0}, "x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0},
'gpt-4': {'input': 0.03, 'output': 0.06}, "gpt-4": {"input": 0.03, "output": 0.06},
'gpt-4-turbo': {'input': 0.01, 'output': 0.03}, "gpt-4-turbo": {"input": 0.01, "output": 0.03},
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015}, "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
'claude-3-opus': {'input': 0.015, 'output': 0.075}, "claude-3-opus": {"input": 0.015, "output": 0.075},
'claude-3-sonnet': {'input': 0.003, 'output': 0.015}, "claude-3-sonnet": {"input": 0.003, "output": 0.015},
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}, "claude-3-haiku": {"input": 0.00025, "output": 0.00125},
} }
@ -23,12 +24,12 @@ class UsageTracker:
def __init__(self): def __init__(self):
self.session_usage = { self.session_usage = {
'requests': 0, "requests": 0,
'total_tokens': 0, "total_tokens": 0,
'input_tokens': 0, "input_tokens": 0,
'output_tokens': 0, "output_tokens": 0,
'estimated_cost': 0.0, "estimated_cost": 0.0,
'models_used': {} "models_used": {},
} }
def track_request( def track_request(
@ -36,30 +37,30 @@ class UsageTracker:
model: str, model: str,
input_tokens: int, input_tokens: int,
output_tokens: int, output_tokens: int,
total_tokens: Optional[int] = None total_tokens: Optional[int] = None,
): ):
if total_tokens is None: if total_tokens is None:
total_tokens = input_tokens + output_tokens total_tokens = input_tokens + output_tokens
self.session_usage['requests'] += 1 self.session_usage["requests"] += 1
self.session_usage['total_tokens'] += total_tokens self.session_usage["total_tokens"] += total_tokens
self.session_usage['input_tokens'] += input_tokens self.session_usage["input_tokens"] += input_tokens
self.session_usage['output_tokens'] += output_tokens self.session_usage["output_tokens"] += output_tokens
if model not in self.session_usage['models_used']: if model not in self.session_usage["models_used"]:
self.session_usage['models_used'][model] = { self.session_usage["models_used"][model] = {
'requests': 0, "requests": 0,
'tokens': 0, "tokens": 0,
'cost': 0.0 "cost": 0.0,
} }
model_usage = self.session_usage['models_used'][model] model_usage = self.session_usage["models_used"][model]
model_usage['requests'] += 1 model_usage["requests"] += 1
model_usage['tokens'] += total_tokens model_usage["tokens"] += total_tokens
cost = self._calculate_cost(model, input_tokens, output_tokens) cost = self._calculate_cost(model, input_tokens, output_tokens)
model_usage['cost'] += cost model_usage["cost"] += cost
self.session_usage['estimated_cost'] += cost self.session_usage["estimated_cost"] += cost
self._save_to_history(model, input_tokens, output_tokens, cost) self._save_to_history(model, input_tokens, output_tokens, cost)
@ -67,9 +68,11 @@ class UsageTracker:
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}" f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
) )
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float: def _calculate_cost(
self, model: str, input_tokens: int, output_tokens: int
) -> float:
if model not in MODEL_COSTS: if model not in MODEL_COSTS:
base_model = model.split('/')[0] if '/' in model else model base_model = model.split("/")[0] if "/" in model else model
if base_model not in MODEL_COSTS: if base_model not in MODEL_COSTS:
logger.warning(f"Unknown model for cost calculation: {model}") logger.warning(f"Unknown model for cost calculation: {model}")
return 0.0 return 0.0
@ -77,31 +80,35 @@ class UsageTracker:
else: else:
costs = MODEL_COSTS[model] costs = MODEL_COSTS[model]
input_cost = (input_tokens / 1000) * costs['input'] input_cost = (input_tokens / 1000) * costs["input"]
output_cost = (output_tokens / 1000) * costs['output'] output_cost = (output_tokens / 1000) * costs["output"]
return input_cost + output_cost return input_cost + output_cost
def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float): def _save_to_history(
self, model: str, input_tokens: int, output_tokens: int, cost: float
):
try: try:
history = [] history = []
if os.path.exists(USAGE_DB_FILE): if os.path.exists(USAGE_DB_FILE):
with open(USAGE_DB_FILE, 'r') as f: with open(USAGE_DB_FILE) as f:
history = json.load(f) history = json.load(f)
history.append({ history.append(
'timestamp': datetime.now().isoformat(), {
'model': model, "timestamp": datetime.now().isoformat(),
'input_tokens': input_tokens, "model": model,
'output_tokens': output_tokens, "input_tokens": input_tokens,
'total_tokens': input_tokens + output_tokens, "output_tokens": output_tokens,
'cost': cost "total_tokens": input_tokens + output_tokens,
}) "cost": cost,
}
)
if len(history) > 10000: if len(history) > 10000:
history = history[-10000:] history = history[-10000:]
with open(USAGE_DB_FILE, 'w') as f: with open(USAGE_DB_FILE, "w") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=2)
except Exception as e: except Exception as e:
@ -121,42 +128,34 @@ class UsageTracker:
f"Estimated Cost: ${usage['estimated_cost']:.4f}", f"Estimated Cost: ${usage['estimated_cost']:.4f}",
] ]
if usage['models_used']: if usage["models_used"]:
lines.append("\nModels Used:") lines.append("\nModels Used:")
for model, stats in usage['models_used'].items(): for model, stats in usage["models_used"].items():
lines.append( lines.append(
f" {model}: {stats['requests']} requests, " f" {model}: {stats['requests']} requests, "
f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}" f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}"
) )
return '\n'.join(lines) return "\n".join(lines)
@staticmethod @staticmethod
def get_total_usage() -> Dict: def get_total_usage() -> Dict:
if not os.path.exists(USAGE_DB_FILE): if not os.path.exists(USAGE_DB_FILE):
return { return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
'total_requests': 0,
'total_tokens': 0,
'total_cost': 0.0
}
try: try:
with open(USAGE_DB_FILE, 'r') as f: with open(USAGE_DB_FILE) as f:
history = json.load(f) history = json.load(f)
total_tokens = sum(entry['total_tokens'] for entry in history) total_tokens = sum(entry["total_tokens"] for entry in history)
total_cost = sum(entry['cost'] for entry in history) total_cost = sum(entry["cost"] for entry in history)
return { return {
'total_requests': len(history), "total_requests": len(history),
'total_tokens': total_tokens, "total_tokens": total_tokens,
'total_cost': total_cost "total_cost": total_cost,
} }
except Exception as e: except Exception as e:
logger.error(f"Error loading usage history: {e}") logger.error(f"Error loading usage history: {e}")
return { return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
'total_requests': 0,
'total_tokens': 0,
'total_cost': 0.0
}

View File

@ -1,5 +1,5 @@
import os import os
from typing import Optional
from pr.core.exceptions import ValidationError from pr.core.exceptions import ValidationError
@ -16,7 +16,9 @@ def validate_file_path(path: str, must_exist: bool = False) -> str:
return os.path.abspath(path) return os.path.abspath(path)
def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str: def validate_directory_path(
path: str, must_exist: bool = False, create: bool = False
) -> str:
if not path: if not path:
raise ValidationError("Directory path cannot be empty") raise ValidationError("Directory path cannot be empty")
@ -48,7 +50,7 @@ def validate_api_url(url: str) -> str:
if not url: if not url:
raise ValidationError("API URL cannot be empty") raise ValidationError("API URL cannot be empty")
if not url.startswith(('http://', 'https://')): if not url.startswith(("http://", "https://")):
raise ValidationError("API URL must start with http:// or https://") raise ValidationError("API URL must start with http:// or https://")
return url return url
@ -58,7 +60,7 @@ def validate_session_name(name: str) -> str:
if not name: if not name:
raise ValidationError("Session name cannot be empty") raise ValidationError("Session name cannot be empty")
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|'] invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]
for char in invalid_chars: for char in invalid_chars:
if char in name: if char in name:
raise ValidationError(f"Session name contains invalid character: {char}") raise ValidationError(f"Session name contains invalid character: {char}")

View File

@ -1,17 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import atexit
import curses import curses
import threading
import sys
import os import os
import re
import socket
import pickle import pickle
import queue import queue
import time import re
import atexit
import signal import signal
import traceback import socket
from contextlib import contextmanager import sys
import threading
import time
class RPEditor: class RPEditor:
def __init__(self, filename=None, auto_save=False, timeout=30): def __init__(self, filename=None, auto_save=False, timeout=30):
@ -27,7 +26,7 @@ class RPEditor:
self.lines = [""] self.lines = [""]
self.cursor_y = 0 self.cursor_y = 0
self.cursor_x = 0 self.cursor_x = 0
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
self.stdscr = None self.stdscr = None
self.running = False self.running = False
@ -106,7 +105,7 @@ class RPEditor:
# Clear screen after curses cleanup # Clear screen after curses cleanup
try: try:
os.system('clear' if os.name != 'nt' else 'cls') os.system("clear" if os.name != "nt" else "cls")
except: except:
pass pass
@ -130,12 +129,12 @@ class RPEditor:
"""Load file with enhanced error handling.""" """Load file with enhanced error handling."""
try: try:
if os.path.exists(self.filename): if os.path.exists(self.filename):
with open(self.filename, 'r', encoding='utf-8', errors='replace') as f: with open(self.filename, encoding="utf-8", errors="replace") as f:
content = f.read() content = f.read()
self.lines = content.splitlines() if content else [""] self.lines = content.splitlines() if content else [""]
else: else:
self.lines = [""] self.lines = [""]
except Exception as e: except Exception:
self.lines = [""] self.lines = [""]
# Don't raise, just use empty content # Don't raise, just use empty content
@ -150,18 +149,18 @@ class RPEditor:
if os.path.exists(self.filename): if os.path.exists(self.filename):
backup_name = f"{self.filename}.bak" backup_name = f"{self.filename}.bak"
try: try:
with open(self.filename, 'r', encoding='utf-8') as f: with open(self.filename, encoding="utf-8") as f:
backup_content = f.read() backup_content = f.read()
with open(backup_name, 'w', encoding='utf-8') as f: with open(backup_name, "w", encoding="utf-8") as f:
f.write(backup_content) f.write(backup_content)
except: except:
pass # Backup failed, but continue with save pass # Backup failed, but continue with save
# Save the file # Save the file
with open(self.filename, 'w', encoding='utf-8') as f: with open(self.filename, "w", encoding="utf-8") as f:
f.write('\n'.join(self.lines)) f.write("\n".join(self.lines))
return True return True
except Exception as e: except Exception:
return False return False
def save_file(self): def save_file(self):
@ -169,7 +168,7 @@ class RPEditor:
if not self.running: if not self.running:
return self._save_file() return self._save_file()
try: try:
self.client_sock.send(pickle.dumps({'command': 'save_file'})) self.client_sock.send(pickle.dumps({"command": "save_file"}))
except: except:
return self._save_file() # Fallback to direct save return self._save_file() # Fallback to direct save
@ -180,7 +179,9 @@ class RPEditor:
try: try:
self.running = True self.running = True
self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True) self.socket_thread = threading.Thread(
target=self.socket_listener, daemon=True
)
self.socket_thread.start() self.socket_thread.start()
self.thread = threading.Thread(target=self.run, daemon=True) self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start() self.thread.start()
@ -194,7 +195,7 @@ class RPEditor:
"""Stop the editor with proper cleanup.""" """Stop the editor with proper cleanup."""
try: try:
if self.client_sock: if self.client_sock:
self.client_sock.send(pickle.dumps({'command': 'stop'})) self.client_sock.send(pickle.dumps({"command": "stop"}))
except: except:
pass pass
@ -206,7 +207,7 @@ class RPEditor:
"""Run the main editor loop with exception handling.""" """Run the main editor loop with exception handling."""
try: try:
curses.wrapper(self.main_loop) curses.wrapper(self.main_loop)
except Exception as e: except Exception:
self._exception_occurred = True self._exception_occurred = True
self._cleanup() self._cleanup()
@ -244,11 +245,11 @@ class RPEditor:
except curses.error: except curses.error:
pass # Ignore curses errors pass # Ignore curses errors
except Exception as e: except Exception:
# Log error but continue running # Log error but continue running
pass pass
except Exception as e: except Exception:
self._exception_occurred = True self._exception_occurred = True
finally: finally:
self._cleanup() self._cleanup()
@ -265,18 +266,18 @@ class RPEditor:
break break
try: try:
# Handle long lines and special characters # Handle long lines and special characters
display_line = line[:width-1] if len(line) >= width else line display_line = line[: width - 1] if len(line) >= width else line
self.stdscr.addstr(i, 0, display_line) self.stdscr.addstr(i, 0, display_line)
except curses.error: except curses.error:
pass # Skip lines that can't be displayed pass # Skip lines that can't be displayed
# Draw status line # Draw status line
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}" status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
if self.mode == 'command': if self.mode == "command":
status = self.command[:width-1] status = self.command[: width - 1]
try: try:
self.stdscr.addstr(height - 1, 0, status[:width-1]) self.stdscr.addstr(height - 1, 0, status[: width - 1])
except curses.error: except curses.error:
pass pass
@ -295,11 +296,11 @@ class RPEditor:
def handle_key(self, key): def handle_key(self, key):
"""Handle keyboard input with error recovery.""" """Handle keyboard input with error recovery."""
try: try:
if self.mode == 'normal': if self.mode == "normal":
self.handle_normal(key) self.handle_normal(key)
elif self.mode == 'insert': elif self.mode == "insert":
self.handle_insert(key) self.handle_insert(key)
elif self.mode == 'command': elif self.mode == "command":
self.handle_command(key) self.handle_command(key)
except Exception: except Exception:
pass # Continue on error pass # Continue on error
@ -307,71 +308,71 @@ class RPEditor:
def handle_normal(self, key): def handle_normal(self, key):
"""Handle normal mode keys.""" """Handle normal mode keys."""
try: try:
if key == ord('h') or key == curses.KEY_LEFT: if key == ord("h") or key == curses.KEY_LEFT:
self.move_cursor(0, -1) self.move_cursor(0, -1)
elif key == ord('j') or key == curses.KEY_DOWN: elif key == ord("j") or key == curses.KEY_DOWN:
self.move_cursor(1, 0) self.move_cursor(1, 0)
elif key == ord('k') or key == curses.KEY_UP: elif key == ord("k") or key == curses.KEY_UP:
self.move_cursor(-1, 0) self.move_cursor(-1, 0)
elif key == ord('l') or key == curses.KEY_RIGHT: elif key == ord("l") or key == curses.KEY_RIGHT:
self.move_cursor(0, 1) self.move_cursor(0, 1)
elif key == ord('i'): elif key == ord("i"):
self.mode = 'insert' self.mode = "insert"
elif key == ord(':'): elif key == ord(":"):
self.mode = 'command' self.mode = "command"
self.command = ":" self.command = ":"
elif key == ord('x'): elif key == ord("x"):
self._delete_char() self._delete_char()
elif key == ord('a'): elif key == ord("a"):
self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y])) self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y]))
self.mode = 'insert' self.mode = "insert"
elif key == ord('A'): elif key == ord("A"):
self.cursor_x = len(self.lines[self.cursor_y]) self.cursor_x = len(self.lines[self.cursor_y])
self.mode = 'insert' self.mode = "insert"
elif key == ord('o'): elif key == ord("o"):
self._insert_line(self.cursor_y + 1, "") self._insert_line(self.cursor_y + 1, "")
self.cursor_y += 1 self.cursor_y += 1
self.cursor_x = 0 self.cursor_x = 0
self.mode = 'insert' self.mode = "insert"
elif key == ord('O'): elif key == ord("O"):
self._insert_line(self.cursor_y, "") self._insert_line(self.cursor_y, "")
self.cursor_x = 0 self.cursor_x = 0
self.mode = 'insert' self.mode = "insert"
elif key == ord('d') and self.prev_key == ord('d'): elif key == ord("d") and self.prev_key == ord("d"):
if self.cursor_y < len(self.lines): if self.cursor_y < len(self.lines):
self.clipboard = self.lines[self.cursor_y] self.clipboard = self.lines[self.cursor_y]
self._delete_line(self.cursor_y) self._delete_line(self.cursor_y)
if self.cursor_y >= len(self.lines): if self.cursor_y >= len(self.lines):
self.cursor_y = max(0, len(self.lines) - 1) self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('y') and self.prev_key == ord('y'): elif key == ord("y") and self.prev_key == ord("y"):
if self.cursor_y < len(self.lines): if self.cursor_y < len(self.lines):
self.clipboard = self.lines[self.cursor_y] self.clipboard = self.lines[self.cursor_y]
elif key == ord('p'): elif key == ord("p"):
self._insert_line(self.cursor_y + 1, self.clipboard) self._insert_line(self.cursor_y + 1, self.clipboard)
self.cursor_y += 1 self.cursor_y += 1
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('P'): elif key == ord("P"):
self._insert_line(self.cursor_y, self.clipboard) self._insert_line(self.cursor_y, self.clipboard)
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('w'): elif key == ord("w"):
self._move_word_forward() self._move_word_forward()
elif key == ord('b'): elif key == ord("b"):
self._move_word_backward() self._move_word_backward()
elif key == ord('0'): elif key == ord("0"):
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('$'): elif key == ord("$"):
self.cursor_x = len(self.lines[self.cursor_y]) self.cursor_x = len(self.lines[self.cursor_y])
elif key == ord('g'): elif key == ord("g"):
if self.prev_key == ord('g'): if self.prev_key == ord("g"):
self.cursor_y = 0 self.cursor_y = 0
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('G'): elif key == ord("G"):
self.cursor_y = max(0, len(self.lines) - 1) self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('u'): elif key == ord("u"):
self.undo() self.undo()
elif key == ord('r') and self.prev_key == 18: # Ctrl-R elif key == ord("r") and self.prev_key == 18: # Ctrl-R
self.redo() self.redo()
self.prev_key = key self.prev_key = key
@ -410,7 +411,7 @@ class RPEditor:
"""Handle insert mode keys.""" """Handle insert mode keys."""
try: try:
if key == 27: # ESC if key == 27: # ESC
self.mode = 'normal' self.mode = "normal"
if self.cursor_x > 0: if self.cursor_x > 0:
self.cursor_x -= 1 self.cursor_x -= 1
elif key == 10 or key == 13: # Enter elif key == 10 or key == 13: # Enter
@ -438,10 +439,10 @@ class RPEditor:
elif cmd.startswith("w "): elif cmd.startswith("w "):
self.filename = cmd[2:].strip() self.filename = cmd[2:].strip()
self._save_file() self._save_file()
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
elif key == 27: # ESC elif key == 27: # ESC
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8: elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
if len(self.command) > 1: if len(self.command) > 1:
@ -449,7 +450,7 @@ class RPEditor:
elif 32 <= key <= 126: elif 32 <= key <= 126:
self.command += chr(key) self.command += chr(key)
except Exception: except Exception:
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
def move_cursor(self, dy, dx): def move_cursor(self, dy, dx):
@ -477,9 +478,9 @@ class RPEditor:
"""Save current state for undo.""" """Save current state for undo."""
with self.lock: with self.lock:
state = { state = {
'lines': [line for line in self.lines], "lines": list(self.lines),
'cursor_y': self.cursor_y, "cursor_y": self.cursor_y,
'cursor_x': self.cursor_x "cursor_x": self.cursor_x,
} }
self.undo_stack.append(state) self.undo_stack.append(state)
if len(self.undo_stack) > self.max_undo: if len(self.undo_stack) > self.max_undo:
@ -491,30 +492,36 @@ class RPEditor:
with self.lock: with self.lock:
if self.undo_stack: if self.undo_stack:
current_state = { current_state = {
'lines': [line for line in self.lines], "lines": list(self.lines),
'cursor_y': self.cursor_y, "cursor_y": self.cursor_y,
'cursor_x': self.cursor_x "cursor_x": self.cursor_x,
} }
self.redo_stack.append(current_state) self.redo_stack.append(current_state)
state = self.undo_stack.pop() state = self.undo_stack.pop()
self.lines = state['lines'] self.lines = state["lines"]
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1) self.cursor_y = min(state["cursor_y"], len(self.lines) - 1)
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0) self.cursor_x = min(
state["cursor_x"],
len(self.lines[self.cursor_y]) if self.lines else 0,
)
def redo(self): def redo(self):
"""Redo last undone change.""" """Redo last undone change."""
with self.lock: with self.lock:
if self.redo_stack: if self.redo_stack:
current_state = { current_state = {
'lines': [line for line in self.lines], "lines": list(self.lines),
'cursor_y': self.cursor_y, "cursor_y": self.cursor_y,
'cursor_x': self.cursor_x "cursor_x": self.cursor_x,
} }
self.undo_stack.append(current_state) self.undo_stack.append(current_state)
state = self.redo_stack.pop() state = self.redo_stack.pop()
self.lines = state['lines'] self.lines = state["lines"]
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1) self.cursor_y = min(state["cursor_y"], len(self.lines) - 1)
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0) self.cursor_x = min(
state["cursor_x"],
len(self.lines[self.cursor_y]) if self.lines else 0,
)
def _insert_text(self, text): def _insert_text(self, text):
"""Insert text at cursor position.""" """Insert text at cursor position."""
@ -522,7 +529,7 @@ class RPEditor:
return return
self.save_state() self.save_state()
lines = text.split('\n') lines = text.split("\n")
if len(lines) == 1: if len(lines) == 1:
# Single line insert # Single line insert
@ -531,7 +538,9 @@ class RPEditor:
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + text + line[self.cursor_x:] self.lines[self.cursor_y] = (
line[: self.cursor_x] + text + line[self.cursor_x :]
)
self.cursor_x += len(text) self.cursor_x += len(text)
else: else:
# Multi-line insert # Multi-line insert
@ -539,8 +548,8 @@ class RPEditor:
self.lines.append("") self.lines.append("")
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0] first = self.lines[self.cursor_y][: self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:] last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :]
self.lines[self.cursor_y] = first self.lines[self.cursor_y] = first
for i in range(1, len(lines) - 1): for i in range(1, len(lines) - 1):
@ -553,7 +562,9 @@ class RPEditor:
def insert_text(self, text): def insert_text(self, text):
"""Thread-safe text insertion.""" """Thread-safe text insertion."""
try: try:
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text})) self.client_sock.send(
pickle.dumps({"command": "insert_text", "text": text})
)
except: except:
with self.lock: with self.lock:
self._insert_text(text) self._insert_text(text)
@ -561,14 +572,18 @@ class RPEditor:
def _delete_char(self): def _delete_char(self):
"""Delete character at cursor.""" """Delete character at cursor."""
self.save_state() self.save_state()
if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]): if self.cursor_y < len(self.lines) and self.cursor_x < len(
self.lines[self.cursor_y]
):
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + line[self.cursor_x+1:] self.lines[self.cursor_y] = (
line[: self.cursor_x] + line[self.cursor_x + 1 :]
)
def delete_char(self): def delete_char(self):
"""Thread-safe character deletion.""" """Thread-safe character deletion."""
try: try:
self.client_sock.send(pickle.dumps({'command': 'delete_char'})) self.client_sock.send(pickle.dumps({"command": "delete_char"}))
except: except:
with self.lock: with self.lock:
self._delete_char() self._delete_char()
@ -580,7 +595,7 @@ class RPEditor:
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + char + line[self.cursor_x:] self.lines[self.cursor_y] = line[: self.cursor_x] + char + line[self.cursor_x :]
self.cursor_x += 1 self.cursor_x += 1
def _split_line(self): def _split_line(self):
@ -590,8 +605,8 @@ class RPEditor:
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] self.lines[self.cursor_y] = line[: self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:]) self.lines.insert(self.cursor_y + 1, line[self.cursor_x :])
self.cursor_y += 1 self.cursor_y += 1
self.cursor_x = 0 self.cursor_x = 0
@ -599,7 +614,9 @@ class RPEditor:
"""Handle backspace key.""" """Handle backspace key."""
if self.cursor_x > 0: if self.cursor_x > 0:
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x-1] + line[self.cursor_x:] self.lines[self.cursor_y] = (
line[: self.cursor_x - 1] + line[self.cursor_x :]
)
self.cursor_x -= 1 self.cursor_x -= 1
elif self.cursor_y > 0: elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1]) prev_len = len(self.lines[self.cursor_y - 1])
@ -637,7 +654,7 @@ class RPEditor:
self._set_text(text) self._set_text(text)
return return
try: try:
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text})) self.client_sock.send(pickle.dumps({"command": "set_text", "text": text}))
except: except:
with self.lock: with self.lock:
self._set_text(text) self._set_text(text)
@ -651,7 +668,9 @@ class RPEditor:
def goto_line(self, line_num): def goto_line(self, line_num):
"""Thread-safe goto line.""" """Thread-safe goto line."""
try: try:
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num})) self.client_sock.send(
pickle.dumps({"command": "goto_line", "line_num": line_num})
)
except: except:
with self.lock: with self.lock:
self._goto_line(line_num) self._goto_line(line_num)
@ -659,17 +678,17 @@ class RPEditor:
def get_text(self): def get_text(self):
"""Get entire text content.""" """Get entire text content."""
try: try:
self.client_sock.send(pickle.dumps({'command': 'get_text'})) self.client_sock.send(pickle.dumps({"command": "get_text"}))
data = self.client_sock.recv(65536) data = self.client_sock.recv(65536)
return pickle.loads(data) return pickle.loads(data)
except: except:
with self.lock: with self.lock:
return '\n'.join(self.lines) return "\n".join(self.lines)
def get_cursor(self): def get_cursor(self):
"""Get cursor position.""" """Get cursor position."""
try: try:
self.client_sock.send(pickle.dumps({'command': 'get_cursor'})) self.client_sock.send(pickle.dumps({"command": "get_cursor"}))
data = self.client_sock.recv(4096) data = self.client_sock.recv(4096)
return pickle.loads(data) return pickle.loads(data)
except: except:
@ -679,16 +698,16 @@ class RPEditor:
def get_file_info(self): def get_file_info(self):
"""Get file information.""" """Get file information."""
try: try:
self.client_sock.send(pickle.dumps({'command': 'get_file_info'})) self.client_sock.send(pickle.dumps({"command": "get_file_info"}))
data = self.client_sock.recv(4096) data = self.client_sock.recv(4096)
return pickle.loads(data) return pickle.loads(data)
except: except:
with self.lock: with self.lock:
return { return {
'filename': self.filename, "filename": self.filename,
'lines': len(self.lines), "lines": len(self.lines),
'cursor': (self.cursor_y, self.cursor_x), "cursor": (self.cursor_y, self.cursor_x),
'mode': self.mode "mode": self.mode,
} }
def socket_listener(self): def socket_listener(self):
@ -713,33 +732,33 @@ class RPEditor:
def execute_command(self, command): def execute_command(self, command):
"""Execute command with error handling.""" """Execute command with error handling."""
try: try:
cmd = command.get('command') cmd = command.get("command")
if cmd == 'insert_text': if cmd == "insert_text":
self._insert_text(command.get('text', '')) self._insert_text(command.get("text", ""))
elif cmd == 'delete_char': elif cmd == "delete_char":
self._delete_char() self._delete_char()
elif cmd == 'save_file': elif cmd == "save_file":
self._save_file() self._save_file()
elif cmd == 'set_text': elif cmd == "set_text":
self._set_text(command.get('text', '')) self._set_text(command.get("text", ""))
elif cmd == 'goto_line': elif cmd == "goto_line":
self._goto_line(command.get('line_num', 1)) self._goto_line(command.get("line_num", 1))
elif cmd == 'get_text': elif cmd == "get_text":
result = '\n'.join(self.lines) result = "\n".join(self.lines)
self.server_sock.send(pickle.dumps(result)) self.server_sock.send(pickle.dumps(result))
elif cmd == 'get_cursor': elif cmd == "get_cursor":
result = (self.cursor_y, self.cursor_x) result = (self.cursor_y, self.cursor_x)
self.server_sock.send(pickle.dumps(result)) self.server_sock.send(pickle.dumps(result))
elif cmd == 'get_file_info': elif cmd == "get_file_info":
result = { result = {
'filename': self.filename, "filename": self.filename,
'lines': len(self.lines), "lines": len(self.lines),
'cursor': (self.cursor_y, self.cursor_x), "cursor": (self.cursor_y, self.cursor_x),
'mode': self.mode "mode": self.mode,
} }
self.server_sock.send(pickle.dumps(result)) self.server_sock.send(pickle.dumps(result))
elif cmd == 'stop': elif cmd == "stop":
self.running = False self.running = False
except Exception: except Exception:
pass pass
@ -801,9 +820,9 @@ class RPEditor:
else: else:
first_part = self.lines[start_line][:start_col] first_part = self.lines[start_line][:start_col]
last_part = self.lines[end_line][end_col:] last_part = self.lines[end_line][end_col:]
new_lines = new_text.split('\n') new_lines = new_text.split("\n")
self.lines[start_line] = first_part + new_lines[0] self.lines[start_line] = first_part + new_lines[0]
del self.lines[start_line + 1:end_line + 1] del self.lines[start_line + 1 : end_line + 1]
for i, new_line in enumerate(new_lines[1:], 1): for i, new_line in enumerate(new_lines[1:], 1):
self.lines.insert(start_line + i, new_line) self.lines.insert(start_line + i, new_line)
if len(new_lines) > 1: if len(new_lines) > 1:
@ -859,7 +878,7 @@ class RPEditor:
result.append(self.lines[i]) result.append(self.lines[i])
if el < len(self.lines): if el < len(self.lines):
result.append(self.lines[el][:ec]) result.append(self.lines[el][:ec])
return '\n'.join(result) return "\n".join(result)
def delete_selection(self): def delete_selection(self):
"""Delete selected text.""" """Delete selected text."""
@ -894,8 +913,8 @@ class RPEditor:
if match: if match:
# Preserve indentation # Preserve indentation
indent = len(self.lines[i]) - len(self.lines[i].lstrip()) indent = len(self.lines[i]) - len(self.lines[i].lstrip())
indented_replace = [' ' * indent + line for line in replace_lines] indented_replace = [" " * indent + line for line in replace_lines]
self.lines[i:i+len(search_lines)] = indented_replace self.lines[i : i + len(search_lines)] = indented_replace
return True return True
return False return False
@ -904,21 +923,21 @@ class RPEditor:
with self.lock: with self.lock:
self.save_state() self.save_state()
try: try:
lines = diff_text.split('\n') lines = diff_text.split("\n")
start_line = 0 start_line = 0
for line in lines: for line in lines:
if line.startswith('@@'): if line.startswith("@@"):
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line) match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
if match: if match:
start_line = int(match.group(1)) - 1 start_line = int(match.group(1)) - 1
elif line.startswith('-'): elif line.startswith("-"):
if start_line < len(self.lines): if start_line < len(self.lines):
del self.lines[start_line] del self.lines[start_line]
elif line.startswith('+'): elif line.startswith("+"):
self.lines.insert(start_line, line[1:]) self.lines.insert(start_line, line[1:])
start_line += 1 start_line += 1
elif line and not line.startswith('\\'): elif line and not line.startswith("\\"):
start_line += 1 start_line += 1
except Exception: except Exception:
pass pass
@ -974,7 +993,7 @@ def main():
filename = sys.argv[1] if len(sys.argv) > 1 else None filename = sys.argv[1] if len(sys.argv) > 1 else None
# Parse additional arguments # Parse additional arguments
auto_save = '--auto-save' in sys.argv auto_save = "--auto-save" in sys.argv
# Create and start editor # Create and start editor
editor = RPEditor(filename, auto_save=auto_save) editor = RPEditor(filename, auto_save=auto_save)
@ -992,7 +1011,7 @@ def main():
if editor: if editor:
editor.stop() editor.stop()
# Ensure screen is cleared # Ensure screen is cleared
os.system('clear' if os.name != 'nt' else 'cls') os.system("clear" if os.name != "nt" else "cls")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import curses import curses
import threading
import sys
import os
import re
import socket
import pickle import pickle
import queue import queue
import re
import socket
import sys
import threading
class RPEditor: class RPEditor:
def __init__(self, filename=None): def __init__(self, filename=None):
@ -14,7 +14,7 @@ class RPEditor:
self.lines = [""] self.lines = [""]
self.cursor_y = 0 self.cursor_y = 0
self.cursor_x = 0 self.cursor_x = 0
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
self.stdscr = None self.stdscr = None
self.running = False self.running = False
@ -35,7 +35,7 @@ class RPEditor:
def load_file(self): def load_file(self):
try: try:
with open(self.filename, 'r') as f: with open(self.filename) as f:
self.lines = f.read().splitlines() self.lines = f.read().splitlines()
if not self.lines: if not self.lines:
self.lines = [""] self.lines = [""]
@ -45,11 +45,11 @@ class RPEditor:
def _save_file(self): def _save_file(self):
with self.lock: with self.lock:
if self.filename: if self.filename:
with open(self.filename, 'w') as f: with open(self.filename, "w") as f:
f.write('\n'.join(self.lines)) f.write("\n".join(self.lines))
def save_file(self): def save_file(self):
self.client_sock.send(pickle.dumps({'command': 'save_file'})) self.client_sock.send(pickle.dumps({"command": "save_file"}))
def start(self): def start(self):
self.running = True self.running = True
@ -59,7 +59,7 @@ class RPEditor:
self.thread.start() self.thread.start()
def stop(self): def stop(self):
self.client_sock.send(pickle.dumps({'command': 'stop'})) self.client_sock.send(pickle.dumps({"command": "stop"}))
self.running = False self.running = False
if self.stdscr: if self.stdscr:
curses.endwin() curses.endwin()
@ -99,66 +99,66 @@ class RPEditor:
self.stdscr.addstr(i, 0, line[:width]) self.stdscr.addstr(i, 0, line[:width])
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}" status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
self.stdscr.addstr(height - 1, 0, status[:width]) self.stdscr.addstr(height - 1, 0, status[:width])
if self.mode == 'command': if self.mode == "command":
self.stdscr.addstr(height - 1, 0, self.command[:width]) self.stdscr.addstr(height - 1, 0, self.command[:width])
self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1)) self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1))
self.stdscr.refresh() self.stdscr.refresh()
def handle_key(self, key): def handle_key(self, key):
if self.mode == 'normal': if self.mode == "normal":
self.handle_normal(key) self.handle_normal(key)
elif self.mode == 'insert': elif self.mode == "insert":
self.handle_insert(key) self.handle_insert(key)
elif self.mode == 'command': elif self.mode == "command":
self.handle_command(key) self.handle_command(key)
def handle_normal(self, key): def handle_normal(self, key):
if key == ord('h') or key == curses.KEY_LEFT: if key == ord("h") or key == curses.KEY_LEFT:
self.move_cursor(0, -1) self.move_cursor(0, -1)
elif key == ord('j') or key == curses.KEY_DOWN: elif key == ord("j") or key == curses.KEY_DOWN:
self.move_cursor(1, 0) self.move_cursor(1, 0)
elif key == ord('k') or key == curses.KEY_UP: elif key == ord("k") or key == curses.KEY_UP:
self.move_cursor(-1, 0) self.move_cursor(-1, 0)
elif key == ord('l') or key == curses.KEY_RIGHT: elif key == ord("l") or key == curses.KEY_RIGHT:
self.move_cursor(0, 1) self.move_cursor(0, 1)
elif key == ord('i'): elif key == ord("i"):
self.mode = 'insert' self.mode = "insert"
elif key == ord(':'): elif key == ord(":"):
self.mode = 'command' self.mode = "command"
self.command = ":" self.command = ":"
elif key == ord('x'): elif key == ord("x"):
self._delete_char() self._delete_char()
elif key == ord('a'): elif key == ord("a"):
self.cursor_x += 1 self.cursor_x += 1
self.mode = 'insert' self.mode = "insert"
elif key == ord('A'): elif key == ord("A"):
self.cursor_x = len(self.lines[self.cursor_y]) self.cursor_x = len(self.lines[self.cursor_y])
self.mode = 'insert' self.mode = "insert"
elif key == ord('o'): elif key == ord("o"):
self._insert_line(self.cursor_y + 1, "") self._insert_line(self.cursor_y + 1, "")
self.cursor_y += 1 self.cursor_y += 1
self.cursor_x = 0 self.cursor_x = 0
self.mode = 'insert' self.mode = "insert"
elif key == ord('O'): elif key == ord("O"):
self._insert_line(self.cursor_y, "") self._insert_line(self.cursor_y, "")
self.cursor_x = 0 self.cursor_x = 0
self.mode = 'insert' self.mode = "insert"
elif key == ord('d') and self.prev_key == ord('d'): elif key == ord("d") and self.prev_key == ord("d"):
self.clipboard = self.lines[self.cursor_y] self.clipboard = self.lines[self.cursor_y]
self._delete_line(self.cursor_y) self._delete_line(self.cursor_y)
if self.cursor_y >= len(self.lines): if self.cursor_y >= len(self.lines):
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('y') and self.prev_key == ord('y'): elif key == ord("y") and self.prev_key == ord("y"):
self.clipboard = self.lines[self.cursor_y] self.clipboard = self.lines[self.cursor_y]
elif key == ord('p'): elif key == ord("p"):
self._insert_line(self.cursor_y + 1, self.clipboard) self._insert_line(self.cursor_y + 1, self.clipboard)
self.cursor_y += 1 self.cursor_y += 1
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('P'): elif key == ord("P"):
self._insert_line(self.cursor_y, self.clipboard) self._insert_line(self.cursor_y, self.clipboard)
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('w'): elif key == ord("w"):
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
i = self.cursor_x i = self.cursor_x
while i < len(line) and not line[i].isalnum(): while i < len(line) and not line[i].isalnum():
@ -166,7 +166,7 @@ class RPEditor:
while i < len(line) and line[i].isalnum(): while i < len(line) and line[i].isalnum():
i += 1 i += 1
self.cursor_x = i self.cursor_x = i
elif key == ord('b'): elif key == ord("b"):
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
i = self.cursor_x - 1 i = self.cursor_x - 1
while i >= 0 and not line[i].isalnum(): while i >= 0 and not line[i].isalnum():
@ -174,26 +174,26 @@ class RPEditor:
while i >= 0 and line[i].isalnum(): while i >= 0 and line[i].isalnum():
i -= 1 i -= 1
self.cursor_x = i + 1 self.cursor_x = i + 1
elif key == ord('0'): elif key == ord("0"):
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('$'): elif key == ord("$"):
self.cursor_x = len(self.lines[self.cursor_y]) self.cursor_x = len(self.lines[self.cursor_y])
elif key == ord('g'): elif key == ord("g"):
if self.prev_key == ord('g'): if self.prev_key == ord("g"):
self.cursor_y = 0 self.cursor_y = 0
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('G'): elif key == ord("G"):
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
self.cursor_x = 0 self.cursor_x = 0
elif key == ord('u'): elif key == ord("u"):
self.undo() self.undo()
elif key == ord('r') and self.prev_key == 18: elif key == ord("r") and self.prev_key == 18:
self.redo() self.redo()
self.prev_key = key self.prev_key = key
def handle_insert(self, key): def handle_insert(self, key):
if key == 27: if key == 27:
self.mode = 'normal' self.mode = "normal"
if self.cursor_x > 0: if self.cursor_x > 0:
self.cursor_x -= 1 self.cursor_x -= 1
elif key == 10: elif key == 10:
@ -207,11 +207,13 @@ class RPEditor:
def handle_command(self, key): def handle_command(self, key):
if key == 10: if key == 10:
cmd = self.command[1:] cmd = self.command[1:]
if cmd == "q" or cmd == 'q!': if cmd == "q" or cmd == "q!":
self.running = False self.running = False
elif cmd == "w": elif cmd == "w":
self._save_file() self._save_file()
elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!": elif (
cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!"
):
self._save_file() self._save_file()
self.running = False self.running = False
elif cmd.startswith("w "): elif cmd.startswith("w "):
@ -220,10 +222,10 @@ class RPEditor:
elif cmd == "wq": elif cmd == "wq":
self._save_file() self._save_file()
self.running = False self.running = False
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
elif key == 27: elif key == 27:
self.mode = 'normal' self.mode = "normal"
self.command = "" self.command = ""
elif key == curses.KEY_BACKSPACE or key == 127: elif key == curses.KEY_BACKSPACE or key == 127:
if len(self.command) > 1: if len(self.command) > 1:
@ -241,9 +243,9 @@ class RPEditor:
def save_state(self): def save_state(self):
with self.lock: with self.lock:
state = { state = {
'lines': [line for line in self.lines], "lines": list(self.lines),
'cursor_y': self.cursor_y, "cursor_y": self.cursor_y,
'cursor_x': self.cursor_x "cursor_x": self.cursor_x,
} }
self.undo_stack.append(state) self.undo_stack.append(state)
if len(self.undo_stack) > self.max_undo: if len(self.undo_stack) > self.max_undo:
@ -254,71 +256,85 @@ class RPEditor:
with self.lock: with self.lock:
if self.undo_stack: if self.undo_stack:
current_state = { current_state = {
'lines': [line for line in self.lines], "lines": list(self.lines),
'cursor_y': self.cursor_y, "cursor_y": self.cursor_y,
'cursor_x': self.cursor_x "cursor_x": self.cursor_x,
} }
self.redo_stack.append(current_state) self.redo_stack.append(current_state)
state = self.undo_stack.pop() state = self.undo_stack.pop()
self.lines = state['lines'] self.lines = state["lines"]
self.cursor_y = state['cursor_y'] self.cursor_y = state["cursor_y"]
self.cursor_x = state['cursor_x'] self.cursor_x = state["cursor_x"]
def redo(self): def redo(self):
with self.lock: with self.lock:
if self.redo_stack: if self.redo_stack:
current_state = { current_state = {
'lines': [line for line in self.lines], "lines": list(self.lines),
'cursor_y': self.cursor_y, "cursor_y": self.cursor_y,
'cursor_x': self.cursor_x "cursor_x": self.cursor_x,
} }
self.undo_stack.append(current_state) self.undo_stack.append(current_state)
state = self.redo_stack.pop() state = self.redo_stack.pop()
self.lines = state['lines'] self.lines = state["lines"]
self.cursor_y = state['cursor_y'] self.cursor_y = state["cursor_y"]
self.cursor_x = state['cursor_x'] self.cursor_x = state["cursor_x"]
def _insert_text(self, text): def _insert_text(self, text):
self.save_state() self.save_state()
lines = text.split('\n') lines = text.split("\n")
if len(lines) == 1: if len(lines) == 1:
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + text + self.lines[self.cursor_y][self.cursor_x:] self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x]
+ text
+ self.lines[self.cursor_y][self.cursor_x :]
)
self.cursor_x += len(text) self.cursor_x += len(text)
else: else:
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0] first = self.lines[self.cursor_y][: self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:] last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :]
self.lines[self.cursor_y] = first self.lines[self.cursor_y] = first
for i in range(1, len(lines)-1): for i in range(1, len(lines) - 1):
self.lines.insert(self.cursor_y + i, lines[i]) self.lines.insert(self.cursor_y + i, lines[i])
self.lines.insert(self.cursor_y + len(lines) - 1, last) self.lines.insert(self.cursor_y + len(lines) - 1, last)
self.cursor_y += len(lines) - 1 self.cursor_y += len(lines) - 1
self.cursor_x = len(lines[-1]) self.cursor_x = len(lines[-1])
def insert_text(self, text): def insert_text(self, text):
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text})) self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text}))
def _delete_char(self): def _delete_char(self):
self.save_state() self.save_state()
if self.cursor_x < len(self.lines[self.cursor_y]): if self.cursor_x < len(self.lines[self.cursor_y]):
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + self.lines[self.cursor_y][self.cursor_x+1:] self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x]
+ self.lines[self.cursor_y][self.cursor_x + 1 :]
)
def delete_char(self): def delete_char(self):
self.client_sock.send(pickle.dumps({'command': 'delete_char'})) self.client_sock.send(pickle.dumps({"command": "delete_char"}))
def _insert_char(self, char): def _insert_char(self, char):
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + char + self.lines[self.cursor_y][self.cursor_x:] self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x]
+ char
+ self.lines[self.cursor_y][self.cursor_x :]
)
self.cursor_x += 1 self.cursor_x += 1
def _split_line(self): def _split_line(self):
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] self.lines[self.cursor_y] = line[: self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:]) self.lines.insert(self.cursor_y + 1, line[self.cursor_x :])
self.cursor_y += 1 self.cursor_y += 1
self.cursor_x = 0 self.cursor_x = 0
def _backspace(self): def _backspace(self):
if self.cursor_x > 0: if self.cursor_x > 0:
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x-1] + self.lines[self.cursor_y][self.cursor_x:] self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x - 1]
+ self.lines[self.cursor_y][self.cursor_x :]
)
self.cursor_x -= 1 self.cursor_x -= 1
elif self.cursor_y > 0: elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1]) prev_len = len(self.lines[self.cursor_y - 1])
@ -347,7 +363,7 @@ class RPEditor:
self.cursor_x = 0 self.cursor_x = 0
def set_text(self, text): def set_text(self, text):
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text})) self.client_sock.send(pickle.dumps({"command": "set_text", "text": text}))
def _goto_line(self, line_num): def _goto_line(self, line_num):
line_num = max(0, min(line_num, len(self.lines) - 1)) line_num = max(0, min(line_num, len(self.lines) - 1))
@ -355,24 +371,26 @@ class RPEditor:
self.cursor_x = 0 self.cursor_x = 0
def goto_line(self, line_num): def goto_line(self, line_num):
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num})) self.client_sock.send(
pickle.dumps({"command": "goto_line", "line_num": line_num})
)
def get_text(self): def get_text(self):
self.client_sock.send(pickle.dumps({'command': 'get_text'})) self.client_sock.send(pickle.dumps({"command": "get_text"}))
try: try:
return pickle.loads(self.client_sock.recv(4096)) return pickle.loads(self.client_sock.recv(4096))
except: except:
return '' return ""
def get_cursor(self): def get_cursor(self):
self.client_sock.send(pickle.dumps({'command': 'get_cursor'})) self.client_sock.send(pickle.dumps({"command": "get_cursor"}))
try: try:
return pickle.loads(self.client_sock.recv(4096)) return pickle.loads(self.client_sock.recv(4096))
except: except:
return (0, 0) return (0, 0)
def get_file_info(self): def get_file_info(self):
self.client_sock.send(pickle.dumps({'command': 'get_file_info'})) self.client_sock.send(pickle.dumps({"command": "get_file_info"}))
try: try:
return pickle.loads(self.client_sock.recv(4096)) return pickle.loads(self.client_sock.recv(4096))
except: except:
@ -390,46 +408,46 @@ class RPEditor:
break break
def execute_command(self, command): def execute_command(self, command):
cmd = command.get('command') cmd = command.get("command")
if cmd == 'insert_text': if cmd == "insert_text":
self._insert_text(command['text']) self._insert_text(command["text"])
elif cmd == 'delete_char': elif cmd == "delete_char":
self._delete_char() self._delete_char()
elif cmd == 'save_file': elif cmd == "save_file":
self._save_file() self._save_file()
elif cmd == 'set_text': elif cmd == "set_text":
self._set_text(command['text']) self._set_text(command["text"])
elif cmd == 'goto_line': elif cmd == "goto_line":
self._goto_line(command['line_num']) self._goto_line(command["line_num"])
elif cmd == 'get_text': elif cmd == "get_text":
result = '\n'.join(self.lines) result = "\n".join(self.lines)
try: try:
self.server_sock.send(pickle.dumps(result)) self.server_sock.send(pickle.dumps(result))
except: except:
pass pass
elif cmd == 'get_cursor': elif cmd == "get_cursor":
result = (self.cursor_y, self.cursor_x) result = (self.cursor_y, self.cursor_x)
try: try:
self.server_sock.send(pickle.dumps(result)) self.server_sock.send(pickle.dumps(result))
except: except:
pass pass
elif cmd == 'get_file_info': elif cmd == "get_file_info":
result = { result = {
'filename': self.filename, "filename": self.filename,
'lines': len(self.lines), "lines": len(self.lines),
'cursor': (self.cursor_y, self.cursor_x), "cursor": (self.cursor_y, self.cursor_x),
'mode': self.mode "mode": self.mode,
} }
try: try:
self.server_sock.send(pickle.dumps(result)) self.server_sock.send(pickle.dumps(result))
except: except:
pass pass
elif cmd == 'stop': elif cmd == "stop":
self.running = False self.running = False
def move_cursor_to(self, y, x): def move_cursor_to(self, y, x):
with self.lock: with self.lock:
self.cursor_y = max(0, min(y, len(self.lines)-1)) self.cursor_y = max(0, min(y, len(self.lines) - 1))
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y]))) self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
def get_line(self, line_num): def get_line(self, line_num):
@ -469,9 +487,9 @@ class RPEditor:
else: else:
first_part = self.lines[start_line][:start_col] first_part = self.lines[start_line][:start_col]
last_part = self.lines[end_line][end_col:] last_part = self.lines[end_line][end_col:]
new_lines = new_text.split('\n') new_lines = new_text.split("\n")
self.lines[start_line] = first_part + new_lines[0] self.lines[start_line] = first_part + new_lines[0]
del self.lines[start_line + 1:end_line + 1] del self.lines[start_line + 1 : end_line + 1]
for i, new_line in enumerate(new_lines[1:], 1): for i, new_line in enumerate(new_lines[1:], 1):
self.lines.insert(start_line + i, new_line) self.lines.insert(start_line + i, new_line)
if len(new_lines) > 1: if len(new_lines) > 1:
@ -511,7 +529,7 @@ class RPEditor:
for i in range(sl + 1, el): for i in range(sl + 1, el):
result.append(self.lines[i]) result.append(self.lines[i])
result.append(self.lines[el][:ec]) result.append(self.lines[el][:ec])
return '\n'.join(result) return "\n".join(result)
def delete_selection(self): def delete_selection(self):
with self.lock: with self.lock:
@ -537,24 +555,24 @@ class RPEditor:
break break
if match: if match:
indent = len(self.lines[i]) - len(self.lines[i].lstrip()) indent = len(self.lines[i]) - len(self.lines[i].lstrip())
indented_replace = [' ' * indent + line for line in replace_lines] indented_replace = [" " * indent + line for line in replace_lines]
self.lines[i:i+len(search_lines)] = indented_replace self.lines[i : i + len(search_lines)] = indented_replace
return True return True
return False return False
def apply_diff(self, diff_text): def apply_diff(self, diff_text):
with self.lock: with self.lock:
self.save_state() self.save_state()
lines = diff_text.split('\n') lines = diff_text.split("\n")
for line in lines: for line in lines:
if line.startswith('@@'): if line.startswith("@@"):
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line) match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
if match: if match:
start_line = int(match.group(1)) - 1 start_line = int(match.group(1)) - 1
elif line.startswith('-'): elif line.startswith("-"):
if start_line < len(self.lines): if start_line < len(self.lines):
del self.lines[start_line] del self.lines[start_line]
elif line.startswith('+'): elif line.startswith("+"):
self.lines.insert(start_line, line[1:]) self.lines.insert(start_line, line[1:])
start_line += 1 start_line += 1
@ -574,6 +592,7 @@ class RPEditor:
if self.thread: if self.thread:
self.thread.join() self.thread.join()
def main(): def main():
filename = sys.argv[1] if len(sys.argv) > 1 else None filename = sys.argv[1] if len(sys.argv) > 1 else None
editor = RPEditor(filename) editor = RPEditor(filename)
@ -583,5 +602,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,14 +3,13 @@
Advanced input handler for PR Assistant with editor mode, file inclusion, and image support. Advanced input handler for PR Assistant with editor mode, file inclusion, and image support.
""" """
import os
import re
import base64 import base64
import mimetypes import mimetypes
import re
import readline import readline
import glob
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
# from pr.ui.colors import Colors # Avoid import issues # from pr.ui.colors import Colors # Avoid import issues
@ -29,7 +28,7 @@ class AdvancedInputHandler:
return None return None
readline.set_completer(completer) readline.set_completer(completer)
readline.parse_and_bind('tab: complete') readline.parse_and_bind("tab: complete")
except: except:
pass # Readline not available pass # Readline not available
@ -60,7 +59,7 @@ class AdvancedInputHandler:
return "" return ""
# Check for special commands # Check for special commands
if user_input.lower() == '/editor': if user_input.lower() == "/editor":
self.toggle_editor_mode() self.toggle_editor_mode()
return self.get_input(prompt) # Recurse to get new input return self.get_input(prompt) # Recurse to get new input
@ -74,23 +73,25 @@ class AdvancedInputHandler:
def _get_editor_input(self, prompt: str) -> Optional[str]: def _get_editor_input(self, prompt: str) -> Optional[str]:
"""Get multi-line input for editor mode.""" """Get multi-line input for editor mode."""
try: try:
print("Editor mode: Enter your message. Type 'END' on a new line to finish.") print(
"Editor mode: Enter your message. Type 'END' on a new line to finish."
)
print("Type '/simple' to switch back to simple mode.") print("Type '/simple' to switch back to simple mode.")
lines = [] lines = []
while True: while True:
try: try:
line = input() line = input()
if line.strip().lower() == 'end': if line.strip().lower() == "end":
break break
elif line.strip().lower() == '/simple': elif line.strip().lower() == "/simple":
self.toggle_editor_mode() self.toggle_editor_mode()
return self.get_input(prompt) # Switch back and get input return self.get_input(prompt) # Switch back and get input
lines.append(line) lines.append(line)
except EOFError: except EOFError:
break break
content = '\n'.join(lines).strip() content = "\n".join(lines).strip()
if not content: if not content:
return "" return ""
@ -114,12 +115,13 @@ class AdvancedInputHandler:
def _process_file_inclusions(self, text: str) -> str: def _process_file_inclusions(self, text: str) -> str:
"""Replace @[filename] with file contents.""" """Replace @[filename] with file contents."""
def replace_file(match): def replace_file(match):
filename = match.group(1).strip() filename = match.group(1).strip()
try: try:
path = Path(filename).expanduser().resolve() path = Path(filename).expanduser().resolve()
if path.exists() and path.is_file(): if path.exists() and path.is_file():
with open(path, 'r', encoding='utf-8', errors='replace') as f: with open(path, encoding="utf-8", errors="replace") as f:
content = f.read() content = f.read()
return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n" return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n"
else: else:
@ -128,7 +130,7 @@ class AdvancedInputHandler:
return f"[Error reading file {filename}: {e}]" return f"[Error reading file {filename}: {e}]"
# Replace @[filename] patterns # Replace @[filename] patterns
pattern = r'@\[([^\]]+)\]' pattern = r"@\[([^\]]+)\]"
return re.sub(pattern, replace_file, text) return re.sub(pattern, replace_file, text)
def _process_image_inclusions(self, text: str) -> str: def _process_image_inclusions(self, text: str) -> str:
@ -143,20 +145,22 @@ class AdvancedInputHandler:
path = Path(word.strip()).expanduser().resolve() path = Path(word.strip()).expanduser().resolve()
if path.exists() and path.is_file(): if path.exists() and path.is_file():
mime_type, _ = mimetypes.guess_type(str(path)) mime_type, _ = mimetypes.guess_type(str(path))
if mime_type and mime_type.startswith('image/'): if mime_type and mime_type.startswith("image/"):
# Encode image # Encode image
with open(path, 'rb') as f: with open(path, "rb") as f:
image_data = base64.b64encode(f.read()).decode('utf-8') image_data = base64.b64encode(f.read()).decode("utf-8")
# Replace with data URL # Replace with data URL
processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n") processed_parts.append(
f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n"
)
continue continue
except: except:
pass pass
processed_parts.append(word) processed_parts.append(word)
return ' '.join(processed_parts) return " ".join(processed_parts)
# Global instance # Global instance

View File

@ -1,7 +1,12 @@
from .knowledge_store import KnowledgeStore, KnowledgeEntry
from .semantic_index import SemanticIndex
from .conversation_memory import ConversationMemory from .conversation_memory import ConversationMemory
from .fact_extractor import FactExtractor from .fact_extractor import FactExtractor
from .knowledge_store import KnowledgeEntry, KnowledgeStore
from .semantic_index import SemanticIndex
__all__ = ['KnowledgeStore', 'KnowledgeEntry', 'SemanticIndex', __all__ = [
'ConversationMemory', 'FactExtractor'] "KnowledgeStore",
"KnowledgeEntry",
"SemanticIndex",
"ConversationMemory",
"FactExtractor",
]

View File

@ -1,7 +1,8 @@
import json import json
import sqlite3 import sqlite3
import time import time
from typing import List, Dict, Any, Optional from typing import Any, Dict, List, Optional
class ConversationMemory: class ConversationMemory:
def __init__(self, db_path: str): def __init__(self, db_path: str):
@ -12,7 +13,8 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS conversation_history ( CREATE TABLE IF NOT EXISTS conversation_history (
conversation_id TEXT PRIMARY KEY, conversation_id TEXT PRIMARY KEY,
session_id TEXT, session_id TEXT,
@ -23,9 +25,11 @@ class ConversationMemory:
topics TEXT, topics TEXT,
metadata TEXT metadata TEXT
) )
''') """
)
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS conversation_messages ( CREATE TABLE IF NOT EXISTS conversation_messages (
message_id TEXT PRIMARY KEY, message_id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL, conversation_id TEXT NOT NULL,
@ -36,117 +40,163 @@ class ConversationMemory:
metadata TEXT, metadata TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id) FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
) )
''') """
)
cursor.execute(''' cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id) CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC) CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id) CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp) CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
''') """
)
conn.commit() conn.commit()
conn.close() conn.close()
def create_conversation(self, conversation_id: str, session_id: Optional[str] = None, def create_conversation(
metadata: Optional[Dict[str, Any]] = None): self,
conversation_id: str,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT INTO conversation_history INSERT INTO conversation_history
(conversation_id, session_id, started_at, metadata) (conversation_id, session_id, started_at, metadata)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
''', ( """,
conversation_id, (
session_id, conversation_id,
time.time(), session_id,
json.dumps(metadata) if metadata else None time.time(),
)) json.dumps(metadata) if metadata else None,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
def add_message(self, conversation_id: str, message_id: str, role: str, def add_message(
content: str, tool_calls: Optional[List[Dict[str, Any]]] = None, self,
metadata: Optional[Dict[str, Any]] = None): conversation_id: str,
message_id: str,
role: str,
content: str,
tool_calls: Optional[List[Dict[str, Any]]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT INTO conversation_messages INSERT INTO conversation_messages
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata) (message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
''', ( """,
message_id, (
conversation_id, message_id,
role, conversation_id,
content, role,
time.time(), content,
json.dumps(tool_calls) if tool_calls else None, time.time(),
json.dumps(metadata) if metadata else None json.dumps(tool_calls) if tool_calls else None,
)) json.dumps(metadata) if metadata else None,
),
)
cursor.execute(''' cursor.execute(
"""
UPDATE conversation_history UPDATE conversation_history
SET message_count = message_count + 1 SET message_count = message_count + 1
WHERE conversation_id = ? WHERE conversation_id = ?
''', (conversation_id,)) """,
(conversation_id,),
)
conn.commit() conn.commit()
conn.close() conn.close()
def get_conversation_messages(self, conversation_id: str, def get_conversation_messages(
limit: Optional[int] = None) -> List[Dict[str, Any]]: self, conversation_id: str, limit: Optional[int] = None
) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
if limit: if limit:
cursor.execute(''' cursor.execute(
"""
SELECT message_id, role, content, timestamp, tool_calls, metadata SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages FROM conversation_messages
WHERE conversation_id = ? WHERE conversation_id = ?
ORDER BY timestamp DESC ORDER BY timestamp DESC
LIMIT ? LIMIT ?
''', (conversation_id, limit)) """,
(conversation_id, limit),
)
else: else:
cursor.execute(''' cursor.execute(
"""
SELECT message_id, role, content, timestamp, tool_calls, metadata SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages FROM conversation_messages
WHERE conversation_id = ? WHERE conversation_id = ?
ORDER BY timestamp ASC ORDER BY timestamp ASC
''', (conversation_id,)) """,
(conversation_id,),
)
messages = [] messages = []
for row in cursor.fetchall(): for row in cursor.fetchall():
messages.append({ messages.append(
'message_id': row[0], {
'role': row[1], "message_id": row[0],
'content': row[2], "role": row[1],
'timestamp': row[3], "content": row[2],
'tool_calls': json.loads(row[4]) if row[4] else None, "timestamp": row[3],
'metadata': json.loads(row[5]) if row[5] else None "tool_calls": json.loads(row[4]) if row[4] else None,
}) "metadata": json.loads(row[5]) if row[5] else None,
}
)
conn.close() conn.close()
return messages return messages
def update_conversation_summary(self, conversation_id: str, summary: str, def update_conversation_summary(
topics: Optional[List[str]] = None): self, conversation_id: str, summary: str, topics: Optional[List[str]] = None
):
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
UPDATE conversation_history UPDATE conversation_history
SET summary = ?, topics = ?, ended_at = ? SET summary = ?, topics = ?, ended_at = ?
WHERE conversation_id = ? WHERE conversation_id = ?
''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id)) """,
(
summary,
json.dumps(topics) if topics else None,
time.time(),
conversation_id,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -155,7 +205,8 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at, SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
h.message_count, h.summary, h.topics h.message_count, h.summary, h.topics
FROM conversation_history h FROM conversation_history h
@ -163,56 +214,69 @@ class ConversationMemory:
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ? WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
ORDER BY h.started_at DESC ORDER BY h.started_at DESC
LIMIT ? LIMIT ?
''', (f'%{query}%', f'%{query}%', f'%{query}%', limit)) """,
(f"%{query}%", f"%{query}%", f"%{query}%", limit),
)
conversations = [] conversations = []
for row in cursor.fetchall(): for row in cursor.fetchall():
conversations.append({ conversations.append(
'conversation_id': row[0], {
'session_id': row[1], "conversation_id": row[0],
'started_at': row[2], "session_id": row[1],
'message_count': row[3], "started_at": row[2],
'summary': row[4], "message_count": row[3],
'topics': json.loads(row[5]) if row[5] else [] "summary": row[4],
}) "topics": json.loads(row[5]) if row[5] else [],
}
)
conn.close() conn.close()
return conversations return conversations
def get_recent_conversations(self, limit: int = 10, def get_recent_conversations(
session_id: Optional[str] = None) -> List[Dict[str, Any]]: self, limit: int = 10, session_id: Optional[str] = None
) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
if session_id: if session_id:
cursor.execute(''' cursor.execute(
"""
SELECT conversation_id, session_id, started_at, ended_at, SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics message_count, summary, topics
FROM conversation_history FROM conversation_history
WHERE session_id = ? WHERE session_id = ?
ORDER BY started_at DESC ORDER BY started_at DESC
LIMIT ? LIMIT ?
''', (session_id, limit)) """,
(session_id, limit),
)
else: else:
cursor.execute(''' cursor.execute(
"""
SELECT conversation_id, session_id, started_at, ended_at, SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics message_count, summary, topics
FROM conversation_history FROM conversation_history
ORDER BY started_at DESC ORDER BY started_at DESC
LIMIT ? LIMIT ?
''', (limit,)) """,
(limit,),
)
conversations = [] conversations = []
for row in cursor.fetchall(): for row in cursor.fetchall():
conversations.append({ conversations.append(
'conversation_id': row[0], {
'session_id': row[1], "conversation_id": row[0],
'started_at': row[2], "session_id": row[1],
'ended_at': row[3], "started_at": row[2],
'message_count': row[4], "ended_at": row[3],
'summary': row[5], "message_count": row[4],
'topics': json.loads(row[6]) if row[6] else [] "summary": row[5],
}) "topics": json.loads(row[6]) if row[6] else [],
}
)
conn.close() conn.close()
return conversations return conversations
@ -221,10 +285,14 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?', cursor.execute(
(conversation_id,)) "DELETE FROM conversation_messages WHERE conversation_id = ?",
cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?', (conversation_id,),
(conversation_id,)) )
cursor.execute(
"DELETE FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
deleted = cursor.rowcount > 0 deleted = cursor.rowcount > 0
conn.commit() conn.commit()
@ -236,24 +304,26 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM conversation_history') cursor.execute("SELECT COUNT(*) FROM conversation_history")
total_conversations = cursor.fetchone()[0] total_conversations = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM conversation_messages') cursor.execute("SELECT COUNT(*) FROM conversation_messages")
total_messages = cursor.fetchone()[0] total_messages = cursor.fetchone()[0]
cursor.execute('SELECT SUM(message_count) FROM conversation_history') cursor.execute("SELECT SUM(message_count) FROM conversation_history")
total_message_count = cursor.fetchone()[0] or 0 cursor.fetchone()[0] or 0
cursor.execute(''' cursor.execute(
"""
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0 SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
''') """
)
avg_messages = cursor.fetchone()[0] or 0 avg_messages = cursor.fetchone()[0] or 0
conn.close() conn.close()
return { return {
'total_conversations': total_conversations, "total_conversations": total_conversations,
'total_messages': total_messages, "total_messages": total_messages,
'average_messages_per_conversation': round(avg_messages, 2) "average_messages_per_conversation": round(avg_messages, 2),
} }

View File

@ -1,16 +1,16 @@
import re import re
import json
from typing import List, Dict, Any, Set
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List
class FactExtractor: class FactExtractor:
def __init__(self): def __init__(self):
self.fact_patterns = [ self.fact_patterns = [
(r'([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)', 'definition'), (r"([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)", "definition"),
(r'([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})', 'temporal'), (r"([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})", "temporal"),
(r'([A-Z][a-z]+) (invented|created|developed) ([^.]+)', 'attribution'), (r"([A-Z][a-z]+) (invented|created|developed) ([^.]+)", "attribution"),
(r'([^.]+) (costs?|worth) (\$[\d,]+)', 'numeric'), (r"([^.]+) (costs?|worth) (\$[\d,]+)", "numeric"),
(r'([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)', 'location'), (r"([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)", "location"),
] ]
def extract_facts(self, text: str) -> List[Dict[str, Any]]: def extract_facts(self, text: str) -> List[Dict[str, Any]]:
@ -19,27 +19,31 @@ class FactExtractor:
for pattern, fact_type in self.fact_patterns: for pattern, fact_type in self.fact_patterns:
matches = re.finditer(pattern, text) matches = re.finditer(pattern, text)
for match in matches: for match in matches:
facts.append({ facts.append(
'type': fact_type, {
'text': match.group(0), "type": fact_type,
'components': match.groups(), "text": match.group(0),
'confidence': 0.7 "components": match.groups(),
}) "confidence": 0.7,
}
)
noun_phrases = self._extract_noun_phrases(text) noun_phrases = self._extract_noun_phrases(text)
for phrase in noun_phrases: for phrase in noun_phrases:
if len(phrase.split()) >= 2: if len(phrase.split()) >= 2:
facts.append({ facts.append(
'type': 'entity', {
'text': phrase, "type": "entity",
'components': [phrase], "text": phrase,
'confidence': 0.5 "components": [phrase],
}) "confidence": 0.5,
}
)
return facts return facts
def _extract_noun_phrases(self, text: str) -> List[str]: def _extract_noun_phrases(self, text: str) -> List[str]:
sentences = re.split(r'[.!?]', text) sentences = re.split(r"[.!?]", text)
phrases = [] phrases = []
for sentence in sentences: for sentence in sentences:
@ -51,25 +55,73 @@ class FactExtractor:
current_phrase.append(word) current_phrase.append(word)
else: else:
if len(current_phrase) >= 2: if len(current_phrase) >= 2:
phrases.append(' '.join(current_phrase)) phrases.append(" ".join(current_phrase))
current_phrase = [] current_phrase = []
if len(current_phrase) >= 2: if len(current_phrase) >= 2:
phrases.append(' '.join(current_phrase)) phrases.append(" ".join(current_phrase))
return list(set(phrases)) return list(set(phrases))
def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]: def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]:
words = re.findall(r'\b[a-z]{4,}\b', text.lower()) words = re.findall(r"\b[a-z]{4,}\b", text.lower())
stopwords = { stopwords = {
'this', 'that', 'these', 'those', 'what', 'which', 'where', 'when', "this",
'with', 'from', 'have', 'been', 'were', 'will', 'would', 'could', "that",
'should', 'about', 'their', 'there', 'other', 'than', 'then', 'them', "these",
'some', 'more', 'very', 'such', 'into', 'through', 'during', 'before', "those",
'after', 'above', 'below', 'between', 'under', 'again', 'further', "what",
'once', 'here', 'both', 'each', 'doing', 'only', 'over', 'same', "which",
'being', 'does', 'just', 'also', 'make', 'made', 'know', 'like' "where",
"when",
"with",
"from",
"have",
"been",
"were",
"will",
"would",
"could",
"should",
"about",
"their",
"there",
"other",
"than",
"then",
"them",
"some",
"more",
"very",
"such",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"again",
"further",
"once",
"here",
"both",
"each",
"doing",
"only",
"over",
"same",
"being",
"does",
"just",
"also",
"make",
"made",
"know",
"like",
} }
filtered_words = [w for w in words if w not in stopwords] filtered_words = [w for w in words if w not in stopwords]
@ -85,57 +137,120 @@ class FactExtractor:
relationships = [] relationships = []
relationship_patterns = [ relationship_patterns = [
(r'([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)', 'employment'), (
(r'([A-Z][a-z]+) (owns|has|possesses) ([^.]+)', 'ownership'), r"([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)",
(r'([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)', 'location'), "employment",
(r'([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)', 'usage'), ),
(r"([A-Z][a-z]+) (owns|has|possesses) ([^.]+)", "ownership"),
(
r"([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)",
"location",
),
(r"([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)", "usage"),
] ]
for pattern, rel_type in relationship_patterns: for pattern, rel_type in relationship_patterns:
matches = re.finditer(pattern, text) matches = re.finditer(pattern, text)
for match in matches: for match in matches:
relationships.append({ relationships.append(
'type': rel_type, {
'subject': match.group(1), "type": rel_type,
'predicate': match.group(2), "subject": match.group(1),
'object': match.group(3), "predicate": match.group(2),
'confidence': 0.6 "object": match.group(3),
}) "confidence": 0.6,
}
)
return relationships return relationships
def extract_metadata(self, text: str) -> Dict[str, Any]: def extract_metadata(self, text: str) -> Dict[str, Any]:
word_count = len(text.split()) word_count = len(text.split())
sentence_count = len(re.split(r'[.!?]', text)) sentence_count = len(re.split(r"[.!?]", text))
urls = re.findall(r'https?://[^\s]+', text) urls = re.findall(r"https?://[^\s]+", text)
email_addresses = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text) email_addresses = re.findall(
dates = re.findall(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b', text) r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text
numbers = re.findall(r'\b\d+(?:,\d{3})*(?:\.\d+)?\b', text) )
dates = re.findall(
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
)
numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text)
return { return {
'word_count': word_count, "word_count": word_count,
'sentence_count': sentence_count, "sentence_count": sentence_count,
'avg_words_per_sentence': round(word_count / max(sentence_count, 1), 2), "avg_words_per_sentence": round(word_count / max(sentence_count, 1), 2),
'urls': urls, "urls": urls,
'email_addresses': email_addresses, "email_addresses": email_addresses,
'dates': dates, "dates": dates,
'numeric_values': numbers, "numeric_values": numbers,
'has_code': bool(re.search(r'```|def |class |import |function ', text)), "has_code": bool(re.search(r"```|def |class |import |function ", text)),
'has_questions': bool(re.search(r'\?', text)) "has_questions": bool(re.search(r"\?", text)),
} }
def categorize_content(self, text: str) -> List[str]: def categorize_content(self, text: str) -> List[str]:
categories = [] categories = []
category_keywords = { category_keywords = {
'programming': ['code', 'function', 'class', 'variable', 'programming', 'software', 'debug'], "programming": [
'data': ['data', 'database', 'query', 'table', 'record', 'statistics', 'analysis'], "code",
'documentation': ['documentation', 'guide', 'tutorial', 'manual', 'readme', 'explain'], "function",
'configuration': ['config', 'settings', 'configuration', 'setup', 'install', 'deployment'], "class",
'testing': ['test', 'testing', 'validate', 'verification', 'quality', 'assertion'], "variable",
'research': ['research', 'study', 'analysis', 'investigation', 'findings', 'results'], "programming",
'planning': ['plan', 'planning', 'schedule', 'roadmap', 'milestone', 'timeline'], "software",
"debug",
],
"data": [
"data",
"database",
"query",
"table",
"record",
"statistics",
"analysis",
],
"documentation": [
"documentation",
"guide",
"tutorial",
"manual",
"readme",
"explain",
],
"configuration": [
"config",
"settings",
"configuration",
"setup",
"install",
"deployment",
],
"testing": [
"test",
"testing",
"validate",
"verification",
"quality",
"assertion",
],
"research": [
"research",
"study",
"analysis",
"investigation",
"findings",
"results",
],
"planning": [
"plan",
"planning",
"schedule",
"roadmap",
"milestone",
"timeline",
],
} }
text_lower = text.lower() text_lower = text.lower()
@ -143,4 +258,4 @@ class FactExtractor:
if any(keyword in text_lower for keyword in keywords): if any(keyword in text_lower for keyword in keywords):
categories.append(category) categories.append(category)
return categories if categories else ['general'] return categories if categories else ["general"]

View File

@ -1,10 +1,12 @@
import json import json
import sqlite3 import sqlite3
import time import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from .semantic_index import SemanticIndex from .semantic_index import SemanticIndex
@dataclass @dataclass
class KnowledgeEntry: class KnowledgeEntry:
entry_id: str entry_id: str
@ -18,16 +20,17 @@ class KnowledgeEntry:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
'entry_id': self.entry_id, "entry_id": self.entry_id,
'category': self.category, "category": self.category,
'content': self.content, "content": self.content,
'metadata': self.metadata, "metadata": self.metadata,
'created_at': self.created_at, "created_at": self.created_at,
'updated_at': self.updated_at, "updated_at": self.updated_at,
'access_count': self.access_count, "access_count": self.access_count,
'importance_score': self.importance_score "importance_score": self.importance_score,
} }
class KnowledgeStore: class KnowledgeStore:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
@ -39,7 +42,8 @@ class KnowledgeStore:
def _initialize_store(self): def _initialize_store(self):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS knowledge_entries ( CREATE TABLE IF NOT EXISTS knowledge_entries (
entry_id TEXT PRIMARY KEY, entry_id TEXT PRIMARY KEY,
category TEXT NOT NULL, category TEXT NOT NULL,
@ -50,44 +54,54 @@ class KnowledgeStore:
access_count INTEGER DEFAULT 0, access_count INTEGER DEFAULT 0,
importance_score REAL DEFAULT 1.0 importance_score REAL DEFAULT 1.0
) )
''') """
)
cursor.execute(''' cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category) CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC) CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC) CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
''') """
)
self.conn.commit() self.conn.commit()
def _load_index(self): def _load_index(self):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('SELECT entry_id, content FROM knowledge_entries') cursor.execute("SELECT entry_id, content FROM knowledge_entries")
for row in cursor.fetchall(): for row in cursor.fetchall():
self.semantic_index.add_document(row[0], row[1]) self.semantic_index.add_document(row[0], row[1])
def add_entry(self, entry: KnowledgeEntry): def add_entry(self, entry: KnowledgeEntry):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
INSERT OR REPLACE INTO knowledge_entries INSERT OR REPLACE INTO knowledge_entries
(entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score) (entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', ( """,
entry.entry_id, (
entry.category, entry.entry_id,
entry.content, entry.category,
json.dumps(entry.metadata), entry.content,
entry.created_at, json.dumps(entry.metadata),
entry.updated_at, entry.created_at,
entry.access_count, entry.updated_at,
entry.importance_score entry.access_count,
)) entry.importance_score,
),
)
self.conn.commit() self.conn.commit()
@ -96,20 +110,26 @@ class KnowledgeStore:
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]: def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries FROM knowledge_entries
WHERE entry_id = ? WHERE entry_id = ?
''', (entry_id,)) """,
(entry_id,),
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
cursor.execute(''' cursor.execute(
"""
UPDATE knowledge_entries UPDATE knowledge_entries
SET access_count = access_count + 1 SET access_count = access_count + 1
WHERE entry_id = ? WHERE entry_id = ?
''', (entry_id,)) """,
(entry_id,),
)
self.conn.commit() self.conn.commit()
return KnowledgeEntry( return KnowledgeEntry(
@ -120,13 +140,14 @@ class KnowledgeStore:
created_at=row[4], created_at=row[4],
updated_at=row[5], updated_at=row[5],
access_count=row[6] + 1, access_count=row[6] + 1,
importance_score=row[7] importance_score=row[7],
) )
return None return None
def search_entries(self, query: str, category: Optional[str] = None, def search_entries(
top_k: int = 5) -> List[KnowledgeEntry]: self, query: str, category: Optional[str] = None, top_k: int = 5
) -> List[KnowledgeEntry]:
search_results = self.semantic_index.search(query, top_k * 2) search_results = self.semantic_index.search(query, top_k * 2)
cursor = self.conn.cursor() cursor = self.conn.cursor()
@ -134,17 +155,23 @@ class KnowledgeStore:
entries = [] entries = []
for entry_id, score in search_results: for entry_id, score in search_results:
if category: if category:
cursor.execute(''' cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries FROM knowledge_entries
WHERE entry_id = ? AND category = ? WHERE entry_id = ? AND category = ?
''', (entry_id, category)) """,
(entry_id, category),
)
else: else:
cursor.execute(''' cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries FROM knowledge_entries
WHERE entry_id = ? WHERE entry_id = ?
''', (entry_id,)) """,
(entry_id,),
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
@ -156,7 +183,7 @@ class KnowledgeStore:
created_at=row[4], created_at=row[4],
updated_at=row[5], updated_at=row[5],
access_count=row[6], access_count=row[6],
importance_score=row[7] importance_score=row[7],
) )
entries.append(entry) entries.append(entry)
@ -168,44 +195,52 @@ class KnowledgeStore:
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]: def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries FROM knowledge_entries
WHERE category = ? WHERE category = ?
ORDER BY importance_score DESC, created_at DESC ORDER BY importance_score DESC, created_at DESC
LIMIT ? LIMIT ?
''', (category, limit)) """,
(category, limit),
)
entries = [] entries = []
for row in cursor.fetchall(): for row in cursor.fetchall():
entries.append(KnowledgeEntry( entries.append(
entry_id=row[0], KnowledgeEntry(
category=row[1], entry_id=row[0],
content=row[2], category=row[1],
metadata=json.loads(row[3]) if row[3] else {}, content=row[2],
created_at=row[4], metadata=json.loads(row[3]) if row[3] else {},
updated_at=row[5], created_at=row[4],
access_count=row[6], updated_at=row[5],
importance_score=row[7] access_count=row[6],
)) importance_score=row[7],
)
)
return entries return entries
def update_importance(self, entry_id: str, importance_score: float): def update_importance(self, entry_id: str, importance_score: float):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute(''' cursor.execute(
"""
UPDATE knowledge_entries UPDATE knowledge_entries
SET importance_score = ?, updated_at = ? SET importance_score = ?, updated_at = ?
WHERE entry_id = ? WHERE entry_id = ?
''', (importance_score, time.time(), entry_id)) """,
(importance_score, time.time(), entry_id),
)
self.conn.commit() self.conn.commit()
def delete_entry(self, entry_id: str) -> bool: def delete_entry(self, entry_id: str) -> bool:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,)) cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,))
deleted = cursor.rowcount > 0 deleted = cursor.rowcount > 0
self.conn.commit() self.conn.commit()
@ -218,27 +253,29 @@ class KnowledgeStore:
def get_statistics(self) -> Dict[str, Any]: def get_statistics(self) -> Dict[str, Any]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('SELECT COUNT(*) FROM knowledge_entries') cursor.execute("SELECT COUNT(*) FROM knowledge_entries")
total_entries = cursor.fetchone()[0] total_entries = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(DISTINCT category) FROM knowledge_entries') cursor.execute("SELECT COUNT(DISTINCT category) FROM knowledge_entries")
total_categories = cursor.fetchone()[0] total_categories = cursor.fetchone()[0]
cursor.execute(''' cursor.execute(
"""
SELECT category, COUNT(*) as count SELECT category, COUNT(*) as count
FROM knowledge_entries FROM knowledge_entries
GROUP BY category GROUP BY category
ORDER BY count DESC ORDER BY count DESC
''') """
)
category_counts = {row[0]: row[1] for row in cursor.fetchall()} category_counts = {row[0]: row[1] for row in cursor.fetchall()}
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries') cursor.execute("SELECT SUM(access_count) FROM knowledge_entries")
total_accesses = cursor.fetchone()[0] or 0 total_accesses = cursor.fetchone()[0] or 0
return { return {
'total_entries': total_entries, "total_entries": total_entries,
'total_categories': total_categories, "total_categories": total_categories,
'category_distribution': category_counts, "category_distribution": category_counts,
'total_accesses': total_accesses, "total_accesses": total_accesses,
'vocabulary_size': len(self.semantic_index.vocabulary) "vocabulary_size": len(self.semantic_index.vocabulary),
} }

View File

@ -1,7 +1,8 @@
import math import math
import re import re
from collections import Counter, defaultdict from collections import Counter, defaultdict
from typing import List, Dict, Tuple, Set from typing import Dict, List, Set, Tuple
class SemanticIndex: class SemanticIndex:
def __init__(self): def __init__(self):
@ -12,7 +13,7 @@ class SemanticIndex:
def _tokenize(self, text: str) -> List[str]: def _tokenize(self, text: str) -> List[str]:
text = text.lower() text = text.lower()
text = re.sub(r'[^a-z0-9\s]', ' ', text) text = re.sub(r"[^a-z0-9\s]", " ", text)
tokens = text.split() tokens = text.split()
return tokens return tokens
@ -78,8 +79,12 @@ class SemanticIndex:
scores.sort(key=lambda x: x[1], reverse=True) scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k] return scores[:top_k]
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float: def _cosine_similarity(
dot_product = sum(vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)) self, vec1: Dict[str, float], vec2: Dict[str, float]
) -> float:
dot_product = sum(
vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)
)
norm1 = math.sqrt(sum(val**2 for val in vec1.values())) norm1 = math.sqrt(sum(val**2 for val in vec1.values()))
norm2 = math.sqrt(sum(val**2 for val in vec2.values())) norm2 = math.sqrt(sum(val**2 for val in vec2.values()))
if norm1 == 0 or norm2 == 0: if norm1 == 0 or norm2 == 0:

View File

@ -1,14 +1,13 @@
import threading
import queue import queue
import time
import sys
import subprocess import subprocess
import signal import sys
import os import threading
from pr.ui import Colors import time
from collections import defaultdict
from pr.tools.process_handlers import get_handler_for_process, detect_process_type from pr.tools.process_handlers import detect_process_type, get_handler_for_process
from pr.tools.prompt_detection import get_global_detector from pr.tools.prompt_detection import get_global_detector
from pr.ui import Colors
class TerminalMultiplexer: class TerminalMultiplexer:
def __init__(self, name, show_output=True): def __init__(self, name, show_output=True):
@ -21,17 +20,19 @@ class TerminalMultiplexer:
self.active = True self.active = True
self.lock = threading.Lock() self.lock = threading.Lock()
self.metadata = { self.metadata = {
'start_time': time.time(), "start_time": time.time(),
'last_activity': time.time(), "last_activity": time.time(),
'interaction_count': 0, "interaction_count": 0,
'process_type': 'unknown', "process_type": "unknown",
'state': 'active' "state": "active",
} }
self.handler = None self.handler = None
self.prompt_detector = get_global_detector() self.prompt_detector = get_global_detector()
if self.show_output: if self.show_output:
self.display_thread = threading.Thread(target=self._display_worker, daemon=True) self.display_thread = threading.Thread(
target=self._display_worker, daemon=True
)
self.display_thread.start() self.display_thread.start()
def _display_worker(self): def _display_worker(self):
@ -47,7 +48,9 @@ class TerminalMultiplexer:
try: try:
line = self.stderr_queue.get(timeout=0.1) line = self.stderr_queue.get(timeout=0.1)
if line: if line:
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}") sys.stderr.write(
f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}"
)
sys.stderr.flush() sys.stderr.flush()
except queue.Empty: except queue.Empty:
pass pass
@ -55,40 +58,44 @@ class TerminalMultiplexer:
def write_stdout(self, data): def write_stdout(self, data):
with self.lock: with self.lock:
self.stdout_buffer.append(data) self.stdout_buffer.append(data)
self.metadata['last_activity'] = time.time() self.metadata["last_activity"] = time.time()
# Update handler state if available # Update handler state if available
if self.handler: if self.handler:
self.handler.update_state(data) self.handler.update_state(data)
# Update prompt detector # Update prompt detector
self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type']) self.prompt_detector.update_session_state(
self.name, data, self.metadata["process_type"]
)
if self.show_output: if self.show_output:
self.stdout_queue.put(data) self.stdout_queue.put(data)
def write_stderr(self, data): def write_stderr(self, data):
with self.lock: with self.lock:
self.stderr_buffer.append(data) self.stderr_buffer.append(data)
self.metadata['last_activity'] = time.time() self.metadata["last_activity"] = time.time()
# Update handler state if available # Update handler state if available
if self.handler: if self.handler:
self.handler.update_state(data) self.handler.update_state(data)
# Update prompt detector # Update prompt detector
self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type']) self.prompt_detector.update_session_state(
self.name, data, self.metadata["process_type"]
)
if self.show_output: if self.show_output:
self.stderr_queue.put(data) self.stderr_queue.put(data)
def get_stdout(self): def get_stdout(self):
with self.lock: with self.lock:
return ''.join(self.stdout_buffer) return "".join(self.stdout_buffer)
def get_stderr(self): def get_stderr(self):
with self.lock: with self.lock:
return ''.join(self.stderr_buffer) return "".join(self.stderr_buffer)
def get_all_output(self): def get_all_output(self):
with self.lock: with self.lock:
return { return {
'stdout': ''.join(self.stdout_buffer), "stdout": "".join(self.stdout_buffer),
'stderr': ''.join(self.stderr_buffer) "stderr": "".join(self.stderr_buffer),
} }
def get_metadata(self): def get_metadata(self):
@ -102,31 +109,32 @@ class TerminalMultiplexer:
def set_process_type(self, process_type): def set_process_type(self, process_type):
"""Set the process type and initialize appropriate handler.""" """Set the process type and initialize appropriate handler."""
with self.lock: with self.lock:
self.metadata['process_type'] = process_type self.metadata["process_type"] = process_type
self.handler = get_handler_for_process(process_type, self) self.handler = get_handler_for_process(process_type, self)
def send_input(self, input_data): def send_input(self, input_data):
if hasattr(self, 'process') and self.process.poll() is None: if hasattr(self, "process") and self.process.poll() is None:
try: try:
self.process.stdin.write(input_data + '\n') self.process.stdin.write(input_data + "\n")
self.process.stdin.flush() self.process.stdin.flush()
with self.lock: with self.lock:
self.metadata['last_activity'] = time.time() self.metadata["last_activity"] = time.time()
self.metadata['interaction_count'] += 1 self.metadata["interaction_count"] += 1
except Exception as e: except Exception as e:
self.write_stderr(f"Error sending input: {e}") self.write_stderr(f"Error sending input: {e}")
else: else:
# This will be implemented when we have a process attached # This will be implemented when we have a process attached
# For now, just update activity # For now, just update activity
with self.lock: with self.lock:
self.metadata['last_activity'] = time.time() self.metadata["last_activity"] = time.time()
self.metadata['interaction_count'] += 1 self.metadata["interaction_count"] += 1
def close(self): def close(self):
self.active = False self.active = False
if hasattr(self, 'display_thread'): if hasattr(self, "display_thread"):
self.display_thread.join(timeout=1) self.display_thread.join(timeout=1)
_multiplexers = {} _multiplexers = {}
_mux_counter = 0 _mux_counter = 0
_mux_lock = threading.Lock() _mux_lock = threading.Lock()
@ -134,6 +142,7 @@ _background_monitor = None
_monitor_active = False _monitor_active = False
_monitor_interval = 0.2 # 200ms _monitor_interval = 0.2 # 200ms
def create_multiplexer(name=None, show_output=True): def create_multiplexer(name=None, show_output=True):
global _mux_counter global _mux_counter
with _mux_lock: with _mux_lock:
@ -144,44 +153,50 @@ def create_multiplexer(name=None, show_output=True):
_multiplexers[name] = mux _multiplexers[name] = mux
return name, mux return name, mux
def get_multiplexer(name): def get_multiplexer(name):
return _multiplexers.get(name) return _multiplexers.get(name)
def close_multiplexer(name): def close_multiplexer(name):
mux = _multiplexers.get(name) mux = _multiplexers.get(name)
if mux: if mux:
mux.close() mux.close()
del _multiplexers[name] del _multiplexers[name]
def get_all_multiplexer_states(): def get_all_multiplexer_states():
with _mux_lock: with _mux_lock:
states = {} states = {}
for name, mux in _multiplexers.items(): for name, mux in _multiplexers.items():
states[name] = { states[name] = {
'metadata': mux.get_metadata(), "metadata": mux.get_metadata(),
'output_summary': { "output_summary": {
'stdout_lines': len(mux.stdout_buffer), "stdout_lines": len(mux.stdout_buffer),
'stderr_lines': len(mux.stderr_buffer) "stderr_lines": len(mux.stderr_buffer),
} },
} }
return states return states
def cleanup_all_multiplexers(): def cleanup_all_multiplexers():
for mux in list(_multiplexers.values()): for mux in list(_multiplexers.values()):
mux.close() mux.close()
_multiplexers.clear() _multiplexers.clear()
# Background process management # Background process management
_background_processes = {} _background_processes = {}
_process_lock = threading.Lock() _process_lock = threading.Lock()
class BackgroundProcess: class BackgroundProcess:
def __init__(self, name, command): def __init__(self, name, command):
self.name = name self.name = name
self.command = command self.command = command
self.process = None self.process = None
self.multiplexer = None self.multiplexer = None
self.status = 'starting' self.status = "starting"
self.start_time = time.time() self.start_time = time.time()
self.end_time = None self.end_time = None
@ -205,27 +220,27 @@ class BackgroundProcess:
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
bufsize=1, bufsize=1,
universal_newlines=True universal_newlines=True,
) )
self.status = 'running' self.status = "running"
# Start output monitoring threads # Start output monitoring threads
threading.Thread(target=self._monitor_stdout, daemon=True).start() threading.Thread(target=self._monitor_stdout, daemon=True).start()
threading.Thread(target=self._monitor_stderr, daemon=True).start() threading.Thread(target=self._monitor_stderr, daemon=True).start()
return {'status': 'success', 'pid': self.process.pid} return {"status": "success", "pid": self.process.pid}
except Exception as e: except Exception as e:
self.status = 'error' self.status = "error"
return {'status': 'error', 'error': str(e)} return {"status": "error", "error": str(e)}
def _monitor_stdout(self): def _monitor_stdout(self):
"""Monitor stdout from the process.""" """Monitor stdout from the process."""
try: try:
for line in iter(self.process.stdout.readline, ''): for line in iter(self.process.stdout.readline, ""):
if line: if line:
self.multiplexer.write_stdout(line.rstrip('\n\r')) self.multiplexer.write_stdout(line.rstrip("\n\r"))
except Exception as e: except Exception as e:
self.write_stderr(f"Error reading stdout: {e}") self.write_stderr(f"Error reading stdout: {e}")
finally: finally:
@ -234,29 +249,33 @@ class BackgroundProcess:
def _monitor_stderr(self): def _monitor_stderr(self):
"""Monitor stderr from the process.""" """Monitor stderr from the process."""
try: try:
for line in iter(self.process.stderr.readline, ''): for line in iter(self.process.stderr.readline, ""):
if line: if line:
self.multiplexer.write_stderr(line.rstrip('\n\r')) self.multiplexer.write_stderr(line.rstrip("\n\r"))
except Exception as e: except Exception as e:
self.write_stderr(f"Error reading stderr: {e}") self.write_stderr(f"Error reading stderr: {e}")
def _check_completion(self): def _check_completion(self):
"""Check if process has completed.""" """Check if process has completed."""
if self.process and self.process.poll() is not None: if self.process and self.process.poll() is not None:
self.status = 'completed' self.status = "completed"
self.end_time = time.time() self.end_time = time.time()
def get_info(self): def get_info(self):
"""Get process information.""" """Get process information."""
self._check_completion() self._check_completion()
return { return {
'name': self.name, "name": self.name,
'command': self.command, "command": self.command,
'status': self.status, "status": self.status,
'pid': self.process.pid if self.process else None, "pid": self.process.pid if self.process else None,
'start_time': self.start_time, "start_time": self.start_time,
'end_time': self.end_time, "end_time": self.end_time,
'runtime': time.time() - self.start_time if not self.end_time else self.end_time - self.start_time "runtime": (
time.time() - self.start_time
if not self.end_time
else self.end_time - self.start_time
),
} }
def get_output(self, lines=None): def get_output(self, lines=None):
@ -265,8 +284,8 @@ class BackgroundProcess:
return [] return []
all_output = self.multiplexer.get_all_output() all_output = self.multiplexer.get_all_output()
stdout_lines = all_output['stdout'].split('\n') if all_output['stdout'] else [] stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else []
stderr_lines = all_output['stderr'].split('\n') if all_output['stderr'] else [] stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else []
combined = stdout_lines + stderr_lines combined = stdout_lines + stderr_lines
if lines: if lines:
@ -276,45 +295,47 @@ class BackgroundProcess:
def send_input(self, input_text): def send_input(self, input_text):
"""Send input to the process.""" """Send input to the process."""
if self.process and self.status == 'running': if self.process and self.status == "running":
try: try:
self.process.stdin.write(input_text + '\n') self.process.stdin.write(input_text + "\n")
self.process.stdin.flush() self.process.stdin.flush()
return {'status': 'success'} return {"status": "success"}
except Exception as e: except Exception as e:
return {'status': 'error', 'error': str(e)} return {"status": "error", "error": str(e)}
return {'status': 'error', 'error': 'Process not running or no stdin'} return {"status": "error", "error": "Process not running or no stdin"}
def kill(self): def kill(self):
"""Kill the process.""" """Kill the process."""
if self.process and self.status == 'running': if self.process and self.status == "running":
try: try:
self.process.terminate() self.process.terminate()
# Wait a bit for graceful termination # Wait a bit for graceful termination
time.sleep(0.1) time.sleep(0.1)
if self.process.poll() is None: if self.process.poll() is None:
self.process.kill() self.process.kill()
self.status = 'killed' self.status = "killed"
self.end_time = time.time() self.end_time = time.time()
return {'status': 'success'} return {"status": "success"}
except Exception as e: except Exception as e:
return {'status': 'error', 'error': str(e)} return {"status": "error", "error": str(e)}
return {'status': 'error', 'error': 'Process not running'} return {"status": "error", "error": "Process not running"}
def start_background_process(name, command): def start_background_process(name, command):
"""Start a background process.""" """Start a background process."""
with _process_lock: with _process_lock:
if name in _background_processes: if name in _background_processes:
return {'status': 'error', 'error': f'Process {name} already exists'} return {"status": "error", "error": f"Process {name} already exists"}
process = BackgroundProcess(name, command) process = BackgroundProcess(name, command)
result = process.start() result = process.start()
if result['status'] == 'success': if result["status"] == "success":
_background_processes[name] = process _background_processes[name] = process
return result return result
def get_all_sessions(): def get_all_sessions():
"""Get all background process sessions.""" """Get all background process sessions."""
with _process_lock: with _process_lock:
@ -323,23 +344,31 @@ def get_all_sessions():
sessions[name] = process.get_info() sessions[name] = process.get_info()
return sessions return sessions
def get_session_info(name): def get_session_info(name):
"""Get information about a specific session.""" """Get information about a specific session."""
with _process_lock: with _process_lock:
process = _background_processes.get(name) process = _background_processes.get(name)
return process.get_info() if process else None return process.get_info() if process else None
def get_session_output(name, lines=None): def get_session_output(name, lines=None):
"""Get output from a specific session.""" """Get output from a specific session."""
with _process_lock: with _process_lock:
process = _background_processes.get(name) process = _background_processes.get(name)
return process.get_output(lines) if process else None return process.get_output(lines) if process else None
def send_input_to_session(name, input_text): def send_input_to_session(name, input_text):
"""Send input to a background session.""" """Send input to a background session."""
with _process_lock: with _process_lock:
process = _background_processes.get(name) process = _background_processes.get(name)
return process.send_input(input_text) if process else {'status': 'error', 'error': 'Session not found'} return (
process.send_input(input_text)
if process
else {"status": "error", "error": "Session not found"}
)
def kill_session(name): def kill_session(name):
"""Kill a background session.""" """Kill a background session."""
@ -347,7 +376,7 @@ def kill_session(name):
process = _background_processes.get(name) process = _background_processes.get(name)
if process: if process:
result = process.kill() result = process.kill()
if result['status'] == 'success': if result["status"] == "success":
del _background_processes[name] del _background_processes[name]
return result return result
return {'status': 'error', 'error': 'Session not found'} return {"status": "error", "error": "Session not found"}

View File

@ -1,10 +1,11 @@
import importlib.util
import os import os
import sys import sys
import importlib.util from typing import Callable, Dict, List
from typing import List, Dict, Callable, Any
from pr.core.logging import get_logger from pr.core.logging import get_logger
logger = get_logger('plugins') logger = get_logger("plugins")
PLUGINS_DIR = os.path.expanduser("~/.pr/plugins") PLUGINS_DIR = os.path.expanduser("~/.pr/plugins")
@ -21,7 +22,7 @@ class PluginLoader:
logger.info("No plugins directory found") logger.info("No plugins directory found")
return [] return []
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith('.py')] plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith(".py")]
for plugin_file in plugin_files: for plugin_file in plugin_files:
try: try:
@ -44,16 +45,20 @@ class PluginLoader:
sys.modules[plugin_name] = module sys.modules[plugin_name] = module
spec.loader.exec_module(module) spec.loader.exec_module(module)
if hasattr(module, 'register_tools'): if hasattr(module, "register_tools"):
tools = module.register_tools() tools = module.register_tools()
if isinstance(tools, list): if isinstance(tools, list):
self.plugin_tools.extend(tools) self.plugin_tools.extend(tools)
self.loaded_plugins[plugin_name] = module self.loaded_plugins[plugin_name] = module
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)") logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
else: else:
logger.warning(f"Plugin {plugin_name} register_tools() did not return a list") logger.warning(
f"Plugin {plugin_name} register_tools() did not return a list"
)
else: else:
logger.warning(f"Plugin {plugin_name} does not have register_tools() function") logger.warning(
f"Plugin {plugin_name} does not have register_tools() function"
)
def get_plugin_function(self, tool_name: str) -> Callable: def get_plugin_function(self, tool_name: str) -> Callable:
for plugin_name, module in self.loaded_plugins.items(): for plugin_name, module in self.loaded_plugins.items():
@ -67,7 +72,7 @@ class PluginLoader:
def create_example_plugin(): def create_example_plugin():
example_plugin = os.path.join(PLUGINS_DIR, 'example_plugin.py') example_plugin = os.path.join(PLUGINS_DIR, "example_plugin.py")
if os.path.exists(example_plugin): if os.path.exists(example_plugin):
return return
@ -121,7 +126,7 @@ def register_tools():
try: try:
os.makedirs(PLUGINS_DIR, exist_ok=True) os.makedirs(PLUGINS_DIR, exist_ok=True)
with open(example_plugin, 'w') as f: with open(example_plugin, "w") as f:
f.write(example_code) f.write(example_code)
logger.info(f"Created example plugin at {example_plugin}") logger.info(f"Created example plugin at {example_plugin}")
except Exception as e: except Exception as e:

View File

@ -1,25 +1,86 @@
from pr.tools.base import get_tools_definition from pr.tools.agents import (
from pr.tools.filesystem import ( collaborate_agents,
read_file, write_file, list_directory, mkdir, chdir, getpwd, index_source_directory, search_replace create_agent,
execute_agent_task,
list_agents,
remove_agent,
)
from pr.tools.base import get_tools_definition
from pr.tools.command import (
kill_process,
run_command,
run_command_interactive,
tail_process,
)
from pr.tools.database import db_get, db_query, db_set
from pr.tools.editor import (
close_editor,
editor_insert_text,
editor_replace_text,
editor_search,
open_editor,
)
from pr.tools.filesystem import (
chdir,
getpwd,
index_source_directory,
list_directory,
mkdir,
read_file,
search_replace,
write_file,
)
from pr.tools.memory import (
add_knowledge_entry,
delete_knowledge_entry,
get_knowledge_by_category,
get_knowledge_entry,
get_knowledge_statistics,
search_knowledge,
update_knowledge_importance,
) )
from pr.tools.command import run_command, run_command_interactive, tail_process, kill_process
from pr.tools.editor import open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor
from pr.tools.database import db_set, db_get, db_query
from pr.tools.web import http_fetch, web_search, web_search_news
from pr.tools.python_exec import python_exec
from pr.tools.patch import apply_patch, create_diff from pr.tools.patch import apply_patch, create_diff
from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents from pr.tools.python_exec import python_exec
from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics from pr.tools.web import http_fetch, web_search, web_search_news
__all__ = [ __all__ = [
'get_tools_definition', "get_tools_definition",
'read_file', 'write_file', 'list_directory', 'mkdir', 'chdir', 'getpwd', 'index_source_directory', 'search_replace', "read_file",
'open_editor', 'editor_insert_text', 'editor_replace_text', 'editor_search','close_editor', "write_file",
'run_command', 'run_command_interactive', "list_directory",
'db_set', 'db_get', 'db_query', "mkdir",
'http_fetch', 'web_search', 'web_search_news', "chdir",
'python_exec','tail_process', 'kill_process', "getpwd",
'apply_patch', 'create_diff', "index_source_directory",
'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents', "search_replace",
'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics' "open_editor",
"editor_insert_text",
"editor_replace_text",
"editor_search",
"close_editor",
"run_command",
"run_command_interactive",
"db_set",
"db_get",
"db_query",
"http_fetch",
"web_search",
"web_search_news",
"python_exec",
"tail_process",
"kill_process",
"apply_patch",
"create_diff",
"create_agent",
"list_agents",
"execute_agent_task",
"remove_agent",
"collaborate_agents",
"add_knowledge_entry",
"get_knowledge_entry",
"search_knowledge",
"get_knowledge_by_category",
"update_knowledge_importance",
"delete_knowledge_entry",
"get_knowledge_statistics",
] ]

View File

@ -1,13 +1,15 @@
import os import os
from typing import Dict, Any, List from typing import Any, Dict, List
from pr.agents.agent_manager import AgentManager from pr.agents.agent_manager import AgentManager
from pr.core.api import call_api from pr.core.api import call_api
def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]: def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
"""Create a new agent with the specified role.""" """Create a new agent with the specified role."""
try: try:
# Get db_path from environment or default # Get db_path from environment or default
db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite') db_path = os.environ.get("ASSISTANT_DB_PATH", "~/.assistant_db.sqlite")
db_path = os.path.expanduser(db_path) db_path = os.path.expanduser(db_path)
manager = AgentManager(db_path, call_api) manager = AgentManager(db_path, call_api)
@ -16,47 +18,57 @@ def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def list_agents() -> Dict[str, Any]: def list_agents() -> Dict[str, Any]:
"""List all active agents.""" """List all active agents."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api) manager = AgentManager(db_path, call_api)
agents = [] agents = []
for agent_id, agent in manager.active_agents.items(): for agent_id, agent in manager.active_agents.items():
agents.append({ agents.append(
"agent_id": agent_id, {
"role": agent.role.name, "agent_id": agent_id,
"task_count": agent.task_count, "role": agent.role.name,
"message_count": len(agent.message_history) "task_count": agent.task_count,
}) "message_count": len(agent.message_history),
}
)
return {"status": "success", "agents": agents} return {"status": "success", "agents": agents}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
def execute_agent_task(
agent_id: str, task: str, context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Execute a task with the specified agent.""" """Execute a task with the specified agent."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api) manager = AgentManager(db_path, call_api)
result = manager.execute_agent_task(agent_id, task, context) result = manager.execute_agent_task(agent_id, task, context)
return result return result
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def remove_agent(agent_id: str) -> Dict[str, Any]: def remove_agent(agent_id: str) -> Dict[str, Any]:
"""Remove an agent.""" """Remove an agent."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api) manager = AgentManager(db_path, call_api)
success = manager.remove_agent(agent_id) success = manager.remove_agent(agent_id)
return {"status": "success" if success else "not_found", "agent_id": agent_id} return {"status": "success" if success else "not_found", "agent_id": agent_id}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
def collaborate_agents(
orchestrator_id: str, task: str, agent_roles: List[str]
) -> Dict[str, Any]:
"""Collaborate multiple agents on a task.""" """Collaborate multiple agents on a task."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api) manager = AgentManager(db_path, call_api)
result = manager.collaborate_agents(orchestrator_id, task, agent_roles) result = manager.collaborate_agents(orchestrator_id, task, agent_roles)
return result return result

View File

@ -10,12 +10,12 @@ def get_tools_definition():
"properties": { "properties": {
"pid": { "pid": {
"type": "integer", "type": "integer",
"description": "The process ID returned by run_command when status is 'running'." "description": "The process ID returned by run_command when status is 'running'.",
} }
}, },
"required": ["pid"] "required": ["pid"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -27,17 +27,17 @@ def get_tools_definition():
"properties": { "properties": {
"pid": { "pid": {
"type": "integer", "type": "integer",
"description": "The process ID returned by run_command when status is 'running'." "description": "The process ID returned by run_command when status is 'running'.",
}, },
"timeout": { "timeout": {
"type": "integer", "type": "integer",
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.", "description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
"default": 30 "default": 30,
} },
}, },
"required": ["pid"] "required": ["pid"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -48,11 +48,14 @@ def get_tools_definition():
"type": "object", "type": "object",
"properties": { "properties": {
"url": {"type": "string", "description": "The URL to fetch"}, "url": {"type": "string", "description": "The URL to fetch"},
"headers": {"type": "object", "description": "Optional HTTP headers"} "headers": {
"type": "object",
"description": "Optional HTTP headers",
},
}, },
"required": ["url"] "required": ["url"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -62,12 +65,19 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"command": {"type": "string", "description": "The shell command to execute"}, "command": {
"timeout": {"type": "integer", "description": "Maximum seconds to wait for completion", "default": 30} "type": "string",
"description": "The shell command to execute",
},
"timeout": {
"type": "integer",
"description": "Maximum seconds to wait for completion",
"default": 30,
},
}, },
"required": ["command"] "required": ["command"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -77,11 +87,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"command": {"type": "string", "description": "The interactive command to execute (e.g., vim, nano, top)"} "command": {
"type": "string",
"description": "The interactive command to execute (e.g., vim, nano, top)",
}
}, },
"required": ["command"] "required": ["command"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -91,12 +104,18 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"session_name": {"type": "string", "description": "The name of the session"}, "session_name": {
"input_data": {"type": "string", "description": "The input to send to the session"} "type": "string",
"description": "The name of the session",
},
"input_data": {
"type": "string",
"description": "The input to send to the session",
},
}, },
"required": ["session_name", "input_data"] "required": ["session_name", "input_data"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -106,11 +125,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"session_name": {"type": "string", "description": "The name of the session"} "session_name": {
"type": "string",
"description": "The name of the session",
}
}, },
"required": ["session_name"] "required": ["session_name"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -120,11 +142,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"session_name": {"type": "string", "description": "The name of the session"} "session_name": {
"type": "string",
"description": "The name of the session",
}
}, },
"required": ["session_name"] "required": ["session_name"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -134,11 +159,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"} "filepath": {
"type": "string",
"description": "Path to the file",
}
}, },
"required": ["filepath"] "required": ["filepath"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -148,12 +176,18 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"}, "filepath": {
"content": {"type": "string", "description": "Content to write"} "type": "string",
"description": "Path to the file",
},
"content": {
"type": "string",
"description": "Content to write",
},
}, },
"required": ["filepath", "content"] "required": ["filepath", "content"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -163,11 +197,19 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"path": {"type": "string", "description": "Directory path", "default": "."}, "path": {
"recursive": {"type": "boolean", "description": "List recursively", "default": False} "type": "string",
} "description": "Directory path",
} "default": ".",
} },
"recursive": {
"type": "boolean",
"description": "List recursively",
"default": False,
},
},
},
},
}, },
{ {
"type": "function", "type": "function",
@ -177,11 +219,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"path": {"type": "string", "description": "Path of the directory to create"} "path": {
"type": "string",
"description": "Path of the directory to create",
}
}, },
"required": ["path"] "required": ["path"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -193,17 +238,17 @@ def get_tools_definition():
"properties": { "properties": {
"path": {"type": "string", "description": "Path to change to"} "path": {"type": "string", "description": "Path to change to"}
}, },
"required": ["path"] "required": ["path"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "getpwd", "name": "getpwd",
"description": "Get the current working directory", "description": "Get the current working directory",
"parameters": {"type": "object", "properties": {}} "parameters": {"type": "object", "properties": {}},
} },
}, },
{ {
"type": "function", "type": "function",
@ -214,11 +259,11 @@ def get_tools_definition():
"type": "object", "type": "object",
"properties": { "properties": {
"key": {"type": "string", "description": "The key"}, "key": {"type": "string", "description": "The key"},
"value": {"type": "string", "description": "The value"} "value": {"type": "string", "description": "The value"},
}, },
"required": ["key", "value"] "required": ["key", "value"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -227,12 +272,10 @@ def get_tools_definition():
"description": "Get a value from the database", "description": "Get a value from the database",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"key": {"type": "string", "description": "The key"}},
"key": {"type": "string", "description": "The key"} "required": ["key"],
}, },
"required": ["key"] },
}
}
}, },
{ {
"type": "function", "type": "function",
@ -244,9 +287,9 @@ def get_tools_definition():
"properties": { "properties": {
"query": {"type": "string", "description": "SQL query"} "query": {"type": "string", "description": "SQL query"}
}, },
"required": ["query"] "required": ["query"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -258,9 +301,9 @@ def get_tools_definition():
"properties": { "properties": {
"query": {"type": "string", "description": "Search query"} "query": {"type": "string", "description": "Search query"}
}, },
"required": ["query"] "required": ["query"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -270,11 +313,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"query": {"type": "string", "description": "Search query for news"} "query": {
"type": "string",
"description": "Search query for news",
}
}, },
"required": ["query"] "required": ["query"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -284,11 +330,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"code": {"type": "string", "description": "Python code to execute"} "code": {
"type": "string",
"description": "Python code to execute",
}
}, },
"required": ["code"] "required": ["code"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -300,9 +349,9 @@ def get_tools_definition():
"properties": { "properties": {
"path": {"type": "string", "description": "Path to index"} "path": {"type": "string", "description": "Path to index"}
}, },
"required": ["path"] "required": ["path"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -312,13 +361,22 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"}, "filepath": {
"old_string": {"type": "string", "description": "String to replace"}, "type": "string",
"new_string": {"type": "string", "description": "Replacement string"} "description": "Path to the file",
},
"old_string": {
"type": "string",
"description": "String to replace",
},
"new_string": {
"type": "string",
"description": "Replacement string",
},
}, },
"required": ["filepath", "old_string", "new_string"] "required": ["filepath", "old_string", "new_string"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -328,12 +386,18 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file to patch"}, "filepath": {
"patch_content": {"type": "string", "description": "The patch content as a string"} "type": "string",
"description": "Path to the file to patch",
},
"patch_content": {
"type": "string",
"description": "The patch content as a string",
},
}, },
"required": ["filepath", "patch_content"] "required": ["filepath", "patch_content"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -343,14 +407,28 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"file1": {"type": "string", "description": "Path to the first file"}, "file1": {
"file2": {"type": "string", "description": "Path to the second file"}, "type": "string",
"fromfile": {"type": "string", "description": "Label for the first file", "default": "file1"}, "description": "Path to the first file",
"tofile": {"type": "string", "description": "Label for the second file", "default": "file2"} },
"file2": {
"type": "string",
"description": "Path to the second file",
},
"fromfile": {
"type": "string",
"description": "Label for the first file",
"default": "file1",
},
"tofile": {
"type": "string",
"description": "Label for the second file",
"default": "file2",
},
}, },
"required": ["file1", "file2"] "required": ["file1", "file2"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -360,11 +438,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"} "filepath": {
"type": "string",
"description": "Path to the file",
}
}, },
"required": ["filepath"] "required": ["filepath"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -374,11 +455,14 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"} "filepath": {
"type": "string",
"description": "Path to the file",
}
}, },
"required": ["filepath"] "required": ["filepath"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -388,14 +472,23 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"}, "filepath": {
"type": "string",
"description": "Path to the file",
},
"text": {"type": "string", "description": "Text to insert"}, "text": {"type": "string", "description": "Text to insert"},
"line": {"type": "integer", "description": "Line number (optional)"}, "line": {
"col": {"type": "integer", "description": "Column number (optional)"} "type": "integer",
"description": "Line number (optional)",
},
"col": {
"type": "integer",
"description": "Column number (optional)",
},
}, },
"required": ["filepath", "text"] "required": ["filepath", "text"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -405,16 +498,26 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"}, "filepath": {
"type": "string",
"description": "Path to the file",
},
"start_line": {"type": "integer", "description": "Start line"}, "start_line": {"type": "integer", "description": "Start line"},
"start_col": {"type": "integer", "description": "Start column"}, "start_col": {"type": "integer", "description": "Start column"},
"end_line": {"type": "integer", "description": "End line"}, "end_line": {"type": "integer", "description": "End line"},
"end_col": {"type": "integer", "description": "End column"}, "end_col": {"type": "integer", "description": "End column"},
"new_text": {"type": "string", "description": "New text"} "new_text": {"type": "string", "description": "New text"},
}, },
"required": ["filepath", "start_line", "start_col", "end_line", "end_col", "new_text"] "required": [
} "filepath",
} "start_line",
"start_col",
"end_line",
"end_col",
"new_text",
],
},
},
}, },
{ {
"type": "function", "type": "function",
@ -424,13 +527,20 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath": {"type": "string", "description": "Path to the file"}, "filepath": {
"type": "string",
"description": "Path to the file",
},
"pattern": {"type": "string", "description": "Regex pattern"}, "pattern": {"type": "string", "description": "Regex pattern"},
"start_line": {"type": "integer", "description": "Start line", "default": 0} "start_line": {
"type": "integer",
"description": "Start line",
"default": 0,
},
}, },
"required": ["filepath", "pattern"] "required": ["filepath", "pattern"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
@ -440,24 +550,31 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"filepath1": {"type": "string", "description": "Path to the original file"}, "filepath1": {
"filepath2": {"type": "string", "description": "Path to the modified file"}, "type": "string",
"format_type": {"type": "string", "description": "Display format: 'unified' or 'side-by-side'", "default": "unified"} "description": "Path to the original file",
},
"filepath2": {
"type": "string",
"description": "Path to the modified file",
},
"format_type": {
"type": "string",
"description": "Display format: 'unified' or 'side-by-side'",
"default": "unified",
},
}, },
"required": ["filepath1", "filepath2"] "required": ["filepath1", "filepath2"],
} },
} },
}, },
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "display_edit_summary", "name": "display_edit_summary",
"description": "Display a summary of all edit operations performed during the session", "description": "Display a summary of all edit operations performed during the session",
"parameters": { "parameters": {"type": "object", "properties": {}},
"type": "object", },
"properties": {}
}
}
}, },
{ {
"type": "function", "type": "function",
@ -467,21 +584,21 @@ def get_tools_definition():
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"show_content": {"type": "boolean", "description": "Show content previews", "default": False} "show_content": {
} "type": "boolean",
} "description": "Show content previews",
} "default": False,
}
},
},
},
}, },
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "clear_edit_tracker", "name": "clear_edit_tracker",
"description": "Clear the edit tracker to start fresh", "description": "Clear the edit tracker to start fresh",
"parameters": { "parameters": {"type": "object", "properties": {}},
"type": "object", },
"properties": {} },
}
}
}
] ]

View File

@ -1,21 +1,23 @@
import os import os
import select
import subprocess import subprocess
import time import time
import select
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
from pr.tools.interactive_control import start_interactive_session
from pr.config import MAX_CONCURRENT_SESSIONS
_processes = {} _processes = {}
def _register_process(pid:int, process):
def _register_process(pid: int, process):
_processes[pid] = process _processes[pid] = process
return _processes return _processes
def _get_process(pid:int):
def _get_process(pid: int):
return _processes.get(pid) return _processes.get(pid)
def kill_process(pid:int):
def kill_process(pid: int):
try: try:
process = _get_process(pid) process = _get_process(pid)
if process: if process:
@ -67,7 +69,7 @@ def tail_process(pid: int, timeout: int = 30):
"status": "success", "status": "success",
"stdout": stdout_content, "stdout": stdout_content,
"stderr": stderr_content, "stderr": stderr_content,
"returncode": process.returncode "returncode": process.returncode,
} }
if time.time() - start_time > timeout_duration: if time.time() - start_time > timeout_duration:
@ -76,10 +78,12 @@ def tail_process(pid: int, timeout: int = 30):
"message": "Process is still running. Call tail_process again to continue monitoring.", "message": "Process is still running. Call tail_process again to continue monitoring.",
"stdout_so_far": stdout_content, "stdout_so_far": stdout_content,
"stderr_so_far": stderr_content, "stderr_so_far": stderr_content,
"pid": pid "pid": pid,
} }
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) ready, _, _ = select.select(
[process.stdout, process.stderr], [], [], 0.1
)
for pipe in ready: for pipe in ready:
if pipe == process.stdout: if pipe == process.stdout:
line = process.stdout.readline() line = process.stdout.readline()
@ -100,7 +104,13 @@ def tail_process(pid: int, timeout: int = 30):
def run_command(command, timeout=30, monitored=False): def run_command(command, timeout=30, monitored=False):
mux_name = None mux_name = None
try: try:
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
_register_process(process.pid, process) _register_process(process.pid, process)
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True) mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
@ -129,7 +139,7 @@ def run_command(command, timeout=30, monitored=False):
"status": "success", "status": "success",
"stdout": stdout_content, "stdout": stdout_content,
"stderr": stderr_content, "stderr": stderr_content,
"returncode": process.returncode "returncode": process.returncode,
} }
if time.time() - start_time > timeout_duration: if time.time() - start_time > timeout_duration:
@ -139,7 +149,7 @@ def run_command(command, timeout=30, monitored=False):
"stdout_so_far": stdout_content, "stdout_so_far": stdout_content,
"stderr_so_far": stderr_content, "stderr_so_far": stderr_content,
"pid": process.pid, "pid": process.pid,
"mux_name": mux_name "mux_name": mux_name,
} }
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
@ -158,6 +168,8 @@ def run_command(command, timeout=30, monitored=False):
if mux_name: if mux_name:
close_multiplexer(mux_name) close_multiplexer(mux_name)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def run_command_interactive(command): def run_command_interactive(command):
try: try:
return_code = os.system(command) return_code = os.system(command)

View File

@ -1,18 +1,23 @@
import time import time
def db_set(key, value, db_conn): def db_set(key, value, db_conn):
if not db_conn: if not db_conn:
return {"status": "error", "error": "Database not initialized"} return {"status": "error", "error": "Database not initialized"}
try: try:
cursor = db_conn.cursor() cursor = db_conn.cursor()
cursor.execute("""INSERT OR REPLACE INTO kv_store (key, value, timestamp) cursor.execute(
VALUES (?, ?, ?)""", (key, value, time.time())) """INSERT OR REPLACE INTO kv_store (key, value, timestamp)
VALUES (?, ?, ?)""",
(key, value, time.time()),
)
db_conn.commit() db_conn.commit()
return {"status": "success", "message": f"Set {key}"} return {"status": "success", "message": f"Set {key}"}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def db_get(key, db_conn): def db_get(key, db_conn):
if not db_conn: if not db_conn:
return {"status": "error", "error": "Database not initialized"} return {"status": "error", "error": "Database not initialized"}
@ -28,6 +33,7 @@ def db_get(key, db_conn):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def db_query(query, db_conn): def db_query(query, db_conn):
if not db_conn: if not db_conn:
return {"status": "error", "error": "Database not initialized"} return {"status": "error", "error": "Database not initialized"}
@ -36,9 +42,11 @@ def db_query(query, db_conn):
cursor = db_conn.cursor() cursor = db_conn.cursor()
cursor.execute(query) cursor.execute(query)
if query.strip().upper().startswith('SELECT'): if query.strip().upper().startswith("SELECT"):
results = cursor.fetchall() results = cursor.fetchall()
columns = [desc[0] for desc in cursor.description] if cursor.description else [] columns = (
[desc[0] for desc in cursor.description] if cursor.description else []
)
return {"status": "success", "columns": columns, "rows": results} return {"status": "success", "columns": columns, "rows": results}
else: else:
db_conn.commit() db_conn.commit()

View File

@ -1,18 +1,21 @@
from pr.editor import RPEditor
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
from ..ui.diff_display import display_diff, get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
from ..tools.patch import display_content_diff
import os import os
import os.path import os.path
from pr.editor import RPEditor
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
from ..tools.patch import display_content_diff
from ..ui.edit_feedback import track_edit, tracker
_editors = {} _editors = {}
def get_editor(filepath): def get_editor(filepath):
if filepath not in _editors: if filepath not in _editors:
_editors[filepath] = RPEditor(filepath) _editors[filepath] = RPEditor(filepath)
return _editors[filepath] return _editors[filepath]
def close_editor(filepath): def close_editor(filepath):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -29,6 +32,7 @@ def close_editor(filepath):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def open_editor(filepath): def open_editor(filepath):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -39,21 +43,28 @@ def open_editor(filepath):
mux_name, mux = create_multiplexer(mux_name, show_output=True) mux_name, mux = create_multiplexer(mux_name, show_output=True)
mux.write_stdout(f"Opened editor for: {path}\n") mux.write_stdout(f"Opened editor for: {path}\n")
return {"status": "success", "message": f"Editor opened for {path}", "mux_name": mux_name} return {
"status": "success",
"message": f"Editor opened for {path}",
"mux_name": mux_name,
}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
old_content = "" old_content = ""
if os.path.exists(path): if os.path.exists(path):
with open(path, 'r') as f: with open(path) as f:
old_content = f.read() old_content = f.read()
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0) position = (line if line is not None else 0) * 1000 + (
operation = track_edit('INSERT', filepath, start_pos=position, content=text) col if col is not None else 0
)
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
editor = get_editor(path) editor = get_editor(path)
@ -65,12 +76,16 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
mux_name = f"editor-{path}" mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name) mux = get_multiplexer(mux_name)
if mux: if mux:
location = f" at line {line}, col {col}" if line is not None and col is not None else "" location = (
f" at line {line}, col {col}"
if line is not None and col is not None
else ""
)
preview = text[:50] + "..." if len(text) > 50 else text preview = text[:50] + "..." if len(text) > 50 else text
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n") mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")
if show_diff and old_content: if show_diff and old_content:
with open(path, 'r') as f: with open(path) as f:
new_content = f.read() new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath) diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success": if diff_result["status"] == "success":
@ -81,23 +96,32 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
close_editor(filepath) close_editor(filepath)
return result return result
except Exception as e: except Exception as e:
if 'operation' in locals(): if "operation" in locals():
tracker.mark_failed(operation) tracker.mark_failed(operation)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True):
def editor_replace_text(
filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True
):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
old_content = "" old_content = ""
if os.path.exists(path): if os.path.exists(path):
with open(path, 'r') as f: with open(path) as f:
old_content = f.read() old_content = f.read()
start_pos = start_line * 1000 + start_col start_pos = start_line * 1000 + start_col
end_pos = end_line * 1000 + end_col end_pos = end_line * 1000 + end_col
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos, operation = track_edit(
content=new_text, old_content=old_content) "REPLACE",
filepath,
start_pos=start_pos,
end_pos=end_pos,
content=new_text,
old_content=old_content,
)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
editor = get_editor(path) editor = get_editor(path)
@ -108,10 +132,12 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
mux = get_multiplexer(mux_name) mux = get_multiplexer(mux_name)
if mux: if mux:
preview = new_text[:50] + "..." if len(new_text) > 50 else new_text preview = new_text[:50] + "..." if len(new_text) > 50 else new_text
mux.write_stdout(f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n") mux.write_stdout(
f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n"
)
if show_diff and old_content: if show_diff and old_content:
with open(path, 'r') as f: with open(path) as f:
new_content = f.read() new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath) diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success": if diff_result["status"] == "success":
@ -122,10 +148,11 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
close_editor(filepath) close_editor(filepath)
return result return result
except Exception as e: except Exception as e:
if 'operation' in locals(): if "operation" in locals():
tracker.mark_failed(operation) tracker.mark_failed(operation)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def editor_search(filepath, pattern, start_line=0): def editor_search(filepath, pattern, start_line=0):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -135,7 +162,9 @@ def editor_search(filepath, pattern, start_line=0):
mux_name = f"editor-{path}" mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name) mux = get_multiplexer(mux_name)
if mux: if mux:
mux.write_stdout(f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n") mux.write_stdout(
f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n"
)
result = {"status": "success", "results": results} result = {"status": "success", "results": results}
close_editor(filepath) close_editor(filepath)

View File

@ -1,31 +1,36 @@
import os
import hashlib import hashlib
import os
import time import time
from typing import Dict
from pr.editor import RPEditor from pr.editor import RPEditor
from ..ui.diff_display import display_diff, get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
from ..tools.patch import display_content_diff from ..tools.patch import display_content_diff
from ..ui.diff_display import get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
_id = 0 _id = 0
def get_uid(): def get_uid():
global _id global _id
_id += 3 _id += 3
return _id return _id
def read_file(filepath, db_conn=None): def read_file(filepath, db_conn=None):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
with open(path, 'r') as f: with open(path) as f:
content = f.read() content = f.read()
if db_conn: if db_conn:
from pr.tools.database import db_set from pr.tools.database import db_set
db_set("read:" + path, "true", db_conn) db_set("read:" + path, "true", db_conn)
return {"status": "success", "content": content} return {"status": "success", "content": content}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def write_file(filepath, content, db_conn=None, show_diff=True): def write_file(filepath, content, db_conn=None, show_diff=True):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -34,15 +39,24 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
if not is_new_file and db_conn: if not is_new_file and db_conn:
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true": if (
return {"status": "error", "error": "File must be read before writing. Please read the file first."} read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
if not is_new_file: if not is_new_file:
with open(path, 'r') as f: with open(path) as f:
old_content = f.read() old_content = f.read()
operation = track_edit('WRITE', filepath, content=content, old_content=old_content) operation = track_edit(
"WRITE", filepath, content=content, old_content=old_content
)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
if show_diff and not is_new_file: if show_diff and not is_new_file:
@ -59,13 +73,18 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
cursor = db_conn.cursor() cursor = db_conn.cursor()
file_hash = hashlib.md5(old_content.encode()).hexdigest() file_hash = hashlib.md5(old_content.encode()).hexdigest()
cursor.execute("SELECT MAX(version) FROM file_versions WHERE filepath = ?", (filepath,)) cursor.execute(
"SELECT MAX(version) FROM file_versions WHERE filepath = ?",
(filepath,),
)
result = cursor.fetchone() result = cursor.fetchone()
version = (result[0] + 1) if result[0] else 1 version = (result[0] + 1) if result[0] else 1
cursor.execute("""INSERT INTO file_versions (filepath, content, hash, timestamp, version) cursor.execute(
"""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
VALUES (?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?)""",
(filepath, old_content, file_hash, time.time(), version)) (filepath, old_content, file_hash, time.time(), version),
)
db_conn.commit() db_conn.commit()
except Exception: except Exception:
pass pass
@ -79,10 +98,11 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
return {"status": "success", "message": message} return {"status": "success", "message": message}
except Exception as e: except Exception as e:
if 'operation' in locals(): if "operation" in locals():
tracker.mark_failed(operation) tracker.mark_failed(operation)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def list_directory(path=".", recursive=False): def list_directory(path=".", recursive=False):
try: try:
path = os.path.expanduser(path) path = os.path.expanduser(path)
@ -91,21 +111,36 @@ def list_directory(path=".", recursive=False):
for root, dirs, files in os.walk(path): for root, dirs, files in os.walk(path):
for name in files: for name in files:
item_path = os.path.join(root, name) item_path = os.path.join(root, name)
items.append({"path": item_path, "type": "file", "size": os.path.getsize(item_path)}) items.append(
{
"path": item_path,
"type": "file",
"size": os.path.getsize(item_path),
}
)
for name in dirs: for name in dirs:
items.append({"path": os.path.join(root, name), "type": "directory"}) items.append(
{"path": os.path.join(root, name), "type": "directory"}
)
else: else:
for item in os.listdir(path): for item in os.listdir(path):
item_path = os.path.join(path, item) item_path = os.path.join(path, item)
items.append({ items.append(
"name": item, {
"type": "directory" if os.path.isdir(item_path) else "file", "name": item,
"size": os.path.getsize(item_path) if os.path.isfile(item_path) else None "type": "directory" if os.path.isdir(item_path) else "file",
}) "size": (
os.path.getsize(item_path)
if os.path.isfile(item_path)
else None
),
}
)
return {"status": "success", "items": items} return {"status": "success", "items": items}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def mkdir(path): def mkdir(path):
try: try:
os.makedirs(os.path.expanduser(path), exist_ok=True) os.makedirs(os.path.expanduser(path), exist_ok=True)
@ -113,6 +148,7 @@ def mkdir(path):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def chdir(path): def chdir(path):
try: try:
os.chdir(os.path.expanduser(path)) os.chdir(os.path.expanduser(path))
@ -120,16 +156,32 @@ def chdir(path):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def getpwd(): def getpwd():
try: try:
return {"status": "success", "path": os.getcwd()} return {"status": "success", "path": os.getcwd()}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def index_source_directory(path): def index_source_directory(path):
extensions = [ extensions = [
".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp", ".py",
".html", ".css", ".json", ".xml", ".md", ".sh", ".rb", ".go" ".js",
".ts",
".java",
".cpp",
".c",
".h",
".hpp",
".html",
".css",
".json",
".xml",
".md",
".sh",
".rb",
".go",
] ]
source_files = [] source_files = []
try: try:
@ -138,18 +190,16 @@ def index_source_directory(path):
if any(file.endswith(ext) for ext in extensions): if any(file.endswith(ext) for ext in extensions):
filepath = os.path.join(root, file) filepath = os.path.join(root, file)
try: try:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, encoding="utf-8") as f:
content = f.read() content = f.read()
source_files.append({ source_files.append({"path": filepath, "content": content})
"path": filepath,
"content": content
})
except Exception: except Exception:
continue continue
return {"status": "success", "indexed_files": source_files} return {"status": "success", "indexed_files": source_files}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def search_replace(filepath, old_string, new_string, db_conn=None): def search_replace(filepath, old_string, new_string, db_conn=None):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -157,25 +207,38 @@ def search_replace(filepath, old_string, new_string, db_conn=None):
return {"status": "error", "error": "File does not exist"} return {"status": "error", "error": "File does not exist"}
if db_conn: if db_conn:
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true": if (
return {"status": "error", "error": "File must be read before writing. Please read the file first."} read_status.get("status") != "success"
with open(path, 'r') as f: or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
with open(path) as f:
content = f.read() content = f.read()
content = content.replace(old_string, new_string) content = content.replace(old_string, new_string)
with open(path, 'w') as f: with open(path, "w") as f:
f.write(content) f.write(content)
return {"status": "success", "message": f"Replaced '{old_string}' with '{new_string}' in {path}"} return {
"status": "success",
"message": f"Replaced '{old_string}' with '{new_string}' in {path}",
}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
_editors = {} _editors = {}
def get_editor(filepath): def get_editor(filepath):
if filepath not in _editors: if filepath not in _editors:
_editors[filepath] = RPEditor(filepath) _editors[filepath] = RPEditor(filepath)
return _editors[filepath] return _editors[filepath]
def close_editor(filepath): def close_editor(filepath):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -185,6 +248,7 @@ def close_editor(filepath):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def open_editor(filepath): def open_editor(filepath):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
@ -194,22 +258,34 @@ def open_editor(filepath):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
def editor_insert_text(
filepath, text, line=None, col=None, show_diff=True, db_conn=None
):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
if db_conn: if db_conn:
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true": if (
return {"status": "error", "error": "File must be read before writing. Please read the file first."} read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
old_content = "" old_content = ""
if os.path.exists(path): if os.path.exists(path):
with open(path, 'r') as f: with open(path) as f:
old_content = f.read() old_content = f.read()
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0) position = (line if line is not None else 0) * 1000 + (
operation = track_edit('INSERT', filepath, start_pos=position, content=text) col if col is not None else 0
)
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
editor = get_editor(path) editor = get_editor(path)
@ -219,7 +295,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
editor.save_file() editor.save_file()
if show_diff and old_content: if show_diff and old_content:
with open(path, 'r') as f: with open(path) as f:
new_content = f.read() new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath) diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success": if diff_result["status"] == "success":
@ -228,28 +304,51 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
tracker.mark_completed(operation) tracker.mark_completed(operation)
return {"status": "success", "message": f"Inserted text in {path}"} return {"status": "success", "message": f"Inserted text in {path}"}
except Exception as e: except Exception as e:
if 'operation' in locals(): if "operation" in locals():
tracker.mark_failed(operation) tracker.mark_failed(operation)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True, db_conn=None):
def editor_replace_text(
filepath,
start_line,
start_col,
end_line,
end_col,
new_text,
show_diff=True,
db_conn=None,
):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
if db_conn: if db_conn:
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true": if (
return {"status": "error", "error": "File must be read before writing. Please read the file first."} read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
old_content = "" old_content = ""
if os.path.exists(path): if os.path.exists(path):
with open(path, 'r') as f: with open(path) as f:
old_content = f.read() old_content = f.read()
start_pos = start_line * 1000 + start_col start_pos = start_line * 1000 + start_col
end_pos = end_line * 1000 + end_col end_pos = end_line * 1000 + end_col
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos, operation = track_edit(
content=new_text, old_content=old_content) "REPLACE",
filepath,
start_pos=start_pos,
end_pos=end_pos,
content=new_text,
old_content=old_content,
)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
editor = get_editor(path) editor = get_editor(path)
@ -257,7 +356,7 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
editor.save_file() editor.save_file()
if show_diff and old_content: if show_diff and old_content:
with open(path, 'r') as f: with open(path) as f:
new_content = f.read() new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath) diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success": if diff_result["status"] == "success":
@ -266,22 +365,25 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
tracker.mark_completed(operation) tracker.mark_completed(operation)
return {"status": "success", "message": f"Replaced text in {path}"} return {"status": "success", "message": f"Replaced text in {path}"}
except Exception as e: except Exception as e:
if 'operation' in locals(): if "operation" in locals():
tracker.mark_failed(operation) tracker.mark_failed(operation)
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def display_edit_summary(): def display_edit_summary():
from ..ui.edit_feedback import display_edit_summary from ..ui.edit_feedback import display_edit_summary
return display_edit_summary() return display_edit_summary()
def display_edit_timeline(show_content=False): def display_edit_timeline(show_content=False):
from ..ui.edit_feedback import display_edit_timeline from ..ui.edit_feedback import display_edit_timeline
return display_edit_timeline(show_content) return display_edit_timeline(show_content)
def clear_edit_tracker(): def clear_edit_tracker():
from ..ui.edit_feedback import clear_tracker from ..ui.edit_feedback import clear_tracker
clear_tracker() clear_tracker()
return {"status": "success", "message": "Edit tracker cleared"} return {"status": "success", "message": "Edit tracker cleared"}

View File

@ -1,9 +1,15 @@
import subprocess import subprocess
import threading import threading
import time
from pr.multiplexer import create_multiplexer, get_multiplexer, close_multiplexer, get_all_multiplexer_states
def start_interactive_session(command, session_name=None, process_type='generic'): from pr.multiplexer import (
close_multiplexer,
create_multiplexer,
get_all_multiplexer_states,
get_multiplexer,
)
def start_interactive_session(command, session_name=None, process_type="generic"):
""" """
Start an interactive session in a dedicated multiplexer. Start an interactive session in a dedicated multiplexer.
@ -16,7 +22,7 @@ def start_interactive_session(command, session_name=None, process_type='generic'
session_name: The name of the created session session_name: The name of the created session
""" """
name, mux = create_multiplexer(session_name) name, mux = create_multiplexer(session_name)
mux.update_metadata('process_type', process_type) mux.update_metadata("process_type", process_type)
# Start the process # Start the process
if isinstance(command, str): if isinstance(command, str):
@ -29,19 +35,23 @@ def start_interactive_session(command, session_name=None, process_type='generic'
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
bufsize=1 bufsize=1,
) )
mux.process = process mux.process = process
mux.update_metadata('pid', process.pid) mux.update_metadata("pid", process.pid)
# Set process type and handler # Set process type and handler
detected_type = detect_process_type(command) detected_type = detect_process_type(command)
mux.set_process_type(detected_type) mux.set_process_type(detected_type)
# Start output readers # Start output readers
stdout_thread = threading.Thread(target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True) stdout_thread = threading.Thread(
stderr_thread = threading.Thread(target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True) target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True
)
stderr_thread = threading.Thread(
target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True
)
stdout_thread.start() stdout_thread.start()
stderr_thread.start() stderr_thread.start()
@ -54,15 +64,17 @@ def start_interactive_session(command, session_name=None, process_type='generic'
close_multiplexer(name) close_multiplexer(name)
raise e raise e
def _read_output(stream, write_func): def _read_output(stream, write_func):
"""Read from a stream and write to multiplexer buffer.""" """Read from a stream and write to multiplexer buffer."""
try: try:
for line in iter(stream.readline, ''): for line in iter(stream.readline, ""):
if line: if line:
write_func(line.rstrip('\n')) write_func(line.rstrip("\n"))
except Exception as e: except Exception as e:
print(f"Error reading output: {e}") print(f"Error reading output: {e}")
def send_input_to_session(session_name, input_data): def send_input_to_session(session_name, input_data):
""" """
Send input to an interactive session. Send input to an interactive session.
@ -75,15 +87,16 @@ def send_input_to_session(session_name, input_data):
if not mux: if not mux:
raise ValueError(f"Session {session_name} not found") raise ValueError(f"Session {session_name} not found")
if not hasattr(mux, 'process') or mux.process.poll() is not None: if not hasattr(mux, "process") or mux.process.poll() is not None:
raise ValueError(f"Session {session_name} is not active") raise ValueError(f"Session {session_name} is not active")
try: try:
mux.process.stdin.write(input_data + '\n') mux.process.stdin.write(input_data + "\n")
mux.process.stdin.flush() mux.process.stdin.flush()
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to send input to session {session_name}: {e}") raise RuntimeError(f"Failed to send input to session {session_name}: {e}")
def read_session_output(session_name, lines=None): def read_session_output(session_name, lines=None):
""" """
Read output from a session. Read output from a session.
@ -102,14 +115,12 @@ def read_session_output(session_name, lines=None):
output = mux.get_all_output() output = mux.get_all_output()
if lines is not None: if lines is not None:
# Return last N lines # Return last N lines
stdout_lines = output['stdout'].split('\n')[-lines:] if output['stdout'] else [] stdout_lines = output["stdout"].split("\n")[-lines:] if output["stdout"] else []
stderr_lines = output['stderr'].split('\n')[-lines:] if output['stderr'] else [] stderr_lines = output["stderr"].split("\n")[-lines:] if output["stderr"] else []
output = { output = {"stdout": "\n".join(stdout_lines), "stderr": "\n".join(stderr_lines)}
'stdout': '\n'.join(stdout_lines),
'stderr': '\n'.join(stderr_lines)
}
return output return output
def list_active_sessions(): def list_active_sessions():
""" """
List all active interactive sessions. List all active interactive sessions.
@ -119,6 +130,7 @@ def list_active_sessions():
""" """
return get_all_multiplexer_states() return get_all_multiplexer_states()
def get_session_status(session_name): def get_session_status(session_name):
""" """
Get detailed status of a session. Get detailed status of a session.
@ -134,15 +146,16 @@ def get_session_status(session_name):
return None return None
status = mux.get_metadata() status = mux.get_metadata()
status['is_active'] = hasattr(mux, 'process') and mux.process.poll() is None status["is_active"] = hasattr(mux, "process") and mux.process.poll() is None
if status['is_active']: if status["is_active"]:
status['pid'] = mux.process.pid status["pid"] = mux.process.pid
status['output_summary'] = { status["output_summary"] = {
'stdout_lines': len(mux.stdout_buffer), "stdout_lines": len(mux.stdout_buffer),
'stderr_lines': len(mux.stderr_buffer) "stderr_lines": len(mux.stderr_buffer),
} }
return status return status
def close_interactive_session(session_name): def close_interactive_session(session_name):
""" """
Close an interactive session. Close an interactive session.

View File

@ -1,13 +1,17 @@
import os import os
from typing import Dict, Any, List
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
import time import time
import uuid import uuid
from typing import Any, Dict
def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]: from pr.memory.knowledge_store import KnowledgeEntry, KnowledgeStore
def add_knowledge_entry(
category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None
) -> Dict[str, Any]:
"""Add a new entry to the knowledge base.""" """Add a new entry to the knowledge base."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
if entry_id is None: if entry_id is None:
@ -19,7 +23,7 @@ def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] =
content=content, content=content,
metadata=metadata or {}, metadata=metadata or {},
created_at=time.time(), created_at=time.time(),
updated_at=time.time() updated_at=time.time(),
) )
store.add_entry(entry) store.add_entry(entry)
@ -27,10 +31,11 @@ def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] =
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def get_knowledge_entry(entry_id: str) -> Dict[str, Any]: def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
"""Retrieve a knowledge entry by ID.""" """Retrieve a knowledge entry by ID."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
entry = store.get_entry(entry_id) entry = store.get_entry(entry_id)
@ -41,10 +46,13 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
def search_knowledge(
query: str, category: str = None, top_k: int = 5
) -> Dict[str, Any]:
"""Search the knowledge base semantically.""" """Search the knowledge base semantically."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
entries = store.search_entries(query, category, top_k) entries = store.search_entries(query, category, top_k)
@ -53,10 +61,11 @@ def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[s
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]: def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
"""Get knowledge entries by category.""" """Get knowledge entries by category."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
entries = store.get_by_category(category, limit) entries = store.get_by_category(category, limit)
@ -65,21 +74,29 @@ def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
def update_knowledge_importance(
entry_id: str, importance_score: float
) -> Dict[str, Any]:
"""Update the importance score of a knowledge entry.""" """Update the importance score of a knowledge entry."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
store.update_importance(entry_id, importance_score) store.update_importance(entry_id, importance_score)
return {"status": "success", "entry_id": entry_id, "importance_score": importance_score} return {
"status": "success",
"entry_id": entry_id,
"importance_score": importance_score,
}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]: def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]:
"""Delete a knowledge entry.""" """Delete a knowledge entry."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
success = store.delete_entry(entry_id) success = store.delete_entry(entry_id)
@ -87,10 +104,11 @@ def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]:
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def get_knowledge_statistics() -> Dict[str, Any]: def get_knowledge_statistics() -> Dict[str, Any]:
"""Get statistics about the knowledge base.""" """Get statistics about the knowledge base."""
try: try:
db_path = os.path.expanduser('~/.assistant_db.sqlite') db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path) store = KnowledgeStore(db_path)
stats = store.get_statistics() stats = store.get_statistics()

View File

@ -1,23 +1,37 @@
import os
import tempfile
import subprocess
import difflib import difflib
from ..ui.diff_display import display_diff, get_diff_stats, DiffDisplay import os
import subprocess
import tempfile
from ..ui.diff_display import display_diff, get_diff_stats
def apply_patch(filepath, patch_content, db_conn=None): def apply_patch(filepath, patch_content, db_conn=None):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
if db_conn: if db_conn:
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true": if (
return {"status": "error", "error": "File must be read before writing. Please read the file first."} read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
# Write patch to temp file # Write patch to temp file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.patch') as f: with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".patch") as f:
f.write(patch_content) f.write(patch_content)
patch_file = f.name patch_file = f.name
# Run patch command # Run patch command
result = subprocess.run(['patch', path, patch_file], capture_output=True, text=True, cwd=os.path.dirname(path)) result = subprocess.run(
["patch", path, patch_file],
capture_output=True,
text=True,
cwd=os.path.dirname(path),
)
os.unlink(patch_file) os.unlink(patch_file)
if result.returncode == 0: if result.returncode == 0:
return {"status": "success", "output": result.stdout.strip()} return {"status": "success", "output": result.stdout.strip()}
@ -26,11 +40,14 @@ def apply_patch(filepath, patch_content, db_conn=None):
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, format_type='unified'):
def create_diff(
file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified"
):
try: try:
path1 = os.path.expanduser(file1) path1 = os.path.expanduser(file1)
path2 = os.path.expanduser(file2) path2 = os.path.expanduser(file2)
with open(path1, 'r') as f1, open(path2, 'r') as f2: with open(path1) as f1, open(path2) as f2:
content1 = f1.read() content1 = f1.read()
content2 = f2.read() content2 = f2.read()
@ -39,53 +56,51 @@ def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, fo
stats = get_diff_stats(content1, content2) stats = get_diff_stats(content1, content2)
lines1 = content1.splitlines(keepends=True) lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True) lines2 = content2.splitlines(keepends=True)
plain_diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)) plain_diff = list(
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
)
return { return {
"status": "success", "status": "success",
"diff": ''.join(plain_diff), "diff": "".join(plain_diff),
"visual_diff": visual_diff, "visual_diff": visual_diff,
"stats": stats "stats": stats,
} }
else: else:
lines1 = content1.splitlines(keepends=True) lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True) lines2 = content2.splitlines(keepends=True)
diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)) diff = list(
return {"status": "success", "diff": ''.join(diff)} difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
)
return {"status": "success", "diff": "".join(diff)}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def display_file_diff(filepath1, filepath2, format_type='unified', context_lines=3): def display_file_diff(filepath1, filepath2, format_type="unified", context_lines=3):
try: try:
path1 = os.path.expanduser(filepath1) path1 = os.path.expanduser(filepath1)
path2 = os.path.expanduser(filepath2) path2 = os.path.expanduser(filepath2)
with open(path1, 'r') as f1: with open(path1) as f1:
old_content = f1.read() old_content = f1.read()
with open(path2, 'r') as f2: with open(path2) as f2:
new_content = f2.read() new_content = f2.read()
visual_diff = display_diff(old_content, new_content, filepath1, format_type) visual_diff = display_diff(old_content, new_content, filepath1, format_type)
stats = get_diff_stats(old_content, new_content) stats = get_diff_stats(old_content, new_content)
return { return {"status": "success", "visual_diff": visual_diff, "stats": stats}
"status": "success",
"visual_diff": visual_diff,
"stats": stats
}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def display_content_diff(old_content, new_content, filename='file', format_type='unified'): def display_content_diff(
old_content, new_content, filename="file", format_type="unified"
):
try: try:
visual_diff = display_diff(old_content, new_content, filename, format_type) visual_diff = display_diff(old_content, new_content, filename, format_type)
stats = get_diff_stats(old_content, new_content) stats = get_diff_stats(old_content, new_content)
return { return {"status": "success", "visual_diff": visual_diff, "stats": stats}
"status": "success",
"visual_diff": visual_diff,
"stats": stats
}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}

View File

@ -1,14 +1,13 @@
import re
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class ProcessHandler(ABC): class ProcessHandler(ABC):
"""Base class for process-specific handlers.""" """Base class for process-specific handlers."""
def __init__(self, multiplexer): def __init__(self, multiplexer):
self.multiplexer = multiplexer self.multiplexer = multiplexer
self.state_machine = {} self.state_machine = {}
self.current_state = 'initial' self.current_state = "initial"
self.prompt_patterns = [] self.prompt_patterns = []
self.response_suggestions = {} self.response_suggestions = {}
@ -27,7 +26,8 @@ class ProcessHandler(ABC):
def is_waiting_for_input(self): def is_waiting_for_input(self):
"""Check if process appears to be waiting for input.""" """Check if process appears to be waiting for input."""
return self.current_state in ['waiting_confirmation', 'waiting_input'] return self.current_state in ["waiting_confirmation", "waiting_input"]
class AptHandler(ProcessHandler): class AptHandler(ProcessHandler):
"""Handler for apt package manager interactions.""" """Handler for apt package manager interactions."""
@ -35,230 +35,238 @@ class AptHandler(ProcessHandler):
def __init__(self, multiplexer): def __init__(self, multiplexer):
super().__init__(multiplexer) super().__init__(multiplexer)
self.state_machine = { self.state_machine = {
'initial': ['running_command'], "initial": ["running_command"],
'running_command': ['waiting_confirmation', 'completed'], "running_command": ["waiting_confirmation", "completed"],
'waiting_confirmation': ['confirmed', 'cancelled'], "waiting_confirmation": ["confirmed", "cancelled"],
'confirmed': ['installing', 'completed'], "confirmed": ["installing", "completed"],
'installing': ['completed', 'error'], "installing": ["completed", "error"],
'completed': [], "completed": [],
'error': [], "error": [],
'cancelled': [] "cancelled": [],
} }
self.prompt_patterns = [ self.prompt_patterns = [
(r'Do you want to continue\?', 'confirmation'), (r"Do you want to continue\?", "confirmation"),
(r'After this operation.*installed\.', 'size_info'), (r"After this operation.*installed\.", "size_info"),
(r'Need to get.*B of archives\.', 'download_info'), (r"Need to get.*B of archives\.", "download_info"),
(r'Unpacking.*Configuring', 'configuring'), (r"Unpacking.*Configuring", "configuring"),
(r'Setting up', 'setting_up'), (r"Setting up", "setting_up"),
(r'E:\s', 'error') (r"E:\s", "error"),
] ]
def get_process_type(self): def get_process_type(self):
return 'apt' return "apt"
def update_state(self, output): def update_state(self, output):
"""Update state based on apt output patterns.""" """Update state based on apt output patterns."""
output_lower = output.lower() output_lower = output.lower()
# Check for completion # Check for completion
if 'processing triggers' in output_lower or 'done' in output_lower: if "processing triggers" in output_lower or "done" in output_lower:
self.current_state = 'completed' self.current_state = "completed"
# Check for confirmation prompts # Check for confirmation prompts
elif 'do you want to continue' in output_lower: elif "do you want to continue" in output_lower:
self.current_state = 'waiting_confirmation' self.current_state = "waiting_confirmation"
# Check for installation progress # Check for installation progress
elif 'setting up' in output_lower or 'unpacking' in output_lower: elif "setting up" in output_lower or "unpacking" in output_lower:
self.current_state = 'installing' self.current_state = "installing"
# Check for errors # Check for errors
elif 'e:' in output_lower or 'error' in output_lower: elif "e:" in output_lower or "error" in output_lower:
self.current_state = 'error' self.current_state = "error"
def get_prompt_suggestions(self): def get_prompt_suggestions(self):
"""Return suggested responses for apt prompts.""" """Return suggested responses for apt prompts."""
suggestions = super().get_prompt_suggestions() suggestions = super().get_prompt_suggestions()
if self.current_state == 'waiting_confirmation': if self.current_state == "waiting_confirmation":
suggestions.extend(['y', 'yes', 'n', 'no']) suggestions.extend(["y", "yes", "n", "no"])
return suggestions return suggestions
class VimHandler(ProcessHandler): class VimHandler(ProcessHandler):
"""Handler for vim editor interactions.""" """Handler for vim editor interactions."""
def __init__(self, multiplexer): def __init__(self, multiplexer):
super().__init__(multiplexer) super().__init__(multiplexer)
self.state_machine = { self.state_machine = {
'initial': ['normal_mode', 'insert_mode'], "initial": ["normal_mode", "insert_mode"],
'normal_mode': ['insert_mode', 'command_mode', 'visual_mode'], "normal_mode": ["insert_mode", "command_mode", "visual_mode"],
'insert_mode': ['normal_mode'], "insert_mode": ["normal_mode"],
'command_mode': ['normal_mode'], "command_mode": ["normal_mode"],
'visual_mode': ['normal_mode'], "visual_mode": ["normal_mode"],
'exiting': [] "exiting": [],
} }
self.prompt_patterns = [ self.prompt_patterns = [
(r'-- INSERT --', 'insert_mode'), (r"-- INSERT --", "insert_mode"),
(r'-- VISUAL --', 'visual_mode'), (r"-- VISUAL --", "visual_mode"),
(r':', 'command_mode'), (r":", "command_mode"),
(r'Press ENTER', 'waiting_enter'), (r"Press ENTER", "waiting_enter"),
(r'Saved', 'saved') (r"Saved", "saved"),
] ]
self.mode_indicators = { self.mode_indicators = {
'insert': '-- INSERT --', "insert": "-- INSERT --",
'visual': '-- VISUAL --', "visual": "-- VISUAL --",
'command': ':' "command": ":",
} }
def get_process_type(self): def get_process_type(self):
return 'vim' return "vim"
def update_state(self, output): def update_state(self, output):
"""Update state based on vim mode indicators.""" """Update state based on vim mode indicators."""
if '-- INSERT --' in output: if "-- INSERT --" in output:
self.current_state = 'insert_mode' self.current_state = "insert_mode"
elif '-- VISUAL --' in output: elif "-- VISUAL --" in output:
self.current_state = 'visual_mode' self.current_state = "visual_mode"
elif output.strip().endswith(':'): elif output.strip().endswith(":"):
self.current_state = 'command_mode' self.current_state = "command_mode"
elif 'Press ENTER' in output: elif "Press ENTER" in output:
self.current_state = 'waiting_enter' self.current_state = "waiting_enter"
else: else:
# Default to normal mode if no specific indicators # Default to normal mode if no specific indicators
self.current_state = 'normal_mode' self.current_state = "normal_mode"
def get_prompt_suggestions(self): def get_prompt_suggestions(self):
"""Return suggested commands for vim modes.""" """Return suggested commands for vim modes."""
suggestions = super().get_prompt_suggestions() suggestions = super().get_prompt_suggestions()
if self.current_state == 'command_mode': if self.current_state == "command_mode":
suggestions.extend(['w', 'q', 'wq', 'q!', 'w!']) suggestions.extend(["w", "q", "wq", "q!", "w!"])
elif self.current_state == 'normal_mode': elif self.current_state == "normal_mode":
suggestions.extend(['i', 'a', 'o', 'dd', ':w', ':q']) suggestions.extend(["i", "a", "o", "dd", ":w", ":q"])
elif self.current_state == 'waiting_enter': elif self.current_state == "waiting_enter":
suggestions.extend(['\n']) suggestions.extend(["\n"])
return suggestions return suggestions
class SSHHandler(ProcessHandler): class SSHHandler(ProcessHandler):
"""Handler for SSH connection interactions.""" """Handler for SSH connection interactions."""
def __init__(self, multiplexer): def __init__(self, multiplexer):
super().__init__(multiplexer) super().__init__(multiplexer)
self.state_machine = { self.state_machine = {
'initial': ['connecting'], "initial": ["connecting"],
'connecting': ['auth_prompt', 'connected', 'failed'], "connecting": ["auth_prompt", "connected", "failed"],
'auth_prompt': ['connected', 'failed'], "auth_prompt": ["connected", "failed"],
'connected': ['shell', 'disconnected'], "connected": ["shell", "disconnected"],
'shell': ['disconnected'], "shell": ["disconnected"],
'failed': [], "failed": [],
'disconnected': [] "disconnected": [],
} }
self.prompt_patterns = [ self.prompt_patterns = [
(r'password:', 'password_prompt'), (r"password:", "password_prompt"),
(r'yes/no', 'host_key_prompt'), (r"yes/no", "host_key_prompt"),
(r'Permission denied', 'auth_failed'), (r"Permission denied", "auth_failed"),
(r'Welcome to', 'connected'), (r"Welcome to", "connected"),
(r'\$', 'shell_prompt'), (r"\$", "shell_prompt"),
(r'\#', 'root_shell_prompt'), (r"\#", "root_shell_prompt"),
(r'Connection closed', 'disconnected') (r"Connection closed", "disconnected"),
] ]
def get_process_type(self): def get_process_type(self):
return 'ssh' return "ssh"
def update_state(self, output): def update_state(self, output):
"""Update state based on SSH connection output.""" """Update state based on SSH connection output."""
output_lower = output.lower() output_lower = output.lower()
if 'permission denied' in output_lower: if "permission denied" in output_lower:
self.current_state = 'failed' self.current_state = "failed"
elif 'password:' in output_lower: elif "password:" in output_lower:
self.current_state = 'auth_prompt' self.current_state = "auth_prompt"
elif 'yes/no' in output_lower: elif "yes/no" in output_lower:
self.current_state = 'auth_prompt' self.current_state = "auth_prompt"
elif 'welcome to' in output_lower or 'last login' in output_lower: elif "welcome to" in output_lower or "last login" in output_lower:
self.current_state = 'connected' self.current_state = "connected"
elif output.strip().endswith('$') or output.strip().endswith('#'): elif output.strip().endswith("$") or output.strip().endswith("#"):
self.current_state = 'shell' self.current_state = "shell"
elif 'connection closed' in output_lower: elif "connection closed" in output_lower:
self.current_state = 'disconnected' self.current_state = "disconnected"
def get_prompt_suggestions(self): def get_prompt_suggestions(self):
"""Return suggested responses for SSH prompts.""" """Return suggested responses for SSH prompts."""
suggestions = super().get_prompt_suggestions() suggestions = super().get_prompt_suggestions()
if self.current_state == 'auth_prompt': if self.current_state == "auth_prompt":
if 'password:' in self.multiplexer.get_all_output()['stdout']: if "password:" in self.multiplexer.get_all_output()["stdout"]:
suggestions.extend(['<password>']) # Placeholder for actual password suggestions.extend(["<password>"]) # Placeholder for actual password
elif 'yes/no' in self.multiplexer.get_all_output()['stdout']: elif "yes/no" in self.multiplexer.get_all_output()["stdout"]:
suggestions.extend(['yes', 'no']) suggestions.extend(["yes", "no"])
return suggestions return suggestions
class GenericProcessHandler(ProcessHandler): class GenericProcessHandler(ProcessHandler):
"""Fallback handler for unknown process types.""" """Fallback handler for unknown process types."""
def __init__(self, multiplexer): def __init__(self, multiplexer):
super().__init__(multiplexer) super().__init__(multiplexer)
self.state_machine = { self.state_machine = {
'initial': ['running'], "initial": ["running"],
'running': ['waiting_input', 'completed'], "running": ["waiting_input", "completed"],
'waiting_input': ['running'], "waiting_input": ["running"],
'completed': [] "completed": [],
} }
self.prompt_patterns = [ self.prompt_patterns = [
(r'\?\s*$', 'waiting_input'), # Lines ending with ? (r"\?\s*$", "waiting_input"), # Lines ending with ?
(r'>\s*$', 'waiting_input'), # Lines ending with > (r">\s*$", "waiting_input"), # Lines ending with >
(r':\s*$', 'waiting_input'), # Lines ending with : (r":\s*$", "waiting_input"), # Lines ending with :
(r'done', 'completed'), (r"done", "completed"),
(r'finished', 'completed'), (r"finished", "completed"),
(r'exit code', 'completed') (r"exit code", "completed"),
] ]
def get_process_type(self): def get_process_type(self):
return 'generic' return "generic"
def update_state(self, output): def update_state(self, output):
"""Basic state detection for generic processes.""" """Basic state detection for generic processes."""
output_lower = output.lower() output_lower = output.lower()
if any(pattern in output_lower for pattern in ['done', 'finished', 'complete']): if any(pattern in output_lower for pattern in ["done", "finished", "complete"]):
self.current_state = 'completed' self.current_state = "completed"
elif any(output.strip().endswith(char) for char in ['?', '>', ':']): elif any(output.strip().endswith(char) for char in ["?", ">", ":"]):
self.current_state = 'waiting_input' self.current_state = "waiting_input"
else: else:
self.current_state = 'running' self.current_state = "running"
# Handler registry # Handler registry
_handler_classes = { _handler_classes = {
'apt': AptHandler, "apt": AptHandler,
'vim': VimHandler, "vim": VimHandler,
'ssh': SSHHandler, "ssh": SSHHandler,
'generic': GenericProcessHandler "generic": GenericProcessHandler,
} }
def get_handler_for_process(process_type, multiplexer): def get_handler_for_process(process_type, multiplexer):
"""Get appropriate handler for a process type.""" """Get appropriate handler for a process type."""
handler_class = _handler_classes.get(process_type, GenericProcessHandler) handler_class = _handler_classes.get(process_type, GenericProcessHandler)
return handler_class(multiplexer) return handler_class(multiplexer)
def detect_process_type(command): def detect_process_type(command):
"""Detect process type from command.""" """Detect process type from command."""
command_str = ' '.join(command) if isinstance(command, list) else command command_str = " ".join(command) if isinstance(command, list) else command
command_lower = command_str.lower() command_lower = command_str.lower()
if 'apt' in command_lower or 'apt-get' in command_lower: if "apt" in command_lower or "apt-get" in command_lower:
return 'apt' return "apt"
elif 'vim' in command_lower or 'vi ' in command_lower: elif "vim" in command_lower or "vi " in command_lower:
return 'vim' return "vim"
elif 'ssh' in command_lower: elif "ssh" in command_lower:
return 'ssh' return "ssh"
else: else:
return 'generic' return "generic"
return 'ssh' return "ssh"
def detect_process_type(command): def detect_process_type(command):
"""Detect process type from command.""" """Detect process type from command."""
command_str = ' '.join(command) if isinstance(command, list) else command command_str = " ".join(command) if isinstance(command, list) else command
command_lower = command_str.lower() command_lower = command_str.lower()
if 'apt' in command_lower or 'apt-get' in command_lower: if "apt" in command_lower or "apt-get" in command_lower:
return 'apt' return "apt"
elif 'vim' in command_lower or 'vi ' in command_lower: elif "vim" in command_lower or "vi " in command_lower:
return 'vim' return "vim"
elif 'ssh' in command_lower: elif "ssh" in command_lower:
return 'ssh' return "ssh"
else: else:
return 'generic' return "generic"

View File

@ -1,6 +1,6 @@
import re import re
import time import time
from collections import defaultdict
class PromptDetector: class PromptDetector:
"""Detects various process prompts and manages interaction state.""" """Detects various process prompts and manages interaction state."""
@ -10,101 +10,119 @@ class PromptDetector:
self.state_machines = self._load_state_machines() self.state_machines = self._load_state_machines()
self.session_states = {} self.session_states = {}
self.timeout_configs = { self.timeout_configs = {
'default': 30, # 30 seconds default timeout "default": 30, # 30 seconds default timeout
'apt': 300, # 5 minutes for apt operations "apt": 300, # 5 minutes for apt operations
'ssh': 60, # 1 minute for SSH connections "ssh": 60, # 1 minute for SSH connections
'vim': 3600 # 1 hour for vim sessions "vim": 3600, # 1 hour for vim sessions
} }
def _load_prompt_patterns(self): def _load_prompt_patterns(self):
"""Load regex patterns for detecting various prompts.""" """Load regex patterns for detecting various prompts."""
return { return {
'bash_prompt': [ "bash_prompt": [
re.compile(r'[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$'), re.compile(r"[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$"),
re.compile(r'\$\s*$'), re.compile(r"\$\s*$"),
re.compile(r'#\s*$'), re.compile(r"#\s*$"),
re.compile(r'>\s*$') # Continuation prompt re.compile(r">\s*$"), # Continuation prompt
], ],
'confirmation': [ "confirmation": [
re.compile(r'[Yy]/[Nn]', re.IGNORECASE), re.compile(r"[Yy]/[Nn]", re.IGNORECASE),
re.compile(r'[Yy]es/[Nn]o', re.IGNORECASE), re.compile(r"[Yy]es/[Nn]o", re.IGNORECASE),
re.compile(r'continue\?', re.IGNORECASE), re.compile(r"continue\?", re.IGNORECASE),
re.compile(r'proceed\?', re.IGNORECASE) re.compile(r"proceed\?", re.IGNORECASE),
], ],
'password': [ "password": [
re.compile(r'password:', re.IGNORECASE), re.compile(r"password:", re.IGNORECASE),
re.compile(r'passphrase:', re.IGNORECASE), re.compile(r"passphrase:", re.IGNORECASE),
re.compile(r'enter password', re.IGNORECASE) re.compile(r"enter password", re.IGNORECASE),
], ],
'sudo_password': [ "sudo_password": [re.compile(r"\[sudo\].*password", re.IGNORECASE)],
re.compile(r'\[sudo\].*password', re.IGNORECASE) "apt": [
re.compile(r"Do you want to continue\?", re.IGNORECASE),
re.compile(r"After this operation", re.IGNORECASE),
re.compile(r"Need to get", re.IGNORECASE),
], ],
'apt': [ "vim": [
re.compile(r'Do you want to continue\?', re.IGNORECASE), re.compile(r"-- INSERT --"),
re.compile(r'After this operation', re.IGNORECASE), re.compile(r"-- VISUAL --"),
re.compile(r'Need to get', re.IGNORECASE) re.compile(r":"),
re.compile(r"Press ENTER", re.IGNORECASE),
], ],
'vim': [ "ssh": [
re.compile(r'-- INSERT --'), re.compile(r"yes/no", re.IGNORECASE),
re.compile(r'-- VISUAL --'), re.compile(r"password:", re.IGNORECASE),
re.compile(r':'), re.compile(r"Permission denied", re.IGNORECASE),
re.compile(r'Press ENTER', re.IGNORECASE)
], ],
'ssh': [ "git": [
re.compile(r'yes/no', re.IGNORECASE), re.compile(r"Username:", re.IGNORECASE),
re.compile(r'password:', re.IGNORECASE), re.compile(r"Email:", re.IGNORECASE),
re.compile(r'Permission denied', re.IGNORECASE)
], ],
'git': [ "error": [
re.compile(r'Username:', re.IGNORECASE), re.compile(r"error:", re.IGNORECASE),
re.compile(r'Email:', re.IGNORECASE) re.compile(r"failed", re.IGNORECASE),
re.compile(r"exception", re.IGNORECASE),
], ],
'error': [
re.compile(r'error:', re.IGNORECASE),
re.compile(r'failed', re.IGNORECASE),
re.compile(r'exception', re.IGNORECASE)
]
} }
def _load_state_machines(self): def _load_state_machines(self):
"""Load state machines for different process types.""" """Load state machines for different process types."""
return { return {
'apt': { "apt": {
'states': ['initial', 'running', 'confirming', 'installing', 'completed', 'error'], "states": [
'transitions': { "initial",
'initial': ['running'], "running",
'running': ['confirming', 'installing', 'completed', 'error'], "confirming",
'confirming': ['installing', 'cancelled'], "installing",
'installing': ['completed', 'error'], "completed",
'completed': [], "error",
'error': [], ],
'cancelled': [] "transitions": {
} "initial": ["running"],
"running": ["confirming", "installing", "completed", "error"],
"confirming": ["installing", "cancelled"],
"installing": ["completed", "error"],
"completed": [],
"error": [],
"cancelled": [],
},
}, },
'ssh': { "ssh": {
'states': ['initial', 'connecting', 'authenticating', 'connected', 'error'], "states": [
'transitions': { "initial",
'initial': ['connecting'], "connecting",
'connecting': ['authenticating', 'connected', 'error'], "authenticating",
'authenticating': ['connected', 'error'], "connected",
'connected': ['error'], "error",
'error': [] ],
} "transitions": {
"initial": ["connecting"],
"connecting": ["authenticating", "connected", "error"],
"authenticating": ["connected", "error"],
"connected": ["error"],
"error": [],
},
},
"vim": {
"states": [
"initial",
"normal",
"insert",
"visual",
"command",
"exiting",
],
"transitions": {
"initial": ["normal", "insert"],
"normal": ["insert", "visual", "command", "exiting"],
"insert": ["normal"],
"visual": ["normal"],
"command": ["normal", "exiting"],
"exiting": [],
},
}, },
'vim': {
'states': ['initial', 'normal', 'insert', 'visual', 'command', 'exiting'],
'transitions': {
'initial': ['normal', 'insert'],
'normal': ['insert', 'visual', 'command', 'exiting'],
'insert': ['normal'],
'visual': ['normal'],
'command': ['normal', 'exiting'],
'exiting': []
}
}
} }
def detect_prompt(self, output, process_type='generic'): def detect_prompt(self, output, process_type="generic"):
"""Detect what type of prompt is present in the output.""" """Detect what type of prompt is present in the output."""
detections = {} detections = {}
@ -125,93 +143,97 @@ class PromptDetector:
return detections return detections
def get_response_suggestions(self, prompt_detections, process_type='generic'): def get_response_suggestions(self, prompt_detections, process_type="generic"):
"""Get suggested responses based on detected prompts.""" """Get suggested responses based on detected prompts."""
suggestions = [] suggestions = []
for category, patterns in prompt_detections.items(): for category, patterns in prompt_detections.items():
if category == 'confirmation': if category == "confirmation":
suggestions.extend(['y', 'yes', 'n', 'no']) suggestions.extend(["y", "yes", "n", "no"])
elif category == 'password': elif category == "password":
suggestions.append('<password>') suggestions.append("<password>")
elif category == 'sudo_password': elif category == "sudo_password":
suggestions.append('<sudo_password>') suggestions.append("<sudo_password>")
elif category == 'apt': elif category == "apt":
if any('continue' in p for p in patterns): if any("continue" in p for p in patterns):
suggestions.extend(['y', 'yes']) suggestions.extend(["y", "yes"])
elif category == 'vim': elif category == "vim":
if any(':' in p for p in patterns): if any(":" in p for p in patterns):
suggestions.extend(['w', 'q', 'wq', 'q!']) suggestions.extend(["w", "q", "wq", "q!"])
elif any('ENTER' in p for p in patterns): elif any("ENTER" in p for p in patterns):
suggestions.append('\n') suggestions.append("\n")
elif category == 'ssh': elif category == "ssh":
if any('yes/no' in p for p in patterns): if any("yes/no" in p for p in patterns):
suggestions.extend(['yes', 'no']) suggestions.extend(["yes", "no"])
elif any('password' in p for p in patterns): elif any("password" in p for p in patterns):
suggestions.append('<password>') suggestions.append("<password>")
elif category == 'bash_prompt': elif category == "bash_prompt":
suggestions.extend(['help', 'ls', 'pwd', 'exit']) suggestions.extend(["help", "ls", "pwd", "exit"])
return list(set(suggestions)) # Remove duplicates return list(set(suggestions)) # Remove duplicates
def update_session_state(self, session_name, output, process_type='generic'): def update_session_state(self, session_name, output, process_type="generic"):
"""Update the state machine for a session based on output.""" """Update the state machine for a session based on output."""
if session_name not in self.session_states: if session_name not in self.session_states:
self.session_states[session_name] = { self.session_states[session_name] = {
'current_state': 'initial', "current_state": "initial",
'process_type': process_type, "process_type": process_type,
'last_activity': time.time(), "last_activity": time.time(),
'transitions': [] "transitions": [],
} }
session_state = self.session_states[session_name] session_state = self.session_states[session_name]
old_state = session_state['current_state'] old_state = session_state["current_state"]
# Detect prompts and determine new state # Detect prompts and determine new state
detections = self.detect_prompt(output, process_type) detections = self.detect_prompt(output, process_type)
new_state = self._determine_state_from_detections(detections, process_type, old_state) new_state = self._determine_state_from_detections(
detections, process_type, old_state
)
if new_state != old_state: if new_state != old_state:
session_state['transitions'].append({ session_state["transitions"].append(
'from': old_state, {
'to': new_state, "from": old_state,
'timestamp': time.time(), "to": new_state,
'trigger': detections "timestamp": time.time(),
}) "trigger": detections,
session_state['current_state'] = new_state }
)
session_state["current_state"] = new_state
session_state['last_activity'] = time.time() session_state["last_activity"] = time.time()
return new_state return new_state
def _determine_state_from_detections(self, detections, process_type, current_state): def _determine_state_from_detections(self, detections, process_type, current_state):
"""Determine new state based on prompt detections.""" """Determine new state based on prompt detections."""
if process_type in self.state_machines: if process_type in self.state_machines:
state_machine = self.state_machines[process_type] self.state_machines[process_type]
# State transition logic based on detections # State transition logic based on detections
if 'confirmation' in detections and current_state in ['running', 'initial']: if "confirmation" in detections and current_state in ["running", "initial"]:
return 'confirming' return "confirming"
elif 'password' in detections or 'sudo_password' in detections: elif "password" in detections or "sudo_password" in detections:
return 'authenticating' return "authenticating"
elif 'error' in detections: elif "error" in detections:
return 'error' return "error"
elif 'bash_prompt' in detections and current_state != 'initial': elif "bash_prompt" in detections and current_state != "initial":
return 'connected' if process_type == 'ssh' else 'completed' return "connected" if process_type == "ssh" else "completed"
elif 'vim' in detections: elif "vim" in detections:
if any('-- INSERT --' in p for p in detections.get('vim', [])): if any("-- INSERT --" in p for p in detections.get("vim", [])):
return 'insert' return "insert"
elif any('-- VISUAL --' in p for p in detections.get('vim', [])): elif any("-- VISUAL --" in p for p in detections.get("vim", [])):
return 'visual' return "visual"
elif any(':' in p for p in detections.get('vim', [])): elif any(":" in p for p in detections.get("vim", [])):
return 'command' return "command"
# Default state progression # Default state progression
if current_state == 'initial': if current_state == "initial":
return 'running' return "running"
elif current_state == 'running' and detections: elif current_state == "running" and detections:
return 'waiting_input' return "waiting_input"
elif current_state == 'waiting_input' and not detections: elif current_state == "waiting_input" and not detections:
return 'running' return "running"
return current_state return current_state
@ -220,15 +242,15 @@ class PromptDetector:
if session_name not in self.session_states: if session_name not in self.session_states:
return False return False
state = self.session_states[session_name]['current_state'] state = self.session_states[session_name]["current_state"]
process_type = self.session_states[session_name]['process_type'] process_type = self.session_states[session_name]["process_type"]
# States that typically indicate waiting for input # States that typically indicate waiting for input
waiting_states = { waiting_states = {
'generic': ['waiting_input'], "generic": ["waiting_input"],
'apt': ['confirming'], "apt": ["confirming"],
'ssh': ['authenticating'], "ssh": ["authenticating"],
'vim': ['command', 'insert', 'visual'] "vim": ["command", "insert", "visual"],
} }
return state in waiting_states.get(process_type, []) return state in waiting_states.get(process_type, [])
@ -236,10 +258,10 @@ class PromptDetector:
def get_session_timeout(self, session_name): def get_session_timeout(self, session_name):
"""Get the timeout for a session based on its process type.""" """Get the timeout for a session based on its process type."""
if session_name not in self.session_states: if session_name not in self.session_states:
return self.timeout_configs['default'] return self.timeout_configs["default"]
process_type = self.session_states[session_name]['process_type'] process_type = self.session_states[session_name]["process_type"]
return self.timeout_configs.get(process_type, self.timeout_configs['default']) return self.timeout_configs.get(process_type, self.timeout_configs["default"])
def check_for_timeouts(self): def check_for_timeouts(self):
"""Check all sessions for timeouts and return timed out sessions.""" """Check all sessions for timeouts and return timed out sessions."""
@ -248,7 +270,7 @@ class PromptDetector:
for session_name, state in self.session_states.items(): for session_name, state in self.session_states.items():
timeout = self.get_session_timeout(session_name) timeout = self.get_session_timeout(session_name)
if current_time - state['last_activity'] > timeout: if current_time - state["last_activity"] > timeout:
timed_out.append(session_name) timed_out.append(session_name)
return timed_out return timed_out
@ -260,16 +282,18 @@ class PromptDetector:
state = self.session_states[session_name] state = self.session_states[session_name]
return { return {
'current_state': state['current_state'], "current_state": state["current_state"],
'process_type': state['process_type'], "process_type": state["process_type"],
'last_activity': state['last_activity'], "last_activity": state["last_activity"],
'transitions': state['transitions'][-5:], # Last 5 transitions "transitions": state["transitions"][-5:], # Last 5 transitions
'is_waiting': self.is_waiting_for_input(session_name) "is_waiting": self.is_waiting_for_input(session_name),
} }
# Global detector instance # Global detector instance
_detector = None _detector = None
def get_global_detector(): def get_global_detector():
"""Get the global prompt detector instance.""" """Get the global prompt detector instance."""
global _detector global _detector

View File

@ -1,6 +1,7 @@
import contextlib
import traceback import traceback
from io import StringIO from io import StringIO
import contextlib
def python_exec(code, python_globals): def python_exec(code, python_globals):
try: try:

View File

@ -1,7 +1,8 @@
import urllib.request
import urllib.parse
import urllib.error
import json import json
import urllib.error
import urllib.parse
import urllib.request
def http_fetch(url, headers=None): def http_fetch(url, headers=None):
try: try:
@ -11,26 +12,28 @@ def http_fetch(url, headers=None):
req.add_header(key, value) req.add_header(key, value)
with urllib.request.urlopen(req) as response: with urllib.request.urlopen(req) as response:
content = response.read().decode('utf-8') content = response.read().decode("utf-8")
return {"status": "success", "content": content[:10000]} return {"status": "success", "content": content[:10000]}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def _perform_search(base_url, query, params=None): def _perform_search(base_url, query, params=None):
try: try:
full_url = f"https://static.molodetz.nl/search.cgi?query={query}" full_url = f"https://static.molodetz.nl/search.cgi?query={query}"
with urllib.request.urlopen(full_url) as response: with urllib.request.urlopen(full_url) as response:
content = response.read().decode('utf-8') content = response.read().decode("utf-8")
return {"status": "success", "content": json.loads(content)} return {"status": "success", "content": json.loads(content)}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def web_search(query): def web_search(query):
base_url = "https://search.molodetz.nl/search" base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query) return _perform_search(base_url, query)
def web_search_news(query): def web_search_news(query):
base_url = "https://search.molodetz.nl/search" base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query) return _perform_search(base_url, query)

View File

@ -1,5 +1,11 @@
from pr.ui.colors import Colors from pr.ui.colors import Colors
from pr.ui.rendering import highlight_code, render_markdown
from pr.ui.display import display_tool_call, print_autonomous_header from pr.ui.display import display_tool_call, print_autonomous_header
from pr.ui.rendering import highlight_code, render_markdown
__all__ = ['Colors', 'highlight_code', 'render_markdown', 'display_tool_call', 'print_autonomous_header'] __all__ = [
"Colors",
"highlight_code",
"render_markdown",
"display_tool_call",
"print_autonomous_header",
]

View File

@ -1,14 +1,14 @@
class Colors: class Colors:
RESET = '\033[0m' RESET = "\033[0m"
BOLD = '\033[1m' BOLD = "\033[1m"
RED = '\033[91m' RED = "\033[91m"
GREEN = '\033[92m' GREEN = "\033[92m"
YELLOW = '\033[93m' YELLOW = "\033[93m"
BLUE = '\033[94m' BLUE = "\033[94m"
MAGENTA = '\033[95m' MAGENTA = "\033[95m"
CYAN = '\033[96m' CYAN = "\033[96m"
GRAY = '\033[90m' GRAY = "\033[90m"
WHITE = '\033[97m' WHITE = "\033[97m"
BG_BLUE = '\033[44m' BG_BLUE = "\033[44m"
BG_GREEN = '\033[42m' BG_GREEN = "\033[42m"
BG_RED = '\033[41m' BG_RED = "\033[41m"

View File

@ -1,5 +1,6 @@
import difflib import difflib
from typing import List, Tuple, Dict, Optional from typing import Dict, List, Optional, Tuple
from .colors import Colors from .colors import Colors
@ -19,8 +20,13 @@ class DiffStats:
class DiffLine: class DiffLine:
def __init__(self, line_type: str, content: str, old_line_num: Optional[int] = None, def __init__(
new_line_num: Optional[int] = None): self,
line_type: str,
content: str,
old_line_num: Optional[int] = None,
new_line_num: Optional[int] = None,
):
self.line_type = line_type self.line_type = line_type
self.content = content self.content = content
self.old_line_num = old_line_num self.old_line_num = old_line_num
@ -28,27 +34,27 @@ class DiffLine:
def format(self, show_line_nums: bool = True) -> str: def format(self, show_line_nums: bool = True) -> str:
color = { color = {
'add': Colors.GREEN, "add": Colors.GREEN,
'delete': Colors.RED, "delete": Colors.RED,
'context': Colors.GRAY, "context": Colors.GRAY,
'header': Colors.CYAN, "header": Colors.CYAN,
'stats': Colors.BLUE "stats": Colors.BLUE,
}.get(self.line_type, Colors.RESET) }.get(self.line_type, Colors.RESET)
prefix = { prefix = {
'add': '+ ', "add": "+ ",
'delete': '- ', "delete": "- ",
'context': ' ', "context": " ",
'header': '', "header": "",
'stats': '' "stats": "",
}.get(self.line_type, ' ') }.get(self.line_type, " ")
if show_line_nums and self.line_type in ('add', 'delete', 'context'): if show_line_nums and self.line_type in ("add", "delete", "context"):
old_num = str(self.old_line_num) if self.old_line_num else ' ' old_num = str(self.old_line_num) if self.old_line_num else " "
new_num = str(self.new_line_num) if self.new_line_num else ' ' new_num = str(self.new_line_num) if self.new_line_num else " "
line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} " line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} "
else: else:
line_num_str = '' line_num_str = ""
return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}" return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}"
@ -57,8 +63,9 @@ class DiffDisplay:
def __init__(self, context_lines: int = 3): def __init__(self, context_lines: int = 3):
self.context_lines = context_lines self.context_lines = context_lines
def create_diff(self, old_content: str, new_content: str, def create_diff(
filename: str = "file") -> Tuple[List[DiffLine], DiffStats]: self, old_content: str, new_content: str, filename: str = "file"
) -> Tuple[List[DiffLine], DiffStats]:
old_lines = old_content.splitlines(keepends=True) old_lines = old_content.splitlines(keepends=True)
new_lines = new_content.splitlines(keepends=True) new_lines = new_content.splitlines(keepends=True)
@ -67,31 +74,38 @@ class DiffDisplay:
stats.files_changed = 1 stats.files_changed = 1
diff = difflib.unified_diff( diff = difflib.unified_diff(
old_lines, new_lines, old_lines,
new_lines,
fromfile=f"a/{filename}", fromfile=f"a/{filename}",
tofile=f"b/{filename}", tofile=f"b/{filename}",
n=self.context_lines n=self.context_lines,
) )
old_line_num = 0 old_line_num = 0
new_line_num = 0 new_line_num = 0
for line in diff: for line in diff:
if line.startswith('---') or line.startswith('+++'): if line.startswith("---") or line.startswith("+++"):
diff_lines.append(DiffLine('header', line.rstrip())) diff_lines.append(DiffLine("header", line.rstrip()))
elif line.startswith('@@'): elif line.startswith("@@"):
diff_lines.append(DiffLine('header', line.rstrip())) diff_lines.append(DiffLine("header", line.rstrip()))
old_line_num, new_line_num = self._parse_hunk_header(line) old_line_num, new_line_num = self._parse_hunk_header(line)
elif line.startswith('+'): elif line.startswith("+"):
stats.insertions += 1 stats.insertions += 1
diff_lines.append(DiffLine('add', line[1:].rstrip(), None, new_line_num)) diff_lines.append(
DiffLine("add", line[1:].rstrip(), None, new_line_num)
)
new_line_num += 1 new_line_num += 1
elif line.startswith('-'): elif line.startswith("-"):
stats.deletions += 1 stats.deletions += 1
diff_lines.append(DiffLine('delete', line[1:].rstrip(), old_line_num, None)) diff_lines.append(
DiffLine("delete", line[1:].rstrip(), old_line_num, None)
)
old_line_num += 1 old_line_num += 1
elif line.startswith(' '): elif line.startswith(" "):
diff_lines.append(DiffLine('context', line[1:].rstrip(), old_line_num, new_line_num)) diff_lines.append(
DiffLine("context", line[1:].rstrip(), old_line_num, new_line_num)
)
old_line_num += 1 old_line_num += 1
new_line_num += 1 new_line_num += 1
@ -101,15 +115,20 @@ class DiffDisplay:
def _parse_hunk_header(self, header: str) -> Tuple[int, int]: def _parse_hunk_header(self, header: str) -> Tuple[int, int]:
try: try:
parts = header.split('@@')[1].strip().split() parts = header.split("@@")[1].strip().split()
old_start = int(parts[0].split(',')[0].replace('-', '')) old_start = int(parts[0].split(",")[0].replace("-", ""))
new_start = int(parts[1].split(',')[0].replace('+', '')) new_start = int(parts[1].split(",")[0].replace("+", ""))
return old_start, new_start return old_start, new_start
except (IndexError, ValueError): except (IndexError, ValueError):
return 0, 0 return 0, 0
def render_diff(self, diff_lines: List[DiffLine], stats: DiffStats, def render_diff(
show_line_nums: bool = True, show_stats: bool = True) -> str: self,
diff_lines: List[DiffLine],
stats: DiffStats,
show_line_nums: bool = True,
show_stats: bool = True,
) -> str:
output = [] output = []
if show_stats: if show_stats:
@ -124,10 +143,15 @@ class DiffDisplay:
if show_stats: if show_stats:
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output) return "\n".join(output)
def display_file_diff(self, old_content: str, new_content: str, def display_file_diff(
filename: str = "file", show_line_nums: bool = True) -> str: self,
old_content: str,
new_content: str,
filename: str = "file",
show_line_nums: bool = True,
) -> str:
diff_lines, stats = self.create_diff(old_content, new_content, filename) diff_lines, stats = self.create_diff(old_content, new_content, filename)
if not diff_lines: if not diff_lines:
@ -135,8 +159,13 @@ class DiffDisplay:
return self.render_diff(diff_lines, stats, show_line_nums) return self.render_diff(diff_lines, stats, show_line_nums)
def display_side_by_side(self, old_content: str, new_content: str, def display_side_by_side(
filename: str = "file", width: int = 80) -> str: self,
old_content: str,
new_content: str,
filename: str = "file",
width: int = 80,
) -> str:
old_lines = old_content.splitlines() old_lines = old_content.splitlines()
new_lines = new_content.splitlines() new_lines = new_content.splitlines()
@ -144,40 +173,57 @@ class DiffDisplay:
output = [] output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}") output.append(
f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}"
)
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n") output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
half_width = (width - 5) // 2 half_width = (width - 5) // 2
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal': if tag == "equal":
for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])): for i, (old_line, new_line) in enumerate(
zip(old_lines[i1:i2], new_lines[j1:j2])
):
old_display = old_line[:half_width].ljust(half_width) old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width)
output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}") output.append(
elif tag == 'replace': f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}"
)
elif tag == "replace":
max_lines = max(i2 - i1, j2 - j1) max_lines = max(i2 - i1, j2 - j1)
for i in range(max_lines): for i in range(max_lines):
old_line = old_lines[i1 + i] if i1 + i < i2 else "" old_line = old_lines[i1 + i] if i1 + i < i2 else ""
new_line = new_lines[j1 + i] if j1 + i < j2 else "" new_line = new_lines[j1 + i] if j1 + i < j2 else ""
old_display = old_line[:half_width].ljust(half_width) old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width)
output.append(f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}") output.append(
elif tag == 'delete': f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}"
)
elif tag == "delete":
for old_line in old_lines[i1:i2]: for old_line in old_lines[i1:i2]:
old_display = old_line[:half_width].ljust(half_width) old_display = old_line[:half_width].ljust(half_width)
output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}") output.append(
elif tag == 'insert': f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}"
)
elif tag == "insert":
for new_line in new_lines[j1:j2]: for new_line in new_lines[j1:j2]:
new_display = new_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width)
output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}") output.append(
f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}"
)
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
return '\n'.join(output) return "\n".join(output)
def display_diff(old_content: str, new_content: str, filename: str = "file", def display_diff(
format_type: str = "unified", context_lines: int = 3) -> str: old_content: str,
new_content: str,
filename: str = "file",
format_type: str = "unified",
context_lines: int = 3,
) -> str:
displayer = DiffDisplay(context_lines) displayer = DiffDisplay(context_lines)
if format_type == "side-by-side": if format_type == "side-by-side":
@ -191,9 +237,9 @@ def get_diff_stats(old_content: str, new_content: str) -> Dict[str, int]:
_, stats = displayer.create_diff(old_content, new_content) _, stats = displayer.create_diff(old_content, new_content)
return { return {
'insertions': stats.insertions, "insertions": stats.insertions,
'deletions': stats.deletions, "deletions": stats.deletions,
'modifications': stats.modifications, "modifications": stats.modifications,
'total_changes': stats.total_changes, "total_changes": stats.total_changes,
'files_changed': stats.files_changed "files_changed": stats.files_changed,
} }

View File

@ -1,8 +1,6 @@
import json
import time
from typing import Dict, Any
from pr.ui.colors import Colors from pr.ui.colors import Colors
def display_tool_call(tool_name, arguments, status="running", result=None): def display_tool_call(tool_name, arguments, status="running", result=None):
if status == "running": if status == "running":
return return
@ -15,8 +13,11 @@ def display_tool_call(tool_name, arguments, status="running", result=None):
print(f"{Colors.GRAY}{line}{Colors.RESET}") print(f"{Colors.GRAY}{line}{Colors.RESET}")
def print_autonomous_header(task): def print_autonomous_header(task):
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}") print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}") print(
f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}"
)
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n") print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n") print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n")

View File

@ -1,12 +1,20 @@
from typing import List, Dict, Optional
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional
from .colors import Colors from .colors import Colors
from .progress import ProgressBar from .progress import ProgressBar
class EditOperation: class EditOperation:
def __init__(self, op_type: str, filepath: str, start_pos: int = 0, def __init__(
end_pos: int = 0, content: str = "", old_content: str = ""): self,
op_type: str,
filepath: str,
start_pos: int = 0,
end_pos: int = 0,
content: str = "",
old_content: str = "",
):
self.op_type = op_type self.op_type = op_type
self.filepath = filepath self.filepath = filepath
self.start_pos = start_pos self.start_pos = start_pos
@ -18,40 +26,46 @@ class EditOperation:
def format_operation(self) -> str: def format_operation(self) -> str:
op_colors = { op_colors = {
'INSERT': Colors.GREEN, "INSERT": Colors.GREEN,
'REPLACE': Colors.YELLOW, "REPLACE": Colors.YELLOW,
'DELETE': Colors.RED, "DELETE": Colors.RED,
'WRITE': Colors.BLUE "WRITE": Colors.BLUE,
} }
color = op_colors.get(self.op_type, Colors.RESET) color = op_colors.get(self.op_type, Colors.RESET)
status_icon = { status_icon = {
'pending': '', "pending": "",
'in_progress': '', "in_progress": "",
'completed': '', "completed": "",
'failed': '' "failed": "",
}.get(self.status, '') }.get(self.status, "")
return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}" return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}"
def format_details(self, show_content: bool = True) -> str: def format_details(self, show_content: bool = True) -> str:
output = [self.format_operation()] output = [self.format_operation()]
if self.op_type in ('INSERT', 'REPLACE'): if self.op_type in ("INSERT", "REPLACE"):
output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}") output.append(
f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}"
)
if show_content: if show_content:
if self.old_content: if self.old_content:
lines = self.old_content.split('\n') lines = self.old_content.split("\n")
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '') preview = lines[0][:60] + (
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
)
output.append(f" {Colors.RED}- {preview}{Colors.RESET}") output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
if self.content: if self.content:
lines = self.content.split('\n') lines = self.content.split("\n")
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '') preview = lines[0][:60] + (
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
)
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}") output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
return '\n'.join(output) return "\n".join(output)
class EditTracker: class EditTracker:
@ -76,11 +90,13 @@ class EditTracker:
def get_stats(self) -> Dict[str, int]: def get_stats(self) -> Dict[str, int]:
stats = { stats = {
'total': len(self.operations), "total": len(self.operations),
'completed': sum(1 for op in self.operations if op.status == 'completed'), "completed": sum(1 for op in self.operations if op.status == "completed"),
'pending': sum(1 for op in self.operations if op.status == 'pending'), "pending": sum(1 for op in self.operations if op.status == "pending"),
'in_progress': sum(1 for op in self.operations if op.status == 'in_progress'), "in_progress": sum(
'failed': sum(1 for op in self.operations if op.status == 'failed') 1 for op in self.operations if op.status == "in_progress"
),
"failed": sum(1 for op in self.operations if op.status == "failed"),
} }
return stats return stats
@ -88,7 +104,7 @@ class EditTracker:
if not self.operations: if not self.operations:
return 0.0 return 0.0
stats = self.get_stats() stats = self.get_stats()
return (stats['completed'] / stats['total']) * 100 return (stats["completed"] / stats["total"]) * 100
def display_progress(self) -> str: def display_progress(self) -> str:
if not self.operations: if not self.operations:
@ -96,26 +112,30 @@ class EditTracker:
output = [] output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}") output.append(
f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}"
)
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
stats = self.get_stats() stats = self.get_stats()
completion = self.get_completion_percentage() self.get_completion_percentage()
progress_bar = ProgressBar(total=stats['total'], width=40) progress_bar = ProgressBar(total=stats["total"], width=40)
progress_bar.current = stats['completed'] progress_bar.current = stats["completed"]
bar_display = progress_bar._get_bar_display() bar_display = progress_bar._get_bar_display()
output.append(f"Progress: {bar_display}") output.append(f"Progress: {bar_display}")
output.append(f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, " output.append(
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n") f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n"
)
output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}") output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}")
for i, op in enumerate(self.operations[-5:], 1): for i, op in enumerate(self.operations[-5:], 1):
output.append(f"{i}. {op.format_operation()}") output.append(f"{i}. {op.format_operation()}")
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output) return "\n".join(output)
def display_timeline(self, show_content: bool = False) -> str: def display_timeline(self, show_content: bool = False) -> str:
if not self.operations: if not self.operations:
@ -134,18 +154,20 @@ class EditTracker:
stats = self.get_stats() stats = self.get_stats()
output.append(f"{Colors.BOLD}Summary:{Colors.RESET}") output.append(f"{Colors.BOLD}Summary:{Colors.RESET}")
output.append(f"{Colors.BLUE}Total operations: {stats['total']}, " output.append(
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}") f"{Colors.BLUE}Total operations: {stats['total']}, "
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}"
)
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output) return "\n".join(output)
def display_summary(self) -> str: def display_summary(self) -> str:
if not self.operations: if not self.operations:
return f"{Colors.GRAY}No edits to summarize{Colors.RESET}" return f"{Colors.GRAY}No edits to summarize{Colors.RESET}"
stats = self.get_stats() stats = self.get_stats()
files_modified = len(set(op.filepath for op in self.operations)) files_modified = len({op.filepath for op in self.operations})
output = [] output = []
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}") output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}")
@ -156,7 +178,7 @@ class EditTracker:
output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}") output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}")
output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}") output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}")
if stats['failed'] > 0: if stats["failed"] > 0:
output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}") output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}")
output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}") output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}")
@ -168,7 +190,7 @@ class EditTracker:
output.append(f" {op_type}: {count}") output.append(f" {op_type}: {count}")
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n") output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output) return "\n".join(output)
def clear(self): def clear(self):
self.operations.clear() self.operations.clear()

View File

@ -1,31 +1,31 @@
import json import json
import sys import sys
from typing import Any, Dict, List
from datetime import datetime from datetime import datetime
from typing import Any
class OutputFormatter: class OutputFormatter:
def __init__(self, format_type: str = 'text', quiet: bool = False): def __init__(self, format_type: str = "text", quiet: bool = False):
self.format_type = format_type self.format_type = format_type
self.quiet = quiet self.quiet = quiet
def output(self, data: Any, message_type: str = 'response'): def output(self, data: Any, message_type: str = "response"):
if self.quiet and message_type not in ['error', 'result']: if self.quiet and message_type not in ["error", "result"]:
return return
if self.format_type == 'json': if self.format_type == "json":
self._output_json(data, message_type) self._output_json(data, message_type)
elif self.format_type == 'structured': elif self.format_type == "structured":
self._output_structured(data, message_type) self._output_structured(data, message_type)
else: else:
self._output_text(data, message_type) self._output_text(data, message_type)
def _output_json(self, data: Any, message_type: str): def _output_json(self, data: Any, message_type: str):
output = { output = {
'type': message_type, "type": message_type,
'timestamp': datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
'data': data "data": data,
} }
print(json.dumps(output, indent=2)) print(json.dumps(output, indent=2))
@ -46,24 +46,24 @@ class OutputFormatter:
print(data) print(data)
def error(self, message: str): def error(self, message: str):
if self.format_type == 'json': if self.format_type == "json":
self._output_json({'error': message}, 'error') self._output_json({"error": message}, "error")
else: else:
print(f"Error: {message}", file=sys.stderr) print(f"Error: {message}", file=sys.stderr)
def success(self, message: str): def success(self, message: str):
if not self.quiet: if not self.quiet:
if self.format_type == 'json': if self.format_type == "json":
self._output_json({'success': message}, 'success') self._output_json({"success": message}, "success")
else: else:
print(message) print(message)
def info(self, message: str): def info(self, message: str):
if not self.quiet: if not self.quiet:
if self.format_type == 'json': if self.format_type == "json":
self._output_json({'info': message}, 'info') self._output_json({"info": message}, "info")
else: else:
print(message) print(message)
def result(self, data: Any): def result(self, data: Any):
self.output(data, 'result') self.output(data, "result")

View File

@ -1,6 +1,6 @@
import sys import sys
import time
import threading import threading
import time
class ProgressIndicator: class ProgressIndicator:
@ -30,15 +30,15 @@ class ProgressIndicator:
self.running = False self.running = False
if self.thread: if self.thread:
self.thread.join(timeout=1.0) self.thread.join(timeout=1.0)
sys.stdout.write('\r' + ' ' * (len(self.message) + 10) + '\r') sys.stdout.write("\r" + " " * (len(self.message) + 10) + "\r")
sys.stdout.flush() sys.stdout.flush()
def _animate(self): def _animate(self):
spinner = ['', '', '', '', '', '', '', '', '', ''] spinner = ["", "", "", "", "", "", "", "", "", ""]
idx = 0 idx = 0
while self.running: while self.running:
sys.stdout.write(f'\r{spinner[idx]} {self.message}...') sys.stdout.write(f"\r{spinner[idx]} {self.message}...")
sys.stdout.flush() sys.stdout.flush()
idx = (idx + 1) % len(spinner) idx = (idx + 1) % len(spinner)
time.sleep(0.1) time.sleep(0.1)
@ -62,14 +62,20 @@ class ProgressBar:
else: else:
percent = int((self.current / self.total) * 100) percent = int((self.current / self.total) * 100)
filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width filled = (
bar = '' * filled + '' * (self.width - filled) int((self.current / self.total) * self.width)
if self.total > 0
else self.width
)
bar = "" * filled + "" * (self.width - filled)
sys.stdout.write(f'\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})') sys.stdout.write(
f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})"
)
sys.stdout.flush() sys.stdout.flush()
if self.current >= self.total: if self.current >= self.total:
sys.stdout.write('\n') sys.stdout.write("\n")
def finish(self): def finish(self):
self.current = self.total self.current = self.total

View File

@ -1,90 +1,103 @@
import re import re
from pr.ui.colors import Colors
from pr.config import LANGUAGE_KEYWORDS from pr.config import LANGUAGE_KEYWORDS
from pr.ui.colors import Colors
def highlight_code(code, language=None, syntax_highlighting=True): def highlight_code(code, language=None, syntax_highlighting=True):
if not syntax_highlighting: if not syntax_highlighting:
return code return code
if not language: if not language:
if 'def ' in code or 'import ' in code: if "def " in code or "import " in code:
language = 'python' language = "python"
elif 'function ' in code or 'const ' in code: elif "function " in code or "const " in code:
language = 'javascript' language = "javascript"
elif 'public ' in code or 'class ' in code: elif "public " in code or "class " in code:
language = 'java' language = "java"
if language and language in LANGUAGE_KEYWORDS: if language and language in LANGUAGE_KEYWORDS:
keywords = LANGUAGE_KEYWORDS[language] keywords = LANGUAGE_KEYWORDS[language]
for keyword in keywords: for keyword in keywords:
pattern = r'\b' + re.escape(keyword) + r'\b' pattern = r"\b" + re.escape(keyword) + r"\b"
code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code) code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code)
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code) code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code) code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
code = re.sub(r'#(.*)$', f'{Colors.GRAY}#\\1{Colors.RESET}', code, flags=re.MULTILINE) code = re.sub(
code = re.sub(r'//(.*)$', f'{Colors.GRAY}//\\1{Colors.RESET}', code, flags=re.MULTILINE) r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE
)
code = re.sub(
r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE
)
return code return code
def render_markdown(text, syntax_highlighting=True): def render_markdown(text, syntax_highlighting=True):
if not syntax_highlighting: if not syntax_highlighting:
return text return text
code_blocks = [] code_blocks = []
def extract_code_block(match): def extract_code_block(match):
lang = match.group(1) or '' lang = match.group(1) or ""
code = match.group(2) code = match.group(2)
highlighted_code = highlight_code(code.strip('\n'), lang, syntax_highlighting) highlighted_code = highlight_code(code.strip("\n"), lang, syntax_highlighting)
placeholder = f"%%CODEBLOCK{len(code_blocks)}%%" placeholder = f"%%CODEBLOCK{len(code_blocks)}%%"
full_block = f'{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}' full_block = f"{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}"
code_blocks.append(full_block) code_blocks.append(full_block)
return placeholder return placeholder
text = re.sub(r'```(\w*)\n(.*?)\n?```', extract_code_block, text, flags=re.DOTALL) text = re.sub(r"```(\w*)\n(.*?)\n?```", extract_code_block, text, flags=re.DOTALL)
inline_codes = [] inline_codes = []
def extract_inline_code(match): def extract_inline_code(match):
code = match.group(1) code = match.group(1)
placeholder = f"%%INLINECODE{len(inline_codes)}%%" placeholder = f"%%INLINECODE{len(inline_codes)}%%"
inline_codes.append(f'{Colors.YELLOW}{code}{Colors.RESET}') inline_codes.append(f"{Colors.YELLOW}{code}{Colors.RESET}")
return placeholder return placeholder
text = re.sub(r'`([^`]+)`', extract_inline_code, text) text = re.sub(r"`([^`]+)`", extract_inline_code, text)
lines = text.split('\n') lines = text.split("\n")
processed_lines = [] processed_lines = []
for line in lines: for line in lines:
if line.startswith('### '): if line.startswith("### "):
line = f'{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}' line = f"{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}"
elif line.startswith('## '): elif line.startswith("## "):
line = f'{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}' line = f"{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}"
elif line.startswith('# '): elif line.startswith("# "):
line = f'{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}' line = f"{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}"
elif line.startswith('> '): elif line.startswith("> "):
line = f'{Colors.CYAN}> {line[2:]}{Colors.RESET}' line = f"{Colors.CYAN}> {line[2:]}{Colors.RESET}"
elif re.match(r'^\s*[\*\-\+]\s', line): elif re.match(r"^\s*[\*\-\+]\s", line):
match = re.match(r'^(\s*)([\*\-\+])(\s+.*)', line) match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line)
if match: if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
elif re.match(r'^\s*\d+\.\s', line): elif re.match(r"^\s*\d+\.\s", line):
match = re.match(r'^(\s*)(\d+\.)(\s+.*)', line) match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line)
if match: if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
processed_lines.append(line) processed_lines.append(line)
text = '\n'.join(processed_lines) text = "\n".join(processed_lines)
text = re.sub(r'\[(.*?)\]\((.*?)\)', f'{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}', text) text = re.sub(
text = re.sub(r'~~(.*?)~~', f'{Colors.GRAY}\\1{Colors.RESET}', text) r"\[(.*?)\]\((.*?)\)",
text = re.sub(r'\*\*(.*?)\*\*', f'{Colors.BOLD}\\1{Colors.RESET}', text) f"{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}",
text = re.sub(r'__(.*?)__', f'{Colors.BOLD}\\1{Colors.RESET}', text) text,
text = re.sub(r'\*(.*?)\*', f'{Colors.CYAN}\\1{Colors.RESET}', text) )
text = re.sub(r'_(.*?)_', f'{Colors.CYAN}\\1{Colors.RESET}', text) text = re.sub(r"~~(.*?)~~", f"{Colors.GRAY}\\1{Colors.RESET}", text)
text = re.sub(r"\*\*(.*?)\*\*", f"{Colors.BOLD}\\1{Colors.RESET}", text)
text = re.sub(r"__(.*?)__", f"{Colors.BOLD}\\1{Colors.RESET}", text)
text = re.sub(r"\*(.*?)\*", f"{Colors.CYAN}\\1{Colors.RESET}", text)
text = re.sub(r"_(.*?)_", f"{Colors.CYAN}\\1{Colors.RESET}", text)
for i, code in enumerate(inline_codes): for i, code in enumerate(inline_codes):
text = text.replace(f'%%INLINECODE{i}%%', code) text = text.replace(f"%%INLINECODE{i}%%", code)
for i, block in enumerate(code_blocks): for i, block in enumerate(code_blocks):
text = text.replace(f'%%CODEBLOCK{i}%%', block) text = text.replace(f"%%CODEBLOCK{i}%%", block)
return text return text

View File

@ -1,5 +1,11 @@
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode from .workflow_definition import ExecutionMode, Workflow, WorkflowStep
from .workflow_engine import WorkflowEngine from .workflow_engine import WorkflowEngine
from .workflow_storage import WorkflowStorage from .workflow_storage import WorkflowStorage
__all__ = ['Workflow', 'WorkflowStep', 'ExecutionMode', 'WorkflowEngine', 'WorkflowStorage'] __all__ = [
"Workflow",
"WorkflowStep",
"ExecutionMode",
"WorkflowEngine",
"WorkflowStorage",
]

View File

@ -1,12 +1,14 @@
from enum import Enum
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
class ExecutionMode(Enum): class ExecutionMode(Enum):
SEQUENTIAL = "sequential" SEQUENTIAL = "sequential"
PARALLEL = "parallel" PARALLEL = "parallel"
CONDITIONAL = "conditional" CONDITIONAL = "conditional"
@dataclass @dataclass
class WorkflowStep: class WorkflowStep:
tool_name: str tool_name: str
@ -20,29 +22,30 @@ class WorkflowStep:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
'tool_name': self.tool_name, "tool_name": self.tool_name,
'arguments': self.arguments, "arguments": self.arguments,
'step_id': self.step_id, "step_id": self.step_id,
'condition': self.condition, "condition": self.condition,
'on_success': self.on_success, "on_success": self.on_success,
'on_failure': self.on_failure, "on_failure": self.on_failure,
'retry_count': self.retry_count, "retry_count": self.retry_count,
'timeout_seconds': self.timeout_seconds "timeout_seconds": self.timeout_seconds,
} }
@staticmethod @staticmethod
def from_dict(data: Dict[str, Any]) -> 'WorkflowStep': def from_dict(data: Dict[str, Any]) -> "WorkflowStep":
return WorkflowStep( return WorkflowStep(
tool_name=data['tool_name'], tool_name=data["tool_name"],
arguments=data['arguments'], arguments=data["arguments"],
step_id=data['step_id'], step_id=data["step_id"],
condition=data.get('condition'), condition=data.get("condition"),
on_success=data.get('on_success'), on_success=data.get("on_success"),
on_failure=data.get('on_failure'), on_failure=data.get("on_failure"),
retry_count=data.get('retry_count', 0), retry_count=data.get("retry_count", 0),
timeout_seconds=data.get('timeout_seconds', 300) timeout_seconds=data.get("timeout_seconds", 300),
) )
@dataclass @dataclass
class Workflow: class Workflow:
name: str name: str
@ -54,23 +57,23 @@ class Workflow:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
'name': self.name, "name": self.name,
'description': self.description, "description": self.description,
'steps': [step.to_dict() for step in self.steps], "steps": [step.to_dict() for step in self.steps],
'execution_mode': self.execution_mode.value, "execution_mode": self.execution_mode.value,
'variables': self.variables, "variables": self.variables,
'tags': self.tags "tags": self.tags,
} }
@staticmethod @staticmethod
def from_dict(data: Dict[str, Any]) -> 'Workflow': def from_dict(data: Dict[str, Any]) -> "Workflow":
return Workflow( return Workflow(
name=data['name'], name=data["name"],
description=data['description'], description=data["description"],
steps=[WorkflowStep.from_dict(step) for step in data['steps']], steps=[WorkflowStep.from_dict(step) for step in data["steps"]],
execution_mode=ExecutionMode(data.get('execution_mode', 'sequential')), execution_mode=ExecutionMode(data.get("execution_mode", "sequential")),
variables=data.get('variables', {}), variables=data.get("variables", {}),
tags=data.get('tags', []) tags=data.get("tags", []),
) )
def add_step(self, step: WorkflowStep): def add_step(self, step: WorkflowStep):

View File

@ -1,8 +1,10 @@
import time
import re import re
from typing import Dict, Any, List, Callable, Optional import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode from typing import Any, Callable, Dict, List, Optional
from .workflow_definition import ExecutionMode, Workflow, WorkflowStep
class WorkflowExecutionContext: class WorkflowExecutionContext:
def __init__(self): def __init__(self):
@ -23,57 +25,66 @@ class WorkflowExecutionContext:
return self.step_results.get(step_id) return self.step_results.get(step_id)
def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]): def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]):
self.execution_log.append({ self.execution_log.append(
'timestamp': time.time(), {
'event_type': event_type, "timestamp": time.time(),
'step_id': step_id, "event_type": event_type,
'details': details "step_id": step_id,
}) "details": details,
}
)
class WorkflowEngine: class WorkflowEngine:
def __init__(self, tool_executor: Callable, max_workers: int = 5): def __init__(self, tool_executor: Callable, max_workers: int = 5):
self.tool_executor = tool_executor self.tool_executor = tool_executor
self.max_workers = max_workers self.max_workers = max_workers
def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool: def _evaluate_condition(
self, condition: str, context: WorkflowExecutionContext
) -> bool:
if not condition: if not condition:
return True return True
try: try:
safe_locals = { safe_locals = {
'variables': context.variables, "variables": context.variables,
'results': context.step_results "results": context.step_results,
} }
return eval(condition, {"__builtins__": {}}, safe_locals) return eval(condition, {"__builtins__": {}}, safe_locals)
except Exception: except Exception:
return False return False
def _substitute_variables(self, arguments: Dict[str, Any], context: WorkflowExecutionContext) -> Dict[str, Any]: def _substitute_variables(
self, arguments: Dict[str, Any], context: WorkflowExecutionContext
) -> Dict[str, Any]:
substituted = {} substituted = {}
for key, value in arguments.items(): for key, value in arguments.items():
if isinstance(value, str): if isinstance(value, str):
pattern = r'\$\{([^}]+)\}' pattern = r"\$\{([^}]+)\}"
matches = re.findall(pattern, value) matches = re.findall(pattern, value)
for match in matches: for match in matches:
if match.startswith('step.'): if match.startswith("step."):
step_id = match.split('.', 1)[1] step_id = match.split(".", 1)[1]
replacement = context.get_step_result(step_id) replacement = context.get_step_result(step_id)
if replacement is not None: if replacement is not None:
value = value.replace(f'${{{match}}}', str(replacement)) value = value.replace(f"${{{match}}}", str(replacement))
elif match.startswith('var.'): elif match.startswith("var."):
var_name = match.split('.', 1)[1] var_name = match.split(".", 1)[1]
replacement = context.get_variable(var_name) replacement = context.get_variable(var_name)
if replacement is not None: if replacement is not None:
value = value.replace(f'${{{match}}}', str(replacement)) value = value.replace(f"${{{match}}}", str(replacement))
substituted[key] = value substituted[key] = value
else: else:
substituted[key] = value substituted[key] = value
return substituted return substituted
def _execute_step(self, step: WorkflowStep, context: WorkflowExecutionContext) -> Dict[str, Any]: def _execute_step(
self, step: WorkflowStep, context: WorkflowExecutionContext
) -> Dict[str, Any]:
if not self._evaluate_condition(step.condition, context): if not self._evaluate_condition(step.condition, context):
context.log_event('skipped', step.step_id, {'reason': 'condition_not_met'}) context.log_event("skipped", step.step_id, {"reason": "condition_not_met"})
return {'status': 'skipped', 'step_id': step.step_id} return {"status": "skipped", "step_id": step.step_id}
arguments = self._substitute_variables(step.arguments, context) arguments = self._substitute_variables(step.arguments, context)
@ -83,26 +94,34 @@ class WorkflowEngine:
while retry_attempts <= step.retry_count: while retry_attempts <= step.retry_count:
try: try:
context.log_event('executing', step.step_id, { context.log_event(
'tool': step.tool_name, "executing",
'arguments': arguments, step.step_id,
'attempt': retry_attempts + 1 {
}) "tool": step.tool_name,
"arguments": arguments,
"attempt": retry_attempts + 1,
},
)
result = self.tool_executor(step.tool_name, arguments) result = self.tool_executor(step.tool_name, arguments)
execution_time = time.time() - start_time execution_time = time.time() - start_time
context.set_step_result(step.step_id, result) context.set_step_result(step.step_id, result)
context.log_event('completed', step.step_id, { context.log_event(
'execution_time': execution_time, "completed",
'result_size': len(str(result)) if result else 0 step.step_id,
}) {
"execution_time": execution_time,
"result_size": len(str(result)) if result else 0,
},
)
return { return {
'status': 'success', "status": "success",
'step_id': step.step_id, "step_id": step.step_id,
'result': result, "result": result,
'execution_time': execution_time "execution_time": execution_time,
} }
except Exception as e: except Exception as e:
@ -111,25 +130,26 @@ class WorkflowEngine:
if retry_attempts <= step.retry_count: if retry_attempts <= step.retry_count:
time.sleep(1 * retry_attempts) time.sleep(1 * retry_attempts)
context.log_event('failed', step.step_id, {'error': last_error}) context.log_event("failed", step.step_id, {"error": last_error})
return { return {
'status': 'failed', "status": "failed",
'step_id': step.step_id, "step_id": step.step_id,
'error': last_error, "error": last_error,
'execution_time': time.time() - start_time "execution_time": time.time() - start_time,
} }
def _get_next_steps(self, completed_step: WorkflowStep, result: Dict[str, Any], def _get_next_steps(
workflow: Workflow) -> List[WorkflowStep]: self, completed_step: WorkflowStep, result: Dict[str, Any], workflow: Workflow
) -> List[WorkflowStep]:
next_steps = [] next_steps = []
if result['status'] == 'success' and completed_step.on_success: if result["status"] == "success" and completed_step.on_success:
for step_id in completed_step.on_success: for step_id in completed_step.on_success:
step = workflow.get_step(step_id) step = workflow.get_step(step_id)
if step: if step:
next_steps.append(step) next_steps.append(step)
elif result['status'] == 'failed' and completed_step.on_failure: elif result["status"] == "failed" and completed_step.on_failure:
for step_id in completed_step.on_failure: for step_id in completed_step.on_failure:
step = workflow.get_step(step_id) step = workflow.get_step(step_id)
if step: if step:
@ -142,7 +162,9 @@ class WorkflowEngine:
return next_steps return next_steps
def execute_workflow(self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None) -> WorkflowExecutionContext: def execute_workflow(
self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None
) -> WorkflowExecutionContext:
context = WorkflowExecutionContext() context = WorkflowExecutionContext()
if initial_variables: if initial_variables:
@ -151,7 +173,7 @@ class WorkflowEngine:
if workflow.variables: if workflow.variables:
context.variables.update(workflow.variables) context.variables.update(workflow.variables)
context.log_event('workflow_started', 'workflow', {'name': workflow.name}) context.log_event("workflow_started", "workflow", {"name": workflow.name})
if workflow.execution_mode == ExecutionMode.PARALLEL: if workflow.execution_mode == ExecutionMode.PARALLEL:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor: with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@ -164,9 +186,11 @@ class WorkflowEngine:
step = futures[future] step = futures[future]
try: try:
result = future.result() result = future.result()
context.log_event('step_completed', step.step_id, result) context.log_event("step_completed", step.step_id, result)
except Exception as e: except Exception as e:
context.log_event('step_failed', step.step_id, {'error': str(e)}) context.log_event(
"step_failed", step.step_id, {"error": str(e)}
)
else: else:
pending_steps = workflow.get_initial_steps() pending_steps = workflow.get_initial_steps()
@ -184,9 +208,13 @@ class WorkflowEngine:
next_steps = self._get_next_steps(step, result, workflow) next_steps = self._get_next_steps(step, result, workflow)
pending_steps.extend(next_steps) pending_steps.extend(next_steps)
context.log_event('workflow_completed', 'workflow', { context.log_event(
'total_steps': len(context.step_results), "workflow_completed",
'executed_steps': list(context.step_results.keys()) "workflow",
}) {
"total_steps": len(context.step_results),
"executed_steps": list(context.step_results.keys()),
},
)
return context return context

View File

@ -2,8 +2,10 @@ import json
import sqlite3 import sqlite3
import time import time
from typing import List, Optional from typing import List, Optional
from .workflow_definition import Workflow from .workflow_definition import Workflow
class WorkflowStorage: class WorkflowStorage:
def __init__(self, db_path: str): def __init__(self, db_path: str):
self.db_path = db_path self.db_path = db_path
@ -13,7 +15,8 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS workflows ( CREATE TABLE IF NOT EXISTS workflows (
workflow_id TEXT PRIMARY KEY, workflow_id TEXT PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
@ -25,9 +28,11 @@ class WorkflowStorage:
last_execution_at INTEGER, last_execution_at INTEGER,
tags TEXT tags TEXT
) )
''') """
)
cursor.execute(''' cursor.execute(
"""
CREATE TABLE IF NOT EXISTS workflow_executions ( CREATE TABLE IF NOT EXISTS workflow_executions (
execution_id TEXT PRIMARY KEY, execution_id TEXT PRIMARY KEY,
workflow_id TEXT NOT NULL, workflow_id TEXT NOT NULL,
@ -39,17 +44,24 @@ class WorkflowStorage:
step_results TEXT, step_results TEXT,
FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id) FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id)
) )
''') """
)
cursor.execute(''' cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name) CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id) CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id)
''') """
cursor.execute(''' )
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at) CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at)
''') """
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -66,12 +78,22 @@ class WorkflowStorage:
current_time = int(time.time()) current_time = int(time.time())
tags_json = json.dumps(workflow.tags) tags_json = json.dumps(workflow.tags)
cursor.execute(''' cursor.execute(
"""
INSERT OR REPLACE INTO workflows INSERT OR REPLACE INTO workflows
(workflow_id, name, description, workflow_data, created_at, updated_at, tags) (workflow_id, name, description, workflow_data, created_at, updated_at, tags)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
''', (workflow_id, workflow.name, workflow.description, workflow_data, """,
current_time, current_time, tags_json)) (
workflow_id,
workflow.name,
workflow.description,
workflow_data,
current_time,
current_time,
tags_json,
),
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -82,7 +104,9 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT workflow_data FROM workflows WHERE workflow_id = ?', (workflow_id,)) cursor.execute(
"SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@ -95,7 +119,7 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT workflow_data FROM workflows WHERE name = ?', (name,)) cursor.execute("SELECT workflow_data FROM workflows WHERE name = ?", (name,))
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@ -109,29 +133,36 @@ class WorkflowStorage:
cursor = conn.cursor() cursor = conn.cursor()
if tag: if tag:
cursor.execute(''' cursor.execute(
"""
SELECT workflow_id, name, description, execution_count, last_execution_at, tags SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows FROM workflows
WHERE tags LIKE ? WHERE tags LIKE ?
ORDER BY name ORDER BY name
''', (f'%"{tag}"%',)) """,
(f'%"{tag}"%',),
)
else: else:
cursor.execute(''' cursor.execute(
"""
SELECT workflow_id, name, description, execution_count, last_execution_at, tags SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows FROM workflows
ORDER BY name ORDER BY name
''') """
)
workflows = [] workflows = []
for row in cursor.fetchall(): for row in cursor.fetchall():
workflows.append({ workflows.append(
'workflow_id': row[0], {
'name': row[1], "workflow_id": row[0],
'description': row[2], "name": row[1],
'execution_count': row[3], "description": row[2],
'last_execution_at': row[4], "execution_count": row[3],
'tags': json.loads(row[5]) if row[5] else [] "last_execution_at": row[4],
}) "tags": json.loads(row[5]) if row[5] else [],
}
)
conn.close() conn.close()
return workflows return workflows
@ -140,18 +171,21 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('DELETE FROM workflows WHERE workflow_id = ?', (workflow_id,)) cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,))
deleted = cursor.rowcount > 0 deleted = cursor.rowcount > 0
cursor.execute('DELETE FROM workflow_executions WHERE workflow_id = ?', (workflow_id,)) cursor.execute(
"DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)
)
conn.commit() conn.commit()
conn.close() conn.close()
return deleted return deleted
def save_execution(self, workflow_id: str, execution_context: 'WorkflowExecutionContext') -> str: def save_execution(
import hashlib self, workflow_id: str, execution_context: "WorkflowExecutionContext"
) -> str:
import uuid import uuid
execution_id = str(uuid.uuid4())[:16] execution_id = str(uuid.uuid4())[:16]
@ -159,30 +193,40 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
started_at = int(execution_context.execution_log[0]['timestamp']) if execution_context.execution_log else int(time.time()) started_at = (
int(execution_context.execution_log[0]["timestamp"])
if execution_context.execution_log
else int(time.time())
)
completed_at = int(time.time()) completed_at = int(time.time())
cursor.execute(''' cursor.execute(
"""
INSERT INTO workflow_executions INSERT INTO workflow_executions
(execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results) (execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', ( """,
execution_id, (
workflow_id, execution_id,
started_at, workflow_id,
completed_at, started_at,
'completed', completed_at,
json.dumps(execution_context.execution_log), "completed",
json.dumps(execution_context.variables), json.dumps(execution_context.execution_log),
json.dumps(execution_context.step_results) json.dumps(execution_context.variables),
)) json.dumps(execution_context.step_results),
),
)
cursor.execute(''' cursor.execute(
"""
UPDATE workflows UPDATE workflows
SET execution_count = execution_count + 1, SET execution_count = execution_count + 1,
last_execution_at = ? last_execution_at = ?
WHERE workflow_id = ? WHERE workflow_id = ?
''', (completed_at, workflow_id)) """,
(completed_at, workflow_id),
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -193,22 +237,27 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute(
"""
SELECT execution_id, started_at, completed_at, status SELECT execution_id, started_at, completed_at, status
FROM workflow_executions FROM workflow_executions
WHERE workflow_id = ? WHERE workflow_id = ?
ORDER BY started_at DESC ORDER BY started_at DESC
LIMIT ? LIMIT ?
''', (workflow_id, limit)) """,
(workflow_id, limit),
)
executions = [] executions = []
for row in cursor.fetchall(): for row in cursor.fetchall():
executions.append({ executions.append(
'execution_id': row[0], {
'started_at': row[1], "execution_id": row[0],
'completed_at': row[2], "started_at": row[1],
'status': row[3] "completed_at": row[2],
}) "status": row[3],
}
)
conn.close() conn.close()
return executions return executions

3
rp.py
View File

@ -2,8 +2,7 @@
# Trigger build # Trigger build
import sys
from pr.__main__ import main from pr.__main__ import main
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -1,8 +1,9 @@
import pytest
import os import os
import tempfile import tempfile
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
@ -13,19 +14,8 @@ def temp_dir():
@pytest.fixture @pytest.fixture
def mock_api_response(): def mock_api_response():
return { return {
'choices': [ "choices": [{"message": {"role": "assistant", "content": "Test response"}}],
{ "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
'message': {
'role': 'assistant',
'content': 'Test response'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
} }
@ -47,7 +37,7 @@ def mock_args():
@pytest.fixture @pytest.fixture
def sample_context_file(temp_dir): def sample_context_file(temp_dir):
context_path = os.path.join(temp_dir, '.rcontext.txt') context_path = os.path.join(temp_dir, ".rcontext.txt")
with open(context_path, 'w') as f: with open(context_path, "w") as f:
f.write('Sample context content\n') f.write("Sample context content\n")
return context_path return context_path

View File

@ -1,39 +1,53 @@
import pytest
from pr.core.advanced_context import AdvancedContextManager from pr.core.advanced_context import AdvancedContextManager
def test_adaptive_context_window_simple(): def test_adaptive_context_window_simple():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] messages = [
window = mgr.adaptive_context_window(messages, 'simple') {"content": "short"},
{"content": "this is a longer message with more words"},
]
window = mgr.adaptive_context_window(messages, "simple")
assert isinstance(window, int) assert isinstance(window, int)
assert window >= 10 assert window >= 10
def test_adaptive_context_window_medium(): def test_adaptive_context_window_medium():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] messages = [
window = mgr.adaptive_context_window(messages, 'medium') {"content": "short"},
{"content": "this is a longer message with more words"},
]
window = mgr.adaptive_context_window(messages, "medium")
assert isinstance(window, int) assert isinstance(window, int)
assert window >= 20 assert window >= 20
def test_adaptive_context_window_complex(): def test_adaptive_context_window_complex():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] messages = [
window = mgr.adaptive_context_window(messages, 'complex') {"content": "short"},
{"content": "this is a longer message with more words"},
]
window = mgr.adaptive_context_window(messages, "complex")
assert isinstance(window, int) assert isinstance(window, int)
assert window >= 35 assert window >= 35
def test_analyze_message_complexity(): def test_analyze_message_complexity():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [{'content': 'hello world'}, {'content': 'hello again'}] messages = [{"content": "hello world"}, {"content": "hello again"}]
score = mgr._analyze_message_complexity(messages) score = mgr._analyze_message_complexity(messages)
assert 0 <= score <= 1 assert 0 <= score <= 1
def test_analyze_message_complexity_empty(): def test_analyze_message_complexity_empty():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [] messages = []
score = mgr._analyze_message_complexity(messages) score = mgr._analyze_message_complexity(messages)
assert score == 0 assert score == 0
def test_extract_key_sentences(): def test_extract_key_sentences():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words." text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words."
@ -41,41 +55,47 @@ def test_extract_key_sentences():
assert len(sentences) <= 2 assert len(sentences) <= 2
assert all(isinstance(s, str) for s in sentences) assert all(isinstance(s, str) for s in sentences)
def test_extract_key_sentences_empty(): def test_extract_key_sentences_empty():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
text = "" text = ""
sentences = mgr.extract_key_sentences(text, 5) sentences = mgr.extract_key_sentences(text, 5)
assert sentences == [] assert sentences == []
def test_advanced_summarize_messages(): def test_advanced_summarize_messages():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [{'content': 'Hello'}, {'content': 'How are you?'}] messages = [{"content": "Hello"}, {"content": "How are you?"}]
summary = mgr.advanced_summarize_messages(messages) summary = mgr.advanced_summarize_messages(messages)
assert isinstance(summary, str) assert isinstance(summary, str)
def test_advanced_summarize_messages_empty(): def test_advanced_summarize_messages_empty():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
messages = [] messages = []
summary = mgr.advanced_summarize_messages(messages) summary = mgr.advanced_summarize_messages(messages)
assert summary == "No content to summarize." assert summary == "No content to summarize."
def test_score_message_relevance(): def test_score_message_relevance():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
message = {'content': 'hello world'} message = {"content": "hello world"}
context = 'world hello' context = "world hello"
score = mgr.score_message_relevance(message, context) score = mgr.score_message_relevance(message, context)
assert 0 <= score <= 1 assert 0 <= score <= 1
def test_score_message_relevance_no_overlap(): def test_score_message_relevance_no_overlap():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
message = {'content': 'hello'} message = {"content": "hello"}
context = 'world' context = "world"
score = mgr.score_message_relevance(message, context) score = mgr.score_message_relevance(message, context)
assert score == 0 assert score == 0
def test_score_message_relevance_empty(): def test_score_message_relevance_empty():
mgr = AdvancedContextManager() mgr = AdvancedContextManager()
message = {'content': ''} message = {"content": ""}
context = '' context = ""
score = mgr.score_message_relevance(message, context) score = mgr.score_message_relevance(message, context)
assert score == 0 assert score == 0

View File

@ -1,127 +1,213 @@
import pytest from pr.agents.agent_communication import (
import time AgentCommunicationBus,
AgentMessage,
MessageType,
)
from pr.agents.agent_manager import AgentInstance, AgentManager
from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles
from pr.agents.agent_manager import AgentManager, AgentInstance
from pr.agents.agent_communication import AgentCommunicationBus, AgentMessage, MessageType
def test_get_agent_role(): def test_get_agent_role():
role = get_agent_role('coding') role = get_agent_role("coding")
assert isinstance(role, AgentRole) assert isinstance(role, AgentRole)
assert role.name == 'coding' assert role.name == "coding"
def test_list_agent_roles(): def test_list_agent_roles():
roles = list_agent_roles() roles = list_agent_roles()
assert isinstance(roles, dict) assert isinstance(roles, dict)
assert len(roles) > 0 assert len(roles) > 0
assert 'coding' in roles assert "coding" in roles
def test_agent_role(): def test_agent_role():
role = AgentRole(name='test', description='test', system_prompt='test', allowed_tools=set(), specialization_areas=[]) role = AgentRole(
assert role.name == 'test' name="test",
description="test",
system_prompt="test",
allowed_tools=set(),
specialization_areas=[],
)
assert role.name == "test"
def test_agent_instance(): def test_agent_instance():
role = get_agent_role('coding') role = get_agent_role("coding")
instance = AgentInstance(agent_id='test', role=role) instance = AgentInstance(agent_id="test", role=role)
assert instance.agent_id == 'test' assert instance.agent_id == "test"
assert instance.role == role assert instance.role == role
def test_agent_manager_init(): def test_agent_manager_init():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
assert mgr is not None assert mgr is not None
def test_agent_manager_create_agent(): def test_agent_manager_create_agent():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
agent = mgr.create_agent('coding', 'test_agent') agent = mgr.create_agent("coding", "test_agent")
assert agent is not None assert agent is not None
def test_agent_manager_get_agent(): def test_agent_manager_get_agent():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
mgr.create_agent('coding', 'test_agent') mgr.create_agent("coding", "test_agent")
agent = mgr.get_agent('test_agent') agent = mgr.get_agent("test_agent")
assert isinstance(agent, AgentInstance) assert isinstance(agent, AgentInstance)
def test_agent_manager_remove_agent(): def test_agent_manager_remove_agent():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
mgr.create_agent('coding', 'test_agent') mgr.create_agent("coding", "test_agent")
mgr.remove_agent('test_agent') mgr.remove_agent("test_agent")
agent = mgr.get_agent('test_agent') agent = mgr.get_agent("test_agent")
assert agent is None assert agent is None
def test_agent_manager_send_agent_message(): def test_agent_manager_send_agent_message():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
mgr.create_agent('coding', 'a') mgr.create_agent("coding", "a")
mgr.create_agent('coding', 'b') mgr.create_agent("coding", "b")
mgr.send_agent_message('a', 'b', 'test') mgr.send_agent_message("a", "b", "test")
assert True assert True
def test_agent_manager_get_agent_messages(): def test_agent_manager_get_agent_messages():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
mgr.create_agent('coding', 'test') mgr.create_agent("coding", "test")
messages = mgr.get_agent_messages('test') messages = mgr.get_agent_messages("test")
assert isinstance(messages, list) assert isinstance(messages, list)
def test_agent_manager_get_session_summary(): def test_agent_manager_get_session_summary():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
summary = mgr.get_session_summary() summary = mgr.get_session_summary()
assert isinstance(summary, str) assert isinstance(summary, str)
def test_agent_manager_collaborate_agents(): def test_agent_manager_collaborate_agents():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
result = mgr.collaborate_agents('orchestrator', 'task', ['coding', 'research']) result = mgr.collaborate_agents("orchestrator", "task", ["coding", "research"])
assert result is not None assert result is not None
def test_agent_manager_execute_agent_task(): def test_agent_manager_execute_agent_task():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
mgr.create_agent('coding', 'test') mgr.create_agent("coding", "test")
result = mgr.execute_agent_task('test', 'task') result = mgr.execute_agent_task("test", "task")
assert result is not None assert result is not None
def test_agent_manager_clear_session(): def test_agent_manager_clear_session():
mgr = AgentManager(':memory:', None) mgr = AgentManager(":memory:", None)
mgr.clear_session() mgr.clear_session()
assert True assert True
def test_agent_message(): def test_agent_message():
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') msg = AgentMessage(
assert msg.from_agent == 'a' from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
assert msg.from_agent == "a"
def test_agent_message_to_dict(): def test_agent_message_to_dict():
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
d = msg.to_dict() d = msg.to_dict()
assert isinstance(d, dict) assert isinstance(d, dict)
def test_agent_message_from_dict(): def test_agent_message_from_dict():
d = {'from_agent': 'a', 'to_agent': 'b', 'message_type': 'request', 'content': 'test', 'metadata': {}, 'timestamp': 1.0, 'message_id': 'id'} d = {
"from_agent": "a",
"to_agent": "b",
"message_type": "request",
"content": "test",
"metadata": {},
"timestamp": 1.0,
"message_id": "id",
}
msg = AgentMessage.from_dict(d) msg = AgentMessage.from_dict(d)
assert isinstance(msg, AgentMessage) assert isinstance(msg, AgentMessage)
def test_agent_communication_bus_init(): def test_agent_communication_bus_init():
bus = AgentCommunicationBus(':memory:') bus = AgentCommunicationBus(":memory:")
assert bus is not None assert bus is not None
def test_agent_communication_bus_send_message(): def test_agent_communication_bus_send_message():
bus = AgentCommunicationBus(':memory:') bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg) bus.send_message(msg)
assert True assert True
def test_agent_communication_bus_receive_messages(): def test_agent_communication_bus_receive_messages():
bus = AgentCommunicationBus(':memory:') bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg) bus.send_message(msg)
messages = bus.receive_messages('b') messages = bus.receive_messages("b")
assert len(messages) == 1 assert len(messages) == 1
def test_agent_communication_bus_get_conversation_history(): def test_agent_communication_bus_get_conversation_history():
bus = AgentCommunicationBus(':memory:') bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg) bus.send_message(msg)
history = bus.get_conversation_history('a', 'b') history = bus.get_conversation_history("a", "b")
assert len(history) == 1 assert len(history) == 1
def test_agent_communication_bus_mark_as_read(): def test_agent_communication_bus_mark_as_read():
bus = AgentCommunicationBus(':memory:') bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id') msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg) bus.send_message(msg)
bus.mark_as_read(msg.message_id) bus.mark_as_read(msg.message_id)
assert True assert True

View File

@ -1,63 +1,65 @@
import unittest import unittest
from unittest.mock import patch, MagicMock
import json
import urllib.error import urllib.error
from unittest.mock import MagicMock, patch
from pr.core.api import call_api, list_models from pr.core.api import call_api, list_models
class TestApi(unittest.TestCase): class TestApi(unittest.TestCase):
@patch('pr.core.api.urllib.request.urlopen') @patch("pr.core.api.urllib.request.urlopen")
@patch('pr.core.api.auto_slim_messages') @patch("pr.core.api.auto_slim_messages")
def test_call_api_success(self, mock_slim, mock_urlopen): def test_call_api_success(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}] mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_response = MagicMock() mock_response = MagicMock()
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}' mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
mock_urlopen.return_value.__enter__.return_value = mock_response mock_urlopen.return_value.__enter__.return_value = mock_response
result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}]) result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}])
self.assertIn('choices', result) self.assertIn("choices", result)
mock_urlopen.assert_called_once() mock_urlopen.assert_called_once()
@patch('urllib.request.urlopen') @patch("urllib.request.urlopen")
@patch('pr.core.api.auto_slim_messages') @patch("pr.core.api.auto_slim_messages")
def test_call_api_http_error(self, mock_slim, mock_urlopen): def test_call_api_http_error(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}] mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock()) mock_urlopen.side_effect = urllib.error.HTTPError(
"http://url", 500, "error", None, MagicMock()
)
result = call_api([], 'model', 'http://url', 'key', False, []) result = call_api([], "model", "http://url", "key", False, [])
self.assertIn('error', result) self.assertIn("error", result)
@patch('urllib.request.urlopen') @patch("urllib.request.urlopen")
@patch('pr.core.api.auto_slim_messages') @patch("pr.core.api.auto_slim_messages")
def test_call_api_general_error(self, mock_slim, mock_urlopen): def test_call_api_general_error(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}] mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_urlopen.side_effect = Exception('test error') mock_urlopen.side_effect = Exception("test error")
result = call_api([], 'model', 'http://url', 'key', False, []) result = call_api([], "model", "http://url", "key", False, [])
self.assertIn('error', result) self.assertIn("error", result)
@patch('urllib.request.urlopen') @patch("urllib.request.urlopen")
def test_list_models_success(self, mock_urlopen): def test_list_models_success(self, mock_urlopen):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.read.return_value = b'{"data": [{"id": "model1"}]}' mock_response.read.return_value = b'{"data": [{"id": "model1"}]}'
mock_urlopen.return_value.__enter__.return_value = mock_response mock_urlopen.return_value.__enter__.return_value = mock_response
result = list_models('http://url', 'key') result = list_models("http://url", "key")
self.assertEqual(result, [{'id': 'model1'}]) self.assertEqual(result, [{"id": "model1"}])
@patch('urllib.request.urlopen') @patch("urllib.request.urlopen")
def test_list_models_error(self, mock_urlopen): def test_list_models_error(self, mock_urlopen):
mock_urlopen.side_effect = Exception('error') mock_urlopen.side_effect = Exception("error")
result = list_models('http://url', 'key') result = list_models("http://url", "key")
self.assertIn('error', result) self.assertIn("error", result)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,7 +1,6 @@
import unittest import unittest
from unittest.mock import patch, MagicMock from unittest.mock import MagicMock, patch
import tempfile
import os
from pr.core.assistant import Assistant, process_message from pr.core.assistant import Assistant, process_message
@ -12,83 +11,106 @@ class TestAssistant(unittest.TestCase):
self.args.verbose = False self.args.verbose = False
self.args.debug = False self.args.debug = False
self.args.no_syntax = False self.args.no_syntax = False
self.args.model = 'test-model' self.args.model = "test-model"
self.args.api_url = 'test-url' self.args.api_url = "test-url"
self.args.model_list_url = 'test-list-url' self.args.model_list_url = "test-list-url"
@patch('sqlite3.connect') @patch("sqlite3.connect")
@patch('os.environ.get') @patch("os.environ.get")
@patch('pr.core.context.init_system_message') @patch("pr.core.context.init_system_message")
@patch('pr.core.enhanced_assistant.EnhancedAssistant') @patch("pr.core.enhanced_assistant.EnhancedAssistant")
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite): def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default) mock_env.side_effect = lambda key, default: {
"OPENROUTER_API_KEY": "key",
"AI_MODEL": "model",
"API_URL": "url",
"MODEL_LIST_URL": "list",
"USE_TOOLS": "1",
"STRICT_MODE": "0",
}.get(key, default)
mock_conn = MagicMock() mock_conn = MagicMock()
mock_sqlite.return_value = mock_conn mock_sqlite.return_value = mock_conn
mock_init_sys.return_value = {'role': 'system', 'content': 'sys'} mock_init_sys.return_value = {"role": "system", "content": "sys"}
assistant = Assistant(self.args) assistant = Assistant(self.args)
self.assertEqual(assistant.api_key, 'key') self.assertEqual(assistant.api_key, "key")
self.assertEqual(assistant.model, 'test-model') self.assertEqual(assistant.model, "test-model")
mock_sqlite.assert_called_once() mock_sqlite.assert_called_once()
@patch('pr.core.assistant.call_api') @patch("pr.core.assistant.call_api")
@patch('pr.core.assistant.render_markdown') @patch("pr.core.assistant.render_markdown")
def test_process_response_no_tools(self, mock_render, mock_call): def test_process_response_no_tools(self, mock_render, mock_call):
assistant = MagicMock() assistant = MagicMock()
assistant.messages = MagicMock() assistant.messages = MagicMock()
assistant.verbose = False assistant.verbose = False
assistant.syntax_highlighting = True assistant.syntax_highlighting = True
mock_render.return_value = 'rendered' mock_render.return_value = "rendered"
response = {'choices': [{'message': {'content': 'content'}}]} response = {"choices": [{"message": {"content": "content"}}]}
result = Assistant.process_response(assistant, response) result = Assistant.process_response(assistant, response)
self.assertEqual(result, 'rendered') self.assertEqual(result, "rendered")
assistant.messages.append.assert_called_with({'content': 'content'}) assistant.messages.append.assert_called_with({"content": "content"})
@patch('pr.core.assistant.call_api') @patch("pr.core.assistant.call_api")
@patch('pr.core.assistant.render_markdown') @patch("pr.core.assistant.render_markdown")
@patch('pr.core.assistant.get_tools_definition') @patch("pr.core.assistant.get_tools_definition")
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call): def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
assistant = MagicMock() assistant = MagicMock()
assistant.messages = MagicMock() assistant.messages = MagicMock()
assistant.verbose = False assistant.verbose = False
assistant.syntax_highlighting = True assistant.syntax_highlighting = True
assistant.use_tools = True assistant.use_tools = True
assistant.model = 'model' assistant.model = "model"
assistant.api_url = 'url' assistant.api_url = "url"
assistant.api_key = 'key' assistant.api_key = "key"
mock_tools_def.return_value = [] mock_tools_def.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]} mock_call.return_value = {"choices": [{"message": {"content": "follow"}}]}
response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]} response = {
"choices": [
{
"message": {
"tool_calls": [
{"id": "1", "function": {"name": "test", "arguments": "{}"}}
]
}
}
]
}
with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]): with patch.object(
result = Assistant.process_response(assistant, response) assistant,
"execute_tool_calls",
return_value=[{"role": "tool", "content": "result"}],
):
Assistant.process_response(assistant, response)
mock_call.assert_called() mock_call.assert_called()
@patch('pr.core.assistant.call_api') @patch("pr.core.assistant.call_api")
@patch('pr.core.assistant.get_tools_definition') @patch("pr.core.assistant.get_tools_definition")
def test_process_message(self, mock_tools, mock_call): def test_process_message(self, mock_tools, mock_call):
assistant = MagicMock() assistant = MagicMock()
assistant.messages = MagicMock() assistant.messages = MagicMock()
assistant.verbose = False assistant.verbose = False
assistant.use_tools = True assistant.use_tools = True
assistant.model = 'model' assistant.model = "model"
assistant.api_url = 'url' assistant.api_url = "url"
assistant.api_key = 'key' assistant.api_key = "key"
mock_tools.return_value = [] mock_tools.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]} mock_call.return_value = {"choices": [{"message": {"content": "response"}}]}
with patch('pr.core.assistant.render_markdown', return_value='rendered'): with patch("pr.core.assistant.render_markdown", return_value="rendered"):
with patch('builtins.print'): with patch("builtins.print"):
process_message(assistant, 'test message') process_message(assistant, "test message")
assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'}) assistant.messages.append.assert_called_with(
{"role": "user", "content": "test message"}
)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,31 +1,30 @@
import pytest
from pr import config from pr import config
class TestConfig: class TestConfig:
def test_default_model_exists(self): def test_default_model_exists(self):
assert hasattr(config, 'DEFAULT_MODEL') assert hasattr(config, "DEFAULT_MODEL")
assert isinstance(config.DEFAULT_MODEL, str) assert isinstance(config.DEFAULT_MODEL, str)
assert len(config.DEFAULT_MODEL) > 0 assert len(config.DEFAULT_MODEL) > 0
def test_api_url_exists(self): def test_api_url_exists(self):
assert hasattr(config, 'DEFAULT_API_URL') assert hasattr(config, "DEFAULT_API_URL")
assert config.DEFAULT_API_URL.startswith('http') assert config.DEFAULT_API_URL.startswith("http")
def test_file_paths_exist(self): def test_file_paths_exist(self):
assert hasattr(config, 'DB_PATH') assert hasattr(config, "DB_PATH")
assert hasattr(config, 'LOG_FILE') assert hasattr(config, "LOG_FILE")
assert hasattr(config, 'HISTORY_FILE') assert hasattr(config, "HISTORY_FILE")
def test_autonomous_config(self): def test_autonomous_config(self):
assert hasattr(config, 'MAX_AUTONOMOUS_ITERATIONS') assert hasattr(config, "MAX_AUTONOMOUS_ITERATIONS")
assert config.MAX_AUTONOMOUS_ITERATIONS > 0 assert config.MAX_AUTONOMOUS_ITERATIONS > 0
assert hasattr(config, 'CONTEXT_COMPRESSION_THRESHOLD') assert hasattr(config, "CONTEXT_COMPRESSION_THRESHOLD")
assert config.CONTEXT_COMPRESSION_THRESHOLD > 0 assert config.CONTEXT_COMPRESSION_THRESHOLD > 0
def test_language_keywords(self): def test_language_keywords(self):
assert hasattr(config, 'LANGUAGE_KEYWORDS') assert hasattr(config, "LANGUAGE_KEYWORDS")
assert 'python' in config.LANGUAGE_KEYWORDS assert "python" in config.LANGUAGE_KEYWORDS
assert isinstance(config.LANGUAGE_KEYWORDS['python'], list) assert isinstance(config.LANGUAGE_KEYWORDS["python"], list)

View File

@ -1,56 +1,71 @@
import pytest from unittest.mock import mock_open, patch
from unittest.mock import patch, mock_open
import os from pr.core.config_loader import (
from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config _load_config_file,
_parse_value,
create_default_config,
load_config,
)
def test_parse_value_string(): def test_parse_value_string():
assert _parse_value('hello') == 'hello' assert _parse_value("hello") == "hello"
def test_parse_value_int(): def test_parse_value_int():
assert _parse_value('123') == 123 assert _parse_value("123") == 123
def test_parse_value_float(): def test_parse_value_float():
assert _parse_value('1.23') == 1.23 assert _parse_value("1.23") == 1.23
def test_parse_value_bool_true(): def test_parse_value_bool_true():
assert _parse_value('true') == True assert _parse_value("true") == True
def test_parse_value_bool_false(): def test_parse_value_bool_false():
assert _parse_value('false') == False assert _parse_value("false") == False
def test_parse_value_bool_upper(): def test_parse_value_bool_upper():
assert _parse_value('TRUE') == True assert _parse_value("TRUE") == True
@patch('os.path.exists', return_value=False)
@patch("os.path.exists", return_value=False)
def test_load_config_file_not_exists(mock_exists): def test_load_config_file_not_exists(mock_exists):
config = _load_config_file('test.ini') config = _load_config_file("test.ini")
assert config == {} assert config == {}
@patch('os.path.exists', return_value=True)
@patch('configparser.ConfigParser') @patch("os.path.exists", return_value=True)
@patch("configparser.ConfigParser")
def test_load_config_file_exists(mock_parser_class, mock_exists): def test_load_config_file_exists(mock_parser_class, mock_exists):
mock_parser = mock_parser_class.return_value mock_parser = mock_parser_class.return_value
mock_parser.sections.return_value = ['api'] mock_parser.sections.return_value = ["api"]
mock_parser.items.return_value = [('key', 'value')] mock_parser.items.return_value = [("key", "value")]
config = _load_config_file('test.ini') config = _load_config_file("test.ini")
assert 'api' in config assert "api" in config
assert config['api']['key'] == 'value' assert config["api"]["key"] == "value"
@patch('pr.core.config_loader._load_config_file')
@patch("pr.core.config_loader._load_config_file")
def test_load_config(mock_load): def test_load_config(mock_load):
mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}] mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}]
config = load_config() config = load_config()
assert config['api']['key'] == 'local' assert config["api"]["key"] == "local"
@patch('builtins.open', new_callable=mock_open)
@patch("builtins.open", new_callable=mock_open)
def test_create_default_config(mock_file): def test_create_default_config(mock_file):
result = create_default_config('test.ini') result = create_default_config("test.ini")
assert result == True assert result == True
mock_file.assert_called_once_with('test.ini', 'w') mock_file.assert_called_once_with("test.ini", "w")
handle = mock_file() handle = mock_file()
handle.write.assert_called_once() handle.write.assert_called_once()
@patch('builtins.open', side_effect=Exception('error'))
@patch("builtins.open", side_effect=Exception("error"))
def test_create_default_config_error(mock_file): def test_create_default_config_error(mock_file):
result = create_default_config('test.ini') result = create_default_config("test.ini")
assert result == False assert result == False

View File

@ -1,30 +1,29 @@
import pytest
from pr.core.context import should_compress_context, compress_context
from pr.config import RECENT_MESSAGES_TO_KEEP from pr.config import RECENT_MESSAGES_TO_KEEP
from pr.core.context import compress_context, should_compress_context
class TestContextManagement: class TestContextManagement:
def test_should_compress_context_below_threshold(self): def test_should_compress_context_below_threshold(self):
messages = [{'role': 'user', 'content': 'test'}] * 10 messages = [{"role": "user", "content": "test"}] * 10
assert should_compress_context(messages) is False assert should_compress_context(messages) is False
def test_should_compress_context_above_threshold(self): def test_should_compress_context_above_threshold(self):
messages = [{'role': 'user', 'content': 'test'}] * 35 messages = [{"role": "user", "content": "test"}] * 35
assert should_compress_context(messages) is True assert should_compress_context(messages) is True
def test_compress_context_preserves_system_message(self): def test_compress_context_preserves_system_message(self):
messages = [ messages = [
{'role': 'system', 'content': 'System prompt'}, {"role": "system", "content": "System prompt"},
{'role': 'user', 'content': 'Hello'}, {"role": "user", "content": "Hello"},
{'role': 'assistant', 'content': 'Hi'}, {"role": "assistant", "content": "Hi"},
] * 40 # Ensure compression ] * 40 # Ensure compression
compressed = compress_context(messages) compressed = compress_context(messages)
assert compressed[0]['role'] == 'system' assert compressed[0]["role"] == "system"
assert 'System prompt' in compressed[0]['content'] assert "System prompt" in compressed[0]["content"]
def test_compress_context_keeps_recent_messages(self): def test_compress_context_keeps_recent_messages(self):
messages = [{'role': 'user', 'content': f'msg{i}'} for i in range(40)] messages = [{"role": "user", "content": f"msg{i}"} for i in range(40)]
compressed = compress_context(messages) compressed = compress_context(messages)
# Should keep recent messages # Should keep recent messages
recent = compressed[-RECENT_MESSAGES_TO_KEEP:] recent = compressed[-RECENT_MESSAGES_TO_KEEP:]
@ -32,4 +31,4 @@ class TestContextManagement:
# Check that the messages are the most recent ones # Check that the messages are the most recent ones
for i, msg in enumerate(recent): for i, msg in enumerate(recent):
expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i
assert msg['content'] == f'msg{expected_index}' assert msg["content"] == f"msg{expected_index}"

View File

@ -1,34 +1,37 @@
import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from pr.core.enhanced_assistant import EnhancedAssistant from pr.core.enhanced_assistant import EnhancedAssistant
def test_enhanced_assistant_init(): def test_enhanced_assistant_init():
mock_base = MagicMock() mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base) assistant = EnhancedAssistant(mock_base)
assert assistant.base == mock_base assert assistant.base == mock_base
assert assistant.current_conversation_id is not None assert assistant.current_conversation_id is not None
def test_enhanced_call_api_with_cache(): def test_enhanced_call_api_with_cache():
mock_base = MagicMock() mock_base = MagicMock()
mock_base.model = 'test-model' mock_base.model = "test-model"
mock_base.api_url = 'http://test' mock_base.api_url = "http://test"
mock_base.api_key = 'key' mock_base.api_key = "key"
mock_base.use_tools = False mock_base.use_tools = False
mock_base.verbose = False mock_base.verbose = False
assistant = EnhancedAssistant(mock_base) assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock() assistant.api_cache = MagicMock()
assistant.api_cache.get.return_value = {'cached': True} assistant.api_cache.get.return_value = {"cached": True}
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}]) result = assistant.enhanced_call_api([{"role": "user", "content": "test"}])
assert result == {'cached': True} assert result == {"cached": True}
assistant.api_cache.get.assert_called_once() assistant.api_cache.get.assert_called_once()
def test_enhanced_call_api_without_cache(): def test_enhanced_call_api_without_cache():
mock_base = MagicMock() mock_base = MagicMock()
mock_base.model = 'test-model' mock_base.model = "test-model"
mock_base.api_url = 'http://test' mock_base.api_url = "http://test"
mock_base.api_key = 'key' mock_base.api_key = "key"
mock_base.use_tools = False mock_base.use_tools = False
mock_base.verbose = False mock_base.verbose = False
@ -36,8 +39,9 @@ def test_enhanced_call_api_without_cache():
assistant.api_cache = None assistant.api_cache = None
# It will try to call API and fail with network error, but that's expected # It will try to call API and fail with network error, but that's expected
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}]) result = assistant.enhanced_call_api([{"role": "user", "content": "test"}])
assert 'error' in result assert "error" in result
def test_execute_workflow_not_found(): def test_execute_workflow_not_found():
mock_base = MagicMock() mock_base = MagicMock()
@ -45,38 +49,42 @@ def test_execute_workflow_not_found():
assistant.workflow_storage = MagicMock() assistant.workflow_storage = MagicMock()
assistant.workflow_storage.load_workflow_by_name.return_value = None assistant.workflow_storage.load_workflow_by_name.return_value = None
result = assistant.execute_workflow('nonexistent') result = assistant.execute_workflow("nonexistent")
assert 'error' in result assert "error" in result
def test_create_agent(): def test_create_agent():
mock_base = MagicMock() mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base) assistant = EnhancedAssistant(mock_base)
assistant.agent_manager = MagicMock() assistant.agent_manager = MagicMock()
assistant.agent_manager.create_agent.return_value = 'agent_id' assistant.agent_manager.create_agent.return_value = "agent_id"
result = assistant.create_agent("role")
assert result == "agent_id"
result = assistant.create_agent('role')
assert result == 'agent_id'
def test_search_knowledge(): def test_search_knowledge():
mock_base = MagicMock() mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base) assistant = EnhancedAssistant(mock_base)
assistant.knowledge_store = MagicMock() assistant.knowledge_store = MagicMock()
assistant.knowledge_store.search_entries.return_value = [{'result': True}] assistant.knowledge_store.search_entries.return_value = [{"result": True}]
result = assistant.search_knowledge("query")
assert result == [{"result": True}]
result = assistant.search_knowledge('query')
assert result == [{'result': True}]
def test_get_cache_statistics(): def test_get_cache_statistics():
mock_base = MagicMock() mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base) assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock() assistant.api_cache = MagicMock()
assistant.api_cache.get_statistics.return_value = {'hits': 10} assistant.api_cache.get_statistics.return_value = {"hits": 10}
assistant.tool_cache = MagicMock() assistant.tool_cache = MagicMock()
assistant.tool_cache.get_statistics.return_value = {'misses': 5} assistant.tool_cache.get_statistics.return_value = {"misses": 5}
stats = assistant.get_cache_statistics() stats = assistant.get_cache_statistics()
assert 'api_cache' in stats assert "api_cache" in stats
assert 'tool_cache' in stats assert "tool_cache" in stats
def test_clear_caches(): def test_clear_caches():
mock_base = MagicMock() mock_base = MagicMock()

View File

@ -1,24 +1,25 @@
import pytest from pr.core.logging import get_logger, setup_logging
import tempfile
import os
from pr.core.logging import setup_logging, get_logger
def test_setup_logging_basic(): def test_setup_logging_basic():
logger = setup_logging(verbose=False) logger = setup_logging(verbose=False)
assert logger.name == 'pr' assert logger.name == "pr"
assert logger.level == 20 # INFO assert logger.level == 20 # INFO
def test_setup_logging_verbose(): def test_setup_logging_verbose():
logger = setup_logging(verbose=True) logger = setup_logging(verbose=True)
assert logger.name == 'pr' assert logger.name == "pr"
assert logger.level == 10 # DEBUG assert logger.level == 10 # DEBUG
# Should have console handler # Should have console handler
assert len(logger.handlers) >= 2 assert len(logger.handlers) >= 2
def test_get_logger_default(): def test_get_logger_default():
logger = get_logger() logger = get_logger()
assert logger.name == 'pr' assert logger.name == "pr"
def test_get_logger_named(): def test_get_logger_named():
logger = get_logger('test') logger = get_logger("test")
assert logger.name == 'pr.test' assert logger.name == "pr.test"

View File

@ -1,118 +1,134 @@
import pytest
from unittest.mock import patch from unittest.mock import patch
import sys
import pytest
from pr.__main__ import main from pr.__main__ import main
def test_main_version(capsys): def test_main_version(capsys):
with patch('sys.argv', ['pr', '--version']): with patch("sys.argv", ["pr", "--version"]):
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'PR Assistant' in captured.out assert "PR Assistant" in captured.out
def test_main_create_config_success(capsys): def test_main_create_config_success(capsys):
with patch('pr.core.config_loader.create_default_config', return_value=True): with patch("pr.core.config_loader.create_default_config", return_value=True):
with patch('sys.argv', ['pr', '--create-config']): with patch("sys.argv", ["pr", "--create-config"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Configuration file created' in captured.out assert "Configuration file created" in captured.out
def test_main_create_config_fail(capsys): def test_main_create_config_fail(capsys):
with patch('pr.core.config_loader.create_default_config', return_value=False): with patch("pr.core.config_loader.create_default_config", return_value=False):
with patch('sys.argv', ['pr', '--create-config']): with patch("sys.argv", ["pr", "--create-config"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Error creating configuration file' in captured.err assert "Error creating configuration file" in captured.err
def test_main_list_sessions_no_sessions(capsys): def test_main_list_sessions_no_sessions(capsys):
with patch('pr.core.session.SessionManager') as mock_sm: with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value mock_instance = mock_sm.return_value
mock_instance.list_sessions.return_value = [] mock_instance.list_sessions.return_value = []
with patch('sys.argv', ['pr', '--list-sessions']): with patch("sys.argv", ["pr", "--list-sessions"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'No saved sessions found' in captured.out assert "No saved sessions found" in captured.out
def test_main_list_sessions_with_sessions(capsys): def test_main_list_sessions_with_sessions(capsys):
sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}] sessions = [{"name": "test", "created_at": "2023-01-01", "message_count": 5}]
with patch('pr.core.session.SessionManager') as mock_sm: with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value mock_instance = mock_sm.return_value
mock_instance.list_sessions.return_value = sessions mock_instance.list_sessions.return_value = sessions
with patch('sys.argv', ['pr', '--list-sessions']): with patch("sys.argv", ["pr", "--list-sessions"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Found 1 saved sessions' in captured.out assert "Found 1 saved sessions" in captured.out
assert 'test' in captured.out assert "test" in captured.out
def test_main_delete_session_success(capsys): def test_main_delete_session_success(capsys):
with patch('pr.core.session.SessionManager') as mock_sm: with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value mock_instance = mock_sm.return_value
mock_instance.delete_session.return_value = True mock_instance.delete_session.return_value = True
with patch('sys.argv', ['pr', '--delete-session', 'test']): with patch("sys.argv", ["pr", "--delete-session", "test"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert "Session 'test' deleted" in captured.out assert "Session 'test' deleted" in captured.out
def test_main_delete_session_fail(capsys): def test_main_delete_session_fail(capsys):
with patch('pr.core.session.SessionManager') as mock_sm: with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value mock_instance = mock_sm.return_value
mock_instance.delete_session.return_value = False mock_instance.delete_session.return_value = False
with patch('sys.argv', ['pr', '--delete-session', 'test']): with patch("sys.argv", ["pr", "--delete-session", "test"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert "Error deleting session 'test'" in captured.err assert "Error deleting session 'test'" in captured.err
def test_main_export_session_json(capsys): def test_main_export_session_json(capsys):
with patch('pr.core.session.SessionManager') as mock_sm: with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value mock_instance = mock_sm.return_value
mock_instance.export_session.return_value = True mock_instance.export_session.return_value = True
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']): with patch("sys.argv", ["pr", "--export-session", "test", "output.json"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Session exported to output.json' in captured.out assert "Session exported to output.json" in captured.out
def test_main_export_session_md(capsys): def test_main_export_session_md(capsys):
with patch('pr.core.session.SessionManager') as mock_sm: with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value mock_instance = mock_sm.return_value
mock_instance.export_session.return_value = True mock_instance.export_session.return_value = True
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']): with patch("sys.argv", ["pr", "--export-session", "test", "output.md"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Session exported to output.md' in captured.out assert "Session exported to output.md" in captured.out
def test_main_usage(capsys): def test_main_usage(capsys):
usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01} usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01}
with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage): with patch(
with patch('sys.argv', ['pr', '--usage']): "pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage
):
with patch("sys.argv", ["pr", "--usage"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Total Usage Statistics' in captured.out assert "Total Usage Statistics" in captured.out
assert 'Requests: 10' in captured.out assert "Requests: 10" in captured.out
def test_main_plugins_no_plugins(capsys): def test_main_plugins_no_plugins(capsys):
with patch('pr.plugins.loader.PluginLoader') as mock_loader: with patch("pr.plugins.loader.PluginLoader") as mock_loader:
mock_instance = mock_loader.return_value mock_instance = mock_loader.return_value
mock_instance.load_plugins.return_value = None mock_instance.load_plugins.return_value = None
mock_instance.list_loaded_plugins.return_value = [] mock_instance.list_loaded_plugins.return_value = []
with patch('sys.argv', ['pr', '--plugins']): with patch("sys.argv", ["pr", "--plugins"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'No plugins loaded' in captured.out assert "No plugins loaded" in captured.out
def test_main_plugins_with_plugins(capsys): def test_main_plugins_with_plugins(capsys):
with patch('pr.plugins.loader.PluginLoader') as mock_loader: with patch("pr.plugins.loader.PluginLoader") as mock_loader:
mock_instance = mock_loader.return_value mock_instance = mock_loader.return_value
mock_instance.load_plugins.return_value = None mock_instance.load_plugins.return_value = None
mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2'] mock_instance.list_loaded_plugins.return_value = ["plugin1", "plugin2"]
with patch('sys.argv', ['pr', '--plugins']): with patch("sys.argv", ["pr", "--plugins"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()
assert 'Loaded 2 plugins' in captured.out assert "Loaded 2 plugins" in captured.out
def test_main_run_assistant(): def test_main_run_assistant():
with patch('pr.__main__.Assistant') as mock_assistant: with patch("pr.__main__.Assistant") as mock_assistant:
mock_instance = mock_assistant.return_value mock_instance = mock_assistant.return_value
with patch('sys.argv', ['pr', 'test message']): with patch("sys.argv", ["pr", "test message"]):
main() main()
mock_assistant.assert_called_once() mock_assistant.assert_called_once()
mock_instance.run.assert_called_once() mock_instance.run.assert_called_once()

View File

@ -1,26 +1,32 @@
import pytest
import tempfile
import os
import json import json
import os
import pytest
from pr.core.session import SessionManager from pr.core.session import SessionManager
@pytest.fixture @pytest.fixture
def temp_sessions_dir(tmp_path, monkeypatch): def temp_sessions_dir(tmp_path, monkeypatch):
from pr.core import session from pr.core import session
original_dir = session.SESSIONS_DIR original_dir = session.SESSIONS_DIR
monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path)) monkeypatch.setattr(session, "SESSIONS_DIR", str(tmp_path))
# Clean any existing files # Clean any existing files
import shutil import shutil
if os.path.exists(str(tmp_path)): if os.path.exists(str(tmp_path)):
shutil.rmtree(str(tmp_path)) shutil.rmtree(str(tmp_path))
os.makedirs(str(tmp_path), exist_ok=True) os.makedirs(str(tmp_path), exist_ok=True)
yield tmp_path yield tmp_path
monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir) monkeypatch.setattr(session, "SESSIONS_DIR", original_dir)
def test_session_manager_init(temp_sessions_dir): def test_session_manager_init(temp_sessions_dir):
manager = SessionManager() SessionManager()
assert os.path.exists(temp_sessions_dir) assert os.path.exists(temp_sessions_dir)
def test_save_and_load_session(temp_sessions_dir): def test_save_and_load_session(temp_sessions_dir):
manager = SessionManager() manager = SessionManager()
name = "test_session" name = "test_session"
@ -31,15 +37,17 @@ def test_save_and_load_session(temp_sessions_dir):
loaded = manager.load_session(name) loaded = manager.load_session(name)
assert loaded is not None assert loaded is not None
assert loaded['name'] == name assert loaded["name"] == name
assert loaded['messages'] == messages assert loaded["messages"] == messages
assert loaded['metadata'] == metadata assert loaded["metadata"] == metadata
def test_load_nonexistent_session(temp_sessions_dir): def test_load_nonexistent_session(temp_sessions_dir):
manager = SessionManager() manager = SessionManager()
loaded = manager.load_session("nonexistent") loaded = manager.load_session("nonexistent")
assert loaded is None assert loaded is None
def test_list_sessions(temp_sessions_dir): def test_list_sessions(temp_sessions_dir):
manager = SessionManager() manager = SessionManager()
# Save a session # Save a session
@ -48,7 +56,8 @@ def test_list_sessions(temp_sessions_dir):
sessions = manager.list_sessions() sessions = manager.list_sessions()
assert len(sessions) == 2 assert len(sessions) == 2
assert sessions[0]['name'] == "session2" # sorted by created_at desc assert sessions[0]["name"] == "session2" # sorted by created_at desc
def test_delete_session(temp_sessions_dir): def test_delete_session(temp_sessions_dir):
manager = SessionManager() manager = SessionManager()
@ -58,10 +67,12 @@ def test_delete_session(temp_sessions_dir):
assert manager.delete_session(name) assert manager.delete_session(name)
assert manager.load_session(name) is None assert manager.load_session(name) is None
def test_delete_nonexistent_session(temp_sessions_dir): def test_delete_nonexistent_session(temp_sessions_dir):
manager = SessionManager() manager = SessionManager()
assert not manager.delete_session("nonexistent") assert not manager.delete_session("nonexistent")
def test_export_session_json(temp_sessions_dir, tmp_path): def test_export_session_json(temp_sessions_dir, tmp_path):
manager = SessionManager() manager = SessionManager()
name = "export_test" name = "export_test"
@ -69,12 +80,13 @@ def test_export_session_json(temp_sessions_dir, tmp_path):
manager.save_session(name, messages) manager.save_session(name, messages)
output_path = tmp_path / "exported.json" output_path = tmp_path / "exported.json"
assert manager.export_session(name, str(output_path), 'json') assert manager.export_session(name, str(output_path), "json")
assert output_path.exists() assert output_path.exists()
with open(output_path) as f: with open(output_path) as f:
data = json.load(f) data = json.load(f)
assert data['name'] == name assert data["name"] == name
def test_export_session_markdown(temp_sessions_dir, tmp_path): def test_export_session_markdown(temp_sessions_dir, tmp_path):
manager = SessionManager() manager = SessionManager()
@ -83,12 +95,13 @@ def test_export_session_markdown(temp_sessions_dir, tmp_path):
manager.save_session(name, messages) manager.save_session(name, messages)
output_path = tmp_path / "exported.md" output_path = tmp_path / "exported.md"
assert manager.export_session(name, str(output_path), 'markdown') assert manager.export_session(name, str(output_path), "markdown")
assert output_path.exists() assert output_path.exists()
content = output_path.read_text() content = output_path.read_text()
assert "# Session: export_md" in content assert "# Session: export_md" in content
def test_export_session_txt(temp_sessions_dir, tmp_path): def test_export_session_txt(temp_sessions_dir, tmp_path):
manager = SessionManager() manager = SessionManager()
name = "export_txt" name = "export_txt"
@ -96,16 +109,18 @@ def test_export_session_txt(temp_sessions_dir, tmp_path):
manager.save_session(name, messages) manager.save_session(name, messages)
output_path = tmp_path / "exported.txt" output_path = tmp_path / "exported.txt"
assert manager.export_session(name, str(output_path), 'txt') assert manager.export_session(name, str(output_path), "txt")
assert output_path.exists() assert output_path.exists()
content = output_path.read_text() content = output_path.read_text()
assert "Session: export_txt" in content assert "Session: export_txt" in content
def test_export_nonexistent_session(temp_sessions_dir, tmp_path): def test_export_nonexistent_session(temp_sessions_dir, tmp_path):
manager = SessionManager() manager = SessionManager()
output_path = tmp_path / "nonexistent.json" output_path = tmp_path / "nonexistent.json"
assert not manager.export_session("nonexistent", str(output_path), 'json') assert not manager.export_session("nonexistent", str(output_path), "json")
def test_export_unsupported_format(temp_sessions_dir, tmp_path): def test_export_unsupported_format(temp_sessions_dir, tmp_path):
manager = SessionManager() manager = SessionManager()
@ -113,4 +128,4 @@ def test_export_unsupported_format(temp_sessions_dir, tmp_path):
manager.save_session(name, [{"role": "user", "content": "Test"}]) manager.save_session(name, [{"role": "user", "content": "Test"}])
output_path = tmp_path / "test.unsupported" output_path = tmp_path / "test.unsupported"
assert not manager.export_session(name, str(output_path), 'unsupported') assert not manager.export_session(name, str(output_path), "unsupported")

View File

@ -1,69 +1,68 @@
import pytest
import os import os
import tempfile
from pr.tools.filesystem import read_file, write_file, list_directory, search_replace
from pr.tools.patch import apply_patch, create_diff
from pr.tools.base import get_tools_definition from pr.tools.base import get_tools_definition
from pr.tools.filesystem import list_directory, read_file, search_replace, write_file
from pr.tools.patch import apply_patch, create_diff
class TestFilesystemTools: class TestFilesystemTools:
def test_write_and_read_file(self, temp_dir): def test_write_and_read_file(self, temp_dir):
filepath = os.path.join(temp_dir, 'test.txt') filepath = os.path.join(temp_dir, "test.txt")
content = 'Hello, World!' content = "Hello, World!"
write_result = write_file(filepath, content) write_result = write_file(filepath, content)
assert write_result['status'] == 'success' assert write_result["status"] == "success"
read_result = read_file(filepath) read_result = read_file(filepath)
assert read_result['status'] == 'success' assert read_result["status"] == "success"
assert content in read_result['content'] assert content in read_result["content"]
def test_read_nonexistent_file(self): def test_read_nonexistent_file(self):
result = read_file('/nonexistent/path/file.txt') result = read_file("/nonexistent/path/file.txt")
assert result['status'] == 'error' assert result["status"] == "error"
def test_list_directory(self, temp_dir): def test_list_directory(self, temp_dir):
test_file = os.path.join(temp_dir, 'testfile.txt') test_file = os.path.join(temp_dir, "testfile.txt")
with open(test_file, 'w') as f: with open(test_file, "w") as f:
f.write('test') f.write("test")
result = list_directory(temp_dir) result = list_directory(temp_dir)
assert result['status'] == 'success' assert result["status"] == "success"
assert any(item['name'] == 'testfile.txt' for item in result['items']) assert any(item["name"] == "testfile.txt" for item in result["items"])
def test_search_replace(self, temp_dir): def test_search_replace(self, temp_dir):
filepath = os.path.join(temp_dir, 'test.txt') filepath = os.path.join(temp_dir, "test.txt")
content = 'Hello, World!' content = "Hello, World!"
with open(filepath, 'w') as f: with open(filepath, "w") as f:
f.write(content) f.write(content)
result = search_replace(filepath, 'World', 'Universe') result = search_replace(filepath, "World", "Universe")
assert result['status'] == 'success' assert result["status"] == "success"
read_result = read_file(filepath) read_result = read_file(filepath)
assert 'Hello, Universe!' in read_result['content'] assert "Hello, Universe!" in read_result["content"]
class TestPatchTools: class TestPatchTools:
def test_create_diff(self, temp_dir): def test_create_diff(self, temp_dir):
file1 = os.path.join(temp_dir, 'file1.txt') file1 = os.path.join(temp_dir, "file1.txt")
file2 = os.path.join(temp_dir, 'file2.txt') file2 = os.path.join(temp_dir, "file2.txt")
with open(file1, 'w') as f: with open(file1, "w") as f:
f.write('line1\nline2\nline3\n') f.write("line1\nline2\nline3\n")
with open(file2, 'w') as f: with open(file2, "w") as f:
f.write('line1\nline2 modified\nline3\n') f.write("line1\nline2 modified\nline3\n")
result = create_diff(file1, file2) result = create_diff(file1, file2)
assert result['status'] == 'success' assert result["status"] == "success"
assert 'line2' in result['diff'] assert "line2" in result["diff"]
assert 'line2 modified' in result['diff'] assert "line2 modified" in result["diff"]
def test_apply_patch(self, temp_dir): def test_apply_patch(self, temp_dir):
filepath = os.path.join(temp_dir, 'file.txt') filepath = os.path.join(temp_dir, "file.txt")
with open(filepath, 'w') as f: with open(filepath, "w") as f:
f.write('line1\nline2\nline3\n') f.write("line1\nline2\nline3\n")
# Create a simple patch # Create a simple patch
patch_content = """--- a/file.txt patch_content = """--- a/file.txt
@ -75,10 +74,10 @@ class TestPatchTools:
line3 line3
""" """
result = apply_patch(filepath, patch_content) result = apply_patch(filepath, patch_content)
assert result['status'] == 'success' assert result["status"] == "success"
read_result = read_file(filepath) read_result = read_file(filepath)
assert 'line2 modified' in read_result['content'] assert "line2 modified" in read_result["content"]
class TestToolDefinitions: class TestToolDefinitions:
@ -92,27 +91,27 @@ class TestToolDefinitions:
tools = get_tools_definition() tools = get_tools_definition()
for tool in tools: for tool in tools:
assert 'type' in tool assert "type" in tool
assert tool['type'] == 'function' assert tool["type"] == "function"
assert 'function' in tool assert "function" in tool
func = tool['function'] func = tool["function"]
assert 'name' in func assert "name" in func
assert 'description' in func assert "description" in func
assert 'parameters' in func assert "parameters" in func
def test_filesystem_tools_present(self): def test_filesystem_tools_present(self):
tools = get_tools_definition() tools = get_tools_definition()
tool_names = [t['function']['name'] for t in tools] tool_names = [t["function"]["name"] for t in tools]
assert 'read_file' in tool_names assert "read_file" in tool_names
assert 'write_file' in tool_names assert "write_file" in tool_names
assert 'list_directory' in tool_names assert "list_directory" in tool_names
assert 'search_replace' in tool_names assert "search_replace" in tool_names
def test_patch_tools_present(self): def test_patch_tools_present(self):
tools = get_tools_definition() tools = get_tools_definition()
tool_names = [t['function']['name'] for t in tools] tool_names = [t["function"]["name"] for t in tools]
assert 'apply_patch' in tool_names assert "apply_patch" in tool_names
assert 'create_diff' in tool_names assert "create_diff" in tool_names

View File

@ -1,63 +1,71 @@
import pytest
import tempfile
import os
import json import json
import os
import pytest
from pr.core.usage_tracker import UsageTracker from pr.core.usage_tracker import UsageTracker
@pytest.fixture @pytest.fixture
def temp_usage_file(tmp_path, monkeypatch): def temp_usage_file(tmp_path, monkeypatch):
from pr.core import usage_tracker from pr.core import usage_tracker
original_file = usage_tracker.USAGE_DB_FILE original_file = usage_tracker.USAGE_DB_FILE
temp_file = str(tmp_path / "usage.json") temp_file = str(tmp_path / "usage.json")
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file) monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", temp_file)
yield temp_file yield temp_file
if os.path.exists(temp_file): if os.path.exists(temp_file):
os.remove(temp_file) os.remove(temp_file)
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file) monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", original_file)
def test_usage_tracker_init(): def test_usage_tracker_init():
tracker = UsageTracker() tracker = UsageTracker()
summary = tracker.get_session_summary() summary = tracker.get_session_summary()
assert summary['requests'] == 0 assert summary["requests"] == 0
assert summary['total_tokens'] == 0 assert summary["total_tokens"] == 0
assert summary['estimated_cost'] == 0.0 assert summary["estimated_cost"] == 0.0
def test_track_request_known_model(): def test_track_request_known_model():
tracker = UsageTracker() tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50) tracker.track_request("gpt-3.5-turbo", 100, 50)
summary = tracker.get_session_summary() summary = tracker.get_session_summary()
assert summary['requests'] == 1 assert summary["requests"] == 1
assert summary['input_tokens'] == 100 assert summary["input_tokens"] == 100
assert summary['output_tokens'] == 50 assert summary["output_tokens"] == 50
assert summary['total_tokens'] == 150 assert summary["total_tokens"] == 150
assert 'gpt-3.5-turbo' in summary['models_used'] assert "gpt-3.5-turbo" in summary["models_used"]
# Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125 # Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125
assert abs(summary['estimated_cost'] - 0.000125) < 1e-6 assert abs(summary["estimated_cost"] - 0.000125) < 1e-6
def test_track_request_unknown_model(): def test_track_request_unknown_model():
tracker = UsageTracker() tracker = UsageTracker()
tracker.track_request('unknown-model', 100, 50) tracker.track_request("unknown-model", 100, 50)
summary = tracker.get_session_summary() summary = tracker.get_session_summary()
assert summary['requests'] == 1 assert summary["requests"] == 1
assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0 assert summary["estimated_cost"] == 0.0 # Unknown model, cost 0
def test_track_request_multiple(): def test_track_request_multiple():
tracker = UsageTracker() tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50) tracker.track_request("gpt-3.5-turbo", 100, 50)
tracker.track_request('gpt-4', 200, 100) tracker.track_request("gpt-4", 200, 100)
summary = tracker.get_session_summary() summary = tracker.get_session_summary()
assert summary['requests'] == 2 assert summary["requests"] == 2
assert summary['input_tokens'] == 300 assert summary["input_tokens"] == 300
assert summary['output_tokens'] == 150 assert summary["output_tokens"] == 150
assert summary['total_tokens'] == 450 assert summary["total_tokens"] == 450
assert len(summary['models_used']) == 2 assert len(summary["models_used"]) == 2
def test_get_formatted_summary(): def test_get_formatted_summary():
tracker = UsageTracker() tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50) tracker.track_request("gpt-3.5-turbo", 100, 50)
formatted = tracker.get_formatted_summary() formatted = tracker.get_formatted_summary()
assert "Total Requests: 1" in formatted assert "Total Requests: 1" in formatted
@ -65,22 +73,38 @@ def test_get_formatted_summary():
assert "Estimated Cost: $0.0001" in formatted assert "Estimated Cost: $0.0001" in formatted
assert "gpt-3.5-turbo" in formatted assert "gpt-3.5-turbo" in formatted
def test_get_total_usage_no_file(temp_usage_file): def test_get_total_usage_no_file(temp_usage_file):
total = UsageTracker.get_total_usage() total = UsageTracker.get_total_usage()
assert total['total_requests'] == 0 assert total["total_requests"] == 0
assert total['total_tokens'] == 0 assert total["total_tokens"] == 0
assert total['total_cost'] == 0.0 assert total["total_cost"] == 0.0
def test_get_total_usage_with_data(temp_usage_file): def test_get_total_usage_with_data(temp_usage_file):
# Manually create history file # Manually create history file
history = [ history = [
{'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125}, {
{'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008} "timestamp": "2023-01-01",
"model": "gpt-3.5-turbo",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
"cost": 0.000125,
},
{
"timestamp": "2023-01-02",
"model": "gpt-4",
"input_tokens": 200,
"output_tokens": 100,
"total_tokens": 300,
"cost": 0.008,
},
] ]
with open(temp_usage_file, 'w') as f: with open(temp_usage_file, "w") as f:
json.dump(history, f) json.dump(history, f)
total = UsageTracker.get_total_usage() total = UsageTracker.get_total_usage()
assert total['total_requests'] == 2 assert total["total_requests"] == 2
assert total['total_tokens'] == 450 assert total["total_tokens"] == 450
assert abs(total['total_cost'] - 0.008125) < 1e-6 assert abs(total["total_cost"] - 0.008125) < 1e-6

View File

@ -1,49 +1,59 @@
import pytest
import tempfile
import os import os
import tempfile
import pytest
from pr.core.exceptions import ValidationError
from pr.core.validation import ( from pr.core.validation import (
validate_file_path,
validate_directory_path,
validate_model_name,
validate_api_url, validate_api_url,
validate_directory_path,
validate_file_path,
validate_max_tokens,
validate_model_name,
validate_session_name, validate_session_name,
validate_temperature, validate_temperature,
validate_max_tokens,
) )
from pr.core.exceptions import ValidationError
def test_validate_file_path_empty(): def test_validate_file_path_empty():
with pytest.raises(ValidationError, match="File path cannot be empty"): with pytest.raises(ValidationError, match="File path cannot be empty"):
validate_file_path("") validate_file_path("")
def test_validate_file_path_not_exist(): def test_validate_file_path_not_exist():
with pytest.raises(ValidationError, match="File does not exist"): with pytest.raises(ValidationError, match="File does not exist"):
validate_file_path("/nonexistent/file.txt", must_exist=True) validate_file_path("/nonexistent/file.txt", must_exist=True)
def test_validate_file_path_is_dir(): def test_validate_file_path_is_dir():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(ValidationError, match="Path is a directory"): with pytest.raises(ValidationError, match="Path is a directory"):
validate_file_path(tmpdir, must_exist=True) validate_file_path(tmpdir, must_exist=True)
def test_validate_file_path_valid(): def test_validate_file_path_valid():
with tempfile.NamedTemporaryFile() as tmpfile: with tempfile.NamedTemporaryFile() as tmpfile:
result = validate_file_path(tmpfile.name, must_exist=True) result = validate_file_path(tmpfile.name, must_exist=True)
assert os.path.isabs(result) assert os.path.isabs(result)
assert result == os.path.abspath(tmpfile.name) assert result == os.path.abspath(tmpfile.name)
def test_validate_directory_path_empty(): def test_validate_directory_path_empty():
with pytest.raises(ValidationError, match="Directory path cannot be empty"): with pytest.raises(ValidationError, match="Directory path cannot be empty"):
validate_directory_path("") validate_directory_path("")
def test_validate_directory_path_not_exist(): def test_validate_directory_path_not_exist():
with pytest.raises(ValidationError, match="Directory does not exist"): with pytest.raises(ValidationError, match="Directory does not exist"):
validate_directory_path("/nonexistent/dir", must_exist=True) validate_directory_path("/nonexistent/dir", must_exist=True)
def test_validate_directory_path_not_dir(): def test_validate_directory_path_not_dir():
with tempfile.NamedTemporaryFile() as tmpfile: with tempfile.NamedTemporaryFile() as tmpfile:
with pytest.raises(ValidationError, match="Path is not a directory"): with pytest.raises(ValidationError, match="Path is not a directory"):
validate_directory_path(tmpfile.name, must_exist=True) validate_directory_path(tmpfile.name, must_exist=True)
def test_validate_directory_path_create(): def test_validate_directory_path_create():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
new_dir = os.path.join(tmpdir, "new_dir") new_dir = os.path.join(tmpdir, "new_dir")
@ -51,72 +61,89 @@ def test_validate_directory_path_create():
assert os.path.isdir(new_dir) assert os.path.isdir(new_dir)
assert result == os.path.abspath(new_dir) assert result == os.path.abspath(new_dir)
def test_validate_directory_path_valid(): def test_validate_directory_path_valid():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
result = validate_directory_path(tmpdir, must_exist=True) result = validate_directory_path(tmpdir, must_exist=True)
assert result == os.path.abspath(tmpdir) assert result == os.path.abspath(tmpdir)
def test_validate_model_name_empty(): def test_validate_model_name_empty():
with pytest.raises(ValidationError, match="Model name cannot be empty"): with pytest.raises(ValidationError, match="Model name cannot be empty"):
validate_model_name("") validate_model_name("")
def test_validate_model_name_too_short(): def test_validate_model_name_too_short():
with pytest.raises(ValidationError, match="Model name too short"): with pytest.raises(ValidationError, match="Model name too short"):
validate_model_name("a") validate_model_name("a")
def test_validate_model_name_valid(): def test_validate_model_name_valid():
result = validate_model_name("gpt-3.5-turbo") result = validate_model_name("gpt-3.5-turbo")
assert result == "gpt-3.5-turbo" assert result == "gpt-3.5-turbo"
def test_validate_api_url_empty(): def test_validate_api_url_empty():
with pytest.raises(ValidationError, match="API URL cannot be empty"): with pytest.raises(ValidationError, match="API URL cannot be empty"):
validate_api_url("") validate_api_url("")
def test_validate_api_url_invalid(): def test_validate_api_url_invalid():
with pytest.raises(ValidationError, match="API URL must start with"): with pytest.raises(ValidationError, match="API URL must start with"):
validate_api_url("invalid-url") validate_api_url("invalid-url")
def test_validate_api_url_valid(): def test_validate_api_url_valid():
result = validate_api_url("https://api.example.com") result = validate_api_url("https://api.example.com")
assert result == "https://api.example.com" assert result == "https://api.example.com"
def test_validate_session_name_empty(): def test_validate_session_name_empty():
with pytest.raises(ValidationError, match="Session name cannot be empty"): with pytest.raises(ValidationError, match="Session name cannot be empty"):
validate_session_name("") validate_session_name("")
def test_validate_session_name_invalid_char(): def test_validate_session_name_invalid_char():
with pytest.raises(ValidationError, match="contains invalid character"): with pytest.raises(ValidationError, match="contains invalid character"):
validate_session_name("test/session") validate_session_name("test/session")
def test_validate_session_name_too_long(): def test_validate_session_name_too_long():
long_name = "a" * 256 long_name = "a" * 256
with pytest.raises(ValidationError, match="Session name too long"): with pytest.raises(ValidationError, match="Session name too long"):
validate_session_name(long_name) validate_session_name(long_name)
def test_validate_session_name_valid(): def test_validate_session_name_valid():
result = validate_session_name("valid_session_123") result = validate_session_name("valid_session_123")
assert result == "valid_session_123" assert result == "valid_session_123"
def test_validate_temperature_too_low(): def test_validate_temperature_too_low():
with pytest.raises(ValidationError, match="Temperature must be between"): with pytest.raises(ValidationError, match="Temperature must be between"):
validate_temperature(-0.1) validate_temperature(-0.1)
def test_validate_temperature_too_high(): def test_validate_temperature_too_high():
with pytest.raises(ValidationError, match="Temperature must be between"): with pytest.raises(ValidationError, match="Temperature must be between"):
validate_temperature(2.1) validate_temperature(2.1)
def test_validate_temperature_valid(): def test_validate_temperature_valid():
result = validate_temperature(0.7) result = validate_temperature(0.7)
assert result == 0.7 assert result == 0.7
def test_validate_max_tokens_too_low(): def test_validate_max_tokens_too_low():
with pytest.raises(ValidationError, match="Max tokens must be at least 1"): with pytest.raises(ValidationError, match="Max tokens must be at least 1"):
validate_max_tokens(0) validate_max_tokens(0)
def test_validate_max_tokens_too_high(): def test_validate_max_tokens_too_high():
with pytest.raises(ValidationError, match="Max tokens too high"): with pytest.raises(ValidationError, match="Max tokens too high"):
validate_max_tokens(100001) validate_max_tokens(100001)
def test_validate_max_tokens_valid(): def test_validate_max_tokens_valid():
result = validate_max_tokens(1000) result = validate_max_tokens(1000)
assert result == 1000 assert result == 1000