diff --git a/client.py b/client.py index d8a1509..c91ebfe 100644 --- a/client.py +++ b/client.py @@ -29,15 +29,23 @@ async def websocket_client(url: str, ollama_url: str) -> None: api_url = urlunparse(urlparse(ollama_url)._replace(path=data['path'])) async with session.post(api_url, json=data['data']) as response: + print(response) if response.status != 200: logging.error(f"Failed to post data: {response.status}") continue + logging.info(f"Streaming response.") async for msg in response.content: - msg = json.loads(msg.decode('utf-8')) + #first_index = msg.find(b"{") + #msg = msg[first_index:] + #last_index = msg.rfind(b"}") + #msg = msg[:last_index+1] + #if not msg: + # continue + #msg = json.loads(msg.decode('utf-8')) await ws.send_json(dict( request_id=request_id, - data=msg + data=msg.decode() )) logging.info(f"Response complete.") elif msg.type == aiohttp.WSMsgType.ERROR: diff --git a/server.py b/server.py index fa7d979..0d912bd 100644 --- a/server.py +++ b/server.py @@ -3,6 +3,7 @@ import aiohttp from aiohttp import web import uuid import pathlib +import json class OllamaServer: def __init__(self, ws, models): @@ -27,13 +28,22 @@ class OllamaServer: while True: chunk = await self.queues[request_id].get() - yield chunk - + if chunk: + yield chunk + if not chunk: + yield '\n' + print("CHUNK:", chunk) + #try: + #yield json.loads(chunk) + #except: + # yield chunk if not 'done' in chunk: break - + if 'stop' in chunk: + break if chunk['done']: break + async def serve(self): async for msg in self.ws: @@ -68,6 +78,7 @@ class ServerManager: server = self.servers.pop(0) self.servers.append(server) server = self.get_server_by_model_name(message['model']) + if not server: raise NoServerFoundException async for msg in server.forward_to_websocket(request_id, message, path): @@ -83,11 +94,14 @@ class ServerManager: models[model_name] = {} models[model_name]['id'] = model_name models[model_name]['instances'] = 0 - models[model_name]['owner'] = 'public' + models[model_name]['owned_by'] = 'uberlama' models[model_name]['object'] = 'model' - models[model_name]['created'] = 0 + models[model_name]['created'] = 1743724800 models[model_name]['instances'] += 1 - return list(models.values()) + return { + 'object':"list", + 'data':list(models.values()) + } server_manager = ServerManager() @@ -95,19 +109,31 @@ server_manager = ServerManager() async def websocket_handler(request): ws = web.WebSocketResponse() await ws.prepare(request) + await asyncio.sleep(1) + try: + models = await ws.receive_json() + except: + print("Non JSON, sleeping three seconds.") + print(request.headers) - models = await ws.receive_json() + #await ws.send_json("YOUR DATA IS CORRUPT. EXIT PROCESS!") + await ws.close() + return ws server = OllamaServer(ws, models['models']) server_manager.add_server(server) - async for msg in ws: - if msg.type == web.WSMsgType.TEXT: - data = msg.json() - await server.forward_to_http(data['request_id'], data['data']) - elif msg.type == web.WSMsgType.ERROR: - print(f'WebSocket connection closed with exception: {ws.exception()}') - + try: + async for msg in ws: + if msg.type == web.WSMsgType.TEXT: + data = msg.json() + #print(data) + await server.forward_to_http(data['request_id'], data['data']) + elif msg.type == web.WSMsgType.ERROR: + print(f'WebSocket connection closed with exception: {ws.exception()}') + except: + print("Closing...") + await ws.close() server_manager.remove_server(server) return ws @@ -118,13 +144,21 @@ async def http_handler(request): data = await request.json() except ValueError: return web.Response(status=400) - - resp = web.StreamResponse(headers={'Content-Type': 'application/x-ndjson', 'Transfer-Encoding': 'chunked'}) + # application/x-ndjson text/event-stream + if data['stream']: + resp = web.StreamResponse(headers={'Content-Type': 'text/event-stream', 'Transfer-Encoding': 'chunked'}) + else: + resp = web.StreamResponse(headers={'Content-Type': 'application/json', 'Transfer-Encoding': 'chunked'}) await resp.prepare(request) import json try: async for result in server_manager.forward_to_websocket(request_id, data, path=request.path): - await resp.write(json.dumps(result).encode()+b'\r\n') + try: + a = 3 + #result = json.dumps(result) + except: + pass + await resp.write(result.encode()) except NoServerFoundException: await resp.write(json.dumps(dict(error="No server with that model found.",available=server_manager.get_models())).encode() + b'\r\n') await resp.write_eof() @@ -141,7 +175,9 @@ async def not_found_handler(request): return web.json_response({"error":"not found"}) async def models_handler(self): - return web.json_response(server_manager.get_models()) + print("Listing models.") + response_json = json.dumps(server_manager.get_models(),indent=2) + return web.Response(text=response_json,content_type="application/json") app = web.Application()