diff --git a/src/pyr/ai/client.py b/src/pyr/ai/client.py index 1757e69..1a261a3 100644 --- a/src/pyr/ai/client.py +++ b/src/pyr/ai/client.py @@ -15,12 +15,14 @@ logger = logging.getLogger(__name__) class Message(BaseModel): role: str content: str + tool_calls: Optional[List[Dict[str, Any]]] = None # Added for OpenAI tool calls + tool_call_id: Optional[str] = None # Added for tool responses class ToolCall(BaseModel): id: str type: str - function: Dict[str, Any] + function: Dict[str, Any] = None class AIResponse(BaseModel): @@ -47,13 +49,15 @@ class BaseAIClient(ABC): async def add_user_message(self, content: str) -> None: self.messages.append(Message(role="user", content=content)) - async def add_assistant_message(self, content: str) -> None: - self.messages.append(Message(role="assistant", content=content)) + async def add_assistant_message(self, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None) -> None: + self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls)) async def add_tool_result(self, tool_call_id: str, result: str) -> None: + # OpenAI expects tool results in a specific format self.messages.append(Message( role="tool", - content=json.dumps({"tool_call_id": tool_call_id, "result": result}) + content=result, + tool_call_id=tool_call_id )) @abstractmethod @@ -78,13 +82,27 @@ class OpenAIClient(BaseAIClient): self.base_url = config.get_completions_url() self.headers = config.get_auth_headers() + def _prepare_message_for_api(self, msg: Message) -> Dict[str, Any]: + """Convert internal Message to OpenAI API format""" + api_msg = {"role": msg.role, "content": msg.content} + + # Add tool_calls if present (for assistant messages) + if msg.tool_calls: + api_msg["tool_calls"] = msg.tool_calls + + # Add tool_call_id if present (for tool messages) + if msg.tool_call_id: + api_msg["tool_call_id"] = msg.tool_call_id + + return api_msg + async def chat(self, role: str, message: str) -> str: if message: await self.add_user_message(message) payload = { "model": self.config.model, - "messages": [msg.dict() for msg in self.messages], + "messages": [self._prepare_message_for_api(msg) for msg in self.messages], "temperature": self.config.temperature, } @@ -110,21 +128,27 @@ class OpenAIClient(BaseAIClient): raise async def chat_with_tools(self, role: str, message: str, tools: List[Dict[str, Any]]) -> AIResponse: + # Add message only if it's not empty if message: await self.add_user_message(message) + # If no messages, return empty response + if not self.messages: + return AIResponse(content="", tool_calls=None, usage=None) + payload = { "model": self.config.model, - "messages": [msg.dict() for msg in self.messages], + "messages": [self._prepare_message_for_api(msg) for msg in self.messages], "temperature": self.config.temperature, } - + if tools: payload["tools"] = tools - + payload["tool_choice"] = "auto" # Let the model decide when to use tools + if self.config.max_tokens: payload["max_tokens"] = self.config.max_tokens - + try: response = await self.client.post( self.base_url, @@ -132,32 +156,35 @@ class OpenAIClient(BaseAIClient): json=payload ) response.raise_for_status() - + data = response.json() choice = data["choices"][0] message_data = choice["message"] - + tool_calls = None + tool_calls_raw = None if "tool_calls" in message_data and message_data["tool_calls"]: + tool_calls_raw = message_data["tool_calls"] tool_calls = [ ToolCall( id=tc["id"], type=tc["type"], function=tc["function"] - ) for tc in message_data["tool_calls"] + ) for tc in tool_calls_raw ] + + # Content can be None when only tool calls are returned + content = message_data.get("content") or "" - content = message_data.get("content", "") - - if content: - await self.add_assistant_message(content) - + # Add the assistant's response to history, including tool calls if present + await self.add_assistant_message(content, tool_calls_raw) + return AIResponse( content=content, tool_calls=tool_calls, usage=data.get("usage") ) - + except Exception as e: logger.error(f"OpenAI API error: {e}") raise @@ -168,7 +195,7 @@ class OpenAIClient(BaseAIClient): payload = { "model": self.config.model, - "messages": [msg.dict() for msg in self.messages], + "messages": [self._prepare_message_for_api(msg) for msg in self.messages], "temperature": self.config.temperature, "stream": True, } diff --git a/src/pyr/core/app.py b/src/pyr/core/app.py index 422df76..b0493ef 100644 --- a/src/pyr/core/app.py +++ b/src/pyr/core/app.py @@ -277,8 +277,8 @@ class PyrApp: for tool_call in response.tool_calls: try: result = await self.tool_registry.execute_tool( - tool_call.function.name, - tool_call.function.arguments + tool_call.function['name'], + tool_call.function['arguments'] ) # Send tool result back to AI await self.ai_client.add_tool_result(tool_call.id, result)