import json import logging import uuid from typing import Any, Dict, List, Optional from pr.agents import AgentManager from pr.cache import APICache, ToolCache from pr.config import ( ADVANCED_CONTEXT_ENABLED, API_CACHE_TTL, CACHE_ENABLED, CONVERSATION_SUMMARY_THRESHOLD, DB_PATH, KNOWLEDGE_SEARCH_LIMIT, MEMORY_AUTO_SUMMARIZE, TOOL_CACHE_TTL, WORKFLOW_EXECUTOR_MAX_WORKERS, ) from pr.core.advanced_context import AdvancedContextManager from pr.core.api import call_api from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore from pr.tools.base import get_tools_definition from pr.workflows import WorkflowEngine, WorkflowStorage logger = logging.getLogger("pr") class EnhancedAssistant: def __init__(self, base_assistant): self.base = base_assistant if CACHE_ENABLED: self.api_cache = APICache(DB_PATH, API_CACHE_TTL) self.tool_cache = ToolCache(DB_PATH, TOOL_CACHE_TTL) else: self.api_cache = None self.tool_cache = None self.workflow_storage = WorkflowStorage(DB_PATH) self.workflow_engine = WorkflowEngine( tool_executor=self._execute_tool_for_workflow, max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS, ) self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent) self.knowledge_store = KnowledgeStore(DB_PATH) self.conversation_memory = ConversationMemory(DB_PATH) self.fact_extractor = FactExtractor() if ADVANCED_CONTEXT_ENABLED: self.context_manager = AdvancedContextManager( knowledge_store=self.knowledge_store, conversation_memory=self.conversation_memory, ) else: self.context_manager = None self.current_conversation_id = str(uuid.uuid4())[:16] self.conversation_memory.create_conversation( self.current_conversation_id, session_id=str(uuid.uuid4())[:16] ) logger.info("Enhanced Assistant initialized with all features") def _execute_tool_for_workflow( self, tool_name: str, arguments: Dict[str, Any] ) -> Any: if self.tool_cache: cached_result = self.tool_cache.get(tool_name, arguments) if cached_result is not None: logger.debug(f"Tool cache hit for {tool_name}") return cached_result func_map = { "read_file": lambda **kw: self.base.execute_tool_calls( [ { "id": "temp", "function": {"name": "read_file", "arguments": json.dumps(kw)}, } ] )[0], "write_file": lambda **kw: self.base.execute_tool_calls( [ { "id": "temp", "function": {"name": "write_file", "arguments": json.dumps(kw)}, } ] )[0], "list_directory": lambda **kw: self.base.execute_tool_calls( [ { "id": "temp", "function": { "name": "list_directory", "arguments": json.dumps(kw), }, } ] )[0], "run_command": lambda **kw: self.base.execute_tool_calls( [ { "id": "temp", "function": { "name": "run_command", "arguments": json.dumps(kw), }, } ] )[0], } if tool_name in func_map: result = func_map[tool_name](**arguments) if self.tool_cache: content = result.get("content", "") try: parsed_content = ( json.loads(content) if isinstance(content, str) else content ) self.tool_cache.set(tool_name, arguments, parsed_content) except Exception: pass return result return {"error": f"Unknown tool: {tool_name}"} def _api_caller_for_agent( self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int ) -> Dict[str, Any]: return call_api( messages, self.base.model, self.base.api_url, self.base.api_key, use_tools=False, tools=None, temperature=temperature, max_tokens=max_tokens, verbose=self.base.verbose, ) def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: if self.api_cache and CACHE_ENABLED: cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096) if cached_response: logger.debug("API cache hit") return cached_response response = call_api( messages, self.base.model, self.base.api_url, self.base.api_key, self.base.use_tools, get_tools_definition(), verbose=self.base.verbose, ) if self.api_cache and CACHE_ENABLED and "error" not in response: token_count = response.get("usage", {}).get("total_tokens", 0) self.api_cache.set( self.base.model, messages, 0.7, 4096, response, token_count ) return response def process_with_enhanced_context(self, user_message: str) -> str: self.base.messages.append({"role": "user", "content": user_message}) self.conversation_memory.add_message( self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message ) if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0: facts = self.fact_extractor.extract_facts(user_message) for fact in facts[:3]: entry_id = str(uuid.uuid4())[:16] import time from pr.memory import KnowledgeEntry categories = self.fact_extractor.categorize_content(fact["text"]) entry = KnowledgeEntry( entry_id=entry_id, category=categories[0] if categories else "general", content=fact["text"], metadata={"type": fact["type"], "confidence": fact["confidence"]}, created_at=time.time(), updated_at=time.time(), ) self.knowledge_store.add_entry(entry) if self.context_manager and ADVANCED_CONTEXT_ENABLED: enhanced_messages, context_info = ( self.context_manager.create_enhanced_context( self.base.messages, user_message, include_knowledge=True ) ) if self.base.verbose: logger.info(f"Enhanced context: {context_info}") working_messages = enhanced_messages else: working_messages = self.base.messages response = self.enhanced_call_api(working_messages) result = self.base.process_response(response) if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD: summary = ( self.context_manager.advanced_summarize_messages( self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:] ) if self.context_manager else "Conversation in progress" ) topics = self.fact_extractor.categorize_content(summary) self.conversation_memory.update_conversation_summary( self.current_conversation_id, summary, topics ) return result def execute_workflow( self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: workflow = self.workflow_storage.load_workflow_by_name(workflow_name) if not workflow: return {"error": f'Workflow "{workflow_name}" not found'} context = self.workflow_engine.execute_workflow(workflow, initial_variables) execution_id = self.workflow_storage.save_execution( self.workflow_storage.load_workflow_by_name(workflow_name).name, context ) return { "success": True, "execution_id": execution_id, "results": context.step_results, "execution_log": context.execution_log, } def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str: return self.agent_manager.create_agent(role_name, agent_id) def agent_task(self, agent_id: str, task: str) -> Dict[str, Any]: return self.agent_manager.execute_agent_task(agent_id, task) def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]: orchestrator_id = self.agent_manager.create_agent("orchestrator") return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles) def search_knowledge( self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT ) -> List[Any]: return self.knowledge_store.search_entries(query, top_k=limit) def get_cache_statistics(self) -> Dict[str, Any]: stats = {} if self.api_cache: stats["api_cache"] = self.api_cache.get_statistics() if self.tool_cache: stats["tool_cache"] = self.tool_cache.get_statistics() return stats def get_workflow_list(self) -> List[Dict[str, Any]]: return self.workflow_storage.list_workflows() def get_agent_summary(self) -> Dict[str, Any]: return self.agent_manager.get_session_summary() def get_knowledge_statistics(self) -> Dict[str, Any]: return self.knowledge_store.get_statistics() def get_conversation_history(self, limit: int = 10) -> List[Dict[str, Any]]: return self.conversation_memory.get_recent_conversations(limit=limit) def clear_caches(self): if self.api_cache: self.api_cache.clear_all() if self.tool_cache: self.tool_cache.clear_all() logger.info("All caches cleared") def cleanup(self): if self.api_cache: self.api_cache.clear_expired() if self.tool_cache: self.tool_cache.clear_expired() self.agent_manager.clear_session()