Updates.
This commit is contained in:
parent
0f88a658b1
commit
f1c4553038
12
client.py
12
client.py
@ -29,15 +29,23 @@ async def websocket_client(url: str, ollama_url: str) -> None:
|
|||||||
api_url = urlunparse(urlparse(ollama_url)._replace(path=data['path']))
|
api_url = urlunparse(urlparse(ollama_url)._replace(path=data['path']))
|
||||||
|
|
||||||
async with session.post(api_url, json=data['data']) as response:
|
async with session.post(api_url, json=data['data']) as response:
|
||||||
|
print(response)
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
logging.error(f"Failed to post data: {response.status}")
|
logging.error(f"Failed to post data: {response.status}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logging.info(f"Streaming response.")
|
logging.info(f"Streaming response.")
|
||||||
async for msg in response.content:
|
async for msg in response.content:
|
||||||
msg = json.loads(msg.decode('utf-8'))
|
#first_index = msg.find(b"{")
|
||||||
|
#msg = msg[first_index:]
|
||||||
|
#last_index = msg.rfind(b"}")
|
||||||
|
#msg = msg[:last_index+1]
|
||||||
|
#if not msg:
|
||||||
|
# continue
|
||||||
|
#msg = json.loads(msg.decode('utf-8'))
|
||||||
await ws.send_json(dict(
|
await ws.send_json(dict(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
data=msg
|
data=msg.decode()
|
||||||
))
|
))
|
||||||
logging.info(f"Response complete.")
|
logging.info(f"Response complete.")
|
||||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
72
server.py
72
server.py
@ -3,6 +3,7 @@ import aiohttp
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import uuid
|
import uuid
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import json
|
||||||
|
|
||||||
class OllamaServer:
|
class OllamaServer:
|
||||||
def __init__(self, ws, models):
|
def __init__(self, ws, models):
|
||||||
@ -27,13 +28,22 @@ class OllamaServer:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
chunk = await self.queues[request_id].get()
|
chunk = await self.queues[request_id].get()
|
||||||
yield chunk
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
if not chunk:
|
||||||
|
yield '\n'
|
||||||
|
print("CHUNK:", chunk)
|
||||||
|
#try:
|
||||||
|
#yield json.loads(chunk)
|
||||||
|
#except:
|
||||||
|
# yield chunk
|
||||||
if not 'done' in chunk:
|
if not 'done' in chunk:
|
||||||
break
|
break
|
||||||
|
if 'stop' in chunk:
|
||||||
|
break
|
||||||
if chunk['done']:
|
if chunk['done']:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
async def serve(self):
|
async def serve(self):
|
||||||
async for msg in self.ws:
|
async for msg in self.ws:
|
||||||
@ -68,6 +78,7 @@ class ServerManager:
|
|||||||
server = self.servers.pop(0)
|
server = self.servers.pop(0)
|
||||||
self.servers.append(server)
|
self.servers.append(server)
|
||||||
server = self.get_server_by_model_name(message['model'])
|
server = self.get_server_by_model_name(message['model'])
|
||||||
|
|
||||||
if not server:
|
if not server:
|
||||||
raise NoServerFoundException
|
raise NoServerFoundException
|
||||||
async for msg in server.forward_to_websocket(request_id, message, path):
|
async for msg in server.forward_to_websocket(request_id, message, path):
|
||||||
@ -83,11 +94,14 @@ class ServerManager:
|
|||||||
models[model_name] = {}
|
models[model_name] = {}
|
||||||
models[model_name]['id'] = model_name
|
models[model_name]['id'] = model_name
|
||||||
models[model_name]['instances'] = 0
|
models[model_name]['instances'] = 0
|
||||||
models[model_name]['owner'] = 'public'
|
models[model_name]['owned_by'] = 'uberlama'
|
||||||
models[model_name]['object'] = 'model'
|
models[model_name]['object'] = 'model'
|
||||||
models[model_name]['created'] = 0
|
models[model_name]['created'] = 1743724800
|
||||||
models[model_name]['instances'] += 1
|
models[model_name]['instances'] += 1
|
||||||
return list(models.values())
|
return {
|
||||||
|
'object':"list",
|
||||||
|
'data':list(models.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
server_manager = ServerManager()
|
server_manager = ServerManager()
|
||||||
@ -95,19 +109,31 @@ server_manager = ServerManager()
|
|||||||
async def websocket_handler(request):
|
async def websocket_handler(request):
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
await ws.prepare(request)
|
await ws.prepare(request)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
try:
|
||||||
|
models = await ws.receive_json()
|
||||||
|
except:
|
||||||
|
print("Non JSON, sleeping three seconds.")
|
||||||
|
print(request.headers)
|
||||||
|
|
||||||
models = await ws.receive_json()
|
|
||||||
|
|
||||||
|
#await ws.send_json("YOUR DATA IS CORRUPT. EXIT PROCESS!")
|
||||||
|
await ws.close()
|
||||||
|
return ws
|
||||||
server = OllamaServer(ws, models['models'])
|
server = OllamaServer(ws, models['models'])
|
||||||
server_manager.add_server(server)
|
server_manager.add_server(server)
|
||||||
|
|
||||||
async for msg in ws:
|
try:
|
||||||
if msg.type == web.WSMsgType.TEXT:
|
async for msg in ws:
|
||||||
data = msg.json()
|
if msg.type == web.WSMsgType.TEXT:
|
||||||
await server.forward_to_http(data['request_id'], data['data'])
|
data = msg.json()
|
||||||
elif msg.type == web.WSMsgType.ERROR:
|
#print(data)
|
||||||
print(f'WebSocket connection closed with exception: {ws.exception()}')
|
await server.forward_to_http(data['request_id'], data['data'])
|
||||||
|
elif msg.type == web.WSMsgType.ERROR:
|
||||||
|
print(f'WebSocket connection closed with exception: {ws.exception()}')
|
||||||
|
except:
|
||||||
|
print("Closing...")
|
||||||
|
await ws.close()
|
||||||
server_manager.remove_server(server)
|
server_manager.remove_server(server)
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
@ -118,13 +144,21 @@ async def http_handler(request):
|
|||||||
data = await request.json()
|
data = await request.json()
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
# application/x-ndjson text/event-stream
|
||||||
resp = web.StreamResponse(headers={'Content-Type': 'application/x-ndjson', 'Transfer-Encoding': 'chunked'})
|
if data['stream']:
|
||||||
|
resp = web.StreamResponse(headers={'Content-Type': 'text/event-stream', 'Transfer-Encoding': 'chunked'})
|
||||||
|
else:
|
||||||
|
resp = web.StreamResponse(headers={'Content-Type': 'application/json', 'Transfer-Encoding': 'chunked'})
|
||||||
await resp.prepare(request)
|
await resp.prepare(request)
|
||||||
import json
|
import json
|
||||||
try:
|
try:
|
||||||
async for result in server_manager.forward_to_websocket(request_id, data, path=request.path):
|
async for result in server_manager.forward_to_websocket(request_id, data, path=request.path):
|
||||||
await resp.write(json.dumps(result).encode()+b'\r\n')
|
try:
|
||||||
|
a = 3
|
||||||
|
#result = json.dumps(result)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await resp.write(result.encode())
|
||||||
except NoServerFoundException:
|
except NoServerFoundException:
|
||||||
await resp.write(json.dumps(dict(error="No server with that model found.",available=server_manager.get_models())).encode() + b'\r\n')
|
await resp.write(json.dumps(dict(error="No server with that model found.",available=server_manager.get_models())).encode() + b'\r\n')
|
||||||
await resp.write_eof()
|
await resp.write_eof()
|
||||||
@ -141,7 +175,9 @@ async def not_found_handler(request):
|
|||||||
return web.json_response({"error":"not found"})
|
return web.json_response({"error":"not found"})
|
||||||
|
|
||||||
async def models_handler(self):
|
async def models_handler(self):
|
||||||
return web.json_response(server_manager.get_models())
|
print("Listing models.")
|
||||||
|
response_json = json.dumps(server_manager.get_models(),indent=2)
|
||||||
|
return web.Response(text=response_json,content_type="application/json")
|
||||||
|
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user