Updated rpc.
This commit is contained in:
parent
04cd9489ac
commit
328534e5c9
@ -9,53 +9,72 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import pathlib
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import pathlib
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("snekbot.rpc")
|
||||||
|
|
||||||
class RPC:
|
class RPC:
|
||||||
class Response:
|
class Response:
|
||||||
def __init__(self, msg):
|
def __init__(self, msg):
|
||||||
|
logger.debug("Initializing response.")
|
||||||
if isinstance(msg, list):
|
if isinstance(msg, list):
|
||||||
self.list = msg
|
self.list = msg
|
||||||
self.__dict__.update(msg)
|
self.__dict__.update(msg)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for item in self.data:
|
logger.debug("Iterating sync trough result data.")
|
||||||
yield item
|
for k in self.__dict__.get("data", []):
|
||||||
|
yield k
|
||||||
|
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
for item in self.data:
|
logger.debug("Iterating async trough result data.")
|
||||||
yield item
|
for k in self.__dict__.get("data", []):
|
||||||
|
for k in self.__dict__.get("data", []):
|
||||||
|
yield k
|
||||||
|
|
||||||
def __getitem__(self, name):
|
def __getitem__(self, name):
|
||||||
|
logger.debug("Getting result data: " + name + ".")
|
||||||
return self.__dict__[name]
|
return self.__dict__[name]
|
||||||
|
|
||||||
def __setitem__(self, name, value):
|
def __setitem__(self, name, value):
|
||||||
|
logger.debug("Setting result data: " + name + ".")
|
||||||
self.__dict__[name] = value
|
self.__dict__[name] = value
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return json.dumps(self.__dict__, default=str, indent=2)
|
return json.dumps(self.__dict__, default=str, indent=2)
|
||||||
|
|
||||||
def __init__(self, ws):
|
def __init__(self, ws):
|
||||||
|
logger.debug("Initializing RPC.")
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
self.current_call_id = None
|
||||||
|
|
||||||
|
async def echo(self, data):
|
||||||
|
logger.debug("Sending echo to server: " + str(data))
|
||||||
|
await self.ws.send_json(dict(method="echo",args=[data]))
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
async def method(*args, **kwargs):
|
async def method(*args, **kwargs):
|
||||||
payload = {"method": name, "args": args}
|
self.current_call_id = str(uuid.uuid4())
|
||||||
|
payload = dict(method=name, args=args, kwargs=kwargs, callId=self.current_call_id)
|
||||||
try:
|
try:
|
||||||
await self.ws.send_json(payload)
|
await self.ws.send_json(payload)
|
||||||
except Exception:
|
except Exception as ex:
|
||||||
return None
|
print(ex)
|
||||||
|
|
||||||
async def returner():
|
async def returner():
|
||||||
|
while True:
|
||||||
response = await self.ws.receive()
|
response = await self.ws.receive()
|
||||||
return self.Response(response.json())
|
data = response.json()
|
||||||
|
if not data.get("callId") == self.current_call_id:
|
||||||
|
await self.echo(data)
|
||||||
|
continue
|
||||||
|
return self.Response(data)
|
||||||
return returner
|
return returner
|
||||||
|
|
||||||
return method
|
return method
|
||||||
|
|
||||||
async def system(self, command):
|
async def system(self, command):
|
||||||
@ -78,8 +97,8 @@ class RPC:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
path.unlink()
|
path.unlink()
|
||||||
except Exception:
|
except Exception as ex:
|
||||||
pass
|
print("Error deleting temporary file:", ex)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -87,15 +106,27 @@ class RPC:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = await self.ws.receive()
|
msg = await self.ws.receive()
|
||||||
except Exception:
|
except Exception as ex:
|
||||||
|
print("Error while receiving:", ex)
|
||||||
break
|
break
|
||||||
if msg.type == aiohttp.WSMsgType.CLOSED:
|
if msg.type == aiohttp.WSMsgType.CLOSED:
|
||||||
break
|
break
|
||||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
|
print("WebSocket error:", msg)
|
||||||
break
|
break
|
||||||
elif msg.type == aiohttp.WSMsgType.TEXT:
|
elif msg.type == aiohttp.WSMsgType.TEXT:
|
||||||
|
|
||||||
|
if self.current_call_id and not msg.json().get('callId') != self.current_call_id:
|
||||||
|
await self.echo(msg.json())
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
return self.Response(msg.json())
|
response = self.Response(msg.json())
|
||||||
except Exception:
|
self.current_call_id = None
|
||||||
|
return response
|
||||||
|
except Exception as ex:
|
||||||
|
print("Error while parsing message:", msg, ex)
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("Unexpected message type.")
|
||||||
|
print("huh")
|
||||||
return None
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user