Update.
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 38s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 51s
Tests / test (ubuntu-latest, 3.11) (push) Successful in 46s
Tests / test (ubuntu-latest, 3.12) (push) Successful in 1m30s
Tests / test (ubuntu-latest, 3.8) (push) Successful in 1m0s
Tests / test (ubuntu-latest, 3.9) (push) Successful in 57s

This commit is contained in:
retoor 2025-11-04 08:10:37 +01:00
parent 1a29ee4918
commit 9a5bf46a54
44 changed files with 163 additions and 469 deletions

View File

@ -32,30 +32,22 @@ Commands in interactive mode:
) )
parser.add_argument("message", nargs="?", help="Message to send to assistant") parser.add_argument("message", nargs="?", help="Message to send to assistant")
parser.add_argument( parser.add_argument("--version", action="version", version=f"PR Assistant {__version__}")
"--version", action="version", version=f"PR Assistant {__version__}"
)
parser.add_argument("-m", "--model", help="AI model to use") parser.add_argument("-m", "--model", help="AI model to use")
parser.add_argument("-u", "--api-url", help="API endpoint URL") parser.add_argument("-u", "--api-url", help="API endpoint URL")
parser.add_argument("--model-list-url", help="Model list endpoint URL") parser.add_argument("--model-list-url", help="Model list endpoint URL")
parser.add_argument( parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode")
"-i", "--interactive", action="store_true", help="Interactive mode"
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument( parser.add_argument(
"--debug", action="store_true", help="Enable debug mode with detailed logging" "--debug", action="store_true", help="Enable debug mode with detailed logging"
) )
parser.add_argument( parser.add_argument("--no-syntax", action="store_true", help="Disable syntax highlighting")
"--no-syntax", action="store_true", help="Disable syntax highlighting"
)
parser.add_argument( parser.add_argument(
"--include-env", "--include-env",
action="store_true", action="store_true",
help="Include environment variables in context", help="Include environment variables in context",
) )
parser.add_argument( parser.add_argument("-c", "--context", action="append", help="Additional context files")
"-c", "--context", action="append", help="Additional context files"
)
parser.add_argument( parser.add_argument(
"--api-mode", action="store_true", help="API mode for specialized interaction" "--api-mode", action="store_true", help="API mode for specialized interaction"
) )
@ -68,18 +60,10 @@ Commands in interactive mode:
) )
parser.add_argument("--quiet", action="store_true", help="Minimal output") parser.add_argument("--quiet", action="store_true", help="Minimal output")
parser.add_argument( parser.add_argument("--save-session", metavar="NAME", help="Save session with given name")
"--save-session", metavar="NAME", help="Save session with given name" parser.add_argument("--load-session", metavar="NAME", help="Load session with given name")
) parser.add_argument("--list-sessions", action="store_true", help="List all saved sessions")
parser.add_argument( parser.add_argument("--delete-session", metavar="NAME", help="Delete a saved session")
"--load-session", metavar="NAME", help="Load session with given name"
)
parser.add_argument(
"--list-sessions", action="store_true", help="List all saved sessions"
)
parser.add_argument(
"--delete-session", metavar="NAME", help="Delete a saved session"
)
parser.add_argument( parser.add_argument(
"--export-session", "--export-session",
nargs=2, nargs=2,
@ -87,9 +71,7 @@ Commands in interactive mode:
help="Export session to file", help="Export session to file",
) )
parser.add_argument( parser.add_argument("--usage", action="store_true", help="Show token usage statistics")
"--usage", action="store_true", help="Show token usage statistics"
)
parser.add_argument( parser.add_argument(
"--create-config", action="store_true", help="Create default configuration file" "--create-config", action="store_true", help="Create default configuration file"
) )

View File

