349 lines
12 KiB
Python
Raw Normal View History

import asyncio
2025-11-04 05:17:27 +01:00
import json
2025-11-04 08:09:12 +01:00
import logging
2025-11-04 05:17:27 +01:00
import uuid
2025-11-04 08:09:12 +01:00
from typing import Any, Dict, List, Optional
from pr.agents import AgentManager
from pr.cache import APICache, ToolCache
2025-11-04 05:17:27 +01:00
from pr.config import (
2025-11-04 08:09:12 +01:00
ADVANCED_CONTEXT_ENABLED,
API_CACHE_TTL,
CACHE_ENABLED,
CONVERSATION_SUMMARY_THRESHOLD,
DB_PATH,
KNOWLEDGE_SEARCH_LIMIT,
TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS,
2025-11-04 05:17:27 +01:00
)
from pr.core.advanced_context import AdvancedContextManager
from pr.core.api import call_api
2025-11-04 08:09:12 +01:00
from pr.memory import ConversationMemory, FactExtractor, KnowledgeStore
2025-11-04 05:17:27 +01:00
from pr.tools.base import get_tools_definition
2025-11-04 08:09:12 +01:00
from pr.workflows import WorkflowEngine, WorkflowStorage
logger = logging.getLogger("pr")
2025-11-04 05:17:27 +01:00
class EnhancedAssistant:
def __init__(self, base_assistant):
self.base = base_assistant
if CACHE_ENABLED:
self.api_cache = APICache(DB_PATH, API_CACHE_TTL)
self.tool_cache = ToolCache(DB_PATH, TOOL_CACHE_TTL)
else:
self.api_cache = None
self.tool_cache = None
self.workflow_storage = WorkflowStorage(DB_PATH)
self.workflow_engine = WorkflowEngine(
tool_executor=self._execute_tool_for_workflow,
2025-11-04 08:09:12 +01:00
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS,
2025-11-04 05:17:27 +01:00
)
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
self.knowledge_store = KnowledgeStore(DB_PATH)
self.conversation_memory = ConversationMemory(DB_PATH)
self.fact_extractor = FactExtractor()
if ADVANCED_CONTEXT_ENABLED:
self.context_manager = AdvancedContextManager(
knowledge_store=self.knowledge_store,
2025-11-04 08:09:12 +01:00
conversation_memory=self.conversation_memory,
2025-11-04 05:17:27 +01:00
)
else:
self.context_manager = None
self.current_conversation_id = str(uuid.uuid4())[:16]
self.conversation_memory.create_conversation(
2025-11-04 08:09:12 +01:00
self.current_conversation_id, session_id=str(uuid.uuid4())[:16]
2025-11-04 05:17:27 +01:00
)
logger.info("Enhanced Assistant initialized with all features")
2025-11-04 08:10:37 +01:00
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
2025-11-04 05:17:27 +01:00
if self.tool_cache:
cached_result = self.tool_cache.get(tool_name, arguments)
if cached_result is not None:
logger.debug(f"Tool cache hit for {tool_name}")
return cached_result
func_map = {
2025-11-04 08:09:12 +01:00
"read_file": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {"name": "read_file", "arguments": json.dumps(kw)},
}
]
)[0],
"write_file": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {"name": "write_file", "arguments": json.dumps(kw)},
}
]
)[0],
"list_directory": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {
"name": "list_directory",
"arguments": json.dumps(kw),
},
}
]
)[0],
"run_command": lambda **kw: self.base.execute_tool_calls(
[
{
"id": "temp",
"function": {
"name": "run_command",
"arguments": json.dumps(kw),
},
}
]
)[0],
2025-11-04 05:17:27 +01:00
}
if tool_name in func_map:
result = func_map[tool_name](**arguments)
if self.tool_cache:
2025-11-04 08:09:12 +01:00
content = result.get("content", "")
2025-11-04 05:17:27 +01:00
try:
2025-11-04 08:10:37 +01:00
parsed_content = json.loads(content) if isinstance(content, str) else content
2025-11-04 05:17:27 +01:00
self.tool_cache.set(tool_name, arguments, parsed_content)
except Exception:
pass
return result
2025-11-04 08:09:12 +01:00
return {"error": f"Unknown tool: {tool_name}"}
2025-11-04 05:17:27 +01:00
2025-11-04 08:09:12 +01:00
def _api_caller_for_agent(
self, messages: List[Dict[str, Any]], temperature: float, max_tokens: int
) -> Dict[str, Any]:
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# If we're in an event loop, use run_coroutine_threadsafe
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = asyncio.run_coroutine_threadsafe(
call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
use_tools=False,
tools_definition=[],
verbose=self.base.verbose,
),
loop,
)
return future.result()
except RuntimeError:
# No event loop running, use asyncio.run
return asyncio.run(
call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
use_tools=False,
tools_definition=[],
verbose=self.base.verbose,
)
)
2025-11-04 05:17:27 +01:00
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.api_cache and CACHE_ENABLED:
2025-11-04 08:09:12 +01:00
cached_response = self.api_cache.get(self.base.model, messages, 0.7, 4096)
2025-11-04 05:17:27 +01:00
if cached_response:
logger.debug("API cache hit")
return cached_response
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# If we're in an event loop, use run_coroutine_threadsafe
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = asyncio.run_coroutine_threadsafe(
call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose,
),
loop,
)
response = future.result()
except RuntimeError:
# No event loop running, use asyncio.run
response = asyncio.run(
call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose,
)
)
2025-11-04 05:17:27 +01:00
2025-11-04 08:09:12 +01:00
if self.api_cache and CACHE_ENABLED and "error" not in response:
token_count = response.get("usage", {}).get("total_tokens", 0)
2025-11-04 08:10:37 +01:00
self.api_cache.set(self.base.model, messages, 0.7, 4096, response, token_count)
2025-11-04 05:17:27 +01:00
return response
def process_with_enhanced_context(self, user_message: str) -> str:
self.base.messages.append({"role": "user", "content": user_message})
self.conversation_memory.add_message(
2025-11-04 08:09:12 +01:00
self.current_conversation_id, str(uuid.uuid4())[:16], "user", user_message
2025-11-04 05:17:27 +01:00
)
# Automatically extract and store facts from every user message
facts = self.fact_extractor.extract_facts(user_message)
for fact in facts[:5]: # Store up to 5 facts per message
entry_id = str(uuid.uuid4())[:16]
import time
from pr.memory import KnowledgeEntry
categories = self.fact_extractor.categorize_content(fact["text"])
entry = KnowledgeEntry(
entry_id=entry_id,
category=categories[0] if categories else "general",
content=fact["text"],
metadata={
"type": fact["type"],
"confidence": fact["confidence"],
"source": "user_message",
},
created_at=time.time(),
updated_at=time.time(),
)
self.knowledge_store.add_entry(entry)
2025-11-04 05:17:27 +01:00
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
2025-11-04 08:10:37 +01:00
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
self.base.messages, user_message, include_knowledge=True
2025-11-04 05:17:27 +01:00
)
if self.base.verbose:
logger.info(f"Enhanced context: {context_info}")
working_messages = enhanced_messages
else:
working_messages = self.base.messages
response = self.enhanced_call_api(working_messages)
result = self.base.process_response(response)
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
2025-11-04 08:09:12 +01:00
summary = (
self.context_manager.advanced_summarize_messages(
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
)
if self.context_manager
else "Conversation in progress"
)
2025-11-04 05:17:27 +01:00
topics = self.fact_extractor.categorize_content(summary)
self.conversation_memory.update_conversation_summary(
2025-11-04 08:09:12 +01:00
self.current_conversation_id, summary, topics
2025-11-04 05:17:27 +01:00
)
return result
2025-11-04 08:09:12 +01:00
def execute_workflow(
self, workflow_name: str, initial_variables: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
2025-11-04 05:17:27 +01:00
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
if not workflow:
2025-11-04 08:09:12 +01:00
return {"error": f'Workflow "{workflow_name}" not found'}
2025-11-04 05:17:27 +01:00
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
execution_id = self.workflow_storage.save_execution(
2025-11-04 08:09:12 +01:00
self.workflow_storage.load_workflow_by_name(workflow_name).name, context
2025-11-04 05:17:27 +01:00
)
return {
2025-11-04 08:09:12 +01:00
"success": True,
"execution_id": execution_id,
"results": context.step_results,
"execution_log": context.execution_log,
2025-11-04 05:17:27 +01:00
}
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
return self.agent_manager.create_agent(role_name, agent_id)
def agent_task(self, agent_id: str, task: str) -> Dict[str, Any]:
return self.agent_manager.execute_agent_task(agent_id, task)
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
2025-11-04 08:09:12 +01:00
orchestrator_id = self.agent_manager.create_agent("orchestrator")
2025-11-04 05:17:27 +01:00
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
2025-11-04 08:10:37 +01:00
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
2025-11-04 05:17:27 +01:00
return self.knowledge_store.search_entries(query, top_k=limit)
def get_cache_statistics(self) -> Dict[str, Any]:
stats = {}
if self.api_cache:
2025-11-04 08:09:12 +01:00
stats["api_cache"] = self.api_cache.get_statistics()
2025-11-04 05:17:27 +01:00
if self.tool_cache:
2025-11-04 08:09:12 +01:00
stats["tool_cache"] = self.tool_cache.get_statistics()
2025-11-04 05:17:27 +01:00
return stats
def get_workflow_list(self) -> List[Dict[str, Any]]:
return self.workflow_storage.list_workflows()
def get_agent_summary(self) -> Dict[str, Any]:
return self.agent_manager.get_session_summary()
def get_knowledge_statistics(self) -> Dict[str, Any]:
return self.knowledge_store.get_statistics()
def get_conversation_history(self, limit: int = 10) -> List[Dict[str, Any]]:
return self.conversation_memory.get_recent_conversations(limit=limit)
def clear_caches(self):
if self.api_cache:
self.api_cache.clear_all()
if self.tool_cache:
self.tool_cache.clear_all()
logger.info("All caches cleared")
def cleanup(self):
if self.api_cache:
self.api_cache.clear_expired()
if self.tool_cache:
self.tool_cache.clear_expired()
self.agent_manager.clear_session()