|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pr.agents import AgentManager
|
|
from pr.cache import APICache, ToolCache
|
|
from pr.config import (
|
|
ADVANCED_CONTEXT_ENABLED,
|
|
API_CACHE_TTL,
|
|
CACHE_ENABLED,
|
|
CONVERSATION_SUMMARY_THRESHOLD,
|
|
DB_PATH,
|
|
KNOWLEDGE_SEARCH_LIMIT,
|
|
MEMORY_AUTO_SUMMARIZE,
|
|
TOOL_CACHE_TTL,
|
|
WORKFLOW_EXECUTOR_MAX_WORKERS,
|
|
)
|
|
from pr.core.advanced_context import AdvancedContextManager
|
|
from pr.core.api import call_api
|
|
from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore
|
|
from pr.tools.base import get_tools_definition
|
|
from pr.workflows import WorkflowEngine, WorkflowStorage
|
|
|
|
logger = logging.getLogger("pr")
|
|
|
|
|
|
class EnhancedAssistant:
|
|
def __init__(self, base_assistant):
|
|
self.base = base_assistant
|
|
|
|
if CACHE_ENABLED:
|
|
self.api_cache = APICache(DB_PATH, API_CACHE_TTL)
|
|
self.tool_cache = ToolCache(DB_PATH, TOOL_CACHE_TTL)
|
|
else:
|
|
self.api_cache = None
|
|
self.tool_cache = None
|
|
|
|
self.workflow_storage = WorkflowStorage(DB_PATH)
|
|
self.workflow_engine = WorkflowEngine(
|
|
tool_executor=self._execute_tool_for_workflow,
|
|
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS,
|
|
)
|
|
|
|
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
|
|
|
|
self.knowledge_store = KnowledgeStore(DB_PATH)
|
|
self.conversation_memory = ConversationMemory(DB_PATH)
|
|
self.fact_extractor = FactExtractor()
|
|
|
|
if ADVANCED_CONTEXT_ENABLED:
|
|
self.context_manager = AdvancedContextManager(
|
|
knowledge_store=self.knowledge_store,
|
|
conversation_memory=self.conversation_memory,
|
|
)
|
|
else:
|
|
self.context_manager = None
|
|
|
|
self.current_conversation_id = str(uuid.uuid4())[:16]
|
|
self.conversation_memory.create_conversation(
|
|
self.current_conversation_id, session_id=str(uuid.uuid4())[:16]
|
|
)
|
|
|
|
logger.info("Enhanced Assistant initialized with all features")
|
|
|
|
def _execute_tool_for_workflow(
|
|
self, tool_name: str, arguments: Dict[str, Any]
|
|
) -> Any:
|
|
if self.tool_cache:
|
|
cached_result = self.tool_cache.get(tool_name, arguments)
|
|
if cached_result is not None:
|
|
logger.debug(f"Tool cache hit for {tool_name}")
|
|
return cached_result
|
|
|
|
func_map = {
|
|
"read_file": lambda **kw: self.base.execute_tool_calls(
|
|
[
|
|
{
|
|
"id": "temp",
|
|
"function": {"name": "read_file", "arguments": json.dumps(kw)},
|
|
}
|
|
]
|
|
)[0],
|
|
"write_file": lambda **kw: self.base.execute_tool_calls(
|
|
[
|
|
{
|
|
"id": "temp",
|
|
"function": {"name": "write_file", "arguments": json.dumps(kw)},
|
|
}
|
|
]
|
|
)[0],
|
|
"list_directory": lambda **kw: self.base.execute_tool_calls(
|
|
[
|
|
{
|
|
"id": "temp",
|
|
"function": {
|
|
"name": "list_directory",
|
|
"arguments": json.dumps(kw),
|
|
},
|
|
}
|
|
]
|
|
)[0],
|
|
"run_command": lambda **kw: self.base.execute_tool_calls(
|
|
[
|
|
{
|
|
"id": "temp",
|
|
"function": {
|
|
"name": "run_command",
|
|
"arguments": json.dumps(kw),
|
|
},
|
|
}
|
|
]
|
|
)[0],
|
|
}
|
|
|
|
if tool_name in func_map:
|
|
result = func_map[tool_name](**arguments)
|
|
|
|
if self.tool_cache:
|
|
content = result.get("content", "")
|
|
try:
|
|
parsed_content = (
|
|
json.loads(content) if isinstance(content, str) else content
|
|
)
|
|
self.tool_cache.set(tool_name, arguments, parsed_content)
|
|
except Exception:
|
|
pass
|
|
|
|
return result
|
|
|
|
return {"error": f"Unknown tool: {tool_name}"}
|
|
|
|
def _api_caller_for_agent(
|
|
self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
|
|
) -> Dict[str, Any]:
|
|
return call_api(
|
|
messages,
|
|
self.base.model,
|
|
self.base.api_url,
|
|
self.base.api_key,
|
|
use_tools=False,
|
|
tools=None,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
verbose=self.base.verbose,
|
|
)
|
|
|
|
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
if self.api_cache and CACHE_ENABLED:
|
|
cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096)
|
|
if cached_response:
|
|
logger.debug("API cache hit")
|
|
return cached_response
|
|
|
|
response = call_api(
|
|
messages,
|
|
self.base.model,
|
|
self.base.api_url,
|
|
self.base.api_key,
|
|
self.base.use_tools,
|
|
get_tools_definition(),
|
|
verbose=self.base.verbose,
|
|
)
|
|
|
|
if self.api_cache and CACHE_ENABLED and "error" not in response:
|
|
token_count = response.get("usage", {}).get("total_tokens", 0)
|
|
self.api_cache.set(
|
|
self.base.model, messages, 0.7, 4096, response, token_count
|
|
)
|
|
|
|
return response
|
|
|
|
def process_with_enhanced_context(self, user_message: str) -> str:
|
|
self.base.messages.append({"role": "user", "content": user_message})
|
|
|
|
self.conversation_memory.add_message(
|
|
self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
|
|
)
|
|
|
|
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
|
|
facts = self.fact_extractor.extract_facts(user_message)
|
|
for fact in facts[:3]:
|
|
entry_id = str(uuid.uuid4())[:16]
|
|
import time
|
|
|
|
from pr.memory import KnowledgeEntry
|
|
|
|
categories = self.fact_extractor.categorize_content(fact["text"])
|
|
entry = KnowledgeEntry(
|
|
entry_id=entry_id,
|
|
category=categories[0] if categories else "general",
|
|
content=fact["text"],
|
|
metadata={"type": fact["type"], "confidence": fact["confidence"]},
|
|
created_at=time.time(),
|
|
updated_at=time.time(),
|
|
)
|
|
self.knowledge_store.add_entry(entry)
|
|
|
|
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
|
|
enhanced_messages, context_info = (
|
|
self.context_manager.create_enhanced_context(
|
|
self.base.messages, user_message, include_knowledge=True
|
|
)
|
|
)
|
|
|
|
if self.base.verbose:
|
|
logger.info(f"Enhanced context: {context_info}")
|
|
|
|
working_messages = enhanced_messages
|
|
else:
|
|
working_messages = self.base.messages
|
|
|
|
response = self.enhanced_call_api(working_messages)
|
|
|
|
result = self.base.process_response(response)
|
|
|
|
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
|
|
summary = (
|
|
self.context_manager.advanced_summarize_messages(
|
|
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
|
|
)
|
|
if self.context_manager
|
|
else "Conversation in progress"
|
|
)
|
|
|
|
topics = self.fact_extractor.categorize_content(summary)
|
|
self.conversation_memory.update_conversation_summary(
|
|
self.current_conversation_id, summary, topics
|
|
)
|
|
|
|
return result
|
|
|
|
def execute_workflow(
|
|
self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
|
|
|
|
if not workflow:
|
|
return {"error": f'Workflow "{workflow_name}" not found'}
|
|
|
|
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
|
|
|
|
execution_id = self.workflow_storage.save_execution(
|
|
self.workflow_storage.load_workflow_by_name(workflow_name).name, context
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"execution_id": execution_id,
|
|
"results": context.step_results,
|
|
"execution_log": context.execution_log,
|
|
}
|
|
|
|
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
|
|
return self.agent_manager.create_agent(role_name, agent_id)
|
|
|
|
def agent_task(self, agent_id: str, task: str) -> Dict[str, Any]:
|
|
return self.agent_manager.execute_agent_task(agent_id, task)
|
|
|
|
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
|
|
orchestrator_id = self.agent_manager.create_agent("orchestrator")
|
|
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
|
|
|
|
def search_knowledge(
|
|
self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT
|
|
) -> List[Any]:
|
|
return self.knowledge_store.search_entries(query, top_k=limit)
|
|
|
|
def get_cache_statistics(self) -> Dict[str, Any]:
|
|
stats = {}
|
|
|
|
if self.api_cache:
|
|
stats["api_cache"] = self.api_cache.get_statistics()
|
|
|
|
if self.tool_cache:
|
|
stats["tool_cache"] = self.tool_cache.get_statistics()
|
|
|
|
return stats
|
|
|
|
def get_workflow_list(self) -> List[Dict[str, Any]]:
|
|
return self.workflow_storage.list_workflows()
|
|
|
|
def get_agent_summary(self) -> Dict[str, Any]:
|
|
return self.agent_manager.get_session_summary()
|
|
|
|
def get_knowledge_statistics(self) -> Dict[str, Any]:
|
|
return self.knowledge_store.get_statistics()
|
|
|
|
def get_conversation_history(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
return self.conversation_memory.get_recent_conversations(limit=limit)
|
|
|
|
def clear_caches(self):
|
|
if self.api_cache:
|
|
self.api_cache.clear_all()
|
|
|
|
if self.tool_cache:
|
|
self.tool_cache.clear_all()
|
|
|
|
logger.info("All caches cleared")
|
|
|
|
def cleanup(self):
|
|
if self.api_cache:
|
|
self.api_cache.clear_expired()
|
|
|
|
if self.tool_cache:
|
|
self.tool_cache.clear_expired()
|
|
|
|
self.agent_manager.clear_session()
|