|
import json
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from ..memory.knowledge_store import KnowledgeStore
|
|
from .agent_communication import AgentCommunicationBus, AgentMessage, MessageType
|
|
from .agent_roles import AgentRole, get_agent_role
|
|
|
|
|
|
@dataclass
|
|
class AgentInstance:
|
|
agent_id: str
|
|
role: AgentRole
|
|
message_history: List[Dict[str, Any]] = field(default_factory=list)
|
|
context: Dict[str, Any] = field(default_factory=dict)
|
|
created_at: float = field(default_factory=time.time)
|
|
task_count: int = 0
|
|
|
|
def add_message(self, role: str, content: str):
|
|
self.message_history.append(
|
|
{"role": role, "content": content, "timestamp": time.time()}
|
|
)
|
|
|
|
def get_system_message(self) -> Dict[str, str]:
|
|
return {"role": "system", "content": self.role.system_prompt}
|
|
|
|
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
|
return [self.get_system_message()] + [
|
|
{"role": msg["role"], "content": msg["content"]}
|
|
for msg in self.message_history
|
|
]
|
|
|
|
|
|
class AgentManager:
|
|
def __init__(self, db_path: str, api_caller: Callable):
|
|
self.db_path = db_path
|
|
self.api_caller = api_caller
|
|
self.communication_bus = AgentCommunicationBus(db_path)
|
|
self.knowledge_store = KnowledgeStore(db_path)
|
|
self.active_agents: Dict[str, AgentInstance] = {}
|
|
self.session_id = str(uuid.uuid4())[:16]
|
|
|
|
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
|
|
if agent_id is None:
|
|
agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}"
|
|
|
|
role = get_agent_role(role_name)
|
|
agent = AgentInstance(agent_id=agent_id, role=role)
|
|
|
|
self.active_agents[agent_id] = agent
|
|
return agent_id
|
|
|
|
def get_agent(self, agent_id: str) -> Optional[AgentInstance]:
|
|
return self.active_agents.get(agent_id)
|
|
|
|
def remove_agent(self, agent_id: str) -> bool:
|
|
if agent_id in self.active_agents:
|
|
del self.active_agents[agent_id]
|
|
return True
|
|
return False
|
|
|
|
def execute_agent_task(
|
|
self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
agent = self.get_agent(agent_id)
|
|
if not agent:
|
|
return {"error": f"Agent {agent_id} not found"}
|
|
|
|
if context:
|
|
agent.context.update(context)
|
|
|
|
agent.add_message("user", task)
|
|
knowledge_matches = self.knowledge_store.search_entries(task, top_k=3)
|
|
agent.task_count += 1
|
|
|
|
messages = agent.get_messages_for_api()
|
|
if knowledge_matches:
|
|
knowledge_content = "Knowledge base matches based on your query:\\n"
|
|
for i, entry in enumerate(knowledge_matches, 1):
|
|
shortened_content = entry.content[:2000]
|
|
knowledge_content += f"{i}. {shortened_content}\\n\\n"
|
|
messages.insert(-1, {"role": "user", "content": knowledge_content})
|
|
|
|
try:
|
|
response = self.api_caller(
|
|
messages=messages,
|
|
temperature=agent.role.temperature,
|
|
max_tokens=agent.role.max_tokens,
|
|
)
|
|
|
|
if response and "choices" in response:
|
|
assistant_message = response["choices"][0]["message"]["content"]
|
|
agent.add_message("assistant", assistant_message)
|
|
|
|
return {
|
|
"success": True,
|
|
"agent_id": agent_id,
|
|
"response": assistant_message,
|
|
"role": agent.role.name,
|
|
"task_count": agent.task_count,
|
|
}
|
|
else:
|
|
return {"error": "Invalid API response", "agent_id": agent_id}
|
|
|
|
except Exception as e:
|
|
return {"error": str(e), "agent_id": agent_id}
|
|
|
|
def send_agent_message(
|
|
self,
|
|
from_agent_id: str,
|
|
to_agent_id: str,
|
|
content: str,
|
|
message_type: MessageType = MessageType.REQUEST,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
):
|
|
message = AgentMessage(
|
|
from_agent=from_agent_id,
|
|
to_agent=to_agent_id,
|
|
message_type=message_type,
|
|
content=content,
|
|
metadata=metadata or {},
|
|
timestamp=time.time(),
|
|
message_id=str(uuid.uuid4())[:16],
|
|
)
|
|
|
|
self.communication_bus.send_message(message, self.session_id)
|
|
return message.message_id
|
|
|
|
def get_agent_messages(
|
|
self, agent_id: str, unread_only: bool = True
|
|
) -> List[AgentMessage]:
|
|
return self.communication_bus.get_messages(agent_id, unread_only)
|
|
|
|
def collaborate_agents(
|
|
self, orchestrator_id: str, task: str, agent_roles: List[str]
|
|
):
|
|
orchestrator = self.get_agent(orchestrator_id)
|
|
if not orchestrator:
|
|
orchestrator_id = self.create_agent("orchestrator")
|
|
orchestrator = self.get_agent(orchestrator_id)
|
|
|
|
worker_agents = []
|
|
for role in agent_roles:
|
|
agent_id = self.create_agent(role)
|
|
worker_agents.append({"agent_id": agent_id, "role": role})
|
|
|
|
orchestration_prompt = f"""Task: {task}
|
|
|
|
Available specialized agents:
|
|
{chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])}
|
|
|
|
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results."""
|
|
|
|
orchestrator_result = self.execute_agent_task(
|
|
orchestrator_id, orchestration_prompt
|
|
)
|
|
|
|
results = {"orchestrator": orchestrator_result, "agents": []}
|
|
|
|
for agent_info in worker_agents:
|
|
agent_id = agent_info["agent_id"]
|
|
messages = self.get_agent_messages(agent_id)
|
|
|
|
for msg in messages:
|
|
subtask = msg.content
|
|
result = self.execute_agent_task(agent_id, subtask)
|
|
results["agents"].append(result)
|
|
|
|
self.send_agent_message(
|
|
from_agent_id=agent_id,
|
|
to_agent_id=orchestrator_id,
|
|
content=result.get("response", ""),
|
|
message_type=MessageType.RESPONSE,
|
|
)
|
|
self.communication_bus.mark_as_read(msg.message_id)
|
|
|
|
return results
|
|
|
|
def get_session_summary(self) -> str:
|
|
summary = {
|
|
"session_id": self.session_id,
|
|
"active_agents": len(self.active_agents),
|
|
"agents": [
|
|
{
|
|
"agent_id": agent_id,
|
|
"role": agent.role.name,
|
|
"task_count": agent.task_count,
|
|
"message_count": len(agent.message_history),
|
|
}
|
|
for agent_id, agent in self.active_agents.items()
|
|
],
|
|
}
|
|
return json.dumps(summary)
|
|
|
|
def clear_session(self):
|
|
self.active_agents.clear()
|
|
self.communication_bus.clear_messages(session_id=self.session_id)
|
|
self.session_id = str(uuid.uuid4())[:16]
|