@ -93,9 +93,7 @@ class AgentCommunicationBus:
self.conn.commit() self.conn.commit()
def get_messages( def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
self, agent_id: str, unread_only: bool = True
) -> List[AgentMessage]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
if unread_only: if unread_only:
cursor.execute( cursor.execute(
@ -135,17 +133,13 @@ class AgentCommunicationBus:
def mark_as_read(self, message_id: str): def mark_as_read(self, message_id: str):
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute( cursor.execute("UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,))
"UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,)
)
self.conn.commit() self.conn.commit()
def clear_messages(self, session_id: Optional[str] = None): def clear_messages(self, session_id: Optional[str] = None):
cursor = self.conn.cursor() cursor = self.conn.cursor()
if session_id: if session_id:
cursor.execute( cursor.execute("DELETE FROM agent_messages WHERE session_id = ?", (session_id,))
"DELETE FROM agent_messages WHERE session_id = ?", (session_id,)
)
else: else:
cursor.execute("DELETE FROM agent_messages") cursor.execute("DELETE FROM agent_messages")
self.conn.commit() self.conn.commit()
@ -156,9 +150,7 @@ class AgentCommunicationBus:
def receive_messages(self, agent_id: str) -> List[AgentMessage]: def receive_messages(self, agent_id: str) -> List[AgentMessage]:
return self.get_messages(agent_id, unread_only=True) return self.get_messages(agent_id, unread_only=True)
def get_conversation_history( def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
self, agent_a: str, agent_b: str
) -> List[AgentMessage]:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute( cursor.execute(
""" """

View File

@ -19,17 +19,14 @@ class AgentInstance:
task_count: int = 0 task_count: int = 0
def add_message(self, role: str, content: str): def add_message(self, role: str, content: str):
self.message_history.append( self.message_history.append({"role": role, "content": content, "timestamp": time.time()})
{"role": role, "content": content, "timestamp": time.time()}
)
def get_system_message(self) -> Dict[str, str]: def get_system_message(self) -> Dict[str, str]:
return {"role": "system", "content": self.role.system_prompt} return {"role": "system", "content": self.role.system_prompt}
def get_messages_for_api(self) -> List[Dict[str, str]]: def get_messages_for_api(self) -> List[Dict[str, str]]:
return [self.get_system_message()] + [ return [self.get_system_message()] + [
{"role": msg["role"], "content": msg["content"]} {"role": msg["role"], "content": msg["content"]} for msg in self.message_history
for msg in self.message_history
] ]
@ -128,14 +125,10 @@ class AgentManager:
self.communication_bus.send_message(message, self.session_id) self.communication_bus.send_message(message, self.session_id)
return message.message_id return message.message_id
def get_agent_messages( def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
self, agent_id: str, unread_only: bool = True
) -> List[AgentMessage]:
return self.communication_bus.get_messages(agent_id, unread_only) return self.communication_bus.get_messages(agent_id, unread_only)
def collaborate_agents( def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
self, orchestrator_id: str, task: str, agent_roles: List[str]
):
orchestrator = self.get_agent(orchestrator_id) orchestrator = self.get_agent(orchestrator_id)
if not orchestrator: if not orchestrator:
orchestrator_id = self.create_agent("orchestrator") orchestrator_id = self.create_agent("orchestrator")
@ -153,9 +146,7 @@ Available specialized agents:
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.""" Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results."""
orchestrator_result = self.execute_agent_task( orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt)
orchestrator_id, orchestration_prompt
)
results = {"orchestrator": orchestrator_result, "agents": []} results = {"orchestrator": orchestrator_result, "agents": []}

View File

@ -22,22 +22,14 @@ def run_autonomous_mode(assistant, task):
while True: while True:
assistant.autonomous_iterations += 1 assistant.autonomous_iterations += 1
logger.debug( logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
f"--- Autonomous iteration {assistant.autonomous_iterations} ---" logger.debug(f"Messages before context management: {len(assistant.messages)}")
)
logger.debug(
f"Messages before context management: {len(assistant.messages)}"
)
from pr.core.context import manage_context_window from pr.core.context import manage_context_window
assistant.messages = manage_context_window( assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
assistant.messages, assistant.verbose
)
logger.debug( logger.debug(f"Messages after context management: {len(assistant.messages)}")
f"Messages after context management: {len(assistant.messages)}"
)
from pr.core.api import call_api from pr.core.api import call_api
from pr.tools.base import get_tools_definition from pr.tools.base import get_tools_definition
@ -193,9 +185,7 @@ def execute_single_tool(assistant, func_name, arguments):
"db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn), "db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
"web_search": lambda **kw: web_search(**kw), "web_search": lambda **kw: web_search(**kw),
"web_search_news": lambda **kw: web_search_news(**kw), "web_search_news": lambda **kw: web_search_news(**kw),
"python_exec": lambda **kw: python_exec( "python_exec": lambda **kw: python_exec(**kw, python_globals=assistant.python_globals),
**kw, python_globals=assistant.python_globals
),
"index_source_directory": lambda **kw: index_source_directory(**kw), "index_source_directory": lambda **kw: index_source_directory(**kw),
"search_replace": lambda **kw: search_replace(**kw), "search_replace": lambda **kw: search_replace(**kw),
"open_editor": lambda **kw: open_editor(**kw), "open_editor": lambda **kw: open_editor(**kw),

View File

@ -140,9 +140,7 @@ class APICache:
total_entries = cursor.fetchone()[0] total_entries = cursor.fetchone()[0]
current_time = int(time.time()) current_time = int(time.time())
cursor.execute( cursor.execute("SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,))
"SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)
)
valid_entries = cursor.fetchone()[0] valid_entries = cursor.fetchone()[0]
cursor.execute( cursor.execute(

View File

@ -161,9 +161,7 @@ class ToolCache:
total_entries = cursor.fetchone()[0] total_entries = cursor.fetchone()[0]
current_time = int(time.time()) current_time = int(time.time())
cursor.execute( cursor.execute("SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,))
"SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)
)
valid_entries = cursor.fetchone()[0] valid_entries = cursor.fetchone()[0]
cursor.execute( cursor.execute(

View File

@ -95,9 +95,7 @@ def handle_command(assistant, command):
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}") print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
for tool in get_tools_definition(): for tool in get_tools_definition():
func = tool["function"] func = tool["function"]
print( print(f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}")
f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}"
)
elif cmd == "/review" and len(command_parts) > 1: elif cmd == "/review" and len(command_parts) > 1:
filename = command_parts[1] filename = command_parts[1]
@ -168,9 +166,7 @@ def handle_command(assistant, command):
def review_file(assistant, filename): def review_file(assistant, filename):
result = read_file(filename) result = read_file(filename)
if result["status"] == "success": if result["status"] == "success":
message = ( message = f"Please review this file and provide feedback:\n\n{result['content']}"
f"Please review this file and provide feedback:\n\n{result['content']}"
)
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) process_message(assistant, message)
@ -181,9 +177,7 @@ def review_file(assistant, filename):
def refactor_file(assistant, filename): def refactor_file(assistant, filename):
result = read_file(filename) result = read_file(filename)
if result["status"] == "success": if result["status"] == "success":
message = ( message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
f"Please refactor this code to improve its quality:\n\n{result['content']}"
)
from pr.core.assistant import process_message from pr.core.assistant import process_message
process_message(assistant, message) process_message(assistant, message)
@ -354,9 +348,7 @@ def show_conversation_history(assistant):
for conv in history: for conv in history:
import datetime import datetime
started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime( started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime("%Y-%m-%d %H:%M")
"%Y-%m-%d %H:%M"
)
print(f"\n{Colors.CYAN}{conv['conversation_id']}{Colors.RESET}") print(f"\n{Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
print(f" Started: {started}") print(f" Started: {started}")
print(f" Messages: {conv['message_count']}") print(f" Messages: {conv['message_count']}")
@ -536,9 +528,7 @@ def show_session_output(assistant, session_name):
for line in output: for line in output:
print(line) print(line)
else: else:
print( print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}")
f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}"
)
except Exception as e: except Exception as e:
print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}") print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}")

View File

@ -116,9 +116,7 @@ def kill_session(args):
close_interactive_session(session_name) close_interactive_session(session_name)
print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}") print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}")
except Exception as e: except Exception as e:
print( print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}")
f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}"
)
def send_command(args): def send_command(args):
@ -132,13 +130,9 @@ def send_command(args):
try: try:
send_input_to_session(session_name, command) send_input_to_session(session_name, command)
print( print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}")
f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}"
)
except Exception as e: except Exception as e:
print( print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}")
f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}"
)
def show_session_log(args): def show_session_log(args):
@ -220,9 +214,7 @@ def list_waiting_sessions(args=None):
waiting_sessions.append(session_name) waiting_sessions.append(session_name)
if not waiting_sessions: if not waiting_sessions:
print( print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}")
f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}"
)
return return
print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}") print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}")
@ -237,9 +229,7 @@ def list_waiting_sessions(args=None):
if session_info: if session_info:
suggestions = detector.get_response_suggestions({}, process_type) suggestions = detector.get_response_suggestions({}, process_type)
if suggestions: if suggestions:
print( print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3
f" Suggested inputs: {', '.join(suggestions[:3])}"
) # Show first 3
print() print()

View File

@ -40,9 +40,7 @@ class AdvancedContextManager:
words = re.findall(r"\b\w+\b", content.lower()) words = re.findall(r"\b\w+\b", content.lower())
unique_words.update(words) unique_words.update(words)
vocabulary_richness = ( vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
len(unique_words) / total_length if total_length > 0 else 0
)
# Simple complexity score based on length and richness # Simple complexity score based on length and richness
complexity = min(1.0, (avg_length / 100) + vocabulary_richness) complexity = min(1.0, (avg_length / 100) + vocabulary_richness)

View File

