# Written by retoor@molodetz.nl

# This code creates a server using asyncio and aiohttp that manages websocket and HTTP connections to forward messages between them.

# Used Imports: asyncio, aiohttp

# The MIT License (MIT)
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import asyncio
import aiohttp
from aiohttp import web
import uuid
import pathlib 

class OllamaServer:
    def __init__(self,ws,models):
        self.ws = ws 
        self.queues = {}
        self.models = models  
        print("New OllamaServer created")
        print(self.model_names)

    @property
    def model_names(self):
        return [model['name'] for model in self.models] 

    async def forward_to_http(self, request_id, message):
        if not request_id in self.queues:
            self.queues[request_id] = asyncio.Queue()
        await self.queues[request_id].put(message)
    
    async def forward_to_websocket(self, request_id, message,path):
        self.queues[request_id] = asyncio.Queue()
        await self.ws.send_json(dict(request_id=request_id, data=message,path=path))

        while True:
            chunk = await self.queues[request_id].get()
            yield chunk 
            if chunk['done']:
                break


    async def serve(self):
        async for msg in self.ws:
            if msg.type == web.WSMsgType.TEXT:
                data = msg.json()
                request_id = data['request_id']
                await self.forward_to_http(request_id, data['data'])
            elif msg.type == web.WSMsgType.ERROR:
                break
   
class ServerManager:
    def __init__(self):
        self.servers = []

    def add_server(self, server):
        self.servers.append(server)

    def remove_server(self, server):
        self.servers.remove(server)

    async def forward_to_websocket(self, request_id, message,path):
        try:
            server = self.servers.pop(0)
            self.servers.append(server)
            async for msg in  server.forward_to_websocket(request_id, message,path):
                yield msg
        except:
            raise 

server_manager = ServerManager()



async def websocket_handler(request):
    ws = web.WebSocketResponse()
    await ws.prepare(request)

    models = await ws.receive_json()

    server = OllamaServer(ws,models['models'])
    server_manager.add_server(server)

    async for msg in ws:
        if msg.type == web.WSMsgType.TEXT:
            data = msg.json()
            await server.forward_to_http(data['request_id'], data['data'])
        elif msg.type == web.WSMsgType.ERROR:
            print(f'WebSocket connection closed with exception: {ws.exception()}')

    server_manager.remove_server(server)
    return ws


async def http_handler(request):
    request_id = str(uuid.uuid4())
    data = None
    try:
        data = await request.json()
    except ValueError:
        return web.Response(status=400)

    resp = web.StreamResponse(headers={'Content-Type': 'application/x-ndjson','Transfer-Encoding': 'chunked'})
    await resp.prepare(request)
    import json 
    async for result in server_manager.forward_to_websocket(request_id, data,path=request.path):
        await resp.write(json.dumps(result).encode() + b'\n')    
    await resp.write_eof()
    return resp

async def index_handler(request):
    index_template = pathlib.Path("index.html").read_text()
    client_py = pathlib.Path("client.py").read_text()
    index_template = index_template.replace("#client.py", client_py)
    return web.Response(text=index_template, content_type="text/html")

app = web.Application()

app.router.add_get("/", index_handler)
app.router.add_route('GET', '/publish', websocket_handler)
app.router.add_route('POST', '/api/chat', http_handler)

if __name__ == '__main__':
    web.run_app(app, port=8080)