diff --git a/pr/__main__.py b/pr/__main__.py index 8ddbd73..d2eb216 100644 --- a/pr/__main__.py +++ b/pr/__main__.py @@ -32,30 +32,22 @@ Commands in interactive mode: ) parser.add_argument("message", nargs="?", help="Message to send to assistant") - parser.add_argument( - "--version", action="version", version=f"PR Assistant {__version__}" - ) + parser.add_argument("--version", action="version", version=f"PR Assistant {__version__}") parser.add_argument("-m", "--model", help="AI model to use") parser.add_argument("-u", "--api-url", help="API endpoint URL") parser.add_argument("--model-list-url", help="Model list endpoint URL") - parser.add_argument( - "-i", "--interactive", action="store_true", help="Interactive mode" - ) + parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") parser.add_argument( "--debug", action="store_true", help="Enable debug mode with detailed logging" ) - parser.add_argument( - "--no-syntax", action="store_true", help="Disable syntax highlighting" - ) + parser.add_argument("--no-syntax", action="store_true", help="Disable syntax highlighting") parser.add_argument( "--include-env", action="store_true", help="Include environment variables in context", ) - parser.add_argument( - "-c", "--context", action="append", help="Additional context files" - ) + parser.add_argument("-c", "--context", action="append", help="Additional context files") parser.add_argument( "--api-mode", action="store_true", help="API mode for specialized interaction" ) @@ -68,18 +60,10 @@ Commands in interactive mode: ) parser.add_argument("--quiet", action="store_true", help="Minimal output") - parser.add_argument( - "--save-session", metavar="NAME", help="Save session with given name" - ) - parser.add_argument( - "--load-session", metavar="NAME", help="Load session with given name" - ) - parser.add_argument( - "--list-sessions", action="store_true", help="List all saved sessions" - ) - parser.add_argument( - "--delete-session", metavar="NAME", help="Delete a saved session" - ) + parser.add_argument("--save-session", metavar="NAME", help="Save session with given name") + parser.add_argument("--load-session", metavar="NAME", help="Load session with given name") + parser.add_argument("--list-sessions", action="store_true", help="List all saved sessions") + parser.add_argument("--delete-session", metavar="NAME", help="Delete a saved session") parser.add_argument( "--export-session", nargs=2, @@ -87,9 +71,7 @@ Commands in interactive mode: help="Export session to file", ) - parser.add_argument( - "--usage", action="store_true", help="Show token usage statistics" - ) + parser.add_argument("--usage", action="store_true", help="Show token usage statistics") parser.add_argument( "--create-config", action="store_true", help="Create default configuration file" ) diff --git a/pr/agents/agent_communication.py b/pr/agents/agent_communication.py index 82ac954..0875bd0 100644 --- a/pr/agents/agent_communication.py +++ b/pr/agents/agent_communication.py @@ -93,9 +93,7 @@ class AgentCommunicationBus: self.conn.commit() - def get_messages( - self, agent_id: str, unread_only: bool = True - ) -> List[AgentMessage]: + def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: cursor = self.conn.cursor() if unread_only: cursor.execute( @@ -135,17 +133,13 @@ class AgentCommunicationBus: def mark_as_read(self, message_id: str): cursor = self.conn.cursor() - cursor.execute( - "UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,) - ) + cursor.execute("UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,)) self.conn.commit() def clear_messages(self, session_id: Optional[str] = None): cursor = self.conn.cursor() if session_id: - cursor.execute( - "DELETE FROM agent_messages WHERE session_id = ?", (session_id,) - ) + cursor.execute("DELETE FROM agent_messages WHERE session_id = ?", (session_id,)) else: cursor.execute("DELETE FROM agent_messages") self.conn.commit() @@ -156,9 +150,7 @@ class AgentCommunicationBus: def receive_messages(self, agent_id: str) -> List[AgentMessage]: return self.get_messages(agent_id, unread_only=True) - def get_conversation_history( - self, agent_a: str, agent_b: str - ) -> List[AgentMessage]: + def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]: cursor = self.conn.cursor() cursor.execute( """ diff --git a/pr/agents/agent_manager.py b/pr/agents/agent_manager.py index cc58b4d..12ec1df 100644 --- a/pr/agents/agent_manager.py +++ b/pr/agents/agent_manager.py @@ -19,17 +19,14 @@ class AgentInstance: task_count: int = 0 def add_message(self, role: str, content: str): - self.message_history.append( - {"role": role, "content": content, "timestamp": time.time()} - ) + self.message_history.append({"role": role, "content": content, "timestamp": time.time()}) def get_system_message(self) -> Dict[str, str]: return {"role": "system", "content": self.role.system_prompt} def get_messages_for_api(self) -> List[Dict[str, str]]: return [self.get_system_message()] + [ - {"role": msg["role"], "content": msg["content"]} - for msg in self.message_history + {"role": msg["role"], "content": msg["content"]} for msg in self.message_history ] @@ -128,14 +125,10 @@ class AgentManager: self.communication_bus.send_message(message, self.session_id) return message.message_id - def get_agent_messages( - self, agent_id: str, unread_only: bool = True - ) -> List[AgentMessage]: + def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]: return self.communication_bus.get_messages(agent_id, unread_only) - def collaborate_agents( - self, orchestrator_id: str, task: str, agent_roles: List[str] - ): + def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]): orchestrator = self.get_agent(orchestrator_id) if not orchestrator: orchestrator_id = self.create_agent("orchestrator") @@ -153,9 +146,7 @@ Available specialized agents: Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.""" - orchestrator_result = self.execute_agent_task( - orchestrator_id, orchestration_prompt - ) + orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt) results = {"orchestrator": orchestrator_result, "agents": []} diff --git a/pr/autonomous/mode.py b/pr/autonomous/mode.py index 6200f9b..f05571b 100644 --- a/pr/autonomous/mode.py +++ b/pr/autonomous/mode.py @@ -22,22 +22,14 @@ def run_autonomous_mode(assistant, task): while True: assistant.autonomous_iterations += 1 - logger.debug( - f"--- Autonomous iteration {assistant.autonomous_iterations} ---" - ) - logger.debug( - f"Messages before context management: {len(assistant.messages)}" - ) + logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---") + logger.debug(f"Messages before context management: {len(assistant.messages)}") from pr.core.context import manage_context_window - assistant.messages = manage_context_window( - assistant.messages, assistant.verbose - ) + assistant.messages = manage_context_window(assistant.messages, assistant.verbose) - logger.debug( - f"Messages after context management: {len(assistant.messages)}" - ) + logger.debug(f"Messages after context management: {len(assistant.messages)}") from pr.core.api import call_api from pr.tools.base import get_tools_definition @@ -193,9 +185,7 @@ def execute_single_tool(assistant, func_name, arguments): "db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn), "web_search": lambda **kw: web_search(**kw), "web_search_news": lambda **kw: web_search_news(**kw), - "python_exec": lambda **kw: python_exec( - **kw, python_globals=assistant.python_globals - ), + "python_exec": lambda **kw: python_exec(**kw, python_globals=assistant.python_globals), "index_source_directory": lambda **kw: index_source_directory(**kw), "search_replace": lambda **kw: search_replace(**kw), "open_editor": lambda **kw: open_editor(**kw), diff --git a/pr/cache/api_cache.py b/pr/cache/api_cache.py index f224f28..6dc6548 100644 --- a/pr/cache/api_cache.py +++ b/pr/cache/api_cache.py @@ -140,9 +140,7 @@ class APICache: total_entries = cursor.fetchone()[0] current_time = int(time.time()) - cursor.execute( - "SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,) - ) + cursor.execute("SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)) valid_entries = cursor.fetchone()[0] cursor.execute( diff --git a/pr/cache/tool_cache.py b/pr/cache/tool_cache.py index da90225..383982d 100644 --- a/pr/cache/tool_cache.py +++ b/pr/cache/tool_cache.py @@ -161,9 +161,7 @@ class ToolCache: total_entries = cursor.fetchone()[0] current_time = int(time.time()) - cursor.execute( - "SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,) - ) + cursor.execute("SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)) valid_entries = cursor.fetchone()[0] cursor.execute( diff --git a/pr/commands/handlers.py b/pr/commands/handlers.py index c532e23..4968684 100644 --- a/pr/commands/handlers.py +++ b/pr/commands/handlers.py @@ -95,9 +95,7 @@ def handle_command(assistant, command): print(f"{Colors.BOLD}Available Tools:{Colors.RESET}") for tool in get_tools_definition(): func = tool["function"] - print( - f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}" - ) + print(f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}") elif cmd == "/review" and len(command_parts) > 1: filename = command_parts[1] @@ -168,9 +166,7 @@ def handle_command(assistant, command): def review_file(assistant, filename): result = read_file(filename) if result["status"] == "success": - message = ( - f"Please review this file and provide feedback:\n\n{result['content']}" - ) + message = f"Please review this file and provide feedback:\n\n{result['content']}" from pr.core.assistant import process_message process_message(assistant, message) @@ -181,9 +177,7 @@ def review_file(assistant, filename): def refactor_file(assistant, filename): result = read_file(filename) if result["status"] == "success": - message = ( - f"Please refactor this code to improve its quality:\n\n{result['content']}" - ) + message = f"Please refactor this code to improve its quality:\n\n{result['content']}" from pr.core.assistant import process_message process_message(assistant, message) @@ -354,9 +348,7 @@ def show_conversation_history(assistant): for conv in history: import datetime - started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime( - "%Y-%m-%d %H:%M" - ) + started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime("%Y-%m-%d %H:%M") print(f"\n • {Colors.CYAN}{conv['conversation_id']}{Colors.RESET}") print(f" Started: {started}") print(f" Messages: {conv['message_count']}") @@ -536,9 +528,7 @@ def show_session_output(assistant, session_name): for line in output: print(line) else: - print( - f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}" - ) + print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}") except Exception as e: print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}") diff --git a/pr/commands/multiplexer_commands.py b/pr/commands/multiplexer_commands.py index 27d5333..b96127a 100644 --- a/pr/commands/multiplexer_commands.py +++ b/pr/commands/multiplexer_commands.py @@ -116,9 +116,7 @@ def kill_session(args): close_interactive_session(session_name) print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}") except Exception as e: - print( - f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}" - ) + print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}") def send_command(args): @@ -132,13 +130,9 @@ def send_command(args): try: send_input_to_session(session_name, command) - print( - f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}" - ) + print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}") except Exception as e: - print( - f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}" - ) + print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}") def show_session_log(args): @@ -220,9 +214,7 @@ def list_waiting_sessions(args=None): waiting_sessions.append(session_name) if not waiting_sessions: - print( - f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}" - ) + print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}") return print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}") @@ -237,9 +229,7 @@ def list_waiting_sessions(args=None): if session_info: suggestions = detector.get_response_suggestions({}, process_type) if suggestions: - print( - f" Suggested inputs: {', '.join(suggestions[:3])}" - ) # Show first 3 + print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3 print() diff --git a/pr/core/advanced_context.py b/pr/core/advanced_context.py index 1d10154..cedba7e 100644 --- a/pr/core/advanced_context.py +++ b/pr/core/advanced_context.py @@ -40,9 +40,7 @@ class AdvancedContextManager: words = re.findall(r"\b\w+\b", content.lower()) unique_words.update(words) - vocabulary_richness = ( - len(unique_words) / total_length if total_length > 0 else 0 - ) + vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0 # Simple complexity score based on length and richness complexity = min(1.0, (avg_length / 100) + vocabulary_richness) diff --git a/pr/core/api.py b/pr/core/api.py index b52efa4..f65affd 100644 --- a/pr/core/api.py +++ b/pr/core/api.py @@ -9,9 +9,7 @@ from pr.core.context import auto_slim_messages logger = logging.getLogger("pr") -def call_api( - messages, model, api_url, api_key, use_tools, tools_definition, verbose=False -): +def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False): try: messages = auto_slim_messages(messages, verbose=verbose) @@ -65,13 +63,9 @@ def call_api( msg = choice["message"] logger.debug(f"Response role: {msg.get('role', 'N/A')}") if "content" in msg and msg["content"]: - logger.debug( - f"Response content length: {len(msg['content'])} chars" - ) + logger.debug(f"Response content length: {len(msg['content'])} chars") if "tool_calls" in msg: - logger.debug( - f"Response contains {len(msg['tool_calls'])} tool call(s)" - ) + logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)") logger.debug("=== API CALL END ===") return result diff --git a/pr/core/assistant.py b/pr/core/assistant.py index 92257d6..b962c7d 100644 --- a/pr/core/assistant.py +++ b/pr/core/assistant.py @@ -76,9 +76,7 @@ logger = logging.getLogger("pr") logger.setLevel(logging.DEBUG) file_handler = logging.FileHandler(LOG_FILE) -file_handler.setFormatter( - logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -) +file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) logger.addHandler(file_handler) @@ -93,9 +91,7 @@ class Assistant: if self.debug: console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) - console_handler.setFormatter( - logging.Formatter("%(levelname)s: %(message)s") - ) + console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) logger.addHandler(console_handler) logger.debug("Debug mode enabled") self.api_key = os.environ.get("OPENROUTER_API_KEY", "") @@ -210,13 +206,9 @@ class Assistant: session_name = event.get("session_name", "unknown") if event_type == "session_started": - print( - f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started" - ) + print(f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started") elif event_type == "session_ended": - print( - f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended" - ) + print(f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended") elif event_type == "output_received": lines = len(event.get("new_output", {}).get("stdout", [])) print( @@ -241,9 +233,7 @@ class Assistant: except Exception as e: if self.debug: - print( - f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}" - ) + print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}") def execute_tool_calls(self, tool_calls): results = [] @@ -263,14 +253,10 @@ class Assistant: "run_command": lambda **kw: run_command(**kw), "tail_process": lambda **kw: tail_process(**kw), "kill_process": lambda **kw: kill_process(**kw), - "start_interactive_session": lambda **kw: start_interactive_session( - **kw - ), + "start_interactive_session": lambda **kw: start_interactive_session(**kw), "send_input_to_session": lambda **kw: send_input_to_session(**kw), "read_session_output": lambda **kw: read_session_output(**kw), - "close_interactive_session": lambda **kw: close_interactive_session( - **kw - ), + "close_interactive_session": lambda **kw: close_interactive_session(**kw), "read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn), "write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn), "list_directory": lambda **kw: list_directory(**kw), @@ -286,9 +272,7 @@ class Assistant: **kw, python_globals=self.python_globals ), "index_source_directory": lambda **kw: index_source_directory(**kw), - "search_replace": lambda **kw: search_replace( - **kw, db_conn=self.db_conn - ), + "search_replace": lambda **kw: search_replace(**kw, db_conn=self.db_conn), "open_editor": lambda **kw: open_editor(**kw), "editor_insert_text": lambda **kw: editor_insert_text( **kw, db_conn=self.db_conn @@ -304,15 +288,11 @@ class Assistant: "display_edit_summary": lambda **kw: display_edit_summary(), "display_edit_timeline": lambda **kw: display_edit_timeline(**kw), "clear_edit_tracker": lambda **kw: clear_edit_tracker(), - "start_interactive_session": lambda **kw: start_interactive_session( - **kw - ), + "start_interactive_session": lambda **kw: start_interactive_session(**kw), "send_input_to_session": lambda **kw: send_input_to_session(**kw), "read_session_output": lambda **kw: read_session_output(**kw), "list_active_sessions": lambda **kw: list_active_sessions(**kw), - "close_interactive_session": lambda **kw: close_interactive_session( - **kw - ), + "close_interactive_session": lambda **kw: close_interactive_session(**kw), "create_agent": lambda **kw: create_agent(**kw), "list_agents": lambda **kw: list_agents(**kw), "execute_agent_task": lambda **kw: execute_agent_task(**kw), @@ -321,16 +301,10 @@ class Assistant: "add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw), "get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw), "search_knowledge": lambda **kw: search_knowledge(**kw), - "get_knowledge_by_category": lambda **kw: get_knowledge_by_category( - **kw - ), - "update_knowledge_importance": lambda **kw: update_knowledge_importance( - **kw - ), + "get_knowledge_by_category": lambda **kw: get_knowledge_by_category(**kw), + "update_knowledge_importance": lambda **kw: update_knowledge_importance(**kw), "delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw), - "get_knowledge_statistics": lambda **kw: get_knowledge_statistics( - **kw - ), + "get_knowledge_statistics": lambda **kw: get_knowledge_statistics(**kw), } if func_name in func_map: @@ -356,9 +330,7 @@ class Assistant: { "tool_call_id": tool_id, "role": "tool", - "content": json.dumps( - {"status": "error", "error": error_msg} - ), + "content": json.dumps({"status": "error", "error": error_msg}), } ) @@ -405,9 +377,7 @@ class Assistant: self.autonomous_mode = False sys.exit(0) else: - print( - f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}" - ) + print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}") return self.interrupt_count += 1 diff --git a/pr/core/autonomous_interactions.py b/pr/core/autonomous_interactions.py index 067a432..6074cd2 100644 --- a/pr/core/autonomous_interactions.py +++ b/pr/core/autonomous_interactions.py @@ -21,9 +21,7 @@ class AutonomousInteractions: self.llm_callback = llm_callback if self.interaction_thread is None: self.active = True - self.interaction_thread = threading.Thread( - target=self._interaction_loop, daemon=True - ) + self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True) self.interaction_thread.start() def stop(self): @@ -55,9 +53,7 @@ class AutonomousInteractions: if not sessions: return # No active sessions - sessions_needing_attention = self._identify_sessions_needing_attention( - sessions - ) + sessions_needing_attention = self._identify_sessions_needing_attention(sessions) if sessions_needing_attention and self.llm_callback: # Format session updates for LLM @@ -84,9 +80,7 @@ class AutonomousInteractions: continue # 2. High output volume (potential completion or error) - total_lines = ( - output_summary["stdout_lines"] + output_summary["stderr_lines"] - ) + total_lines = output_summary["stdout_lines"] + output_summary["stderr_lines"] if total_lines > 50: # Arbitrary threshold needing_attention.append(session_name) continue diff --git a/pr/core/background_monitor.py b/pr/core/background_monitor.py index f898b99..b30de63 100644 --- a/pr/core/background_monitor.py +++ b/pr/core/background_monitor.py @@ -18,9 +18,7 @@ class BackgroundMonitor: """Start the background monitoring thread.""" if self.monitor_thread is None: self.active = True - self.monitor_thread = threading.Thread( - target=self._monitor_loop, daemon=True - ) + self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) self.monitor_thread.start() def stop(self): @@ -105,21 +103,14 @@ class BackgroundMonitor: old_stderr_lines = old_state["output_summary"]["stderr_lines"] new_stderr_lines = new_state["output_summary"]["stderr_lines"] - if ( - new_stdout_lines > old_stdout_lines - or new_stderr_lines > old_stderr_lines - ): + if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines: # Get the new output mux = get_multiplexer(session_name) if mux: all_output = mux.get_all_output() new_output = { - "stdout": all_output["stdout"].split("\n")[ - old_stdout_lines: - ], - "stderr": all_output["stderr"].split("\n")[ - old_stderr_lines: - ], + "stdout": all_output["stdout"].split("\n")[old_stdout_lines:], + "stderr": all_output["stderr"].split("\n")[old_stderr_lines:], } events.append( @@ -167,9 +158,7 @@ class BackgroundMonitor: output_summary = state["output_summary"] # Heuristic: High output volume might indicate completion or error - total_lines = ( - output_summary["stdout_lines"] + output_summary["stderr_lines"] - ) + total_lines = output_summary["stdout_lines"] + output_summary["stderr_lines"] if total_lines > 100: # Arbitrary threshold events.append( { @@ -193,9 +182,7 @@ class BackgroundMonitor: # Heuristic: Sessions that might be waiting for input # This would be enhanced with prompt detection in later phases if self._might_be_waiting_for_input(session_name, state): - events.append( - {"type": "possible_input_needed", "session_name": session_name} - ) + events.append({"type": "possible_input_needed", "session_name": session_name}) return events diff --git a/pr/core/context.py b/pr/core/context.py index 4cbc04f..c625c86 100644 --- a/pr/core/context.py +++ b/pr/core/context.py @@ -41,15 +41,11 @@ def truncate_tool_result(result, max_length=None): if "data" in result_copy and isinstance(result_copy["data"], str): if len(result_copy["data"]) > max_length: - result_copy["data"] = ( - result_copy["data"][:max_length] + f"\n... [truncated]" - ) + result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]" if "error" in result_copy and isinstance(result_copy["error"], str): if len(result_copy["error"]) > max_length // 2: - result_copy["error"] = ( - result_copy["error"][: max_length // 2] + "... [truncated]" - ) + result_copy["error"] = result_copy["error"][: max_length // 2] + "... [truncated]" return result_copy @@ -111,9 +107,7 @@ Shell Commands: system_message = "\n\n".join(context_parts) if len(system_message) > max_context_size * 3: - system_message = ( - system_message[: max_context_size * 3] + "\n... [system message truncated]" - ) + system_message = system_message[: max_context_size * 3] + "\n... [system message truncated]" return {"role": "system", "content": system_message} @@ -198,18 +192,14 @@ def trim_message_content(message, max_length): if isinstance(content, str) and len(content) > max_length: trimmed_msg["content"] = ( - content[:max_length] - + f"\n... [trimmed {len(content) - max_length} chars]" + content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]" ) elif isinstance(content, list): trimmed_content = [] for item in content: if isinstance(item, dict): trimmed_item = item.copy() - if ( - "text" in trimmed_item - and len(trimmed_item["text"]) > max_length - ): + if "text" in trimmed_item and len(trimmed_item["text"]) > max_length: trimmed_item["text"] = ( trimmed_item["text"][:max_length] + f"\n... [trimmed]" ) @@ -236,8 +226,7 @@ def trim_message_content(message, max_length): and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2 ): parsed["output"] = ( - parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2] - + f"\n... [truncated]" + parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]" ) if ( "content" in parsed @@ -245,8 +234,7 @@ def trim_message_content(message, max_length): and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2 ): parsed["content"] = ( - parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2] - + f"\n... [truncated]" + parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]" ) trimmed_msg["content"] = json.dumps(parsed) except: @@ -259,24 +247,18 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3): if estimate_tokens(messages) <= target_tokens: return messages - system_msg = ( - messages[0] if messages and messages[0].get("role") == "system" else None - ) + system_msg = messages[0] if messages and messages[0].get("role") == "system" else None start_idx = 1 if system_msg else 0 recent_messages = ( messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:] ) - middle_messages = ( - messages[start_idx:-keep_recent] if len(messages) > keep_recent else [] - ) + middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else [] trimmed_middle = [] for msg in middle_messages: if msg.get("role") == "tool": - trimmed_middle.append( - trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2) - ) + trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)) elif msg.get("role") in ["user", "assistant"]: trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH)) else: @@ -313,9 +295,7 @@ def auto_slim_messages(messages, verbose=False): print( f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}" ) - print( - f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}" - ) + print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}") result = intelligently_trim_messages( messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP @@ -339,17 +319,13 @@ def auto_slim_messages(messages, verbose=False): f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}" ) if removed_count > 0: - print( - f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}" - ) + print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}") return result def emergency_reduce_messages(messages, target_tokens, verbose=False): - system_msg = ( - messages[0] if messages and messages[0].get("role") == "system" else None - ) + system_msg = messages[0] if messages and messages[0].get("role") == "system" else None start_idx = 1 if system_msg else 0 keep_count = 2 diff --git a/pr/core/enhanced_assistant.py b/pr/core/enhanced_assistant.py index c36e732..d679df3 100644 --- a/pr/core/enhanced_assistant.py +++ b/pr/core/enhanced_assistant.py @@ -63,9 +63,7 @@ class EnhancedAssistant: logger.info("Enhanced Assistant initialized with all features") - def _execute_tool_for_workflow( - self, tool_name: str, arguments: Dict[str, Any] - ) -> Any: + def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any: if self.tool_cache: cached_result = self.tool_cache.get(tool_name, arguments) if cached_result is not None: @@ -119,9 +117,7 @@ class EnhancedAssistant: if self.tool_cache: content = result.get("content", "") try: - parsed_content = ( - json.loads(content) if isinstance(content, str) else content - ) + parsed_content = json.loads(content) if isinstance(content, str) else content self.tool_cache.set(tool_name, arguments, parsed_content) except Exception: pass @@ -164,9 +160,7 @@ class EnhancedAssistant: if self.api_cache and CACHE_ENABLED and "error" not in response: token_count = response.get("usage", {}).get("total_tokens", 0) - self.api_cache.set( - self.base.model, messages, 0.7, 4096, response, token_count - ) + self.api_cache.set(self.base.model, messages, 0.7, 4096, response, token_count) return response @@ -197,10 +191,8 @@ class EnhancedAssistant: self.knowledge_store.add_entry(entry) if self.context_manager and ADVANCED_CONTEXT_ENABLED: - enhanced_messages, context_info = ( - self.context_manager.create_enhanced_context( - self.base.messages, user_message, include_knowledge=True - ) + enhanced_messages, context_info = self.context_manager.create_enhanced_context( + self.base.messages, user_message, include_knowledge=True ) if self.base.verbose: @@ -261,9 +253,7 @@ class EnhancedAssistant: orchestrator_id = self.agent_manager.create_agent("orchestrator") return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles) - def search_knowledge( - self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT - ) -> List[Any]: + def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]: return self.knowledge_store.search_entries(query, top_k=limit) def get_cache_statistics(self) -> Dict[str, Any]: diff --git a/pr/core/logging.py b/pr/core/logging.py index 4df2153..d692e60 100644 --- a/pr/core/logging.py +++ b/pr/core/logging.py @@ -16,9 +16,7 @@ def setup_logging(verbose=False): if logger.handlers: logger.handlers.clear() - file_handler = RotatingFileHandler( - LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5 - ) + file_handler = RotatingFileHandler(LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5) file_handler.setLevel(logging.DEBUG) file_formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", diff --git a/pr/core/usage_tracker.py b/pr/core/usage_tracker.py index d47eaae..7223263 100644 --- a/pr/core/usage_tracker.py +++ b/pr/core/usage_tracker.py @@ -64,13 +64,9 @@ class UsageTracker: self._save_to_history(model, input_tokens, output_tokens, cost) - logger.debug( - f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}" - ) + logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}") - def _calculate_cost( - self, model: str, input_tokens: int, output_tokens: int - ) -> float: + def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float: if model not in MODEL_COSTS: base_model = model.split("/")[0] if "/" in model else model if base_model not in MODEL_COSTS: @@ -85,9 +81,7 @@ class UsageTracker: return input_cost + output_cost - def _save_to_history( - self, model: str, input_tokens: int, output_tokens: int, cost: float - ): + def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float): try: history = [] if os.path.exists(USAGE_DB_FILE): diff --git a/pr/core/validation.py b/pr/core/validation.py index 02c5301..b675765 100644 --- a/pr/core/validation.py +++ b/pr/core/validation.py @@ -16,9 +16,7 @@ def validate_file_path(path: str, must_exist: bool = False) -> str: return os.path.abspath(path) -def validate_directory_path( - path: str, must_exist: bool = False, create: bool = False -) -> str: +def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str: if not path: raise ValidationError("Directory path cannot be empty") diff --git a/pr/editor.py b/pr/editor.py index 5fbb45d..dfd0039 100644 --- a/pr/editor.py +++ b/pr/editor.py @@ -179,9 +179,7 @@ class RPEditor: try: self.running = True - self.socket_thread = threading.Thread( - target=self.socket_listener, daemon=True - ) + self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True) self.socket_thread.start() self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start() @@ -538,9 +536,7 @@ class RPEditor: self.cursor_y = len(self.lines) - 1 line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = ( - line[: self.cursor_x] + text + line[self.cursor_x :] - ) + self.lines[self.cursor_y] = line[: self.cursor_x] + text + line[self.cursor_x :] self.cursor_x += len(text) else: # Multi-line insert @@ -562,9 +558,7 @@ class RPEditor: def insert_text(self, text): """Thread-safe text insertion.""" try: - self.client_sock.send( - pickle.dumps({"command": "insert_text", "text": text}) - ) + self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text})) except: with self.lock: self._insert_text(text) @@ -572,13 +566,9 @@ class RPEditor: def _delete_char(self): """Delete character at cursor.""" self.save_state() - if self.cursor_y < len(self.lines) and self.cursor_x < len( - self.lines[self.cursor_y] - ): + if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]): line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = ( - line[: self.cursor_x] + line[self.cursor_x + 1 :] - ) + self.lines[self.cursor_y] = line[: self.cursor_x] + line[self.cursor_x + 1 :] def delete_char(self): """Thread-safe character deletion.""" @@ -614,9 +604,7 @@ class RPEditor: """Handle backspace key.""" if self.cursor_x > 0: line = self.lines[self.cursor_y] - self.lines[self.cursor_y] = ( - line[: self.cursor_x - 1] + line[self.cursor_x :] - ) + self.lines[self.cursor_y] = line[: self.cursor_x - 1] + line[self.cursor_x :] self.cursor_x -= 1 elif self.cursor_y > 0: prev_len = len(self.lines[self.cursor_y - 1]) @@ -668,9 +656,7 @@ class RPEditor: def goto_line(self, line_num): """Thread-safe goto line.""" try: - self.client_sock.send( - pickle.dumps({"command": "goto_line", "line_num": line_num}) - ) + self.client_sock.send(pickle.dumps({"command": "goto_line", "line_num": line_num})) except: with self.lock: self._goto_line(line_num) diff --git a/pr/editor2.py b/pr/editor2.py index 4ad4506..d2b2306 100644 --- a/pr/editor2.py +++ b/pr/editor2.py @@ -211,9 +211,7 @@ class RPEditor: self.running = False elif cmd == "w": self._save_file() - elif ( - cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!" - ): + elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!": self._save_file() self.running = False elif cmd.startswith("w "): @@ -371,9 +369,7 @@ class RPEditor: self.cursor_x = 0 def goto_line(self, line_num): - self.client_sock.send( - pickle.dumps({"command": "goto_line", "line_num": line_num}) - ) + self.client_sock.send(pickle.dumps({"command": "goto_line", "line_num": line_num})) def get_text(self): self.client_sock.send(pickle.dumps({"command": "get_text"})) diff --git a/pr/input_handler.py b/pr/input_handler.py index 96f2b67..71a8e2a 100644 --- a/pr/input_handler.py +++ b/pr/input_handler.py @@ -73,9 +73,7 @@ class AdvancedInputHandler: def _get_editor_input(self, prompt: str) -> Optional[str]: """Get multi-line input for editor mode.""" try: - print( - "Editor mode: Enter your message. Type 'END' on a new line to finish." - ) + print("Editor mode: Enter your message. Type 'END' on a new line to finish.") print("Type '/simple' to switch back to simple mode.") lines = [] diff --git a/pr/memory/fact_extractor.py b/pr/memory/fact_extractor.py index 7fec5eb..ff28d52 100644 --- a/pr/memory/fact_extractor.py +++ b/pr/memory/fact_extractor.py @@ -169,9 +169,7 @@ class FactExtractor: sentence_count = len(re.split(r"[.!?]", text)) urls = re.findall(r"https?://[^\s]+", text) - email_addresses = re.findall( - r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text - ) + email_addresses = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text) dates = re.findall( r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text ) diff --git a/pr/memory/semantic_index.py b/pr/memory/semantic_index.py index db9534c..bceedb5 100644 --- a/pr/memory/semantic_index.py +++ b/pr/memory/semantic_index.py @@ -38,8 +38,7 @@ class SemanticIndex: self.idf_scores = {token: 1.0 for token in token_doc_count} else: self.idf_scores = { - token: math.log(doc_count / count) - for token, count in token_doc_count.items() + token: math.log(doc_count / count) for token, count in token_doc_count.items() } def add_document(self, doc_id: str, text: str): @@ -51,8 +50,7 @@ class SemanticIndex: tf_scores = self._compute_tf(tokens) self.doc_vectors[doc_id] = { - token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) - for token in tokens + token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) for token in tokens } def remove_document(self, doc_id: str): @@ -67,8 +65,7 @@ class SemanticIndex: query_tf = self._compute_tf(query_tokens) query_vector = { - token: query_tf.get(token, 0) * self.idf_scores.get(token, 0) - for token in query_tokens + token: query_tf.get(token, 0) * self.idf_scores.get(token, 0) for token in query_tokens } scores = [] @@ -79,9 +76,7 @@ class SemanticIndex: scores.sort(key=lambda x: x[1], reverse=True) return scores[:top_k] - def _cosine_similarity( - self, vec1: Dict[str, float], vec2: Dict[str, float] - ) -> float: + def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float: dot_product = sum( vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2) ) diff --git a/pr/multiplexer.py b/pr/multiplexer.py index e924631..271d365 100644 --- a/pr/multiplexer.py +++ b/pr/multiplexer.py @@ -30,9 +30,7 @@ class TerminalMultiplexer: self.prompt_detector = get_global_detector() if self.show_output: - self.display_thread = threading.Thread( - target=self._display_worker, daemon=True - ) + self.display_thread = threading.Thread(target=self._display_worker, daemon=True) self.display_thread.start() def _display_worker(self): @@ -48,9 +46,7 @@ class TerminalMultiplexer: try: line = self.stderr_queue.get(timeout=0.1) if line: - sys.stderr.write( - f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}" - ) + sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}") sys.stderr.flush() except queue.Empty: pass diff --git a/pr/plugins/loader.py b/pr/plugins/loader.py index 4c5babf..fa70374 100644 --- a/pr/plugins/loader.py +++ b/pr/plugins/loader.py @@ -52,13 +52,9 @@ class PluginLoader: self.loaded_plugins[plugin_name] = module logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)") else: - logger.warning( - f"Plugin {plugin_name} register_tools() did not return a list" - ) + logger.warning(f"Plugin {plugin_name} register_tools() did not return a list") else: - logger.warning( - f"Plugin {plugin_name} does not have register_tools() function" - ) + logger.warning(f"Plugin {plugin_name} does not have register_tools() function") def get_plugin_function(self, tool_name: str) -> Callable: for plugin_name, module in self.loaded_plugins.items(): diff --git a/pr/tools/agents.py b/pr/tools/agents.py index 1bc02c5..cbf39d1 100644 --- a/pr/tools/agents.py +++ b/pr/tools/agents.py @@ -39,9 +39,7 @@ def list_agents() -> Dict[str, Any]: return {"status": "error", "error": str(e)} -def execute_agent_task( - agent_id: str, task: str, context: Dict[str, Any] = None -) -> Dict[str, Any]: +def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]: """Execute a task with the specified agent.""" try: db_path = os.path.expanduser("~/.assistant_db.sqlite") @@ -63,9 +61,7 @@ def remove_agent(agent_id: str) -> Dict[str, Any]: return {"status": "error", "error": str(e)} -def collaborate_agents( - orchestrator_id: str, task: str, agent_roles: List[str] -) -> Dict[str, Any]: +def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]: """Collaborate multiple agents on a task.""" try: db_path = os.path.expanduser("~/.assistant_db.sqlite") diff --git a/pr/tools/base.py b/pr/tools/base.py index 73f613b..2d9d962 100644 --- a/pr/tools/base.py +++ b/pr/tools/base.py @@ -235,9 +235,7 @@ def get_tools_definition(): "description": "Change the current working directory", "parameters": { "type": "object", - "properties": { - "path": {"type": "string", "description": "Path to change to"} - }, + "properties": {"path": {"type": "string", "description": "Path to change to"}}, "required": ["path"], }, }, @@ -284,9 +282,7 @@ def get_tools_definition(): "description": "Execute a database query", "parameters": { "type": "object", - "properties": { - "query": {"type": "string", "description": "SQL query"} - }, + "properties": {"query": {"type": "string", "description": "SQL query"}}, "required": ["query"], }, }, @@ -298,9 +294,7 @@ def get_tools_definition(): "description": "Perform a web search", "parameters": { "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"} - }, + "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"], }, }, @@ -346,9 +340,7 @@ def get_tools_definition(): "description": "Index directory recursively and read all source files.", "parameters": { "type": "object", - "properties": { - "path": {"type": "string", "description": "Path to index"} - }, + "properties": {"path": {"type": "string", "description": "Path to index"}}, "required": ["path"], }, }, diff --git a/pr/tools/command.py b/pr/tools/command.py index 0e0def1..ed54cde 100644 --- a/pr/tools/command.py +++ b/pr/tools/command.py @@ -81,9 +81,7 @@ def tail_process(pid: int, timeout: int = 30): "pid": pid, } - ready, _, _ = select.select( - [process.stdout, process.stderr], [], [], 0.1 - ) + ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1) for pipe in ready: if pipe == process.stdout: line = process.stdout.readline() diff --git a/pr/tools/database.py b/pr/tools/database.py index 6de4d57..d1f6426 100644 --- a/pr/tools/database.py +++ b/pr/tools/database.py @@ -44,9 +44,7 @@ def db_query(query, db_conn): if query.strip().upper().startswith("SELECT"): results = cursor.fetchall() - columns = ( - [desc[0] for desc in cursor.description] if cursor.description else [] - ) + columns = [desc[0] for desc in cursor.description] if cursor.description else [] return {"status": "success", "columns": columns, "rows": results} else: db_conn.commit() diff --git a/pr/tools/editor.py b/pr/tools/editor.py index dad2112..fc8135b 100644 --- a/pr/tools/editor.py +++ b/pr/tools/editor.py @@ -61,9 +61,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): with open(path) as f: old_content = f.read() - position = (line if line is not None else 0) * 1000 + ( - col if col is not None else 0 - ) + position = (line if line is not None else 0) * 1000 + (col if col is not None else 0) operation = track_edit("INSERT", filepath, start_pos=position, content=text) tracker.mark_in_progress(operation) @@ -76,11 +74,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True): mux_name = f"editor-{path}" mux = get_multiplexer(mux_name) if mux: - location = ( - f" at line {line}, col {col}" - if line is not None and col is not None - else "" - ) + location = f" at line {line}, col {col}" if line is not None and col is not None else "" preview = text[:50] + "..." if len(text) > 50 else text mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n") diff --git a/pr/tools/filesystem.py b/pr/tools/filesystem.py index a51f58c..fcab959 100644 --- a/pr/tools/filesystem.py +++ b/pr/tools/filesystem.py @@ -41,10 +41,7 @@ def write_file(filepath, content, db_conn=None, show_diff=True): from pr.tools.database import db_get read_status = db_get("read:" + path, db_conn) - if ( - read_status.get("status") != "success" - or read_status.get("value") != "true" - ): + if read_status.get("status") != "success" or read_status.get("value") != "true": return { "status": "error", "error": "File must be read before writing. Please read the file first.", @@ -54,9 +51,7 @@ def write_file(filepath, content, db_conn=None, show_diff=True): with open(path) as f: old_content = f.read() - operation = track_edit( - "WRITE", filepath, content=content, old_content=old_content - ) + operation = track_edit("WRITE", filepath, content=content, old_content=old_content) tracker.mark_in_progress(operation) if show_diff and not is_new_file: @@ -119,9 +114,7 @@ def list_directory(path=".", recursive=False): } ) for name in dirs: - items.append( - {"path": os.path.join(root, name), "type": "directory"} - ) + items.append({"path": os.path.join(root, name), "type": "directory"}) else: for item in os.listdir(path): item_path = os.path.join(path, item) @@ -129,11 +122,7 @@ def list_directory(path=".", recursive=False): { "name": item, "type": "directory" if os.path.isdir(item_path) else "file", - "size": ( - os.path.getsize(item_path) - if os.path.isfile(item_path) - else None - ), + "size": (os.path.getsize(item_path) if os.path.isfile(item_path) else None), } ) return {"status": "success", "items": items} @@ -209,10 +198,7 @@ def search_replace(filepath, old_string, new_string, db_conn=None): from pr.tools.database import db_get read_status = db_get("read:" + path, db_conn) - if ( - read_status.get("status") != "success" - or read_status.get("value") != "true" - ): + if read_status.get("status") != "success" or read_status.get("value") != "true": return { "status": "error", "error": "File must be read before writing. Please read the file first.", @@ -259,19 +245,14 @@ def open_editor(filepath): return {"status": "error", "error": str(e)} -def editor_insert_text( - filepath, text, line=None, col=None, show_diff=True, db_conn=None -): +def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None): try: path = os.path.expanduser(filepath) if db_conn: from pr.tools.database import db_get read_status = db_get("read:" + path, db_conn) - if ( - read_status.get("status") != "success" - or read_status.get("value") != "true" - ): + if read_status.get("status") != "success" or read_status.get("value") != "true": return { "status": "error", "error": "File must be read before writing. Please read the file first.", @@ -282,9 +263,7 @@ def editor_insert_text( with open(path) as f: old_content = f.read() - position = (line if line is not None else 0) * 1000 + ( - col if col is not None else 0 - ) + position = (line if line is not None else 0) * 1000 + (col if col is not None else 0) operation = track_edit("INSERT", filepath, start_pos=position, content=text) tracker.mark_in_progress(operation) @@ -325,10 +304,7 @@ def editor_replace_text( from pr.tools.database import db_get read_status = db_get("read:" + path, db_conn) - if ( - read_status.get("status") != "success" - or read_status.get("value") != "true" - ): + if read_status.get("status") != "success" or read_status.get("value") != "true": return { "status": "error", "error": "File must be read before writing. Please read the file first.", diff --git a/pr/tools/memory.py b/pr/tools/memory.py index d169df1..ddef8c4 100644 --- a/pr/tools/memory.py +++ b/pr/tools/memory.py @@ -47,9 +47,7 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]: return {"status": "error", "error": str(e)} -def search_knowledge( - query: str, category: str = None, top_k: int = 5 -) -> Dict[str, Any]: +def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]: """Search the knowledge base semantically.""" try: db_path = os.path.expanduser("~/.assistant_db.sqlite") @@ -75,9 +73,7 @@ def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]: return {"status": "error", "error": str(e)} -def update_knowledge_importance( - entry_id: str, importance_score: float -) -> Dict[str, Any]: +def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]: """Update the importance score of a knowledge entry.""" try: db_path = os.path.expanduser("~/.assistant_db.sqlite") diff --git a/pr/tools/patch.py b/pr/tools/patch.py index ab40ecc..98a6d22 100644 --- a/pr/tools/patch.py +++ b/pr/tools/patch.py @@ -13,10 +13,7 @@ def apply_patch(filepath, patch_content, db_conn=None): from pr.tools.database import db_get read_status = db_get("read:" + path, db_conn) - if ( - read_status.get("status") != "success" - or read_status.get("value") != "true" - ): + if read_status.get("status") != "success" or read_status.get("value") != "true": return { "status": "error", "error": "File must be read before writing. Please read the file first.", @@ -68,9 +65,7 @@ def create_diff( else: lines1 = content1.splitlines(keepends=True) lines2 = content2.splitlines(keepends=True) - diff = list( - difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile) - ) + diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)) return {"status": "success", "diff": "".join(diff)} except Exception as e: return {"status": "error", "error": str(e)} @@ -94,9 +89,7 @@ def display_file_diff(filepath1, filepath2, format_type="unified", context_lines return {"status": "error", "error": str(e)} -def display_content_diff( - old_content, new_content, filename="file", format_type="unified" -): +def display_content_diff(old_content, new_content, filename="file", format_type="unified"): try: visual_diff = display_diff(old_content, new_content, filename, format_type) stats = get_diff_stats(old_content, new_content) diff --git a/pr/tools/prompt_detection.py b/pr/tools/prompt_detection.py index 022366e..79ee168 100644 --- a/pr/tools/prompt_detection.py +++ b/pr/tools/prompt_detection.py @@ -187,9 +187,7 @@ class PromptDetector: # Detect prompts and determine new state detections = self.detect_prompt(output, process_type) - new_state = self._determine_state_from_detections( - detections, process_type, old_state - ) + new_state = self._determine_state_from_detections(detections, process_type, old_state) if new_state != old_state: session_state["transitions"].append( diff --git a/pr/ui/diff_display.py b/pr/ui/diff_display.py index 2e83072..1aca9af 100644 --- a/pr/ui/diff_display.py +++ b/pr/ui/diff_display.py @@ -92,15 +92,11 @@ class DiffDisplay: old_line_num, new_line_num = self._parse_hunk_header(line) elif line.startswith("+"): stats.insertions += 1 - diff_lines.append( - DiffLine("add", line[1:].rstrip(), None, new_line_num) - ) + diff_lines.append(DiffLine("add", line[1:].rstrip(), None, new_line_num)) new_line_num += 1 elif line.startswith("-"): stats.deletions += 1 - diff_lines.append( - DiffLine("delete", line[1:].rstrip(), old_line_num, None) - ) + diff_lines.append(DiffLine("delete", line[1:].rstrip(), old_line_num, None)) old_line_num += 1 elif line.startswith(" "): diff_lines.append( @@ -182,14 +178,10 @@ class DiffDisplay: for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag == "equal": - for i, (old_line, new_line) in enumerate( - zip(old_lines[i1:i2], new_lines[j1:j2]) - ): + for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])): old_display = old_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width) - output.append( - f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}" - ) + output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}") elif tag == "replace": max_lines = max(i2 - i1, j2 - j1) for i in range(max_lines): @@ -203,15 +195,11 @@ class DiffDisplay: elif tag == "delete": for old_line in old_lines[i1:i2]: old_display = old_line[:half_width].ljust(half_width) - output.append( - f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}" - ) + output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}") elif tag == "insert": for new_line in new_lines[j1:j2]: new_display = new_line[:half_width].ljust(half_width) - output.append( - f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}" - ) + output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n") return "\n".join(output) diff --git a/pr/ui/display.py b/pr/ui/display.py index a3aa80c..9936548 100644 --- a/pr/ui/display.py +++ b/pr/ui/display.py @@ -16,8 +16,6 @@ def display_tool_call(tool_name, arguments, status="running", result=None): def print_autonomous_header(task): print(f"{Colors.BOLD}Task:{Colors.RESET} {task}") - print( - f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}" - ) + print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}") print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n") print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n") diff --git a/pr/ui/edit_feedback.py b/pr/ui/edit_feedback.py index d1de847..536ded4 100644 --- a/pr/ui/edit_feedback.py +++ b/pr/ui/edit_feedback.py @@ -46,23 +46,17 @@ class EditOperation: output = [self.format_operation()] if self.op_type in ("INSERT", "REPLACE"): - output.append( - f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}" - ) + output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}") if show_content: if self.old_content: lines = self.old_content.split("\n") - preview = lines[0][:60] + ( - "..." if len(lines[0]) > 60 or len(lines) > 1 else "" - ) + preview = lines[0][:60] + ("..." if len(lines[0]) > 60 or len(lines) > 1 else "") output.append(f" {Colors.RED}- {preview}{Colors.RESET}") if self.content: lines = self.content.split("\n") - preview = lines[0][:60] + ( - "..." if len(lines[0]) > 60 or len(lines) > 1 else "" - ) + preview = lines[0][:60] + ("..." if len(lines[0]) > 60 or len(lines) > 1 else "") output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}") return "\n".join(output) @@ -93,9 +87,7 @@ class EditTracker: "total": len(self.operations), "completed": sum(1 for op in self.operations if op.status == "completed"), "pending": sum(1 for op in self.operations if op.status == "pending"), - "in_progress": sum( - 1 for op in self.operations if op.status == "in_progress" - ), + "in_progress": sum(1 for op in self.operations if op.status == "in_progress"), "failed": sum(1 for op in self.operations if op.status == "failed"), } return stats @@ -112,9 +104,7 @@ class EditTracker: output = [] output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}") - output.append( - f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}" - ) + output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}") output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") stats = self.get_stats() diff --git a/pr/ui/progress.py b/pr/ui/progress.py index 5b3df23..1e0db9e 100644 --- a/pr/ui/progress.py +++ b/pr/ui/progress.py @@ -62,16 +62,10 @@ class ProgressBar: else: percent = int((self.current / self.total) * 100) - filled = ( - int((self.current / self.total) * self.width) - if self.total > 0 - else self.width - ) + filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width bar = "█" * filled + "░" * (self.width - filled) - sys.stdout.write( - f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})" - ) + sys.stdout.write(f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})") sys.stdout.flush() if self.current >= self.total: diff --git a/pr/ui/rendering.py b/pr/ui/rendering.py index 72f0613..c77071c 100644 --- a/pr/ui/rendering.py +++ b/pr/ui/rendering.py @@ -25,12 +25,8 @@ def highlight_code(code, language=None, syntax_highlighting=True): code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code) code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code) - code = re.sub( - r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE - ) - code = re.sub( - r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE - ) + code = re.sub(r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE) + code = re.sub(r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE) return code @@ -76,11 +72,15 @@ def render_markdown(text, syntax_highlighting=True): elif re.match(r"^\s*[\*\-\+]\s", line): match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line) if match: - line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" + line = ( + f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" + ) elif re.match(r"^\s*\d+\.\s", line): match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line) if match: - line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" + line = ( + f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" + ) processed_lines.append(line) text = "\n".join(processed_lines) diff --git a/pr/workflows/workflow_engine.py b/pr/workflows/workflow_engine.py index 66251b7..b0795ee 100644 --- a/pr/workflows/workflow_engine.py +++ b/pr/workflows/workflow_engine.py @@ -40,9 +40,7 @@ class WorkflowEngine: self.tool_executor = tool_executor self.max_workers = max_workers - def _evaluate_condition( - self, condition: str, context: WorkflowExecutionContext - ) -> bool: + def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool: if not condition: return True @@ -188,9 +186,7 @@ class WorkflowEngine: result = future.result() context.log_event("step_completed", step.step_id, result) except Exception as e: - context.log_event( - "step_failed", step.step_id, {"error": str(e)} - ) + context.log_event("step_failed", step.step_id, {"error": str(e)}) else: pending_steps = workflow.get_initial_steps() diff --git a/pr/workflows/workflow_storage.py b/pr/workflows/workflow_storage.py index 258ce4f..6389572 100644 --- a/pr/workflows/workflow_storage.py +++ b/pr/workflows/workflow_storage.py @@ -104,9 +104,7 @@ class WorkflowStorage: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() - cursor.execute( - "SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,) - ) + cursor.execute("SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,)) row = cursor.fetchone() conn.close() @@ -174,9 +172,7 @@ class WorkflowStorage: cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,)) deleted = cursor.rowcount > 0 - cursor.execute( - "DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,) - ) + cursor.execute("DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)) conn.commit() conn.close() diff --git a/tests/test_api.py b/tests/test_api.py index f67c491..333f158 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -12,7 +12,9 @@ class TestApi(unittest.TestCase): def test_call_api_success(self, mock_slim, mock_urlopen): mock_slim.return_value = [{"role": "user", "content": "test"}] mock_response = MagicMock() - mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}' + mock_response.read.return_value = ( + b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}' + ) mock_urlopen.return_value.__enter__.return_value = mock_response result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}]) diff --git a/tests/test_assistant.py b/tests/test_assistant.py index b3395e9..47f6b7e 100644 --- a/tests/test_assistant.py +++ b/tests/test_assistant.py @@ -73,9 +73,7 @@ class TestAssistant(unittest.TestCase): "choices": [ { "message": { - "tool_calls": [ - {"id": "1", "function": {"name": "test", "arguments": "{}"}} - ] + "tool_calls": [{"id": "1", "function": {"name": "test", "arguments": "{}"}}] } } ] @@ -107,9 +105,7 @@ class TestAssistant(unittest.TestCase): with patch("builtins.print"): process_message(assistant, "test message") - assistant.messages.append.assert_called_with( - {"role": "user", "content": "test message"} - ) + assistant.messages.append.assert_called_with({"role": "user", "content": "test message"}) if __name__ == "__main__": diff --git a/tests/test_main.py b/tests/test_main.py index 7340157..fe87f72 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -93,9 +93,7 @@ def test_main_export_session_md(capsys): def test_main_usage(capsys): usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01} - with patch( - "pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage - ): + with patch("pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage): with patch("sys.argv", ["pr", "--usage"]): main() captured = capsys.readouterr()