Update agent.
All checks were successful
Build Base Application / Build (push) Successful in 1m44s

This commit is contained in:
retoor 2024-12-22 12:31:45 +01:00
parent d1fb0e351a
commit f24569f498

View File

@ -72,8 +72,12 @@ class Agent:
self.api_key = api_key self.api_key = api_key
self.client = OpenAI(api_key=self.api_key) self.client = OpenAI(api_key=self.api_key)
self.messages = messages or [] self.messages = messages or []
self.tool_handlers = {}
self.thread = self.client.beta.threads.create(messages=self.messages) self.thread = self.client.beta.threads.create(messages=self.messages)
async def register_tool_handler(self, name, method):
self.tool_handlers[name] = method
async def dalle2( async def dalle2(
self, prompt: str, width: Optional[int] = 512, height: Optional[int] = 512 self, prompt: str, width: Optional[int] = 512, height: Optional[int] = 512
) -> dict: ) -> dict:
@ -144,9 +148,24 @@ class Agent:
) )
while run.status != "completed": while run.status != "completed":
outputs = []
for tool in run.required_action.submit_tool_outputs.tool_calls:
tool_handler = self.tool_handlers[tool.name]
output = await tool_handler(tool.arguments)
outputs.append(dict(
tool_call_id=tool.id,
output=output
))
if outputs:
run = client.beta.threads.runs.submit_tool_outputs_and_poll(
thread_id=self.thread.id,
run_id=run.id,
tool_outputs=outputs
)
run = self.client.beta.threads.runs.retrieve( run = self.client.beta.threads.runs.retrieve(
thread_id=self.thread.id, run_id=run.id thread_id=self.thread.id, run_id=run.id
) )
yield None yield None
await asyncio.sleep(interval) await asyncio.sleep(interval)