@ -9,9 +9,7 @@ from pr.core.context import auto_slim_messages
logger = logging.getLogger("pr") logger = logging.getLogger("pr")
def call_api( def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
messages, model, api_url, api_key, use_tools, tools_definition, verbose=False
):
try: try:
messages = auto_slim_messages(messages, verbose=verbose) messages = auto_slim_messages(messages, verbose=verbose)
@ -65,13 +63,9 @@ def call_api(
msg = choice["message"] msg = choice["message"]
logger.debug(f"Response role: {msg.get('role', 'N/A')}") logger.debug(f"Response role: {msg.get('role', 'N/A')}")
if "content" in msg and msg["content"]: if "content" in msg and msg["content"]:
logger.debug( logger.debug(f"Response content length: {len(msg['content'])} chars")
f"Response content length: {len(msg['content'])} chars"
)
if "tool_calls" in msg: if "tool_calls" in msg:
logger.debug( logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
f"Response contains {len(msg['tool_calls'])} tool call(s)"
)
logger.debug("=== API CALL END ===") logger.debug("=== API CALL END ===")
return result return result

View File

@ -76,9 +76,7 @@ logger = logging.getLogger("pr")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(LOG_FILE) file_handler = logging.FileHandler(LOG_FILE)
file_handler.setFormatter( file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
)
logger.addHandler(file_handler) logger.addHandler(file_handler)
@ -93,9 +91,7 @@ class Assistant:
if self.debug: if self.debug:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG) console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter( console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
logging.Formatter("%(levelname)s: %(message)s")
)
logger.addHandler(console_handler) logger.addHandler(console_handler)
logger.debug("Debug mode enabled") logger.debug("Debug mode enabled")
self.api_key = os.environ.get("OPENROUTER_API_KEY", "") self.api_key = os.environ.get("OPENROUTER_API_KEY", "")
@ -210,13 +206,9 @@ class Assistant:
session_name = event.get("session_name", "unknown") session_name = event.get("session_name", "unknown")
if event_type == "session_started": if event_type == "session_started":
print( print(f" {Colors.GREEN}{Colors.RESET} Session '{session_name}' started")
f" {Colors.GREEN}{Colors.RESET} Session '{session_name}' started"
)
elif event_type == "session_ended": elif event_type == "session_ended":
print( print(f" {Colors.YELLOW}{Colors.RESET} Session '{session_name}' ended")
f" {Colors.YELLOW}{Colors.RESET} Session '{session_name}' ended"
)
elif event_type == "output_received": elif event_type == "output_received":
lines = len(event.get("new_output", {}).get("stdout", [])) lines = len(event.get("new_output", {}).get("stdout", []))
print( print(
@ -241,9 +233,7 @@ class Assistant:
except Exception as e: except Exception as e:
if self.debug: if self.debug:
print( print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}")
f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}"
)
def execute_tool_calls(self, tool_calls): def execute_tool_calls(self, tool_calls):
results = [] results = []
@ -263,14 +253,10 @@ class Assistant:
"run_command": lambda **kw: run_command(**kw), "run_command": lambda **kw: run_command(**kw),
"tail_process": lambda **kw: tail_process(**kw), "tail_process": lambda **kw: tail_process(**kw),
"kill_process": lambda **kw: kill_process(**kw), "kill_process": lambda **kw: kill_process(**kw),
"start_interactive_session": lambda **kw: start_interactive_session( "start_interactive_session": lambda **kw: start_interactive_session(**kw),
**kw
),
"send_input_to_session": lambda **kw: send_input_to_session(**kw), "send_input_to_session": lambda **kw: send_input_to_session(**kw),
"read_session_output": lambda **kw: read_session_output(**kw), "read_session_output": lambda **kw: read_session_output(**kw),
"close_interactive_session": lambda **kw: close_interactive_session( "close_interactive_session": lambda **kw: close_interactive_session(**kw),
**kw
),
"read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn), "read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn),
"write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn), "write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn),
"list_directory": lambda **kw: list_directory(**kw), "list_directory": lambda **kw: list_directory(**kw),
@ -286,9 +272,7 @@ class Assistant:
**kw, python_globals=self.python_globals **kw, python_globals=self.python_globals
), ),
"index_source_directory": lambda **kw: index_source_directory(**kw), "index_source_directory": lambda **kw: index_source_directory(**kw),
"search_replace": lambda **kw: search_replace( "search_replace": lambda **kw: search_replace(**kw, db_conn=self.db_conn),
**kw, db_conn=self.db_conn
),
"open_editor": lambda **kw: open_editor(**kw), "open_editor": lambda **kw: open_editor(**kw),
"editor_insert_text": lambda **kw: editor_insert_text( "editor_insert_text": lambda **kw: editor_insert_text(
**kw, db_conn=self.db_conn **kw, db_conn=self.db_conn
@ -304,15 +288,11 @@ class Assistant:
"display_edit_summary": lambda **kw: display_edit_summary(), "display_edit_summary": lambda **kw: display_edit_summary(),
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw), "display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
"clear_edit_tracker": lambda **kw: clear_edit_tracker(), "clear_edit_tracker": lambda **kw: clear_edit_tracker(),
"start_interactive_session": lambda **kw: start_interactive_session( "start_interactive_session": lambda **kw: start_interactive_session(**kw),
**kw
),
"send_input_to_session": lambda **kw: send_input_to_session(**kw), "send_input_to_session": lambda **kw: send_input_to_session(**kw),
"read_session_output": lambda **kw: read_session_output(**kw), "read_session_output": lambda **kw: read_session_output(**kw),
"list_active_sessions": lambda **kw: list_active_sessions(**kw), "list_active_sessions": lambda **kw: list_active_sessions(**kw),
"close_interactive_session": lambda **kw: close_interactive_session( "close_interactive_session": lambda **kw: close_interactive_session(**kw),
**kw
),
"create_agent": lambda **kw: create_agent(**kw), "create_agent": lambda **kw: create_agent(**kw),
"list_agents": lambda **kw: list_agents(**kw), "list_agents": lambda **kw: list_agents(**kw),
"execute_agent_task": lambda **kw: execute_agent_task(**kw), "execute_agent_task": lambda **kw: execute_agent_task(**kw),
@ -321,16 +301,10 @@ class Assistant:
"add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw), "add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw),
"get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw), "get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw),
"search_knowledge": lambda **kw: search_knowledge(**kw), "search_knowledge": lambda **kw: search_knowledge(**kw),
"get_knowledge_by_category": lambda **kw: get_knowledge_by_category( "get_knowledge_by_category": lambda **kw: get_knowledge_by_category(**kw),
**kw "update_knowledge_importance": lambda **kw: update_knowledge_importance(**kw),
),
"update_knowledge_importance": lambda **kw: update_knowledge_importance(
**kw
),
"delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw), "delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw),
"get_knowledge_statistics": lambda **kw: get_knowledge_statistics( "get_knowledge_statistics": lambda **kw: get_knowledge_statistics(**kw),
**kw
),
} }
if func_name in func_map: if func_name in func_map:
@ -356,9 +330,7 @@ class Assistant:
{ {
"tool_call_id": tool_id, "tool_call_id": tool_id,
"role": "tool", "role": "tool",
"content": json.dumps( "content": json.dumps({"status": "error", "error": error_msg}),
{"status": "error", "error": error_msg}
),
} }
) )
@ -405,9 +377,7 @@ class Assistant:
self.autonomous_mode = False self.autonomous_mode = False
sys.exit(0) sys.exit(0)
else: else:
print( print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}")
f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}"
)
return return
self.interrupt_count += 1 self.interrupt_count += 1

View File

@ -21,9 +21,7 @@ class AutonomousInteractions:
self.llm_callback = llm_callback self.llm_callback = llm_callback
if self.interaction_thread is None: if self.interaction_thread is None:
self.active = True self.active = True
self.interaction_thread = threading.Thread( self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True)
target=self._interaction_loop, daemon=True
)
self.interaction_thread.start() self.interaction_thread.start()
def stop(self): def stop(self):
@ -55,9 +53,7 @@ class AutonomousInteractions:
if not sessions: if not sessions:
return # No active sessions return # No active sessions
sessions_needing_attention = self._identify_sessions_needing_attention( sessions_needing_attention = self._identify_sessions_needing_attention(sessions)
sessions
)
if sessions_needing_attention and self.llm_callback: if sessions_needing_attention and self.llm_callback:
# Format session updates for LLM # Format session updates for LLM
@ -84,9 +80,7 @@ class AutonomousInteractions:
continue continue
# 2. High output volume (potential completion or error) # 2. High output volume (potential completion or error)
total_lines = ( total_lines = output_summary["stdout_lines"] + output_summary["stderr_lines"]
output_summary["stdout_lines"] + output_summary["stderr_lines"]
)
if total_lines > 50: # Arbitrary threshold if total_lines > 50: # Arbitrary threshold
needing_attention.append(session_name) needing_attention.append(session_name)
continue continue

