Update.
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (ubuntu-latest, 3.8) (push) Waiting to run
Tests / test (ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 39s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 55s
Tests / test (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / test (ubuntu-latest, 3.12) (push) Has been cancelled

This commit is contained in:
retoor 2025-11-04 08:09:12 +01:00
parent 5f04811dcc
commit 1a29ee4918
82 changed files with 4965 additions and 3096 deletions

View File

@ -159,6 +159,7 @@ def tool_function(args):
"""Implementation"""
pass
def register_tools():
"""Return list of tool definitions"""
return [...]
@ -177,8 +178,8 @@ def register_tools():
```python
def test_read_file_with_valid_path_returns_content(temp_dir):
# Arrange
filepath = os.path.join(temp_dir, 'test.txt')
expected_content = 'Hello, World!'
filepath = os.path.join(temp_dir, "test.txt")
expected_content = "Hello, World!"
write_file(filepath, expected_content)
# Act

View File

@ -1,4 +1,4 @@
from pr.core import Assistant
__version__ = '1.0.0'
__all__ = ['Assistant']
__version__ = "1.0.0"
__all__ = ["Assistant"]

View File

@ -1,12 +1,14 @@
import argparse
import sys
from pr.core import Assistant
from pr import __version__
from pr.core import Assistant
def main():
parser = argparse.ArgumentParser(
description='PR Assistant - Professional CLI AI assistant with autonomous execution',
epilog='''
description="PR Assistant - Professional CLI AI assistant with autonomous execution",
epilog="""
Examples:
pr "What is Python?" # Single query
pr -i # Interactive mode
@ -25,43 +27,79 @@ Commands in interactive mode:
/usage - Show usage statistics
/save <name> - Save current session
exit, quit, q - Exit the program
''',
formatter_class=argparse.RawDescriptionHelpFormatter
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument('message', nargs='?', help='Message to send to assistant')
parser.add_argument('--version', action='version', version=f'PR Assistant {__version__}')
parser.add_argument('-m', '--model', help='AI model to use')
parser.add_argument('-u', '--api-url', help='API endpoint URL')
parser.add_argument('--model-list-url', help='Model list endpoint URL')
parser.add_argument('-i', '--interactive', action='store_true', help='Interactive mode')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output')
parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging')
parser.add_argument('--no-syntax', action='store_true', help='Disable syntax highlighting')
parser.add_argument('--include-env', action='store_true', help='Include environment variables in context')
parser.add_argument('-c', '--context', action='append', help='Additional context files')
parser.add_argument('--api-mode', action='store_true', help='API mode for specialized interaction')
parser.add_argument("message", nargs="?", help="Message to send to assistant")
parser.add_argument(
"--version", action="version", version=f"PR Assistant {__version__}"
)
parser.add_argument("-m", "--model", help="AI model to use")
parser.add_argument("-u", "--api-url", help="API endpoint URL")
parser.add_argument("--model-list-url", help="Model list endpoint URL")
parser.add_argument(
"-i", "--interactive", action="store_true", help="Interactive mode"
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument(
"--debug", action="store_true", help="Enable debug mode with detailed logging"
)
parser.add_argument(
"--no-syntax", action="store_true", help="Disable syntax highlighting"
)
parser.add_argument(
"--include-env",
action="store_true",
help="Include environment variables in context",
)
parser.add_argument(
"-c", "--context", action="append", help="Additional context files"
)
parser.add_argument(
"--api-mode", action="store_true", help="API mode for specialized interaction"
)
parser.add_argument('--output', choices=['text', 'json', 'structured'],
default='text', help='Output format')
parser.add_argument('--quiet', action='store_true', help='Minimal output')
parser.add_argument(
"--output",
choices=["text", "json", "structured"],
default="text",
help="Output format",
)
parser.add_argument("--quiet", action="store_true", help="Minimal output")
parser.add_argument('--save-session', metavar='NAME', help='Save session with given name')
parser.add_argument('--load-session', metavar='NAME', help='Load session with given name')
parser.add_argument('--list-sessions', action='store_true', help='List all saved sessions')
parser.add_argument('--delete-session', metavar='NAME', help='Delete a saved session')
parser.add_argument('--export-session', nargs=2, metavar=('NAME', 'FILE'),
help='Export session to file')
parser.add_argument(
"--save-session", metavar="NAME", help="Save session with given name"
)
parser.add_argument(
"--load-session", metavar="NAME", help="Load session with given name"
)
parser.add_argument(
"--list-sessions", action="store_true", help="List all saved sessions"
)
parser.add_argument(
"--delete-session", metavar="NAME", help="Delete a saved session"
)
parser.add_argument(
"--export-session",
nargs=2,
metavar=("NAME", "FILE"),
help="Export session to file",
)
parser.add_argument('--usage', action='store_true', help='Show token usage statistics')
parser.add_argument('--create-config', action='store_true',
help='Create default configuration file')
parser.add_argument('--plugins', action='store_true', help='List loaded plugins')
parser.add_argument(
"--usage", action="store_true", help="Show token usage statistics"
)
parser.add_argument(
"--create-config", action="store_true", help="Create default configuration file"
)
parser.add_argument("--plugins", action="store_true", help="List loaded plugins")
args = parser.parse_args()
if args.create_config:
from pr.core.config_loader import create_default_config
if create_default_config():
print("Configuration file created at ~/.prrc")
else:
@ -70,6 +108,7 @@ Commands in interactive mode:
if args.list_sessions:
from pr.core.session import SessionManager
sm = SessionManager()
sessions = sm.list_sessions()
if not sessions:
@ -85,6 +124,7 @@ Commands in interactive mode:
if args.delete_session:
from pr.core.session import SessionManager
sm = SessionManager()
if sm.delete_session(args.delete_session):
print(f"Session '{args.delete_session}' deleted")
@ -94,13 +134,14 @@ Commands in interactive mode:
if args.export_session:
from pr.core.session import SessionManager
sm = SessionManager()
name, output_file = args.export_session
format_type = 'json'
if output_file.endswith('.md'):
format_type = 'markdown'
elif output_file.endswith('.txt'):
format_type = 'txt'
format_type = "json"
if output_file.endswith(".md"):
format_type = "markdown"
elif output_file.endswith(".txt"):
format_type = "txt"
if sm.export_session(name, output_file, format_type):
print(f"Session exported to {output_file}")
@ -110,6 +151,7 @@ Commands in interactive mode:
if args.usage:
from pr.core.usage_tracker import UsageTracker
usage = UsageTracker.get_total_usage()
print(f"\nTotal Usage Statistics:")
print(f" Requests: {usage['total_requests']}")
@ -119,6 +161,7 @@ Commands in interactive mode:
if args.plugins:
from pr.plugins.loader import PluginLoader
loader = PluginLoader()
loader.load_plugins()
plugins = loader.list_loaded_plugins()
@ -133,5 +176,6 @@ Commands in interactive mode:
assistant = Assistant(args)
assistant.run()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,6 +1,13 @@
from .agent_communication import AgentCommunicationBus, AgentMessage
from .agent_manager import AgentInstance, AgentManager
from .agent_roles import AgentRole, get_agent_role, list_agent_roles
from .agent_manager import AgentManager, AgentInstance
from .agent_communication import AgentMessage, AgentCommunicationBus
__all__ = ['AgentRole', 'get_agent_role', 'list_agent_roles', 'AgentManager', 'AgentInstance',
'AgentMessage', 'AgentCommunicationBus']
__all__ = [
"AgentRole",
"get_agent_role",
"list_agent_roles",
"AgentManager",
"AgentInstance",
"AgentMessage",
"AgentCommunicationBus",
]

View File

@ -1,14 +1,16 @@
import sqlite3
import json
from typing import List, Optional
import sqlite3
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
class MessageType(Enum):
REQUEST = "request"
RESPONSE = "response"
NOTIFICATION = "notification"
@dataclass
class AgentMessage:
message_id: str
@ -21,27 +23,28 @@ class AgentMessage:
def to_dict(self) -> dict:
return {
'message_id': self.message_id,
'from_agent': self.from_agent,
'to_agent': self.to_agent,
'message_type': self.message_type.value,
'content': self.content,
'metadata': self.metadata,
'timestamp': self.timestamp
"message_id": self.message_id,
"from_agent": self.from_agent,
"to_agent": self.to_agent,
"message_type": self.message_type.value,
"content": self.content,
"metadata": self.metadata,
"timestamp": self.timestamp,
}
@classmethod
def from_dict(cls, data: dict) -> 'AgentMessage':
def from_dict(cls, data: dict) -> "AgentMessage":
return cls(
message_id=data['message_id'],
from_agent=data['from_agent'],
to_agent=data['to_agent'],
message_type=MessageType(data['message_type']),
content=data['content'],
metadata=data['metadata'],
timestamp=data['timestamp']
message_id=data["message_id"],
from_agent=data["from_agent"],
to_agent=data["to_agent"],
message_type=MessageType(data["message_type"]),
content=data["content"],
metadata=data["metadata"],
timestamp=data["timestamp"],
)
class AgentCommunicationBus:
def __init__(self, db_path: str):
self.db_path = db_path
@ -50,7 +53,8 @@ class AgentCommunicationBus:
def _create_tables(self):
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS agent_messages (
message_id TEXT PRIMARY KEY,
from_agent TEXT,
@ -62,70 +66,88 @@ class AgentCommunicationBus:
session_id TEXT,
read INTEGER DEFAULT 0
)
''')
"""
)
self.conn.commit()
def send_message(self, message: AgentMessage, session_id: Optional[str] = None):
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT INTO agent_messages
(message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
message.message_id,
message.from_agent,
message.to_agent,
message.message_type.value,
message.content,
json.dumps(message.metadata),
message.timestamp,
session_id
))
""",
(
message.message_id,
message.from_agent,
message.to_agent,
message.message_type.value,
message.content,
json.dumps(message.metadata),
message.timestamp,
session_id,
),
)
self.conn.commit()
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
def get_messages(
self, agent_id: str, unread_only: bool = True
) -> List[AgentMessage]:
cursor = self.conn.cursor()
if unread_only:
cursor.execute('''
cursor.execute(
"""
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages
WHERE to_agent = ? AND read = 0
ORDER BY timestamp ASC
''', (agent_id,))
""",
(agent_id,),
)
else:
cursor.execute('''
cursor.execute(
"""
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages
WHERE to_agent = ?
ORDER BY timestamp ASC
''', (agent_id,))
""",
(agent_id,),
)
messages = []
for row in cursor.fetchall():
messages.append(AgentMessage(
message_id=row[0],
from_agent=row[1],
to_agent=row[2],
message_type=MessageType(row[3]),
content=row[4],
metadata=json.loads(row[5]) if row[5] else {},
timestamp=row[6]
))
messages.append(
AgentMessage(
message_id=row[0],
from_agent=row[1],
to_agent=row[2],
message_type=MessageType(row[3]),
content=row[4],
metadata=json.loads(row[5]) if row[5] else {},
timestamp=row[6],
)
)
return messages
def mark_as_read(self, message_id: str):
cursor = self.conn.cursor()
cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,))
cursor.execute(
"UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,)
)
self.conn.commit()
def clear_messages(self, session_id: Optional[str] = None):
cursor = self.conn.cursor()
if session_id:
cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,))
cursor.execute(
"DELETE FROM agent_messages WHERE session_id = ?", (session_id,)
)
else:
cursor.execute('DELETE FROM agent_messages')
cursor.execute("DELETE FROM agent_messages")
self.conn.commit()
def close(self):
@ -134,24 +156,31 @@ class AgentCommunicationBus:
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
return self.get_messages(agent_id, unread_only=True)
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
def get_conversation_history(
self, agent_a: str, agent_b: str
) -> List[AgentMessage]:
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages
WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?)
ORDER BY timestamp ASC
''', (agent_a, agent_b, agent_b, agent_a))
""",
(agent_a, agent_b, agent_b, agent_a),
)
messages = []
for row in cursor.fetchall():
messages.append(AgentMessage(
message_id=row[0],
from_agent=row[1],
to_agent=row[2],
message_type=MessageType(row[3]),
content=row[4],
metadata=json.loads(row[5]) if row[5] else {},
timestamp=row[6]
))
return messages
messages.append(
AgentMessage(
message_id=row[0],
from_agent=row[1],
to_agent=row[2],
message_type=MessageType(row[3]),
content=row[4],
metadata=json.loads(row[5]) if row[5] else {},
timestamp=row[6],
)
)
return messages

View File

@ -1,11 +1,13 @@
import time
import json
import time
import uuid
from typing import Dict, List, Any, Optional, Callable
from dataclasses import dataclass, field
from .agent_roles import AgentRole, get_agent_role
from .agent_communication import AgentMessage, AgentCommunicationBus, MessageType
from typing import Any, Callable, Dict, List, Optional
from ..memory.knowledge_store import KnowledgeStore
from .agent_communication import AgentCommunicationBus, AgentMessage, MessageType
from .agent_roles import AgentRole, get_agent_role
@dataclass
class AgentInstance:
@ -17,21 +19,20 @@ class AgentInstance:
task_count: int = 0
def add_message(self, role: str, content: str):
self.message_history.append({
'role': role,
'content': content,
'timestamp': time.time()
})
self.message_history.append(
{"role": role, "content": content, "timestamp": time.time()}
)
def get_system_message(self) -> Dict[str, str]:
return {'role': 'system', 'content': self.role.system_prompt}
return {"role": "system", "content": self.role.system_prompt}
def get_messages_for_api(self) -> List[Dict[str, str]]:
return [self.get_system_message()] + [
{'role': msg['role'], 'content': msg['content']}
{"role": msg["role"], "content": msg["content"]}
for msg in self.message_history
]
class AgentManager:
def __init__(self, db_path: str, api_caller: Callable):
self.db_path = db_path
@ -46,32 +47,31 @@ class AgentManager:
agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}"
role = get_agent_role(role_name)
agent = AgentInstance(
agent_id=agent_id,
role=role
)
agent = AgentInstance(agent_id=agent_id, role=role)
self.active_agents[agent_id] = agent
return agent_id
def get_agent(self, agent_id: str) -> Optional[AgentInstance]:
return self.active_agents.get(agent_id)
def remove_agent(self, agent_id: str) -> bool:
if agent_id in self.active_agents:
del self.active_agents[agent_id]
return True
return False
def execute_agent_task(self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
def execute_agent_task(
self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
agent = self.get_agent(agent_id)
if not agent:
return {'error': f'Agent {agent_id} not found'}
return {"error": f"Agent {agent_id} not found"}
if context:
agent.context.update(context)
agent.add_message('user', task)
agent.add_message("user", task)
knowledge_matches = self.knowledge_store.search_entries(task, top_k=3)
agent.task_count += 1
@ -81,35 +81,40 @@ class AgentManager:
for i, entry in enumerate(knowledge_matches, 1):
shortened_content = entry.content[:2000]
knowledge_content += f"{i}. {shortened_content}\\n\\n"
messages.insert(-1, {'role': 'user', 'content': knowledge_content})
messages.insert(-1, {"role": "user", "content": knowledge_content})
try:
response = self.api_caller(
messages=messages,
temperature=agent.role.temperature,
max_tokens=agent.role.max_tokens
max_tokens=agent.role.max_tokens,
)
if response and 'choices' in response:
assistant_message = response['choices'][0]['message']['content']
agent.add_message('assistant', assistant_message)
if response and "choices" in response:
assistant_message = response["choices"][0]["message"]["content"]
agent.add_message("assistant", assistant_message)
return {
'success': True,
'agent_id': agent_id,
'response': assistant_message,
'role': agent.role.name,
'task_count': agent.task_count
"success": True,
"agent_id": agent_id,
"response": assistant_message,
"role": agent.role.name,
"task_count": agent.task_count,
}
else:
return {'error': 'Invalid API response', 'agent_id': agent_id}
return {"error": "Invalid API response", "agent_id": agent_id}
except Exception as e:
return {'error': str(e), 'agent_id': agent_id}
return {"error": str(e), "agent_id": agent_id}
def send_agent_message(self, from_agent_id: str, to_agent_id: str,
content: str, message_type: MessageType = MessageType.REQUEST,
metadata: Optional[Dict[str, Any]] = None):
def send_agent_message(
self,
from_agent_id: str,
to_agent_id: str,
content: str,
message_type: MessageType = MessageType.REQUEST,
metadata: Optional[Dict[str, Any]] = None,
):
message = AgentMessage(
from_agent=from_agent_id,
to_agent=to_agent_id,
@ -117,57 +122,57 @@ class AgentManager:
content=content,
metadata=metadata or {},
timestamp=time.time(),
message_id=str(uuid.uuid4())[:16]
message_id=str(uuid.uuid4())[:16],
)
self.communication_bus.send_message(message, self.session_id)
return message.message_id
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
def get_agent_messages(
self, agent_id: str, unread_only: bool = True
) -> List[AgentMessage]:
return self.communication_bus.get_messages(agent_id, unread_only)
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
def collaborate_agents(
self, orchestrator_id: str, task: str, agent_roles: List[str]
):
orchestrator = self.get_agent(orchestrator_id)
if not orchestrator:
orchestrator_id = self.create_agent('orchestrator')
orchestrator_id = self.create_agent("orchestrator")
orchestrator = self.get_agent(orchestrator_id)
worker_agents = []
for role in agent_roles:
agent_id = self.create_agent(role)
worker_agents.append({
'agent_id': agent_id,
'role': role
})
worker_agents.append({"agent_id": agent_id, "role": role})
orchestration_prompt = f'''Task: {task}
orchestration_prompt = f"""Task: {task}
Available specialized agents:
{chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])}
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.'''
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results."""
orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt)
orchestrator_result = self.execute_agent_task(
orchestrator_id, orchestration_prompt
)
results = {
'orchestrator': orchestrator_result,
'agents': []
}
results = {"orchestrator": orchestrator_result, "agents": []}
for agent_info in worker_agents:
agent_id = agent_info['agent_id']
agent_id = agent_info["agent_id"]
messages = self.get_agent_messages(agent_id)
for msg in messages:
subtask = msg.content
result = self.execute_agent_task(agent_id, subtask)
results['agents'].append(result)
results["agents"].append(result)
self.send_agent_message(
from_agent_id=agent_id,
to_agent_id=orchestrator_id,
content=result.get('response', ''),
message_type=MessageType.RESPONSE
content=result.get("response", ""),
message_type=MessageType.RESPONSE,
)
self.communication_bus.mark_as_read(msg.message_id)
@ -175,21 +180,21 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
def get_session_summary(self) -> str:
summary = {
'session_id': self.session_id,
'active_agents': len(self.active_agents),
'agents': [
"session_id": self.session_id,
"active_agents": len(self.active_agents),
"agents": [
{
'agent_id': agent_id,
'role': agent.role.name,
'task_count': agent.task_count,
'message_count': len(agent.message_history)
"agent_id": agent_id,
"role": agent.role.name,
"task_count": agent.task_count,
"message_count": len(agent.message_history),
}
for agent_id, agent in self.active_agents.items()
]
],
}
return json.dumps(summary)
def clear_session(self):
self.active_agents.clear()
self.communication_bus.clear_messages(session_id=self.session_id)
self.session_id = str(uuid.uuid4())[:16]
self.session_id = str(uuid.uuid4())[:16]

View File

@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import List, Dict, Any, Set
from typing import Dict, List, Set
@dataclass
class AgentRole:
@ -11,182 +12,262 @@ class AgentRole:
temperature: float = 0.7
max_tokens: int = 4096
AGENT_ROLES = {
'coding': AgentRole(
name='coding',
description='Specialized in writing, reviewing, and debugging code',
system_prompt='''You are a coding specialist AI assistant. Your primary responsibilities:
"coding": AgentRole(
name="coding",
description="Specialized in writing, reviewing, and debugging code",
system_prompt="""You are a coding specialist AI assistant. Your primary responsibilities:
- Write clean, efficient, well-structured code
- Review code for bugs, security issues, and best practices
- Refactor and optimize existing code
- Implement features based on specifications
- Follow language-specific conventions and patterns
Focus on code quality, maintainability, and performance.''',
Focus on code quality, maintainability, and performance.""",
allowed_tools={
'read_file', 'write_file', 'list_directory', 'create_directory',
'change_directory', 'get_current_directory', 'python_exec',
'run_command', 'index_directory'
"read_file",
"write_file",
"list_directory",
"create_directory",
"change_directory",
"get_current_directory",
"python_exec",
"run_command",
"index_directory",
},
specialization_areas=['code_writing', 'code_review', 'debugging', 'refactoring'],
temperature=0.3
specialization_areas=[
"code_writing",
"code_review",
"debugging",
"refactoring",
],
temperature=0.3,
),
'research': AgentRole(
name='research',
description='Specialized in information gathering and analysis',
system_prompt='''You are a research specialist AI assistant. Your primary responsibilities:
"research": AgentRole(
name="research",
description="Specialized in information gathering and analysis",
system_prompt="""You are a research specialist AI assistant. Your primary responsibilities:
- Search for and gather relevant information
- Analyze data and documentation
- Synthesize findings into clear summaries
- Verify facts and cross-reference sources
- Identify trends and patterns in information
Focus on accuracy, thoroughness, and clear communication of findings.''',
Focus on accuracy, thoroughness, and clear communication of findings.""",
allowed_tools={
'read_file', 'list_directory', 'index_directory',
'http_fetch', 'web_search', 'web_search_news',
'db_query', 'db_get'
"read_file",
"list_directory",
"index_directory",
"http_fetch",
"web_search",
"web_search_news",
"db_query",
"db_get",
},
specialization_areas=['information_gathering', 'analysis', 'documentation', 'fact_checking'],
temperature=0.5
specialization_areas=[
"information_gathering",
"analysis",
"documentation",
"fact_checking",
],
temperature=0.5,
),
'data_analysis': AgentRole(
name='data_analysis',
description='Specialized in data processing and analysis',
system_prompt='''You are a data analysis specialist AI assistant. Your primary responsibilities:
"data_analysis": AgentRole(
name="data_analysis",
description="Specialized in data processing and analysis",
system_prompt="""You are a data analysis specialist AI assistant. Your primary responsibilities:
- Process and analyze structured and unstructured data
- Perform statistical analysis and pattern recognition
- Query databases and extract insights
- Create data summaries and reports
- Identify anomalies and trends
Focus on accuracy, data integrity, and actionable insights.''',
Focus on accuracy, data integrity, and actionable insights.""",
allowed_tools={
'db_query', 'db_get', 'db_set', 'read_file', 'write_file',
'python_exec', 'run_command', 'list_directory'
"db_query",
"db_get",
"db_set",
"read_file",
"write_file",
"python_exec",
"run_command",
"list_directory",
},
specialization_areas=['data_processing', 'statistical_analysis', 'database_operations'],
temperature=0.3
specialization_areas=[
"data_processing",
"statistical_analysis",
"database_operations",
],
temperature=0.3,
),
'planning': AgentRole(
name='planning',
description='Specialized in task planning and coordination',
system_prompt='''You are a planning specialist AI assistant. Your primary responsibilities:
"planning": AgentRole(
name="planning",
description="Specialized in task planning and coordination",
system_prompt="""You are a planning specialist AI assistant. Your primary responsibilities:
- Break down complex tasks into manageable steps
- Create execution plans and workflows
- Identify dependencies and prerequisites
- Estimate effort and resource requirements
- Coordinate between different components
Focus on logical organization, completeness, and feasibility.''',
Focus on logical organization, completeness, and feasibility.""",
allowed_tools={
'read_file', 'write_file', 'list_directory', 'index_directory',
'db_set', 'db_get'
"read_file",
"write_file",
"list_directory",
"index_directory",
"db_set",
"db_get",
},
specialization_areas=['task_decomposition', 'workflow_design', 'coordination'],
temperature=0.6
specialization_areas=["task_decomposition", "workflow_design", "coordination"],
temperature=0.6,
),
'testing': AgentRole(
name='testing',
description='Specialized in testing and quality assurance',
system_prompt='''You are a testing specialist AI assistant. Your primary responsibilities:
"testing": AgentRole(
name="testing",
description="Specialized in testing and quality assurance",
system_prompt="""You are a testing specialist AI assistant. Your primary responsibilities:
- Design and execute test cases
- Identify edge cases and potential failures
- Verify functionality and correctness
- Test error handling and edge conditions
- Ensure code meets quality standards
Focus on thoroughness, coverage, and issue identification.''',
Focus on thoroughness, coverage, and issue identification.""",
allowed_tools={
'read_file', 'write_file', 'python_exec', 'run_command',
'list_directory', 'db_query'
"read_file",
"write_file",
"python_exec",
"run_command",
"list_directory",
"db_query",
},
specialization_areas=['test_design', 'quality_assurance', 'validation'],
temperature=0.4
specialization_areas=["test_design", "quality_assurance", "validation"],
temperature=0.4,
),
'documentation': AgentRole(
name='documentation',
description='Specialized in creating and maintaining documentation',
system_prompt='''You are a documentation specialist AI assistant. Your primary responsibilities:
"documentation": AgentRole(
name="documentation",
description="Specialized in creating and maintaining documentation",
system_prompt="""You are a documentation specialist AI assistant. Your primary responsibilities:
- Write clear, comprehensive documentation
- Create API references and user guides
- Document code with comments and docstrings
- Organize and structure information logically
- Ensure documentation is up-to-date and accurate
Focus on clarity, completeness, and user-friendliness.''',
Focus on clarity, completeness, and user-friendliness.""",
allowed_tools={
'read_file', 'write_file', 'list_directory', 'index_directory',
'http_fetch', 'web_search'
"read_file",
"write_file",
"list_directory",
"index_directory",
"http_fetch",
"web_search",
},
specialization_areas=['technical_writing', 'documentation_organization', 'user_guides'],
temperature=0.6
specialization_areas=[
"technical_writing",
"documentation_organization",
"user_guides",
],
temperature=0.6,
),
'orchestrator': AgentRole(
name='orchestrator',
description='Coordinates multiple agents and manages overall execution',
system_prompt='''You are an orchestrator AI assistant. Your primary responsibilities:
"orchestrator": AgentRole(
name="orchestrator",
description="Coordinates multiple agents and manages overall execution",
system_prompt="""You are an orchestrator AI assistant. Your primary responsibilities:
- Coordinate multiple specialized agents
- Delegate tasks to appropriate agents
- Integrate results from different agents
- Manage overall workflow execution
- Ensure task completion and quality
Focus on effective delegation, integration, and overall success.''',
Focus on effective delegation, integration, and overall success.""",
allowed_tools={
'read_file', 'write_file', 'list_directory', 'db_set', 'db_get', 'db_query'
"read_file",
"write_file",
"list_directory",
"db_set",
"db_get",
"db_query",
},
specialization_areas=['agent_coordination', 'task_delegation', 'result_integration'],
temperature=0.5
specialization_areas=[
"agent_coordination",
"task_delegation",
"result_integration",
],
temperature=0.5,
),
'general': AgentRole(
name='general',
description='General purpose agent for miscellaneous tasks',
system_prompt='''You are a general purpose AI assistant. Your responsibilities:
"general": AgentRole(
name="general",
description="General purpose agent for miscellaneous tasks",
system_prompt="""You are a general purpose AI assistant. Your responsibilities:
- Handle diverse tasks across multiple domains
- Provide balanced assistance for various needs
- Adapt to different types of requests
- Collaborate with specialized agents when needed
Focus on versatility, helpfulness, and task completion.''',
Focus on versatility, helpfulness, and task completion.""",
allowed_tools={
'read_file', 'write_file', 'list_directory', 'create_directory',
'change_directory', 'get_current_directory', 'python_exec',
'run_command', 'run_command_interactive', 'http_fetch',
'web_search', 'web_search_news', 'db_set', 'db_get', 'db_query',
'index_directory'
"read_file",
"write_file",
"list_directory",
"create_directory",
"change_directory",
"get_current_directory",
"python_exec",
"run_command",
"run_command_interactive",
"http_fetch",
"web_search",
"web_search_news",
"db_set",
"db_get",
"db_query",
"index_directory",
},
specialization_areas=['general_assistance'],
temperature=0.7
)
specialization_areas=["general_assistance"],
temperature=0.7,
),
}
def get_agent_role(role_name: str) -> AgentRole:
return AGENT_ROLES.get(role_name, AGENT_ROLES['general'])
return AGENT_ROLES.get(role_name, AGENT_ROLES["general"])
def list_agent_roles() -> Dict[str, AgentRole]:
return AGENT_ROLES.copy()
def get_recommended_agent(task_description: str) -> str:
task_lower = task_description.lower()
code_keywords = ['code', 'implement', 'function', 'class', 'bug', 'debug', 'refactor', 'optimize']
research_keywords = ['search', 'find', 'research', 'information', 'analyze', 'investigate']
data_keywords = ['data', 'database', 'query', 'statistics', 'analyze', 'process']
planning_keywords = ['plan', 'organize', 'workflow', 'steps', 'coordinate']
testing_keywords = ['test', 'verify', 'validate', 'check', 'quality']
doc_keywords = ['document', 'documentation', 'explain', 'guide', 'manual']
code_keywords = [
"code",
"implement",
"function",
"class",
"bug",
"debug",
"refactor",
"optimize",
]
research_keywords = [
"search",
"find",
"research",
"information",
"analyze",
"investigate",
]
data_keywords = ["data", "database", "query", "statistics", "analyze", "process"]
planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"]
testing_keywords = ["test", "verify", "validate", "check", "quality"]
doc_keywords = ["document", "documentation", "explain", "guide", "manual"]
if any(keyword in task_lower for keyword in code_keywords):
return 'coding'
return "coding"
elif any(keyword in task_lower for keyword in research_keywords):
return 'research'
return "research"
elif any(keyword in task_lower for keyword in data_keywords):
return 'data_analysis'
return "data_analysis"
elif any(keyword in task_lower for keyword in planning_keywords):
return 'planning'
return "planning"
elif any(keyword in task_lower for keyword in testing_keywords):
return 'testing'
return "testing"
elif any(keyword in task_lower for keyword in doc_keywords):
return 'documentation'
return "documentation"
else:
return 'general'
return "general"

View File

@ -1,4 +1,4 @@
from pr.autonomous.detection import is_task_complete
from pr.autonomous.mode import run_autonomous_mode, process_response_autonomous
from pr.autonomous.mode import process_response_autonomous, run_autonomous_mode
__all__ = ['is_task_complete', 'run_autonomous_mode', 'process_response_autonomous']
__all__ = ["is_task_complete", "run_autonomous_mode", "process_response_autonomous"]

View File

@ -1,28 +1,39 @@
from pr.config import MAX_AUTONOMOUS_ITERATIONS
from pr.ui import Colors
def is_task_complete(response, iteration):
if 'error' in response:
if "error" in response:
return True
if 'choices' not in response or not response['choices']:
if "choices" not in response or not response["choices"]:
return True
message = response['choices'][0]['message']
content = message.get('content', '').lower()
message = response["choices"][0]["message"]
content = message.get("content", "").lower()
completion_keywords = [
'task complete', 'task is complete', 'finished', 'done',
'successfully completed', 'task accomplished', 'all done',
'implementation complete', 'setup complete', 'installation complete'
"task complete",
"task is complete",
"finished",
"done",
"successfully completed",
"task accomplished",
"all done",
"implementation complete",
"setup complete",
"installation complete",
]
error_keywords = [
'cannot proceed', 'unable to continue', 'fatal error',
'cannot complete', 'impossible to'
"cannot proceed",
"unable to continue",
"fatal error",
"cannot complete",
"impossible to",
]
has_tool_calls = 'tool_calls' in message and message['tool_calls']
has_tool_calls = "tool_calls" in message and message["tool_calls"]
mentions_completion = any(keyword in content for keyword in completion_keywords)
mentions_error = any(keyword in content for keyword in error_keywords)

View File

@ -1,11 +1,13 @@
import time
import json
import logging
from pr.ui import Colors, display_tool_call, print_autonomous_header
import time
from pr.autonomous.detection import is_task_complete
from pr.core.context import truncate_tool_result
from pr.ui import Colors, display_tool_call
logger = logging.getLogger("pr")
logger = logging.getLogger('pr')
def run_autonomous_mode(assistant, task):
assistant.autonomous_mode = True
@ -14,25 +16,32 @@ def run_autonomous_mode(assistant, task):
logger.debug(f"=== AUTONOMOUS MODE START ===")
logger.debug(f"Task: {task}")
assistant.messages.append({
"role": "user",
"content": f"{task}"
})
assistant.messages.append({"role": "user", "content": f"{task}"})
try:
while True:
assistant.autonomous_iterations += 1
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
logger.debug(f"Messages before context management: {len(assistant.messages)}")
logger.debug(
f"--- Autonomous iteration {assistant.autonomous_iterations} ---"
)
logger.debug(
f"Messages before context management: {len(assistant.messages)}"
)
from pr.core.context import manage_context_window
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
logger.debug(f"Messages after context management: {len(assistant.messages)}")
assistant.messages = manage_context_window(
assistant.messages, assistant.verbose
)
logger.debug(
f"Messages after context management: {len(assistant.messages)}"
)
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
response = call_api(
assistant.messages,
assistant.model,
@ -40,10 +49,10 @@ def run_autonomous_mode(assistant, task):
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose
verbose=assistant.verbose,
)
if 'error' in response:
if "error" in response:
logger.error(f"API error in autonomous mode: {response['error']}")
print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}")
break
@ -74,22 +83,23 @@ def run_autonomous_mode(assistant, task):
assistant.autonomous_mode = False
logger.debug("=== AUTONOMOUS MODE END ===")
def process_response_autonomous(assistant, response):
if 'error' in response:
if "error" in response:
return f"Error: {response['error']}"
if 'choices' not in response or not response['choices']:
if "choices" not in response or not response["choices"]:
return "No response from API"
message = response['choices'][0]['message']
message = response["choices"][0]["message"]
assistant.messages.append(message)
if 'tool_calls' in message and message['tool_calls']:
if "tool_calls" in message and message["tool_calls"]:
tool_results = []
for tool_call in message['tool_calls']:
func_name = tool_call['function']['name']
arguments = json.loads(tool_call['function']['arguments'])
for tool_call in message["tool_calls"]:
func_name = tool_call["function"]["name"]
arguments = json.loads(tool_call["function"]["arguments"])
result = execute_single_tool(assistant, func_name, arguments)
result = truncate_tool_result(result)
@ -97,16 +107,19 @@ def process_response_autonomous(assistant, response):
status = "success" if result.get("status") == "success" else "error"
display_tool_call(func_name, arguments, status, result)
tool_results.append({
"tool_call_id": tool_call['id'],
"role": "tool",
"content": json.dumps(result)
})
tool_results.append(
{
"tool_call_id": tool_call["id"],
"role": "tool",
"content": json.dumps(result),
}
)
for result in tool_results:
assistant.messages.append(result)
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
follow_up = call_api(
assistant.messages,
assistant.model,
@ -114,59 +127,88 @@ def process_response_autonomous(assistant, response):
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose
verbose=assistant.verbose,
)
return process_response_autonomous(assistant, follow_up)
content = message.get('content', '')
content = message.get("content", "")
from pr.ui import render_markdown
return render_markdown(content, assistant.syntax_highlighting)
def execute_single_tool(assistant, func_name, arguments):
logger.debug(f"Executing tool in autonomous mode: {func_name}")
logger.debug(f"Tool arguments: {arguments}")
from pr.tools import (
http_fetch, run_command, run_command_interactive, read_file, write_file,
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
web_search, web_search_news, python_exec, index_source_directory,
search_replace, open_editor, editor_insert_text, editor_replace_text,
editor_search, close_editor, create_diff, apply_patch, tail_process, kill_process
apply_patch,
chdir,
close_editor,
create_diff,
db_get,
db_query,
db_set,
editor_insert_text,
editor_replace_text,
editor_search,
getpwd,
http_fetch,
index_source_directory,
kill_process,
list_directory,
mkdir,
open_editor,
python_exec,
read_file,
run_command,
run_command_interactive,
search_replace,
tail_process,
web_search,
web_search_news,
write_file,
)
from pr.tools.filesystem import (
clear_edit_tracker,
display_edit_summary,
display_edit_timeline,
)
from pr.tools.patch import display_file_diff
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
func_map = {
'http_fetch': lambda **kw: http_fetch(**kw),
'run_command': lambda **kw: run_command(**kw),
'tail_process': lambda **kw: tail_process(**kw),
'kill_process': lambda **kw: kill_process(**kw),
'run_command_interactive': lambda **kw: run_command_interactive(**kw),
'read_file': lambda **kw: read_file(**kw),
'write_file': lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
'list_directory': lambda **kw: list_directory(**kw),
'mkdir': lambda **kw: mkdir(**kw),
'chdir': lambda **kw: chdir(**kw),
'getpwd': lambda **kw: getpwd(**kw),
'db_set': lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
'db_get': lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
'db_query': lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
'web_search': lambda **kw: web_search(**kw),
'web_search_news': lambda **kw: web_search_news(**kw),
'python_exec': lambda **kw: python_exec(**kw, python_globals=assistant.python_globals),
'index_source_directory': lambda **kw: index_source_directory(**kw),
'search_replace': lambda **kw: search_replace(**kw),
'open_editor': lambda **kw: open_editor(**kw),
'editor_insert_text': lambda **kw: editor_insert_text(**kw),
'editor_replace_text': lambda **kw: editor_replace_text(**kw),
'editor_search': lambda **kw: editor_search(**kw),
'close_editor': lambda **kw: close_editor(**kw),
'create_diff': lambda **kw: create_diff(**kw),
'apply_patch': lambda **kw: apply_patch(**kw),
'display_file_diff': lambda **kw: display_file_diff(**kw),
'display_edit_summary': lambda **kw: display_edit_summary(),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
"http_fetch": lambda **kw: http_fetch(**kw),
"run_command": lambda **kw: run_command(**kw),
"tail_process": lambda **kw: tail_process(**kw),
"kill_process": lambda **kw: kill_process(**kw),
"run_command_interactive": lambda **kw: run_command_interactive(**kw),
"read_file": lambda **kw: read_file(**kw),
"write_file": lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
"list_directory": lambda **kw: list_directory(**kw),
"mkdir": lambda **kw: mkdir(**kw),
"chdir": lambda **kw: chdir(**kw),
"getpwd": lambda **kw: getpwd(**kw),
"db_set": lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
"db_get": lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
"db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
"web_search": lambda **kw: web_search(**kw),
"web_search_news": lambda **kw: web_search_news(**kw),
"python_exec": lambda **kw: python_exec(
**kw, python_globals=assistant.python_globals
),
"index_source_directory": lambda **kw: index_source_directory(**kw),
"search_replace": lambda **kw: search_replace(**kw),
"open_editor": lambda **kw: open_editor(**kw),
"editor_insert_text": lambda **kw: editor_insert_text(**kw),
"editor_replace_text": lambda **kw: editor_replace_text(**kw),
"editor_search": lambda **kw: editor_search(**kw),
"close_editor": lambda **kw: close_editor(**kw),
"create_diff": lambda **kw: create_diff(**kw),
"apply_patch": lambda **kw: apply_patch(**kw),
"display_file_diff": lambda **kw: display_file_diff(**kw),
"display_edit_summary": lambda **kw: display_edit_summary(),
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
"clear_edit_tracker": lambda **kw: clear_edit_tracker(),
}
if func_name in func_map:

View File

@ -1,4 +1,4 @@
from .api_cache import APICache
from .tool_cache import ToolCache
__all__ = ['APICache', 'ToolCache']
__all__ = ["APICache", "ToolCache"]

86
pr/cache/api_cache.py vendored
View File

@ -2,7 +2,8 @@ import hashlib
import json
import sqlite3
import time
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
class APICache:
def __init__(self, db_path: str, ttl_seconds: int = 3600):
@ -13,7 +14,8 @@ class APICache:
def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS api_cache (
cache_key TEXT PRIMARY KEY,
response_data TEXT NOT NULL,
@ -22,34 +24,44 @@ class APICache:
model TEXT,
token_count INTEGER
)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
''')
"""
)
conn.commit()
conn.close()
def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str:
def _generate_cache_key(
self, model: str, messages: list, temperature: float, max_tokens: int
) -> str:
cache_data = {
'model': model,
'messages': messages,
'temperature': temperature,
'max_tokens': max_tokens
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]:
def get(
self, model: str, messages: list, temperature: float, max_tokens: int
) -> Optional[Dict[str, Any]]:
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute('''
cursor.execute(
"""
SELECT response_data FROM api_cache
WHERE cache_key = ? AND expires_at > ?
''', (cache_key, current_time))
""",
(cache_key, current_time),
)
row = cursor.fetchone()
conn.close()
@ -58,8 +70,15 @@ class APICache:
return json.loads(row[0])
return None
def set(self, model: str, messages: list, temperature: float, max_tokens: int,
response: Dict[str, Any], token_count: int = 0):
def set(
self,
model: str,
messages: list,
temperature: float,
max_tokens: int,
response: Dict[str, Any],
token_count: int = 0,
):
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
current_time = int(time.time())
@ -68,11 +87,21 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT OR REPLACE INTO api_cache
(cache_key, response_data, created_at, expires_at, model, token_count)
VALUES (?, ?, ?, ?, ?, ?)
''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count))
""",
(
cache_key,
json.dumps(response),
current_time,
expires_at,
model,
token_count,
),
)
conn.commit()
conn.close()
@ -83,7 +112,7 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,))
cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,))
deleted_count = cursor.rowcount
conn.commit()
@ -95,7 +124,7 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM api_cache')
cursor.execute("DELETE FROM api_cache")
deleted_count = cursor.rowcount
conn.commit()
@ -107,21 +136,26 @@ class APICache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM api_cache')
cursor.execute("SELECT COUNT(*) FROM api_cache")
total_entries = cursor.fetchone()[0]
current_time = int(time.time())
cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,))
cursor.execute(
"SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)
)
valid_entries = cursor.fetchone()[0]
cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,))
cursor.execute(
"SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?",
(current_time,),
)
total_tokens = cursor.fetchone()[0] or 0
conn.close()
return {
'total_entries': total_entries,
'valid_entries': valid_entries,
'expired_entries': total_entries - valid_entries,
'total_cached_tokens': total_tokens
"total_entries": total_entries,
"valid_entries": valid_entries,
"expired_entries": total_entries - valid_entries,
"total_cached_tokens": total_tokens,
}

View File

@ -2,16 +2,17 @@ import hashlib
import json
import sqlite3
import time
from typing import Optional, Any, Set
from typing import Any, Optional, Set
class ToolCache:
DETERMINISTIC_TOOLS: Set[str] = {
'read_file',
'list_directory',
'get_current_directory',
'db_get',
'db_query',
'index_directory'
"read_file",
"list_directory",
"get_current_directory",
"db_get",
"db_query",
"index_directory",
}
def __init__(self, db_path: str, ttl_seconds: int = 300):
@ -22,7 +23,8 @@ class ToolCache:
def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tool_cache (
cache_key TEXT PRIMARY KEY,
tool_name TEXT NOT NULL,
@ -31,21 +33,23 @@ class ToolCache:
expires_at INTEGER NOT NULL,
hit_count INTEGER DEFAULT 0
)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
''')
"""
)
conn.commit()
conn.close()
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
cache_data = {
'tool': tool_name,
'args': arguments
}
cache_data = {"tool": tool_name, "args": arguments}
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
@ -62,18 +66,24 @@ class ToolCache:
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute('''
cursor.execute(
"""
SELECT result_data, hit_count FROM tool_cache
WHERE cache_key = ? AND expires_at > ?
''', (cache_key, current_time))
""",
(cache_key, current_time),
)
row = cursor.fetchone()
if row:
cursor.execute('''
cursor.execute(
"""
UPDATE tool_cache SET hit_count = hit_count + 1
WHERE cache_key = ?
''', (cache_key,))
""",
(cache_key,),
)
conn.commit()
conn.close()
return json.loads(row[0])
@ -93,11 +103,14 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT OR REPLACE INTO tool_cache
(cache_key, tool_name, result_data, created_at, expires_at, hit_count)
VALUES (?, ?, ?, ?, ?, 0)
''', (cache_key, tool_name, json.dumps(result), current_time, expires_at))
""",
(cache_key, tool_name, json.dumps(result), current_time, expires_at),
)
conn.commit()
conn.close()
@ -106,7 +119,7 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,))
cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,))
deleted_count = cursor.rowcount
conn.commit()
@ -120,7 +133,7 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,))
cursor.execute("DELETE FROM tool_cache WHERE expires_at <= ?", (current_time,))
deleted_count = cursor.rowcount
conn.commit()
@ -132,7 +145,7 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache')
cursor.execute("DELETE FROM tool_cache")
deleted_count = cursor.rowcount
conn.commit()
@ -144,36 +157,41 @@ class ToolCache:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM tool_cache')
cursor.execute("SELECT COUNT(*) FROM tool_cache")
total_entries = cursor.fetchone()[0]
current_time = int(time.time())
cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,))
cursor.execute(
"SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)
)
valid_entries = cursor.fetchone()[0]
cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,))
cursor.execute(
"SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?",
(current_time,),
)
total_hits = cursor.fetchone()[0] or 0
cursor.execute('''
cursor.execute(
"""
SELECT tool_name, COUNT(*), SUM(hit_count)
FROM tool_cache
WHERE expires_at > ?
GROUP BY tool_name
''', (current_time,))
""",
(current_time,),
)
tool_stats = {}
for row in cursor.fetchall():
tool_stats[row[0]] = {
'cached_entries': row[1],
'total_hits': row[2] or 0
}
tool_stats[row[0]] = {"cached_entries": row[1], "total_hits": row[2] or 0}
conn.close()
return {
'total_entries': total_entries,
'valid_entries': valid_entries,
'expired_entries': total_entries - valid_entries,
'total_cache_hits': total_hits,
'by_tool': tool_stats
"total_entries": total_entries,
"valid_entries": valid_entries,
"expired_entries": total_entries - valid_entries,
"total_cache_hits": total_hits,
"by_tool": tool_stats,
}

View File

@ -1,3 +1,3 @@
from pr.commands.handlers import handle_command
__all__ = ['handle_command']
__all__ = ["handle_command"]

View File

@ -1,30 +1,35 @@
import json
import time
from pr.ui import Colors
from pr.autonomous import run_autonomous_mode
from pr.core.api import list_models
from pr.tools import read_file
from pr.tools.base import get_tools_definition
from pr.core.api import list_models
from pr.autonomous import run_autonomous_mode
from pr.ui import Colors
def handle_command(assistant, command):
command_parts = command.strip().split(maxsplit=1)
cmd = command_parts[0].lower()
if cmd == '/auto':
if cmd == "/auto":
if len(command_parts) < 2:
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")
print(f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}")
print(
f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}"
)
return True
task = command_parts[1]
run_autonomous_mode(assistant, task)
return True
if cmd in ['exit', 'quit', 'q']:
if cmd in ["exit", "quit", "q"]:
return False
elif cmd == 'help':
print(f"""
elif cmd == "help":
print(
f"""
{Colors.BOLD}Available Commands:{Colors.RESET}
{Colors.BOLD}Basic:{Colors.RESET}
@ -54,18 +59,21 @@ def handle_command(assistant, command):
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics
""")
"""
)
elif cmd == '/reset':
elif cmd == "/reset":
assistant.messages = assistant.messages[:1]
print(f"{Colors.GREEN}Message history cleared{Colors.RESET}")
elif cmd == '/dump':
elif cmd == "/dump":
print(json.dumps(assistant.messages, indent=2))
elif cmd == '/verbose':
elif cmd == "/verbose":
assistant.verbose = not assistant.verbose
print(f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}")
print(
f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}"
)
elif cmd.startswith("/model"):
if len(command_parts) < 2:
@ -74,77 +82,81 @@ def handle_command(assistant, command):
assistant.model = command_parts[1]
print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}")
elif cmd == '/models':
elif cmd == "/models":
models = list_models(assistant.model_list_url, assistant.api_key)
if isinstance(models, dict) and 'error' in models:
if isinstance(models, dict) and "error" in models:
print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}")
else:
print(f"{Colors.BOLD}Available Models:{Colors.RESET}")
for model in models:
print(f"{Colors.CYAN}{model['id']}{Colors.RESET}")
elif cmd == '/tools':
elif cmd == "/tools":
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
for tool in get_tools_definition():
func = tool['function']
print(f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}")
func = tool["function"]
print(
f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}"
)
elif cmd == '/review' and len(command_parts) > 1:
elif cmd == "/review" and len(command_parts) > 1:
filename = command_parts[1]
review_file(assistant, filename)
elif cmd == '/refactor' and len(command_parts) > 1:
elif cmd == "/refactor" and len(command_parts) > 1:
filename = command_parts[1]
refactor_file(assistant, filename)
elif cmd == '/obfuscate' and len(command_parts) > 1:
elif cmd == "/obfuscate" and len(command_parts) > 1:
filename = command_parts[1]
obfuscate_file(assistant, filename)
elif cmd == '/workflows':
elif cmd == "/workflows":
show_workflows(assistant)
elif cmd == '/workflow' and len(command_parts) > 1:
elif cmd == "/workflow" and len(command_parts) > 1:
workflow_name = command_parts[1]
execute_workflow_command(assistant, workflow_name)
elif cmd == '/agent' and len(command_parts) > 1:
elif cmd == "/agent" and len(command_parts) > 1:
args = command_parts[1].split(maxsplit=1)
if len(args) < 2:
print(f"{Colors.RED}Usage: /agent <role> <task>{Colors.RESET}")
print(f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}")
print(
f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}"
)
else:
role, task = args[0], args[1]
execute_agent_task(assistant, role, task)
elif cmd == '/agents':
elif cmd == "/agents":
show_agents(assistant)
elif cmd == '/collaborate' and len(command_parts) > 1:
elif cmd == "/collaborate" and len(command_parts) > 1:
task = command_parts[1]
collaborate_agents_command(assistant, task)
elif cmd == '/knowledge' and len(command_parts) > 1:
elif cmd == "/knowledge" and len(command_parts) > 1:
query = command_parts[1]
search_knowledge(assistant, query)
elif cmd == '/remember' and len(command_parts) > 1:
elif cmd == "/remember" and len(command_parts) > 1:
content = command_parts[1]
store_knowledge(assistant, content)
elif cmd == '/history':
elif cmd == "/history":
show_conversation_history(assistant)
elif cmd == '/cache':
if len(command_parts) > 1 and command_parts[1].lower() == 'clear':
elif cmd == "/cache":
if len(command_parts) > 1 and command_parts[1].lower() == "clear":
clear_caches(assistant)
else:
show_cache_stats(assistant)
elif cmd == '/stats':
elif cmd == "/stats":
show_system_stats(assistant)
elif cmd.startswith('/bg'):
elif cmd.startswith("/bg"):
handle_background_command(assistant, command)
else:
@ -152,35 +164,46 @@ def handle_command(assistant, command):
return True
def review_file(assistant, filename):
result = read_file(filename)
if result['status'] == 'success':
message = f"Please review this file and provide feedback:\n\n{result['content']}"
if result["status"] == "success":
message = (
f"Please review this file and provide feedback:\n\n{result['content']}"
)
from pr.core.assistant import process_message
process_message(assistant, message)
else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def refactor_file(assistant, filename):
result = read_file(filename)
if result['status'] == 'success':
message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
if result["status"] == "success":
message = (
f"Please refactor this code to improve its quality:\n\n{result['content']}"
)
from pr.core.assistant import process_message
process_message(assistant, message)
else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def obfuscate_file(assistant, filename):
result = read_file(filename)
if result['status'] == 'success':
if result["status"] == "success":
message = f"Please obfuscate this code:\n\n{result['content']}"
from pr.core.assistant import process_message
process_message(assistant, message)
else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def show_workflows(assistant):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -194,23 +217,25 @@ def show_workflows(assistant):
print(f"{Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}")
print(f" Executions: {wf['execution_count']}")
def execute_workflow_command(assistant, workflow_name):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}")
result = assistant.enhanced.execute_workflow(workflow_name)
if 'error' in result:
if "error" in result:
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
else:
print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}")
print(f"Execution ID: {result['execution_id']}")
print(f"Results: {json.dumps(result['results'], indent=2)}")
def execute_agent_task(assistant, role, task):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -221,14 +246,15 @@ def execute_agent_task(assistant, role, task):
print(f"{Colors.YELLOW}Executing task...{Colors.RESET}")
result = assistant.enhanced.agent_task(agent_id, task)
if 'error' in result:
if "error" in result:
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
else:
print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}")
print(result['response'])
print(result["response"])
def show_agents(assistant):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -236,37 +262,39 @@ def show_agents(assistant):
print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}")
print(f"Active agents: {summary['active_agents']}")
if summary['agents']:
for agent in summary['agents']:
if summary["agents"]:
for agent in summary["agents"]:
print(f"\n{Colors.CYAN}{agent['agent_id']}{Colors.RESET}")
print(f" Role: {agent['role']}")
print(f" Tasks completed: {agent['task_count']}")
print(f" Messages: {agent['message_count']}")
def collaborate_agents_command(assistant, task):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}")
roles = ['coding', 'research', 'planning']
roles = ["coding", "research", "planning"]
result = assistant.enhanced.collaborate_agents(task, roles)
print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}")
print(f"\nOrchestrator response:")
if 'orchestrator' in result and 'response' in result['orchestrator']:
print(result['orchestrator']['response'])
if "orchestrator" in result and "response" in result["orchestrator"]:
print(result["orchestrator"]["response"])
if result.get('agents'):
if result.get("agents"):
print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}")
for agent_result in result['agents']:
if 'role' in agent_result:
for agent_result in result["agents"]:
if "role" in agent_result:
print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}")
print(agent_result.get('response', 'No response'))
print(agent_result.get("response", "No response"))
def search_knowledge(assistant, query):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -282,13 +310,15 @@ def search_knowledge(assistant, query):
print(f" {entry.content[:200]}...")
print(f" Accessed: {entry.access_count} times")
def store_knowledge(assistant, content):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
import uuid
import time
import uuid
from pr.memory import KnowledgeEntry
categories = assistant.enhanced.fact_extractor.categorize_content(content)
@ -296,11 +326,11 @@ def store_knowledge(assistant, content):
entry = KnowledgeEntry(
entry_id=entry_id,
category=categories[0] if categories else 'general',
category=categories[0] if categories else "general",
content=content,
metadata={'manual_entry': True},
metadata={"manual_entry": True},
created_at=time.time(),
updated_at=time.time()
updated_at=time.time(),
)
assistant.enhanced.knowledge_store.add_entry(entry)
@ -308,8 +338,9 @@ def store_knowledge(assistant, content):
print(f"Entry ID: {entry_id}")
print(f"Category: {entry.category}")
def show_conversation_history(assistant):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -322,17 +353,21 @@ def show_conversation_history(assistant):
print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}")
for conv in history:
import datetime
started = datetime.datetime.fromtimestamp(conv['started_at']).strftime('%Y-%m-%d %H:%M')
started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime(
"%Y-%m-%d %H:%M"
)
print(f"\n{Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
print(f" Started: {started}")
print(f" Messages: {conv['message_count']}")
if conv.get('summary'):
if conv.get("summary"):
print(f" Summary: {conv['summary'][:100]}...")
if conv.get('topics'):
if conv.get("topics"):
print(f" Topics: {', '.join(conv['topics'])}")
def show_cache_stats(assistant):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -340,36 +375,40 @@ def show_cache_stats(assistant):
print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}")
if 'api_cache' in stats:
api_stats = stats['api_cache']
if "api_cache" in stats:
api_stats = stats["api_cache"]
print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}")
print(f" Total entries: {api_stats['total_entries']}")
print(f" Valid entries: {api_stats['valid_entries']}")
print(f" Expired entries: {api_stats['expired_entries']}")
print(f" Cached tokens: {api_stats['total_cached_tokens']}")
if 'tool_cache' in stats:
tool_stats = stats['tool_cache']
if "tool_cache" in stats:
tool_stats = stats["tool_cache"]
print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}")
print(f" Total entries: {tool_stats['total_entries']}")
print(f" Valid entries: {tool_stats['valid_entries']}")
print(f" Total cache hits: {tool_stats['total_cache_hits']}")
if tool_stats.get('by_tool'):
if tool_stats.get("by_tool"):
print(f"\n Per-tool statistics:")
for tool_name, tool_stat in tool_stats['by_tool'].items():
print(f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits")
for tool_name, tool_stat in tool_stats["by_tool"].items():
print(
f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits"
)
def clear_caches(assistant):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
assistant.enhanced.clear_caches()
print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}")
def show_system_stats(assistant):
if not hasattr(assistant, 'enhanced'):
if not hasattr(assistant, "enhanced"):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
@ -388,68 +427,81 @@ def show_system_stats(assistant):
print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}")
print(f" Count: {agent_summary['active_agents']}")
if 'api_cache' in cache_stats:
if "api_cache" in cache_stats:
print(f"\n{Colors.CYAN}Caching:{Colors.RESET}")
print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}")
if 'tool_cache' in cache_stats:
if "tool_cache" in cache_stats:
print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}")
def handle_background_command(assistant, command):
"""Handle background multiplexer commands."""
parts = command.strip().split(maxsplit=2)
if len(parts) < 2:
print(f"{Colors.RED}Usage: /bg <subcommand> [args]{Colors.RESET}")
print(f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}")
print(
f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}"
)
return
subcmd = parts[1].lower()
try:
if subcmd == 'start' and len(parts) >= 3:
if subcmd == "start" and len(parts) >= 3:
session_name = f"bg_{len(parts[2].split())}_{int(time.time())}"
start_background_session(assistant, session_name, parts[2])
elif subcmd == 'list':
elif subcmd == "list":
list_background_sessions(assistant)
elif subcmd == 'status' and len(parts) >= 3:
elif subcmd == "status" and len(parts) >= 3:
show_session_status(assistant, parts[2])
elif subcmd == 'output' and len(parts) >= 3:
elif subcmd == "output" and len(parts) >= 3:
show_session_output(assistant, parts[2])
elif subcmd == 'input' and len(parts) >= 4:
elif subcmd == "input" and len(parts) >= 4:
send_session_input(assistant, parts[2], parts[3])
elif subcmd == 'kill' and len(parts) >= 3:
elif subcmd == "kill" and len(parts) >= 3:
kill_background_session(assistant, parts[2])
elif subcmd == 'events':
elif subcmd == "events":
show_background_events(assistant)
else:
print(f"{Colors.RED}Unknown background command: {subcmd}{Colors.RESET}")
print(f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}")
print(
f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}"
)
except Exception as e:
print(f"{Colors.RED}Error executing background command: {e}{Colors.RESET}")
def start_background_session(assistant, session_name, command):
"""Start a command in background."""
try:
from pr.multiplexer import start_background_process
result = start_background_process(session_name, command)
if result['status'] == 'success':
print(f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}")
if result["status"] == "success":
print(
f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}"
)
else:
print(f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}")
print(
f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}"
)
except Exception as e:
print(f"{Colors.RED}Error starting background session: {e}{Colors.RESET}")
def list_background_sessions(assistant):
"""List all background sessions."""
try:
from pr.ui.display import display_multiplexer_status
from pr.multiplexer import get_all_sessions
from pr.ui.display import display_multiplexer_status
sessions = get_all_sessions()
display_multiplexer_status(sessions)
except Exception as e:
print(f"{Colors.RED}Error listing background sessions: {e}{Colors.RESET}")
def show_session_status(assistant, session_name):
"""Show status of a specific session."""
try:
@ -461,15 +513,17 @@ def show_session_status(assistant, session_name):
print(f" Status: {info.get('status', 'unknown')}")
print(f" PID: {info.get('pid', 'N/A')}")
print(f" Command: {info.get('command', 'N/A')}")
if 'start_time' in info:
if "start_time" in info:
import time
elapsed = time.time() - info['start_time']
elapsed = time.time() - info["start_time"]
print(f" Running for: {elapsed:.1f}s")
else:
print(f"{Colors.YELLOW}Session '{session_name}' not found{Colors.RESET}")
except Exception as e:
print(f"{Colors.RED}Error getting session status: {e}{Colors.RESET}")
def show_session_output(assistant, session_name):
"""Show output of a specific session."""
try:
@ -482,36 +536,45 @@ def show_session_output(assistant, session_name):
for line in output:
print(line)
else:
print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}")
print(
f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}"
)
except Exception as e:
print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}")
def send_session_input(assistant, session_name, input_text):
"""Send input to a background session."""
try:
from pr.multiplexer import send_input_to_session
result = send_input_to_session(session_name, input_text)
if result['status'] == 'success':
if result["status"] == "success":
print(f"{Colors.GREEN}Input sent to session '{session_name}'{Colors.RESET}")
else:
print(f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}")
print(
f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}"
)
except Exception as e:
print(f"{Colors.RED}Error sending input: {e}{Colors.RESET}")
def kill_background_session(assistant, session_name):
"""Kill a background session."""
try:
from pr.multiplexer import kill_session
result = kill_session(session_name)
if result['status'] == 'success':
if result["status"] == "success":
print(f"{Colors.GREEN}Session '{session_name}' terminated{Colors.RESET}")
else:
print(f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}")
print(
f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}"
)
except Exception as e:
print(f"{Colors.RED}Error killing session: {e}{Colors.RESET}")
def show_background_events(assistant):
"""Show recent background events."""
try:
@ -526,6 +589,7 @@ def show_background_events(assistant):
for event in events[-10:]: # Show last 10 events
from pr.ui.display import display_background_event
display_background_event(event)
else:
print(f"{Colors.GRAY}No recent background events{Colors.RESET}")

View File

@ -1,11 +1,15 @@
from pr.tools.interactive_control import (
list_active_sessions, get_session_status, read_session_output,
send_input_to_session, close_interactive_session
)
from pr.multiplexer import get_multiplexer
from pr.tools.interactive_control import (
close_interactive_session,
get_session_status,
list_active_sessions,
read_session_output,
send_input_to_session,
)
from pr.tools.prompt_detection import get_global_detector
from pr.ui import Colors
def show_sessions(args=None):
"""Show all active multiplexer sessions."""
sessions = list_active_sessions()
@ -18,24 +22,29 @@ def show_sessions(args=None):
print("-" * 80)
for session_name, session_data in sessions.items():
metadata = session_data['metadata']
output_summary = session_data['output_summary']
metadata = session_data["metadata"]
output_summary = session_data["output_summary"]
status = get_session_status(session_name)
is_active = status.get('is_active', False) if status else False
is_active = status.get("is_active", False) if status else False
status_color = Colors.GREEN if is_active else Colors.RED
print(f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}")
print(
f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}"
)
if status and 'pid' in status:
if status and "pid" in status:
print(f" PID: {status['pid']}")
print(f" Age: {metadata.get('start_time', 0):.1f}s")
print(f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines")
print(
f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines"
)
print(f" Interactions: {metadata.get('interaction_count', 0)}")
print(f" State: {metadata.get('state', 'unknown')}")
print()
def attach_session(args):
"""Attach to a session (show its output and allow interaction)."""
if not args or len(args) < 1:
@ -56,20 +65,23 @@ def attach_session(args):
# Show recent output
try:
output = read_session_output(session_name, lines=20)
if output['stdout']:
if output["stdout"]:
print(f"{Colors.GRAY}Recent stdout:{Colors.RESET}")
for line in output['stdout'].split('\n'):
for line in output["stdout"].split("\n"):
if line.strip():
print(f" {line}")
if output['stderr']:
if output["stderr"]:
print(f"{Colors.YELLOW}Recent stderr:{Colors.RESET}")
for line in output['stderr'].split('\n'):
for line in output["stderr"].split("\n"):
if line.strip():
print(f" {line}")
except Exception as e:
print(f"{Colors.RED}Error reading output: {e}{Colors.RESET}")
print(f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}")
print(
f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}"
)
def detach_session(args):
"""Detach from a session (stop showing its output but keep it running)."""
@ -87,7 +99,10 @@ def detach_session(args):
# In this implementation, detaching just means we stop displaying output
# The session continues to run in the background
mux.show_output = False
print(f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}")
print(
f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}"
)
def kill_session(args):
"""Kill a session forcefully."""
@ -101,7 +116,10 @@ def kill_session(args):
close_interactive_session(session_name)
print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}")
except Exception as e:
print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}")
print(
f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}"
)
def send_command(args):
"""Send a command to a session."""
@ -110,13 +128,18 @@ def send_command(args):
return
session_name = args[0]
command = ' '.join(args[1:])
command = " ".join(args[1:])
try:
send_input_to_session(session_name, command)
print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}")
print(
f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}"
)
except Exception as e:
print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}")
print(
f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}"
)
def show_session_log(args):
"""Show the full log/output of a session."""
@ -131,19 +154,20 @@ def show_session_log(args):
print(f"{Colors.BOLD}Full log for session: {session_name}{Colors.RESET}")
print("=" * 80)
if output['stdout']:
if output["stdout"]:
print(f"{Colors.GRAY}STDOUT:{Colors.RESET}")
print(output['stdout'])
print(output["stdout"])
print()
if output['stderr']:
if output["stderr"]:
print(f"{Colors.YELLOW}STDERR:{Colors.RESET}")
print(output['stderr'])
print(output["stderr"])
print()
except Exception as e:
print(f"{Colors.RED}Error reading log for '{session_name}': {e}{Colors.RESET}")
def show_session_status(args):
"""Show detailed status of a session."""
if not args or len(args) < 1:
@ -160,11 +184,11 @@ def show_session_status(args):
print(f"{Colors.BOLD}Status for session: {session_name}{Colors.RESET}")
print("-" * 50)
metadata = status.get('metadata', {})
metadata = status.get("metadata", {})
print(f"Process type: {metadata.get('process_type', 'unknown')}")
print(f"Active: {status.get('is_active', False)}")
if 'pid' in status:
if "pid" in status:
print(f"PID: {status['pid']}")
print(f"Start time: {metadata.get('start_time', 0):.1f}")
@ -172,8 +196,10 @@ def show_session_status(args):
print(f"Interaction count: {metadata.get('interaction_count', 0)}")
print(f"State: {metadata.get('state', 'unknown')}")
output_summary = status.get('output_summary', {})
print(f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr")
output_summary = status.get("output_summary", {})
print(
f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr"
)
# Show prompt detection info
detector = get_global_detector()
@ -182,6 +208,7 @@ def show_session_status(args):
print(f"Current state: {session_info['current_state']}")
print(f"Is waiting for input: {session_info['is_waiting']}")
def list_waiting_sessions(args=None):
"""List sessions that appear to be waiting for input."""
sessions = list_active_sessions()
@ -193,14 +220,16 @@ def list_waiting_sessions(args=None):
waiting_sessions.append(session_name)
if not waiting_sessions:
print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}")
print(
f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}"
)
return
print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}")
for session_name in waiting_sessions:
status = get_session_status(session_name)
if status:
process_type = status.get('metadata', {}).get('process_type', 'unknown')
process_type = status.get("metadata", {}).get("process_type", "unknown")
print(f" {Colors.CYAN}{session_name}{Colors.RESET} ({process_type})")
# Show suggestions
@ -208,17 +237,20 @@ def list_waiting_sessions(args=None):
if session_info:
suggestions = detector.get_response_suggestions({}, process_type)
if suggestions:
print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3
print(
f" Suggested inputs: {', '.join(suggestions[:3])}"
) # Show first 3
print()
# Command registry for the multiplexer commands
MULTIPLEXER_COMMANDS = {
'show_sessions': show_sessions,
'attach_session': attach_session,
'detach_session': detach_session,
'kill_session': kill_session,
'send_command': send_command,
'show_session_log': show_session_log,
'show_session_status': show_session_status,
'list_waiting_sessions': list_waiting_sessions,
}
"show_sessions": show_sessions,
"attach_session": attach_session,
"detach_session": detach_session,
"kill_session": kill_session,
"send_command": send_command,
"show_session_log": show_session_log,
"show_session_status": show_session_status,
"list_waiting_sessions": list_waiting_sessions,
}

View File

@ -27,15 +27,78 @@ CONTENT_TRIM_LENGTH = 30000
MAX_TOOL_RESULT_LENGTH = 30000
LANGUAGE_KEYWORDS = {
'python': ['def', 'class', 'import', 'from', 'if', 'else', 'elif', 'for', 'while',
'return', 'try', 'except', 'finally', 'with', 'as', 'lambda', 'yield',
'None', 'True', 'False', 'and', 'or', 'not', 'in', 'is'],
'javascript': ['function', 'var', 'let', 'const', 'if', 'else', 'for', 'while',
'return', 'try', 'catch', 'finally', 'class', 'extends', 'new',
'this', 'null', 'undefined', 'true', 'false'],
'java': ['public', 'private', 'protected', 'class', 'interface', 'extends',
'implements', 'static', 'final', 'void', 'int', 'String', 'boolean',
'if', 'else', 'for', 'while', 'return', 'try', 'catch', 'finally'],
"python": [
"def",
"class",
"import",
"from",
"if",
"else",
"elif",
"for",
"while",
"return",
"try",
"except",
"finally",
"with",
"as",
"lambda",
"yield",
"None",
"True",
"False",
"and",
"or",
"not",
"in",
"is",
],
"javascript": [
"function",
"var",
"let",
"const",
"if",
"else",
"for",
"while",
"return",
"try",
"catch",
"finally",
"class",
"extends",
"new",
"this",
"null",
"undefined",
"true",
"false",
],
"java": [
"public",
"private",
"protected",
"class",
"interface",
"extends",
"implements",
"static",
"final",
"void",
"int",
"String",
"boolean",
"if",
"else",
"for",
"while",
"return",
"try",
"catch",
"finally",
],
}
CACHE_ENABLED = True
@ -70,18 +133,18 @@ MAX_CONCURRENT_SESSIONS = 10
# Process-specific timeouts (seconds)
PROCESS_TIMEOUTS = {
'default': 300, # 5 minutes
'apt': 600, # 10 minutes
'ssh': 60, # 1 minute
'vim': 3600, # 1 hour
'git': 300, # 5 minutes
'npm': 600, # 10 minutes
'pip': 300, # 5 minutes
"default": 300, # 5 minutes
"apt": 600, # 10 minutes
"ssh": 60, # 1 minute
"vim": 3600, # 1 hour
"git": 300, # 5 minutes
"npm": 600, # 10 minutes
"pip": 300, # 5 minutes
}
# Activity thresholds for LLM notification
HIGH_OUTPUT_THRESHOLD = 50 # lines
INACTIVE_THRESHOLD = 300 # seconds
INACTIVE_THRESHOLD = 300 # seconds
SESSION_NOTIFY_INTERVAL = 60 # seconds
# Autonomous behavior flags

View File

@ -1,5 +1,11 @@
from pr.core.assistant import Assistant
from pr.core.api import call_api, list_models
from pr.core.assistant import Assistant
from pr.core.context import init_system_message, manage_context_window
__all__ = ['Assistant', 'call_api', 'list_models', 'init_system_message', 'manage_context_window']
__all__ = [
"Assistant",
"call_api",
"list_models",
"init_system_message",
"manage_context_window",
]

View File

@ -1,20 +1,20 @@
import re
import math
from typing import List, Dict, Any
from collections import Counter
from typing import Any, Dict, List
class AdvancedContextManager:
def __init__(self, knowledge_store=None, conversation_memory=None):
self.knowledge_store = knowledge_store
self.conversation_memory = conversation_memory
def adaptive_context_window(self, messages: List[Dict[str, Any]],
task_complexity: str = 'medium') -> int:
def adaptive_context_window(
self, messages: List[Dict[str, Any]], task_complexity: str = "medium"
) -> int:
complexity_thresholds = {
'simple': 10,
'medium': 20,
'complex': 35,
'very_complex': 50
"simple": 10,
"medium": 20,
"complex": 35,
"very_complex": 50,
}
base_threshold = complexity_thresholds.get(task_complexity, 20)
@ -31,17 +31,19 @@ class AdvancedContextManager:
return max(base_threshold, adjusted)
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
total_length = sum(len(msg.get('content', '')) for msg in messages)
total_length = sum(len(msg.get("content", "")) for msg in messages)
avg_length = total_length / len(messages) if messages else 0
unique_words = set()
for msg in messages:
content = msg.get('content', '')
words = re.findall(r'\b\w+\b', content.lower())
content = msg.get("content", "")
words = re.findall(r"\b\w+\b", content.lower())
unique_words.update(words)
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
vocabulary_richness = (
len(unique_words) / total_length if total_length > 0 else 0
)
# Simple complexity score based on length and richness
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
return complexity
@ -49,10 +51,10 @@ class AdvancedContextManager:
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
if not text.strip():
return []
sentences = re.split(r'(?<=[.!?])\s+', text)
sentences = re.split(r"(?<=[.!?])\s+", text)
if not sentences:
return []
# Simple scoring based on length and position
scored_sentences = []
for i, sentence in enumerate(sentences):
@ -60,25 +62,25 @@ class AdvancedContextManager:
position_score = 1.0 if i == 0 else 0.8 if i < len(sentences) / 2 else 0.6
score = (length_score + position_score) / 2
scored_sentences.append((sentence, score))
scored_sentences.sort(key=lambda x: x[1], reverse=True)
return [s[0] for s in scored_sentences[:top_k]]
def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str:
all_content = ' '.join([msg.get('content', '') for msg in messages])
all_content = " ".join([msg.get("content", "") for msg in messages])
key_sentences = self.extract_key_sentences(all_content, top_k=3)
summary = ' '.join(key_sentences)
summary = " ".join(key_sentences)
return summary if summary else "No content to summarize."
def score_message_relevance(self, message: Dict[str, Any], context: str) -> float:
content = message.get('content', '')
content_words = set(re.findall(r'\b\w+\b', content.lower()))
context_words = set(re.findall(r'\b\w+\b', context.lower()))
content = message.get("content", "")
content_words = set(re.findall(r"\b\w+\b", content.lower()))
context_words = set(re.findall(r"\b\w+\b", context.lower()))
intersection = content_words & context_words
union = content_words | context_words
if not union:
return 0.0
return len(intersection) / len(union)
return len(intersection) / len(union)

View File

@ -1,13 +1,17 @@
import json
import urllib.request
import urllib.error
import logging
from pr.config import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
import urllib.error
import urllib.request
from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE
from pr.core.context import auto_slim_messages
logger = logging.getLogger('pr')
logger = logging.getLogger("pr")
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
def call_api(
messages, model, api_url, api_key, use_tools, tools_definition, verbose=False
):
try:
messages = auto_slim_messages(messages, verbose=verbose)
@ -17,62 +21,63 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
logger.debug(f"Use tools: {use_tools}")
logger.debug(f"Message count: {len(messages)}")
headers = {
'Content-Type': 'application/json',
"Content-Type": "application/json",
}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
headers["Authorization"] = f"Bearer {api_key}"
data = {
'model': model,
'messages': messages,
'temperature': DEFAULT_TEMPERATURE,
'max_tokens': DEFAULT_MAX_TOKENS
"model": model,
"messages": messages,
"temperature": DEFAULT_TEMPERATURE,
"max_tokens": DEFAULT_MAX_TOKENS,
}
if "gpt-5" in model:
del data['temperature']
del data['max_tokens']
del data["temperature"]
del data["max_tokens"]
logger.debug("GPT-5 detected: removed temperature and max_tokens")
if use_tools:
data['tools'] = tools_definition
data['tool_choice'] = 'auto'
data["tools"] = tools_definition
data["tool_choice"] = "auto"
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
request_json = json.dumps(data)
logger.debug(f"Request payload size: {len(request_json)} bytes")
req = urllib.request.Request(
api_url,
data=request_json.encode('utf-8'),
headers=headers,
method='POST'
api_url, data=request_json.encode("utf-8"), headers=headers, method="POST"
)
logger.debug("Sending HTTP request...")
with urllib.request.urlopen(req) as response:
response_data = response.read().decode('utf-8')
response_data = response.read().decode("utf-8")
logger.debug(f"Response received: {len(response_data)} bytes")
result = json.loads(response_data)
if 'usage' in result:
if "usage" in result:
logger.debug(f"Token usage: {result['usage']}")
if 'choices' in result and result['choices']:
choice = result['choices'][0]
if 'message' in choice:
msg = choice['message']
if "choices" in result and result["choices"]:
choice = result["choices"][0]
if "message" in choice:
msg = choice["message"]
logger.debug(f"Response role: {msg.get('role', 'N/A')}")
if 'content' in msg and msg['content']:
logger.debug(f"Response content length: {len(msg['content'])} chars")
if 'tool_calls' in msg:
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
if "content" in msg and msg["content"]:
logger.debug(
f"Response content length: {len(msg['content'])} chars"
)
if "tool_calls" in msg:
logger.debug(
f"Response contains {len(msg['tool_calls'])} tool call(s)"
)
logger.debug("=== API CALL END ===")
return result
except urllib.error.HTTPError as e:
error_body = e.read().decode('utf-8')
error_body = e.read().decode("utf-8")
logger.error(f"API HTTP Error: {e.code} - {error_body}")
logger.debug("=== API CALL FAILED ===")
return {"error": f"API Error: {e.code}", "message": error_body}
@ -81,15 +86,16 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
logger.debug("=== API CALL FAILED ===")
return {"error": str(e)}
def list_models(model_list_url, api_key):
try:
req = urllib.request.Request(model_list_url)
if api_key:
req.add_header('Authorization', f'Bearer {api_key}')
req.add_header("Authorization", f"Bearer {api_key}")
with urllib.request.urlopen(req) as response:
data = json.loads(response.read().decode('utf-8'))
data = json.loads(response.read().decode("utf-8"))
return data.get('data', [])
return data.get("data", [])
except Exception as e:
return {"error": str(e)}

View File

@ -1,63 +1,111 @@
import os
import sys
import json
import sqlite3
import signal
import logging
import traceback
import readline
import glob as glob_module
import json
import logging
import os
import readline
import signal
import sqlite3
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
from pr.config import DB_PATH, LOG_FILE, DEFAULT_MODEL, DEFAULT_API_URL, MODEL_LIST_URL, HISTORY_FILE
from pr.ui import Colors, render_markdown
from pr.core.context import init_system_message, truncate_tool_result
from pr.commands import handle_command
from pr.config import (
DB_PATH,
DEFAULT_API_URL,
DEFAULT_MODEL,
HISTORY_FILE,
LOG_FILE,
MODEL_LIST_URL,
)
from pr.core.api import call_api
from pr.core.autonomous_interactions import (
get_global_autonomous,
stop_global_autonomous,
)
from pr.core.background_monitor import (
get_global_monitor,
start_global_monitor,
stop_global_monitor,
)
from pr.core.context import init_system_message, truncate_tool_result
from pr.tools import (
http_fetch, run_command, run_command_interactive, read_file, write_file,
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
web_search, web_search_news, python_exec, index_source_directory,
open_editor, editor_insert_text, editor_replace_text, editor_search,
search_replace,close_editor,create_diff,apply_patch,
tail_process, kill_process
apply_patch,
chdir,
close_editor,
create_diff,
db_get,
db_query,
db_set,
editor_insert_text,
editor_replace_text,
editor_search,
getpwd,
http_fetch,
index_source_directory,
kill_process,
list_directory,
mkdir,
open_editor,
python_exec,
read_file,
run_command,
search_replace,
tail_process,
web_search,
web_search_news,
write_file,
)
from pr.tools.base import get_tools_definition
from pr.tools.filesystem import (
clear_edit_tracker,
display_edit_summary,
display_edit_timeline,
)
from pr.tools.interactive_control import (
start_interactive_session, send_input_to_session, read_session_output,
list_active_sessions, close_interactive_session
close_interactive_session,
list_active_sessions,
read_session_output,
send_input_to_session,
start_interactive_session,
)
from pr.tools.patch import display_file_diff
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
from pr.tools.base import get_tools_definition
from pr.commands import handle_command
from pr.core.background_monitor import start_global_monitor, stop_global_monitor, get_global_monitor
from pr.core.autonomous_interactions import start_global_autonomous, stop_global_autonomous, get_global_autonomous
from pr.ui import Colors, render_markdown
logger = logging.getLogger('pr')
logger = logging.getLogger("pr")
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(LOG_FILE)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
file_handler.setFormatter(
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
)
logger.addHandler(file_handler)
class Assistant:
def __init__(self, args):
self.args = args
self.messages = []
self.verbose = args.verbose
self.debug = getattr(args, 'debug', False)
self.debug = getattr(args, "debug", False)
self.syntax_highlighting = not args.no_syntax
if self.debug:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
console_handler.setFormatter(
logging.Formatter("%(levelname)s: %(message)s")
)
logger.addHandler(console_handler)
logger.debug("Debug mode enabled")
self.api_key = os.environ.get('OPENROUTER_API_KEY', '')
self.model = args.model or os.environ.get('AI_MODEL', DEFAULT_MODEL)
self.api_url = args.api_url or os.environ.get('API_URL', DEFAULT_API_URL)
self.model_list_url = args.model_list_url or os.environ.get('MODEL_LIST_URL', MODEL_LIST_URL)
self.use_tools = os.environ.get('USE_TOOLS', '1') == '1'
self.strict_mode = os.environ.get('STRICT_MODE', '0') == '1'
self.api_key = os.environ.get("OPENROUTER_API_KEY", "")
self.model = args.model or os.environ.get("AI_MODEL", DEFAULT_MODEL)
self.api_url = args.api_url or os.environ.get("API_URL", DEFAULT_API_URL)
self.model_list_url = args.model_list_url or os.environ.get(
"MODEL_LIST_URL", MODEL_LIST_URL
)
self.use_tools = os.environ.get("USE_TOOLS", "1") == "1"
self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1"
self.interrupt_count = 0
self.python_globals = {}
self.db_conn = None
@ -69,6 +117,7 @@ class Assistant:
try:
from pr.core.enhanced_assistant import EnhancedAssistant
self.enhanced = EnhancedAssistant(self)
if self.debug:
logger.debug("Enhanced assistant features initialized")
@ -94,13 +143,17 @@ class Assistant:
self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
cursor = self.db_conn.cursor()
cursor.execute('''CREATE TABLE IF NOT EXISTS kv_store
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)''')
cursor.execute(
"""CREATE TABLE IF NOT EXISTS kv_store
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)"""
)
cursor.execute('''CREATE TABLE IF NOT EXISTS file_versions
cursor.execute(
"""CREATE TABLE IF NOT EXISTS file_versions
(id INTEGER PRIMARY KEY AUTOINCREMENT,
filepath TEXT, content TEXT, hash TEXT,
timestamp REAL, version INTEGER)''')
timestamp REAL, version INTEGER)"""
)
self.db_conn.commit()
logger.debug("Database initialized successfully")
@ -110,7 +163,7 @@ class Assistant:
def _handle_background_updates(self, updates):
"""Handle background session updates by injecting them into the conversation."""
if not updates or not updates.get('sessions'):
if not updates or not updates.get("sessions"):
return
# Format the update as a system message
@ -118,10 +171,12 @@ class Assistant:
# Inject into current conversation if we're in an active session
if self.messages and len(self.messages) > 0:
self.messages.append({
"role": "system",
"content": f"Background session updates: {update_message}"
})
self.messages.append(
{
"role": "system",
"content": f"Background session updates: {update_message}",
}
)
if self.verbose:
print(f"{Colors.CYAN}Background update: {update_message}{Colors.RESET}")
@ -130,8 +185,8 @@ class Assistant:
"""Format background updates for LLM consumption."""
session_summaries = []
for session_name, session_info in updates.get('sessions', {}).items():
summary = session_info.get('summary', f'Session {session_name}')
for session_name, session_info in updates.get("sessions", {}).items():
summary = session_info.get("summary", f"Session {session_name}")
session_summaries.append(f"{session_name}: {summary}")
if session_summaries:
@ -151,30 +206,44 @@ class Assistant:
if events:
print(f"\n{Colors.CYAN}Background Events:{Colors.RESET}")
for event in events:
event_type = event.get('type', 'unknown')
session_name = event.get('session_name', 'unknown')
event_type = event.get("type", "unknown")
session_name = event.get("session_name", "unknown")
if event_type == 'session_started':
print(f" {Colors.GREEN}{Colors.RESET} Session '{session_name}' started")
elif event_type == 'session_ended':
print(f" {Colors.YELLOW}{Colors.RESET} Session '{session_name}' ended")
elif event_type == 'output_received':
lines = len(event.get('new_output', {}).get('stdout', []))
print(f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output")
elif event_type == 'possible_input_needed':
print(f" {Colors.RED}{Colors.RESET} Session '{session_name}' may need input")
elif event_type == 'high_output_volume':
total = event.get('total_lines', 0)
print(f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)")
elif event_type == 'inactive_session':
inactive_time = event.get('inactive_seconds', 0)
print(f" {Colors.GRAY}{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s")
if event_type == "session_started":
print(
f" {Colors.GREEN}{Colors.RESET} Session '{session_name}' started"
)
elif event_type == "session_ended":
print(
f" {Colors.YELLOW}{Colors.RESET} Session '{session_name}' ended"
)
elif event_type == "output_received":
lines = len(event.get("new_output", {}).get("stdout", []))
print(
f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output"
)
elif event_type == "possible_input_needed":
print(
f" {Colors.RED}{Colors.RESET} Session '{session_name}' may need input"
)
elif event_type == "high_output_volume":
total = event.get("total_lines", 0)
print(
f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)"
)
elif event_type == "inactive_session":
inactive_time = event.get("inactive_seconds", 0)
print(
f" {Colors.GRAY}{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s"
)
print() # Add blank line after events
except Exception as e:
if self.debug:
print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}")
print(
f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}"
)
def execute_tool_calls(self, tool_calls):
results = []
@ -185,114 +254,147 @@ class Assistant:
futures = []
for tool_call in tool_calls:
func_name = tool_call['function']['name']
arguments = json.loads(tool_call['function']['arguments'])
func_name = tool_call["function"]["name"]
arguments = json.loads(tool_call["function"]["arguments"])
logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
func_map = {
'http_fetch': lambda **kw: http_fetch(**kw),
'run_command': lambda **kw: run_command(**kw),
'tail_process': lambda **kw: tail_process(**kw),
'kill_process': lambda **kw: kill_process(**kw),
'start_interactive_session': lambda **kw: start_interactive_session(**kw),
'send_input_to_session': lambda **kw: send_input_to_session(**kw),
'read_session_output': lambda **kw: read_session_output(**kw),
'close_interactive_session': lambda **kw: close_interactive_session(**kw),
'read_file': lambda **kw: read_file(**kw, db_conn=self.db_conn),
'write_file': lambda **kw: write_file(**kw, db_conn=self.db_conn),
'list_directory': lambda **kw: list_directory(**kw),
'mkdir': lambda **kw: mkdir(**kw),
'chdir': lambda **kw: chdir(**kw),
'getpwd': lambda **kw: getpwd(**kw),
'db_set': lambda **kw: db_set(**kw, db_conn=self.db_conn),
'db_get': lambda **kw: db_get(**kw, db_conn=self.db_conn),
'db_query': lambda **kw: db_query(**kw, db_conn=self.db_conn),
'web_search': lambda **kw: web_search(**kw),
'web_search_news': lambda **kw: web_search_news(**kw),
'python_exec': lambda **kw: python_exec(**kw, python_globals=self.python_globals),
'index_source_directory': lambda **kw: index_source_directory(**kw),
'search_replace': lambda **kw: search_replace(**kw, db_conn=self.db_conn),
'open_editor': lambda **kw: open_editor(**kw),
'editor_insert_text': lambda **kw: editor_insert_text(**kw, db_conn=self.db_conn),
'editor_replace_text': lambda **kw: editor_replace_text(**kw, db_conn=self.db_conn),
'editor_search': lambda **kw: editor_search(**kw),
'close_editor': lambda **kw: close_editor(**kw),
'create_diff': lambda **kw: create_diff(**kw),
'apply_patch': lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
'display_file_diff': lambda **kw: display_file_diff(**kw),
'display_edit_summary': lambda **kw: display_edit_summary(),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
'start_interactive_session': lambda **kw: start_interactive_session(**kw),
'send_input_to_session': lambda **kw: send_input_to_session(**kw),
'read_session_output': lambda **kw: read_session_output(**kw),
'list_active_sessions': lambda **kw: list_active_sessions(**kw),
'close_interactive_session': lambda **kw: close_interactive_session(**kw),
'create_agent': lambda **kw: create_agent(**kw),
'list_agents': lambda **kw: list_agents(**kw),
'execute_agent_task': lambda **kw: execute_agent_task(**kw),
'remove_agent': lambda **kw: remove_agent(**kw),
'collaborate_agents': lambda **kw: collaborate_agents(**kw),
'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw),
'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw),
'search_knowledge': lambda **kw: search_knowledge(**kw),
'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw),
'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw),
'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw),
'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw),
"http_fetch": lambda **kw: http_fetch(**kw),
"run_command": lambda **kw: run_command(**kw),
"tail_process": lambda **kw: tail_process(**kw),
"kill_process": lambda **kw: kill_process(**kw),
"start_interactive_session": lambda **kw: start_interactive_session(
**kw
),
"send_input_to_session": lambda **kw: send_input_to_session(**kw),
"read_session_output": lambda **kw: read_session_output(**kw),
"close_interactive_session": lambda **kw: close_interactive_session(
**kw
),
"read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn),
"write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn),
"list_directory": lambda **kw: list_directory(**kw),
"mkdir": lambda **kw: mkdir(**kw),
"chdir": lambda **kw: chdir(**kw),
"getpwd": lambda **kw: getpwd(**kw),
"db_set": lambda **kw: db_set(**kw, db_conn=self.db_conn),
"db_get": lambda **kw: db_get(**kw, db_conn=self.db_conn),
"db_query": lambda **kw: db_query(**kw, db_conn=self.db_conn),
"web_search": lambda **kw: web_search(**kw),
"web_search_news": lambda **kw: web_search_news(**kw),
"python_exec": lambda **kw: python_exec(
**kw, python_globals=self.python_globals
),
"index_source_directory": lambda **kw: index_source_directory(**kw),
"search_replace": lambda **kw: search_replace(
**kw, db_conn=self.db_conn
),
"open_editor": lambda **kw: open_editor(**kw),
"editor_insert_text": lambda **kw: editor_insert_text(
**kw, db_conn=self.db_conn
),
"editor_replace_text": lambda **kw: editor_replace_text(
**kw, db_conn=self.db_conn
),
"editor_search": lambda **kw: editor_search(**kw),
"close_editor": lambda **kw: close_editor(**kw),
"create_diff": lambda **kw: create_diff(**kw),
"apply_patch": lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
"display_file_diff": lambda **kw: display_file_diff(**kw),
"display_edit_summary": lambda **kw: display_edit_summary(),
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
"clear_edit_tracker": lambda **kw: clear_edit_tracker(),
"start_interactive_session": lambda **kw: start_interactive_session(
**kw
),
"send_input_to_session": lambda **kw: send_input_to_session(**kw),
"read_session_output": lambda **kw: read_session_output(**kw),
"list_active_sessions": lambda **kw: list_active_sessions(**kw),
"close_interactive_session": lambda **kw: close_interactive_session(
**kw
),
"create_agent": lambda **kw: create_agent(**kw),
"list_agents": lambda **kw: list_agents(**kw),
"execute_agent_task": lambda **kw: execute_agent_task(**kw),
"remove_agent": lambda **kw: remove_agent(**kw),
"collaborate_agents": lambda **kw: collaborate_agents(**kw),
"add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw),
"get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw),
"search_knowledge": lambda **kw: search_knowledge(**kw),
"get_knowledge_by_category": lambda **kw: get_knowledge_by_category(
**kw
),
"update_knowledge_importance": lambda **kw: update_knowledge_importance(
**kw
),
"delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw),
"get_knowledge_statistics": lambda **kw: get_knowledge_statistics(
**kw
),
}
if func_name in func_map:
future = executor.submit(func_map[func_name], **arguments)
futures.append((tool_call['id'], future))
futures.append((tool_call["id"], future))
for tool_id, future in futures:
try:
result = future.result(timeout=30)
result = truncate_tool_result(result)
logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...")
results.append({
"tool_call_id": tool_id,
"role": "tool",
"content": json.dumps(result)
})
results.append(
{
"tool_call_id": tool_id,
"role": "tool",
"content": json.dumps(result),
}
)
except Exception as e:
logger.debug(f"Tool error for {tool_id}: {str(e)}")
error_msg = str(e)[:200] if len(str(e)) > 200 else str(e)
results.append({
"tool_call_id": tool_id,
"role": "tool",
"content": json.dumps({"status": "error", "error": error_msg})
})
results.append(
{
"tool_call_id": tool_id,
"role": "tool",
"content": json.dumps(
{"status": "error", "error": error_msg}
),
}
)
return results
def process_response(self, response):
if 'error' in response:
if "error" in response:
return f"Error: {response['error']}"
if 'choices' not in response or not response['choices']:
if "choices" not in response or not response["choices"]:
return "No response from API"
message = response['choices'][0]['message']
message = response["choices"][0]["message"]
self.messages.append(message)
if 'tool_calls' in message and message['tool_calls']:
if "tool_calls" in message and message["tool_calls"]:
if self.verbose:
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}")
tool_results = self.execute_tool_calls(message['tool_calls'])
tool_results = self.execute_tool_calls(message["tool_calls"])
for result in tool_results:
self.messages.append(result)
follow_up = call_api(
self.messages, self.model, self.api_url, self.api_key,
self.use_tools, get_tools_definition(), verbose=self.verbose
self.messages,
self.model,
self.api_url,
self.api_key,
self.use_tools,
get_tools_definition(),
verbose=self.verbose,
)
return self.process_response(follow_up)
content = message.get('content', '')
content = message.get("content", "")
return render_markdown(content, self.syntax_highlighting)
def signal_handler(self, signum, frame):
@ -303,7 +405,9 @@ class Assistant:
self.autonomous_mode = False
sys.exit(0)
else:
print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}")
print(
f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}"
)
return
self.interrupt_count += 1
@ -323,21 +427,34 @@ class Assistant:
readline.set_history_length(1000)
import atexit
atexit.register(readline.write_history_file, HISTORY_FILE)
commands = ['exit', 'quit', 'help', 'reset', 'dump', 'verbose',
'models', 'tools', 'review', 'refactor', 'obfuscate', '/auto']
commands = [
"exit",
"quit",
"help",
"reset",
"dump",
"verbose",
"models",
"tools",
"review",
"refactor",
"obfuscate",
"/auto",
]
def completer(text, state):
options = [cmd for cmd in commands if cmd.startswith(text)]
glob_pattern = os.path.expanduser(text) + '*'
glob_pattern = os.path.expanduser(text) + "*"
path_options = glob_module.glob(glob_pattern)
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
combined_options = sorted(list(set(options + path_options)))
#combined_options.extend(self.commands)
# combined_options.extend(self.commands)
if state < len(combined_options):
return combined_options[state]
@ -345,10 +462,10 @@ class Assistant:
return None
delims = readline.get_completer_delims()
readline.set_completer_delims(delims.replace('/', ''))
readline.set_completer_delims(delims.replace("/", ""))
readline.set_completer(completer)
readline.parse_and_bind('tab: complete')
readline.parse_and_bind("tab: complete")
def run_repl(self):
self.setup_readline()
@ -368,8 +485,11 @@ class Assistant:
if self.background_monitoring:
try:
from pr.multiplexer import get_all_sessions
sessions = get_all_sessions()
active_count = sum(1 for s in sessions.values() if s.get('status') == 'running')
active_count = sum(
1 for s in sessions.values() if s.get("status") == "running"
)
if active_count > 0:
prompt += f"[{active_count}bg]"
except:
@ -405,10 +525,11 @@ class Assistant:
message = sys.stdin.read()
from pr.autonomous.mode import run_autonomous_mode
run_autonomous_mode(self, message)
def cleanup(self):
if hasattr(self, 'enhanced') and self.enhanced:
if hasattr(self, "enhanced") and self.enhanced:
try:
self.enhanced.cleanup()
except Exception as e:
@ -424,6 +545,7 @@ class Assistant:
try:
from pr.multiplexer import cleanup_all_multiplexers
cleanup_all_multiplexers()
except Exception as e:
logger.error(f"Error cleaning up multiplexers: {e}")
@ -433,7 +555,9 @@ class Assistant:
def run(self):
try:
print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}")
print(
f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}"
)
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
print("DEBUG: calling run_repl")
self.run_repl()
@ -443,6 +567,7 @@ class Assistant:
finally:
self.cleanup()
def process_message(assistant, message):
assistant.messages.append({"role": "user", "content": message})
@ -453,9 +578,13 @@ def process_message(assistant, message):
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}")
response = call_api(
assistant.messages, assistant.model, assistant.api_url,
assistant.api_key, assistant.use_tools, get_tools_definition(),
verbose=assistant.verbose
assistant.messages,
assistant.model,
assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose,
)
result = assistant.process_response(response)

View File

@ -1,7 +1,12 @@
import time
import threading
from pr.core.background_monitor import get_global_monitor
from pr.tools.interactive_control import list_active_sessions, get_session_status, read_session_output
import time
from pr.tools.interactive_control import (
get_session_status,
list_active_sessions,
read_session_output,
)
class AutonomousInteractions:
def __init__(self, interaction_interval=10.0):
@ -16,7 +21,9 @@ class AutonomousInteractions:
self.llm_callback = llm_callback
if self.interaction_thread is None:
self.active = True
self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True)
self.interaction_thread = threading.Thread(
target=self._interaction_loop, daemon=True
)
self.interaction_thread.start()
def stop(self):
@ -48,7 +55,9 @@ class AutonomousInteractions:
if not sessions:
return # No active sessions
sessions_needing_attention = self._identify_sessions_needing_attention(sessions)
sessions_needing_attention = self._identify_sessions_needing_attention(
sessions
)
if sessions_needing_attention and self.llm_callback:
# Format session updates for LLM
@ -63,26 +72,30 @@ class AutonomousInteractions:
needing_attention = []
for session_name, session_data in sessions.items():
metadata = session_data['metadata']
output_summary = session_data['output_summary']
metadata = session_data["metadata"]
output_summary = session_data["output_summary"]
# Criteria for needing attention:
# 1. Recent output activity
time_since_activity = time.time() - metadata.get('last_activity', 0)
time_since_activity = time.time() - metadata.get("last_activity", 0)
if time_since_activity < 30: # Activity in last 30 seconds
needing_attention.append(session_name)
continue
# 2. High output volume (potential completion or error)
total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines']
total_lines = (
output_summary["stdout_lines"] + output_summary["stderr_lines"]
)
if total_lines > 50: # Arbitrary threshold
needing_attention.append(session_name)
continue
# 3. Long-running sessions that might need intervention
session_age = time.time() - metadata.get('start_time', 0)
if session_age > 300 and time_since_activity > 60: # 5+ minutes old, inactive for 1+ minute
session_age = time.time() - metadata.get("start_time", 0)
if (
session_age > 300 and time_since_activity > 60
): # 5+ minutes old, inactive for 1+ minute
needing_attention.append(session_name)
continue
@ -95,18 +108,18 @@ class AutonomousInteractions:
def _session_looks_stuck(self, session_name, session_data):
"""Determine if a session appears to be stuck waiting for input."""
metadata = session_data['metadata']
metadata = session_data["metadata"]
# Check if process is still running
status = get_session_status(session_name)
if not status or not status.get('is_active', False):
if not status or not status.get("is_active", False):
return False
time_since_activity = time.time() - metadata.get('last_activity', 0)
interaction_count = metadata.get('interaction_count', 0)
time_since_activity = time.time() - metadata.get("last_activity", 0)
interaction_count = metadata.get("interaction_count", 0)
# If running for a while but no interactions, might be waiting
session_age = time.time() - metadata.get('start_time', 0)
session_age = time.time() - metadata.get("start_time", 0)
if session_age > 60 and interaction_count == 0 and time_since_activity > 30:
return True
@ -119,9 +132,9 @@ class AutonomousInteractions:
def _format_session_updates(self, session_names):
"""Format session information for LLM consumption."""
updates = {
'type': 'background_session_updates',
'timestamp': time.time(),
'sessions': {}
"type": "background_session_updates",
"timestamp": time.time(),
"sessions": {},
}
for session_name in session_names:
@ -131,12 +144,12 @@ class AutonomousInteractions:
try:
recent_output = read_session_output(session_name, lines=20)
except:
recent_output = {'stdout': '', 'stderr': ''}
recent_output = {"stdout": "", "stderr": ""}
updates['sessions'][session_name] = {
'status': status,
'recent_output': recent_output,
'summary': self._create_session_summary(status, recent_output)
updates["sessions"][session_name] = {
"status": status,
"recent_output": recent_output,
"summary": self._create_session_summary(status, recent_output),
}
return updates
@ -145,34 +158,39 @@ class AutonomousInteractions:
"""Create a human-readable summary of session status."""
summary_parts = []
process_type = status.get('metadata', {}).get('process_type', 'unknown')
process_type = status.get("metadata", {}).get("process_type", "unknown")
summary_parts.append(f"Type: {process_type}")
is_active = status.get('is_active', False)
is_active = status.get("is_active", False)
summary_parts.append(f"Status: {'Active' if is_active else 'Inactive'}")
if is_active and 'pid' in status:
if is_active and "pid" in status:
summary_parts.append(f"PID: {status['pid']}")
age = time.time() - status.get('metadata', {}).get('start_time', 0)
age = time.time() - status.get("metadata", {}).get("start_time", 0)
summary_parts.append(f"Age: {age:.1f}s")
output_lines = len(recent_output.get('stdout', '').split('\n')) + len(recent_output.get('stderr', '').split('\n'))
output_lines = len(recent_output.get("stdout", "").split("\n")) + len(
recent_output.get("stderr", "").split("\n")
)
summary_parts.append(f"Recent output: {output_lines} lines")
interaction_count = status.get('metadata', {}).get('interaction_count', 0)
interaction_count = status.get("metadata", {}).get("interaction_count", 0)
summary_parts.append(f"Interactions: {interaction_count}")
return " | ".join(summary_parts)
# Global autonomous interactions instance
_global_autonomous = None
def get_global_autonomous():
"""Get the global autonomous interactions instance."""
global _global_autonomous
return _global_autonomous
def start_global_autonomous(llm_callback=None):
"""Start global autonomous interactions."""
global _global_autonomous
@ -181,6 +199,7 @@ def start_global_autonomous(llm_callback=None):
_global_autonomous.start(llm_callback)
return _global_autonomous
def stop_global_autonomous():
"""Stop global autonomous interactions."""
global _global_autonomous

View File

@ -1,8 +1,9 @@
import queue
import threading
import time
import queue
from pr.multiplexer import get_all_multiplexer_states, get_multiplexer
from pr.tools.interactive_control import get_session_status
class BackgroundMonitor:
def __init__(self, check_interval=5.0):
@ -17,7 +18,9 @@ class BackgroundMonitor:
"""Start the background monitoring thread."""
if self.monitor_thread is None:
self.active = True
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self.monitor_thread = threading.Thread(
target=self._monitor_loop, daemon=True
)
self.monitor_thread.start()
def stop(self):
@ -78,19 +81,18 @@ class BackgroundMonitor:
# Check for new sessions
for session_name in new_states:
if session_name not in old_states:
events.append({
'type': 'session_started',
'session_name': session_name,
'metadata': new_states[session_name]['metadata']
})
events.append(
{
"type": "session_started",
"session_name": session_name,
"metadata": new_states[session_name]["metadata"],
}
)
# Check for ended sessions
for session_name in old_states:
if session_name not in new_states:
events.append({
'type': 'session_ended',
'session_name': session_name
})
events.append({"type": "session_ended", "session_name": session_name})
# Check for activity in existing sessions
for session_name, new_state in new_states.items():
@ -98,92 +100,112 @@ class BackgroundMonitor:
old_state = old_states[session_name]
# Check for output changes
old_stdout_lines = old_state['output_summary']['stdout_lines']
new_stdout_lines = new_state['output_summary']['stdout_lines']
old_stderr_lines = old_state['output_summary']['stderr_lines']
new_stderr_lines = new_state['output_summary']['stderr_lines']
old_stdout_lines = old_state["output_summary"]["stdout_lines"]
new_stdout_lines = new_state["output_summary"]["stdout_lines"]
old_stderr_lines = old_state["output_summary"]["stderr_lines"]
new_stderr_lines = new_state["output_summary"]["stderr_lines"]
if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines:
if (
new_stdout_lines > old_stdout_lines
or new_stderr_lines > old_stderr_lines
):
# Get the new output
mux = get_multiplexer(session_name)
if mux:
all_output = mux.get_all_output()
new_output = {
'stdout': all_output['stdout'].split('\n')[old_stdout_lines:],
'stderr': all_output['stderr'].split('\n')[old_stderr_lines:]
"stdout": all_output["stdout"].split("\n")[
old_stdout_lines:
],
"stderr": all_output["stderr"].split("\n")[
old_stderr_lines:
],
}
events.append({
'type': 'output_received',
'session_name': session_name,
'new_output': new_output,
'total_lines': {
'stdout': new_stdout_lines,
'stderr': new_stderr_lines
events.append(
{
"type": "output_received",
"session_name": session_name,
"new_output": new_output,
"total_lines": {
"stdout": new_stdout_lines,
"stderr": new_stderr_lines,
},
}
})
)
# Check for state changes
old_metadata = old_state['metadata']
new_metadata = new_state['metadata']
old_metadata = old_state["metadata"]
new_metadata = new_state["metadata"]
if old_metadata.get('state') != new_metadata.get('state'):
events.append({
'type': 'state_changed',
'session_name': session_name,
'old_state': old_metadata.get('state'),
'new_state': new_metadata.get('state')
})
if old_metadata.get("state") != new_metadata.get("state"):
events.append(
{
"type": "state_changed",
"session_name": session_name,
"old_state": old_metadata.get("state"),
"new_state": new_metadata.get("state"),
}
)
# Check for process type identification
if (old_metadata.get('process_type') == 'unknown' and
new_metadata.get('process_type') != 'unknown'):
events.append({
'type': 'process_identified',
'session_name': session_name,
'process_type': new_metadata.get('process_type')
})
if (
old_metadata.get("process_type") == "unknown"
and new_metadata.get("process_type") != "unknown"
):
events.append(
{
"type": "process_identified",
"session_name": session_name,
"process_type": new_metadata.get("process_type"),
}
)
# Check for sessions needing attention (based on heuristics)
for session_name, state in new_states.items():
metadata = state['metadata']
output_summary = state['output_summary']
metadata = state["metadata"]
output_summary = state["output_summary"]
# Heuristic: High output volume might indicate completion or error
total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines']
total_lines = (
output_summary["stdout_lines"] + output_summary["stderr_lines"]
)
if total_lines > 100: # Arbitrary threshold
events.append({
'type': 'high_output_volume',
'session_name': session_name,
'total_lines': total_lines
})
events.append(
{
"type": "high_output_volume",
"session_name": session_name,
"total_lines": total_lines,
}
)
# Heuristic: Long-running session without recent activity
time_since_activity = time.time() - metadata.get('last_activity', 0)
time_since_activity = time.time() - metadata.get("last_activity", 0)
if time_since_activity > 300: # 5 minutes
events.append({
'type': 'inactive_session',
'session_name': session_name,
'inactive_seconds': time_since_activity
})
events.append(
{
"type": "inactive_session",
"session_name": session_name,
"inactive_seconds": time_since_activity,
}
)
# Heuristic: Sessions that might be waiting for input
# This would be enhanced with prompt detection in later phases
if self._might_be_waiting_for_input(session_name, state):
events.append({
'type': 'possible_input_needed',
'session_name': session_name
})
events.append(
{"type": "possible_input_needed", "session_name": session_name}
)
return events
def _might_be_waiting_for_input(self, session_name, state):
"""Heuristic to detect if a session might be waiting for input."""
metadata = state['metadata']
process_type = metadata.get('process_type', 'unknown')
metadata = state["metadata"]
metadata.get("process_type", "unknown")
# Simple heuristics based on process type and recent activity
time_since_activity = time.time() - metadata.get('last_activity', 0)
time_since_activity = time.time() - metadata.get("last_activity", 0)
# If it's been more than 10 seconds since last activity, might be waiting
if time_since_activity > 10:
@ -191,9 +213,11 @@ class BackgroundMonitor:
return False
# Global monitor instance
_global_monitor = None
def get_global_monitor():
"""Get the global background monitor instance."""
global _global_monitor
@ -201,20 +225,24 @@ def get_global_monitor():
_global_monitor = BackgroundMonitor()
return _global_monitor
def start_global_monitor():
"""Start the global background monitor."""
monitor = get_global_monitor()
monitor.start()
def stop_global_monitor():
"""Stop the global background monitor."""
global _global_monitor
if _global_monitor:
_global_monitor.stop()
# Global monitor instance
_global_monitor = None
def start_global_monitor():
"""Start the global background monitor."""
global _global_monitor
@ -223,6 +251,7 @@ def start_global_monitor():
_global_monitor.start()
return _global_monitor
def stop_global_monitor():
"""Stop the global background monitor."""
global _global_monitor
@ -230,6 +259,7 @@ def stop_global_monitor():
_global_monitor.stop()
_global_monitor = None
def get_global_monitor():
"""Get the global background monitor instance."""
global _global_monitor

View File

@ -1,22 +1,17 @@
import os
import configparser
from typing import Dict, Any
import os
from typing import Any, Dict
from pr.core.logging import get_logger
logger = get_logger('config')
logger = get_logger("config")
CONFIG_FILE = os.path.expanduser("~/.prrc")
LOCAL_CONFIG_FILE = ".prrc"
def load_config() -> Dict[str, Any]:
config = {
'api': {},
'autonomous': {},
'ui': {},
'output': {},
'session': {}
}
config = {"api": {}, "autonomous": {}, "ui": {}, "output": {}, "session": {}}
global_config = _load_config_file(CONFIG_FILE)
local_config = _load_config_file(LOCAL_CONFIG_FILE)
@ -55,9 +50,9 @@ def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
def _parse_value(value: str) -> Any:
value = value.strip()
if value.lower() == 'true':
if value.lower() == "true":
return True
if value.lower() == 'false':
if value.lower() == "false":
return False
if value.isdigit():
@ -99,7 +94,7 @@ max_history = 1000
"""
try:
with open(filepath, 'w') as f:
with open(filepath, "w") as f:
f.write(default_config)
logger.info(f"Created default configuration at {filepath}")
return True

View File

@ -1,11 +1,21 @@
import os
import json
import logging
from pr.config import (CONTEXT_FILE, GLOBAL_CONTEXT_FILE, CONTEXT_COMPRESSION_THRESHOLD,
RECENT_MESSAGES_TO_KEEP, MAX_TOKENS_LIMIT, CHARS_PER_TOKEN,
EMERGENCY_MESSAGES_TO_KEEP, CONTENT_TRIM_LENGTH, MAX_TOOL_RESULT_LENGTH)
import os
from pr.config import (
CHARS_PER_TOKEN,
CONTENT_TRIM_LENGTH,
CONTEXT_COMPRESSION_THRESHOLD,
CONTEXT_FILE,
EMERGENCY_MESSAGES_TO_KEEP,
GLOBAL_CONTEXT_FILE,
MAX_TOKENS_LIMIT,
MAX_TOOL_RESULT_LENGTH,
RECENT_MESSAGES_TO_KEEP,
)
from pr.ui import Colors
def truncate_tool_result(result, max_length=None):
if max_length is None:
max_length = MAX_TOOL_RESULT_LENGTH
@ -17,24 +27,36 @@ def truncate_tool_result(result, max_length=None):
if "output" in result_copy and isinstance(result_copy["output"], str):
if len(result_copy["output"]) > max_length:
result_copy["output"] = result_copy["output"][:max_length] + f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
result_copy["output"] = (
result_copy["output"][:max_length]
+ f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
)
if "content" in result_copy and isinstance(result_copy["content"], str):
if len(result_copy["content"]) > max_length:
result_copy["content"] = result_copy["content"][:max_length] + f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
result_copy["content"] = (
result_copy["content"][:max_length]
+ f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
)
if "data" in result_copy and isinstance(result_copy["data"], str):
if len(result_copy["data"]) > max_length:
result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]"
result_copy["data"] = (
result_copy["data"][:max_length] + f"\n... [truncated]"
)
if "error" in result_copy and isinstance(result_copy["error"], str):
if len(result_copy["error"]) > max_length // 2:
result_copy["error"] = result_copy["error"][:max_length // 2] + "... [truncated]"
result_copy["error"] = (
result_copy["error"][: max_length // 2] + "... [truncated]"
)
return result_copy
def init_system_message(args):
context_parts = ["""You are a professional AI assistant with access to advanced tools.
context_parts = [
"""You are a professional AI assistant with access to advanced tools.
File Operations:
- Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications
@ -51,14 +73,15 @@ Process Management:
Shell Commands:
- Be a shell ninja using native OS tools
- Prefer standard Unix utilities over complex scripts
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""]
#context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""
]
# context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
max_context_size = 10000
if args.include_env:
env_context = "Environment Variables:\n"
for key, value in os.environ.items():
if not key.startswith('_'):
if not key.startswith("_"):
env_context += f"{key}={value}\n"
if len(env_context) > max_context_size:
env_context = env_context[:max_context_size] + "\n... [truncated]"
@ -67,7 +90,7 @@ Shell Commands:
for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]:
if os.path.exists(context_file):
try:
with open(context_file, 'r') as f:
with open(context_file) as f:
content = f.read()
if len(content) > max_context_size:
content = content[:max_context_size] + "\n... [truncated]"
@ -78,7 +101,7 @@ Shell Commands:
if args.context:
for ctx_file in args.context:
try:
with open(ctx_file, 'r') as f:
with open(ctx_file) as f:
content = f.read()
if len(content) > max_context_size:
content = content[:max_context_size] + "\n... [truncated]"
@ -88,22 +111,29 @@ Shell Commands:
system_message = "\n\n".join(context_parts)
if len(system_message) > max_context_size * 3:
system_message = system_message[:max_context_size * 3] + "\n... [system message truncated]"
system_message = (
system_message[: max_context_size * 3] + "\n... [system message truncated]"
)
return {"role": "system", "content": system_message}
def should_compress_context(messages):
return len(messages) > CONTEXT_COMPRESSION_THRESHOLD
def compress_context(messages):
return manage_context_window(messages, verbose=False)
def manage_context_window(messages, verbose):
if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD:
return messages
if verbose:
print(f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}")
print(
f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}"
)
system_message = messages[0]
recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:]
@ -113,18 +143,21 @@ def manage_context_window(messages, verbose):
summary = summarize_messages(middle_messages)
summary_message = {
"role": "system",
"content": f"[Previous conversation summary: {summary}]"
"content": f"[Previous conversation summary: {summary}]",
}
new_messages = [system_message, summary_message] + recent_messages
if verbose:
print(f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}")
print(
f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}"
)
return new_messages
return messages
def summarize_messages(messages):
summary_parts = []
@ -142,6 +175,7 @@ def summarize_messages(messages):
return " | ".join(summary_parts[:10])
def estimate_tokens(messages):
total_chars = 0
@ -155,6 +189,7 @@ def estimate_tokens(messages):
return int(estimated_tokens * overhead_multiplier)
def trim_message_content(message, max_length):
trimmed_msg = message.copy()
@ -162,14 +197,22 @@ def trim_message_content(message, max_length):
content = trimmed_msg["content"]
if isinstance(content, str) and len(content) > max_length:
trimmed_msg["content"] = content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]"
trimmed_msg["content"] = (
content[:max_length]
+ f"\n... [trimmed {len(content) - max_length} chars]"
)
elif isinstance(content, list):
trimmed_content = []
for item in content:
if isinstance(item, dict):
trimmed_item = item.copy()
if "text" in trimmed_item and len(trimmed_item["text"]) > max_length:
trimmed_item["text"] = trimmed_item["text"][:max_length] + f"\n... [trimmed]"
if (
"text" in trimmed_item
and len(trimmed_item["text"]) > max_length
):
trimmed_item["text"] = (
trimmed_item["text"][:max_length] + f"\n... [trimmed]"
)
trimmed_content.append(trimmed_item)
else:
trimmed_content.append(item)
@ -179,35 +222,61 @@ def trim_message_content(message, max_length):
if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str):
content = trimmed_msg["content"]
if len(content) > MAX_TOOL_RESULT_LENGTH:
trimmed_msg["content"] = content[:MAX_TOOL_RESULT_LENGTH] + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
trimmed_msg["content"] = (
content[:MAX_TOOL_RESULT_LENGTH]
+ f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
)
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
if "output" in parsed and isinstance(parsed["output"], str) and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2:
parsed["output"] = parsed["output"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
if "content" in parsed and isinstance(parsed["content"], str) and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2:
parsed["content"] = parsed["content"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
if (
"output" in parsed
and isinstance(parsed["output"], str)
and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2
):
parsed["output"] = (
parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2]
+ f"\n... [truncated]"
)
if (
"content" in parsed
and isinstance(parsed["content"], str)
and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2
):
parsed["content"] = (
parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2]
+ f"\n... [truncated]"
)
trimmed_msg["content"] = json.dumps(parsed)
except:
pass
return trimmed_msg
def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
if estimate_tokens(messages) <= target_tokens:
return messages
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
system_msg = (
messages[0] if messages and messages[0].get("role") == "system" else None
)
start_idx = 1 if system_msg else 0
recent_messages = messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
recent_messages = (
messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
)
middle_messages = (
messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
)
trimmed_middle = []
for msg in middle_messages:
if msg.get("role") == "tool":
trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2))
trimmed_middle.append(
trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)
)
elif msg.get("role") in ["user", "assistant"]:
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
else:
@ -233,6 +302,7 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
return ([system_msg] if system_msg else []) + messages[-1:]
def auto_slim_messages(messages, verbose=False):
estimated_tokens = estimate_tokens(messages)
@ -240,29 +310,46 @@ def auto_slim_messages(messages, verbose=False):
return messages
if verbose:
print(f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}")
print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}")
print(
f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}"
)
print(
f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}"
)
result = intelligently_trim_messages(messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP)
result = intelligently_trim_messages(
messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP
)
final_tokens = estimate_tokens(result)
if final_tokens > MAX_TOKENS_LIMIT:
if verbose:
print(f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}")
print(
f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}"
)
result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose)
final_tokens = estimate_tokens(result)
if verbose:
removed_count = len(messages) - len(result)
print(f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}")
print(f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}")
print(
f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}"
)
print(
f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}"
)
if removed_count > 0:
print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}")
print(
f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}"
)
return result
def emergency_reduce_messages(messages, target_tokens, verbose=False):
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
system_msg = (
messages[0] if messages and messages[0].get("role") == "system" else None
)
start_idx = 1 if system_msg else 0
keep_count = 2

View File

@ -1,22 +1,29 @@
import logging
import json
import logging
import uuid
from typing import Optional, Dict, Any, List
from pr.config import (
DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS,
KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED,
MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD
)
from pr.cache import APICache, ToolCache
from pr.workflows import WorkflowEngine, WorkflowStorage
from typing import Any, Dict, List, Optional
from pr.agents import AgentManager
from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor
from pr.cache import APICache, ToolCache
from pr.config import (
ADVANCED_CONTEXT_ENABLED,
API_CACHE_TTL,
CACHE_ENABLED,
CONVERSATION_SUMMARY_THRESHOLD,
DB_PATH,
KNOWLEDGE_SEARCH_LIMIT,
MEMORY_AUTO_SUMMARIZE,
TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS,
)
from pr.core.advanced_context import AdvancedContextManager
from pr.core.api import call_api
from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore
from pr.tools.base import get_tools_definition
from pr.workflows import WorkflowEngine, WorkflowStorage
logger = logging.getLogger("pr")
logger = logging.getLogger('pr')
class EnhancedAssistant:
def __init__(self, base_assistant):
@ -32,7 +39,7 @@ class EnhancedAssistant:
self.workflow_storage = WorkflowStorage(DB_PATH)
self.workflow_engine = WorkflowEngine(
tool_executor=self._execute_tool_for_workflow,
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS,
)
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
@ -44,20 +51,21 @@ class EnhancedAssistant:
if ADVANCED_CONTEXT_ENABLED:
self.context_manager = AdvancedContextManager(
knowledge_store=self.knowledge_store,
conversation_memory=self.conversation_memory
conversation_memory=self.conversation_memory,
)
else:
self.context_manager = None
self.current_conversation_id = str(uuid.uuid4())[:16]
self.conversation_memory.create_conversation(
self.current_conversation_id,
session_id=str(uuid.uuid4())[:16]
self.current_conversation_id, session_id=str(uuid.uuid4())[:16]
)
logger.info("Enhanced Assistant initialized with all features")
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
def _execute_tool_for_workflow(
self, tool_name: str, arguments: Dict[str, Any]
) -> Any:
if self.tool_cache:
cached_result = self.tool_cache.get(tool_name, arguments)
if cached_result is not None:
@ -65,41 +73,66 @@ class EnhancedAssistant:
return cached_result
func_map = {
'read_file': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'read_file', 'arguments': json.dumps(kw)}
}])[0],
'write_file': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'write_file', 'arguments': json.dumps(kw)}
}])[0],
'list_directory': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'list_directory', 'arguments': json.dumps(kw)}
}])[0],
'run_command': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'run_command', 'arguments': json.dumps(kw)}
}])[0],
"read_file": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {"name": "read_file", "arguments": json.dumps(kw)},
}
]
)[0],
"write_file": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {"name": "write_file", "arguments": json.dumps(kw)},
}
]
)[0],
"list_directory": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {
"name": "list_directory",
"arguments": json.dumps(kw),
},
}
]
)[0],
"run_command": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {
"name": "run_command",
"arguments": json.dumps(kw),
},
}
]
)[0],
}
if tool_name in func_map:
result = func_map[tool_name](**arguments)
if self.tool_cache:
content = result.get('content', '')
content = result.get("content", "")
try:
parsed_content = json.loads(content) if isinstance(content, str) else content
parsed_content = (
json.loads(content) if isinstance(content, str) else content
)
self.tool_cache.set(tool_name, arguments, parsed_content)
except Exception:
pass
return result
return {'error': f'Unknown tool: {tool_name}'}
return {"error": f"Unknown tool: {tool_name}"}
def _api_caller_for_agent(self, messages: List[Dict[str, Any]],
temperature: float, max_tokens: int) -> Dict[str, Any]:
def _api_caller_for_agent(
self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
) -> Dict[str, Any]:
return call_api(
messages,
self.base.model,
@ -109,15 +142,12 @@ class EnhancedAssistant:
tools=None,
temperature=temperature,
max_tokens=max_tokens,
verbose=self.base.verbose
verbose=self.base.verbose,
)
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.api_cache and CACHE_ENABLED:
cached_response = self.api_cache.get(
self.base.model, messages,
0.7, 4096
)
cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096)
if cached_response:
logger.debug("API cache hit")
return cached_response
@ -129,15 +159,13 @@ class EnhancedAssistant:
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose
verbose=self.base.verbose,
)
if self.api_cache and CACHE_ENABLED and 'error' not in response:
token_count = response.get('usage', {}).get('total_tokens', 0)
if self.api_cache and CACHE_ENABLED and "error" not in response:
token_count = response.get("usage", {}).get("total_tokens", 0)
self.api_cache.set(
self.base.model, messages,
0.7, 4096,
response, token_count
self.base.model, messages, 0.7, 4096, response, token_count
)
return response
@ -146,35 +174,33 @@ class EnhancedAssistant:
self.base.messages.append({"role": "user", "content": user_message})
self.conversation_memory.add_message(
self.current_conversation_id,
str(uuid.uuid4())[:16],
'user',
user_message
self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
)
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
facts = self.fact_extractor.extract_facts(user_message)
for fact in facts[:3]:
entry_id = str(uuid.uuid4())[:16]
from pr.memory import KnowledgeEntry
import time
categories = self.fact_extractor.categorize_content(fact['text'])
from pr.memory import KnowledgeEntry
categories = self.fact_extractor.categorize_content(fact["text"])
entry = KnowledgeEntry(
entry_id=entry_id,
category=categories[0] if categories else 'general',
content=fact['text'],
metadata={'type': fact['type'], 'confidence': fact['confidence']},
category=categories[0] if categories else "general",
content=fact["text"],
metadata={"type": fact["type"], "confidence": fact["confidence"]},
created_at=time.time(),
updated_at=time.time()
updated_at=time.time(),
)
self.knowledge_store.add_entry(entry)
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
self.base.messages,
user_message,
include_knowledge=True
enhanced_messages, context_info = (
self.context_manager.create_enhanced_context(
self.base.messages, user_message, include_knowledge=True
)
)
if self.base.verbose:
@ -189,38 +215,40 @@ class EnhancedAssistant:
result = self.base.process_response(response)
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
summary = self.context_manager.advanced_summarize_messages(
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
) if self.context_manager else "Conversation in progress"
summary = (
self.context_manager.advanced_summarize_messages(
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
)
if self.context_manager
else "Conversation in progress"
)
topics = self.fact_extractor.categorize_content(summary)
self.conversation_memory.update_conversation_summary(
self.current_conversation_id,
summary,
topics
self.current_conversation_id, summary, topics
)
return result
def execute_workflow(self, workflow_name: str,
initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
def execute_workflow(
self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
if not workflow:
return {'error': f'Workflow "{workflow_name}" not found'}
return {"error": f'Workflow "{workflow_name}" not found'}
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
execution_id = self.workflow_storage.save_execution(
self.workflow_storage.load_workflow_by_name(workflow_name).name,
context
self.workflow_storage.load_workflow_by_name(workflow_name).name, context
)
return {
'success': True,
'execution_id': execution_id,
'results': context.step_results,
'execution_log': context.execution_log
"success": True,
"execution_id": execution_id,
"results": context.step_results,
"execution_log": context.execution_log,
}
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
@ -230,20 +258,22 @@ class EnhancedAssistant:
return self.agent_manager.execute_agent_task(agent_id, task)
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
orchestrator_id = self.agent_manager.create_agent('orchestrator')
orchestrator_id = self.agent_manager.create_agent("orchestrator")
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
def search_knowledge(
self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT
) -> List[Any]:
return self.knowledge_store.search_entries(query, top_k=limit)
def get_cache_statistics(self) -> Dict[str, Any]:
stats = {}
if self.api_cache:
stats['api_cache'] = self.api_cache.get_statistics()
stats["api_cache"] = self.api_cache.get_statistics()
if self.tool_cache:
stats['tool_cache'] = self.tool_cache.get_statistics()
stats["tool_cache"] = self.tool_cache.get_statistics()
return stats

View File

@ -1,6 +1,7 @@
import logging
import os
from logging.handlers import RotatingFileHandler
from pr.config import LOG_FILE
@ -9,21 +10,19 @@ def setup_logging(verbose=False):
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger('pr')
logger = logging.getLogger("pr")
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
if logger.handlers:
logger.handlers.clear()
file_handler = RotatingFileHandler(
LOG_FILE,
maxBytes=10 * 1024 * 1024,
backupCount=5
LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5
)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
@ -31,9 +30,7 @@ def setup_logging(verbose=False):
if verbose:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
'%(levelname)s: %(message)s'
)
console_formatter = logging.Formatter("%(levelname)s: %(message)s")
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
@ -42,5 +39,5 @@ def setup_logging(verbose=False):
def get_logger(name=None):
if name:
return logging.getLogger(f'pr.{name}')
return logging.getLogger('pr')
return logging.getLogger(f"pr.{name}")
return logging.getLogger("pr")

View File

@ -2,9 +2,10 @@ import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from pr.core.logging import get_logger
logger = get_logger('session')
logger = get_logger("session")
SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions")
@ -14,18 +15,20 @@ class SessionManager:
def __init__(self):
os.makedirs(SESSIONS_DIR, exist_ok=True)
def save_session(self, name: str, messages: List[Dict], metadata: Optional[Dict] = None) -> bool:
def save_session(
self, name: str, messages: List[Dict], metadata: Optional[Dict] = None
) -> bool:
try:
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
session_data = {
'name': name,
'created_at': datetime.now().isoformat(),
'messages': messages,
'metadata': metadata or {}
"name": name,
"created_at": datetime.now().isoformat(),
"messages": messages,
"metadata": metadata or {},
}
with open(session_file, 'w') as f:
with open(session_file, "w") as f:
json.dump(session_data, f, indent=2)
logger.info(f"Session saved: {name}")
@ -43,7 +46,7 @@ class SessionManager:
logger.warning(f"Session not found: {name}")
return None
with open(session_file, 'r') as f:
with open(session_file) as f:
session_data = json.load(f)
logger.info(f"Session loaded: {name}")
@ -58,22 +61,24 @@ class SessionManager:
try:
for filename in os.listdir(SESSIONS_DIR):
if filename.endswith('.json'):
if filename.endswith(".json"):
filepath = os.path.join(SESSIONS_DIR, filename)
try:
with open(filepath, 'r') as f:
with open(filepath) as f:
data = json.load(f)
sessions.append({
'name': data.get('name', filename[:-5]),
'created_at': data.get('created_at', 'unknown'),
'message_count': len(data.get('messages', [])),
'metadata': data.get('metadata', {})
})
sessions.append(
{
"name": data.get("name", filename[:-5]),
"created_at": data.get("created_at", "unknown"),
"message_count": len(data.get("messages", [])),
"metadata": data.get("metadata", {}),
}
)
except Exception as e:
logger.warning(f"Error reading session file {filename}: {e}")
sessions.sort(key=lambda x: x['created_at'], reverse=True)
sessions.sort(key=lambda x: x["created_at"], reverse=True)
except Exception as e:
logger.error(f"Error listing sessions: {e}")
@ -96,39 +101,39 @@ class SessionManager:
logger.error(f"Error deleting session {name}: {e}")
return False
def export_session(self, name: str, output_path: str, format: str = 'json') -> bool:
def export_session(self, name: str, output_path: str, format: str = "json") -> bool:
session_data = self.load_session(name)
if not session_data:
return False
try:
if format == 'json':
with open(output_path, 'w') as f:
if format == "json":
with open(output_path, "w") as f:
json.dump(session_data, f, indent=2)
elif format == 'markdown':
with open(output_path, 'w') as f:
elif format == "markdown":
with open(output_path, "w") as f:
f.write(f"# Session: {name}\n\n")
f.write(f"Created: {session_data['created_at']}\n\n")
f.write("---\n\n")
for msg in session_data['messages']:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
for msg in session_data["messages"]:
role = msg.get("role", "unknown")
content = msg.get("content", "")
f.write(f"## {role.capitalize()}\n\n")
f.write(f"{content}\n\n")
f.write("---\n\n")
elif format == 'txt':
with open(output_path, 'w') as f:
elif format == "txt":
with open(output_path, "w") as f:
f.write(f"Session: {name}\n")
f.write(f"Created: {session_data['created_at']}\n")
f.write("=" * 80 + "\n\n")
for msg in session_data['messages']:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
for msg in session_data["messages"]:
role = msg.get("role", "unknown")
content = msg.get("content", "")
f.write(f"[{role.upper()}]\n")
f.write(f"{content}\n")

View File

@ -2,20 +2,21 @@ import json
import os
from datetime import datetime
from typing import Dict, Optional
from pr.core.logging import get_logger
logger = get_logger('usage')
logger = get_logger("usage")
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
MODEL_COSTS = {
'x-ai/grok-code-fast-1': {'input': 0.0, 'output': 0.0},
'gpt-4': {'input': 0.03, 'output': 0.06},
'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
'claude-3-opus': {'input': 0.015, 'output': 0.075},
'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
"x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0},
"gpt-4": {"input": 0.03, "output": 0.06},
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
"claude-3-opus": {"input": 0.015, "output": 0.075},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
"claude-3-haiku": {"input": 0.00025, "output": 0.00125},
}
@ -23,12 +24,12 @@ class UsageTracker:
def __init__(self):
self.session_usage = {
'requests': 0,
'total_tokens': 0,
'input_tokens': 0,
'output_tokens': 0,
'estimated_cost': 0.0,
'models_used': {}
"requests": 0,
"total_tokens": 0,
"input_tokens": 0,
"output_tokens": 0,
"estimated_cost": 0.0,
"models_used": {},
}
def track_request(
@ -36,30 +37,30 @@ class UsageTracker:
model: str,
input_tokens: int,
output_tokens: int,
total_tokens: Optional[int] = None
total_tokens: Optional[int] = None,
):
if total_tokens is None:
total_tokens = input_tokens + output_tokens
self.session_usage['requests'] += 1
self.session_usage['total_tokens'] += total_tokens
self.session_usage['input_tokens'] += input_tokens
self.session_usage['output_tokens'] += output_tokens
self.session_usage["requests"] += 1
self.session_usage["total_tokens"] += total_tokens
self.session_usage["input_tokens"] += input_tokens
self.session_usage["output_tokens"] += output_tokens
if model not in self.session_usage['models_used']:
self.session_usage['models_used'][model] = {
'requests': 0,
'tokens': 0,
'cost': 0.0
if model not in self.session_usage["models_used"]:
self.session_usage["models_used"][model] = {
"requests": 0,
"tokens": 0,
"cost": 0.0,
}
model_usage = self.session_usage['models_used'][model]
model_usage['requests'] += 1
model_usage['tokens'] += total_tokens
model_usage = self.session_usage["models_used"][model]
model_usage["requests"] += 1
model_usage["tokens"] += total_tokens
cost = self._calculate_cost(model, input_tokens, output_tokens)
model_usage['cost'] += cost
self.session_usage['estimated_cost'] += cost
model_usage["cost"] += cost
self.session_usage["estimated_cost"] += cost
self._save_to_history(model, input_tokens, output_tokens, cost)
@ -67,9 +68,11 @@ class UsageTracker:
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
)
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
def _calculate_cost(
self, model: str, input_tokens: int, output_tokens: int
) -> float:
if model not in MODEL_COSTS:
base_model = model.split('/')[0] if '/' in model else model
base_model = model.split("/")[0] if "/" in model else model
if base_model not in MODEL_COSTS:
logger.warning(f"Unknown model for cost calculation: {model}")
return 0.0
@ -77,31 +80,35 @@ class UsageTracker:
else:
costs = MODEL_COSTS[model]
input_cost = (input_tokens / 1000) * costs['input']
output_cost = (output_tokens / 1000) * costs['output']
input_cost = (input_tokens / 1000) * costs["input"]
output_cost = (output_tokens / 1000) * costs["output"]
return input_cost + output_cost
def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float):
def _save_to_history(
self, model: str, input_tokens: int, output_tokens: int, cost: float
):
try:
history = []
if os.path.exists(USAGE_DB_FILE):
with open(USAGE_DB_FILE, 'r') as f:
with open(USAGE_DB_FILE) as f:
history = json.load(f)
history.append({
'timestamp': datetime.now().isoformat(),
'model': model,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': input_tokens + output_tokens,
'cost': cost
})
history.append(
{
"timestamp": datetime.now().isoformat(),
"model": model,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cost": cost,
}
)
if len(history) > 10000:
history = history[-10000:]
with open(USAGE_DB_FILE, 'w') as f:
with open(USAGE_DB_FILE, "w") as f:
json.dump(history, f, indent=2)
except Exception as e:
@ -121,42 +128,34 @@ class UsageTracker:
f"Estimated Cost: ${usage['estimated_cost']:.4f}",
]
if usage['models_used']:
if usage["models_used"]:
lines.append("\nModels Used:")
for model, stats in usage['models_used'].items():
for model, stats in usage["models_used"].items():
lines.append(
f" {model}: {stats['requests']} requests, "
f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}"
)
return '\n'.join(lines)
return "\n".join(lines)
@staticmethod
def get_total_usage() -> Dict:
if not os.path.exists(USAGE_DB_FILE):
return {
'total_requests': 0,
'total_tokens': 0,
'total_cost': 0.0
}
return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
try:
with open(USAGE_DB_FILE, 'r') as f:
with open(USAGE_DB_FILE) as f:
history = json.load(f)
total_tokens = sum(entry['total_tokens'] for entry in history)
total_cost = sum(entry['cost'] for entry in history)
total_tokens = sum(entry["total_tokens"] for entry in history)
total_cost = sum(entry["cost"] for entry in history)
return {
'total_requests': len(history),
'total_tokens': total_tokens,
'total_cost': total_cost
"total_requests": len(history),
"total_tokens": total_tokens,
"total_cost": total_cost,
}
except Exception as e:
logger.error(f"Error loading usage history: {e}")
return {
'total_requests': 0,
'total_tokens': 0,
'total_cost': 0.0
}
return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}

View File

@ -1,5 +1,5 @@
import os
from typing import Optional
from pr.core.exceptions import ValidationError
@ -16,7 +16,9 @@ def validate_file_path(path: str, must_exist: bool = False) -> str:
return os.path.abspath(path)
def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str:
def validate_directory_path(
path: str, must_exist: bool = False, create: bool = False
) -> str:
if not path:
raise ValidationError("Directory path cannot be empty")
@ -48,7 +50,7 @@ def validate_api_url(url: str) -> str:
if not url:
raise ValidationError("API URL cannot be empty")
if not url.startswith(('http://', 'https://')):
if not url.startswith(("http://", "https://")):
raise ValidationError("API URL must start with http:// or https://")
return url
@ -58,7 +60,7 @@ def validate_session_name(name: str) -> str:
if not name:
raise ValidationError("Session name cannot be empty")
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]
for char in invalid_chars:
if char in name:
raise ValidationError(f"Session name contains invalid character: {char}")

View File

@ -1,23 +1,22 @@
#!/usr/bin/env python3
import atexit
import curses
import threading
import sys
import os
import re
import socket
import pickle
import queue
import time
import atexit
import re
import signal
import traceback
from contextlib import contextmanager
import socket
import sys
import threading
import time
class RPEditor:
def __init__(self, filename=None, auto_save=False, timeout=30):
"""
Initialize RPEditor with enhanced robustness features.
Args:
filename: File to edit
auto_save: Enable auto-save on exit
@ -27,7 +26,7 @@ class RPEditor:
self.lines = [""]
self.cursor_y = 0
self.cursor_x = 0
self.mode = 'normal'
self.mode = "normal"
self.command = ""
self.stdscr = None
self.running = False
@ -47,7 +46,7 @@ class RPEditor:
self._cleanup_registered = False
self._original_terminal_state = None
self._exception_occurred = False
# Create socket pair with error handling
try:
self.client_sock, self.server_sock = socket.socketpair()
@ -56,10 +55,10 @@ class RPEditor:
except Exception as e:
self._cleanup()
raise RuntimeError(f"Failed to create socket pair: {e}")
# Register cleanup handlers
self._register_cleanup()
if filename:
self.load_file()
@ -81,14 +80,14 @@ class RPEditor:
try:
# Stop the editor
self.running = False
# Save if auto-save is enabled
if self.auto_save and self.filename and not self._exception_occurred:
try:
self._save_file()
except:
pass
# Clean up curses
if self.stdscr:
try:
@ -103,13 +102,13 @@ class RPEditor:
curses.endwin()
except:
pass
# Clear screen after curses cleanup
try:
os.system('clear' if os.name != 'nt' else 'cls')
os.system("clear" if os.name != "nt" else "cls")
except:
pass
# Close sockets
for sock in [self.client_sock, self.server_sock]:
if sock:
@ -117,12 +116,12 @@ class RPEditor:
sock.close()
except:
pass
# Wait for threads to finish
for thread in [self.thread, self.socket_thread]:
if thread and thread.is_alive():
thread.join(timeout=1)
except:
pass
@ -130,12 +129,12 @@ class RPEditor:
"""Load file with enhanced error handling."""
try:
if os.path.exists(self.filename):
with open(self.filename, 'r', encoding='utf-8', errors='replace') as f:
with open(self.filename, encoding="utf-8", errors="replace") as f:
content = f.read()
self.lines = content.splitlines() if content else [""]
else:
self.lines = [""]
except Exception as e:
except Exception:
self.lines = [""]
# Don't raise, just use empty content
@ -144,24 +143,24 @@ class RPEditor:
with self.lock:
if not self.filename:
return False
try:
# Create backup if file exists
if os.path.exists(self.filename):
backup_name = f"{self.filename}.bak"
try:
with open(self.filename, 'r', encoding='utf-8') as f:
with open(self.filename, encoding="utf-8") as f:
backup_content = f.read()
with open(backup_name, 'w', encoding='utf-8') as f:
with open(backup_name, "w", encoding="utf-8") as f:
f.write(backup_content)
except:
pass # Backup failed, but continue with save
# Save the file
with open(self.filename, 'w', encoding='utf-8') as f:
f.write('\n'.join(self.lines))
with open(self.filename, "w", encoding="utf-8") as f:
f.write("\n".join(self.lines))
return True
except Exception as e:
except Exception:
return False
def save_file(self):
@ -169,7 +168,7 @@ class RPEditor:
if not self.running:
return self._save_file()
try:
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
self.client_sock.send(pickle.dumps({"command": "save_file"}))
except:
return self._save_file() # Fallback to direct save
@ -177,10 +176,12 @@ class RPEditor:
"""Start the editor with enhanced error handling."""
if self.running:
return False
try:
self.running = True
self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True)
self.socket_thread = threading.Thread(
target=self.socket_listener, daemon=True
)
self.socket_thread.start()
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()
@ -194,10 +195,10 @@ class RPEditor:
"""Stop the editor with proper cleanup."""
try:
if self.client_sock:
self.client_sock.send(pickle.dumps({'command': 'stop'}))
self.client_sock.send(pickle.dumps({"command": "stop"}))
except:
pass
self.running = False
time.sleep(0.1) # Give threads time to finish
self._cleanup()
@ -206,20 +207,20 @@ class RPEditor:
"""Run the main editor loop with exception handling."""
try:
curses.wrapper(self.main_loop)
except Exception as e:
except Exception:
self._exception_occurred = True
self._cleanup()
def main_loop(self, stdscr):
"""Main editor loop with enhanced error recovery."""
self.stdscr = stdscr
try:
# Configure curses
curses.curs_set(1)
self.stdscr.keypad(True)
self.stdscr.timeout(100) # Non-blocking with timeout
while self.running:
try:
# Process queued commands
@ -230,11 +231,11 @@ class RPEditor:
self.execute_command(command)
except queue.Empty:
break
# Draw screen
with self.lock:
self.draw()
# Handle input
try:
key = self.stdscr.getch()
@ -243,12 +244,12 @@ class RPEditor:
self.handle_key(key)
except curses.error:
pass # Ignore curses errors
except Exception as e:
except Exception:
# Log error but continue running
pass
except Exception as e:
except Exception:
self._exception_occurred = True
finally:
self._cleanup()
@ -258,28 +259,28 @@ class RPEditor:
try:
self.stdscr.clear()
height, width = self.stdscr.getmaxyx()
# Draw lines
for i, line in enumerate(self.lines):
if i >= height - 1:
break
try:
# Handle long lines and special characters
display_line = line[:width-1] if len(line) >= width else line
display_line = line[: width - 1] if len(line) >= width else line
self.stdscr.addstr(i, 0, display_line)
except curses.error:
pass # Skip lines that can't be displayed
# Draw status line
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
if self.mode == 'command':
status = self.command[:width-1]
if self.mode == "command":
status = self.command[: width - 1]
try:
self.stdscr.addstr(height - 1, 0, status[:width-1])
self.stdscr.addstr(height - 1, 0, status[: width - 1])
except curses.error:
pass
# Position cursor
cursor_x = min(self.cursor_x, width - 1)
cursor_y = min(self.cursor_y, height - 2)
@ -287,7 +288,7 @@ class RPEditor:
self.stdscr.move(cursor_y, cursor_x)
except curses.error:
pass
self.stdscr.refresh()
except Exception:
pass # Continue even if draw fails
@ -295,11 +296,11 @@ class RPEditor:
def handle_key(self, key):
"""Handle keyboard input with error recovery."""
try:
if self.mode == 'normal':
if self.mode == "normal":
self.handle_normal(key)
elif self.mode == 'insert':
elif self.mode == "insert":
self.handle_insert(key)
elif self.mode == 'command':
elif self.mode == "command":
self.handle_command(key)
except Exception:
pass # Continue on error
@ -307,73 +308,73 @@ class RPEditor:
def handle_normal(self, key):
"""Handle normal mode keys."""
try:
if key == ord('h') or key == curses.KEY_LEFT:
if key == ord("h") or key == curses.KEY_LEFT:
self.move_cursor(0, -1)
elif key == ord('j') or key == curses.KEY_DOWN:
elif key == ord("j") or key == curses.KEY_DOWN:
self.move_cursor(1, 0)
elif key == ord('k') or key == curses.KEY_UP:
elif key == ord("k") or key == curses.KEY_UP:
self.move_cursor(-1, 0)
elif key == ord('l') or key == curses.KEY_RIGHT:
elif key == ord("l") or key == curses.KEY_RIGHT:
self.move_cursor(0, 1)
elif key == ord('i'):
self.mode = 'insert'
elif key == ord(':'):
self.mode = 'command'
elif key == ord("i"):
self.mode = "insert"
elif key == ord(":"):
self.mode = "command"
self.command = ":"
elif key == ord('x'):
elif key == ord("x"):
self._delete_char()
elif key == ord('a'):
elif key == ord("a"):
self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y]))
self.mode = 'insert'
elif key == ord('A'):
self.mode = "insert"
elif key == ord("A"):
self.cursor_x = len(self.lines[self.cursor_y])
self.mode = 'insert'
elif key == ord('o'):
self.mode = "insert"
elif key == ord("o"):
self._insert_line(self.cursor_y + 1, "")
self.cursor_y += 1
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('O'):
self.mode = "insert"
elif key == ord("O"):
self._insert_line(self.cursor_y, "")
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('d') and self.prev_key == ord('d'):
self.mode = "insert"
elif key == ord("d") and self.prev_key == ord("d"):
if self.cursor_y < len(self.lines):
self.clipboard = self.lines[self.cursor_y]
self._delete_line(self.cursor_y)
if self.cursor_y >= len(self.lines):
self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = 0
elif key == ord('y') and self.prev_key == ord('y'):
elif key == ord("y") and self.prev_key == ord("y"):
if self.cursor_y < len(self.lines):
self.clipboard = self.lines[self.cursor_y]
elif key == ord('p'):
elif key == ord("p"):
self._insert_line(self.cursor_y + 1, self.clipboard)
self.cursor_y += 1
self.cursor_x = 0
elif key == ord('P'):
elif key == ord("P"):
self._insert_line(self.cursor_y, self.clipboard)
self.cursor_x = 0
elif key == ord('w'):
elif key == ord("w"):
self._move_word_forward()
elif key == ord('b'):
elif key == ord("b"):
self._move_word_backward()
elif key == ord('0'):
elif key == ord("0"):
self.cursor_x = 0
elif key == ord('$'):
elif key == ord("$"):
self.cursor_x = len(self.lines[self.cursor_y])
elif key == ord('g'):
if self.prev_key == ord('g'):
elif key == ord("g"):
if self.prev_key == ord("g"):
self.cursor_y = 0
self.cursor_x = 0
elif key == ord('G'):
elif key == ord("G"):
self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = 0
elif key == ord('u'):
elif key == ord("u"):
self.undo()
elif key == ord('r') and self.prev_key == 18: # Ctrl-R
elif key == ord("r") and self.prev_key == 18: # Ctrl-R
self.redo()
self.prev_key = key
except Exception:
pass
@ -410,7 +411,7 @@ class RPEditor:
"""Handle insert mode keys."""
try:
if key == 27: # ESC
self.mode = 'normal'
self.mode = "normal"
if self.cursor_x > 0:
self.cursor_x -= 1
elif key == 10 or key == 13: # Enter
@ -438,10 +439,10 @@ class RPEditor:
elif cmd.startswith("w "):
self.filename = cmd[2:].strip()
self._save_file()
self.mode = 'normal'
self.mode = "normal"
self.command = ""
elif key == 27: # ESC
self.mode = 'normal'
self.mode = "normal"
self.command = ""
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
if len(self.command) > 1:
@ -449,17 +450,17 @@ class RPEditor:
elif 32 <= key <= 126:
self.command += chr(key)
except Exception:
self.mode = 'normal'
self.mode = "normal"
self.command = ""
def move_cursor(self, dy, dx):
"""Move cursor with bounds checking."""
if not self.lines:
self.lines = [""]
new_y = self.cursor_y + dy
new_x = self.cursor_x + dx
# Ensure valid Y position
if 0 <= new_y < len(self.lines):
self.cursor_y = new_y
@ -477,9 +478,9 @@ class RPEditor:
"""Save current state for undo."""
with self.lock:
state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
"lines": list(self.lines),
"cursor_y": self.cursor_y,
"cursor_x": self.cursor_x,
}
self.undo_stack.append(state)
if len(self.undo_stack) > self.max_undo:
@ -491,69 +492,79 @@ class RPEditor:
with self.lock:
if self.undo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
"lines": list(self.lines),
"cursor_y": self.cursor_y,
"cursor_x": self.cursor_x,
}
self.redo_stack.append(current_state)
state = self.undo_stack.pop()
self.lines = state['lines']
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
self.lines = state["lines"]
self.cursor_y = min(state["cursor_y"], len(self.lines) - 1)
self.cursor_x = min(
state["cursor_x"],
len(self.lines[self.cursor_y]) if self.lines else 0,
)
def redo(self):
"""Redo last undone change."""
with self.lock:
if self.redo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
"lines": list(self.lines),
"cursor_y": self.cursor_y,
"cursor_x": self.cursor_x,
}
self.undo_stack.append(current_state)
state = self.redo_stack.pop()
self.lines = state['lines']
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
self.lines = state["lines"]
self.cursor_y = min(state["cursor_y"], len(self.lines) - 1)
self.cursor_x = min(
state["cursor_x"],
len(self.lines[self.cursor_y]) if self.lines else 0,
)
def _insert_text(self, text):
"""Insert text at cursor position."""
if not text:
return
self.save_state()
lines = text.split('\n')
lines = text.split("\n")
if len(lines) == 1:
# Single line insert
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + text + line[self.cursor_x:]
self.lines[self.cursor_y] = (
line[: self.cursor_x] + text + line[self.cursor_x :]
)
self.cursor_x += len(text)
else:
# Multi-line insert
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
first = self.lines[self.cursor_y][: self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :]
self.lines[self.cursor_y] = first
for i in range(1, len(lines) - 1):
self.lines.insert(self.cursor_y + i, lines[i])
self.lines.insert(self.cursor_y + len(lines) - 1, last)
self.cursor_y += len(lines) - 1
self.cursor_x = len(lines[-1])
def insert_text(self, text):
"""Thread-safe text insertion."""
try:
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
self.client_sock.send(
pickle.dumps({"command": "insert_text", "text": text})
)
except:
with self.lock:
self._insert_text(text)
@ -561,14 +572,18 @@ class RPEditor:
def _delete_char(self):
"""Delete character at cursor."""
self.save_state()
if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]):
if self.cursor_y < len(self.lines) and self.cursor_x < len(
self.lines[self.cursor_y]
):
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + line[self.cursor_x+1:]
self.lines[self.cursor_y] = (
line[: self.cursor_x] + line[self.cursor_x + 1 :]
)
def delete_char(self):
"""Thread-safe character deletion."""
try:
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
self.client_sock.send(pickle.dumps({"command": "delete_char"}))
except:
with self.lock:
self._delete_char()
@ -578,9 +593,9 @@ class RPEditor:
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + char + line[self.cursor_x:]
self.lines[self.cursor_y] = line[: self.cursor_x] + char + line[self.cursor_x :]
self.cursor_x += 1
def _split_line(self):
@ -588,10 +603,10 @@ class RPEditor:
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
self.lines[self.cursor_y] = line[: self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x :])
self.cursor_y += 1
self.cursor_x = 0
@ -599,7 +614,9 @@ class RPEditor:
"""Handle backspace key."""
if self.cursor_x > 0:
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x-1] + line[self.cursor_x:]
self.lines[self.cursor_y] = (
line[: self.cursor_x - 1] + line[self.cursor_x :]
)
self.cursor_x -= 1
elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1])
@ -637,7 +654,7 @@ class RPEditor:
self._set_text(text)
return
try:
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
self.client_sock.send(pickle.dumps({"command": "set_text", "text": text}))
except:
with self.lock:
self._set_text(text)
@ -651,7 +668,9 @@ class RPEditor:
def goto_line(self, line_num):
"""Thread-safe goto line."""
try:
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
self.client_sock.send(
pickle.dumps({"command": "goto_line", "line_num": line_num})
)
except:
with self.lock:
self._goto_line(line_num)
@ -659,17 +678,17 @@ class RPEditor:
def get_text(self):
"""Get entire text content."""
try:
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
self.client_sock.send(pickle.dumps({"command": "get_text"}))
data = self.client_sock.recv(65536)
return pickle.loads(data)
except:
with self.lock:
return '\n'.join(self.lines)
return "\n".join(self.lines)
def get_cursor(self):
"""Get cursor position."""
try:
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
self.client_sock.send(pickle.dumps({"command": "get_cursor"}))
data = self.client_sock.recv(4096)
return pickle.loads(data)
except:
@ -679,16 +698,16 @@ class RPEditor:
def get_file_info(self):
"""Get file information."""
try:
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
self.client_sock.send(pickle.dumps({"command": "get_file_info"}))
data = self.client_sock.recv(4096)
return pickle.loads(data)
except:
with self.lock:
return {
'filename': self.filename,
'lines': len(self.lines),
'cursor': (self.cursor_y, self.cursor_x),
'mode': self.mode
"filename": self.filename,
"lines": len(self.lines),
"cursor": (self.cursor_y, self.cursor_x),
"mode": self.mode,
}
def socket_listener(self):
@ -713,39 +732,39 @@ class RPEditor:
def execute_command(self, command):
"""Execute command with error handling."""
try:
cmd = command.get('command')
if cmd == 'insert_text':
self._insert_text(command.get('text', ''))
elif cmd == 'delete_char':
cmd = command.get("command")
if cmd == "insert_text":
self._insert_text(command.get("text", ""))
elif cmd == "delete_char":
self._delete_char()
elif cmd == 'save_file':
elif cmd == "save_file":
self._save_file()
elif cmd == 'set_text':
self._set_text(command.get('text', ''))
elif cmd == 'goto_line':
self._goto_line(command.get('line_num', 1))
elif cmd == 'get_text':
result = '\n'.join(self.lines)
elif cmd == "set_text":
self._set_text(command.get("text", ""))
elif cmd == "goto_line":
self._goto_line(command.get("line_num", 1))
elif cmd == "get_text":
result = "\n".join(self.lines)
self.server_sock.send(pickle.dumps(result))
elif cmd == 'get_cursor':
elif cmd == "get_cursor":
result = (self.cursor_y, self.cursor_x)
self.server_sock.send(pickle.dumps(result))
elif cmd == 'get_file_info':
elif cmd == "get_file_info":
result = {
'filename': self.filename,
'lines': len(self.lines),
'cursor': (self.cursor_y, self.cursor_x),
'mode': self.mode
"filename": self.filename,
"lines": len(self.lines),
"cursor": (self.cursor_y, self.cursor_x),
"mode": self.mode,
}
self.server_sock.send(pickle.dumps(result))
elif cmd == 'stop':
elif cmd == "stop":
self.running = False
except Exception:
pass
# Additional public methods for backwards compatibility
def move_cursor_to(self, y, x):
"""Move cursor to specific position."""
with self.lock:
@ -788,11 +807,11 @@ class RPEditor:
"""Replace text in range."""
with self.lock:
self.save_state()
# Validate bounds
start_line = max(0, min(start_line, len(self.lines) - 1))
end_line = max(0, min(end_line, len(self.lines) - 1))
if start_line == end_line:
line = self.lines[start_line]
start_col = max(0, min(start_col, len(line)))
@ -801,9 +820,9 @@ class RPEditor:
else:
first_part = self.lines[start_line][:start_col]
last_part = self.lines[end_line][end_col:]
new_lines = new_text.split('\n')
new_lines = new_text.split("\n")
self.lines[start_line] = first_part + new_lines[0]
del self.lines[start_line + 1:end_line + 1]
del self.lines[start_line + 1 : end_line + 1]
for i, new_line in enumerate(new_lines[1:], 1):
self.lines.insert(start_line + i, new_line)
if len(new_lines) > 1:
@ -842,24 +861,24 @@ class RPEditor:
with self.lock:
if not self.selection_start or not self.selection_end:
return ""
sl, sc = self.selection_start
el, ec = self.selection_end
# Validate bounds
if sl < 0 or sl >= len(self.lines) or el < 0 or el >= len(self.lines):
return ""
if sl == el:
return self.lines[sl][sc:ec]
result = [self.lines[sl][sc:]]
for i in range(sl + 1, el):
if i < len(self.lines):
result.append(self.lines[i])
if el < len(self.lines):
result.append(self.lines[el][:ec])
return '\n'.join(result)
return "\n".join(result)
def delete_selection(self):
"""Delete selected text."""
@ -880,7 +899,7 @@ class RPEditor:
self.save_state()
search_lines = search_block.splitlines()
replace_lines = replace_block.splitlines()
for i in range(len(self.lines) - len(search_lines) + 1):
match = True
for j, search_line in enumerate(search_lines):
@ -890,12 +909,12 @@ class RPEditor:
if self.lines[i + j].strip() != search_line.strip():
match = False
break
if match:
# Preserve indentation
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
indented_replace = [' ' * indent + line for line in replace_lines]
self.lines[i:i+len(search_lines)] = indented_replace
indented_replace = [" " * indent + line for line in replace_lines]
self.lines[i : i + len(search_lines)] = indented_replace
return True
return False
@ -904,21 +923,21 @@ class RPEditor:
with self.lock:
self.save_state()
try:
lines = diff_text.split('\n')
lines = diff_text.split("\n")
start_line = 0
for line in lines:
if line.startswith('@@'):
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
if line.startswith("@@"):
match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
if match:
start_line = int(match.group(1)) - 1
elif line.startswith('-'):
elif line.startswith("-"):
if start_line < len(self.lines):
del self.lines[start_line]
elif line.startswith('+'):
elif line.startswith("+"):
self.lines.insert(start_line, line[1:])
start_line += 1
elif line and not line.startswith('\\'):
elif line and not line.startswith("\\"):
start_line += 1
except Exception:
pass
@ -972,18 +991,18 @@ def main():
editor = None
try:
filename = sys.argv[1] if len(sys.argv) > 1 else None
# Parse additional arguments
auto_save = '--auto-save' in sys.argv
auto_save = "--auto-save" in sys.argv
# Create and start editor
editor = RPEditor(filename, auto_save=auto_save)
editor.start()
# Wait for editor to finish
if editor.thread:
editor.thread.join()
except KeyboardInterrupt:
pass
except Exception as e:
@ -992,7 +1011,7 @@ def main():
if editor:
editor.stop()
# Ensure screen is cleared
os.system('clear' if os.name != 'nt' else 'cls')
os.system("clear" if os.name != "nt" else "cls")
if __name__ == "__main__":

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python3
import curses
import threading
import sys
import os
import re
import socket
import pickle
import queue
import re
import socket
import sys
import threading
class RPEditor:
def __init__(self, filename=None):
@ -14,7 +14,7 @@ class RPEditor:
self.lines = [""]
self.cursor_y = 0
self.cursor_x = 0
self.mode = 'normal'
self.mode = "normal"
self.command = ""
self.stdscr = None
self.running = False
@ -35,7 +35,7 @@ class RPEditor:
def load_file(self):
try:
with open(self.filename, 'r') as f:
with open(self.filename) as f:
self.lines = f.read().splitlines()
if not self.lines:
self.lines = [""]
@ -45,11 +45,11 @@ class RPEditor:
def _save_file(self):
with self.lock:
if self.filename:
with open(self.filename, 'w') as f:
f.write('\n'.join(self.lines))
with open(self.filename, "w") as f:
f.write("\n".join(self.lines))
def save_file(self):
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
self.client_sock.send(pickle.dumps({"command": "save_file"}))
def start(self):
self.running = True
@ -59,7 +59,7 @@ class RPEditor:
self.thread.start()
def stop(self):
self.client_sock.send(pickle.dumps({'command': 'stop'}))
self.client_sock.send(pickle.dumps({"command": "stop"}))
self.running = False
if self.stdscr:
curses.endwin()
@ -99,66 +99,66 @@ class RPEditor:
self.stdscr.addstr(i, 0, line[:width])
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
self.stdscr.addstr(height - 1, 0, status[:width])
if self.mode == 'command':
if self.mode == "command":
self.stdscr.addstr(height - 1, 0, self.command[:width])
self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1))
self.stdscr.refresh()
def handle_key(self, key):
if self.mode == 'normal':
if self.mode == "normal":
self.handle_normal(key)
elif self.mode == 'insert':
elif self.mode == "insert":
self.handle_insert(key)
elif self.mode == 'command':
elif self.mode == "command":
self.handle_command(key)
def handle_normal(self, key):
if key == ord('h') or key == curses.KEY_LEFT:
if key == ord("h") or key == curses.KEY_LEFT:
self.move_cursor(0, -1)
elif key == ord('j') or key == curses.KEY_DOWN:
elif key == ord("j") or key == curses.KEY_DOWN:
self.move_cursor(1, 0)
elif key == ord('k') or key == curses.KEY_UP:
elif key == ord("k") or key == curses.KEY_UP:
self.move_cursor(-1, 0)
elif key == ord('l') or key == curses.KEY_RIGHT:
elif key == ord("l") or key == curses.KEY_RIGHT:
self.move_cursor(0, 1)
elif key == ord('i'):
self.mode = 'insert'
elif key == ord(':'):
self.mode = 'command'
elif key == ord("i"):
self.mode = "insert"
elif key == ord(":"):
self.mode = "command"
self.command = ":"
elif key == ord('x'):
elif key == ord("x"):
self._delete_char()
elif key == ord('a'):
elif key == ord("a"):
self.cursor_x += 1
self.mode = 'insert'
elif key == ord('A'):
self.mode = "insert"
elif key == ord("A"):
self.cursor_x = len(self.lines[self.cursor_y])
self.mode = 'insert'
elif key == ord('o'):
self.mode = "insert"
elif key == ord("o"):
self._insert_line(self.cursor_y + 1, "")
self.cursor_y += 1
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('O'):
self.mode = "insert"
elif key == ord("O"):
self._insert_line(self.cursor_y, "")
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('d') and self.prev_key == ord('d'):
self.mode = "insert"
elif key == ord("d") and self.prev_key == ord("d"):
self.clipboard = self.lines[self.cursor_y]
self._delete_line(self.cursor_y)
if self.cursor_y >= len(self.lines):
self.cursor_y = len(self.lines) - 1
self.cursor_x = 0
elif key == ord('y') and self.prev_key == ord('y'):
elif key == ord("y") and self.prev_key == ord("y"):
self.clipboard = self.lines[self.cursor_y]
elif key == ord('p'):
elif key == ord("p"):
self._insert_line(self.cursor_y + 1, self.clipboard)
self.cursor_y += 1
self.cursor_x = 0
elif key == ord('P'):
elif key == ord("P"):
self._insert_line(self.cursor_y, self.clipboard)
self.cursor_x = 0
elif key == ord('w'):
elif key == ord("w"):
line = self.lines[self.cursor_y]
i = self.cursor_x
while i < len(line) and not line[i].isalnum():
@ -166,7 +166,7 @@ class RPEditor:
while i < len(line) and line[i].isalnum():
i += 1
self.cursor_x = i
elif key == ord('b'):
elif key == ord("b"):
line = self.lines[self.cursor_y]
i = self.cursor_x - 1
while i >= 0 and not line[i].isalnum():
@ -174,26 +174,26 @@ class RPEditor:
while i >= 0 and line[i].isalnum():
i -= 1
self.cursor_x = i + 1
elif key == ord('0'):
elif key == ord("0"):
self.cursor_x = 0
elif key == ord('$'):
elif key == ord("$"):
self.cursor_x = len(self.lines[self.cursor_y])
elif key == ord('g'):
if self.prev_key == ord('g'):
elif key == ord("g"):
if self.prev_key == ord("g"):
self.cursor_y = 0
self.cursor_x = 0
elif key == ord('G'):
elif key == ord("G"):
self.cursor_y = len(self.lines) - 1
self.cursor_x = 0
elif key == ord('u'):
elif key == ord("u"):
self.undo()
elif key == ord('r') and self.prev_key == 18:
elif key == ord("r") and self.prev_key == 18:
self.redo()
self.prev_key = key
def handle_insert(self, key):
if key == 27:
self.mode = 'normal'
self.mode = "normal"
if self.cursor_x > 0:
self.cursor_x -= 1
elif key == 10:
@ -207,11 +207,13 @@ class RPEditor:
def handle_command(self, key):
if key == 10:
cmd = self.command[1:]
if cmd == "q" or cmd == 'q!':
if cmd == "q" or cmd == "q!":
self.running = False
elif cmd == "w":
elif cmd == "w":
self._save_file()
elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!":
elif (
cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!"
):
self._save_file()
self.running = False
elif cmd.startswith("w "):
@ -220,10 +222,10 @@ class RPEditor:
elif cmd == "wq":
self._save_file()
self.running = False
self.mode = 'normal'
self.mode = "normal"
self.command = ""
elif key == 27:
self.mode = 'normal'
self.mode = "normal"
self.command = ""
elif key == curses.KEY_BACKSPACE or key == 127:
if len(self.command) > 1:
@ -241,9 +243,9 @@ class RPEditor:
def save_state(self):
with self.lock:
state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
"lines": list(self.lines),
"cursor_y": self.cursor_y,
"cursor_x": self.cursor_x,
}
self.undo_stack.append(state)
if len(self.undo_stack) > self.max_undo:
@ -254,71 +256,85 @@ class RPEditor:
with self.lock:
if self.undo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
"lines": list(self.lines),
"cursor_y": self.cursor_y,
"cursor_x": self.cursor_x,
}
self.redo_stack.append(current_state)
state = self.undo_stack.pop()
self.lines = state['lines']
self.cursor_y = state['cursor_y']
self.cursor_x = state['cursor_x']
self.lines = state["lines"]
self.cursor_y = state["cursor_y"]
self.cursor_x = state["cursor_x"]
def redo(self):
with self.lock:
if self.redo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
"lines": list(self.lines),
"cursor_y": self.cursor_y,
"cursor_x": self.cursor_x,
}
self.undo_stack.append(current_state)
state = self.redo_stack.pop()
self.lines = state['lines']
self.cursor_y = state['cursor_y']
self.cursor_x = state['cursor_x']
self.lines = state["lines"]
self.cursor_y = state["cursor_y"]
self.cursor_x = state["cursor_x"]
def _insert_text(self, text):
self.save_state()
lines = text.split('\n')
lines = text.split("\n")
if len(lines) == 1:
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + text + self.lines[self.cursor_y][self.cursor_x:]
self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x]
+ text
+ self.lines[self.cursor_y][self.cursor_x :]
)
self.cursor_x += len(text)
else:
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
first = self.lines[self.cursor_y][: self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :]
self.lines[self.cursor_y] = first
for i in range(1, len(lines)-1):
for i in range(1, len(lines) - 1):
self.lines.insert(self.cursor_y + i, lines[i])
self.lines.insert(self.cursor_y + len(lines) - 1, last)
self.cursor_y += len(lines) - 1
self.cursor_x = len(lines[-1])
def insert_text(self, text):
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text}))
def _delete_char(self):
self.save_state()
if self.cursor_x < len(self.lines[self.cursor_y]):
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + self.lines[self.cursor_y][self.cursor_x+1:]
self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x]
+ self.lines[self.cursor_y][self.cursor_x + 1 :]
)
def delete_char(self):
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
self.client_sock.send(pickle.dumps({"command": "delete_char"}))
def _insert_char(self, char):
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + char + self.lines[self.cursor_y][self.cursor_x:]
self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x]
+ char
+ self.lines[self.cursor_y][self.cursor_x :]
)
self.cursor_x += 1
def _split_line(self):
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
self.lines[self.cursor_y] = line[: self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x :])
self.cursor_y += 1
self.cursor_x = 0
def _backspace(self):
if self.cursor_x > 0:
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x-1] + self.lines[self.cursor_y][self.cursor_x:]
self.lines[self.cursor_y] = (
self.lines[self.cursor_y][: self.cursor_x - 1]
+ self.lines[self.cursor_y][self.cursor_x :]
)
self.cursor_x -= 1
elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1])
@ -347,7 +363,7 @@ class RPEditor:
self.cursor_x = 0
def set_text(self, text):
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
self.client_sock.send(pickle.dumps({"command": "set_text", "text": text}))
def _goto_line(self, line_num):
line_num = max(0, min(line_num, len(self.lines) - 1))
@ -355,24 +371,26 @@ class RPEditor:
self.cursor_x = 0
def goto_line(self, line_num):
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
self.client_sock.send(
pickle.dumps({"command": "goto_line", "line_num": line_num})
)
def get_text(self):
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
self.client_sock.send(pickle.dumps({"command": "get_text"}))
try:
return pickle.loads(self.client_sock.recv(4096))
except:
return ''
return ""
def get_cursor(self):
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
self.client_sock.send(pickle.dumps({"command": "get_cursor"}))
try:
return pickle.loads(self.client_sock.recv(4096))
except:
return (0, 0)
def get_file_info(self):
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
self.client_sock.send(pickle.dumps({"command": "get_file_info"}))
try:
return pickle.loads(self.client_sock.recv(4096))
except:
@ -390,46 +408,46 @@ class RPEditor:
break
def execute_command(self, command):
cmd = command.get('command')
if cmd == 'insert_text':
self._insert_text(command['text'])
elif cmd == 'delete_char':
cmd = command.get("command")
if cmd == "insert_text":
self._insert_text(command["text"])
elif cmd == "delete_char":
self._delete_char()
elif cmd == 'save_file':
elif cmd == "save_file":
self._save_file()
elif cmd == 'set_text':
self._set_text(command['text'])
elif cmd == 'goto_line':
self._goto_line(command['line_num'])
elif cmd == 'get_text':
result = '\n'.join(self.lines)
elif cmd == "set_text":
self._set_text(command["text"])
elif cmd == "goto_line":
self._goto_line(command["line_num"])
elif cmd == "get_text":
result = "\n".join(self.lines)
try:
self.server_sock.send(pickle.dumps(result))
except:
pass
elif cmd == 'get_cursor':
elif cmd == "get_cursor":
result = (self.cursor_y, self.cursor_x)
try:
self.server_sock.send(pickle.dumps(result))
except:
pass
elif cmd == 'get_file_info':
elif cmd == "get_file_info":
result = {
'filename': self.filename,
'lines': len(self.lines),
'cursor': (self.cursor_y, self.cursor_x),
'mode': self.mode
"filename": self.filename,
"lines": len(self.lines),
"cursor": (self.cursor_y, self.cursor_x),
"mode": self.mode,
}
try:
self.server_sock.send(pickle.dumps(result))
except:
pass
elif cmd == 'stop':
elif cmd == "stop":
self.running = False
def move_cursor_to(self, y, x):
with self.lock:
self.cursor_y = max(0, min(y, len(self.lines)-1))
self.cursor_y = max(0, min(y, len(self.lines) - 1))
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
def get_line(self, line_num):
@ -469,9 +487,9 @@ class RPEditor:
else:
first_part = self.lines[start_line][:start_col]
last_part = self.lines[end_line][end_col:]
new_lines = new_text.split('\n')
new_lines = new_text.split("\n")
self.lines[start_line] = first_part + new_lines[0]
del self.lines[start_line + 1:end_line + 1]
del self.lines[start_line + 1 : end_line + 1]
for i, new_line in enumerate(new_lines[1:], 1):
self.lines.insert(start_line + i, new_line)
if len(new_lines) > 1:
@ -511,7 +529,7 @@ class RPEditor:
for i in range(sl + 1, el):
result.append(self.lines[i])
result.append(self.lines[el][:ec])
return '\n'.join(result)
return "\n".join(result)
def delete_selection(self):
with self.lock:
@ -537,24 +555,24 @@ class RPEditor:
break
if match:
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
indented_replace = [' ' * indent + line for line in replace_lines]
self.lines[i:i+len(search_lines)] = indented_replace
indented_replace = [" " * indent + line for line in replace_lines]
self.lines[i : i + len(search_lines)] = indented_replace
return True
return False
def apply_diff(self, diff_text):
with self.lock:
self.save_state()
lines = diff_text.split('\n')
lines = diff_text.split("\n")
for line in lines:
if line.startswith('@@'):
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
if line.startswith("@@"):
match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
if match:
start_line = int(match.group(1)) - 1
elif line.startswith('-'):
elif line.startswith("-"):
if start_line < len(self.lines):
del self.lines[start_line]
elif line.startswith('+'):
elif line.startswith("+"):
self.lines.insert(start_line, line[1:])
start_line += 1
@ -574,6 +592,7 @@ class RPEditor:
if self.thread:
self.thread.join()
def main():
filename = sys.argv[1] if len(sys.argv) > 1 else None
editor = RPEditor(filename)
@ -583,5 +602,3 @@ def main():
if __name__ == "__main__":
main()

View File

@ -3,14 +3,13 @@
Advanced input handler for PR Assistant with editor mode, file inclusion, and image support.
"""
import os
import re
import base64
import mimetypes
import re
import readline
import glob
from pathlib import Path
from typing import Optional
# from pr.ui.colors import Colors # Avoid import issues
@ -29,7 +28,7 @@ class AdvancedInputHandler:
return None
readline.set_completer(completer)
readline.parse_and_bind('tab: complete')
readline.parse_and_bind("tab: complete")
except:
pass # Readline not available
@ -60,7 +59,7 @@ class AdvancedInputHandler:
return ""
# Check for special commands
if user_input.lower() == '/editor':
if user_input.lower() == "/editor":
self.toggle_editor_mode()
return self.get_input(prompt) # Recurse to get new input
@ -74,23 +73,25 @@ class AdvancedInputHandler:
def _get_editor_input(self, prompt: str) -> Optional[str]:
"""Get multi-line input for editor mode."""
try:
print("Editor mode: Enter your message. Type 'END' on a new line to finish.")
print(
"Editor mode: Enter your message. Type 'END' on a new line to finish."
)
print("Type '/simple' to switch back to simple mode.")
lines = []
while True:
try:
line = input()
if line.strip().lower() == 'end':
if line.strip().lower() == "end":
break
elif line.strip().lower() == '/simple':
elif line.strip().lower() == "/simple":
self.toggle_editor_mode()
return self.get_input(prompt) # Switch back and get input
lines.append(line)
except EOFError:
break
content = '\n'.join(lines).strip()
content = "\n".join(lines).strip()
if not content:
return ""
@ -114,12 +115,13 @@ class AdvancedInputHandler:
def _process_file_inclusions(self, text: str) -> str:
"""Replace @[filename] with file contents."""
def replace_file(match):
filename = match.group(1).strip()
try:
path = Path(filename).expanduser().resolve()
if path.exists() and path.is_file():
with open(path, 'r', encoding='utf-8', errors='replace') as f:
with open(path, encoding="utf-8", errors="replace") as f:
content = f.read()
return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n"
else:
@ -128,7 +130,7 @@ class AdvancedInputHandler:
return f"[Error reading file {filename}: {e}]"
# Replace @[filename] patterns
pattern = r'@\[([^\]]+)\]'
pattern = r"@\[([^\]]+)\]"
return re.sub(pattern, replace_file, text)
def _process_image_inclusions(self, text: str) -> str:
@ -143,20 +145,22 @@ class AdvancedInputHandler:
path = Path(word.strip()).expanduser().resolve()
if path.exists() and path.is_file():
mime_type, _ = mimetypes.guess_type(str(path))
if mime_type and mime_type.startswith('image/'):
if mime_type and mime_type.startswith("image/"):
# Encode image
with open(path, 'rb') as f:
image_data = base64.b64encode(f.read()).decode('utf-8')
with open(path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
# Replace with data URL
processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n")
processed_parts.append(
f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n"
)
continue
except:
pass
processed_parts.append(word)
return ' '.join(processed_parts)
return " ".join(processed_parts)
# Global instance

View File

@ -1,7 +1,12 @@
from .knowledge_store import KnowledgeStore, KnowledgeEntry
from .semantic_index import SemanticIndex
from .conversation_memory import ConversationMemory
from .fact_extractor import FactExtractor
from .knowledge_store import KnowledgeEntry, KnowledgeStore
from .semantic_index import SemanticIndex
__all__ = ['KnowledgeStore', 'KnowledgeEntry', 'SemanticIndex',
'ConversationMemory', 'FactExtractor']
__all__ = [
"KnowledgeStore",
"KnowledgeEntry",
"SemanticIndex",
"ConversationMemory",
"FactExtractor",
]

View File

@ -1,7 +1,8 @@
import json
import sqlite3
import time
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
class ConversationMemory:
def __init__(self, db_path: str):
@ -12,7 +13,8 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS conversation_history (
conversation_id TEXT PRIMARY KEY,
session_id TEXT,
@ -23,9 +25,11 @@ class ConversationMemory:
topics TEXT,
metadata TEXT
)
''')
"""
)
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS conversation_messages (
message_id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
@ -36,117 +40,163 @@ class ConversationMemory:
metadata TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
)
''')
"""
)
cursor.execute('''
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
''')
"""
)
conn.commit()
conn.close()
def create_conversation(self, conversation_id: str, session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None):
def create_conversation(
self,
conversation_id: str,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT INTO conversation_history
(conversation_id, session_id, started_at, metadata)
VALUES (?, ?, ?, ?)
''', (
conversation_id,
session_id,
time.time(),
json.dumps(metadata) if metadata else None
))
""",
(
conversation_id,
session_id,
time.time(),
json.dumps(metadata) if metadata else None,
),
)
conn.commit()
conn.close()
def add_message(self, conversation_id: str, message_id: str, role: str,
content: str, tool_calls: Optional[List[Dict[str, Any]]] = None,
metadata: Optional[Dict[str, Any]] = None):
def add_message(
self,
conversation_id: str,
message_id: str,
role: str,
content: str,
tool_calls: Optional[List[Dict[str, Any]]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT INTO conversation_messages
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
message_id,
conversation_id,
role,
content,
time.time(),
json.dumps(tool_calls) if tool_calls else None,
json.dumps(metadata) if metadata else None
))
""",
(
message_id,
conversation_id,
role,
content,
time.time(),
json.dumps(tool_calls) if tool_calls else None,
json.dumps(metadata) if metadata else None,
),
)
cursor.execute('''
cursor.execute(
"""
UPDATE conversation_history
SET message_count = message_count + 1
WHERE conversation_id = ?
''', (conversation_id,))
""",
(conversation_id,),
)
conn.commit()
conn.close()
def get_conversation_messages(self, conversation_id: str,
limit: Optional[int] = None) -> List[Dict[str, Any]]:
def get_conversation_messages(
self, conversation_id: str, limit: Optional[int] = None
) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if limit:
cursor.execute('''
cursor.execute(
"""
SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages
WHERE conversation_id = ?
ORDER BY timestamp DESC
LIMIT ?
''', (conversation_id, limit))
""",
(conversation_id, limit),
)
else:
cursor.execute('''
cursor.execute(
"""
SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages
WHERE conversation_id = ?
ORDER BY timestamp ASC
''', (conversation_id,))
""",
(conversation_id,),
)
messages = []
for row in cursor.fetchall():
messages.append({
'message_id': row[0],
'role': row[1],
'content': row[2],
'timestamp': row[3],
'tool_calls': json.loads(row[4]) if row[4] else None,
'metadata': json.loads(row[5]) if row[5] else None
})
messages.append(
{
"message_id": row[0],
"role": row[1],
"content": row[2],
"timestamp": row[3],
"tool_calls": json.loads(row[4]) if row[4] else None,
"metadata": json.loads(row[5]) if row[5] else None,
}
)
conn.close()
return messages
def update_conversation_summary(self, conversation_id: str, summary: str,
topics: Optional[List[str]] = None):
def update_conversation_summary(
self, conversation_id: str, summary: str, topics: Optional[List[str]] = None
):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
UPDATE conversation_history
SET summary = ?, topics = ?, ended_at = ?
WHERE conversation_id = ?
''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id))
""",
(
summary,
json.dumps(topics) if topics else None,
time.time(),
conversation_id,
),
)
conn.commit()
conn.close()
@ -155,7 +205,8 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
h.message_count, h.summary, h.topics
FROM conversation_history h
@ -163,56 +214,69 @@ class ConversationMemory:
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
ORDER BY h.started_at DESC
LIMIT ?
''', (f'%{query}%', f'%{query}%', f'%{query}%', limit))
""",
(f"%{query}%", f"%{query}%", f"%{query}%", limit),
)
conversations = []
for row in cursor.fetchall():
conversations.append({
'conversation_id': row[0],
'session_id': row[1],
'started_at': row[2],
'message_count': row[3],
'summary': row[4],
'topics': json.loads(row[5]) if row[5] else []
})
conversations.append(
{
"conversation_id": row[0],
"session_id": row[1],
"started_at": row[2],
"message_count": row[3],
"summary": row[4],
"topics": json.loads(row[5]) if row[5] else [],
}
)
conn.close()
return conversations
def get_recent_conversations(self, limit: int = 10,
session_id: Optional[str] = None) -> List[Dict[str, Any]]:
def get_recent_conversations(
self, limit: int = 10, session_id: Optional[str] = None
) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if session_id:
cursor.execute('''
cursor.execute(
"""
SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics
FROM conversation_history
WHERE session_id = ?
ORDER BY started_at DESC
LIMIT ?
''', (session_id, limit))
""",
(session_id, limit),
)
else:
cursor.execute('''
cursor.execute(
"""
SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics
FROM conversation_history
ORDER BY started_at DESC
LIMIT ?
''', (limit,))
""",
(limit,),
)
conversations = []
for row in cursor.fetchall():
conversations.append({
'conversation_id': row[0],
'session_id': row[1],
'started_at': row[2],
'ended_at': row[3],
'message_count': row[4],
'summary': row[5],
'topics': json.loads(row[6]) if row[6] else []
})
conversations.append(
{
"conversation_id": row[0],
"session_id": row[1],
"started_at": row[2],
"ended_at": row[3],
"message_count": row[4],
"summary": row[5],
"topics": json.loads(row[6]) if row[6] else [],
}
)
conn.close()
return conversations
@ -221,10 +285,14 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?',
(conversation_id,))
cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?',
(conversation_id,))
cursor.execute(
"DELETE FROM conversation_messages WHERE conversation_id = ?",
(conversation_id,),
)
cursor.execute(
"DELETE FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
deleted = cursor.rowcount > 0
conn.commit()
@ -236,24 +304,26 @@ class ConversationMemory:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM conversation_history')
cursor.execute("SELECT COUNT(*) FROM conversation_history")
total_conversations = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM conversation_messages')
cursor.execute("SELECT COUNT(*) FROM conversation_messages")
total_messages = cursor.fetchone()[0]
cursor.execute('SELECT SUM(message_count) FROM conversation_history')
total_message_count = cursor.fetchone()[0] or 0
cursor.execute("SELECT SUM(message_count) FROM conversation_history")
cursor.fetchone()[0] or 0
cursor.execute('''
cursor.execute(
"""
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
''')
"""
)
avg_messages = cursor.fetchone()[0] or 0
conn.close()
return {
'total_conversations': total_conversations,
'total_messages': total_messages,
'average_messages_per_conversation': round(avg_messages, 2)
"total_conversations": total_conversations,
"total_messages": total_messages,
"average_messages_per_conversation": round(avg_messages, 2),
}

View File

@ -1,16 +1,16 @@
import re
import json
from typing import List, Dict, Any, Set
from collections import defaultdict
from typing import Any, Dict, List
class FactExtractor:
def __init__(self):
self.fact_patterns = [
(r'([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)', 'definition'),
(r'([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})', 'temporal'),
(r'([A-Z][a-z]+) (invented|created|developed) ([^.]+)', 'attribution'),
(r'([^.]+) (costs?|worth) (\$[\d,]+)', 'numeric'),
(r'([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)', 'location'),
(r"([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)", "definition"),
(r"([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})", "temporal"),
(r"([A-Z][a-z]+) (invented|created|developed) ([^.]+)", "attribution"),
(r"([^.]+) (costs?|worth) (\$[\d,]+)", "numeric"),
(r"([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)", "location"),
]
def extract_facts(self, text: str) -> List[Dict[str, Any]]:
@ -19,27 +19,31 @@ class FactExtractor:
for pattern, fact_type in self.fact_patterns:
matches = re.finditer(pattern, text)
for match in matches:
facts.append({
'type': fact_type,
'text': match.group(0),
'components': match.groups(),
'confidence': 0.7
})
facts.append(
{
"type": fact_type,
"text": match.group(0),
"components": match.groups(),
"confidence": 0.7,
}
)
noun_phrases = self._extract_noun_phrases(text)
for phrase in noun_phrases:
if len(phrase.split()) >= 2:
facts.append({
'type': 'entity',
'text': phrase,
'components': [phrase],
'confidence': 0.5
})
facts.append(
{
"type": "entity",
"text": phrase,
"components": [phrase],
"confidence": 0.5,
}
)
return facts
def _extract_noun_phrases(self, text: str) -> List[str]:
sentences = re.split(r'[.!?]', text)
sentences = re.split(r"[.!?]", text)
phrases = []
for sentence in sentences:
@ -51,25 +55,73 @@ class FactExtractor:
current_phrase.append(word)
else:
if len(current_phrase) >= 2:
phrases.append(' '.join(current_phrase))
phrases.append(" ".join(current_phrase))
current_phrase = []
if len(current_phrase) >= 2:
phrases.append(' '.join(current_phrase))
phrases.append(" ".join(current_phrase))
return list(set(phrases))
def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]:
words = re.findall(r'\b[a-z]{4,}\b', text.lower())
words = re.findall(r"\b[a-z]{4,}\b", text.lower())
stopwords = {
'this', 'that', 'these', 'those', 'what', 'which', 'where', 'when',
'with', 'from', 'have', 'been', 'were', 'will', 'would', 'could',
'should', 'about', 'their', 'there', 'other', 'than', 'then', 'them',
'some', 'more', 'very', 'such', 'into', 'through', 'during', 'before',
'after', 'above', 'below', 'between', 'under', 'again', 'further',
'once', 'here', 'both', 'each', 'doing', 'only', 'over', 'same',
'being', 'does', 'just', 'also', 'make', 'made', 'know', 'like'
"this",
"that",
"these",
"those",
"what",
"which",
"where",
"when",
"with",
"from",
"have",
"been",
"were",
"will",
"would",
"could",
"should",
"about",
"their",
"there",
"other",
"than",
"then",
"them",
"some",
"more",
"very",
"such",
"into",
"through",
"during",
"before",
"after",
"above",
"below",
"between",
"under",
"again",
"further",
"once",
"here",
"both",
"each",
"doing",
"only",
"over",
"same",
"being",
"does",
"just",
"also",
"make",
"made",
"know",
"like",
}
filtered_words = [w for w in words if w not in stopwords]
@ -85,57 +137,120 @@ class FactExtractor:
relationships = []
relationship_patterns = [
(r'([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)', 'employment'),
(r'([A-Z][a-z]+) (owns|has|possesses) ([^.]+)', 'ownership'),
(r'([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)', 'location'),
(r'([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)', 'usage'),
(
r"([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)",
"employment",
),
(r"([A-Z][a-z]+) (owns|has|possesses) ([^.]+)", "ownership"),
(
r"([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)",
"location",
),
(r"([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)", "usage"),
]
for pattern, rel_type in relationship_patterns:
matches = re.finditer(pattern, text)
for match in matches:
relationships.append({
'type': rel_type,
'subject': match.group(1),
'predicate': match.group(2),
'object': match.group(3),
'confidence': 0.6
})
relationships.append(
{
"type": rel_type,
"subject": match.group(1),
"predicate": match.group(2),
"object": match.group(3),
"confidence": 0.6,
}
)
return relationships
def extract_metadata(self, text: str) -> Dict[str, Any]:
word_count = len(text.split())
sentence_count = len(re.split(r'[.!?]', text))
sentence_count = len(re.split(r"[.!?]", text))
urls = re.findall(r'https?://[^\s]+', text)
email_addresses = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text)
dates = re.findall(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b', text)
numbers = re.findall(r'\b\d+(?:,\d{3})*(?:\.\d+)?\b', text)
urls = re.findall(r"https?://[^\s]+", text)
email_addresses = re.findall(
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text
)
dates = re.findall(
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
)
numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text)
return {
'word_count': word_count,
'sentence_count': sentence_count,
'avg_words_per_sentence': round(word_count / max(sentence_count, 1), 2),
'urls': urls,
'email_addresses': email_addresses,
'dates': dates,
'numeric_values': numbers,
'has_code': bool(re.search(r'```|def |class |import |function ', text)),
'has_questions': bool(re.search(r'\?', text))
"word_count": word_count,
"sentence_count": sentence_count,
"avg_words_per_sentence": round(word_count / max(sentence_count, 1), 2),
"urls": urls,
"email_addresses": email_addresses,
"dates": dates,
"numeric_values": numbers,
"has_code": bool(re.search(r"```|def |class |import |function ", text)),
"has_questions": bool(re.search(r"\?", text)),
}
def categorize_content(self, text: str) -> List[str]:
categories = []
category_keywords = {
'programming': ['code', 'function', 'class', 'variable', 'programming', 'software', 'debug'],
'data': ['data', 'database', 'query', 'table', 'record', 'statistics', 'analysis'],
'documentation': ['documentation', 'guide', 'tutorial', 'manual', 'readme', 'explain'],
'configuration': ['config', 'settings', 'configuration', 'setup', 'install', 'deployment'],
'testing': ['test', 'testing', 'validate', 'verification', 'quality', 'assertion'],
'research': ['research', 'study', 'analysis', 'investigation', 'findings', 'results'],
'planning': ['plan', 'planning', 'schedule', 'roadmap', 'milestone', 'timeline'],
"programming": [
"code",
"function",
"class",
"variable",
"programming",
"software",
"debug",
],
"data": [
"data",
"database",
"query",
"table",
"record",
"statistics",
"analysis",
],
"documentation": [
"documentation",
"guide",
"tutorial",
"manual",
"readme",
"explain",
],
"configuration": [
"config",
"settings",
"configuration",
"setup",
"install",
"deployment",
],
"testing": [
"test",
"testing",
"validate",
"verification",
"quality",
"assertion",
],
"research": [
"research",
"study",
"analysis",
"investigation",
"findings",
"results",
],
"planning": [
"plan",
"planning",
"schedule",
"roadmap",
"milestone",
"timeline",
],
}
text_lower = text.lower()
@ -143,4 +258,4 @@ class FactExtractor:
if any(keyword in text_lower for keyword in keywords):
categories.append(category)
return categories if categories else ['general']
return categories if categories else ["general"]

View File

@ -1,10 +1,12 @@
import json
import sqlite3
import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from .semantic_index import SemanticIndex
@dataclass
class KnowledgeEntry:
entry_id: str
@ -18,16 +20,17 @@ class KnowledgeEntry:
def to_dict(self) -> Dict[str, Any]:
return {
'entry_id': self.entry_id,
'category': self.category,
'content': self.content,
'metadata': self.metadata,
'created_at': self.created_at,
'updated_at': self.updated_at,
'access_count': self.access_count,
'importance_score': self.importance_score
"entry_id": self.entry_id,
"category": self.category,
"content": self.content,
"metadata": self.metadata,
"created_at": self.created_at,
"updated_at": self.updated_at,
"access_count": self.access_count,
"importance_score": self.importance_score,
}
class KnowledgeStore:
def __init__(self, db_path: str):
self.db_path = db_path
@ -39,7 +42,8 @@ class KnowledgeStore:
def _initialize_store(self):
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS knowledge_entries (
entry_id TEXT PRIMARY KEY,
category TEXT NOT NULL,
@ -50,44 +54,54 @@ class KnowledgeStore:
access_count INTEGER DEFAULT 0,
importance_score REAL DEFAULT 1.0
)
''')
"""
)
cursor.execute('''
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
''')
"""
)
self.conn.commit()
def _load_index(self):
cursor = self.conn.cursor()
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
cursor.execute("SELECT entry_id, content FROM knowledge_entries")
for row in cursor.fetchall():
self.semantic_index.add_document(row[0], row[1])
def add_entry(self, entry: KnowledgeEntry):
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
INSERT OR REPLACE INTO knowledge_entries
(entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
entry.entry_id,
entry.category,
entry.content,
json.dumps(entry.metadata),
entry.created_at,
entry.updated_at,
entry.access_count,
entry.importance_score
))
""",
(
entry.entry_id,
entry.category,
entry.content,
json.dumps(entry.metadata),
entry.created_at,
entry.updated_at,
entry.access_count,
entry.importance_score,
),
)
self.conn.commit()
@ -96,20 +110,26 @@ class KnowledgeStore:
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE entry_id = ?
''', (entry_id,))
""",
(entry_id,),
)
row = cursor.fetchone()
if row:
cursor.execute('''
cursor.execute(
"""
UPDATE knowledge_entries
SET access_count = access_count + 1
WHERE entry_id = ?
''', (entry_id,))
""",
(entry_id,),
)
self.conn.commit()
return KnowledgeEntry(
@ -120,13 +140,14 @@ class KnowledgeStore:
created_at=row[4],
updated_at=row[5],
access_count=row[6] + 1,
importance_score=row[7]
importance_score=row[7],
)
return None
def search_entries(self, query: str, category: Optional[str] = None,
top_k: int = 5) -> List[KnowledgeEntry]:
def search_entries(
self, query: str, category: Optional[str] = None, top_k: int = 5
) -> List[KnowledgeEntry]:
search_results = self.semantic_index.search(query, top_k * 2)
cursor = self.conn.cursor()
@ -134,17 +155,23 @@ class KnowledgeStore:
entries = []
for entry_id, score in search_results:
if category:
cursor.execute('''
cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE entry_id = ? AND category = ?
''', (entry_id, category))
""",
(entry_id, category),
)
else:
cursor.execute('''
cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE entry_id = ?
''', (entry_id,))
""",
(entry_id,),
)
row = cursor.fetchone()
if row:
@ -156,7 +183,7 @@ class KnowledgeStore:
created_at=row[4],
updated_at=row[5],
access_count=row[6],
importance_score=row[7]
importance_score=row[7],
)
entries.append(entry)
@ -168,44 +195,52 @@ class KnowledgeStore:
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE category = ?
ORDER BY importance_score DESC, created_at DESC
LIMIT ?
''', (category, limit))
""",
(category, limit),
)
entries = []
for row in cursor.fetchall():
entries.append(KnowledgeEntry(
entry_id=row[0],
category=row[1],
content=row[2],
metadata=json.loads(row[3]) if row[3] else {},
created_at=row[4],
updated_at=row[5],
access_count=row[6],
importance_score=row[7]
))
entries.append(
KnowledgeEntry(
entry_id=row[0],
category=row[1],
content=row[2],
metadata=json.loads(row[3]) if row[3] else {},
created_at=row[4],
updated_at=row[5],
access_count=row[6],
importance_score=row[7],
)
)
return entries
def update_importance(self, entry_id: str, importance_score: float):
cursor = self.conn.cursor()
cursor.execute('''
cursor.execute(
"""
UPDATE knowledge_entries
SET importance_score = ?, updated_at = ?
WHERE entry_id = ?
''', (importance_score, time.time(), entry_id))
""",
(importance_score, time.time(), entry_id),
)
self.conn.commit()
def delete_entry(self, entry_id: str) -> bool:
cursor = self.conn.cursor()
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,))
deleted = cursor.rowcount > 0
self.conn.commit()
@ -218,27 +253,29 @@ class KnowledgeStore:
def get_statistics(self) -> Dict[str, Any]:
cursor = self.conn.cursor()
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
cursor.execute("SELECT COUNT(*) FROM knowledge_entries")
total_entries = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(DISTINCT category) FROM knowledge_entries')
cursor.execute("SELECT COUNT(DISTINCT category) FROM knowledge_entries")
total_categories = cursor.fetchone()[0]
cursor.execute('''
cursor.execute(
"""
SELECT category, COUNT(*) as count
FROM knowledge_entries
GROUP BY category
ORDER BY count DESC
''')
"""
)
category_counts = {row[0]: row[1] for row in cursor.fetchall()}
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
cursor.execute("SELECT SUM(access_count) FROM knowledge_entries")
total_accesses = cursor.fetchone()[0] or 0
return {
'total_entries': total_entries,
'total_categories': total_categories,
'category_distribution': category_counts,
'total_accesses': total_accesses,
'vocabulary_size': len(self.semantic_index.vocabulary)
"total_entries": total_entries,
"total_categories": total_categories,
"category_distribution": category_counts,
"total_accesses": total_accesses,
"vocabulary_size": len(self.semantic_index.vocabulary),
}

View File

@ -1,7 +1,8 @@
import math
import re
from collections import Counter, defaultdict
from typing import List, Dict, Tuple, Set
from typing import Dict, List, Set, Tuple
class SemanticIndex:
def __init__(self):
@ -12,7 +13,7 @@ class SemanticIndex:
def _tokenize(self, text: str) -> List[str]:
text = text.lower()
text = re.sub(r'[^a-z0-9\s]', ' ', text)
text = re.sub(r"[^a-z0-9\s]", " ", text)
tokens = text.split()
return tokens
@ -78,8 +79,12 @@ class SemanticIndex:
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
dot_product = sum(vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2))
def _cosine_similarity(
self, vec1: Dict[str, float], vec2: Dict[str, float]
) -> float:
dot_product = sum(
vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)
)
norm1 = math.sqrt(sum(val**2 for val in vec1.values()))
norm2 = math.sqrt(sum(val**2 for val in vec2.values()))
if norm1 == 0 or norm2 == 0:

View File

@ -1,14 +1,13 @@
import threading
import queue
import time
import sys
import subprocess
import signal
import os
from pr.ui import Colors
from collections import defaultdict
from pr.tools.process_handlers import get_handler_for_process, detect_process_type
import sys
import threading
import time
from pr.tools.process_handlers import detect_process_type, get_handler_for_process
from pr.tools.prompt_detection import get_global_detector
from pr.ui import Colors
class TerminalMultiplexer:
def __init__(self, name, show_output=True):
@ -21,17 +20,19 @@ class TerminalMultiplexer:
self.active = True
self.lock = threading.Lock()
self.metadata = {
'start_time': time.time(),
'last_activity': time.time(),
'interaction_count': 0,
'process_type': 'unknown',
'state': 'active'
"start_time": time.time(),
"last_activity": time.time(),
"interaction_count": 0,
"process_type": "unknown",
"state": "active",
}
self.handler = None
self.prompt_detector = get_global_detector()
if self.show_output:
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
self.display_thread = threading.Thread(
target=self._display_worker, daemon=True
)
self.display_thread.start()
def _display_worker(self):
@ -47,7 +48,9 @@ class TerminalMultiplexer:
try:
line = self.stderr_queue.get(timeout=0.1)
if line:
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}")
sys.stderr.write(
f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}"
)
sys.stderr.flush()
except queue.Empty:
pass
@ -55,40 +58,44 @@ class TerminalMultiplexer:
def write_stdout(self, data):
with self.lock:
self.stdout_buffer.append(data)
self.metadata['last_activity'] = time.time()
self.metadata["last_activity"] = time.time()
# Update handler state if available
if self.handler:
self.handler.update_state(data)
# Update prompt detector
self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type'])
self.prompt_detector.update_session_state(
self.name, data, self.metadata["process_type"]
)
if self.show_output:
self.stdout_queue.put(data)
def write_stderr(self, data):
with self.lock:
self.stderr_buffer.append(data)
self.metadata['last_activity'] = time.time()
self.metadata["last_activity"] = time.time()
# Update handler state if available
if self.handler:
self.handler.update_state(data)
# Update prompt detector
self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type'])
self.prompt_detector.update_session_state(
self.name, data, self.metadata["process_type"]
)
if self.show_output:
self.stderr_queue.put(data)
def get_stdout(self):
with self.lock:
return ''.join(self.stdout_buffer)
return "".join(self.stdout_buffer)
def get_stderr(self):
with self.lock:
return ''.join(self.stderr_buffer)
return "".join(self.stderr_buffer)
def get_all_output(self):
with self.lock:
return {
'stdout': ''.join(self.stdout_buffer),
'stderr': ''.join(self.stderr_buffer)
"stdout": "".join(self.stdout_buffer),
"stderr": "".join(self.stderr_buffer),
}
def get_metadata(self):
@ -102,31 +109,32 @@ class TerminalMultiplexer:
def set_process_type(self, process_type):
"""Set the process type and initialize appropriate handler."""
with self.lock:
self.metadata['process_type'] = process_type
self.metadata["process_type"] = process_type
self.handler = get_handler_for_process(process_type, self)
def send_input(self, input_data):
if hasattr(self, 'process') and self.process.poll() is None:
if hasattr(self, "process") and self.process.poll() is None:
try:
self.process.stdin.write(input_data + '\n')
self.process.stdin.write(input_data + "\n")
self.process.stdin.flush()
with self.lock:
self.metadata['last_activity'] = time.time()
self.metadata['interaction_count'] += 1
self.metadata["last_activity"] = time.time()
self.metadata["interaction_count"] += 1
except Exception as e:
self.write_stderr(f"Error sending input: {e}")
else:
# This will be implemented when we have a process attached
# For now, just update activity
with self.lock:
self.metadata['last_activity'] = time.time()
self.metadata['interaction_count'] += 1
self.metadata["last_activity"] = time.time()
self.metadata["interaction_count"] += 1
def close(self):
self.active = False
if hasattr(self, 'display_thread'):
if hasattr(self, "display_thread"):
self.display_thread.join(timeout=1)
_multiplexers = {}
_mux_counter = 0
_mux_lock = threading.Lock()
@ -134,6 +142,7 @@ _background_monitor = None
_monitor_active = False
_monitor_interval = 0.2 # 200ms
def create_multiplexer(name=None, show_output=True):
global _mux_counter
with _mux_lock:
@ -144,44 +153,50 @@ def create_multiplexer(name=None, show_output=True):
_multiplexers[name] = mux
return name, mux
def get_multiplexer(name):
return _multiplexers.get(name)
def close_multiplexer(name):
mux = _multiplexers.get(name)
if mux:
mux.close()
del _multiplexers[name]
def get_all_multiplexer_states():
with _mux_lock:
states = {}
for name, mux in _multiplexers.items():
states[name] = {
'metadata': mux.get_metadata(),
'output_summary': {
'stdout_lines': len(mux.stdout_buffer),
'stderr_lines': len(mux.stderr_buffer)
}
"metadata": mux.get_metadata(),
"output_summary": {
"stdout_lines": len(mux.stdout_buffer),
"stderr_lines": len(mux.stderr_buffer),
},
}
return states
def cleanup_all_multiplexers():
for mux in list(_multiplexers.values()):
mux.close()
_multiplexers.clear()
# Background process management
_background_processes = {}
_process_lock = threading.Lock()
class BackgroundProcess:
def __init__(self, name, command):
self.name = name
self.command = command
self.process = None
self.multiplexer = None
self.status = 'starting'
self.status = "starting"
self.start_time = time.time()
self.end_time = None
@ -205,27 +220,27 @@ class BackgroundProcess:
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True
universal_newlines=True,
)
self.status = 'running'
self.status = "running"
# Start output monitoring threads
threading.Thread(target=self._monitor_stdout, daemon=True).start()
threading.Thread(target=self._monitor_stderr, daemon=True).start()
return {'status': 'success', 'pid': self.process.pid}
return {"status": "success", "pid": self.process.pid}
except Exception as e:
self.status = 'error'
return {'status': 'error', 'error': str(e)}
self.status = "error"
return {"status": "error", "error": str(e)}
def _monitor_stdout(self):
"""Monitor stdout from the process."""
try:
for line in iter(self.process.stdout.readline, ''):
for line in iter(self.process.stdout.readline, ""):
if line:
self.multiplexer.write_stdout(line.rstrip('\n\r'))
self.multiplexer.write_stdout(line.rstrip("\n\r"))
except Exception as e:
self.write_stderr(f"Error reading stdout: {e}")
finally:
@ -234,29 +249,33 @@ class BackgroundProcess:
def _monitor_stderr(self):
"""Monitor stderr from the process."""
try:
for line in iter(self.process.stderr.readline, ''):
for line in iter(self.process.stderr.readline, ""):
if line:
self.multiplexer.write_stderr(line.rstrip('\n\r'))
self.multiplexer.write_stderr(line.rstrip("\n\r"))
except Exception as e:
self.write_stderr(f"Error reading stderr: {e}")
def _check_completion(self):
"""Check if process has completed."""
if self.process and self.process.poll() is not None:
self.status = 'completed'
self.status = "completed"
self.end_time = time.time()
def get_info(self):
"""Get process information."""
self._check_completion()
return {
'name': self.name,
'command': self.command,
'status': self.status,
'pid': self.process.pid if self.process else None,
'start_time': self.start_time,
'end_time': self.end_time,
'runtime': time.time() - self.start_time if not self.end_time else self.end_time - self.start_time
"name": self.name,
"command": self.command,
"status": self.status,
"pid": self.process.pid if self.process else None,
"start_time": self.start_time,
"end_time": self.end_time,
"runtime": (
time.time() - self.start_time
if not self.end_time
else self.end_time - self.start_time
),
}
def get_output(self, lines=None):
@ -265,8 +284,8 @@ class BackgroundProcess:
return []
all_output = self.multiplexer.get_all_output()
stdout_lines = all_output['stdout'].split('\n') if all_output['stdout'] else []
stderr_lines = all_output['stderr'].split('\n') if all_output['stderr'] else []
stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else []
stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else []
combined = stdout_lines + stderr_lines
if lines:
@ -276,45 +295,47 @@ class BackgroundProcess:
def send_input(self, input_text):
"""Send input to the process."""
if self.process and self.status == 'running':
if self.process and self.status == "running":
try:
self.process.stdin.write(input_text + '\n')
self.process.stdin.write(input_text + "\n")
self.process.stdin.flush()
return {'status': 'success'}
return {"status": "success"}
except Exception as e:
return {'status': 'error', 'error': str(e)}
return {'status': 'error', 'error': 'Process not running or no stdin'}
return {"status": "error", "error": str(e)}
return {"status": "error", "error": "Process not running or no stdin"}
def kill(self):
"""Kill the process."""
if self.process and self.status == 'running':
if self.process and self.status == "running":
try:
self.process.terminate()
# Wait a bit for graceful termination
time.sleep(0.1)
if self.process.poll() is None:
self.process.kill()
self.status = 'killed'
self.status = "killed"
self.end_time = time.time()
return {'status': 'success'}
return {"status": "success"}
except Exception as e:
return {'status': 'error', 'error': str(e)}
return {'status': 'error', 'error': 'Process not running'}
return {"status": "error", "error": str(e)}
return {"status": "error", "error": "Process not running"}
def start_background_process(name, command):
"""Start a background process."""
with _process_lock:
if name in _background_processes:
return {'status': 'error', 'error': f'Process {name} already exists'}
return {"status": "error", "error": f"Process {name} already exists"}
process = BackgroundProcess(name, command)
result = process.start()
if result['status'] == 'success':
if result["status"] == "success":
_background_processes[name] = process
return result
def get_all_sessions():
"""Get all background process sessions."""
with _process_lock:
@ -323,23 +344,31 @@ def get_all_sessions():
sessions[name] = process.get_info()
return sessions
def get_session_info(name):
"""Get information about a specific session."""
with _process_lock:
process = _background_processes.get(name)
return process.get_info() if process else None
def get_session_output(name, lines=None):
"""Get output from a specific session."""
with _process_lock:
process = _background_processes.get(name)
return process.get_output(lines) if process else None
def send_input_to_session(name, input_text):
"""Send input to a background session."""
with _process_lock:
process = _background_processes.get(name)
return process.send_input(input_text) if process else {'status': 'error', 'error': 'Session not found'}
return (
process.send_input(input_text)
if process
else {"status": "error", "error": "Session not found"}
)
def kill_session(name):
"""Kill a background session."""
@ -347,7 +376,7 @@ def kill_session(name):
process = _background_processes.get(name)
if process:
result = process.kill()
if result['status'] == 'success':
if result["status"] == "success":
del _background_processes[name]
return result
return {'status': 'error', 'error': 'Session not found'}
return {"status": "error", "error": "Session not found"}

View File

@ -1,10 +1,11 @@
import importlib.util
import os
import sys
import importlib.util
from typing import List, Dict, Callable, Any
from typing import Callable, Dict, List
from pr.core.logging import get_logger
logger = get_logger('plugins')
logger = get_logger("plugins")
PLUGINS_DIR = os.path.expanduser("~/.pr/plugins")
@ -21,7 +22,7 @@ class PluginLoader:
logger.info("No plugins directory found")
return []
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith('.py')]
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith(".py")]
for plugin_file in plugin_files:
try:
@ -44,16 +45,20 @@ class PluginLoader:
sys.modules[plugin_name] = module
spec.loader.exec_module(module)
if hasattr(module, 'register_tools'):
if hasattr(module, "register_tools"):
tools = module.register_tools()
if isinstance(tools, list):
self.plugin_tools.extend(tools)
self.loaded_plugins[plugin_name] = module
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
else:
logger.warning(f"Plugin {plugin_name} register_tools() did not return a list")
logger.warning(
f"Plugin {plugin_name} register_tools() did not return a list"
)
else:
logger.warning(f"Plugin {plugin_name} does not have register_tools() function")
logger.warning(
f"Plugin {plugin_name} does not have register_tools() function"
)
def get_plugin_function(self, tool_name: str) -> Callable:
for plugin_name, module in self.loaded_plugins.items():
@ -67,7 +72,7 @@ class PluginLoader:
def create_example_plugin():
example_plugin = os.path.join(PLUGINS_DIR, 'example_plugin.py')
example_plugin = os.path.join(PLUGINS_DIR, "example_plugin.py")
if os.path.exists(example_plugin):
return
@ -121,7 +126,7 @@ def register_tools():
try:
os.makedirs(PLUGINS_DIR, exist_ok=True)
with open(example_plugin, 'w') as f:
with open(example_plugin, "w") as f:
f.write(example_code)
logger.info(f"Created example plugin at {example_plugin}")
except Exception as e:

View File

@ -1,25 +1,86 @@
from pr.tools.base import get_tools_definition
from pr.tools.filesystem import (
read_file, write_file, list_directory, mkdir, chdir, getpwd, index_source_directory, search_replace
from pr.tools.agents import (
collaborate_agents,
create_agent,
execute_agent_task,
list_agents,
remove_agent,
)
from pr.tools.base import get_tools_definition
from pr.tools.command import (
kill_process,
run_command,
run_command_interactive,
tail_process,
)
from pr.tools.database import db_get, db_query, db_set
from pr.tools.editor import (
close_editor,
editor_insert_text,
editor_replace_text,
editor_search,
open_editor,
)
from pr.tools.filesystem import (
chdir,
getpwd,
index_source_directory,
list_directory,
mkdir,
read_file,
search_replace,
write_file,
)
from pr.tools.memory import (
add_knowledge_entry,
delete_knowledge_entry,
get_knowledge_by_category,
get_knowledge_entry,
get_knowledge_statistics,
search_knowledge,
update_knowledge_importance,
)
from pr.tools.command import run_command, run_command_interactive, tail_process, kill_process
from pr.tools.editor import open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor
from pr.tools.database import db_set, db_get, db_query
from pr.tools.web import http_fetch, web_search, web_search_news
from pr.tools.python_exec import python_exec
from pr.tools.patch import apply_patch, create_diff
from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents
from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics
from pr.tools.python_exec import python_exec
from pr.tools.web import http_fetch, web_search, web_search_news
__all__ = [
'get_tools_definition',
'read_file', 'write_file', 'list_directory', 'mkdir', 'chdir', 'getpwd', 'index_source_directory', 'search_replace',
'open_editor', 'editor_insert_text', 'editor_replace_text', 'editor_search','close_editor',
'run_command', 'run_command_interactive',
'db_set', 'db_get', 'db_query',
'http_fetch', 'web_search', 'web_search_news',
'python_exec','tail_process', 'kill_process',
'apply_patch', 'create_diff',
'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents',
'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics'
"get_tools_definition",
"read_file",
"write_file",
"list_directory",
"mkdir",
"chdir",
"getpwd",
"index_source_directory",
"search_replace",
"open_editor",
"editor_insert_text",
"editor_replace_text",
"editor_search",
"close_editor",
"run_command",
"run_command_interactive",
"db_set",
"db_get",
"db_query",
"http_fetch",
"web_search",
"web_search_news",
"python_exec",
"tail_process",
"kill_process",
"apply_patch",
"create_diff",
"create_agent",
"list_agents",
"execute_agent_task",
"remove_agent",
"collaborate_agents",
"add_knowledge_entry",
"get_knowledge_entry",
"search_knowledge",
"get_knowledge_by_category",
"update_knowledge_importance",
"delete_knowledge_entry",
"get_knowledge_statistics",
]

View File

@ -1,13 +1,15 @@
import os
from typing import Dict, Any, List
from typing import Any, Dict, List
from pr.agents.agent_manager import AgentManager
from pr.core.api import call_api
def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
"""Create a new agent with the specified role."""
try:
# Get db_path from environment or default
db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite')
db_path = os.environ.get("ASSISTANT_DB_PATH", "~/.assistant_db.sqlite")
db_path = os.path.expanduser(db_path)
manager = AgentManager(db_path, call_api)
@ -16,47 +18,57 @@ def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
except Exception as e:
return {"status": "error", "error": str(e)}
def list_agents() -> Dict[str, Any]:
"""List all active agents."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api)
agents = []
for agent_id, agent in manager.active_agents.items():
agents.append({
"agent_id": agent_id,
"role": agent.role.name,
"task_count": agent.task_count,
"message_count": len(agent.message_history)
})
agents.append(
{
"agent_id": agent_id,
"role": agent.role.name,
"task_count": agent.task_count,
"message_count": len(agent.message_history),
}
)
return {"status": "success", "agents": agents}
except Exception as e:
return {"status": "error", "error": str(e)}
def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
def execute_agent_task(
agent_id: str, task: str, context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Execute a task with the specified agent."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api)
result = manager.execute_agent_task(agent_id, task, context)
return result
except Exception as e:
return {"status": "error", "error": str(e)}
def remove_agent(agent_id: str) -> Dict[str, Any]:
"""Remove an agent."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api)
success = manager.remove_agent(agent_id)
return {"status": "success" if success else "not_found", "agent_id": agent_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
def collaborate_agents(
orchestrator_id: str, task: str, agent_roles: List[str]
) -> Dict[str, Any]:
"""Collaborate multiple agents on a task."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
manager = AgentManager(db_path, call_api)
result = manager.collaborate_agents(orchestrator_id, task, agent_roles)
return result

View File

@ -10,12 +10,12 @@ def get_tools_definition():
"properties": {
"pid": {
"type": "integer",
"description": "The process ID returned by run_command when status is 'running'."
"description": "The process ID returned by run_command when status is 'running'.",
}
},
"required": ["pid"]
}
}
"required": ["pid"],
},
},
},
{
"type": "function",
@ -27,17 +27,17 @@ def get_tools_definition():
"properties": {
"pid": {
"type": "integer",
"description": "The process ID returned by run_command when status is 'running'."
"description": "The process ID returned by run_command when status is 'running'.",
},
"timeout": {
"type": "integer",
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
"default": 30
}
"default": 30,
},
},
"required": ["pid"]
}
}
"required": ["pid"],
},
},
},
{
"type": "function",
@ -48,11 +48,14 @@ def get_tools_definition():
"type": "object",
"properties": {
"url": {"type": "string", "description": "The URL to fetch"},
"headers": {"type": "object", "description": "Optional HTTP headers"}
"headers": {
"type": "object",
"description": "Optional HTTP headers",
},
},
"required": ["url"]
}
}
"required": ["url"],
},
},
},
{
"type": "function",
@ -62,12 +65,19 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "The shell command to execute"},
"timeout": {"type": "integer", "description": "Maximum seconds to wait for completion", "default": 30}
"command": {
"type": "string",
"description": "The shell command to execute",
},
"timeout": {
"type": "integer",
"description": "Maximum seconds to wait for completion",
"default": 30,
},
},
"required": ["command"]
}
}
"required": ["command"],
},
},
},
{
"type": "function",
@ -77,11 +87,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "The interactive command to execute (e.g., vim, nano, top)"}
"command": {
"type": "string",
"description": "The interactive command to execute (e.g., vim, nano, top)",
}
},
"required": ["command"]
}
}
"required": ["command"],
},
},
},
{
"type": "function",
@ -91,12 +104,18 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"session_name": {"type": "string", "description": "The name of the session"},
"input_data": {"type": "string", "description": "The input to send to the session"}
"session_name": {
"type": "string",
"description": "The name of the session",
},
"input_data": {
"type": "string",
"description": "The input to send to the session",
},
},
"required": ["session_name", "input_data"]
}
}
"required": ["session_name", "input_data"],
},
},
},
{
"type": "function",
@ -106,11 +125,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"session_name": {"type": "string", "description": "The name of the session"}
"session_name": {
"type": "string",
"description": "The name of the session",
}
},
"required": ["session_name"]
}
}
"required": ["session_name"],
},
},
},
{
"type": "function",
@ -120,11 +142,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"session_name": {"type": "string", "description": "The name of the session"}
"session_name": {
"type": "string",
"description": "The name of the session",
}
},
"required": ["session_name"]
}
}
"required": ["session_name"],
},
},
},
{
"type": "function",
@ -134,11 +159,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"}
"filepath": {
"type": "string",
"description": "Path to the file",
}
},
"required": ["filepath"]
}
}
"required": ["filepath"],
},
},
},
{
"type": "function",
@ -148,12 +176,18 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"content": {"type": "string", "description": "Content to write"}
"filepath": {
"type": "string",
"description": "Path to the file",
},
"content": {
"type": "string",
"description": "Content to write",
},
},
"required": ["filepath", "content"]
}
}
"required": ["filepath", "content"],
},
},
},
{
"type": "function",
@ -163,11 +197,19 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Directory path", "default": "."},
"recursive": {"type": "boolean", "description": "List recursively", "default": False}
}
}
}
"path": {
"type": "string",
"description": "Directory path",
"default": ".",
},
"recursive": {
"type": "boolean",
"description": "List recursively",
"default": False,
},
},
},
},
},
{
"type": "function",
@ -177,11 +219,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path of the directory to create"}
"path": {
"type": "string",
"description": "Path of the directory to create",
}
},
"required": ["path"]
}
}
"required": ["path"],
},
},
},
{
"type": "function",
@ -193,17 +238,17 @@ def get_tools_definition():
"properties": {
"path": {"type": "string", "description": "Path to change to"}
},
"required": ["path"]
}
}
"required": ["path"],
},
},
},
{
"type": "function",
"function": {
"name": "getpwd",
"description": "Get the current working directory",
"parameters": {"type": "object", "properties": {}}
}
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
@ -214,11 +259,11 @@ def get_tools_definition():
"type": "object",
"properties": {
"key": {"type": "string", "description": "The key"},
"value": {"type": "string", "description": "The value"}
"value": {"type": "string", "description": "The value"},
},
"required": ["key", "value"]
}
}
"required": ["key", "value"],
},
},
},
{
"type": "function",
@ -227,12 +272,10 @@ def get_tools_definition():
"description": "Get a value from the database",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "The key"}
},
"required": ["key"]
}
}
"properties": {"key": {"type": "string", "description": "The key"}},
"required": ["key"],
},
},
},
{
"type": "function",
@ -244,9 +287,9 @@ def get_tools_definition():
"properties": {
"query": {"type": "string", "description": "SQL query"}
},
"required": ["query"]
}
}
"required": ["query"],
},
},
},
{
"type": "function",
@ -258,9 +301,9 @@ def get_tools_definition():
"properties": {
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"]
}
}
"required": ["query"],
},
},
},
{
"type": "function",
@ -270,11 +313,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query for news"}
"query": {
"type": "string",
"description": "Search query for news",
}
},
"required": ["query"]
}
}
"required": ["query"],
},
},
},
{
"type": "function",
@ -284,11 +330,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "Python code to execute"}
"code": {
"type": "string",
"description": "Python code to execute",
}
},
"required": ["code"]
}
}
"required": ["code"],
},
},
},
{
"type": "function",
@ -300,9 +349,9 @@ def get_tools_definition():
"properties": {
"path": {"type": "string", "description": "Path to index"}
},
"required": ["path"]
}
}
"required": ["path"],
},
},
},
{
"type": "function",
@ -312,13 +361,22 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"old_string": {"type": "string", "description": "String to replace"},
"new_string": {"type": "string", "description": "Replacement string"}
"filepath": {
"type": "string",
"description": "Path to the file",
},
"old_string": {
"type": "string",
"description": "String to replace",
},
"new_string": {
"type": "string",
"description": "Replacement string",
},
},
"required": ["filepath", "old_string", "new_string"]
}
}
"required": ["filepath", "old_string", "new_string"],
},
},
},
{
"type": "function",
@ -328,12 +386,18 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file to patch"},
"patch_content": {"type": "string", "description": "The patch content as a string"}
"filepath": {
"type": "string",
"description": "Path to the file to patch",
},
"patch_content": {
"type": "string",
"description": "The patch content as a string",
},
},
"required": ["filepath", "patch_content"]
}
}
"required": ["filepath", "patch_content"],
},
},
},
{
"type": "function",
@ -343,14 +407,28 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"file1": {"type": "string", "description": "Path to the first file"},
"file2": {"type": "string", "description": "Path to the second file"},
"fromfile": {"type": "string", "description": "Label for the first file", "default": "file1"},
"tofile": {"type": "string", "description": "Label for the second file", "default": "file2"}
"file1": {
"type": "string",
"description": "Path to the first file",
},
"file2": {
"type": "string",
"description": "Path to the second file",
},
"fromfile": {
"type": "string",
"description": "Label for the first file",
"default": "file1",
},
"tofile": {
"type": "string",
"description": "Label for the second file",
"default": "file2",
},
},
"required": ["file1", "file2"]
}
}
"required": ["file1", "file2"],
},
},
},
{
"type": "function",
@ -360,11 +438,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"}
"filepath": {
"type": "string",
"description": "Path to the file",
}
},
"required": ["filepath"]
}
}
"required": ["filepath"],
},
},
},
{
"type": "function",
@ -374,11 +455,14 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"}
"filepath": {
"type": "string",
"description": "Path to the file",
}
},
"required": ["filepath"]
}
}
"required": ["filepath"],
},
},
},
{
"type": "function",
@ -388,14 +472,23 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"filepath": {
"type": "string",
"description": "Path to the file",
},
"text": {"type": "string", "description": "Text to insert"},
"line": {"type": "integer", "description": "Line number (optional)"},
"col": {"type": "integer", "description": "Column number (optional)"}
"line": {
"type": "integer",
"description": "Line number (optional)",
},
"col": {
"type": "integer",
"description": "Column number (optional)",
},
},
"required": ["filepath", "text"]
}
}
"required": ["filepath", "text"],
},
},
},
{
"type": "function",
@ -405,16 +498,26 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"filepath": {
"type": "string",
"description": "Path to the file",
},
"start_line": {"type": "integer", "description": "Start line"},
"start_col": {"type": "integer", "description": "Start column"},
"end_line": {"type": "integer", "description": "End line"},
"end_col": {"type": "integer", "description": "End column"},
"new_text": {"type": "string", "description": "New text"}
"new_text": {"type": "string", "description": "New text"},
},
"required": ["filepath", "start_line", "start_col", "end_line", "end_col", "new_text"]
}
}
"required": [
"filepath",
"start_line",
"start_col",
"end_line",
"end_col",
"new_text",
],
},
},
},
{
"type": "function",
@ -424,13 +527,20 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"filepath": {
"type": "string",
"description": "Path to the file",
},
"pattern": {"type": "string", "description": "Regex pattern"},
"start_line": {"type": "integer", "description": "Start line", "default": 0}
"start_line": {
"type": "integer",
"description": "Start line",
"default": 0,
},
},
"required": ["filepath", "pattern"]
}
}
"required": ["filepath", "pattern"],
},
},
},
{
"type": "function",
@ -440,24 +550,31 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"filepath1": {"type": "string", "description": "Path to the original file"},
"filepath2": {"type": "string", "description": "Path to the modified file"},
"format_type": {"type": "string", "description": "Display format: 'unified' or 'side-by-side'", "default": "unified"}
"filepath1": {
"type": "string",
"description": "Path to the original file",
},
"filepath2": {
"type": "string",
"description": "Path to the modified file",
},
"format_type": {
"type": "string",
"description": "Display format: 'unified' or 'side-by-side'",
"default": "unified",
},
},
"required": ["filepath1", "filepath2"]
}
}
"required": ["filepath1", "filepath2"],
},
},
},
{
"type": "function",
"function": {
"name": "display_edit_summary",
"description": "Display a summary of all edit operations performed during the session",
"parameters": {
"type": "object",
"properties": {}
}
}
"parameters": {"type": "object", "properties": {}},
},
},
{
"type": "function",
@ -467,21 +584,21 @@ def get_tools_definition():
"parameters": {
"type": "object",
"properties": {
"show_content": {"type": "boolean", "description": "Show content previews", "default": False}
}
}
}
"show_content": {
"type": "boolean",
"description": "Show content previews",
"default": False,
}
},
},
},
},
{
"type": "function",
"function": {
"name": "clear_edit_tracker",
"description": "Clear the edit tracker to start fresh",
"parameters": {
"type": "object",
"properties": {}
}
}
}
"parameters": {"type": "object", "properties": {}},
},
},
]

View File

@ -1,21 +1,23 @@
import os
import select
import subprocess
import time
import select
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
from pr.tools.interactive_control import start_interactive_session
from pr.config import MAX_CONCURRENT_SESSIONS
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
_processes = {}
def _register_process(pid:int, process):
def _register_process(pid: int, process):
_processes[pid] = process
return _processes
def _get_process(pid:int):
def _get_process(pid: int):
return _processes.get(pid)
def kill_process(pid:int):
def kill_process(pid: int):
try:
process = _get_process(pid)
if process:
@ -67,7 +69,7 @@ def tail_process(pid: int, timeout: int = 30):
"status": "success",
"stdout": stdout_content,
"stderr": stderr_content,
"returncode": process.returncode
"returncode": process.returncode,
}
if time.time() - start_time > timeout_duration:
@ -76,10 +78,12 @@ def tail_process(pid: int, timeout: int = 30):
"message": "Process is still running. Call tail_process again to continue monitoring.",
"stdout_so_far": stdout_content,
"stderr_so_far": stderr_content,
"pid": pid
"pid": pid,
}
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
ready, _, _ = select.select(
[process.stdout, process.stderr], [], [], 0.1
)
for pipe in ready:
if pipe == process.stdout:
line = process.stdout.readline()
@ -100,7 +104,13 @@ def tail_process(pid: int, timeout: int = 30):
def run_command(command, timeout=30, monitored=False):
mux_name = None
try:
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
_register_process(process.pid, process)
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
@ -129,7 +139,7 @@ def run_command(command, timeout=30, monitored=False):
"status": "success",
"stdout": stdout_content,
"stderr": stderr_content,
"returncode": process.returncode
"returncode": process.returncode,
}
if time.time() - start_time > timeout_duration:
@ -139,7 +149,7 @@ def run_command(command, timeout=30, monitored=False):
"stdout_so_far": stdout_content,
"stderr_so_far": stderr_content,
"pid": process.pid,
"mux_name": mux_name
"mux_name": mux_name,
}
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
@ -158,6 +168,8 @@ def run_command(command, timeout=30, monitored=False):
if mux_name:
close_multiplexer(mux_name)
return {"status": "error", "error": str(e)}
def run_command_interactive(command):
try:
return_code = os.system(command)

View File

@ -1,18 +1,23 @@
import time
def db_set(key, value, db_conn):
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
try:
cursor = db_conn.cursor()
cursor.execute("""INSERT OR REPLACE INTO kv_store (key, value, timestamp)
VALUES (?, ?, ?)""", (key, value, time.time()))
cursor.execute(
"""INSERT OR REPLACE INTO kv_store (key, value, timestamp)
VALUES (?, ?, ?)""",
(key, value, time.time()),
)
db_conn.commit()
return {"status": "success", "message": f"Set {key}"}
except Exception as e:
return {"status": "error", "error": str(e)}
def db_get(key, db_conn):
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
@ -28,6 +33,7 @@ def db_get(key, db_conn):
except Exception as e:
return {"status": "error", "error": str(e)}
def db_query(query, db_conn):
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
@ -36,9 +42,11 @@ def db_query(query, db_conn):
cursor = db_conn.cursor()
cursor.execute(query)
if query.strip().upper().startswith('SELECT'):
if query.strip().upper().startswith("SELECT"):
results = cursor.fetchall()
columns = [desc[0] for desc in cursor.description] if cursor.description else []
columns = (
[desc[0] for desc in cursor.description] if cursor.description else []
)
return {"status": "success", "columns": columns, "rows": results}
else:
db_conn.commit()

View File

@ -1,18 +1,21 @@
from pr.editor import RPEditor
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
from ..ui.diff_display import display_diff, get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
from ..tools.patch import display_content_diff
import os
import os.path
from pr.editor import RPEditor
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
from ..tools.patch import display_content_diff
from ..ui.edit_feedback import track_edit, tracker
_editors = {}
def get_editor(filepath):
if filepath not in _editors:
_editors[filepath] = RPEditor(filepath)
return _editors[filepath]
def close_editor(filepath):
try:
path = os.path.expanduser(filepath)
@ -29,6 +32,7 @@ def close_editor(filepath):
except Exception as e:
return {"status": "error", "error": str(e)}
def open_editor(filepath):
try:
path = os.path.expanduser(filepath)
@ -39,21 +43,28 @@ def open_editor(filepath):
mux_name, mux = create_multiplexer(mux_name, show_output=True)
mux.write_stdout(f"Opened editor for: {path}\n")
return {"status": "success", "message": f"Editor opened for {path}", "mux_name": mux_name}
return {
"status": "success",
"message": f"Editor opened for {path}",
"mux_name": mux_name,
}
except Exception as e:
return {"status": "error", "error": str(e)}
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
try:
path = os.path.expanduser(filepath)
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
with open(path) as f:
old_content = f.read()
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
position = (line if line is not None else 0) * 1000 + (
col if col is not None else 0
)
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation)
editor = get_editor(path)
@ -65,12 +76,16 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name)
if mux:
location = f" at line {line}, col {col}" if line is not None and col is not None else ""
location = (
f" at line {line}, col {col}"
if line is not None and col is not None
else ""
)
preview = text[:50] + "..." if len(text) > 50 else text
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")
if show_diff and old_content:
with open(path, 'r') as f:
with open(path) as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
@ -81,23 +96,32 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
close_editor(filepath)
return result
except Exception as e:
if 'operation' in locals():
if "operation" in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True):
def editor_replace_text(
filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True
):
try:
path = os.path.expanduser(filepath)
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
with open(path) as f:
old_content = f.read()
start_pos = start_line * 1000 + start_col
end_pos = end_line * 1000 + end_col
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
content=new_text, old_content=old_content)
operation = track_edit(
"REPLACE",
filepath,
start_pos=start_pos,
end_pos=end_pos,
content=new_text,
old_content=old_content,
)
tracker.mark_in_progress(operation)
editor = get_editor(path)
@ -108,10 +132,12 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
mux = get_multiplexer(mux_name)
if mux:
preview = new_text[:50] + "..." if len(new_text) > 50 else new_text
mux.write_stdout(f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n")
mux.write_stdout(
f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n"
)
if show_diff and old_content:
with open(path, 'r') as f:
with open(path) as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
@ -122,10 +148,11 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
close_editor(filepath)
return result
except Exception as e:
if 'operation' in locals():
if "operation" in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def editor_search(filepath, pattern, start_line=0):
try:
path = os.path.expanduser(filepath)
@ -135,7 +162,9 @@ def editor_search(filepath, pattern, start_line=0):
mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name)
if mux:
mux.write_stdout(f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n")
mux.write_stdout(
f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n"
)
result = {"status": "success", "results": results}
close_editor(filepath)

View File

@ -1,31 +1,36 @@
import os
import hashlib
import os
import time
from typing import Dict
from pr.editor import RPEditor
from ..ui.diff_display import display_diff, get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
from ..tools.patch import display_content_diff
from ..ui.diff_display import get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
_id = 0
def get_uid():
global _id
_id += 3
return _id
def read_file(filepath, db_conn=None):
try:
path = os.path.expanduser(filepath)
with open(path, 'r') as f:
with open(path) as f:
content = f.read()
if db_conn:
from pr.tools.database import db_set
db_set("read:" + path, "true", db_conn)
return {"status": "success", "content": content}
except Exception as e:
return {"status": "error", "error": str(e)}
def write_file(filepath, content, db_conn=None, show_diff=True):
try:
path = os.path.expanduser(filepath)
@ -34,15 +39,24 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
if not is_new_file and db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
if (
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
if not is_new_file:
with open(path, 'r') as f:
with open(path) as f:
old_content = f.read()
operation = track_edit('WRITE', filepath, content=content, old_content=old_content)
operation = track_edit(
"WRITE", filepath, content=content, old_content=old_content
)
tracker.mark_in_progress(operation)
if show_diff and not is_new_file:
@ -59,13 +73,18 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
cursor = db_conn.cursor()
file_hash = hashlib.md5(old_content.encode()).hexdigest()
cursor.execute("SELECT MAX(version) FROM file_versions WHERE filepath = ?", (filepath,))
cursor.execute(
"SELECT MAX(version) FROM file_versions WHERE filepath = ?",
(filepath,),
)
result = cursor.fetchone()
version = (result[0] + 1) if result[0] else 1
cursor.execute("""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
cursor.execute(
"""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
VALUES (?, ?, ?, ?, ?)""",
(filepath, old_content, file_hash, time.time(), version))
(filepath, old_content, file_hash, time.time(), version),
)
db_conn.commit()
except Exception:
pass
@ -79,10 +98,11 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
return {"status": "success", "message": message}
except Exception as e:
if 'operation' in locals():
if "operation" in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def list_directory(path=".", recursive=False):
try:
path = os.path.expanduser(path)
@ -91,21 +111,36 @@ def list_directory(path=".", recursive=False):
for root, dirs, files in os.walk(path):
for name in files:
item_path = os.path.join(root, name)
items.append({"path": item_path, "type": "file", "size": os.path.getsize(item_path)})
items.append(
{
"path": item_path,
"type": "file",
"size": os.path.getsize(item_path),
}
)
for name in dirs:
items.append({"path": os.path.join(root, name), "type": "directory"})
items.append(
{"path": os.path.join(root, name), "type": "directory"}
)
else:
for item in os.listdir(path):
item_path = os.path.join(path, item)
items.append({
"name": item,
"type": "directory" if os.path.isdir(item_path) else "file",
"size": os.path.getsize(item_path) if os.path.isfile(item_path) else None
})
items.append(
{
"name": item,
"type": "directory" if os.path.isdir(item_path) else "file",
"size": (
os.path.getsize(item_path)
if os.path.isfile(item_path)
else None
),
}
)
return {"status": "success", "items": items}
except Exception as e:
return {"status": "error", "error": str(e)}
def mkdir(path):
try:
os.makedirs(os.path.expanduser(path), exist_ok=True)
@ -113,6 +148,7 @@ def mkdir(path):
except Exception as e:
return {"status": "error", "error": str(e)}
def chdir(path):
try:
os.chdir(os.path.expanduser(path))
@ -120,16 +156,32 @@ def chdir(path):
except Exception as e:
return {"status": "error", "error": str(e)}
def getpwd():
try:
return {"status": "success", "path": os.getcwd()}
except Exception as e:
return {"status": "error", "error": str(e)}
def index_source_directory(path):
extensions = [
".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp",
".html", ".css", ".json", ".xml", ".md", ".sh", ".rb", ".go"
".py",
".js",
".ts",
".java",
".cpp",
".c",
".h",
".hpp",
".html",
".css",
".json",
".xml",
".md",
".sh",
".rb",
".go",
]
source_files = []
try:
@ -138,18 +190,16 @@ def index_source_directory(path):
if any(file.endswith(ext) for ext in extensions):
filepath = os.path.join(root, file)
try:
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, encoding="utf-8") as f:
content = f.read()
source_files.append({
"path": filepath,
"content": content
})
source_files.append({"path": filepath, "content": content})
except Exception:
continue
return {"status": "success", "indexed_files": source_files}
except Exception as e:
return {"status": "error", "error": str(e)}
def search_replace(filepath, old_string, new_string, db_conn=None):
try:
path = os.path.expanduser(filepath)
@ -157,25 +207,38 @@ def search_replace(filepath, old_string, new_string, db_conn=None):
return {"status": "error", "error": "File does not exist"}
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
with open(path, 'r') as f:
if (
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
with open(path) as f:
content = f.read()
content = content.replace(old_string, new_string)
with open(path, 'w') as f:
with open(path, "w") as f:
f.write(content)
return {"status": "success", "message": f"Replaced '{old_string}' with '{new_string}' in {path}"}
return {
"status": "success",
"message": f"Replaced '{old_string}' with '{new_string}' in {path}",
}
except Exception as e:
return {"status": "error", "error": str(e)}
_editors = {}
def get_editor(filepath):
if filepath not in _editors:
_editors[filepath] = RPEditor(filepath)
return _editors[filepath]
def close_editor(filepath):
try:
path = os.path.expanduser(filepath)
@ -185,6 +248,7 @@ def close_editor(filepath):
except Exception as e:
return {"status": "error", "error": str(e)}
def open_editor(filepath):
try:
path = os.path.expanduser(filepath)
@ -194,22 +258,34 @@ def open_editor(filepath):
except Exception as e:
return {"status": "error", "error": str(e)}
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
def editor_insert_text(
filepath, text, line=None, col=None, show_diff=True, db_conn=None
):
try:
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
if (
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
with open(path) as f:
old_content = f.read()
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
position = (line if line is not None else 0) * 1000 + (
col if col is not None else 0
)
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation)
editor = get_editor(path)
@ -219,7 +295,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
editor.save_file()
if show_diff and old_content:
with open(path, 'r') as f:
with open(path) as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
@ -228,28 +304,51 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
tracker.mark_completed(operation)
return {"status": "success", "message": f"Inserted text in {path}"}
except Exception as e:
if 'operation' in locals():
if "operation" in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True, db_conn=None):
def editor_replace_text(
filepath,
start_line,
start_col,
end_line,
end_col,
new_text,
show_diff=True,
db_conn=None,
):
try:
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
if (
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
with open(path) as f:
old_content = f.read()
start_pos = start_line * 1000 + start_col
end_pos = end_line * 1000 + end_col
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
content=new_text, old_content=old_content)
operation = track_edit(
"REPLACE",
filepath,
start_pos=start_pos,
end_pos=end_pos,
content=new_text,
old_content=old_content,
)
tracker.mark_in_progress(operation)
editor = get_editor(path)
@ -257,7 +356,7 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
editor.save_file()
if show_diff and old_content:
with open(path, 'r') as f:
with open(path) as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
@ -266,22 +365,25 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
tracker.mark_completed(operation)
return {"status": "success", "message": f"Replaced text in {path}"}
except Exception as e:
if 'operation' in locals():
if "operation" in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def display_edit_summary():
from ..ui.edit_feedback import display_edit_summary
return display_edit_summary()
def display_edit_timeline(show_content=False):
from ..ui.edit_feedback import display_edit_timeline
return display_edit_timeline(show_content)
def clear_edit_tracker():
from ..ui.edit_feedback import clear_tracker
clear_tracker()
return {"status": "success", "message": "Edit tracker cleared"}

View File

@ -1,9 +1,15 @@
import subprocess
import threading
import time
from pr.multiplexer import create_multiplexer, get_multiplexer, close_multiplexer, get_all_multiplexer_states
def start_interactive_session(command, session_name=None, process_type='generic'):
from pr.multiplexer import (
close_multiplexer,
create_multiplexer,
get_all_multiplexer_states,
get_multiplexer,
)
def start_interactive_session(command, session_name=None, process_type="generic"):
"""
Start an interactive session in a dedicated multiplexer.
@ -16,7 +22,7 @@ def start_interactive_session(command, session_name=None, process_type='generic'
session_name: The name of the created session
"""
name, mux = create_multiplexer(session_name)
mux.update_metadata('process_type', process_type)
mux.update_metadata("process_type", process_type)
# Start the process
if isinstance(command, str):
@ -29,19 +35,23 @@ def start_interactive_session(command, session_name=None, process_type='generic'
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
bufsize=1,
)
mux.process = process
mux.update_metadata('pid', process.pid)
mux.update_metadata("pid", process.pid)
# Set process type and handler
detected_type = detect_process_type(command)
mux.set_process_type(detected_type)
# Start output readers
stdout_thread = threading.Thread(target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True)
stderr_thread = threading.Thread(target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True)
stdout_thread = threading.Thread(
target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True
)
stderr_thread = threading.Thread(
target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True
)
stdout_thread.start()
stderr_thread.start()
@ -54,15 +64,17 @@ def start_interactive_session(command, session_name=None, process_type='generic'
close_multiplexer(name)
raise e
def _read_output(stream, write_func):
"""Read from a stream and write to multiplexer buffer."""
try:
for line in iter(stream.readline, ''):
for line in iter(stream.readline, ""):
if line:
write_func(line.rstrip('\n'))
write_func(line.rstrip("\n"))
except Exception as e:
print(f"Error reading output: {e}")
def send_input_to_session(session_name, input_data):
"""
Send input to an interactive session.
@ -75,15 +87,16 @@ def send_input_to_session(session_name, input_data):
if not mux:
raise ValueError(f"Session {session_name} not found")
if not hasattr(mux, 'process') or mux.process.poll() is not None:
if not hasattr(mux, "process") or mux.process.poll() is not None:
raise ValueError(f"Session {session_name} is not active")
try:
mux.process.stdin.write(input_data + '\n')
mux.process.stdin.write(input_data + "\n")
mux.process.stdin.flush()
except Exception as e:
raise RuntimeError(f"Failed to send input to session {session_name}: {e}")
def read_session_output(session_name, lines=None):
"""
Read output from a session.
@ -102,14 +115,12 @@ def read_session_output(session_name, lines=None):
output = mux.get_all_output()
if lines is not None:
# Return last N lines
stdout_lines = output['stdout'].split('\n')[-lines:] if output['stdout'] else []
stderr_lines = output['stderr'].split('\n')[-lines:] if output['stderr'] else []
output = {
'stdout': '\n'.join(stdout_lines),
'stderr': '\n'.join(stderr_lines)
}
stdout_lines = output["stdout"].split("\n")[-lines:] if output["stdout"] else []
stderr_lines = output["stderr"].split("\n")[-lines:] if output["stderr"] else []
output = {"stdout": "\n".join(stdout_lines), "stderr": "\n".join(stderr_lines)}
return output
def list_active_sessions():
"""
List all active interactive sessions.
@ -119,6 +130,7 @@ def list_active_sessions():
"""
return get_all_multiplexer_states()
def get_session_status(session_name):
"""
Get detailed status of a session.
@ -134,15 +146,16 @@ def get_session_status(session_name):
return None
status = mux.get_metadata()
status['is_active'] = hasattr(mux, 'process') and mux.process.poll() is None
if status['is_active']:
status['pid'] = mux.process.pid
status['output_summary'] = {
'stdout_lines': len(mux.stdout_buffer),
'stderr_lines': len(mux.stderr_buffer)
status["is_active"] = hasattr(mux, "process") and mux.process.poll() is None
if status["is_active"]:
status["pid"] = mux.process.pid
status["output_summary"] = {
"stdout_lines": len(mux.stdout_buffer),
"stderr_lines": len(mux.stderr_buffer),
}
return status
def close_interactive_session(session_name):
"""
Close an interactive session.

View File

@ -1,38 +1,43 @@
import os
from typing import Dict, Any, List
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
import time
import uuid
from typing import Any, Dict
def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]:
from pr.memory.knowledge_store import KnowledgeEntry, KnowledgeStore
def add_knowledge_entry(
category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None
) -> Dict[str, Any]:
"""Add a new entry to the knowledge base."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
if entry_id is None:
entry_id = str(uuid.uuid4())[:16]
entry = KnowledgeEntry(
entry_id=entry_id,
category=category,
content=content,
metadata=metadata or {},
created_at=time.time(),
updated_at=time.time()
updated_at=time.time(),
)
store.add_entry(entry)
return {"status": "success", "entry_id": entry_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
"""Retrieve a knowledge entry by ID."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
entry = store.get_entry(entry_id)
if entry:
return {"status": "success", "entry": entry.to_dict()}
@ -41,58 +46,71 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
except Exception as e:
return {"status": "error", "error": str(e)}
def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
def search_knowledge(
query: str, category: str = None, top_k: int = 5
) -> Dict[str, Any]:
"""Search the knowledge base semantically."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
entries = store.search_entries(query, category, top_k)
results = [entry.to_dict() for entry in entries]
return {"status": "success", "results": results}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
"""Get knowledge entries by category."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
entries = store.get_by_category(category, limit)
results = [entry.to_dict() for entry in entries]
return {"status": "success", "entries": results}
except Exception as e:
return {"status": "error", "error": str(e)}
def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
def update_knowledge_importance(
entry_id: str, importance_score: float
) -> Dict[str, Any]:
"""Update the importance score of a knowledge entry."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
store.update_importance(entry_id, importance_score)
return {"status": "success", "entry_id": entry_id, "importance_score": importance_score}
return {
"status": "success",
"entry_id": entry_id,
"importance_score": importance_score,
}
except Exception as e:
return {"status": "error", "error": str(e)}
def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]:
"""Delete a knowledge entry."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
success = store.delete_entry(entry_id)
return {"status": "success" if success else "not_found", "entry_id": entry_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_knowledge_statistics() -> Dict[str, Any]:
"""Get statistics about the knowledge base."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
db_path = os.path.expanduser("~/.assistant_db.sqlite")
store = KnowledgeStore(db_path)
stats = store.get_statistics()
return {"status": "success", "statistics": stats}
except Exception as e:

View File

@ -1,23 +1,37 @@
import os
import tempfile
import subprocess
import difflib
from ..ui.diff_display import display_diff, get_diff_stats, DiffDisplay
import os
import subprocess
import tempfile
from ..ui.diff_display import display_diff, get_diff_stats
def apply_patch(filepath, patch_content, db_conn=None):
try:
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
if (
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return {
"status": "error",
"error": "File must be read before writing. Please read the file first.",
}
# Write patch to temp file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.patch') as f:
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".patch") as f:
f.write(patch_content)
patch_file = f.name
# Run patch command
result = subprocess.run(['patch', path, patch_file], capture_output=True, text=True, cwd=os.path.dirname(path))
result = subprocess.run(
["patch", path, patch_file],
capture_output=True,
text=True,
cwd=os.path.dirname(path),
)
os.unlink(patch_file)
if result.returncode == 0:
return {"status": "success", "output": result.stdout.strip()}
@ -26,11 +40,14 @@ def apply_patch(filepath, patch_content, db_conn=None):
except Exception as e:
return {"status": "error", "error": str(e)}
def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, format_type='unified'):
def create_diff(
file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified"
):
try:
path1 = os.path.expanduser(file1)
path2 = os.path.expanduser(file2)
with open(path1, 'r') as f1, open(path2, 'r') as f2:
with open(path1) as f1, open(path2) as f2:
content1 = f1.read()
content2 = f2.read()
@ -39,53 +56,51 @@ def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, fo
stats = get_diff_stats(content1, content2)
lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True)
plain_diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
plain_diff = list(
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
)
return {
"status": "success",
"diff": ''.join(plain_diff),
"diff": "".join(plain_diff),
"visual_diff": visual_diff,
"stats": stats
"stats": stats,
}
else:
lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True)
diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
return {"status": "success", "diff": ''.join(diff)}
diff = list(
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
)
return {"status": "success", "diff": "".join(diff)}
except Exception as e:
return {"status": "error", "error": str(e)}
def display_file_diff(filepath1, filepath2, format_type='unified', context_lines=3):
def display_file_diff(filepath1, filepath2, format_type="unified", context_lines=3):
try:
path1 = os.path.expanduser(filepath1)
path2 = os.path.expanduser(filepath2)
with open(path1, 'r') as f1:
with open(path1) as f1:
old_content = f1.read()
with open(path2, 'r') as f2:
with open(path2) as f2:
new_content = f2.read()
visual_diff = display_diff(old_content, new_content, filepath1, format_type)
stats = get_diff_stats(old_content, new_content)
return {
"status": "success",
"visual_diff": visual_diff,
"stats": stats
}
return {"status": "success", "visual_diff": visual_diff, "stats": stats}
except Exception as e:
return {"status": "error", "error": str(e)}
def display_content_diff(old_content, new_content, filename='file', format_type='unified'):
def display_content_diff(
old_content, new_content, filename="file", format_type="unified"
):
try:
visual_diff = display_diff(old_content, new_content, filename, format_type)
stats = get_diff_stats(old_content, new_content)
return {
"status": "success",
"visual_diff": visual_diff,
"stats": stats
}
return {"status": "success", "visual_diff": visual_diff, "stats": stats}
except Exception as e:
return {"status": "error", "error": str(e)}
return {"status": "error", "error": str(e)}

View File

@ -1,14 +1,13 @@
import re
import time
from abc import ABC, abstractmethod
class ProcessHandler(ABC):
"""Base class for process-specific handlers."""
def __init__(self, multiplexer):
self.multiplexer = multiplexer
self.state_machine = {}
self.current_state = 'initial'
self.current_state = "initial"
self.prompt_patterns = []
self.response_suggestions = {}
@ -27,7 +26,8 @@ class ProcessHandler(ABC):
def is_waiting_for_input(self):
"""Check if process appears to be waiting for input."""
return self.current_state in ['waiting_confirmation', 'waiting_input']
return self.current_state in ["waiting_confirmation", "waiting_input"]
class AptHandler(ProcessHandler):
"""Handler for apt package manager interactions."""
@ -35,230 +35,238 @@ class AptHandler(ProcessHandler):
def __init__(self, multiplexer):
super().__init__(multiplexer)
self.state_machine = {
'initial': ['running_command'],
'running_command': ['waiting_confirmation', 'completed'],
'waiting_confirmation': ['confirmed', 'cancelled'],
'confirmed': ['installing', 'completed'],
'installing': ['completed', 'error'],
'completed': [],
'error': [],
'cancelled': []
"initial": ["running_command"],
"running_command": ["waiting_confirmation", "completed"],
"waiting_confirmation": ["confirmed", "cancelled"],
"confirmed": ["installing", "completed"],
"installing": ["completed", "error"],
"completed": [],
"error": [],
"cancelled": [],
}
self.prompt_patterns = [
(r'Do you want to continue\?', 'confirmation'),
(r'After this operation.*installed\.', 'size_info'),
(r'Need to get.*B of archives\.', 'download_info'),
(r'Unpacking.*Configuring', 'configuring'),
(r'Setting up', 'setting_up'),
(r'E:\s', 'error')
(r"Do you want to continue\?", "confirmation"),
(r"After this operation.*installed\.", "size_info"),
(r"Need to get.*B of archives\.", "download_info"),
(r"Unpacking.*Configuring", "configuring"),
(r"Setting up", "setting_up"),
(r"E:\s", "error"),
]
def get_process_type(self):
return 'apt'
return "apt"
def update_state(self, output):
"""Update state based on apt output patterns."""
output_lower = output.lower()
# Check for completion
if 'processing triggers' in output_lower or 'done' in output_lower:
self.current_state = 'completed'
if "processing triggers" in output_lower or "done" in output_lower:
self.current_state = "completed"
# Check for confirmation prompts
elif 'do you want to continue' in output_lower:
self.current_state = 'waiting_confirmation'
elif "do you want to continue" in output_lower:
self.current_state = "waiting_confirmation"
# Check for installation progress
elif 'setting up' in output_lower or 'unpacking' in output_lower:
self.current_state = 'installing'
elif "setting up" in output_lower or "unpacking" in output_lower:
self.current_state = "installing"
# Check for errors
elif 'e:' in output_lower or 'error' in output_lower:
self.current_state = 'error'
elif "e:" in output_lower or "error" in output_lower:
self.current_state = "error"
def get_prompt_suggestions(self):
"""Return suggested responses for apt prompts."""
suggestions = super().get_prompt_suggestions()
if self.current_state == 'waiting_confirmation':
suggestions.extend(['y', 'yes', 'n', 'no'])
if self.current_state == "waiting_confirmation":
suggestions.extend(["y", "yes", "n", "no"])
return suggestions
class VimHandler(ProcessHandler):
"""Handler for vim editor interactions."""
def __init__(self, multiplexer):
super().__init__(multiplexer)
self.state_machine = {
'initial': ['normal_mode', 'insert_mode'],
'normal_mode': ['insert_mode', 'command_mode', 'visual_mode'],
'insert_mode': ['normal_mode'],
'command_mode': ['normal_mode'],
'visual_mode': ['normal_mode'],
'exiting': []
"initial": ["normal_mode", "insert_mode"],
"normal_mode": ["insert_mode", "command_mode", "visual_mode"],
"insert_mode": ["normal_mode"],
"command_mode": ["normal_mode"],
"visual_mode": ["normal_mode"],
"exiting": [],
}
self.prompt_patterns = [
(r'-- INSERT --', 'insert_mode'),
(r'-- VISUAL --', 'visual_mode'),
(r':', 'command_mode'),
(r'Press ENTER', 'waiting_enter'),
(r'Saved', 'saved')
(r"-- INSERT --", "insert_mode"),
(r"-- VISUAL --", "visual_mode"),
(r":", "command_mode"),
(r"Press ENTER", "waiting_enter"),
(r"Saved", "saved"),
]
self.mode_indicators = {
'insert': '-- INSERT --',
'visual': '-- VISUAL --',
'command': ':'
"insert": "-- INSERT --",
"visual": "-- VISUAL --",
"command": ":",
}
def get_process_type(self):
return 'vim'
return "vim"
def update_state(self, output):
"""Update state based on vim mode indicators."""
if '-- INSERT --' in output:
self.current_state = 'insert_mode'
elif '-- VISUAL --' in output:
self.current_state = 'visual_mode'
elif output.strip().endswith(':'):
self.current_state = 'command_mode'
elif 'Press ENTER' in output:
self.current_state = 'waiting_enter'
if "-- INSERT --" in output:
self.current_state = "insert_mode"
elif "-- VISUAL --" in output:
self.current_state = "visual_mode"
elif output.strip().endswith(":"):
self.current_state = "command_mode"
elif "Press ENTER" in output:
self.current_state = "waiting_enter"
else:
# Default to normal mode if no specific indicators
self.current_state = 'normal_mode'
self.current_state = "normal_mode"
def get_prompt_suggestions(self):
"""Return suggested commands for vim modes."""
suggestions = super().get_prompt_suggestions()
if self.current_state == 'command_mode':
suggestions.extend(['w', 'q', 'wq', 'q!', 'w!'])
elif self.current_state == 'normal_mode':
suggestions.extend(['i', 'a', 'o', 'dd', ':w', ':q'])
elif self.current_state == 'waiting_enter':
suggestions.extend(['\n'])
if self.current_state == "command_mode":
suggestions.extend(["w", "q", "wq", "q!", "w!"])
elif self.current_state == "normal_mode":
suggestions.extend(["i", "a", "o", "dd", ":w", ":q"])
elif self.current_state == "waiting_enter":
suggestions.extend(["\n"])
return suggestions
class SSHHandler(ProcessHandler):
"""Handler for SSH connection interactions."""
def __init__(self, multiplexer):
super().__init__(multiplexer)
self.state_machine = {
'initial': ['connecting'],
'connecting': ['auth_prompt', 'connected', 'failed'],
'auth_prompt': ['connected', 'failed'],
'connected': ['shell', 'disconnected'],
'shell': ['disconnected'],
'failed': [],
'disconnected': []
"initial": ["connecting"],
"connecting": ["auth_prompt", "connected", "failed"],
"auth_prompt": ["connected", "failed"],
"connected": ["shell", "disconnected"],
"shell": ["disconnected"],
"failed": [],
"disconnected": [],
}
self.prompt_patterns = [
(r'password:', 'password_prompt'),
(r'yes/no', 'host_key_prompt'),
(r'Permission denied', 'auth_failed'),
(r'Welcome to', 'connected'),
(r'\$', 'shell_prompt'),
(r'\#', 'root_shell_prompt'),
(r'Connection closed', 'disconnected')
(r"password:", "password_prompt"),
(r"yes/no", "host_key_prompt"),
(r"Permission denied", "auth_failed"),
(r"Welcome to", "connected"),
(r"\$", "shell_prompt"),
(r"\#", "root_shell_prompt"),
(r"Connection closed", "disconnected"),
]
def get_process_type(self):
return 'ssh'
return "ssh"
def update_state(self, output):
"""Update state based on SSH connection output."""
output_lower = output.lower()
if 'permission denied' in output_lower:
self.current_state = 'failed'
elif 'password:' in output_lower:
self.current_state = 'auth_prompt'
elif 'yes/no' in output_lower:
self.current_state = 'auth_prompt'
elif 'welcome to' in output_lower or 'last login' in output_lower:
self.current_state = 'connected'
elif output.strip().endswith('$') or output.strip().endswith('#'):
self.current_state = 'shell'
elif 'connection closed' in output_lower:
self.current_state = 'disconnected'
if "permission denied" in output_lower:
self.current_state = "failed"
elif "password:" in output_lower:
self.current_state = "auth_prompt"
elif "yes/no" in output_lower:
self.current_state = "auth_prompt"
elif "welcome to" in output_lower or "last login" in output_lower:
self.current_state = "connected"
elif output.strip().endswith("$") or output.strip().endswith("#"):
self.current_state = "shell"
elif "connection closed" in output_lower:
self.current_state = "disconnected"
def get_prompt_suggestions(self):
"""Return suggested responses for SSH prompts."""
suggestions = super().get_prompt_suggestions()
if self.current_state == 'auth_prompt':
if 'password:' in self.multiplexer.get_all_output()['stdout']:
suggestions.extend(['<password>']) # Placeholder for actual password
elif 'yes/no' in self.multiplexer.get_all_output()['stdout']:
suggestions.extend(['yes', 'no'])
if self.current_state == "auth_prompt":
if "password:" in self.multiplexer.get_all_output()["stdout"]:
suggestions.extend(["<password>"]) # Placeholder for actual password
elif "yes/no" in self.multiplexer.get_all_output()["stdout"]:
suggestions.extend(["yes", "no"])
return suggestions
class GenericProcessHandler(ProcessHandler):
"""Fallback handler for unknown process types."""
def __init__(self, multiplexer):
super().__init__(multiplexer)
self.state_machine = {
'initial': ['running'],
'running': ['waiting_input', 'completed'],
'waiting_input': ['running'],
'completed': []
"initial": ["running"],
"running": ["waiting_input", "completed"],
"waiting_input": ["running"],
"completed": [],
}
self.prompt_patterns = [
(r'\?\s*$', 'waiting_input'), # Lines ending with ?
(r'>\s*$', 'waiting_input'), # Lines ending with >
(r':\s*$', 'waiting_input'), # Lines ending with :
(r'done', 'completed'),
(r'finished', 'completed'),
(r'exit code', 'completed')
(r"\?\s*$", "waiting_input"), # Lines ending with ?
(r">\s*$", "waiting_input"), # Lines ending with >
(r":\s*$", "waiting_input"), # Lines ending with :
(r"done", "completed"),
(r"finished", "completed"),
(r"exit code", "completed"),
]
def get_process_type(self):
return 'generic'
return "generic"
def update_state(self, output):
"""Basic state detection for generic processes."""
output_lower = output.lower()
if any(pattern in output_lower for pattern in ['done', 'finished', 'complete']):
self.current_state = 'completed'
elif any(output.strip().endswith(char) for char in ['?', '>', ':']):
self.current_state = 'waiting_input'
if any(pattern in output_lower for pattern in ["done", "finished", "complete"]):
self.current_state = "completed"
elif any(output.strip().endswith(char) for char in ["?", ">", ":"]):
self.current_state = "waiting_input"
else:
self.current_state = 'running'
self.current_state = "running"
# Handler registry
_handler_classes = {
'apt': AptHandler,
'vim': VimHandler,
'ssh': SSHHandler,
'generic': GenericProcessHandler
"apt": AptHandler,
"vim": VimHandler,
"ssh": SSHHandler,
"generic": GenericProcessHandler,
}
def get_handler_for_process(process_type, multiplexer):
"""Get appropriate handler for a process type."""
handler_class = _handler_classes.get(process_type, GenericProcessHandler)
return handler_class(multiplexer)
def detect_process_type(command):
"""Detect process type from command."""
command_str = ' '.join(command) if isinstance(command, list) else command
command_str = " ".join(command) if isinstance(command, list) else command
command_lower = command_str.lower()
if 'apt' in command_lower or 'apt-get' in command_lower:
return 'apt'
elif 'vim' in command_lower or 'vi ' in command_lower:
return 'vim'
elif 'ssh' in command_lower:
return 'ssh'
if "apt" in command_lower or "apt-get" in command_lower:
return "apt"
elif "vim" in command_lower or "vi " in command_lower:
return "vim"
elif "ssh" in command_lower:
return "ssh"
else:
return 'generic'
return 'ssh'
return "generic"
return "ssh"
def detect_process_type(command):
"""Detect process type from command."""
command_str = ' '.join(command) if isinstance(command, list) else command
command_str = " ".join(command) if isinstance(command, list) else command
command_lower = command_str.lower()
if 'apt' in command_lower or 'apt-get' in command_lower:
return 'apt'
elif 'vim' in command_lower or 'vi ' in command_lower:
return 'vim'
elif 'ssh' in command_lower:
return 'ssh'
if "apt" in command_lower or "apt-get" in command_lower:
return "apt"
elif "vim" in command_lower or "vi " in command_lower:
return "vim"
elif "ssh" in command_lower:
return "ssh"
else:
return 'generic'
return "generic"

View File

@ -1,6 +1,6 @@
import re
import time
from collections import defaultdict
class PromptDetector:
"""Detects various process prompts and manages interaction state."""
@ -10,101 +10,119 @@ class PromptDetector:
self.state_machines = self._load_state_machines()
self.session_states = {}
self.timeout_configs = {
'default': 30, # 30 seconds default timeout
'apt': 300, # 5 minutes for apt operations
'ssh': 60, # 1 minute for SSH connections
'vim': 3600 # 1 hour for vim sessions
"default": 30, # 30 seconds default timeout
"apt": 300, # 5 minutes for apt operations
"ssh": 60, # 1 minute for SSH connections
"vim": 3600, # 1 hour for vim sessions
}
def _load_prompt_patterns(self):
"""Load regex patterns for detecting various prompts."""
return {
'bash_prompt': [
re.compile(r'[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$'),
re.compile(r'\$\s*$'),
re.compile(r'#\s*$'),
re.compile(r'>\s*$') # Continuation prompt
"bash_prompt": [
re.compile(r"[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$"),
re.compile(r"\$\s*$"),
re.compile(r"#\s*$"),
re.compile(r">\s*$"), # Continuation prompt
],
'confirmation': [
re.compile(r'[Yy]/[Nn]', re.IGNORECASE),
re.compile(r'[Yy]es/[Nn]o', re.IGNORECASE),
re.compile(r'continue\?', re.IGNORECASE),
re.compile(r'proceed\?', re.IGNORECASE)
"confirmation": [
re.compile(r"[Yy]/[Nn]", re.IGNORECASE),
re.compile(r"[Yy]es/[Nn]o", re.IGNORECASE),
re.compile(r"continue\?", re.IGNORECASE),
re.compile(r"proceed\?", re.IGNORECASE),
],
'password': [
re.compile(r'password:', re.IGNORECASE),
re.compile(r'passphrase:', re.IGNORECASE),
re.compile(r'enter password', re.IGNORECASE)
"password": [
re.compile(r"password:", re.IGNORECASE),
re.compile(r"passphrase:", re.IGNORECASE),
re.compile(r"enter password", re.IGNORECASE),
],
'sudo_password': [
re.compile(r'\[sudo\].*password', re.IGNORECASE)
"sudo_password": [re.compile(r"\[sudo\].*password", re.IGNORECASE)],
"apt": [
re.compile(r"Do you want to continue\?", re.IGNORECASE),
re.compile(r"After this operation", re.IGNORECASE),
re.compile(r"Need to get", re.IGNORECASE),
],
'apt': [
re.compile(r'Do you want to continue\?', re.IGNORECASE),
re.compile(r'After this operation', re.IGNORECASE),
re.compile(r'Need to get', re.IGNORECASE)
"vim": [
re.compile(r"-- INSERT --"),
re.compile(r"-- VISUAL --"),
re.compile(r":"),
re.compile(r"Press ENTER", re.IGNORECASE),
],
'vim': [
re.compile(r'-- INSERT --'),
re.compile(r'-- VISUAL --'),
re.compile(r':'),
re.compile(r'Press ENTER', re.IGNORECASE)
"ssh": [
re.compile(r"yes/no", re.IGNORECASE),
re.compile(r"password:", re.IGNORECASE),
re.compile(r"Permission denied", re.IGNORECASE),
],
'ssh': [
re.compile(r'yes/no', re.IGNORECASE),
re.compile(r'password:', re.IGNORECASE),
re.compile(r'Permission denied', re.IGNORECASE)
"git": [
re.compile(r"Username:", re.IGNORECASE),
re.compile(r"Email:", re.IGNORECASE),
],
'git': [
re.compile(r'Username:', re.IGNORECASE),
re.compile(r'Email:', re.IGNORECASE)
"error": [
re.compile(r"error:", re.IGNORECASE),
re.compile(r"failed", re.IGNORECASE),
re.compile(r"exception", re.IGNORECASE),
],
'error': [
re.compile(r'error:', re.IGNORECASE),
re.compile(r'failed', re.IGNORECASE),
re.compile(r'exception', re.IGNORECASE)
]
}
def _load_state_machines(self):
"""Load state machines for different process types."""
return {
'apt': {
'states': ['initial', 'running', 'confirming', 'installing', 'completed', 'error'],
'transitions': {
'initial': ['running'],
'running': ['confirming', 'installing', 'completed', 'error'],
'confirming': ['installing', 'cancelled'],
'installing': ['completed', 'error'],
'completed': [],
'error': [],
'cancelled': []
}
"apt": {
"states": [
"initial",
"running",
"confirming",
"installing",
"completed",
"error",
],
"transitions": {
"initial": ["running"],
"running": ["confirming", "installing", "completed", "error"],
"confirming": ["installing", "cancelled"],
"installing": ["completed", "error"],
"completed": [],
"error": [],
"cancelled": [],
},
},
'ssh': {
'states': ['initial', 'connecting', 'authenticating', 'connected', 'error'],
'transitions': {
'initial': ['connecting'],
'connecting': ['authenticating', 'connected', 'error'],
'authenticating': ['connected', 'error'],
'connected': ['error'],
'error': []
}
"ssh": {
"states": [
"initial",
"connecting",
"authenticating",
"connected",
"error",
],
"transitions": {
"initial": ["connecting"],
"connecting": ["authenticating", "connected", "error"],
"authenticating": ["connected", "error"],
"connected": ["error"],
"error": [],
},
},
"vim": {
"states": [
"initial",
"normal",
"insert",
"visual",
"command",
"exiting",
],
"transitions": {
"initial": ["normal", "insert"],
"normal": ["insert", "visual", "command", "exiting"],
"insert": ["normal"],
"visual": ["normal"],
"command": ["normal", "exiting"],
"exiting": [],
},
},
'vim': {
'states': ['initial', 'normal', 'insert', 'visual', 'command', 'exiting'],
'transitions': {
'initial': ['normal', 'insert'],
'normal': ['insert', 'visual', 'command', 'exiting'],
'insert': ['normal'],
'visual': ['normal'],
'command': ['normal', 'exiting'],
'exiting': []
}
}
}
def detect_prompt(self, output, process_type='generic'):
def detect_prompt(self, output, process_type="generic"):
"""Detect what type of prompt is present in the output."""
detections = {}
@ -125,93 +143,97 @@ class PromptDetector:
return detections
def get_response_suggestions(self, prompt_detections, process_type='generic'):
def get_response_suggestions(self, prompt_detections, process_type="generic"):
"""Get suggested responses based on detected prompts."""
suggestions = []
for category, patterns in prompt_detections.items():
if category == 'confirmation':
suggestions.extend(['y', 'yes', 'n', 'no'])
elif category == 'password':
suggestions.append('<password>')
elif category == 'sudo_password':
suggestions.append('<sudo_password>')
elif category == 'apt':
if any('continue' in p for p in patterns):
suggestions.extend(['y', 'yes'])
elif category == 'vim':
if any(':' in p for p in patterns):
suggestions.extend(['w', 'q', 'wq', 'q!'])
elif any('ENTER' in p for p in patterns):
suggestions.append('\n')
elif category == 'ssh':
if any('yes/no' in p for p in patterns):
suggestions.extend(['yes', 'no'])
elif any('password' in p for p in patterns):
suggestions.append('<password>')
elif category == 'bash_prompt':
suggestions.extend(['help', 'ls', 'pwd', 'exit'])
if category == "confirmation":
suggestions.extend(["y", "yes", "n", "no"])
elif category == "password":
suggestions.append("<password>")
elif category == "sudo_password":
suggestions.append("<sudo_password>")
elif category == "apt":
if any("continue" in p for p in patterns):
suggestions.extend(["y", "yes"])
elif category == "vim":
if any(":" in p for p in patterns):
suggestions.extend(["w", "q", "wq", "q!"])
elif any("ENTER" in p for p in patterns):
suggestions.append("\n")
elif category == "ssh":
if any("yes/no" in p for p in patterns):
suggestions.extend(["yes", "no"])
elif any("password" in p for p in patterns):
suggestions.append("<password>")
elif category == "bash_prompt":
suggestions.extend(["help", "ls", "pwd", "exit"])
return list(set(suggestions)) # Remove duplicates
def update_session_state(self, session_name, output, process_type='generic'):
def update_session_state(self, session_name, output, process_type="generic"):
"""Update the state machine for a session based on output."""
if session_name not in self.session_states:
self.session_states[session_name] = {
'current_state': 'initial',
'process_type': process_type,
'last_activity': time.time(),
'transitions': []
"current_state": "initial",
"process_type": process_type,
"last_activity": time.time(),
"transitions": [],
}
session_state = self.session_states[session_name]
old_state = session_state['current_state']
old_state = session_state["current_state"]
# Detect prompts and determine new state
detections = self.detect_prompt(output, process_type)
new_state = self._determine_state_from_detections(detections, process_type, old_state)
new_state = self._determine_state_from_detections(
detections, process_type, old_state
)
if new_state != old_state:
session_state['transitions'].append({
'from': old_state,
'to': new_state,
'timestamp': time.time(),
'trigger': detections
})
session_state['current_state'] = new_state
session_state["transitions"].append(
{
"from": old_state,
"to": new_state,
"timestamp": time.time(),
"trigger": detections,
}
)
session_state["current_state"] = new_state
session_state['last_activity'] = time.time()
session_state["last_activity"] = time.time()
return new_state
def _determine_state_from_detections(self, detections, process_type, current_state):
"""Determine new state based on prompt detections."""
if process_type in self.state_machines:
state_machine = self.state_machines[process_type]
self.state_machines[process_type]
# State transition logic based on detections
if 'confirmation' in detections and current_state in ['running', 'initial']:
return 'confirming'
elif 'password' in detections or 'sudo_password' in detections:
return 'authenticating'
elif 'error' in detections:
return 'error'
elif 'bash_prompt' in detections and current_state != 'initial':
return 'connected' if process_type == 'ssh' else 'completed'
elif 'vim' in detections:
if any('-- INSERT --' in p for p in detections.get('vim', [])):
return 'insert'
elif any('-- VISUAL --' in p for p in detections.get('vim', [])):
return 'visual'
elif any(':' in p for p in detections.get('vim', [])):
return 'command'
if "confirmation" in detections and current_state in ["running", "initial"]:
return "confirming"
elif "password" in detections or "sudo_password" in detections:
return "authenticating"
elif "error" in detections:
return "error"
elif "bash_prompt" in detections and current_state != "initial":
return "connected" if process_type == "ssh" else "completed"
elif "vim" in detections:
if any("-- INSERT --" in p for p in detections.get("vim", [])):
return "insert"
elif any("-- VISUAL --" in p for p in detections.get("vim", [])):
return "visual"
elif any(":" in p for p in detections.get("vim", [])):
return "command"
# Default state progression
if current_state == 'initial':
return 'running'
elif current_state == 'running' and detections:
return 'waiting_input'
elif current_state == 'waiting_input' and not detections:
return 'running'
if current_state == "initial":
return "running"
elif current_state == "running" and detections:
return "waiting_input"
elif current_state == "waiting_input" and not detections:
return "running"
return current_state
@ -220,15 +242,15 @@ class PromptDetector:
if session_name not in self.session_states:
return False
state = self.session_states[session_name]['current_state']
process_type = self.session_states[session_name]['process_type']
state = self.session_states[session_name]["current_state"]
process_type = self.session_states[session_name]["process_type"]
# States that typically indicate waiting for input
waiting_states = {
'generic': ['waiting_input'],
'apt': ['confirming'],
'ssh': ['authenticating'],
'vim': ['command', 'insert', 'visual']
"generic": ["waiting_input"],
"apt": ["confirming"],
"ssh": ["authenticating"],
"vim": ["command", "insert", "visual"],
}
return state in waiting_states.get(process_type, [])
@ -236,10 +258,10 @@ class PromptDetector:
def get_session_timeout(self, session_name):
"""Get the timeout for a session based on its process type."""
if session_name not in self.session_states:
return self.timeout_configs['default']
return self.timeout_configs["default"]
process_type = self.session_states[session_name]['process_type']
return self.timeout_configs.get(process_type, self.timeout_configs['default'])
process_type = self.session_states[session_name]["process_type"]
return self.timeout_configs.get(process_type, self.timeout_configs["default"])
def check_for_timeouts(self):
"""Check all sessions for timeouts and return timed out sessions."""
@ -248,7 +270,7 @@ class PromptDetector:
for session_name, state in self.session_states.items():
timeout = self.get_session_timeout(session_name)
if current_time - state['last_activity'] > timeout:
if current_time - state["last_activity"] > timeout:
timed_out.append(session_name)
return timed_out
@ -260,16 +282,18 @@ class PromptDetector:
state = self.session_states[session_name]
return {
'current_state': state['current_state'],
'process_type': state['process_type'],
'last_activity': state['last_activity'],
'transitions': state['transitions'][-5:], # Last 5 transitions
'is_waiting': self.is_waiting_for_input(session_name)
"current_state": state["current_state"],
"process_type": state["process_type"],
"last_activity": state["last_activity"],
"transitions": state["transitions"][-5:], # Last 5 transitions
"is_waiting": self.is_waiting_for_input(session_name),
}
# Global detector instance
_detector = None
def get_global_detector():
"""Get the global prompt detector instance."""
global _detector

View File

@ -1,6 +1,7 @@
import contextlib
import traceback
from io import StringIO
import contextlib
def python_exec(code, python_globals):
try:

View File

@ -1,7 +1,8 @@
import urllib.request
import urllib.parse
import urllib.error
import json
import urllib.error
import urllib.parse
import urllib.request
def http_fetch(url, headers=None):
try:
@ -11,26 +12,28 @@ def http_fetch(url, headers=None):
req.add_header(key, value)
with urllib.request.urlopen(req) as response:
content = response.read().decode('utf-8')
content = response.read().decode("utf-8")
return {"status": "success", "content": content[:10000]}
except Exception as e:
return {"status": "error", "error": str(e)}
def _perform_search(base_url, query, params=None):
try:
full_url = f"https://static.molodetz.nl/search.cgi?query={query}"
with urllib.request.urlopen(full_url) as response:
content = response.read().decode('utf-8')
content = response.read().decode("utf-8")
return {"status": "success", "content": json.loads(content)}
except Exception as e:
return {"status": "error", "error": str(e)}
def web_search(query):
base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query)
def web_search_news(query):
base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query)

View File

@ -1,5 +1,11 @@
from pr.ui.colors import Colors
from pr.ui.rendering import highlight_code, render_markdown
from pr.ui.display import display_tool_call, print_autonomous_header
from pr.ui.rendering import highlight_code, render_markdown
__all__ = ['Colors', 'highlight_code', 'render_markdown', 'display_tool_call', 'print_autonomous_header']
__all__ = [
"Colors",
"highlight_code",
"render_markdown",
"display_tool_call",
"print_autonomous_header",
]

View File

@ -1,14 +1,14 @@
class Colors:
RESET = '\033[0m'
BOLD = '\033[1m'
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
MAGENTA = '\033[95m'
CYAN = '\033[96m'
GRAY = '\033[90m'
WHITE = '\033[97m'
BG_BLUE = '\033[44m'
BG_GREEN = '\033[42m'
BG_RED = '\033[41m'
RESET = "\033[0m"
BOLD = "\033[1m"
RED = "\033[91m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
BLUE = "\033[94m"
MAGENTA = "\033[95m"
CYAN = "\033[96m"
GRAY = "\033[90m"
WHITE = "\033[97m"
BG_BLUE = "\033[44m"
BG_GREEN = "\033[42m"
BG_RED = "\033[41m"

View File

@ -1,5 +1,6 @@
import difflib
from typing import List, Tuple, Dict, Optional
from typing import Dict, List, Optional, Tuple
from .colors import Colors
@ -19,8 +20,13 @@ class DiffStats:
class DiffLine:
def __init__(self, line_type: str, content: str, old_line_num: Optional[int] = None,
new_line_num: Optional[int] = None):
def __init__(
self,
line_type: str,
content: str,
old_line_num: Optional[int] = None,
new_line_num: Optional[int] = None,
):
self.line_type = line_type
self.content = content
self.old_line_num = old_line_num
@ -28,27 +34,27 @@ class DiffLine:
def format(self, show_line_nums: bool = True) -> str:
color = {
'add': Colors.GREEN,
'delete': Colors.RED,
'context': Colors.GRAY,
'header': Colors.CYAN,
'stats': Colors.BLUE
"add": Colors.GREEN,
"delete": Colors.RED,
"context": Colors.GRAY,
"header": Colors.CYAN,
"stats": Colors.BLUE,
}.get(self.line_type, Colors.RESET)
prefix = {
'add': '+ ',
'delete': '- ',
'context': ' ',
'header': '',
'stats': ''
}.get(self.line_type, ' ')
"add": "+ ",
"delete": "- ",
"context": " ",
"header": "",
"stats": "",
}.get(self.line_type, " ")
if show_line_nums and self.line_type in ('add', 'delete', 'context'):
old_num = str(self.old_line_num) if self.old_line_num else ' '
new_num = str(self.new_line_num) if self.new_line_num else ' '
if show_line_nums and self.line_type in ("add", "delete", "context"):
old_num = str(self.old_line_num) if self.old_line_num else " "
new_num = str(self.new_line_num) if self.new_line_num else " "
line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} "
else:
line_num_str = ''
line_num_str = ""
return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}"
@ -57,8 +63,9 @@ class DiffDisplay:
def __init__(self, context_lines: int = 3):
self.context_lines = context_lines
def create_diff(self, old_content: str, new_content: str,
filename: str = "file") -> Tuple[List[DiffLine], DiffStats]:
def create_diff(
self, old_content: str, new_content: str, filename: str = "file"
) -> Tuple[List[DiffLine], DiffStats]:
old_lines = old_content.splitlines(keepends=True)
new_lines = new_content.splitlines(keepends=True)
@ -67,31 +74,38 @@ class DiffDisplay:
stats.files_changed = 1
diff = difflib.unified_diff(
old_lines, new_lines,
old_lines,
new_lines,
fromfile=f"a/{filename}",
tofile=f"b/{filename}",
n=self.context_lines
n=self.context_lines,
)
old_line_num = 0
new_line_num = 0
for line in diff:
if line.startswith('---') or line.startswith('+++'):
diff_lines.append(DiffLine('header', line.rstrip()))
elif line.startswith('@@'):
diff_lines.append(DiffLine('header', line.rstrip()))
if line.startswith("---") or line.startswith("+++"):
diff_lines.append(DiffLine("header", line.rstrip()))
elif line.startswith("@@"):
diff_lines.append(DiffLine("header", line.rstrip()))
old_line_num, new_line_num = self._parse_hunk_header(line)
elif line.startswith('+'):
elif line.startswith("+"):
stats.insertions += 1
diff_lines.append(DiffLine('add', line[1:].rstrip(), None, new_line_num))
diff_lines.append(
DiffLine("add", line[1:].rstrip(), None, new_line_num)
)
new_line_num += 1
elif line.startswith('-'):
elif line.startswith("-"):
stats.deletions += 1
diff_lines.append(DiffLine('delete', line[1:].rstrip(), old_line_num, None))
diff_lines.append(
DiffLine("delete", line[1:].rstrip(), old_line_num, None)
)
old_line_num += 1
elif line.startswith(' '):
diff_lines.append(DiffLine('context', line[1:].rstrip(), old_line_num, new_line_num))
elif line.startswith(" "):
diff_lines.append(
DiffLine("context", line[1:].rstrip(), old_line_num, new_line_num)
)
old_line_num += 1
new_line_num += 1
@ -101,15 +115,20 @@ class DiffDisplay:
def _parse_hunk_header(self, header: str) -> Tuple[int, int]:
try:
parts = header.split('@@')[1].strip().split()
old_start = int(parts[0].split(',')[0].replace('-', ''))
new_start = int(parts[1].split(',')[0].replace('+', ''))
parts = header.split("@@")[1].strip().split()
old_start = int(parts[0].split(",")[0].replace("-", ""))
new_start = int(parts[1].split(",")[0].replace("+", ""))
return old_start, new_start
except (IndexError, ValueError):
return 0, 0
def render_diff(self, diff_lines: List[DiffLine], stats: DiffStats,
show_line_nums: bool = True, show_stats: bool = True) -> str:
def render_diff(
self,
diff_lines: List[DiffLine],
stats: DiffStats,
show_line_nums: bool = True,
show_stats: bool = True,
) -> str:
output = []
if show_stats:
@ -124,10 +143,15 @@ class DiffDisplay:
if show_stats:
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
return "\n".join(output)
def display_file_diff(self, old_content: str, new_content: str,
filename: str = "file", show_line_nums: bool = True) -> str:
def display_file_diff(
self,
old_content: str,
new_content: str,
filename: str = "file",
show_line_nums: bool = True,
) -> str:
diff_lines, stats = self.create_diff(old_content, new_content, filename)
if not diff_lines:
@ -135,8 +159,13 @@ class DiffDisplay:
return self.render_diff(diff_lines, stats, show_line_nums)
def display_side_by_side(self, old_content: str, new_content: str,
filename: str = "file", width: int = 80) -> str:
def display_side_by_side(
self,
old_content: str,
new_content: str,
filename: str = "file",
width: int = 80,
) -> str:
old_lines = old_content.splitlines()
new_lines = new_content.splitlines()
@ -144,40 +173,57 @@ class DiffDisplay:
output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}")
output.append(
f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}"
)
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
half_width = (width - 5) // 2
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])):
if tag == "equal":
for i, (old_line, new_line) in enumerate(
zip(old_lines[i1:i2], new_lines[j1:j2])
):
old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width)
output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}")
elif tag == 'replace':
output.append(
f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}"
)
elif tag == "replace":
max_lines = max(i2 - i1, j2 - j1)
for i in range(max_lines):
old_line = old_lines[i1 + i] if i1 + i < i2 else ""
new_line = new_lines[j1 + i] if j1 + i < j2 else ""
old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width)
output.append(f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}")
elif tag == 'delete':
output.append(
f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}"
)
elif tag == "delete":
for old_line in old_lines[i1:i2]:
old_display = old_line[:half_width].ljust(half_width)
output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}")
elif tag == 'insert':
output.append(
f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}"
)
elif tag == "insert":
for new_line in new_lines[j1:j2]:
new_display = new_line[:half_width].ljust(half_width)
output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}")
output.append(
f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}"
)
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
return '\n'.join(output)
return "\n".join(output)
def display_diff(old_content: str, new_content: str, filename: str = "file",
format_type: str = "unified", context_lines: int = 3) -> str:
def display_diff(
old_content: str,
new_content: str,
filename: str = "file",
format_type: str = "unified",
context_lines: int = 3,
) -> str:
displayer = DiffDisplay(context_lines)
if format_type == "side-by-side":
@ -191,9 +237,9 @@ def get_diff_stats(old_content: str, new_content: str) -> Dict[str, int]:
_, stats = displayer.create_diff(old_content, new_content)
return {
'insertions': stats.insertions,
'deletions': stats.deletions,
'modifications': stats.modifications,
'total_changes': stats.total_changes,
'files_changed': stats.files_changed
"insertions": stats.insertions,
"deletions": stats.deletions,
"modifications": stats.modifications,
"total_changes": stats.total_changes,
"files_changed": stats.files_changed,
}

View File

@ -1,8 +1,6 @@
import json
import time
from typing import Dict, Any
from pr.ui.colors import Colors
def display_tool_call(tool_name, arguments, status="running", result=None):
if status == "running":
return
@ -15,8 +13,11 @@ def display_tool_call(tool_name, arguments, status="running", result=None):
print(f"{Colors.GRAY}{line}{Colors.RESET}")
def print_autonomous_header(task):
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}")
print(
f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}"
)
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n")

View File

@ -1,12 +1,20 @@
from typing import List, Dict, Optional
from datetime import datetime
from typing import Dict, List, Optional
from .colors import Colors
from .progress import ProgressBar
class EditOperation:
def __init__(self, op_type: str, filepath: str, start_pos: int = 0,
end_pos: int = 0, content: str = "", old_content: str = ""):
def __init__(
self,
op_type: str,
filepath: str,
start_pos: int = 0,
end_pos: int = 0,
content: str = "",
old_content: str = "",
):
self.op_type = op_type
self.filepath = filepath
self.start_pos = start_pos
@ -18,40 +26,46 @@ class EditOperation:
def format_operation(self) -> str:
op_colors = {
'INSERT': Colors.GREEN,
'REPLACE': Colors.YELLOW,
'DELETE': Colors.RED,
'WRITE': Colors.BLUE
"INSERT": Colors.GREEN,
"REPLACE": Colors.YELLOW,
"DELETE": Colors.RED,
"WRITE": Colors.BLUE,
}
color = op_colors.get(self.op_type, Colors.RESET)
status_icon = {
'pending': '',
'in_progress': '',
'completed': '',
'failed': ''
}.get(self.status, '')
"pending": "",
"in_progress": "",
"completed": "",
"failed": "",
}.get(self.status, "")
return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}"
def format_details(self, show_content: bool = True) -> str:
output = [self.format_operation()]
if self.op_type in ('INSERT', 'REPLACE'):
output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}")
if self.op_type in ("INSERT", "REPLACE"):
output.append(
f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}"
)
if show_content:
if self.old_content:
lines = self.old_content.split('\n')
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
lines = self.old_content.split("\n")
preview = lines[0][:60] + (
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
)
output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
if self.content:
lines = self.content.split('\n')
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
lines = self.content.split("\n")
preview = lines[0][:60] + (
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
)
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
return '\n'.join(output)
return "\n".join(output)
class EditTracker:
@ -76,11 +90,13 @@ class EditTracker:
def get_stats(self) -> Dict[str, int]:
stats = {
'total': len(self.operations),
'completed': sum(1 for op in self.operations if op.status == 'completed'),
'pending': sum(1 for op in self.operations if op.status == 'pending'),
'in_progress': sum(1 for op in self.operations if op.status == 'in_progress'),
'failed': sum(1 for op in self.operations if op.status == 'failed')
"total": len(self.operations),
"completed": sum(1 for op in self.operations if op.status == "completed"),
"pending": sum(1 for op in self.operations if op.status == "pending"),
"in_progress": sum(
1 for op in self.operations if op.status == "in_progress"
),
"failed": sum(1 for op in self.operations if op.status == "failed"),
}
return stats
@ -88,7 +104,7 @@ class EditTracker:
if not self.operations:
return 0.0
stats = self.get_stats()
return (stats['completed'] / stats['total']) * 100
return (stats["completed"] / stats["total"]) * 100
def display_progress(self) -> str:
if not self.operations:
@ -96,26 +112,30 @@ class EditTracker:
output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}")
output.append(
f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}"
)
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
stats = self.get_stats()
completion = self.get_completion_percentage()
self.get_completion_percentage()
progress_bar = ProgressBar(total=stats['total'], width=40)
progress_bar.current = stats['completed']
progress_bar = ProgressBar(total=stats["total"], width=40)
progress_bar.current = stats["completed"]
bar_display = progress_bar._get_bar_display()
output.append(f"Progress: {bar_display}")
output.append(f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n")
output.append(
f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n"
)
output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}")
for i, op in enumerate(self.operations[-5:], 1):
output.append(f"{i}. {op.format_operation()}")
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
return "\n".join(output)
def display_timeline(self, show_content: bool = False) -> str:
if not self.operations:
@ -134,18 +154,20 @@ class EditTracker:
stats = self.get_stats()
output.append(f"{Colors.BOLD}Summary:{Colors.RESET}")
output.append(f"{Colors.BLUE}Total operations: {stats['total']}, "
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}")
output.append(
f"{Colors.BLUE}Total operations: {stats['total']}, "
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}"
)
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
return "\n".join(output)
def display_summary(self) -> str:
if not self.operations:
return f"{Colors.GRAY}No edits to summarize{Colors.RESET}"
stats = self.get_stats()
files_modified = len(set(op.filepath for op in self.operations))
files_modified = len({op.filepath for op in self.operations})
output = []
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}")
@ -156,7 +178,7 @@ class EditTracker:
output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}")
output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}")
if stats['failed'] > 0:
if stats["failed"] > 0:
output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}")
output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}")
@ -168,7 +190,7 @@ class EditTracker:
output.append(f" {op_type}: {count}")
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
return "\n".join(output)
def clear(self):
self.operations.clear()

View File

@ -1,31 +1,31 @@
import json
import sys
from typing import Any, Dict, List
from datetime import datetime
from typing import Any
class OutputFormatter:
def __init__(self, format_type: str = 'text', quiet: bool = False):
def __init__(self, format_type: str = "text", quiet: bool = False):
self.format_type = format_type
self.quiet = quiet
def output(self, data: Any, message_type: str = 'response'):
if self.quiet and message_type not in ['error', 'result']:
def output(self, data: Any, message_type: str = "response"):
if self.quiet and message_type not in ["error", "result"]:
return
if self.format_type == 'json':
if self.format_type == "json":
self._output_json(data, message_type)
elif self.format_type == 'structured':
elif self.format_type == "structured":
self._output_structured(data, message_type)
else:
self._output_text(data, message_type)
def _output_json(self, data: Any, message_type: str):
output = {
'type': message_type,
'timestamp': datetime.now().isoformat(),
'data': data
"type": message_type,
"timestamp": datetime.now().isoformat(),
"data": data,
}
print(json.dumps(output, indent=2))
@ -46,24 +46,24 @@ class OutputFormatter:
print(data)
def error(self, message: str):
if self.format_type == 'json':
self._output_json({'error': message}, 'error')
if self.format_type == "json":
self._output_json({"error": message}, "error")
else:
print(f"Error: {message}", file=sys.stderr)
def success(self, message: str):
if not self.quiet:
if self.format_type == 'json':
self._output_json({'success': message}, 'success')
if self.format_type == "json":
self._output_json({"success": message}, "success")
else:
print(message)
def info(self, message: str):
if not self.quiet:
if self.format_type == 'json':
self._output_json({'info': message}, 'info')
if self.format_type == "json":
self._output_json({"info": message}, "info")
else:
print(message)
def result(self, data: Any):
self.output(data, 'result')
self.output(data, "result")

View File

@ -1,6 +1,6 @@
import sys
import time
import threading
import time
class ProgressIndicator:
@ -30,15 +30,15 @@ class ProgressIndicator:
self.running = False
if self.thread:
self.thread.join(timeout=1.0)
sys.stdout.write('\r' + ' ' * (len(self.message) + 10) + '\r')
sys.stdout.write("\r" + " " * (len(self.message) + 10) + "\r")
sys.stdout.flush()
def _animate(self):
spinner = ['', '', '', '', '', '', '', '', '', '']
spinner = ["", "", "", "", "", "", "", "", "", ""]
idx = 0
while self.running:
sys.stdout.write(f'\r{spinner[idx]} {self.message}...')
sys.stdout.write(f"\r{spinner[idx]} {self.message}...")
sys.stdout.flush()
idx = (idx + 1) % len(spinner)
time.sleep(0.1)
@ -62,14 +62,20 @@ class ProgressBar:
else:
percent = int((self.current / self.total) * 100)
filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width
bar = '' * filled + '' * (self.width - filled)
filled = (
int((self.current / self.total) * self.width)
if self.total > 0
else self.width
)
bar = "" * filled + "" * (self.width - filled)
sys.stdout.write(f'\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})')
sys.stdout.write(
f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})"
)
sys.stdout.flush()
if self.current >= self.total:
sys.stdout.write('\n')
sys.stdout.write("\n")
def finish(self):
self.current = self.total

View File

@ -1,90 +1,103 @@
import re
from pr.ui.colors import Colors
from pr.config import LANGUAGE_KEYWORDS
from pr.ui.colors import Colors
def highlight_code(code, language=None, syntax_highlighting=True):
if not syntax_highlighting:
return code
if not language:
if 'def ' in code or 'import ' in code:
language = 'python'
elif 'function ' in code or 'const ' in code:
language = 'javascript'
elif 'public ' in code or 'class ' in code:
language = 'java'
if "def " in code or "import " in code:
language = "python"
elif "function " in code or "const " in code:
language = "javascript"
elif "public " in code or "class " in code:
language = "java"
if language and language in LANGUAGE_KEYWORDS:
keywords = LANGUAGE_KEYWORDS[language]
for keyword in keywords:
pattern = r'\b' + re.escape(keyword) + r'\b'
pattern = r"\b" + re.escape(keyword) + r"\b"
code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code)
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
code = re.sub(r'#(.*)$', f'{Colors.GRAY}#\\1{Colors.RESET}', code, flags=re.MULTILINE)
code = re.sub(r'//(.*)$', f'{Colors.GRAY}//\\1{Colors.RESET}', code, flags=re.MULTILINE)
code = re.sub(
r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE
)
code = re.sub(
r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE
)
return code
def render_markdown(text, syntax_highlighting=True):
if not syntax_highlighting:
return text
code_blocks = []
def extract_code_block(match):
lang = match.group(1) or ''
lang = match.group(1) or ""
code = match.group(2)
highlighted_code = highlight_code(code.strip('\n'), lang, syntax_highlighting)
highlighted_code = highlight_code(code.strip("\n"), lang, syntax_highlighting)
placeholder = f"%%CODEBLOCK{len(code_blocks)}%%"
full_block = f'{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}'
full_block = f"{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}"
code_blocks.append(full_block)
return placeholder
text = re.sub(r'```(\w*)\n(.*?)\n?```', extract_code_block, text, flags=re.DOTALL)
text = re.sub(r"```(\w*)\n(.*?)\n?```", extract_code_block, text, flags=re.DOTALL)
inline_codes = []
def extract_inline_code(match):
code = match.group(1)
placeholder = f"%%INLINECODE{len(inline_codes)}%%"
inline_codes.append(f'{Colors.YELLOW}{code}{Colors.RESET}')
inline_codes.append(f"{Colors.YELLOW}{code}{Colors.RESET}")
return placeholder
text = re.sub(r'`([^`]+)`', extract_inline_code, text)
text = re.sub(r"`([^`]+)`", extract_inline_code, text)
lines = text.split('\n')
lines = text.split("\n")
processed_lines = []
for line in lines:
if line.startswith('### '):
line = f'{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}'
elif line.startswith('## '):
line = f'{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}'
elif line.startswith('# '):
line = f'{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}'
elif line.startswith('> '):
line = f'{Colors.CYAN}> {line[2:]}{Colors.RESET}'
elif re.match(r'^\s*[\*\-\+]\s', line):
match = re.match(r'^(\s*)([\*\-\+])(\s+.*)', line)
if line.startswith("### "):
line = f"{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}"
elif line.startswith("## "):
line = f"{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}"
elif line.startswith("# "):
line = f"{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}"
elif line.startswith("> "):
line = f"{Colors.CYAN}> {line[2:]}{Colors.RESET}"
elif re.match(r"^\s*[\*\-\+]\s", line):
match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line)
if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
elif re.match(r'^\s*\d+\.\s', line):
match = re.match(r'^(\s*)(\d+\.)(\s+.*)', line)
elif re.match(r"^\s*\d+\.\s", line):
match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line)
if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
processed_lines.append(line)
text = '\n'.join(processed_lines)
text = "\n".join(processed_lines)
text = re.sub(r'\[(.*?)\]\((.*?)\)', f'{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}', text)
text = re.sub(r'~~(.*?)~~', f'{Colors.GRAY}\\1{Colors.RESET}', text)
text = re.sub(r'\*\*(.*?)\*\*', f'{Colors.BOLD}\\1{Colors.RESET}', text)
text = re.sub(r'__(.*?)__', f'{Colors.BOLD}\\1{Colors.RESET}', text)
text = re.sub(r'\*(.*?)\*', f'{Colors.CYAN}\\1{Colors.RESET}', text)
text = re.sub(r'_(.*?)_', f'{Colors.CYAN}\\1{Colors.RESET}', text)
text = re.sub(
r"\[(.*?)\]\((.*?)\)",
f"{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}",
text,
)
text = re.sub(r"~~(.*?)~~", f"{Colors.GRAY}\\1{Colors.RESET}", text)
text = re.sub(r"\*\*(.*?)\*\*", f"{Colors.BOLD}\\1{Colors.RESET}", text)
text = re.sub(r"__(.*?)__", f"{Colors.BOLD}\\1{Colors.RESET}", text)
text = re.sub(r"\*(.*?)\*", f"{Colors.CYAN}\\1{Colors.RESET}", text)
text = re.sub(r"_(.*?)_", f"{Colors.CYAN}\\1{Colors.RESET}", text)
for i, code in enumerate(inline_codes):
text = text.replace(f'%%INLINECODE{i}%%', code)
text = text.replace(f"%%INLINECODE{i}%%", code)
for i, block in enumerate(code_blocks):
text = text.replace(f'%%CODEBLOCK{i}%%', block)
text = text.replace(f"%%CODEBLOCK{i}%%", block)
return text

View File

@ -1,5 +1,11 @@
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
from .workflow_definition import ExecutionMode, Workflow, WorkflowStep
from .workflow_engine import WorkflowEngine
from .workflow_storage import WorkflowStorage
__all__ = ['Workflow', 'WorkflowStep', 'ExecutionMode', 'WorkflowEngine', 'WorkflowStorage']
__all__ = [
"Workflow",
"WorkflowStep",
"ExecutionMode",
"WorkflowEngine",
"WorkflowStorage",
]

View File

@ -1,12 +1,14 @@
from enum import Enum
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
class ExecutionMode(Enum):
SEQUENTIAL = "sequential"
PARALLEL = "parallel"
CONDITIONAL = "conditional"
@dataclass
class WorkflowStep:
tool_name: str
@ -20,29 +22,30 @@ class WorkflowStep:
def to_dict(self) -> Dict[str, Any]:
return {
'tool_name': self.tool_name,
'arguments': self.arguments,
'step_id': self.step_id,
'condition': self.condition,
'on_success': self.on_success,
'on_failure': self.on_failure,
'retry_count': self.retry_count,
'timeout_seconds': self.timeout_seconds
"tool_name": self.tool_name,
"arguments": self.arguments,
"step_id": self.step_id,
"condition": self.condition,
"on_success": self.on_success,
"on_failure": self.on_failure,
"retry_count": self.retry_count,
"timeout_seconds": self.timeout_seconds,
}
@staticmethod
def from_dict(data: Dict[str, Any]) -> 'WorkflowStep':
def from_dict(data: Dict[str, Any]) -> "WorkflowStep":
return WorkflowStep(
tool_name=data['tool_name'],
arguments=data['arguments'],
step_id=data['step_id'],
condition=data.get('condition'),
on_success=data.get('on_success'),
on_failure=data.get('on_failure'),
retry_count=data.get('retry_count', 0),
timeout_seconds=data.get('timeout_seconds', 300)
tool_name=data["tool_name"],
arguments=data["arguments"],
step_id=data["step_id"],
condition=data.get("condition"),
on_success=data.get("on_success"),
on_failure=data.get("on_failure"),
retry_count=data.get("retry_count", 0),
timeout_seconds=data.get("timeout_seconds", 300),
)
@dataclass
class Workflow:
name: str
@ -54,23 +57,23 @@ class Workflow:
def to_dict(self) -> Dict[str, Any]:
return {
'name': self.name,
'description': self.description,
'steps': [step.to_dict() for step in self.steps],
'execution_mode': self.execution_mode.value,
'variables': self.variables,
'tags': self.tags
"name": self.name,
"description": self.description,
"steps": [step.to_dict() for step in self.steps],
"execution_mode": self.execution_mode.value,
"variables": self.variables,
"tags": self.tags,
}
@staticmethod
def from_dict(data: Dict[str, Any]) -> 'Workflow':
def from_dict(data: Dict[str, Any]) -> "Workflow":
return Workflow(
name=data['name'],
description=data['description'],
steps=[WorkflowStep.from_dict(step) for step in data['steps']],
execution_mode=ExecutionMode(data.get('execution_mode', 'sequential')),
variables=data.get('variables', {}),
tags=data.get('tags', [])
name=data["name"],
description=data["description"],
steps=[WorkflowStep.from_dict(step) for step in data["steps"]],
execution_mode=ExecutionMode(data.get("execution_mode", "sequential")),
variables=data.get("variables", {}),
tags=data.get("tags", []),
)
def add_step(self, step: WorkflowStep):

View File

@ -1,8 +1,10 @@
import time
import re
from typing import Dict, Any, List, Callable, Optional
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
from typing import Any, Callable, Dict, List, Optional
from .workflow_definition import ExecutionMode, Workflow, WorkflowStep
class WorkflowExecutionContext:
def __init__(self):
@ -23,57 +25,66 @@ class WorkflowExecutionContext:
return self.step_results.get(step_id)
def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]):
self.execution_log.append({
'timestamp': time.time(),
'event_type': event_type,
'step_id': step_id,
'details': details
})
self.execution_log.append(
{
"timestamp": time.time(),
"event_type": event_type,
"step_id": step_id,
"details": details,
}
)
class WorkflowEngine:
def __init__(self, tool_executor: Callable, max_workers: int = 5):
self.tool_executor = tool_executor
self.max_workers = max_workers
def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool:
def _evaluate_condition(
self, condition: str, context: WorkflowExecutionContext
) -> bool:
if not condition:
return True
try:
safe_locals = {
'variables': context.variables,
'results': context.step_results
"variables": context.variables,
"results": context.step_results,
}
return eval(condition, {"__builtins__": {}}, safe_locals)
except Exception:
return False
def _substitute_variables(self, arguments: Dict[str, Any], context: WorkflowExecutionContext) -> Dict[str, Any]:
def _substitute_variables(
self, arguments: Dict[str, Any], context: WorkflowExecutionContext
) -> Dict[str, Any]:
substituted = {}
for key, value in arguments.items():
if isinstance(value, str):
pattern = r'\$\{([^}]+)\}'
pattern = r"\$\{([^}]+)\}"
matches = re.findall(pattern, value)
for match in matches:
if match.startswith('step.'):
step_id = match.split('.', 1)[1]
if match.startswith("step."):
step_id = match.split(".", 1)[1]
replacement = context.get_step_result(step_id)
if replacement is not None:
value = value.replace(f'${{{match}}}', str(replacement))
elif match.startswith('var.'):
var_name = match.split('.', 1)[1]
value = value.replace(f"${{{match}}}", str(replacement))
elif match.startswith("var."):
var_name = match.split(".", 1)[1]
replacement = context.get_variable(var_name)
if replacement is not None:
value = value.replace(f'${{{match}}}', str(replacement))
value = value.replace(f"${{{match}}}", str(replacement))
substituted[key] = value
else:
substituted[key] = value
return substituted
def _execute_step(self, step: WorkflowStep, context: WorkflowExecutionContext) -> Dict[str, Any]:
def _execute_step(
self, step: WorkflowStep, context: WorkflowExecutionContext
) -> Dict[str, Any]:
if not self._evaluate_condition(step.condition, context):
context.log_event('skipped', step.step_id, {'reason': 'condition_not_met'})
return {'status': 'skipped', 'step_id': step.step_id}
context.log_event("skipped", step.step_id, {"reason": "condition_not_met"})
return {"status": "skipped", "step_id": step.step_id}
arguments = self._substitute_variables(step.arguments, context)
@ -83,26 +94,34 @@ class WorkflowEngine:
while retry_attempts <= step.retry_count:
try:
context.log_event('executing', step.step_id, {
'tool': step.tool_name,
'arguments': arguments,
'attempt': retry_attempts + 1
})
context.log_event(
"executing",
step.step_id,
{
"tool": step.tool_name,
"arguments": arguments,
"attempt": retry_attempts + 1,
},
)
result = self.tool_executor(step.tool_name, arguments)
execution_time = time.time() - start_time
context.set_step_result(step.step_id, result)
context.log_event('completed', step.step_id, {
'execution_time': execution_time,
'result_size': len(str(result)) if result else 0
})
context.log_event(
"completed",
step.step_id,
{
"execution_time": execution_time,
"result_size": len(str(result)) if result else 0,
},
)
return {
'status': 'success',
'step_id': step.step_id,
'result': result,
'execution_time': execution_time
"status": "success",
"step_id": step.step_id,
"result": result,
"execution_time": execution_time,
}
except Exception as e:
@ -111,25 +130,26 @@ class WorkflowEngine:
if retry_attempts <= step.retry_count:
time.sleep(1 * retry_attempts)
context.log_event('failed', step.step_id, {'error': last_error})
context.log_event("failed", step.step_id, {"error": last_error})
return {
'status': 'failed',
'step_id': step.step_id,
'error': last_error,
'execution_time': time.time() - start_time
"status": "failed",
"step_id": step.step_id,
"error": last_error,
"execution_time": time.time() - start_time,
}
def _get_next_steps(self, completed_step: WorkflowStep, result: Dict[str, Any],
workflow: Workflow) -> List[WorkflowStep]:
def _get_next_steps(
self, completed_step: WorkflowStep, result: Dict[str, Any], workflow: Workflow
) -> List[WorkflowStep]:
next_steps = []
if result['status'] == 'success' and completed_step.on_success:
if result["status"] == "success" and completed_step.on_success:
for step_id in completed_step.on_success:
step = workflow.get_step(step_id)
if step:
next_steps.append(step)
elif result['status'] == 'failed' and completed_step.on_failure:
elif result["status"] == "failed" and completed_step.on_failure:
for step_id in completed_step.on_failure:
step = workflow.get_step(step_id)
if step:
@ -142,7 +162,9 @@ class WorkflowEngine:
return next_steps
def execute_workflow(self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None) -> WorkflowExecutionContext:
def execute_workflow(
self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None
) -> WorkflowExecutionContext:
context = WorkflowExecutionContext()
if initial_variables:
@ -151,7 +173,7 @@ class WorkflowEngine:
if workflow.variables:
context.variables.update(workflow.variables)
context.log_event('workflow_started', 'workflow', {'name': workflow.name})
context.log_event("workflow_started", "workflow", {"name": workflow.name})
if workflow.execution_mode == ExecutionMode.PARALLEL:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
@ -164,9 +186,11 @@ class WorkflowEngine:
step = futures[future]
try:
result = future.result()
context.log_event('step_completed', step.step_id, result)
context.log_event("step_completed", step.step_id, result)
except Exception as e:
context.log_event('step_failed', step.step_id, {'error': str(e)})
context.log_event(
"step_failed", step.step_id, {"error": str(e)}
)
else:
pending_steps = workflow.get_initial_steps()
@ -184,9 +208,13 @@ class WorkflowEngine:
next_steps = self._get_next_steps(step, result, workflow)
pending_steps.extend(next_steps)
context.log_event('workflow_completed', 'workflow', {
'total_steps': len(context.step_results),
'executed_steps': list(context.step_results.keys())
})
context.log_event(
"workflow_completed",
"workflow",
{
"total_steps": len(context.step_results),
"executed_steps": list(context.step_results.keys()),
},
)
return context

View File

@ -2,8 +2,10 @@ import json
import sqlite3
import time
from typing import List, Optional
from .workflow_definition import Workflow
class WorkflowStorage:
def __init__(self, db_path: str):
self.db_path = db_path
@ -13,7 +15,8 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS workflows (
workflow_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
@ -25,9 +28,11 @@ class WorkflowStorage:
last_execution_at INTEGER,
tags TEXT
)
''')
"""
)
cursor.execute('''
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS workflow_executions (
execution_id TEXT PRIMARY KEY,
workflow_id TEXT NOT NULL,
@ -39,17 +44,24 @@ class WorkflowStorage:
step_results TEXT,
FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id)
)
''')
"""
)
cursor.execute('''
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id)
''')
cursor.execute('''
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at)
''')
"""
)
conn.commit()
conn.close()
@ -66,12 +78,22 @@ class WorkflowStorage:
current_time = int(time.time())
tags_json = json.dumps(workflow.tags)
cursor.execute('''
cursor.execute(
"""
INSERT OR REPLACE INTO workflows
(workflow_id, name, description, workflow_data, created_at, updated_at, tags)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (workflow_id, workflow.name, workflow.description, workflow_data,
current_time, current_time, tags_json))
""",
(
workflow_id,
workflow.name,
workflow.description,
workflow_data,
current_time,
current_time,
tags_json,
),
)
conn.commit()
conn.close()
@ -82,7 +104,9 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT workflow_data FROM workflows WHERE workflow_id = ?', (workflow_id,))
cursor.execute(
"SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,)
)
row = cursor.fetchone()
conn.close()
@ -95,7 +119,7 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT workflow_data FROM workflows WHERE name = ?', (name,))
cursor.execute("SELECT workflow_data FROM workflows WHERE name = ?", (name,))
row = cursor.fetchone()
conn.close()
@ -109,29 +133,36 @@ class WorkflowStorage:
cursor = conn.cursor()
if tag:
cursor.execute('''
cursor.execute(
"""
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows
WHERE tags LIKE ?
ORDER BY name
''', (f'%"{tag}"%',))
""",
(f'%"{tag}"%',),
)
else:
cursor.execute('''
cursor.execute(
"""
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows
ORDER BY name
''')
"""
)
workflows = []
for row in cursor.fetchall():
workflows.append({
'workflow_id': row[0],
'name': row[1],
'description': row[2],
'execution_count': row[3],
'last_execution_at': row[4],
'tags': json.loads(row[5]) if row[5] else []
})
workflows.append(
{
"workflow_id": row[0],
"name": row[1],
"description": row[2],
"execution_count": row[3],
"last_execution_at": row[4],
"tags": json.loads(row[5]) if row[5] else [],
}
)
conn.close()
return workflows
@ -140,18 +171,21 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM workflows WHERE workflow_id = ?', (workflow_id,))
cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,))
deleted = cursor.rowcount > 0
cursor.execute('DELETE FROM workflow_executions WHERE workflow_id = ?', (workflow_id,))
cursor.execute(
"DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)
)
conn.commit()
conn.close()
return deleted
def save_execution(self, workflow_id: str, execution_context: 'WorkflowExecutionContext') -> str:
import hashlib
def save_execution(
self, workflow_id: str, execution_context: "WorkflowExecutionContext"
) -> str:
import uuid
execution_id = str(uuid.uuid4())[:16]
@ -159,30 +193,40 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
started_at = int(execution_context.execution_log[0]['timestamp']) if execution_context.execution_log else int(time.time())
started_at = (
int(execution_context.execution_log[0]["timestamp"])
if execution_context.execution_log
else int(time.time())
)
completed_at = int(time.time())
cursor.execute('''
cursor.execute(
"""
INSERT INTO workflow_executions
(execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
execution_id,
workflow_id,
started_at,
completed_at,
'completed',
json.dumps(execution_context.execution_log),
json.dumps(execution_context.variables),
json.dumps(execution_context.step_results)
))
""",
(
execution_id,
workflow_id,
started_at,
completed_at,
"completed",
json.dumps(execution_context.execution_log),
json.dumps(execution_context.variables),
json.dumps(execution_context.step_results),
),
)
cursor.execute('''
cursor.execute(
"""
UPDATE workflows
SET execution_count = execution_count + 1,
last_execution_at = ?
WHERE workflow_id = ?
''', (completed_at, workflow_id))
""",
(completed_at, workflow_id),
)
conn.commit()
conn.close()
@ -193,22 +237,27 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
SELECT execution_id, started_at, completed_at, status
FROM workflow_executions
WHERE workflow_id = ?
ORDER BY started_at DESC
LIMIT ?
''', (workflow_id, limit))
""",
(workflow_id, limit),
)
executions = []
for row in cursor.fetchall():
executions.append({
'execution_id': row[0],
'started_at': row[1],
'completed_at': row[2],
'status': row[3]
})
executions.append(
{
"execution_id": row[0],
"started_at": row[1],
"completed_at": row[2],
"status": row[3],
}
)
conn.close()
return executions

3
rp.py
View File

@ -2,8 +2,7 @@
# Trigger build
import sys
from pr.__main__ import main
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,8 +1,9 @@
import pytest
import os
import tempfile
from unittest.mock import MagicMock
import pytest
@pytest.fixture
def temp_dir():
@ -13,19 +14,8 @@ def temp_dir():
@pytest.fixture
def mock_api_response():
return {
'choices': [
{
'message': {
'role': 'assistant',
'content': 'Test response'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
"choices": [{"message": {"role": "assistant", "content": "Test response"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}
@ -47,7 +37,7 @@ def mock_args():
@pytest.fixture
def sample_context_file(temp_dir):
context_path = os.path.join(temp_dir, '.rcontext.txt')
with open(context_path, 'w') as f:
f.write('Sample context content\n')
context_path = os.path.join(temp_dir, ".rcontext.txt")
with open(context_path, "w") as f:
f.write("Sample context content\n")
return context_path

View File

@ -1,39 +1,53 @@
import pytest
from pr.core.advanced_context import AdvancedContextManager
def test_adaptive_context_window_simple():
mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
window = mgr.adaptive_context_window(messages, 'simple')
messages = [
{"content": "short"},
{"content": "this is a longer message with more words"},
]
window = mgr.adaptive_context_window(messages, "simple")
assert isinstance(window, int)
assert window >= 10
def test_adaptive_context_window_medium():
mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
window = mgr.adaptive_context_window(messages, 'medium')
messages = [
{"content": "short"},
{"content": "this is a longer message with more words"},
]
window = mgr.adaptive_context_window(messages, "medium")
assert isinstance(window, int)
assert window >= 20
def test_adaptive_context_window_complex():
mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
window = mgr.adaptive_context_window(messages, 'complex')
messages = [
{"content": "short"},
{"content": "this is a longer message with more words"},
]
window = mgr.adaptive_context_window(messages, "complex")
assert isinstance(window, int)
assert window >= 35
def test_analyze_message_complexity():
mgr = AdvancedContextManager()
messages = [{'content': 'hello world'}, {'content': 'hello again'}]
messages = [{"content": "hello world"}, {"content": "hello again"}]
score = mgr._analyze_message_complexity(messages)
assert 0 <= score <= 1
def test_analyze_message_complexity_empty():
mgr = AdvancedContextManager()
messages = []
score = mgr._analyze_message_complexity(messages)
assert score == 0
def test_extract_key_sentences():
mgr = AdvancedContextManager()
text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words."
@ -41,41 +55,47 @@ def test_extract_key_sentences():
assert len(sentences) <= 2
assert all(isinstance(s, str) for s in sentences)
def test_extract_key_sentences_empty():
mgr = AdvancedContextManager()
text = ""
sentences = mgr.extract_key_sentences(text, 5)
assert sentences == []
def test_advanced_summarize_messages():
mgr = AdvancedContextManager()
messages = [{'content': 'Hello'}, {'content': 'How are you?'}]
messages = [{"content": "Hello"}, {"content": "How are you?"}]
summary = mgr.advanced_summarize_messages(messages)
assert isinstance(summary, str)
def test_advanced_summarize_messages_empty():
mgr = AdvancedContextManager()
messages = []
summary = mgr.advanced_summarize_messages(messages)
assert summary == "No content to summarize."
def test_score_message_relevance():
mgr = AdvancedContextManager()
message = {'content': 'hello world'}
context = 'world hello'
message = {"content": "hello world"}
context = "world hello"
score = mgr.score_message_relevance(message, context)
assert 0 <= score <= 1
def test_score_message_relevance_no_overlap():
mgr = AdvancedContextManager()
message = {'content': 'hello'}
context = 'world'
message = {"content": "hello"}
context = "world"
score = mgr.score_message_relevance(message, context)
assert score == 0
def test_score_message_relevance_empty():
mgr = AdvancedContextManager()
message = {'content': ''}
context = ''
message = {"content": ""}
context = ""
score = mgr.score_message_relevance(message, context)
assert score == 0
assert score == 0

View File

@ -1,127 +1,213 @@
import pytest
import time
from pr.agents.agent_communication import (
AgentCommunicationBus,
AgentMessage,
MessageType,
)
from pr.agents.agent_manager import AgentInstance, AgentManager
from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles
from pr.agents.agent_manager import AgentManager, AgentInstance
from pr.agents.agent_communication import AgentCommunicationBus, AgentMessage, MessageType
def test_get_agent_role():
role = get_agent_role('coding')
role = get_agent_role("coding")
assert isinstance(role, AgentRole)
assert role.name == 'coding'
assert role.name == "coding"
def test_list_agent_roles():
roles = list_agent_roles()
assert isinstance(roles, dict)
assert len(roles) > 0
assert 'coding' in roles
assert "coding" in roles
def test_agent_role():
role = AgentRole(name='test', description='test', system_prompt='test', allowed_tools=set(), specialization_areas=[])
assert role.name == 'test'
role = AgentRole(
name="test",
description="test",
system_prompt="test",
allowed_tools=set(),
specialization_areas=[],
)
assert role.name == "test"
def test_agent_instance():
role = get_agent_role('coding')
instance = AgentInstance(agent_id='test', role=role)
assert instance.agent_id == 'test'
role = get_agent_role("coding")
instance = AgentInstance(agent_id="test", role=role)
assert instance.agent_id == "test"
assert instance.role == role
def test_agent_manager_init():
mgr = AgentManager(':memory:', None)
mgr = AgentManager(":memory:", None)
assert mgr is not None
def test_agent_manager_create_agent():
mgr = AgentManager(':memory:', None)
agent = mgr.create_agent('coding', 'test_agent')
mgr = AgentManager(":memory:", None)
agent = mgr.create_agent("coding", "test_agent")
assert agent is not None
def test_agent_manager_get_agent():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test_agent')
agent = mgr.get_agent('test_agent')
mgr = AgentManager(":memory:", None)
mgr.create_agent("coding", "test_agent")
agent = mgr.get_agent("test_agent")
assert isinstance(agent, AgentInstance)
def test_agent_manager_remove_agent():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test_agent')
mgr.remove_agent('test_agent')
agent = mgr.get_agent('test_agent')
mgr = AgentManager(":memory:", None)
mgr.create_agent("coding", "test_agent")
mgr.remove_agent("test_agent")
agent = mgr.get_agent("test_agent")
assert agent is None
def test_agent_manager_send_agent_message():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'a')
mgr.create_agent('coding', 'b')
mgr.send_agent_message('a', 'b', 'test')
mgr = AgentManager(":memory:", None)
mgr.create_agent("coding", "a")
mgr.create_agent("coding", "b")
mgr.send_agent_message("a", "b", "test")
assert True
def test_agent_manager_get_agent_messages():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test')
messages = mgr.get_agent_messages('test')
mgr = AgentManager(":memory:", None)
mgr.create_agent("coding", "test")
messages = mgr.get_agent_messages("test")
assert isinstance(messages, list)
def test_agent_manager_get_session_summary():
mgr = AgentManager(':memory:', None)
mgr = AgentManager(":memory:", None)
summary = mgr.get_session_summary()
assert isinstance(summary, str)
def test_agent_manager_collaborate_agents():
mgr = AgentManager(':memory:', None)
result = mgr.collaborate_agents('orchestrator', 'task', ['coding', 'research'])
mgr = AgentManager(":memory:", None)
result = mgr.collaborate_agents("orchestrator", "task", ["coding", "research"])
assert result is not None
def test_agent_manager_execute_agent_task():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test')
result = mgr.execute_agent_task('test', 'task')
mgr = AgentManager(":memory:", None)
mgr.create_agent("coding", "test")
result = mgr.execute_agent_task("test", "task")
assert result is not None
def test_agent_manager_clear_session():
mgr = AgentManager(':memory:', None)
mgr = AgentManager(":memory:", None)
mgr.clear_session()
assert True
def test_agent_message():
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
assert msg.from_agent == 'a'
msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
assert msg.from_agent == "a"
def test_agent_message_to_dict():
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
d = msg.to_dict()
assert isinstance(d, dict)
def test_agent_message_from_dict():
d = {'from_agent': 'a', 'to_agent': 'b', 'message_type': 'request', 'content': 'test', 'metadata': {}, 'timestamp': 1.0, 'message_id': 'id'}
d = {
"from_agent": "a",
"to_agent": "b",
"message_type": "request",
"content": "test",
"metadata": {},
"timestamp": 1.0,
"message_id": "id",
}
msg = AgentMessage.from_dict(d)
assert isinstance(msg, AgentMessage)
def test_agent_communication_bus_init():
bus = AgentCommunicationBus(':memory:')
bus = AgentCommunicationBus(":memory:")
assert bus is not None
def test_agent_communication_bus_send_message():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg)
assert True
def test_agent_communication_bus_receive_messages():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg)
messages = bus.receive_messages('b')
messages = bus.receive_messages("b")
assert len(messages) == 1
def test_agent_communication_bus_get_conversation_history():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg)
history = bus.get_conversation_history('a', 'b')
history = bus.get_conversation_history("a", "b")
assert len(history) == 1
def test_agent_communication_bus_mark_as_read():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus = AgentCommunicationBus(":memory:")
msg = AgentMessage(
from_agent="a",
to_agent="b",
message_type=MessageType.REQUEST,
content="test",
metadata={},
timestamp=1.0,
message_id="id",
)
bus.send_message(msg)
bus.mark_as_read(msg.message_id)
assert True
assert True

View File

@ -1,63 +1,65 @@
import unittest
from unittest.mock import patch, MagicMock
import json
import urllib.error
from unittest.mock import MagicMock, patch
from pr.core.api import call_api, list_models
class TestApi(unittest.TestCase):
@patch('pr.core.api.urllib.request.urlopen')
@patch('pr.core.api.auto_slim_messages')
@patch("pr.core.api.urllib.request.urlopen")
@patch("pr.core.api.auto_slim_messages")
def test_call_api_success(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_response = MagicMock()
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
mock_urlopen.return_value.__enter__.return_value = mock_response
result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}])
result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}])
self.assertIn('choices', result)
self.assertIn("choices", result)
mock_urlopen.assert_called_once()
@patch('urllib.request.urlopen')
@patch('pr.core.api.auto_slim_messages')
@patch("urllib.request.urlopen")
@patch("pr.core.api.auto_slim_messages")
def test_call_api_http_error(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock())
mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_urlopen.side_effect = urllib.error.HTTPError(
"http://url", 500, "error", None, MagicMock()
)
result = call_api([], 'model', 'http://url', 'key', False, [])
result = call_api([], "model", "http://url", "key", False, [])
self.assertIn('error', result)
self.assertIn("error", result)
@patch('urllib.request.urlopen')
@patch('pr.core.api.auto_slim_messages')
@patch("urllib.request.urlopen")
@patch("pr.core.api.auto_slim_messages")
def test_call_api_general_error(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
mock_urlopen.side_effect = Exception('test error')
mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_urlopen.side_effect = Exception("test error")
result = call_api([], 'model', 'http://url', 'key', False, [])
result = call_api([], "model", "http://url", "key", False, [])
self.assertIn('error', result)
self.assertIn("error", result)
@patch('urllib.request.urlopen')
@patch("urllib.request.urlopen")
def test_list_models_success(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = b'{"data": [{"id": "model1"}]}'
mock_urlopen.return_value.__enter__.return_value = mock_response
result = list_models('http://url', 'key')
result = list_models("http://url", "key")
self.assertEqual(result, [{'id': 'model1'}])
self.assertEqual(result, [{"id": "model1"}])
@patch('urllib.request.urlopen')
@patch("urllib.request.urlopen")
def test_list_models_error(self, mock_urlopen):
mock_urlopen.side_effect = Exception('error')
mock_urlopen.side_effect = Exception("error")
result = list_models('http://url', 'key')
result = list_models("http://url", "key")
self.assertIn('error', result)
self.assertIn("error", result)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,7 +1,6 @@
import unittest
from unittest.mock import patch, MagicMock
import tempfile
import os
from unittest.mock import MagicMock, patch
from pr.core.assistant import Assistant, process_message
@ -12,83 +11,106 @@ class TestAssistant(unittest.TestCase):
self.args.verbose = False
self.args.debug = False
self.args.no_syntax = False
self.args.model = 'test-model'
self.args.api_url = 'test-url'
self.args.model_list_url = 'test-list-url'
self.args.model = "test-model"
self.args.api_url = "test-url"
self.args.model_list_url = "test-list-url"
@patch('sqlite3.connect')
@patch('os.environ.get')
@patch('pr.core.context.init_system_message')
@patch('pr.core.enhanced_assistant.EnhancedAssistant')
@patch("sqlite3.connect")
@patch("os.environ.get")
@patch("pr.core.context.init_system_message")
@patch("pr.core.enhanced_assistant.EnhancedAssistant")
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default)
mock_env.side_effect = lambda key, default: {
"OPENROUTER_API_KEY": "key",
"AI_MODEL": "model",
"API_URL": "url",
"MODEL_LIST_URL": "list",
"USE_TOOLS": "1",
"STRICT_MODE": "0",
}.get(key, default)
mock_conn = MagicMock()
mock_sqlite.return_value = mock_conn
mock_init_sys.return_value = {'role': 'system', 'content': 'sys'}
mock_init_sys.return_value = {"role": "system", "content": "sys"}
assistant = Assistant(self.args)
self.assertEqual(assistant.api_key, 'key')
self.assertEqual(assistant.model, 'test-model')
self.assertEqual(assistant.api_key, "key")
self.assertEqual(assistant.model, "test-model")
mock_sqlite.assert_called_once()
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.render_markdown')
@patch("pr.core.assistant.call_api")
@patch("pr.core.assistant.render_markdown")
def test_process_response_no_tools(self, mock_render, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.syntax_highlighting = True
mock_render.return_value = 'rendered'
mock_render.return_value = "rendered"
response = {'choices': [{'message': {'content': 'content'}}]}
response = {"choices": [{"message": {"content": "content"}}]}
result = Assistant.process_response(assistant, response)
self.assertEqual(result, 'rendered')
assistant.messages.append.assert_called_with({'content': 'content'})
self.assertEqual(result, "rendered")
assistant.messages.append.assert_called_with({"content": "content"})
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.render_markdown')
@patch('pr.core.assistant.get_tools_definition')
@patch("pr.core.assistant.call_api")
@patch("pr.core.assistant.render_markdown")
@patch("pr.core.assistant.get_tools_definition")
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.syntax_highlighting = True
assistant.use_tools = True
assistant.model = 'model'
assistant.api_url = 'url'
assistant.api_key = 'key'
assistant.model = "model"
assistant.api_url = "url"
assistant.api_key = "key"
mock_tools_def.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]}
mock_call.return_value = {"choices": [{"message": {"content": "follow"}}]}
response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]}
response = {
"choices": [
{
"message": {
"tool_calls": [
{"id": "1", "function": {"name": "test", "arguments": "{}"}}
]
}
}
]
}
with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]):
result = Assistant.process_response(assistant, response)
with patch.object(
assistant,
"execute_tool_calls",
return_value=[{"role": "tool", "content": "result"}],
):
Assistant.process_response(assistant, response)
mock_call.assert_called()
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.get_tools_definition')
@patch("pr.core.assistant.call_api")
@patch("pr.core.assistant.get_tools_definition")
def test_process_message(self, mock_tools, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.use_tools = True
assistant.model = 'model'
assistant.api_url = 'url'
assistant.api_key = 'key'
assistant.model = "model"
assistant.api_url = "url"
assistant.api_key = "key"
mock_tools.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]}
mock_call.return_value = {"choices": [{"message": {"content": "response"}}]}
with patch('pr.core.assistant.render_markdown', return_value='rendered'):
with patch('builtins.print'):
process_message(assistant, 'test message')
with patch("pr.core.assistant.render_markdown", return_value="rendered"):
with patch("builtins.print"):
process_message(assistant, "test message")
assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'})
assistant.messages.append.assert_called_with(
{"role": "user", "content": "test message"}
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,31 +1,30 @@
import pytest
from pr import config
class TestConfig:
def test_default_model_exists(self):
assert hasattr(config, 'DEFAULT_MODEL')
assert hasattr(config, "DEFAULT_MODEL")
assert isinstance(config.DEFAULT_MODEL, str)
assert len(config.DEFAULT_MODEL) > 0
def test_api_url_exists(self):
assert hasattr(config, 'DEFAULT_API_URL')
assert config.DEFAULT_API_URL.startswith('http')
assert hasattr(config, "DEFAULT_API_URL")
assert config.DEFAULT_API_URL.startswith("http")
def test_file_paths_exist(self):
assert hasattr(config, 'DB_PATH')
assert hasattr(config, 'LOG_FILE')
assert hasattr(config, 'HISTORY_FILE')
assert hasattr(config, "DB_PATH")
assert hasattr(config, "LOG_FILE")
assert hasattr(config, "HISTORY_FILE")
def test_autonomous_config(self):
assert hasattr(config, 'MAX_AUTONOMOUS_ITERATIONS')
assert hasattr(config, "MAX_AUTONOMOUS_ITERATIONS")
assert config.MAX_AUTONOMOUS_ITERATIONS > 0
assert hasattr(config, 'CONTEXT_COMPRESSION_THRESHOLD')
assert hasattr(config, "CONTEXT_COMPRESSION_THRESHOLD")
assert config.CONTEXT_COMPRESSION_THRESHOLD > 0
def test_language_keywords(self):
assert hasattr(config, 'LANGUAGE_KEYWORDS')
assert 'python' in config.LANGUAGE_KEYWORDS
assert isinstance(config.LANGUAGE_KEYWORDS['python'], list)
assert hasattr(config, "LANGUAGE_KEYWORDS")
assert "python" in config.LANGUAGE_KEYWORDS
assert isinstance(config.LANGUAGE_KEYWORDS["python"], list)

View File

@ -1,56 +1,71 @@
import pytest
from unittest.mock import patch, mock_open
import os
from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config
from unittest.mock import mock_open, patch
from pr.core.config_loader import (
_load_config_file,
_parse_value,
create_default_config,
load_config,
)
def test_parse_value_string():
assert _parse_value('hello') == 'hello'
assert _parse_value("hello") == "hello"
def test_parse_value_int():
assert _parse_value('123') == 123
assert _parse_value("123") == 123
def test_parse_value_float():
assert _parse_value('1.23') == 1.23
assert _parse_value("1.23") == 1.23
def test_parse_value_bool_true():
assert _parse_value('true') == True
assert _parse_value("true") == True
def test_parse_value_bool_false():
assert _parse_value('false') == False
assert _parse_value("false") == False
def test_parse_value_bool_upper():
assert _parse_value('TRUE') == True
assert _parse_value("TRUE") == True
@patch('os.path.exists', return_value=False)
@patch("os.path.exists", return_value=False)
def test_load_config_file_not_exists(mock_exists):
config = _load_config_file('test.ini')
config = _load_config_file("test.ini")
assert config == {}
@patch('os.path.exists', return_value=True)
@patch('configparser.ConfigParser')
@patch("os.path.exists", return_value=True)
@patch("configparser.ConfigParser")
def test_load_config_file_exists(mock_parser_class, mock_exists):
mock_parser = mock_parser_class.return_value
mock_parser.sections.return_value = ['api']
mock_parser.items.return_value = [('key', 'value')]
config = _load_config_file('test.ini')
assert 'api' in config
assert config['api']['key'] == 'value'
mock_parser.sections.return_value = ["api"]
mock_parser.items.return_value = [("key", "value")]
config = _load_config_file("test.ini")
assert "api" in config
assert config["api"]["key"] == "value"
@patch('pr.core.config_loader._load_config_file')
@patch("pr.core.config_loader._load_config_file")
def test_load_config(mock_load):
mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}]
mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}]
config = load_config()
assert config['api']['key'] == 'local'
assert config["api"]["key"] == "local"
@patch('builtins.open', new_callable=mock_open)
@patch("builtins.open", new_callable=mock_open)
def test_create_default_config(mock_file):
result = create_default_config('test.ini')
result = create_default_config("test.ini")
assert result == True
mock_file.assert_called_once_with('test.ini', 'w')
mock_file.assert_called_once_with("test.ini", "w")
handle = mock_file()
handle.write.assert_called_once()
@patch('builtins.open', side_effect=Exception('error'))
@patch("builtins.open", side_effect=Exception("error"))
def test_create_default_config_error(mock_file):
result = create_default_config('test.ini')
assert result == False
result = create_default_config("test.ini")
assert result == False

View File

@ -1,30 +1,29 @@
import pytest
from pr.core.context import should_compress_context, compress_context
from pr.config import RECENT_MESSAGES_TO_KEEP
from pr.core.context import compress_context, should_compress_context
class TestContextManagement:
def test_should_compress_context_below_threshold(self):
messages = [{'role': 'user', 'content': 'test'}] * 10
messages = [{"role": "user", "content": "test"}] * 10
assert should_compress_context(messages) is False
def test_should_compress_context_above_threshold(self):
messages = [{'role': 'user', 'content': 'test'}] * 35
messages = [{"role": "user", "content": "test"}] * 35
assert should_compress_context(messages) is True
def test_compress_context_preserves_system_message(self):
messages = [
{'role': 'system', 'content': 'System prompt'},
{'role': 'user', 'content': 'Hello'},
{'role': 'assistant', 'content': 'Hi'},
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
] * 40 # Ensure compression
compressed = compress_context(messages)
assert compressed[0]['role'] == 'system'
assert 'System prompt' in compressed[0]['content']
assert compressed[0]["role"] == "system"
assert "System prompt" in compressed[0]["content"]
def test_compress_context_keeps_recent_messages(self):
messages = [{'role': 'user', 'content': f'msg{i}'} for i in range(40)]
messages = [{"role": "user", "content": f"msg{i}"} for i in range(40)]
compressed = compress_context(messages)
# Should keep recent messages
recent = compressed[-RECENT_MESSAGES_TO_KEEP:]
@ -32,4 +31,4 @@ class TestContextManagement:
# Check that the messages are the most recent ones
for i, msg in enumerate(recent):
expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i
assert msg['content'] == f'msg{expected_index}'
assert msg["content"] == f"msg{expected_index}"

View File

@ -1,89 +1,97 @@
import pytest
from unittest.mock import MagicMock
from pr.core.enhanced_assistant import EnhancedAssistant
def test_enhanced_assistant_init():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assert assistant.base == mock_base
assert assistant.current_conversation_id is not None
def test_enhanced_call_api_with_cache():
mock_base = MagicMock()
mock_base.model = 'test-model'
mock_base.api_url = 'http://test'
mock_base.api_key = 'key'
mock_base.model = "test-model"
mock_base.api_url = "http://test"
mock_base.api_key = "key"
mock_base.use_tools = False
mock_base.verbose = False
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock()
assistant.api_cache.get.return_value = {'cached': True}
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
assert result == {'cached': True}
assistant.api_cache.get.return_value = {"cached": True}
result = assistant.enhanced_call_api([{"role": "user", "content": "test"}])
assert result == {"cached": True}
assistant.api_cache.get.assert_called_once()
def test_enhanced_call_api_without_cache():
mock_base = MagicMock()
mock_base.model = 'test-model'
mock_base.api_url = 'http://test'
mock_base.api_key = 'key'
mock_base.model = "test-model"
mock_base.api_url = "http://test"
mock_base.api_key = "key"
mock_base.use_tools = False
mock_base.verbose = False
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = None
# It will try to call API and fail with network error, but that's expected
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
assert 'error' in result
result = assistant.enhanced_call_api([{"role": "user", "content": "test"}])
assert "error" in result
def test_execute_workflow_not_found():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.workflow_storage = MagicMock()
assistant.workflow_storage.load_workflow_by_name.return_value = None
result = assistant.execute_workflow('nonexistent')
assert 'error' in result
result = assistant.execute_workflow("nonexistent")
assert "error" in result
def test_create_agent():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.agent_manager = MagicMock()
assistant.agent_manager.create_agent.return_value = 'agent_id'
result = assistant.create_agent('role')
assert result == 'agent_id'
assistant.agent_manager.create_agent.return_value = "agent_id"
result = assistant.create_agent("role")
assert result == "agent_id"
def test_search_knowledge():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.knowledge_store = MagicMock()
assistant.knowledge_store.search_entries.return_value = [{'result': True}]
result = assistant.search_knowledge('query')
assert result == [{'result': True}]
assistant.knowledge_store.search_entries.return_value = [{"result": True}]
result = assistant.search_knowledge("query")
assert result == [{"result": True}]
def test_get_cache_statistics():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock()
assistant.api_cache.get_statistics.return_value = {'hits': 10}
assistant.api_cache.get_statistics.return_value = {"hits": 10}
assistant.tool_cache = MagicMock()
assistant.tool_cache.get_statistics.return_value = {'misses': 5}
assistant.tool_cache.get_statistics.return_value = {"misses": 5}
stats = assistant.get_cache_statistics()
assert 'api_cache' in stats
assert 'tool_cache' in stats
assert "api_cache" in stats
assert "tool_cache" in stats
def test_clear_caches():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock()
assistant.tool_cache = MagicMock()
assistant.clear_caches()
assistant.api_cache.clear_all.assert_called_once()
assistant.tool_cache.clear_all.assert_called_once()

View File

@ -1,24 +1,25 @@
import pytest
import tempfile
import os
from pr.core.logging import setup_logging, get_logger
from pr.core.logging import get_logger, setup_logging
def test_setup_logging_basic():
logger = setup_logging(verbose=False)
assert logger.name == 'pr'
assert logger.name == "pr"
assert logger.level == 20 # INFO
def test_setup_logging_verbose():
logger = setup_logging(verbose=True)
assert logger.name == 'pr'
assert logger.name == "pr"
assert logger.level == 10 # DEBUG
# Should have console handler
assert len(logger.handlers) >= 2
def test_get_logger_default():
logger = get_logger()
assert logger.name == 'pr'
assert logger.name == "pr"
def test_get_logger_named():
logger = get_logger('test')
assert logger.name == 'pr.test'
logger = get_logger("test")
assert logger.name == "pr.test"

View File

@ -1,118 +1,134 @@
import pytest
from unittest.mock import patch
import sys
import pytest
from pr.__main__ import main
def test_main_version(capsys):
with patch('sys.argv', ['pr', '--version']):
with patch("sys.argv", ["pr", "--version"]):
with pytest.raises(SystemExit):
main()
captured = capsys.readouterr()
assert 'PR Assistant' in captured.out
assert "PR Assistant" in captured.out
def test_main_create_config_success(capsys):
with patch('pr.core.config_loader.create_default_config', return_value=True):
with patch('sys.argv', ['pr', '--create-config']):
with patch("pr.core.config_loader.create_default_config", return_value=True):
with patch("sys.argv", ["pr", "--create-config"]):
main()
captured = capsys.readouterr()
assert 'Configuration file created' in captured.out
assert "Configuration file created" in captured.out
def test_main_create_config_fail(capsys):
with patch('pr.core.config_loader.create_default_config', return_value=False):
with patch('sys.argv', ['pr', '--create-config']):
with patch("pr.core.config_loader.create_default_config", return_value=False):
with patch("sys.argv", ["pr", "--create-config"]):
main()
captured = capsys.readouterr()
assert 'Error creating configuration file' in captured.err
assert "Error creating configuration file" in captured.err
def test_main_list_sessions_no_sessions(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.list_sessions.return_value = []
with patch('sys.argv', ['pr', '--list-sessions']):
with patch("sys.argv", ["pr", "--list-sessions"]):
main()
captured = capsys.readouterr()
assert 'No saved sessions found' in captured.out
assert "No saved sessions found" in captured.out
def test_main_list_sessions_with_sessions(capsys):
sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}]
with patch('pr.core.session.SessionManager') as mock_sm:
sessions = [{"name": "test", "created_at": "2023-01-01", "message_count": 5}]
with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.list_sessions.return_value = sessions
with patch('sys.argv', ['pr', '--list-sessions']):
with patch("sys.argv", ["pr", "--list-sessions"]):
main()
captured = capsys.readouterr()
assert 'Found 1 saved sessions' in captured.out
assert 'test' in captured.out
assert "Found 1 saved sessions" in captured.out
assert "test" in captured.out
def test_main_delete_session_success(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.delete_session.return_value = True
with patch('sys.argv', ['pr', '--delete-session', 'test']):
with patch("sys.argv", ["pr", "--delete-session", "test"]):
main()
captured = capsys.readouterr()
assert "Session 'test' deleted" in captured.out
def test_main_delete_session_fail(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.delete_session.return_value = False
with patch('sys.argv', ['pr', '--delete-session', 'test']):
with patch("sys.argv", ["pr", "--delete-session", "test"]):
main()
captured = capsys.readouterr()
assert "Error deleting session 'test'" in captured.err
def test_main_export_session_json(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.export_session.return_value = True
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']):
with patch("sys.argv", ["pr", "--export-session", "test", "output.json"]):
main()
captured = capsys.readouterr()
assert 'Session exported to output.json' in captured.out
assert "Session exported to output.json" in captured.out
def test_main_export_session_md(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
with patch("pr.core.session.SessionManager") as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.export_session.return_value = True
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']):
with patch("sys.argv", ["pr", "--export-session", "test", "output.md"]):
main()
captured = capsys.readouterr()
assert 'Session exported to output.md' in captured.out
assert "Session exported to output.md" in captured.out
def test_main_usage(capsys):
usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01}
with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage):
with patch('sys.argv', ['pr', '--usage']):
usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01}
with patch(
"pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage
):
with patch("sys.argv", ["pr", "--usage"]):
main()
captured = capsys.readouterr()
assert 'Total Usage Statistics' in captured.out
assert 'Requests: 10' in captured.out
assert "Total Usage Statistics" in captured.out
assert "Requests: 10" in captured.out
def test_main_plugins_no_plugins(capsys):
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
with patch("pr.plugins.loader.PluginLoader") as mock_loader:
mock_instance = mock_loader.return_value
mock_instance.load_plugins.return_value = None
mock_instance.list_loaded_plugins.return_value = []
with patch('sys.argv', ['pr', '--plugins']):
with patch("sys.argv", ["pr", "--plugins"]):
main()
captured = capsys.readouterr()
assert 'No plugins loaded' in captured.out
assert "No plugins loaded" in captured.out
def test_main_plugins_with_plugins(capsys):
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
with patch("pr.plugins.loader.PluginLoader") as mock_loader:
mock_instance = mock_loader.return_value
mock_instance.load_plugins.return_value = None
mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2']
with patch('sys.argv', ['pr', '--plugins']):
mock_instance.list_loaded_plugins.return_value = ["plugin1", "plugin2"]
with patch("sys.argv", ["pr", "--plugins"]):
main()
captured = capsys.readouterr()
assert 'Loaded 2 plugins' in captured.out
assert "Loaded 2 plugins" in captured.out
def test_main_run_assistant():
with patch('pr.__main__.Assistant') as mock_assistant:
with patch("pr.__main__.Assistant") as mock_assistant:
mock_instance = mock_assistant.return_value
with patch('sys.argv', ['pr', 'test message']):
with patch("sys.argv", ["pr", "test message"]):
main()
mock_assistant.assert_called_once()
mock_instance.run.assert_called_once()
mock_instance.run.assert_called_once()

View File

@ -1,116 +1,131 @@
import pytest
import tempfile
import os
import json
import os
import pytest
from pr.core.session import SessionManager
@pytest.fixture
def temp_sessions_dir(tmp_path, monkeypatch):
from pr.core import session
original_dir = session.SESSIONS_DIR
monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path))
monkeypatch.setattr(session, "SESSIONS_DIR", str(tmp_path))
# Clean any existing files
import shutil
if os.path.exists(str(tmp_path)):
shutil.rmtree(str(tmp_path))
os.makedirs(str(tmp_path), exist_ok=True)
yield tmp_path
monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir)
monkeypatch.setattr(session, "SESSIONS_DIR", original_dir)
def test_session_manager_init(temp_sessions_dir):
manager = SessionManager()
SessionManager()
assert os.path.exists(temp_sessions_dir)
def test_save_and_load_session(temp_sessions_dir):
manager = SessionManager()
name = "test_session"
messages = [{"role": "user", "content": "Hello"}]
metadata = {"test": True}
assert manager.save_session(name, messages, metadata)
loaded = manager.load_session(name)
assert loaded is not None
assert loaded['name'] == name
assert loaded['messages'] == messages
assert loaded['metadata'] == metadata
assert loaded["name"] == name
assert loaded["messages"] == messages
assert loaded["metadata"] == metadata
def test_load_nonexistent_session(temp_sessions_dir):
manager = SessionManager()
loaded = manager.load_session("nonexistent")
assert loaded is None
def test_list_sessions(temp_sessions_dir):
manager = SessionManager()
# Save a session
manager.save_session("session1", [{"role": "user", "content": "Hi"}])
manager.save_session("session2", [{"role": "user", "content": "Hello"}])
sessions = manager.list_sessions()
assert len(sessions) == 2
assert sessions[0]['name'] == "session2" # sorted by created_at desc
assert sessions[0]["name"] == "session2" # sorted by created_at desc
def test_delete_session(temp_sessions_dir):
manager = SessionManager()
name = "to_delete"
manager.save_session(name, [{"role": "user", "content": "Test"}])
assert manager.delete_session(name)
assert manager.load_session(name) is None
def test_delete_nonexistent_session(temp_sessions_dir):
manager = SessionManager()
assert not manager.delete_session("nonexistent")
def test_export_session_json(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "export_test"
messages = [{"role": "user", "content": "Export me"}]
manager.save_session(name, messages)
output_path = tmp_path / "exported.json"
assert manager.export_session(name, str(output_path), 'json')
assert manager.export_session(name, str(output_path), "json")
assert output_path.exists()
with open(output_path) as f:
data = json.load(f)
assert data['name'] == name
assert data["name"] == name
def test_export_session_markdown(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "export_md"
messages = [{"role": "user", "content": "Markdown export"}]
manager.save_session(name, messages)
output_path = tmp_path / "exported.md"
assert manager.export_session(name, str(output_path), 'markdown')
assert manager.export_session(name, str(output_path), "markdown")
assert output_path.exists()
content = output_path.read_text()
assert "# Session: export_md" in content
def test_export_session_txt(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "export_txt"
messages = [{"role": "user", "content": "Text export"}]
manager.save_session(name, messages)
output_path = tmp_path / "exported.txt"
assert manager.export_session(name, str(output_path), 'txt')
assert manager.export_session(name, str(output_path), "txt")
assert output_path.exists()
content = output_path.read_text()
assert "Session: export_txt" in content
def test_export_nonexistent_session(temp_sessions_dir, tmp_path):
manager = SessionManager()
output_path = tmp_path / "nonexistent.json"
assert not manager.export_session("nonexistent", str(output_path), 'json')
assert not manager.export_session("nonexistent", str(output_path), "json")
def test_export_unsupported_format(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "test"
manager.save_session(name, [{"role": "user", "content": "Test"}])
output_path = tmp_path / "test.unsupported"
assert not manager.export_session(name, str(output_path), 'unsupported')
assert not manager.export_session(name, str(output_path), "unsupported")

View File

@ -1,69 +1,68 @@
import pytest
import os
import tempfile
from pr.tools.filesystem import read_file, write_file, list_directory, search_replace
from pr.tools.patch import apply_patch, create_diff
from pr.tools.base import get_tools_definition
from pr.tools.filesystem import list_directory, read_file, search_replace, write_file
from pr.tools.patch import apply_patch, create_diff
class TestFilesystemTools:
def test_write_and_read_file(self, temp_dir):
filepath = os.path.join(temp_dir, 'test.txt')
content = 'Hello, World!'
filepath = os.path.join(temp_dir, "test.txt")
content = "Hello, World!"
write_result = write_file(filepath, content)
assert write_result['status'] == 'success'
assert write_result["status"] == "success"
read_result = read_file(filepath)
assert read_result['status'] == 'success'
assert content in read_result['content']
assert read_result["status"] == "success"
assert content in read_result["content"]
def test_read_nonexistent_file(self):
result = read_file('/nonexistent/path/file.txt')
assert result['status'] == 'error'
result = read_file("/nonexistent/path/file.txt")
assert result["status"] == "error"
def test_list_directory(self, temp_dir):
test_file = os.path.join(temp_dir, 'testfile.txt')
with open(test_file, 'w') as f:
f.write('test')
test_file = os.path.join(temp_dir, "testfile.txt")
with open(test_file, "w") as f:
f.write("test")
result = list_directory(temp_dir)
assert result['status'] == 'success'
assert any(item['name'] == 'testfile.txt' for item in result['items'])
assert result["status"] == "success"
assert any(item["name"] == "testfile.txt" for item in result["items"])
def test_search_replace(self, temp_dir):
filepath = os.path.join(temp_dir, 'test.txt')
content = 'Hello, World!'
with open(filepath, 'w') as f:
filepath = os.path.join(temp_dir, "test.txt")
content = "Hello, World!"
with open(filepath, "w") as f:
f.write(content)
result = search_replace(filepath, 'World', 'Universe')
assert result['status'] == 'success'
result = search_replace(filepath, "World", "Universe")
assert result["status"] == "success"
read_result = read_file(filepath)
assert 'Hello, Universe!' in read_result['content']
assert "Hello, Universe!" in read_result["content"]
class TestPatchTools:
def test_create_diff(self, temp_dir):
file1 = os.path.join(temp_dir, 'file1.txt')
file2 = os.path.join(temp_dir, 'file2.txt')
with open(file1, 'w') as f:
f.write('line1\nline2\nline3\n')
with open(file2, 'w') as f:
f.write('line1\nline2 modified\nline3\n')
file1 = os.path.join(temp_dir, "file1.txt")
file2 = os.path.join(temp_dir, "file2.txt")
with open(file1, "w") as f:
f.write("line1\nline2\nline3\n")
with open(file2, "w") as f:
f.write("line1\nline2 modified\nline3\n")
result = create_diff(file1, file2)
assert result['status'] == 'success'
assert 'line2' in result['diff']
assert 'line2 modified' in result['diff']
assert result["status"] == "success"
assert "line2" in result["diff"]
assert "line2 modified" in result["diff"]
def test_apply_patch(self, temp_dir):
filepath = os.path.join(temp_dir, 'file.txt')
with open(filepath, 'w') as f:
f.write('line1\nline2\nline3\n')
filepath = os.path.join(temp_dir, "file.txt")
with open(filepath, "w") as f:
f.write("line1\nline2\nline3\n")
# Create a simple patch
patch_content = """--- a/file.txt
@ -75,10 +74,10 @@ class TestPatchTools:
line3
"""
result = apply_patch(filepath, patch_content)
assert result['status'] == 'success'
assert result["status"] == "success"
read_result = read_file(filepath)
assert 'line2 modified' in read_result['content']
assert "line2 modified" in read_result["content"]
class TestToolDefinitions:
@ -92,27 +91,27 @@ class TestToolDefinitions:
tools = get_tools_definition()
for tool in tools:
assert 'type' in tool
assert tool['type'] == 'function'
assert 'function' in tool
assert "type" in tool
assert tool["type"] == "function"
assert "function" in tool
func = tool['function']
assert 'name' in func
assert 'description' in func
assert 'parameters' in func
func = tool["function"]
assert "name" in func
assert "description" in func
assert "parameters" in func
def test_filesystem_tools_present(self):
tools = get_tools_definition()
tool_names = [t['function']['name'] for t in tools]
tool_names = [t["function"]["name"] for t in tools]
assert 'read_file' in tool_names
assert 'write_file' in tool_names
assert 'list_directory' in tool_names
assert 'search_replace' in tool_names
assert "read_file" in tool_names
assert "write_file" in tool_names
assert "list_directory" in tool_names
assert "search_replace" in tool_names
def test_patch_tools_present(self):
tools = get_tools_definition()
tool_names = [t['function']['name'] for t in tools]
tool_names = [t["function"]["name"] for t in tools]
assert 'apply_patch' in tool_names
assert 'create_diff' in tool_names
assert "apply_patch" in tool_names
assert "create_diff" in tool_names

View File

@ -1,86 +1,110 @@
import pytest
import tempfile
import os
import json
import os
import pytest
from pr.core.usage_tracker import UsageTracker
@pytest.fixture
def temp_usage_file(tmp_path, monkeypatch):
from pr.core import usage_tracker
original_file = usage_tracker.USAGE_DB_FILE
temp_file = str(tmp_path / "usage.json")
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file)
monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", temp_file)
yield temp_file
if os.path.exists(temp_file):
os.remove(temp_file)
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file)
monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", original_file)
def test_usage_tracker_init():
tracker = UsageTracker()
summary = tracker.get_session_summary()
assert summary['requests'] == 0
assert summary['total_tokens'] == 0
assert summary['estimated_cost'] == 0.0
assert summary["requests"] == 0
assert summary["total_tokens"] == 0
assert summary["estimated_cost"] == 0.0
def test_track_request_known_model():
tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50)
tracker.track_request("gpt-3.5-turbo", 100, 50)
summary = tracker.get_session_summary()
assert summary['requests'] == 1
assert summary['input_tokens'] == 100
assert summary['output_tokens'] == 50
assert summary['total_tokens'] == 150
assert 'gpt-3.5-turbo' in summary['models_used']
assert summary["requests"] == 1
assert summary["input_tokens"] == 100
assert summary["output_tokens"] == 50
assert summary["total_tokens"] == 150
assert "gpt-3.5-turbo" in summary["models_used"]
# Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125
assert abs(summary['estimated_cost'] - 0.000125) < 1e-6
assert abs(summary["estimated_cost"] - 0.000125) < 1e-6
def test_track_request_unknown_model():
tracker = UsageTracker()
tracker.track_request('unknown-model', 100, 50)
tracker.track_request("unknown-model", 100, 50)
summary = tracker.get_session_summary()
assert summary['requests'] == 1
assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0
assert summary["requests"] == 1
assert summary["estimated_cost"] == 0.0 # Unknown model, cost 0
def test_track_request_multiple():
tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50)
tracker.track_request('gpt-4', 200, 100)
tracker.track_request("gpt-3.5-turbo", 100, 50)
tracker.track_request("gpt-4", 200, 100)
summary = tracker.get_session_summary()
assert summary['requests'] == 2
assert summary['input_tokens'] == 300
assert summary['output_tokens'] == 150
assert summary['total_tokens'] == 450
assert len(summary['models_used']) == 2
assert summary["requests"] == 2
assert summary["input_tokens"] == 300
assert summary["output_tokens"] == 150
assert summary["total_tokens"] == 450
assert len(summary["models_used"]) == 2
def test_get_formatted_summary():
tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50)
tracker.track_request("gpt-3.5-turbo", 100, 50)
formatted = tracker.get_formatted_summary()
assert "Total Requests: 1" in formatted
assert "Total Tokens: 150" in formatted
assert "Estimated Cost: $0.0001" in formatted
assert "gpt-3.5-turbo" in formatted
def test_get_total_usage_no_file(temp_usage_file):
total = UsageTracker.get_total_usage()
assert total['total_requests'] == 0
assert total['total_tokens'] == 0
assert total['total_cost'] == 0.0
assert total["total_requests"] == 0
assert total["total_tokens"] == 0
assert total["total_cost"] == 0.0
def test_get_total_usage_with_data(temp_usage_file):
# Manually create history file
history = [
{'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125},
{'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008}
{
"timestamp": "2023-01-01",
"model": "gpt-3.5-turbo",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
"cost": 0.000125,
},
{
"timestamp": "2023-01-02",
"model": "gpt-4",
"input_tokens": 200,
"output_tokens": 100,
"total_tokens": 300,
"cost": 0.008,
},
]
with open(temp_usage_file, 'w') as f:
with open(temp_usage_file, "w") as f:
json.dump(history, f)
total = UsageTracker.get_total_usage()
assert total['total_requests'] == 2
assert total['total_tokens'] == 450
assert abs(total['total_cost'] - 0.008125) < 1e-6
assert total["total_requests"] == 2
assert total["total_tokens"] == 450
assert abs(total["total_cost"] - 0.008125) < 1e-6

View File

@ -1,49 +1,59 @@
import pytest
import tempfile
import os
import tempfile
import pytest
from pr.core.exceptions import ValidationError
from pr.core.validation import (
validate_file_path,
validate_directory_path,
validate_model_name,
validate_api_url,
validate_directory_path,
validate_file_path,
validate_max_tokens,
validate_model_name,
validate_session_name,
validate_temperature,
validate_max_tokens,
)
from pr.core.exceptions import ValidationError
def test_validate_file_path_empty():
with pytest.raises(ValidationError, match="File path cannot be empty"):
validate_file_path("")
def test_validate_file_path_not_exist():
with pytest.raises(ValidationError, match="File does not exist"):
validate_file_path("/nonexistent/file.txt", must_exist=True)
def test_validate_file_path_is_dir():
with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(ValidationError, match="Path is a directory"):
validate_file_path(tmpdir, must_exist=True)
def test_validate_file_path_valid():
with tempfile.NamedTemporaryFile() as tmpfile:
result = validate_file_path(tmpfile.name, must_exist=True)
assert os.path.isabs(result)
assert result == os.path.abspath(tmpfile.name)
def test_validate_directory_path_empty():
with pytest.raises(ValidationError, match="Directory path cannot be empty"):
validate_directory_path("")
def test_validate_directory_path_not_exist():
with pytest.raises(ValidationError, match="Directory does not exist"):
validate_directory_path("/nonexistent/dir", must_exist=True)
def test_validate_directory_path_not_dir():
with tempfile.NamedTemporaryFile() as tmpfile:
with pytest.raises(ValidationError, match="Path is not a directory"):
validate_directory_path(tmpfile.name, must_exist=True)
def test_validate_directory_path_create():
with tempfile.TemporaryDirectory() as tmpdir:
new_dir = os.path.join(tmpdir, "new_dir")
@ -51,72 +61,89 @@ def test_validate_directory_path_create():
assert os.path.isdir(new_dir)
assert result == os.path.abspath(new_dir)
def test_validate_directory_path_valid():
with tempfile.TemporaryDirectory() as tmpdir:
result = validate_directory_path(tmpdir, must_exist=True)
assert result == os.path.abspath(tmpdir)
def test_validate_model_name_empty():
with pytest.raises(ValidationError, match="Model name cannot be empty"):
validate_model_name("")
def test_validate_model_name_too_short():
with pytest.raises(ValidationError, match="Model name too short"):
validate_model_name("a")
def test_validate_model_name_valid():
result = validate_model_name("gpt-3.5-turbo")
assert result == "gpt-3.5-turbo"
def test_validate_api_url_empty():
with pytest.raises(ValidationError, match="API URL cannot be empty"):
validate_api_url("")
def test_validate_api_url_invalid():
with pytest.raises(ValidationError, match="API URL must start with"):
validate_api_url("invalid-url")
def test_validate_api_url_valid():
result = validate_api_url("https://api.example.com")
assert result == "https://api.example.com"
def test_validate_session_name_empty():
with pytest.raises(ValidationError, match="Session name cannot be empty"):
validate_session_name("")
def test_validate_session_name_invalid_char():
with pytest.raises(ValidationError, match="contains invalid character"):
validate_session_name("test/session")
def test_validate_session_name_too_long():
long_name = "a" * 256
with pytest.raises(ValidationError, match="Session name too long"):
validate_session_name(long_name)
def test_validate_session_name_valid():
result = validate_session_name("valid_session_123")
assert result == "valid_session_123"
def test_validate_temperature_too_low():
with pytest.raises(ValidationError, match="Temperature must be between"):
validate_temperature(-0.1)
def test_validate_temperature_too_high():
with pytest.raises(ValidationError, match="Temperature must be between"):
validate_temperature(2.1)
def test_validate_temperature_valid():
result = validate_temperature(0.7)
assert result == 0.7
def test_validate_max_tokens_too_low():
with pytest.raises(ValidationError, match="Max tokens must be at least 1"):
validate_max_tokens(0)
def test_validate_max_tokens_too_high():
with pytest.raises(ValidationError, match="Max tokens too high"):
validate_max_tokens(100001)
def test_validate_max_tokens_valid():
result = validate_max_tokens(1000)
assert result == 1000