This commit is contained in:
retoor 2025-05-09 01:33:41 +02:00
parent 3c0fea6812
commit 02a0253c1d
21 changed files with 148 additions and 74 deletions

View File

@ -1,32 +1,37 @@
import argparse import argparse
from aiohttp import web from aiohttp import web
from snek.app import Application from snek.app import Application
def main(): def main():
parser = argparse.ArgumentParser(description="Run the web application.") parser = argparse.ArgumentParser(description="Run the web application.")
parser.add_argument( parser.add_argument(
"--port", "--port",
type=int, type=int,
default=8081, default=8081,
help="Port to run the application on (default: 8081)" help="Port to run the application on (default: 8081)",
) )
parser.add_argument( parser.add_argument(
"--host", "--host",
type=str, type=str,
default="0.0.0.0", default="0.0.0.0",
help="Host to run the application on (default: 0.0.0.0)" help="Host to run the application on (default: 0.0.0.0)",
) )
parser.add_argument( parser.add_argument(
"--db_path", "--db_path",
type=str, type=str,
default="snek.db", default="snek.db",
help="Database path for the application (default: sqlite:///snek.db)" help="Database path for the application (default: sqlite:///snek.db)",
) )
args = parser.parse_args() args = parser.parse_args()
web.run_app(Application(db_path='sqlite:///' + args.db_path), port=args.port, host=args.host) web.run_app(
Application(db_path="sqlite:///" + args.db_path), port=args.port, host=args.host
)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -17,8 +17,9 @@ from aiohttp_session import (
setup as session_setup, setup as session_setup,
) )
from aiohttp_session.cookie_storage import EncryptedCookieStorage from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader
from snek.docs.app import Application as DocsApplication from snek.docs.app import Application as DocsApplication
from snek.mapper import get_mappers from snek.mapper import get_mappers
from snek.service import get_services from snek.service import get_services
@ -40,12 +41,12 @@ from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView from snek.view.search_user import SearchUserView
from snek.view.settings.index import SettingsIndexView from snek.view.settings.index import SettingsIndexView
from snek.view.settings.profile import SettingsProfileView from snek.view.settings.profile import SettingsProfileView
from snek.view.stats import StatsView
from snek.view.status import StatusView from snek.view.status import StatusView
from snek.view.terminal import TerminalSocketView, TerminalView from snek.view.terminal import TerminalSocketView, TerminalView
from snek.view.upload import UploadView from snek.view.upload import UploadView
from snek.view.web import WebView
from snek.view.stats import StatsView
from snek.view.user import UserView from snek.view.user import UserView
from snek.view.web import WebView
from snek.webdav import WebdavApplication from snek.webdav import WebdavApplication
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34" SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
@ -204,7 +205,6 @@ class Application(BaseApplication):
channels = [] channels = []
if not context: if not context:
context = {} context = {}
context["rid"] = str(uuid.uuid4()) context["rid"] = str(uuid.uuid4())
if request.session.get("uid"): if request.session.get("uid"):
async for subscribed_channel in self.services.channel_member.find( async for subscribed_channel in self.services.channel_member.find(
@ -231,7 +231,6 @@ class Application(BaseApplication):
item["uid"] = subscribed_channel["channel_uid"] item["uid"] = subscribed_channel["channel_uid"]
item["new_count"] = subscribed_channel["new_count"] item["new_count"] = subscribed_channel["new_count"]
print(item)
channels.append(item) channels.append(item)
channels.sort(key=lambda x: x["last_message_on"] or "", reverse=True) channels.sort(key=lambda x: x["last_message_on"] or "", reverse=True)
@ -239,10 +238,37 @@ class Application(BaseApplication):
context["channels"] = channels context["channels"] = channels
if "user" not in context: if "user" not in context:
context["user"] = await self.services.user.get( context["user"] = await self.services.user.get(
uid=request.session.get("uid") request.session.get("uid")
) )
return await super().render_template(template, request, context) self.template_path.joinpath(template)
await self.services.user.get_template_path(request.session.get("uid"))
self.original_loader = self.jinja2_env.loader
self.jinja2_env.loader = await self.get_user_template_loader(
request.session.get("uid")
)
rendered = await super().render_template(template, request, context)
self.jinja2_env.loader = self.original_loader
return rendered
async def get_user_template_loader(self, uid=None):
template_paths = []
for admin_uid in self.services.user.get_admin_uids():
user_template_path = await self.services.user.get_template_path(admin_uid)
template_paths.append(user_template_path)
if uid:
user_template_path = await self.services.user.get_template_path(uid)
template_paths.append(user_template_path)
template_paths.append(self.template_path)
return FileSystemLoader(template_paths)
app = Application(db_path="sqlite:///snek.db") app = Application(db_path="sqlite:///snek.db")

View File

@ -1,8 +1,8 @@
import pathlib import pathlib
from aiohttp import web from aiohttp import web
from app.app import Application as BaseApplication from app.app import Application as BaseApplication
from snek.system.markdown import MarkdownExtension from snek.system.markdown import MarkdownExtension

View File

@ -5,3 +5,16 @@ from snek.system.mapper import BaseMapper
class UserMapper(BaseMapper): class UserMapper(BaseMapper):
table_name = "user" table_name = "user"
model_class = UserModel model_class = UserModel
def get_admin_uids(self):
try:
return [
user["uid"]
for user in self.db.query(
"SELECT uid FROM user WHERE is_admin = :is_admin",
{"is_admin": True},
)
]
except Exception as ex:
print(ex)
return []

View File

@ -29,6 +29,8 @@ class UserModel(BaseModel):
last_ping = ModelField(name="last_ping", required=False, kind=str) last_ping = ModelField(name="last_ping", required=False, kind=str)
is_admin = ModelField(name="is_admin", required=False, kind=bool)
async def get_property(self, name): async def get_property(self, name):
prop = await self.app.services.user_property.find_one( prop = await self.app.services.user_property.find_one(
user_uid=self["uid"], name=name user_uid=self["uid"], name=name

View File

@ -11,9 +11,12 @@ class ChannelMemberService(BaseService):
return await self.save(channel_member) return await self.save(channel_member)
async def get_user_uids(self, channel_uid): async def get_user_uids(self, channel_uid):
async for model in self.mapper.query("SELECT user_uid FROM channel_member WHERE channel_uid=:channel_uid", {"channel_uid": channel_uid}): async for model in self.mapper.query(
"SELECT user_uid FROM channel_member WHERE channel_uid=:channel_uid",
{"channel_uid": channel_uid},
):
yield model["user_uid"] yield model["user_uid"]
async def create( async def create(
self, self,
channel_uid, channel_uid,

View File

@ -15,7 +15,7 @@ class SocketService(BaseService):
return False return False
try: try:
await self.ws.send_json(data) await self.ws.send_json(data)
except Exception as ex: except Exception:
self.is_connected = False self.is_connected = False
return self.is_connected return self.is_connected
@ -56,7 +56,9 @@ class SocketService(BaseService):
async def broadcast(self, channel_uid, message): async def broadcast(self, channel_uid, message):
try: try:
async for user_uid in self.services.channel_member.get_user_uids(channel_uid): async for user_uid in self.services.channel_member.get_user_uids(
channel_uid
):
print(user_uid, flush=True) print(user_uid, flush=True)
await self.send_to_user(user_uid, message) await self.send_to_user(user_uid, message)
except Exception as ex: except Exception as ex:

View File

@ -39,6 +39,15 @@ class UserService(BaseService):
model = await self.get(username=username, deleted_at=None) model = await self.get(username=username, deleted_at=None)
return model return model
def get_admin_uids(self):
return self.mapper.get_admin_uids()
async def get_template_path(self, user_uid):
path = pathlib.Path(f"./drive/{user_uid}/snek/templates")
if not path.exists():
return None
return path
async def get_home_folder(self, user_uid): async def get_home_folder(self, user_uid):
folder = pathlib.Path(f"./drive/{user_uid}") folder = pathlib.Path(f"./drive/{user_uid}")
if not folder.exists(): if not folder.exists():

View File

@ -9,16 +9,18 @@ class UserPropertyService(BaseService):
async def set(self, user_uid, name, value): async def set(self, user_uid, name, value):
self.mapper.db["user_property"].upsert( self.mapper.db["user_property"].upsert(
{ {
"user_uid": user_uid, "user_uid": user_uid,
"name": name, "name": name,
"value": json.dumps(value, default=str) "value": json.dumps(value, default=str),
}, },
["user_uid", "name"] ["user_uid", "name"],
) )
async def get(self, user_uid, name): async def get(self, user_uid, name):
try: try:
return json.loads((await super().get(user_uid=user_uid, name=name))["value"]) return json.loads(
(await super().get(user_uid=user_uid, name=name))["value"]
)
except Exception as ex: except Exception as ex:
print(ex) print(ex)
return None return None

View File

@ -18,7 +18,7 @@ class Cache:
self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4 self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
async def get(self, args): async def get(self, args):
await self.update_stat(args, 'get') await self.update_stat(args, "get")
try: try:
self.lru.pop(self.lru.index(args)) self.lru.pop(self.lru.index(args))
except: except:
@ -34,20 +34,28 @@ class Cache:
async def get_stats(self): async def get_stats(self):
all_ = [] all_ = []
for key in self.lru: for key in self.lru:
all_.append({'key': key, 'set': self.stats[key]['set'], 'get': self.stats[key]['get'], 'delete': self.stats[key]['delete'],'value': str(self.serialize(self.cache[key].record))}) all_.append(
{
"key": key,
"set": self.stats[key]["set"],
"get": self.stats[key]["get"],
"delete": self.stats[key]["delete"],
"value": str(self.serialize(self.cache[key].record)),
}
)
return all_ return all_
def serialize(self, obj): def serialize(self, obj):
cpy = obj.copy() cpy = obj.copy()
cpy.pop('created_at', None) cpy.pop("created_at", None)
cpy.pop('deleted_at', None) cpy.pop("deleted_at", None)
cpy.pop('email', None) cpy.pop("email", None)
cpy.pop('password', None) cpy.pop("password", None)
return cpy return cpy
async def update_stat(self, key, action): async def update_stat(self, key, action):
if not key in self.stats: if key not in self.stats:
self.stats[key] = {'set':0, 'get':0, 'delete':0} self.stats[key] = {"set": 0, "get": 0, "delete": 0}
self.stats[key][action] = self.stats[key][action] + 1 self.stats[key][action] = self.stats[key][action] + 1
def json_default(self, value): def json_default(self, value):
@ -70,7 +78,7 @@ class Cache:
async def set(self, args, result): async def set(self, args, result):
is_new = args not in self.cache is_new = args not in self.cache
self.cache[args] = result self.cache[args] = result
await self.update_stat(args, 'set') await self.update_stat(args, "set")
try: try:
self.lru.pop(self.lru.index(args)) self.lru.pop(self.lru.index(args))
except (ValueError, IndexError): except (ValueError, IndexError):
@ -86,7 +94,7 @@ class Cache:
# print(f"Cache store! {len(self.lru)} items. New version:", self.version, flush=True) # print(f"Cache store! {len(self.lru)} items. New version:", self.version, flush=True)
async def delete(self, args): async def delete(self, args):
await self.update_stat(args, 'delete') await self.update_stat(args, "delete")
if args in self.cache: if args in self.cache:
try: try:
self.lru.pop(self.lru.index(args)) self.lru.pop(self.lru.index(args))

View File

@ -32,9 +32,8 @@ from urllib.parse import urljoin
import aiohttp import aiohttp
import imgkit import imgkit
from bs4 import BeautifulSoup
from app.cache import time_cache_async from app.cache import time_cache_async
from bs4 import BeautifulSoup
async def crc32(data): async def crc32(data):

View File

@ -2,13 +2,12 @@
from types import SimpleNamespace from types import SimpleNamespace
from app.cache import time_cache_async
from mistune import HTMLRenderer, Markdown from mistune import HTMLRenderer, Markdown
from pygments import highlight from pygments import highlight
from pygments.formatters import html from pygments.formatters import html
from pygments.lexers import get_lexer_by_name from pygments.lexers import get_lexer_by_name
from app.cache import time_cache_async
class MarkdownRenderer(HTMLRenderer): class MarkdownRenderer(HTMLRenderer):

View File

@ -16,12 +16,12 @@ commands = {
class TerminalSession: class TerminalSession:
def __init__(self, command): def __init__(self, command):
self.master, self.slave = None,None self.master, self.slave = None, None
self.process = None self.process = None
self.sockets = [] self.sockets = []
self.history = b"" self.history = b""
self.history_size = 1024 * 20 self.history_size = 1024 * 20
self.command = command self.command = command
self.start_process(self.command) self.start_process(self.command)
def start_process(self, command): def start_process(self, command):
@ -29,7 +29,7 @@ class TerminalSession:
if self.master: if self.master:
os.close(self.master) os.close(self.master)
os.close(self.slave) os.close(self.slave)
self.master = None self.master = None
self.slave = None self.slave = None
self.master, self.slave = pty.openpty() self.master, self.slave = pty.openpty()
@ -45,7 +45,7 @@ class TerminalSession:
def is_running(self): def is_running(self):
if not self.process: if not self.process:
return False return False
loop = asyncio.get_event_loop() asyncio.get_event_loop()
return self.process.poll() is None return self.process.poll() is None
async def add_websocket(self, ws): async def add_websocket(self, ws):
@ -78,7 +78,7 @@ class TerminalSession:
self.sockets.remove(ws) self.sockets.remove(ws)
except Exception: except Exception:
await self.close() await self.close()
break break
async def close(self): async def close(self):
print("Terminating process") print("Terminating process")
@ -88,8 +88,8 @@ class TerminalSession:
if self.master: if self.master:
os.close(self.master) os.close(self.master)
os.close(self.slave) os.close(self.slave)
self.master = None self.master = None
self.slave = None self.slave = None
print("Terminated process") print("Terminated process")
for ws in self.sockets: for ws in self.sockets:

View File

@ -8,7 +8,9 @@ class BaseView(web.View):
login_required = False login_required = False
async def _iter(self): async def _iter(self):
if self.login_required and (not self.session.get("logged_in") or not self.session.get("uid")): if self.login_required and (
not self.session.get("logged_in") or not self.session.get("uid")
):
return web.HTTPFound("/") return web.HTTPFound("/")
return await super()._iter() return await super()._iter()

View File

@ -10,9 +10,11 @@
# MIT License # MIT License
from snek.system.view import BaseView
from aiohttp import web from aiohttp import web
from snek.system.view import BaseView
class IndexView(BaseView): class IndexView(BaseView):
async def get(self): async def get(self):
if self.session.get("uid"): if self.session.get("uid"):

View File

@ -35,6 +35,7 @@ from snek.system.view import BaseFormView
class SearchUserView(BaseFormView): class SearchUserView(BaseFormView):
form = SearchUserForm form = SearchUserForm
login_required = True login_required = True
async def get(self): async def get(self):
users = [] users = []
query = self.request.query.get("query") query = self.request.query.get("query")

View File

@ -17,24 +17,22 @@ class SettingsProfileView(BaseFormView):
return web.json_response(await form.to_json()) return web.json_response(await form.to_json())
profile = await self.services.user_property.get(
self.session.get("uid"), "profile"
)
profile = await self.services.user_property.get(self.session.get("uid"), "profile")
user = await self.services.user.get(uid=self.session.get("uid")) user = await self.services.user.get(uid=self.session.get("uid"))
return await self.render_template( return await self.render_template(
"settings/profile.html", {"form": await form.to_json(), "user": user, "profile": profile or ''} "settings/profile.html",
{"form": await form.to_json(), "user": user, "profile": profile or ""},
) )
async def post(self): async def post(self):
data = await self.request.post() data = await self.request.post()
user = await self.services.user.get(uid=self.session.get("uid")) user = await self.services.user.get(uid=self.session.get("uid"))
user['nick'] = data['nick'] user["nick"] = data["nick"]
await self.services.user.save(user) await self.services.user.save(user)
await self.services.user_property.set(user["uid"],"profile", data['profile']) await self.services.user_property.set(user["uid"], "profile", data["profile"])
return web.HTTPFound("/settings/profile.html") return web.HTTPFound("/settings/profile.html")

View File

@ -1,10 +1,13 @@
from snek.system.view import BaseView import json
import json
from aiohttp import web from aiohttp import web
from snek.system.view import BaseView
class StatsView(BaseView): class StatsView(BaseView):
async def get(self): async def get(self):
data = await self.app.cache.get_stats() data = await self.app.cache.get_stats()
data = json.dumps({"total": len(data), "stats": data}, default=str, indent=1) data = json.dumps({"total": len(data), "stats": data}, default=str, indent=1)
return web.Response(text=data, content_type='application/json') return web.Response(text=data, content_type="application/json")

View File

@ -38,9 +38,9 @@ class UploadView(BaseView):
user_uid = self.request.session.get("uid") user_uid = self.request.session.get("uid")
upload_dir = await self.services.user.get_home_folder(user_uid) upload_dir = await self.services.user.get_home_folder(user_uid)
upload_dir = upload_dir.joinpath("upload") upload_dir = upload_dir.joinpath("upload")
upload_dir.mkdir(parents=True, exist_ok=True) upload_dir.mkdir(parents=True, exist_ok=True)
channel_uid = None channel_uid = None
drive = await self.services.drive.get_or_create( drive = await self.services.drive.get_or_create(

View File

@ -2,13 +2,14 @@ from snek.system.view import BaseView
class UserView(BaseView): class UserView(BaseView):
async def get(self): async def get(self):
user_uid = self.request.match_info.get('user') user_uid = self.request.match_info.get("user")
user = await self.services.user.get(uid=user_uid) user = await self.services.user.get(uid=user_uid)
profile_content = await self.services.user_property.get(user['uid'],'profile') or '' profile_content = (
return await self.render_template('user.html', { await self.services.user_property.get(user["uid"], "profile") or ""
'user_uid': user_uid, )
'user': user.record, return await self.render_template(
'profile': profile_content "user.html",
}) {"user_uid": user_uid, "user": user.record, "profile": profile_content},
)

View File

@ -12,9 +12,8 @@ import uuid
import aiofiles import aiofiles
import aiohttp import aiohttp
import aiohttp.web import aiohttp.web
from lxml import etree
from app.cache import time_cache_async from app.cache import time_cache_async
from lxml import etree
@aiohttp.web.middleware @aiohttp.web.middleware