|
import asyncio
|
|
import aiohttp
|
|
import json
|
|
import logging
|
|
import argparse
|
|
from urllib.parse import urlparse, urlunparse
|
|
|
|
DEFAULT_CONCURRENCY = 4
|
|
DEFAULT_OLLAMA_URL = 'http://localhost:11434'
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
async def websocket_client(url: str, ollama_url: str) -> None:
|
|
async with aiohttp.ClientSession() as session:
|
|
try:
|
|
async with session.ws_connect(f'{url}/publish') as ws:
|
|
async with session.get(f'{ollama_url}/api/tags') as response:
|
|
if response.status != 200:
|
|
logging.error(f"Failed to fetch models: {response.status}")
|
|
return
|
|
models = await response.json()
|
|
await ws.send_json(models)
|
|
|
|
async for msg in ws:
|
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
data = msg.json()
|
|
logging.info(f"Received data: {data}")
|
|
request_id = data['request_id']
|
|
api_url = urlunparse(urlparse(ollama_url)._replace(path=data['path']))
|
|
|
|
async with session.post(api_url, json=data['data']) as response:
|
|
if response.status != 200:
|
|
logging.error(f"Failed to post data: {response.status}")
|
|
continue
|
|
logging.info(f"Streaming response.")
|
|
async for msg in response.content:
|
|
msg = json.loads(msg.decode('utf-8'))
|
|
await ws.send_json(dict(
|
|
request_id=request_id,
|
|
data=msg
|
|
))
|
|
logging.info(f"Response complete.")
|
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
logging.error("WebSocket error occurred.")
|
|
break
|
|
except aiohttp.ClientError as e:
|
|
logging.error(f"Client error occurred: {e}")
|
|
except Exception as e:
|
|
logging.error(f"An unexpected error occurred: {e}")
|
|
|
|
async def main(concurrency: int, ollama_url: str) -> None:
|
|
url = 'https://ollama.molodetz.nl'
|
|
|
|
while True:
|
|
tasks = []
|
|
for _ in range(concurrency):
|
|
tasks.append(websocket_client(url, ollama_url))
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
except Exception as e:
|
|
logging.error(f"Connection error: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
def validate_url(url: str) -> bool:
|
|
parsed = urlparse(url)
|
|
return all([parsed.scheme, parsed.netloc])
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='WebSocket Client for Ollama API')
|
|
parser.add_argument('--concurrency', type=int, default=DEFAULT_CONCURRENCY,
|
|
help='Number of concurrent WebSocket connections (default: 4)')
|
|
parser.add_argument('--ollama_url', type=str, default=DEFAULT_OLLAMA_URL,
|
|
help='Ollama API URL (default: http://localhost:11434)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not validate_url(args.ollama_url):
|
|
logging.error(f"Invalid Ollama URL: {args.ollama_url}")
|
|
exit(1)
|
|
|
|
asyncio.run(main(args.concurrency, args.ollama_url))
|