This commit is contained in:
retoor 2025-11-04 05:57:23 +01:00
parent 3f979d2bbd
commit 685766ef86
16 changed files with 787 additions and 55 deletions

View File

@ -1,4 +1,7 @@
# rp Assistant
rp
[![Tests](https://img.shields.io/badge/tests-passing-brightgreen.svg)](https://github.com/retoor/rp-assistant)
[![Python](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://www.python.org/downloads/)

View File

@ -164,7 +164,7 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
return results
def get_session_summary(self) -> Dict[str, Any]:
def get_session_summary(self) -> str:
summary = {
'session_id': self.session_id,
'active_agents': len(self.active_agents),
@ -178,7 +178,7 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei
for agent_id, agent in self.active_agents.items()
]
}
return summary
return json.dumps(summary)
def clear_session(self):
self.active_agents.clear()

View File

@ -14,12 +14,9 @@ def run_autonomous_mode(assistant, task):
logger.debug(f"=== AUTONOMOUS MODE START ===")
logger.debug(f"Task: {task}")
if assistant.verbose:
print_autonomous_header(task)
assistant.messages.append({
"role": "user",
"content": f"AUTONOMOUS TASK: {task}\n\nPlease work on this task step by step. Use tools as needed. When the task is fully complete, clearly state 'Task complete'."
"content": f"{task}"
})
try:
@ -29,17 +26,11 @@ def run_autonomous_mode(assistant, task):
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
logger.debug(f"Messages before context management: {len(assistant.messages)}")
if assistant.verbose:
print(f"\n{Colors.BOLD}{Colors.MAGENTA}{'' * 3} Iteration {assistant.autonomous_iterations} {'' * 3}{Colors.RESET}\n")
from pr.core.context import manage_context_window
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
logger.debug(f"Messages after context management: {len(assistant.messages)}")
if assistant.verbose:
print(f"{Colors.GRAY}Calling API...{Colors.RESET}")
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
response = call_api(
@ -67,9 +58,6 @@ def run_autonomous_mode(assistant, task):
logger.debug(f"=== AUTONOMOUS MODE COMPLETE ===")
logger.debug(f"Total iterations: {assistant.autonomous_iterations}")
logger.debug(f"Final message count: {len(assistant.messages)}")
print(f"{Colors.BOLD}Total Iterations:{Colors.RESET} {assistant.autonomous_iterations}")
print(f"{Colors.BOLD}Messages in Context:{Colors.RESET} {len(assistant.messages)}\n")
break
result = process_response_autonomous(assistant, response)
@ -97,16 +85,12 @@ def process_response_autonomous(assistant, response):
assistant.messages.append(message)
if 'tool_calls' in message and message['tool_calls']:
print(f"{Colors.BOLD}{Colors.CYAN}🔧 Executing {len(message['tool_calls'])} tool(s)...{Colors.RESET}\n")
tool_results = []
for tool_call in message['tool_calls']:
func_name = tool_call['function']['name']
arguments = json.loads(tool_call['function']['arguments'])
display_tool_call(func_name, arguments, "running")
result = execute_single_tool(assistant, func_name, arguments)
result = truncate_tool_result(result)
@ -121,8 +105,6 @@ def process_response_autonomous(assistant, response):
for result in tool_results:
assistant.messages.append(result)
print(f"{Colors.GRAY}Processing tool results...{Colors.RESET}\n")
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
follow_up = call_api(

View File

@ -47,6 +47,8 @@ class AdvancedContextManager:
return complexity
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
if not text.strip():
return []
sentences = re.split(r'(?<=[.!?])\s+', text)
if not sentences:
return []

View File

@ -133,6 +133,18 @@ class Assistant:
'display_edit_summary': lambda **kw: display_edit_summary(),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
'create_agent': lambda **kw: create_agent(**kw),
'list_agents': lambda **kw: list_agents(**kw),
'execute_agent_task': lambda **kw: execute_agent_task(**kw),
'remove_agent': lambda **kw: remove_agent(**kw),
'collaborate_agents': lambda **kw: collaborate_agents(**kw),
'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw),
'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw),
'search_knowledge': lambda **kw: search_knowledge(**kw),
'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw),
'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw),
'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw),
'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw),
}
if func_name in func_map:
@ -230,6 +242,7 @@ class Assistant:
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
combined_options = sorted(list(set(options + path_options)))
#combined_options.extend(self.commands)
if state < len(combined_options):
return combined_options[state]
@ -279,7 +292,8 @@ class Assistant:
else:
message = sys.stdin.read()
process_message(self, message)
from pr.autonomous.mode import run_autonomous_mode
run_autonomous_mode(self, message)
def cleanup(self):
if hasattr(self, 'enhanced') and self.enhanced:
@ -299,9 +313,12 @@ class Assistant:
def run(self):
try:
print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}")
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
print("DEBUG: calling run_repl")
self.run_repl()
else:
print("DEBUG: calling run_single")
self.run_single()
finally:
self.cleanup()

View File

@ -166,6 +166,8 @@ class RPEditor:
def save_file(self):
"""Thread-safe save file command."""
if not self.running:
return self._save_file()
try:
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
except:
@ -630,6 +632,10 @@ class RPEditor:
def set_text(self, text):
"""Thread-safe text setting."""
if not self.running:
with self.lock:
self._set_text(text)
return
try:
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
except:

168
pr/input_handler.py Normal file
View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Advanced input handler for PR Assistant with editor mode, file inclusion, and image support.
"""
import os
import re
import base64
import mimetypes
import readline
import glob
from pathlib import Path
from typing import Optional
# from pr.ui.colors import Colors # Avoid import issues
class AdvancedInputHandler:
"""Handles advanced input with editor mode, file inclusion, and image support."""
def __init__(self):
self.editor_mode = False
self.setup_readline()
def setup_readline(self):
"""Setup readline with basic completer."""
try:
# Simple completer that doesn't interfere
def completer(text, state):
return None
readline.set_completer(completer)
readline.parse_and_bind('tab: complete')
except:
pass # Readline not available
def toggle_editor_mode(self):
"""Toggle between simple and editor input modes."""
self.editor_mode = not self.editor_mode
mode = "Editor" if self.editor_mode else "Simple"
print(f"\nSwitched to {mode.lower()} input mode.")
def get_input(self, prompt: str = "You> ") -> Optional[str]:
"""Get input from user, handling different modes."""
try:
if self.editor_mode:
return self._get_editor_input(prompt)
else:
return self._get_simple_input(prompt)
except KeyboardInterrupt:
return None
except EOFError:
return None
def _get_simple_input(self, prompt: str) -> Optional[str]:
"""Get simple input with file completion."""
try:
user_input = input(prompt).strip()
if not user_input:
return ""
# Check for special commands
if user_input.lower() == '/editor':
self.toggle_editor_mode()
return self.get_input(prompt) # Recurse to get new input
# Process file inclusions and images
processed_input = self._process_input(user_input)
return processed_input
except KeyboardInterrupt:
return None
def _get_editor_input(self, prompt: str) -> Optional[str]:
"""Get multi-line input for editor mode."""
try:
print("Editor mode: Enter your message. Type 'END' on a new line to finish.")
print("Type '/simple' to switch back to simple mode.")
lines = []
while True:
try:
line = input()
if line.strip().lower() == 'end':
break
elif line.strip().lower() == '/simple':
self.toggle_editor_mode()
return self.get_input(prompt) # Switch back and get input
lines.append(line)
except EOFError:
break
content = '\n'.join(lines).strip()
if not content:
return ""
# Process file inclusions and images
processed_content = self._process_input(content)
return processed_content
except KeyboardInterrupt:
return None
def _process_input(self, text: str) -> str:
"""Process input text for file inclusions and images."""
# Process @[filename] inclusions
text = self._process_file_inclusions(text)
# Process image inclusions (look for image file paths)
text = self._process_image_inclusions(text)
return text
def _process_file_inclusions(self, text: str) -> str:
"""Replace @[filename] with file contents."""
def replace_file(match):
filename = match.group(1).strip()
try:
path = Path(filename).expanduser().resolve()
if path.exists() and path.is_file():
with open(path, 'r', encoding='utf-8', errors='replace') as f:
content = f.read()
return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n"
else:
return f"[File not found: {filename}]"
except Exception as e:
return f"[Error reading file {filename}: {e}]"
# Replace @[filename] patterns
pattern = r'@\[([^\]]+)\]'
return re.sub(pattern, replace_file, text)
def _process_image_inclusions(self, text: str) -> str:
"""Process image file references and encode them."""
# Find potential image file paths
words = text.split()
processed_parts = []
for word in words:
# Check if it's a file path that exists and is an image
try:
path = Path(word.strip()).expanduser().resolve()
if path.exists() and path.is_file():
mime_type, _ = mimetypes.guess_type(str(path))
if mime_type and mime_type.startswith('image/'):
# Encode image
with open(path, 'rb') as f:
image_data = base64.b64encode(f.read()).decode('utf-8')
# Replace with data URL
processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n")
continue
except:
pass
processed_parts.append(word)
return ' '.join(processed_parts)
# Global instance
input_handler = AdvancedInputHandler()
def get_advanced_input(prompt: str = "You> ") -> Optional[str]:
"""Get advanced input from user."""
return input_handler.get_input(prompt)

View File

@ -8,6 +8,8 @@ from pr.tools.database import db_set, db_get, db_query
from pr.tools.web import http_fetch, web_search, web_search_news
from pr.tools.python_exec import python_exec
from pr.tools.patch import apply_patch, create_diff
from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents
from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics
__all__ = [
'get_tools_definition',
@ -17,5 +19,7 @@ __all__ = [
'db_set', 'db_get', 'db_query',
'http_fetch', 'web_search', 'web_search_news',
'python_exec','tail_process', 'kill_process',
'apply_patch', 'create_diff'
'apply_patch', 'create_diff',
'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents',
'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics'
]

64
pr/tools/agents.py Normal file
View File

@ -0,0 +1,64 @@
import os
from typing import Dict, Any, List
from pr.agents.agent_manager import AgentManager
from pr.core.api import call_api
def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]:
"""Create a new agent with the specified role."""
try:
# Get db_path from environment or default
db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite')
db_path = os.path.expanduser(db_path)
manager = AgentManager(db_path, call_api)
agent_id = manager.create_agent(role_name, agent_id)
return {"status": "success", "agent_id": agent_id, "role": role_name}
except Exception as e:
return {"status": "error", "error": str(e)}
def list_agents() -> Dict[str, Any]:
"""List all active agents."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
manager = AgentManager(db_path, call_api)
agents = []
for agent_id, agent in manager.active_agents.items():
agents.append({
"agent_id": agent_id,
"role": agent.role.name,
"task_count": agent.task_count,
"message_count": len(agent.message_history)
})
return {"status": "success", "agents": agents}
except Exception as e:
return {"status": "error", "error": str(e)}
def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
"""Execute a task with the specified agent."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
manager = AgentManager(db_path, call_api)
result = manager.execute_agent_task(agent_id, task, context)
return result
except Exception as e:
return {"status": "error", "error": str(e)}
def remove_agent(agent_id: str) -> Dict[str, Any]:
"""Remove an agent."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
manager = AgentManager(db_path, call_api)
success = manager.remove_agent(agent_id)
return {"status": "success" if success else "not_found", "agent_id": agent_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
"""Collaborate multiple agents on a task."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
manager = AgentManager(db_path, call_api)
result = manager.collaborate_agents(orchestrator_id, task, agent_roles)
return result
except Exception as e:
return {"status": "error", "error": str(e)}

99
pr/tools/memory.py Normal file
View File

@ -0,0 +1,99 @@
import os
from typing import Dict, Any, List
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
import time
import uuid
def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]:
"""Add a new entry to the knowledge base."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
if entry_id is None:
entry_id = str(uuid.uuid4())[:16]
entry = KnowledgeEntry(
entry_id=entry_id,
category=category,
content=content,
metadata=metadata or {},
created_at=time.time(),
updated_at=time.time()
)
store.add_entry(entry)
return {"status": "success", "entry_id": entry_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
"""Retrieve a knowledge entry by ID."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
entry = store.get_entry(entry_id)
if entry:
return {"status": "success", "entry": entry.to_dict()}
else:
return {"status": "not_found", "entry_id": entry_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
"""Search the knowledge base semantically."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
entries = store.search_entries(query, category, top_k)
results = [entry.to_dict() for entry in entries]
return {"status": "success", "results": results}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
"""Get knowledge entries by category."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
entries = store.get_by_category(category, limit)
results = [entry.to_dict() for entry in entries]
return {"status": "success", "entries": results}
except Exception as e:
return {"status": "error", "error": str(e)}
def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
"""Update the importance score of a knowledge entry."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
store.update_importance(entry_id, importance_score)
return {"status": "success", "entry_id": entry_id, "importance_score": importance_score}
except Exception as e:
return {"status": "error", "error": str(e)}
def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]:
"""Delete a knowledge entry."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
success = store.delete_entry(entry_id)
return {"status": "success" if success else "not_found", "entry_id": entry_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def get_knowledge_statistics() -> Dict[str, Any]:
"""Get statistics about the knowledge base."""
try:
db_path = os.path.expanduser('~/.assistant_db.sqlite')
store = KnowledgeStore(db_path)
stats = store.get_statistics()
return {"status": "success", "statistics": stats}
except Exception as e:
return {"status": "error", "error": str(e)}

View File

@ -3,41 +3,16 @@ from typing import Dict, Any
from pr.ui.colors import Colors
def display_tool_call(tool_name, arguments, status="running", result=None):
status_icons = {
"running": ("", Colors.YELLOW),
"success": ("", Colors.GREEN),
"error": ("", Colors.RED)
}
if status == "running":
return
icon, color = status_icons.get(status, ("", Colors.WHITE))
args_str = ", ".join([f"{k}={str(v)[:20]}" for k, v in list(arguments.items())[:2]])
line = f"{tool_name}({args_str})"
print(f"\n{Colors.BOLD}{'' * 80}{Colors.RESET}")
print(f"{color}{icon} {Colors.BOLD}{Colors.CYAN}TOOL: {tool_name}{Colors.RESET}")
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}")
if len(line) > 80:
line = line[:77] + "..."
if arguments:
print(f"{Colors.YELLOW}Parameters:{Colors.RESET}")
for key, value in arguments.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:100] + "..."
print(f" {Colors.CYAN}{key}:{Colors.RESET} {value_str}")
if result is not None and status != "running":
print(f"\n{Colors.YELLOW}Result:{Colors.RESET}")
result_str = json.dumps(result, indent=2) if isinstance(result, dict) else str(result)
if len(result_str) > 500:
result_str = result_str[:500] + f"\n{Colors.GRAY}... (truncated){Colors.RESET}"
if status == "success":
print(f"{Colors.GREEN}{result_str}{Colors.RESET}")
elif status == "error":
print(f"{Colors.RED}{result_str}{Colors.RESET}")
else:
print(result_str)
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n")
print(f"{Colors.GRAY}{line}{Colors.RESET}")
def print_autonomous_header(task):
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")

View File

@ -0,0 +1,81 @@
import pytest
from pr.core.advanced_context import AdvancedContextManager
def test_adaptive_context_window_simple():
mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
window = mgr.adaptive_context_window(messages, 'simple')
assert isinstance(window, int)
assert window >= 10
def test_adaptive_context_window_medium():
mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
window = mgr.adaptive_context_window(messages, 'medium')
assert isinstance(window, int)
assert window >= 20
def test_adaptive_context_window_complex():
mgr = AdvancedContextManager()
messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}]
window = mgr.adaptive_context_window(messages, 'complex')
assert isinstance(window, int)
assert window >= 35
def test_analyze_message_complexity():
mgr = AdvancedContextManager()
messages = [{'content': 'hello world'}, {'content': 'hello again'}]
score = mgr._analyze_message_complexity(messages)
assert 0 <= score <= 1
def test_analyze_message_complexity_empty():
mgr = AdvancedContextManager()
messages = []
score = mgr._analyze_message_complexity(messages)
assert score == 0
def test_extract_key_sentences():
mgr = AdvancedContextManager()
text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words."
sentences = mgr.extract_key_sentences(text, 2)
assert len(sentences) <= 2
assert all(isinstance(s, str) for s in sentences)
def test_extract_key_sentences_empty():
mgr = AdvancedContextManager()
text = ""
sentences = mgr.extract_key_sentences(text, 5)
assert sentences == []
def test_advanced_summarize_messages():
mgr = AdvancedContextManager()
messages = [{'content': 'Hello'}, {'content': 'How are you?'}]
summary = mgr.advanced_summarize_messages(messages)
assert isinstance(summary, str)
def test_advanced_summarize_messages_empty():
mgr = AdvancedContextManager()
messages = []
summary = mgr.advanced_summarize_messages(messages)
assert summary == "No content to summarize."
def test_score_message_relevance():
mgr = AdvancedContextManager()
message = {'content': 'hello world'}
context = 'world hello'
score = mgr.score_message_relevance(message, context)
assert 0 <= score <= 1
def test_score_message_relevance_no_overlap():
mgr = AdvancedContextManager()
message = {'content': 'hello'}
context = 'world'
score = mgr.score_message_relevance(message, context)
assert score == 0
def test_score_message_relevance_empty():
mgr = AdvancedContextManager()
message = {'content': ''}
context = ''
score = mgr.score_message_relevance(message, context)
assert score == 0

63
tests/test_api.py Normal file
View File

@ -0,0 +1,63 @@
import unittest
from unittest.mock import patch, MagicMock
import json
import urllib.error
from pr.core.api import call_api, list_models
class TestApi(unittest.TestCase):
@patch('pr.core.api.urllib.request.urlopen')
@patch('pr.core.api.auto_slim_messages')
def test_call_api_success(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
mock_response = MagicMock()
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
mock_urlopen.return_value.__enter__.return_value = mock_response
result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}])
self.assertIn('choices', result)
mock_urlopen.assert_called_once()
@patch('urllib.request.urlopen')
@patch('pr.core.api.auto_slim_messages')
def test_call_api_http_error(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock())
result = call_api([], 'model', 'http://url', 'key', False, [])
self.assertIn('error', result)
@patch('urllib.request.urlopen')
@patch('pr.core.api.auto_slim_messages')
def test_call_api_general_error(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{'role': 'user', 'content': 'test'}]
mock_urlopen.side_effect = Exception('test error')
result = call_api([], 'model', 'http://url', 'key', False, [])
self.assertIn('error', result)
@patch('urllib.request.urlopen')
def test_list_models_success(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = b'{"data": [{"id": "model1"}]}'
mock_urlopen.return_value.__enter__.return_value = mock_response
result = list_models('http://url', 'key')
self.assertEqual(result, [{'id': 'model1'}])
@patch('urllib.request.urlopen')
def test_list_models_error(self, mock_urlopen):
mock_urlopen.side_effect = Exception('error')
result = list_models('http://url', 'key')
self.assertIn('error', result)
if __name__ == '__main__':
unittest.main()

94
tests/test_assistant.py Normal file
View File

@ -0,0 +1,94 @@
import unittest
from unittest.mock import patch, MagicMock
import tempfile
import os
from pr.core.assistant import Assistant, process_message
class TestAssistant(unittest.TestCase):
def setUp(self):
self.args = MagicMock()
self.args.verbose = False
self.args.debug = False
self.args.no_syntax = False
self.args.model = 'test-model'
self.args.api_url = 'test-url'
self.args.model_list_url = 'test-list-url'
@patch('sqlite3.connect')
@patch('os.environ.get')
@patch('pr.core.context.init_system_message')
@patch('pr.core.enhanced_assistant.EnhancedAssistant')
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default)
mock_conn = MagicMock()
mock_sqlite.return_value = mock_conn
mock_init_sys.return_value = {'role': 'system', 'content': 'sys'}
assistant = Assistant(self.args)
self.assertEqual(assistant.api_key, 'key')
self.assertEqual(assistant.model, 'test-model')
mock_sqlite.assert_called_once()
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.render_markdown')
def test_process_response_no_tools(self, mock_render, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.syntax_highlighting = True
mock_render.return_value = 'rendered'
response = {'choices': [{'message': {'content': 'content'}}]}
result = Assistant.process_response(assistant, response)
self.assertEqual(result, 'rendered')
assistant.messages.append.assert_called_with({'content': 'content'})
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.render_markdown')
@patch('pr.core.assistant.get_tools_definition')
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.syntax_highlighting = True
assistant.use_tools = True
assistant.model = 'model'
assistant.api_url = 'url'
assistant.api_key = 'key'
mock_tools_def.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]}
response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]}
with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]):
result = Assistant.process_response(assistant, response)
mock_call.assert_called()
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.get_tools_definition')
def test_process_message(self, mock_tools, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.use_tools = True
assistant.model = 'model'
assistant.api_url = 'url'
assistant.api_key = 'key'
mock_tools.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]}
with patch('pr.core.assistant.render_markdown', return_value='rendered'):
with patch('builtins.print'):
process_message(assistant, 'test message')
assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'})
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,56 @@
import pytest
from unittest.mock import patch, mock_open
import os
from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config
def test_parse_value_string():
assert _parse_value('hello') == 'hello'
def test_parse_value_int():
assert _parse_value('123') == 123
def test_parse_value_float():
assert _parse_value('1.23') == 1.23
def test_parse_value_bool_true():
assert _parse_value('true') == True
def test_parse_value_bool_false():
assert _parse_value('false') == False
def test_parse_value_bool_upper():
assert _parse_value('TRUE') == True
@patch('os.path.exists', return_value=False)
def test_load_config_file_not_exists(mock_exists):
config = _load_config_file('test.ini')
assert config == {}
@patch('os.path.exists', return_value=True)
@patch('configparser.ConfigParser')
def test_load_config_file_exists(mock_parser_class, mock_exists):
mock_parser = mock_parser_class.return_value
mock_parser.sections.return_value = ['api']
mock_parser.items.return_value = [('key', 'value')]
config = _load_config_file('test.ini')
assert 'api' in config
assert config['api']['key'] == 'value'
@patch('pr.core.config_loader._load_config_file')
def test_load_config(mock_load):
mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}]
config = load_config()
assert config['api']['key'] == 'local'
@patch('builtins.open', new_callable=mock_open)
def test_create_default_config(mock_file):
result = create_default_config('test.ini')
assert result == True
mock_file.assert_called_once_with('test.ini', 'w')
handle = mock_file()
handle.write.assert_called_once()
@patch('builtins.open', side_effect=Exception('error'))
def test_create_default_config_error(mock_file):
result = create_default_config('test.ini')
assert result == False

118
tests/test_main.py Normal file
View File

@ -0,0 +1,118 @@
import pytest
from unittest.mock import patch
import sys
from pr.__main__ import main
def test_main_version(capsys):
with patch('sys.argv', ['pr', '--version']):
with pytest.raises(SystemExit):
main()
captured = capsys.readouterr()
assert 'PR Assistant' in captured.out
def test_main_create_config_success(capsys):
with patch('pr.core.config_loader.create_default_config', return_value=True):
with patch('sys.argv', ['pr', '--create-config']):
main()
captured = capsys.readouterr()
assert 'Configuration file created' in captured.out
def test_main_create_config_fail(capsys):
with patch('pr.core.config_loader.create_default_config', return_value=False):
with patch('sys.argv', ['pr', '--create-config']):
main()
captured = capsys.readouterr()
assert 'Error creating configuration file' in captured.err
def test_main_list_sessions_no_sessions(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.list_sessions.return_value = []
with patch('sys.argv', ['pr', '--list-sessions']):
main()
captured = capsys.readouterr()
assert 'No saved sessions found' in captured.out
def test_main_list_sessions_with_sessions(capsys):
sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}]
with patch('pr.core.session.SessionManager') as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.list_sessions.return_value = sessions
with patch('sys.argv', ['pr', '--list-sessions']):
main()
captured = capsys.readouterr()
assert 'Found 1 saved sessions' in captured.out
assert 'test' in captured.out
def test_main_delete_session_success(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.delete_session.return_value = True
with patch('sys.argv', ['pr', '--delete-session', 'test']):
main()
captured = capsys.readouterr()
assert "Session 'test' deleted" in captured.out
def test_main_delete_session_fail(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.delete_session.return_value = False
with patch('sys.argv', ['pr', '--delete-session', 'test']):
main()
captured = capsys.readouterr()
assert "Error deleting session 'test'" in captured.err
def test_main_export_session_json(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.export_session.return_value = True
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']):
main()
captured = capsys.readouterr()
assert 'Session exported to output.json' in captured.out
def test_main_export_session_md(capsys):
with patch('pr.core.session.SessionManager') as mock_sm:
mock_instance = mock_sm.return_value
mock_instance.export_session.return_value = True
with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']):
main()
captured = capsys.readouterr()
assert 'Session exported to output.md' in captured.out
def test_main_usage(capsys):
usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01}
with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage):
with patch('sys.argv', ['pr', '--usage']):
main()
captured = capsys.readouterr()
assert 'Total Usage Statistics' in captured.out
assert 'Requests: 10' in captured.out
def test_main_plugins_no_plugins(capsys):
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
mock_instance = mock_loader.return_value
mock_instance.load_plugins.return_value = None
mock_instance.list_loaded_plugins.return_value = []
with patch('sys.argv', ['pr', '--plugins']):
main()
captured = capsys.readouterr()
assert 'No plugins loaded' in captured.out
def test_main_plugins_with_plugins(capsys):
with patch('pr.plugins.loader.PluginLoader') as mock_loader:
mock_instance = mock_loader.return_value
mock_instance.load_plugins.return_value = None
mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2']
with patch('sys.argv', ['pr', '--plugins']):
main()
captured = capsys.readouterr()
assert 'Loaded 2 plugins' in captured.out
def test_main_run_assistant():
with patch('pr.__main__.Assistant') as mock_assistant:
mock_instance = mock_assistant.return_value
with patch('sys.argv', ['pr', 'test message']):
main()
mock_assistant.assert_called_once()
mock_instance.run.assert_called_once()