This commit is contained in:
retoor 2025-08-21 11:29:20 +02:00
parent d6b45d662d
commit 820e95ac0c
2 changed files with 48 additions and 21 deletions

View File

@ -15,12 +15,14 @@ logger = logging.getLogger(__name__)
class Message(BaseModel): class Message(BaseModel):
role: str role: str
content: str content: str
tool_calls: Optional[List[Dict[str, Any]]] = None # Added for OpenAI tool calls
tool_call_id: Optional[str] = None # Added for tool responses
class ToolCall(BaseModel): class ToolCall(BaseModel):
id: str id: str
type: str type: str
function: Dict[str, Any] function: Dict[str, Any] = None
class AIResponse(BaseModel): class AIResponse(BaseModel):
@ -47,13 +49,15 @@ class BaseAIClient(ABC):
async def add_user_message(self, content: str) -> None: async def add_user_message(self, content: str) -> None:
self.messages.append(Message(role="user", content=content)) self.messages.append(Message(role="user", content=content))
async def add_assistant_message(self, content: str) -> None: async def add_assistant_message(self, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None) -> None:
self.messages.append(Message(role="assistant", content=content)) self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls))
async def add_tool_result(self, tool_call_id: str, result: str) -> None: async def add_tool_result(self, tool_call_id: str, result: str) -> None:
# OpenAI expects tool results in a specific format
self.messages.append(Message( self.messages.append(Message(
role="tool", role="tool",
content=json.dumps({"tool_call_id": tool_call_id, "result": result}) content=result,
tool_call_id=tool_call_id
)) ))
@abstractmethod @abstractmethod
@ -78,13 +82,27 @@ class OpenAIClient(BaseAIClient):
self.base_url = config.get_completions_url() self.base_url = config.get_completions_url()
self.headers = config.get_auth_headers() self.headers = config.get_auth_headers()
def _prepare_message_for_api(self, msg: Message) -> Dict[str, Any]:
"""Convert internal Message to OpenAI API format"""
api_msg = {"role": msg.role, "content": msg.content}
# Add tool_calls if present (for assistant messages)
if msg.tool_calls:
api_msg["tool_calls"] = msg.tool_calls
# Add tool_call_id if present (for tool messages)
if msg.tool_call_id:
api_msg["tool_call_id"] = msg.tool_call_id
return api_msg
async def chat(self, role: str, message: str) -> str: async def chat(self, role: str, message: str) -> str:
if message: if message:
await self.add_user_message(message) await self.add_user_message(message)
payload = { payload = {
"model": self.config.model, "model": self.config.model,
"messages": [msg.dict() for msg in self.messages], "messages": [self._prepare_message_for_api(msg) for msg in self.messages],
"temperature": self.config.temperature, "temperature": self.config.temperature,
} }
@ -110,21 +128,27 @@ class OpenAIClient(BaseAIClient):
raise raise
async def chat_with_tools(self, role: str, message: str, tools: List[Dict[str, Any]]) -> AIResponse: async def chat_with_tools(self, role: str, message: str, tools: List[Dict[str, Any]]) -> AIResponse:
# Add message only if it's not empty
if message: if message:
await self.add_user_message(message) await self.add_user_message(message)
# If no messages, return empty response
if not self.messages:
return AIResponse(content="", tool_calls=None, usage=None)
payload = { payload = {
"model": self.config.model, "model": self.config.model,
"messages": [msg.dict() for msg in self.messages], "messages": [self._prepare_message_for_api(msg) for msg in self.messages],
"temperature": self.config.temperature, "temperature": self.config.temperature,
} }
if tools: if tools:
payload["tools"] = tools payload["tools"] = tools
payload["tool_choice"] = "auto" # Let the model decide when to use tools
if self.config.max_tokens: if self.config.max_tokens:
payload["max_tokens"] = self.config.max_tokens payload["max_tokens"] = self.config.max_tokens
try: try:
response = await self.client.post( response = await self.client.post(
self.base_url, self.base_url,
@ -132,32 +156,35 @@ class OpenAIClient(BaseAIClient):
json=payload json=payload
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
choice = data["choices"][0] choice = data["choices"][0]
message_data = choice["message"] message_data = choice["message"]
tool_calls = None tool_calls = None
tool_calls_raw = None
if "tool_calls" in message_data and message_data["tool_calls"]: if "tool_calls" in message_data and message_data["tool_calls"]:
tool_calls_raw = message_data["tool_calls"]
tool_calls = [ tool_calls = [
ToolCall( ToolCall(
id=tc["id"], id=tc["id"],
type=tc["type"], type=tc["type"],
function=tc["function"] function=tc["function"]
) for tc in message_data["tool_calls"] ) for tc in tool_calls_raw
] ]
# Content can be None when only tool calls are returned
content = message_data.get("content") or ""
content = message_data.get("content", "") # Add the assistant's response to history, including tool calls if present
await self.add_assistant_message(content, tool_calls_raw)
if content:
await self.add_assistant_message(content)
return AIResponse( return AIResponse(
content=content, content=content,
tool_calls=tool_calls, tool_calls=tool_calls,
usage=data.get("usage") usage=data.get("usage")
) )
except Exception as e: except Exception as e:
logger.error(f"OpenAI API error: {e}") logger.error(f"OpenAI API error: {e}")
raise raise
@ -168,7 +195,7 @@ class OpenAIClient(BaseAIClient):
payload = { payload = {
"model": self.config.model, "model": self.config.model,
"messages": [msg.dict() for msg in self.messages], "messages": [self._prepare_message_for_api(msg) for msg in self.messages],
"temperature": self.config.temperature, "temperature": self.config.temperature,
"stream": True, "stream": True,
} }

View File

@ -277,8 +277,8 @@ class PyrApp:
for tool_call in response.tool_calls: for tool_call in response.tool_calls:
try: try:
result = await self.tool_registry.execute_tool( result = await self.tool_registry.execute_tool(
tool_call.function.name, tool_call.function['name'],
tool_call.function.arguments tool_call.function['arguments']
) )
# Send tool result back to AI # Send tool result back to AI
await self.ai_client.add_tool_result(tool_call.id, result) await self.ai_client.add_tool_result(tool_call.id, result)