|
# Written by retoor@molodetz.nl
|
|
|
|
# This source code implements a WebSocket-based RPC (Remote Procedure Call) view that uses asynchronous methods to facilitate real-time communication and services for an authenticated user session in a web application. The class handles WebSocket events, user authentication, and various RPC interactions such as login, message retrieval, and more.
|
|
|
|
# External imports are used from the aiohttp library for the WebSocket response handling and the snek.system view for the BaseView class.
|
|
|
|
# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
|
|
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import traceback
|
|
|
|
from aiohttp import web
|
|
|
|
from snek.system.model import now
|
|
from snek.system.profiler import Profiler
|
|
from snek.system.view import BaseView
|
|
import time
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RPCView(BaseView):
|
|
class RPCApi:
|
|
def __init__(self, view, ws):
|
|
self.view = view
|
|
self.app = self.view.app
|
|
self.services = self.app.services
|
|
self.ws = ws
|
|
self.user_session = {}
|
|
self._scheduled = []
|
|
self._finalize_task = None
|
|
|
|
async def _session_ensure(self):
|
|
uid = await self.view.session_get("uid")
|
|
if uid not in self.user_session:
|
|
self.user_session[uid] = {
|
|
"said_hello": False,
|
|
}
|
|
|
|
async def session_get(self, key, default):
|
|
await self._session_ensure()
|
|
return self.user_session[self.user_uid].get(key, default)
|
|
|
|
async def session_set(self, key, value):
|
|
await self._session_ensure()
|
|
self.user_session[self.user_uid][key] = value
|
|
return True
|
|
|
|
async def db_insert(self, table_name, record):
|
|
self._require_login()
|
|
|
|
return await self.services.db.insert(self.user_uid, table_name, record)
|
|
|
|
async def db_update(self, table_name, record):
|
|
self._require_login()
|
|
return await self.services.db.update(self.user_uid, table_name, record)
|
|
|
|
async def set_typing(self, channel_uid, color=None):
|
|
self._require_login()
|
|
user = await self.services.user.get(self.user_uid)
|
|
if not color:
|
|
color = user["color"]
|
|
return await self.services.socket.broadcast(
|
|
channel_uid,
|
|
{
|
|
"channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2",
|
|
"event": "set_typing",
|
|
"data": {
|
|
"event": "set_typing",
|
|
"user_uid": user["uid"],
|
|
"username": user["username"],
|
|
"nick": user["nick"],
|
|
"channel_uid": channel_uid,
|
|
"color": color,
|
|
},
|
|
},
|
|
)
|
|
|
|
async def db_delete(self, table_name, record):
|
|
self._require_login()
|
|
return await self.services.db.delete(self.user_uid, table_name, record)
|
|
|
|
async def db_get(self, table_name, record):
|
|
self._require_login()
|
|
return await self.services.db.get(self.user_uid, table_name, record)
|
|
|
|
async def db_find(self, table_name, record):
|
|
self._require_login()
|
|
return await self.services.db.find(self.user_uid, table_name, record)
|
|
|
|
async def db_upsert(self, table_name, record, keys):
|
|
self._require_login()
|
|
return await self.services.db.upsert(
|
|
self.user_uid, table_name, record, keys
|
|
)
|
|
|
|
async def db_query(self, table_name, args):
|
|
self._require_login()
|
|
return await self.services.db.query(self.user_uid, table_name, sql, args)
|
|
|
|
@property
|
|
def user_uid(self):
|
|
return self.view.session.get("uid")
|
|
|
|
@property
|
|
def request(self):
|
|
return self.view.request
|
|
|
|
def _require_login(self):
|
|
if not self.is_logged_in:
|
|
raise Exception("Not logged in")
|
|
|
|
@property
|
|
def is_logged_in(self):
|
|
return self.view.session.get("logged_in", False)
|
|
|
|
async def mark_as_read(self, channel_uid):
|
|
self._require_login()
|
|
await self.services.channel_member.mark_as_read(channel_uid, self.user_uid)
|
|
return True
|
|
|
|
async def login(self, username, password):
|
|
success = await self.services.user.validate_login(username, password)
|
|
if not success:
|
|
raise Exception("Invalid username or password")
|
|
user = await self.services.user.get(username=username)
|
|
self.view.session["uid"] = user["uid"]
|
|
self.view.session["logged_in"] = True
|
|
self.view.session["username"] = user["username"]
|
|
self.view.session["user_nick"] = user["nick"]
|
|
record = user.record
|
|
del record["password"]
|
|
del record["deleted_at"]
|
|
await self.services.socket.add(
|
|
self.ws, self.view.request.session.get("uid")
|
|
)
|
|
async for subscription in self.services.channel_member.find(
|
|
user_uid=self.view.request.session.get("uid"),
|
|
deleted_at=None,
|
|
is_banned=False,
|
|
):
|
|
await self.services.socket.subscribe(
|
|
self.ws,
|
|
subscription["channel_uid"],
|
|
self.view.request.session.get("uid"),
|
|
)
|
|
return record
|
|
|
|
async def search_user(self, query):
|
|
self._require_login()
|
|
return [user["username"] for user in await self.services.user.search(query)]
|
|
|
|
async def get_user(self, user_uid):
|
|
self._require_login()
|
|
if not user_uid:
|
|
user_uid = self.user_uid
|
|
user = await self.services.user.get(uid=user_uid)
|
|
record = user.record
|
|
del record["password"]
|
|
del record["deleted_at"]
|
|
if user_uid != user["uid"]:
|
|
del record["email"]
|
|
return record
|
|
|
|
async def get_messages(self, channel_uid, offset=0, timestamp=None):
|
|
self._require_login()
|
|
messages = []
|
|
for message in await self.services.channel_message.offset(
|
|
channel_uid, offset or 0, timestamp or None
|
|
):
|
|
extended_dict = await self.services.channel_message.to_extended_dict(
|
|
message
|
|
)
|
|
messages.append(extended_dict)
|
|
return messages
|
|
|
|
async def get_channels(self):
|
|
self._require_login()
|
|
channels = []
|
|
async for subscription in self.services.channel_member.find(
|
|
user_uid=self.user_uid, is_banned=False
|
|
):
|
|
channel = await self.services.channel.get(
|
|
uid=subscription["channel_uid"]
|
|
)
|
|
last_message = await channel.get_last_message()
|
|
color = None
|
|
if last_message:
|
|
last_message_user = await last_message.get_user()
|
|
color = last_message_user["color"]
|
|
channels.append(
|
|
{
|
|
"name": subscription["label"],
|
|
"uid": subscription["channel_uid"],
|
|
"tag": channel["tag"],
|
|
"new_count": subscription["new_count"],
|
|
"is_moderator": subscription["is_moderator"],
|
|
"is_read_only": subscription["is_read_only"],
|
|
"new_count": subscription["new_count"],
|
|
"color": color,
|
|
}
|
|
)
|
|
return channels
|
|
|
|
async def clear_channel(self, channel_uid):
|
|
self._require_login()
|
|
user = await self.services.user.get(uid=self.user_uid)
|
|
if not user["is_admin"]:
|
|
raise Exception("Not allowed")
|
|
return await self.services.channel_message.clear(channel_uid)
|
|
|
|
async def write_container(self, channel_uid, content,timeout=3):
|
|
self._require_login()
|
|
channel_member = await self.services.channel_member.get(
|
|
channel_uid=channel_uid, user_uid=self.user_uid
|
|
)
|
|
if not channel_member:
|
|
raise Exception("Not allowed")
|
|
|
|
container_name = await self.services.container.get_container_name(channel_uid)
|
|
|
|
class SessionCall:
|
|
|
|
def __init__(self, app,channel_uid_uid, container_name):
|
|
self.app = app
|
|
self.channel_uid = channel_uid
|
|
self.container_name = container_name
|
|
self.time_last_output = time.time()
|
|
self.output = b''
|
|
|
|
async def stdout_event_handler(self, data):
|
|
self.time_last_output = time.time()
|
|
self.output += data
|
|
return True
|
|
|
|
async def communicate(self,content, timeout=3):
|
|
await self.app.services.container.add_event_listener(self.container_name, "stdout", self.stdout_event_handler)
|
|
await self.app.services.container.write_stdin(self.channel_uid, content)
|
|
|
|
while time.time() - self.time_last_output < timeout:
|
|
await asyncio.sleep(0.1)
|
|
await self.app.services.container.remove_event_listener(self.container_name, "stdout", self.stdout_event_handler)
|
|
return self.output
|
|
|
|
sc = SessionCall(self, channel_uid,container_name)
|
|
return (await sc.communicate(content)).decode("utf-8","ignore")
|
|
|
|
async def get_container(self, channel_uid):
|
|
self._require_login()
|
|
channel_member = await self.services.channel_member.get(
|
|
channel_uid=channel_uid, user_uid=self.user_uid
|
|
)
|
|
if not channel_member:
|
|
raise Exception("Not allowed")
|
|
container = await self.services.container.get(channel_uid)
|
|
result = None
|
|
if container:
|
|
result = {
|
|
"name": await self.services.container.get_container_name(channel_uid),
|
|
"cpus": container["deploy"]["resources"]["limits"]["cpus"],
|
|
"memory": container["deploy"]["resources"]["limits"]["memory"],
|
|
"image": "ubuntu:latest",
|
|
"volumes": [],
|
|
"status": container["status"]
|
|
}
|
|
return result
|
|
|
|
async def send_message(self, channel_uid, message, is_final=True):
|
|
self._require_login()
|
|
message = await self.services.chat.send(
|
|
self.user_uid, channel_uid, message, is_final
|
|
)
|
|
|
|
return message["uid"]
|
|
|
|
|
|
async def start_container(self, channel_uid):
|
|
self._require_login()
|
|
channel_member = await self.services.channel_member.get(
|
|
channel_uid=channel_uid, user_uid=self.user_uid
|
|
)
|
|
if not channel_member:
|
|
raise Exception("Not allowed")
|
|
return await self.services.container.start(channel_uid)
|
|
|
|
async def stop_container(self, channel_uid):
|
|
self._require_login()
|
|
channel_member = await self.services.channel_member.get(
|
|
channel_uid=channel_uid, user_uid=self.user_uid
|
|
)
|
|
if not channel_member:
|
|
raise Exception("Not allowed")
|
|
return await self.services.container.stop(channel_uid)
|
|
|
|
async def get_container_status(self, channel_uid):
|
|
self._require_login()
|
|
channel_member = await self.services.channel_member.get(
|
|
channel_uid=channel_uid, user_uid=self.user_uid
|
|
)
|
|
if not channel_member:
|
|
raise Exception("Not allowed")
|
|
return await self.services.container.get_status(channel_uid)
|
|
|
|
|
|
|
|
async def finalize_message(self, message_uid):
|
|
self._require_login()
|
|
message = await self.services.channel_message.get(message_uid)
|
|
if not message:
|
|
return False
|
|
|
|
if message["user_uid"] != self.user_uid:
|
|
raise Exception("Not allowed")
|
|
|
|
if not message["is_final"]:
|
|
await self.services.chat.finalize(message["uid"])
|
|
|
|
return True
|
|
|
|
async def update_message_text(self, message_uid, text):
|
|
async with self.app.no_save():
|
|
self._require_login()
|
|
message = await self.services.channel_message.get(message_uid)
|
|
if message["user_uid"] != self.user_uid:
|
|
raise Exception("Not allowed")
|
|
|
|
if message.get_seconds_since_last_update() > 5:
|
|
return {
|
|
"error": "Message too old",
|
|
"seconds_since_last_update": message.get_seconds_since_last_update(),
|
|
"success": False,
|
|
}
|
|
|
|
message["message"] = text
|
|
if not text:
|
|
message["deleted_at"] = now()
|
|
else:
|
|
message["deleted_at"] = None
|
|
|
|
await self.services.channel_message.save(message)
|
|
data = message.record
|
|
data["text"] = message["message"]
|
|
data["message_uid"] = message_uid
|
|
|
|
await self.services.socket.broadcast(
|
|
message["channel_uid"],
|
|
{
|
|
"channel_uid": message["channel_uid"],
|
|
"event": "update_message_text",
|
|
"data": message.record,
|
|
},
|
|
)
|
|
|
|
return {"success": True}
|
|
|
|
|
|
|
|
|
|
async def clear_channel(self, channel_uid):
|
|
self._require_login()
|
|
user = await self.services.user.get(uid=self.user_uid)
|
|
if not user["is_admin"]:
|
|
raise Exception("Not allowed")
|
|
channel = await self.services.channel.get(uid=channel_uid)
|
|
if not channel:
|
|
raise Exception("Channel not found")
|
|
channel['history_start'] = datetime.now()
|
|
await self.services.channel.save(channel)
|
|
return await self.services.channel_message.clear(channel_uid)
|
|
|
|
|
|
async def echo(self, *args):
|
|
self._require_login()
|
|
return args
|
|
|
|
async def query(self, *args):
|
|
self._require_login()
|
|
query = args[0]
|
|
lowercase = query.lower()
|
|
if (
|
|
any(
|
|
keyword in lowercase
|
|
for keyword in [
|
|
"drop",
|
|
"alter",
|
|
"update",
|
|
"delete",
|
|
"replace",
|
|
"insert",
|
|
"truncate",
|
|
]
|
|
)
|
|
and "select" not in lowercase
|
|
):
|
|
raise Exception("Not allowed")
|
|
records = [
|
|
dict(record) async for record in self.services.channel.query(args[0])
|
|
]
|
|
for record in records:
|
|
try:
|
|
del record["email"]
|
|
except KeyError:
|
|
pass
|
|
try:
|
|
del record["password"]
|
|
except KeyError:
|
|
pass
|
|
try:
|
|
del record["message"]
|
|
except:
|
|
pass
|
|
try:
|
|
del record["html"]
|
|
except:
|
|
pass
|
|
return [
|
|
dict(record) async for record in self.services.channel.query(args[0])
|
|
]
|
|
|
|
async def __call__(self, data):
|
|
try:
|
|
call_id = data.get("callId")
|
|
method_name = data.get("method")
|
|
if method_name.startswith("_"):
|
|
raise Exception("Not allowed")
|
|
args = data.get("args") or []
|
|
if hasattr(super(), method_name) or not hasattr(self, method_name):
|
|
return await self._send_json(
|
|
{"callId": call_id, "data": "Not allowed"}
|
|
)
|
|
method = getattr(self, method_name.replace(".", "_"), None)
|
|
if not method:
|
|
raise Exception("Method not found")
|
|
success = True
|
|
try:
|
|
result = await method(*args)
|
|
except Exception as ex:
|
|
result = {"exception": str(ex), "traceback": traceback.format_exc()}
|
|
success = False
|
|
logger.exception(ex)
|
|
if result != "noresponse":
|
|
await self._send_json(
|
|
{"callId": call_id, "success": success, "data": result}
|
|
)
|
|
except Exception as ex:
|
|
print(str(ex), flush=True)
|
|
logger.exception(ex)
|
|
await self._send_json(
|
|
{"callId": call_id, "success": False, "data": str(ex)}
|
|
)
|
|
|
|
async def _send_json(self, obj):
|
|
try:
|
|
await self.ws.send_str(json.dumps(obj, default=str))
|
|
except Exception as ex:
|
|
await self.services.socket.delete(self.ws)
|
|
await self.ws.close()
|
|
|
|
async def get_online_users(self, channel_uid):
|
|
self._require_login()
|
|
|
|
results = [
|
|
record
|
|
async for record in self.services.channel.get_recent_users(channel_uid)
|
|
]
|
|
results = sorted(results, key=lambda x: x["nick"])
|
|
return results
|
|
|
|
async def echo(self, obj):
|
|
await self.ws.send_json(obj)
|
|
return "noresponse"
|
|
|
|
async def get_recent_users(self, channel_uid):
|
|
self._require_login()
|
|
|
|
return [
|
|
{
|
|
"uid": record["uid"],
|
|
"username": record["username"],
|
|
"nick": record["nick"],
|
|
"last_ping": record["last_ping"],
|
|
}
|
|
async for record in self.services.channel.get_recent_users(channel_uid)
|
|
]
|
|
|
|
async def get_users(self, channel_uid):
|
|
self._require_login()
|
|
|
|
return [
|
|
{
|
|
"uid": record["uid"],
|
|
"username": record["username"],
|
|
"nick": record["nick"],
|
|
"last_ping": record["last_ping"],
|
|
}
|
|
async for record in self.services.channel.get_users(channel_uid)
|
|
]
|
|
|
|
async def _schedule(self, seconds, call):
|
|
self._scheduled.append(call)
|
|
await asyncio.sleep(seconds)
|
|
await self.services.socket.send_to_user(self.user_uid, call)
|
|
self._scheduled.remove(call)
|
|
|
|
async def ping(self, callId, *args):
|
|
if self.user_uid:
|
|
user = await self.services.user.get(uid=self.user_uid)
|
|
user["last_ping"] = now()
|
|
await self.services.user.save(user)
|
|
return {"pong": args}
|
|
|
|
async def stars_render(self, channel_uid, message):
|
|
|
|
for user in await self.get_online_users(channel_uid):
|
|
try:
|
|
await self.services.socket.send_to_user(
|
|
user["uid"],
|
|
{
|
|
"event": "stars_render",
|
|
"data": {"channel_uid": channel_uid, "message": message},
|
|
},
|
|
)
|
|
except Exception as ex:
|
|
print(ex)
|
|
|
|
async def get(self):
|
|
scheduled = []
|
|
|
|
async def schedule(uid, seconds, call):
|
|
scheduled.append(call)
|
|
await asyncio.sleep(seconds)
|
|
await self.services.socket.send_to_user(uid, call)
|
|
scheduled.remove(call)
|
|
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(self.request)
|
|
if self.request.session.get("logged_in"):
|
|
await self.services.socket.add(ws, self.request.session.get("uid"))
|
|
async for subscription in self.services.channel_member.find(
|
|
user_uid=self.request.session.get("uid"),
|
|
deleted_at=None,
|
|
is_banned=False,
|
|
):
|
|
await self.services.socket.subscribe(
|
|
ws, subscription["channel_uid"], self.request.session.get("uid")
|
|
)
|
|
if not scheduled and self.request.app.uptime_seconds < 5:
|
|
await schedule(
|
|
self.request.session.get("uid"),
|
|
0,
|
|
{"event": "refresh", "data": {"message": "Finishing deployment"}},
|
|
)
|
|
await schedule(
|
|
self.request.session.get("uid"),
|
|
15,
|
|
{"event": "deployed", "data": {"uptime": self.request.app.uptime}},
|
|
)
|
|
|
|
rpc = RPCView.RPCApi(self, ws)
|
|
async for msg in ws:
|
|
if msg.type == web.WSMsgType.TEXT:
|
|
try:
|
|
await rpc(msg.json())
|
|
except Exception as ex:
|
|
print("Deleting socket", ex, flush=True)
|
|
logger.exception(ex)
|
|
await self.services.socket.delete(ws)
|
|
break
|
|
elif msg.type == web.WSMsgType.ERROR:
|
|
pass
|
|
elif msg.type == web.WSMsgType.CLOSE:
|
|
pass
|
|
return ws
|