import asyncio
import aiohttp
from aiohttp import web
import uuid
import pathlib
import json

class OllamaServer:
    def __init__(self, ws, models):
        self.ws = ws
        self.queues = {}
        self.models = models

    @property
    def model_names(self):
        return [model['name'] for model in self.models]

    async def forward_to_http(self, request_id, message):
        if request_id not in self.queues:
            self.queues[request_id] = asyncio.Queue()
        print(message)
        await self.queues[request_id].put(message)

    async def forward_to_websocket(self, request_id, message, path):
        self.queues[request_id] = asyncio.Queue()
        print(path,request_id,message)
        await self.ws.send_json(dict(request_id=request_id, data=message, path=path))

        while True:
            chunk = await self.queues[request_id].get()
            if chunk:
                yield chunk
            if not chunk:
                yield ''
                #yield '\n'
            print("CHUNK:", chunk)
            try:
                obj = json.loads(chunk)
                if obj.get('done'):
                    break
            except:
                pass

            try:
                if '"finish_reason":"stop"' in chunk:
                    break
            except:
                pass 

            try:
                if 'data: [DONE]' in chunk:
                    break
            except:
                pass
            #try:
                #yield json.loads(chunk)
            #except:
            #    yield chunk
            #if not 'done' in chunk:
            #    break
            #if 'stop' in chunk:
            #    break
            #if chunk.get('done'):
            #    break
            

    async def serve(self):
        async for msg in self.ws:
            if msg.type == web.WSMsgType.TEXT:
                data = msg.json()
                request_id = data['request_id']
                await self.forward_to_http(request_id, data['data'])
            elif msg.type == web.WSMsgType.ERROR:
                break

class NoServerFoundException(BaseException):
    pass 

class ServerManager:
    def __init__(self):
        self.servers = []

    def add_server(self, server):
        self.servers.append(server)

    def remove_server(self, server):
        self.servers.remove(server)

    def get_server_by_model_name(self, model_name):
        for server in self.servers:
            if model_name in server.model_names:
                return server
        return None

    async def forward_to_websocket(self, request_id, message, path):
        try:
            server = self.servers.pop(0)
            self.servers.append(server)
            server = self.get_server_by_model_name(message['model'])
            
            if not server:
                raise NoServerFoundException
            async for msg in server.forward_to_websocket(request_id, message, path):
                yield msg
        except:
            raise

    def get_models(self):
        models = {}
        for server in self.servers:
            for model_name in server.model_names:
                if not model_name in models:
                    models[model_name] = {}
                    models[model_name]['id'] = model_name
                    models[model_name]['instances'] = 0
                    models[model_name]['owned_by'] = 'uberlama'
                    models[model_name]['object'] = 'model'
                    models[model_name]['created'] = 1743724800
                models[model_name]['instances'] += 1
        return {
            'object':"list",
            'data':list(models.values())
        }


server_manager = ServerManager()

async def websocket_handler(request):
    ws = web.WebSocketResponse()
    await ws.prepare(request)
    await asyncio.sleep(1)
    try:
        models = await ws.receive_json()
    except:
        print("Non JSON, sleeping three seconds.")
        print(request.headers)


        #await ws.send_json("YOUR DATA IS CORRUPT. EXIT PROCESS!")
        await ws.close()
        return ws
    server = OllamaServer(ws, models['models'])
    server_manager.add_server(server)

    try:
        async for msg in ws:
            if msg.type == web.WSMsgType.TEXT:
                data = msg.json()
                #print(data)
                await server.forward_to_http(data['request_id'], data['data'])
            elif msg.type == web.WSMsgType.ERROR:
                print(f'WebSocket connection closed with exception: {ws.exception()}')
    except:
        print("Closing...")
        await ws.close()
    server_manager.remove_server(server)
    return ws

async def http_handler(request):
    request_id = str(uuid.uuid4())
    data = None
    try:
        data = await request.json()
    except ValueError:
        return web.Response(status=400)
    # application/x-ndjson text/event-stream
    if data.get('stream'):
        resp = web.StreamResponse(headers={'Content-Type': 'text/event-stream', 'Transfer-Encoding': 'chunked'})
    else:
        resp = web.StreamResponse(headers={'Content-Type': 'application/json', 'Transfer-Encoding':'chunked'})
    await resp.prepare(request)
    import json
    try:
        async for result in server_manager.forward_to_websocket(request_id, data, path=request.path):
            try:
                a = 3
                #result = json.dumps(result)
            except:
                pass
            await resp.write(result.encode())
    except NoServerFoundException:
        await resp.write(json.dumps(dict(error="No server with that model found.",available=server_manager.get_models())).encode() + b'\r\n')
    await resp.write_eof()
    return resp

async def index_handler(request):
    index_template = pathlib.Path("index.html").read_text()
    client_py = pathlib.Path("client.py").read_text()
    index_template = index_template.replace("#client.py", client_py)
    return web.Response(text=index_template, content_type="text/html")

async def not_found_handler(request):
    print("not found:",request.path)
    return web.json_response({"error":"not found"})

async def models_handler(self):
    print("Listing models.")
    response_json = json.dumps(server_manager.get_models(),indent=2)
    return web.Response(text=response_json,content_type="application/json")

app = web.Application()

app.router.add_get("/", index_handler)
app.router.add_route('GET', '/publish', websocket_handler)
app.router.add_route('POST', '/api/chat', http_handler)
app.router.add_route('POST', '/v1/chat', http_handler)
app.router.add_route('POST', '/v1/completions', http_handler)
app.router.add_route('POST', '/v1/chat/completions', http_handler)
app.router.add_route('GET', '/models', models_handler)
app.router.add_route('GET', '/v1/models', models_handler)
app.router.add_route('*', '/{tail:.*}', not_found_handler)

if __name__ == '__main__':
    web.run_app(app, port=1984)