279 lines
10 KiB
Python
Raw Normal View History

2025-11-04 05:17:27 +01:00
import logging
import json
import uuid
from typing import Optional, Dict, Any, List
from pr.config import (
DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS,
KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED,
MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD
)
from pr.cache import APICache, ToolCache
from pr.workflows import WorkflowEngine, WorkflowStorage
from pr.agents import AgentManager
from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor
from pr.core.advanced_context import AdvancedContextManager
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
logger = logging.getLogger('pr')
class EnhancedAssistant:
def __init__(self, base_assistant):
self.base = base_assistant
if CACHE_ENABLED:
self.api_cache = APICache(DB_PATH, API_CACHE_TTL)
self.tool_cache = ToolCache(DB_PATH, TOOL_CACHE_TTL)
else:
self.api_cache = None
self.tool_cache = None
self.workflow_storage = WorkflowStorage(DB_PATH)
self.workflow_engine = WorkflowEngine(
tool_executor=self._execute_tool_for_workflow,
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS
)
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
self.knowledge_store = KnowledgeStore(DB_PATH)
self.conversation_memory = ConversationMemory(DB_PATH)
self.fact_extractor = FactExtractor()
if ADVANCED_CONTEXT_ENABLED:
self.context_manager = AdvancedContextManager(
knowledge_store=self.knowledge_store,
conversation_memory=self.conversation_memory
)
else:
self.context_manager = None
self.current_conversation_id = str(uuid.uuid4())[:16]
self.conversation_memory.create_conversation(
self.current_conversation_id,
session_id=str(uuid.uuid4())[:16]
)
logger.info("Enhanced Assistant initialized with all features")
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
if self.tool_cache:
cached_result = self.tool_cache.get(tool_name, arguments)
if cached_result is not None:
logger.debug(f"Tool cache hit for {tool_name}")
return cached_result
func_map = {
'read_file': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'read_file', 'arguments': json.dumps(kw)}
}])[0],
'write_file': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'write_file', 'arguments': json.dumps(kw)}
}])[0],
'list_directory': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'list_directory', 'arguments': json.dumps(kw)}
}])[0],
'run_command': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'run_command', 'arguments': json.dumps(kw)}
}])[0],
}
if tool_name in func_map:
result = func_map[tool_name](**arguments)
if self.tool_cache:
content = result.get('content', '')
try:
parsed_content = json.loads(content) if isinstance(content, str) else content
self.tool_cache.set(tool_name, arguments, parsed_content)
except Exception:
pass
return result
return {'error': f'Unknown tool: {tool_name}'}
def _api_caller_for_agent(self, messages: List[Dict[str, Any]],
temperature: float, max_tokens: int) -> Dict[str, Any]:
return call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
use_tools=False,
tools=None,
temperature=temperature,
max_tokens=max_tokens,
verbose=self.base.verbose
)
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.api_cache and CACHE_ENABLED:
cached_response = self.api_cache.get(
self.base.model, messages,
0.7, 4096
)
if cached_response:
logger.debug("API cache hit")
return cached_response
response = call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose
)
if self.api_cache and CACHE_ENABLED and 'error' not in response:
token_count = response.get('usage', {}).get('total_tokens', 0)
self.api_cache.set(
self.base.model, messages,
0.7, 4096,
response, token_count
)
return response
def process_with_enhanced_context(self, user_message: str) -> str:
self.base.messages.append({"role": "user", "content": user_message})
self.conversation_memory.add_message(
self.current_conversation_id,
str(uuid.uuid4())[:16],
'user',
user_message
)
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
facts = self.fact_extractor.extract_facts(user_message)
for fact in facts[:3]:
entry_id = str(uuid.uuid4())[:16]
from pr.memory import KnowledgeEntry
import time
categories = self.fact_extractor.categorize_content(fact['text'])
entry = KnowledgeEntry(
entry_id=entry_id,
category=categories[0] if categories else 'general',
content=fact['text'],
metadata={'type': fact['type'], 'confidence': fact['confidence']},
created_at=time.time(),
updated_at=time.time()
)
self.knowledge_store.add_entry(entry)
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
self.base.messages,
user_message,
include_knowledge=True
)
if self.base.verbose:
logger.info(f"Enhanced context: {context_info}")
working_messages = enhanced_messages
else:
working_messages = self.base.messages
response = self.enhanced_call_api(working_messages)
result = self.base.process_response(response)
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
summary = self.context_manager.advanced_summarize_messages(
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
) if self.context_manager else "Conversation in progress"
topics = self.fact_extractor.categorize_content(summary)
self.conversation_memory.update_conversation_summary(
self.current_conversation_id,
summary,
topics
)
return result
def execute_workflow(self, workflow_name: str,
initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
if not workflow:
return {'error': f'Workflow "{workflow_name}" not found'}
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
execution_id = self.workflow_storage.save_execution(
self.workflow_storage.load_workflow_by_name(workflow_name).name,
context
)
return {
'success': True,
'execution_id': execution_id,
'results': context.step_results,
'execution_log': context.execution_log
}
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
return self.agent_manager.create_agent(role_name, agent_id)
def agent_task(self, agent_id: str, task: str) -> Dict[str, Any]:
return self.agent_manager.execute_agent_task(agent_id, task)
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
orchestrator_id = self.agent_manager.create_agent('orchestrator')
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
return self.knowledge_store.search_entries(query, top_k=limit)
def get_cache_statistics(self) -> Dict[str, Any]:
stats = {}
if self.api_cache:
stats['api_cache'] = self.api_cache.get_statistics()
if self.tool_cache:
stats['tool_cache'] = self.tool_cache.get_statistics()
return stats
def get_workflow_list(self) -> List[Dict[str, Any]]:
return self.workflow_storage.list_workflows()
def get_agent_summary(self) -> Dict[str, Any]:
return self.agent_manager.get_session_summary()
def get_knowledge_statistics(self) -> Dict[str, Any]:
return self.knowledge_store.get_statistics()
def get_conversation_history(self, limit: int = 10) -> List[Dict[str, Any]]:
return self.conversation_memory.get_recent_conversations(limit=limit)
def clear_caches(self):
if self.api_cache:
self.api_cache.clear_all()
if self.tool_cache:
self.tool_cache.clear_all()
logger.info("All caches cleared")
def cleanup(self):
if self.api_cache:
self.api_cache.clear_expired()
if self.tool_cache:
self.tool_cache.clear_expired()
self.agent_manager.clear_session()