From f24569f4981a04e5f4cb8763d88d9c5691420c5d Mon Sep 17 00:00:00 2001 From: retoor Date: Sun, 22 Dec 2024 12:31:45 +0100 Subject: [PATCH] Update agent. --- src/app/agent.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/app/agent.py b/src/app/agent.py index 8a5fa09..0b3e2ef 100644 --- a/src/app/agent.py +++ b/src/app/agent.py @@ -72,8 +72,12 @@ class Agent: self.api_key = api_key self.client = OpenAI(api_key=self.api_key) self.messages = messages or [] + self.tool_handlers = {} self.thread = self.client.beta.threads.create(messages=self.messages) + async def register_tool_handler(self, name, method): + self.tool_handlers[name] = method + async def dalle2( self, prompt: str, width: Optional[int] = 512, height: Optional[int] = 512 ) -> dict: @@ -144,9 +148,24 @@ class Agent: ) while run.status != "completed": + outputs = [] + for tool in run.required_action.submit_tool_outputs.tool_calls: + tool_handler = self.tool_handlers[tool.name] + output = await tool_handler(tool.arguments) + outputs.append(dict( + tool_call_id=tool.id, + output=output + )) + if outputs: + run = client.beta.threads.runs.submit_tool_outputs_and_poll( + thread_id=self.thread.id, + run_id=run.id, + tool_outputs=outputs + ) run = self.client.beta.threads.runs.retrieve( thread_id=self.thread.id, run_id=run.id ) + yield None await asyncio.sleep(interval)