From 685766ef867288e3a62f2b3dc13eda4b6af494cd Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 4 Nov 2025 05:57:23 +0100 Subject: [PATCH] Update. --- README.md | 3 + pr/agents/agent_manager.py | 4 +- pr/autonomous/mode.py | 20 +--- pr/core/advanced_context.py | 2 + pr/core/assistant.py | 19 +++- pr/editor.py | 6 ++ pr/input_handler.py | 168 +++++++++++++++++++++++++++++++++ pr/tools/__init__.py | 6 +- pr/tools/agents.py | 64 +++++++++++++ pr/tools/memory.py | 99 +++++++++++++++++++ pr/ui/display.py | 39 ++------ tests/test_advanced_context.py | 81 ++++++++++++++++ tests/test_api.py | 63 +++++++++++++ tests/test_assistant.py | 94 ++++++++++++++++++ tests/test_config_loader.py | 56 +++++++++++ tests/test_main.py | 118 +++++++++++++++++++++++ 16 files changed, 787 insertions(+), 55 deletions(-) create mode 100644 pr/input_handler.py create mode 100644 pr/tools/agents.py create mode 100644 pr/tools/memory.py create mode 100644 tests/test_advanced_context.py create mode 100644 tests/test_api.py create mode 100644 tests/test_assistant.py create mode 100644 tests/test_config_loader.py create mode 100644 tests/test_main.py diff --git a/README.md b/README.md index 2b2761a..047cf9f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ # rp Assistant + + + rp [![Tests](https://img.shields.io/badge/tests-passing-brightgreen.svg)](https://github.com/retoor/rp-assistant) [![Python](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://www.python.org/downloads/) diff --git a/pr/agents/agent_manager.py b/pr/agents/agent_manager.py index 1abc55a..1ea705a 100644 --- a/pr/agents/agent_manager.py +++ b/pr/agents/agent_manager.py @@ -164,7 +164,7 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei return results - def get_session_summary(self) -> Dict[str, Any]: + def get_session_summary(self) -> str: summary = { 'session_id': self.session_id, 'active_agents': len(self.active_agents), @@ -178,7 +178,7 @@ Break down the task and delegate subtasks to appropriate agents. Coordinate thei for agent_id, agent in self.active_agents.items() ] } - return summary + return json.dumps(summary) def clear_session(self): self.active_agents.clear() diff --git a/pr/autonomous/mode.py b/pr/autonomous/mode.py index 55db88a..5019d75 100644 --- a/pr/autonomous/mode.py +++ b/pr/autonomous/mode.py @@ -14,12 +14,9 @@ def run_autonomous_mode(assistant, task): logger.debug(f"=== AUTONOMOUS MODE START ===") logger.debug(f"Task: {task}") - if assistant.verbose: - print_autonomous_header(task) - assistant.messages.append({ "role": "user", - "content": f"AUTONOMOUS TASK: {task}\n\nPlease work on this task step by step. Use tools as needed. When the task is fully complete, clearly state 'Task complete'." + "content": f"{task}" }) try: @@ -29,17 +26,11 @@ def run_autonomous_mode(assistant, task): logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---") logger.debug(f"Messages before context management: {len(assistant.messages)}") - if assistant.verbose: - print(f"\n{Colors.BOLD}{Colors.MAGENTA}{'▶' * 3} Iteration {assistant.autonomous_iterations} {'◀' * 3}{Colors.RESET}\n") - from pr.core.context import manage_context_window assistant.messages = manage_context_window(assistant.messages, assistant.verbose) logger.debug(f"Messages after context management: {len(assistant.messages)}") - if assistant.verbose: - print(f"{Colors.GRAY}Calling API...{Colors.RESET}") - from pr.core.api import call_api from pr.tools.base import get_tools_definition response = call_api( @@ -67,9 +58,6 @@ def run_autonomous_mode(assistant, task): logger.debug(f"=== AUTONOMOUS MODE COMPLETE ===") logger.debug(f"Total iterations: {assistant.autonomous_iterations}") logger.debug(f"Final message count: {len(assistant.messages)}") - - print(f"{Colors.BOLD}Total Iterations:{Colors.RESET} {assistant.autonomous_iterations}") - print(f"{Colors.BOLD}Messages in Context:{Colors.RESET} {len(assistant.messages)}\n") break result = process_response_autonomous(assistant, response) @@ -97,16 +85,12 @@ def process_response_autonomous(assistant, response): assistant.messages.append(message) if 'tool_calls' in message and message['tool_calls']: - print(f"{Colors.BOLD}{Colors.CYAN}🔧 Executing {len(message['tool_calls'])} tool(s)...{Colors.RESET}\n") - tool_results = [] for tool_call in message['tool_calls']: func_name = tool_call['function']['name'] arguments = json.loads(tool_call['function']['arguments']) - display_tool_call(func_name, arguments, "running") - result = execute_single_tool(assistant, func_name, arguments) result = truncate_tool_result(result) @@ -121,8 +105,6 @@ def process_response_autonomous(assistant, response): for result in tool_results: assistant.messages.append(result) - - print(f"{Colors.GRAY}Processing tool results...{Colors.RESET}\n") from pr.core.api import call_api from pr.tools.base import get_tools_definition follow_up = call_api( diff --git a/pr/core/advanced_context.py b/pr/core/advanced_context.py index ebb7e50..e5c25a4 100644 --- a/pr/core/advanced_context.py +++ b/pr/core/advanced_context.py @@ -47,6 +47,8 @@ class AdvancedContextManager: return complexity def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]: + if not text.strip(): + return [] sentences = re.split(r'(?<=[.!?])\s+', text) if not sentences: return [] diff --git a/pr/core/assistant.py b/pr/core/assistant.py index 76c0113..c034d07 100644 --- a/pr/core/assistant.py +++ b/pr/core/assistant.py @@ -133,6 +133,18 @@ class Assistant: 'display_edit_summary': lambda **kw: display_edit_summary(), 'display_edit_timeline': lambda **kw: display_edit_timeline(**kw), 'clear_edit_tracker': lambda **kw: clear_edit_tracker(), + 'create_agent': lambda **kw: create_agent(**kw), + 'list_agents': lambda **kw: list_agents(**kw), + 'execute_agent_task': lambda **kw: execute_agent_task(**kw), + 'remove_agent': lambda **kw: remove_agent(**kw), + 'collaborate_agents': lambda **kw: collaborate_agents(**kw), + 'add_knowledge_entry': lambda **kw: add_knowledge_entry(**kw), + 'get_knowledge_entry': lambda **kw: get_knowledge_entry(**kw), + 'search_knowledge': lambda **kw: search_knowledge(**kw), + 'get_knowledge_by_category': lambda **kw: get_knowledge_by_category(**kw), + 'update_knowledge_importance': lambda **kw: update_knowledge_importance(**kw), + 'delete_knowledge_entry': lambda **kw: delete_knowledge_entry(**kw), + 'get_knowledge_statistics': lambda **kw: get_knowledge_statistics(**kw), } if func_name in func_map: @@ -230,6 +242,7 @@ class Assistant: path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options] combined_options = sorted(list(set(options + path_options))) + #combined_options.extend(self.commands) if state < len(combined_options): return combined_options[state] @@ -279,7 +292,8 @@ class Assistant: else: message = sys.stdin.read() - process_message(self, message) + from pr.autonomous.mode import run_autonomous_mode + run_autonomous_mode(self, message) def cleanup(self): if hasattr(self, 'enhanced') and self.enhanced: @@ -299,9 +313,12 @@ class Assistant: def run(self): try: + print(f"DEBUG: interactive={self.args.interactive}, message={self.args.message}, isatty={sys.stdin.isatty()}") if self.args.interactive or (not self.args.message and sys.stdin.isatty()): + print("DEBUG: calling run_repl") self.run_repl() else: + print("DEBUG: calling run_single") self.run_single() finally: self.cleanup() diff --git a/pr/editor.py b/pr/editor.py index eb9fea8..598db74 100644 --- a/pr/editor.py +++ b/pr/editor.py @@ -166,6 +166,8 @@ class RPEditor: def save_file(self): """Thread-safe save file command.""" + if not self.running: + return self._save_file() try: self.client_sock.send(pickle.dumps({'command': 'save_file'})) except: @@ -630,6 +632,10 @@ class RPEditor: def set_text(self, text): """Thread-safe text setting.""" + if not self.running: + with self.lock: + self._set_text(text) + return try: self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text})) except: diff --git a/pr/input_handler.py b/pr/input_handler.py new file mode 100644 index 0000000..07452f5 --- /dev/null +++ b/pr/input_handler.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +""" +Advanced input handler for PR Assistant with editor mode, file inclusion, and image support. +""" + +import os +import re +import base64 +import mimetypes +import readline +import glob +from pathlib import Path +from typing import Optional +# from pr.ui.colors import Colors # Avoid import issues + + +class AdvancedInputHandler: + """Handles advanced input with editor mode, file inclusion, and image support.""" + + def __init__(self): + self.editor_mode = False + self.setup_readline() + + def setup_readline(self): + """Setup readline with basic completer.""" + try: + # Simple completer that doesn't interfere + def completer(text, state): + return None + + readline.set_completer(completer) + readline.parse_and_bind('tab: complete') + except: + pass # Readline not available + + def toggle_editor_mode(self): + """Toggle between simple and editor input modes.""" + self.editor_mode = not self.editor_mode + mode = "Editor" if self.editor_mode else "Simple" + print(f"\nSwitched to {mode.lower()} input mode.") + + def get_input(self, prompt: str = "You> ") -> Optional[str]: + """Get input from user, handling different modes.""" + try: + if self.editor_mode: + return self._get_editor_input(prompt) + else: + return self._get_simple_input(prompt) + except KeyboardInterrupt: + return None + except EOFError: + return None + + def _get_simple_input(self, prompt: str) -> Optional[str]: + """Get simple input with file completion.""" + try: + user_input = input(prompt).strip() + + if not user_input: + return "" + + # Check for special commands + if user_input.lower() == '/editor': + self.toggle_editor_mode() + return self.get_input(prompt) # Recurse to get new input + + # Process file inclusions and images + processed_input = self._process_input(user_input) + return processed_input + + except KeyboardInterrupt: + return None + + def _get_editor_input(self, prompt: str) -> Optional[str]: + """Get multi-line input for editor mode.""" + try: + print("Editor mode: Enter your message. Type 'END' on a new line to finish.") + print("Type '/simple' to switch back to simple mode.") + + lines = [] + while True: + try: + line = input() + if line.strip().lower() == 'end': + break + elif line.strip().lower() == '/simple': + self.toggle_editor_mode() + return self.get_input(prompt) # Switch back and get input + lines.append(line) + except EOFError: + break + + content = '\n'.join(lines).strip() + + if not content: + return "" + + # Process file inclusions and images + processed_content = self._process_input(content) + return processed_content + + except KeyboardInterrupt: + return None + + def _process_input(self, text: str) -> str: + """Process input text for file inclusions and images.""" + # Process @[filename] inclusions + text = self._process_file_inclusions(text) + + # Process image inclusions (look for image file paths) + text = self._process_image_inclusions(text) + + return text + + def _process_file_inclusions(self, text: str) -> str: + """Replace @[filename] with file contents.""" + def replace_file(match): + filename = match.group(1).strip() + try: + path = Path(filename).expanduser().resolve() + if path.exists() and path.is_file(): + with open(path, 'r', encoding='utf-8', errors='replace') as f: + content = f.read() + return f"\n--- File: {filename} ---\n{content}\n--- End of {filename} ---\n" + else: + return f"[File not found: {filename}]" + except Exception as e: + return f"[Error reading file {filename}: {e}]" + + # Replace @[filename] patterns + pattern = r'@\[([^\]]+)\]' + return re.sub(pattern, replace_file, text) + + def _process_image_inclusions(self, text: str) -> str: + """Process image file references and encode them.""" + # Find potential image file paths + words = text.split() + processed_parts = [] + + for word in words: + # Check if it's a file path that exists and is an image + try: + path = Path(word.strip()).expanduser().resolve() + if path.exists() and path.is_file(): + mime_type, _ = mimetypes.guess_type(str(path)) + if mime_type and mime_type.startswith('image/'): + # Encode image + with open(path, 'rb') as f: + image_data = base64.b64encode(f.read()).decode('utf-8') + + # Replace with data URL + processed_parts.append(f"[Image: {path.name}]\ndata:{mime_type};base64,{image_data}\n") + continue + except: + pass + + processed_parts.append(word) + + return ' '.join(processed_parts) + + +# Global instance +input_handler = AdvancedInputHandler() + + +def get_advanced_input(prompt: str = "You> ") -> Optional[str]: + """Get advanced input from user.""" + return input_handler.get_input(prompt) diff --git a/pr/tools/__init__.py b/pr/tools/__init__.py index 927ab7c..13b11de 100644 --- a/pr/tools/__init__.py +++ b/pr/tools/__init__.py @@ -8,6 +8,8 @@ from pr.tools.database import db_set, db_get, db_query from pr.tools.web import http_fetch, web_search, web_search_news from pr.tools.python_exec import python_exec from pr.tools.patch import apply_patch, create_diff +from pr.tools.agents import create_agent, list_agents, execute_agent_task, remove_agent, collaborate_agents +from pr.tools.memory import add_knowledge_entry, get_knowledge_entry, search_knowledge, get_knowledge_by_category, update_knowledge_importance, delete_knowledge_entry, get_knowledge_statistics __all__ = [ 'get_tools_definition', @@ -17,5 +19,7 @@ __all__ = [ 'db_set', 'db_get', 'db_query', 'http_fetch', 'web_search', 'web_search_news', 'python_exec','tail_process', 'kill_process', - 'apply_patch', 'create_diff' + 'apply_patch', 'create_diff', + 'create_agent', 'list_agents', 'execute_agent_task', 'remove_agent', 'collaborate_agents', + 'add_knowledge_entry', 'get_knowledge_entry', 'search_knowledge', 'get_knowledge_by_category', 'update_knowledge_importance', 'delete_knowledge_entry', 'get_knowledge_statistics' ] diff --git a/pr/tools/agents.py b/pr/tools/agents.py new file mode 100644 index 0000000..770c40e --- /dev/null +++ b/pr/tools/agents.py @@ -0,0 +1,64 @@ +import os +from typing import Dict, Any, List +from pr.agents.agent_manager import AgentManager +from pr.core.api import call_api + +def create_agent(role_name: str, agent_id: str = None) -> Dict[str, Any]: + """Create a new agent with the specified role.""" + try: + # Get db_path from environment or default + db_path = os.environ.get('ASSISTANT_DB_PATH', '~/.assistant_db.sqlite') + db_path = os.path.expanduser(db_path) + + manager = AgentManager(db_path, call_api) + agent_id = manager.create_agent(role_name, agent_id) + return {"status": "success", "agent_id": agent_id, "role": role_name} + except Exception as e: + return {"status": "error", "error": str(e)} + +def list_agents() -> Dict[str, Any]: + """List all active agents.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + manager = AgentManager(db_path, call_api) + agents = [] + for agent_id, agent in manager.active_agents.items(): + agents.append({ + "agent_id": agent_id, + "role": agent.role.name, + "task_count": agent.task_count, + "message_count": len(agent.message_history) + }) + return {"status": "success", "agents": agents} + except Exception as e: + return {"status": "error", "error": str(e)} + +def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]: + """Execute a task with the specified agent.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + manager = AgentManager(db_path, call_api) + result = manager.execute_agent_task(agent_id, task, context) + return result + except Exception as e: + return {"status": "error", "error": str(e)} + +def remove_agent(agent_id: str) -> Dict[str, Any]: + """Remove an agent.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + manager = AgentManager(db_path, call_api) + success = manager.remove_agent(agent_id) + return {"status": "success" if success else "not_found", "agent_id": agent_id} + except Exception as e: + return {"status": "error", "error": str(e)} + +def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]: + """Collaborate multiple agents on a task.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + manager = AgentManager(db_path, call_api) + result = manager.collaborate_agents(orchestrator_id, task, agent_roles) + return result + except Exception as e: + return {"status": "error", "error": str(e)} diff --git a/pr/tools/memory.py b/pr/tools/memory.py new file mode 100644 index 0000000..12b48b4 --- /dev/null +++ b/pr/tools/memory.py @@ -0,0 +1,99 @@ +import os +from typing import Dict, Any, List +from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry +import time +import uuid + +def add_knowledge_entry(category: str, content: str, metadata: Dict[str, Any] = None, entry_id: str = None) -> Dict[str, Any]: + """Add a new entry to the knowledge base.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + if entry_id is None: + entry_id = str(uuid.uuid4())[:16] + + entry = KnowledgeEntry( + entry_id=entry_id, + category=category, + content=content, + metadata=metadata or {}, + created_at=time.time(), + updated_at=time.time() + ) + + store.add_entry(entry) + return {"status": "success", "entry_id": entry_id} + except Exception as e: + return {"status": "error", "error": str(e)} + +def get_knowledge_entry(entry_id: str) -> Dict[str, Any]: + """Retrieve a knowledge entry by ID.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + entry = store.get_entry(entry_id) + if entry: + return {"status": "success", "entry": entry.to_dict()} + else: + return {"status": "not_found", "entry_id": entry_id} + except Exception as e: + return {"status": "error", "error": str(e)} + +def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]: + """Search the knowledge base semantically.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + entries = store.search_entries(query, category, top_k) + results = [entry.to_dict() for entry in entries] + return {"status": "success", "results": results} + except Exception as e: + return {"status": "error", "error": str(e)} + +def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]: + """Get knowledge entries by category.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + entries = store.get_by_category(category, limit) + results = [entry.to_dict() for entry in entries] + return {"status": "success", "entries": results} + except Exception as e: + return {"status": "error", "error": str(e)} + +def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]: + """Update the importance score of a knowledge entry.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + store.update_importance(entry_id, importance_score) + return {"status": "success", "entry_id": entry_id, "importance_score": importance_score} + except Exception as e: + return {"status": "error", "error": str(e)} + +def delete_knowledge_entry(entry_id: str) -> Dict[str, Any]: + """Delete a knowledge entry.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + success = store.delete_entry(entry_id) + return {"status": "success" if success else "not_found", "entry_id": entry_id} + except Exception as e: + return {"status": "error", "error": str(e)} + +def get_knowledge_statistics() -> Dict[str, Any]: + """Get statistics about the knowledge base.""" + try: + db_path = os.path.expanduser('~/.assistant_db.sqlite') + store = KnowledgeStore(db_path) + + stats = store.get_statistics() + return {"status": "success", "statistics": stats} + except Exception as e: + return {"status": "error", "error": str(e)} diff --git a/pr/ui/display.py b/pr/ui/display.py index 3ead830..122d4c7 100644 --- a/pr/ui/display.py +++ b/pr/ui/display.py @@ -3,41 +3,16 @@ from typing import Dict, Any from pr.ui.colors import Colors def display_tool_call(tool_name, arguments, status="running", result=None): - status_icons = { - "running": ("⚙", Colors.YELLOW), - "success": ("✓", Colors.GREEN), - "error": ("✗", Colors.RED) - } + if status == "running": + return - icon, color = status_icons.get(status, ("•", Colors.WHITE)) + args_str = ", ".join([f"{k}={str(v)[:20]}" for k, v in list(arguments.items())[:2]]) + line = f"{tool_name}({args_str})" - print(f"\n{Colors.BOLD}{'═' * 80}{Colors.RESET}") - print(f"{color}{icon} {Colors.BOLD}{Colors.CYAN}TOOL: {tool_name}{Colors.RESET}") - print(f"{Colors.BOLD}{'─' * 80}{Colors.RESET}") + if len(line) > 80: + line = line[:77] + "..." - if arguments: - print(f"{Colors.YELLOW}Parameters:{Colors.RESET}") - for key, value in arguments.items(): - value_str = str(value) - if len(value_str) > 100: - value_str = value_str[:100] + "..." - print(f" {Colors.CYAN}• {key}:{Colors.RESET} {value_str}") - - if result is not None and status != "running": - print(f"\n{Colors.YELLOW}Result:{Colors.RESET}") - result_str = json.dumps(result, indent=2) if isinstance(result, dict) else str(result) - - if len(result_str) > 500: - result_str = result_str[:500] + f"\n{Colors.GRAY}... (truncated){Colors.RESET}" - - if status == "success": - print(f"{Colors.GREEN}{result_str}{Colors.RESET}") - elif status == "error": - print(f"{Colors.RED}{result_str}{Colors.RESET}") - else: - print(result_str) - - print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n") + print(f"{Colors.GRAY}{line}{Colors.RESET}") def print_autonomous_header(task): print(f"{Colors.BOLD}Task:{Colors.RESET} {task}") diff --git a/tests/test_advanced_context.py b/tests/test_advanced_context.py new file mode 100644 index 0000000..218fec1 --- /dev/null +++ b/tests/test_advanced_context.py @@ -0,0 +1,81 @@ +import pytest +from pr.core.advanced_context import AdvancedContextManager + +def test_adaptive_context_window_simple(): + mgr = AdvancedContextManager() + messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] + window = mgr.adaptive_context_window(messages, 'simple') + assert isinstance(window, int) + assert window >= 10 + +def test_adaptive_context_window_medium(): + mgr = AdvancedContextManager() + messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] + window = mgr.adaptive_context_window(messages, 'medium') + assert isinstance(window, int) + assert window >= 20 + +def test_adaptive_context_window_complex(): + mgr = AdvancedContextManager() + messages = [{'content': 'short'}, {'content': 'this is a longer message with more words'}] + window = mgr.adaptive_context_window(messages, 'complex') + assert isinstance(window, int) + assert window >= 35 + +def test_analyze_message_complexity(): + mgr = AdvancedContextManager() + messages = [{'content': 'hello world'}, {'content': 'hello again'}] + score = mgr._analyze_message_complexity(messages) + assert 0 <= score <= 1 + +def test_analyze_message_complexity_empty(): + mgr = AdvancedContextManager() + messages = [] + score = mgr._analyze_message_complexity(messages) + assert score == 0 + +def test_extract_key_sentences(): + mgr = AdvancedContextManager() + text = "This is the first sentence. This is the second sentence. This is a longer third sentence with more words." + sentences = mgr.extract_key_sentences(text, 2) + assert len(sentences) <= 2 + assert all(isinstance(s, str) for s in sentences) + +def test_extract_key_sentences_empty(): + mgr = AdvancedContextManager() + text = "" + sentences = mgr.extract_key_sentences(text, 5) + assert sentences == [] + +def test_advanced_summarize_messages(): + mgr = AdvancedContextManager() + messages = [{'content': 'Hello'}, {'content': 'How are you?'}] + summary = mgr.advanced_summarize_messages(messages) + assert isinstance(summary, str) + +def test_advanced_summarize_messages_empty(): + mgr = AdvancedContextManager() + messages = [] + summary = mgr.advanced_summarize_messages(messages) + assert summary == "No content to summarize." + +def test_score_message_relevance(): + mgr = AdvancedContextManager() + message = {'content': 'hello world'} + context = 'world hello' + score = mgr.score_message_relevance(message, context) + assert 0 <= score <= 1 + +def test_score_message_relevance_no_overlap(): + mgr = AdvancedContextManager() + message = {'content': 'hello'} + context = 'world' + score = mgr.score_message_relevance(message, context) + assert score == 0 + +def test_score_message_relevance_empty(): + mgr = AdvancedContextManager() + message = {'content': ''} + context = '' + score = mgr.score_message_relevance(message, context) + assert score == 0 \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..3f01176 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import patch, MagicMock +import json +import urllib.error +from pr.core.api import call_api, list_models + + +class TestApi(unittest.TestCase): + + @patch('pr.core.api.urllib.request.urlopen') + @patch('pr.core.api.auto_slim_messages') + def test_call_api_success(self, mock_slim, mock_urlopen): + mock_slim.return_value = [{'role': 'user', 'content': 'test'}] + mock_response = MagicMock() + mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}' + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = call_api([], 'model', 'http://url', 'key', True, [{'name': 'tool'}]) + + self.assertIn('choices', result) + mock_urlopen.assert_called_once() + + @patch('urllib.request.urlopen') + @patch('pr.core.api.auto_slim_messages') + def test_call_api_http_error(self, mock_slim, mock_urlopen): + mock_slim.return_value = [{'role': 'user', 'content': 'test'}] + mock_urlopen.side_effect = urllib.error.HTTPError('http://url', 500, 'error', None, MagicMock()) + + result = call_api([], 'model', 'http://url', 'key', False, []) + + self.assertIn('error', result) + + @patch('urllib.request.urlopen') + @patch('pr.core.api.auto_slim_messages') + def test_call_api_general_error(self, mock_slim, mock_urlopen): + mock_slim.return_value = [{'role': 'user', 'content': 'test'}] + mock_urlopen.side_effect = Exception('test error') + + result = call_api([], 'model', 'http://url', 'key', False, []) + + self.assertIn('error', result) + + @patch('urllib.request.urlopen') + def test_list_models_success(self, mock_urlopen): + mock_response = MagicMock() + mock_response.read.return_value = b'{"data": [{"id": "model1"}]}' + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = list_models('http://url', 'key') + + self.assertEqual(result, [{'id': 'model1'}]) + + @patch('urllib.request.urlopen') + def test_list_models_error(self, mock_urlopen): + mock_urlopen.side_effect = Exception('error') + + result = list_models('http://url', 'key') + + self.assertIn('error', result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_assistant.py b/tests/test_assistant.py new file mode 100644 index 0000000..b6f6037 --- /dev/null +++ b/tests/test_assistant.py @@ -0,0 +1,94 @@ +import unittest +from unittest.mock import patch, MagicMock +import tempfile +import os +from pr.core.assistant import Assistant, process_message + + +class TestAssistant(unittest.TestCase): + + def setUp(self): + self.args = MagicMock() + self.args.verbose = False + self.args.debug = False + self.args.no_syntax = False + self.args.model = 'test-model' + self.args.api_url = 'test-url' + self.args.model_list_url = 'test-list-url' + + @patch('sqlite3.connect') + @patch('os.environ.get') + @patch('pr.core.context.init_system_message') + @patch('pr.core.enhanced_assistant.EnhancedAssistant') + def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite): + mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default) + mock_conn = MagicMock() + mock_sqlite.return_value = mock_conn + mock_init_sys.return_value = {'role': 'system', 'content': 'sys'} + + assistant = Assistant(self.args) + + self.assertEqual(assistant.api_key, 'key') + self.assertEqual(assistant.model, 'test-model') + mock_sqlite.assert_called_once() + + @patch('pr.core.assistant.call_api') + @patch('pr.core.assistant.render_markdown') + def test_process_response_no_tools(self, mock_render, mock_call): + assistant = MagicMock() + assistant.messages = MagicMock() + assistant.verbose = False + assistant.syntax_highlighting = True + mock_render.return_value = 'rendered' + + response = {'choices': [{'message': {'content': 'content'}}]} + + result = Assistant.process_response(assistant, response) + + self.assertEqual(result, 'rendered') + assistant.messages.append.assert_called_with({'content': 'content'}) + + @patch('pr.core.assistant.call_api') + @patch('pr.core.assistant.render_markdown') + @patch('pr.core.assistant.get_tools_definition') + def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call): + assistant = MagicMock() + assistant.messages = MagicMock() + assistant.verbose = False + assistant.syntax_highlighting = True + assistant.use_tools = True + assistant.model = 'model' + assistant.api_url = 'url' + assistant.api_key = 'key' + mock_tools_def.return_value = [] + mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]} + + response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]} + + with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]): + result = Assistant.process_response(assistant, response) + + mock_call.assert_called() + + @patch('pr.core.assistant.call_api') + @patch('pr.core.assistant.get_tools_definition') + def test_process_message(self, mock_tools, mock_call): + assistant = MagicMock() + assistant.messages = MagicMock() + assistant.verbose = False + assistant.use_tools = True + assistant.model = 'model' + assistant.api_url = 'url' + assistant.api_key = 'key' + mock_tools.return_value = [] + mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]} + + with patch('pr.core.assistant.render_markdown', return_value='rendered'): + with patch('builtins.print'): + process_message(assistant, 'test message') + + assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py new file mode 100644 index 0000000..96e4171 --- /dev/null +++ b/tests/test_config_loader.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import patch, mock_open +import os +from pr.core.config_loader import load_config, _load_config_file, _parse_value, create_default_config + +def test_parse_value_string(): + assert _parse_value('hello') == 'hello' + +def test_parse_value_int(): + assert _parse_value('123') == 123 + +def test_parse_value_float(): + assert _parse_value('1.23') == 1.23 + +def test_parse_value_bool_true(): + assert _parse_value('true') == True + +def test_parse_value_bool_false(): + assert _parse_value('false') == False + +def test_parse_value_bool_upper(): + assert _parse_value('TRUE') == True + +@patch('os.path.exists', return_value=False) +def test_load_config_file_not_exists(mock_exists): + config = _load_config_file('test.ini') + assert config == {} + +@patch('os.path.exists', return_value=True) +@patch('configparser.ConfigParser') +def test_load_config_file_exists(mock_parser_class, mock_exists): + mock_parser = mock_parser_class.return_value + mock_parser.sections.return_value = ['api'] + mock_parser.items.return_value = [('key', 'value')] + config = _load_config_file('test.ini') + assert 'api' in config + assert config['api']['key'] == 'value' + +@patch('pr.core.config_loader._load_config_file') +def test_load_config(mock_load): + mock_load.side_effect = [{'api': {'key': 'global'}}, {'api': {'key': 'local'}}] + config = load_config() + assert config['api']['key'] == 'local' + +@patch('builtins.open', new_callable=mock_open) +def test_create_default_config(mock_file): + result = create_default_config('test.ini') + assert result == True + mock_file.assert_called_once_with('test.ini', 'w') + handle = mock_file() + handle.write.assert_called_once() + +@patch('builtins.open', side_effect=Exception('error')) +def test_create_default_config_error(mock_file): + result = create_default_config('test.ini') + assert result == False \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..2a483b5 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,118 @@ +import pytest +from unittest.mock import patch +import sys +from pr.__main__ import main + +def test_main_version(capsys): + with patch('sys.argv', ['pr', '--version']): + with pytest.raises(SystemExit): + main() + captured = capsys.readouterr() + assert 'PR Assistant' in captured.out + +def test_main_create_config_success(capsys): + with patch('pr.core.config_loader.create_default_config', return_value=True): + with patch('sys.argv', ['pr', '--create-config']): + main() + captured = capsys.readouterr() + assert 'Configuration file created' in captured.out + +def test_main_create_config_fail(capsys): + with patch('pr.core.config_loader.create_default_config', return_value=False): + with patch('sys.argv', ['pr', '--create-config']): + main() + captured = capsys.readouterr() + assert 'Error creating configuration file' in captured.err + +def test_main_list_sessions_no_sessions(capsys): + with patch('pr.core.session.SessionManager') as mock_sm: + mock_instance = mock_sm.return_value + mock_instance.list_sessions.return_value = [] + with patch('sys.argv', ['pr', '--list-sessions']): + main() + captured = capsys.readouterr() + assert 'No saved sessions found' in captured.out + +def test_main_list_sessions_with_sessions(capsys): + sessions = [{'name': 'test', 'created_at': '2023-01-01', 'message_count': 5}] + with patch('pr.core.session.SessionManager') as mock_sm: + mock_instance = mock_sm.return_value + mock_instance.list_sessions.return_value = sessions + with patch('sys.argv', ['pr', '--list-sessions']): + main() + captured = capsys.readouterr() + assert 'Found 1 saved sessions' in captured.out + assert 'test' in captured.out + +def test_main_delete_session_success(capsys): + with patch('pr.core.session.SessionManager') as mock_sm: + mock_instance = mock_sm.return_value + mock_instance.delete_session.return_value = True + with patch('sys.argv', ['pr', '--delete-session', 'test']): + main() + captured = capsys.readouterr() + assert "Session 'test' deleted" in captured.out + +def test_main_delete_session_fail(capsys): + with patch('pr.core.session.SessionManager') as mock_sm: + mock_instance = mock_sm.return_value + mock_instance.delete_session.return_value = False + with patch('sys.argv', ['pr', '--delete-session', 'test']): + main() + captured = capsys.readouterr() + assert "Error deleting session 'test'" in captured.err + +def test_main_export_session_json(capsys): + with patch('pr.core.session.SessionManager') as mock_sm: + mock_instance = mock_sm.return_value + mock_instance.export_session.return_value = True + with patch('sys.argv', ['pr', '--export-session', 'test', 'output.json']): + main() + captured = capsys.readouterr() + assert 'Session exported to output.json' in captured.out + +def test_main_export_session_md(capsys): + with patch('pr.core.session.SessionManager') as mock_sm: + mock_instance = mock_sm.return_value + mock_instance.export_session.return_value = True + with patch('sys.argv', ['pr', '--export-session', 'test', 'output.md']): + main() + captured = capsys.readouterr() + assert 'Session exported to output.md' in captured.out + +def test_main_usage(capsys): + usage = {'total_requests': 10, 'total_tokens': 1000, 'total_cost': 0.01} + with patch('pr.core.usage_tracker.UsageTracker.get_total_usage', return_value=usage): + with patch('sys.argv', ['pr', '--usage']): + main() + captured = capsys.readouterr() + assert 'Total Usage Statistics' in captured.out + assert 'Requests: 10' in captured.out + +def test_main_plugins_no_plugins(capsys): + with patch('pr.plugins.loader.PluginLoader') as mock_loader: + mock_instance = mock_loader.return_value + mock_instance.load_plugins.return_value = None + mock_instance.list_loaded_plugins.return_value = [] + with patch('sys.argv', ['pr', '--plugins']): + main() + captured = capsys.readouterr() + assert 'No plugins loaded' in captured.out + +def test_main_plugins_with_plugins(capsys): + with patch('pr.plugins.loader.PluginLoader') as mock_loader: + mock_instance = mock_loader.return_value + mock_instance.load_plugins.return_value = None + mock_instance.list_loaded_plugins.return_value = ['plugin1', 'plugin2'] + with patch('sys.argv', ['pr', '--plugins']): + main() + captured = capsys.readouterr() + assert 'Loaded 2 plugins' in captured.out + +def test_main_run_assistant(): + with patch('pr.__main__.Assistant') as mock_assistant: + mock_instance = mock_assistant.return_value + with patch('sys.argv', ['pr', 'test message']): + main() + mock_assistant.assert_called_once() + mock_instance.run.assert_called_once() \ No newline at end of file