diff --git a/src/snekbot/rpc.py b/src/snekbot/rpc.py index 6c65db1..53bd27b 100644 --- a/src/snekbot/rpc.py +++ b/src/snekbot/rpc.py @@ -1,6 +1,6 @@ # Written by retoor@molodetz.nl -# This code defines an RPC class that allows asynchronous communication over websockets, +# This code defines an RPC class that allows asynchronous communication over websockets, # including command execution and handling of asynchronous responses using Python's asyncio and websockets library. # Uses aiohttp for asynchronous HTTP network communication. @@ -9,53 +9,72 @@ import json -import pathlib import subprocess - +import pathlib import aiohttp +import uuid +import logging + +logger = logging.getLogger("snekbot.rpc") class RPC: class Response: def __init__(self, msg): + logger.debug("Initializing response.") if isinstance(msg, list): - self.list = msg + self.list = msg self.__dict__.update(msg) - + def __iter__(self): - for item in self.data: - yield item + logger.debug("Iterating sync trough result data.") + for k in self.__dict__.get("data", []): + yield k async def __aiter__(self): - for item in self.data: - yield item + logger.debug("Iterating async trough result data.") + for k in self.__dict__.get("data", []): + for k in self.__dict__.get("data", []): + yield k def __getitem__(self, name): + logger.debug("Getting result data: " + name + ".") return self.__dict__[name] - + def __setitem__(self, name, value): + logger.debug("Setting result data: " + name + ".") self.__dict__[name] = value def __str__(self): return json.dumps(self.__dict__, default=str, indent=2) def __init__(self, ws): + logger.debug("Initializing RPC.") self.ws = ws + self.current_call_id = None + + async def echo(self, data): + logger.debug("Sending echo to server: " + str(data)) + await self.ws.send_json(dict(method="echo",args=[data])) def __getattr__(self, name): async def method(*args, **kwargs): - payload = {"method": name, "args": args} + self.current_call_id = str(uuid.uuid4()) + payload = dict(method=name, args=args, kwargs=kwargs, callId=self.current_call_id) try: await self.ws.send_json(payload) - except Exception: - return None - + except Exception as ex: + print(ex) + async def returner(): - response = await self.ws.receive() - return self.Response(response.json()) - - return returner - + while True: + response = await self.ws.receive() + data = response.json() + if not data.get("callId") == self.current_call_id: + await self.echo(data) + continue + return self.Response(data) + return returner return method async def system(self, command): @@ -78,8 +97,8 @@ class RPC: try: path.unlink() - except Exception: - pass + except Exception as ex: + print("Error deleting temporary file:", ex) return response @@ -87,15 +106,27 @@ class RPC: while True: try: msg = await self.ws.receive() - except Exception: + except Exception as ex: + print("Error while receiving:", ex) break if msg.type == aiohttp.WSMsgType.CLOSED: break elif msg.type == aiohttp.WSMsgType.ERROR: + print("WebSocket error:", msg) break elif msg.type == aiohttp.WSMsgType.TEXT: + + if self.current_call_id and not msg.json().get('callId') != self.current_call_id: + await self.echo(msg.json()) + continue try: - return self.Response(msg.json()) - except Exception: + response = self.Response(msg.json()) + self.current_call_id = None + return response + except Exception as ex: + print("Error while parsing message:", msg, ex) break + else: + raise Exception("Unexpected message type.") + print("huh") return None