Updates.
This commit is contained in:
		
							parent
							
								
									0f88a658b1
								
							
						
					
					
						commit
						f1c4553038
					
				
							
								
								
									
										12
									
								
								client.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								client.py
									
									
									
									
									
								
							| @ -29,15 +29,23 @@ async def websocket_client(url: str, ollama_url: str) -> None: | |||||||
|                         api_url = urlunparse(urlparse(ollama_url)._replace(path=data['path'])) |                         api_url = urlunparse(urlparse(ollama_url)._replace(path=data['path'])) | ||||||
| 
 | 
 | ||||||
|                         async with session.post(api_url, json=data['data']) as response: |                         async with session.post(api_url, json=data['data']) as response: | ||||||
|  |                             print(response) | ||||||
|                             if response.status != 200: |                             if response.status != 200: | ||||||
|                                 logging.error(f"Failed to post data: {response.status}") |                                 logging.error(f"Failed to post data: {response.status}") | ||||||
|                                 continue |                                 continue | ||||||
|  |                              | ||||||
|                             logging.info(f"Streaming response.") |                             logging.info(f"Streaming response.") | ||||||
|                             async for msg in response.content: |                             async for msg in response.content: | ||||||
|                                 msg = json.loads(msg.decode('utf-8')) |                                 #first_index = msg.find(b"{") | ||||||
|  |                                 #msg = msg[first_index:] | ||||||
|  |                                 #last_index = msg.rfind(b"}") | ||||||
|  |                                 #msg = msg[:last_index+1] | ||||||
|  |                                 #if not msg: | ||||||
|  |                                 #    continue | ||||||
|  |                                 #msg = json.loads(msg.decode('utf-8')) | ||||||
|                                 await ws.send_json(dict( |                                 await ws.send_json(dict( | ||||||
|                                     request_id=request_id, |                                     request_id=request_id, | ||||||
|                                     data=msg |                                     data=msg.decode() | ||||||
|                                 )) |                                 )) | ||||||
|                             logging.info(f"Response complete.") |                             logging.info(f"Response complete.") | ||||||
|                     elif msg.type == aiohttp.WSMsgType.ERROR: |                     elif msg.type == aiohttp.WSMsgType.ERROR: | ||||||
|  | |||||||
							
								
								
									
										72
									
								
								server.py
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								server.py
									
									
									
									
									
								
							| @ -3,6 +3,7 @@ import aiohttp | |||||||
| from aiohttp import web | from aiohttp import web | ||||||
| import uuid | import uuid | ||||||
| import pathlib | import pathlib | ||||||
|  | import json | ||||||
| 
 | 
 | ||||||
| class OllamaServer: | class OllamaServer: | ||||||
|     def __init__(self, ws, models): |     def __init__(self, ws, models): | ||||||
| @ -27,14 +28,23 @@ class OllamaServer: | |||||||
| 
 | 
 | ||||||
|         while True: |         while True: | ||||||
|             chunk = await self.queues[request_id].get() |             chunk = await self.queues[request_id].get() | ||||||
|             yield chunk |             if chunk: | ||||||
| 
 |                 yield chunk | ||||||
|  |             if not chunk: | ||||||
|  |                 yield '\n' | ||||||
|  |             print("CHUNK:", chunk) | ||||||
|  |             #try: | ||||||
|  |                 #yield json.loads(chunk) | ||||||
|  |             #except: | ||||||
|  |             #    yield chunk | ||||||
|             if not 'done' in chunk: |             if not 'done' in chunk: | ||||||
|                 break |                 break | ||||||
| 
 |             if 'stop' in chunk: | ||||||
|  |                 break | ||||||
|             if chunk['done']: |             if chunk['done']: | ||||||
|                 break |                 break | ||||||
|              |              | ||||||
|  | 
 | ||||||
