Update.
This commit is contained in:
parent
d6b45d662d
commit
820e95ac0c
@ -15,12 +15,14 @@ logger = logging.getLogger(__name__)
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None # Added for OpenAI tool calls
|
||||
tool_call_id: Optional[str] = None # Added for tool responses
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
function: Dict[str, Any]
|
||||
function: Dict[str, Any] = None
|
||||
|
||||
|
||||
class AIResponse(BaseModel):
|
||||
@ -47,13 +49,15 @@ class BaseAIClient(ABC):
|
||||
async def add_user_message(self, content: str) -> None:
|
||||
self.messages.append(Message(role="user", content=content))
|
||||
|
||||
async def add_assistant_message(self, content: str) -> None:
|
||||
self.messages.append(Message(role="assistant", content=content))
|
||||
async def add_assistant_message(self, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None) -> None:
|
||||
self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls))
|
||||
|
||||
async def add_tool_result(self, tool_call_id: str, result: str) -> None:
|
||||
# OpenAI expects tool results in a specific format
|
||||
self.messages.append(Message(
|
||||
role="tool",
|
||||
content=json.dumps({"tool_call_id": tool_call_id, "result": result})
|
||||
content=result,
|
||||
tool_call_id=tool_call_id
|
||||
))
|
||||
|
||||
@abstractmethod
|
||||
@ -78,13 +82,27 @@ class OpenAIClient(BaseAIClient):
|
||||
self.base_url = config.get_completions_url()
|
||||
self.headers = config.get_auth_headers()
|
||||
|
||||
def _prepare_message_for_api(self, msg: Message) -> Dict[str, Any]:
|
||||
"""Convert internal Message to OpenAI API format"""
|
||||
api_msg = {"role": msg.role, "content": msg.content}
|
||||
|
||||
# Add tool_calls if present (for assistant messages)
|
||||
if msg.tool_calls:
|
||||
api_msg["tool_calls"] = msg.tool_calls
|
||||
|
||||
# Add tool_call_id if present (for tool messages)
|
||||
if msg.tool_call_id:
|
||||
api_msg["tool_call_id"] = msg.tool_call_id
|
||||
|
||||
return api_msg
|
||||
|
||||
async def chat(self, role: str, message: str) -> str:
|
||||
if message:
|
||||
await self.add_user_message(message)
|
||||
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": [msg.dict() for msg in self.messages],
|
||||
"messages": [self._prepare_message_for_api(msg) for msg in self.messages],
|
||||
"temperature": self.config.temperature,
|
||||
}
|
||||
|
||||
@ -110,21 +128,27 @@ class OpenAIClient(BaseAIClient):
|
||||
raise
|
||||
|
||||
async def chat_with_tools(self, role: str, message: str, tools: List[Dict[str, Any]]) -> AIResponse:
|
||||
# Add message only if it's not empty
|
||||
if message:
|
||||
await self.add_user_message(message)
|
||||
|
||||
# If no messages, return empty response
|
||||
if not self.messages:
|
||||
return AIResponse(content="", tool_calls=None, usage=None)
|
||||
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": [msg.dict() for msg in self.messages],
|
||||
"messages": [self._prepare_message_for_api(msg) for msg in self.messages],
|
||||
"temperature": self.config.temperature,
|
||||
}
|
||||
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
payload["tool_choice"] = "auto" # Let the model decide when to use tools
|
||||
|
||||
if self.config.max_tokens:
|
||||
payload["max_tokens"] = self.config.max_tokens
|
||||
|
||||
|
||||
try:
|
||||
response = await self.client.post(
|
||||
self.base_url,
|
||||
@ -132,32 +156,35 @@ class OpenAIClient(BaseAIClient):
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
data = response.json()
|
||||
choice = data["choices"][0]
|
||||
message_data = choice["message"]
|
||||
|
||||
|
||||
tool_calls = None
|
||||
tool_calls_raw = None
|
||||
if "tool_calls" in message_data and message_data["tool_calls"]:
|
||||
tool_calls_raw = message_data["tool_calls"]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=tc["function"]
|
||||
) for tc in message_data["tool_calls"]
|
||||
) for tc in tool_calls_raw
|
||||
]
|
||||
|
||||
# Content can be None when only tool calls are returned
|
||||
content = message_data.get("content") or ""
|
||||
|
||||
content = message_data.get("content", "")
|
||||
|
||||
if content:
|
||||
await self.add_assistant_message(content)
|
||||
|
||||
# Add the assistant's response to history, including tool calls if present
|
||||
await self.add_assistant_message(content, tool_calls_raw)
|
||||
|
||||
return AIResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
usage=data.get("usage")
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
@ -168,7 +195,7 @@ class OpenAIClient(BaseAIClient):
|
||||
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": [msg.dict() for msg in self.messages],
|
||||
"messages": [self._prepare_message_for_api(msg) for msg in self.messages],
|
||||
"temperature": self.config.temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
@ -277,8 +277,8 @@ class PyrApp:
|
||||
for tool_call in response.tool_calls:
|
||||
try:
|
||||
result = await self.tool_registry.execute_tool(
|
||||
tool_call.function.name,
|
||||
tool_call.function.arguments
|
||||
tool_call.function['name'],
|
||||
tool_call.function['arguments']
|
||||
)
|
||||
# Send tool result back to AI
|
||||
await self.ai_client.add_tool_result(tool_call.id, result)
|
||||
|
Loading…
Reference in New Issue
Block a user