Update.
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 38s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 51s
Tests / test (ubuntu-latest, 3.11) (push) Successful in 46s
Tests / test (ubuntu-latest, 3.12) (push) Successful in 1m30s
Tests / test (ubuntu-latest, 3.8) (push) Successful in 1m0s
Tests / test (ubuntu-latest, 3.9) (push) Successful in 57s
Some checks failed
Tests / test (macos-latest, 3.10) (push) Waiting to run
Tests / test (macos-latest, 3.11) (push) Waiting to run
Tests / test (macos-latest, 3.12) (push) Waiting to run
Tests / test (macos-latest, 3.8) (push) Waiting to run
Tests / test (macos-latest, 3.9) (push) Waiting to run
Tests / test (windows-latest, 3.10) (push) Waiting to run
Tests / test (windows-latest, 3.11) (push) Waiting to run
Tests / test (windows-latest, 3.12) (push) Waiting to run
Tests / test (windows-latest, 3.8) (push) Waiting to run
Tests / test (windows-latest, 3.9) (push) Waiting to run
Lint / lint (push) Failing after 38s
Tests / test (ubuntu-latest, 3.10) (push) Successful in 51s
Tests / test (ubuntu-latest, 3.11) (push) Successful in 46s
Tests / test (ubuntu-latest, 3.12) (push) Successful in 1m30s
Tests / test (ubuntu-latest, 3.8) (push) Successful in 1m0s
Tests / test (ubuntu-latest, 3.9) (push) Successful in 57s
This commit is contained in:
parent
1a29ee4918
commit
9a5bf46a54
@ -32,30 +32,22 @@ Commands in interactive mode:
|
||||
)
|
||||
|
||||
parser.add_argument("message", nargs="?", help="Message to send to assistant")
|
||||
parser.add_argument(
|
||||
"--version", action="version", version=f"PR Assistant {__version__}"
|
||||
)
|
||||
parser.add_argument("--version", action="version", version=f"PR Assistant {__version__}")
|
||||
parser.add_argument("-m", "--model", help="AI model to use")
|
||||
parser.add_argument("-u", "--api-url", help="API endpoint URL")
|
||||
parser.add_argument("--model-list-url", help="Model list endpoint URL")
|
||||
parser.add_argument(
|
||||
"-i", "--interactive", action="store_true", help="Interactive mode"
|
||||
)
|
||||
parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Enable debug mode with detailed logging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-syntax", action="store_true", help="Disable syntax highlighting"
|
||||
)
|
||||
parser.add_argument("--no-syntax", action="store_true", help="Disable syntax highlighting")
|
||||
parser.add_argument(
|
||||
"--include-env",
|
||||
action="store_true",
|
||||
help="Include environment variables in context",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--context", action="append", help="Additional context files"
|
||||
)
|
||||
parser.add_argument("-c", "--context", action="append", help="Additional context files")
|
||||
parser.add_argument(
|
||||
"--api-mode", action="store_true", help="API mode for specialized interaction"
|
||||
)
|
||||
@ -68,18 +60,10 @@ Commands in interactive mode:
|
||||
)
|
||||
parser.add_argument("--quiet", action="store_true", help="Minimal output")
|
||||
|
||||
parser.add_argument(
|
||||
"--save-session", metavar="NAME", help="Save session with given name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-session", metavar="NAME", help="Load session with given name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-sessions", action="store_true", help="List all saved sessions"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-session", metavar="NAME", help="Delete a saved session"
|
||||
)
|
||||
parser.add_argument("--save-session", metavar="NAME", help="Save session with given name")
|
||||
parser.add_argument("--load-session", metavar="NAME", help="Load session with given name")
|
||||
parser.add_argument("--list-sessions", action="store_true", help="List all saved sessions")
|
||||
parser.add_argument("--delete-session", metavar="NAME", help="Delete a saved session")
|
||||
parser.add_argument(
|
||||
"--export-session",
|
||||
nargs=2,
|
||||
@ -87,9 +71,7 @@ Commands in interactive mode:
|
||||
help="Export session to file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--usage", action="store_true", help="Show token usage statistics"
|
||||
)
|
||||
parser.add_argument("--usage", action="store_true", help="Show token usage statistics")
|
||||
parser.add_argument(
|
||||
"--create-config", action="store_true", help="Create default configuration file"
|
||||
)
|
||||
|
||||
@ -93,9 +93,7 @@ class AgentCommunicationBus:
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_messages(
|
||||
self, agent_id: str, unread_only: bool = True
|
||||
) -> List[AgentMessage]:
|
||||
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
if unread_only:
|
||||
cursor.execute(
|
||||
@ -135,17 +133,13 @@ class AgentCommunicationBus:
|
||||
|
||||
def mark_as_read(self, message_id: str):
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,)
|
||||
)
|
||||
cursor.execute("UPDATE agent_messages SET read = 1 WHERE message_id = ?", (message_id,))
|
||||
self.conn.commit()
|
||||
|
||||
def clear_messages(self, session_id: Optional[str] = None):
|
||||
cursor = self.conn.cursor()
|
||||
if session_id:
|
||||
cursor.execute(
|
||||
"DELETE FROM agent_messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
cursor.execute("DELETE FROM agent_messages WHERE session_id = ?", (session_id,))
|
||||
else:
|
||||
cursor.execute("DELETE FROM agent_messages")
|
||||
self.conn.commit()
|
||||
@ -156,9 +150,7 @@ class AgentCommunicationBus:
|
||||
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
|
||||
return self.get_messages(agent_id, unread_only=True)
|
||||
|
||||
def get_conversation_history(
|
||||
self, agent_a: str, agent_b: str
|
||||
) -> List[AgentMessage]:
|
||||
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
|
||||
@ -19,17 +19,14 @@ class AgentInstance:
|
||||
task_count: int = 0
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
self.message_history.append(
|
||||
{"role": role, "content": content, "timestamp": time.time()}
|
||||
)
|
||||
self.message_history.append({"role": role, "content": content, "timestamp": time.time()})
|
||||
|
||||
def get_system_message(self) -> Dict[str, str]:
|
||||
return {"role": "system", "content": self.role.system_prompt}
|
||||
|
||||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||||
return [self.get_system_message()] + [
|
||||
{"role": msg["role"], "content": msg["content"]}
|
||||
for msg in self.message_history
|
||||
{"role": msg["role"], "content": msg["content"]} for msg in self.message_history
|
||||
]
|
||||
|
||||
|
||||
@ -128,14 +125,10 @@ class AgentManager:
|
||||
self.communication_bus.send_message(message, self.session_id)
|
||||
return message.message_id
|
||||
|
||||
def get_agent_messages(
|
||||
self, agent_id: str, unread_only: bool = True
|
||||
) -> List[AgentMessage]:
|
||||
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
||||
return self.communication_bus.get_messages(agent_id, unread_only)
|
||||
|
||||
def collaborate_agents(
|
||||
self, orchestrator_id: str, task: str, agent_roles: List[str]
|
||||
):
|
||||
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
|
||||
orchestrator = self.get_agent(orchestrator_id)
|
||||
if not orchestrator:
|
||||
orchestrator_id = self.create_agent("orchestrator")
|
||||
@ -153,9 +146,7 @@ Available specialized agents:
|
||||
|
||||
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results."""
|
||||
|
||||
orchestrator_result = self.execute_agent_task(
|
||||
orchestrator_id, orchestration_prompt
|
||||
)
|
||||
orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt)
|
||||
|
||||
results = {"orchestrator": orchestrator_result, "agents": []}
|
||||
|
||||
|
||||
@ -22,22 +22,14 @@ def run_autonomous_mode(assistant, task):
|
||||
while True:
|
||||
assistant.autonomous_iterations += 1
|
||||
|
||||
logger.debug(
|
||||
f"--- Autonomous iteration {assistant.autonomous_iterations} ---"
|
||||
)
|
||||
logger.debug(
|
||||
f"Messages before context management: {len(assistant.messages)}"
|
||||
)
|
||||
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
|
||||
logger.debug(f"Messages before context management: {len(assistant.messages)}")
|
||||
|
||||
from pr.core.context import manage_context_window
|
||||
|
||||
assistant.messages = manage_context_window(
|
||||
assistant.messages, assistant.verbose
|
||||
)
|
||||
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
|
||||
|
||||
logger.debug(
|
||||
f"Messages after context management: {len(assistant.messages)}"
|
||||
)
|
||||
logger.debug(f"Messages after context management: {len(assistant.messages)}")
|
||||
|
||||
from pr.core.api import call_api
|
||||
from pr.tools.base import get_tools_definition
|
||||
@ -193,9 +185,7 @@ def execute_single_tool(assistant, func_name, arguments):
|
||||
"db_query": lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
|
||||
"web_search": lambda **kw: web_search(**kw),
|
||||
"web_search_news": lambda **kw: web_search_news(**kw),
|
||||
"python_exec": lambda **kw: python_exec(
|
||||
**kw, python_globals=assistant.python_globals
|
||||
),
|
||||
"python_exec": lambda **kw: python_exec(**kw, python_globals=assistant.python_globals),
|
||||
"index_source_directory": lambda **kw: index_source_directory(**kw),
|
||||
"search_replace": lambda **kw: search_replace(**kw),
|
||||
"open_editor": lambda **kw: open_editor(**kw),
|
||||
|
||||
4
pr/cache/api_cache.py
vendored
4
pr/cache/api_cache.py
vendored
@ -140,9 +140,7 @@ class APICache:
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)
|
||||
)
|
||||
cursor.execute("SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,))
|
||||
valid_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute(
|
||||
|
||||
4
pr/cache/tool_cache.py
vendored
4
pr/cache/tool_cache.py
vendored
@ -161,9 +161,7 @@ class ToolCache:
|
||||
total_entries = cursor.fetchone()[0]
|
||||
|
||||
current_time = int(time.time())
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)
|
||||
)
|
||||
cursor.execute("SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,))
|
||||
valid_entries = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute(
|
||||
|
||||
@ -95,9 +95,7 @@ def handle_command(assistant, command):
|
||||
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
|
||||
for tool in get_tools_definition():
|
||||
func = tool["function"]
|
||||
print(
|
||||
f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}"
|
||||
)
|
||||
print(f" • {Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}")
|
||||
|
||||
elif cmd == "/review" and len(command_parts) > 1:
|
||||
filename = command_parts[1]
|
||||
@ -168,9 +166,7 @@ def handle_command(assistant, command):
|
||||
def review_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result["status"] == "success":
|
||||
message = (
|
||||
f"Please review this file and provide feedback:\n\n{result['content']}"
|
||||
)
|
||||
message = f"Please review this file and provide feedback:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
@ -181,9 +177,7 @@ def review_file(assistant, filename):
|
||||
def refactor_file(assistant, filename):
|
||||
result = read_file(filename)
|
||||
if result["status"] == "success":
|
||||
message = (
|
||||
f"Please refactor this code to improve its quality:\n\n{result['content']}"
|
||||
)
|
||||
message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
|
||||
from pr.core.assistant import process_message
|
||||
|
||||
process_message(assistant, message)
|
||||
@ -354,9 +348,7 @@ def show_conversation_history(assistant):
|
||||
for conv in history:
|
||||
import datetime
|
||||
|
||||
started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
started = datetime.datetime.fromtimestamp(conv["started_at"]).strftime("%Y-%m-%d %H:%M")
|
||||
print(f"\n • {Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
|
||||
print(f" Started: {started}")
|
||||
print(f" Messages: {conv['message_count']}")
|
||||
@ -536,9 +528,7 @@ def show_session_output(assistant, session_name):
|
||||
for line in output:
|
||||
print(line)
|
||||
else:
|
||||
print(
|
||||
f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.YELLOW}No output available for session '{session_name}'{Colors.RESET}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error getting session output: {e}{Colors.RESET}")
|
||||
|
||||
|
||||
@ -116,9 +116,7 @@ def kill_session(args):
|
||||
close_interactive_session(session_name)
|
||||
print(f"{Colors.GREEN}Session '{session_name}' terminated.{Colors.RESET}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.RED}Error terminating session '{session_name}': {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def send_command(args):
|
||||
@ -132,13 +130,9 @@ def send_command(args):
|
||||
|
||||
try:
|
||||
send_input_to_session(session_name, command)
|
||||
print(
|
||||
f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.GREEN}Sent command to '{session_name}': {command}{Colors.RESET}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.RED}Error sending command to '{session_name}': {e}{Colors.RESET}")
|
||||
|
||||
|
||||
def show_session_log(args):
|
||||
@ -220,9 +214,7 @@ def list_waiting_sessions(args=None):
|
||||
waiting_sessions.append(session_name)
|
||||
|
||||
if not waiting_sessions:
|
||||
print(
|
||||
f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.GREEN}No sessions are currently waiting for input.{Colors.RESET}")
|
||||
return
|
||||
|
||||
print(f"{Colors.BOLD}Sessions waiting for input:{Colors.RESET}")
|
||||
@ -237,9 +229,7 @@ def list_waiting_sessions(args=None):
|
||||
if session_info:
|
||||
suggestions = detector.get_response_suggestions({}, process_type)
|
||||
if suggestions:
|
||||
print(
|
||||
f" Suggested inputs: {', '.join(suggestions[:3])}"
|
||||
) # Show first 3
|
||||
print(f" Suggested inputs: {', '.join(suggestions[:3])}") # Show first 3
|
||||
print()
|
||||
|
||||
|
||||
|
||||
@ -40,9 +40,7 @@ class AdvancedContextManager:
|
||||
words = re.findall(r"\b\w+\b", content.lower())
|
||||
unique_words.update(words)
|
||||
|
||||
vocabulary_richness = (
|
||||
len(unique_words) / total_length if total_length > 0 else 0
|
||||
)
|
||||
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
|
||||
|
||||
# Simple complexity score based on length and richness
|
||||
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
|
||||
|
||||
@ -9,9 +9,7 @@ from pr.core.context import auto_slim_messages
|
||||
logger = logging.getLogger("pr")
|
||||
|
||||
|
||||
def call_api(
|
||||
messages, model, api_url, api_key, use_tools, tools_definition, verbose=False
|
||||
):
|
||||
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
|
||||
try:
|
||||
messages = auto_slim_messages(messages, verbose=verbose)
|
||||
|
||||
@ -65,13 +63,9 @@ def call_api(
|
||||
msg = choice["message"]
|
||||
logger.debug(f"Response role: {msg.get('role', 'N/A')}")
|
||||
if "content" in msg and msg["content"]:
|
||||
logger.debug(
|
||||
f"Response content length: {len(msg['content'])} chars"
|
||||
)
|
||||
logger.debug(f"Response content length: {len(msg['content'])} chars")
|
||||
if "tool_calls" in msg:
|
||||
logger.debug(
|
||||
f"Response contains {len(msg['tool_calls'])} tool call(s)"
|
||||
)
|
||||
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
|
||||
|
||||
logger.debug("=== API CALL END ===")
|
||||
return result
|
||||
|
||||
@ -76,9 +76,7 @@ logger = logging.getLogger("pr")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
file_handler = logging.FileHandler(LOG_FILE)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||
)
|
||||
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
||||
@ -93,9 +91,7 @@ class Assistant:
|
||||
if self.debug:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s: %(message)s")
|
||||
)
|
||||
console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
|
||||
logger.addHandler(console_handler)
|
||||
logger.debug("Debug mode enabled")
|
||||
self.api_key = os.environ.get("OPENROUTER_API_KEY", "")
|
||||
@ -210,13 +206,9 @@ class Assistant:
|
||||
session_name = event.get("session_name", "unknown")
|
||||
|
||||
if event_type == "session_started":
|
||||
print(
|
||||
f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started"
|
||||
)
|
||||
print(f" {Colors.GREEN}✓{Colors.RESET} Session '{session_name}' started")
|
||||
elif event_type == "session_ended":
|
||||
print(
|
||||
f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended"
|
||||
)
|
||||
print(f" {Colors.YELLOW}✗{Colors.RESET} Session '{session_name}' ended")
|
||||
elif event_type == "output_received":
|
||||
lines = len(event.get("new_output", {}).get("stdout", []))
|
||||
print(
|
||||
@ -241,9 +233,7 @@ class Assistant:
|
||||
|
||||
except Exception as e:
|
||||
if self.debug:
|
||||
print(
|
||||
f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.RED}Error checking background updates: {e}{Colors.RESET}")
|
||||
|
||||
def execute_tool_calls(self, tool_calls):
|
||||
results = []
|
||||
@ -263,14 +253,10 @@ class Assistant:
|
||||
"run_command": lambda **kw: run_command(**kw),
|
||||
"tail_process": lambda **kw: tail_process(**kw),
|
||||
"kill_process": lambda **kw: kill_process(**kw),
|
||||
"start_interactive_session": lambda **kw: start_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"start_interactive_session": lambda **kw: start_interactive_session(**kw),
|
||||
"send_input_to_session": lambda **kw: send_input_to_session(**kw),
|
||||
"read_session_output": lambda **kw: read_session_output(**kw),
|
||||
"close_interactive_session": lambda **kw: close_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"close_interactive_session": lambda **kw: close_interactive_session(**kw),
|
||||
"read_file": lambda **kw: read_file(**kw, db_conn=self.db_conn),
|
||||
"write_file": lambda **kw: write_file(**kw, db_conn=self.db_conn),
|
||||
"list_directory": lambda **kw: list_directory(**kw),
|
||||
@ -286,9 +272,7 @@ class Assistant:
|
||||
**kw, python_globals=self.python_globals
|
||||
),
|
||||
"index_source_directory": lambda **kw: index_source_directory(**kw),
|
||||
"search_replace": lambda **kw: search_replace(
|
||||
**kw, db_conn=self.db_conn
|
||||
),
|
||||
"search_replace": lambda **kw: search_replace(**kw, db_conn=self.db_conn),
|
||||
"open_editor": lambda **kw: open_editor(**kw),
|
||||
"editor_insert_text": lambda **kw: editor_insert_text(
|
||||
**kw, db_conn=self.db_conn
|
||||
@ -304,15 +288,11 @@ class Assistant:
|
||||
"display_edit_summary": lambda **kw: display_edit_summary(),
|
||||
"display_edit_timeline": lambda **kw: display_edit_timeline(**kw),
|
||||
"clear_edit_tracker": lambda **kw: clear_edit_tracker(),
|
||||
"start_interactive_session": lambda **kw: start_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"start_interactive_session": lambda **kw: start_interactive_session(**kw),
|
||||
"send_input_to_session": lambda **kw: send_input_to_session(**kw),
|
||||
"read_session_output": lambda **kw: read_session_output(**kw),
|
||||
"list_active_sessions": lambda **kw: list_active_sessions(**kw),
|
||||
"close_interactive_session": lambda **kw: close_interactive_session(
|
||||
**kw
|
||||
),
|
||||
"close_interactive_session": lambda **kw: close_interactive_session(**kw),
|
||||
"create_agent": lambda **kw: create_agent(**kw),
|
||||
"list_agents": lambda **kw: list_agents(**kw),
|
||||
"execute_agent_task": lambda **kw: execute_agent_task(**kw),
|
||||
@ -321,16 +301,10 @@ class Assistant:
|
||||
"add_knowledge_entry": lambda **kw: add_knowledge_entry(**kw),
|
||||
"get_knowledge_entry": lambda **kw: get_knowledge_entry(**kw),
|
||||
"search_knowledge": lambda **kw: search_knowledge(**kw),
|
||||
"get_knowledge_by_category": lambda **kw: get_knowledge_by_category(
|
||||
**kw
|
||||
),
|
||||
"update_knowledge_importance": lambda **kw: update_knowledge_importance(
|
||||
**kw
|
||||
),
|
||||
"get_knowledge_by_category": lambda **kw: get_knowledge_by_category(**kw),
|
||||
"update_knowledge_importance": lambda **kw: update_knowledge_importance(**kw),
|
||||
"delete_knowledge_entry": lambda **kw: delete_knowledge_entry(**kw),
|
||||
"get_knowledge_statistics": lambda **kw: get_knowledge_statistics(
|
||||
**kw
|
||||
),
|
||||
"get_knowledge_statistics": lambda **kw: get_knowledge_statistics(**kw),
|
||||
}
|
||||
|
||||
if func_name in func_map:
|
||||
@ -356,9 +330,7 @@ class Assistant:
|
||||
{
|
||||
"tool_call_id": tool_id,
|
||||
"role": "tool",
|
||||
"content": json.dumps(
|
||||
{"status": "error", "error": error_msg}
|
||||
),
|
||||
"content": json.dumps({"status": "error", "error": error_msg}),
|
||||
}
|
||||
)
|
||||
|
||||
@ -405,9 +377,7 @@ class Assistant:
|
||||
self.autonomous_mode = False
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(
|
||||
f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}"
|
||||
)
|
||||
print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}")
|
||||
return
|
||||
|
||||
self.interrupt_count += 1
|
||||
|
||||
@ -21,9 +21,7 @@ class AutonomousInteractions:
|
||||
self.llm_callback = llm_callback
|
||||
if self.interaction_thread is None:
|
||||
self.active = True
|
||||
self.interaction_thread = threading.Thread(
|
||||
target=self._interaction_loop, daemon=True
|
||||
)
|
||||
self.interaction_thread = threading.Thread(target=self._interaction_loop, daemon=True)
|
||||
self.interaction_thread.start()
|
||||
|
||||
def stop(self):
|
||||
@ -55,9 +53,7 @@ class AutonomousInteractions:
|
||||
if not sessions:
|
||||
return # No active sessions
|
||||
|
||||
sessions_needing_attention = self._identify_sessions_needing_attention(
|
||||
sessions
|
||||
)
|
||||
sessions_needing_attention = self._identify_sessions_needing_attention(sessions)
|
||||
|
||||
if sessions_needing_attention and self.llm_callback:
|
||||
# Format session updates for LLM
|
||||
@ -84,9 +80,7 @@ class AutonomousInteractions:
|
||||
continue
|
||||
|
||||
# 2. High output volume (potential completion or error)
|
||||
total_lines = (
|
||||
output_summary["stdout_lines"] + output_summary["stderr_lines"]
|
||||
)
|
||||
total_lines = output_summary["stdout_lines"] + output_summary["stderr_lines"]
|
||||
if total_lines > 50: # Arbitrary threshold
|
||||
needing_attention.append(session_name)
|
||||
continue
|
||||
|
||||
@ -18,9 +18,7 @@ class BackgroundMonitor:
|
||||
"""Start the background monitoring thread."""
|
||||
if self.monitor_thread is None:
|
||||
self.active = True
|
||||
self.monitor_thread = threading.Thread(
|
||||
target=self._monitor_loop, daemon=True
|
||||
)
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
|
||||
def stop(self):
|
||||
@ -105,21 +103,14 @@ class BackgroundMonitor:
|
||||
old_stderr_lines = old_state["output_summary"]["stderr_lines"]
|
||||
new_stderr_lines = new_state["output_summary"]["stderr_lines"]
|
||||
|
||||
if (
|
||||
new_stdout_lines > old_stdout_lines
|
||||
or new_stderr_lines > old_stderr_lines
|
||||
):
|
||||
if new_stdout_lines > old_stdout_lines or new_stderr_lines > old_stderr_lines:
|
||||
# Get the new output
|
||||
mux = get_multiplexer(session_name)
|
||||
if mux:
|
||||
all_output = mux.get_all_output()
|
||||
new_output = {
|
||||
"stdout": all_output["stdout"].split("\n")[
|
||||
old_stdout_lines:
|
||||
],
|
||||
"stderr": all_output["stderr"].split("\n")[
|
||||
old_stderr_lines:
|
||||
],
|
||||
"stdout": all_output["stdout"].split("\n")[old_stdout_lines:],
|
||||
"stderr": all_output["stderr"].split("\n")[old_stderr_lines:],
|
||||
}
|
||||
|
||||
events.append(
|
||||
@ -167,9 +158,7 @@ class BackgroundMonitor:
|
||||
output_summary = state["output_summary"]
|
||||
|
||||
# Heuristic: High output volume might indicate completion or error
|
||||
total_lines = (
|
||||
output_summary["stdout_lines"] + output_summary["stderr_lines"]
|
||||
)
|
||||
total_lines = output_summary["stdout_lines"] + output_summary["stderr_lines"]
|
||||
if total_lines > 100: # Arbitrary threshold
|
||||
events.append(
|
||||
{
|
||||
@ -193,9 +182,7 @@ class BackgroundMonitor:
|
||||
# Heuristic: Sessions that might be waiting for input
|
||||
# This would be enhanced with prompt detection in later phases
|
||||
if self._might_be_waiting_for_input(session_name, state):
|
||||
events.append(
|
||||
{"type": "possible_input_needed", "session_name": session_name}
|
||||
)
|
||||
events.append({"type": "possible_input_needed", "session_name": session_name})
|
||||
|
||||
return events
|
||||
|
||||
|
||||
@ -41,15 +41,11 @@ def truncate_tool_result(result, max_length=None):
|
||||
|
||||
if "data" in result_copy and isinstance(result_copy["data"], str):
|
||||
if len(result_copy["data"]) > max_length:
|
||||
result_copy["data"] = (
|
||||
result_copy["data"][:max_length] + f"\n... [truncated]"
|
||||
)
|
||||
result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]"
|
||||
|
||||
if "error" in result_copy and isinstance(result_copy["error"], str):
|
||||
if len(result_copy["error"]) > max_length // 2:
|
||||
result_copy["error"] = (
|
||||
result_copy["error"][: max_length // 2] + "... [truncated]"
|
||||
)
|
||||
result_copy["error"] = result_copy["error"][: max_length // 2] + "... [truncated]"
|
||||
|
||||
return result_copy
|
||||
|
||||
@ -111,9 +107,7 @@ Shell Commands:
|
||||
|
||||
system_message = "\n\n".join(context_parts)
|
||||
if len(system_message) > max_context_size * 3:
|
||||
system_message = (
|
||||
system_message[: max_context_size * 3] + "\n... [system message truncated]"
|
||||
)
|
||||
system_message = system_message[: max_context_size * 3] + "\n... [system message truncated]"
|
||||
|
||||
return {"role": "system", "content": system_message}
|
||||
|
||||
@ -198,18 +192,14 @@ def trim_message_content(message, max_length):
|
||||
|
||||
if isinstance(content, str) and len(content) > max_length:
|
||||
trimmed_msg["content"] = (
|
||||
content[:max_length]
|
||||
+ f"\n... [trimmed {len(content) - max_length} chars]"
|
||||
content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]"
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
trimmed_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
trimmed_item = item.copy()
|
||||
if (
|
||||
"text" in trimmed_item
|
||||
and len(trimmed_item["text"]) > max_length
|
||||
):
|
||||
if "text" in trimmed_item and len(trimmed_item["text"]) > max_length:
|
||||
trimmed_item["text"] = (
|
||||
trimmed_item["text"][:max_length] + f"\n... [trimmed]"
|
||||
)
|
||||
@ -236,8 +226,7 @@ def trim_message_content(message, max_length):
|
||||
and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2
|
||||
):
|
||||
parsed["output"] = (
|
||||
parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2]
|
||||
+ f"\n... [truncated]"
|
||||
parsed["output"][: MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
|
||||
)
|
||||
if (
|
||||
"content" in parsed
|
||||
@ -245,8 +234,7 @@ def trim_message_content(message, max_length):
|
||||
and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2
|
||||
):
|
||||
parsed["content"] = (
|
||||
parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2]
|
||||
+ f"\n... [truncated]"
|
||||
parsed["content"][: MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
|
||||
)
|
||||
trimmed_msg["content"] = json.dumps(parsed)
|
||||
except:
|
||||
@ -259,24 +247,18 @@ def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
|
||||
if estimate_tokens(messages) <= target_tokens:
|
||||
return messages
|
||||
|
||||
system_msg = (
|
||||
messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
)
|
||||
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
recent_messages = (
|
||||
messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
|
||||
)
|
||||
middle_messages = (
|
||||
messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
|
||||
)
|
||||
middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
|
||||
|
||||
trimmed_middle = []
|
||||
for msg in middle_messages:
|
||||
if msg.get("role") == "tool":
|
||||
trimmed_middle.append(
|
||||
trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2)
|
||||
)
|
||||
trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2))
|
||||
elif msg.get("role") in ["user", "assistant"]:
|
||||
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
|
||||
else:
|
||||
@ -313,9 +295,7 @@ def auto_slim_messages(messages, verbose=False):
|
||||
print(
|
||||
f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}")
|
||||
|
||||
result = intelligently_trim_messages(
|
||||
messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP
|
||||
@ -339,17 +319,13 @@ def auto_slim_messages(messages, verbose=False):
|
||||
f"{Colors.GREEN} Token estimate: {estimated_tokens} → {final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}"
|
||||
)
|
||||
if removed_count > 0:
|
||||
print(
|
||||
f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def emergency_reduce_messages(messages, target_tokens, verbose=False):
|
||||
system_msg = (
|
||||
messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
)
|
||||
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
|
||||
start_idx = 1 if system_msg else 0
|
||||
|
||||
keep_count = 2
|
||||
|
||||
@ -63,9 +63,7 @@ class EnhancedAssistant:
|
||||
|
||||
logger.info("Enhanced Assistant initialized with all features")
|
||||
|
||||
def _execute_tool_for_workflow(
|
||||
self, tool_name: str, arguments: Dict[str, Any]
|
||||
) -> Any:
|
||||
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
if self.tool_cache:
|
||||
cached_result = self.tool_cache.get(tool_name, arguments)
|
||||
if cached_result is not None:
|
||||
@ -119,9 +117,7 @@ class EnhancedAssistant:
|
||||
if self.tool_cache:
|
||||
content = result.get("content", "")
|
||||
try:
|
||||
parsed_content = (
|
||||
json.loads(content) if isinstance(content, str) else content
|
||||
)
|
||||
parsed_content = json.loads(content) if isinstance(content, str) else content
|
||||
self.tool_cache.set(tool_name, arguments, parsed_content)
|
||||
except Exception:
|
||||
pass
|
||||
@ -164,9 +160,7 @@ class EnhancedAssistant:
|
||||
|
||||
if self.api_cache and CACHE_ENABLED and "error" not in response:
|
||||
token_count = response.get("usage", {}).get("total_tokens", 0)
|
||||
self.api_cache.set(
|
||||
self.base.model, messages, 0.7, 4096, response, token_count
|
||||
)
|
||||
self.api_cache.set(self.base.model, messages, 0.7, 4096, response, token_count)
|
||||
|
||||
return response
|
||||
|
||||
@ -197,11 +191,9 @@ class EnhancedAssistant:
|
||||
self.knowledge_store.add_entry(entry)
|
||||
|
||||
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
|
||||
enhanced_messages, context_info = (
|
||||
self.context_manager.create_enhanced_context(
|
||||
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
|
||||
self.base.messages, user_message, include_knowledge=True
|
||||
)
|
||||
)
|
||||
|
||||
if self.base.verbose:
|
||||
logger.info(f"Enhanced context: {context_info}")
|
||||
@ -261,9 +253,7 @@ class EnhancedAssistant:
|
||||
orchestrator_id = self.agent_manager.create_agent("orchestrator")
|
||||
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
|
||||
|
||||
def search_knowledge(
|
||||
self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT
|
||||
) -> List[Any]:
|
||||
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
|
||||
return self.knowledge_store.search_entries(query, top_k=limit)
|
||||
|
||||
def get_cache_statistics(self) -> Dict[str, Any]:
|
||||
|
||||
@ -16,9 +16,7 @@ def setup_logging(verbose=False):
|
||||
if logger.handlers:
|
||||
logger.handlers.clear()
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5
|
||||
)
|
||||
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||
|
||||
@ -64,13 +64,9 @@ class UsageTracker:
|
||||
|
||||
self._save_to_history(model, input_tokens, output_tokens, cost)
|
||||
|
||||
logger.debug(
|
||||
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
|
||||
)
|
||||
logger.debug(f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}")
|
||||
|
||||
def _calculate_cost(
|
||||
self, model: str, input_tokens: int, output_tokens: int
|
||||
) -> float:
|
||||
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
if model not in MODEL_COSTS:
|
||||
base_model = model.split("/")[0] if "/" in model else model
|
||||
if base_model not in MODEL_COSTS:
|
||||
@ -85,9 +81,7 @@ class UsageTracker:
|
||||
|
||||
return input_cost + output_cost
|
||||
|
||||
def _save_to_history(
|
||||
self, model: str, input_tokens: int, output_tokens: int, cost: float
|
||||
):
|
||||
def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float):
|
||||
try:
|
||||
history = []
|
||||
if os.path.exists(USAGE_DB_FILE):
|
||||
|
||||
@ -16,9 +16,7 @@ def validate_file_path(path: str, must_exist: bool = False) -> str:
|
||||
return os.path.abspath(path)
|
||||
|
||||
|
||||
def validate_directory_path(
|
||||
path: str, must_exist: bool = False, create: bool = False
|
||||
) -> str:
|
||||
def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str:
|
||||
if not path:
|
||||
raise ValidationError("Directory path cannot be empty")
|
||||
|
||||
|
||||
28
pr/editor.py
28
pr/editor.py
@ -179,9 +179,7 @@ class RPEditor:
|
||||
|
||||
try:
|
||||
self.running = True
|
||||
self.socket_thread = threading.Thread(
|
||||
target=self.socket_listener, daemon=True
|
||||
)
|
||||
self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True)
|
||||
self.socket_thread.start()
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
self.thread.start()
|
||||
@ -538,9 +536,7 @@ class RPEditor:
|
||||
self.cursor_y = len(self.lines) - 1
|
||||
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = (
|
||||
line[: self.cursor_x] + text + line[self.cursor_x :]
|
||||
)
|
||||
self.lines[self.cursor_y] = line[: self.cursor_x] + text + line[self.cursor_x :]
|
||||
self.cursor_x += len(text)
|
||||
else:
|
||||
# Multi-line insert
|
||||
@ -562,9 +558,7 @@ class RPEditor:
|
||||
def insert_text(self, text):
|
||||
"""Thread-safe text insertion."""
|
||||
try:
|
||||
self.client_sock.send(
|
||||
pickle.dumps({"command": "insert_text", "text": text})
|
||||
)
|
||||
self.client_sock.send(pickle.dumps({"command": "insert_text", "text": text}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._insert_text(text)
|
||||
@ -572,13 +566,9 @@ class RPEditor:
|
||||
def _delete_char(self):
|
||||
"""Delete character at cursor."""
|
||||
self.save_state()
|
||||
if self.cursor_y < len(self.lines) and self.cursor_x < len(
|
||||
self.lines[self.cursor_y]
|
||||
):
|
||||
if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]):
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = (
|
||||
line[: self.cursor_x] + line[self.cursor_x + 1 :]
|
||||
)
|
||||
self.lines[self.cursor_y] = line[: self.cursor_x] + line[self.cursor_x + 1 :]
|
||||
|
||||
def delete_char(self):
|
||||
"""Thread-safe character deletion."""
|
||||
@ -614,9 +604,7 @@ class RPEditor:
|
||||
"""Handle backspace key."""
|
||||
if self.cursor_x > 0:
|
||||
line = self.lines[self.cursor_y]
|
||||
self.lines[self.cursor_y] = (
|
||||
line[: self.cursor_x - 1] + line[self.cursor_x :]
|
||||
)
|
||||
self.lines[self.cursor_y] = line[: self.cursor_x - 1] + line[self.cursor_x :]
|
||||
self.cursor_x -= 1
|
||||
elif self.cursor_y > 0:
|
||||
prev_len = len(self.lines[self.cursor_y - 1])
|
||||
@ -668,9 +656,7 @@ class RPEditor:
|
||||
def goto_line(self, line_num):
|
||||
"""Thread-safe goto line."""
|
||||
try:
|
||||
self.client_sock.send(
|
||||
pickle.dumps({"command": "goto_line", "line_num": line_num})
|
||||
)
|
||||
self.client_sock.send(pickle.dumps({"command": "goto_line", "line_num": line_num}))
|
||||
except:
|
||||
with self.lock:
|
||||
self._goto_line(line_num)
|
||||
|
||||
@ -211,9 +211,7 @@ class RPEditor:
|
||||
self.running = False
|
||||
elif cmd == "w":
|
||||
self._save_file()
|
||||
elif (
|
||||
cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!"
|
||||
):
|
||||
elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!":
|
||||
self._save_file()
|
||||
self.running = False
|
||||
elif cmd.startswith("w "):
|
||||
@ -371,9 +369,7 @@ class RPEditor:
|
||||
self.cursor_x = 0
|
||||
|
||||
def goto_line(self, line_num):
|
||||
self.client_sock.send(
|
||||
pickle.dumps({"command": "goto_line", "line_num": line_num})
|
||||
)
|
||||
self.client_sock.send(pickle.dumps({"command": "goto_line", "line_num": line_num}))
|
||||
|
||||
def get_text(self):
|
||||
self.client_sock.send(pickle.dumps({"command": "get_text"}))
|
||||
|
||||
@ -73,9 +73,7 @@ class AdvancedInputHandler:
|
||||
def _get_editor_input(self, prompt: str) -> Optional[str]:
|
||||
"""Get multi-line input for editor mode."""
|
||||
try:
|
||||
print(
|
||||
"Editor mode: Enter your message. Type 'END' on a new line to finish."
|
||||
)
|
||||
print("Editor mode: Enter your message. Type 'END' on a new line to finish.")
|
||||
print("Type '/simple' to switch back to simple mode.")
|
||||
|
||||
lines = []
|
||||
|
||||
@ -169,9 +169,7 @@ class FactExtractor:
|
||||
sentence_count = len(re.split(r"[.!?]", text))
|
||||
|
||||
urls = re.findall(r"https?://[^\s]+", text)
|
||||
email_addresses = re.findall(
|
||||
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text
|
||||
)
|
||||
email_addresses = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text)
|
||||
dates = re.findall(
|
||||
r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b", text
|
||||
)
|
||||
|
||||
@ -38,8 +38,7 @@ class SemanticIndex:
|
||||
self.idf_scores = {token: 1.0 for token in token_doc_count}
|
||||
else:
|
||||
self.idf_scores = {
|
||||
token: math.log(doc_count / count)
|
||||
for token, count in token_doc_count.items()
|
||||
token: math.log(doc_count / count) for token, count in token_doc_count.items()
|
||||
}
|
||||
|
||||
def add_document(self, doc_id: str, text: str):
|
||||
@ -51,8 +50,7 @@ class SemanticIndex:
|
||||
|
||||
tf_scores = self._compute_tf(tokens)
|
||||
self.doc_vectors[doc_id] = {
|
||||
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0)
|
||||
for token in tokens
|
||||
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0) for token in tokens
|
||||
}
|
||||
|
||||
def remove_document(self, doc_id: str):
|
||||
@ -67,8 +65,7 @@ class SemanticIndex:
|
||||
query_tf = self._compute_tf(query_tokens)
|
||||
|
||||
query_vector = {
|
||||
token: query_tf.get(token, 0) * self.idf_scores.get(token, 0)
|
||||
for token in query_tokens
|
||||
token: query_tf.get(token, 0) * self.idf_scores.get(token, 0) for token in query_tokens
|
||||
}
|
||||
|
||||
scores = []
|
||||
@ -79,9 +76,7 @@ class SemanticIndex:
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return scores[:top_k]
|
||||
|
||||
def _cosine_similarity(
|
||||
self, vec1: Dict[str, float], vec2: Dict[str, float]
|
||||
) -> float:
|
||||
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
|
||||
dot_product = sum(
|
||||
vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2)
|
||||
)
|
||||
|
||||
@ -30,9 +30,7 @@ class TerminalMultiplexer:
|
||||
self.prompt_detector = get_global_detector()
|
||||
|
||||
if self.show_output:
|
||||
self.display_thread = threading.Thread(
|
||||
target=self._display_worker, daemon=True
|
||||
)
|
||||
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
|
||||
self.display_thread.start()
|
||||
|
||||
def _display_worker(self):
|
||||
@ -48,9 +46,7 @@ class TerminalMultiplexer:
|
||||
try:
|
||||
line = self.stderr_queue.get(timeout=0.1)
|
||||
if line:
|
||||
sys.stderr.write(
|
||||
f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}"
|
||||
)
|
||||
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}")
|
||||
sys.stderr.flush()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
@ -52,13 +52,9 @@ class PluginLoader:
|
||||
self.loaded_plugins[plugin_name] = module
|
||||
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Plugin {plugin_name} register_tools() did not return a list"
|
||||
)
|
||||
logger.warning(f"Plugin {plugin_name} register_tools() did not return a list")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Plugin {plugin_name} does not have register_tools() function"
|
||||
)
|
||||
logger.warning(f"Plugin {plugin_name} does not have register_tools() function")
|
||||
|
||||
def get_plugin_function(self, tool_name: str) -> Callable:
|
||||
for plugin_name, module in self.loaded_plugins.items():
|
||||
|
||||
@ -39,9 +39,7 @@ def list_agents() -> Dict[str, Any]:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def execute_agent_task(
|
||||
agent_id: str, task: str, context: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
def execute_agent_task(agent_id: str, task: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""Execute a task with the specified agent."""
|
||||
try:
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
@ -63,9 +61,7 @@ def remove_agent(agent_id: str) -> Dict[str, Any]:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def collaborate_agents(
|
||||
orchestrator_id: str, task: str, agent_roles: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
def collaborate_agents(orchestrator_id: str, task: str, agent_roles: List[str]) -> Dict[str, Any]:
|
||||
"""Collaborate multiple agents on a task."""
|
||||
try:
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
|
||||
@ -235,9 +235,7 @@ def get_tools_definition():
|
||||
"description": "Change the current working directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to change to"}
|
||||
},
|
||||
"properties": {"path": {"type": "string", "description": "Path to change to"}},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
@ -284,9 +282,7 @@ def get_tools_definition():
|
||||
"description": "Execute a database query",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "SQL query"}
|
||||
},
|
||||
"properties": {"query": {"type": "string", "description": "SQL query"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
@ -298,9 +294,7 @@ def get_tools_definition():
|
||||
"description": "Perform a web search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"properties": {"query": {"type": "string", "description": "Search query"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
@ -346,9 +340,7 @@ def get_tools_definition():
|
||||
"description": "Index directory recursively and read all source files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to index"}
|
||||
},
|
||||
"properties": {"path": {"type": "string", "description": "Path to index"}},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
|
||||
@ -81,9 +81,7 @@ def tail_process(pid: int, timeout: int = 30):
|
||||
"pid": pid,
|
||||
}
|
||||
|
||||
ready, _, _ = select.select(
|
||||
[process.stdout, process.stderr], [], [], 0.1
|
||||
)
|
||||
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
|
||||
for pipe in ready:
|
||||
if pipe == process.stdout:
|
||||
line = process.stdout.readline()
|
||||
|
||||
@ -44,9 +44,7 @@ def db_query(query, db_conn):
|
||||
|
||||
if query.strip().upper().startswith("SELECT"):
|
||||
results = cursor.fetchall()
|
||||
columns = (
|
||||
[desc[0] for desc in cursor.description] if cursor.description else []
|
||||
)
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
return {"status": "success", "columns": columns, "rows": results}
|
||||
else:
|
||||
db_conn.commit()
|
||||
|
||||
@ -61,9 +61,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
position = (line if line is not None else 0) * 1000 + (
|
||||
col if col is not None else 0
|
||||
)
|
||||
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
|
||||
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
@ -76,11 +74,7 @@ def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
|
||||
mux_name = f"editor-{path}"
|
||||
mux = get_multiplexer(mux_name)
|
||||
if mux:
|
||||
location = (
|
||||
f" at line {line}, col {col}"
|
||||
if line is not None and col is not None
|
||||
else ""
|
||||
)
|
||||
location = f" at line {line}, col {col}" if line is not None and col is not None else ""
|
||||
preview = text[:50] + "..." if len(text) > 50 else text
|
||||
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")
|
||||
|
||||
|
||||
@ -41,10 +41,7 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
@ -54,9 +51,7 @@ def write_file(filepath, content, db_conn=None, show_diff=True):
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
operation = track_edit(
|
||||
"WRITE", filepath, content=content, old_content=old_content
|
||||
)
|
||||
operation = track_edit("WRITE", filepath, content=content, old_content=old_content)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
if show_diff and not is_new_file:
|
||||
@ -119,9 +114,7 @@ def list_directory(path=".", recursive=False):
|
||||
}
|
||||
)
|
||||
for name in dirs:
|
||||
items.append(
|
||||
{"path": os.path.join(root, name), "type": "directory"}
|
||||
)
|
||||
items.append({"path": os.path.join(root, name), "type": "directory"})
|
||||
else:
|
||||
for item in os.listdir(path):
|
||||
item_path = os.path.join(path, item)
|
||||
@ -129,11 +122,7 @@ def list_directory(path=".", recursive=False):
|
||||
{
|
||||
"name": item,
|
||||
"type": "directory" if os.path.isdir(item_path) else "file",
|
||||
"size": (
|
||||
os.path.getsize(item_path)
|
||||
if os.path.isfile(item_path)
|
||||
else None
|
||||
),
|
||||
"size": (os.path.getsize(item_path) if os.path.isfile(item_path) else None),
|
||||
}
|
||||
)
|
||||
return {"status": "success", "items": items}
|
||||
@ -209,10 +198,7 @@ def search_replace(filepath, old_string, new_string, db_conn=None):
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
@ -259,19 +245,14 @@ def open_editor(filepath):
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def editor_insert_text(
|
||||
filepath, text, line=None, col=None, show_diff=True, db_conn=None
|
||||
):
|
||||
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
|
||||
try:
|
||||
path = os.path.expanduser(filepath)
|
||||
if db_conn:
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
@ -282,9 +263,7 @@ def editor_insert_text(
|
||||
with open(path) as f:
|
||||
old_content = f.read()
|
||||
|
||||
position = (line if line is not None else 0) * 1000 + (
|
||||
col if col is not None else 0
|
||||
)
|
||||
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
|
||||
operation = track_edit("INSERT", filepath, start_pos=position, content=text)
|
||||
tracker.mark_in_progress(operation)
|
||||
|
||||
@ -325,10 +304,7 @@ def editor_replace_text(
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
|
||||
@ -47,9 +47,7 @@ def get_knowledge_entry(entry_id: str) -> Dict[str, Any]:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def search_knowledge(
|
||||
query: str, category: str = None, top_k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
def search_knowledge(query: str, category: str = None, top_k: int = 5) -> Dict[str, Any]:
|
||||
"""Search the knowledge base semantically."""
|
||||
try:
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
@ -75,9 +73,7 @@ def get_knowledge_by_category(category: str, limit: int = 20) -> Dict[str, Any]:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def update_knowledge_importance(
|
||||
entry_id: str, importance_score: float
|
||||
) -> Dict[str, Any]:
|
||||
def update_knowledge_importance(entry_id: str, importance_score: float) -> Dict[str, Any]:
|
||||
"""Update the importance score of a knowledge entry."""
|
||||
try:
|
||||
db_path = os.path.expanduser("~/.assistant_db.sqlite")
|
||||
|
||||
@ -13,10 +13,7 @@ def apply_patch(filepath, patch_content, db_conn=None):
|
||||
from pr.tools.database import db_get
|
||||
|
||||
read_status = db_get("read:" + path, db_conn)
|
||||
if (
|
||||
read_status.get("status") != "success"
|
||||
or read_status.get("value") != "true"
|
||||
):
|
||||
if read_status.get("status") != "success" or read_status.get("value") != "true":
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "File must be read before writing. Please read the file first.",
|
||||
@ -68,9 +65,7 @@ def create_diff(
|
||||
else:
|
||||
lines1 = content1.splitlines(keepends=True)
|
||||
lines2 = content2.splitlines(keepends=True)
|
||||
diff = list(
|
||||
difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile)
|
||||
)
|
||||
diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
|
||||
return {"status": "success", "diff": "".join(diff)}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
@ -94,9 +89,7 @@ def display_file_diff(filepath1, filepath2, format_type="unified", context_lines
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
|
||||
def display_content_diff(
|
||||
old_content, new_content, filename="file", format_type="unified"
|
||||
):
|
||||
def display_content_diff(old_content, new_content, filename="file", format_type="unified"):
|
||||
try:
|
||||
visual_diff = display_diff(old_content, new_content, filename, format_type)
|
||||
stats = get_diff_stats(old_content, new_content)
|
||||
|
||||
@ -187,9 +187,7 @@ class PromptDetector:
|
||||
|
||||
# Detect prompts and determine new state
|
||||
detections = self.detect_prompt(output, process_type)
|
||||
new_state = self._determine_state_from_detections(
|
||||
detections, process_type, old_state
|
||||
)
|
||||
new_state = self._determine_state_from_detections(detections, process_type, old_state)
|
||||
|
||||
if new_state != old_state:
|
||||
session_state["transitions"].append(
|
||||
|
||||
@ -92,15 +92,11 @@ class DiffDisplay:
|
||||
old_line_num, new_line_num = self._parse_hunk_header(line)
|
||||
elif line.startswith("+"):
|
||||
stats.insertions += 1
|
||||
diff_lines.append(
|
||||
DiffLine("add", line[1:].rstrip(), None, new_line_num)
|
||||
)
|
||||
diff_lines.append(DiffLine("add", line[1:].rstrip(), None, new_line_num))
|
||||
new_line_num += 1
|
||||
elif line.startswith("-"):
|
||||
stats.deletions += 1
|
||||
diff_lines.append(
|
||||
DiffLine("delete", line[1:].rstrip(), old_line_num, None)
|
||||
)
|
||||
diff_lines.append(DiffLine("delete", line[1:].rstrip(), old_line_num, None))
|
||||
old_line_num += 1
|
||||
elif line.startswith(" "):
|
||||
diff_lines.append(
|
||||
@ -182,14 +178,10 @@ class DiffDisplay:
|
||||
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "equal":
|
||||
for i, (old_line, new_line) in enumerate(
|
||||
zip(old_lines[i1:i2], new_lines[j1:j2])
|
||||
):
|
||||
for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])):
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(
|
||||
f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}"
|
||||
)
|
||||
output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}")
|
||||
elif tag == "replace":
|
||||
max_lines = max(i2 - i1, j2 - j1)
|
||||
for i in range(max_lines):
|
||||
@ -203,15 +195,11 @@ class DiffDisplay:
|
||||
elif tag == "delete":
|
||||
for old_line in old_lines[i1:i2]:
|
||||
old_display = old_line[:half_width].ljust(half_width)
|
||||
output.append(
|
||||
f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}"
|
||||
)
|
||||
output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}")
|
||||
elif tag == "insert":
|
||||
for new_line in new_lines[j1:j2]:
|
||||
new_display = new_line[:half_width].ljust(half_width)
|
||||
output.append(
|
||||
f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}"
|
||||
)
|
||||
output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}")
|
||||
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
|
||||
return "\n".join(output)
|
||||
|
||||
@ -16,8 +16,6 @@ def display_tool_call(tool_name, arguments, status="running", result=None):
|
||||
|
||||
def print_autonomous_header(task):
|
||||
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
|
||||
print(
|
||||
f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}"
|
||||
)
|
||||
print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}")
|
||||
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
|
||||
print(f"{Colors.BOLD}{'═' * 80}{Colors.RESET}\n")
|
||||
|
||||
@ -46,23 +46,17 @@ class EditOperation:
|
||||
output = [self.format_operation()]
|
||||
|
||||
if self.op_type in ("INSERT", "REPLACE"):
|
||||
output.append(
|
||||
f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}"
|
||||
)
|
||||
output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}")
|
||||
|
||||
if show_content:
|
||||
if self.old_content:
|
||||
lines = self.old_content.split("\n")
|
||||
preview = lines[0][:60] + (
|
||||
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
|
||||
)
|
||||
preview = lines[0][:60] + ("..." if len(lines[0]) > 60 or len(lines) > 1 else "")
|
||||
output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
|
||||
|
||||
if self.content:
|
||||
lines = self.content.split("\n")
|
||||
preview = lines[0][:60] + (
|
||||
"..." if len(lines[0]) > 60 or len(lines) > 1 else ""
|
||||
)
|
||||
preview = lines[0][:60] + ("..." if len(lines[0]) > 60 or len(lines) > 1 else "")
|
||||
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
|
||||
|
||||
return "\n".join(output)
|
||||
@ -93,9 +87,7 @@ class EditTracker:
|
||||
"total": len(self.operations),
|
||||
"completed": sum(1 for op in self.operations if op.status == "completed"),
|
||||
"pending": sum(1 for op in self.operations if op.status == "pending"),
|
||||
"in_progress": sum(
|
||||
1 for op in self.operations if op.status == "in_progress"
|
||||
),
|
||||
"in_progress": sum(1 for op in self.operations if op.status == "in_progress"),
|
||||
"failed": sum(1 for op in self.operations if op.status == "failed"),
|
||||
}
|
||||
return stats
|
||||
@ -112,9 +104,7 @@ class EditTracker:
|
||||
|
||||
output = []
|
||||
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
|
||||
output.append(
|
||||
f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}"
|
||||
)
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}")
|
||||
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
|
||||
|
||||
stats = self.get_stats()
|
||||
|
||||
@ -62,16 +62,10 @@ class ProgressBar:
|
||||
else:
|
||||
percent = int((self.current / self.total) * 100)
|
||||
|
||||
filled = (
|
||||
int((self.current / self.total) * self.width)
|
||||
if self.total > 0
|
||||
else self.width
|
||||
)
|
||||
filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width
|
||||
bar = "█" * filled + "░" * (self.width - filled)
|
||||
|
||||
sys.stdout.write(
|
||||
f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})"
|
||||
)
|
||||
sys.stdout.write(f"\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})")
|
||||
sys.stdout.flush()
|
||||
|
||||
if self.current >= self.total:
|
||||
|
||||
@ -25,12 +25,8 @@ def highlight_code(code, language=None, syntax_highlighting=True):
|
||||
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
|
||||
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
|
||||
|
||||
code = re.sub(
|
||||
r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE
|
||||
)
|
||||
code = re.sub(
|
||||
r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE
|
||||
)
|
||||
code = re.sub(r"#(.*)$", f"{Colors.GRAY}#\\1{Colors.RESET}", code, flags=re.MULTILINE)
|
||||
code = re.sub(r"//(.*)$", f"{Colors.GRAY}//\\1{Colors.RESET}", code, flags=re.MULTILINE)
|
||||
|
||||
return code
|
||||
|
||||
@ -76,11 +72,15 @@ def render_markdown(text, syntax_highlighting=True):
|
||||
elif re.match(r"^\s*[\*\-\+]\s", line):
|
||||
match = re.match(r"^(\s*)([\*\-\+])(\s+.*)", line)
|
||||
if match:
|
||||
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
line = (
|
||||
f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
)
|
||||
elif re.match(r"^\s*\d+\.\s", line):
|
||||
match = re.match(r"^(\s*)(\d+\.)(\s+.*)", line)
|
||||
if match:
|
||||
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
line = (
|
||||
f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
|
||||
)
|
||||
processed_lines.append(line)
|
||||
text = "\n".join(processed_lines)
|
||||
|
||||
|
||||
@ -40,9 +40,7 @@ class WorkflowEngine:
|
||||
self.tool_executor = tool_executor
|
||||
self.max_workers = max_workers
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str, context: WorkflowExecutionContext
|
||||
) -> bool:
|
||||
def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool:
|
||||
if not condition:
|
||||
return True
|
||||
|
||||
@ -188,9 +186,7 @@ class WorkflowEngine:
|
||||
result = future.result()
|
||||
context.log_event("step_completed", step.step_id, result)
|
||||
except Exception as e:
|
||||
context.log_event(
|
||||
"step_failed", step.step_id, {"error": str(e)}
|
||||
)
|
||||
context.log_event("step_failed", step.step_id, {"error": str(e)})
|
||||
|
||||
else:
|
||||
pending_steps = workflow.get_initial_steps()
|
||||
|
||||
@ -104,9 +104,7 @@ class WorkflowStorage:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,)
|
||||
)
|
||||
cursor.execute("SELECT workflow_data FROM workflows WHERE workflow_id = ?", (workflow_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@ -174,9 +172,7 @@ class WorkflowStorage:
|
||||
cursor.execute("DELETE FROM workflows WHERE workflow_id = ?", (workflow_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
|
||||
cursor.execute(
|
||||
"DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)
|
||||
)
|
||||
cursor.execute("DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@ -12,7 +12,9 @@ class TestApi(unittest.TestCase):
|
||||
def test_call_api_success(self, mock_slim, mock_urlopen):
|
||||
mock_slim.return_value = [{"role": "user", "content": "test"}]
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
|
||||
mock_response.read.return_value = (
|
||||
b'{"choices": [{"message": {"content": "response"}}], "usage": {"tokens": 10}}'
|
||||
)
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
result = call_api([], "model", "http://url", "key", True, [{"name": "tool"}])
|
||||
|
||||
@ -73,9 +73,7 @@ class TestAssistant(unittest.TestCase):
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"tool_calls": [
|
||||
{"id": "1", "function": {"name": "test", "arguments": "{}"}}
|
||||
]
|
||||
"tool_calls": [{"id": "1", "function": {"name": "test", "arguments": "{}"}}]
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -107,9 +105,7 @@ class TestAssistant(unittest.TestCase):
|
||||
with patch("builtins.print"):
|
||||
process_message(assistant, "test message")
|
||||
|
||||
assistant.messages.append.assert_called_with(
|
||||
{"role": "user", "content": "test message"}
|
||||
)
|
||||
assistant.messages.append.assert_called_with({"role": "user", "content": "test message"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -93,9 +93,7 @@ def test_main_export_session_md(capsys):
|
||||
|
||||
def test_main_usage(capsys):
|
||||
usage = {"total_requests": 10, "total_tokens": 1000, "total_cost": 0.01}
|
||||
with patch(
|
||||
"pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage
|
||||
):
|
||||
with patch("pr.core.usage_tracker.UsageTracker.get_total_usage", return_value=usage):
|
||||
with patch("sys.argv", ["pr", "--usage"]):
|
||||
main()
|
||||
captured = capsys.readouterr()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user