This commit is contained in:
parent
d1fb0e351a
commit
f24569f498
@ -72,8 +72,12 @@ class Agent:
|
||||
self.api_key = api_key
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
self.messages = messages or []
|
||||
self.tool_handlers = {}
|
||||
self.thread = self.client.beta.threads.create(messages=self.messages)
|
||||
|
||||
async def register_tool_handler(self, name, method):
|
||||
self.tool_handlers[name] = method
|
||||
|
||||
async def dalle2(
|
||||
self, prompt: str, width: Optional[int] = 512, height: Optional[int] = 512
|
||||
) -> dict:
|
||||
@ -144,9 +148,24 @@ class Agent:
|
||||
)
|
||||
|
||||
while run.status != "completed":
|
||||
outputs = []
|
||||
for tool in run.required_action.submit_tool_outputs.tool_calls:
|
||||
tool_handler = self.tool_handlers[tool.name]
|
||||
output = await tool_handler(tool.arguments)
|
||||
outputs.append(dict(
|
||||
tool_call_id=tool.id,
|
||||
output=output
|
||||
))
|
||||
if outputs:
|
||||
run = client.beta.threads.runs.submit_tool_outputs_and_poll(
|
||||
thread_id=self.thread.id,
|
||||
run_id=run.id,
|
||||
tool_outputs=outputs
|
||||
)
|
||||
run = self.client.beta.threads.runs.retrieve(
|
||||
thread_id=self.thread.id, run_id=run.id
|
||||
)
|
||||
|
||||
yield None
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user