View File

@ -18,9 +18,7 @@ class BackgroundMonitor:
"""Start the background monitoring thread.""" """Start the background monitoring thread."""
if self.monitor_thread is None: if self.monitor_thread is None:
self.active = True self.active = True
self.monitor_thread = threading.Thread( self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
target=self._monitor_loop, daemon=True
)
self.monitor_thread.start() self.monitor_thread.start()
def stop(self): def stop(self):
@ -105,21 +103,14 @@ class BackgroundMonitor:
old_stderr_lines = old_state["output_summary"]["stderr_lines"] old_stderr_lines = old_state["output_summary"]["stderr_lines"]
new_stderr_lines = new_state["output_summary"]["stderr_lines"] new_stderr_lines = new_state["output_summary"]["stderr_lines"]
if ( if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines:
new_stdout_lines > old_stdout_lines
or new_stderr_lines > old_stderr_lines
):
# Get the new output # Get the new output
mux = get_multiplexer(session_name) mux = get_multiplexer(session_name)
if mux: if mux:
all_output = mux.get_all_output() all_output = mux.get_all_output()
new_output = { new_output = {
"stdout": all_output["stdout"].split("\n")[ "stdout": all_output["stdout"].split("\n")[old_stdout_lines:],
old_stdout_lines: "stderr": all_output["stderr"].split("\n")[old_stderr_lines:],
],
"stderr": all_output["stderr"].split("\n")[
old_stderr_lines:
],
} }
events.append( events.append(
@ -167,9 +158,7 @@ class BackgroundMonitor:
output_summary = state["output_summary"] output_summary = state["output_summary"]
# Heuristic: High output volume might indicate completion or error # Heuristic: High output volume might indicate completion or error
total_lines = ( total_lines = output_summary["stdout_lines"] + output_summary["stderr_lines"]
output_summary["stdout_lines"] + output_summary["stderr_lines"]
)
if total_lines > 100: # Arbitrary threshold if total_lines > 100: # Arbitrary threshold
events.append( events.append(
{ {
@ -193,9 +182,7 @@ class BackgroundMonitor:
# Heuristic: Sessions that might be waiting for input # Heuristic: Sessions that might be waiting for input
# This would be enhanced with prompt detection in later phases # This would be enhanced with prompt detection in later phases
if self._might_be_waiting_for_input(session_name, state): if self._might_be_waiting_for_input(session_name, state):
events.append( events.append({"type": "possible_input_needed", "session_name": session_name})
{"type": "possible_input_needed", "session_name": session_name}
)
return events return events

View File

@ -41,15 +41,11 @@ def truncate_tool_result(result, max_length=None):
if "data" in result_copy and isinstance(result_copy["data"], str): if "data" in result_copy and isinstance(result_copy["data"], str):
if len(result_copy["data"]) > max_length: if len(result_copy["data"]) > max_length:
result_copy["data"] = ( result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]"
result_copy["data"][:max_length] + f"\n... [truncated]"
)
if "error" in result_copy and isinstance(result_copy["error"], str): if "error" in result_copy and isinstance(result_copy["error"], str):
if len(result_copy["error"]) > max_length // 2: if len(result_copy["error"]) > max_length // 2:
result_copy["error"] = ( result_copy["error"] = result_copy["error"][: max_length // 2] + "... [truncated]"
result_copy["error"][: max_length // 2] + "... [truncated]"
)
return result_copy return result_copy
@ -111,9 +107,7 @@ Shell Commands:
system_message = "\n\n".join(context_parts) system_message = "\n\n".join(context_parts)
if len(system_message) > max_context_size * 3: if len(system_message) > max_context_size * 3:
system_message = ( system_message = system_message[: max_context_size * 3] + "\n... [system message truncated]"
system_message[: max_context_size * 3] + "\n... [system message truncated]"
)
return {"role": "system", "content": system_message} return {"role": "system", "content": system_message}
@ -198,18 +192,14 @@ def trim_message_content(message, max_length):
if isinstance(content, str) and len(content) > max_length: if isinstance(content, str) and len(content) > max_length:
trimmed_msg["content"] = ( trimmed_msg["content"] = (
content[:max_length] content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]"
+ f"\n... [trimmed {len(content) - max_length} chars]"
) )
elif isinstance(content, list): elif isinstance(content, list):
trimmed_content = [] trimmed_content = []
for item in content: for item in content:
if isinstance(item, dict): if isinstance(item, dict):
trimmed_item = item.copy() trimmed_item = item.copy()
if ( if "text" in trimmed_item and len(trimmed_item["text"]) > max_length:
"text" in trimmed_item
and len(trimmed_item["text"]) > max_length
):
trimmed_item["text"] = ( trimmed_item["text"] = (
trimmed_item["text"][:max_length] + f"\n... [trimmed]" trimmed_item["text"][:max_length] + f"\n... [trimmed]"
) )
@ -236,8 +226,7 @@ def trim_message_content(message, max_length):
and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2 and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2
): ):
parsed["output"] = ( parsed["output"] = (
parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2] parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
+ f"\n... [truncated]"
) )
if ( if (
"content" in parsed "content" in parsed
@ -245,8 +234,7 @@ def trim_message_content(message, max_length):
and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2 and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2
): ):
parsed["content"] = ( parsed["content"] = (
parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2] parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
+ f"\n... [truncated]"
) )
trimmed_msg["content"] = json.dumps(parsed) trimmed_msg["content"] = json.dumps(parsed)
except: except:
@ -259,24 +247,18 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
if estimate_tokens(messages) <= target_tokens: if estimate_tokens(messages) <= target_tokens:
return messages return messages
system_msg = ( system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
messages[0] if messages and messages[0].get("role") == "system" else None
)
start_idx = 1 if system_msg else 0 start_idx = 1 if system_msg else 0
recent_messages = ( recent_messages = (
messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:] messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
) )
middle_messages = ( middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
)
trimmed_middle = [] trimmed_middle = []
for msg in middle_messages: for msg in middle_messages:
if msg.get("role") == "tool": if msg.get("role") == "tool":
trimmed_middle.append( trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2))
trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)
)
elif msg.get("role") in ["user", "assistant"]: elif msg.get("role") in ["user", "assistant"]:
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH)) trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
else: else:
@ -313,9 +295,7 @@ def auto_slim_messages(messages, verbose=False):
print( print(
f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}" f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}"
) )
print( print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}")
f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}"
)
result = intelligently_trim_messages( result = intelligently_trim_messages(
messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP
@ -339,17 +319,13 @@ def auto_slim_messages(messages, verbose=False):
f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}" f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}"
) )
if removed_count > 0: if removed_count > 0:
print( print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}")
f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}"
)
return result return result
def emergency_reduce_messages(messages, target_tokens, verbose=False): def emergency_reduce_messages(messages, target_tokens, verbose=False):
system_msg = ( system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
messages[0] if messages and messages[0].get("role") == "system" else None
)
start_idx = 1 if system_msg else 0 start_idx = 1 if system_msg else 0
keep_count = 2 keep_count = 2

View File

