Update.
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (ubuntu-latest, 3.8) (push) Waiting to run
Tests / test (ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 39s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 55s
Tests / test (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / test (ubuntu-latest, 3.12) (push) Has been cancelled
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (ubuntu-latest, 3.8) (push) Waiting to run
Tests / test (ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 39s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 55s
Tests / test (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / test (ubuntu-latest, 3.12) (push) Has been cancelled
This commit is contained in:
parent
5f04811dcc
commit
1a29ee4918
@ -159,6 +159,7 @@ def tool_function(args):
|
||||
"""Implementation"""
|
||||
pass
|
||||
|
||||
|
||||
def register_tools():
|
||||
"""Return list of tool definitions"""
|
||||
return [...]
|
||||
@ -177,8 +178,8 @@ def register_tools():
|
||||
```python
|
||||
def test_read_file_with_valid_path_returns_content(temp_dir):
|
||||
# Arrange
|
||||
filepath = os.path.join(temp_dir, 'test.txt')
|
||||
expected_content = 'Hello, World!'
|
||||
filepath = os.path.join(temp_dir, "test.txt")
|
||||
expected_content = "Hello, World!"
|
||||
write_file(filepath, expected_content)
|
||||
|
||||
# Act
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pr.core import Assistant
|
||||
|
||||
__version__ = '1.0.0'
|
||||
__all__ = ['Assistant']
|
||||
__version__ = "1.0.0"
|
||||
__all__ = ["Assistant"]
|
||||
|
||||
116
pr/__main__.py
116
pr/__main__.py
@ -1,12 +1,14 @@
|
||||
import argparse
|
||||
import sys
|
||||
from pr.core import Assistant
|
||||
|
||||
from pr import __version__
|
||||
from pr.core import Assistant
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='PR Assistant - Professional CLI AI assistant with autonomous execution',
|
||||
epilog='''
|
||||
description="PR Assistant - Professional CLI AI assistant with autonomous execution",
|
||||
epilog="""
|
||||
Examples:
|
||||
pr "What is Python?" # Single query
|
||||
pr -i # Interactive mode
|
||||
@ -25,43 +27,79 @@ Commands in interactive mode:
|
||||
/usage - Show usage statistics
|
||||
/save <name> - Save current session
|
||||
exit, quit, q - Exit the program
|
||||
''',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument('message', nargs='?', help='Message to send to assistant')
|
||||
parser.add_argument('--version', action='version', version=f'PR Assistant {__version__}')
|
||||
parser.add_argument('-m', '--model', help='AI model to use')
|
||||
parser.add_argument('-u', '--api-url', help='API endpoint URL')
|
||||
parser.add_argument('--model-list-url', help='Model list endpoint URL')
|
||||
parser.add_argument('-i', '--interactive', action='store_true', help='Interactive mode')
|
||||
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging')
|
||||
parser.add_argument('--no-syntax', action='store_true', help='Disable syntax highlighting')
|
||||
parser.add_argument('--include-env', action='store_true', help='Include environment variables in context')
|
||||
parser.add_argument('-c', '--context', action='append', help='Additional context files')
|
||||
parser.add_argument('--api-mode', action='store_true', help='API mode for specialized interaction')
|
||||
parser.add_argument("message", nargs="?", help="Message to send to assistant")
|
||||
parser.add_argument(
|
||||
"--version", action="version", version=f"PR Assistant {__version__}"
|
||||
)
|
||||
parser.add_argument("-m", "--model", help="AI model to use")
|
||||
parser.add_argument("-u", "--api-url", help="API endpoint URL")
|
||||
parser.add_argument("--model-list-url", help="Model list endpoint URL")
|
||||
parser.add_argument(
|
||||
"-i", "--interactive", action="store_true", help="Interactive mode"
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Enable debug mode with detailed logging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-syntax", action="store_true", help="Disable syntax highlighting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-env",
|
||||
action="store_true",
|
||||
help="Include environment variables in context",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--context", action="append", help="Additional context files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-mode", action="store_true", help="API mode for specialized interaction"
|
||||
)
|
||||
|
||||
parser.add_argument('--output', choices=['text', 'json', 'structured'],
|
||||
default='text', help='Output format')
|
||||
parser.add_argument('--quiet', action='store_true', help='Minimal output')
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
choices=["text", "json", "structured"],
|
||||
default="text",
|
||||
help="Output format",
|
||||
)
|
||||
parser.add_argument("--quiet", action="store_true", help="Minimal output")
|
||||
|
||||
parser.add_argument('--save-session', metavar='NAME', help='Save session with given name')
|
||||
parser.add_argument('--load-session', metavar='NAME', help='Load session with given name')
|
||||
parser.add_argument('--list-sessions', action='store_true', help='List all saved sessions')
|
||||
parser.add_argument('--delete-session', metavar='NAME', help='Delete a saved session')
|
||||
parser.add_argument('--export-session', nargs=2, metavar=('NAME', 'FILE'),
|
||||
help='Export session to file')
|
||||
parser.add_argument(
|
||||
"--save-session", metavar="NAME", help="Save session with given name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-session", metavar="NAME", help="Load session with given name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-sessions", action="store_true", help="List all saved sessions"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-session", metavar="NAME", help="Delete a saved session"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-session",
|
||||
nargs=2,
|
||||
metavar=("NAME", "FILE"),
|
||||
help="Export session to file",
|
||||
)
|
||||
|
||||
parser.add_argument('--usage', action='store_true', help='Show token usage statistics')
|
||||
parser.add_argument('--create-config', action='store_true',
|
||||
help='Create default configuration file')
|
||||
parser.add_argument('--plugins', action='store_true', help='List loaded plugins')
|
||||
parser.add_argument(
|
||||
"--usage", action="store_true", help="Show token usage statistics"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create-config", action="store_true", help="Create default configuration file"
|
||||
)
|
||||
parser.add_argument("--plugins", action="store_true", help="List loaded plugins")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_config:
|
||||
from pr.core.config_loader import create_default_config
|
||||
|
||||
if create_default_config():
|
||||
print("Configuration file created at ~/.prrc")
|
||||
else:
|
||||
@ -70,6 +108,7 @@ Commands in interactive mode:
|
||||
|
||||
if args.list_sessions:
|
||||
from pr.core.session import SessionManager
|
||||
|
||||
sm = SessionManager()
|
||||
sessions = sm.list_sessions()
|
||||
if not sessions:
|
||||
@ -85,6 +124,7 @@ Commands in interactive mode:
|
||||
|
||||
if args.delete_session:
|
||||
from pr.core.session import SessionManager
|
||||
|
||||
sm = SessionManager()
|
||||
if sm.delete_session(args.delete_session):
|
||||
print(f"Session '{args.delete_session}' deleted")
|
||||
@ -94,13 +134,14 @@ Commands in interactive mode:
|
||||
|
||||
if args.export_session:
|
||||
from pr.core.session import SessionManager
|
||||
|
||||
sm = SessionManager()
|
||||
name, output_file = args.export_session
|
||||
format_type = 'json'
|
||||
if output_file.endswith('.md'):
|
||||
format_type = 'markdown'
|
||||
elif output_file.endswith('.txt'):
|
||||
format_type = 'txt'
|
||||
format_type = "json"
|
||||
if output_file.endswith(".md"):
|
||||
format_type = "markdown"
|
||||
elif output_file.endswith(".txt"):
|
||||
format_type = "txt"
|
||||
|
||||
if sm.export_session(name, output_file, format_type):
|
||||
print(f"Session exported to {output_file}")
|
||||
@ -110,6 +151,7 @@ Commands in interactive mode:
|
||||
|
||||
if args.usage:
|
||||
from pr.core.usage_tracker import UsageTracker
|
||||
|
||||
usage = UsageTracker.get_total_usage()
|
||||
print(f"\nTotal Usage Statistics:")
|
||||
print(f" Requests: {usage['total_requests']}")
|
||||
@ -119,6 +161,7 @@ Commands in interactive mode:
|
||||
|
||||
if args.plugins:
|
||||
from pr.plugins.loader import PluginLoader
|
||||
|
||||
loader = PluginLoader()
|
||||
loader.load_plugins()
|
||||
plugins = loader.list_loaded_plugins()
|
||||
@ -133,5 +176,6 @@ Commands in interactive mode:
|
||||
assistant = Assistant(args)
|
||||
assistant.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from .agent_communication import AgentCommunicationBus, AgentMessage
|
||||
from .agent_manager import AgentInstance, AgentManager
|
||||
from .agent_roles import AgentRole, get_agent_role, list_agent_roles
|
||||
from .agent_manager import AgentManager, AgentInstance
|
||||
from .agent_communication import AgentMessage, AgentCommunicationBus
|
||||
|
||||
__all__ = ['AgentRole', 'get_agent_role', 'list_agent_roles', 'AgentManager', 'AgentInstance',
|
||||
'AgentMessage', 'AgentCommunicationBus']
|
||||
__all__ = [
|
||||
"AgentRole",
|
||||
"get_agent_role",
|
||||
"list_agent_roles",
|
||||
"AgentManager",
|
||||
"AgentInstance",
|
||||
"AgentMessage",
|
||||
"AgentCommunicationBus",
|
||||
]
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
import sqlite3
|
||||
import json
|
||||
from typing import List, Optional
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
NOTIFICATION = "notification"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
message_id: str
|
||||
@ -21,27 +23,28 @@ class AgentMessage:
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'message_id': self.message_id,
|
||||
'from_agent': self.from_agent,
|
||||
'to_agent': self.to_agent,
|
||||
'message_type': self.message_type.value,
|
||||
'content': self.content,
|
||||
'metadata': self.metadata,
|
||||
'timestamp': self.timestamp
|
||||
"message_id": self.message_id,
|
||||
"from_agent": self.from_agent,
|
||||
"to_agent": self.to_agent,
|
||||
"message_type": self.message_type.value,
|
||||
"content": self.content,
|
||||
"metadata": self.metadata,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'AgentMessage':
|
||||
def from_dict(cls, data: dict) -> "AgentMessage":
|
||||
return cls(
|
||||
message_id=data['message_id'],
|
||||
from_agent=data['from_agent'],
|
||||
to_agent=data['to_agent'],
|
||||
message_type=MessageType(data['message_type']),
|
||||
content=data['content'],
|
||||
metadata=data['metadata'],
|
||||
timestamp=data['timestamp']
|
||||
message_id=data["message_id"],
|
||||
from_agent=data["from_agent"],
|
||||
to_agent=data["to_agent"],
|
||||
message_type=MessageType(data["message_type"]),
|
||||
content=data["content"],
|
||||
metadata=data["metadata"],
|
||||
timestamp=data["timestamp"],
|
||||
)
|
||||
|
||||
|
||||
class AgentCommunicationBus:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
@ -50,7 +53,8 @@ class AgentCommunicationBus:
|
||||
|
||||
def _create_tables(self):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS agent_messages (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
from_agent TEXT,
|
||||
@ -62,70 +66,88 @@ class AgentCommunicationBus:
|
||||
session_id TEXT,
|
||||
read INTEGER DEFAULT 0
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def send_message(self, message: AgentMessage, session_id: Optional[str] = None):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO agent_messages
|
||||
(message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
message.message_id,
|
||||
message.from_agent,
|
||||
message.to_agent,
|
||||
message.message_type.value,
|
||||
message.content,
|
||||
json.dumps(message.metadata),
|
||||
message.timestamp,
|
||||
session_id
|
||||
))
|
||||
""",
|
||||
(
|
||||
message.message_id,
|
||||
message.from_agent,
|
||||
message.to_agent,
|
||||
message.message_type.value,
|
||||
message.content,
|
||||
json.dumps(message.metadata),
|
||||
message.timestamp,
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
def get_messages(
|
||||
self, agent_id: str, unread_only: bool = True
|
||||
) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
if unread_only:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
||||
FROM agent_messages
|
||||
WHERE to_agent = ? AND read = 0
|
||||
ORDER BY timestamp ASC
|
||||
''', (agent_id,))
|
||||
""",
|
||||
(agent_id,),
|
||||
)
|
||||
else:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
||||
FROM agent_messages
|
||||
WHERE to_agent = ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (agent_id,))
|
||||
""",
|
||||
(agent_id,),
|
||||
)
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append(AgentMessage(
|
||||
message_id=row[0],
|
||||
from_agent=row[1],
|
||||
to_agent=row[2],
|
||||
message_type=MessageType(row[3]),
|
||||
content=row[4],
|
||||
metadata=json.loads(row[5]) if row[5] else {},
|
||||
timestamp=row[6]
|
||||
))
|
||||
messages.append(
|
||||
AgentMessage(
|
||||
message_id=row[0],
|
||||
from_agent=row[1],
|
||||
to_agent=row[2],
|
||||
message_type=MessageType(row[3]),
|
||||
content=row[4],
|
||||
metadata=json.loads(row[5]) if row[5] else {},
|
||||
timestamp=row[6],
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
def mark_as_read(self, message_id: str):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,))
|
||||
cursor.execute(
|
||||
"UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,)
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def clear_messages(self, session_id: Optional[str] = None):
|
||||
cursor = self.conn.cursor()
|
||||
if session_id:
|
||||
cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM agent_messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
else:
|
||||
cursor.execute('DELETE FROM agent_messages')
|
||||
cursor.execute("DELETE FROM agent_messages")
|
||||
self.conn.commit()
|
||||
|
||||
def close(self):
|
||||
@ -134,24 +156,31 @@ class AgentCommunicationBus:
|
||||
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
|
||||
return self.get_messages(agent_id, unread_only=True)
|
||||
|
||||
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
|
||||
def get_conversation_history(
|
||||
self, agent_a: str, agent_b: str
|
||||
) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
||||
FROM agent_messages
|
||||
WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?)
|
||||
ORDER BY timestamp ASC
|
||||
''', (agent_a, agent_b, agent_b, agent_a))
|
||||
""",
|
||||
(agent_a, agent_b, agent_b, agent_a),
|
||||
)
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append(AgentMessage(
|
||||
message_id=row[0],
|
||||
from_agent=row[1],
|
||||
to_agent=row[2],
|
||||
message_type=MessageType(row[3]),
|
||||
content=row[4],
|
||||
metadata=json.loads(row[5]) if row[5] else {},
|
||||
timestamp=row[6]
|
||||
))
|
||||
return messages
|
||||
messages.append(
|
||||
AgentMessage(
|
||||
message_id=row[0],
|
||||
from_agent=row[1],
|
||||
to_agent=row[2],
|
||||
message_type=MessageType(row[3]),
|
||||
content=row[4],
|
||||
metadata=json.loads(row[5]) if row[5] else {},
|
||||
timestamp=row[6],
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import time
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from .agent_roles import AgentRole, get_agent_role
|
||||
from .agent_communication import AgentMessage, AgentCommunicationBus, MessageType
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from ..memory.knowledge_store import KnowledgeStore
|
||||
from .agent_communication import AgentCommunicationBus, AgentMessage, MessageType
|
||||
from .agent_roles import AgentRole, get_agent_role
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInstance:
|
||||
@ -17,21 +19,20 @@ class AgentInstance:
|
||||
task_count: int = 0
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
self.message_history.append({
|
||||
'role': role,
|
||||
'content': content,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
self.message_history.append(
|
||||
{"role": role, "content": content, "timestamp": time.time()}
|
||||
)
|
||||
|
||||
def get_system_message(self) -> Dict[str, str]:
|
||||
return {'role': 'system', 'content': self.role.system_prompt}
|
||||
return {"role": "system", "content": self.role.system_prompt}
|
||||
|
||||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||||
return [self.get_system_message()] + [
|
||||
{'role': msg['role'], 'content': msg['content']}
|
||||
{"role": msg["role"], "content": msg["content"]}
|
||||
for msg in self.message_history
|
||||
]
|
||||
|
||||
|
||||
class AgentManager:
|
||||
def __init__(self, db_path: str, api_caller: Callable):
|
||||
self.db_path = db_path
|
||||
@ -46,32 +47,31 @@ class AgentManager:
|
||||
agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
role = get_agent_role(role_name)
|
||||
agent = AgentInstance(
|
||||
agent_id=agent_id,
|
||||
role=role
|
||||
)
|
||||
agent = AgentInstance(agent_id=agent_id, role=role)
|
||||
|
||||
self.active_agents[agent_id] = agent
|
||||
return agent_id
|
||||
|
||||
def get_agent(self, agent_id: str) -> Optional[AgentInstance]:
|
||||
return self.active_agents.get(agent_id)
|
||||
|
||||
|
||||
def remove_agent(self, agent_id: str) -> bool:
|
||||
if agent_id in self.active_agents:
|
||||
del self.active_agents[agent_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def execute_agent_task(self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
def execute_agent_task(
|
||||
self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
agent = self.get_agent(agent_id)
|
||||
if not agent:
|
||||
return {'error': f'Agent {agent_id} not found'}
|
||||
return {"error": f"Agent {agent_id} not found"}
|
||||
|
||||
if context:
|
||||
agent.context.update(context)
|
||||
|
||||
agent.add_message('user', task)
|
||||
agent.add_message("user", task)
|
||||
knowledge_matches = self.knowledge_store.search_entries(task, top_k=3)
|
||||
agent.task_count += 1
|
||||
|
||||
@ -81,35 +81,40 @@ class AgentManager:
|
||||
for i, entry in enumerate(knowledge_matches, 1):
|
||||
shortened_content = entry.content[:2000]
|
||||
knowledge_content += f"{i}. {shortened_content}\\n\\n"
|
||||
messages.insert(-1, {'role': 'user', 'content': knowledge_content})
|
||||
messages.insert(-1, {"role": "user", "content": knowledge_content})
|
||||
|
||||
try:
|
||||
response = self.api_caller(
|
||||
messages=messages,
|
||||
temperature=agent.role.temperature,
|
||||
max_tokens=agent.role.max_tokens
|
||||
max_tokens=agent.role.max_tokens,
|
||||
)
|
||||
|
||||
if response and 'choices' in response:
|
||||
assistant_message = response['choices'][0]['message']['content']
|
||||
agent.add_message('assistant', assistant_message)
|
||||
if response and "choices" in response:
|
||||
assistant_message = response["choices"][0]["message"]["content"]
|
||||
agent.add_message("assistant", assistant_message)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'agent_id': agent_id,
|
||||
'response': assistant_message,
|
||||
'role': agent.role.name,
|
||||
'task_count': agent.task_count
|
||||
"success": True,
|
||||
"agent_id": agent_id,
|
||||
"response": assistant_message,
|
||||
"role": agent.role.name,
|
||||
"task_count": agent.task_count,
|
||||
}
|
||||
else:
|
||||
return {'error': 'Invalid API response', 'agent_id': agent_id}
|
||||
return {"error": "Invalid API response", "agent_id": agent_id}
|
||||
|
||||
except Exception as e:
|
||||
return {'error': str(e), 'agent_id': agent_id}
|
||||
return {"error": str(e), "agent_id": agent_id}
|
||||
|
||||
def send_agent_message(self, from_agent_id: str, to_agent_id: str,
|
||||
content: str, message_type: MessageType = MessageType.REQUEST,
|
||||
metadata: Optional[Dict[str, Any]] = None):
|
||||
def send_agent_message(
|
||||
self,
|
||||
from_agent_id: str,
|
||||
to_agent_id: str,
|
||||
content: str,
|
||||
message_type: MessageType = MessageType.REQUEST,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
message = AgentMessage(
|
||||
from_agent=from_agent_id,
|
||||
to_agent=to_agent_id,
|
||||
@ -117,57 +122,57 @@ class AgentManager:
|
||||
content=content,
|
||||
metadata=metadata or {},
|
||||
timestamp=time.time(),
|
||||
message_id=str(uuid.uuid4())[:16]
|
||||
message_id=str(uuid.uuid4())[:16],
|
||||
)
|
||||
|
||||
self.communication_bus.send_message(message, self.session_id)
|
||||
return message.message_id
|
||||
|
||||
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
def get_agent_messages(
|
||||
self, agent_id: str, unread_only: bool = True
|
||||
) -> List[AgentMessage]:
|
||||
return self.communication_bus.get_messages(agent_id, unread_only)
|
||||
|
||||
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
|
||||
def collaborate_agents(
|
||||
self, orchestrator_id: str, task: str, agent_roles: List[str]
|
||||
):
|
||||
orchestrator = self.get_agent(orchestrator_id)
|
||||
if not orchestrator:
|
||||
orchestrator_id = self.create_agent('orchestrator')
|
||||
orchestrator_id = self.create_agent("orchestrator")
|
||||
orchestrator = self.get_agent(orchestrator_id)
|
||||
|
||||
worker_agents = []
|
||||
for role in agent_roles:
|
||||
agent_id = self.create_agent(role)
|
||||
worker_agents.append({
|
||||
'agent_id': agent_id,
|
||||
'role': role
|
||||
})
|
||||
worker_agents.append({"agent_id": agent_id, "role": role})
|
||||
|
||||
orchestration_prompt = f'''Task: {task}
|
||||
orchestration_prompt = f"""Task: {task}
|
||||
|
||||
Available specialized agents:
|
||||
{chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])}
|
||||
|
||||
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.'''
|
||||
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results."""
|
||||
|
||||
orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt)
|
||||
orchestrator_result = self.execute_agent_task(
|
||||
orchestrator_id, orchestration_prompt
|
||||
)
|
||||
|
||||
results = {
|
||||
'orchestrator': orchestrator_result,
|
||||
'agents': []
|
||||
}
|
||||
results = {"orchestrator": orchestrator_result, "agents": []}
|
||||
|
||||
for agent_info in worker_agents:
|
||||
agent_id = agent_info['agent_id']
|
||||
agent_id = agent_info["agent_id"]
|
||||
messages = self.get_agent_messages(agent_id)
|
||||
|
||||
for msg in messages:
|
||||
subtask = msg.content
|
||||
result = self.execute_agent_task(agent_id, subtask)
|
||||
results['agents'].append(result)
|
||||
results["agents"].append(result)
|
||||
|
||||
self.send_agent_message(
|
||||
from_agent_id=agent_id,
|
||||
to_agent_id=orchestrator_id,
|
||||
content=result.get('response', ''),
|
||||
message_type=MessageType.RESPONSE
|
||||
content=result.get("response", ""),
|
||||
message_type=MessageType.RESPONSE,
|
||||
)
|
||||
self.communication_bus.mark_as_read(msg.message_id)
|
||||
|
||||
@ -175,21 +180,21 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
|
||||
|
||||
def get_session_summary(self) -> str:
|
||||
summary = {
|
||||
'session_id': self.session_id,
|
||||
'active_agents': len(self.active_agents),
|
||||
'agents': [
|
||||
"session_id": self.session_id,
|
||||
"active_agents": len(self.active_agents),
|
||||
"agents": [
|
||||
{
|
||||
'agent_id': agent_id,
|
||||
'role': agent.role.name,
|
||||
'task_count': agent.task_count,
|
||||
'message_count': len(agent.message_history)
|
||||
"agent_id": agent_id,
|
||||
"role": agent.role.name,
|
||||
"task_count": agent.task_count,
|
||||
"message_count": len(agent.message_history),
|
||||
}
|
||||
for agent_id, agent in self.active_agents.items()
|
||||
]
|
||||
],
|
||||
}
|
||||
return json.dumps(summary)
|
||||
|
||||
def clear_session(self):
|
||||
self.active_agents.clear()
|
||||
self.communication_bus.clear_messages(session_id=self.session_id)
|
||||
self.session_id = str(uuid.uuid4())[:16]
|
||||
self.session_id = str(uuid.uuid4())[:16]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any, Set
|
||||
from typing import Dict, List, Set
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRole:
|
||||
@ -11,182 +12,262 @@ class AgentRole:
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
AGENT_ROLES = {
|
||||
'coding': AgentRole(
|
||||
name='coding',
|
||||
description='Specialized in writing, reviewing, and debugging code',
|
||||
system_prompt='''You are a coding specialist AI assistant. Your primary responsibilities:
|
||||
"coding": AgentRole(
|
||||
name="coding",
|
||||
description="Specialized in writing, reviewing, and debugging code",
|
||||
system_prompt="""You are a coding specialist AI assistant. Your primary responsibilities:
|
||||
- Write clean, efficient, well-structured code
|
||||
- Review code for bugs, security issues, and best practices
|
||||
- Refactor and optimize existing code
|
||||
- Implement features based on specifications
|
||||
- Follow language-specific conventions and patterns
|
||||
Focus on code quality, maintainability, and performance.''',
|
||||
Focus on code quality, maintainability, and performance.""",
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'create_directory',
|
||||
'change_directory', 'get_current_directory', 'python_exec',
|
||||
'run_command', 'index_directory'
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_directory",
|
||||
"create_directory",
|
||||
"change_directory",
|
||||
"get_current_directory",
|
||||
"python_exec",
|
||||
"run_command",
|
||||
"index_directory",
|
||||
},
|
||||
specialization_areas=['code_writing', 'code_review', 'debugging', 'refactoring'],
|
||||
temperature=0.3
|
||||
specialization_areas=[
|
||||
"code_writing",
|
||||
"code_review",
|
||||
"debugging",
|
||||
"refactoring",
|
||||
],
|
||||
temperature=0.3,
|
||||
),
|
||||
|
||||
'research': AgentRole(
|
||||
name='research',
|
||||
description='Specialized in information gathering and analysis',
|
||||
system_prompt='''You are a research specialist AI assistant. Your primary responsibilities:
|
||||
"research": AgentRole(
|
||||
name="research",
|
||||
description="Specialized in information gathering and analysis",
|
||||
system_prompt="""You are a research specialist AI assistant. Your primary responsibilities:
|
||||
- Search for and gather relevant information
|
||||
- Analyze data and documentation
|
||||
- Synthesize findings into clear summaries
|
||||
- Verify facts and cross-reference sources
|
||||
- Identify trends and patterns in information
|
||||
Focus on accuracy, thoroughness, and clear communication of findings.''',
|
||||
Focus on accuracy, thoroughness, and clear communication of findings.""",
|
||||
allowed_tools={
|
||||
'read_file', 'list_directory', 'index_directory',
|
||||
'http_fetch', 'web_search', 'web_search_news',
|
||||
'db_query', 'db_get'
|
||||
"read_file",
|
||||
"list_directory",
|
||||
"index_directory",
|
||||
"http_fetch",
|
||||
"web_search",
|
||||
"web_search_news",
|
||||
"db_query",
|
||||
"db_get",
|
||||
},
|
||||
specialization_areas=['information_gathering', 'analysis', 'documentation', 'fact_checking'],
|
||||
temperature=0.5
|
||||
specialization_areas=[
|
||||
"information_gathering",
|
||||
"analysis",
|
||||
"documentation",
|
||||
"fact_checking",
|
||||
],
|
||||
temperature=0.5,
|
||||
),
|
||||
|
||||
'data_analysis': AgentRole(
|
||||
name='data_analysis',
|
||||
description='Specialized in data processing and analysis',
|
||||
system_prompt='''You are a data analysis specialist AI assistant. Your primary responsibilities:
|
||||
"data_analysis": AgentRole(
|
||||
name="data_analysis",
|
||||
description="Specialized in data processing and analysis",
|
||||
system_prompt="""You are a data analysis specialist AI assistant. Your primary responsibilities:
|
||||
- Process and analyze structured and unstructured data
|
||||
- Perform statistical analysis and pattern recognition
|
||||
- Query databases and extract insights
|
||||
- Create data summaries and reports
|
||||
- Identify anomalies and trends
|
||||
Focus on accuracy, data integrity, and actionable insights.''',
|
||||
Focus on accuracy, data integrity, and actionable insights.""",
|
||||
allowed_tools={
|
||||
'db_query', 'db_get', 'db_set', 'read_file', 'write_file',
|
||||
'python_exec', 'run_command', 'list_directory'
|
||||
"db_query",
|
||||
"db_get",
|
||||
"db_set",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"python_exec",
|
||||
"run_command",
|
||||
"list_directory",
|
||||
},
|
||||
specialization_areas=['data_processing', 'statistical_analysis', 'database_operations'],
|
||||
temperature=0.3
|
||||
specialization_areas=[
|
||||
"data_processing",
|
||||
"statistical_analysis",
|
||||
"database_operations",
|
||||
],
|
||||
temperature=0.3,
|
||||
),
|
||||
|
||||
'planning': AgentRole(
|
||||
name='planning',
|
||||
description='Specialized in task planning and coordination',
|
||||
system_prompt='''You are a planning specialist AI assistant. Your primary responsibilities:
|
||||
"planning": AgentRole(
|
||||
name="planning",
|
||||
description="Specialized in task planning and coordination",
|
||||
system_prompt="""You are a planning specialist AI assistant. Your primary responsibilities:
|
||||
- Break down complex tasks into manageable steps
|
||||
- Create execution plans and workflows
|
||||
- Identify dependencies and prerequisites
|
||||
- Estimate effort and resource requirements
|
||||
- Coordinate between different components
|
||||
Focus on logical organization, completeness, and feasibility.''',
|
||||
Focus on logical organization, completeness, and feasibility.""",
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'index_directory',
|
||||
'db_set', 'db_get'
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_directory",
|
||||
"index_directory",
|
||||
"db_set",
|
||||
"db_get",
|
||||
},
|
||||
specialization_areas=['task_decomposition', 'workflow_design', 'coordination'],
|
||||
temperature=0.6
|
||||
specialization_areas=["task_decomposition", "workflow_design", "coordination"],
|
||||
temperature=0.6,
|
||||
),
|
||||
|
||||
'testing': AgentRole(
|
||||
name='testing',
|
||||
description='Specialized in testing and quality assurance',
|
||||
system_prompt='''You are a testing specialist AI assistant. Your primary responsibilities:
|
||||
"testing": AgentRole(
|
||||
name="testing",
|
||||
description="Specialized in testing and quality assurance",
|
||||
system_prompt="""You are a testing specialist AI assistant. Your primary responsibilities:
|
||||
- Design and execute test cases
|
||||
- Identify edge cases and potential failures
|
||||
- Verify functionality and correctness
|
||||
- Test error handling and edge conditions
|
||||
- Ensure code meets quality standards
|
||||
Focus on thoroughness, coverage, and issue identification.''',
|
||||
Focus on thoroughness, coverage, and issue identification.""",
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'python_exec', 'run_command',
|
||||
'list_directory', 'db_query'
|
||||
"read_file",
|
||||
"write_file",
|
||||
"python_exec",
|
||||
"run_command",
|
||||
"list_directory",
|
||||
"db_query",
|
||||
},
|
||||
specialization_areas=['test_design', 'quality_assurance', 'validation'],
|
||||
temperature=0.4
|
||||
specialization_areas=["test_design", "quality_assurance", "validation"],
|
||||
temperature=0.4,
|
||||
),
|
||||
|
||||
'documentation': AgentRole(
|
||||
name='documentation',
|
||||
description='Specialized in creating and maintaining documentation',
|
||||
system_prompt='''You are a documentation specialist AI assistant. Your primary responsibilities:
|
||||
"documentation": AgentRole(
|
||||
name="documentation",
|
||||
description="Specialized in creating and maintaining documentation",
|
||||
system_prompt="""You are a documentation specialist AI assistant. Your primary responsibilities:
|
||||
- Write clear, comprehensive documentation
|
||||
- Create API references and user guides
|
||||
- Document code with comments and docstrings
|
||||
- Organize and structure information logically
|
||||
- Ensure documentation is up-to-date and accurate
|
||||
Focus on clarity, completeness, and user-friendliness.''',
|
||||
Focus on clarity, completeness, and user-friendliness.""",
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'index_directory',
|
||||
'http_fetch', 'web_search'
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_directory",
|
||||
"index_directory",
|
||||
"http_fetch",
|
||||
"web_search",
|
||||
},
|
||||
specialization_areas=['technical_writing', 'documentation_organization', 'user_guides'],
|
||||
temperature=0.6
|
||||
specialization_areas=[
|
||||
"technical_writing",
|
||||
"documentation_organization",
|
||||
"user_guides",
|
||||
],
|
||||
temperature=0.6,
|
||||
),
|
||||
|
||||
'orchestrator': AgentRole(
|
||||
name='orchestrator',
|
||||
description='Coordinates multiple agents and manages overall execution',
|
||||
system_prompt='''You are an orchestrator AI assistant. Your primary responsibilities:
|
||||
"orchestrator": AgentRole(
|
||||
name="orchestrator",
|
||||
description="Coordinates multiple agents and manages overall execution",
|
||||
system_prompt="""You are an orchestrator AI assistant. Your primary responsibilities:
|
||||
- Coordinate multiple specialized agents
|
||||
- Delegate tasks to appropriate agents
|
||||
- Integrate results from different agents
|
||||
- Manage overall workflow execution
|
||||
- Ensure task completion and quality
|
||||
Focus on effective delegation, integration, and overall success.''',
|
||||
Focus on effective delegation, integration, and overall success.""",
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'db_set', 'db_get', 'db_query'
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_directory",
|
||||
"db_set",
|
||||
"db_get",
|
||||
"db_query",
|
||||
},
|
||||
specialization_areas=['agent_coordination', 'task_delegation', 'result_integration'],
|
||||
temperature=0.5
|
||||
specialization_areas=[
|
||||
"agent_coordination",
|
||||
"task_delegation",
|
||||
"result_integration",
|
||||
],
|
||||
temperature=0.5,
|
||||
),
|
||||
|
||||
'general': AgentRole(
|
||||
name='general',
|
||||
description='General purpose agent for miscellaneous tasks',
|
||||
system_prompt='''You are a general purpose AI assistant. Your responsibilities:
|
||||
"general": AgentRole(
|
||||
name="general",
|
||||
description="General purpose agent for miscellaneous tasks",
|
||||
system_prompt="""You are a general purpose AI assistant. Your responsibilities:
|
||||
- Handle diverse tasks across multiple domains
|
||||
- Provide balanced assistance for various needs
|
||||
- Adapt to different types of requests
|
||||
- Collaborate with specialized agents when needed
|
||||
Focus on versatility, helpfulness, and task completion.''',
|
||||
Focus on versatility, helpfulness, and task completion.""",
|
||||
allowed_tools={
|
||||
'read_file', 'write_file', 'list_directory', 'create_directory',
|
||||
'change_directory', 'get_current_directory', 'python_exec',
|
||||
'run_command', 'run_command_interactive', 'http_fetch',
|
||||
'web_search', 'web_search_news', 'db_set', 'db_get', 'db_query',
|
||||
'index_directory'
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_directory",
|
||||
"create_directory",
|
||||
"change_directory",
|
||||
"get_current_directory",
|
||||
"python_exec",
|
||||
"run_command",
|
||||
"run_command_interactive",
|
||||
"http_fetch",
|
||||
"web_search",
|
||||
"web_search_news",
|
||||
"db_set",
|
||||
"db_get",
|
||||
"db_query",
|
||||
"index_directory",
|
||||
},
|
||||
specialization_areas=['general_assistance'],
|
||||
temperature=0.7
|
||||
)
|
||||
specialization_areas=["general_assistance"],
|
||||
temperature=0.7,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_agent_role(role_name: str) -> AgentRole:
|
||||
return AGENT_ROLES.get(role_name, AGENT_ROLES['general'])
|
||||
return AGENT_ROLES.get(role_name, AGENT_ROLES["general"])
|
||||
|
||||
|
||||
def list_agent_roles() -> Dict[str, AgentRole]:
|
||||
return AGENT_ROLES.copy()
|
||||
|
||||
|
||||
def get_recommended_agent(task_description: str) -> str:
|
||||
task_lower = task_description.lower()
|
||||
|
||||
code_keywords = ['code', 'implement', 'function', 'class', 'bug', 'debug', 'refactor', 'optimize']
|
||||
research_keywords = ['search', 'find', 'research', 'information', 'analyze', 'investigate']
|
||||
data_keywords = ['data', 'database', 'query', 'statistics', 'analyze', 'process']
|
||||
planning_keywords = ['plan', 'organize', 'workflow', 'steps', 'coordinate']
|
||||
testing_keywords = ['test', 'verify', 'validate', 'check', 'quality']
|
||||
doc_keywords = ['document', 'documentation', 'explain', 'guide', 'manual']
|
||||
code_keywords = [
|
||||
"code",
|
||||
"implement",
|
||||
"function",
|
||||
"class",
|
||||
"bug",
|
||||
"debug",
|
||||
"refactor",
|
||||
"optimize",
|
||||
]
|
||||
research_keywords = [
|
||||
"search",
|
||||
"find",
|
||||
"research",
|
||||
"information",
|
||||
"analyze",
|
||||
"investigate",
|
||||
]
|
||||
data_keywords = ["data", "database", "query", "statistics", "analyze", "process"]
|
||||
planning_keywords = ["plan", "organize", "workflow", "steps", "coordinate"]
|
||||
testing_keywords = ["test", "verify", "validate", "check", "quality"]
|
||||
doc_keywords = ["document", "documentation", "explain", "guide", "manual"]
|
||||
|
||||
if any(keyword in task_lower for keyword in code_keywords):
|
||||
return 'coding'
|
||||
return "coding"
|
||||
elif any(keyword in task_lower for keyword in research_keywords):
|
||||
return 'research'
|
||||
return "research"
|
||||
elif any(keyword in task_lower for keyword in data_keywords):
|
||||
return 'data_analysis'
|
||||
return "data_analysis"
|
||||
elif any(keyword in task_lower for keyword in planning_keywords):
|
||||
return 'planning'
|
||||
return "planning"
|
||||
elif any(keyword in task_lower for keyword in testing_keywords):
|
||||
return 'testing'
|
||||
return "testing"
|
||||
elif any(keyword in task_lower for keyword in doc_keywords):
|
||||
return 'documentation'
|
||||
return "documentation"
|
||||
else:
|
||||
return 'general'
|
||||
return "general"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from pr.autonomous.detection import is_task_complete
|
||||
from pr.autonomous.mode import run_autonomous_mode, process_response_autonomous
|
||||
from pr.autonomous.mode import process_response_autonomous, run_autonomous_mode
|
||||
|
||||
__all__ = ['is_task_complete', 'run_autonomous_mode', 'process_response_autonomous']
|
||||
__all__ = ["is_task_complete", "run_autonomous_mode", "process_response_autonomous"]
|
||||
|
||||
@ -1,28 +1,39 @@
|
||||
from pr.config import MAX_AUTONOMOUS_ITERATIONS
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
def is_task_complete(response, iteration):
|
||||
if 'error' in response:
|
||||
if "error" in response:
|
||||
return True
|
||||
|
||||
if 'choices' not in response or not response['choices']:
|
||||
if "choices" not in response or not response["choices"]:
|
||||
return True
|
||||
|
||||
message = response['choices'][0]['message']
|
||||
content = message.get('content', '').lower()
|
||||
message = response["choices"][0]["message"]
|
||||
content = message.get("content", "").lower()
|
||||
|
||||
completion_keywords = [
|
||||
'task complete', 'task is complete', 'finished', 'done',
|
||||
'successfully completed', 'task accomplished', 'all done',
|
||||
'implementation complete', 'setup complete', 'installation complete'
|
||||
"task complete",
|
||||
"task is complete",
|
||||
"finished",
|
||||
"done",
|
||||
"successfully completed",
|
||||
"task accomplished",
|
||||
"all done",
|
||||
"implementation complete",
|
||||
"setup complete",
|
||||
"installation complete",
|
||||
]
|
||||
|
||||
error_keywords = [
|
||||
'cannot proceed', 'unable to continue', 'fatal error',
|
||||
'cannot complete', 'impossible to'
|
||||
"cannot proceed",
|
||||
"unable to continue",
|
||||
"fatal error",
|
||||
"cannot complete",
|
||||
"impossible to",
|
||||
]
|
||||
|
||||
has_tool_calls = 'tool_calls' in message and message['tool_calls']
|
||||
has_tool_calls = "tool_calls" in message and message["tool_calls"]
|
||||
mentions_completion = any(keyword in content for keyword in completion_keywords)
|
||||
mentions_error = any(keyword in content for keyword in error_keywords)
|
||||
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from pr.ui import Colors, display_tool_call, print_autonomous_header
|
||||
import time
|
||||
|
||||
from pr.autonomous.detection import is_task_complete
|
||||
from pr.core.context import truncate_tool_result
|
||||
from pr.ui import Colors, display_tool_call
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
|
||||
def run_autonomous_mode(assistant, task):
|
||||
assistant.autonomous_mode = True
|
||||
@ -14,25 +16,32 @@ def run_autonomous_mode(assistant, task):
|
||||
logger.debug(f"=== AUTONOMOUS MODE START ===")
|
||||
logger.debug(f"Task: {task}")
|
||||
|
||||
assistant.messages.append({
|
||||
"role": "user",
|
||||
"content": f"{task}"
|
||||
})
|
||||
assistant.messages.append({"role": "user", "content": f"{task}"})
|
||||
|
||||
try:
|
||||
while True:
|
||||
assistant.autonomous_iterations += 1
|
||||
|
||||
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
|
||||
logger.debug(f"Messages before context management: {len(assistant.messages)}")
|
||||
logger.debug(
|
||||
f"--- Autonomous iteration {assistant.autonomous_iterations} ---"
|
||||
)
|
||||
logger.debug(
|
||||
f"Messages before context management: {len(assistant.messages)}"
|
||||
)
|
||||
|
||||
from pr.core.context import manage_context_window
|
||||
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
|
||||
|
||||
logger.debug(f"Messages after context management: {len(assistant.messages)}")
|
||||
assistant.messages = manage_context_window(
|
||||
assistant.messages, assistant.verbose
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Messages after context management: {len(assistant.messages)}"
|
||||
)
|
||||
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
|
||||
response = call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
@ -40,10 +49,10 @@ def run_autonomous_mode(assistant, task):
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose
|
||||
verbose=assistant.verbose,
|
||||
)
|
||||
|
||||
if 'error' in response:
|
||||
if "error" in response:
|
||||
logger.error(f"API error in autonomous mode: {response['error']}")
|
||||
print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}")
|
||||
break
|
||||
@ -74,22 +83,23 @@ def run_autonomous_mode(assistant, task):
|
||||
assistant.autonomous_mode = False
|
||||
logger.debug("=== AUTONOMOUS MODE END ===")
|
||||
|
||||
|
||||
def process_response_autonomous(assistant, response):
|
||||
if 'error' in response:
|
||||
if "error" in response:
|
||||
return f"Error: {response['error']}"
|
||||
|
||||
if 'choices' not in response or not response['choices']:
|
||||
if "choices" not in response or not response["choices"]:
|
||||
return "No response from API"
|
||||
|
||||
message = response['choices'][0]['message']
|
||||
message = response["choices"][0]["message"]
|
||||
assistant.messages.append(message)
|
||||
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
tool_results = []
|
||||
|
||||
for tool_call in message['tool_calls']:
|
||||
func_name = tool_call['function']['name']
|
||||
arguments = json.loads(tool_call['function']['arguments'])
|
||||
for tool_call in message["tool_calls"]:
|
||||
func_name = tool_call["function"]["name"]
|
||||
arguments = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
result = execute_single_tool(assistant, func_name, arguments)
|
||||
result = truncate_tool_result(result)
|
||||
@ -97,16 +107,19 @@ def process_response_autonomous(assistant, response):
|
||||
status = "success" if result.get("status") == "success" else "error"
|
||||
display_tool_call(func_name, arguments, status, result)
|
||||
|
||||
tool_results.append({
|
||||
"tool_call_id": tool_call['id'],
|
||||
"role": "tool",
|
||||
"content": json.dumps(result)
|
||||
})
|
||||
tool_results.append(
|
||||
{
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"content": json.dumps(result),
|
||||
}
|
||||
)
|
||||
|
||||
for result in tool_results:
|
||||
assistant.messages.append(result)
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
|
||||
follow_up = call_api(
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
@ -114,59 +127,88 @@ def process_response_autonomous(assistant, response):
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose
|
||||
verbose=assistant.verbose,
|
||||
)
|
||||
return process_response_autonomous(assistant, follow_up)
|
||||
|
||||
content = message.get('content', '')
|
||||
content = message.get("content", "")
|
||||
from pr.ui import render_markdown
|
||||
|
||||
return render_markdown(content, assistant.syntax_highlighting)
|
||||
|
||||
|
||||
def execute_single_tool(assistant, func_name, arguments):
|
||||
logger.debug(f"Executing tool in autonomous mode: {func_name}")
|
||||
logger.debug(f"Tool arguments: {arguments}")
|
||||
|
||||
from pr.tools import (
|
||||
http_fetch, run_command, run_command_interactive, read_file, write_file,
|
||||
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
|
||||
web_search, web_search_news, python_exec, index_source_directory,
|
||||
search_replace, open_editor, editor_insert_text, editor_replace_text,
|
||||
editor_search, close_editor, create_diff, apply_patch, tail_process, kill_process
|
||||
apply_patch,
|
||||
chdir,
|
||||
close_editor,
|
||||
create_diff,
|
||||
db_get,
|
||||
db_query,
|
||||
db_set,
|
||||
editor_insert_text,
|
||||
editor_replace_text,
|
||||
editor_search,
|
||||
getpwd,
|
||||
http_fetch,
|
||||
index_source_directory,
|
||||
kill_process,
|
||||
list_directory,
|
||||
mkdir,
|
||||
open_editor,
|
||||
python_exec,
|
||||
read_file,
|
||||
run_command,
|
||||
run_command_interactive,
|
||||
search_replace,
|
||||
tail_process,
|
||||
web_search,
|
||||
web_search_news,
|
||||
write_file,
|
||||
)
|
||||
from pr.tools.filesystem import (
|
||||
clear_edit_tracker,
|
||||
display_edit_summary,
|
||||
display_edit_timeline,
|
||||
)
|
||||
from pr.tools.patch import display_file_diff
|
||||
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
|
||||
|
||||
func_map = {
|
||||
'http_fetch': lambda **kw: http_fetch(**kw),
|
||||
'run_command': lambda **kw: run_command(**kw),
|
||||
'tail_process': lambda **kw: tail_process(**kw),
|
||||
'kill_process': lambda **kw: kill_process(**kw),
|
||||
'run_command_interactive': lambda **kw: run_command_interactive(**kw),
|
||||
'read_file': lambda **kw: read_file(**kw),
|
||||
'write_file': lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
|
||||
'list_directory': lambda **kw: list_directory(**kw),
|
||||
'mkdir': lambda **kw: mkdir(**kw),
|
||||
'chdir': lambda **kw: chdir(**kw),
|
||||
'getpwd': lambda **kw: getpwd(**kw),
|
||||
'db_set': lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
|
||||
'db_get': lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
|
||||
'db_query': lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
|
||||
'web_search': lambda **kw: web_search(**kw),
|
||||
'web_search_news': lambda **kw: web_search_news(**kw),
|
||||
'python_exec': lambda **kw: python_exec(**kw, python_globals=assistant.python_globals),
|
||||
'index_source_directory': lambda **kw: index_source_directory(**kw),
|
||||
'search_replace': lambda **kw: search_replace(**kw),
|
||||
'open_editor': lambda **kw: open_editor(**kw),
|
||||
'editor_insert_text': lambda **kw: editor_insert_text(**kw),
|
||||
'editor_replace_text': lambda **kw: editor_replace_text(**kw),
|
||||
'editor_search': lambda **kw: editor_search(**kw),
|
||||
'close_editor': lambda **kw: close_editor(**kw),
|
||||
'create_diff': lambda **kw: create_diff(**kw),
|
||||
'apply_patch': lambda **kw: apply_patch(**kw),
|
||||
'display_file_diff': lambda **kw: display_file_diff(**kw),
|
||||
'display_edit_summary': lambda **kw: display_edit_summary(),
|
||||
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
|
||||
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
|
||||
"http_fetch": lambda **kw: http_fetch(**kw),
|
||||
"run_command": lambda **kw: run_command(**kw),
|
||||
"tail_process": lambda **kw: tail_process(**kw),
|
||||
"kill_process": lambda **kw: kill_process(**kw),
|
||||
"run_command_interactive": lambda **kw: run_command_interactive(**kw),
|
||||
"read_file": lambda **kw: read_file(**kw),
|
||||
"write_file": lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
|
||||
"list_directory": lambda **kw: list_directory(**kw),
|
||||
"mkdir": lambda **kw: mkdir(**kw),
|
||||
"chdir": lambda **kw: chdir(**kw),
|
||||
"getpwd": lambda **kw: getpwd(**kw),
|
||||
"db_set": lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
|
||||
"db_get": lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
|
||||
"db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
|
||||
"web_search": lambda **kw: web_search(**kw),
|
||||
"web_search_news": lambda **kw: web_search_news(**kw),
|
||||
"python_exec": lambda **kw: python_exec(
|
||||
**kw, python_globals=assistant.python_globals
|
||||
),
|
||||
"index_source_directory": lambda **kw: index_source_directory(**kw),
|
||||
"search_replace": lambda **kw: search_replace(**kw),
|
||||
"open_editor": lambda **kw: open_editor(**kw),
|
||||
"editor_insert_text": lambda **kw: editor_insert_text(**kw),
|
||||
"editor_replace_text": lambda **kw: editor_replace_text(**kw),
|
||||
"editor_search": lambda **kw: editor_search(**kw),
|
||||
"close_editor": lambda **kw: close_editor(**kw),
|
||||
"create_diff": lambda **kw: create_diff(**kw),
|
||||
"apply_patch": lambda **kw: apply_patch(**kw),
|
||||
"display_file_diff": lambda **kw: display_file_diff(**kw),
|
||||
"display_edit_summary": lambda **kw: display_edit_summary(),
|
||||
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
|
||||
"clear_edit_tracker": lambda **kw: clear_edit_tracker(),
|
||||
}
|
||||
|
||||
if func_name in func_map:
|
||||
|
||||
2
pr/cache/__init__.py
vendored
2
pr/cache/__init__.py
vendored
@ -1,4 +1,4 @@
|
||||
from .api_cache import APICache
|
||||
from .tool_cache import ToolCache
|
||||
|
||||
__all__ = ['APICache', 'ToolCache']
|
||||
__all__ = ["APICache", "ToolCache"]
|
||||
|
||||
86
pr/cache/api_cache.py
vendored
86
pr/cache/api_cache.py
vendored
@ -2,7 +2,8 @@ import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class APICache:
|
||||
def __init__(self, db_path: str, ttl_seconds: int = 3600):
|
||||
@ -13,7 +14,8 @@ class APICache:
|
||||
def _initialize_cache(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS api_cache (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
response_data TEXT NOT NULL,
|
||||
@ -22,34 +24,44 @@ class APICache:
|
||||
model TEXT,
|
||||
token_count INTEGER
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str:
|
||||
def _generate_cache_key(
|
||||
self, model: str, messages: list, temperature: float, max_tokens: int
|
||||
) -> str:
|
||||
cache_data = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'temperature': temperature,
|
||||
'max_tokens': max_tokens
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
serialized = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]:
|
||||
def get(
|
||||
self, model: str, messages: list, temperature: float, max_tokens: int
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT response_data FROM api_cache
|
||||
WHERE cache_key = ? AND expires_at > ?
|
||||
''', (cache_key, current_time))
|
||||
""",
|
||||
(cache_key, current_time),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
@ -58,8 +70,15 @@ class APICache:
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
|
||||
def set(self, model: str, messages: list, temperature: float, max_tokens: int,
|
||||
response: Dict[str, Any], token_count: int = 0):
|
||||
def set(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
response: Dict[str, Any],
|
||||
token_count: int = 0,
|
||||
):
|
||||
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
||||
|
||||
current_time = int(time.time())
|
||||
@ -68,11 +87,21 @@ class APICache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO api_cache
|
||||
(cache_key, response_data, created_at, expires_at, model, token_count)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count))
|
||||
""",
|
||||
(
|
||||
cache_key,
|
||||
json.dumps(response),
|
||||
current_time,
|
||||
expires_at,
|
||||
model,
|
||||
token_count,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -83,7 +112,7 @@ class APICache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,))
|
||||
cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
@ -95,7 +124,7 @@ class APICache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM api_cache')
|
||||
cursor.execute("DELETE FROM api_cache")
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
@ -107,21 +136,26 @@ class APICache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM api_cache')
|
||||
cursor.execute("SELECT COUNT(*) FROM api_cache")
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,))
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)
|
||||
)
|
||||
valid_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,))
|
||||
cursor.execute(
|
||||
"SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?",
|
||||
(current_time,),
|
||||
)
|
||||
total_tokens = cursor.fetchone()[0] or 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'valid_entries': valid_entries,
|
||||
'expired_entries': total_entries - valid_entries,
|
||||
'total_cached_tokens': total_tokens
|
||||
"total_entries": total_entries,
|
||||
"valid_entries": valid_entries,
|
||||
"expired_entries": total_entries - valid_entries,
|
||||
"total_cached_tokens": total_tokens,
|
||||
}
|
||||
|
||||
98
pr/cache/tool_cache.py
vendored
98
pr/cache/tool_cache.py
vendored
@ -2,16 +2,17 @@ import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Optional, Any, Set
|
||||
from typing import Any, Optional, Set
|
||||
|
||||
|
||||
class ToolCache:
|
||||
DETERMINISTIC_TOOLS: Set[str] = {
|
||||
'read_file',
|
||||
'list_directory',
|
||||
'get_current_directory',
|
||||
'db_get',
|
||||
'db_query',
|
||||
'index_directory'
|
||||
"read_file",
|
||||
"list_directory",
|
||||
"get_current_directory",
|
||||
"db_get",
|
||||
"db_query",
|
||||
"index_directory",
|
||||
}
|
||||
|
||||
def __init__(self, db_path: str, ttl_seconds: int = 300):
|
||||
@ -22,7 +23,8 @@ class ToolCache:
|
||||
def _initialize_cache(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tool_cache (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
tool_name TEXT NOT NULL,
|
||||
@ -31,21 +33,23 @@ class ToolCache:
|
||||
expires_at INTEGER NOT NULL,
|
||||
hit_count INTEGER DEFAULT 0
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
|
||||
cache_data = {
|
||||
'tool': tool_name,
|
||||
'args': arguments
|
||||
}
|
||||
cache_data = {"tool": tool_name, "args": arguments}
|
||||
serialized = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||
|
||||
@ -62,18 +66,24 @@ class ToolCache:
|
||||
cursor = conn.cursor()
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT result_data, hit_count FROM tool_cache
|
||||
WHERE cache_key = ? AND expires_at > ?
|
||||
''', (cache_key, current_time))
|
||||
""",
|
||||
(cache_key, current_time),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE tool_cache SET hit_count = hit_count + 1
|
||||
WHERE cache_key = ?
|
||||
''', (cache_key,))
|
||||
""",
|
||||
(cache_key,),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return json.loads(row[0])
|
||||
@ -93,11 +103,14 @@ class ToolCache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO tool_cache
|
||||
(cache_key, tool_name, result_data, created_at, expires_at, hit_count)
|
||||
VALUES (?, ?, ?, ?, ?, 0)
|
||||
''', (cache_key, tool_name, json.dumps(result), current_time, expires_at))
|
||||
""",
|
||||
(cache_key, tool_name, json.dumps(result), current_time, expires_at),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -106,7 +119,7 @@ class ToolCache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,))
|
||||
cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
@ -120,7 +133,7 @@ class ToolCache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,))
|
||||
cursor.execute("DELETE FROM tool_cache WHERE expires_at <= ?", (current_time,))
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
@ -132,7 +145,7 @@ class ToolCache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM tool_cache')
|
||||
cursor.execute("DELETE FROM tool_cache")
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
@ -144,36 +157,41 @@ class ToolCache:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM tool_cache')
|
||||
cursor.execute("SELECT COUNT(*) FROM tool_cache")
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,))
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)
|
||||
)
|
||||
valid_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,))
|
||||
cursor.execute(
|
||||
"SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?",
|
||||
(current_time,),
|
||||
)
|
||||
total_hits = cursor.fetchone()[0] or 0
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT tool_name, COUNT(*), SUM(hit_count)
|
||||
FROM tool_cache
|
||||
WHERE expires_at > ?
|
||||
GROUP BY tool_name
|
||||
''', (current_time,))
|
||||
""",
|
||||
(current_time,),
|
||||
)
|
||||
|
||||
tool_stats = {}
|
||||
for row in cursor.fetchall():
|
||||
tool_stats[row[0]] = {
|
||||
'cached_entries': row[1],
|
||||
'total_hits': row[2] or 0
|
||||
}
|
||||
tool_stats[row[0]] = {"cached_entries": row[1], "total_hits": row[2] or 0}
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'valid_entries': valid_entries,
|
||||
'expired_entries': total_entries - valid_entries,
|
||||
'total_cache_hits': total_hits,
|
||||
'by_tool': tool_stats
|
||||
"total_entries": total_entries,
|
||||
"valid_entries": valid_entries,
|
||||
"expired_entries": total_entries - valid_entries,
|
||||
"total_cache_hits": total_hits,
|
||||
"by_tool": tool_stats,
|
||||
}
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from pr.commands.handlers import handle_command
|
||||
|
||||
__all__ = ['handle_command']
|
||||
__all__ = ["handle_command"]
|
||||
|
||||
@ -1,30 +1,35 @@
|
||||
import json
|
||||
import time
|
||||
from pr.ui import Colors
|
||||
|
||||
from pr.autonomous import run_autonomous_mode
|
||||
from pr.core.api import list_models
|
||||
from pr.tools import read_file
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.core.api import list_models
|
||||
from pr.autonomous import run_autonomous_mode
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
def handle_command(assistant, command):
|
||||
command_parts = command.strip().split(maxsplit=1)
|
||||
cmd = command_parts[0].lower()
|
||||
|
||||
if cmd == '/auto':
|
||||
if cmd == "/auto":
|
||||
if len(command_parts) < 2:
|
||||
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}"
|
||||
)
|
||||
return True
|
||||
|
||||
task = command_parts[1]
|
||||
run_autonomous_mode(assistant, task)
|
||||
return True
|
||||
|
||||
if cmd in ['exit', 'quit', 'q']:
|
||||
if cmd in ["exit", "quit", "q"]:
|
||||
return False
|
||||
|
||||
elif cmd == 'help':
|
||||
print(f"""
|
||||
elif cmd == "help":
|
||||
print(
|
||||
f"""
|
||||
{Colors.BOLD}Available Commands:{Colors.RESET}
|
||||
|
||||
{Colors.BOLD}Basic:{Colors.RESET}
|
||||
@ -54,18 +59,21 @@ def handle_command(assistant, command):
|
||||
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
|
||||
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
|
||||
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
elif cmd == '/reset':
|
||||
elif cmd == "/reset":
|
||||
assistant.messages = assistant.messages[:1]
|
||||
print(f"{Colors.GREEN}Message history cleared{Colors.RESET}")
|
||||
|
||||
elif cmd == '/dump':
|
||||
elif cmd == "/dump":
|
||||
print(json.dumps(assistant.messages, indent=2))
|
||||
|
||||
elif cmd == '/verbose':
|
||||
elif cmd == "/verbose":
|
||||
assistant.verbose = not assistant.verbose
|
||||
print(f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}")
|
||||
print(
|
||||
f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}"
|
||||
)
|
||||
|
||||
elif cmd.startswith("/model"):
|
||||
if len(command_parts) < 2:
|
||||
@ -74,77 +82,81 @@ def handle_command(assistant, command):
|
||||
assistant.model = command_parts[1]
|
||||
print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}")
|
||||
|
||||
elif cmd == '/models':
|
||||
elif cmd == "/models":
|
||||
models = list_models(assistant.model_list_url, assistant.api_key)
|
||||
if isinstance(models, dict) and 'error' in models:
|
||||
if isinstance(models, dict) and "error" in models:
|
||||
print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.BOLD}Available Models:{Colors.RESET}")
|
||||
for model in models:
|
||||
print(f" • {Colors.CYAN}{model['id']}{Colors.RESET}")
|
||||
|
||||
elif cmd == '/tools':
|
||||
elif cmd == "/tools":
|
||||
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
|
||||
for tool in get_tools_definition():
|
||||
func = tool['function']
|
||||
print(f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}")
|
||||
func = tool["function"]
|
||||
print(
|
||||
f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}"
|
||||
)
|
||||
|
||||
elif cmd == '/review' and len(command_parts) > 1:
|
||||
elif cmd == "/review" and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
review_file(assistant, filename)
|
||||
|
||||
elif cmd == '/refactor' and len(command_parts) > 1:
|
||||
elif cmd == "/refactor" and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
refactor_file(assistant, filename)
|
||||
|
||||
elif cmd == '/obfuscate' and len(command_parts) > 1:
|
||||
elif cmd == "/obfuscate" and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
obfuscate_file(assistant, filename)
|
||||
|
||||
elif cmd == '/workflows':
|
||||
elif cmd == "/workflows":
|
||||
show_workflows(assistant)
|
||||
|
||||
elif cmd == '/workflow' and len(command_parts) > 1:
|
||||
elif cmd == "/workflow" and len(command_parts) > 1:
|
||||
workflow_name = command_parts[1]
|
||||
execute_workflow_command(assistant, workflow_name)
|
||||
|
||||
elif cmd == '/agent' and len(command_parts) > 1:
|
||||
elif cmd == "/agent" and len(command_parts) > 1:
|
||||
args = command_parts[1].split(maxsplit=1)
|
||||
if len(args) < 2:
|
||||
print(f"{Colors.RED}Usage: /agent <role> <task>{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}"
|
||||
)
|
||||
else:
|
||||
role, task = args[0], args[1]
|
||||
execute_agent_task(assistant, role, task)
|
||||
|
||||
elif cmd == '/agents':
|
||||
elif cmd == "/agents":
|
||||
show_agents(assistant)
|
||||
|
||||
elif cmd == '/collaborate' and len(command_parts) > 1:
|
||||
elif cmd == "/collaborate" and len(command_parts) > 1:
|
||||
task = command_parts[1]
|
||||
collaborate_agents_command(assistant, task)
|
||||
|
||||
elif cmd == '/knowledge' and len(command_parts) > 1:
|
||||
elif cmd == "/knowledge" and len(command_parts) > 1:
|
||||
query = command_parts[1]
|
||||
search_knowledge(assistant, query)
|
||||
|
||||
elif cmd == '/remember' and len(command_parts) > 1:
|
||||
elif cmd == "/remember" and len(command_parts) > 1:
|
||||
content = command_parts[1]
|
||||
store_knowledge(assistant, content)
|
||||
|
||||
elif cmd == '/history':
|
||||
elif cmd == "/history":
|
||||
show_conversation_history(assistant)
|
||||
|
||||
elif cmd == '/cache':
|
||||
if len(command_parts) > 1 and command_parts[1].lower() == 'clear':
|
||||
elif cmd == "/cache":
|
||||
if len(command_parts) > 1 and command_parts[1].lower() == "clear":
|
||||
clear_caches(assistant)
|
||||
else:
|
||||
show_cache_stats(assistant)
|
||||
|
||||
elif cmd == '/stats':
|
||||
elif cmd == "/stats":
|
||||
show_system_stats(assistant)
|
||||
|
||||
elif cmd.startswith('/bg'):
|
||||
elif cmd.startswith("/bg"):
|
||||
handle_background_command(assistant, command)
|
||||
|
||||
else:
|
||||
@ -152,35 +164,46 @@ def handle_command(assistant, command):
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def review_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result['status'] == 'success':
|
||||
message = f"Please review this file and provide feedback:\n\n{result['content']}"
|
||||
if result["status"] == "success":
|
||||
message = (
|
||||
f"Please review this file and provide feedback:\n\n{result['content']}"
|
||||
)
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
|
||||
def refactor_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result['status'] == 'success':
|
||||
message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
|
||||
if result["status"] == "success":
|
||||
message = (
|
||||
f"Please refactor this code to improve its quality:\n\n{result['content']}"
|
||||
)
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
|
||||
def obfuscate_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result['status'] == 'success':
|
||||
if result["status"] == "success":
|
||||
message = f"Please obfuscate this code:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
else:
|
||||
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
|
||||
|
||||
|
||||
def show_workflows(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -194,23 +217,25 @@ def show_workflows(assistant):
|
||||
print(f" • {Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}")
|
||||
print(f" Executions: {wf['execution_count']}")
|
||||
|
||||
|
||||
def execute_workflow_command(assistant, workflow_name):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}")
|
||||
result = assistant.enhanced.execute_workflow(workflow_name)
|
||||
|
||||
if 'error' in result:
|
||||
if "error" in result:
|
||||
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}")
|
||||
print(f"Execution ID: {result['execution_id']}")
|
||||
print(f"Results: {json.dumps(result['results'], indent=2)}")
|
||||
|
||||
|
||||
def execute_agent_task(assistant, role, task):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -221,14 +246,15 @@ def execute_agent_task(assistant, role, task):
|
||||
print(f"{Colors.YELLOW}Executing task...{Colors.RESET}")
|
||||
result = assistant.enhanced.agent_task(agent_id, task)
|
||||
|
||||
if 'error' in result:
|
||||
if "error" in result:
|
||||
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
|
||||
else:
|
||||
print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}")
|
||||
print(result['response'])
|
||||
print(result["response"])
|
||||
|
||||
|
||||
def show_agents(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -236,37 +262,39 @@ def show_agents(assistant):
|
||||
print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}")
|
||||
print(f"Active agents: {summary['active_agents']}")
|
||||
|
||||
if summary['agents']:
|
||||
for agent in summary['agents']:
|
||||
if summary["agents"]:
|
||||
for agent in summary["agents"]:
|
||||
print(f"\n • {Colors.CYAN}{agent['agent_id']}{Colors.RESET}")
|
||||
print(f" Role: {agent['role']}")
|
||||
print(f" Tasks completed: {agent['task_count']}")
|
||||
print(f" Messages: {agent['message_count']}")
|
||||
|
||||
|
||||
def collaborate_agents_command(assistant, task):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}")
|
||||
roles = ['coding', 'research', 'planning']
|
||||
roles = ["coding", "research", "planning"]
|
||||
|
||||
result = assistant.enhanced.collaborate_agents(task, roles)
|
||||
|
||||
print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}")
|
||||
print(f"\nOrchestrator response:")
|
||||
if 'orchestrator' in result and 'response' in result['orchestrator']:
|
||||
print(result['orchestrator']['response'])
|
||||
if "orchestrator" in result and "response" in result["orchestrator"]:
|
||||
print(result["orchestrator"]["response"])
|
||||
|
||||
if result.get('agents'):
|
||||
if result.get("agents"):
|
||||
print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}")
|
||||
for agent_result in result['agents']:
|
||||
if 'role' in agent_result:
|
||||
for agent_result in result["agents"]:
|
||||
if "role" in agent_result:
|
||||
print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}")
|
||||
print(agent_result.get('response', 'No response'))
|
||||
print(agent_result.get("response", "No response"))
|
||||
|
||||
|
||||
def search_knowledge(assistant, query):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -282,13 +310,15 @@ def search_knowledge(assistant, query):
|
||||
print(f" {entry.content[:200]}...")
|
||||
print(f" Accessed: {entry.access_count} times")
|
||||
|
||||
|
||||
def store_knowledge(assistant, content):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
import uuid
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from pr.memory import KnowledgeEntry
|
||||
|
||||
categories = assistant.enhanced.fact_extractor.categorize_content(content)
|
||||
@ -296,11 +326,11 @@ def store_knowledge(assistant, content):
|
||||
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=entry_id,
|
||||
category=categories[0] if categories else 'general',
|
||||
category=categories[0] if categories else "general",
|
||||
content=content,
|
||||
metadata={'manual_entry': True},
|
||||
metadata={"manual_entry": True},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
updated_at=time.time(),
|
||||
)
|
||||
|
||||
assistant.enhanced.knowledge_store.add_entry(entry)
|
||||
@ -308,8 +338,9 @@ def store_knowledge(assistant, content):
|
||||
print(f"Entry ID: {entry_id}")
|
||||
print(f"Category: {entry.category}")
|
||||
|
||||
|
||||
def show_conversation_history(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -322,17 +353,21 @@ def show_conversation_history(assistant):
|
||||
print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}")
|
||||
for conv in history:
|
||||
import datetime
|
||||
started = datetime.datetime.fromtimestamp(conv['started_at']).strftime('%Y-%m-%d %H:%M')
|
||||
|
||||
started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
print(f"\n • {Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
|
||||
print(f" Started: {started}")
|
||||
print(f" Messages: {conv['message_count']}")
|
||||
if conv.get('summary'):
|
||||
if conv.get("summary"):
|
||||
print(f" Summary: {conv['summary'][:100]}...")
|
||||
if conv.get('topics'):
|
||||
if conv.get("topics"):
|
||||
print(f" Topics: {', '.join(conv['topics'])}")
|
||||
|
||||
|
||||
def show_cache_stats(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -340,36 +375,40 @@ def show_cache_stats(assistant):
|
||||
|
||||
print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}")
|
||||
|
||||
if 'api_cache' in stats:
|
||||
api_stats = stats['api_cache']
|
||||
if "api_cache" in stats:
|
||||
api_stats = stats["api_cache"]
|
||||
print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}")
|
||||
print(f" Total entries: {api_stats['total_entries']}")
|
||||
print(f" Valid entries: {api_stats['valid_entries']}")
|
||||
print(f" Expired entries: {api_stats['expired_entries']}")
|
||||
print(f" Cached tokens: {api_stats['total_cached_tokens']}")
|
||||
|
||||
if 'tool_cache' in stats:
|
||||
tool_stats = stats['tool_cache']
|
||||
if "tool_cache" in stats:
|
||||
tool_stats = stats["tool_cache"]
|
||||
print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}")
|
||||
print(f" Total entries: {tool_stats['total_entries']}")
|
||||
print(f" Valid entries: {tool_stats['valid_entries']}")
|
||||
print(f" Total cache hits: {tool_stats['total_cache_hits']}")
|
||||
|
||||
if tool_stats.get('by_tool'):
|
||||
if tool_stats.get("by_tool"):
|
||||
print(f"\n Per-tool statistics:")
|
||||
for tool_name, tool_stat in tool_stats['by_tool'].items():
|
||||
print(f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits")
|
||||
for tool_name, tool_stat in tool_stats["by_tool"].items():
|
||||
print(
|
||||
f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits"
|
||||
)
|
||||
|
||||
|
||||
def clear_caches(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
assistant.enhanced.clear_caches()
|
||||
print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}")
|
||||
|
||||
|
||||
def show_system_stats(assistant):
|
||||
if not hasattr(assistant, 'enhanced'):
|
||||
if not hasattr(assistant, "enhanced"):
|
||||
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
|
||||
return
|
||||
|
||||
@ -388,68 +427,81 @@ def show_system_stats(assistant):
|
||||
print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}")
|
||||
print(f" Count: {agent_summary['active_agents']}")
|
||||
|
||||
if 'api_cache' in cache_stats:
|
||||
if "api_cache" in cache_stats:
|
||||
print(f"\n{Colors.CYAN}Caching:{Colors.RESET}")
|
||||
print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}")
|
||||
if 'tool_cache' in cache_stats:
|
||||
if "tool_cache" in cache_stats:
|
||||
print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}")
|
||||
|
||||
|
||||
def handle_background_command(assistant, command):
|
||||
"""Handle background multiplexer commands."""
|
||||
parts = command.strip().split(maxsplit=2)
|
||||
if len(parts) < 2:
|
||||
print(f"{Colors.RED}Usage: /bg <subcommand> [args]{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GRAY}Available subcommands: start, list, status, output, input, kill, events{Colors.RESET}"
|
||||
)
|
||||
return
|
||||
|
||||
subcmd = parts[1].lower()
|
||||
|
||||
try:
|
||||
if subcmd == 'start' and len(parts) >= 3:
|
||||
if subcmd == "start" and len(parts) >= 3:
|
||||
session_name = f"bg_{len(parts[2].split())}_{int(time.time())}"
|
||||
start_background_session(assistant, session_name, parts[2])
|
||||
elif subcmd == 'list':
|
||||
elif subcmd == "list":
|
||||
list_background_sessions(assistant)
|
||||
elif subcmd == 'status' and len(parts) >= 3:
|
||||
elif subcmd == "status" and len(parts) >= 3:
|
||||
show_session_status(assistant, parts[2])
|
||||
elif subcmd == 'output' and len(parts) >= 3:
|
||||
elif subcmd == "output" and len(parts) >= 3:
|
||||
show_session_output(assistant, parts[2])
|
||||
elif subcmd == 'input' and len(parts) >= 4:
|
||||
elif subcmd == "input" and len(parts) >= 4:
|
||||
send_session_input(assistant, parts[2], parts[3])
|
||||
elif subcmd == 'kill' and len(parts) >= 3:
|
||||
elif subcmd == "kill" and len(parts) >= 3:
|
||||
kill_background_session(assistant, parts[2])
|
||||
elif subcmd == 'events':
|
||||
elif subcmd == "events":
|
||||
show_background_events(assistant)
|
||||
else:
|
||||
print(f"{Colors.RED}Unknown background command: {subcmd}{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GRAY}Available: start, list, status, output, input, kill, events{Colors.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error executing background command: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def start_background_session(assistant, session_name, command):
|
||||
"""Start a command in background."""
|
||||
try:
|
||||
from pr.multiplexer import start_background_process
|
||||
|
||||
result = start_background_process(session_name, command)
|
||||
|
||||
if result['status'] == 'success':
|
||||
print(f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}")
|
||||
if result["status"] == "success":
|
||||
print(
|
||||
f"{Colors.GREEN}Started background session '{session_name}' with PID {result['pid']}{Colors.RESET}"
|
||||
)
|
||||
else:
|
||||
print(f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}Failed to start background session: {result.get('error', 'Unknown error')}{Colors.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error starting background session: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def list_background_sessions(assistant):
|
||||
"""List all background sessions."""
|
||||
try:
|
||||
from pr.ui.display import display_multiplexer_status
|
||||
from pr.multiplexer import get_all_sessions
|
||||
from pr.ui.display import display_multiplexer_status
|
||||
|
||||
sessions = get_all_sessions()
|
||||
display_multiplexer_status(sessions)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error listing background sessions: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def show_session_status(assistant, session_name):
|
||||
"""Show status of a specific session."""
|
||||
try:
|
||||
@ -461,15 +513,17 @@ def show_session_status(assistant, session_name):
|
||||
print(f" Status: {info.get('status', 'unknown')}")
|
||||
print(f" PID: {info.get('pid', 'N/A')}")
|
||||
print(f" Command: {info.get('command', 'N/A')}")
|
||||
if 'start_time' in info:
|
||||
if "start_time" in info:
|
||||
import time
|
||||
elapsed = time.time() - info['start_time']
|
||||
|
||||
elapsed = time.time() - info["start_time"]
|
||||
print(f" Running for: {elapsed:.1f}s")
|
||||
else:
|
||||
print(f"{Colors.YELLOW}Session '{session_name}' not found{Colors.RESET}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error getting session status: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def show_session_output(assistant, session_name):
|
||||
"""Show output of a specific session."""
|
||||
try:
|
||||
@ -482,36 +536,45 @@ def show_session_output(assistant, session_name):
|
||||
for line in output:
|
||||
print(line)
|
||||
else:
|
||||
print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def send_session_input(assistant, session_name, input_text):
|
||||
"""Send input to a background session."""
|
||||
try:
|
||||
from pr.multiplexer import send_input_to_session
|
||||
|
||||
result = send_input_to_session(session_name, input_text)
|
||||
if result['status'] == 'success':
|
||||
if result["status"] == "success":
|
||||
print(f"{Colors.GREEN}Input sent to session '{session_name}'{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}Failed to send input: {result.get('error', 'Unknown error')}{Colors.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error sending input: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def kill_background_session(assistant, session_name):
|
||||
"""Kill a background session."""
|
||||
try:
|
||||
from pr.multiplexer import kill_session
|
||||
|
||||
result = kill_session(session_name)
|
||||
if result['status'] == 'success':
|
||||
if result["status"] == "success":
|
||||
print(f"{Colors.GREEN}Session '{session_name}' terminated{Colors.RESET}")
|
||||
else:
|
||||
print(f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}Failed to kill session: {result.get('error', 'Unknown error')}{Colors.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error killing session: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def show_background_events(assistant):
|
||||
"""Show recent background events."""
|
||||
try:
|
||||
@ -526,6 +589,7 @@ def show_background_events(assistant):
|
||||
|
||||
for event in events[-10:]: # Show last 10 events
|
||||
from pr.ui.display import display_background_event
|
||||
|
||||
display_background_event(event)
|
||||
else:
|
||||
print(f"{Colors.GRAY}No recent background events{Colors.RESET}")
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
from pr.tools.interactive_control import (
|
||||
list_active_sessions, get_session_status, read_session_output,
|
||||
send_input_to_session, close_interactive_session
|
||||
)
|
||||
from pr.multiplexer import get_multiplexer
|
||||
from pr.tools.interactive_control import (
|
||||
close_interactive_session,
|
||||
get_session_status,
|
||||
list_active_sessions,
|
||||
read_session_output,
|
||||
send_input_to_session,
|
||||
)
|
||||
from pr.tools.prompt_detection import get_global_detector
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
def show_sessions(args=None):
|
||||
"""Show all active multiplexer sessions."""
|
||||
sessions = list_active_sessions()
|
||||
@ -18,24 +22,29 @@ def show_sessions(args=None):
|
||||
print("-" * 80)
|
||||
|
||||
for session_name, session_data in sessions.items():
|
||||
metadata = session_data['metadata']
|
||||
output_summary = session_data['output_summary']
|
||||
metadata = session_data["metadata"]
|
||||
output_summary = session_data["output_summary"]
|
||||
|
||||
status = get_session_status(session_name)
|
||||
is_active = status.get('is_active', False) if status else False
|
||||
is_active = status.get("is_active", False) if status else False
|
||||
|
||||
status_color = Colors.GREEN if is_active else Colors.RED
|
||||
print(f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.CYAN}{session_name}{Colors.RESET}: {status_color}{metadata.get('process_type', 'unknown')}{Colors.RESET}"
|
||||
)
|
||||
|
||||
if status and 'pid' in status:
|
||||
if status and "pid" in status:
|
||||
print(f" PID: {status['pid']}")
|
||||
|
||||
print(f" Age: {metadata.get('start_time', 0):.1f}s")
|
||||
print(f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines")
|
||||
print(
|
||||
f" Output: {output_summary['stdout_lines']} stdout, {output_summary['stderr_lines']} stderr lines"
|
||||
)
|
||||
print(f" Interactions: {metadata.get('interaction_count', 0)}")
|
||||
print(f" State: {metadata.get('state', 'unknown')}")
|
||||
print()
|
||||
|
||||
|
||||
def attach_session(args):
|
||||
"""Attach to a session (show its output and allow interaction)."""
|
||||
if not args or len(args) < 1:
|
||||
@ -56,20 +65,23 @@ def attach_session(args):
|
||||
# Show recent output
|
||||
try:
|
||||
output = read_session_output(session_name, lines=20)
|
||||
if output['stdout']:
|
||||
if output["stdout"]:
|
||||
print(f"{Colors.GRAY}Recent stdout:{Colors.RESET}")
|
||||
for line in output['stdout'].split('\n'):
|
||||
for line in output["stdout"].split("\n"):
|
||||
if line.strip():
|
||||
print(f" {line}")
|
||||
if output['stderr']:
|
||||
if output["stderr"]:
|
||||
print(f"{Colors.YELLOW}Recent stderr:{Colors.RESET}")
|
||||
for line in output['stderr'].split('\n'):
|
||||
for line in output["stderr"].split("\n"):
|
||||
if line.strip():
|
||||
print(f" {line}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error reading output: {e}{Colors.RESET}")
|
||||
|
||||
print(f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}")
|
||||
print(
|
||||
f"\n{Colors.CYAN}Session is {'active' if status.get('is_active') else 'inactive'}{Colors.RESET}"
|
||||
)
|
||||
|
||||
|
||||
def detach_session(args):
|
||||
"""Detach from a session (stop showing its output but keep it running)."""
|
||||
@ -87,7 +99,10 @@ def detach_session(args):
|
||||
# In this implementation, detaching just means we stop displaying output
|
||||
# The session continues to run in the background
|
||||
mux.show_output = False
|
||||
print(f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}Detached from session '{session_name}'. It continues running in background.{Colors.RESET}"
|
||||
)
|
||||
|
||||
|
||||
def kill_session(args):
|
||||
"""Kill a session forcefully."""
|
||||
@ -101,7 +116,10 @@ def kill_session(args):
|
||||
close_interactive_session(session_name)
|
||||
print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}"
|
||||
)
|
||||
|
||||
|
||||
def send_command(args):
|
||||
"""Send a command to a session."""
|
||||
@ -110,13 +128,18 @@ def send_command(args):
|
||||
return
|
||||
|
||||
session_name = args[0]
|
||||
command = ' '.join(args[1:])
|
||||
command = " ".join(args[1:])
|
||||
|
||||
try:
|
||||
send_input_to_session(session_name, command)
|
||||
print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}"
|
||||
)
|
||||
|
||||
|
||||
def show_session_log(args):
|
||||
"""Show the full log/output of a session."""
|
||||
@ -131,19 +154,20 @@ def show_session_log(args):
|
||||
print(f"{Colors.BOLD}Full log for session: {session_name}{Colors.RESET}")
|
||||
print("=" * 80)
|
||||
|
||||
if output['stdout']:
|
||||
if output["stdout"]:
|
||||
print(f"{Colors.GRAY}STDOUT:{Colors.RESET}")
|
||||
print(output['stdout'])
|
||||
print(output["stdout"])
|
||||
print()
|
||||
|
||||
if output['stderr']:
|
||||
if output["stderr"]:
|
||||
print(f"{Colors.YELLOW}STDERR:{Colors.RESET}")
|
||||
print(output['stderr'])
|
||||
print(output["stderr"])
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error reading log for '{session_name}': {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def show_session_status(args):
|
||||
"""Show detailed status of a session."""
|
||||
if not args or len(args) < 1:
|
||||
@ -160,11 +184,11 @@ def show_session_status(args):
|
||||
print(f"{Colors.BOLD}Status for session: {session_name}{Colors.RESET}")
|
||||
print("-" * 50)
|
||||
|
||||
metadata = status.get('metadata', {})
|
||||
metadata = status.get("metadata", {})
|
||||
print(f"Process type: {metadata.get('process_type', 'unknown')}")
|
||||
print(f"Active: {status.get('is_active', False)}")
|
||||
|
||||
if 'pid' in status:
|
||||
if "pid" in status:
|
||||
print(f"PID: {status['pid']}")
|
||||
|
||||
print(f"Start time: {metadata.get('start_time', 0):.1f}")
|
||||
@ -172,8 +196,10 @@ def show_session_status(args):
|
||||
print(f"Interaction count: {metadata.get('interaction_count', 0)}")
|
||||
print(f"State: {metadata.get('state', 'unknown')}")
|
||||
|
||||
output_summary = status.get('output_summary', {})
|
||||
print(f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr")
|
||||
output_summary = status.get("output_summary", {})
|
||||
print(
|
||||
f"Output lines: {output_summary.get('stdout_lines', 0)} stdout, {output_summary.get('stderr_lines', 0)} stderr"
|
||||
)
|
||||
|
||||
# Show prompt detection info
|
||||
detector = get_global_detector()
|
||||
@ -182,6 +208,7 @@ def show_session_status(args):
|
||||
print(f"Current state: {session_info['current_state']}")
|
||||
print(f"Is waiting for input: {session_info['is_waiting']}")
|
||||
|
||||
|
||||
def list_waiting_sessions(args=None):
|
||||
"""List sessions that appear to be waiting for input."""
|
||||
sessions = list_active_sessions()
|
||||
@ -193,14 +220,16 @@ def list_waiting_sessions(args=None):
|
||||
waiting_sessions.append(session_name)
|
||||
|
||||
if not waiting_sessions:
|
||||
print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}"
|
||||
)
|
||||
return
|
||||
|
||||
print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}")
|
||||
for session_name in waiting_sessions:
|
||||
status = get_session_status(session_name)
|
||||
if status:
|
||||
process_type = status.get('metadata', {}).get('process_type', 'unknown')
|
||||
process_type = status.get("metadata", {}).get("process_type", "unknown")
|
||||
print(f" {Colors.CYAN}{session_name}{Colors.RESET} ({process_type})")
|
||||
|
||||
# Show suggestions
|
||||
@ -208,17 +237,20 @@ def list_waiting_sessions(args=None):
|
||||
if session_info:
|
||||
suggestions = detector.get_response_suggestions({}, process_type)
|
||||
if suggestions:
|
||||
print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3
|
||||
print(
|
||||
f" Suggested inputs: {', '.join(suggestions[:3])}"
|
||||
) # Show first 3
|
||||
print()
|
||||
|
||||
|
||||
# Command registry for the multiplexer commands
|
||||
MULTIPLEXER_COMMANDS = {
|
||||
'show_sessions': show_sessions,
|
||||
'attach_session': attach_session,
|
||||
'detach_session': detach_session,
|
||||
'kill_session': kill_session,
|
||||
'send_command': send_command,
|
||||
'show_session_log': show_session_log,
|
||||
'show_session_status': show_session_status,
|
||||
'list_waiting_sessions': list_waiting_sessions,
|
||||
}
|
||||
"show_sessions": show_sessions,
|
||||
"attach_session": attach_session,
|
||||
"detach_session": detach_session,
|
||||
"kill_session": kill_session,
|
||||
"send_command": send_command,
|
||||
"show_session_log": show_session_log,
|
||||
"show_session_status": show_session_status,
|
||||
"list_waiting_sessions": list_waiting_sessions,
|
||||
}
|
||||
|
||||
97
pr/config.py
97
pr/config.py
@ -27,15 +27,78 @@ CONTENT_TRIM_LENGTH = 30000
|
||||
MAX_TOOL_RESULT_LENGTH = 30000
|
||||
|
||||
LANGUAGE_KEYWORDS = {
|
||||
'python': ['def', 'class', 'import', 'from', 'if', 'else', 'elif', 'for', 'while',
|
||||
'return', 'try', 'except', 'finally', 'with', 'as', 'lambda', 'yield',
|
||||
'None', 'True', 'False', 'and', 'or', 'not', 'in', 'is'],
|
||||
'javascript': ['function', 'var', 'let', 'const', 'if', 'else', 'for', 'while',
|
||||
'return', 'try', 'catch', 'finally', 'class', 'extends', 'new',
|
||||
'this', 'null', 'undefined', 'true', 'false'],
|
||||
'java': ['public', 'private', 'protected', 'class', 'interface', 'extends',
|
||||
'implements', 'static', 'final', 'void', 'int', 'String', 'boolean',
|
||||
'if', 'else', 'for', 'while', 'return', 'try', 'catch', 'finally'],
|
||||
"python": [
|
||||
"def",
|
||||
"class",
|
||||
"import",
|
||||
"from",
|
||||
"if",
|
||||
"else",
|
||||
"elif",
|
||||
"for",
|
||||
"while",
|
||||
"return",
|
||||
"try",
|
||||
"except",
|
||||
"finally",
|
||||
"with",
|
||||
"as",
|
||||
"lambda",
|
||||
"yield",
|
||||
"None",
|
||||
"True",
|
||||
"False",
|
||||
"and",
|
||||
"or",
|
||||
"not",
|
||||
"in",
|
||||
"is",
|
||||
],
|
||||
"javascript": [
|
||||
"function",
|
||||
"var",
|
||||
"let",
|
||||
"const",
|
||||
"if",
|
||||
"else",
|
||||
"for",
|
||||
"while",
|
||||
"return",
|
||||
"try",
|
||||
"catch",
|
||||
"finally",
|
||||
"class",
|
||||
"extends",
|
||||
"new",
|
||||
"this",
|
||||
"null",
|
||||
"undefined",
|
||||
"true",
|
||||
"false",
|
||||
],
|
||||
"java": [
|
||||
"public",
|
||||
"private",
|
||||
"protected",
|
||||
"class",
|
||||
"interface",
|
||||
"extends",
|
||||
"implements",
|
||||
"static",
|
||||
"final",
|
||||
"void",
|
||||
"int",
|
||||
"String",
|
||||
"boolean",
|
||||
"if",
|
||||
"else",
|
||||
"for",
|
||||
"while",
|
||||
"return",
|
||||
"try",
|
||||
"catch",
|
||||
"finally",
|
||||
],
|
||||
}
|
||||
|
||||
CACHE_ENABLED = True
|
||||
@ -70,18 +133,18 @@ MAX_CONCURRENT_SESSIONS = 10
|
||||
|
||||
# Process-specific timeouts (seconds)
|
||||
PROCESS_TIMEOUTS = {
|
||||
'default': 300, # 5 minutes
|
||||
'apt': 600, # 10 minutes
|
||||
'ssh': 60, # 1 minute
|
||||
'vim': 3600, # 1 hour
|
||||
'git': 300, # 5 minutes
|
||||
'npm': 600, # 10 minutes
|
||||
'pip': 300, # 5 minutes
|
||||
"default": 300, # 5 minutes
|
||||
"apt": 600, # 10 minutes
|
||||
"ssh": 60, # 1 minute
|
||||
"vim": 3600, # 1 hour
|
||||
"git": 300, # 5 minutes
|
||||
"npm": 600, # 10 minutes
|
||||
"pip": 300, # 5 minutes
|
||||
}
|
||||
|
||||
# Activity thresholds for LLM notification
|
||||
HIGH_OUTPUT_THRESHOLD = 50 # lines
|
||||
INACTIVE_THRESHOLD = 300 # seconds
|
||||
INACTIVE_THRESHOLD = 300 # seconds
|
||||
SESSION_NOTIFY_INTERVAL = 60 # seconds
|
||||
|
||||
# Autonomous behavior flags
|
||||
|
||||
@ -1,5 +1,11 @@
|
||||
from pr.core.assistant import Assistant
|
||||
from pr.core.api import call_api, list_models
|
||||
from pr.core.assistant import Assistant
|
||||
from pr.core.context import init_system_message, manage_context_window
|
||||
|
||||
__all__ = ['Assistant', 'call_api', 'list_models', 'init_system_message', 'manage_context_window']
|
||||
__all__ = [
|
||||
"Assistant",
|
||||
"call_api",
|
||||
"list_models",
|
||||
"init_system_message",
|
||||
"manage_context_window",
|
||||
]
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
import re
|
||||
import math
|
||||
from typing import List, Dict, Any
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class AdvancedContextManager:
|
||||
def __init__(self, knowledge_store=None, conversation_memory=None):
|
||||
self.knowledge_store = knowledge_store
|
||||
self.conversation_memory = conversation_memory
|
||||
|
||||
def adaptive_context_window(self, messages: List[Dict[str, Any]],
|
||||
task_complexity: str = 'medium') -> int:
|
||||
def adaptive_context_window(
|
||||
self, messages: List[Dict[str, Any]], task_complexity: str = "medium"
|
||||
) -> int:
|
||||
complexity_thresholds = {
|
||||
'simple': 10,
|
||||
'medium': 20,
|
||||
'complex': 35,
|
||||
'very_complex': 50
|
||||
"simple": 10,
|
||||
"medium": 20,
|
||||
"complex": 35,
|
||||
"very_complex": 50,
|
||||
}
|
||||
|
||||
base_threshold = complexity_thresholds.get(task_complexity, 20)
|
||||
@ -31,17 +31,19 @@ class AdvancedContextManager:
|
||||
return max(base_threshold, adjusted)
|
||||
|
||||
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
|
||||
total_length = sum(len(msg.get('content', '')) for msg in messages)
|
||||
total_length = sum(len(msg.get("content", "")) for msg in messages)
|
||||
avg_length = total_length / len(messages) if messages else 0
|
||||
|
||||
|
||||
unique_words = set()
|
||||
for msg in messages:
|
||||
content = msg.get('content', '')
|
||||
words = re.findall(r'\b\w+\b', content.lower())
|
||||
content = msg.get("content", "")
|
||||
words = re.findall(r"\b\w+\b", content.lower())
|
||||
unique_words.update(words)
|
||||
|
||||
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
|
||||
|
||||
|
||||
vocabulary_richness = (
|
||||
len(unique_words) / total_length if total_length > 0 else 0
|
||||
)
|
||||
|
||||
# Simple complexity score based on length and richness
|
||||
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
|
||||
return complexity
|
||||
@ -49,10 +51,10 @@ class AdvancedContextManager:
|
||||
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
|
||||
if not text.strip():
|
||||
return []
|
||||
sentences = re.split(r'(?<=[.!?])\s+', text)
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
||||
if not sentences:
|
||||
return []
|
||||
|
||||
|
||||
# Simple scoring based on length and position
|
||||
scored_sentences = []
|
||||
for i, sentence in enumerate(sentences):
|
||||
@ -60,25 +62,25 @@ class AdvancedContextManager:
|
||||
position_score = 1.0 if i == 0 else 0.8 if i < len(sentences) / 2 else 0.6
|
||||
score = (length_score + position_score) / 2
|
||||
scored_sentences.append((sentence, score))
|
||||
|
||||
|
||||
scored_sentences.sort(key=lambda x: x[1], reverse=True)
|
||||
return [s[0] for s in scored_sentences[:top_k]]
|
||||
|
||||
def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str:
|
||||
all_content = ' '.join([msg.get('content', '') for msg in messages])
|
||||
all_content = " ".join([msg.get("content", "") for msg in messages])
|
||||
key_sentences = self.extract_key_sentences(all_content, top_k=3)
|
||||
summary = ' '.join(key_sentences)
|
||||
summary = " ".join(key_sentences)
|
||||
return summary if summary else "No content to summarize."
|
||||
|
||||
def score_message_relevance(self, message: Dict[str, Any], context: str) -> float:
|
||||
content = message.get('content', '')
|
||||
content_words = set(re.findall(r'\b\w+\b', content.lower()))
|
||||
context_words = set(re.findall(r'\b\w+\b', context.lower()))
|
||||
|
||||
content = message.get("content", "")
|
||||
content_words = set(re.findall(r"\b\w+\b", content.lower()))
|
||||
context_words = set(re.findall(r"\b\w+\b", context.lower()))
|
||||
|
||||
intersection = content_words & context_words
|
||||
union = content_words | context_words
|
||||
|
||||
|
||||
if not union:
|
||||
return 0.0
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import logging
|
||||
from pr.config import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
from pr.config import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE
|
||||
from pr.core.context import auto_slim_messages
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
|
||||
|
||||
def call_api(
|
||||
messages, model, api_url, api_key, use_tools, tools_definition, verbose=False
|
||||
):
|
||||
try:
|
||||
messages = auto_slim_messages(messages, verbose=verbose)
|
||||
|
||||
@ -17,62 +21,63 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
|
||||
logger.debug(f"Use tools: {use_tools}")
|
||||
logger.debug(f"Message count: {len(messages)}")
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'temperature': DEFAULT_TEMPERATURE,
|
||||
'max_tokens': DEFAULT_MAX_TOKENS
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": DEFAULT_TEMPERATURE,
|
||||
"max_tokens": DEFAULT_MAX_TOKENS,
|
||||
}
|
||||
|
||||
if "gpt-5" in model:
|
||||
del data['temperature']
|
||||
del data['max_tokens']
|
||||
del data["temperature"]
|
||||
del data["max_tokens"]
|
||||
logger.debug("GPT-5 detected: removed temperature and max_tokens")
|
||||
|
||||
if use_tools:
|
||||
data['tools'] = tools_definition
|
||||
data['tool_choice'] = 'auto'
|
||||
data["tools"] = tools_definition
|
||||
data["tool_choice"] = "auto"
|
||||
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
|
||||
|
||||
request_json = json.dumps(data)
|
||||
logger.debug(f"Request payload size: {len(request_json)} bytes")
|
||||
|
||||
req = urllib.request.Request(
|
||||
api_url,
|
||||
data=request_json.encode('utf-8'),
|
||||
headers=headers,
|
||||
method='POST'
|
||||
api_url, data=request_json.encode("utf-8"), headers=headers, method="POST"
|
||||
)
|
||||
|
||||
logger.debug("Sending HTTP request...")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
response_data = response.read().decode('utf-8')
|
||||
response_data = response.read().decode("utf-8")
|
||||
logger.debug(f"Response received: {len(response_data)} bytes")
|
||||
result = json.loads(response_data)
|
||||
|
||||
if 'usage' in result:
|
||||
if "usage" in result:
|
||||
logger.debug(f"Token usage: {result['usage']}")
|
||||
if 'choices' in result and result['choices']:
|
||||
choice = result['choices'][0]
|
||||
if 'message' in choice:
|
||||
msg = choice['message']
|
||||
if "choices" in result and result["choices"]:
|
||||
choice = result["choices"][0]
|
||||
if "message" in choice:
|
||||
msg = choice["message"]
|
||||
logger.debug(f"Response role: {msg.get('role', 'N/A')}")
|
||||
if 'content' in msg and msg['content']:
|
||||
logger.debug(f"Response content length: {len(msg['content'])} chars")
|
||||
if 'tool_calls' in msg:
|
||||
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
|
||||
if "content" in msg and msg["content"]:
|
||||
logger.debug(
|
||||
f"Response content length: {len(msg['content'])} chars"
|
||||
)
|
||||
if "tool_calls" in msg:
|
||||
logger.debug(
|
||||
f"Response contains {len(msg['tool_calls'])} tool call(s)"
|
||||
)
|
||||
|
||||
logger.debug("=== API CALL END ===")
|
||||
return result
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode('utf-8')
|
||||
error_body = e.read().decode("utf-8")
|
||||
logger.error(f"API HTTP Error: {e.code} - {error_body}")
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": f"API Error: {e.code}", "message": error_body}
|
||||
@ -81,15 +86,16 @@ def call_api(messages, model, api_url, api_key, use_tools, tools_definition, ver
|
||||
logger.debug("=== API CALL FAILED ===")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def list_models(model_list_url, api_key):
|
||||
try:
|
||||
req = urllib.request.Request(model_list_url)
|
||||
if api_key:
|
||||
req.add_header('Authorization', f'Bearer {api_key}')
|
||||
req.add_header("Authorization", f"Bearer {api_key}")
|
||||
|
||||
with urllib.request.urlopen(req) as response:
|
||||
data = json.loads(response.read().decode('utf-8'))
|
||||
data = json.loads(response.read().decode("utf-8"))
|
||||
|
||||
return data.get('data', [])
|
||||
return data.get("data", [])
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@ -1,63 +1,111 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
import signal
|
||||
import logging
|
||||
import traceback
|
||||
import readline
|
||||
import glob as glob_module
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import readline
|
||||
import signal
|
||||
import sqlite3
|
||||
import sys
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pr.config import DB_PATH, LOG_FILE, DEFAULT_MODEL, DEFAULT_API_URL, MODEL_LIST_URL, HISTORY_FILE
|
||||
from pr.ui import Colors, render_markdown
|
||||
from pr.core.context import init_system_message, truncate_tool_result
|
||||
|
||||
from pr.commands import handle_command
|
||||
from pr.config import (
|
||||
DB_PATH,
|
||||
DEFAULT_API_URL,
|
||||
DEFAULT_MODEL,
|
||||
HISTORY_FILE,
|
||||
LOG_FILE,
|
||||
MODEL_LIST_URL,
|
||||
)
|
||||
from pr.core.api import call_api
|
||||
from pr.core.autonomous_interactions import (
|
||||
get_global_autonomous,
|
||||
stop_global_autonomous,
|
||||
)
|
||||
from pr.core.background_monitor import (
|
||||
get_global_monitor,
|
||||
start_global_monitor,
|
||||
stop_global_monitor,
|
||||
)
|
||||
from pr.core.context import init_system_message, truncate_tool_result
|
||||
from pr.tools import (
|
||||
http_fetch, run_command, run_command_interactive, read_file, write_file,
|
||||
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
|
||||
web_search, web_search_news, python_exec, index_source_directory,
|
||||
open_editor, editor_insert_text, editor_replace_text, editor_search,
|
||||
search_replace,close_editor,create_diff,apply_patch,
|
||||
tail_process, kill_process
|
||||
apply_patch,
|
||||
chdir,
|
||||
close_editor,
|
||||
create_diff,
|
||||
db_get,
|
||||
db_query,
|
||||
db_set,
|
||||
editor_insert_text,
|
||||
editor_replace_text,
|
||||
editor_search,
|
||||
getpwd,
|
||||
http_fetch,
|
||||
index_source_directory,
|
||||
kill_process,
|
||||
list_directory,
|
||||
mkdir,
|
||||
open_editor,
|
||||
python_exec,
|
||||
read_file,
|
||||
run_command,
|
||||
search_replace,
|
||||
tail_process,
|
||||
web_search,
|
||||
web_search_news,
|
||||
write_file,
|
||||
)
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.tools.filesystem import (
|
||||
clear_edit_tracker,
|
||||
display_edit_summary,
|
||||
display_edit_timeline,
|
||||
)
|
||||
from pr.tools.interactive_control import (
|
||||
start_interactive_session, send_input_to_session, read_session_output,
|
||||
list_active_sessions, close_interactive_session
|
||||
close_interactive_session,
|
||||
list_active_sessions,
|
||||
read_session_output,
|
||||
send_input_to_session,
|
||||
start_interactive_session,
|
||||
)
|
||||
from pr.tools.patch import display_file_diff
|
||||
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.commands import handle_command
|
||||
from pr.core.background_monitor import start_global_monitor, stop_global_monitor, get_global_monitor
|
||||
from pr.core.autonomous_interactions import start_global_autonomous, stop_global_autonomous, get_global_autonomous
|
||||
from pr.ui import Colors, render_markdown
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
logger = logging.getLogger("pr")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
file_handler = logging.FileHandler(LOG_FILE)
|
||||
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
class Assistant:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.messages = []
|
||||
self.verbose = args.verbose
|
||||
self.debug = getattr(args, 'debug', False)
|
||||
self.debug = getattr(args, "debug", False)
|
||||
self.syntax_highlighting = not args.no_syntax
|
||||
|
||||
if self.debug:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
|
||||
console_handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s: %(message)s")
|
||||
)
|
||||
logger.addHandler(console_handler)
|
||||
logger.debug("Debug mode enabled")
|
||||
self.api_key = os.environ.get('OPENROUTER_API_KEY', '')
|
||||
self.model = args.model or os.environ.get('AI_MODEL', DEFAULT_MODEL)
|
||||
self.api_url = args.api_url or os.environ.get('API_URL', DEFAULT_API_URL)
|
||||
self.model_list_url = args.model_list_url or os.environ.get('MODEL_LIST_URL', MODEL_LIST_URL)
|
||||
self.use_tools = os.environ.get('USE_TOOLS', '1') == '1'
|
||||
self.strict_mode = os.environ.get('STRICT_MODE', '0') == '1'
|
||||
self.api_key = os.environ.get("OPENROUTER_API_KEY", "")
|
||||
self.model = args.model or os.environ.get("AI_MODEL", DEFAULT_MODEL)
|
||||
self.api_url = args.api_url or os.environ.get("API_URL", DEFAULT_API_URL)
|
||||
self.model_list_url = args.model_list_url or os.environ.get(
|
||||
"MODEL_LIST_URL", MODEL_LIST_URL
|
||||
)
|
||||
self.use_tools = os.environ.get("USE_TOOLS", "1") == "1"
|
||||
self.strict_mode = os.environ.get("STRICT_MODE", "0") == "1"
|
||||
self.interrupt_count = 0
|
||||
self.python_globals = {}
|
||||
self.db_conn = None
|
||||
@ -69,6 +117,7 @@ class Assistant:
|
||||
|
||||
try:
|
||||
from pr.core.enhanced_assistant import EnhancedAssistant
|
||||
|
||||
self.enhanced = EnhancedAssistant(self)
|
||||
if self.debug:
|
||||
logger.debug("Enhanced assistant features initialized")
|
||||
@ -94,13 +143,17 @@ class Assistant:
|
||||
self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
||||
cursor = self.db_conn.cursor()
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS kv_store
|
||||
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)''')
|
||||
cursor.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS kv_store
|
||||
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)"""
|
||||
)
|
||||
|
||||
cursor.execute('''CREATE TABLE IF NOT EXISTS file_versions
|
||||
cursor.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS file_versions
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
filepath TEXT, content TEXT, hash TEXT,
|
||||
timestamp REAL, version INTEGER)''')
|
||||
timestamp REAL, version INTEGER)"""
|
||||
)
|
||||
|
||||
self.db_conn.commit()
|
||||
logger.debug("Database initialized successfully")
|
||||
@ -110,7 +163,7 @@ class Assistant:
|
||||
|
||||
def _handle_background_updates(self, updates):
|
||||
"""Handle background session updates by injecting them into the conversation."""
|
||||
if not updates or not updates.get('sessions'):
|
||||
if not updates or not updates.get("sessions"):
|
||||
return
|
||||
|
||||
# Format the update as a system message
|
||||
@ -118,10 +171,12 @@ class Assistant:
|
||||
|
||||
# Inject into current conversation if we're in an active session
|
||||
if self.messages and len(self.messages) > 0:
|
||||
self.messages.append({
|
||||
"role": "system",
|
||||
"content": f"Background session updates: {update_message}"
|
||||
})
|
||||
self.messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"Background session updates: {update_message}",
|
||||
}
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(f"{Colors.CYAN}Background update: {update_message}{Colors.RESET}")
|
||||
@ -130,8 +185,8 @@ class Assistant:
|
||||
"""Format background updates for LLM consumption."""
|
||||
session_summaries = []
|
||||
|
||||
for session_name, session_info in updates.get('sessions', {}).items():
|
||||
summary = session_info.get('summary', f'Session {session_name}')
|
||||
for session_name, session_info in updates.get("sessions", {}).items():
|
||||
summary = session_info.get("summary", f"Session {session_name}")
|
||||
session_summaries.append(f"{session_name}: {summary}")
|
||||
|
||||
if session_summaries:
|
||||
@ -151,30 +206,44 @@ class Assistant:
|
||||
if events:
|
||||
print(f"\n{Colors.CYAN}Background Events:{Colors.RESET}")
|
||||
for event in events:
|
||||
event_type = event.get('type', 'unknown')
|
||||
session_name = event.get('session_name', 'unknown')
|
||||
event_type = event.get("type", "unknown")
|
||||
session_name = event.get("session_name", "unknown")
|
||||
|
||||
if event_type == 'session_started':
|
||||
print(f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started")
|
||||
elif event_type == 'session_ended':
|
||||
print(f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended")
|
||||
elif event_type == 'output_received':
|
||||
lines = len(event.get('new_output', {}).get('stdout', []))
|
||||
print(f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output")
|
||||
elif event_type == 'possible_input_needed':
|
||||
print(f" {Colors.RED}❓{Colors.RESET} Session '{session_name}' may need input")
|
||||
elif event_type == 'high_output_volume':
|
||||
total = event.get('total_lines', 0)
|
||||
print(f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)")
|
||||
elif event_type == 'inactive_session':
|
||||
inactive_time = event.get('inactive_seconds', 0)
|
||||
print(f" {Colors.GRAY}⏰{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s")
|
||||
if event_type == "session_started":
|
||||
print(
|
||||
f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started"
|
||||
)
|
||||
elif event_type == "session_ended":
|
||||
print(
|
||||
f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended"
|
||||
)
|
||||
elif event_type == "output_received":
|
||||
lines = len(event.get("new_output", {}).get("stdout", []))
|
||||
print(
|
||||
f" {Colors.BLUE}📝{Colors.RESET} Session '{session_name}' produced {lines} lines of output"
|
||||
)
|
||||
elif event_type == "possible_input_needed":
|
||||
print(
|
||||
f" {Colors.RED}❓{Colors.RESET} Session '{session_name}' may need input"
|
||||
)
|
||||
elif event_type == "high_output_volume":
|
||||
total = event.get("total_lines", 0)
|
||||
print(
|
||||
f" {Colors.YELLOW}📊{Colors.RESET} Session '{session_name}' has high output volume ({total} lines)"
|
||||
)
|
||||
elif event_type == "inactive_session":
|
||||
inactive_time = event.get("inactive_seconds", 0)
|
||||
print(
|
||||
f" {Colors.GRAY}⏰{Colors.RESET} Session '{session_name}' inactive for {inactive_time:.0f}s"
|
||||
)
|
||||
|
||||
print() # Add blank line after events
|
||||
|
||||
except Exception as e:
|
||||
if self.debug:
|
||||
print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}"
|
||||
)
|
||||
|
||||
def execute_tool_calls(self, tool_calls):
|
||||
results = []
|
||||
@ -185,114 +254,147 @@ class Assistant:
|
||||
futures = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
func_name = tool_call['function']['name']
|
||||
arguments = json.loads(tool_call['function']['arguments'])
|
||||
func_name = tool_call["function"]["name"]
|
||||
arguments = json.loads(tool_call["function"]["arguments"])
|
||||
logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
|
||||
|
||||
func_map = {
|
||||
'http_fetch': lambda **kw: http_fetch(**kw),
|
||||
'run_command': lambda **kw: run_command(**kw),
|
||||
'tail_process': lambda **kw: tail_process(**kw),
|
||||
'kill_process': lambda **kw: kill_process(**kw),
|
||||
'start_interactive_session': lambda **kw: start_interactive_session(**kw),
|
||||
'send_input_to_session': lambda **kw: send_input_to_session(**kw),
|
||||
'read_session_output': lambda **kw: read_session_output(**kw),
|
||||
'close_interactive_session': lambda **kw: close_interactive_session(**kw),
|
||||
'read_file': lambda **kw: read_file(**kw, db_conn=self.db_conn),
|
||||
'write_file': lambda **kw: write_file(**kw, db_conn=self.db_conn),
|
||||
'list_directory': lambda **kw: list_directory(**kw),
|
||||
'mkdir': lambda **kw: mkdir(**kw),
|
||||
'chdir': lambda **kw: chdir(**kw),
|
||||
'getpwd': lambda **kw: getpwd(**kw),
|
||||
'db_set': lambda **kw: db_set(**kw, db_conn=self.db_conn),
|
||||
'db_get': lambda **kw: db_get(**kw, db_conn=self.db_conn),
|
||||
'db_query': lambda **kw: db_query(**kw, db_conn=self.db_conn),
|
||||
'web_search': lambda **kw: web_search(**kw),
|
||||
'web_search_news': lambda **kw: web_search_news(**kw),
|
||||
'python_exec': lambda **kw: python_exec(**kw, python_globals=self.python_globals),
|
||||
'index_source_directory': lambda **kw: index_source_directory(**kw),
|
||||
'search_replace': lambda **kw: search_replace(**kw, db_conn=self.db_conn),
|
||||
'open_editor': lambda **kw: open_editor(**kw),
|
||||
'editor_insert_text': lambda **kw: editor_insert_text(**kw, db_conn=self.db_conn),
|
||||
'editor_replace_text': lambda **kw: editor_replace_text(**kw, db_conn=self.db_conn),
|
||||
'editor_search': lambda **kw: editor_search(**kw),
|
||||
'close_editor': lambda **kw: close_editor(**kw),
|
||||
'create_diff': lambda **kw: create_diff(**kw),
|
||||
'apply_patch': lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
|
||||
'display_file_diff': lambda **kw: display_file_diff(**kw),
|
||||
'display_edit_summary': lambda **kw: display_edit_summary(),
|
||||
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
|
||||
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
|
||||
'start_interactive_session': lambda **kw: start_interactive_session(**kw),
|
||||
'send_input_to_session': lambda **kw: send_input_to_session(**kw),
|
||||
'read_session_output': lambda **kw: read_session_output(**kw),
|
||||
'list_active_sessions': lambda **kw: list_active_sessions(**kw),
|
||||
'close_interactive_session': lambda **kw: close_interactive_session(**kw),
|
||||
'create_agent': lambda **kw: create_agent(**kw),
|
||||
'list_agents': lambda **kw: list_agents(**kw),
|
||||
'execute_agent_task': lambda **kw: execute_agent_task(**kw),
|
||||
'remove_agent': lambda **kw: remove_agent(**kw),
|
||||
'collaborate_agents': lambda **kw: collaborate_agents(**kw),
|
||||
'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw),
|
||||
'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw),
|
||||
'search_knowledge': lambda **kw: search_knowledge(**kw),
|
||||
'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw),
|
||||
'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw),
|
||||
'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw),
|
||||
'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw),
|
||||
"http_fetch": lambda **kw: http_fetch(**kw),
|
||||
"run_command": lambda **kw: run_command(**kw),
|
||||
"tail_process": lambda **kw: tail_process(**kw),
|
||||
"kill_process": lambda **kw: kill_process(**kw),
|
||||
"start_interactive_session": lambda **kw: start_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"send_input_to_session": lambda **kw: send_input_to_session(**kw),
|
||||
"read_session_output": lambda **kw: read_session_output(**kw),
|
||||
"close_interactive_session": lambda **kw: close_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn),
|
||||
"write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn),
|
||||
"list_directory": lambda **kw: list_directory(**kw),
|
||||
"mkdir": lambda **kw: mkdir(**kw),
|
||||
"chdir": lambda **kw: chdir(**kw),
|
||||
"getpwd": lambda **kw: getpwd(**kw),
|
||||
"db_set": lambda **kw: db_set(**kw, db_conn=self.db_conn),
|
||||
"db_get": lambda **kw: db_get(**kw, db_conn=self.db_conn),
|
||||
"db_query": lambda **kw: db_query(**kw, db_conn=self.db_conn),
|
||||
"web_search": lambda **kw: web_search(**kw),
|
||||
"web_search_news": lambda **kw: web_search_news(**kw),
|
||||
"python_exec": lambda **kw: python_exec(
|
||||
**kw, python_globals=self.python_globals
|
||||
),
|
||||
"index_source_directory": lambda **kw: index_source_directory(**kw),
|
||||
"search_replace": lambda **kw: search_replace(
|
||||
**kw, db_conn=self.db_conn
|
||||
),
|
||||
"open_editor": lambda **kw: open_editor(**kw),
|
||||
"editor_insert_text": lambda **kw: editor_insert_text(
|
||||
**kw, db_conn=self.db_conn
|
||||
),
|
||||
"editor_replace_text": lambda **kw: editor_replace_text(
|
||||
**kw, db_conn=self.db_conn
|
||||
),
|
||||
"editor_search": lambda **kw: editor_search(**kw),
|
||||
"close_editor": lambda **kw: close_editor(**kw),
|
||||
"create_diff": lambda **kw: create_diff(**kw),
|
||||
"apply_patch": lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
|
||||
"display_file_diff": lambda **kw: display_file_diff(**kw),
|
||||
"display_edit_summary": lambda **kw: display_edit_summary(),
|
||||
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
|
||||
"clear_edit_tracker": lambda **kw: clear_edit_tracker(),
|
||||
"start_interactive_session": lambda **kw: start_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"send_input_to_session": lambda **kw: send_input_to_session(**kw),
|
||||
"read_session_output": lambda **kw: read_session_output(**kw),
|
||||
"list_active_sessions": lambda **kw: list_active_sessions(**kw),
|
||||
"close_interactive_session": lambda **kw: close_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"create_agent": lambda **kw: create_agent(**kw),
|
||||
"list_agents": lambda **kw: list_agents(**kw),
|
||||
"execute_agent_task": lambda **kw: execute_agent_task(**kw),
|
||||
"remove_agent": lambda **kw: remove_agent(**kw),
|
||||
"collaborate_agents": lambda **kw: collaborate_agents(**kw),
|
||||
"add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw),
|
||||
"get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw),
|
||||
"search_knowledge": lambda **kw: search_knowledge(**kw),
|
||||
"get_knowledge_by_category": lambda **kw: get_knowledge_by_category(
|
||||
**kw
|
||||
),
|
||||
"update_knowledge_importance": lambda **kw: update_knowledge_importance(
|
||||
**kw
|
||||
),
|
||||
"delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw),
|
||||
"get_knowledge_statistics": lambda **kw: get_knowledge_statistics(
|
||||
**kw
|
||||
),
|
||||
}
|
||||
|
||||
if func_name in func_map:
|
||||
future = executor.submit(func_map[func_name], **arguments)
|
||||
futures.append((tool_call['id'], future))
|
||||
futures.append((tool_call["id"], future))
|
||||
|
||||
for tool_id, future in futures:
|
||||
try:
|
||||
result = future.result(timeout=30)
|
||||
result = truncate_tool_result(result)
|
||||
logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...")
|
||||
results.append({
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps(result)
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps(result),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Tool error for {tool_id}: {str(e)}")
|
||||
error_msg = str(e)[:200] if len(str(e)) > 200 else str(e)
|
||||
results.append({
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps({"status": "error", "error": error_msg})
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps(
|
||||
{"status": "error", "error": error_msg}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def process_response(self, response):
|
||||
if 'error' in response:
|
||||
if "error" in response:
|
||||
return f"Error: {response['error']}"
|
||||
|
||||
if 'choices' not in response or not response['choices']:
|
||||
if "choices" not in response or not response["choices"]:
|
||||
return "No response from API"
|
||||
|
||||
message = response['choices'][0]['message']
|
||||
message = response["choices"][0]["message"]
|
||||
self.messages.append(message)
|
||||
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
if self.verbose:
|
||||
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}")
|
||||
|
||||
tool_results = self.execute_tool_calls(message['tool_calls'])
|
||||
tool_results = self.execute_tool_calls(message["tool_calls"])
|
||||
|
||||
for result in tool_results:
|
||||
self.messages.append(result)
|
||||
|
||||
follow_up = call_api(
|
||||
self.messages, self.model, self.api_url, self.api_key,
|
||||
self.use_tools, get_tools_definition(), verbose=self.verbose
|
||||
self.messages,
|
||||
self.model,
|
||||
self.api_url,
|
||||
self.api_key,
|
||||
self.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=self.verbose,
|
||||
)
|
||||
return self.process_response(follow_up)
|
||||
|
||||
content = message.get('content', '')
|
||||
content = message.get("content", "")
|
||||
return render_markdown(content, self.syntax_highlighting)
|
||||
|
||||
def signal_handler(self, signum, frame):
|
||||
@ -303,7 +405,9 @@ class Assistant:
|
||||
self.autonomous_mode = False
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}")
|
||||
print(
|
||||
f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}"
|
||||
)
|
||||
return
|
||||
|
||||
self.interrupt_count += 1
|
||||
@ -323,21 +427,34 @@ class Assistant:
|
||||
readline.set_history_length(1000)
|
||||
|
||||
import atexit
|
||||
|
||||
atexit.register(readline.write_history_file, HISTORY_FILE)
|
||||
|
||||
commands = ['exit', 'quit', 'help', 'reset', 'dump', 'verbose',
|
||||
'models', 'tools', 'review', 'refactor', 'obfuscate', '/auto']
|
||||
commands = [
|
||||
"exit",
|
||||
"quit",
|
||||
"help",
|
||||
"reset",
|
||||
"dump",
|
||||
"verbose",
|
||||
"models",
|
||||
"tools",
|
||||
"review",
|
||||
"refactor",
|
||||
"obfuscate",
|
||||
"/auto",
|
||||
]
|
||||
|
||||
def completer(text, state):
|
||||
options = [cmd for cmd in commands if cmd.startswith(text)]
|
||||
|
||||
glob_pattern = os.path.expanduser(text) + '*'
|
||||
glob_pattern = os.path.expanduser(text) + "*"
|
||||
path_options = glob_module.glob(glob_pattern)
|
||||
|
||||
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
|
||||
|
||||
combined_options = sorted(list(set(options + path_options)))
|
||||
#combined_options.extend(self.commands)
|
||||
# combined_options.extend(self.commands)
|
||||
|
||||
if state < len(combined_options):
|
||||
return combined_options[state]
|
||||
@ -345,10 +462,10 @@ class Assistant:
|
||||
return None
|
||||
|
||||
delims = readline.get_completer_delims()
|
||||
readline.set_completer_delims(delims.replace('/', ''))
|
||||
readline.set_completer_delims(delims.replace("/", ""))
|
||||
|
||||
readline.set_completer(completer)
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind("tab: complete")
|
||||
|
||||
def run_repl(self):
|
||||
self.setup_readline()
|
||||
@ -368,8 +485,11 @@ class Assistant:
|
||||
if self.background_monitoring:
|
||||
try:
|
||||
from pr.multiplexer import get_all_sessions
|
||||
|
||||
sessions = get_all_sessions()
|
||||
active_count = sum(1 for s in sessions.values() if s.get('status') == 'running')
|
||||
active_count = sum(
|
||||
1 for s in sessions.values() if s.get("status") == "running"
|
||||
)
|
||||
if active_count > 0:
|
||||
prompt += f"[{active_count}bg]"
|
||||
except:
|
||||
@ -405,10 +525,11 @@ class Assistant:
|
||||
message = sys.stdin.read()
|
||||
|
||||
from pr.autonomous.mode import run_autonomous_mode
|
||||
|
||||
run_autonomous_mode(self, message)
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, 'enhanced') and self.enhanced:
|
||||
if hasattr(self, "enhanced") and self.enhanced:
|
||||
try:
|
||||
self.enhanced.cleanup()
|
||||
except Exception as e:
|
||||
@ -424,6 +545,7 @@ class Assistant:
|
||||
|
||||
try:
|
||||
from pr.multiplexer import cleanup_all_multiplexers
|
||||
|
||||
cleanup_all_multiplexers()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up multiplexers: {e}")
|
||||
@ -433,7 +555,9 @@ class Assistant:
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}")
|
||||
print(
|
||||
f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}"
|
||||
)
|
||||
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
|
||||
print("DEBUG: calling run_repl")
|
||||
self.run_repl()
|
||||
@ -443,6 +567,7 @@ class Assistant:
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
|
||||
def process_message(assistant, message):
|
||||
assistant.messages.append({"role": "user", "content": message})
|
||||
|
||||
@ -453,9 +578,13 @@ def process_message(assistant, message):
|
||||
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}")
|
||||
|
||||
response = call_api(
|
||||
assistant.messages, assistant.model, assistant.api_url,
|
||||
assistant.api_key, assistant.use_tools, get_tools_definition(),
|
||||
verbose=assistant.verbose
|
||||
assistant.messages,
|
||||
assistant.model,
|
||||
assistant.api_url,
|
||||
assistant.api_key,
|
||||
assistant.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=assistant.verbose,
|
||||
)
|
||||
result = assistant.process_response(response)
|
||||
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
import time
|
||||
import threading
|
||||
from pr.core.background_monitor import get_global_monitor
|
||||
from pr.tools.interactive_control import list_active_sessions, get_session_status, read_session_output
|
||||
import time
|
||||
|
||||
from pr.tools.interactive_control import (
|
||||
get_session_status,
|
||||
list_active_sessions,
|
||||
read_session_output,
|
||||
)
|
||||
|
||||
|
||||
class AutonomousInteractions:
|
||||
def __init__(self, interaction_interval=10.0):
|
||||
@ -16,7 +21,9 @@ class AutonomousInteractions:
|
||||
self.llm_callback = llm_callback
|
||||
if self.interaction_thread is None:
|
||||
self.active = True
|
||||
self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True)
|
||||
self.interaction_thread = threading.Thread(
|
||||
target=self._interaction_loop, daemon=True
|
||||
)
|
||||
self.interaction_thread.start()
|
||||
|
||||
def stop(self):
|
||||
@ -48,7 +55,9 @@ class AutonomousInteractions:
|
||||
if not sessions:
|
||||
return # No active sessions
|
||||
|
||||
sessions_needing_attention = self._identify_sessions_needing_attention(sessions)
|
||||
sessions_needing_attention = self._identify_sessions_needing_attention(
|
||||
sessions
|
||||
)
|
||||
|
||||
if sessions_needing_attention and self.llm_callback:
|
||||
# Format session updates for LLM
|
||||
@ -63,26 +72,30 @@ class AutonomousInteractions:
|
||||
needing_attention = []
|
||||
|
||||
for session_name, session_data in sessions.items():
|
||||
metadata = session_data['metadata']
|
||||
output_summary = session_data['output_summary']
|
||||
metadata = session_data["metadata"]
|
||||
output_summary = session_data["output_summary"]
|
||||
|
||||
# Criteria for needing attention:
|
||||
|
||||
# 1. Recent output activity
|
||||
time_since_activity = time.time() - metadata.get('last_activity', 0)
|
||||
time_since_activity = time.time() - metadata.get("last_activity", 0)
|
||||
if time_since_activity < 30: # Activity in last 30 seconds
|
||||
needing_attention.append(session_name)
|
||||
continue
|
||||
|
||||
# 2. High output volume (potential completion or error)
|
||||
total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines']
|
||||
total_lines = (
|
||||
output_summary["stdout_lines"] + output_summary["stderr_lines"]
|
||||
)
|
||||
if total_lines > 50: # Arbitrary threshold
|
||||
needing_attention.append(session_name)
|
||||
continue
|
||||
|
||||
# 3. Long-running sessions that might need intervention
|
||||
session_age = time.time() - metadata.get('start_time', 0)
|
||||
if session_age > 300 and time_since_activity > 60: # 5+ minutes old, inactive for 1+ minute
|
||||
session_age = time.time() - metadata.get("start_time", 0)
|
||||
if (
|
||||
session_age > 300 and time_since_activity > 60
|
||||
): # 5+ minutes old, inactive for 1+ minute
|
||||
needing_attention.append(session_name)
|
||||
continue
|
||||
|
||||
@ -95,18 +108,18 @@ class AutonomousInteractions:
|
||||
|
||||
def _session_looks_stuck(self, session_name, session_data):
|
||||
"""Determine if a session appears to be stuck waiting for input."""
|
||||
metadata = session_data['metadata']
|
||||
metadata = session_data["metadata"]
|
||||
|
||||
# Check if process is still running
|
||||
status = get_session_status(session_name)
|
||||
if not status or not status.get('is_active', False):
|
||||
if not status or not status.get("is_active", False):
|
||||
return False
|
||||
|
||||
time_since_activity = time.time() - metadata.get('last_activity', 0)
|
||||
interaction_count = metadata.get('interaction_count', 0)
|
||||
time_since_activity = time.time() - metadata.get("last_activity", 0)
|
||||
interaction_count = metadata.get("interaction_count", 0)
|
||||
|
||||
# If running for a while but no interactions, might be waiting
|
||||
session_age = time.time() - metadata.get('start_time', 0)
|
||||
session_age = time.time() - metadata.get("start_time", 0)
|
||||
if session_age > 60 and interaction_count == 0 and time_since_activity > 30:
|
||||
return True
|
||||
|
||||
@ -119,9 +132,9 @@ class AutonomousInteractions:
|
||||
def _format_session_updates(self, session_names):
|
||||
"""Format session information for LLM consumption."""
|
||||
updates = {
|
||||
'type': 'background_session_updates',
|
||||
'timestamp': time.time(),
|
||||
'sessions': {}
|
||||
"type": "background_session_updates",
|
||||
"timestamp": time.time(),
|
||||
"sessions": {},
|
||||
}
|
||||
|
||||
for session_name in session_names:
|
||||
@ -131,12 +144,12 @@ class AutonomousInteractions:
|
||||
try:
|
||||
recent_output = read_session_output(session_name, lines=20)
|
||||
except:
|
||||
recent_output = {'stdout': '', 'stderr': ''}
|
||||
recent_output = {"stdout": "", "stderr": ""}
|
||||
|
||||
updates['sessions'][session_name] = {
|
||||
'status': status,
|
||||
'recent_output': recent_output,
|
||||
'summary': self._create_session_summary(status, recent_output)
|
||||
updates["sessions"][session_name] = {
|
||||
"status": status,
|
||||
"recent_output": recent_output,
|
||||
"summary": self._create_session_summary(status, recent_output),
|
||||
}
|
||||
|
||||
return updates
|
||||
@ -145,34 +158,39 @@ class AutonomousInteractions:
|
||||
"""Create a human-readable summary of session status."""
|
||||
summary_parts = []
|
||||
|
||||
process_type = status.get('metadata', {}).get('process_type', 'unknown')
|
||||
process_type = status.get("metadata", {}).get("process_type", "unknown")
|
||||
summary_parts.append(f"Type: {process_type}")
|
||||
|
||||
is_active = status.get('is_active', False)
|
||||
is_active = status.get("is_active", False)
|
||||
summary_parts.append(f"Status: {'Active' if is_active else 'Inactive'}")
|
||||
|
||||
if is_active and 'pid' in status:
|
||||
if is_active and "pid" in status:
|
||||
summary_parts.append(f"PID: {status['pid']}")
|
||||
|
||||
age = time.time() - status.get('metadata', {}).get('start_time', 0)
|
||||
age = time.time() - status.get("metadata", {}).get("start_time", 0)
|
||||
summary_parts.append(f"Age: {age:.1f}s")
|
||||
|
||||
output_lines = len(recent_output.get('stdout', '').split('\n')) + len(recent_output.get('stderr', '').split('\n'))
|
||||
output_lines = len(recent_output.get("stdout", "").split("\n")) + len(
|
||||
recent_output.get("stderr", "").split("\n")
|
||||
)
|
||||
summary_parts.append(f"Recent output: {output_lines} lines")
|
||||
|
||||
interaction_count = status.get('metadata', {}).get('interaction_count', 0)
|
||||
interaction_count = status.get("metadata", {}).get("interaction_count", 0)
|
||||
summary_parts.append(f"Interactions: {interaction_count}")
|
||||
|
||||
return " | ".join(summary_parts)
|
||||
|
||||
|
||||
# Global autonomous interactions instance
|
||||
_global_autonomous = None
|
||||
|
||||
|
||||
def get_global_autonomous():
|
||||
"""Get the global autonomous interactions instance."""
|
||||
global _global_autonomous
|
||||
return _global_autonomous
|
||||
|
||||
|
||||
def start_global_autonomous(llm_callback=None):
|
||||
"""Start global autonomous interactions."""
|
||||
global _global_autonomous
|
||||
@ -181,6 +199,7 @@ def start_global_autonomous(llm_callback=None):
|
||||
_global_autonomous.start(llm_callback)
|
||||
return _global_autonomous
|
||||
|
||||
|
||||
def stop_global_autonomous():
|
||||
"""Stop global autonomous interactions."""
|
||||
global _global_autonomous
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import queue
|
||||
|
||||
from pr.multiplexer import get_all_multiplexer_states, get_multiplexer
|
||||
from pr.tools.interactive_control import get_session_status
|
||||
|
||||
|
||||
class BackgroundMonitor:
|
||||
def __init__(self, check_interval=5.0):
|
||||
@ -17,7 +18,9 @@ class BackgroundMonitor:
|
||||
"""Start the background monitoring thread."""
|
||||
if self.monitor_thread is None:
|
||||
self.active = True
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self.monitor_thread = threading.Thread(
|
||||
target=self._monitor_loop, daemon=True
|
||||
)
|
||||
self.monitor_thread.start()
|
||||
|
||||
def stop(self):
|
||||
@ -78,19 +81,18 @@ class BackgroundMonitor:
|
||||
# Check for new sessions
|
||||
for session_name in new_states:
|
||||
if session_name not in old_states:
|
||||
events.append({
|
||||
'type': 'session_started',
|
||||
'session_name': session_name,
|
||||
'metadata': new_states[session_name]['metadata']
|
||||
})
|
||||
events.append(
|
||||
{
|
||||
"type": "session_started",
|
||||
"session_name": session_name,
|
||||
"metadata": new_states[session_name]["metadata"],
|
||||
}
|
||||
)
|
||||
|
||||
# Check for ended sessions
|
||||
for session_name in old_states:
|
||||
if session_name not in new_states:
|
||||
events.append({
|
||||
'type': 'session_ended',
|
||||
'session_name': session_name
|
||||
})
|
||||
events.append({"type": "session_ended", "session_name": session_name})
|
||||
|
||||
# Check for activity in existing sessions
|
||||
for session_name, new_state in new_states.items():
|
||||
@ -98,92 +100,112 @@ class BackgroundMonitor:
|
||||
old_state = old_states[session_name]
|
||||
|
||||
# Check for output changes
|
||||
old_stdout_lines = old_state['output_summary']['stdout_lines']
|
||||
new_stdout_lines = new_state['output_summary']['stdout_lines']
|
||||
old_stderr_lines = old_state['output_summary']['stderr_lines']
|
||||
new_stderr_lines = new_state['output_summary']['stderr_lines']
|
||||
old_stdout_lines = old_state["output_summary"]["stdout_lines"]
|
||||
new_stdout_lines = new_state["output_summary"]["stdout_lines"]
|
||||
old_stderr_lines = old_state["output_summary"]["stderr_lines"]
|
||||
new_stderr_lines = new_state["output_summary"]["stderr_lines"]
|
||||
|
||||
if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines:
|
||||
if (
|
||||
new_stdout_lines > old_stdout_lines
|
||||
or new_stderr_lines > old_stderr_lines
|
||||
):
|
||||
# Get the new output
|
||||
mux = get_multiplexer(session_name)
|
||||
if mux:
|
||||
all_output = mux.get_all_output()
|
||||
new_output = {
|
||||
'stdout': all_output['stdout'].split('\n')[old_stdout_lines:],
|
||||
'stderr': all_output['stderr'].split('\n')[old_stderr_lines:]
|
||||
"stdout": all_output["stdout"].split("\n")[
|
||||
old_stdout_lines:
|
||||
],
|
||||
"stderr": all_output["stderr"].split("\n")[
|
||||
old_stderr_lines:
|
||||
],
|
||||
}
|
||||
|
||||
events.append({
|
||||
'type': 'output_received',
|
||||
'session_name': session_name,
|
||||
'new_output': new_output,
|
||||
'total_lines': {
|
||||
'stdout': new_stdout_lines,
|
||||
'stderr': new_stderr_lines
|
||||
events.append(
|
||||
{
|
||||
"type": "output_received",
|
||||
"session_name": session_name,
|
||||
"new_output": new_output,
|
||||
"total_lines": {
|
||||
"stdout": new_stdout_lines,
|
||||
"stderr": new_stderr_lines,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
# Check for state changes
|
||||
old_metadata = old_state['metadata']
|
||||
new_metadata = new_state['metadata']
|
||||
old_metadata = old_state["metadata"]
|
||||
new_metadata = new_state["metadata"]
|
||||
|
||||
if old_metadata.get('state') != new_metadata.get('state'):
|
||||
events.append({
|
||||
'type': 'state_changed',
|
||||
'session_name': session_name,
|
||||
'old_state': old_metadata.get('state'),
|
||||
'new_state': new_metadata.get('state')
|
||||
})
|
||||
if old_metadata.get("state") != new_metadata.get("state"):
|
||||
events.append(
|
||||
{
|
||||
"type": "state_changed",
|
||||
"session_name": session_name,
|
||||
"old_state": old_metadata.get("state"),
|
||||
"new_state": new_metadata.get("state"),
|
||||
}
|
||||
)
|
||||
|
||||
# Check for process type identification
|
||||
if (old_metadata.get('process_type') == 'unknown' and
|
||||
new_metadata.get('process_type') != 'unknown'):
|
||||
events.append({
|
||||
'type': 'process_identified',
|
||||
'session_name': session_name,
|
||||
'process_type': new_metadata.get('process_type')
|
||||
})
|
||||
if (
|
||||
old_metadata.get("process_type") == "unknown"
|
||||
and new_metadata.get("process_type") != "unknown"
|
||||
):
|
||||
events.append(
|
||||
{
|
||||
"type": "process_identified",
|
||||
"session_name": session_name,
|
||||
"process_type": new_metadata.get("process_type"),
|
||||
}
|
||||
)
|
||||
|
||||
# Check for sessions needing attention (based on heuristics)
|
||||
for session_name, state in new_states.items():
|
||||
metadata = state['metadata']
|
||||
output_summary = state['output_summary']
|
||||
metadata = state["metadata"]
|
||||
output_summary = state["output_summary"]
|
||||
|
||||
# Heuristic: High output volume might indicate completion or error
|
||||
total_lines = output_summary['stdout_lines'] + output_summary['stderr_lines']
|
||||
total_lines = (
|
||||
output_summary["stdout_lines"] + output_summary["stderr_lines"]
|
||||
)
|
||||
if total_lines > 100: # Arbitrary threshold
|
||||
events.append({
|
||||
'type': 'high_output_volume',
|
||||
'session_name': session_name,
|
||||
'total_lines': total_lines
|
||||
})
|
||||
events.append(
|
||||
{
|
||||
"type": "high_output_volume",
|
||||
"session_name": session_name,
|
||||
"total_lines": total_lines,
|
||||
}
|
||||
)
|
||||
|
||||
# Heuristic: Long-running session without recent activity
|
||||
time_since_activity = time.time() - metadata.get('last_activity', 0)
|
||||
time_since_activity = time.time() - metadata.get("last_activity", 0)
|
||||
if time_since_activity > 300: # 5 minutes
|
||||
events.append({
|
||||
'type': 'inactive_session',
|
||||
'session_name': session_name,
|
||||
'inactive_seconds': time_since_activity
|
||||
})
|
||||
events.append(
|
||||
{
|
||||
"type": "inactive_session",
|
||||
"session_name": session_name,
|
||||
"inactive_seconds": time_since_activity,
|
||||
}
|
||||
)
|
||||
|
||||
# Heuristic: Sessions that might be waiting for input
|
||||
# This would be enhanced with prompt detection in later phases
|
||||
if self._might_be_waiting_for_input(session_name, state):
|
||||
events.append({
|
||||
'type': 'possible_input_needed',
|
||||
'session_name': session_name
|
||||
})
|
||||
events.append(
|
||||
{"type": "possible_input_needed", "session_name": session_name}
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
def _might_be_waiting_for_input(self, session_name, state):
|
||||
"""Heuristic to detect if a session might be waiting for input."""
|
||||
metadata = state['metadata']
|
||||
process_type = metadata.get('process_type', 'unknown')
|
||||
metadata = state["metadata"]
|
||||
metadata.get("process_type", "unknown")
|
||||
|
||||
# Simple heuristics based on process type and recent activity
|
||||
time_since_activity = time.time() - metadata.get('last_activity', 0)
|
||||
time_since_activity = time.time() - metadata.get("last_activity", 0)
|
||||
|
||||
# If it's been more than 10 seconds since last activity, might be waiting
|
||||
if time_since_activity > 10:
|
||||
@ -191,9 +213,11 @@ class BackgroundMonitor:
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global monitor instance
|
||||
_global_monitor = None
|
||||
|
||||
|
||||
def get_global_monitor():
|
||||
"""Get the global background monitor instance."""
|
||||
global _global_monitor
|
||||
@ -201,20 +225,24 @@ def get_global_monitor():
|
||||
_global_monitor = BackgroundMonitor()
|
||||
return _global_monitor
|
||||
|
||||
|
||||
def start_global_monitor():
|
||||
"""Start the global background monitor."""
|
||||
monitor = get_global_monitor()
|
||||
monitor.start()
|
||||
|
||||
|
||||
def stop_global_monitor():
|
||||
"""Stop the global background monitor."""
|
||||
global _global_monitor
|
||||
if _global_monitor:
|
||||
_global_monitor.stop()
|
||||
|
||||
|
||||
# Global monitor instance
|
||||
_global_monitor = None
|
||||
|
||||
|
||||
def start_global_monitor():
|
||||
"""Start the global background monitor."""
|
||||
global _global_monitor
|
||||
@ -223,6 +251,7 @@ def start_global_monitor():
|
||||
_global_monitor.start()
|
||||
return _global_monitor
|
||||
|
||||
|
||||
def stop_global_monitor():
|
||||
"""Stop the global background monitor."""
|
||||
global _global_monitor
|
||||
@ -230,6 +259,7 @@ def stop_global_monitor():
|
||||
_global_monitor.stop()
|
||||
_global_monitor = None
|
||||
|
||||
|
||||
def get_global_monitor():
|
||||
"""Get the global background monitor instance."""
|
||||
global _global_monitor
|
||||
|
||||
@ -1,22 +1,17 @@
|
||||
import os
|
||||
import configparser
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('config')
|
||||
logger = get_logger("config")
|
||||
|
||||
CONFIG_FILE = os.path.expanduser("~/.prrc")
|
||||
LOCAL_CONFIG_FILE = ".prrc"
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
config = {
|
||||
'api': {},
|
||||
'autonomous': {},
|
||||
'ui': {},
|
||||
'output': {},
|
||||
'session': {}
|
||||
}
|
||||
config = {"api": {}, "autonomous": {}, "ui": {}, "output": {}, "session": {}}
|
||||
|
||||
global_config = _load_config_file(CONFIG_FILE)
|
||||
local_config = _load_config_file(LOCAL_CONFIG_FILE)
|
||||
@ -55,9 +50,9 @@ def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
|
||||
def _parse_value(value: str) -> Any:
|
||||
value = value.strip()
|
||||
|
||||
if value.lower() == 'true':
|
||||
if value.lower() == "true":
|
||||
return True
|
||||
if value.lower() == 'false':
|
||||
if value.lower() == "false":
|
||||
return False
|
||||
|
||||
if value.isdigit():
|
||||
@ -99,7 +94,7 @@ max_history = 1000
|
||||
"""
|
||||
|
||||
try:
|
||||
with open(filepath, 'w') as f:
|
||||
with open(filepath, "w") as f:
|
||||
f.write(default_config)
|
||||
logger.info(f"Created default configuration at {filepath}")
|
||||
return True
|
||||
|
||||
@ -1,11 +1,21 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from pr.config import (CONTEXT_FILE, GLOBAL_CONTEXT_FILE, CONTEXT_COMPRESSION_THRESHOLD,
|
||||
RECENT_MESSAGES_TO_KEEP, MAX_TOKENS_LIMIT, CHARS_PER_TOKEN,
|
||||
EMERGENCY_MESSAGES_TO_KEEP, CONTENT_TRIM_LENGTH, MAX_TOOL_RESULT_LENGTH)
|
||||
import os
|
||||
|
||||
from pr.config import (
|
||||
CHARS_PER_TOKEN,
|
||||
CONTENT_TRIM_LENGTH,
|
||||
CONTEXT_COMPRESSION_THRESHOLD,
|
||||
CONTEXT_FILE,
|
||||
EMERGENCY_MESSAGES_TO_KEEP,
|
||||
GLOBAL_CONTEXT_FILE,
|
||||
MAX_TOKENS_LIMIT,
|
||||
MAX_TOOL_RESULT_LENGTH,
|
||||
RECENT_MESSAGES_TO_KEEP,
|
||||
)
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
def truncate_tool_result(result, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = MAX_TOOL_RESULT_LENGTH
|
||||
@ -17,24 +27,36 @@ def truncate_tool_result(result, max_length=None):
|
||||
|
||||
if "output" in result_copy and isinstance(result_copy["output"], str):
|
||||
if len(result_copy["output"]) > max_length:
|
||||
result_copy["output"] = result_copy["output"][:max_length] + f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
|
||||
result_copy["output"] = (
|
||||
result_copy["output"][:max_length]
|
||||
+ f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
|
||||
)
|
||||
|
||||
if "content" in result_copy and isinstance(result_copy["content"], str):
|
||||
if len(result_copy["content"]) > max_length:
|
||||
result_copy["content"] = result_copy["content"][:max_length] + f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
|
||||
result_copy["content"] = (
|
||||
result_copy["content"][:max_length]
|
||||
+ f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
|
||||
)
|
||||
|
||||
if "data" in result_copy and isinstance(result_copy["data"], str):
|
||||
if len(result_copy["data"]) > max_length:
|
||||
result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]"
|
||||
result_copy["data"] = (
|
||||
result_copy["data"][:max_length] + f"\n... [truncated]"
|
||||
)
|
||||
|
||||
if "error" in result_copy and isinstance(result_copy["error"], str):
|
||||
if len(result_copy["error"]) > max_length // 2:
|
||||
result_copy["error"] = result_copy["error"][:max_length // 2] + "... [truncated]"
|
||||
result_copy["error"] = (
|
||||
result_copy["error"][: max_length // 2] + "... [truncated]"
|
||||
)
|
||||
|
||||
return result_copy
|
||||
|
||||
|
||||
def init_system_message(args):
|
||||
context_parts = ["""You are a professional AI assistant with access to advanced tools.
|
||||
context_parts = [
|
||||
"""You are a professional AI assistant with access to advanced tools.
|
||||
|
||||
File Operations:
|
||||
- Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications
|
||||
@ -51,14 +73,15 @@ Process Management:
|
||||
Shell Commands:
|
||||
- Be a shell ninja using native OS tools
|
||||
- Prefer standard Unix utilities over complex scripts
|
||||
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""]
|
||||
#context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
|
||||
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""
|
||||
]
|
||||
# context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
|
||||
max_context_size = 10000
|
||||
|
||||
if args.include_env:
|
||||
env_context = "Environment Variables:\n"
|
||||
for key, value in os.environ.items():
|
||||
if not key.startswith('_'):
|
||||
if not key.startswith("_"):
|
||||
env_context += f"{key}={value}\n"
|
||||
if len(env_context) > max_context_size:
|
||||
env_context = env_context[:max_context_size] + "\n... [truncated]"
|
||||
@ -67,7 +90,7 @@ Shell Commands:
|
||||
for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]:
|
||||
if os.path.exists(context_file):
|
||||
try:
|
||||
with open(context_file, 'r') as f:
|
||||
with open(context_file) as f:
|
||||
content = f.read()
|
||||
if len(content) > max_context_size:
|
||||
content = content[:max_context_size] + "\n... [truncated]"
|
||||
@ -78,7 +101,7 @@ Shell Commands:
|
||||
if args.context:
|
||||
for ctx_file in args.context:
|
||||
try:
|
||||
with open(ctx_file, 'r') as f:
|
||||
with open(ctx_file) as f:
|
||||
content = f.read()
|
||||
if len(content) > max_context_size:
|
||||
content = content[:max_context_size] + "\n... [truncated]"
|
||||
@ -88,22 +111,29 @@ Shell Commands:
|
||||
|
||||
system_message = "\n\n".join(context_parts)
|
||||
if len(system_message) > max_context_size * 3:
|
||||
system_message = system_message[:max_context_size * 3] + "\n... [system message truncated]"
|
||||
system_message = (
|
||||
system_message[: max_context_size * 3] + "\n... [system message truncated]"
|
||||
)
|
||||
|
||||
return {"role": "system", "content": system_message}
|
||||
|
||||
|
||||
def should_compress_context(messages):
|
||||
return len(messages) > CONTEXT_COMPRESSION_THRESHOLD
|
||||
|
||||
|
||||
def compress_context(messages):
|
||||
return manage_context_window(messages, verbose=False)
|
||||
|
||||
|
||||
def manage_context_window(messages, verbose):
|
||||
if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD:
|
||||
return messages
|
||||
|
||||
if verbose:
|
||||
print(f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}"
|
||||
)
|
||||
|
||||
system_message = messages[0]
|
||||
recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:]
|
||||
@ -113,18 +143,21 @@ def manage_context_window(messages, verbose):
|
||||
summary = summarize_messages(middle_messages)
|
||||
summary_message = {
|
||||
"role": "system",
|
||||
"content": f"[Previous conversation summary: {summary}]"
|
||||
"content": f"[Previous conversation summary: {summary}]",
|
||||
}
|
||||
|
||||
new_messages = [system_message, summary_message] + recent_messages
|
||||
|
||||
if verbose:
|
||||
print(f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}"
|
||||
)
|
||||
|
||||
return new_messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def summarize_messages(messages):
|
||||
summary_parts = []
|
||||
|
||||
@ -142,6 +175,7 @@ def summarize_messages(messages):
|
||||
|
||||
return " | ".join(summary_parts[:10])
|
||||
|
||||
|
||||
def estimate_tokens(messages):
|
||||
total_chars = 0
|
||||
|
||||
@ -155,6 +189,7 @@ def estimate_tokens(messages):
|
||||
|
||||
return int(estimated_tokens * overhead_multiplier)
|
||||
|
||||
|
||||
def trim_message_content(message, max_length):
|
||||
trimmed_msg = message.copy()
|
||||
|
||||
@ -162,14 +197,22 @@ def trim_message_content(message, max_length):
|
||||
content = trimmed_msg["content"]
|
||||
|
||||
if isinstance(content, str) and len(content) > max_length:
|
||||
trimmed_msg["content"] = content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]"
|
||||
trimmed_msg["content"] = (
|
||||
content[:max_length]
|
||||
+ f"\n... [trimmed {len(content) - max_length} chars]"
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
trimmed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
trimmed_item = item.copy()
|
||||
if "text" in trimmed_item and len(trimmed_item["text"]) > max_length:
|
||||
trimmed_item["text"] = trimmed_item["text"][:max_length] + f"\n... [trimmed]"
|
||||
if (
|
||||
"text" in trimmed_item
|
||||
and len(trimmed_item["text"]) > max_length
|
||||
):
|
||||
trimmed_item["text"] = (
|
||||
trimmed_item["text"][:max_length] + f"\n... [trimmed]"
|
||||
)
|
||||
trimmed_content.append(trimmed_item)
|
||||
else:
|
||||
trimmed_content.append(item)
|
||||
@ -179,35 +222,61 @@ def trim_message_content(message, max_length):
|
||||
if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str):
|
||||
content = trimmed_msg["content"]
|
||||
if len(content) > MAX_TOOL_RESULT_LENGTH:
|
||||
trimmed_msg["content"] = content[:MAX_TOOL_RESULT_LENGTH] + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
|
||||
trimmed_msg["content"] = (
|
||||
content[:MAX_TOOL_RESULT_LENGTH]
|
||||
+ f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
|
||||
)
|
||||
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
if "output" in parsed and isinstance(parsed["output"], str) and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2:
|
||||
parsed["output"] = parsed["output"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
|
||||
if "content" in parsed and isinstance(parsed["content"], str) and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2:
|
||||
parsed["content"] = parsed["content"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
|
||||
if (
|
||||
"output" in parsed
|
||||
and isinstance(parsed["output"], str)
|
||||
and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2
|
||||
):
|
||||
parsed["output"] = (
|
||||
parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2]
|
||||
+ f"\n... [truncated]"
|
||||
)
|
||||
if (
|
||||
"content" in parsed
|
||||
and isinstance(parsed["content"], str)
|
||||
and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2
|
||||
):
|
||||
parsed["content"] = (
|
||||
parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2]
|
||||
+ f"\n... [truncated]"
|
||||
)
|
||||
trimmed_msg["content"] = json.dumps(parsed)
|
||||
except:
|
||||
pass
|
||||
|
||||
return trimmed_msg
|
||||
|
||||
|
||||
def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
|
||||
if estimate_tokens(messages) <= target_tokens:
|
||||
return messages
|
||||
|
||||
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
system_msg = (
|
||||
messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
)
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
recent_messages = messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
|
||||
middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
|
||||
recent_messages = (
|
||||
messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
|
||||
)
|
||||
middle_messages = (
|
||||
messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
|
||||
)
|
||||
|
||||
trimmed_middle = []
|
||||
for msg in middle_messages:
|
||||
if msg.get("role") == "tool":
|
||||
trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2))
|
||||
trimmed_middle.append(
|
||||
trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)
|
||||
)
|
||||
elif msg.get("role") in ["user", "assistant"]:
|
||||
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
|
||||
else:
|
||||
@ -233,6 +302,7 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
|
||||
|
||||
return ([system_msg] if system_msg else []) + messages[-1:]
|
||||
|
||||
|
||||
def auto_slim_messages(messages, verbose=False):
|
||||
estimated_tokens = estimate_tokens(messages)
|
||||
|
||||
@ -240,29 +310,46 @@ def auto_slim_messages(messages, verbose=False):
|
||||
return messages
|
||||
|
||||
if verbose:
|
||||
print(f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}")
|
||||
print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}"
|
||||
)
|
||||
|
||||
result = intelligently_trim_messages(messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP)
|
||||
result = intelligently_trim_messages(
|
||||
messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP
|
||||
)
|
||||
final_tokens = estimate_tokens(result)
|
||||
|
||||
if final_tokens > MAX_TOKENS_LIMIT:
|
||||
if verbose:
|
||||
print(f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}"
|
||||
)
|
||||
result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose)
|
||||
final_tokens = estimate_tokens(result)
|
||||
|
||||
if verbose:
|
||||
removed_count = len(messages) - len(result)
|
||||
print(f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}")
|
||||
print(f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}"
|
||||
)
|
||||
if removed_count > 0:
|
||||
print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def emergency_reduce_messages(messages, target_tokens, verbose=False):
|
||||
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
system_msg = (
|
||||
messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
)
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
keep_count = 2
|
||||
|
||||
@ -1,22 +1,29 @@
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pr.config import (
|
||||
DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL,
|
||||
WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS,
|
||||
KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED,
|
||||
MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD
|
||||
)
|
||||
from pr.cache import APICache, ToolCache
|
||||
from pr.workflows import WorkflowEngine, WorkflowStorage
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pr.agents import AgentManager
|
||||
from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor
|
||||
from pr.cache import APICache, ToolCache
|
||||
from pr.config import (
|
||||
ADVANCED_CONTEXT_ENABLED,
|
||||
API_CACHE_TTL,
|
||||
CACHE_ENABLED,
|
||||
CONVERSATION_SUMMARY_THRESHOLD,
|
||||
DB_PATH,
|
||||
KNOWLEDGE_SEARCH_LIMIT,
|
||||
MEMORY_AUTO_SUMMARIZE,
|
||||
TOOL_CACHE_TTL,
|
||||
WORKFLOW_EXECUTOR_MAX_WORKERS,
|
||||
)
|
||||
from pr.core.advanced_context import AdvancedContextManager
|
||||
from pr.core.api import call_api
|
||||
from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.workflows import WorkflowEngine, WorkflowStorage
|
||||
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
|
||||
class EnhancedAssistant:
|
||||
def __init__(self, base_assistant):
|
||||
@ -32,7 +39,7 @@ class EnhancedAssistant:
|
||||
self.workflow_storage = WorkflowStorage(DB_PATH)
|
||||
self.workflow_engine = WorkflowEngine(
|
||||
tool_executor=self._execute_tool_for_workflow,
|
||||
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS
|
||||
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS,
|
||||
)
|
||||
|
||||
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
|
||||
@ -44,20 +51,21 @@ class EnhancedAssistant:
|
||||
if ADVANCED_CONTEXT_ENABLED:
|
||||
self.context_manager = AdvancedContextManager(
|
||||
knowledge_store=self.knowledge_store,
|
||||
conversation_memory=self.conversation_memory
|
||||
conversation_memory=self.conversation_memory,
|
||||
)
|
||||
else:
|
||||
self.context_manager = None
|
||||
|
||||
self.current_conversation_id = str(uuid.uuid4())[:16]
|
||||
self.conversation_memory.create_conversation(
|
||||
self.current_conversation_id,
|
||||
session_id=str(uuid.uuid4())[:16]
|
||||
self.current_conversation_id, session_id=str(uuid.uuid4())[:16]
|
||||
)
|
||||
|
||||
logger.info("Enhanced Assistant initialized with all features")
|
||||
|
||||
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
def _execute_tool_for_workflow(
|
||||
self, tool_name: str, arguments: Dict[str, Any]
|
||||
) -> Any:
|
||||
if self.tool_cache:
|
||||
cached_result = self.tool_cache.get(tool_name, arguments)
|
||||
if cached_result is not None:
|
||||
@ -65,41 +73,66 @@ class EnhancedAssistant:
|
||||
return cached_result
|
||||
|
||||
func_map = {
|
||||
'read_file': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'read_file', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
'write_file': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'write_file', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
'list_directory': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'list_directory', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
'run_command': lambda **kw: self.base.execute_tool_calls([{
|
||||
'id': 'temp',
|
||||
'function': {'name': 'run_command', 'arguments': json.dumps(kw)}
|
||||
}])[0],
|
||||
"read_file": lambda **kw: self.base.execute_tool_calls(
|
||||
[
|
||||
{
|
||||
"id": "temp",
|
||||
"function": {"name": "read_file", "arguments": json.dumps(kw)},
|
||||
}
|
||||
]
|
||||
)[0],
|
||||
"write_file": lambda **kw: self.base.execute_tool_calls(
|
||||
[
|
||||
{
|
||||
"id": "temp",
|
||||
"function": {"name": "write_file", "arguments": json.dumps(kw)},
|
||||
}
|
||||
]
|
||||
)[0],
|
||||
"list_directory": lambda **kw: self.base.execute_tool_calls(
|
||||
[
|
||||
{
|
||||
"id": "temp",
|
||||
"function": {
|
||||
"name": "list_directory",
|
||||
"arguments": json.dumps(kw),
|
||||
},
|
||||
}
|
||||
]
|
||||
)[0],
|
||||
"run_command": lambda **kw: self.base.execute_tool_calls(
|
||||
[
|
||||
{
|
||||
"id": "temp",
|
||||
"function": {
|
||||
"name": "run_command",
|
||||
"arguments": json.dumps(kw),
|
||||
},
|
||||
}
|
||||
]
|
||||
)[0],
|
||||
}
|
||||
|
||||
if tool_name in func_map:
|
||||
result = func_map[tool_name](**arguments)
|
||||
|
||||
if self.tool_cache:
|
||||
content = result.get('content', '')
|
||||
content = result.get("content", "")
|
||||
try:
|
||||
parsed_content = json.loads(content) if isinstance(content, str) else content
|
||||
parsed_content = (
|
||||
json.loads(content) if isinstance(content, str) else content
|
||||
)
|
||||
self.tool_cache.set(tool_name, arguments, parsed_content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
return {'error': f'Unknown tool: {tool_name}'}
|
||||
return {"error": f"Unknown tool: {tool_name}"}
|
||||
|
||||
def _api_caller_for_agent(self, messages: List[Dict[str, Any]],
|
||||
temperature: float, max_tokens: int) -> Dict[str, Any]:
|
||||
def _api_caller_for_agent(
|
||||
self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
|
||||
) -> Dict[str, Any]:
|
||||
return call_api(
|
||||
messages,
|
||||
self.base.model,
|
||||
@ -109,15 +142,12 @@ class EnhancedAssistant:
|
||||
tools=None,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
verbose=self.base.verbose
|
||||
verbose=self.base.verbose,
|
||||
)
|
||||
|
||||
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
if self.api_cache and CACHE_ENABLED:
|
||||
cached_response = self.api_cache.get(
|
||||
self.base.model, messages,
|
||||
0.7, 4096
|
||||
)
|
||||
cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096)
|
||||
if cached_response:
|
||||
logger.debug("API cache hit")
|
||||
return cached_response
|
||||
@ -129,15 +159,13 @@ class EnhancedAssistant:
|
||||
self.base.api_key,
|
||||
self.base.use_tools,
|
||||
get_tools_definition(),
|
||||
verbose=self.base.verbose
|
||||
verbose=self.base.verbose,
|
||||
)
|
||||
|
||||
if self.api_cache and CACHE_ENABLED and 'error' not in response:
|
||||
token_count = response.get('usage', {}).get('total_tokens', 0)
|
||||
if self.api_cache and CACHE_ENABLED and "error" not in response:
|
||||
token_count = response.get("usage", {}).get("total_tokens", 0)
|
||||
self.api_cache.set(
|
||||
self.base.model, messages,
|
||||
0.7, 4096,
|
||||
response, token_count
|
||||
self.base.model, messages, 0.7, 4096, response, token_count
|
||||
)
|
||||
|
||||
return response
|
||||
@ -146,35 +174,33 @@ class EnhancedAssistant:
|
||||
self.base.messages.append({"role": "user", "content": user_message})
|
||||
|
||||
self.conversation_memory.add_message(
|
||||
self.current_conversation_id,
|
||||
str(uuid.uuid4())[:16],
|
||||
'user',
|
||||
user_message
|
||||
self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
|
||||
)
|
||||
|
||||
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
|
||||
facts = self.fact_extractor.extract_facts(user_message)
|
||||
for fact in facts[:3]:
|
||||
entry_id = str(uuid.uuid4())[:16]
|
||||
from pr.memory import KnowledgeEntry
|
||||
import time
|
||||
|
||||
categories = self.fact_extractor.categorize_content(fact['text'])
|
||||
from pr.memory import KnowledgeEntry
|
||||
|
||||
categories = self.fact_extractor.categorize_content(fact["text"])
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=entry_id,
|
||||
category=categories[0] if categories else 'general',
|
||||
content=fact['text'],
|
||||
metadata={'type': fact['type'], 'confidence': fact['confidence']},
|
||||
category=categories[0] if categories else "general",
|
||||
content=fact["text"],
|
||||
metadata={"type": fact["type"], "confidence": fact["confidence"]},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
updated_at=time.time(),
|
||||
)
|
||||
self.knowledge_store.add_entry(entry)
|
||||
|
||||
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
|
||||
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
|
||||
self.base.messages,
|
||||
user_message,
|
||||
include_knowledge=True
|
||||
enhanced_messages, context_info = (
|
||||
self.context_manager.create_enhanced_context(
|
||||
self.base.messages, user_message, include_knowledge=True
|
||||
)
|
||||
)
|
||||
|
||||
if self.base.verbose:
|
||||
@ -189,38 +215,40 @@ class EnhancedAssistant:
|
||||
result = self.base.process_response(response)
|
||||
|
||||
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
|
||||
summary = self.context_manager.advanced_summarize_messages(
|
||||
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
|
||||
) if self.context_manager else "Conversation in progress"
|
||||
summary = (
|
||||
self.context_manager.advanced_summarize_messages(
|
||||
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
|
||||
)
|
||||
if self.context_manager
|
||||
else "Conversation in progress"
|
||||
)
|
||||
|
||||
topics = self.fact_extractor.categorize_content(summary)
|
||||
self.conversation_memory.update_conversation_summary(
|
||||
self.current_conversation_id,
|
||||
summary,
|
||||
topics
|
||||
self.current_conversation_id, summary, topics
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def execute_workflow(self, workflow_name: str,
|
||||
initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
def execute_workflow(
|
||||
self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
|
||||
|
||||
if not workflow:
|
||||
return {'error': f'Workflow "{workflow_name}" not found'}
|
||||
return {"error": f'Workflow "{workflow_name}" not found'}
|
||||
|
||||
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
|
||||
|
||||
execution_id = self.workflow_storage.save_execution(
|
||||
self.workflow_storage.load_workflow_by_name(workflow_name).name,
|
||||
context
|
||||
self.workflow_storage.load_workflow_by_name(workflow_name).name, context
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'execution_id': execution_id,
|
||||
'results': context.step_results,
|
||||
'execution_log': context.execution_log
|
||||
"success": True,
|
||||
"execution_id": execution_id,
|
||||
"results": context.step_results,
|
||||
"execution_log": context.execution_log,
|
||||
}
|
||||
|
||||
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
|
||||
@ -230,20 +258,22 @@ class EnhancedAssistant:
|
||||
return self.agent_manager.execute_agent_task(agent_id, task)
|
||||
|
||||
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
|
||||
orchestrator_id = self.agent_manager.create_agent('orchestrator')
|
||||
orchestrator_id = self.agent_manager.create_agent("orchestrator")
|
||||
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
|
||||
|
||||
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
|
||||
def search_knowledge(
|
||||
self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT
|
||||
) -> List[Any]:
|
||||
return self.knowledge_store.search_entries(query, top_k=limit)
|
||||
|
||||
def get_cache_statistics(self) -> Dict[str, Any]:
|
||||
stats = {}
|
||||
|
||||
if self.api_cache:
|
||||
stats['api_cache'] = self.api_cache.get_statistics()
|
||||
stats["api_cache"] = self.api_cache.get_statistics()
|
||||
|
||||
if self.tool_cache:
|
||||
stats['tool_cache'] = self.tool_cache.get_statistics()
|
||||
stats["tool_cache"] = self.tool_cache.get_statistics()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from pr.config import LOG_FILE
|
||||
|
||||
|
||||
@ -9,21 +10,19 @@ def setup_logging(verbose=False):
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
logger = logging.getLogger('pr')
|
||||
logger = logging.getLogger("pr")
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
||||
|
||||
if logger.handlers:
|
||||
logger.handlers.clear()
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
LOG_FILE,
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5
|
||||
LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
@ -31,9 +30,7 @@ def setup_logging(verbose=False):
|
||||
if verbose:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(levelname)s: %(message)s'
|
||||
)
|
||||
console_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
@ -42,5 +39,5 @@ def setup_logging(verbose=False):
|
||||
|
||||
def get_logger(name=None):
|
||||
if name:
|
||||
return logging.getLogger(f'pr.{name}')
|
||||
return logging.getLogger('pr')
|
||||
return logging.getLogger(f"pr.{name}")
|
||||
return logging.getLogger("pr")
|
||||
|
||||
@ -2,9 +2,10 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('session')
|
||||
logger = get_logger("session")
|
||||
|
||||
SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions")
|
||||
|
||||
@ -14,18 +15,20 @@ class SessionManager:
|
||||
def __init__(self):
|
||||
os.makedirs(SESSIONS_DIR, exist_ok=True)
|
||||
|
||||
def save_session(self, name: str, messages: List[Dict], metadata: Optional[Dict] = None) -> bool:
|
||||
def save_session(
|
||||
self, name: str, messages: List[Dict], metadata: Optional[Dict] = None
|
||||
) -> bool:
|
||||
try:
|
||||
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
|
||||
|
||||
session_data = {
|
||||
'name': name,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'messages': messages,
|
||||
'metadata': metadata or {}
|
||||
"name": name,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"messages": messages,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
||||
with open(session_file, 'w') as f:
|
||||
with open(session_file, "w") as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
logger.info(f"Session saved: {name}")
|
||||
@ -43,7 +46,7 @@ class SessionManager:
|
||||
logger.warning(f"Session not found: {name}")
|
||||
return None
|
||||
|
||||
with open(session_file, 'r') as f:
|
||||
with open(session_file) as f:
|
||||
session_data = json.load(f)
|
||||
|
||||
logger.info(f"Session loaded: {name}")
|
||||
@ -58,22 +61,24 @@ class SessionManager:
|
||||
|
||||
try:
|
||||
for filename in os.listdir(SESSIONS_DIR):
|
||||
if filename.endswith('.json'):
|
||||
if filename.endswith(".json"):
|
||||
filepath = os.path.join(SESSIONS_DIR, filename)
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
with open(filepath) as f:
|
||||
data = json.load(f)
|
||||
|
||||
sessions.append({
|
||||
'name': data.get('name', filename[:-5]),
|
||||
'created_at': data.get('created_at', 'unknown'),
|
||||
'message_count': len(data.get('messages', [])),
|
||||
'metadata': data.get('metadata', {})
|
||||
})
|
||||
sessions.append(
|
||||
{
|
||||
"name": data.get("name", filename[:-5]),
|
||||
"created_at": data.get("created_at", "unknown"),
|
||||
"message_count": len(data.get("messages", [])),
|
||||
"metadata": data.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading session file {filename}: {e}")
|
||||
|
||||
sessions.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
sessions.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions: {e}")
|
||||
@ -96,39 +101,39 @@ class SessionManager:
|
||||
logger.error(f"Error deleting session {name}: {e}")
|
||||
return False
|
||||
|
||||
def export_session(self, name: str, output_path: str, format: str = 'json') -> bool:
|
||||
def export_session(self, name: str, output_path: str, format: str = "json") -> bool:
|
||||
session_data = self.load_session(name)
|
||||
if not session_data:
|
||||
return False
|
||||
|
||||
try:
|
||||
if format == 'json':
|
||||
with open(output_path, 'w') as f:
|
||||
if format == "json":
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(session_data, f, indent=2)
|
||||
|
||||
elif format == 'markdown':
|
||||
with open(output_path, 'w') as f:
|
||||
elif format == "markdown":
|
||||
with open(output_path, "w") as f:
|
||||
f.write(f"# Session: {name}\n\n")
|
||||
f.write(f"Created: {session_data['created_at']}\n\n")
|
||||
f.write("---\n\n")
|
||||
|
||||
for msg in session_data['messages']:
|
||||
role = msg.get('role', 'unknown')
|
||||
content = msg.get('content', '')
|
||||
for msg in session_data["messages"]:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
f.write(f"## {role.capitalize()}\n\n")
|
||||
f.write(f"{content}\n\n")
|
||||
f.write("---\n\n")
|
||||
|
||||
elif format == 'txt':
|
||||
with open(output_path, 'w') as f:
|
||||
elif format == "txt":
|
||||
with open(output_path, "w") as f:
|
||||
f.write(f"Session: {name}\n")
|
||||
f.write(f"Created: {session_data['created_at']}\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
for msg in session_data['messages']:
|
||||
role = msg.get('role', 'unknown')
|
||||
content = msg.get('content', '')
|
||||
for msg in session_data["messages"]:
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
|
||||
f.write(f"[{role.upper()}]\n")
|
||||
f.write(f"{content}\n")
|
||||
|
||||
@ -2,20 +2,21 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('usage')
|
||||
logger = get_logger("usage")
|
||||
|
||||
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
|
||||
|
||||
MODEL_COSTS = {
|
||||
'x-ai/grok-code-fast-1': {'input': 0.0, 'output': 0.0},
|
||||
'gpt-4': {'input': 0.03, 'output': 0.06},
|
||||
'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
|
||||
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
|
||||
'claude-3-opus': {'input': 0.015, 'output': 0.075},
|
||||
'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
|
||||
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
|
||||
"x-ai/grok-code-fast-1": {"input": 0.0, "output": 0.0},
|
||||
"gpt-4": {"input": 0.03, "output": 0.06},
|
||||
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
||||
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
|
||||
"claude-3-opus": {"input": 0.015, "output": 0.075},
|
||||
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
|
||||
"claude-3-haiku": {"input": 0.00025, "output": 0.00125},
|
||||
}
|
||||
|
||||
|
||||
@ -23,12 +24,12 @@ class UsageTracker:
|
||||
|
||||
def __init__(self):
|
||||
self.session_usage = {
|
||||
'requests': 0,
|
||||
'total_tokens': 0,
|
||||
'input_tokens': 0,
|
||||
'output_tokens': 0,
|
||||
'estimated_cost': 0.0,
|
||||
'models_used': {}
|
||||
"requests": 0,
|
||||
"total_tokens": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"estimated_cost": 0.0,
|
||||
"models_used": {},
|
||||
}
|
||||
|
||||
def track_request(
|
||||
@ -36,30 +37,30 @@ class UsageTracker:
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
total_tokens: Optional[int] = None
|
||||
total_tokens: Optional[int] = None,
|
||||
):
|
||||
if total_tokens is None:
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
self.session_usage['requests'] += 1
|
||||
self.session_usage['total_tokens'] += total_tokens
|
||||
self.session_usage['input_tokens'] += input_tokens
|
||||
self.session_usage['output_tokens'] += output_tokens
|
||||
self.session_usage["requests"] += 1
|
||||
self.session_usage["total_tokens"] += total_tokens
|
||||
self.session_usage["input_tokens"] += input_tokens
|
||||
self.session_usage["output_tokens"] += output_tokens
|
||||
|
||||
if model not in self.session_usage['models_used']:
|
||||
self.session_usage['models_used'][model] = {
|
||||
'requests': 0,
|
||||
'tokens': 0,
|
||||
'cost': 0.0
|
||||
if model not in self.session_usage["models_used"]:
|
||||
self.session_usage["models_used"][model] = {
|
||||
"requests": 0,
|
||||
"tokens": 0,
|
||||
"cost": 0.0,
|
||||
}
|
||||
|
||||
model_usage = self.session_usage['models_used'][model]
|
||||
model_usage['requests'] += 1
|
||||
model_usage['tokens'] += total_tokens
|
||||
model_usage = self.session_usage["models_used"][model]
|
||||
model_usage["requests"] += 1
|
||||
model_usage["tokens"] += total_tokens
|
||||
|
||||
cost = self._calculate_cost(model, input_tokens, output_tokens)
|
||||
model_usage['cost'] += cost
|
||||
self.session_usage['estimated_cost'] += cost
|
||||
model_usage["cost"] += cost
|
||||
self.session_usage["estimated_cost"] += cost
|
||||
|
||||
self._save_to_history(model, input_tokens, output_tokens, cost)
|
||||
|
||||
@ -67,9 +68,11 @@ class UsageTracker:
|
||||
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
|
||||
)
|
||||
|
||||
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
def _calculate_cost(
|
||||
self, model: str, input_tokens: int, output_tokens: int
|
||||
) -> float:
|
||||
if model not in MODEL_COSTS:
|
||||
base_model = model.split('/')[0] if '/' in model else model
|
||||
base_model = model.split("/")[0] if "/" in model else model
|
||||
if base_model not in MODEL_COSTS:
|
||||
logger.warning(f"Unknown model for cost calculation: {model}")
|
||||
return 0.0
|
||||
@ -77,31 +80,35 @@ class UsageTracker:
|
||||
else:
|
||||
costs = MODEL_COSTS[model]
|
||||
|
||||
input_cost = (input_tokens / 1000) * costs['input']
|
||||
output_cost = (output_tokens / 1000) * costs['output']
|
||||
input_cost = (input_tokens / 1000) * costs["input"]
|
||||
output_cost = (output_tokens / 1000) * costs["output"]
|
||||
|
||||
return input_cost + output_cost
|
||||
|
||||
def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float):
|
||||
def _save_to_history(
|
||||
self, model: str, input_tokens: int, output_tokens: int, cost: float
|
||||
):
|
||||
try:
|
||||
history = []
|
||||
if os.path.exists(USAGE_DB_FILE):
|
||||
with open(USAGE_DB_FILE, 'r') as f:
|
||||
with open(USAGE_DB_FILE) as f:
|
||||
history = json.load(f)
|
||||
|
||||
history.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model': model,
|
||||
'input_tokens': input_tokens,
|
||||
'output_tokens': output_tokens,
|
||||
'total_tokens': input_tokens + output_tokens,
|
||||
'cost': cost
|
||||
})
|
||||
history.append(
|
||||
{
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
"cost": cost,
|
||||
}
|
||||
)
|
||||
|
||||
if len(history) > 10000:
|
||||
history = history[-10000:]
|
||||
|
||||
with open(USAGE_DB_FILE, 'w') as f:
|
||||
with open(USAGE_DB_FILE, "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
@ -121,42 +128,34 @@ class UsageTracker:
|
||||
f"Estimated Cost: ${usage['estimated_cost']:.4f}",
|
||||
]
|
||||
|
||||
if usage['models_used']:
|
||||
if usage["models_used"]:
|
||||
lines.append("\nModels Used:")
|
||||
for model, stats in usage['models_used'].items():
|
||||
for model, stats in usage["models_used"].items():
|
||||
lines.append(
|
||||
f" {model}: {stats['requests']} requests, "
|
||||
f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}"
|
||||
)
|
||||
|
||||
return '\n'.join(lines)
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def get_total_usage() -> Dict:
|
||||
if not os.path.exists(USAGE_DB_FILE):
|
||||
return {
|
||||
'total_requests': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0
|
||||
}
|
||||
return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
|
||||
|
||||
try:
|
||||
with open(USAGE_DB_FILE, 'r') as f:
|
||||
with open(USAGE_DB_FILE) as f:
|
||||
history = json.load(f)
|
||||
|
||||
total_tokens = sum(entry['total_tokens'] for entry in history)
|
||||
total_cost = sum(entry['cost'] for entry in history)
|
||||
total_tokens = sum(entry["total_tokens"] for entry in history)
|
||||
total_cost = sum(entry["cost"] for entry in history)
|
||||
|
||||
return {
|
||||
'total_requests': len(history),
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': total_cost
|
||||
"total_requests": len(history),
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading usage history: {e}")
|
||||
return {
|
||||
'total_requests': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0
|
||||
}
|
||||
return {"total_requests": 0, "total_tokens": 0, "total_cost": 0.0}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from pr.core.exceptions import ValidationError
|
||||
|
||||
|
||||
@ -16,7 +16,9 @@ def validate_file_path(path: str, must_exist: bool = False) -> str:
|
||||
return os.path.abspath(path)
|
||||
|
||||
|
||||
def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str:
|
||||
def validate_directory_path(
|
||||
path: str, must_exist: bool = False, create: bool = False
|
||||
) -> str:
|
||||
if not path:
|
||||
raise ValidationError("Directory path cannot be empty")
|
||||
|
||||
@ -48,7 +50,7 @@ def validate_api_url(url: str) -> str:
|
||||
if not url:
|
||||
raise ValidationError("API URL cannot be empty")
|
||||
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
if not url.startswith(("http://", "https://")):
|
||||
raise ValidationError("API URL must start with http:// or https://")
|
||||
|
||||
return url
|
||||
@ -58,7 +60,7 @@ def validate_session_name(name: str) -> str:
|
||||
if not name:
|
||||
raise ValidationError("Session name cannot be empty")
|
||||
|
||||
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
|
||||
invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]
|
||||
for char in invalid_chars:
|
||||
if char in name:
|
||||
raise ValidationError(f"Session name contains invalid character: {char}")
|
||||
|
||||
385
pr/editor.py
385
pr/editor.py
@ -1,23 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
import atexit
|
||||
import curses
|
||||
import threading
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
import atexit
|
||||
import re
|
||||
import signal
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class RPEditor:
|
||||
def __init__(self, filename=None, auto_save=False, timeout=30):
|
||||
"""
|
||||
Initialize RPEditor with enhanced robustness features.
|
||||
|
||||
|
||||
Args:
|
||||
filename: File to edit
|
||||
auto_save: Enable auto-save on exit
|
||||
@ -27,7 +26,7 @@ class RPEditor:
|
||||
self.lines = [""]
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
self.stdscr = None
|
||||
self.running = False
|
||||
@ -47,7 +46,7 @@ class RPEditor:
|
||||
self._cleanup_registered = False
|
||||
self._original_terminal_state = None
|
||||
self._exception_occurred = False
|
||||
|
||||
|
||||
# Create socket pair with error handling
|
||||
try:
|
||||
self.client_sock, self.server_sock = socket.socketpair()
|
||||
@ -56,10 +55,10 @@ class RPEditor:
|
||||
except Exception as e:
|
||||
self._cleanup()
|
||||
raise RuntimeError(f"Failed to create socket pair: {e}")
|
||||
|
||||
|
||||
# Register cleanup handlers
|
||||
self._register_cleanup()
|
||||
|
||||
|
||||
if filename:
|
||||
self.load_file()
|
||||
|
||||
@ -81,14 +80,14 @@ class RPEditor:
|
||||
try:
|
||||
# Stop the editor
|
||||
self.running = False
|
||||
|
||||
|
||||
# Save if auto-save is enabled
|
||||
if self.auto_save and self.filename and not self._exception_occurred:
|
||||
try:
|
||||
self._save_file()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Clean up curses
|
||||
if self.stdscr:
|
||||
try:
|
||||
@ -103,13 +102,13 @@ class RPEditor:
|
||||
curses.endwin()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Clear screen after curses cleanup
|
||||
try:
|
||||
os.system('clear' if os.name != 'nt' else 'cls')
|
||||
os.system("clear" if os.name != "nt" else "cls")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Close sockets
|
||||
for sock in [self.client_sock, self.server_sock]:
|
||||
if sock:
|
||||
@ -117,12 +116,12 @@ class RPEditor:
|
||||
sock.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Wait for threads to finish
|
||||
for thread in [self.thread, self.socket_thread]:
|
||||
if thread and thread.is_alive():
|
||||
thread.join(timeout=1)
|
||||
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -130,12 +129,12 @@ class RPEditor:
|
||||
"""Load file with enhanced error handling."""
|
||||
try:
|
||||
if os.path.exists(self.filename):
|
||||
with open(self.filename, 'r', encoding='utf-8', errors='replace') as f:
|
||||
with open(self.filename, encoding="utf-8", errors="replace") as f:
|
||||
content = f.read()
|
||||
self.lines = content.splitlines() if content else [""]
|
||||
else:
|
||||
self.lines = [""]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
self.lines = [""]
|
||||
# Don't raise, just use empty content
|
||||
|
||||
@ -144,24 +143,24 @@ class RPEditor:
|
||||
with self.lock:
|
||||
if not self.filename:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
# Create backup if file exists
|
||||
if os.path.exists(self.filename):
|
||||
backup_name = f"{self.filename}.bak"
|
||||
try:
|
||||
with open(self.filename, 'r', encoding='utf-8') as f:
|
||||
with open(self.filename, encoding="utf-8") as f:
|
||||
backup_content = f.read()
|
||||
with open(backup_name, 'w', encoding='utf-8') as f:
|
||||
with open(backup_name, "w", encoding="utf-8") as f:
|
||||
f.write(backup_content)
|
||||
except:
|
||||
pass # Backup failed, but continue with save
|
||||
|
||||
|
||||
# Save the file
|
||||
with open(self.filename, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(self.lines))
|
||||
with open(self.filename, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(self.lines))
|
||||
return True
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def save_file(self):
|
||||
@ -169,7 +168,7 @@ class RPEditor:
|
||||
if not self.running:
|
||||
return self._save_file()
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "save_file"}))
|
||||
except:
|
||||
return self._save_file() # Fallback to direct save
|
||||
|
||||
@ -177,10 +176,12 @@ class RPEditor:
|
||||
"""Start the editor with enhanced error handling."""
|
||||
if self.running:
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
self.running = True
|
||||
self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True)
|
||||
self.socket_thread = threading.Thread(
|
||||
target=self.socket_listener, daemon=True
|
||||
)
|
||||
self.socket_thread.start()
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
self.thread.start()
|
||||
@ -194,10 +195,10 @@ class RPEditor:
|
||||
"""Stop the editor with proper cleanup."""
|
||||
try:
|
||||
if self.client_sock:
|
||||
self.client_sock.send(pickle.dumps({'command': 'stop'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "stop"}))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
self.running = False
|
||||
time.sleep(0.1) # Give threads time to finish
|
||||
self._cleanup()
|
||||
@ -206,20 +207,20 @@ class RPEditor:
|
||||
"""Run the main editor loop with exception handling."""
|
||||
try:
|
||||
curses.wrapper(self.main_loop)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
self._exception_occurred = True
|
||||
self._cleanup()
|
||||
|
||||
def main_loop(self, stdscr):
|
||||
"""Main editor loop with enhanced error recovery."""
|
||||
self.stdscr = stdscr
|
||||
|
||||
|
||||
try:
|
||||
# Configure curses
|
||||
curses.curs_set(1)
|
||||
self.stdscr.keypad(True)
|
||||
self.stdscr.timeout(100) # Non-blocking with timeout
|
||||
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process queued commands
|
||||
@ -230,11 +231,11 @@ class RPEditor:
|
||||
self.execute_command(command)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
|
||||
# Draw screen
|
||||
with self.lock:
|
||||
self.draw()
|
||||
|
||||
|
||||
# Handle input
|
||||
try:
|
||||
key = self.stdscr.getch()
|
||||
@ -243,12 +244,12 @@ class RPEditor:
|
||||
self.handle_key(key)
|
||||
except curses.error:
|
||||
pass # Ignore curses errors
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception:
|
||||
# Log error but continue running
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
|
||||
except Exception:
|
||||
self._exception_occurred = True
|
||||
finally:
|
||||
self._cleanup()
|
||||
@ -258,28 +259,28 @@ class RPEditor:
|
||||
try:
|
||||
self.stdscr.clear()
|
||||
height, width = self.stdscr.getmaxyx()
|
||||
|
||||
|
||||
# Draw lines
|
||||
for i, line in enumerate(self.lines):
|
||||
if i >= height - 1:
|
||||
break
|
||||
try:
|
||||
# Handle long lines and special characters
|
||||
display_line = line[:width-1] if len(line) >= width else line
|
||||
display_line = line[: width - 1] if len(line) >= width else line
|
||||
self.stdscr.addstr(i, 0, display_line)
|
||||
except curses.error:
|
||||
pass # Skip lines that can't be displayed
|
||||
|
||||
|
||||
# Draw status line
|
||||
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
|
||||
if self.mode == 'command':
|
||||
status = self.command[:width-1]
|
||||
|
||||
if self.mode == "command":
|
||||
status = self.command[: width - 1]
|
||||
|
||||
try:
|
||||
self.stdscr.addstr(height - 1, 0, status[:width-1])
|
||||
self.stdscr.addstr(height - 1, 0, status[: width - 1])
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
|
||||
# Position cursor
|
||||
cursor_x = min(self.cursor_x, width - 1)
|
||||
cursor_y = min(self.cursor_y, height - 2)
|
||||
@ -287,7 +288,7 @@ class RPEditor:
|
||||
self.stdscr.move(cursor_y, cursor_x)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
|
||||
self.stdscr.refresh()
|
||||
except Exception:
|
||||
pass # Continue even if draw fails
|
||||
@ -295,11 +296,11 @@ class RPEditor:
|
||||
def handle_key(self, key):
|
||||
"""Handle keyboard input with error recovery."""
|
||||
try:
|
||||
if self.mode == 'normal':
|
||||
if self.mode == "normal":
|
||||
self.handle_normal(key)
|
||||
elif self.mode == 'insert':
|
||||
elif self.mode == "insert":
|
||||
self.handle_insert(key)
|
||||
elif self.mode == 'command':
|
||||
elif self.mode == "command":
|
||||
self.handle_command(key)
|
||||
except Exception:
|
||||
pass # Continue on error
|
||||
@ -307,73 +308,73 @@ class RPEditor:
|
||||
def handle_normal(self, key):
|
||||
"""Handle normal mode keys."""
|
||||
try:
|
||||
if key == ord('h') or key == curses.KEY_LEFT:
|
||||
if key == ord("h") or key == curses.KEY_LEFT:
|
||||
self.move_cursor(0, -1)
|
||||
elif key == ord('j') or key == curses.KEY_DOWN:
|
||||
elif key == ord("j") or key == curses.KEY_DOWN:
|
||||
self.move_cursor(1, 0)
|
||||
elif key == ord('k') or key == curses.KEY_UP:
|
||||
elif key == ord("k") or key == curses.KEY_UP:
|
||||
self.move_cursor(-1, 0)
|
||||
elif key == ord('l') or key == curses.KEY_RIGHT:
|
||||
elif key == ord("l") or key == curses.KEY_RIGHT:
|
||||
self.move_cursor(0, 1)
|
||||
elif key == ord('i'):
|
||||
self.mode = 'insert'
|
||||
elif key == ord(':'):
|
||||
self.mode = 'command'
|
||||
elif key == ord("i"):
|
||||
self.mode = "insert"
|
||||
elif key == ord(":"):
|
||||
self.mode = "command"
|
||||
self.command = ":"
|
||||
elif key == ord('x'):
|
||||
elif key == ord("x"):
|
||||
self._delete_char()
|
||||
elif key == ord('a'):
|
||||
elif key == ord("a"):
|
||||
self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y]))
|
||||
self.mode = 'insert'
|
||||
elif key == ord('A'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("A"):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
self.mode = 'insert'
|
||||
elif key == ord('o'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("o"):
|
||||
self._insert_line(self.cursor_y + 1, "")
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('O'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("O"):
|
||||
self._insert_line(self.cursor_y, "")
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('d') and self.prev_key == ord('d'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("d") and self.prev_key == ord("d"):
|
||||
if self.cursor_y < len(self.lines):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
self._delete_line(self.cursor_y)
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.cursor_y = max(0, len(self.lines) - 1)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('y') and self.prev_key == ord('y'):
|
||||
elif key == ord("y") and self.prev_key == ord("y"):
|
||||
if self.cursor_y < len(self.lines):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
elif key == ord('p'):
|
||||
elif key == ord("p"):
|
||||
self._insert_line(self.cursor_y + 1, self.clipboard)
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('P'):
|
||||
elif key == ord("P"):
|
||||
self._insert_line(self.cursor_y, self.clipboard)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('w'):
|
||||
elif key == ord("w"):
|
||||
self._move_word_forward()
|
||||
elif key == ord('b'):
|
||||
elif key == ord("b"):
|
||||
self._move_word_backward()
|
||||
elif key == ord('0'):
|
||||
elif key == ord("0"):
|
||||
self.cursor_x = 0
|
||||
elif key == ord('$'):
|
||||
elif key == ord("$"):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
elif key == ord('g'):
|
||||
if self.prev_key == ord('g'):
|
||||
elif key == ord("g"):
|
||||
if self.prev_key == ord("g"):
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
elif key == ord('G'):
|
||||
elif key == ord("G"):
|
||||
self.cursor_y = max(0, len(self.lines) - 1)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('u'):
|
||||
elif key == ord("u"):
|
||||
self.undo()
|
||||
elif key == ord('r') and self.prev_key == 18: # Ctrl-R
|
||||
elif key == ord("r") and self.prev_key == 18: # Ctrl-R
|
||||
self.redo()
|
||||
|
||||
|
||||
self.prev_key = key
|
||||
except Exception:
|
||||
pass
|
||||
@ -410,7 +411,7 @@ class RPEditor:
|
||||
"""Handle insert mode keys."""
|
||||
try:
|
||||
if key == 27: # ESC
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
if self.cursor_x > 0:
|
||||
self.cursor_x -= 1
|
||||
elif key == 10 or key == 13: # Enter
|
||||
@ -438,10 +439,10 @@ class RPEditor:
|
||||
elif cmd.startswith("w "):
|
||||
self.filename = cmd[2:].strip()
|
||||
self._save_file()
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
elif key == 27: # ESC
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
|
||||
if len(self.command) > 1:
|
||||
@ -449,17 +450,17 @@ class RPEditor:
|
||||
elif 32 <= key <= 126:
|
||||
self.command += chr(key)
|
||||
except Exception:
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
|
||||
def move_cursor(self, dy, dx):
|
||||
"""Move cursor with bounds checking."""
|
||||
if not self.lines:
|
||||
self.lines = [""]
|
||||
|
||||
|
||||
new_y = self.cursor_y + dy
|
||||
new_x = self.cursor_x + dx
|
||||
|
||||
|
||||
# Ensure valid Y position
|
||||
if 0 <= new_y < len(self.lines):
|
||||
self.cursor_y = new_y
|
||||
@ -477,9 +478,9 @@ class RPEditor:
|
||||
"""Save current state for undo."""
|
||||
with self.lock:
|
||||
state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
"lines": list(self.lines),
|
||||
"cursor_y": self.cursor_y,
|
||||
"cursor_x": self.cursor_x,
|
||||
}
|
||||
self.undo_stack.append(state)
|
||||
if len(self.undo_stack) > self.max_undo:
|
||||
@ -491,69 +492,79 @@ class RPEditor:
|
||||
with self.lock:
|
||||
if self.undo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
"lines": list(self.lines),
|
||||
"cursor_y": self.cursor_y,
|
||||
"cursor_x": self.cursor_x,
|
||||
}
|
||||
self.redo_stack.append(current_state)
|
||||
state = self.undo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
|
||||
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
|
||||
self.lines = state["lines"]
|
||||
self.cursor_y = min(state["cursor_y"], len(self.lines) - 1)
|
||||
self.cursor_x = min(
|
||||
state["cursor_x"],
|
||||
len(self.lines[self.cursor_y]) if self.lines else 0,
|
||||
)
|
||||
|
||||
def redo(self):
|
||||
"""Redo last undone change."""
|
||||
with self.lock:
|
||||
if self.redo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
"lines": list(self.lines),
|
||||
"cursor_y": self.cursor_y,
|
||||
"cursor_x": self.cursor_x,
|
||||
}
|
||||
self.undo_stack.append(current_state)
|
||||
state = self.redo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
|
||||
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
|
||||
self.lines = state["lines"]
|
||||
self.cursor_y = min(state["cursor_y"], len(self.lines) - 1)
|
||||
self.cursor_x = min(
|
||||
state["cursor_x"],
|
||||
len(self.lines[self.cursor_y]) if self.lines else 0,
|
||||
)
|
||||
|
||||
def _insert_text(self, text):
|
||||
"""Insert text at cursor position."""
|
||||
if not text:
|
||||
return
|
||||
|
||||
|
||||
self.save_state()
|
||||
lines = text.split('\n')
|
||||
|
||||
lines = text.split("\n")
|
||||
|
||||
if len(lines) == 1:
|
||||
# Single line insert
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x] + text + line[self.cursor_x:]
|
||||
self.lines[self.cursor_y] = (
|
||||
line[: self.cursor_x] + text + line[self.cursor_x :]
|
||||
)
|
||||
self.cursor_x += len(text)
|
||||
else:
|
||||
# Multi-line insert
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
|
||||
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
|
||||
|
||||
|
||||
first = self.lines[self.cursor_y][: self.cursor_x] + lines[0]
|
||||
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :]
|
||||
|
||||
self.lines[self.cursor_y] = first
|
||||
for i in range(1, len(lines) - 1):
|
||||
self.lines.insert(self.cursor_y + i, lines[i])
|
||||
self.lines.insert(self.cursor_y + len(lines) - 1, last)
|
||||
|
||||
|
||||
self.cursor_y += len(lines) - 1
|
||||
self.cursor_x = len(lines[-1])
|
||||
|
||||
def insert_text(self, text):
|
||||
"""Thread-safe text insertion."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
|
||||
self.client_sock.send(
|
||||
pickle.dumps({"command": "insert_text", "text": text})
|
||||
)
|
||||
except:
|
||||
with self.lock:
|
||||
self._insert_text(text)
|
||||
@ -561,14 +572,18 @@ class RPEditor:
|
||||
def _delete_char(self):
|
||||
"""Delete character at cursor."""
|
||||
self.save_state()
|
||||
if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]):
|
||||
if self.cursor_y < len(self.lines) and self.cursor_x < len(
|
||||
self.lines[self.cursor_y]
|
||||
):
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x] + line[self.cursor_x+1:]
|
||||
self.lines[self.cursor_y] = (
|
||||
line[: self.cursor_x] + line[self.cursor_x + 1 :]
|
||||
)
|
||||
|
||||
def delete_char(self):
|
||||
"""Thread-safe character deletion."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "delete_char"}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._delete_char()
|
||||
@ -578,9 +593,9 @@ class RPEditor:
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x] + char + line[self.cursor_x:]
|
||||
self.lines[self.cursor_y] = line[: self.cursor_x] + char + line[self.cursor_x :]
|
||||
self.cursor_x += 1
|
||||
|
||||
def _split_line(self):
|
||||
@ -588,10 +603,10 @@ class RPEditor:
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.lines.append("")
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x]
|
||||
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
|
||||
self.lines[self.cursor_y] = line[: self.cursor_x]
|
||||
self.lines.insert(self.cursor_y + 1, line[self.cursor_x :])
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
|
||||
@ -599,7 +614,9 @@ class RPEditor:
|
||||
"""Handle backspace key."""
|
||||
if self.cursor_x > 0:
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x-1] + line[self.cursor_x:]
|
||||
self.lines[self.cursor_y] = (
|
||||
line[: self.cursor_x - 1] + line[self.cursor_x :]
|
||||
)
|
||||
self.cursor_x -= 1
|
||||
elif self.cursor_y > 0:
|
||||
prev_len = len(self.lines[self.cursor_y - 1])
|
||||
@ -637,7 +654,7 @@ class RPEditor:
|
||||
self._set_text(text)
|
||||
return
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
|
||||
self.client_sock.send(pickle.dumps({"command": "set_text", "text": text}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._set_text(text)
|
||||
@ -651,7 +668,9 @@ class RPEditor:
|
||||
def goto_line(self, line_num):
|
||||
"""Thread-safe goto line."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
|
||||
self.client_sock.send(
|
||||
pickle.dumps({"command": "goto_line", "line_num": line_num})
|
||||
)
|
||||
except:
|
||||
with self.lock:
|
||||
self._goto_line(line_num)
|
||||
@ -659,17 +678,17 @@ class RPEditor:
|
||||
def get_text(self):
|
||||
"""Get entire text content."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "get_text"}))
|
||||
data = self.client_sock.recv(65536)
|
||||
return pickle.loads(data)
|
||||
except:
|
||||
with self.lock:
|
||||
return '\n'.join(self.lines)
|
||||
return "\n".join(self.lines)
|
||||
|
||||
def get_cursor(self):
|
||||
"""Get cursor position."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "get_cursor"}))
|
||||
data = self.client_sock.recv(4096)
|
||||
return pickle.loads(data)
|
||||
except:
|
||||
@ -679,16 +698,16 @@ class RPEditor:
|
||||
def get_file_info(self):
|
||||
"""Get file information."""
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "get_file_info"}))
|
||||
data = self.client_sock.recv(4096)
|
||||
return pickle.loads(data)
|
||||
except:
|
||||
with self.lock:
|
||||
return {
|
||||
'filename': self.filename,
|
||||
'lines': len(self.lines),
|
||||
'cursor': (self.cursor_y, self.cursor_x),
|
||||
'mode': self.mode
|
||||
"filename": self.filename,
|
||||
"lines": len(self.lines),
|
||||
"cursor": (self.cursor_y, self.cursor_x),
|
||||
"mode": self.mode,
|
||||
}
|
||||
|
||||
def socket_listener(self):
|
||||
@ -713,39 +732,39 @@ class RPEditor:
|
||||
def execute_command(self, command):
|
||||
"""Execute command with error handling."""
|
||||
try:
|
||||
cmd = command.get('command')
|
||||
|
||||
if cmd == 'insert_text':
|
||||
self._insert_text(command.get('text', ''))
|
||||
elif cmd == 'delete_char':
|
||||
cmd = command.get("command")
|
||||
|
||||
if cmd == "insert_text":
|
||||
self._insert_text(command.get("text", ""))
|
||||
elif cmd == "delete_char":
|
||||
self._delete_char()
|
||||
elif cmd == 'save_file':
|
||||
elif cmd == "save_file":
|
||||
self._save_file()
|
||||
elif cmd == 'set_text':
|
||||
self._set_text(command.get('text', ''))
|
||||
elif cmd == 'goto_line':
|
||||
self._goto_line(command.get('line_num', 1))
|
||||
elif cmd == 'get_text':
|
||||
result = '\n'.join(self.lines)
|
||||
elif cmd == "set_text":
|
||||
self._set_text(command.get("text", ""))
|
||||
elif cmd == "goto_line":
|
||||
self._goto_line(command.get("line_num", 1))
|
||||
elif cmd == "get_text":
|
||||
result = "\n".join(self.lines)
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
elif cmd == 'get_cursor':
|
||||
elif cmd == "get_cursor":
|
||||
result = (self.cursor_y, self.cursor_x)
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
elif cmd == 'get_file_info':
|
||||
elif cmd == "get_file_info":
|
||||
result = {
|
||||
'filename': self.filename,
|
||||
'lines': len(self.lines),
|
||||
'cursor': (self.cursor_y, self.cursor_x),
|
||||
'mode': self.mode
|
||||
"filename": self.filename,
|
||||
"lines": len(self.lines),
|
||||
"cursor": (self.cursor_y, self.cursor_x),
|
||||
"mode": self.mode,
|
||||
}
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
elif cmd == 'stop':
|
||||
elif cmd == "stop":
|
||||
self.running = False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Additional public methods for backwards compatibility
|
||||
|
||||
|
||||
def move_cursor_to(self, y, x):
|
||||
"""Move cursor to specific position."""
|
||||
with self.lock:
|
||||
@ -788,11 +807,11 @@ class RPEditor:
|
||||
"""Replace text in range."""
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
|
||||
|
||||
# Validate bounds
|
||||
start_line = max(0, min(start_line, len(self.lines) - 1))
|
||||
end_line = max(0, min(end_line, len(self.lines) - 1))
|
||||
|
||||
|
||||
if start_line == end_line:
|
||||
line = self.lines[start_line]
|
||||
start_col = max(0, min(start_col, len(line)))
|
||||
@ -801,9 +820,9 @@ class RPEditor:
|
||||
else:
|
||||
first_part = self.lines[start_line][:start_col]
|
||||
last_part = self.lines[end_line][end_col:]
|
||||
new_lines = new_text.split('\n')
|
||||
new_lines = new_text.split("\n")
|
||||
self.lines[start_line] = first_part + new_lines[0]
|
||||
del self.lines[start_line + 1:end_line + 1]
|
||||
del self.lines[start_line + 1 : end_line + 1]
|
||||
for i, new_line in enumerate(new_lines[1:], 1):
|
||||
self.lines.insert(start_line + i, new_line)
|
||||
if len(new_lines) > 1:
|
||||
@ -842,24 +861,24 @@ class RPEditor:
|
||||
with self.lock:
|
||||
if not self.selection_start or not self.selection_end:
|
||||
return ""
|
||||
|
||||
|
||||
sl, sc = self.selection_start
|
||||
el, ec = self.selection_end
|
||||
|
||||
|
||||
# Validate bounds
|
||||
if sl < 0 or sl >= len(self.lines) or el < 0 or el >= len(self.lines):
|
||||
return ""
|
||||
|
||||
|
||||
if sl == el:
|
||||
return self.lines[sl][sc:ec]
|
||||
|
||||
|
||||
result = [self.lines[sl][sc:]]
|
||||
for i in range(sl + 1, el):
|
||||
if i < len(self.lines):
|
||||
result.append(self.lines[i])
|
||||
if el < len(self.lines):
|
||||
result.append(self.lines[el][:ec])
|
||||
return '\n'.join(result)
|
||||
return "\n".join(result)
|
||||
|
||||
def delete_selection(self):
|
||||
"""Delete selected text."""
|
||||
@ -880,7 +899,7 @@ class RPEditor:
|
||||
self.save_state()
|
||||
search_lines = search_block.splitlines()
|
||||
replace_lines = replace_block.splitlines()
|
||||
|
||||
|
||||
for i in range(len(self.lines) - len(search_lines) + 1):
|
||||
match = True
|
||||
for j, search_line in enumerate(search_lines):
|
||||
@ -890,12 +909,12 @@ class RPEditor:
|
||||
if self.lines[i + j].strip() != search_line.strip():
|
||||
match = False
|
||||
break
|
||||
|
||||
|
||||
if match:
|
||||
# Preserve indentation
|
||||
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
|
||||
indented_replace = [' ' * indent + line for line in replace_lines]
|
||||
self.lines[i:i+len(search_lines)] = indented_replace
|
||||
indented_replace = [" " * indent + line for line in replace_lines]
|
||||
self.lines[i : i + len(search_lines)] = indented_replace
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -904,21 +923,21 @@ class RPEditor:
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
try:
|
||||
lines = diff_text.split('\n')
|
||||
lines = diff_text.split("\n")
|
||||
start_line = 0
|
||||
|
||||
|
||||
for line in lines:
|
||||
if line.startswith('@@'):
|
||||
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
|
||||
if line.startswith("@@"):
|
||||
match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
|
||||
if match:
|
||||
start_line = int(match.group(1)) - 1
|
||||
elif line.startswith('-'):
|
||||
elif line.startswith("-"):
|
||||
if start_line < len(self.lines):
|
||||
del self.lines[start_line]
|
||||
elif line.startswith('+'):
|
||||
elif line.startswith("+"):
|
||||
self.lines.insert(start_line, line[1:])
|
||||
start_line += 1
|
||||
elif line and not line.startswith('\\'):
|
||||
elif line and not line.startswith("\\"):
|
||||
start_line += 1
|
||||
except Exception:
|
||||
pass
|
||||
@ -972,18 +991,18 @@ def main():
|
||||
editor = None
|
||||
try:
|
||||
filename = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
|
||||
|
||||
# Parse additional arguments
|
||||
auto_save = '--auto-save' in sys.argv
|
||||
|
||||
auto_save = "--auto-save" in sys.argv
|
||||
|
||||
# Create and start editor
|
||||
editor = RPEditor(filename, auto_save=auto_save)
|
||||
editor.start()
|
||||
|
||||
|
||||
# Wait for editor to finish
|
||||
if editor.thread:
|
||||
editor.thread.join()
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
@ -992,7 +1011,7 @@ def main():
|
||||
if editor:
|
||||
editor.stop()
|
||||
# Ensure screen is cleared
|
||||
os.system('clear' if os.name != 'nt' else 'cls')
|
||||
os.system("clear" if os.name != "nt" else "cls")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
247
pr/editor2.py
247
pr/editor2.py
@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import curses
|
||||
import threading
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import pickle
|
||||
import queue
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
|
||||
|
||||
class RPEditor:
|
||||
def __init__(self, filename=None):
|
||||
@ -14,7 +14,7 @@ class RPEditor:
|
||||
self.lines = [""]
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
self.stdscr = None
|
||||
self.running = False
|
||||
@ -35,7 +35,7 @@ class RPEditor:
|
||||
|
||||
def load_file(self):
|
||||
try:
|
||||
with open(self.filename, 'r') as f:
|
||||
with open(self.filename) as f:
|
||||
self.lines = f.read().splitlines()
|
||||
if not self.lines:
|
||||
self.lines = [""]
|
||||
@ -45,11 +45,11 @@ class RPEditor:
|
||||
def _save_file(self):
|
||||
with self.lock:
|
||||
if self.filename:
|
||||
with open(self.filename, 'w') as f:
|
||||
f.write('\n'.join(self.lines))
|
||||
with open(self.filename, "w") as f:
|
||||
f.write("\n".join(self.lines))
|
||||
|
||||
def save_file(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "save_file"}))
|
||||
|
||||
def start(self):
|
||||
self.running = True
|
||||
@ -59,7 +59,7 @@ class RPEditor:
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'stop'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "stop"}))
|
||||
self.running = False
|
||||
if self.stdscr:
|
||||
curses.endwin()
|
||||
@ -99,66 +99,66 @@ class RPEditor:
|
||||
self.stdscr.addstr(i, 0, line[:width])
|
||||
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
|
||||
self.stdscr.addstr(height - 1, 0, status[:width])
|
||||
if self.mode == 'command':
|
||||
if self.mode == "command":
|
||||
self.stdscr.addstr(height - 1, 0, self.command[:width])
|
||||
self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1))
|
||||
self.stdscr.refresh()
|
||||
|
||||
def handle_key(self, key):
|
||||
if self.mode == 'normal':
|
||||
if self.mode == "normal":
|
||||
self.handle_normal(key)
|
||||
elif self.mode == 'insert':
|
||||
elif self.mode == "insert":
|
||||
self.handle_insert(key)
|
||||
elif self.mode == 'command':
|
||||
elif self.mode == "command":
|
||||
self.handle_command(key)
|
||||
|
||||
def handle_normal(self, key):
|
||||
if key == ord('h') or key == curses.KEY_LEFT:
|
||||
if key == ord("h") or key == curses.KEY_LEFT:
|
||||
self.move_cursor(0, -1)
|
||||
elif key == ord('j') or key == curses.KEY_DOWN:
|
||||
elif key == ord("j") or key == curses.KEY_DOWN:
|
||||
self.move_cursor(1, 0)
|
||||
elif key == ord('k') or key == curses.KEY_UP:
|
||||
elif key == ord("k") or key == curses.KEY_UP:
|
||||
self.move_cursor(-1, 0)
|
||||
elif key == ord('l') or key == curses.KEY_RIGHT:
|
||||
elif key == ord("l") or key == curses.KEY_RIGHT:
|
||||
self.move_cursor(0, 1)
|
||||
elif key == ord('i'):
|
||||
self.mode = 'insert'
|
||||
elif key == ord(':'):
|
||||
self.mode = 'command'
|
||||
elif key == ord("i"):
|
||||
self.mode = "insert"
|
||||
elif key == ord(":"):
|
||||
self.mode = "command"
|
||||
self.command = ":"
|
||||
elif key == ord('x'):
|
||||
elif key == ord("x"):
|
||||
self._delete_char()
|
||||
elif key == ord('a'):
|
||||
elif key == ord("a"):
|
||||
self.cursor_x += 1
|
||||
self.mode = 'insert'
|
||||
elif key == ord('A'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("A"):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
self.mode = 'insert'
|
||||
elif key == ord('o'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("o"):
|
||||
self._insert_line(self.cursor_y + 1, "")
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('O'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("O"):
|
||||
self._insert_line(self.cursor_y, "")
|
||||
self.cursor_x = 0
|
||||
self.mode = 'insert'
|
||||
elif key == ord('d') and self.prev_key == ord('d'):
|
||||
self.mode = "insert"
|
||||
elif key == ord("d") and self.prev_key == ord("d"):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
self._delete_line(self.cursor_y)
|
||||
if self.cursor_y >= len(self.lines):
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('y') and self.prev_key == ord('y'):
|
||||
elif key == ord("y") and self.prev_key == ord("y"):
|
||||
self.clipboard = self.lines[self.cursor_y]
|
||||
elif key == ord('p'):
|
||||
elif key == ord("p"):
|
||||
self._insert_line(self.cursor_y + 1, self.clipboard)
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('P'):
|
||||
elif key == ord("P"):
|
||||
self._insert_line(self.cursor_y, self.clipboard)
|
||||
self.cursor_x = 0
|
||||
elif key == ord('w'):
|
||||
elif key == ord("w"):
|
||||
line = self.lines[self.cursor_y]
|
||||
i = self.cursor_x
|
||||
while i < len(line) and not line[i].isalnum():
|
||||
@ -166,7 +166,7 @@ class RPEditor:
|
||||
while i < len(line) and line[i].isalnum():
|
||||
i += 1
|
||||
self.cursor_x = i
|
||||
elif key == ord('b'):
|
||||
elif key == ord("b"):
|
||||
line = self.lines[self.cursor_y]
|
||||
i = self.cursor_x - 1
|
||||
while i >= 0 and not line[i].isalnum():
|
||||
@ -174,26 +174,26 @@ class RPEditor:
|
||||
while i >= 0 and line[i].isalnum():
|
||||
i -= 1
|
||||
self.cursor_x = i + 1
|
||||
elif key == ord('0'):
|
||||
elif key == ord("0"):
|
||||
self.cursor_x = 0
|
||||
elif key == ord('$'):
|
||||
elif key == ord("$"):
|
||||
self.cursor_x = len(self.lines[self.cursor_y])
|
||||
elif key == ord('g'):
|
||||
if self.prev_key == ord('g'):
|
||||
elif key == ord("g"):
|
||||
if self.prev_key == ord("g"):
|
||||
self.cursor_y = 0
|
||||
self.cursor_x = 0
|
||||
elif key == ord('G'):
|
||||
elif key == ord("G"):
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
self.cursor_x = 0
|
||||
elif key == ord('u'):
|
||||
elif key == ord("u"):
|
||||
self.undo()
|
||||
elif key == ord('r') and self.prev_key == 18:
|
||||
elif key == ord("r") and self.prev_key == 18:
|
||||
self.redo()
|
||||
self.prev_key = key
|
||||
|
||||
def handle_insert(self, key):
|
||||
if key == 27:
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
if self.cursor_x > 0:
|
||||
self.cursor_x -= 1
|
||||
elif key == 10:
|
||||
@ -207,11 +207,13 @@ class RPEditor:
|
||||
def handle_command(self, key):
|
||||
if key == 10:
|
||||
cmd = self.command[1:]
|
||||
if cmd == "q" or cmd == 'q!':
|
||||
if cmd == "q" or cmd == "q!":
|
||||
self.running = False
|
||||
elif cmd == "w":
|
||||
elif cmd == "w":
|
||||
self._save_file()
|
||||
elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!":
|
||||
elif (
|
||||
cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!"
|
||||
):
|
||||
self._save_file()
|
||||
self.running = False
|
||||
elif cmd.startswith("w "):
|
||||
@ -220,10 +222,10 @@ class RPEditor:
|
||||
elif cmd == "wq":
|
||||
self._save_file()
|
||||
self.running = False
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
elif key == 27:
|
||||
self.mode = 'normal'
|
||||
self.mode = "normal"
|
||||
self.command = ""
|
||||
elif key == curses.KEY_BACKSPACE or key == 127:
|
||||
if len(self.command) > 1:
|
||||
@ -241,9 +243,9 @@ class RPEditor:
|
||||
def save_state(self):
|
||||
with self.lock:
|
||||
state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
"lines": list(self.lines),
|
||||
"cursor_y": self.cursor_y,
|
||||
"cursor_x": self.cursor_x,
|
||||
}
|
||||
self.undo_stack.append(state)
|
||||
if len(self.undo_stack) > self.max_undo:
|
||||
@ -254,71 +256,85 @@ class RPEditor:
|
||||
with self.lock:
|
||||
if self.undo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
"lines": list(self.lines),
|
||||
"cursor_y": self.cursor_y,
|
||||
"cursor_x": self.cursor_x,
|
||||
}
|
||||
self.redo_stack.append(current_state)
|
||||
state = self.undo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = state['cursor_y']
|
||||
self.cursor_x = state['cursor_x']
|
||||
self.lines = state["lines"]
|
||||
self.cursor_y = state["cursor_y"]
|
||||
self.cursor_x = state["cursor_x"]
|
||||
|
||||
def redo(self):
|
||||
with self.lock:
|
||||
if self.redo_stack:
|
||||
current_state = {
|
||||
'lines': [line for line in self.lines],
|
||||
'cursor_y': self.cursor_y,
|
||||
'cursor_x': self.cursor_x
|
||||
"lines": list(self.lines),
|
||||
"cursor_y": self.cursor_y,
|
||||
"cursor_x": self.cursor_x,
|
||||
}
|
||||
self.undo_stack.append(current_state)
|
||||
state = self.redo_stack.pop()
|
||||
self.lines = state['lines']
|
||||
self.cursor_y = state['cursor_y']
|
||||
self.cursor_x = state['cursor_x']
|
||||
self.lines = state["lines"]
|
||||
self.cursor_y = state["cursor_y"]
|
||||
self.cursor_x = state["cursor_x"]
|
||||
|
||||
def _insert_text(self, text):
|
||||
self.save_state()
|
||||
lines = text.split('\n')
|
||||
lines = text.split("\n")
|
||||
if len(lines) == 1:
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + text + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.lines[self.cursor_y] = (
|
||||
self.lines[self.cursor_y][: self.cursor_x]
|
||||
+ text
|
||||
+ self.lines[self.cursor_y][self.cursor_x :]
|
||||
)
|
||||
self.cursor_x += len(text)
|
||||
else:
|
||||
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
|
||||
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
|
||||
first = self.lines[self.cursor_y][: self.cursor_x] + lines[0]
|
||||
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x :]
|
||||
self.lines[self.cursor_y] = first
|
||||
for i in range(1, len(lines)-1):
|
||||
for i in range(1, len(lines) - 1):
|
||||
self.lines.insert(self.cursor_y + i, lines[i])
|
||||
self.lines.insert(self.cursor_y + len(lines) - 1, last)
|
||||
self.cursor_y += len(lines) - 1
|
||||
self.cursor_x = len(lines[-1])
|
||||
|
||||
def insert_text(self, text):
|
||||
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
|
||||
self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text}))
|
||||
|
||||
def _delete_char(self):
|
||||
self.save_state()
|
||||
if self.cursor_x < len(self.lines[self.cursor_y]):
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + self.lines[self.cursor_y][self.cursor_x+1:]
|
||||
self.lines[self.cursor_y] = (
|
||||
self.lines[self.cursor_y][: self.cursor_x]
|
||||
+ self.lines[self.cursor_y][self.cursor_x + 1 :]
|
||||
)
|
||||
|
||||
def delete_char(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "delete_char"}))
|
||||
|
||||
def _insert_char(self, char):
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + char + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.lines[self.cursor_y] = (
|
||||
self.lines[self.cursor_y][: self.cursor_x]
|
||||
+ char
|
||||
+ self.lines[self.cursor_y][self.cursor_x :]
|
||||
)
|
||||
self.cursor_x += 1
|
||||
|
||||
def _split_line(self):
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = line[:self.cursor_x]
|
||||
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
|
||||
self.lines[self.cursor_y] = line[: self.cursor_x]
|
||||
self.lines.insert(self.cursor_y + 1, line[self.cursor_x :])
|
||||
self.cursor_y += 1
|
||||
self.cursor_x = 0
|
||||
|
||||
def _backspace(self):
|
||||
if self.cursor_x > 0:
|
||||
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x-1] + self.lines[self.cursor_y][self.cursor_x:]
|
||||
self.lines[self.cursor_y] = (
|
||||
self.lines[self.cursor_y][: self.cursor_x - 1]
|
||||
+ self.lines[self.cursor_y][self.cursor_x :]
|
||||
)
|
||||
self.cursor_x -= 1
|
||||
elif self.cursor_y > 0:
|
||||
prev_len = len(self.lines[self.cursor_y - 1])
|
||||
@ -347,7 +363,7 @@ class RPEditor:
|
||||
self.cursor_x = 0
|
||||
|
||||
def set_text(self, text):
|
||||
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
|
||||
self.client_sock.send(pickle.dumps({"command": "set_text", "text": text}))
|
||||
|
||||
def _goto_line(self, line_num):
|
||||
line_num = max(0, min(line_num, len(self.lines) - 1))
|
||||
@ -355,24 +371,26 @@ class RPEditor:
|
||||
self.cursor_x = 0
|
||||
|
||||
def goto_line(self, line_num):
|
||||
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
|
||||
self.client_sock.send(
|
||||
pickle.dumps({"command": "goto_line", "line_num": line_num})
|
||||
)
|
||||
|
||||
def get_text(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "get_text"}))
|
||||
try:
|
||||
return pickle.loads(self.client_sock.recv(4096))
|
||||
except:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def get_cursor(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "get_cursor"}))
|
||||
try:
|
||||
return pickle.loads(self.client_sock.recv(4096))
|
||||
except:
|
||||
return (0, 0)
|
||||
|
||||
def get_file_info(self):
|
||||
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
|
||||
self.client_sock.send(pickle.dumps({"command": "get_file_info"}))
|
||||
try:
|
||||
return pickle.loads(self.client_sock.recv(4096))
|
||||
except:
|
||||
@ -390,46 +408,46 @@ class RPEditor:
|
||||
break
|
||||
|
||||
def execute_command(self, command):
|
||||
cmd = command.get('command')
|
||||
if cmd == 'insert_text':
|
||||
self._insert_text(command['text'])
|
||||
elif cmd == 'delete_char':
|
||||
cmd = command.get("command")
|
||||
if cmd == "insert_text":
|
||||
self._insert_text(command["text"])
|
||||
elif cmd == "delete_char":
|
||||
self._delete_char()
|
||||
elif cmd == 'save_file':
|
||||
elif cmd == "save_file":
|
||||
self._save_file()
|
||||
elif cmd == 'set_text':
|
||||
self._set_text(command['text'])
|
||||
elif cmd == 'goto_line':
|
||||
self._goto_line(command['line_num'])
|
||||
elif cmd == 'get_text':
|
||||
result = '\n'.join(self.lines)
|
||||
elif cmd == "set_text":
|
||||
self._set_text(command["text"])
|
||||
elif cmd == "goto_line":
|
||||
self._goto_line(command["line_num"])
|
||||
elif cmd == "get_text":
|
||||
result = "\n".join(self.lines)
|
||||
try:
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
except:
|
||||
pass
|
||||
elif cmd == 'get_cursor':
|
||||
elif cmd == "get_cursor":
|
||||
result = (self.cursor_y, self.cursor_x)
|
||||
try:
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
except:
|
||||
pass
|
||||
elif cmd == 'get_file_info':
|
||||
elif cmd == "get_file_info":
|
||||
result = {
|
||||
'filename': self.filename,
|
||||
'lines': len(self.lines),
|
||||
'cursor': (self.cursor_y, self.cursor_x),
|
||||
'mode': self.mode
|
||||
"filename": self.filename,
|
||||
"lines": len(self.lines),
|
||||
"cursor": (self.cursor_y, self.cursor_x),
|
||||
"mode": self.mode,
|
||||
}
|
||||
try:
|
||||
self.server_sock.send(pickle.dumps(result))
|
||||
except:
|
||||
pass
|
||||
elif cmd == 'stop':
|
||||
elif cmd == "stop":
|
||||
self.running = False
|
||||
|
||||
def move_cursor_to(self, y, x):
|
||||
with self.lock:
|
||||
self.cursor_y = max(0, min(y, len(self.lines)-1))
|
||||
self.cursor_y = max(0, min(y, len(self.lines) - 1))
|
||||
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
|
||||
|
||||
def get_line(self, line_num):
|
||||
@ -469,9 +487,9 @@ class RPEditor:
|
||||
else:
|
||||
first_part = self.lines[start_line][:start_col]
|
||||
last_part = self.lines[end_line][end_col:]
|
||||
new_lines = new_text.split('\n')
|
||||
new_lines = new_text.split("\n")
|
||||
self.lines[start_line] = first_part + new_lines[0]
|
||||
del self.lines[start_line + 1:end_line + 1]
|
||||
del self.lines[start_line + 1 : end_line + 1]
|
||||
for i, new_line in enumerate(new_lines[1:], 1):
|
||||
self.lines.insert(start_line + i, new_line)
|
||||
if len(new_lines) > 1:
|
||||
@ -511,7 +529,7 @@ class RPEditor:
|
||||
for i in range(sl + 1, el):
|
||||
result.append(self.lines[i])
|
||||
result.append(self.lines[el][:ec])
|
||||
return '\n'.join(result)
|
||||
return "\n".join(result)
|
||||
|
||||
def delete_selection(self):
|
||||
with self.lock:
|
||||
@ -537,24 +555,24 @@ class RPEditor:
|
||||
break
|
||||
if match:
|
||||
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
|
||||
indented_replace = [' ' * indent + line for line in replace_lines]
|
||||
self.lines[i:i+len(search_lines)] = indented_replace
|
||||
indented_replace = [" " * indent + line for line in replace_lines]
|
||||
self.lines[i : i + len(search_lines)] = indented_replace
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_diff(self, diff_text):
|
||||
with self.lock:
|
||||
self.save_state()
|
||||
lines = diff_text.split('\n')
|
||||
lines = diff_text.split("\n")
|
||||
for line in lines:
|
||||
if line.startswith('@@'):
|
||||
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
|
||||
if line.startswith("@@"):
|
||||
match = re.search(r"@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@", line)
|
||||
if match:
|
||||
start_line = int(match.group(1)) - 1
|
||||
elif line.startswith('-'):
|
||||
elif line.startswith("-"):
|
||||
if start_line < len(self.lines):
|
||||
del self.lines[start_line]
|
||||
elif line.startswith('+'):
|
||||
elif line.startswith("+"):
|
||||
self.lines.insert(start_line, line[1:])
|
||||
start_line += 1
|
||||
|
||||
@ -574,6 +592,7 @@ class RPEditor:
|
||||
if self.thread:
|
||||
self.thread.join()
|
||||
|
||||
|
||||
def main():
|
||||
filename = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
editor = RPEditor(filename)
|
||||
@ -583,5 +602,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
@ -3,14 +3,13 @@
|
||||
Advanced input handler for PR Assistant with editor mode, file inclusion, and image support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import base64
|
||||
import mimetypes
|
||||
import re
|
||||
import readline
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# from pr.ui.colors import Colors # Avoid import issues
|
||||
|
||||
|
||||
@ -29,7 +28,7 @@ class AdvancedInputHandler:
|
||||
return None
|
||||
|
||||
readline.set_completer(completer)
|
||||
readline.parse_and_bind('tab: complete')
|
||||
readline.parse_and_bind("tab: complete")
|
||||
except:
|
||||
pass # Readline not available
|
||||
|
||||
@ -60,7 +59,7 @@ class AdvancedInputHandler:
|
||||
return ""
|
||||
|
||||
# Check for special commands
|
||||
if user_input.lower() == '/editor':
|
||||
if user_input.lower() == "/editor":
|
||||
self.toggle_editor_mode()
|
||||
return self.get_input(prompt) # Recurse to get new input
|
||||
|
||||
@ -74,23 +73,25 @@ class AdvancedInputHandler:
|
||||
def _get_editor_input(self, prompt: str) -> Optional[str]:
|
||||
"""Get multi-line input for editor mode."""
|
||||
try:
|
||||
print("Editor mode: Enter your message. Type 'END' on a new line to finish.")
|
||||
print(
|
||||
"Editor mode: Enter your message. Type 'END' on a new line to finish."
|
||||
)
|
||||
print("Type '/simple' to switch back to simple mode.")
|
||||
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
if line.strip().lower() == 'end':
|
||||
if line.strip().lower() == "end":
|
||||
break
|
||||
elif line.strip().lower() == '/simple':
|
||||
elif line.strip().lower() == "/simple":
|
||||
self.toggle_editor_mode()
|
||||
return self.get_input(prompt) # Switch back and get input
|
||||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
content = '\n'.join(lines).strip()
|
||||
content = "\n".join(lines).strip()
|
||||
|
||||
if not content:
|
||||
return ""
|
||||
@ -114,12 +115,13 @@ class AdvancedInputHandler:
|
||||
|
||||
def _process_file_inclusions(self, text: str) -> str:
|
||||
"""Replace @[filename] with file contents."""
|
||||
|
||||
def replace_file(match):
|
||||
filename = match.group(1).strip()
|
||||
try:
|
||||
path = Path(filename).expanduser().resolve()
|
||||
if path.exists() and path.is_file():
|
||||
with open(path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
with open(path, encoding="utf-8", errors="replace") as f:
|
||||
content = f.read()
|
||||
return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n"
|
||||
else:
|
||||
@ -128,7 +130,7 @@ class AdvancedInputHandler:
|
||||
return f"[Error reading file {filename}: {e}]"
|
||||
|
||||
# Replace @[filename] patterns
|
||||
pattern = r'@\[([^\]]+)\]'
|
||||
pattern = r"@\[([^\]]+)\]"
|
||||
return re.sub(pattern, replace_file, text)
|
||||
|
||||
def _process_image_inclusions(self, text: str) -> str:
|
||||
@ -143,20 +145,22 @@ class AdvancedInputHandler:
|
||||
path = Path(word.strip()).expanduser().resolve()
|
||||
if path.exists() and path.is_file():
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
if mime_type and mime_type.startswith('image/'):
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
# Encode image
|
||||
with open(path, 'rb') as f:
|
||||
image_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
with open(path, "rb") as f:
|
||||
image_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
# Replace with data URL
|
||||
processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n")
|
||||
processed_parts.append(
|
||||
f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n"
|
||||
)
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
|
||||
processed_parts.append(word)
|
||||
|
||||
return ' '.join(processed_parts)
|
||||
return " ".join(processed_parts)
|
||||
|
||||
|
||||
# Global instance
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
from .knowledge_store import KnowledgeStore, KnowledgeEntry
|
||||
from .semantic_index import SemanticIndex
|
||||
from .conversation_memory import ConversationMemory
|
||||
from .fact_extractor import FactExtractor
|
||||
from .knowledge_store import KnowledgeEntry, KnowledgeStore
|
||||
from .semantic_index import SemanticIndex
|
||||
|
||||
__all__ = ['KnowledgeStore', 'KnowledgeEntry', 'SemanticIndex',
|
||||
'ConversationMemory', 'FactExtractor']
|
||||
__all__ = [
|
||||
"KnowledgeStore",
|
||||
"KnowledgeEntry",
|
||||
"SemanticIndex",
|
||||
"ConversationMemory",
|
||||
"FactExtractor",
|
||||
]
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ConversationMemory:
|
||||
def __init__(self, db_path: str):
|
||||
@ -12,7 +13,8 @@ class ConversationMemory:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS conversation_history (
|
||||
conversation_id TEXT PRIMARY KEY,
|
||||
session_id TEXT,
|
||||
@ -23,9 +25,11 @@ class ConversationMemory:
|
||||
topics TEXT,
|
||||
metadata TEXT
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS conversation_messages (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
@ -36,117 +40,163 @@ class ConversationMemory:
|
||||
metadata TEXT,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def create_conversation(self, conversation_id: str, session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None):
|
||||
def create_conversation(
|
||||
self,
|
||||
conversation_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO conversation_history
|
||||
(conversation_id, session_id, started_at, metadata)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (
|
||||
conversation_id,
|
||||
session_id,
|
||||
time.time(),
|
||||
json.dumps(metadata) if metadata else None
|
||||
))
|
||||
""",
|
||||
(
|
||||
conversation_id,
|
||||
session_id,
|
||||
time.time(),
|
||||
json.dumps(metadata) if metadata else None,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def add_message(self, conversation_id: str, message_id: str, role: str,
|
||||
content: str, tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None):
|
||||
def add_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO conversation_messages
|
||||
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
message_id,
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
time.time(),
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
json.dumps(metadata) if metadata else None
|
||||
))
|
||||
""",
|
||||
(
|
||||
message_id,
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
time.time(),
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
json.dumps(metadata) if metadata else None,
|
||||
),
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE conversation_history
|
||||
SET message_count = message_count + 1
|
||||
WHERE conversation_id = ?
|
||||
''', (conversation_id,))
|
||||
""",
|
||||
(conversation_id,),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def get_conversation_messages(self, conversation_id: str,
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
def get_conversation_messages(
|
||||
self, conversation_id: str, limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if limit:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, role, content, timestamp, tool_calls, metadata
|
||||
FROM conversation_messages
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (conversation_id, limit))
|
||||
""",
|
||||
(conversation_id, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT message_id, role, content, timestamp, tool_calls, metadata
|
||||
FROM conversation_messages
|
||||
WHERE conversation_id = ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (conversation_id,))
|
||||
""",
|
||||
(conversation_id,),
|
||||
)
|
||||
|
||||
messages = []
|
||||
for row in cursor.fetchall():
|
||||
messages.append({
|
||||
'message_id': row[0],
|
||||
'role': row[1],
|
||||
'content': row[2],
|
||||
'timestamp': row[3],
|
||||
'tool_calls': json.loads(row[4]) if row[4] else None,
|
||||
'metadata': json.loads(row[5]) if row[5] else None
|
||||
})
|
||||
messages.append(
|
||||
{
|
||||
"message_id": row[0],
|
||||
"role": row[1],
|
||||
"content": row[2],
|
||||
"timestamp": row[3],
|
||||
"tool_calls": json.loads(row[4]) if row[4] else None,
|
||||
"metadata": json.loads(row[5]) if row[5] else None,
|
||||
}
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return messages
|
||||
|
||||
def update_conversation_summary(self, conversation_id: str, summary: str,
|
||||
topics: Optional[List[str]] = None):
|
||||
def update_conversation_summary(
|
||||
self, conversation_id: str, summary: str, topics: Optional[List[str]] = None
|
||||
):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE conversation_history
|
||||
SET summary = ?, topics = ?, ended_at = ?
|
||||
WHERE conversation_id = ?
|
||||
''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id))
|
||||
""",
|
||||
(
|
||||
summary,
|
||||
json.dumps(topics) if topics else None,
|
||||
time.time(),
|
||||
conversation_id,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -155,7 +205,8 @@ class ConversationMemory:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
|
||||
h.message_count, h.summary, h.topics
|
||||
FROM conversation_history h
|
||||
@ -163,56 +214,69 @@ class ConversationMemory:
|
||||
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
|
||||
ORDER BY h.started_at DESC
|
||||
LIMIT ?
|
||||
''', (f'%{query}%', f'%{query}%', f'%{query}%', limit))
|
||||
""",
|
||||
(f"%{query}%", f"%{query}%", f"%{query}%", limit),
|
||||
)
|
||||
|
||||
conversations = []
|
||||
for row in cursor.fetchall():
|
||||
conversations.append({
|
||||
'conversation_id': row[0],
|
||||
'session_id': row[1],
|
||||
'started_at': row[2],
|
||||
'message_count': row[3],
|
||||
'summary': row[4],
|
||||
'topics': json.loads(row[5]) if row[5] else []
|
||||
})
|
||||
conversations.append(
|
||||
{
|
||||
"conversation_id": row[0],
|
||||
"session_id": row[1],
|
||||
"started_at": row[2],
|
||||
"message_count": row[3],
|
||||
"summary": row[4],
|
||||
"topics": json.loads(row[5]) if row[5] else [],
|
||||
}
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return conversations
|
||||
|
||||
def get_recent_conversations(self, limit: int = 10,
|
||||
session_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
def get_recent_conversations(
|
||||
self, limit: int = 10, session_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if session_id:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT conversation_id, session_id, started_at, ended_at,
|
||||
message_count, summary, topics
|
||||
FROM conversation_history
|
||||
WHERE session_id = ?
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
''', (session_id, limit))
|
||||
""",
|
||||
(session_id, limit),
|
||||
)
|
||||
else:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT conversation_id, session_id, started_at, ended_at,
|
||||
message_count, summary, topics
|
||||
FROM conversation_history
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
""",
|
||||
(limit,),
|
||||
)
|
||||
|
||||
conversations = []
|
||||
for row in cursor.fetchall():
|
||||
conversations.append({
|
||||
'conversation_id': row[0],
|
||||
'session_id': row[1],
|
||||
'started_at': row[2],
|
||||
'ended_at': row[3],
|
||||
'message_count': row[4],
|
||||
'summary': row[5],
|
||||
'topics': json.loads(row[6]) if row[6] else []
|
||||
})
|
||||
conversations.append(
|
||||
{
|
||||
"conversation_id": row[0],
|
||||
"session_id": row[1],
|
||||
"started_at": row[2],
|
||||
"ended_at": row[3],
|
||||
"message_count": row[4],
|
||||
"summary": row[5],
|
||||
"topics": json.loads(row[6]) if row[6] else [],
|
||||
}
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return conversations
|
||||
@ -221,10 +285,14 @@ class ConversationMemory:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?',
|
||||
(conversation_id,))
|
||||
cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?',
|
||||
(conversation_id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM conversation_messages WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM conversation_history WHERE conversation_id = ?",
|
||||
(conversation_id,),
|
||||
)
|
||||
|
||||
deleted = cursor.rowcount > 0
|
||||
conn.commit()
|
||||
@ -236,24 +304,26 @@ class ConversationMemory:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM conversation_history')
|
||||
cursor.execute("SELECT COUNT(*) FROM conversation_history")
|
||||
total_conversations = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM conversation_messages')
|
||||
cursor.execute("SELECT COUNT(*) FROM conversation_messages")
|
||||
total_messages = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT SUM(message_count) FROM conversation_history')
|
||||
total_message_count = cursor.fetchone()[0] or 0
|
||||
cursor.execute("SELECT SUM(message_count) FROM conversation_history")
|
||||
cursor.fetchone()[0] or 0
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
|
||||
''')
|
||||
"""
|
||||
)
|
||||
avg_messages = cursor.fetchone()[0] or 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_conversations': total_conversations,
|
||||
'total_messages': total_messages,
|
||||
'average_messages_per_conversation': round(avg_messages, 2)
|
||||
"total_conversations": total_conversations,
|
||||
"total_messages": total_messages,
|
||||
"average_messages_per_conversation": round(avg_messages, 2),
|
||||
}
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import re
|
||||
import json
|
||||
from typing import List, Dict, Any, Set
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class FactExtractor:
|
||||
def __init__(self):
|
||||
self.fact_patterns = [
|
||||
(r'([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)', 'definition'),
|
||||
(r'([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})', 'temporal'),
|
||||
(r'([A-Z][a-z]+) (invented|created|developed) ([^.]+)', 'attribution'),
|
||||
(r'([^.]+) (costs?|worth) (\$[\d,]+)', 'numeric'),
|
||||
(r'([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)', 'location'),
|
||||
(r"([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)", "definition"),
|
||||
(r"([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})", "temporal"),
|
||||
(r"([A-Z][a-z]+) (invented|created|developed) ([^.]+)", "attribution"),
|
||||
(r"([^.]+) (costs?|worth) (\$[\d,]+)", "numeric"),
|
||||
(r"([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)", "location"),
|
||||
]
|
||||
|
||||
def extract_facts(self, text: str) -> List[Dict[str, Any]]:
|
||||
@ -19,27 +19,31 @@ class FactExtractor:
|
||||
for pattern, fact_type in self.fact_patterns:
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
facts.append({
|
||||
'type': fact_type,
|
||||
'text': match.group(0),
|
||||
'components': match.groups(),
|
||||
'confidence': 0.7
|
||||
})
|
||||
facts.append(
|
||||
{
|
||||
"type": fact_type,
|
||||
"text": match.group(0),
|
||||
"components": match.groups(),
|
||||
"confidence": 0.7,
|
||||
}
|
||||
)
|
||||
|
||||
noun_phrases = self._extract_noun_phrases(text)
|
||||
for phrase in noun_phrases:
|
||||
if len(phrase.split()) >= 2:
|
||||
facts.append({
|
||||
'type': 'entity',
|
||||
'text': phrase,
|
||||
'components': [phrase],
|
||||
'confidence': 0.5
|
||||
})
|
||||
facts.append(
|
||||
{
|
||||
"type": "entity",
|
||||
"text": phrase,
|
||||
"components": [phrase],
|
||||
"confidence": 0.5,
|
||||
}
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_noun_phrases(self, text: str) -> List[str]:
|
||||
sentences = re.split(r'[.!?]', text)
|
||||
sentences = re.split(r"[.!?]", text)
|
||||
phrases = []
|
||||
|
||||
for sentence in sentences:
|
||||
@ -51,25 +55,73 @@ class FactExtractor:
|
||||
current_phrase.append(word)
|
||||
else:
|
||||
if len(current_phrase) >= 2:
|
||||
phrases.append(' '.join(current_phrase))
|
||||
phrases.append(" ".join(current_phrase))
|
||||
current_phrase = []
|
||||
|
||||
if len(current_phrase) >= 2:
|
||||
phrases.append(' '.join(current_phrase))
|
||||
phrases.append(" ".join(current_phrase))
|
||||
|
||||
return list(set(phrases))
|
||||
|
||||
def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]:
|
||||
words = re.findall(r'\b[a-z]{4,}\b', text.lower())
|
||||
words = re.findall(r"\b[a-z]{4,}\b", text.lower())
|
||||
|
||||
stopwords = {
|
||||
'this', 'that', 'these', 'those', 'what', 'which', 'where', 'when',
|
||||
'with', 'from', 'have', 'been', 'were', 'will', 'would', 'could',
|
||||
'should', 'about', 'their', 'there', 'other', 'than', 'then', 'them',
|
||||
'some', 'more', 'very', 'such', 'into', 'through', 'during', 'before',
|
||||
'after', 'above', 'below', 'between', 'under', 'again', 'further',
|
||||
'once', 'here', 'both', 'each', 'doing', 'only', 'over', 'same',
|
||||
'being', 'does', 'just', 'also', 'make', 'made', 'know', 'like'
|
||||
"this",
|
||||
"that",
|
||||
"these",
|
||||
"those",
|
||||
"what",
|
||||
"which",
|
||||
"where",
|
||||
"when",
|
||||
"with",
|
||||
"from",
|
||||
"have",
|
||||
"been",
|
||||
"were",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
"about",
|
||||
"their",
|
||||
"there",
|
||||
"other",
|
||||
"than",
|
||||
"then",
|
||||
"them",
|
||||
"some",
|
||||
"more",
|
||||
"very",
|
||||
"such",
|
||||
"into",
|
||||
"through",
|
||||
"during",
|
||||
"before",
|
||||
"after",
|
||||
"above",
|
||||
"below",
|
||||
"between",
|
||||
"under",
|
||||
"again",
|
||||
"further",
|
||||
"once",
|
||||
"here",
|
||||
"both",
|
||||
"each",
|
||||
"doing",
|
||||
"only",
|
||||
"over",
|
||||
"same",
|
||||
"being",
|
||||
"does",
|
||||
"just",
|
||||
"also",
|
||||
"make",
|
||||
"made",
|
||||
"know",
|
||||
"like",
|
||||
}
|
||||
|
||||
filtered_words = [w for w in words if w not in stopwords]
|
||||
@ -85,57 +137,120 @@ class FactExtractor:
|
||||
relationships = []
|
||||
|
||||
relationship_patterns = [
|
||||
(r'([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)', 'employment'),
|
||||
(r'([A-Z][a-z]+) (owns|has|possesses) ([^.]+)', 'ownership'),
|
||||
(r'([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)', 'location'),
|
||||
(r'([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)', 'usage'),
|
||||
(
|
||||
r"([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)",
|
||||
"employment",
|
||||
),
|
||||
(r"([A-Z][a-z]+) (owns|has|possesses) ([^.]+)", "ownership"),
|
||||
(
|
||||
r"([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)",
|
||||
"location",
|
||||
),
|
||||
(r"([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)", "usage"),
|
||||
]
|
||||
|
||||
for pattern, rel_type in relationship_patterns:
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
relationships.append({
|
||||
'type': rel_type,
|
||||
'subject': match.group(1),
|
||||
'predicate': match.group(2),
|
||||
'object': match.group(3),
|
||||
'confidence': 0.6
|
||||
})
|
||||
relationships.append(
|
||||
{
|
||||
"type": rel_type,
|
||||
"subject": match.group(1),
|
||||
"predicate": match.group(2),
|
||||
"object": match.group(3),
|
||||
"confidence": 0.6,
|
||||
}
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def extract_metadata(self, text: str) -> Dict[str, Any]:
|
||||
word_count = len(text.split())
|
||||
sentence_count = len(re.split(r'[.!?]', text))
|
||||
sentence_count = len(re.split(r"[.!?]", text))
|
||||
|
||||
urls = re.findall(r'https?://[^\s]+', text)
|
||||
email_addresses = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text)
|
||||
dates = re.findall(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b', text)
|
||||
numbers = re.findall(r'\b\d+(?:,\d{3})*(?:\.\d+)?\b', text)
|
||||
urls = re.findall(r"https?://[^\s]+", text)
|
||||
email_addresses = re.findall(
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text
|
||||
)
|
||||
dates = re.findall(
|
||||
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
|
||||
)
|
||||
numbers = re.findall(r"\b\d+(?:,\d{3})*(?:\.\d+)?\b", text)
|
||||
|
||||
return {
|
||||
'word_count': word_count,
|
||||
'sentence_count': sentence_count,
|
||||
'avg_words_per_sentence': round(word_count / max(sentence_count, 1), 2),
|
||||
'urls': urls,
|
||||
'email_addresses': email_addresses,
|
||||
'dates': dates,
|
||||
'numeric_values': numbers,
|
||||
'has_code': bool(re.search(r'```|def |class |import |function ', text)),
|
||||
'has_questions': bool(re.search(r'\?', text))
|
||||
"word_count": word_count,
|
||||
"sentence_count": sentence_count,
|
||||
"avg_words_per_sentence": round(word_count / max(sentence_count, 1), 2),
|
||||
"urls": urls,
|
||||
"email_addresses": email_addresses,
|
||||
"dates": dates,
|
||||
"numeric_values": numbers,
|
||||
"has_code": bool(re.search(r"```|def |class |import |function ", text)),
|
||||
"has_questions": bool(re.search(r"\?", text)),
|
||||
}
|
||||
|
||||
def categorize_content(self, text: str) -> List[str]:
|
||||
categories = []
|
||||
|
||||
category_keywords = {
|
||||
'programming': ['code', 'function', 'class', 'variable', 'programming', 'software', 'debug'],
|
||||
'data': ['data', 'database', 'query', 'table', 'record', 'statistics', 'analysis'],
|
||||
'documentation': ['documentation', 'guide', 'tutorial', 'manual', 'readme', 'explain'],
|
||||
'configuration': ['config', 'settings', 'configuration', 'setup', 'install', 'deployment'],
|
||||
'testing': ['test', 'testing', 'validate', 'verification', 'quality', 'assertion'],
|
||||
'research': ['research', 'study', 'analysis', 'investigation', 'findings', 'results'],
|
||||
'planning': ['plan', 'planning', 'schedule', 'roadmap', 'milestone', 'timeline'],
|
||||
"programming": [
|
||||
"code",
|
||||
"function",
|
||||
"class",
|
||||
"variable",
|
||||
"programming",
|
||||
"software",
|
||||
"debug",
|
||||
],
|
||||
"data": [
|
||||
"data",
|
||||
"database",
|
||||
"query",
|
||||
"table",
|
||||
"record",
|
||||
"statistics",
|
||||
"analysis",
|
||||
],
|
||||
"documentation": [
|
||||
"documentation",
|
||||
"guide",
|
||||
"tutorial",
|
||||
"manual",
|
||||
"readme",
|
||||
"explain",
|
||||
],
|
||||
"configuration": [
|
||||
"config",
|
||||
"settings",
|
||||
"configuration",
|
||||
"setup",
|
||||
"install",
|
||||
"deployment",
|
||||
],
|
||||
"testing": [
|
||||
"test",
|
||||
"testing",
|
||||
"validate",
|
||||
"verification",
|
||||
"quality",
|
||||
"assertion",
|
||||
],
|
||||
"research": [
|
||||
"research",
|
||||
"study",
|
||||
"analysis",
|
||||
"investigation",
|
||||
"findings",
|
||||
"results",
|
||||
],
|
||||
"planning": [
|
||||
"plan",
|
||||
"planning",
|
||||
"schedule",
|
||||
"roadmap",
|
||||
"milestone",
|
||||
"timeline",
|
||||
],
|
||||
}
|
||||
|
||||
text_lower = text.lower()
|
||||
@ -143,4 +258,4 @@ class FactExtractor:
|
||||
if any(keyword in text_lower for keyword in keywords):
|
||||
categories.append(category)
|
||||
|
||||
return categories if categories else ['general']
|
||||
return categories if categories else ["general"]
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .semantic_index import SemanticIndex
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeEntry:
|
||||
entry_id: str
|
||||
@ -18,16 +20,17 @@ class KnowledgeEntry:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'entry_id': self.entry_id,
|
||||
'category': self.category,
|
||||
'content': self.content,
|
||||
'metadata': self.metadata,
|
||||
'created_at': self.created_at,
|
||||
'updated_at': self.updated_at,
|
||||
'access_count': self.access_count,
|
||||
'importance_score': self.importance_score
|
||||
"entry_id": self.entry_id,
|
||||
"category": self.category,
|
||||
"content": self.content,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"access_count": self.access_count,
|
||||
"importance_score": self.importance_score,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeStore:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
@ -39,7 +42,8 @@ class KnowledgeStore:
|
||||
def _initialize_store(self):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS knowledge_entries (
|
||||
entry_id TEXT PRIMARY KEY,
|
||||
category TEXT NOT NULL,
|
||||
@ -50,44 +54,54 @@ class KnowledgeStore:
|
||||
access_count INTEGER DEFAULT 0,
|
||||
importance_score REAL DEFAULT 1.0
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def _load_index(self):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
|
||||
cursor.execute("SELECT entry_id, content FROM knowledge_entries")
|
||||
for row in cursor.fetchall():
|
||||
self.semantic_index.add_document(row[0], row[1])
|
||||
|
||||
def add_entry(self, entry: KnowledgeEntry):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO knowledge_entries
|
||||
(entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
entry.entry_id,
|
||||
entry.category,
|
||||
entry.content,
|
||||
json.dumps(entry.metadata),
|
||||
entry.created_at,
|
||||
entry.updated_at,
|
||||
entry.access_count,
|
||||
entry.importance_score
|
||||
))
|
||||
""",
|
||||
(
|
||||
entry.entry_id,
|
||||
entry.category,
|
||||
entry.content,
|
||||
json.dumps(entry.metadata),
|
||||
entry.created_at,
|
||||
entry.updated_at,
|
||||
entry.access_count,
|
||||
entry.importance_score,
|
||||
),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
@ -96,20 +110,26 @@ class KnowledgeStore:
|
||||
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
""",
|
||||
(entry_id,),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE knowledge_entries
|
||||
SET access_count = access_count + 1
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
""",
|
||||
(entry_id,),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
return KnowledgeEntry(
|
||||
@ -120,13 +140,14 @@ class KnowledgeStore:
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6] + 1,
|
||||
importance_score=row[7]
|
||||
importance_score=row[7],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def search_entries(self, query: str, category: Optional[str] = None,
|
||||
top_k: int = 5) -> List[KnowledgeEntry]:
|
||||
def search_entries(
|
||||
self, query: str, category: Optional[str] = None, top_k: int = 5
|
||||
) -> List[KnowledgeEntry]:
|
||||
search_results = self.semantic_index.search(query, top_k * 2)
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
@ -134,17 +155,23 @@ class KnowledgeStore:
|
||||
entries = []
|
||||
for entry_id, score in search_results:
|
||||
if category:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE entry_id = ? AND category = ?
|
||||
''', (entry_id, category))
|
||||
""",
|
||||
(entry_id, category),
|
||||
)
|
||||
else:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
""",
|
||||
(entry_id,),
|
||||
)
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
@ -156,7 +183,7 @@ class KnowledgeStore:
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6],
|
||||
importance_score=row[7]
|
||||
importance_score=row[7],
|
||||
)
|
||||
entries.append(entry)
|
||||
|
||||
@ -168,44 +195,52 @@ class KnowledgeStore:
|
||||
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
FROM knowledge_entries
|
||||
WHERE category = ?
|
||||
ORDER BY importance_score DESC, created_at DESC
|
||||
LIMIT ?
|
||||
''', (category, limit))
|
||||
""",
|
||||
(category, limit),
|
||||
)
|
||||
|
||||
entries = []
|
||||
for row in cursor.fetchall():
|
||||
entries.append(KnowledgeEntry(
|
||||
entry_id=row[0],
|
||||
category=row[1],
|
||||
content=row[2],
|
||||
metadata=json.loads(row[3]) if row[3] else {},
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6],
|
||||
importance_score=row[7]
|
||||
))
|
||||
entries.append(
|
||||
KnowledgeEntry(
|
||||
entry_id=row[0],
|
||||
category=row[1],
|
||||
content=row[2],
|
||||
metadata=json.loads(row[3]) if row[3] else {},
|
||||
created_at=row[4],
|
||||
updated_at=row[5],
|
||||
access_count=row[6],
|
||||
importance_score=row[7],
|
||||
)
|
||||
)
|
||||
|
||||
return entries
|
||||
|
||||
def update_importance(self, entry_id: str, importance_score: float):
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE knowledge_entries
|
||||
SET importance_score = ?, updated_at = ?
|
||||
WHERE entry_id = ?
|
||||
''', (importance_score, time.time(), entry_id))
|
||||
""",
|
||||
(importance_score, time.time(), entry_id),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def delete_entry(self, entry_id: str) -> bool:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
|
||||
cursor.execute("DELETE FROM knowledge_entries WHERE entry_id = ?", (entry_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
|
||||
self.conn.commit()
|
||||
@ -218,27 +253,29 @@ class KnowledgeStore:
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
|
||||
cursor.execute("SELECT COUNT(*) FROM knowledge_entries")
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('SELECT COUNT(DISTINCT category) FROM knowledge_entries')
|
||||
cursor.execute("SELECT COUNT(DISTINCT category) FROM knowledge_entries")
|
||||
total_categories = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT category, COUNT(*) as count
|
||||
FROM knowledge_entries
|
||||
GROUP BY category
|
||||
ORDER BY count DESC
|
||||
''')
|
||||
"""
|
||||
)
|
||||
category_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
|
||||
cursor.execute("SELECT SUM(access_count) FROM knowledge_entries")
|
||||
total_accesses = cursor.fetchone()[0] or 0
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'total_categories': total_categories,
|
||||
'category_distribution': category_counts,
|
||||
'total_accesses': total_accesses,
|
||||
'vocabulary_size': len(self.semantic_index.vocabulary)
|
||||
"total_entries": total_entries,
|
||||
"total_categories": total_categories,
|
||||
"category_distribution": category_counts,
|
||||
"total_accesses": total_accesses,
|
||||
"vocabulary_size": len(self.semantic_index.vocabulary),
|
||||
}
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import math
|
||||
import re
|
||||
from collections import Counter, defaultdict
|
||||
from typing import List, Dict, Tuple, Set
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
|
||||
class SemanticIndex:
|
||||
def __init__(self):
|
||||
@ -12,7 +13,7 @@ class SemanticIndex:
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
text = text.lower()
|
||||
text = re.sub(r'[^a-z0-9\s]', ' ', text)
|
||||
text = re.sub(r"[^a-z0-9\s]", " ", text)
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
@ -78,8 +79,12 @@ class SemanticIndex:
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return scores[:top_k]
|
||||
|
||||
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
|
||||
dot_product = sum(vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2))
|
||||
def _cosine_similarity(
|
||||
self, vec1: Dict[str, float], vec2: Dict[str, float]
|
||||
) -> float:
|
||||
dot_product = sum(
|
||||
vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)
|
||||
)
|
||||
norm1 = math.sqrt(sum(val**2 for val in vec1.values()))
|
||||
norm2 = math.sqrt(sum(val**2 for val in vec2.values()))
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import sys
|
||||
import subprocess
|
||||
import signal
|
||||
import os
|
||||
from pr.ui import Colors
|
||||
from collections import defaultdict
|
||||
from pr.tools.process_handlers import get_handler_for_process, detect_process_type
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from pr.tools.process_handlers import detect_process_type, get_handler_for_process
|
||||
from pr.tools.prompt_detection import get_global_detector
|
||||
from pr.ui import Colors
|
||||
|
||||
|
||||
class TerminalMultiplexer:
|
||||
def __init__(self, name, show_output=True):
|
||||
@ -21,17 +20,19 @@ class TerminalMultiplexer:
|
||||
self.active = True
|
||||
self.lock = threading.Lock()
|
||||
self.metadata = {
|
||||
'start_time': time.time(),
|
||||
'last_activity': time.time(),
|
||||
'interaction_count': 0,
|
||||
'process_type': 'unknown',
|
||||
'state': 'active'
|
||||
"start_time": time.time(),
|
||||
"last_activity": time.time(),
|
||||
"interaction_count": 0,
|
||||
"process_type": "unknown",
|
||||
"state": "active",
|
||||
}
|
||||
self.handler = None
|
||||
self.prompt_detector = get_global_detector()
|
||||
|
||||
if self.show_output:
|
||||
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
|
||||
self.display_thread = threading.Thread(
|
||||
target=self._display_worker, daemon=True
|
||||
)
|
||||
self.display_thread.start()
|
||||
|
||||
def _display_worker(self):
|
||||
@ -47,7 +48,9 @@ class TerminalMultiplexer:
|
||||
try:
|
||||
line = self.stderr_queue.get(timeout=0.1)
|
||||
if line:
|
||||
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}")
|
||||
sys.stderr.write(
|
||||
f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}"
|
||||
)
|
||||
sys.stderr.flush()
|
||||
except queue.Empty:
|
||||
pass
|
||||
@ -55,40 +58,44 @@ class TerminalMultiplexer:
|
||||
def write_stdout(self, data):
|
||||
with self.lock:
|
||||
self.stdout_buffer.append(data)
|
||||
self.metadata['last_activity'] = time.time()
|
||||
self.metadata["last_activity"] = time.time()
|
||||
# Update handler state if available
|
||||
if self.handler:
|
||||
self.handler.update_state(data)
|
||||
# Update prompt detector
|
||||
self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type'])
|
||||
self.prompt_detector.update_session_state(
|
||||
self.name, data, self.metadata["process_type"]
|
||||
)
|
||||
if self.show_output:
|
||||
self.stdout_queue.put(data)
|
||||
|
||||
def write_stderr(self, data):
|
||||
with self.lock:
|
||||
self.stderr_buffer.append(data)
|
||||
self.metadata['last_activity'] = time.time()
|
||||
self.metadata["last_activity"] = time.time()
|
||||
# Update handler state if available
|
||||
if self.handler:
|
||||
self.handler.update_state(data)
|
||||
# Update prompt detector
|
||||
self.prompt_detector.update_session_state(self.name, data, self.metadata['process_type'])
|
||||
self.prompt_detector.update_session_state(
|
||||
self.name, data, self.metadata["process_type"]
|
||||
)
|
||||
if self.show_output:
|
||||
self.stderr_queue.put(data)
|
||||
|
||||
def get_stdout(self):
|
||||
with self.lock:
|
||||
return ''.join(self.stdout_buffer)
|
||||
return "".join(self.stdout_buffer)
|
||||
|
||||
def get_stderr(self):
|
||||
with self.lock:
|
||||
return ''.join(self.stderr_buffer)
|
||||
return "".join(self.stderr_buffer)
|
||||
|
||||
def get_all_output(self):
|
||||
with self.lock:
|
||||
return {
|
||||
'stdout': ''.join(self.stdout_buffer),
|
||||
'stderr': ''.join(self.stderr_buffer)
|
||||
"stdout": "".join(self.stdout_buffer),
|
||||
"stderr": "".join(self.stderr_buffer),
|
||||
}
|
||||
|
||||
def get_metadata(self):
|
||||
@ -102,31 +109,32 @@ class TerminalMultiplexer:
|
||||
def set_process_type(self, process_type):
|
||||
"""Set the process type and initialize appropriate handler."""
|
||||
with self.lock:
|
||||
self.metadata['process_type'] = process_type
|
||||
self.metadata["process_type"] = process_type
|
||||
self.handler = get_handler_for_process(process_type, self)
|
||||
|
||||
def send_input(self, input_data):
|
||||
if hasattr(self, 'process') and self.process.poll() is None:
|
||||
if hasattr(self, "process") and self.process.poll() is None:
|
||||
try:
|
||||
self.process.stdin.write(input_data + '\n')
|
||||
self.process.stdin.write(input_data + "\n")
|
||||
self.process.stdin.flush()
|
||||
with self.lock:
|
||||
self.metadata['last_activity'] = time.time()
|
||||
self.metadata['interaction_count'] += 1
|
||||
self.metadata["last_activity"] = time.time()
|
||||
self.metadata["interaction_count"] += 1
|
||||
except Exception as e:
|
||||
self.write_stderr(f"Error sending input: {e}")
|
||||
else:
|
||||
# This will be implemented when we have a process attached
|
||||
# For now, just update activity
|
||||
with self.lock:
|
||||
self.metadata['last_activity'] = time.time()
|
||||
self.metadata['interaction_count'] += 1
|
||||
self.metadata["last_activity"] = time.time()
|
||||
self.metadata["interaction_count"] += 1
|
||||
|
||||
def close(self):
|
||||
self.active = False
|
||||
if hasattr(self, 'display_thread'):
|
||||
if hasattr(self, "display_thread"):
|
||||
self.display_thread.join(timeout=1)
|
||||
|
||||
|
||||
_multiplexers = {}
|
||||
_mux_counter = 0
|
||||
_mux_lock = threading.Lock()
|
||||
@ -134,6 +142,7 @@ _background_monitor = None
|
||||
_monitor_active = False
|
||||
_monitor_interval = 0.2 # 200ms
|
||||
|
||||
|
||||
def create_multiplexer(name=None, show_output=True):
|
||||
global _mux_counter
|
||||
with _mux_lock:
|
||||
@ -144,44 +153,50 @@ def create_multiplexer(name=None, show_output=True):
|
||||
_multiplexers[name] = mux
|
||||
return name, mux
|
||||
|
||||
|
||||
def get_multiplexer(name):
|
||||
return _multiplexers.get(name)
|
||||
|
||||
|
||||
def close_multiplexer(name):
|
||||
mux = _multiplexers.get(name)
|
||||
if mux:
|
||||
mux.close()
|
||||
del _multiplexers[name]
|
||||
|
||||
|
||||
def get_all_multiplexer_states():
|
||||
with _mux_lock:
|
||||
states = {}
|
||||
for name, mux in _multiplexers.items():
|
||||
states[name] = {
|
||||
'metadata': mux.get_metadata(),
|
||||
'output_summary': {
|
||||
'stdout_lines': len(mux.stdout_buffer),
|
||||
'stderr_lines': len(mux.stderr_buffer)
|
||||
}
|
||||
"metadata": mux.get_metadata(),
|
||||
"output_summary": {
|
||||
"stdout_lines": len(mux.stdout_buffer),
|
||||
"stderr_lines": len(mux.stderr_buffer),
|
||||
},
|
||||
}
|
||||
return states
|
||||
|
||||
|
||||
def cleanup_all_multiplexers():
|
||||
for mux in list(_multiplexers.values()):
|
||||
mux.close()
|
||||
_multiplexers.clear()
|
||||
|
||||
|
||||
# Background process management
|
||||
_background_processes = {}
|
||||
_process_lock = threading.Lock()
|
||||
|
||||
|
||||
class BackgroundProcess:
|
||||
def __init__(self, name, command):
|
||||
self.name = name
|
||||
self.command = command
|
||||
self.process = None
|
||||
self.multiplexer = None
|
||||
self.status = 'starting'
|
||||
self.status = "starting"
|
||||
self.start_time = time.time()
|
||||
self.end_time = None
|
||||
|
||||
@ -205,27 +220,27 @@ class BackgroundProcess:
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True
|
||||
universal_newlines=True,
|
||||
)
|
||||
|
||||
self.status = 'running'
|
||||
self.status = "running"
|
||||
|
||||
# Start output monitoring threads
|
||||
threading.Thread(target=self._monitor_stdout, daemon=True).start()
|
||||
threading.Thread(target=self._monitor_stderr, daemon=True).start()
|
||||
|
||||
return {'status': 'success', 'pid': self.process.pid}
|
||||
return {"status": "success", "pid": self.process.pid}
|
||||
|
||||
except Exception as e:
|
||||
self.status = 'error'
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
self.status = "error"
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def _monitor_stdout(self):
|
||||
"""Monitor stdout from the process."""
|
||||
try:
|
||||
for line in iter(self.process.stdout.readline, ''):
|
||||
for line in iter(self.process.stdout.readline, ""):
|
||||
if line:
|
||||
self.multiplexer.write_stdout(line.rstrip('\n\r'))
|
||||
self.multiplexer.write_stdout(line.rstrip("\n\r"))
|
||||
except Exception as e:
|
||||
self.write_stderr(f"Error reading stdout: {e}")
|
||||
finally:
|
||||
@ -234,29 +249,33 @@ class BackgroundProcess:
|
||||
def _monitor_stderr(self):
|
||||
"""Monitor stderr from the process."""
|
||||
try:
|
||||
for line in iter(self.process.stderr.readline, ''):
|
||||
for line in iter(self.process.stderr.readline, ""):
|
||||
if line:
|
||||
self.multiplexer.write_stderr(line.rstrip('\n\r'))
|
||||
self.multiplexer.write_stderr(line.rstrip("\n\r"))
|
||||
except Exception as e:
|
||||
self.write_stderr(f"Error reading stderr: {e}")
|
||||
|
||||
def _check_completion(self):
|
||||
"""Check if process has completed."""
|
||||
if self.process and self.process.poll() is not None:
|
||||
self.status = 'completed'
|
||||
self.status = "completed"
|
||||
self.end_time = time.time()
|
||||
|
||||
def get_info(self):
|
||||
"""Get process information."""
|
||||
self._check_completion()
|
||||
return {
|
||||
'name': self.name,
|
||||
'command': self.command,
|
||||
'status': self.status,
|
||||
'pid': self.process.pid if self.process else None,
|
||||
'start_time': self.start_time,
|
||||
'end_time': self.end_time,
|
||||
'runtime': time.time() - self.start_time if not self.end_time else self.end_time - self.start_time
|
||||
"name": self.name,
|
||||
"command": self.command,
|
||||
"status": self.status,
|
||||
"pid": self.process.pid if self.process else None,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"runtime": (
|
||||
time.time() - self.start_time
|
||||
if not self.end_time
|
||||
else self.end_time - self.start_time
|
||||
),
|
||||
}
|
||||
|
||||
def get_output(self, lines=None):
|
||||
@ -265,8 +284,8 @@ class BackgroundProcess:
|
||||
return []
|
||||
|
||||
all_output = self.multiplexer.get_all_output()
|
||||
stdout_lines = all_output['stdout'].split('\n') if all_output['stdout'] else []
|
||||
stderr_lines = all_output['stderr'].split('\n') if all_output['stderr'] else []
|
||||
stdout_lines = all_output["stdout"].split("\n") if all_output["stdout"] else []
|
||||
stderr_lines = all_output["stderr"].split("\n") if all_output["stderr"] else []
|
||||
|
||||
combined = stdout_lines + stderr_lines
|
||||
if lines:
|
||||
@ -276,45 +295,47 @@ class BackgroundProcess:
|
||||
|
||||
def send_input(self, input_text):
|
||||
"""Send input to the process."""
|
||||
if self.process and self.status == 'running':
|
||||
if self.process and self.status == "running":
|
||||
try:
|
||||
self.process.stdin.write(input_text + '\n')
|
||||
self.process.stdin.write(input_text + "\n")
|
||||
self.process.stdin.flush()
|
||||
return {'status': 'success'}
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
return {'status': 'error', 'error': 'Process not running or no stdin'}
|
||||
return {"status": "error", "error": str(e)}
|
||||
return {"status": "error", "error": "Process not running or no stdin"}
|
||||
|
||||
def kill(self):
|
||||
"""Kill the process."""
|
||||
if self.process and self.status == 'running':
|
||||
if self.process and self.status == "running":
|
||||
try:
|
||||
self.process.terminate()
|
||||
# Wait a bit for graceful termination
|
||||
time.sleep(0.1)
|
||||
if self.process.poll() is None:
|
||||
self.process.kill()
|
||||
self.status = 'killed'
|
||||
self.status = "killed"
|
||||
self.end_time = time.time()
|
||||
return {'status': 'success'}
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
return {'status': 'error', 'error': str(e)}
|
||||
return {'status': 'error', 'error': 'Process not running'}
|
||||
return {"status": "error", "error": str(e)}
|
||||
return {"status": "error", "error": "Process not running"}
|
||||
|
||||
|
||||
def start_background_process(name, command):
|
||||
"""Start a background process."""
|
||||
with _process_lock:
|
||||
if name in _background_processes:
|
||||
return {'status': 'error', 'error': f'Process {name} already exists'}
|
||||
return {"status": "error", "error": f"Process {name} already exists"}
|
||||
|
||||
process = BackgroundProcess(name, command)
|
||||
result = process.start()
|
||||
|
||||
if result['status'] == 'success':
|
||||
if result["status"] == "success":
|
||||
_background_processes[name] = process
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all_sessions():
|
||||
"""Get all background process sessions."""
|
||||
with _process_lock:
|
||||
@ -323,23 +344,31 @@ def get_all_sessions():
|
||||
sessions[name] = process.get_info()
|
||||
return sessions
|
||||
|
||||
|
||||
def get_session_info(name):
|
||||
"""Get information about a specific session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
return process.get_info() if process else None
|
||||
|
||||
|
||||
def get_session_output(name, lines=None):
|
||||
"""Get output from a specific session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
return process.get_output(lines) if process else None
|
||||
|
||||
|
||||
def send_input_to_session(name, input_text):
|
||||
"""Send input to a background session."""
|
||||
with _process_lock:
|
||||
process = _background_processes.get(name)
|
||||
return process.send_input(input_text) if process else {'status': 'error', 'error': 'Session not found'}
|
||||
return (
|
||||
process.send_input(input_text)
|
||||
if process
|
||||
else {"status": "error", "error": "Session not found"}
|
||||
)
|
||||
|
||||
|
||||
def kill_session(name):
|
||||
"""Kill a background session."""
|
||||
@ -347,7 +376,7 @@ def kill_session(name):
|
||||
process = _background_processes.get(name)
|
||||
if process:
|
||||
result = process.kill()
|
||||
if result['status'] == 'success':
|
||||
if result["status"] == "success":
|
||||
del _background_processes[name]
|
||||
return result
|
||||
return {'status': 'error', 'error': 'Session not found'}
|
||||
return {"status": "error", "error": "Session not found"}
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import importlib.util
|
||||
from typing import List, Dict, Callable, Any
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from pr.core.logging import get_logger
|
||||
|
||||
logger = get_logger('plugins')
|
||||
logger = get_logger("plugins")
|
||||
|
||||
PLUGINS_DIR = os.path.expanduser("~/.pr/plugins")
|
||||
|
||||
@ -21,7 +22,7 @@ class PluginLoader:
|
||||
logger.info("No plugins directory found")
|
||||
return []
|
||||
|
||||
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith('.py')]
|
||||
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith(".py")]
|
||||
|
||||
for plugin_file in plugin_files:
|
||||
try:
|
||||
@ -44,16 +45,20 @@ class PluginLoader:
|
||||
sys.modules[plugin_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
if hasattr(module, 'register_tools'):
|
||||
if hasattr(module, "register_tools"):
|
||||
tools = module.register_tools()
|
||||
if isinstance(tools, list):
|
||||
self.plugin_tools.extend(tools)
|
||||
self.loaded_plugins[plugin_name] = module
|
||||
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
|
||||
else:
|
||||
logger.warning(f"Plugin {plugin_name} register_tools() did not return a list")
|
||||
logger.warning(
|
||||
f"Plugin {plugin_name} register_tools() did not return a list"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Plugin {plugin_name} does not have register_tools() function")
|
||||
logger.warning(
|
||||
f"Plugin {plugin_name} does not have register_tools() function"
|
||||
)
|
||||
|
||||
def get_plugin_function(self, tool_name: str) -> Callable:
|
||||
for plugin_name, module in self.loaded_plugins.items():
|
||||
@ -67,7 +72,7 @@ class PluginLoader:
|
||||
|
||||
|
||||
def create_example_plugin():
|
||||
example_plugin = os.path.join(PLUGINS_DIR, 'example_plugin.py')
|
||||
example_plugin = os.path.join(PLUGINS_DIR, "example_plugin.py")
|
||||
|
||||
if os.path.exists(example_plugin):
|
||||
return
|
||||
@ -121,7 +126,7 @@ def register_tools():
|
||||
|
||||
try:
|
||||
os.makedirs(PLUGINS_DIR, exist_ok=True)
|
||||
with open(example_plugin, 'w') as f:
|
||||
with open(example_plugin, "w") as f:
|
||||
f.write(example_code)
|
||||
logger.info(f"Created example plugin at {example_plugin}")
|
||||
except Exception as e:
|
||||
|
||||
@ -1,25 +1,86 @@
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.tools.filesystem import (
|
||||
read_file, write_file, list_directory, mkdir, chdir, getpwd, index_source_directory, search_replace
|
||||
from pr.tools.agents import (
|
||||
collaborate_agents,
|
||||
create_agent,
|
||||
execute_agent_task,
|
||||
list_agents,
|
||||
remove_agent,
|
||||
)
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.tools.command import (
|
||||
kill_process,
|
||||
run_command,
|
||||
run_command_interactive,
|
||||
tail_process,
|
||||
)
|
||||
from pr.tools.database import db_get, db_query, db_set
|
||||
from pr.tools.editor import (
|
||||
close_editor,
|
||||
editor_insert_text,
|
||||
editor_replace_text,
|
||||
editor_search,
|
||||
open_editor,
|
||||
)
|
||||
from pr.tools.filesystem import (
|
||||
chdir,
|
||||
getpwd,
|
||||
index_source_directory,
|
||||
list_directory,
|
||||
mkdir,
|
||||
read_file,
|
||||
search_replace,
|
||||
write_file,
|
||||
)
|
||||
from pr.tools.memory import (
|
||||
add_knowledge_entry,
|
||||
delete_knowledge_entry,
|
||||
get_knowledge_by_category,
|
||||
get_knowledge_entry,
|
||||
get_knowledge_statistics,
|
||||
search_knowledge,
|
||||
update_knowledge_importance,
|
||||
)
|
||||
from pr.tools.command import run_command, run_command_interactive, tail_process, kill_process
|
||||
from pr.tools.editor import open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor
|
||||
from pr.tools.database import db_set, db_get, db_query
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents
|
||||
from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
|
||||
__all__ = [
|
||||
'get_tools_definition',
|
||||
'read_file', 'write_file', 'list_directory', 'mkdir', 'chdir', 'getpwd', 'index_source_directory', 'search_replace',
|
||||
'open_editor', 'editor_insert_text', 'editor_replace_text', 'editor_search','close_editor',
|
||||
'run_command', 'run_command_interactive',
|
||||
'db_set', 'db_get', 'db_query',
|
||||
'http_fetch', 'web_search', 'web_search_news',
|
||||
'python_exec','tail_process', 'kill_process',
|
||||
'apply_patch', 'create_diff',
|
||||
'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents',
|
||||
'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics'
|
||||
"get_tools_definition",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"list_directory",
|
||||
"mkdir",
|
||||
"chdir",
|
||||
"getpwd",
|
||||
"index_source_directory",
|
||||
"search_replace",
|
||||
"open_editor",
|
||||
"editor_insert_text",
|
||||
"editor_replace_text",
|
||||
"editor_search",
|
||||
"close_editor",
|
||||
"run_command",
|
||||
"run_command_interactive",
|
||||
"db_set",
|
||||
"db_get",
|
||||
"db_query",
|
||||
"http_fetch",
|
||||
"web_search",
|
||||
"web_search_news",
|
||||
"python_exec",
|
||||
"tail_process",
|
||||
"kill_process",
|
||||
"apply_patch",
|
||||
"create_diff",
|
||||
"create_agent",
|
||||
"list_agents",
|
||||
"execute_agent_task",
|
||||
"remove_agent",
|
||||
"collaborate_agents",
|
||||
"add_knowledge_entry",
|
||||
"get_knowledge_entry",
|
||||
"search_knowledge",
|
||||
"get_knowledge_by_category",
|
||||
"update_knowledge_importance",
|
||||
"delete_knowledge_entry",
|
||||
"get_knowledge_statistics",
|
||||
]
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pr.agents.agent_manager import AgentManager
|
||||
from pr.core.api import call_api
|
||||
|
||||
|
||||
def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
|
||||
"""Create a new agent with the specified role."""
|
||||
try:
|
||||
# Get db_path from environment or default
|
||||
db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite')
|
||||
db_path = os.environ.get("ASSISTANT_DB_PATH", "~/.assistant_db.sqlite")
|
||||
db_path = os.path.expanduser(db_path)
|
||||
|
||||
manager = AgentManager(db_path, call_api)
|
||||
@ -16,47 +18,57 @@ def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def list_agents() -> Dict[str, Any]:
|
||||
"""List all active agents."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
manager = AgentManager(db_path, call_api)
|
||||
agents = []
|
||||
for agent_id, agent in manager.active_agents.items():
|
||||
agents.append({
|
||||
"agent_id": agent_id,
|
||||
"role": agent.role.name,
|
||||
"task_count": agent.task_count,
|
||||
"message_count": len(agent.message_history)
|
||||
})
|
||||
agents.append(
|
||||
{
|
||||
"agent_id": agent_id,
|
||||
"role": agent.role.name,
|
||||
"task_count": agent.task_count,
|
||||
"message_count": len(agent.message_history),
|
||||
}
|
||||
)
|
||||
return {"status": "success", "agents": agents}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
|
||||
def execute_agent_task(
|
||||
agent_id: str, task: str, context: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a task with the specified agent."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
manager = AgentManager(db_path, call_api)
|
||||
result = manager.execute_agent_task(agent_id, task, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def remove_agent(agent_id: str) -> Dict[str, Any]:
|
||||
"""Remove an agent."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
manager = AgentManager(db_path, call_api)
|
||||
success = manager.remove_agent(agent_id)
|
||||
return {"status": "success" if success else "not_found", "agent_id": agent_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
|
||||
|
||||
def collaborate_agents(
|
||||
orchestrator_id: str, task: str, agent_roles: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Collaborate multiple agents on a task."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
manager = AgentManager(db_path, call_api)
|
||||
result = manager.collaborate_agents(orchestrator_id, task, agent_roles)
|
||||
return result
|
||||
|
||||
417
pr/tools/base.py
417
pr/tools/base.py
@ -10,12 +10,12 @@ def get_tools_definition():
|
||||
"properties": {
|
||||
"pid": {
|
||||
"type": "integer",
|
||||
"description": "The process ID returned by run_command when status is 'running'."
|
||||
"description": "The process ID returned by run_command when status is 'running'.",
|
||||
}
|
||||
},
|
||||
"required": ["pid"]
|
||||
}
|
||||
}
|
||||
"required": ["pid"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -27,17 +27,17 @@ def get_tools_definition():
|
||||
"properties": {
|
||||
"pid": {
|
||||
"type": "integer",
|
||||
"description": "The process ID returned by run_command when status is 'running'."
|
||||
"description": "The process ID returned by run_command when status is 'running'.",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
|
||||
"default": 30
|
||||
}
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["pid"]
|
||||
}
|
||||
}
|
||||
"required": ["pid"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -48,11 +48,14 @@ def get_tools_definition():
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "The URL to fetch"},
|
||||
"headers": {"type": "object", "description": "Optional HTTP headers"}
|
||||
"headers": {
|
||||
"type": "object",
|
||||
"description": "Optional HTTP headers",
|
||||
},
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
}
|
||||
"required": ["url"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -62,12 +65,19 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "The shell command to execute"},
|
||||
"timeout": {"type": "integer", "description": "Maximum seconds to wait for completion", "default": 30}
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Maximum seconds to wait for completion",
|
||||
"default": 30,
|
||||
},
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -77,11 +87,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "The interactive command to execute (e.g., vim, nano, top)"}
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The interactive command to execute (e.g., vim, nano, top)",
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -91,12 +104,18 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_name": {"type": "string", "description": "The name of the session"},
|
||||
"input_data": {"type": "string", "description": "The input to send to the session"}
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the session",
|
||||
},
|
||||
"input_data": {
|
||||
"type": "string",
|
||||
"description": "The input to send to the session",
|
||||
},
|
||||
},
|
||||
"required": ["session_name", "input_data"]
|
||||
}
|
||||
}
|
||||
"required": ["session_name", "input_data"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -106,11 +125,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_name": {"type": "string", "description": "The name of the session"}
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the session",
|
||||
}
|
||||
},
|
||||
"required": ["session_name"]
|
||||
}
|
||||
}
|
||||
"required": ["session_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -120,11 +142,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_name": {"type": "string", "description": "The name of the session"}
|
||||
"session_name": {
|
||||
"type": "string",
|
||||
"description": "The name of the session",
|
||||
}
|
||||
},
|
||||
"required": ["session_name"]
|
||||
}
|
||||
}
|
||||
"required": ["session_name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -134,11 +159,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"}
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
}
|
||||
},
|
||||
"required": ["filepath"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -148,12 +176,18 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"content": {"type": "string", "description": "Content to write"}
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "content"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath", "content"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -163,11 +197,19 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Directory path", "default": "."},
|
||||
"recursive": {"type": "boolean", "description": "List recursively", "default": False}
|
||||
}
|
||||
}
|
||||
}
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory path",
|
||||
"default": ".",
|
||||
},
|
||||
"recursive": {
|
||||
"type": "boolean",
|
||||
"description": "List recursively",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -177,11 +219,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path of the directory to create"}
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path of the directory to create",
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -193,17 +238,17 @@ def get_tools_definition():
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to change to"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "getpwd",
|
||||
"description": "Get the current working directory",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -214,11 +259,11 @@ def get_tools_definition():
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "The key"},
|
||||
"value": {"type": "string", "description": "The value"}
|
||||
"value": {"type": "string", "description": "The value"},
|
||||
},
|
||||
"required": ["key", "value"]
|
||||
}
|
||||
}
|
||||
"required": ["key", "value"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -227,12 +272,10 @@ def get_tools_definition():
|
||||
"description": "Get a value from the database",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"type": "string", "description": "The key"}
|
||||
},
|
||||
"required": ["key"]
|
||||
}
|
||||
}
|
||||
"properties": {"key": {"type": "string", "description": "The key"}},
|
||||
"required": ["key"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -244,9 +287,9 @@ def get_tools_definition():
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "SQL query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -258,9 +301,9 @@ def get_tools_definition():
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -270,11 +313,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query for news"}
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for news",
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -284,11 +330,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "Python code to execute"}
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "Python code to execute",
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}
|
||||
}
|
||||
"required": ["code"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -300,9 +349,9 @@ def get_tools_definition():
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to index"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -312,13 +361,22 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"old_string": {"type": "string", "description": "String to replace"},
|
||||
"new_string": {"type": "string", "description": "Replacement string"}
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "String to replace",
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement string",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "old_string", "new_string"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath", "old_string", "new_string"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -328,12 +386,18 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file to patch"},
|
||||
"patch_content": {"type": "string", "description": "The patch content as a string"}
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to patch",
|
||||
},
|
||||
"patch_content": {
|
||||
"type": "string",
|
||||
"description": "The patch content as a string",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "patch_content"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath", "patch_content"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -343,14 +407,28 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file1": {"type": "string", "description": "Path to the first file"},
|
||||
"file2": {"type": "string", "description": "Path to the second file"},
|
||||
"fromfile": {"type": "string", "description": "Label for the first file", "default": "file1"},
|
||||
"tofile": {"type": "string", "description": "Label for the second file", "default": "file2"}
|
||||
"file1": {
|
||||
"type": "string",
|
||||
"description": "Path to the first file",
|
||||
},
|
||||
"file2": {
|
||||
"type": "string",
|
||||
"description": "Path to the second file",
|
||||
},
|
||||
"fromfile": {
|
||||
"type": "string",
|
||||
"description": "Label for the first file",
|
||||
"default": "file1",
|
||||
},
|
||||
"tofile": {
|
||||
"type": "string",
|
||||
"description": "Label for the second file",
|
||||
"default": "file2",
|
||||
},
|
||||
},
|
||||
"required": ["file1", "file2"]
|
||||
}
|
||||
}
|
||||
"required": ["file1", "file2"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -360,11 +438,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"}
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
}
|
||||
},
|
||||
"required": ["filepath"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -374,11 +455,14 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"}
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
}
|
||||
},
|
||||
"required": ["filepath"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -388,14 +472,23 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"text": {"type": "string", "description": "Text to insert"},
|
||||
"line": {"type": "integer", "description": "Line number (optional)"},
|
||||
"col": {"type": "integer", "description": "Column number (optional)"}
|
||||
"line": {
|
||||
"type": "integer",
|
||||
"description": "Line number (optional)",
|
||||
},
|
||||
"col": {
|
||||
"type": "integer",
|
||||
"description": "Column number (optional)",
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "text"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath", "text"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -405,16 +498,26 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"start_line": {"type": "integer", "description": "Start line"},
|
||||
"start_col": {"type": "integer", "description": "Start column"},
|
||||
"end_line": {"type": "integer", "description": "End line"},
|
||||
"end_col": {"type": "integer", "description": "End column"},
|
||||
"new_text": {"type": "string", "description": "New text"}
|
||||
"new_text": {"type": "string", "description": "New text"},
|
||||
},
|
||||
"required": ["filepath", "start_line", "start_col", "end_line", "end_col", "new_text"]
|
||||
}
|
||||
}
|
||||
"required": [
|
||||
"filepath",
|
||||
"start_line",
|
||||
"start_col",
|
||||
"end_line",
|
||||
"end_col",
|
||||
"new_text",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -424,13 +527,20 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath": {"type": "string", "description": "Path to the file"},
|
||||
"filepath": {
|
||||
"type": "string",
|
||||
"description": "Path to the file",
|
||||
},
|
||||
"pattern": {"type": "string", "description": "Regex pattern"},
|
||||
"start_line": {"type": "integer", "description": "Start line", "default": 0}
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Start line",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["filepath", "pattern"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath", "pattern"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -440,24 +550,31 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filepath1": {"type": "string", "description": "Path to the original file"},
|
||||
"filepath2": {"type": "string", "description": "Path to the modified file"},
|
||||
"format_type": {"type": "string", "description": "Display format: 'unified' or 'side-by-side'", "default": "unified"}
|
||||
"filepath1": {
|
||||
"type": "string",
|
||||
"description": "Path to the original file",
|
||||
},
|
||||
"filepath2": {
|
||||
"type": "string",
|
||||
"description": "Path to the modified file",
|
||||
},
|
||||
"format_type": {
|
||||
"type": "string",
|
||||
"description": "Display format: 'unified' or 'side-by-side'",
|
||||
"default": "unified",
|
||||
},
|
||||
},
|
||||
"required": ["filepath1", "filepath2"]
|
||||
}
|
||||
}
|
||||
"required": ["filepath1", "filepath2"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "display_edit_summary",
|
||||
"description": "Display a summary of all edit operations performed during the session",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@ -467,21 +584,21 @@ def get_tools_definition():
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"show_content": {"type": "boolean", "description": "Show content previews", "default": False}
|
||||
}
|
||||
}
|
||||
}
|
||||
"show_content": {
|
||||
"type": "boolean",
|
||||
"description": "Show content previews",
|
||||
"default": False,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "clear_edit_tracker",
|
||||
"description": "Clear the edit tracker to start fresh",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@ -1,21 +1,23 @@
|
||||
import os
|
||||
import select
|
||||
import subprocess
|
||||
import time
|
||||
import select
|
||||
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
|
||||
from pr.tools.interactive_control import start_interactive_session
|
||||
from pr.config import MAX_CONCURRENT_SESSIONS
|
||||
|
||||
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
|
||||
|
||||
_processes = {}
|
||||
|
||||
def _register_process(pid:int, process):
|
||||
|
||||
def _register_process(pid: int, process):
|
||||
_processes[pid] = process
|
||||
return _processes
|
||||
|
||||
def _get_process(pid:int):
|
||||
|
||||
def _get_process(pid: int):
|
||||
return _processes.get(pid)
|
||||
|
||||
def kill_process(pid:int):
|
||||
|
||||
def kill_process(pid: int):
|
||||
try:
|
||||
process = _get_process(pid)
|
||||
if process:
|
||||
@ -67,7 +69,7 @@ def tail_process(pid: int, timeout: int = 30):
|
||||
"status": "success",
|
||||
"stdout": stdout_content,
|
||||
"stderr": stderr_content,
|
||||
"returncode": process.returncode
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
if time.time() - start_time > timeout_duration:
|
||||
@ -76,10 +78,12 @@ def tail_process(pid: int, timeout: int = 30):
|
||||
"message": "Process is still running. Call tail_process again to continue monitoring.",
|
||||
"stdout_so_far": stdout_content,
|
||||
"stderr_so_far": stderr_content,
|
||||
"pid": pid
|
||||
"pid": pid,
|
||||
}
|
||||
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
ready, _, _ = select.select(
|
||||
[process.stdout, process.stderr], [], [], 0.1
|
||||
)
|
||||
for pipe in ready:
|
||||
if pipe == process.stdout:
|
||||
line = process.stdout.readline()
|
||||
@ -100,7 +104,13 @@ def tail_process(pid: int, timeout: int = 30):
|
||||
def run_command(command, timeout=30, monitored=False):
|
||||
mux_name = None
|
||||
try:
|
||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
_register_process(process.pid, process)
|
||||
|
||||
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
|
||||
@ -129,7 +139,7 @@ def run_command(command, timeout=30, monitored=False):
|
||||
"status": "success",
|
||||
"stdout": stdout_content,
|
||||
"stderr": stderr_content,
|
||||
"returncode": process.returncode
|
||||
"returncode": process.returncode,
|
||||
}
|
||||
|
||||
if time.time() - start_time > timeout_duration:
|
||||
@ -139,7 +149,7 @@ def run_command(command, timeout=30, monitored=False):
|
||||
"stdout_so_far": stdout_content,
|
||||
"stderr_so_far": stderr_content,
|
||||
"pid": process.pid,
|
||||
"mux_name": mux_name
|
||||
"mux_name": mux_name,
|
||||
}
|
||||
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
@ -158,6 +168,8 @@ def run_command(command, timeout=30, monitored=False):
|
||||
if mux_name:
|
||||
close_multiplexer(mux_name)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def run_command_interactive(command):
|
||||
try:
|
||||
return_code = os.system(command)
|
||||
|
||||
@ -1,18 +1,23 @@
|
||||
import time
|
||||
|
||||
|
||||
def db_set(key, value, db_conn):
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
|
||||
try:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("""INSERT OR REPLACE INTO kv_store (key, value, timestamp)
|
||||
VALUES (?, ?, ?)""", (key, value, time.time()))
|
||||
cursor.execute(
|
||||
"""INSERT OR REPLACE INTO kv_store (key, value, timestamp)
|
||||
VALUES (?, ?, ?)""",
|
||||
(key, value, time.time()),
|
||||
)
|
||||
db_conn.commit()
|
||||
return {"status": "success", "message": f"Set {key}"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def db_get(key, db_conn):
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
@ -28,6 +33,7 @@ def db_get(key, db_conn):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def db_query(query, db_conn):
|
||||
if not db_conn:
|
||||
return {"status": "error", "error": "Database not initialized"}
|
||||
@ -36,9 +42,11 @@ def db_query(query, db_conn):
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute(query)
|
||||
|
||||
if query.strip().upper().startswith('SELECT'):
|
||||
if query.strip().upper().startswith("SELECT"):
|
||||
results = cursor.fetchall()
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
columns = (
|
||||
[desc[0] for desc in cursor.description] if cursor.description else []
|
||||
)
|
||||
return {"status": "success", "columns": columns, "rows": results}
|
||||
else:
|
||||
db_conn.commit()
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
from pr.editor import RPEditor
|
||||
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
|
||||
from ..ui.diff_display import display_diff, get_diff_stats
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
from ..tools.patch import display_content_diff
|
||||
import os
|
||||
import os.path
|
||||
|
||||
from pr.editor import RPEditor
|
||||
from pr.multiplexer import close_multiplexer, create_multiplexer, get_multiplexer
|
||||
|
||||
from ..tools.patch import display_content_diff
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
|
||||
_editors = {}
|
||||
|
||||
|
||||
def get_editor(filepath):
|
||||
if filepath not in _editors:
|
||||
_editors[filepath] = RPEditor(filepath)
|
||||
return _editors[filepath]
|
||||
|
||||
|
||||
def close_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -29,6 +32,7 @@ def close_editor(filepath):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def open_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -39,21 +43,28 @@ def open_editor(filepath):
|
||||
mux_name, mux = create_multiplexer(mux_name, show_output=True)
|
||||
mux.write_stdout(f"Opened editor for: {path}\n")
|
||||
|
||||
return {"status": "success", "message": f"Editor opened for {path}", "mux_name": mux_name}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Editor opened for {path}",
|
||||
"mux_name": mux_name,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
|
||||
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
|
||||
position = (line if line is not None else 0) * 1000 + (
|
||||
col if col is not None else 0
|
||||
)
|
||||
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
@ -65,12 +76,16 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
location = f" at line {line}, col {col}" if line is not None and col is not None else ""
|
||||
location = (
|
||||
f" at line {line}, col {col}"
|
||||
if line is not None and col is not None
|
||||
else ""
|
||||
)
|
||||
preview = text[:50] + "..." if len(text) > 50 else text
|
||||
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
@ -81,23 +96,32 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
close_editor(filepath)
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
if "operation" in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True):
|
||||
|
||||
def editor_replace_text(
|
||||
filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True
|
||||
):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
start_pos = start_line * 1000 + start_col
|
||||
end_pos = end_line * 1000 + end_col
|
||||
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
|
||||
content=new_text, old_content=old_content)
|
||||
operation = track_edit(
|
||||
"REPLACE",
|
||||
filepath,
|
||||
start_pos=start_pos,
|
||||
end_pos=end_pos,
|
||||
content=new_text,
|
||||
old_content=old_content,
|
||||
)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
@ -108,10 +132,12 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
preview = new_text[:50] + "..." if len(new_text) > 50 else new_text
|
||||
mux.write_stdout(f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n")
|
||||
mux.write_stdout(
|
||||
f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n"
|
||||
)
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
@ -122,10 +148,11 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
|
||||
close_editor(filepath)
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
if "operation" in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def editor_search(filepath, pattern, start_line=0):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -135,7 +162,9 @@ def editor_search(filepath, pattern, start_line=0):
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
mux.write_stdout(f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n")
|
||||
mux.write_stdout(
|
||||
f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n"
|
||||
)
|
||||
|
||||
result = {"status": "success", "results": results}
|
||||
close_editor(filepath)
|
||||
|
||||
@ -1,31 +1,36 @@
|
||||
import os
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from pr.editor import RPEditor
|
||||
from ..ui.diff_display import display_diff, get_diff_stats
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
|
||||
from ..tools.patch import display_content_diff
|
||||
from ..ui.diff_display import get_diff_stats
|
||||
from ..ui.edit_feedback import track_edit, tracker
|
||||
|
||||
_id = 0
|
||||
|
||||
|
||||
def get_uid():
|
||||
global _id
|
||||
_id += 3
|
||||
return _id
|
||||
|
||||
|
||||
def read_file(filepath, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
content = f.read()
|
||||
if db_conn:
|
||||
from pr.tools.database import db_set
|
||||
|
||||
db_set("read:" + path, "true", db_conn)
|
||||
return {"status": "success", "content": content}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -34,15 +39,24 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
|
||||
if not is_new_file and db_conn:
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
}
|
||||
|
||||
if not is_new_file:
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
operation = track_edit('WRITE', filepath, content=content, old_content=old_content)
|
||||
operation = track_edit(
|
||||
"WRITE", filepath, content=content, old_content=old_content
|
||||
)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
if show_diff and not is_new_file:
|
||||
@ -59,13 +73,18 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
cursor = db_conn.cursor()
|
||||
file_hash = hashlib.md5(old_content.encode()).hexdigest()
|
||||
|
||||
cursor.execute("SELECT MAX(version) FROM file_versions WHERE filepath = ?", (filepath,))
|
||||
cursor.execute(
|
||||
"SELECT MAX(version) FROM file_versions WHERE filepath = ?",
|
||||
(filepath,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
version = (result[0] + 1) if result[0] else 1
|
||||
|
||||
cursor.execute("""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
|
||||
cursor.execute(
|
||||
"""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(filepath, old_content, file_hash, time.time(), version))
|
||||
(filepath, old_content, file_hash, time.time(), version),
|
||||
)
|
||||
db_conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
@ -79,10 +98,11 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
|
||||
return {"status": "success", "message": message}
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
if "operation" in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def list_directory(path=".", recursive=False):
|
||||
try:
|
||||
path = os.path.expanduser(path)
|
||||
@ -91,21 +111,36 @@ def list_directory(path=".", recursive=False):
|
||||
for root, dirs, files in os.walk(path):
|
||||
for name in files:
|
||||
item_path = os.path.join(root, name)
|
||||
items.append({"path": item_path, "type": "file", "size": os.path.getsize(item_path)})
|
||||
items.append(
|
||||
{
|
||||
"path": item_path,
|
||||
"type": "file",
|
||||
"size": os.path.getsize(item_path),
|
||||
}
|
||||
)
|
||||
for name in dirs:
|
||||
items.append({"path": os.path.join(root, name), "type": "directory"})
|
||||
items.append(
|
||||
{"path": os.path.join(root, name), "type": "directory"}
|
||||
)
|
||||
else:
|
||||
for item in os.listdir(path):
|
||||
item_path = os.path.join(path, item)
|
||||
items.append({
|
||||
"name": item,
|
||||
"type": "directory" if os.path.isdir(item_path) else "file",
|
||||
"size": os.path.getsize(item_path) if os.path.isfile(item_path) else None
|
||||
})
|
||||
items.append(
|
||||
{
|
||||
"name": item,
|
||||
"type": "directory" if os.path.isdir(item_path) else "file",
|
||||
"size": (
|
||||
os.path.getsize(item_path)
|
||||
if os.path.isfile(item_path)
|
||||
else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return {"status": "success", "items": items}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
try:
|
||||
os.makedirs(os.path.expanduser(path), exist_ok=True)
|
||||
@ -113,6 +148,7 @@ def mkdir(path):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def chdir(path):
|
||||
try:
|
||||
os.chdir(os.path.expanduser(path))
|
||||
@ -120,16 +156,32 @@ def chdir(path):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def getpwd():
|
||||
try:
|
||||
return {"status": "success", "path": os.getcwd()}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def index_source_directory(path):
|
||||
extensions = [
|
||||
".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp",
|
||||
".html", ".css", ".json", ".xml", ".md", ".sh", ".rb", ".go"
|
||||
".py",
|
||||
".js",
|
||||
".ts",
|
||||
".java",
|
||||
".cpp",
|
||||
".c",
|
||||
".h",
|
||||
".hpp",
|
||||
".html",
|
||||
".css",
|
||||
".json",
|
||||
".xml",
|
||||
".md",
|
||||
".sh",
|
||||
".rb",
|
||||
".go",
|
||||
]
|
||||
source_files = []
|
||||
try:
|
||||
@ -138,18 +190,16 @@ def index_source_directory(path):
|
||||
if any(file.endswith(ext) for ext in extensions):
|
||||
filepath = os.path.join(root, file)
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
source_files.append({
|
||||
"path": filepath,
|
||||
"content": content
|
||||
})
|
||||
source_files.append({"path": filepath, "content": content})
|
||||
except Exception:
|
||||
continue
|
||||
return {"status": "success", "indexed_files": source_files}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def search_replace(filepath, old_string, new_string, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -157,25 +207,38 @@ def search_replace(filepath, old_string, new_string, db_conn=None):
|
||||
return {"status": "error", "error": "File does not exist"}
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
with open(path, 'r') as f:
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
}
|
||||
with open(path) as f:
|
||||
content = f.read()
|
||||
content = content.replace(old_string, new_string)
|
||||
with open(path, 'w') as f:
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
return {"status": "success", "message": f"Replaced '{old_string}' with '{new_string}' in {path}"}
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Replaced '{old_string}' with '{new_string}' in {path}",
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
_editors = {}
|
||||
|
||||
|
||||
def get_editor(filepath):
|
||||
if filepath not in _editors:
|
||||
_editors[filepath] = RPEditor(filepath)
|
||||
return _editors[filepath]
|
||||
|
||||
|
||||
def close_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -185,6 +248,7 @@ def close_editor(filepath):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def open_editor(filepath):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
@ -194,22 +258,34 @@ def open_editor(filepath):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
|
||||
|
||||
def editor_insert_text(
|
||||
filepath, text, line=None, col=None, show_diff=True, db_conn=None
|
||||
):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
}
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
|
||||
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
|
||||
position = (line if line is not None else 0) * 1000 + (
|
||||
col if col is not None else 0
|
||||
)
|
||||
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
@ -219,7 +295,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
|
||||
editor.save_file()
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
@ -228,28 +304,51 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_c
|
||||
tracker.mark_completed(operation)
|
||||
return {"status": "success", "message": f"Inserted text in {path}"}
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
if "operation" in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True, db_conn=None):
|
||||
|
||||
def editor_replace_text(
|
||||
filepath,
|
||||
start_line,
|
||||
start_col,
|
||||
end_line,
|
||||
end_col,
|
||||
new_text,
|
||||
show_diff=True,
|
||||
db_conn=None,
|
||||
):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
}
|
||||
|
||||
old_content = ""
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
start_pos = start_line * 1000 + start_col
|
||||
end_pos = end_line * 1000 + end_col
|
||||
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
|
||||
content=new_text, old_content=old_content)
|
||||
operation = track_edit(
|
||||
"REPLACE",
|
||||
filepath,
|
||||
start_pos=start_pos,
|
||||
end_pos=end_pos,
|
||||
content=new_text,
|
||||
old_content=old_content,
|
||||
)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
editor = get_editor(path)
|
||||
@ -257,7 +356,7 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
|
||||
editor.save_file()
|
||||
|
||||
if show_diff and old_content:
|
||||
with open(path, 'r') as f:
|
||||
with open(path) as f:
|
||||
new_content = f.read()
|
||||
diff_result = display_content_diff(old_content, new_content, filepath)
|
||||
if diff_result["status"] == "success":
|
||||
@ -266,22 +365,25 @@ def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_
|
||||
tracker.mark_completed(operation)
|
||||
return {"status": "success", "message": f"Replaced text in {path}"}
|
||||
except Exception as e:
|
||||
if 'operation' in locals():
|
||||
if "operation" in locals():
|
||||
tracker.mark_failed(operation)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_edit_summary():
|
||||
from ..ui.edit_feedback import display_edit_summary
|
||||
|
||||
return display_edit_summary()
|
||||
|
||||
|
||||
def display_edit_timeline(show_content=False):
|
||||
from ..ui.edit_feedback import display_edit_timeline
|
||||
|
||||
return display_edit_timeline(show_content)
|
||||
|
||||
|
||||
def clear_edit_tracker():
|
||||
from ..ui.edit_feedback import clear_tracker
|
||||
|
||||
clear_tracker()
|
||||
return {"status": "success", "message": "Edit tracker cleared"}
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pr.multiplexer import create_multiplexer, get_multiplexer, close_multiplexer, get_all_multiplexer_states
|
||||
|
||||
def start_interactive_session(command, session_name=None, process_type='generic'):
|
||||
from pr.multiplexer import (
|
||||
close_multiplexer,
|
||||
create_multiplexer,
|
||||
get_all_multiplexer_states,
|
||||
get_multiplexer,
|
||||
)
|
||||
|
||||
|
||||
def start_interactive_session(command, session_name=None, process_type="generic"):
|
||||
"""
|
||||
Start an interactive session in a dedicated multiplexer.
|
||||
|
||||
@ -16,7 +22,7 @@ def start_interactive_session(command, session_name=None, process_type='generic'
|
||||
session_name: The name of the created session
|
||||
"""
|
||||
name, mux = create_multiplexer(session_name)
|
||||
mux.update_metadata('process_type', process_type)
|
||||
mux.update_metadata("process_type", process_type)
|
||||
|
||||
# Start the process
|
||||
if isinstance(command, str):
|
||||
@ -29,19 +35,23 @@ def start_interactive_session(command, session_name=None, process_type='generic'
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
mux.process = process
|
||||
mux.update_metadata('pid', process.pid)
|
||||
mux.update_metadata("pid", process.pid)
|
||||
|
||||
# Set process type and handler
|
||||
detected_type = detect_process_type(command)
|
||||
mux.set_process_type(detected_type)
|
||||
|
||||
# Start output readers
|
||||
stdout_thread = threading.Thread(target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True)
|
||||
stderr_thread = threading.Thread(target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True)
|
||||
stdout_thread = threading.Thread(
|
||||
target=_read_output, args=(process.stdout, mux.write_stdout), daemon=True
|
||||
)
|
||||
stderr_thread = threading.Thread(
|
||||
target=_read_output, args=(process.stderr, mux.write_stderr), daemon=True
|
||||
)
|
||||
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
@ -54,15 +64,17 @@ def start_interactive_session(command, session_name=None, process_type='generic'
|
||||
close_multiplexer(name)
|
||||
raise e
|
||||
|
||||
|
||||
def _read_output(stream, write_func):
|
||||
"""Read from a stream and write to multiplexer buffer."""
|
||||
try:
|
||||
for line in iter(stream.readline, ''):
|
||||
for line in iter(stream.readline, ""):
|
||||
if line:
|
||||
write_func(line.rstrip('\n'))
|
||||
write_func(line.rstrip("\n"))
|
||||
except Exception as e:
|
||||
print(f"Error reading output: {e}")
|
||||
|
||||
|
||||
def send_input_to_session(session_name, input_data):
|
||||
"""
|
||||
Send input to an interactive session.
|
||||
@ -75,15 +87,16 @@ def send_input_to_session(session_name, input_data):
|
||||
if not mux:
|
||||
raise ValueError(f"Session {session_name} not found")
|
||||
|
||||
if not hasattr(mux, 'process') or mux.process.poll() is not None:
|
||||
if not hasattr(mux, "process") or mux.process.poll() is not None:
|
||||
raise ValueError(f"Session {session_name} is not active")
|
||||
|
||||
try:
|
||||
mux.process.stdin.write(input_data + '\n')
|
||||
mux.process.stdin.write(input_data + "\n")
|
||||
mux.process.stdin.flush()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to send input to session {session_name}: {e}")
|
||||
|
||||
|
||||
def read_session_output(session_name, lines=None):
|
||||
"""
|
||||
Read output from a session.
|
||||
@ -102,14 +115,12 @@ def read_session_output(session_name, lines=None):
|
||||
output = mux.get_all_output()
|
||||
if lines is not None:
|
||||
# Return last N lines
|
||||
stdout_lines = output['stdout'].split('\n')[-lines:] if output['stdout'] else []
|
||||
stderr_lines = output['stderr'].split('\n')[-lines:] if output['stderr'] else []
|
||||
output = {
|
||||
'stdout': '\n'.join(stdout_lines),
|
||||
'stderr': '\n'.join(stderr_lines)
|
||||
}
|
||||
stdout_lines = output["stdout"].split("\n")[-lines:] if output["stdout"] else []
|
||||
stderr_lines = output["stderr"].split("\n")[-lines:] if output["stderr"] else []
|
||||
output = {"stdout": "\n".join(stdout_lines), "stderr": "\n".join(stderr_lines)}
|
||||
return output
|
||||
|
||||
|
||||
def list_active_sessions():
|
||||
"""
|
||||
List all active interactive sessions.
|
||||
@ -119,6 +130,7 @@ def list_active_sessions():
|
||||
"""
|
||||
return get_all_multiplexer_states()
|
||||
|
||||
|
||||
def get_session_status(session_name):
|
||||
"""
|
||||
Get detailed status of a session.
|
||||
@ -134,15 +146,16 @@ def get_session_status(session_name):
|
||||
return None
|
||||
|
||||
status = mux.get_metadata()
|
||||
status['is_active'] = hasattr(mux, 'process') and mux.process.poll() is None
|
||||
if status['is_active']:
|
||||
status['pid'] = mux.process.pid
|
||||
status['output_summary'] = {
|
||||
'stdout_lines': len(mux.stdout_buffer),
|
||||
'stderr_lines': len(mux.stderr_buffer)
|
||||
status["is_active"] = hasattr(mux, "process") and mux.process.poll() is None
|
||||
if status["is_active"]:
|
||||
status["pid"] = mux.process.pid
|
||||
status["output_summary"] = {
|
||||
"stdout_lines": len(mux.stdout_buffer),
|
||||
"stderr_lines": len(mux.stderr_buffer),
|
||||
}
|
||||
return status
|
||||
|
||||
|
||||
def close_interactive_session(session_name):
|
||||
"""
|
||||
Close an interactive session.
|
||||
|
||||
@ -1,38 +1,43 @@
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]:
|
||||
from pr.memory.knowledge_store import KnowledgeEntry, KnowledgeStore
|
||||
|
||||
|
||||
def add_knowledge_entry(
|
||||
category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Add a new entry to the knowledge base."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
if entry_id is None:
|
||||
entry_id = str(uuid.uuid4())[:16]
|
||||
|
||||
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=entry_id,
|
||||
category=category,
|
||||
content=content,
|
||||
metadata=metadata or {},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
updated_at=time.time(),
|
||||
)
|
||||
|
||||
|
||||
store.add_entry(entry)
|
||||
return {"status": "success", "entry_id": entry_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve a knowledge entry by ID."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
entry = store.get_entry(entry_id)
|
||||
if entry:
|
||||
return {"status": "success", "entry": entry.to_dict()}
|
||||
@ -41,58 +46,71 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
|
||||
|
||||
def search_knowledge(
|
||||
query: str, category: str = None, top_k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""Search the knowledge base semantically."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
entries = store.search_entries(query, category, top_k)
|
||||
results = [entry.to_dict() for entry in entries]
|
||||
return {"status": "success", "results": results}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
|
||||
"""Get knowledge entries by category."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
entries = store.get_by_category(category, limit)
|
||||
results = [entry.to_dict() for entry in entries]
|
||||
return {"status": "success", "entries": results}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
|
||||
|
||||
def update_knowledge_importance(
|
||||
entry_id: str, importance_score: float
|
||||
) -> Dict[str, Any]:
|
||||
"""Update the importance score of a knowledge entry."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
store.update_importance(entry_id, importance_score)
|
||||
return {"status": "success", "entry_id": entry_id, "importance_score": importance_score}
|
||||
return {
|
||||
"status": "success",
|
||||
"entry_id": entry_id,
|
||||
"importance_score": importance_score,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]:
|
||||
"""Delete a knowledge entry."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
success = store.delete_entry(entry_id)
|
||||
return {"status": "success" if success else "not_found", "entry_id": entry_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def get_knowledge_statistics() -> Dict[str, Any]:
|
||||
"""Get statistics about the knowledge base."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
|
||||
stats = store.get_statistics()
|
||||
return {"status": "success", "statistics": stats}
|
||||
except Exception as e:
|
||||
|
||||
@ -1,23 +1,37 @@
|
||||
import os
|
||||
import tempfile
|
||||
import subprocess
|
||||
import difflib
|
||||
from ..ui.diff_display import display_diff, get_diff_stats, DiffDisplay
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from ..ui.diff_display import display_diff, get_diff_stats
|
||||
|
||||
|
||||
def apply_patch(filepath, patch_content, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
}
|
||||
# Write patch to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.patch') as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".patch") as f:
|
||||
f.write(patch_content)
|
||||
patch_file = f.name
|
||||
# Run patch command
|
||||
result = subprocess.run(['patch', path, patch_file], capture_output=True, text=True, cwd=os.path.dirname(path))
|
||||
result = subprocess.run(
|
||||
["patch", path, patch_file],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=os.path.dirname(path),
|
||||
)
|
||||
os.unlink(patch_file)
|
||||
if result.returncode == 0:
|
||||
return {"status": "success", "output": result.stdout.strip()}
|
||||
@ -26,11 +40,14 @@ def apply_patch(filepath, patch_content, db_conn=None):
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, format_type='unified'):
|
||||
|
||||
def create_diff(
|
||||
file1, file2, fromfile="file1", tofile="file2", visual=False, format_type="unified"
|
||||
):
|
||||
try:
|
||||
path1 = os.path.expanduser(file1)
|
||||
path2 = os.path.expanduser(file2)
|
||||
with open(path1, 'r') as f1, open(path2, 'r') as f2:
|
||||
with open(path1) as f1, open(path2) as f2:
|
||||
content1 = f1.read()
|
||||
content2 = f2.read()
|
||||
|
||||
@ -39,53 +56,51 @@ def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, fo
|
||||
stats = get_diff_stats(content1, content2)
|
||||
lines1 = content1.splitlines(keepends=True)
|
||||
lines2 = content2.splitlines(keepends=True)
|
||||
plain_diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
|
||||
plain_diff = list(
|
||||
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"diff": ''.join(plain_diff),
|
||||
"diff": "".join(plain_diff),
|
||||
"visual_diff": visual_diff,
|
||||
"stats": stats
|
||||
"stats": stats,
|
||||
}
|
||||
else:
|
||||
lines1 = content1.splitlines(keepends=True)
|
||||
lines2 = content2.splitlines(keepends=True)
|
||||
diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
|
||||
return {"status": "success", "diff": ''.join(diff)}
|
||||
diff = list(
|
||||
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
|
||||
)
|
||||
return {"status": "success", "diff": "".join(diff)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_file_diff(filepath1, filepath2, format_type='unified', context_lines=3):
|
||||
def display_file_diff(filepath1, filepath2, format_type="unified", context_lines=3):
|
||||
try:
|
||||
path1 = os.path.expanduser(filepath1)
|
||||
path2 = os.path.expanduser(filepath2)
|
||||
|
||||
with open(path1, 'r') as f1:
|
||||
with open(path1) as f1:
|
||||
old_content = f1.read()
|
||||
with open(path2, 'r') as f2:
|
||||
with open(path2) as f2:
|
||||
new_content = f2.read()
|
||||
|
||||
visual_diff = display_diff(old_content, new_content, filepath1, format_type)
|
||||
stats = get_diff_stats(old_content, new_content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"visual_diff": visual_diff,
|
||||
"stats": stats
|
||||
}
|
||||
return {"status": "success", "visual_diff": visual_diff, "stats": stats}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_content_diff(old_content, new_content, filename='file', format_type='unified'):
|
||||
def display_content_diff(
|
||||
old_content, new_content, filename="file", format_type="unified"
|
||||
):
|
||||
try:
|
||||
visual_diff = display_diff(old_content, new_content, filename, format_type)
|
||||
stats = get_diff_stats(old_content, new_content)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"visual_diff": visual_diff,
|
||||
"stats": stats
|
||||
}
|
||||
return {"status": "success", "visual_diff": visual_diff, "stats": stats}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProcessHandler(ABC):
|
||||
"""Base class for process-specific handlers."""
|
||||
|
||||
def __init__(self, multiplexer):
|
||||
self.multiplexer = multiplexer
|
||||
self.state_machine = {}
|
||||
self.current_state = 'initial'
|
||||
self.current_state = "initial"
|
||||
self.prompt_patterns = []
|
||||
self.response_suggestions = {}
|
||||
|
||||
@ -27,7 +26,8 @@ class ProcessHandler(ABC):
|
||||
|
||||
def is_waiting_for_input(self):
|
||||
"""Check if process appears to be waiting for input."""
|
||||
return self.current_state in ['waiting_confirmation', 'waiting_input']
|
||||
return self.current_state in ["waiting_confirmation", "waiting_input"]
|
||||
|
||||
|
||||
class AptHandler(ProcessHandler):
|
||||
"""Handler for apt package manager interactions."""
|
||||
@ -35,230 +35,238 @@ class AptHandler(ProcessHandler):
|
||||
def __init__(self, multiplexer):
|
||||
super().__init__(multiplexer)
|
||||
self.state_machine = {
|
||||
'initial': ['running_command'],
|
||||
'running_command': ['waiting_confirmation', 'completed'],
|
||||
'waiting_confirmation': ['confirmed', 'cancelled'],
|
||||
'confirmed': ['installing', 'completed'],
|
||||
'installing': ['completed', 'error'],
|
||||
'completed': [],
|
||||
'error': [],
|
||||
'cancelled': []
|
||||
"initial": ["running_command"],
|
||||
"running_command": ["waiting_confirmation", "completed"],
|
||||
"waiting_confirmation": ["confirmed", "cancelled"],
|
||||
"confirmed": ["installing", "completed"],
|
||||
"installing": ["completed", "error"],
|
||||
"completed": [],
|
||||
"error": [],
|
||||
"cancelled": [],
|
||||
}
|
||||
self.prompt_patterns = [
|
||||
(r'Do you want to continue\?', 'confirmation'),
|
||||
(r'After this operation.*installed\.', 'size_info'),
|
||||
(r'Need to get.*B of archives\.', 'download_info'),
|
||||
(r'Unpacking.*Configuring', 'configuring'),
|
||||
(r'Setting up', 'setting_up'),
|
||||
(r'E:\s', 'error')
|
||||
(r"Do you want to continue\?", "confirmation"),
|
||||
(r"After this operation.*installed\.", "size_info"),
|
||||
(r"Need to get.*B of archives\.", "download_info"),
|
||||
(r"Unpacking.*Configuring", "configuring"),
|
||||
(r"Setting up", "setting_up"),
|
||||
(r"E:\s", "error"),
|
||||
]
|
||||
|
||||
def get_process_type(self):
|
||||
return 'apt'
|
||||
return "apt"
|
||||
|
||||
def update_state(self, output):
|
||||
"""Update state based on apt output patterns."""
|
||||
output_lower = output.lower()
|
||||
|
||||
# Check for completion
|
||||
if 'processing triggers' in output_lower or 'done' in output_lower:
|
||||
self.current_state = 'completed'
|
||||
if "processing triggers" in output_lower or "done" in output_lower:
|
||||
self.current_state = "completed"
|
||||
# Check for confirmation prompts
|
||||
elif 'do you want to continue' in output_lower:
|
||||
self.current_state = 'waiting_confirmation'
|
||||
elif "do you want to continue" in output_lower:
|
||||
self.current_state = "waiting_confirmation"
|
||||
# Check for installation progress
|
||||
elif 'setting up' in output_lower or 'unpacking' in output_lower:
|
||||
self.current_state = 'installing'
|
||||
elif "setting up" in output_lower or "unpacking" in output_lower:
|
||||
self.current_state = "installing"
|
||||
# Check for errors
|
||||
elif 'e:' in output_lower or 'error' in output_lower:
|
||||
self.current_state = 'error'
|
||||
elif "e:" in output_lower or "error" in output_lower:
|
||||
self.current_state = "error"
|
||||
|
||||
def get_prompt_suggestions(self):
|
||||
"""Return suggested responses for apt prompts."""
|
||||
suggestions = super().get_prompt_suggestions()
|
||||
if self.current_state == 'waiting_confirmation':
|
||||
suggestions.extend(['y', 'yes', 'n', 'no'])
|
||||
if self.current_state == "waiting_confirmation":
|
||||
suggestions.extend(["y", "yes", "n", "no"])
|
||||
return suggestions
|
||||
|
||||
|
||||
class VimHandler(ProcessHandler):
|
||||
"""Handler for vim editor interactions."""
|
||||
|
||||
def __init__(self, multiplexer):
|
||||
super().__init__(multiplexer)
|
||||
self.state_machine = {
|
||||
'initial': ['normal_mode', 'insert_mode'],
|
||||
'normal_mode': ['insert_mode', 'command_mode', 'visual_mode'],
|
||||
'insert_mode': ['normal_mode'],
|
||||
'command_mode': ['normal_mode'],
|
||||
'visual_mode': ['normal_mode'],
|
||||
'exiting': []
|
||||
"initial": ["normal_mode", "insert_mode"],
|
||||
"normal_mode": ["insert_mode", "command_mode", "visual_mode"],
|
||||
"insert_mode": ["normal_mode"],
|
||||
"command_mode": ["normal_mode"],
|
||||
"visual_mode": ["normal_mode"],
|
||||
"exiting": [],
|
||||
}
|
||||
self.prompt_patterns = [
|
||||
(r'-- INSERT --', 'insert_mode'),
|
||||
(r'-- VISUAL --', 'visual_mode'),
|
||||
(r':', 'command_mode'),
|
||||
(r'Press ENTER', 'waiting_enter'),
|
||||
(r'Saved', 'saved')
|
||||
(r"-- INSERT --", "insert_mode"),
|
||||
(r"-- VISUAL --", "visual_mode"),
|
||||
(r":", "command_mode"),
|
||||
(r"Press ENTER", "waiting_enter"),
|
||||
(r"Saved", "saved"),
|
||||
]
|
||||
self.mode_indicators = {
|
||||
'insert': '-- INSERT --',
|
||||
'visual': '-- VISUAL --',
|
||||
'command': ':'
|
||||
"insert": "-- INSERT --",
|
||||
"visual": "-- VISUAL --",
|
||||
"command": ":",
|
||||
}
|
||||
|
||||
def get_process_type(self):
|
||||
return 'vim'
|
||||
return "vim"
|
||||
|
||||
def update_state(self, output):
|
||||
"""Update state based on vim mode indicators."""
|
||||
if '-- INSERT --' in output:
|
||||
self.current_state = 'insert_mode'
|
||||
elif '-- VISUAL --' in output:
|
||||
self.current_state = 'visual_mode'
|
||||
elif output.strip().endswith(':'):
|
||||
self.current_state = 'command_mode'
|
||||
elif 'Press ENTER' in output:
|
||||
self.current_state = 'waiting_enter'
|
||||
if "-- INSERT --" in output:
|
||||
self.current_state = "insert_mode"
|
||||
elif "-- VISUAL --" in output:
|
||||
self.current_state = "visual_mode"
|
||||
elif output.strip().endswith(":"):
|
||||
self.current_state = "command_mode"
|
||||
elif "Press ENTER" in output:
|
||||
self.current_state = "waiting_enter"
|
||||
else:
|
||||
# Default to normal mode if no specific indicators
|
||||
self.current_state = 'normal_mode'
|
||||
self.current_state = "normal_mode"
|
||||
|
||||
def get_prompt_suggestions(self):
|
||||
"""Return suggested commands for vim modes."""
|
||||
suggestions = super().get_prompt_suggestions()
|
||||
if self.current_state == 'command_mode':
|
||||
suggestions.extend(['w', 'q', 'wq', 'q!', 'w!'])
|
||||
elif self.current_state == 'normal_mode':
|
||||
suggestions.extend(['i', 'a', 'o', 'dd', ':w', ':q'])
|
||||
elif self.current_state == 'waiting_enter':
|
||||
suggestions.extend(['\n'])
|
||||
if self.current_state == "command_mode":
|
||||
suggestions.extend(["w", "q", "wq", "q!", "w!"])
|
||||
elif self.current_state == "normal_mode":
|
||||
suggestions.extend(["i", "a", "o", "dd", ":w", ":q"])
|
||||
elif self.current_state == "waiting_enter":
|
||||
suggestions.extend(["\n"])
|
||||
return suggestions
|
||||
|
||||
|
||||
class SSHHandler(ProcessHandler):
|
||||
"""Handler for SSH connection interactions."""
|
||||
|
||||
def __init__(self, multiplexer):
|
||||
super().__init__(multiplexer)
|
||||
self.state_machine = {
|
||||
'initial': ['connecting'],
|
||||
'connecting': ['auth_prompt', 'connected', 'failed'],
|
||||
'auth_prompt': ['connected', 'failed'],
|
||||
'connected': ['shell', 'disconnected'],
|
||||
'shell': ['disconnected'],
|
||||
'failed': [],
|
||||
'disconnected': []
|
||||
"initial": ["connecting"],
|
||||
"connecting": ["auth_prompt", "connected", "failed"],
|
||||
"auth_prompt": ["connected", "failed"],
|
||||
"connected": ["shell", "disconnected"],
|
||||
"shell": ["disconnected"],
|
||||
"failed": [],
|
||||
"disconnected": [],
|
||||
}
|
||||
self.prompt_patterns = [
|
||||
(r'password:', 'password_prompt'),
|
||||
(r'yes/no', 'host_key_prompt'),
|
||||
(r'Permission denied', 'auth_failed'),
|
||||
(r'Welcome to', 'connected'),
|
||||
(r'\$', 'shell_prompt'),
|
||||
(r'\#', 'root_shell_prompt'),
|
||||
(r'Connection closed', 'disconnected')
|
||||
(r"password:", "password_prompt"),
|
||||
(r"yes/no", "host_key_prompt"),
|
||||
(r"Permission denied", "auth_failed"),
|
||||
(r"Welcome to", "connected"),
|
||||
(r"\$", "shell_prompt"),
|
||||
(r"\#", "root_shell_prompt"),
|
||||
(r"Connection closed", "disconnected"),
|
||||
]
|
||||
|
||||
def get_process_type(self):
|
||||
return 'ssh'
|
||||
return "ssh"
|
||||
|
||||
def update_state(self, output):
|
||||
"""Update state based on SSH connection output."""
|
||||
output_lower = output.lower()
|
||||
|
||||
if 'permission denied' in output_lower:
|
||||
self.current_state = 'failed'
|
||||
elif 'password:' in output_lower:
|
||||
self.current_state = 'auth_prompt'
|
||||
elif 'yes/no' in output_lower:
|
||||
self.current_state = 'auth_prompt'
|
||||
elif 'welcome to' in output_lower or 'last login' in output_lower:
|
||||
self.current_state = 'connected'
|
||||
elif output.strip().endswith('$') or output.strip().endswith('#'):
|
||||
self.current_state = 'shell'
|
||||
elif 'connection closed' in output_lower:
|
||||
self.current_state = 'disconnected'
|
||||
if "permission denied" in output_lower:
|
||||
self.current_state = "failed"
|
||||
elif "password:" in output_lower:
|
||||
self.current_state = "auth_prompt"
|
||||
elif "yes/no" in output_lower:
|
||||
self.current_state = "auth_prompt"
|
||||
elif "welcome to" in output_lower or "last login" in output_lower:
|
||||
self.current_state = "connected"
|
||||
elif output.strip().endswith("$") or output.strip().endswith("#"):
|
||||
self.current_state = "shell"
|
||||
elif "connection closed" in output_lower:
|
||||
self.current_state = "disconnected"
|
||||
|
||||
def get_prompt_suggestions(self):
|
||||
"""Return suggested responses for SSH prompts."""
|
||||
suggestions = super().get_prompt_suggestions()
|
||||
if self.current_state == 'auth_prompt':
|
||||
if 'password:' in self.multiplexer.get_all_output()['stdout']:
|
||||
suggestions.extend(['<password>']) # Placeholder for actual password
|
||||
elif 'yes/no' in self.multiplexer.get_all_output()['stdout']:
|
||||
suggestions.extend(['yes', 'no'])
|
||||
if self.current_state == "auth_prompt":
|
||||
if "password:" in self.multiplexer.get_all_output()["stdout"]:
|
||||
suggestions.extend(["<password>"]) # Placeholder for actual password
|
||||
elif "yes/no" in self.multiplexer.get_all_output()["stdout"]:
|
||||
suggestions.extend(["yes", "no"])
|
||||
return suggestions
|
||||
|
||||
|
||||
class GenericProcessHandler(ProcessHandler):
|
||||
"""Fallback handler for unknown process types."""
|
||||
|
||||
def __init__(self, multiplexer):
|
||||
super().__init__(multiplexer)
|
||||
self.state_machine = {
|
||||
'initial': ['running'],
|
||||
'running': ['waiting_input', 'completed'],
|
||||
'waiting_input': ['running'],
|
||||
'completed': []
|
||||
"initial": ["running"],
|
||||
"running": ["waiting_input", "completed"],
|
||||
"waiting_input": ["running"],
|
||||
"completed": [],
|
||||
}
|
||||
self.prompt_patterns = [
|
||||
(r'\?\s*$', 'waiting_input'), # Lines ending with ?
|
||||
(r'>\s*$', 'waiting_input'), # Lines ending with >
|
||||
(r':\s*$', 'waiting_input'), # Lines ending with :
|
||||
(r'done', 'completed'),
|
||||
(r'finished', 'completed'),
|
||||
(r'exit code', 'completed')
|
||||
(r"\?\s*$", "waiting_input"), # Lines ending with ?
|
||||
(r">\s*$", "waiting_input"), # Lines ending with >
|
||||
(r":\s*$", "waiting_input"), # Lines ending with :
|
||||
(r"done", "completed"),
|
||||
(r"finished", "completed"),
|
||||
(r"exit code", "completed"),
|
||||
]
|
||||
|
||||
def get_process_type(self):
|
||||
return 'generic'
|
||||
return "generic"
|
||||
|
||||
def update_state(self, output):
|
||||
"""Basic state detection for generic processes."""
|
||||
output_lower = output.lower()
|
||||
|
||||
if any(pattern in output_lower for pattern in ['done', 'finished', 'complete']):
|
||||
self.current_state = 'completed'
|
||||
elif any(output.strip().endswith(char) for char in ['?', '>', ':']):
|
||||
self.current_state = 'waiting_input'
|
||||
if any(pattern in output_lower for pattern in ["done", "finished", "complete"]):
|
||||
self.current_state = "completed"
|
||||
elif any(output.strip().endswith(char) for char in ["?", ">", ":"]):
|
||||
self.current_state = "waiting_input"
|
||||
else:
|
||||
self.current_state = 'running'
|
||||
self.current_state = "running"
|
||||
|
||||
|
||||
# Handler registry
|
||||
_handler_classes = {
|
||||
'apt': AptHandler,
|
||||
'vim': VimHandler,
|
||||
'ssh': SSHHandler,
|
||||
'generic': GenericProcessHandler
|
||||
"apt": AptHandler,
|
||||
"vim": VimHandler,
|
||||
"ssh": SSHHandler,
|
||||
"generic": GenericProcessHandler,
|
||||
}
|
||||
|
||||
|
||||
def get_handler_for_process(process_type, multiplexer):
|
||||
"""Get appropriate handler for a process type."""
|
||||
handler_class = _handler_classes.get(process_type, GenericProcessHandler)
|
||||
return handler_class(multiplexer)
|
||||
|
||||
|
||||
def detect_process_type(command):
|
||||
"""Detect process type from command."""
|
||||
command_str = ' '.join(command) if isinstance(command, list) else command
|
||||
command_str = " ".join(command) if isinstance(command, list) else command
|
||||
command_lower = command_str.lower()
|
||||
|
||||
if 'apt' in command_lower or 'apt-get' in command_lower:
|
||||
return 'apt'
|
||||
elif 'vim' in command_lower or 'vi ' in command_lower:
|
||||
return 'vim'
|
||||
elif 'ssh' in command_lower:
|
||||
return 'ssh'
|
||||
if "apt" in command_lower or "apt-get" in command_lower:
|
||||
return "apt"
|
||||
elif "vim" in command_lower or "vi " in command_lower:
|
||||
return "vim"
|
||||
elif "ssh" in command_lower:
|
||||
return "ssh"
|
||||
else:
|
||||
return 'generic'
|
||||
return 'ssh'
|
||||
return "generic"
|
||||
return "ssh"
|
||||
|
||||
|
||||
def detect_process_type(command):
|
||||
"""Detect process type from command."""
|
||||
command_str = ' '.join(command) if isinstance(command, list) else command
|
||||
command_str = " ".join(command) if isinstance(command, list) else command
|
||||
command_lower = command_str.lower()
|
||||
|
||||
if 'apt' in command_lower or 'apt-get' in command_lower:
|
||||
return 'apt'
|
||||
elif 'vim' in command_lower or 'vi ' in command_lower:
|
||||
return 'vim'
|
||||
elif 'ssh' in command_lower:
|
||||
return 'ssh'
|
||||
if "apt" in command_lower or "apt-get" in command_lower:
|
||||
return "apt"
|
||||
elif "vim" in command_lower or "vi " in command_lower:
|
||||
return "vim"
|
||||
elif "ssh" in command_lower:
|
||||
return "ssh"
|
||||
else:
|
||||
return 'generic'
|
||||
return "generic"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class PromptDetector:
|
||||
"""Detects various process prompts and manages interaction state."""
|
||||
@ -10,101 +10,119 @@ class PromptDetector:
|
||||
self.state_machines = self._load_state_machines()
|
||||
self.session_states = {}
|
||||
self.timeout_configs = {
|
||||
'default': 30, # 30 seconds default timeout
|
||||
'apt': 300, # 5 minutes for apt operations
|
||||
'ssh': 60, # 1 minute for SSH connections
|
||||
'vim': 3600 # 1 hour for vim sessions
|
||||
"default": 30, # 30 seconds default timeout
|
||||
"apt": 300, # 5 minutes for apt operations
|
||||
"ssh": 60, # 1 minute for SSH connections
|
||||
"vim": 3600, # 1 hour for vim sessions
|
||||
}
|
||||
|
||||
def _load_prompt_patterns(self):
|
||||
"""Load regex patterns for detecting various prompts."""
|
||||
return {
|
||||
'bash_prompt': [
|
||||
re.compile(r'[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$'),
|
||||
re.compile(r'\$\s*$'),
|
||||
re.compile(r'#\s*$'),
|
||||
re.compile(r'>\s*$') # Continuation prompt
|
||||
"bash_prompt": [
|
||||
re.compile(r"[\w\-\.]+@[\w\-\.]+:.*[\$#]\s*$"),
|
||||
re.compile(r"\$\s*$"),
|
||||
re.compile(r"#\s*$"),
|
||||
re.compile(r">\s*$"), # Continuation prompt
|
||||
],
|
||||
'confirmation': [
|
||||
re.compile(r'[Yy]/[Nn]', re.IGNORECASE),
|
||||
re.compile(r'[Yy]es/[Nn]o', re.IGNORECASE),
|
||||
re.compile(r'continue\?', re.IGNORECASE),
|
||||
re.compile(r'proceed\?', re.IGNORECASE)
|
||||
"confirmation": [
|
||||
re.compile(r"[Yy]/[Nn]", re.IGNORECASE),
|
||||
re.compile(r"[Yy]es/[Nn]o", re.IGNORECASE),
|
||||
re.compile(r"continue\?", re.IGNORECASE),
|
||||
re.compile(r"proceed\?", re.IGNORECASE),
|
||||
],
|
||||
'password': [
|
||||
re.compile(r'password:', re.IGNORECASE),
|
||||
re.compile(r'passphrase:', re.IGNORECASE),
|
||||
re.compile(r'enter password', re.IGNORECASE)
|
||||
"password": [
|
||||
re.compile(r"password:", re.IGNORECASE),
|
||||
re.compile(r"passphrase:", re.IGNORECASE),
|
||||
re.compile(r"enter password", re.IGNORECASE),
|
||||
],
|
||||
'sudo_password': [
|
||||
re.compile(r'\[sudo\].*password', re.IGNORECASE)
|
||||
"sudo_password": [re.compile(r"\[sudo\].*password", re.IGNORECASE)],
|
||||
"apt": [
|
||||
re.compile(r"Do you want to continue\?", re.IGNORECASE),
|
||||
re.compile(r"After this operation", re.IGNORECASE),
|
||||
re.compile(r"Need to get", re.IGNORECASE),
|
||||
],
|
||||
'apt': [
|
||||
re.compile(r'Do you want to continue\?', re.IGNORECASE),
|
||||
re.compile(r'After this operation', re.IGNORECASE),
|
||||
re.compile(r'Need to get', re.IGNORECASE)
|
||||
"vim": [
|
||||
re.compile(r"-- INSERT --"),
|
||||
re.compile(r"-- VISUAL --"),
|
||||
re.compile(r":"),
|
||||
re.compile(r"Press ENTER", re.IGNORECASE),
|
||||
],
|
||||
'vim': [
|
||||
re.compile(r'-- INSERT --'),
|
||||
re.compile(r'-- VISUAL --'),
|
||||
re.compile(r':'),
|
||||
re.compile(r'Press ENTER', re.IGNORECASE)
|
||||
"ssh": [
|
||||
re.compile(r"yes/no", re.IGNORECASE),
|
||||
re.compile(r"password:", re.IGNORECASE),
|
||||
re.compile(r"Permission denied", re.IGNORECASE),
|
||||
],
|
||||
'ssh': [
|
||||
re.compile(r'yes/no', re.IGNORECASE),
|
||||
re.compile(r'password:', re.IGNORECASE),
|
||||
re.compile(r'Permission denied', re.IGNORECASE)
|
||||
"git": [
|
||||
re.compile(r"Username:", re.IGNORECASE),
|
||||
re.compile(r"Email:", re.IGNORECASE),
|
||||
],
|
||||
'git': [
|
||||
re.compile(r'Username:', re.IGNORECASE),
|
||||
re.compile(r'Email:', re.IGNORECASE)
|
||||
"error": [
|
||||
re.compile(r"error:", re.IGNORECASE),
|
||||
re.compile(r"failed", re.IGNORECASE),
|
||||
re.compile(r"exception", re.IGNORECASE),
|
||||
],
|
||||
'error': [
|
||||
re.compile(r'error:', re.IGNORECASE),
|
||||
re.compile(r'failed', re.IGNORECASE),
|
||||
re.compile(r'exception', re.IGNORECASE)
|
||||
]
|
||||
}
|
||||
|
||||
def _load_state_machines(self):
|
||||
"""Load state machines for different process types."""
|
||||
return {
|
||||
'apt': {
|
||||
'states': ['initial', 'running', 'confirming', 'installing', 'completed', 'error'],
|
||||
'transitions': {
|
||||
'initial': ['running'],
|
||||
'running': ['confirming', 'installing', 'completed', 'error'],
|
||||
'confirming': ['installing', 'cancelled'],
|
||||
'installing': ['completed', 'error'],
|
||||
'completed': [],
|
||||
'error': [],
|
||||
'cancelled': []
|
||||
}
|
||||
"apt": {
|
||||
"states": [
|
||||
"initial",
|
||||
"running",
|
||||
"confirming",
|
||||
"installing",
|
||||
"completed",
|
||||
"error",
|
||||
],
|
||||
"transitions": {
|
||||
"initial": ["running"],
|
||||
"running": ["confirming", "installing", "completed", "error"],
|
||||
"confirming": ["installing", "cancelled"],
|
||||
"installing": ["completed", "error"],
|
||||
"completed": [],
|
||||
"error": [],
|
||||
"cancelled": [],
|
||||
},
|
||||
},
|
||||
'ssh': {
|
||||
'states': ['initial', 'connecting', 'authenticating', 'connected', 'error'],
|
||||
'transitions': {
|
||||
'initial': ['connecting'],
|
||||
'connecting': ['authenticating', 'connected', 'error'],
|
||||
'authenticating': ['connected', 'error'],
|
||||
'connected': ['error'],
|
||||
'error': []
|
||||
}
|
||||
"ssh": {
|
||||
"states": [
|
||||
"initial",
|
||||
"connecting",
|
||||
"authenticating",
|
||||
"connected",
|
||||
"error",
|
||||
],
|
||||
"transitions": {
|
||||
"initial": ["connecting"],
|
||||
"connecting": ["authenticating", "connected", "error"],
|
||||
"authenticating": ["connected", "error"],
|
||||
"connected": ["error"],
|
||||
"error": [],
|
||||
},
|
||||
},
|
||||
"vim": {
|
||||
"states": [
|
||||
"initial",
|
||||
"normal",
|
||||
"insert",
|
||||
"visual",
|
||||
"command",
|
||||
"exiting",
|
||||
],
|
||||
"transitions": {
|
||||
"initial": ["normal", "insert"],
|
||||
"normal": ["insert", "visual", "command", "exiting"],
|
||||
"insert": ["normal"],
|
||||
"visual": ["normal"],
|
||||
"command": ["normal", "exiting"],
|
||||
"exiting": [],
|
||||
},
|
||||
},
|
||||
'vim': {
|
||||
'states': ['initial', 'normal', 'insert', 'visual', 'command', 'exiting'],
|
||||
'transitions': {
|
||||
'initial': ['normal', 'insert'],
|
||||
'normal': ['insert', 'visual', 'command', 'exiting'],
|
||||
'insert': ['normal'],
|
||||
'visual': ['normal'],
|
||||
'command': ['normal', 'exiting'],
|
||||
'exiting': []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def detect_prompt(self, output, process_type='generic'):
|
||||
def detect_prompt(self, output, process_type="generic"):
|
||||
"""Detect what type of prompt is present in the output."""
|
||||
detections = {}
|
||||
|
||||
@ -125,93 +143,97 @@ class PromptDetector:
|
||||
|
||||
return detections
|
||||
|
||||
def get_response_suggestions(self, prompt_detections, process_type='generic'):
|
||||
def get_response_suggestions(self, prompt_detections, process_type="generic"):
|
||||
"""Get suggested responses based on detected prompts."""
|
||||
suggestions = []
|
||||
|
||||
for category, patterns in prompt_detections.items():
|
||||
if category == 'confirmation':
|
||||
suggestions.extend(['y', 'yes', 'n', 'no'])
|
||||
elif category == 'password':
|
||||
suggestions.append('<password>')
|
||||
elif category == 'sudo_password':
|
||||
suggestions.append('<sudo_password>')
|
||||
elif category == 'apt':
|
||||
if any('continue' in p for p in patterns):
|
||||
suggestions.extend(['y', 'yes'])
|
||||
elif category == 'vim':
|
||||
if any(':' in p for p in patterns):
|
||||
suggestions.extend(['w', 'q', 'wq', 'q!'])
|
||||
elif any('ENTER' in p for p in patterns):
|
||||
suggestions.append('\n')
|
||||
elif category == 'ssh':
|
||||
if any('yes/no' in p for p in patterns):
|
||||
suggestions.extend(['yes', 'no'])
|
||||
elif any('password' in p for p in patterns):
|
||||
suggestions.append('<password>')
|
||||
elif category == 'bash_prompt':
|
||||
suggestions.extend(['help', 'ls', 'pwd', 'exit'])
|
||||
if category == "confirmation":
|
||||
suggestions.extend(["y", "yes", "n", "no"])
|
||||
elif category == "password":
|
||||
suggestions.append("<password>")
|
||||
elif category == "sudo_password":
|
||||
suggestions.append("<sudo_password>")
|
||||
elif category == "apt":
|
||||
if any("continue" in p for p in patterns):
|
||||
suggestions.extend(["y", "yes"])
|
||||
elif category == "vim":
|
||||
if any(":" in p for p in patterns):
|
||||
suggestions.extend(["w", "q", "wq", "q!"])
|
||||
elif any("ENTER" in p for p in patterns):
|
||||
suggestions.append("\n")
|
||||
elif category == "ssh":
|
||||
if any("yes/no" in p for p in patterns):
|
||||
suggestions.extend(["yes", "no"])
|
||||
elif any("password" in p for p in patterns):
|
||||
suggestions.append("<password>")
|
||||
elif category == "bash_prompt":
|
||||
suggestions.extend(["help", "ls", "pwd", "exit"])
|
||||
|
||||
return list(set(suggestions)) # Remove duplicates
|
||||
|
||||
def update_session_state(self, session_name, output, process_type='generic'):
|
||||
def update_session_state(self, session_name, output, process_type="generic"):
|
||||
"""Update the state machine for a session based on output."""
|
||||
if session_name not in self.session_states:
|
||||
self.session_states[session_name] = {
|
||||
'current_state': 'initial',
|
||||
'process_type': process_type,
|
||||
'last_activity': time.time(),
|
||||
'transitions': []
|
||||
"current_state": "initial",
|
||||
"process_type": process_type,
|
||||
"last_activity": time.time(),
|
||||
"transitions": [],
|
||||
}
|
||||
|
||||
session_state = self.session_states[session_name]
|
||||
old_state = session_state['current_state']
|
||||
old_state = session_state["current_state"]
|
||||
|
||||
# Detect prompts and determine new state
|
||||
detections = self.detect_prompt(output, process_type)
|
||||
new_state = self._determine_state_from_detections(detections, process_type, old_state)
|
||||
new_state = self._determine_state_from_detections(
|
||||
detections, process_type, old_state
|
||||
)
|
||||
|
||||
if new_state != old_state:
|
||||
session_state['transitions'].append({
|
||||
'from': old_state,
|
||||
'to': new_state,
|
||||
'timestamp': time.time(),
|
||||
'trigger': detections
|
||||
})
|
||||
session_state['current_state'] = new_state
|
||||
session_state["transitions"].append(
|
||||
{
|
||||
"from": old_state,
|
||||
"to": new_state,
|
||||
"timestamp": time.time(),
|
||||
"trigger": detections,
|
||||
}
|
||||
)
|
||||
session_state["current_state"] = new_state
|
||||
|
||||
session_state['last_activity'] = time.time()
|
||||
session_state["last_activity"] = time.time()
|
||||
return new_state
|
||||
|
||||
def _determine_state_from_detections(self, detections, process_type, current_state):
|
||||
"""Determine new state based on prompt detections."""
|
||||
if process_type in self.state_machines:
|
||||
state_machine = self.state_machines[process_type]
|
||||
self.state_machines[process_type]
|
||||
|
||||
# State transition logic based on detections
|
||||
if 'confirmation' in detections and current_state in ['running', 'initial']:
|
||||
return 'confirming'
|
||||
elif 'password' in detections or 'sudo_password' in detections:
|
||||
return 'authenticating'
|
||||
elif 'error' in detections:
|
||||
return 'error'
|
||||
elif 'bash_prompt' in detections and current_state != 'initial':
|
||||
return 'connected' if process_type == 'ssh' else 'completed'
|
||||
elif 'vim' in detections:
|
||||
if any('-- INSERT --' in p for p in detections.get('vim', [])):
|
||||
return 'insert'
|
||||
elif any('-- VISUAL --' in p for p in detections.get('vim', [])):
|
||||
return 'visual'
|
||||
elif any(':' in p for p in detections.get('vim', [])):
|
||||
return 'command'
|
||||
if "confirmation" in detections and current_state in ["running", "initial"]:
|
||||
return "confirming"
|
||||
elif "password" in detections or "sudo_password" in detections:
|
||||
return "authenticating"
|
||||
elif "error" in detections:
|
||||
return "error"
|
||||
elif "bash_prompt" in detections and current_state != "initial":
|
||||
return "connected" if process_type == "ssh" else "completed"
|
||||
elif "vim" in detections:
|
||||
if any("-- INSERT --" in p for p in detections.get("vim", [])):
|
||||
return "insert"
|
||||
elif any("-- VISUAL --" in p for p in detections.get("vim", [])):
|
||||
return "visual"
|
||||
elif any(":" in p for p in detections.get("vim", [])):
|
||||
return "command"
|
||||
|
||||
# Default state progression
|
||||
if current_state == 'initial':
|
||||
return 'running'
|
||||
elif current_state == 'running' and detections:
|
||||
return 'waiting_input'
|
||||
elif current_state == 'waiting_input' and not detections:
|
||||
return 'running'
|
||||
if current_state == "initial":
|
||||
return "running"
|
||||
elif current_state == "running" and detections:
|
||||
return "waiting_input"
|
||||
elif current_state == "waiting_input" and not detections:
|
||||
return "running"
|
||||
|
||||
return current_state
|
||||
|
||||
@ -220,15 +242,15 @@ class PromptDetector:
|
||||
if session_name not in self.session_states:
|
||||
return False
|
||||
|
||||
state = self.session_states[session_name]['current_state']
|
||||
process_type = self.session_states[session_name]['process_type']
|
||||
state = self.session_states[session_name]["current_state"]
|
||||
process_type = self.session_states[session_name]["process_type"]
|
||||
|
||||
# States that typically indicate waiting for input
|
||||
waiting_states = {
|
||||
'generic': ['waiting_input'],
|
||||
'apt': ['confirming'],
|
||||
'ssh': ['authenticating'],
|
||||
'vim': ['command', 'insert', 'visual']
|
||||
"generic": ["waiting_input"],
|
||||
"apt": ["confirming"],
|
||||
"ssh": ["authenticating"],
|
||||
"vim": ["command", "insert", "visual"],
|
||||
}
|
||||
|
||||
return state in waiting_states.get(process_type, [])
|
||||
@ -236,10 +258,10 @@ class PromptDetector:
|
||||
def get_session_timeout(self, session_name):
|
||||
"""Get the timeout for a session based on its process type."""
|
||||
if session_name not in self.session_states:
|
||||
return self.timeout_configs['default']
|
||||
return self.timeout_configs["default"]
|
||||
|
||||
process_type = self.session_states[session_name]['process_type']
|
||||
return self.timeout_configs.get(process_type, self.timeout_configs['default'])
|
||||
process_type = self.session_states[session_name]["process_type"]
|
||||
return self.timeout_configs.get(process_type, self.timeout_configs["default"])
|
||||
|
||||
def check_for_timeouts(self):
|
||||
"""Check all sessions for timeouts and return timed out sessions."""
|
||||
@ -248,7 +270,7 @@ class PromptDetector:
|
||||
|
||||
for session_name, state in self.session_states.items():
|
||||
timeout = self.get_session_timeout(session_name)
|
||||
if current_time - state['last_activity'] > timeout:
|
||||
if current_time - state["last_activity"] > timeout:
|
||||
timed_out.append(session_name)
|
||||
|
||||
return timed_out
|
||||
@ -260,16 +282,18 @@ class PromptDetector:
|
||||
|
||||
state = self.session_states[session_name]
|
||||
return {
|
||||
'current_state': state['current_state'],
|
||||
'process_type': state['process_type'],
|
||||
'last_activity': state['last_activity'],
|
||||
'transitions': state['transitions'][-5:], # Last 5 transitions
|
||||
'is_waiting': self.is_waiting_for_input(session_name)
|
||||
"current_state": state["current_state"],
|
||||
"process_type": state["process_type"],
|
||||
"last_activity": state["last_activity"],
|
||||
"transitions": state["transitions"][-5:], # Last 5 transitions
|
||||
"is_waiting": self.is_waiting_for_input(session_name),
|
||||
}
|
||||
|
||||
|
||||
# Global detector instance
|
||||
_detector = None
|
||||
|
||||
|
||||
def get_global_detector():
|
||||
"""Get the global prompt detector instance."""
|
||||
global _detector
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import traceback
|
||||
from io import StringIO
|
||||
import contextlib
|
||||
|
||||
|
||||
def python_exec(code, python_globals):
|
||||
try:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
import json
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
|
||||
def http_fetch(url, headers=None):
|
||||
try:
|
||||
@ -11,26 +12,28 @@ def http_fetch(url, headers=None):
|
||||
req.add_header(key, value)
|
||||
|
||||
with urllib.request.urlopen(req) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
content = response.read().decode("utf-8")
|
||||
return {"status": "success", "content": content[:10000]}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def _perform_search(base_url, query, params=None):
|
||||
try:
|
||||
full_url = f"https://static.molodetz.nl/search.cgi?query={query}"
|
||||
|
||||
with urllib.request.urlopen(full_url) as response:
|
||||
content = response.read().decode('utf-8')
|
||||
content = response.read().decode("utf-8")
|
||||
return {"status": "success", "content": json.loads(content)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def web_search(query):
|
||||
base_url = "https://search.molodetz.nl/search"
|
||||
return _perform_search(base_url, query)
|
||||
|
||||
|
||||
def web_search_news(query):
|
||||
base_url = "https://search.molodetz.nl/search"
|
||||
return _perform_search(base_url, query)
|
||||
|
||||
|
||||
@ -1,5 +1,11 @@
|
||||
from pr.ui.colors import Colors
|
||||
from pr.ui.rendering import highlight_code, render_markdown
|
||||
from pr.ui.display import display_tool_call, print_autonomous_header
|
||||
from pr.ui.rendering import highlight_code, render_markdown
|
||||
|
||||
__all__ = ['Colors', 'highlight_code', 'render_markdown', 'display_tool_call', 'print_autonomous_header']
|
||||
__all__ = [
|
||||
"Colors",
|
||||
"highlight_code",
|
||||
"render_markdown",
|
||||
"display_tool_call",
|
||||
"print_autonomous_header",
|
||||
]
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
class Colors:
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
RED = '\033[91m'
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
BLUE = '\033[94m'
|
||||
MAGENTA = '\033[95m'
|
||||
CYAN = '\033[96m'
|
||||
GRAY = '\033[90m'
|
||||
WHITE = '\033[97m'
|
||||
BG_BLUE = '\033[44m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_RED = '\033[41m'
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
GRAY = "\033[90m"
|
||||
WHITE = "\033[97m"
|
||||
BG_BLUE = "\033[44m"
|
||||
BG_GREEN = "\033[42m"
|
||||
BG_RED = "\033[41m"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import difflib
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from .colors import Colors
|
||||
|
||||
|
||||
@ -19,8 +20,13 @@ class DiffStats:
|
||||
|
||||
|
||||
class DiffLine:
|
||||
def __init__(self, line_type: str, content: str, old_line_num: Optional[int] = None,
|
||||
new_line_num: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
line_type: str,
|
||||
content: str,
|
||||
old_line_num: Optional[int] = None,
|
||||
new_line_num: Optional[int] = None,
|
||||
):
|
||||
self.line_type = line_type
|
||||
self.content = content
|
||||
self.old_line_num = old_line_num
|
||||
@ -28,27 +34,27 @@ class DiffLine:
|
||||
|
||||
def format(self, show_line_nums: bool = True) -> str:
|
||||
color = {
|
||||
'add': Colors.GREEN,
|
||||
'delete': Colors.RED,
|
||||
'context': Colors.GRAY,
|
||||
'header': Colors.CYAN,
|
||||
'stats': Colors.BLUE
|
||||
"add": Colors.GREEN,
|
||||
"delete": Colors.RED,
|
||||
"context": Colors.GRAY,
|
||||
"header": Colors.CYAN,
|
||||
"stats": Colors.BLUE,
|
||||
}.get(self.line_type, Colors.RESET)
|
||||
|
||||
prefix = {
|
||||
'add': '+ ',
|
||||
'delete': '- ',
|
||||
'context': ' ',
|
||||
'header': '',
|
||||
'stats': ''
|
||||
}.get(self.line_type, ' ')
|
||||
"add": "+ ",
|
||||
"delete": "- ",
|
||||
"context": " ",
|
||||
"header": "",
|
||||
"stats": "",
|
||||
}.get(self.line_type, " ")
|
||||
|
||||
if show_line_nums and self.line_type in ('add', 'delete', 'context'):
|
||||
old_num = str(self.old_line_num) if self.old_line_num else ' '
|
||||
new_num = str(self.new_line_num) if self.new_line_num else ' '
|
||||
if show_line_nums and self.line_type in ("add", "delete", "context"):
|
||||
old_num = str(self.old_line_num) if self.old_line_num else " "
|
||||
new_num = str(self.new_line_num) if self.new_line_num else " "
|
||||
line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} "
|
||||
else:
|
||||
line_num_str = ''
|
||||
line_num_str = ""
|
||||
|
||||
return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}"
|
||||
|
||||
@ -57,8 +63,9 @@ class DiffDisplay:
|
||||
def __init__(self, context_lines: int = 3):
|
||||
self.context_lines = context_lines
|
||||
|
||||
def create_diff(self, old_content: str, new_content: str,
|
||||
filename: str = "file") -> Tuple[List[DiffLine], DiffStats]:
|
||||
def create_diff(
|
||||
self, old_content: str, new_content: str, filename: str = "file"
|
||||
) -> Tuple[List[DiffLine], DiffStats]:
|
||||
old_lines = old_content.splitlines(keepends=True)
|
||||
new_lines = new_content.splitlines(keepends=True)
|
||||
|
||||
@ -67,31 +74,38 @@ class DiffDisplay:
|
||||
stats.files_changed = 1
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
old_lines, new_lines,
|
||||
old_lines,
|
||||
new_lines,
|
||||
fromfile=f"a/{filename}",
|
||||
tofile=f"b/{filename}",
|
||||
n=self.context_lines
|
||||
n=self.context_lines,
|
||||
)
|
||||
|
||||
old_line_num = 0
|
||||
new_line_num = 0
|
||||
|
||||
for line in diff:
|
||||
if line.startswith('---') or line.startswith('+++'):
|
||||
diff_lines.append(DiffLine('header', line.rstrip()))
|
||||
elif line.startswith('@@'):
|
||||
diff_lines.append(DiffLine('header', line.rstrip()))
|
||||
if line.startswith("---") or line.startswith("+++"):
|
||||
diff_lines.append(DiffLine("header", line.rstrip()))
|
||||
elif line.startswith("@@"):
|
||||
diff_lines.append(DiffLine("header", line.rstrip()))
|
||||
old_line_num, new_line_num = self._parse_hunk_header(line)
|
||||
elif line.startswith('+'):
|
||||
elif line.startswith("+"):
|
||||
stats.insertions += 1
|
||||
diff_lines.append(DiffLine('add', line[1:].rstrip(), None, new_line_num))
|
||||
diff_lines.append(
|
||||
DiffLine("add", line[1:].rstrip(), None, new_line_num)
|
||||
)
|
||||
new_line_num += 1
|
||||
elif line.startswith('-'):
|
||||
elif line.startswith("-"):
|
||||
stats.deletions += 1
|
||||
diff_lines.append(DiffLine('delete', line[1:].rstrip(), old_line_num, None))
|
||||
diff_lines.append(
|
||||
DiffLine("delete", line[1:].rstrip(), old_line_num, None)
|
||||
)
|
||||
old_line_num += 1
|
||||
elif line.startswith(' '):
|
||||
diff_lines.append(DiffLine('context', line[1:].rstrip(), old_line_num, new_line_num))
|
||||
elif line.startswith(" "):
|
||||
diff_lines.append(
|
||||
DiffLine("context", line[1:].rstrip(), old_line_num, new_line_num)
|
||||
)
|
||||
old_line_num += 1
|
||||
new_line_num += 1
|
||||
|
||||
@ -101,15 +115,20 @@ class DiffDisplay:
|
||||
|
||||
def _parse_hunk_header(self, header: str) -> Tuple[int, int]:
|
||||
try:
|
||||
parts = header.split('@@')[1].strip().split()
|
||||
old_start = int(parts[0].split(',')[0].replace('-', ''))
|
||||
new_start = int(parts[1].split(',')[0].replace('+', ''))
|
||||
parts = header.split("@@")[1].strip().split()
|
||||
old_start = int(parts[0].split(",")[0].replace("-", ""))
|
||||
new_start = int(parts[1].split(",")[0].replace("+", ""))
|
||||
return old_start, new_start
|
||||
except (IndexError, ValueError):
|
||||
return 0, 0
|
||||
|
||||
def render_diff(self, diff_lines: List[DiffLine], stats: DiffStats,
|
||||
show_line_nums: bool = True, show_stats: bool = True) -> str:
|
||||
def render_diff(
|
||||
self,
|
||||
diff_lines: List[DiffLine],
|
||||
stats: DiffStats,
|
||||
show_line_nums: bool = True,
|
||||
show_stats: bool = True,
|
||||
) -> str:
|
||||
output = []
|
||||
|
||||
if show_stats:
|
||||
@ -124,10 +143,15 @@ class DiffDisplay:
|
||||
if show_stats:
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
return '\n'.join(output)
|
||||
return "\n".join(output)
|
||||
|
||||
def display_file_diff(self, old_content: str, new_content: str,
|
||||
filename: str = "file", show_line_nums: bool = True) -> str:
|
||||
def display_file_diff(
|
||||
self,
|
||||
old_content: str,
|
||||
new_content: str,
|
||||
filename: str = "file",
|
||||
show_line_nums: bool = True,
|
||||
) -> str:
|
||||
diff_lines, stats = self.create_diff(old_content, new_content, filename)
|
||||
|
||||
if not diff_lines:
|
||||
@ -135,8 +159,13 @@ class DiffDisplay:
|
||||
|
||||
return self.render_diff(diff_lines, stats, show_line_nums)
|
||||
|
||||
def display_side_by_side(self, old_content: str, new_content: str,
|
||||
filename: str = "file", width: int = 80) -> str:
|
||||
def display_side_by_side(
|
||||
self,
|
||||
old_content: str,
|
||||
new_content: str,
|
||||
filename: str = "file",
|
||||
width: int = 80,
|
||||
) -> str:
|
||||
old_lines = old_content.splitlines()
|
||||
new_lines = new_content.splitlines()
|
||||
|
||||
@ -144,40 +173,57 @@ class DiffDisplay:
|
||||
output = []
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}")
|
||||
output.append(
|
||||
f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}"
|
||||
)
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
|
||||
|
||||
half_width = (width - 5) // 2
|
||||
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == 'equal':
|
||||
for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])):
|
||||
if tag == "equal":
|
||||
for i, (old_line, new_line) in enumerate(
|
||||
zip(old_lines[i1:i2], new_lines[j1:j2])
|
||||
):
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}")
|
||||
elif tag == 'replace':
|
||||
output.append(
|
||||
f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}"
|
||||
)
|
||||
elif tag == "replace":
|
||||
max_lines = max(i2 - i1, j2 - j1)
|
||||
for i in range(max_lines):
|
||||
old_line = old_lines[i1 + i] if i1 + i < i2 else ""
|
||||
new_line = new_lines[j1 + i] if j1 + i < j2 else ""
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}")
|
||||
elif tag == 'delete':
|
||||
output.append(
|
||||
f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}"
|
||||
)
|
||||
elif tag == "delete":
|
||||
for old_line in old_lines[i1:i2]:
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}")
|
||||
elif tag == 'insert':
|
||||
output.append(
|
||||
f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}"
|
||||
)
|
||||
elif tag == "insert":
|
||||
for new_line in new_lines[j1:j2]:
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}")
|
||||
output.append(
|
||||
f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}"
|
||||
)
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def display_diff(old_content: str, new_content: str, filename: str = "file",
|
||||
format_type: str = "unified", context_lines: int = 3) -> str:
|
||||
def display_diff(
|
||||
old_content: str,
|
||||
new_content: str,
|
||||
filename: str = "file",
|
||||
format_type: str = "unified",
|
||||
context_lines: int = 3,
|
||||
) -> str:
|
||||
displayer = DiffDisplay(context_lines)
|
||||
|
||||
if format_type == "side-by-side":
|
||||
@ -191,9 +237,9 @@ def get_diff_stats(old_content: str, new_content: str) -> Dict[str, int]:
|
||||
_, stats = displayer.create_diff(old_content, new_content)
|
||||
|
||||
return {
|
||||
'insertions': stats.insertions,
|
||||
'deletions': stats.deletions,
|
||||
'modifications': stats.modifications,
|
||||
'total_changes': stats.total_changes,
|
||||
'files_changed': stats.files_changed
|
||||
"insertions": stats.insertions,
|
||||
"deletions": stats.deletions,
|
||||
"modifications": stats.modifications,
|
||||
"total_changes": stats.total_changes,
|
||||
"files_changed": stats.files_changed,
|
||||
}
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
from pr.ui.colors import Colors
|
||||
|
||||
|
||||
def display_tool_call(tool_name, arguments, status="running", result=None):
|
||||
if status == "running":
|
||||
return
|
||||
@ -15,8 +13,11 @@ def display_tool_call(tool_name, arguments, status="running", result=None):
|
||||
|
||||
print(f"{Colors.GRAY}{line}{Colors.RESET}")
|
||||
|
||||
|
||||
def print_autonomous_header(task):
|
||||
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
|
||||
print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
|
||||
print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n")
|
||||
|
||||
@ -1,12 +1,20 @@
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .colors import Colors
|
||||
from .progress import ProgressBar
|
||||
|
||||
|
||||
class EditOperation:
|
||||
def __init__(self, op_type: str, filepath: str, start_pos: int = 0,
|
||||
end_pos: int = 0, content: str = "", old_content: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
op_type: str,
|
||||
filepath: str,
|
||||
start_pos: int = 0,
|
||||
end_pos: int = 0,
|
||||
content: str = "",
|
||||
old_content: str = "",
|
||||
):
|
||||
self.op_type = op_type
|
||||
self.filepath = filepath
|
||||
self.start_pos = start_pos
|
||||
@ -18,40 +26,46 @@ class EditOperation:
|
||||
|
||||
def format_operation(self) -> str:
|
||||
op_colors = {
|
||||
'INSERT': Colors.GREEN,
|
||||
'REPLACE': Colors.YELLOW,
|
||||
'DELETE': Colors.RED,
|
||||
'WRITE': Colors.BLUE
|
||||
"INSERT": Colors.GREEN,
|
||||
"REPLACE": Colors.YELLOW,
|
||||
"DELETE": Colors.RED,
|
||||
"WRITE": Colors.BLUE,
|
||||
}
|
||||
|
||||
color = op_colors.get(self.op_type, Colors.RESET)
|
||||
status_icon = {
|
||||
'pending': '○',
|
||||
'in_progress': '◐',
|
||||
'completed': '●',
|
||||
'failed': '✗'
|
||||
}.get(self.status, '○')
|
||||
"pending": "○",
|
||||
"in_progress": "◐",
|
||||
"completed": "●",
|
||||
"failed": "✗",
|
||||
}.get(self.status, "○")
|
||||
|
||||
return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}"
|
||||
|
||||
def format_details(self, show_content: bool = True) -> str:
|
||||
output = [self.format_operation()]
|
||||
|
||||
if self.op_type in ('INSERT', 'REPLACE'):
|
||||
output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}")
|
||||
if self.op_type in ("INSERT", "REPLACE"):
|
||||
output.append(
|
||||
f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}"
|
||||
)
|
||||
|
||||
if show_content:
|
||||
if self.old_content:
|
||||
lines = self.old_content.split('\n')
|
||||
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
|
||||
lines = self.old_content.split("\n")
|
||||
preview = lines[0][:60] + (
|
||||
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
|
||||
)
|
||||
output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
|
||||
|
||||
if self.content:
|
||||
lines = self.content.split('\n')
|
||||
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
|
||||
lines = self.content.split("\n")
|
||||
preview = lines[0][:60] + (
|
||||
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
|
||||
)
|
||||
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
|
||||
|
||||
return '\n'.join(output)
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
class EditTracker:
|
||||
@ -76,11 +90,13 @@ class EditTracker:
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
stats = {
|
||||
'total': len(self.operations),
|
||||
'completed': sum(1 for op in self.operations if op.status == 'completed'),
|
||||
'pending': sum(1 for op in self.operations if op.status == 'pending'),
|
||||
'in_progress': sum(1 for op in self.operations if op.status == 'in_progress'),
|
||||
'failed': sum(1 for op in self.operations if op.status == 'failed')
|
||||
"total": len(self.operations),
|
||||
"completed": sum(1 for op in self.operations if op.status == "completed"),
|
||||
"pending": sum(1 for op in self.operations if op.status == "pending"),
|
||||
"in_progress": sum(
|
||||
1 for op in self.operations if op.status == "in_progress"
|
||||
),
|
||||
"failed": sum(1 for op in self.operations if op.status == "failed"),
|
||||
}
|
||||
return stats
|
||||
|
||||
@ -88,7 +104,7 @@ class EditTracker:
|
||||
if not self.operations:
|
||||
return 0.0
|
||||
stats = self.get_stats()
|
||||
return (stats['completed'] / stats['total']) * 100
|
||||
return (stats["completed"] / stats["total"]) * 100
|
||||
|
||||
def display_progress(self) -> str:
|
||||
if not self.operations:
|
||||
@ -96,26 +112,30 @@ class EditTracker:
|
||||
|
||||
output = []
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}")
|
||||
output.append(
|
||||
f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}"
|
||||
)
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
stats = self.get_stats()
|
||||
completion = self.get_completion_percentage()
|
||||
self.get_completion_percentage()
|
||||
|
||||
progress_bar = ProgressBar(total=stats['total'], width=40)
|
||||
progress_bar.current = stats['completed']
|
||||
progress_bar = ProgressBar(total=stats["total"], width=40)
|
||||
progress_bar.current = stats["completed"]
|
||||
bar_display = progress_bar._get_bar_display()
|
||||
|
||||
output.append(f"Progress: {bar_display}")
|
||||
output.append(f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
|
||||
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n")
|
||||
output.append(
|
||||
f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
|
||||
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n"
|
||||
)
|
||||
|
||||
output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}")
|
||||
for i, op in enumerate(self.operations[-5:], 1):
|
||||
output.append(f"{i}. {op.format_operation()}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
return "\n".join(output)
|
||||
|
||||
def display_timeline(self, show_content: bool = False) -> str:
|
||||
if not self.operations:
|
||||
@ -134,18 +154,20 @@ class EditTracker:
|
||||
|
||||
stats = self.get_stats()
|
||||
output.append(f"{Colors.BOLD}Summary:{Colors.RESET}")
|
||||
output.append(f"{Colors.BLUE}Total operations: {stats['total']}, "
|
||||
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}")
|
||||
output.append(
|
||||
f"{Colors.BLUE}Total operations: {stats['total']}, "
|
||||
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}"
|
||||
)
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
return "\n".join(output)
|
||||
|
||||
def display_summary(self) -> str:
|
||||
if not self.operations:
|
||||
return f"{Colors.GRAY}No edits to summarize{Colors.RESET}"
|
||||
|
||||
stats = self.get_stats()
|
||||
files_modified = len(set(op.filepath for op in self.operations))
|
||||
files_modified = len({op.filepath for op in self.operations})
|
||||
|
||||
output = []
|
||||
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}")
|
||||
@ -156,7 +178,7 @@ class EditTracker:
|
||||
output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}")
|
||||
output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}")
|
||||
|
||||
if stats['failed'] > 0:
|
||||
if stats["failed"] > 0:
|
||||
output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}")
|
||||
@ -168,7 +190,7 @@ class EditTracker:
|
||||
output.append(f" {op_type}: {count}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
|
||||
return '\n'.join(output)
|
||||
return "\n".join(output)
|
||||
|
||||
def clear(self):
|
||||
self.operations.clear()
|
||||
|
||||
@ -1,31 +1,31 @@
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OutputFormatter:
|
||||
|
||||
def __init__(self, format_type: str = 'text', quiet: bool = False):
|
||||
def __init__(self, format_type: str = "text", quiet: bool = False):
|
||||
self.format_type = format_type
|
||||
self.quiet = quiet
|
||||
|
||||
def output(self, data: Any, message_type: str = 'response'):
|
||||
if self.quiet and message_type not in ['error', 'result']:
|
||||
def output(self, data: Any, message_type: str = "response"):
|
||||
if self.quiet and message_type not in ["error", "result"]:
|
||||
return
|
||||
|
||||
if self.format_type == 'json':
|
||||
if self.format_type == "json":
|
||||
self._output_json(data, message_type)
|
||||
elif self.format_type == 'structured':
|
||||
elif self.format_type == "structured":
|
||||
self._output_structured(data, message_type)
|
||||
else:
|
||||
self._output_text(data, message_type)
|
||||
|
||||
def _output_json(self, data: Any, message_type: str):
|
||||
output = {
|
||||
'type': message_type,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'data': data
|
||||
"type": message_type,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data,
|
||||
}
|
||||
print(json.dumps(output, indent=2))
|
||||
|
||||
@ -46,24 +46,24 @@ class OutputFormatter:
|
||||
print(data)
|
||||
|
||||
def error(self, message: str):
|
||||
if self.format_type == 'json':
|
||||
self._output_json({'error': message}, 'error')
|
||||
if self.format_type == "json":
|
||||
self._output_json({"error": message}, "error")
|
||||
else:
|
||||
print(f"Error: {message}", file=sys.stderr)
|
||||
|
||||
def success(self, message: str):
|
||||
if not self.quiet:
|
||||
if self.format_type == 'json':
|
||||
self._output_json({'success': message}, 'success')
|
||||
if self.format_type == "json":
|
||||
self._output_json({"success": message}, "success")
|
||||
else:
|
||||
print(message)
|
||||
|
||||
def info(self, message: str):
|
||||
if not self.quiet:
|
||||
if self.format_type == 'json':
|
||||
self._output_json({'info': message}, 'info')
|
||||
if self.format_type == "json":
|
||||
self._output_json({"info": message}, "info")
|
||||
else:
|
||||
print(message)
|
||||
|
||||
def result(self, data: Any):
|
||||
self.output(data, 'result')
|
||||
self.output(data, "result")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class ProgressIndicator:
|
||||
@ -30,15 +30,15 @@ class ProgressIndicator:
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1.0)
|
||||
sys.stdout.write('\r' + ' ' * (len(self.message) + 10) + '\r')
|
||||
sys.stdout.write("\r" + " " * (len(self.message) + 10) + "\r")
|
||||
sys.stdout.flush()
|
||||
|
||||
def _animate(self):
|
||||
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
||||
spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
||||
idx = 0
|
||||
|
||||
while self.running:
|
||||
sys.stdout.write(f'\r{spinner[idx]} {self.message}...')
|
||||
sys.stdout.write(f"\r{spinner[idx]} {self.message}...")
|
||||
sys.stdout.flush()
|
||||
idx = (idx + 1) % len(spinner)
|
||||
time.sleep(0.1)
|
||||
@ -62,14 +62,20 @@ class ProgressBar:
|
||||
else:
|
||||
percent = int((self.current / self.total) * 100)
|
||||
|
||||
filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width
|
||||
bar = '█' * filled + '░' * (self.width - filled)
|
||||
filled = (
|
||||
int((self.current / self.total) * self.width)
|
||||
if self.total > 0
|
||||
else self.width
|
||||
)
|
||||
bar = "█" * filled + "░" * (self.width - filled)
|
||||
|
||||
sys.stdout.write(f'\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})')
|
||||
sys.stdout.write(
|
||||
f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
if self.current >= self.total:
|
||||
sys.stdout.write('\n')
|
||||
sys.stdout.write("\n")
|
||||
|
||||
def finish(self):
|
||||
self.current = self.total
|
||||
|
||||
@ -1,90 +1,103 @@
|
||||
import re
|
||||
from pr.ui.colors import Colors
|
||||
|
||||
from pr.config import LANGUAGE_KEYWORDS
|
||||
from pr.ui.colors import Colors
|
||||
|
||||
|
||||
def highlight_code(code, language=None, syntax_highlighting=True):
|
||||
if not syntax_highlighting:
|
||||
return code
|
||||
|
||||
if not language:
|
||||
if 'def ' in code or 'import ' in code:
|
||||
language = 'python'
|
||||
elif 'function ' in code or 'const ' in code:
|
||||
language = 'javascript'
|
||||
elif 'public ' in code or 'class ' in code:
|
||||
language = 'java'
|
||||
if "def " in code or "import " in code:
|
||||
language = "python"
|
||||
elif "function " in code or "const " in code:
|
||||
language = "javascript"
|
||||
elif "public " in code or "class " in code:
|
||||
language = "java"
|
||||
|
||||
if language and language in LANGUAGE_KEYWORDS:
|
||||
keywords = LANGUAGE_KEYWORDS[language]
|
||||
for keyword in keywords:
|
||||
pattern = r'\b' + re.escape(keyword) + r'\b'
|
||||
pattern = r"\b" + re.escape(keyword) + r"\b"
|
||||
code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code)
|
||||
|
||||
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
|
||||
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
|
||||
|
||||
code = re.sub(r'#(.*)$', f'{Colors.GRAY}#\\1{Colors.RESET}', code, flags=re.MULTILINE)
|
||||
code = re.sub(r'//(.*)$', f'{Colors.GRAY}//\\1{Colors.RESET}', code, flags=re.MULTILINE)
|
||||
code = re.sub(
|
||||
r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE
|
||||
)
|
||||
code = re.sub(
|
||||
r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def render_markdown(text, syntax_highlighting=True):
|
||||
if not syntax_highlighting:
|
||||
return text
|
||||
|
||||
code_blocks = []
|
||||
|
||||
def extract_code_block(match):
|
||||
lang = match.group(1) or ''
|
||||
lang = match.group(1) or ""
|
||||
code = match.group(2)
|
||||
highlighted_code = highlight_code(code.strip('\n'), lang, syntax_highlighting)
|
||||
highlighted_code = highlight_code(code.strip("\n"), lang, syntax_highlighting)
|
||||
placeholder = f"%%CODEBLOCK{len(code_blocks)}%%"
|
||||
full_block = f'{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}'
|
||||
full_block = f"{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}"
|
||||
code_blocks.append(full_block)
|
||||
return placeholder
|
||||
|
||||
text = re.sub(r'```(\w*)\n(.*?)\n?```', extract_code_block, text, flags=re.DOTALL)
|
||||
text = re.sub(r"```(\w*)\n(.*?)\n?```", extract_code_block, text, flags=re.DOTALL)
|
||||
|
||||
inline_codes = []
|
||||
|
||||
def extract_inline_code(match):
|
||||
code = match.group(1)
|
||||
placeholder = f"%%INLINECODE{len(inline_codes)}%%"
|
||||
inline_codes.append(f'{Colors.YELLOW}{code}{Colors.RESET}')
|
||||
inline_codes.append(f"{Colors.YELLOW}{code}{Colors.RESET}")
|
||||
return placeholder
|
||||
|
||||
text = re.sub(r'`([^`]+)`', extract_inline_code, text)
|
||||
text = re.sub(r"`([^`]+)`", extract_inline_code, text)
|
||||
|
||||
lines = text.split('\n')
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
for line in lines:
|
||||
if line.startswith('### '):
|
||||
line = f'{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}'
|
||||
elif line.startswith('## '):
|
||||
line = f'{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}'
|
||||
elif line.startswith('# '):
|
||||
line = f'{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}'
|
||||
elif line.startswith('> '):
|
||||
line = f'{Colors.CYAN}> {line[2:]}{Colors.RESET}'
|
||||
elif re.match(r'^\s*[\*\-\+]\s', line):
|
||||
match = re.match(r'^(\s*)([\*\-\+])(\s+.*)', line)
|
||||
if line.startswith("### "):
|
||||
line = f"{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}"
|
||||
elif line.startswith("## "):
|
||||
line = f"{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}"
|
||||
elif line.startswith("# "):
|
||||
line = f"{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}"
|
||||
elif line.startswith("> "):
|
||||
line = f"{Colors.CYAN}> {line[2:]}{Colors.RESET}"
|
||||
elif re.match(r"^\s*[\*\-\+]\s", line):
|
||||
match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line)
|
||||
if match:
|
||||
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
elif re.match(r'^\s*\d+\.\s', line):
|
||||
match = re.match(r'^(\s*)(\d+\.)(\s+.*)', line)
|
||||
elif re.match(r"^\s*\d+\.\s", line):
|
||||
match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line)
|
||||
if match:
|
||||
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
processed_lines.append(line)
|
||||
text = '\n'.join(processed_lines)
|
||||
text = "\n".join(processed_lines)
|
||||
|
||||
text = re.sub(r'\[(.*?)\]\((.*?)\)', f'{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}', text)
|
||||
text = re.sub(r'~~(.*?)~~', f'{Colors.GRAY}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'\*\*(.*?)\*\*', f'{Colors.BOLD}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'__(.*?)__', f'{Colors.BOLD}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'\*(.*?)\*', f'{Colors.CYAN}\\1{Colors.RESET}', text)
|
||||
text = re.sub(r'_(.*?)_', f'{Colors.CYAN}\\1{Colors.RESET}', text)
|
||||
text = re.sub(
|
||||
r"\[(.*?)\]\((.*?)\)",
|
||||
f"{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}",
|
||||
text,
|
||||
)
|
||||
text = re.sub(r"~~(.*?)~~", f"{Colors.GRAY}\\1{Colors.RESET}", text)
|
||||
text = re.sub(r"\*\*(.*?)\*\*", f"{Colors.BOLD}\\1{Colors.RESET}", text)
|
||||
text = re.sub(r"__(.*?)__", f"{Colors.BOLD}\\1{Colors.RESET}", text)
|
||||
text = re.sub(r"\*(.*?)\*", f"{Colors.CYAN}\\1{Colors.RESET}", text)
|
||||
text = re.sub(r"_(.*?)_", f"{Colors.CYAN}\\1{Colors.RESET}", text)
|
||||
|
||||
for i, code in enumerate(inline_codes):
|
||||
text = text.replace(f'%%INLINECODE{i}%%', code)
|
||||
text = text.replace(f"%%INLINECODE{i}%%", code)
|
||||
for i, block in enumerate(code_blocks):
|
||||
text = text.replace(f'%%CODEBLOCK{i}%%', block)
|
||||
text = text.replace(f"%%CODEBLOCK{i}%%", block)
|
||||
|
||||
return text
|
||||
|
||||
@ -1,5 +1,11 @@
|
||||
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
|
||||
from .workflow_definition import ExecutionMode, Workflow, WorkflowStep
|
||||
from .workflow_engine import WorkflowEngine
|
||||
from .workflow_storage import WorkflowStorage
|
||||
|
||||
__all__ = ['Workflow', 'WorkflowStep', 'ExecutionMode', 'WorkflowEngine', 'WorkflowStorage']
|
||||
__all__ = [
|
||||
"Workflow",
|
||||
"WorkflowStep",
|
||||
"ExecutionMode",
|
||||
"WorkflowEngine",
|
||||
"WorkflowStorage",
|
||||
]
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowStep:
|
||||
tool_name: str
|
||||
@ -20,29 +22,30 @@ class WorkflowStep:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'tool_name': self.tool_name,
|
||||
'arguments': self.arguments,
|
||||
'step_id': self.step_id,
|
||||
'condition': self.condition,
|
||||
'on_success': self.on_success,
|
||||
'on_failure': self.on_failure,
|
||||
'retry_count': self.retry_count,
|
||||
'timeout_seconds': self.timeout_seconds
|
||||
"tool_name": self.tool_name,
|
||||
"arguments": self.arguments,
|
||||
"step_id": self.step_id,
|
||||
"condition": self.condition,
|
||||
"on_success": self.on_success,
|
||||
"on_failure": self.on_failure,
|
||||
"retry_count": self.retry_count,
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> 'WorkflowStep':
|
||||
def from_dict(data: Dict[str, Any]) -> "WorkflowStep":
|
||||
return WorkflowStep(
|
||||
tool_name=data['tool_name'],
|
||||
arguments=data['arguments'],
|
||||
step_id=data['step_id'],
|
||||
condition=data.get('condition'),
|
||||
on_success=data.get('on_success'),
|
||||
on_failure=data.get('on_failure'),
|
||||
retry_count=data.get('retry_count', 0),
|
||||
timeout_seconds=data.get('timeout_seconds', 300)
|
||||
tool_name=data["tool_name"],
|
||||
arguments=data["arguments"],
|
||||
step_id=data["step_id"],
|
||||
condition=data.get("condition"),
|
||||
on_success=data.get("on_success"),
|
||||
on_failure=data.get("on_failure"),
|
||||
retry_count=data.get("retry_count", 0),
|
||||
timeout_seconds=data.get("timeout_seconds", 300),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Workflow:
|
||||
name: str
|
||||
@ -54,23 +57,23 @@ class Workflow:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'steps': [step.to_dict() for step in self.steps],
|
||||
'execution_mode': self.execution_mode.value,
|
||||
'variables': self.variables,
|
||||
'tags': self.tags
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"steps": [step.to_dict() for step in self.steps],
|
||||
"execution_mode": self.execution_mode.value,
|
||||
"variables": self.variables,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict[str, Any]) -> 'Workflow':
|
||||
def from_dict(data: Dict[str, Any]) -> "Workflow":
|
||||
return Workflow(
|
||||
name=data['name'],
|
||||
description=data['description'],
|
||||
steps=[WorkflowStep.from_dict(step) for step in data['steps']],
|
||||
execution_mode=ExecutionMode(data.get('execution_mode', 'sequential')),
|
||||
variables=data.get('variables', {}),
|
||||
tags=data.get('tags', [])
|
||||
name=data["name"],
|
||||
description=data["description"],
|
||||
steps=[WorkflowStep.from_dict(step) for step in data["steps"]],
|
||||
execution_mode=ExecutionMode(data.get("execution_mode", "sequential")),
|
||||
variables=data.get("variables", {}),
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
def add_step(self, step: WorkflowStep):
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import time
|
||||
import re
|
||||
from typing import Dict, Any, List, Callable, Optional
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from .workflow_definition import ExecutionMode, Workflow, WorkflowStep
|
||||
|
||||
|
||||
class WorkflowExecutionContext:
|
||||
def __init__(self):
|
||||
@ -23,57 +25,66 @@ class WorkflowExecutionContext:
|
||||
return self.step_results.get(step_id)
|
||||
|
||||
def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]):
|
||||
self.execution_log.append({
|
||||
'timestamp': time.time(),
|
||||
'event_type': event_type,
|
||||
'step_id': step_id,
|
||||
'details': details
|
||||
})
|
||||
self.execution_log.append(
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"event_type": event_type,
|
||||
"step_id": step_id,
|
||||
"details": details,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
def __init__(self, tool_executor: Callable, max_workers: int = 5):
|
||||
self.tool_executor = tool_executor
|
||||
self.max_workers = max_workers
|
||||
|
||||
def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool:
|
||||
def _evaluate_condition(
|
||||
self, condition: str, context: WorkflowExecutionContext
|
||||
) -> bool:
|
||||
if not condition:
|
||||
return True
|
||||
|
||||
try:
|
||||
safe_locals = {
|
||||
'variables': context.variables,
|
||||
'results': context.step_results
|
||||
"variables": context.variables,
|
||||
"results": context.step_results,
|
||||
}
|
||||
return eval(condition, {"__builtins__": {}}, safe_locals)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _substitute_variables(self, arguments: Dict[str, Any], context: WorkflowExecutionContext) -> Dict[str, Any]:
|
||||
def _substitute_variables(
|
||||
self, arguments: Dict[str, Any], context: WorkflowExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
substituted = {}
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, str):
|
||||
pattern = r'\$\{([^}]+)\}'
|
||||
pattern = r"\$\{([^}]+)\}"
|
||||
matches = re.findall(pattern, value)
|
||||
for match in matches:
|
||||
if match.startswith('step.'):
|
||||
step_id = match.split('.', 1)[1]
|
||||
if match.startswith("step."):
|
||||
step_id = match.split(".", 1)[1]
|
||||
replacement = context.get_step_result(step_id)
|
||||
if replacement is not None:
|
||||
value = value.replace(f'${{{match}}}', str(replacement))
|
||||
elif match.startswith('var.'):
|
||||
var_name = match.split('.', 1)[1]
|
||||
value = value.replace(f"${{{match}}}", str(replacement))
|
||||
elif match.startswith("var."):
|
||||
var_name = match.split(".", 1)[1]
|
||||
replacement = context.get_variable(var_name)
|
||||
if replacement is not None:
|
||||
value = value.replace(f'${{{match}}}', str(replacement))
|
||||
value = value.replace(f"${{{match}}}", str(replacement))
|
||||
substituted[key] = value
|
||||
else:
|
||||
substituted[key] = value
|
||||
return substituted
|
||||
|
||||
def _execute_step(self, step: WorkflowStep, context: WorkflowExecutionContext) -> Dict[str, Any]:
|
||||
def _execute_step(
|
||||
self, step: WorkflowStep, context: WorkflowExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
if not self._evaluate_condition(step.condition, context):
|
||||
context.log_event('skipped', step.step_id, {'reason': 'condition_not_met'})
|
||||
return {'status': 'skipped', 'step_id': step.step_id}
|
||||
context.log_event("skipped", step.step_id, {"reason": "condition_not_met"})
|
||||
return {"status": "skipped", "step_id": step.step_id}
|
||||
|
||||
arguments = self._substitute_variables(step.arguments, context)
|
||||
|
||||
@ -83,26 +94,34 @@ class WorkflowEngine:
|
||||
|
||||
while retry_attempts <= step.retry_count:
|
||||
try:
|
||||
context.log_event('executing', step.step_id, {
|
||||
'tool': step.tool_name,
|
||||
'arguments': arguments,
|
||||
'attempt': retry_attempts + 1
|
||||
})
|
||||
context.log_event(
|
||||
"executing",
|
||||
step.step_id,
|
||||
{
|
||||
"tool": step.tool_name,
|
||||
"arguments": arguments,
|
||||
"attempt": retry_attempts + 1,
|
||||
},
|
||||
)
|
||||
|
||||
result = self.tool_executor(step.tool_name, arguments)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
context.set_step_result(step.step_id, result)
|
||||
context.log_event('completed', step.step_id, {
|
||||
'execution_time': execution_time,
|
||||
'result_size': len(str(result)) if result else 0
|
||||
})
|
||||
context.log_event(
|
||||
"completed",
|
||||
step.step_id,
|
||||
{
|
||||
"execution_time": execution_time,
|
||||
"result_size": len(str(result)) if result else 0,
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'step_id': step.step_id,
|
||||
'result': result,
|
||||
'execution_time': execution_time
|
||||
"status": "success",
|
||||
"step_id": step.step_id,
|
||||
"result": result,
|
||||
"execution_time": execution_time,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -111,25 +130,26 @@ class WorkflowEngine:
|
||||
if retry_attempts <= step.retry_count:
|
||||
time.sleep(1 * retry_attempts)
|
||||
|
||||
context.log_event('failed', step.step_id, {'error': last_error})
|
||||
context.log_event("failed", step.step_id, {"error": last_error})
|
||||
return {
|
||||
'status': 'failed',
|
||||
'step_id': step.step_id,
|
||||
'error': last_error,
|
||||
'execution_time': time.time() - start_time
|
||||
"status": "failed",
|
||||
"step_id": step.step_id,
|
||||
"error": last_error,
|
||||
"execution_time": time.time() - start_time,
|
||||
}
|
||||
|
||||
def _get_next_steps(self, completed_step: WorkflowStep, result: Dict[str, Any],
|
||||
workflow: Workflow) -> List[WorkflowStep]:
|
||||
def _get_next_steps(
|
||||
self, completed_step: WorkflowStep, result: Dict[str, Any], workflow: Workflow
|
||||
) -> List[WorkflowStep]:
|
||||
next_steps = []
|
||||
|
||||
if result['status'] == 'success' and completed_step.on_success:
|
||||
if result["status"] == "success" and completed_step.on_success:
|
||||
for step_id in completed_step.on_success:
|
||||
step = workflow.get_step(step_id)
|
||||
if step:
|
||||
next_steps.append(step)
|
||||
|
||||
elif result['status'] == 'failed' and completed_step.on_failure:
|
||||
elif result["status"] == "failed" and completed_step.on_failure:
|
||||
for step_id in completed_step.on_failure:
|
||||
step = workflow.get_step(step_id)
|
||||
if step:
|
||||
@ -142,7 +162,9 @@ class WorkflowEngine:
|
||||
|
||||
return next_steps
|
||||
|
||||
def execute_workflow(self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None) -> WorkflowExecutionContext:
|
||||
def execute_workflow(
|
||||
self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None
|
||||
) -> WorkflowExecutionContext:
|
||||
context = WorkflowExecutionContext()
|
||||
|
||||
if initial_variables:
|
||||
@ -151,7 +173,7 @@ class WorkflowEngine:
|
||||
if workflow.variables:
|
||||
context.variables.update(workflow.variables)
|
||||
|
||||
context.log_event('workflow_started', 'workflow', {'name': workflow.name})
|
||||
context.log_event("workflow_started", "workflow", {"name": workflow.name})
|
||||
|
||||
if workflow.execution_mode == ExecutionMode.PARALLEL:
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
@ -164,9 +186,11 @@ class WorkflowEngine:
|
||||
step = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
context.log_event('step_completed', step.step_id, result)
|
||||
context.log_event("step_completed", step.step_id, result)
|
||||
except Exception as e:
|
||||
context.log_event('step_failed', step.step_id, {'error': str(e)})
|
||||
context.log_event(
|
||||
"step_failed", step.step_id, {"error": str(e)}
|
||||
)
|
||||
|
||||
else:
|
||||
pending_steps = workflow.get_initial_steps()
|
||||
@ -184,9 +208,13 @@ class WorkflowEngine:
|
||||
next_steps = self._get_next_steps(step, result, workflow)
|
||||
pending_steps.extend(next_steps)
|
||||
|
||||
context.log_event('workflow_completed', 'workflow', {
|
||||
'total_steps': len(context.step_results),
|
||||
'executed_steps': list(context.step_results.keys())
|
||||
})
|
||||
context.log_event(
|
||||
"workflow_completed",
|
||||
"workflow",
|
||||
{
|
||||
"total_steps": len(context.step_results),
|
||||
"executed_steps": list(context.step_results.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@ -2,8 +2,10 @@ import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from .workflow_definition import Workflow
|
||||
|
||||
|
||||
class WorkflowStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
@ -13,7 +15,8 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
workflow_id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
@ -25,9 +28,11 @@ class WorkflowStorage:
|
||||
last_execution_at INTEGER,
|
||||
tags TEXT
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS workflow_executions (
|
||||
execution_id TEXT PRIMARY KEY,
|
||||
workflow_id TEXT NOT NULL,
|
||||
@ -39,17 +44,24 @@ class WorkflowStorage:
|
||||
step_results TEXT,
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id)
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id)
|
||||
''')
|
||||
cursor.execute('''
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -66,12 +78,22 @@ class WorkflowStorage:
|
||||
current_time = int(time.time())
|
||||
tags_json = json.dumps(workflow.tags)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO workflows
|
||||
(workflow_id, name, description, workflow_data, created_at, updated_at, tags)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (workflow_id, workflow.name, workflow.description, workflow_data,
|
||||
current_time, current_time, tags_json))
|
||||
""",
|
||||
(
|
||||
workflow_id,
|
||||
workflow.name,
|
||||
workflow.description,
|
||||
workflow_data,
|
||||
current_time,
|
||||
current_time,
|
||||
tags_json,
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -82,7 +104,9 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT workflow_data FROM workflows WHERE workflow_id = ?', (workflow_id,))
|
||||
cursor.execute(
|
||||
"SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@ -95,7 +119,7 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT workflow_data FROM workflows WHERE name = ?', (name,))
|
||||
cursor.execute("SELECT workflow_data FROM workflows WHERE name = ?", (name,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@ -109,29 +133,36 @@ class WorkflowStorage:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if tag:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
|
||||
FROM workflows
|
||||
WHERE tags LIKE ?
|
||||
ORDER BY name
|
||||
''', (f'%"{tag}"%',))
|
||||
""",
|
||||
(f'%"{tag}"%',),
|
||||
)
|
||||
else:
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
|
||||
FROM workflows
|
||||
ORDER BY name
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
workflows = []
|
||||
for row in cursor.fetchall():
|
||||
workflows.append({
|
||||
'workflow_id': row[0],
|
||||
'name': row[1],
|
||||
'description': row[2],
|
||||
'execution_count': row[3],
|
||||
'last_execution_at': row[4],
|
||||
'tags': json.loads(row[5]) if row[5] else []
|
||||
})
|
||||
workflows.append(
|
||||
{
|
||||
"workflow_id": row[0],
|
||||
"name": row[1],
|
||||
"description": row[2],
|
||||
"execution_count": row[3],
|
||||
"last_execution_at": row[4],
|
||||
"tags": json.loads(row[5]) if row[5] else [],
|
||||
}
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return workflows
|
||||
@ -140,18 +171,21 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM workflows WHERE workflow_id = ?', (workflow_id,))
|
||||
cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
|
||||
cursor.execute('DELETE FROM workflow_executions WHERE workflow_id = ?', (workflow_id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return deleted
|
||||
|
||||
def save_execution(self, workflow_id: str, execution_context: 'WorkflowExecutionContext') -> str:
|
||||
import hashlib
|
||||
def save_execution(
|
||||
self, workflow_id: str, execution_context: "WorkflowExecutionContext"
|
||||
) -> str:
|
||||
import uuid
|
||||
|
||||
execution_id = str(uuid.uuid4())[:16]
|
||||
@ -159,30 +193,40 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
started_at = int(execution_context.execution_log[0]['timestamp']) if execution_context.execution_log else int(time.time())
|
||||
started_at = (
|
||||
int(execution_context.execution_log[0]["timestamp"])
|
||||
if execution_context.execution_log
|
||||
else int(time.time())
|
||||
)
|
||||
completed_at = int(time.time())
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO workflow_executions
|
||||
(execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
execution_id,
|
||||
workflow_id,
|
||||
started_at,
|
||||
completed_at,
|
||||
'completed',
|
||||
json.dumps(execution_context.execution_log),
|
||||
json.dumps(execution_context.variables),
|
||||
json.dumps(execution_context.step_results)
|
||||
))
|
||||
""",
|
||||
(
|
||||
execution_id,
|
||||
workflow_id,
|
||||
started_at,
|
||||
completed_at,
|
||||
"completed",
|
||||
json.dumps(execution_context.execution_log),
|
||||
json.dumps(execution_context.variables),
|
||||
json.dumps(execution_context.step_results),
|
||||
),
|
||||
)
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE workflows
|
||||
SET execution_count = execution_count + 1,
|
||||
last_execution_at = ?
|
||||
WHERE workflow_id = ?
|
||||
''', (completed_at, workflow_id))
|
||||
""",
|
||||
(completed_at, workflow_id),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@ -193,22 +237,27 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT execution_id, started_at, completed_at, status
|
||||
FROM workflow_executions
|
||||
WHERE workflow_id = ?
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
''', (workflow_id, limit))
|
||||
""",
|
||||
(workflow_id, limit),
|
||||
)
|
||||
|
||||
executions = []
|
||||
for row in cursor.fetchall():
|
||||
executions.append({
|
||||
'execution_id': row[0],
|
||||
'started_at': row[1],
|
||||
'completed_at': row[2],
|
||||
'status': row[3]
|
||||
})
|
||||
executions.append(
|
||||
{
|
||||
"execution_id": row[0],
|
||||
"started_at": row[1],
|
||||
"completed_at": row[2],
|
||||
"status": row[3],
|
||||
}
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return executions
|
||||
|
||||
3
rp.py
3
rp.py
@ -2,8 +2,7 @@
|
||||
|
||||
# Trigger build
|
||||
|
||||
import sys
|
||||
from pr.__main__ import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
@ -13,19 +14,8 @@ def temp_dir():
|
||||
@pytest.fixture
|
||||
def mock_api_response():
|
||||
return {
|
||||
'choices': [
|
||||
{
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': 'Test response'
|
||||
}
|
||||
}
|
||||
],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 5,
|
||||
'total_tokens': 15
|
||||
}
|
||||
"choices": [{"message": {"role": "assistant", "content": "Test response"}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
|
||||
|
||||
@ -47,7 +37,7 @@ def mock_args():
|
||||
|
||||
@pytest.fixture
|
||||
def sample_context_file(temp_dir):
|
||||
context_path = os.path.join(temp_dir, '.rcontext.txt')
|
||||
with open(context_path, 'w') as f:
|
||||
f.write('Sample context content\n')
|
||||
context_path = os.path.join(temp_dir, ".rcontext.txt")
|
||||
with open(context_path, "w") as f:
|
||||
f.write("Sample context content\n")
|
||||
return context_path
|
||||
|
||||
@ -1,39 +1,53 @@
|
||||
import pytest
|
||||
from pr.core.advanced_context import AdvancedContextManager
|
||||
|
||||
|
||||
def test_adaptive_context_window_simple():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
|
||||
window = mgr.adaptive_context_window(messages, 'simple')
|
||||
messages = [
|
||||
{"content": "short"},
|
||||
{"content": "this is a longer message with more words"},
|
||||
]
|
||||
window = mgr.adaptive_context_window(messages, "simple")
|
||||
assert isinstance(window, int)
|
||||
assert window >= 10
|
||||
|
||||
|
||||
def test_adaptive_context_window_medium():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
|
||||
window = mgr.adaptive_context_window(messages, 'medium')
|
||||
messages = [
|
||||
{"content": "short"},
|
||||
{"content": "this is a longer message with more words"},
|
||||
]
|
||||
window = mgr.adaptive_context_window(messages, "medium")
|
||||
assert isinstance(window, int)
|
||||
assert window >= 20
|
||||
|
||||
|
||||
def test_adaptive_context_window_complex():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
|
||||
window = mgr.adaptive_context_window(messages, 'complex')
|
||||
messages = [
|
||||
{"content": "short"},
|
||||
{"content": "this is a longer message with more words"},
|
||||
]
|
||||
window = mgr.adaptive_context_window(messages, "complex")
|
||||
assert isinstance(window, int)
|
||||
assert window >= 35
|
||||
|
||||
|
||||
def test_analyze_message_complexity():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'hello world'}, {'content': 'hello again'}]
|
||||
messages = [{"content": "hello world"}, {"content": "hello again"}]
|
||||
score = mgr._analyze_message_complexity(messages)
|
||||
assert 0 <= score <= 1
|
||||
|
||||
|
||||
def test_analyze_message_complexity_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = []
|
||||
score = mgr._analyze_message_complexity(messages)
|
||||
assert score == 0
|
||||
|
||||
|
||||
def test_extract_key_sentences():
|
||||
mgr = AdvancedContextManager()
|
||||
text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words."
|
||||
@ -41,41 +55,47 @@ def test_extract_key_sentences():
|
||||
assert len(sentences) <= 2
|
||||
assert all(isinstance(s, str) for s in sentences)
|
||||
|
||||
|
||||
def test_extract_key_sentences_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
text = ""
|
||||
sentences = mgr.extract_key_sentences(text, 5)
|
||||
assert sentences == []
|
||||
|
||||
|
||||
def test_advanced_summarize_messages():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'Hello'}, {'content': 'How are you?'}]
|
||||
messages = [{"content": "Hello"}, {"content": "How are you?"}]
|
||||
summary = mgr.advanced_summarize_messages(messages)
|
||||
assert isinstance(summary, str)
|
||||
|
||||
|
||||
def test_advanced_summarize_messages_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = []
|
||||
summary = mgr.advanced_summarize_messages(messages)
|
||||
assert summary == "No content to summarize."
|
||||
|
||||
|
||||
def test_score_message_relevance():
|
||||
mgr = AdvancedContextManager()
|
||||
message = {'content': 'hello world'}
|
||||
context = 'world hello'
|
||||
message = {"content": "hello world"}
|
||||
context = "world hello"
|
||||
score = mgr.score_message_relevance(message, context)
|
||||
assert 0 <= score <= 1
|
||||
|
||||
|
||||
def test_score_message_relevance_no_overlap():
|
||||
mgr = AdvancedContextManager()
|
||||
message = {'content': 'hello'}
|
||||
context = 'world'
|
||||
message = {"content": "hello"}
|
||||
context = "world"
|
||||
score = mgr.score_message_relevance(message, context)
|
||||
assert score == 0
|
||||
|
||||
|
||||
def test_score_message_relevance_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
message = {'content': ''}
|
||||
context = ''
|
||||
message = {"content": ""}
|
||||
context = ""
|
||||
score = mgr.score_message_relevance(message, context)
|
||||
assert score == 0
|
||||
assert score == 0
|
||||
|
||||
@ -1,127 +1,213 @@
|
||||
import pytest
|
||||
import time
|
||||
from pr.agents.agent_communication import (
|
||||
AgentCommunicationBus,
|
||||
AgentMessage,
|
||||
MessageType,
|
||||
)
|
||||
from pr.agents.agent_manager import AgentInstance, AgentManager
|
||||
from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles
|
||||
from pr.agents.agent_manager import AgentManager, AgentInstance
|
||||
from pr.agents.agent_communication import AgentCommunicationBus, AgentMessage, MessageType
|
||||
|
||||
|
||||
def test_get_agent_role():
|
||||
role = get_agent_role('coding')
|
||||
role = get_agent_role("coding")
|
||||
assert isinstance(role, AgentRole)
|
||||
assert role.name == 'coding'
|
||||
assert role.name == "coding"
|
||||
|
||||
|
||||
def test_list_agent_roles():
|
||||
roles = list_agent_roles()
|
||||
assert isinstance(roles, dict)
|
||||
assert len(roles) > 0
|
||||
assert 'coding' in roles
|
||||
assert "coding" in roles
|
||||
|
||||
|
||||
def test_agent_role():
|
||||
role = AgentRole(name='test', description='test', system_prompt='test', allowed_tools=set(), specialization_areas=[])
|
||||
assert role.name == 'test'
|
||||
role = AgentRole(
|
||||
name="test",
|
||||
description="test",
|
||||
system_prompt="test",
|
||||
allowed_tools=set(),
|
||||
specialization_areas=[],
|
||||
)
|
||||
assert role.name == "test"
|
||||
|
||||
|
||||
def test_agent_instance():
|
||||
role = get_agent_role('coding')
|
||||
instance = AgentInstance(agent_id='test', role=role)
|
||||
assert instance.agent_id == 'test'
|
||||
role = get_agent_role("coding")
|
||||
instance = AgentInstance(agent_id="test", role=role)
|
||||
assert instance.agent_id == "test"
|
||||
assert instance.role == role
|
||||
|
||||
|
||||
def test_agent_manager_init():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr = AgentManager(":memory:", None)
|
||||
assert mgr is not None
|
||||
|
||||
|
||||
def test_agent_manager_create_agent():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
agent = mgr.create_agent('coding', 'test_agent')
|
||||
mgr = AgentManager(":memory:", None)
|
||||
agent = mgr.create_agent("coding", "test_agent")
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_agent_manager_get_agent():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test_agent')
|
||||
agent = mgr.get_agent('test_agent')
|
||||
mgr = AgentManager(":memory:", None)
|
||||
mgr.create_agent("coding", "test_agent")
|
||||
agent = mgr.get_agent("test_agent")
|
||||
assert isinstance(agent, AgentInstance)
|
||||
|
||||
|
||||
def test_agent_manager_remove_agent():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test_agent')
|
||||
mgr.remove_agent('test_agent')
|
||||
agent = mgr.get_agent('test_agent')
|
||||
mgr = AgentManager(":memory:", None)
|
||||
mgr.create_agent("coding", "test_agent")
|
||||
mgr.remove_agent("test_agent")
|
||||
agent = mgr.get_agent("test_agent")
|
||||
assert agent is None
|
||||
|
||||
|
||||
def test_agent_manager_send_agent_message():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'a')
|
||||
mgr.create_agent('coding', 'b')
|
||||
mgr.send_agent_message('a', 'b', 'test')
|
||||
mgr = AgentManager(":memory:", None)
|
||||
mgr.create_agent("coding", "a")
|
||||
mgr.create_agent("coding", "b")
|
||||
mgr.send_agent_message("a", "b", "test")
|
||||
assert True
|
||||
|
||||
|
||||
def test_agent_manager_get_agent_messages():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test')
|
||||
messages = mgr.get_agent_messages('test')
|
||||
mgr = AgentManager(":memory:", None)
|
||||
mgr.create_agent("coding", "test")
|
||||
messages = mgr.get_agent_messages("test")
|
||||
assert isinstance(messages, list)
|
||||
|
||||
|
||||
def test_agent_manager_get_session_summary():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr = AgentManager(":memory:", None)
|
||||
summary = mgr.get_session_summary()
|
||||
assert isinstance(summary, str)
|
||||
|
||||
|
||||
def test_agent_manager_collaborate_agents():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
result = mgr.collaborate_agents('orchestrator', 'task', ['coding', 'research'])
|
||||
mgr = AgentManager(":memory:", None)
|
||||
result = mgr.collaborate_agents("orchestrator", "task", ["coding", "research"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_agent_manager_execute_agent_task():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr.create_agent('coding', 'test')
|
||||
result = mgr.execute_agent_task('test', 'task')
|
||||
mgr = AgentManager(":memory:", None)
|
||||
mgr.create_agent("coding", "test")
|
||||
result = mgr.execute_agent_task("test", "task")
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_agent_manager_clear_session():
|
||||
mgr = AgentManager(':memory:', None)
|
||||
mgr = AgentManager(":memory:", None)
|
||||
mgr.clear_session()
|
||||
assert True
|
||||
|
||||
|
||||
def test_agent_message():
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
assert msg.from_agent == 'a'
|
||||
msg = AgentMessage(
|
||||
from_agent="a",
|
||||
to_agent="b",
|
||||
message_type=MessageType.REQUEST,
|
||||
content="test",
|
||||
metadata={},
|
||||
timestamp=1.0,
|
||||
message_id="id",
|
||||
)
|
||||
assert msg.from_agent == "a"
|
||||
|
||||
|
||||
def test_agent_message_to_dict():
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
msg = AgentMessage(
|
||||
from_agent="a",
|
||||
to_agent="b",
|
||||
message_type=MessageType.REQUEST,
|
||||
content="test",
|
||||
metadata={},
|
||||
timestamp=1.0,
|
||||
message_id="id",
|
||||
)
|
||||
d = msg.to_dict()
|
||||
assert isinstance(d, dict)
|
||||
|
||||
|
||||
def test_agent_message_from_dict():
|
||||
d = {'from_agent': 'a', 'to_agent': 'b', 'message_type': 'request', 'content': 'test', 'metadata': {}, 'timestamp': 1.0, 'message_id': 'id'}
|
||||
d = {
|
||||
"from_agent": "a",
|
||||
"to_agent": "b",
|
||||
"message_type": "request",
|
||||
"content": "test",
|
||||
"metadata": {},
|
||||
"timestamp": 1.0,
|
||||
"message_id": "id",
|
||||
}
|
||||
msg = AgentMessage.from_dict(d)
|
||||
assert isinstance(msg, AgentMessage)
|
||||
|
||||
|
||||
def test_agent_communication_bus_init():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
bus = AgentCommunicationBus(":memory:")
|
||||
assert bus is not None
|
||||
|
||||
|
||||
def test_agent_communication_bus_send_message():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus = AgentCommunicationBus(":memory:")
|
||||
msg = AgentMessage(
|
||||
from_agent="a",
|
||||
to_agent="b",
|
||||
message_type=MessageType.REQUEST,
|
||||
content="test",
|
||||
metadata={},
|
||||
timestamp=1.0,
|
||||
message_id="id",
|
||||
)
|
||||
bus.send_message(msg)
|
||||
assert True
|
||||
|
||||
|
||||
def test_agent_communication_bus_receive_messages():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus = AgentCommunicationBus(":memory:")
|
||||
msg = AgentMessage(
|
||||
from_agent="a",
|
||||
to_agent="b",
|
||||
message_type=MessageType.REQUEST,
|
||||
content="test",
|
||||
metadata={},
|
||||
timestamp=1.0,
|
||||
message_id="id",
|
||||
)
|
||||
bus.send_message(msg)
|
||||
messages = bus.receive_messages('b')
|
||||
messages = bus.receive_messages("b")
|
||||
assert len(messages) == 1
|
||||
|
||||
|
||||
def test_agent_communication_bus_get_conversation_history():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus = AgentCommunicationBus(":memory:")
|
||||
msg = AgentMessage(
|
||||
from_agent="a",
|
||||
to_agent="b",
|
||||
message_type=MessageType.REQUEST,
|
||||
content="test",
|
||||
metadata={},
|
||||
timestamp=1.0,
|
||||
message_id="id",
|
||||
)
|
||||
bus.send_message(msg)
|
||||
history = bus.get_conversation_history('a', 'b')
|
||||
history = bus.get_conversation_history("a", "b")
|
||||
assert len(history) == 1
|
||||
|
||||
|
||||
def test_agent_communication_bus_mark_as_read():
|
||||
bus = AgentCommunicationBus(':memory:')
|
||||
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
|
||||
bus = AgentCommunicationBus(":memory:")
|
||||
msg = AgentMessage(
|
||||
from_agent="a",
|
||||
to_agent="b",
|
||||
message_type=MessageType.REQUEST,
|
||||
content="test",
|
||||
metadata={},
|
||||
timestamp=1.0,
|
||||
message_id="id",
|
||||
)
|
||||
bus.send_message(msg)
|
||||
bus.mark_as_read(msg.message_id)
|
||||
assert True
|
||||
assert True
|
||||
|
||||
@ -1,63 +1,65 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
import urllib.error
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pr.core.api import call_api, list_models
|
||||
|
||||
|
||||
class TestApi(unittest.TestCase):
|
||||
|
||||
@patch('pr.core.api.urllib.request.urlopen')
|
||||
@patch('pr.core.api.auto_slim_messages')
|
||||
@patch("pr.core.api.urllib.request.urlopen")
|
||||
@patch("pr.core.api.auto_slim_messages")
|
||||
def test_call_api_success(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
|
||||
mock_slim.return_value = [{"role": "user", "content": "test"}]
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}])
|
||||
result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}])
|
||||
|
||||
self.assertIn('choices', result)
|
||||
self.assertIn("choices", result)
|
||||
mock_urlopen.assert_called_once()
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
@patch('pr.core.api.auto_slim_messages')
|
||||
@patch("urllib.request.urlopen")
|
||||
@patch("pr.core.api.auto_slim_messages")
|
||||
def test_call_api_http_error(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock())
|
||||
mock_slim.return_value = [{"role": "user", "content": "test"}]
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError(
|
||||
"http://url", 500, "error", None, MagicMock()
|
||||
)
|
||||
|
||||
result = call_api([], 'model', 'http://url', 'key', False, [])
|
||||
result = call_api([], "model", "http://url", "key", False, [])
|
||||
|
||||
self.assertIn('error', result)
|
||||
self.assertIn("error", result)
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
@patch('pr.core.api.auto_slim_messages')
|
||||
@patch("urllib.request.urlopen")
|
||||
@patch("pr.core.api.auto_slim_messages")
|
||||
def test_call_api_general_error(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
|
||||
mock_urlopen.side_effect = Exception('test error')
|
||||
mock_slim.return_value = [{"role": "user", "content": "test"}]
|
||||
mock_urlopen.side_effect = Exception("test error")
|
||||
|
||||
result = call_api([], 'model', 'http://url', 'key', False, [])
|
||||
result = call_api([], "model", "http://url", "key", False, [])
|
||||
|
||||
self.assertIn('error', result)
|
||||
self.assertIn("error", result)
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_list_models_success(self, mock_urlopen):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b'{"data": [{"id": "model1"}]}'
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
result = list_models('http://url', 'key')
|
||||
result = list_models("http://url", "key")
|
||||
|
||||
self.assertEqual(result, [{'id': 'model1'}])
|
||||
self.assertEqual(result, [{"id": "model1"}])
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
@patch("urllib.request.urlopen")
|
||||
def test_list_models_error(self, mock_urlopen):
|
||||
mock_urlopen.side_effect = Exception('error')
|
||||
mock_urlopen.side_effect = Exception("error")
|
||||
|
||||
result = list_models('http://url', 'key')
|
||||
result = list_models("http://url", "key")
|
||||
|
||||
self.assertIn('error', result)
|
||||
self.assertIn("error", result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import tempfile
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pr.core.assistant import Assistant, process_message
|
||||
|
||||
|
||||
@ -12,83 +11,106 @@ class TestAssistant(unittest.TestCase):
|
||||
self.args.verbose = False
|
||||
self.args.debug = False
|
||||
self.args.no_syntax = False
|
||||
self.args.model = 'test-model'
|
||||
self.args.api_url = 'test-url'
|
||||
self.args.model_list_url = 'test-list-url'
|
||||
self.args.model = "test-model"
|
||||
self.args.api_url = "test-url"
|
||||
self.args.model_list_url = "test-list-url"
|
||||
|
||||
@patch('sqlite3.connect')
|
||||
@patch('os.environ.get')
|
||||
@patch('pr.core.context.init_system_message')
|
||||
@patch('pr.core.enhanced_assistant.EnhancedAssistant')
|
||||
@patch("sqlite3.connect")
|
||||
@patch("os.environ.get")
|
||||
@patch("pr.core.context.init_system_message")
|
||||
@patch("pr.core.enhanced_assistant.EnhancedAssistant")
|
||||
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
|
||||
mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default)
|
||||
mock_env.side_effect = lambda key, default: {
|
||||
"OPENROUTER_API_KEY": "key",
|
||||
"AI_MODEL": "model",
|
||||
"API_URL": "url",
|
||||
"MODEL_LIST_URL": "list",
|
||||
"USE_TOOLS": "1",
|
||||
"STRICT_MODE": "0",
|
||||
}.get(key, default)
|
||||
mock_conn = MagicMock()
|
||||
mock_sqlite.return_value = mock_conn
|
||||
mock_init_sys.return_value = {'role': 'system', 'content': 'sys'}
|
||||
mock_init_sys.return_value = {"role": "system", "content": "sys"}
|
||||
|
||||
assistant = Assistant(self.args)
|
||||
|
||||
self.assertEqual(assistant.api_key, 'key')
|
||||
self.assertEqual(assistant.model, 'test-model')
|
||||
self.assertEqual(assistant.api_key, "key")
|
||||
self.assertEqual(assistant.model, "test-model")
|
||||
mock_sqlite.assert_called_once()
|
||||
|
||||
@patch('pr.core.assistant.call_api')
|
||||
@patch('pr.core.assistant.render_markdown')
|
||||
@patch("pr.core.assistant.call_api")
|
||||
@patch("pr.core.assistant.render_markdown")
|
||||
def test_process_response_no_tools(self, mock_render, mock_call):
|
||||
assistant = MagicMock()
|
||||
assistant.messages = MagicMock()
|
||||
assistant.verbose = False
|
||||
assistant.syntax_highlighting = True
|
||||
mock_render.return_value = 'rendered'
|
||||
mock_render.return_value = "rendered"
|
||||
|
||||
response = {'choices': [{'message': {'content': 'content'}}]}
|
||||
response = {"choices": [{"message": {"content": "content"}}]}
|
||||
|
||||
result = Assistant.process_response(assistant, response)
|
||||
|
||||
self.assertEqual(result, 'rendered')
|
||||
assistant.messages.append.assert_called_with({'content': 'content'})
|
||||
self.assertEqual(result, "rendered")
|
||||
assistant.messages.append.assert_called_with({"content": "content"})
|
||||
|
||||
@patch('pr.core.assistant.call_api')
|
||||
@patch('pr.core.assistant.render_markdown')
|
||||
@patch('pr.core.assistant.get_tools_definition')
|
||||
@patch("pr.core.assistant.call_api")
|
||||
@patch("pr.core.assistant.render_markdown")
|
||||
@patch("pr.core.assistant.get_tools_definition")
|
||||
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
|
||||
assistant = MagicMock()
|
||||
assistant.messages = MagicMock()
|
||||
assistant.verbose = False
|
||||
assistant.syntax_highlighting = True
|
||||
assistant.use_tools = True
|
||||
assistant.model = 'model'
|
||||
assistant.api_url = 'url'
|
||||
assistant.api_key = 'key'
|
||||
assistant.model = "model"
|
||||
assistant.api_url = "url"
|
||||
assistant.api_key = "key"
|
||||
mock_tools_def.return_value = []
|
||||
mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]}
|
||||
mock_call.return_value = {"choices": [{"message": {"content": "follow"}}]}
|
||||
|
||||
response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]}
|
||||
response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"tool_calls": [
|
||||
{"id": "1", "function": {"name": "test", "arguments": "{}"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]):
|
||||
result = Assistant.process_response(assistant, response)
|
||||
with patch.object(
|
||||
assistant,
|
||||
"execute_tool_calls",
|
||||
return_value=[{"role": "tool", "content": "result"}],
|
||||
):
|
||||
Assistant.process_response(assistant, response)
|
||||
|
||||
mock_call.assert_called()
|
||||
|
||||
@patch('pr.core.assistant.call_api')
|
||||
@patch('pr.core.assistant.get_tools_definition')
|
||||
@patch("pr.core.assistant.call_api")
|
||||
@patch("pr.core.assistant.get_tools_definition")
|
||||
def test_process_message(self, mock_tools, mock_call):
|
||||
assistant = MagicMock()
|
||||
assistant.messages = MagicMock()
|
||||
assistant.verbose = False
|
||||
assistant.use_tools = True
|
||||
assistant.model = 'model'
|
||||
assistant.api_url = 'url'
|
||||
assistant.api_key = 'key'
|
||||
assistant.model = "model"
|
||||
assistant.api_url = "url"
|
||||
assistant.api_key = "key"
|
||||
mock_tools.return_value = []
|
||||
mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]}
|
||||
mock_call.return_value = {"choices": [{"message": {"content": "response"}}]}
|
||||
|
||||
with patch('pr.core.assistant.render_markdown', return_value='rendered'):
|
||||
with patch('builtins.print'):
|
||||
process_message(assistant, 'test message')
|
||||
with patch("pr.core.assistant.render_markdown", return_value="rendered"):
|
||||
with patch("builtins.print"):
|
||||
process_message(assistant, "test message")
|
||||
|
||||
assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'})
|
||||
assistant.messages.append.assert_called_with(
|
||||
{"role": "user", "content": "test message"}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1,31 +1,30 @@
|
||||
import pytest
|
||||
from pr import config
|
||||
|
||||
|
||||
class TestConfig:
|
||||
|
||||
def test_default_model_exists(self):
|
||||
assert hasattr(config, 'DEFAULT_MODEL')
|
||||
assert hasattr(config, "DEFAULT_MODEL")
|
||||
assert isinstance(config.DEFAULT_MODEL, str)
|
||||
assert len(config.DEFAULT_MODEL) > 0
|
||||
|
||||
def test_api_url_exists(self):
|
||||
assert hasattr(config, 'DEFAULT_API_URL')
|
||||
assert config.DEFAULT_API_URL.startswith('http')
|
||||
assert hasattr(config, "DEFAULT_API_URL")
|
||||
assert config.DEFAULT_API_URL.startswith("http")
|
||||
|
||||
def test_file_paths_exist(self):
|
||||
assert hasattr(config, 'DB_PATH')
|
||||
assert hasattr(config, 'LOG_FILE')
|
||||
assert hasattr(config, 'HISTORY_FILE')
|
||||
assert hasattr(config, "DB_PATH")
|
||||
assert hasattr(config, "LOG_FILE")
|
||||
assert hasattr(config, "HISTORY_FILE")
|
||||
|
||||
def test_autonomous_config(self):
|
||||
assert hasattr(config, 'MAX_AUTONOMOUS_ITERATIONS')
|
||||
assert hasattr(config, "MAX_AUTONOMOUS_ITERATIONS")
|
||||
assert config.MAX_AUTONOMOUS_ITERATIONS > 0
|
||||
|
||||
assert hasattr(config, 'CONTEXT_COMPRESSION_THRESHOLD')
|
||||
assert hasattr(config, "CONTEXT_COMPRESSION_THRESHOLD")
|
||||
assert config.CONTEXT_COMPRESSION_THRESHOLD > 0
|
||||
|
||||
def test_language_keywords(self):
|
||||
assert hasattr(config, 'LANGUAGE_KEYWORDS')
|
||||
assert 'python' in config.LANGUAGE_KEYWORDS
|
||||
assert isinstance(config.LANGUAGE_KEYWORDS['python'], list)
|
||||
assert hasattr(config, "LANGUAGE_KEYWORDS")
|
||||
assert "python" in config.LANGUAGE_KEYWORDS
|
||||
assert isinstance(config.LANGUAGE_KEYWORDS["python"], list)
|
||||
|
||||
@ -1,56 +1,71 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, mock_open
|
||||
import os
|
||||
from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from pr.core.config_loader import (
|
||||
_load_config_file,
|
||||
_parse_value,
|
||||
create_default_config,
|
||||
load_config,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_value_string():
|
||||
assert _parse_value('hello') == 'hello'
|
||||
assert _parse_value("hello") == "hello"
|
||||
|
||||
|
||||
def test_parse_value_int():
|
||||
assert _parse_value('123') == 123
|
||||
assert _parse_value("123") == 123
|
||||
|
||||
|
||||
def test_parse_value_float():
|
||||
assert _parse_value('1.23') == 1.23
|
||||
assert _parse_value("1.23") == 1.23
|
||||
|
||||
|
||||
def test_parse_value_bool_true():
|
||||
assert _parse_value('true') == True
|
||||
assert _parse_value("true") == True
|
||||
|
||||
|
||||
def test_parse_value_bool_false():
|
||||
assert _parse_value('false') == False
|
||||
assert _parse_value("false") == False
|
||||
|
||||
|
||||
def test_parse_value_bool_upper():
|
||||
assert _parse_value('TRUE') == True
|
||||
assert _parse_value("TRUE") == True
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
|
||||
@patch("os.path.exists", return_value=False)
|
||||
def test_load_config_file_not_exists(mock_exists):
|
||||
config = _load_config_file('test.ini')
|
||||
config = _load_config_file("test.ini")
|
||||
assert config == {}
|
||||
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('configparser.ConfigParser')
|
||||
|
||||
@patch("os.path.exists", return_value=True)
|
||||
@patch("configparser.ConfigParser")
|
||||
def test_load_config_file_exists(mock_parser_class, mock_exists):
|
||||
mock_parser = mock_parser_class.return_value
|
||||
mock_parser.sections.return_value = ['api']
|
||||
mock_parser.items.return_value = [('key', 'value')]
|
||||
config = _load_config_file('test.ini')
|
||||
assert 'api' in config
|
||||
assert config['api']['key'] == 'value'
|
||||
mock_parser.sections.return_value = ["api"]
|
||||
mock_parser.items.return_value = [("key", "value")]
|
||||
config = _load_config_file("test.ini")
|
||||
assert "api" in config
|
||||
assert config["api"]["key"] == "value"
|
||||
|
||||
@patch('pr.core.config_loader._load_config_file')
|
||||
|
||||
@patch("pr.core.config_loader._load_config_file")
|
||||
def test_load_config(mock_load):
|
||||
mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}]
|
||||
mock_load.side_effect = [{"api": {"key": "global"}}, {"api": {"key": "local"}}]
|
||||
config = load_config()
|
||||
assert config['api']['key'] == 'local'
|
||||
assert config["api"]["key"] == "local"
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
def test_create_default_config(mock_file):
|
||||
result = create_default_config('test.ini')
|
||||
result = create_default_config("test.ini")
|
||||
assert result == True
|
||||
mock_file.assert_called_once_with('test.ini', 'w')
|
||||
mock_file.assert_called_once_with("test.ini", "w")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called_once()
|
||||
|
||||
@patch('builtins.open', side_effect=Exception('error'))
|
||||
|
||||
@patch("builtins.open", side_effect=Exception("error"))
|
||||
def test_create_default_config_error(mock_file):
|
||||
result = create_default_config('test.ini')
|
||||
assert result == False
|
||||
result = create_default_config("test.ini")
|
||||
assert result == False
|
||||
|
||||
@ -1,30 +1,29 @@
|
||||
import pytest
|
||||
from pr.core.context import should_compress_context, compress_context
|
||||
from pr.config import RECENT_MESSAGES_TO_KEEP
|
||||
from pr.core.context import compress_context, should_compress_context
|
||||
|
||||
|
||||
class TestContextManagement:
|
||||
|
||||
def test_should_compress_context_below_threshold(self):
|
||||
messages = [{'role': 'user', 'content': 'test'}] * 10
|
||||
messages = [{"role": "user", "content": "test"}] * 10
|
||||
assert should_compress_context(messages) is False
|
||||
|
||||
def test_should_compress_context_above_threshold(self):
|
||||
messages = [{'role': 'user', 'content': 'test'}] * 35
|
||||
messages = [{"role": "user", "content": "test"}] * 35
|
||||
assert should_compress_context(messages) is True
|
||||
|
||||
def test_compress_context_preserves_system_message(self):
|
||||
messages = [
|
||||
{'role': 'system', 'content': 'System prompt'},
|
||||
{'role': 'user', 'content': 'Hello'},
|
||||
{'role': 'assistant', 'content': 'Hi'},
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
] * 40 # Ensure compression
|
||||
compressed = compress_context(messages)
|
||||
assert compressed[0]['role'] == 'system'
|
||||
assert 'System prompt' in compressed[0]['content']
|
||||
assert compressed[0]["role"] == "system"
|
||||
assert "System prompt" in compressed[0]["content"]
|
||||
|
||||
def test_compress_context_keeps_recent_messages(self):
|
||||
messages = [{'role': 'user', 'content': f'msg{i}'} for i in range(40)]
|
||||
messages = [{"role": "user", "content": f"msg{i}"} for i in range(40)]
|
||||
compressed = compress_context(messages)
|
||||
# Should keep recent messages
|
||||
recent = compressed[-RECENT_MESSAGES_TO_KEEP:]
|
||||
@ -32,4 +31,4 @@ class TestContextManagement:
|
||||
# Check that the messages are the most recent ones
|
||||
for i, msg in enumerate(recent):
|
||||
expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i
|
||||
assert msg['content'] == f'msg{expected_index}'
|
||||
assert msg["content"] == f"msg{expected_index}"
|
||||
|
||||
@ -1,89 +1,97 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pr.core.enhanced_assistant import EnhancedAssistant
|
||||
|
||||
|
||||
def test_enhanced_assistant_init():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assert assistant.base == mock_base
|
||||
assert assistant.current_conversation_id is not None
|
||||
|
||||
|
||||
def test_enhanced_call_api_with_cache():
|
||||
mock_base = MagicMock()
|
||||
mock_base.model = 'test-model'
|
||||
mock_base.api_url = 'http://test'
|
||||
mock_base.api_key = 'key'
|
||||
mock_base.model = "test-model"
|
||||
mock_base.api_url = "http://test"
|
||||
mock_base.api_key = "key"
|
||||
mock_base.use_tools = False
|
||||
mock_base.verbose = False
|
||||
|
||||
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = MagicMock()
|
||||
assistant.api_cache.get.return_value = {'cached': True}
|
||||
|
||||
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
|
||||
assert result == {'cached': True}
|
||||
assistant.api_cache.get.return_value = {"cached": True}
|
||||
|
||||
result = assistant.enhanced_call_api([{"role": "user", "content": "test"}])
|
||||
assert result == {"cached": True}
|
||||
assistant.api_cache.get.assert_called_once()
|
||||
|
||||
|
||||
def test_enhanced_call_api_without_cache():
|
||||
mock_base = MagicMock()
|
||||
mock_base.model = 'test-model'
|
||||
mock_base.api_url = 'http://test'
|
||||
mock_base.api_key = 'key'
|
||||
mock_base.model = "test-model"
|
||||
mock_base.api_url = "http://test"
|
||||
mock_base.api_key = "key"
|
||||
mock_base.use_tools = False
|
||||
mock_base.verbose = False
|
||||
|
||||
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = None
|
||||
|
||||
|
||||
# It will try to call API and fail with network error, but that's expected
|
||||
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
|
||||
assert 'error' in result
|
||||
result = assistant.enhanced_call_api([{"role": "user", "content": "test"}])
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_execute_workflow_not_found():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.workflow_storage = MagicMock()
|
||||
assistant.workflow_storage.load_workflow_by_name.return_value = None
|
||||
|
||||
result = assistant.execute_workflow('nonexistent')
|
||||
assert 'error' in result
|
||||
|
||||
result = assistant.execute_workflow("nonexistent")
|
||||
assert "error" in result
|
||||
|
||||
|
||||
def test_create_agent():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.agent_manager = MagicMock()
|
||||
assistant.agent_manager.create_agent.return_value = 'agent_id'
|
||||
|
||||
result = assistant.create_agent('role')
|
||||
assert result == 'agent_id'
|
||||
assistant.agent_manager.create_agent.return_value = "agent_id"
|
||||
|
||||
result = assistant.create_agent("role")
|
||||
assert result == "agent_id"
|
||||
|
||||
|
||||
def test_search_knowledge():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.knowledge_store = MagicMock()
|
||||
assistant.knowledge_store.search_entries.return_value = [{'result': True}]
|
||||
|
||||
result = assistant.search_knowledge('query')
|
||||
assert result == [{'result': True}]
|
||||
assistant.knowledge_store.search_entries.return_value = [{"result": True}]
|
||||
|
||||
result = assistant.search_knowledge("query")
|
||||
assert result == [{"result": True}]
|
||||
|
||||
|
||||
def test_get_cache_statistics():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = MagicMock()
|
||||
assistant.api_cache.get_statistics.return_value = {'hits': 10}
|
||||
assistant.api_cache.get_statistics.return_value = {"hits": 10}
|
||||
assistant.tool_cache = MagicMock()
|
||||
assistant.tool_cache.get_statistics.return_value = {'misses': 5}
|
||||
|
||||
assistant.tool_cache.get_statistics.return_value = {"misses": 5}
|
||||
|
||||
stats = assistant.get_cache_statistics()
|
||||
assert 'api_cache' in stats
|
||||
assert 'tool_cache' in stats
|
||||
assert "api_cache" in stats
|
||||
assert "tool_cache" in stats
|
||||
|
||||
|
||||
def test_clear_caches():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = MagicMock()
|
||||
assistant.tool_cache = MagicMock()
|
||||
|
||||
|
||||
assistant.clear_caches()
|
||||
assistant.api_cache.clear_all.assert_called_once()
|
||||
assistant.tool_cache.clear_all.assert_called_once()
|
||||
|
||||
@ -1,24 +1,25 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pr.core.logging import setup_logging, get_logger
|
||||
from pr.core.logging import get_logger, setup_logging
|
||||
|
||||
|
||||
def test_setup_logging_basic():
|
||||
logger = setup_logging(verbose=False)
|
||||
assert logger.name == 'pr'
|
||||
assert logger.name == "pr"
|
||||
assert logger.level == 20 # INFO
|
||||
|
||||
|
||||
def test_setup_logging_verbose():
|
||||
logger = setup_logging(verbose=True)
|
||||
assert logger.name == 'pr'
|
||||
assert logger.name == "pr"
|
||||
assert logger.level == 10 # DEBUG
|
||||
# Should have console handler
|
||||
assert len(logger.handlers) >= 2
|
||||
|
||||
|
||||
def test_get_logger_default():
|
||||
logger = get_logger()
|
||||
assert logger.name == 'pr'
|
||||
assert logger.name == "pr"
|
||||
|
||||
|
||||
def test_get_logger_named():
|
||||
logger = get_logger('test')
|
||||
assert logger.name == 'pr.test'
|
||||
logger = get_logger("test")
|
||||
assert logger.name == "pr.test"
|
||||
|
||||
@ -1,118 +1,134 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from pr.__main__ import main
|
||||
|
||||
|
||||
def test_main_version(capsys):
|
||||
with patch('sys.argv', ['pr', '--version']):
|
||||
with patch("sys.argv", ["pr", "--version"]):
|
||||
with pytest.raises(SystemExit):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'PR Assistant' in captured.out
|
||||
assert "PR Assistant" in captured.out
|
||||
|
||||
|
||||
def test_main_create_config_success(capsys):
|
||||
with patch('pr.core.config_loader.create_default_config', return_value=True):
|
||||
with patch('sys.argv', ['pr', '--create-config']):
|
||||
with patch("pr.core.config_loader.create_default_config", return_value=True):
|
||||
with patch("sys.argv", ["pr", "--create-config"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Configuration file created' in captured.out
|
||||
assert "Configuration file created" in captured.out
|
||||
|
||||
|
||||
def test_main_create_config_fail(capsys):
|
||||
with patch('pr.core.config_loader.create_default_config', return_value=False):
|
||||
with patch('sys.argv', ['pr', '--create-config']):
|
||||
with patch("pr.core.config_loader.create_default_config", return_value=False):
|
||||
with patch("sys.argv", ["pr", "--create-config"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Error creating configuration file' in captured.err
|
||||
assert "Error creating configuration file" in captured.err
|
||||
|
||||
|
||||
def test_main_list_sessions_no_sessions(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
with patch("pr.core.session.SessionManager") as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.list_sessions.return_value = []
|
||||
with patch('sys.argv', ['pr', '--list-sessions']):
|
||||
with patch("sys.argv", ["pr", "--list-sessions"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'No saved sessions found' in captured.out
|
||||
assert "No saved sessions found" in captured.out
|
||||
|
||||
|
||||
def test_main_list_sessions_with_sessions(capsys):
|
||||
sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}]
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
sessions = [{"name": "test", "created_at": "2023-01-01", "message_count": 5}]
|
||||
with patch("pr.core.session.SessionManager") as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.list_sessions.return_value = sessions
|
||||
with patch('sys.argv', ['pr', '--list-sessions']):
|
||||
with patch("sys.argv", ["pr", "--list-sessions"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Found 1 saved sessions' in captured.out
|
||||
assert 'test' in captured.out
|
||||
assert "Found 1 saved sessions" in captured.out
|
||||
assert "test" in captured.out
|
||||
|
||||
|
||||
def test_main_delete_session_success(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
with patch("pr.core.session.SessionManager") as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.delete_session.return_value = True
|
||||
with patch('sys.argv', ['pr', '--delete-session', 'test']):
|
||||
with patch("sys.argv", ["pr", "--delete-session", "test"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert "Session 'test' deleted" in captured.out
|
||||
|
||||
|
||||
def test_main_delete_session_fail(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
with patch("pr.core.session.SessionManager") as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.delete_session.return_value = False
|
||||
with patch('sys.argv', ['pr', '--delete-session', 'test']):
|
||||
with patch("sys.argv", ["pr", "--delete-session", "test"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert "Error deleting session 'test'" in captured.err
|
||||
|
||||
|
||||
def test_main_export_session_json(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
with patch("pr.core.session.SessionManager") as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.export_session.return_value = True
|
||||
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']):
|
||||
with patch("sys.argv", ["pr", "--export-session", "test", "output.json"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Session exported to output.json' in captured.out
|
||||
assert "Session exported to output.json" in captured.out
|
||||
|
||||
|
||||
def test_main_export_session_md(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
with patch("pr.core.session.SessionManager") as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.export_session.return_value = True
|
||||
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']):
|
||||
with patch("sys.argv", ["pr", "--export-session", "test", "output.md"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Session exported to output.md' in captured.out
|
||||
assert "Session exported to output.md" in captured.out
|
||||
|
||||
|
||||
def test_main_usage(capsys):
|
||||
usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01}
|
||||
with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage):
|
||||
with patch('sys.argv', ['pr', '--usage']):
|
||||
usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01}
|
||||
with patch(
|
||||
"pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage
|
||||
):
|
||||
with patch("sys.argv", ["pr", "--usage"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Total Usage Statistics' in captured.out
|
||||
assert 'Requests: 10' in captured.out
|
||||
assert "Total Usage Statistics" in captured.out
|
||||
assert "Requests: 10" in captured.out
|
||||
|
||||
|
||||
def test_main_plugins_no_plugins(capsys):
|
||||
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
|
||||
with patch("pr.plugins.loader.PluginLoader") as mock_loader:
|
||||
mock_instance = mock_loader.return_value
|
||||
mock_instance.load_plugins.return_value = None
|
||||
mock_instance.list_loaded_plugins.return_value = []
|
||||
with patch('sys.argv', ['pr', '--plugins']):
|
||||
with patch("sys.argv", ["pr", "--plugins"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'No plugins loaded' in captured.out
|
||||
assert "No plugins loaded" in captured.out
|
||||
|
||||
|
||||
def test_main_plugins_with_plugins(capsys):
|
||||
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
|
||||
with patch("pr.plugins.loader.PluginLoader") as mock_loader:
|
||||
mock_instance = mock_loader.return_value
|
||||
mock_instance.load_plugins.return_value = None
|
||||
mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2']
|
||||
with patch('sys.argv', ['pr', '--plugins']):
|
||||
mock_instance.list_loaded_plugins.return_value = ["plugin1", "plugin2"]
|
||||
with patch("sys.argv", ["pr", "--plugins"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Loaded 2 plugins' in captured.out
|
||||
assert "Loaded 2 plugins" in captured.out
|
||||
|
||||
|
||||
def test_main_run_assistant():
|
||||
with patch('pr.__main__.Assistant') as mock_assistant:
|
||||
with patch("pr.__main__.Assistant") as mock_assistant:
|
||||
mock_instance = mock_assistant.return_value
|
||||
with patch('sys.argv', ['pr', 'test message']):
|
||||
with patch("sys.argv", ["pr", "test message"]):
|
||||
main()
|
||||
mock_assistant.assert_called_once()
|
||||
mock_instance.run.assert_called_once()
|
||||
mock_instance.run.assert_called_once()
|
||||
|
||||
@ -1,116 +1,131 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from pr.core.session import SessionManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_sessions_dir(tmp_path, monkeypatch):
|
||||
from pr.core import session
|
||||
|
||||
original_dir = session.SESSIONS_DIR
|
||||
monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path))
|
||||
monkeypatch.setattr(session, "SESSIONS_DIR", str(tmp_path))
|
||||
# Clean any existing files
|
||||
import shutil
|
||||
|
||||
if os.path.exists(str(tmp_path)):
|
||||
shutil.rmtree(str(tmp_path))
|
||||
os.makedirs(str(tmp_path), exist_ok=True)
|
||||
yield tmp_path
|
||||
monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir)
|
||||
monkeypatch.setattr(session, "SESSIONS_DIR", original_dir)
|
||||
|
||||
|
||||
def test_session_manager_init(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
SessionManager()
|
||||
assert os.path.exists(temp_sessions_dir)
|
||||
|
||||
|
||||
def test_save_and_load_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
name = "test_session"
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
metadata = {"test": True}
|
||||
|
||||
|
||||
assert manager.save_session(name, messages, metadata)
|
||||
|
||||
|
||||
loaded = manager.load_session(name)
|
||||
assert loaded is not None
|
||||
assert loaded['name'] == name
|
||||
assert loaded['messages'] == messages
|
||||
assert loaded['metadata'] == metadata
|
||||
assert loaded["name"] == name
|
||||
assert loaded["messages"] == messages
|
||||
assert loaded["metadata"] == metadata
|
||||
|
||||
|
||||
def test_load_nonexistent_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
loaded = manager.load_session("nonexistent")
|
||||
assert loaded is None
|
||||
|
||||
|
||||
def test_list_sessions(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
# Save a session
|
||||
manager.save_session("session1", [{"role": "user", "content": "Hi"}])
|
||||
manager.save_session("session2", [{"role": "user", "content": "Hello"}])
|
||||
|
||||
|
||||
sessions = manager.list_sessions()
|
||||
assert len(sessions) == 2
|
||||
assert sessions[0]['name'] == "session2" # sorted by created_at desc
|
||||
assert sessions[0]["name"] == "session2" # sorted by created_at desc
|
||||
|
||||
|
||||
def test_delete_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
name = "to_delete"
|
||||
manager.save_session(name, [{"role": "user", "content": "Test"}])
|
||||
|
||||
|
||||
assert manager.delete_session(name)
|
||||
assert manager.load_session(name) is None
|
||||
|
||||
|
||||
def test_delete_nonexistent_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
assert not manager.delete_session("nonexistent")
|
||||
|
||||
|
||||
def test_export_session_json(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "export_test"
|
||||
messages = [{"role": "user", "content": "Export me"}]
|
||||
manager.save_session(name, messages)
|
||||
|
||||
|
||||
output_path = tmp_path / "exported.json"
|
||||
assert manager.export_session(name, str(output_path), 'json')
|
||||
assert manager.export_session(name, str(output_path), "json")
|
||||
assert output_path.exists()
|
||||
|
||||
|
||||
with open(output_path) as f:
|
||||
data = json.load(f)
|
||||
assert data['name'] == name
|
||||
assert data["name"] == name
|
||||
|
||||
|
||||
def test_export_session_markdown(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "export_md"
|
||||
messages = [{"role": "user", "content": "Markdown export"}]
|
||||
manager.save_session(name, messages)
|
||||
|
||||
|
||||
output_path = tmp_path / "exported.md"
|
||||
assert manager.export_session(name, str(output_path), 'markdown')
|
||||
assert manager.export_session(name, str(output_path), "markdown")
|
||||
assert output_path.exists()
|
||||
|
||||
|
||||
content = output_path.read_text()
|
||||
assert "# Session: export_md" in content
|
||||
|
||||
|
||||
def test_export_session_txt(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "export_txt"
|
||||
messages = [{"role": "user", "content": "Text export"}]
|
||||
manager.save_session(name, messages)
|
||||
|
||||
|
||||
output_path = tmp_path / "exported.txt"
|
||||
assert manager.export_session(name, str(output_path), 'txt')
|
||||
assert manager.export_session(name, str(output_path), "txt")
|
||||
assert output_path.exists()
|
||||
|
||||
|
||||
content = output_path.read_text()
|
||||
assert "Session: export_txt" in content
|
||||
|
||||
|
||||
def test_export_nonexistent_session(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
output_path = tmp_path / "nonexistent.json"
|
||||
assert not manager.export_session("nonexistent", str(output_path), 'json')
|
||||
assert not manager.export_session("nonexistent", str(output_path), "json")
|
||||
|
||||
|
||||
def test_export_unsupported_format(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "test"
|
||||
manager.save_session(name, [{"role": "user", "content": "Test"}])
|
||||
|
||||
|
||||
output_path = tmp_path / "test.unsupported"
|
||||
assert not manager.export_session(name, str(output_path), 'unsupported')
|
||||
assert not manager.export_session(name, str(output_path), "unsupported")
|
||||
|
||||
@ -1,69 +1,68 @@
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
from pr.tools.filesystem import read_file, write_file, list_directory, search_replace
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
|
||||
from pr.tools.base import get_tools_definition
|
||||
from pr.tools.filesystem import list_directory, read_file, search_replace, write_file
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
|
||||
|
||||
class TestFilesystemTools:
|
||||
|
||||
def test_write_and_read_file(self, temp_dir):
|
||||
filepath = os.path.join(temp_dir, 'test.txt')
|
||||
content = 'Hello, World!'
|
||||
filepath = os.path.join(temp_dir, "test.txt")
|
||||
content = "Hello, World!"
|
||||
|
||||
write_result = write_file(filepath, content)
|
||||
assert write_result['status'] == 'success'
|
||||
assert write_result["status"] == "success"
|
||||
|
||||
read_result = read_file(filepath)
|
||||
assert read_result['status'] == 'success'
|
||||
assert content in read_result['content']
|
||||
assert read_result["status"] == "success"
|
||||
assert content in read_result["content"]
|
||||
|
||||
def test_read_nonexistent_file(self):
|
||||
result = read_file('/nonexistent/path/file.txt')
|
||||
assert result['status'] == 'error'
|
||||
result = read_file("/nonexistent/path/file.txt")
|
||||
assert result["status"] == "error"
|
||||
|
||||
def test_list_directory(self, temp_dir):
|
||||
test_file = os.path.join(temp_dir, 'testfile.txt')
|
||||
with open(test_file, 'w') as f:
|
||||
f.write('test')
|
||||
test_file = os.path.join(temp_dir, "testfile.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("test")
|
||||
|
||||
result = list_directory(temp_dir)
|
||||
assert result['status'] == 'success'
|
||||
assert any(item['name'] == 'testfile.txt' for item in result['items'])
|
||||
assert result["status"] == "success"
|
||||
assert any(item["name"] == "testfile.txt" for item in result["items"])
|
||||
|
||||
def test_search_replace(self, temp_dir):
|
||||
filepath = os.path.join(temp_dir, 'test.txt')
|
||||
content = 'Hello, World!'
|
||||
with open(filepath, 'w') as f:
|
||||
filepath = os.path.join(temp_dir, "test.txt")
|
||||
content = "Hello, World!"
|
||||
with open(filepath, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
result = search_replace(filepath, 'World', 'Universe')
|
||||
assert result['status'] == 'success'
|
||||
result = search_replace(filepath, "World", "Universe")
|
||||
assert result["status"] == "success"
|
||||
|
||||
read_result = read_file(filepath)
|
||||
assert 'Hello, Universe!' in read_result['content']
|
||||
assert "Hello, Universe!" in read_result["content"]
|
||||
|
||||
|
||||
class TestPatchTools:
|
||||
|
||||
def test_create_diff(self, temp_dir):
|
||||
file1 = os.path.join(temp_dir, 'file1.txt')
|
||||
file2 = os.path.join(temp_dir, 'file2.txt')
|
||||
with open(file1, 'w') as f:
|
||||
f.write('line1\nline2\nline3\n')
|
||||
with open(file2, 'w') as f:
|
||||
f.write('line1\nline2 modified\nline3\n')
|
||||
file1 = os.path.join(temp_dir, "file1.txt")
|
||||
file2 = os.path.join(temp_dir, "file2.txt")
|
||||
with open(file1, "w") as f:
|
||||
f.write("line1\nline2\nline3\n")
|
||||
with open(file2, "w") as f:
|
||||
f.write("line1\nline2 modified\nline3\n")
|
||||
|
||||
result = create_diff(file1, file2)
|
||||
assert result['status'] == 'success'
|
||||
assert 'line2' in result['diff']
|
||||
assert 'line2 modified' in result['diff']
|
||||
assert result["status"] == "success"
|
||||
assert "line2" in result["diff"]
|
||||
assert "line2 modified" in result["diff"]
|
||||
|
||||
def test_apply_patch(self, temp_dir):
|
||||
filepath = os.path.join(temp_dir, 'file.txt')
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('line1\nline2\nline3\n')
|
||||
filepath = os.path.join(temp_dir, "file.txt")
|
||||
with open(filepath, "w") as f:
|
||||
f.write("line1\nline2\nline3\n")
|
||||
|
||||
# Create a simple patch
|
||||
patch_content = """--- a/file.txt
|
||||
@ -75,10 +74,10 @@ class TestPatchTools:
|
||||
line3
|
||||
"""
|
||||
result = apply_patch(filepath, patch_content)
|
||||
assert result['status'] == 'success'
|
||||
assert result["status"] == "success"
|
||||
|
||||
read_result = read_file(filepath)
|
||||
assert 'line2 modified' in read_result['content']
|
||||
assert "line2 modified" in read_result["content"]
|
||||
|
||||
|
||||
class TestToolDefinitions:
|
||||
@ -92,27 +91,27 @@ class TestToolDefinitions:
|
||||
tools = get_tools_definition()
|
||||
|
||||
for tool in tools:
|
||||
assert 'type' in tool
|
||||
assert tool['type'] == 'function'
|
||||
assert 'function' in tool
|
||||
assert "type" in tool
|
||||
assert tool["type"] == "function"
|
||||
assert "function" in tool
|
||||
|
||||
func = tool['function']
|
||||
assert 'name' in func
|
||||
assert 'description' in func
|
||||
assert 'parameters' in func
|
||||
func = tool["function"]
|
||||
assert "name" in func
|
||||
assert "description" in func
|
||||
assert "parameters" in func
|
||||
|
||||
def test_filesystem_tools_present(self):
|
||||
tools = get_tools_definition()
|
||||
tool_names = [t['function']['name'] for t in tools]
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
|
||||
assert 'read_file' in tool_names
|
||||
assert 'write_file' in tool_names
|
||||
assert 'list_directory' in tool_names
|
||||
assert 'search_replace' in tool_names
|
||||
assert "read_file" in tool_names
|
||||
assert "write_file" in tool_names
|
||||
assert "list_directory" in tool_names
|
||||
assert "search_replace" in tool_names
|
||||
|
||||
def test_patch_tools_present(self):
|
||||
tools = get_tools_definition()
|
||||
tool_names = [t['function']['name'] for t in tools]
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
|
||||
assert 'apply_patch' in tool_names
|
||||
assert 'create_diff' in tool_names
|
||||
assert "apply_patch" in tool_names
|
||||
assert "create_diff" in tool_names
|
||||
|
||||
@ -1,86 +1,110 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from pr.core.usage_tracker import UsageTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_usage_file(tmp_path, monkeypatch):
|
||||
from pr.core import usage_tracker
|
||||
|
||||
original_file = usage_tracker.USAGE_DB_FILE
|
||||
temp_file = str(tmp_path / "usage.json")
|
||||
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file)
|
||||
monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", temp_file)
|
||||
yield temp_file
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file)
|
||||
monkeypatch.setattr(usage_tracker, "USAGE_DB_FILE", original_file)
|
||||
|
||||
|
||||
def test_usage_tracker_init():
|
||||
tracker = UsageTracker()
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 0
|
||||
assert summary['total_tokens'] == 0
|
||||
assert summary['estimated_cost'] == 0.0
|
||||
assert summary["requests"] == 0
|
||||
assert summary["total_tokens"] == 0
|
||||
assert summary["estimated_cost"] == 0.0
|
||||
|
||||
|
||||
def test_track_request_known_model():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||
|
||||
tracker.track_request("gpt-3.5-turbo", 100, 50)
|
||||
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 1
|
||||
assert summary['input_tokens'] == 100
|
||||
assert summary['output_tokens'] == 50
|
||||
assert summary['total_tokens'] == 150
|
||||
assert 'gpt-3.5-turbo' in summary['models_used']
|
||||
assert summary["requests"] == 1
|
||||
assert summary["input_tokens"] == 100
|
||||
assert summary["output_tokens"] == 50
|
||||
assert summary["total_tokens"] == 150
|
||||
assert "gpt-3.5-turbo" in summary["models_used"]
|
||||
# Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125
|
||||
assert abs(summary['estimated_cost'] - 0.000125) < 1e-6
|
||||
assert abs(summary["estimated_cost"] - 0.000125) < 1e-6
|
||||
|
||||
|
||||
def test_track_request_unknown_model():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('unknown-model', 100, 50)
|
||||
|
||||
tracker.track_request("unknown-model", 100, 50)
|
||||
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 1
|
||||
assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0
|
||||
assert summary["requests"] == 1
|
||||
assert summary["estimated_cost"] == 0.0 # Unknown model, cost 0
|
||||
|
||||
|
||||
def test_track_request_multiple():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||
tracker.track_request('gpt-4', 200, 100)
|
||||
|
||||
tracker.track_request("gpt-3.5-turbo", 100, 50)
|
||||
tracker.track_request("gpt-4", 200, 100)
|
||||
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 2
|
||||
assert summary['input_tokens'] == 300
|
||||
assert summary['output_tokens'] == 150
|
||||
assert summary['total_tokens'] == 450
|
||||
assert len(summary['models_used']) == 2
|
||||
assert summary["requests"] == 2
|
||||
assert summary["input_tokens"] == 300
|
||||
assert summary["output_tokens"] == 150
|
||||
assert summary["total_tokens"] == 450
|
||||
assert len(summary["models_used"]) == 2
|
||||
|
||||
|
||||
def test_get_formatted_summary():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||
|
||||
tracker.track_request("gpt-3.5-turbo", 100, 50)
|
||||
|
||||
formatted = tracker.get_formatted_summary()
|
||||
assert "Total Requests: 1" in formatted
|
||||
assert "Total Tokens: 150" in formatted
|
||||
assert "Estimated Cost: $0.0001" in formatted
|
||||
assert "gpt-3.5-turbo" in formatted
|
||||
|
||||
|
||||
def test_get_total_usage_no_file(temp_usage_file):
|
||||
total = UsageTracker.get_total_usage()
|
||||
assert total['total_requests'] == 0
|
||||
assert total['total_tokens'] == 0
|
||||
assert total['total_cost'] == 0.0
|
||||
assert total["total_requests"] == 0
|
||||
assert total["total_tokens"] == 0
|
||||
assert total["total_cost"] == 0.0
|
||||
|
||||
|
||||
def test_get_total_usage_with_data(temp_usage_file):
|
||||
# Manually create history file
|
||||
history = [
|
||||
{'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125},
|
||||
{'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008}
|
||||
{
|
||||
"timestamp": "2023-01-01",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"cost": 0.000125,
|
||||
},
|
||||
{
|
||||
"timestamp": "2023-01-02",
|
||||
"model": "gpt-4",
|
||||
"input_tokens": 200,
|
||||
"output_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"cost": 0.008,
|
||||
},
|
||||
]
|
||||
with open(temp_usage_file, 'w') as f:
|
||||
with open(temp_usage_file, "w") as f:
|
||||
json.dump(history, f)
|
||||
|
||||
|
||||
total = UsageTracker.get_total_usage()
|
||||
assert total['total_requests'] == 2
|
||||
assert total['total_tokens'] == 450
|
||||
assert abs(total['total_cost'] - 0.008125) < 1e-6
|
||||
assert total["total_requests"] == 2
|
||||
assert total["total_tokens"] == 450
|
||||
assert abs(total["total_cost"] - 0.008125) < 1e-6
|
||||
|
||||
@ -1,49 +1,59 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from pr.core.exceptions import ValidationError
|
||||
from pr.core.validation import (
|
||||
validate_file_path,
|
||||
validate_directory_path,
|
||||
validate_model_name,
|
||||
validate_api_url,
|
||||
validate_directory_path,
|
||||
validate_file_path,
|
||||
validate_max_tokens,
|
||||
validate_model_name,
|
||||
validate_session_name,
|
||||
validate_temperature,
|
||||
validate_max_tokens,
|
||||
)
|
||||
from pr.core.exceptions import ValidationError
|
||||
|
||||
|
||||
def test_validate_file_path_empty():
|
||||
with pytest.raises(ValidationError, match="File path cannot be empty"):
|
||||
validate_file_path("")
|
||||
|
||||
|
||||
def test_validate_file_path_not_exist():
|
||||
with pytest.raises(ValidationError, match="File does not exist"):
|
||||
validate_file_path("/nonexistent/file.txt", must_exist=True)
|
||||
|
||||
|
||||
def test_validate_file_path_is_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with pytest.raises(ValidationError, match="Path is a directory"):
|
||||
validate_file_path(tmpdir, must_exist=True)
|
||||
|
||||
|
||||
def test_validate_file_path_valid():
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
result = validate_file_path(tmpfile.name, must_exist=True)
|
||||
assert os.path.isabs(result)
|
||||
assert result == os.path.abspath(tmpfile.name)
|
||||
|
||||
|
||||
def test_validate_directory_path_empty():
|
||||
with pytest.raises(ValidationError, match="Directory path cannot be empty"):
|
||||
validate_directory_path("")
|
||||
|
||||
|
||||
def test_validate_directory_path_not_exist():
|
||||
with pytest.raises(ValidationError, match="Directory does not exist"):
|
||||
validate_directory_path("/nonexistent/dir", must_exist=True)
|
||||
|
||||
|
||||
def test_validate_directory_path_not_dir():
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
with pytest.raises(ValidationError, match="Path is not a directory"):
|
||||
validate_directory_path(tmpfile.name, must_exist=True)
|
||||
|
||||
|
||||
def test_validate_directory_path_create():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
new_dir = os.path.join(tmpdir, "new_dir")
|
||||
@ -51,72 +61,89 @@ def test_validate_directory_path_create():
|
||||
assert os.path.isdir(new_dir)
|
||||
assert result == os.path.abspath(new_dir)
|
||||
|
||||
|
||||
def test_validate_directory_path_valid():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = validate_directory_path(tmpdir, must_exist=True)
|
||||
assert result == os.path.abspath(tmpdir)
|
||||
|
||||
|
||||
def test_validate_model_name_empty():
|
||||
with pytest.raises(ValidationError, match="Model name cannot be empty"):
|
||||
validate_model_name("")
|
||||
|
||||
|
||||
def test_validate_model_name_too_short():
|
||||
with pytest.raises(ValidationError, match="Model name too short"):
|
||||
validate_model_name("a")
|
||||
|
||||
|
||||
def test_validate_model_name_valid():
|
||||
result = validate_model_name("gpt-3.5-turbo")
|
||||
assert result == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
def test_validate_api_url_empty():
|
||||
with pytest.raises(ValidationError, match="API URL cannot be empty"):
|
||||
validate_api_url("")
|
||||
|
||||
|
||||
def test_validate_api_url_invalid():
|
||||
with pytest.raises(ValidationError, match="API URL must start with"):
|
||||
validate_api_url("invalid-url")
|
||||
|
||||
|
||||
def test_validate_api_url_valid():
|
||||
result = validate_api_url("https://api.example.com")
|
||||
assert result == "https://api.example.com"
|
||||
|
||||
|
||||
def test_validate_session_name_empty():
|
||||
with pytest.raises(ValidationError, match="Session name cannot be empty"):
|
||||
validate_session_name("")
|
||||
|
||||
|
||||
def test_validate_session_name_invalid_char():
|
||||
with pytest.raises(ValidationError, match="contains invalid character"):
|
||||
validate_session_name("test/session")
|
||||
|
||||
|
||||
def test_validate_session_name_too_long():
|
||||
long_name = "a" * 256
|
||||
with pytest.raises(ValidationError, match="Session name too long"):
|
||||
validate_session_name(long_name)
|
||||
|
||||
|
||||
def test_validate_session_name_valid():
|
||||
result = validate_session_name("valid_session_123")
|
||||
assert result == "valid_session_123"
|
||||
|
||||
|
||||
def test_validate_temperature_too_low():
|
||||
with pytest.raises(ValidationError, match="Temperature must be between"):
|
||||
validate_temperature(-0.1)
|
||||
|
||||
|
||||
def test_validate_temperature_too_high():
|
||||
with pytest.raises(ValidationError, match="Temperature must be between"):
|
||||
validate_temperature(2.1)
|
||||
|
||||
|
||||
def test_validate_temperature_valid():
|
||||
result = validate_temperature(0.7)
|
||||
assert result == 0.7
|
||||
|
||||
|
||||
def test_validate_max_tokens_too_low():
|
||||
with pytest.raises(ValidationError, match="Max tokens must be at least 1"):
|
||||
validate_max_tokens(0)
|
||||
|
||||
|
||||
def test_validate_max_tokens_too_high():
|
||||
with pytest.raises(ValidationError, match="Max tokens too high"):
|
||||
validate_max_tokens(100001)
|
||||
|
||||
|
||||
def test_validate_max_tokens_valid():
|
||||
result = validate_max_tokens(1000)
|
||||
assert result == 1000
|
||||
|
||||
Loading…
Reference in New Issue
Block a user