Update.
This commit is contained in:
parent
3f979d2bbd
commit
685766ef86
@ -1,4 +1,7 @@
|
||||
# rp Assistant
|
||||
|
||||
|
||||
|
||||
rp
|
||||
[](https://github.com/retoor/rp-assistant)
|
||||
[](https://www.python.org/downloads/)
|
||||
|
||||
@ -164,7 +164,7 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
|
||||
|
||||
return results
|
||||
|
||||
def get_session_summary(self) -> Dict[str, Any]:
|
||||
def get_session_summary(self) -> str:
|
||||
summary = {
|
||||
'session_id': self.session_id,
|
||||
'active_agents': len(self.active_agents),
|
||||
@ -178,7 +178,7 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
|
||||
for agent_id, agent in self.active_agents.items()
|
||||
]
|
||||
}
|
||||
return summary
|
||||
return json.dumps(summary)
|
||||
|
||||
def clear_session(self):
|
||||
self.active_agents.clear()
|
||||
|
||||
@ -14,12 +14,9 @@ def run_autonomous_mode(assistant, task):
|
||||
logger.debug(f"=== AUTONOMOUS MODE START ===")
|
||||
logger.debug(f"Task: {task}")
|
||||
|
||||
if assistant.verbose:
|
||||
print_autonomous_header(task)
|
||||
|
||||
assistant.messages.append({
|
||||
"role": "user",
|
||||
"content": f"AUTONOMOUS TASK: {task}\n\nPlease work on this task step by step. Use tools as needed. When the task is fully complete, clearly state 'Task complete'."
|
||||
"content": f"{task}"
|
||||
})
|
||||
|
||||
try:
|
||||
@ -29,17 +26,11 @@ def run_autonomous_mode(assistant, task):
|
||||
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
|
||||
logger.debug(f"Messages before context management: {len(assistant.messages)}")
|
||||
|
||||
if assistant.verbose:
|
||||
print(f"\n{Colors.BOLD}{Colors.MAGENTA}{'▶' * 3} Iteration {assistant.autonomous_iterations} {'◀' * 3}{Colors.RESET}\n")
|
||||
|
||||
from pr.core.context import manage_context_window
|
||||
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
|
||||
|
||||
logger.debug(f"Messages after context management: {len(assistant.messages)}")
|
||||
|
||||
if assistant.verbose:
|
||||
print(f"{Colors.GRAY}Calling API...{Colors.RESET}")
|
||||
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
response = call_api(
|
||||
@ -67,9 +58,6 @@ def run_autonomous_mode(assistant, task):
|
||||
logger.debug(f"=== AUTONOMOUS MODE COMPLETE ===")
|
||||
logger.debug(f"Total iterations: {assistant.autonomous_iterations}")
|
||||
logger.debug(f"Final message count: {len(assistant.messages)}")
|
||||
|
||||
print(f"{Colors.BOLD}Total Iterations:{Colors.RESET} {assistant.autonomous_iterations}")
|
||||
print(f"{Colors.BOLD}Messages in Context:{Colors.RESET} {len(assistant.messages)}\n")
|
||||
break
|
||||
|
||||
result = process_response_autonomous(assistant, response)
|
||||
@ -97,16 +85,12 @@ def process_response_autonomous(assistant, response):
|
||||
assistant.messages.append(message)
|
||||
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🔧 Executing {len(message['tool_calls'])} tool(s)...{Colors.RESET}\n")
|
||||
|
||||
tool_results = []
|
||||
|
||||
for tool_call in message['tool_calls']:
|
||||
func_name = tool_call['function']['name']
|
||||
arguments = json.loads(tool_call['function']['arguments'])
|
||||
|
||||
display_tool_call(func_name, arguments, "running")
|
||||
|
||||
result = execute_single_tool(assistant, func_name, arguments)
|
||||
result = truncate_tool_result(result)
|
||||
|
||||
@ -121,8 +105,6 @@ def process_response_autonomous(assistant, response):
|
||||
|
||||
for result in tool_results:
|
||||
assistant.messages.append(result)
|
||||
|
||||
print(f"{Colors.GRAY}Processing tool results...{Colors.RESET}\n")
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
follow_up = call_api(
|
||||
|
||||
@ -47,6 +47,8 @@ class AdvancedContextManager:
|
||||
return complexity
|
||||
|
||||
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
|
||||
if not text.strip():
|
||||
return []
|
||||
sentences = re.split(r'(?<=[.!?])\s+', text)
|
||||
if not sentences:
|
||||
return []
|
||||
|
||||
@ -133,6 +133,18 @@ class Assistant:
|
||||
'display_edit_summary': lambda **kw: display_edit_summary(),
|
||||
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
|
||||
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
|
||||
'create_agent': lambda **kw: create_agent(**kw),
|
||||
'list_agents': lambda **kw: list_agents(**kw),
|
||||
'execute_agent_task': lambda **kw: execute_agent_task(**kw),
|
||||
'remove_agent': lambda **kw: remove_agent(**kw),
|
||||
'collaborate_agents': lambda **kw: collaborate_agents(**kw),
|
||||
'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw),
|
||||
'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw),
|
||||
'search_knowledge': lambda **kw: search_knowledge(**kw),
|
||||
'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw),
|
||||
'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw),
|
||||
'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw),
|
||||
'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw),
|
||||
}
|
||||
|
||||
if func_name in func_map:
|
||||
@ -230,6 +242,7 @@ class Assistant:
|
||||
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
|
||||
|
||||
combined_options = sorted(list(set(options + path_options)))
|
||||
#combined_options.extend(self.commands)
|
||||
|
||||
if state < len(combined_options):
|
||||
return combined_options[state]
|
||||
@ -279,7 +292,8 @@ class Assistant:
|
||||
else:
|
||||
message = sys.stdin.read()
|
||||
|
||||
process_message(self, message)
|
||||
from pr.autonomous.mode import run_autonomous_mode
|
||||
run_autonomous_mode(self, message)
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, 'enhanced') and self.enhanced:
|
||||
@ -299,9 +313,12 @@ class Assistant:
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}")
|
||||
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
|
||||
print("DEBUG: calling run_repl")
|
||||
self.run_repl()
|
||||
else:
|
||||
print("DEBUG: calling run_single")
|
||||
self.run_single()
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
@ -166,6 +166,8 @@ class RPEditor:
|
||||
|
||||
def save_file(self):
|
||||
"""Thread-safe save file command."""
|
||||
if not self.running:
|
||||
return self._save_file()
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
|
||||
except:
|
||||
@ -630,6 +632,10 @@ class RPEditor:
|
||||
|
||||
def set_text(self, text):
|
||||
"""Thread-safe text setting."""
|
||||
if not self.running:
|
||||
with self.lock:
|
||||
self._set_text(text)
|
||||
return
|
||||
try:
|
||||
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
|
||||
except:
|
||||
|
||||
168
pr/input_handler.py
Normal file
168
pr/input_handler.py
Normal file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Advanced input handler for PR Assistant with editor mode, file inclusion, and image support.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import base64
|
||||
import mimetypes
|
||||
import readline
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
# from pr.ui.colors import Colors # Avoid import issues
|
||||
|
||||
|
||||
class AdvancedInputHandler:
|
||||
"""Handles advanced input with editor mode, file inclusion, and image support."""
|
||||
|
||||
def __init__(self):
|
||||
self.editor_mode = False
|
||||
self.setup_readline()
|
||||
|
||||
def setup_readline(self):
|
||||
"""Setup readline with basic completer."""
|
||||
try:
|
||||
# Simple completer that doesn't interfere
|
||||
def completer(text, state):
|
||||
return None
|
||||
|
||||
readline.set_completer(completer)
|
||||
readline.parse_and_bind('tab: complete')
|
||||
except:
|
||||
pass # Readline not available
|
||||
|
||||
def toggle_editor_mode(self):
|
||||
"""Toggle between simple and editor input modes."""
|
||||
self.editor_mode = not self.editor_mode
|
||||
mode = "Editor" if self.editor_mode else "Simple"
|
||||
print(f"\nSwitched to {mode.lower()} input mode.")
|
||||
|
||||
def get_input(self, prompt: str = "You> ") -> Optional[str]:
|
||||
"""Get input from user, handling different modes."""
|
||||
try:
|
||||
if self.editor_mode:
|
||||
return self._get_editor_input(prompt)
|
||||
else:
|
||||
return self._get_simple_input(prompt)
|
||||
except KeyboardInterrupt:
|
||||
return None
|
||||
except EOFError:
|
||||
return None
|
||||
|
||||
def _get_simple_input(self, prompt: str) -> Optional[str]:
|
||||
"""Get simple input with file completion."""
|
||||
try:
|
||||
user_input = input(prompt).strip()
|
||||
|
||||
if not user_input:
|
||||
return ""
|
||||
|
||||
# Check for special commands
|
||||
if user_input.lower() == '/editor':
|
||||
self.toggle_editor_mode()
|
||||
return self.get_input(prompt) # Recurse to get new input
|
||||
|
||||
# Process file inclusions and images
|
||||
processed_input = self._process_input(user_input)
|
||||
return processed_input
|
||||
|
||||
except KeyboardInterrupt:
|
||||
return None
|
||||
|
||||
def _get_editor_input(self, prompt: str) -> Optional[str]:
|
||||
"""Get multi-line input for editor mode."""
|
||||
try:
|
||||
print("Editor mode: Enter your message. Type 'END' on a new line to finish.")
|
||||
print("Type '/simple' to switch back to simple mode.")
|
||||
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
if line.strip().lower() == 'end':
|
||||
break
|
||||
elif line.strip().lower() == '/simple':
|
||||
self.toggle_editor_mode()
|
||||
return self.get_input(prompt) # Switch back and get input
|
||||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
content = '\n'.join(lines).strip()
|
||||
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# Process file inclusions and images
|
||||
processed_content = self._process_input(content)
|
||||
return processed_content
|
||||
|
||||
except KeyboardInterrupt:
|
||||
return None
|
||||
|
||||
def _process_input(self, text: str) -> str:
|
||||
"""Process input text for file inclusions and images."""
|
||||
# Process @[filename] inclusions
|
||||
text = self._process_file_inclusions(text)
|
||||
|
||||
# Process image inclusions (look for image file paths)
|
||||
text = self._process_image_inclusions(text)
|
||||
|
||||
return text
|
||||
|
||||
def _process_file_inclusions(self, text: str) -> str:
|
||||
"""Replace @[filename] with file contents."""
|
||||
def replace_file(match):
|
||||
filename = match.group(1).strip()
|
||||
try:
|
||||
path = Path(filename).expanduser().resolve()
|
||||
if path.exists() and path.is_file():
|
||||
with open(path, 'r', encoding='utf-8', errors='replace') as f:
|
||||
content = f.read()
|
||||
return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n"
|
||||
else:
|
||||
return f"[File not found: {filename}]"
|
||||
except Exception as e:
|
||||
return f"[Error reading file {filename}: {e}]"
|
||||
|
||||
# Replace @[filename] patterns
|
||||
pattern = r'@\[([^\]]+)\]'
|
||||
return re.sub(pattern, replace_file, text)
|
||||
|
||||
def _process_image_inclusions(self, text: str) -> str:
|
||||
"""Process image file references and encode them."""
|
||||
# Find potential image file paths
|
||||
words = text.split()
|
||||
processed_parts = []
|
||||
|
||||
for word in words:
|
||||
# Check if it's a file path that exists and is an image
|
||||
try:
|
||||
path = Path(word.strip()).expanduser().resolve()
|
||||
if path.exists() and path.is_file():
|
||||
mime_type, _ = mimetypes.guess_type(str(path))
|
||||
if mime_type and mime_type.startswith('image/'):
|
||||
# Encode image
|
||||
with open(path, 'rb') as f:
|
||||
image_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
# Replace with data URL
|
||||
processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n")
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
|
||||
processed_parts.append(word)
|
||||
|
||||
return ' '.join(processed_parts)
|
||||
|
||||
|
||||
# Global instance
|
||||
input_handler = AdvancedInputHandler()
|
||||
|
||||
|
||||
def get_advanced_input(prompt: str = "You> ") -> Optional[str]:
|
||||
"""Get advanced input from user."""
|
||||
return input_handler.get_input(prompt)
|
||||
@ -8,6 +8,8 @@ from pr.tools.database import db_set, db_get, db_query
|
||||
from pr.tools.web import http_fetch, web_search, web_search_news
|
||||
from pr.tools.python_exec import python_exec
|
||||
from pr.tools.patch import apply_patch, create_diff
|
||||
from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents
|
||||
from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics
|
||||
|
||||
__all__ = [
|
||||
'get_tools_definition',
|
||||
@ -17,5 +19,7 @@ __all__ = [
|
||||
'db_set', 'db_get', 'db_query',
|
||||
'http_fetch', 'web_search', 'web_search_news',
|
||||
'python_exec','tail_process', 'kill_process',
|
||||
'apply_patch', 'create_diff'
|
||||
'apply_patch', 'create_diff',
|
||||
'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents',
|
||||
'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics'
|
||||
]
|
||||
|
||||
64
pr/tools/agents.py
Normal file
64
pr/tools/agents.py
Normal file
@ -0,0 +1,64 @@
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from pr.agents.agent_manager import AgentManager
|
||||
from pr.core.api import call_api
|
||||
|
||||
def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
|
||||
"""Create a new agent with the specified role."""
|
||||
try:
|
||||
# Get db_path from environment or default
|
||||
db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite')
|
||||
db_path = os.path.expanduser(db_path)
|
||||
|
||||
manager = AgentManager(db_path, call_api)
|
||||
agent_id = manager.create_agent(role_name, agent_id)
|
||||
return {"status": "success", "agent_id": agent_id, "role": role_name}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def list_agents() -> Dict[str, Any]:
|
||||
"""List all active agents."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
manager = AgentManager(db_path, call_api)
|
||||
agents = []
|
||||
for agent_id, agent in manager.active_agents.items():
|
||||
agents.append({
|
||||
"agent_id": agent_id,
|
||||
"role": agent.role.name,
|
||||
"task_count": agent.task_count,
|
||||
"message_count": len(agent.message_history)
|
||||
})
|
||||
return {"status": "success", "agents": agents}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Execute a task with the specified agent."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
manager = AgentManager(db_path, call_api)
|
||||
result = manager.execute_agent_task(agent_id, task, context)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def remove_agent(agent_id: str) -> Dict[str, Any]:
|
||||
"""Remove an agent."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
manager = AgentManager(db_path, call_api)
|
||||
success = manager.remove_agent(agent_id)
|
||||
return {"status": "success" if success else "not_found", "agent_id": agent_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
|
||||
"""Collaborate multiple agents on a task."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
manager = AgentManager(db_path, call_api)
|
||||
result = manager.collaborate_agents(orchestrator_id, task, agent_roles)
|
||||
return result
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
99
pr/tools/memory.py
Normal file
99
pr/tools/memory.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
|
||||
import time
|
||||
import uuid
|
||||
|
||||
def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]:
|
||||
"""Add a new entry to the knowledge base."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
if entry_id is None:
|
||||
entry_id = str(uuid.uuid4())[:16]
|
||||
|
||||
entry = KnowledgeEntry(
|
||||
entry_id=entry_id,
|
||||
category=category,
|
||||
content=content,
|
||||
metadata=metadata or {},
|
||||
created_at=time.time(),
|
||||
updated_at=time.time()
|
||||
)
|
||||
|
||||
store.add_entry(entry)
|
||||
return {"status": "success", "entry_id": entry_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve a knowledge entry by ID."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
entry = store.get_entry(entry_id)
|
||||
if entry:
|
||||
return {"status": "success", "entry": entry.to_dict()}
|
||||
else:
|
||||
return {"status": "not_found", "entry_id": entry_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
|
||||
"""Search the knowledge base semantically."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
entries = store.search_entries(query, category, top_k)
|
||||
results = [entry.to_dict() for entry in entries]
|
||||
return {"status": "success", "results": results}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
|
||||
"""Get knowledge entries by category."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
entries = store.get_by_category(category, limit)
|
||||
results = [entry.to_dict() for entry in entries]
|
||||
return {"status": "success", "entries": results}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
|
||||
"""Update the importance score of a knowledge entry."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
store.update_importance(entry_id, importance_score)
|
||||
return {"status": "success", "entry_id": entry_id, "importance_score": importance_score}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]:
|
||||
"""Delete a knowledge entry."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
success = store.delete_entry(entry_id)
|
||||
return {"status": "success" if success else "not_found", "entry_id": entry_id}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_knowledge_statistics() -> Dict[str, Any]:
|
||||
"""Get statistics about the knowledge base."""
|
||||
try:
|
||||
db_path = os.path.expanduser('~/.assistant_db.sqlite')
|
||||
store = KnowledgeStore(db_path)
|
||||
|
||||
stats = store.get_statistics()
|
||||
return {"status": "success", "statistics": stats}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
@ -3,41 +3,16 @@ from typing import Dict, Any
|
||||
from pr.ui.colors import Colors
|
||||
|
||||
def display_tool_call(tool_name, arguments, status="running", result=None):
|
||||
status_icons = {
|
||||
"running": ("⚙", Colors.YELLOW),
|
||||
"success": ("✓", Colors.GREEN),
|
||||
"error": ("✗", Colors.RED)
|
||||
}
|
||||
if status == "running":
|
||||
return
|
||||
|
||||
icon, color = status_icons.get(status, ("•", Colors.WHITE))
|
||||
args_str = ", ".join([f"{k}={str(v)[:20]}" for k, v in list(arguments.items())[:2]])
|
||||
line = f"{tool_name}({args_str})"
|
||||
|
||||
print(f"\n{Colors.BOLD}{'═' * 80}{Colors.RESET}")
|
||||
print(f"{color}{icon} {Colors.BOLD}{Colors.CYAN}TOOL: {tool_name}{Colors.RESET}")
|
||||
print(f"{Colors.BOLD}{'─' * 80}{Colors.RESET}")
|
||||
if len(line) > 80:
|
||||
line = line[:77] + "..."
|
||||
|
||||
if arguments:
|
||||
print(f"{Colors.YELLOW}Parameters:{Colors.RESET}")
|
||||
for key, value in arguments.items():
|
||||
value_str = str(value)
|
||||
if len(value_str) > 100:
|
||||
value_str = value_str[:100] + "..."
|
||||
print(f" {Colors.CYAN}• {key}:{Colors.RESET} {value_str}")
|
||||
|
||||
if result is not None and status != "running":
|
||||
print(f"\n{Colors.YELLOW}Result:{Colors.RESET}")
|
||||
result_str = json.dumps(result, indent=2) if isinstance(result, dict) else str(result)
|
||||
|
||||
if len(result_str) > 500:
|
||||
result_str = result_str[:500] + f"\n{Colors.GRAY}... (truncated){Colors.RESET}"
|
||||
|
||||
if status == "success":
|
||||
print(f"{Colors.GREEN}{result_str}{Colors.RESET}")
|
||||
elif status == "error":
|
||||
print(f"{Colors.RED}{result_str}{Colors.RESET}")
|
||||
else:
|
||||
print(result_str)
|
||||
|
||||
print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n")
|
||||
print(f"{Colors.GRAY}{line}{Colors.RESET}")
|
||||
|
||||
def print_autonomous_header(task):
|
||||
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
|
||||
|
||||
81
tests/test_advanced_context.py
Normal file
81
tests/test_advanced_context.py
Normal file
@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from pr.core.advanced_context import AdvancedContextManager
|
||||
|
||||
def test_adaptive_context_window_simple():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
|
||||
window = mgr.adaptive_context_window(messages, 'simple')
|
||||
assert isinstance(window, int)
|
||||
assert window >= 10
|
||||
|
||||
def test_adaptive_context_window_medium():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
|
||||
window = mgr.adaptive_context_window(messages, 'medium')
|
||||
assert isinstance(window, int)
|
||||
assert window >= 20
|
||||
|
||||
def test_adaptive_context_window_complex():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
|
||||
window = mgr.adaptive_context_window(messages, 'complex')
|
||||
assert isinstance(window, int)
|
||||
assert window >= 35
|
||||
|
||||
def test_analyze_message_complexity():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'hello world'}, {'content': 'hello again'}]
|
||||
score = mgr._analyze_message_complexity(messages)
|
||||
assert 0 <= score <= 1
|
||||
|
||||
def test_analyze_message_complexity_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = []
|
||||
score = mgr._analyze_message_complexity(messages)
|
||||
assert score == 0
|
||||
|
||||
def test_extract_key_sentences():
|
||||
mgr = AdvancedContextManager()
|
||||
text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words."
|
||||
sentences = mgr.extract_key_sentences(text, 2)
|
||||
assert len(sentences) <= 2
|
||||
assert all(isinstance(s, str) for s in sentences)
|
||||
|
||||
def test_extract_key_sentences_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
text = ""
|
||||
sentences = mgr.extract_key_sentences(text, 5)
|
||||
assert sentences == []
|
||||
|
||||
def test_advanced_summarize_messages():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = [{'content': 'Hello'}, {'content': 'How are you?'}]
|
||||
summary = mgr.advanced_summarize_messages(messages)
|
||||
assert isinstance(summary, str)
|
||||
|
||||
def test_advanced_summarize_messages_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
messages = []
|
||||
summary = mgr.advanced_summarize_messages(messages)
|
||||
assert summary == "No content to summarize."
|
||||
|
||||
def test_score_message_relevance():
|
||||
mgr = AdvancedContextManager()
|
||||
message = {'content': 'hello world'}
|
||||
context = 'world hello'
|
||||
score = mgr.score_message_relevance(message, context)
|
||||
assert 0 <= score <= 1
|
||||
|
||||
def test_score_message_relevance_no_overlap():
|
||||
mgr = AdvancedContextManager()
|
||||
message = {'content': 'hello'}
|
||||
context = 'world'
|
||||
score = mgr.score_message_relevance(message, context)
|
||||
assert score == 0
|
||||
|
||||
def test_score_message_relevance_empty():
|
||||
mgr = AdvancedContextManager()
|
||||
message = {'content': ''}
|
||||
context = ''
|
||||
score = mgr.score_message_relevance(message, context)
|
||||
assert score == 0
|
||||
63
tests/test_api.py
Normal file
63
tests/test_api.py
Normal file
@ -0,0 +1,63 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import json
|
||||
import urllib.error
|
||||
from pr.core.api import call_api, list_models
|
||||
|
||||
|
||||
class TestApi(unittest.TestCase):
|
||||
|
||||
@patch('pr.core.api.urllib.request.urlopen')
|
||||
@patch('pr.core.api.auto_slim_messages')
|
||||
def test_call_api_success(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}])
|
||||
|
||||
self.assertIn('choices', result)
|
||||
mock_urlopen.assert_called_once()
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
@patch('pr.core.api.auto_slim_messages')
|
||||
def test_call_api_http_error(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
|
||||
mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock())
|
||||
|
||||
result = call_api([], 'model', 'http://url', 'key', False, [])
|
||||
|
||||
self.assertIn('error', result)
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
@patch('pr.core.api.auto_slim_messages')
|
||||
def test_call_api_general_error(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
|
||||
mock_urlopen.side_effect = Exception('test error')
|
||||
|
||||
result = call_api([], 'model', 'http://url', 'key', False, [])
|
||||
|
||||
self.assertIn('error', result)
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_list_models_success(self, mock_urlopen):
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b'{"data": [{"id": "model1"}]}'
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
result = list_models('http://url', 'key')
|
||||
|
||||
self.assertEqual(result, [{'id': 'model1'}])
|
||||
|
||||
@patch('urllib.request.urlopen')
|
||||
def test_list_models_error(self, mock_urlopen):
|
||||
mock_urlopen.side_effect = Exception('error')
|
||||
|
||||
result = list_models('http://url', 'key')
|
||||
|
||||
self.assertIn('error', result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
94
tests/test_assistant.py
Normal file
94
tests/test_assistant.py
Normal file
@ -0,0 +1,94 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import tempfile
|
||||
import os
|
||||
from pr.core.assistant import Assistant, process_message
|
||||
|
||||
|
||||
class TestAssistant(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.args = MagicMock()
|
||||
self.args.verbose = False
|
||||
self.args.debug = False
|
||||
self.args.no_syntax = False
|
||||
self.args.model = 'test-model'
|
||||
self.args.api_url = 'test-url'
|
||||
self.args.model_list_url = 'test-list-url'
|
||||
|
||||
@patch('sqlite3.connect')
|
||||
@patch('os.environ.get')
|
||||
@patch('pr.core.context.init_system_message')
|
||||
@patch('pr.core.enhanced_assistant.EnhancedAssistant')
|
||||
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
|
||||
mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default)
|
||||
mock_conn = MagicMock()
|
||||
mock_sqlite.return_value = mock_conn
|
||||
mock_init_sys.return_value = {'role': 'system', 'content': 'sys'}
|
||||
|
||||
assistant = Assistant(self.args)
|
||||
|
||||
self.assertEqual(assistant.api_key, 'key')
|
||||
self.assertEqual(assistant.model, 'test-model')
|
||||
mock_sqlite.assert_called_once()
|
||||
|
||||
@patch('pr.core.assistant.call_api')
|
||||
@patch('pr.core.assistant.render_markdown')
|
||||
def test_process_response_no_tools(self, mock_render, mock_call):
|
||||
assistant = MagicMock()
|
||||
assistant.messages = MagicMock()
|
||||
assistant.verbose = False
|
||||
assistant.syntax_highlighting = True
|
||||
mock_render.return_value = 'rendered'
|
||||
|
||||
response = {'choices': [{'message': {'content': 'content'}}]}
|
||||
|
||||
result = Assistant.process_response(assistant, response)
|
||||
|
||||
self.assertEqual(result, 'rendered')
|
||||
assistant.messages.append.assert_called_with({'content': 'content'})
|
||||
|
||||
@patch('pr.core.assistant.call_api')
|
||||
@patch('pr.core.assistant.render_markdown')
|
||||
@patch('pr.core.assistant.get_tools_definition')
|
||||
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
|
||||
assistant = MagicMock()
|
||||
assistant.messages = MagicMock()
|
||||
assistant.verbose = False
|
||||
assistant.syntax_highlighting = True
|
||||
assistant.use_tools = True
|
||||
assistant.model = 'model'
|
||||
assistant.api_url = 'url'
|
||||
assistant.api_key = 'key'
|
||||
mock_tools_def.return_value = []
|
||||
mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]}
|
||||
|
||||
response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]}
|
||||
|
||||
with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]):
|
||||
result = Assistant.process_response(assistant, response)
|
||||
|
||||
mock_call.assert_called()
|
||||
|
||||
@patch('pr.core.assistant.call_api')
|
||||
@patch('pr.core.assistant.get_tools_definition')
|
||||
def test_process_message(self, mock_tools, mock_call):
|
||||
assistant = MagicMock()
|
||||
assistant.messages = MagicMock()
|
||||
assistant.verbose = False
|
||||
assistant.use_tools = True
|
||||
assistant.model = 'model'
|
||||
assistant.api_url = 'url'
|
||||
assistant.api_key = 'key'
|
||||
mock_tools.return_value = []
|
||||
mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]}
|
||||
|
||||
with patch('pr.core.assistant.render_markdown', return_value='rendered'):
|
||||
with patch('builtins.print'):
|
||||
process_message(assistant, 'test message')
|
||||
|
||||
assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
56
tests/test_config_loader.py
Normal file
56
tests/test_config_loader.py
Normal file
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, mock_open
|
||||
import os
|
||||
from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config
|
||||
|
||||
def test_parse_value_string():
|
||||
assert _parse_value('hello') == 'hello'
|
||||
|
||||
def test_parse_value_int():
|
||||
assert _parse_value('123') == 123
|
||||
|
||||
def test_parse_value_float():
|
||||
assert _parse_value('1.23') == 1.23
|
||||
|
||||
def test_parse_value_bool_true():
|
||||
assert _parse_value('true') == True
|
||||
|
||||
def test_parse_value_bool_false():
|
||||
assert _parse_value('false') == False
|
||||
|
||||
def test_parse_value_bool_upper():
|
||||
assert _parse_value('TRUE') == True
|
||||
|
||||
@patch('os.path.exists', return_value=False)
|
||||
def test_load_config_file_not_exists(mock_exists):
|
||||
config = _load_config_file('test.ini')
|
||||
assert config == {}
|
||||
|
||||
@patch('os.path.exists', return_value=True)
|
||||
@patch('configparser.ConfigParser')
|
||||
def test_load_config_file_exists(mock_parser_class, mock_exists):
|
||||
mock_parser = mock_parser_class.return_value
|
||||
mock_parser.sections.return_value = ['api']
|
||||
mock_parser.items.return_value = [('key', 'value')]
|
||||
config = _load_config_file('test.ini')
|
||||
assert 'api' in config
|
||||
assert config['api']['key'] == 'value'
|
||||
|
||||
@patch('pr.core.config_loader._load_config_file')
|
||||
def test_load_config(mock_load):
|
||||
mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}]
|
||||
config = load_config()
|
||||
assert config['api']['key'] == 'local'
|
||||
|
||||
@patch('builtins.open', new_callable=mock_open)
|
||||
def test_create_default_config(mock_file):
|
||||
result = create_default_config('test.ini')
|
||||
assert result == True
|
||||
mock_file.assert_called_once_with('test.ini', 'w')
|
||||
handle = mock_file()
|
||||
handle.write.assert_called_once()
|
||||
|
||||
@patch('builtins.open', side_effect=Exception('error'))
|
||||
def test_create_default_config_error(mock_file):
|
||||
result = create_default_config('test.ini')
|
||||
assert result == False
|
||||
118
tests/test_main.py
Normal file
118
tests/test_main.py
Normal file
@ -0,0 +1,118 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import sys
|
||||
from pr.__main__ import main
|
||||
|
||||
def test_main_version(capsys):
|
||||
with patch('sys.argv', ['pr', '--version']):
|
||||
with pytest.raises(SystemExit):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'PR Assistant' in captured.out
|
||||
|
||||
def test_main_create_config_success(capsys):
|
||||
with patch('pr.core.config_loader.create_default_config', return_value=True):
|
||||
with patch('sys.argv', ['pr', '--create-config']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Configuration file created' in captured.out
|
||||
|
||||
def test_main_create_config_fail(capsys):
|
||||
with patch('pr.core.config_loader.create_default_config', return_value=False):
|
||||
with patch('sys.argv', ['pr', '--create-config']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Error creating configuration file' in captured.err
|
||||
|
||||
def test_main_list_sessions_no_sessions(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.list_sessions.return_value = []
|
||||
with patch('sys.argv', ['pr', '--list-sessions']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'No saved sessions found' in captured.out
|
||||
|
||||
def test_main_list_sessions_with_sessions(capsys):
|
||||
sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}]
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.list_sessions.return_value = sessions
|
||||
with patch('sys.argv', ['pr', '--list-sessions']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Found 1 saved sessions' in captured.out
|
||||
assert 'test' in captured.out
|
||||
|
||||
def test_main_delete_session_success(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.delete_session.return_value = True
|
||||
with patch('sys.argv', ['pr', '--delete-session', 'test']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert "Session 'test' deleted" in captured.out
|
||||
|
||||
def test_main_delete_session_fail(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.delete_session.return_value = False
|
||||
with patch('sys.argv', ['pr', '--delete-session', 'test']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert "Error deleting session 'test'" in captured.err
|
||||
|
||||
def test_main_export_session_json(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.export_session.return_value = True
|
||||
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Session exported to output.json' in captured.out
|
||||
|
||||
def test_main_export_session_md(capsys):
|
||||
with patch('pr.core.session.SessionManager') as mock_sm:
|
||||
mock_instance = mock_sm.return_value
|
||||
mock_instance.export_session.return_value = True
|
||||
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Session exported to output.md' in captured.out
|
||||
|
||||
def test_main_usage(capsys):
|
||||
usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01}
|
||||
with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage):
|
||||
with patch('sys.argv', ['pr', '--usage']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Total Usage Statistics' in captured.out
|
||||
assert 'Requests: 10' in captured.out
|
||||
|
||||
def test_main_plugins_no_plugins(capsys):
|
||||
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
|
||||
mock_instance = mock_loader.return_value
|
||||
mock_instance.load_plugins.return_value = None
|
||||
mock_instance.list_loaded_plugins.return_value = []
|
||||
with patch('sys.argv', ['pr', '--plugins']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'No plugins loaded' in captured.out
|
||||
|
||||
def test_main_plugins_with_plugins(capsys):
|
||||
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
|
||||
mock_instance = mock_loader.return_value
|
||||
mock_instance.load_plugins.return_value = None
|
||||
mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2']
|
||||
with patch('sys.argv', ['pr', '--plugins']):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
assert 'Loaded 2 plugins' in captured.out
|
||||
|
||||
def test_main_run_assistant():
|
||||
with patch('pr.__main__.Assistant') as mock_assistant:
|
||||
mock_instance = mock_assistant.return_value
|
||||
with patch('sys.argv', ['pr', 'test message']):
|
||||
main()
|
||||
mock_assistant.assert_called_once()
|
||||
mock_instance.run.assert_called_once()
|
||||
Loading…
Reference in New Issue
Block a user