@ -63,9 +63,7 @@ class EnhancedAssistant:
logger.info("Enhanced Assistant initialized with all features") logger.info("Enhanced Assistant initialized with all features")
def _execute_tool_for_workflow( def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
self, tool_name: str, arguments: Dict[str, Any]
) -> Any:
if self.tool_cache: if self.tool_cache:
cached_result = self.tool_cache.get(tool_name, arguments) cached_result = self.tool_cache.get(tool_name, arguments)
if cached_result is not None: if cached_result is not None:
@ -119,9 +117,7 @@ class EnhancedAssistant:
if self.tool_cache: if self.tool_cache:
content = result.get("content", "") content = result.get("content", "")
try: try:
parsed_content = ( parsed_content = json.loads(content) if isinstance(content, str) else content
json.loads(content) if isinstance(content, str) else content
)
self.tool_cache.set(tool_name, arguments, parsed_content) self.tool_cache.set(tool_name, arguments, parsed_content)
except Exception: except Exception:
pass pass
@ -164,9 +160,7 @@ class EnhancedAssistant:
if self.api_cache and CACHE_ENABLED and "error" not in response: if self.api_cache and CACHE_ENABLED and "error" not in response:
token_count = response.get("usage", {}).get("total_tokens", 0) token_count = response.get("usage", {}).get("total_tokens", 0)
self.api_cache.set( self.api_cache.set(self.base.model, messages, 0.7, 4096, response, token_count)
self.base.model, messages, 0.7, 4096, response, token_count
)
return response return response
@ -197,11 +191,9 @@ class EnhancedAssistant:
self.knowledge_store.add_entry(entry) self.knowledge_store.add_entry(entry)
if self.context_manager and ADVANCED_CONTEXT_ENABLED: if self.context_manager and ADVANCED_CONTEXT_ENABLED:
enhanced_messages, context_info = ( enhanced_messages, context_info = self.context_manager.create_enhanced_context(
self.context_manager.create_enhanced_context(
self.base.messages, user_message, include_knowledge=True self.base.messages, user_message, include_knowledge=True
) )
)
if self.base.verbose: if self.base.verbose:
logger.info(f"Enhanced context: {context_info}") logger.info(f"Enhanced context: {context_info}")
@ -261,9 +253,7 @@ class EnhancedAssistant:
orchestrator_id = self.agent_manager.create_agent("orchestrator") orchestrator_id = self.agent_manager.create_agent("orchestrator")
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles) return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
def search_knowledge( def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT
) -> List[Any]:
return self.knowledge_store.search_entries(query, top_k=limit) return self.knowledge_store.search_entries(query, top_k=limit)
def get_cache_statistics(self) -> Dict[str, Any]: def get_cache_statistics(self) -> Dict[str, Any]:

View File

@ -16,9 +16,7 @@ def setup_logging(verbose=False):
if logger.handlers: if logger.handlers:
logger.handlers.clear() logger.handlers.clear()
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5)
LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5
)
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter( file_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",

View File

@ -64,13 +64,9 @@ class UsageTracker:
self._save_to_history(model, input_tokens, output_tokens, cost) self._save_to_history(model, input_tokens, output_tokens, cost)
logger.debug( logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}")
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
)
def _calculate_cost( def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
self, model: str, input_tokens: int, output_tokens: int
) -> float:
if model not in MODEL_COSTS: if model not in MODEL_COSTS:
base_model = model.split("/")[0] if "/" in model else model base_model = model.split("/")[0] if "/" in model else model
if base_model not in MODEL_COSTS: if base_model not in MODEL_COSTS:
@ -85,9 +81,7 @@ class UsageTracker:
return input_cost + output_cost return input_cost + output_cost
def _save_to_history( def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float):
self, model: str, input_tokens: int, output_tokens: int, cost: float
):
try: try:
history = [] history = []
if os.path.exists(USAGE_DB_FILE): if os.path.exists(USAGE_DB_FILE):

View File

@ -16,9 +16,7 @@ def validate_file_path(path: str, must_exist: bool = False) -> str:
return os.path.abspath(path) return os.path.abspath(path)
def validate_directory_path( def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str:
path: str, must_exist: bool = False, create: bool = False
) -> str:
if not path: if not path:
raise ValidationError("Directory path cannot be empty") raise ValidationError("Directory path cannot be empty")

View File

@ -179,9 +179,7 @@ class RPEditor:
try: try:
self.running = True self.running = True
self.socket_thread = threading.Thread( self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True)
target=self.socket_listener, daemon=True
)
self.socket_thread.start() self.socket_thread.start()
self.thread = threading.Thread(target=self.run, daemon=True) self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start() self.thread.start()
@ -538,9 +536,7 @@ class RPEditor:
self.cursor_y = len(self.lines) - 1 self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = ( self.lines[self.cursor_y] = line[: self.cursor_x] + text + line[self.cursor_x :]
line[: self.cursor_x] + text + line[self.cursor_x :]
)
self.cursor_x += len(text) self.cursor_x += len(text)
else: else:
# Multi-line insert # Multi-line insert
@ -562,9 +558,7 @@ class RPEditor:
def insert_text(self, text): def insert_text(self, text):
"""Thread-safe text insertion.""" """Thread-safe text insertion."""
try: try:
self.client_sock.send( self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text}))
pickle.dumps({"command": "insert_text", "text": text})
)
except: except:
with self.lock: with self.lock:
self._insert_text(text) self._insert_text(text)
@ -572,13 +566,9 @@ class RPEditor:
def _delete_char(self): def _delete_char(self):
"""Delete character at cursor.""" """Delete character at cursor."""
self.save_state() self.save_state()
if self.cursor_y < len(self.lines) and self.cursor_x < len( if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]):
self.lines[self.cursor_y]
):
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = ( self.lines[self.cursor_y] = line[: self.cursor_x] + line[self.cursor_x + 1 :]
line[: self.cursor_x] + line[self.cursor_x + 1 :]
)
def delete_char(self): def delete_char(self):
"""Thread-safe character deletion.""" """Thread-safe character deletion."""
@ -614,9 +604,7 @@ class RPEditor:
"""Handle backspace key.""" """Handle backspace key."""
if self.cursor_x > 0: if self.cursor_x > 0:
line = self.lines[self.cursor_y] line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = ( self.lines[self.cursor_y] = line[: self.cursor_x - 1] + line[self.cursor_x :]
line[: self.cursor_x - 1] + line[self.cursor_x :]
)
self.cursor_x -= 1 self.cursor_x -= 1
elif self.cursor_y > 0: elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1]) prev_len = len(self.lines[self.cursor_y - 1])
@ -668,9 +656,7 @@ class RPEditor:
def goto_line(self, line_num): def goto_line(self, line_num):
"""Thread-safe goto line.""" """Thread-safe goto line."""
try: try:
self.client_sock.send( self.client_sock.send(pickle.dumps({"command": "goto_line", "line_num": line_num}))
pickle.dumps({"command": "goto_line", "line_num": line_num})
)
except: except:
with self.lock: with self.lock:
self._goto_line(line_num) self._goto_line(line_num)

View File

@ -211,9 +211,7 @@ class RPEditor:
self.running = False self.running = False
elif cmd == "w": elif cmd == "w":
self._save_file() self._save_file()
elif ( elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!":
cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!"
):
self._save_file() self._save_file()
self.running = False self.running = False
elif cmd.startswith("w "): elif cmd.startswith("w "):
@ -371,9 +369,7 @@ class RPEditor:
self.cursor_x = 0 self.cursor_x = 0
def goto_line(self, line_num): def goto_line(self, line_num):
self.client_sock.send( self.client_sock.send(pickle.dumps({"command": "goto_line", "line_num": line_num}))
pickle.dumps({"command": "goto_line", "line_num": line_num})
)
def get_text(self): def get_text(self):
self.client_sock.send(pickle.dumps({"command": "get_text"})) self.client_sock.send(pickle.dumps({"command": "get_text"}))