|     async def serve(self): |     async def serve(self): | ||||||
|         async for msg in self.ws: |         async for msg in self.ws: | ||||||
|             if msg.type == web.WSMsgType.TEXT: |             if msg.type == web.WSMsgType.TEXT: | ||||||
| @ -68,6 +78,7 @@ class ServerManager: | |||||||
|             server = self.servers.pop(0) |             server = self.servers.pop(0) | ||||||
|             self.servers.append(server) |             self.servers.append(server) | ||||||
|             server = self.get_server_by_model_name(message['model']) |             server = self.get_server_by_model_name(message['model']) | ||||||
|  |              | ||||||
|             if not server: |             if not server: | ||||||
|                 raise NoServerFoundException |                 raise NoServerFoundException | ||||||
|             async for msg in server.forward_to_websocket(request_id, message, path): |             async for msg in server.forward_to_websocket(request_id, message, path): | ||||||
| @ -83,11 +94,14 @@ class ServerManager: | |||||||
|                     models[model_name] = {} |                     models[model_name] = {} | ||||||
|                     models[model_name]['id'] = model_name |                     models[model_name]['id'] = model_name | ||||||
|                     models[model_name]['instances'] = 0 |                     models[model_name]['instances'] = 0 | ||||||
|                     models[model_name]['owner'] = 'public' |                     models[model_name]['owned_by'] = 'uberlama' | ||||||
|                     models[model_name]['object'] = 'model' |                     models[model_name]['object'] = 'model' | ||||||
|                     models[model_name]['created'] = 0 |                     models[model_name]['created'] = 1743724800 | ||||||
|                 models[model_name]['instances'] += 1 |                 models[model_name]['instances'] += 1 | ||||||
|         return list(models.values()) |         return { | ||||||
|  |             'object':"list", | ||||||
|  |             'data':list(models.values()) | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| server_manager = ServerManager() | server_manager = ServerManager() | ||||||
| @ -95,19 +109,31 @@ server_manager = ServerManager() | |||||||
| async def websocket_handler(request): | async def websocket_handler(request): | ||||||
|     ws = web.WebSocketResponse() |     ws = web.WebSocketResponse() | ||||||
|     await ws.prepare(request) |     await ws.prepare(request) | ||||||
|  |     await asyncio.sleep(1) | ||||||
|  |     try: | ||||||
|  |         models = await ws.receive_json() | ||||||
|  |     except: | ||||||
|  |         print("Non JSON, sleeping three seconds.") | ||||||
|  |         print(request.headers) | ||||||
| 
 | 
 | ||||||
|     models = await ws.receive_json() |  | ||||||
| 
 | 
 | ||||||
|  |         #await ws.send_json("YOUR DATA IS CORRUPT. EXIT PROCESS!") | ||||||
|  |         await ws.close() | ||||||
|  |         return ws | ||||||
|     server = OllamaServer(ws, models['models']) |     server = OllamaServer(ws, models['models']) | ||||||
|     server_manager.add_server(server) |     server_manager.add_server(server) | ||||||
| 
 | 
 | ||||||
|     async for msg in ws: |     try: | ||||||
|         if msg.type == web.WSMsgType.TEXT: |         async for msg in ws: | ||||||
|             data = msg.json() |             if msg.type == web.WSMsgType.TEXT: | ||||||
|             await server.forward_to_http(data['request_id'], data['data']) |                 data = msg.json() | ||||||
|         elif msg.type == web.WSMsgType.ERROR: |                 #print(data) | ||||||
|             print(f'WebSocket connection closed with exception: {ws.exception()}') |                 await server.forward_to_http(data['request_id'], data['data']) | ||||||
| 
 |             elif msg.type == web.WSMsgType.ERROR: | ||||||
|  |                 print(f'WebSocket connection closed with exception: {ws.exception()}') | ||||||
|  |     except: | ||||||
|  |         print("Closing...") | ||||||
|  |         await ws.close() | ||||||
|     server_manager.remove_server(server) |     server_manager.remove_server(server) | ||||||
|     return ws |     return ws | ||||||
| 
 | 
 | ||||||
| @ -118,13 +144,21 @@ async def http_handler(request): | |||||||
|         data = await request.json() |         data = await request.json() | ||||||
|     except ValueError: |     except ValueError: | ||||||
|         return web.Response(status=400) |         return web.Response(status=400) | ||||||
| 
 |     # application/x-ndjson text/event-stream | ||||||
|     resp = web.StreamResponse(headers={'Content-Type': 'application/x-ndjson', 'Transfer-Encoding': 'chunked'}) |     if data['stream']: | ||||||
|  |         resp = web.StreamResponse(headers={'Content-Type': 'text/event-stream', 'Transfer-Encoding': 'chunked'}) | ||||||
|  |     else: | ||||||
|  |         resp = web.StreamResponse(headers={'Content-Type': 'application/json', 'Transfer-Encoding': 'chunked'}) | ||||||
|     await resp.prepare(request) |     await resp.prepare(request) | ||||||
|     import json |     import json | ||||||
|     try: |     try: | ||||||
|         async for result in server_manager.forward_to_websocket(request_id, data, path=request.path): |         async for result in server_manager.forward_to_websocket(request_id, data, path=request.path): | ||||||
|             await resp.write(json.dumps(result).encode()+b'\r\n') |             try: | ||||||
|  |                 a = 3 | ||||||
|  |                 #result = json.dumps(result) | ||||||
|  |             except: | ||||||
|  |                 pass | ||||||
|  |             await resp.write(result.encode()) | ||||||
|     except NoServerFoundException: |     except NoServerFoundException: | ||||||
|         await resp.write(json.dumps(dict(error="No server with that model found.",available=server_manager.get_models())).encode() + b'\r\n') |         await resp.write(json.dumps(dict(error="No server with that model found.",available=server_manager.get_models())).encode() + b'\r\n') | ||||||
|     await resp.write_eof() |     await resp.write_eof() | ||||||
| @ -141,7 +175,9 @@ async def not_found_handler(request): | |||||||
|     return web.json_response({"error":"not found"}) |     return web.json_response({"error":"not found"}) | ||||||
| 
 | 
 | ||||||
| async def models_handler(self): | async def models_handler(self): | ||||||
|     return web.json_response(server_manager.get_models()) |     print("Listing models.") | ||||||
|  |     response_json = json.dumps(server_manager.get_models(),indent=2) | ||||||
|  |     return web.Response(text=response_json,content_type="application/json") | ||||||
| 
 | 
 | ||||||
| app = web.Application() | app = web.Application() | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user