Update.
This commit is contained in:
parent
d6b45d662d
commit
820e95ac0c
@ -15,12 +15,14 @@ logger = logging.getLogger(__name__)
|
|||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None # Added for OpenAI tool calls
|
||||||
|
tool_call_id: Optional[str] = None # Added for tool responses
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
type: str
|
type: str
|
||||||
function: Dict[str, Any]
|
function: Dict[str, Any] = None
|
||||||
|
|
||||||
|
|
||||||
class AIResponse(BaseModel):
|
class AIResponse(BaseModel):
|
||||||
@ -47,13 +49,15 @@ class BaseAIClient(ABC):
|
|||||||
async def add_user_message(self, content: str) -> None:
|
async def add_user_message(self, content: str) -> None:
|
||||||
self.messages.append(Message(role="user", content=content))
|
self.messages.append(Message(role="user", content=content))
|
||||||
|
|
||||||
async def add_assistant_message(self, content: str) -> None:
|
async def add_assistant_message(self, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None) -> None:
|
||||||
self.messages.append(Message(role="assistant", content=content))
|
self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls))
|
||||||
|
|
||||||
async def add_tool_result(self, tool_call_id: str, result: str) -> None:
|
async def add_tool_result(self, tool_call_id: str, result: str) -> None:
|
||||||
|
# OpenAI expects tool results in a specific format
|
||||||
self.messages.append(Message(
|
self.messages.append(Message(
|
||||||
role="tool",
|
role="tool",
|
||||||
content=json.dumps({"tool_call_id": tool_call_id, "result": result})
|
content=result,
|
||||||
|
tool_call_id=tool_call_id
|
||||||
))
|
))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -78,13 +82,27 @@ class OpenAIClient(BaseAIClient):
|
|||||||
self.base_url = config.get_completions_url()
|
self.base_url = config.get_completions_url()
|
||||||
self.headers = config.get_auth_headers()
|
self.headers = config.get_auth_headers()
|
||||||
|
|
||||||
|
def _prepare_message_for_api(self, msg: Message) -> Dict[str, Any]:
|
||||||
|
"""Convert internal Message to OpenAI API format"""
|
||||||
|
api_msg = {"role": msg.role, "content": msg.content}
|
||||||
|
|
||||||
|
# Add tool_calls if present (for assistant messages)
|
||||||
|
if msg.tool_calls:
|
||||||
|
api_msg["tool_calls"] = msg.tool_calls
|
||||||
|
|
||||||
|
# Add tool_call_id if present (for tool messages)
|
||||||
|
if msg.tool_call_id:
|
||||||
|
api_msg["tool_call_id"] = msg.tool_call_id
|
||||||
|
|
||||||
|
return api_msg
|
||||||
|
|
||||||
async def chat(self, role: str, message: str) -> str:
|
async def chat(self, role: str, message: str) -> str:
|
||||||
if message:
|
if message:
|
||||||
await self.add_user_message(message)
|
await self.add_user_message(message)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
"messages": [msg.dict() for msg in self.messages],
|
"messages": [self._prepare_message_for_api(msg) for msg in self.messages],
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,21 +128,27 @@ class OpenAIClient(BaseAIClient):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def chat_with_tools(self, role: str, message: str, tools: List[Dict[str, Any]]) -> AIResponse:
|
async def chat_with_tools(self, role: str, message: str, tools: List[Dict[str, Any]]) -> AIResponse:
|
||||||
|
# Add message only if it's not empty
|
||||||
if message:
|
if message:
|
||||||
await self.add_user_message(message)
|
await self.add_user_message(message)
|
||||||
|
|
||||||
|
# If no messages, return empty response
|
||||||
|
if not self.messages:
|
||||||
|
return AIResponse(content="", tool_calls=None, usage=None)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
"messages": [msg.dict() for msg in self.messages],
|
"messages": [self._prepare_message_for_api(msg) for msg in self.messages],
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
payload["tools"] = tools
|
payload["tools"] = tools
|
||||||
|
payload["tool_choice"] = "auto" # Let the model decide when to use tools
|
||||||
|
|
||||||
if self.config.max_tokens:
|
if self.config.max_tokens:
|
||||||
payload["max_tokens"] = self.config.max_tokens
|
payload["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.client.post(
|
response = await self.client.post(
|
||||||
self.base_url,
|
self.base_url,
|
||||||
@ -132,32 +156,35 @@ class OpenAIClient(BaseAIClient):
|
|||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
message_data = choice["message"]
|
message_data = choice["message"]
|
||||||
|
|
||||||
tool_calls = None
|
tool_calls = None
|
||||||
|
tool_calls_raw = None
|
||||||
if "tool_calls" in message_data and message_data["tool_calls"]:
|
if "tool_calls" in message_data and message_data["tool_calls"]:
|
||||||
|
tool_calls_raw = message_data["tool_calls"]
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=tc["id"],
|
id=tc["id"],
|
||||||
type=tc["type"],
|
type=tc["type"],
|
||||||
function=tc["function"]
|
function=tc["function"]
|
||||||
) for tc in message_data["tool_calls"]
|
) for tc in tool_calls_raw
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Content can be None when only tool calls are returned
|
||||||
|
content = message_data.get("content") or ""
|
||||||
|
|
||||||
content = message_data.get("content", "")
|
# Add the assistant's response to history, including tool calls if present
|
||||||
|
await self.add_assistant_message(content, tool_calls_raw)
|
||||||
if content:
|
|
||||||
await self.add_assistant_message(content)
|
|
||||||
|
|
||||||
return AIResponse(
|
return AIResponse(
|
||||||
content=content,
|
content=content,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
usage=data.get("usage")
|
usage=data.get("usage")
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"OpenAI API error: {e}")
|
logger.error(f"OpenAI API error: {e}")
|
||||||
raise
|
raise
|
||||||
@ -168,7 +195,7 @@ class OpenAIClient(BaseAIClient):
|
|||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
"messages": [msg.dict() for msg in self.messages],
|
"messages": [self._prepare_message_for_api(msg) for msg in self.messages],
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
@ -277,8 +277,8 @@ class PyrApp:
|
|||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
try:
|
try:
|
||||||
result = await self.tool_registry.execute_tool(
|
result = await self.tool_registry.execute_tool(
|
||||||
tool_call.function.name,
|
tool_call.function['name'],
|
||||||
tool_call.function.arguments
|
tool_call.function['arguments']
|
||||||
)
|
)
|
||||||
# Send tool result back to AI
|
# Send tool result back to AI
|
||||||
await self.ai_client.add_tool_result(tool_call.id, result)
|
await self.ai_client.add_tool_result(tool_call.id, result)
|
||||||
|
Loading…
Reference in New Issue
Block a user