View File

@ -73,9 +73,7 @@ class AdvancedInputHandler:
def _get_editor_input(self, prompt: str) -> Optional[str]: def _get_editor_input(self, prompt: str) -> Optional[str]:
"""Get multi-line input for editor mode.""" """Get multi-line input for editor mode."""
try: try:
print( print("Editor mode: Enter your message. Type 'END' on a new line to finish.")
"Editor mode: Enter your message. Type 'END' on a new line to finish."
)
print("Type '/simple' to switch back to simple mode.") print("Type '/simple' to switch back to simple mode.")
lines = [] lines = []

View File

@ -169,9 +169,7 @@ class FactExtractor:
sentence_count = len(re.split(r"[.!?]", text)) sentence_count = len(re.split(r"[.!?]", text))
urls = re.findall(r"https?://[^\s]+", text) urls = re.findall(r"https?://[^\s]+", text)
email_addresses = re.findall( email_addresses = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text)
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text
)
dates = re.findall( dates = re.findall(
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
) )

View File

@ -38,8 +38,7 @@ class SemanticIndex:
self.idf_scores = {token: 1.0 for token in token_doc_count} self.idf_scores = {token: 1.0 for token in token_doc_count}
else: else:
self.idf_scores = { self.idf_scores = {
token: math.log(doc_count / count) token: math.log(doc_count / count) for token, count in token_doc_count.items()
for token, count in token_doc_count.items()
} }
def add_document(self, doc_id: str, text: str): def add_document(self, doc_id: str, text: str):
@ -51,8 +50,7 @@ class SemanticIndex:
tf_scores = self._compute_tf(tokens) tf_scores = self._compute_tf(tokens)
self.doc_vectors[doc_id] = { self.doc_vectors[doc_id] = {
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) for token in tokens
for token in tokens
} }
def remove_document(self, doc_id: str): def remove_document(self, doc_id: str):
@ -67,8 +65,7 @@ class SemanticIndex:
query_tf = self._compute_tf(query_tokens) query_tf = self._compute_tf(query_tokens)
query_vector = { query_vector = {
token: query_tf.get(token, 0) * self.idf_scores.get(token, 0) token: query_tf.get(token, 0) * self.idf_scores.get(token, 0) for token in query_tokens
for token in query_tokens
} }
scores = [] scores = []
@ -79,9 +76,7 @@ class SemanticIndex:
scores.sort(key=lambda x: x[1], reverse=True) scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k] return scores[:top_k]
def _cosine_similarity( def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
self, vec1: Dict[str, float], vec2: Dict[str, float]
) -> float:
dot_product = sum( dot_product = sum(
vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2) vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)
) )

View File

@ -30,9 +30,7 @@ class TerminalMultiplexer:
self.prompt_detector = get_global_detector() self.prompt_detector = get_global_detector()
if self.show_output: if self.show_output:
self.display_thread = threading.Thread( self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
target=self._display_worker, daemon=True
)
self.display_thread.start() self.display_thread.start()
def _display_worker(self): def _display_worker(self):
@ -48,9 +46,7 @@ class TerminalMultiplexer:
try: try:
line = self.stderr_queue.get(timeout=0.1) line = self.stderr_queue.get(timeout=0.1)
if line: if line:
sys.stderr.write( sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}")
f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}"
)
sys.stderr.flush() sys.stderr.flush()
except queue.Empty: except queue.Empty:
pass pass

View File

@ -52,13 +52,9 @@ class PluginLoader:
self.loaded_plugins[plugin_name] = module self.loaded_plugins[plugin_name] = module
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)") logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
else: else:
logger.warning( logger.warning(f"Plugin {plugin_name} register_tools() did not return a list")
f"Plugin {plugin_name} register_tools() did not return a list"
)
else: else:
logger.warning( logger.warning(f"Plugin {plugin_name} does not have register_tools() function")
f"Plugin {plugin_name} does not have register_tools() function"
)
def get_plugin_function(self, tool_name: str) -> Callable: def get_plugin_function(self, tool_name: str) -> Callable:
for plugin_name, module in self.loaded_plugins.items(): for plugin_name, module in self.loaded_plugins.items():

View File

@ -39,9 +39,7 @@ def list_agents() -> Dict[str, Any]:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def execute_agent_task( def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
agent_id: str, task: str, context: Dict[str, Any] = None
) -> Dict[str, Any]:
"""Execute a task with the specified agent.""" """Execute a task with the specified agent."""
try: try:
db_path = os.path.expanduser("~/.assistant_db.sqlite") db_path = os.path.expanduser("~/.assistant_db.sqlite")
@ -63,9 +61,7 @@ def remove_agent(agent_id: str) -> Dict[str, Any]:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def collaborate_agents( def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
orchestrator_id: str, task: str, agent_roles: List[str]
) -> Dict[str, Any]:
"""Collaborate multiple agents on a task.""" """Collaborate multiple agents on a task."""
try: try:
db_path = os.path.expanduser("~/.assistant_db.sqlite") db_path = os.path.expanduser("~/.assistant_db.sqlite")

View File

@ -235,9 +235,7 @@ def get_tools_definition():
"description": "Change the current working directory", "description": "Change the current working directory",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "Path to change to"}},
"path": {"type": "string", "description": "Path to change to"}
},
"required": ["path"], "required": ["path"],
}, },
}, },
@ -284,9 +282,7 @@ def get_tools_definition():
"description": "Execute a database query", "description": "Execute a database query",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"query": {"type": "string", "description": "SQL query"}},
"query": {"type": "string", "description": "SQL query"}
},
"required": ["query"], "required": ["query"],
}, },
}, },
@ -298,9 +294,7 @@ def get_tools_definition():
"description": "Perform a web search", "description": "Perform a web search",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"query": {"type": "string", "description": "Search query"}},
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"], "required": ["query"],
}, },
}, },
@ -346,9 +340,7 @@ def get_tools_definition():
"description": "Index directory recursively and read all source files.", "description": "Index directory recursively and read all source files.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {"path": {"type": "string", "description": "Path to index"}},
"path": {"type": "string", "description": "Path to index"}
},
"required": ["path"], "required": ["path"],
}, },
}, },

View File

@ -81,9 +81,7 @@ def tail_process(pid: int, timeout: int = 30):
"pid": pid, "pid": pid,
} }
ready, _, _ = select.select( ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
[process.stdout, process.stderr], [], [], 0.1
)
for pipe in ready: for pipe in ready:
if pipe == process.stdout: if pipe == process.stdout:
line = process.stdout.readline() line = process.stdout.readline()

View File

@ -44,9 +44,7 @@ def db_query(query, db_conn):
if query.strip().upper().startswith("SELECT"): if query.strip().upper().startswith("SELECT"):
results = cursor.fetchall() results = cursor.fetchall()
columns = ( columns = [desc[0] for desc in cursor.description] if cursor.description else []
[desc[0] for desc in cursor.description] if cursor.description else []
)
return {"status": "success", "columns": columns, "rows": results} return {"status": "success", "columns": columns, "rows": results}
else: else:
db_conn.commit() db_conn.commit()

View File

