|
import sqlite3
|
|
import json
|
|
from typing import List, Optional
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
class MessageType(Enum):
|
|
REQUEST = "request"
|
|
RESPONSE = "response"
|
|
NOTIFICATION = "notification"
|
|
|
|
@dataclass
|
|
class AgentMessage:
|
|
message_id: str
|
|
from_agent: str
|
|
to_agent: str
|
|
message_type: MessageType
|
|
content: str
|
|
metadata: dict
|
|
timestamp: float
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
'message_id': self.message_id,
|
|
'from_agent': self.from_agent,
|
|
'to_agent': self.to_agent,
|
|
'message_type': self.message_type.value,
|
|
'content': self.content,
|
|
'metadata': self.metadata,
|
|
'timestamp': self.timestamp
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> 'AgentMessage':
|
|
return cls(
|
|
message_id=data['message_id'],
|
|
from_agent=data['from_agent'],
|
|
to_agent=data['to_agent'],
|
|
message_type=MessageType(data['message_type']),
|
|
content=data['content'],
|
|
metadata=data['metadata'],
|
|
timestamp=data['timestamp']
|
|
)
|
|
|
|
class AgentCommunicationBus:
|
|
def __init__(self, db_path: str):
|
|
self.db_path = db_path
|
|
self.conn = sqlite3.connect(db_path)
|
|
self._create_tables()
|
|
|
|
def _create_tables(self):
|
|
cursor = self.conn.cursor()
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS agent_messages (
|
|
message_id TEXT PRIMARY KEY,
|
|
from_agent TEXT,
|
|
to_agent TEXT,
|
|
message_type TEXT,
|
|
content TEXT,
|
|
metadata TEXT,
|
|
timestamp REAL,
|
|
session_id TEXT,
|
|
read INTEGER DEFAULT 0
|
|
)
|
|
''')
|
|
self.conn.commit()
|
|
|
|
def send_message(self, message: AgentMessage, session_id: Optional[str] = None):
|
|
cursor = self.conn.cursor()
|
|
|
|
cursor.execute('''
|
|
INSERT INTO agent_messages
|
|
(message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
''', (
|
|
message.message_id,
|
|
message.from_agent,
|
|
message.to_agent,
|
|
message.message_type.value,
|
|
message.content,
|
|
json.dumps(message.metadata),
|
|
message.timestamp,
|
|
session_id
|
|
))
|
|
|
|
self.conn.commit()
|
|
|
|
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
|
|
cursor = self.conn.cursor()
|
|
if unread_only:
|
|
cursor.execute('''
|
|
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
|
FROM agent_messages
|
|
WHERE to_agent = ? AND read = 0
|
|
ORDER BY timestamp ASC
|
|
''', (agent_id,))
|
|
else:
|
|
cursor.execute('''
|
|
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
|
FROM agent_messages
|
|
WHERE to_agent = ?
|
|
ORDER BY timestamp ASC
|
|
''', (agent_id,))
|
|
|
|
messages = []
|
|
for row in cursor.fetchall():
|
|
messages.append(AgentMessage(
|
|
message_id=row[0],
|
|
from_agent=row[1],
|
|
to_agent=row[2],
|
|
message_type=MessageType(row[3]),
|
|
content=row[4],
|
|
metadata=json.loads(row[5]) if row[5] else {},
|
|
timestamp=row[6]
|
|
))
|
|
return messages
|
|
|
|
def mark_as_read(self, message_id: str):
|
|
cursor = self.conn.cursor()
|
|
cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,))
|
|
self.conn.commit()
|
|
|
|
def clear_messages(self, session_id: Optional[str] = None):
|
|
cursor = self.conn.cursor()
|
|
if session_id:
|
|
cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,))
|
|
else:
|
|
cursor.execute('DELETE FROM agent_messages')
|
|
self.conn.commit()
|
|
|
|
def close(self):
|
|
self.conn.close()
|
|
|
|
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
|
|
return self.get_messages(agent_id, unread_only=True)
|
|
|
|
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
|
|
cursor = self.conn.cursor()
|
|
cursor.execute('''
|
|
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
|
|
FROM agent_messages
|
|
WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?)
|
|
ORDER BY timestamp ASC
|
|
''', (agent_a, agent_b, agent_b, agent_a))
|
|
|
|
messages = []
|
|
for row in cursor.fetchall():
|
|
messages.append(AgentMessage(
|
|
message_id=row[0],
|
|
from_agent=row[1],
|
|
to_agent=row[2],
|
|
message_type=MessageType(row[3]),
|
|
content=row[4],
|
|
metadata=json.loads(row[5]) if row[5] else {},
|
|
timestamp=row[6]
|
|
))
|
|
return messages |