@ -61,9 +61,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
with open(path) as f: with open(path) as f:
old_content = f.read() old_content = f.read()
position = (line if line is not None else 0) * 1000 + ( position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
col if col is not None else 0
)
operation = track_edit("INSERT", filepath, start_pos=position, content=text) operation = track_edit("INSERT", filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
@ -76,11 +74,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
mux_name = f"editor-{path}" mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name) mux = get_multiplexer(mux_name)
if mux: if mux:
location = ( location = f" at line {line}, col {col}" if line is not None and col is not None else ""
f" at line {line}, col {col}"
if line is not None and col is not None
else ""
)
preview = text[:50] + "..." if len(text) > 50 else text preview = text[:50] + "..." if len(text) > 50 else text
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n") mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")

View File

@ -41,10 +41,7 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if ( if read_status.get("status") != "success" or read_status.get("value") != "true":
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return { return {
"status": "error", "status": "error",
"error": "File must be read before writing. Please read the file first.", "error": "File must be read before writing. Please read the file first.",
@ -54,9 +51,7 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
with open(path) as f: with open(path) as f:
old_content = f.read() old_content = f.read()
operation = track_edit( operation = track_edit("WRITE", filepath, content=content, old_content=old_content)
"WRITE", filepath, content=content, old_content=old_content
)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
if show_diff and not is_new_file: if show_diff and not is_new_file:
@ -119,9 +114,7 @@ def list_directory(path=".", recursive=False):
} }
) )
for name in dirs: for name in dirs:
items.append( items.append({"path": os.path.join(root, name), "type": "directory"})
{"path": os.path.join(root, name), "type": "directory"}
)
else: else:
for item in os.listdir(path): for item in os.listdir(path):
item_path = os.path.join(path, item) item_path = os.path.join(path, item)
@ -129,11 +122,7 @@ def list_directory(path=".", recursive=False):
{ {
"name": item, "name": item,
"type": "directory" if os.path.isdir(item_path) else "file", "type": "directory" if os.path.isdir(item_path) else "file",
"size": ( "size": (os.path.getsize(item_path) if os.path.isfile(item_path) else None),
os.path.getsize(item_path)
if os.path.isfile(item_path)
else None
),
} }
) )
return {"status": "success", "items": items} return {"status": "success", "items": items}
@ -209,10 +198,7 @@ def search_replace(filepath, old_string, new_string, db_conn=None):
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if ( if read_status.get("status") != "success" or read_status.get("value") != "true":
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return { return {
"status": "error", "status": "error",
"error": "File must be read before writing. Please read the file first.", "error": "File must be read before writing. Please read the file first.",
@ -259,19 +245,14 @@ def open_editor(filepath):
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def editor_insert_text( def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
filepath, text, line=None, col=None, show_diff=True, db_conn=None
):
try: try:
path = os.path.expanduser(filepath) path = os.path.expanduser(filepath)
if db_conn: if db_conn:
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if ( if read_status.get("status") != "success" or read_status.get("value") != "true":
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return { return {
"status": "error", "status": "error",
"error": "File must be read before writing. Please read the file first.", "error": "File must be read before writing. Please read the file first.",
@ -282,9 +263,7 @@ def editor_insert_text(
with open(path) as f: with open(path) as f:
old_content = f.read() old_content = f.read()
position = (line if line is not None else 0) * 1000 + ( position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
col if col is not None else 0
)
operation = track_edit("INSERT", filepath, start_pos=position, content=text) operation = track_edit("INSERT", filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation) tracker.mark_in_progress(operation)
@ -325,10 +304,7 @@ def editor_replace_text(
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if ( if read_status.get("status") != "success" or read_status.get("value") != "true":
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return { return {
"status": "error", "status": "error",
"error": "File must be read before writing. Please read the file first.", "error": "File must be read before writing. Please read the file first.",

View File

@ -47,9 +47,7 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def search_knowledge( def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
query: str, category: str = None, top_k: int = 5
) -> Dict[str, Any]:
"""Search the knowledge base semantically.""" """Search the knowledge base semantically."""
try: try:
db_path = os.path.expanduser("~/.assistant_db.sqlite") db_path = os.path.expanduser("~/.assistant_db.sqlite")
@ -75,9 +73,7 @@ def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def update_knowledge_importance( def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
entry_id: str, importance_score: float
) -> Dict[str, Any]:
"""Update the importance score of a knowledge entry.""" """Update the importance score of a knowledge entry."""
try: try:
db_path = os.path.expanduser("~/.assistant_db.sqlite") db_path = os.path.expanduser("~/.assistant_db.sqlite")

View File

@ -13,10 +13,7 @@ def apply_patch(filepath, patch_content, db_conn=None):
from pr.tools.database import db_get from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn) read_status = db_get("read:" + path, db_conn)
if ( if read_status.get("status") != "success" or read_status.get("value") != "true":
read_status.get("status") != "success"
or read_status.get("value") != "true"
):
return { return {
"status": "error", "status": "error",
"error": "File must be read before writing. Please read the file first.", "error": "File must be read before writing. Please read the file first.",
@ -68,9 +65,7 @@ def create_diff(
else: else:
lines1 = content1.splitlines(keepends=True) lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True) lines2 = content2.splitlines(keepends=True)
diff = list( diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
)
return {"status": "success", "diff": "".join(diff)} return {"status": "success", "diff": "".join(diff)}
except Exception as e: except Exception as e:
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
@ -94,9 +89,7 @@ def display_file_diff(filepath1, filepath2, format_type="unified", context_lines
return {"status": "error", "error": str(e)} return {"status": "error", "error": str(e)}
def display_content_diff( def display_content_diff(old_content, new_content, filename="file", format_type="unified"):
old_content, new_content, filename="file", format_type="unified"
):
try: try:
visual_diff = display_diff(old_content, new_content, filename, format_type) visual_diff = display_diff(old_content, new_content, filename, format_type)
stats = get_diff_stats(old_content, new_content) stats = get_diff_stats(old_content, new_content)

View File

@ -187,9 +187,7 @@ class PromptDetector:
# Detect prompts and determine new state # Detect prompts and determine new state
detections = self.detect_prompt(output, process_type) detections = self.detect_prompt(output, process_type)
new_state = self._determine_state_from_detections( new_state = self._determine_state_from_detections(detections, process_type, old_state)
detections, process_type, old_state
)
if new_state != old_state: if new_state != old_state:
session_state["transitions"].append( session_state["transitions"].append(

View File

@ -92,15 +92,11 @@ class DiffDisplay:
old_line_num, new_line_num = self._parse_hunk_header(line) old_line_num, new_line_num = self._parse_hunk_header(line)
elif line.startswith("+"): elif line.startswith("+"):
stats.insertions += 1 stats.insertions += 1
diff_lines.append( diff_lines.append(DiffLine("add", line[1:].rstrip(), None, new_line_num))
DiffLine("add", line[1:].rstrip(), None, new_line_num)
)
new_line_num += 1 new_line_num += 1
elif line.startswith("-"): elif line.startswith("-"):
stats.deletions += 1 stats.deletions += 1
diff_lines.append( diff_lines.append(DiffLine("delete", line[1:].rstrip(), old_line_num, None))
DiffLine("delete", line[1:].rstrip(), old_line_num, None)
)
old_line_num += 1 old_line_num += 1
elif line.startswith(" "): elif line.startswith(" "):
diff_lines.append( diff_lines.append(
@ -182,14 +178,10 @@ class DiffDisplay:
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal": if tag == "equal":
for i, (old_line, new_line) in enumerate( for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])):
zip(old_lines[i1:i2], new_lines[j1:j2])
):
old_display = old_line[:half_width].ljust(half_width) old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width)
output.append( output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}")
f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}"
)
elif tag == "replace": elif tag == "replace":
max_lines = max(i2 - i1, j2 - j1) max_lines = max(i2 - i1, j2 - j1)
for i in range(max_lines): for i in range(max_lines):
@ -203,15 +195,11 @@ class DiffDisplay:
elif tag == "delete": elif tag == "delete":
for old_line in old_lines[i1:i2]: for old_line in old_lines[i1:i2]:
old_display = old_line[:half_width].ljust(half_width) old_display = old_line[:half_width].ljust(half_width)
output.append( output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}")
f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}"
)
elif tag == "insert": elif tag == "insert":
for new_line in new_lines[j1:j2]: for new_line in new_lines[j1:j2]:
new_display = new_line[:half_width].ljust(half_width) new_display = new_line[:half_width].ljust(half_width)
output.append( output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}")
f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}"
)
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
return "\n".join(output) return "\n".join(output)

View File

@ -16,8 +16,6 @@ def display_tool_call(tool_name, arguments, status="running", result=None):
def print_autonomous_header(task): def print_autonomous_header(task):
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}") print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
print( print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}")
f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}"
)
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n") print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n") print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n")

View File

@ -46,23 +46,17 @@ class EditOperation:
output = [self.format_operation()] output = [self.format_operation()]
if self.op_type in ("INSERT", "REPLACE"): if self.op_type in ("INSERT", "REPLACE"):
output.append( output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}")
f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}"
)
if show_content: if show_content:
if self.old_content: if self.old_content:
lines = self.old_content.split("\n") lines = self.old_content.split("\n")
preview = lines[0][:60] + ( preview = lines[0][:60] + ("..." if len(lines[0]) > 60 or len(lines) > 1 else "")
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
)
output.append(f" {Colors.RED}- {preview}{Colors.RESET}") output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
if self.content: if self.content:
lines = self.content.split("\n") lines = self.content.split("\n")
preview = lines[0][:60] + ( preview = lines[0][:60] + ("..." if len(lines[0]) > 60 or len(lines) > 1 else "")
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
)
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}") output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
return "\n".join(output) return "\n".join(output)
@ -93,9 +87,7 @@ class EditTracker:
"total": len(self.operations), "total": len(self.operations),
"completed": sum(1 for op in self.operations if op.status == "completed"), "completed": sum(1 for op in self.operations if op.status == "completed"),
"pending": sum(1 for op in self.operations if op.status == "pending"), "pending": sum(1 for op in self.operations if op.status == "pending"),
"in_progress": sum( "in_progress": sum(1 for op in self.operations if op.status == "in_progress"),
1 for op in self.operations if op.status == "in_progress"
),
"failed": sum(1 for op in self.operations if op.status == "failed"), "failed": sum(1 for op in self.operations if op.status == "failed"),
} }
return stats return stats
@ -112,9 +104,7 @@ class EditTracker:
output = [] output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}") output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append( output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}")
f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}"
)
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n") output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
stats = self.get_stats() stats = self.get_stats()

View File

@ -62,16 +62,10 @@ class ProgressBar:
else: else:
percent = int((self.current / self.total) * 100) percent = int((self.current / self.total) * 100)
filled = ( filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width
int((self.current / self.total) * self.width)
if self.total > 0
else self.width
)
bar = "" * filled + "" * (self.width - filled) bar = "" * filled + "" * (self.width - filled)
sys.stdout.write( sys.stdout.write(f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})")
f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})"
)
sys.stdout.flush() sys.stdout.flush()
if self.current >= self.total: if self.current >= self.total:

View File

@ -25,12 +25,8 @@ def highlight_code(code, language=None, syntax_highlighting=True):
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code) code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code) code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
code = re.sub( code = re.sub(r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE)
r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE code = re.sub(r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE)
)
code = re.sub(
r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE
)
return code return code
@ -76,11 +72,15 @@ def render_markdown(text, syntax_highlighting=True):
elif re.match(r"^\s*[\*\-\+]\s", line): elif re.match(r"^\s*[\*\-\+]\s", line):
match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line) match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line)
if match: if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" line = (
f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
)
elif re.match(r"^\s*\d+\.\s", line): elif re.match(r"^\s*\d+\.\s", line):
match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line) match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line)
if match: if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}" line = (
f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
)
processed_lines.append(line) processed_lines.append(line)
text = "\n".join(processed_lines) text = "\n".join(processed_lines)

View File

@ -40,9 +40,7 @@ class WorkflowEngine:
self.tool_executor = tool_executor self.tool_executor = tool_executor
self.max_workers = max_workers self.max_workers = max_workers
def _evaluate_condition( def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool:
self, condition: str, context: WorkflowExecutionContext
) -> bool:
if not condition: if not condition:
return True return True
@ -188,9 +186,7 @@ class WorkflowEngine:
result = future.result() result = future.result()
context.log_event("step_completed", step.step_id, result) context.log_event("step_completed", step.step_id, result)
except Exception as e: except Exception as e:
context.log_event( context.log_event("step_failed", step.step_id, {"error": str(e)})
"step_failed", step.step_id, {"error": str(e)}
)
else: else:
pending_steps = workflow.get_initial_steps() pending_steps = workflow.get_initial_steps()

View File

@ -104,9 +104,7 @@ class WorkflowStorage:
conn = sqlite3.connect(self.db_path, check_same_thread=False) conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,))
"SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
conn.close() conn.close()
@ -174,9 +172,7 @@ class WorkflowStorage:
cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,)) cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,))
deleted = cursor.rowcount > 0 deleted = cursor.rowcount > 0
cursor.execute( cursor.execute("DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,))
"DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)
)
conn.commit() conn.commit()
conn.close() conn.close()

View File

@ -12,7 +12,9 @@ class TestApi(unittest.TestCase):
def test_call_api_success(self, mock_slim, mock_urlopen): def test_call_api_success(self, mock_slim, mock_urlopen):
mock_slim.return_value = [{"role": "user", "content": "test"}] mock_slim.return_value = [{"role": "user", "content": "test"}]
mock_response = MagicMock() mock_response = MagicMock()
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}' mock_response.read.return_value = (
b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
)
mock_urlopen.return_value.__enter__.return_value = mock_response mock_urlopen.return_value.__enter__.return_value = mock_response
result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}]) result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}])

View File

@ -73,9 +73,7 @@ class TestAssistant(unittest.TestCase):
"choices": [ "choices": [
{ {
"message": { "message": {
"tool_calls": [ "tool_calls": [{"id": "1", "function": {"name": "test", "arguments": "{}"}}]
{"id": "1", "function": {"name": "test", "arguments": "{}"}}
]
} }
} }
] ]
@ -107,9 +105,7 @@ class TestAssistant(unittest.TestCase):
with patch("builtins.print"): with patch("builtins.print"):
process_message(assistant, "test message") process_message(assistant, "test message")
assistant.messages.append.assert_called_with( assistant.messages.append.assert_called_with({"role": "user", "content": "test message"})
{"role": "user", "content": "test message"}
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -93,9 +93,7 @@ def test_main_export_session_md(capsys):
def test_main_usage(capsys): def test_main_usage(capsys):
usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01} usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01}
with patch( with patch("pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage):
"pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage
):
with patch("sys.argv", ["pr", "--usage"]): with patch("sys.argv", ["pr", "--usage"]):
main() main()
captured = capsys.readouterr() captured = capsys.readouterr()