diff --git a/src/snek/__init__.py b/src/snek/__init__.py index ba51bb5..9dff4d0 100644 --- a/src/snek/__init__.py +++ b/src/snek/__init__.py @@ -26,9 +26,9 @@ Description: Utility to load environment variables from a .env file. """ import os -from typing import Optional -def load_env(file_path: str = '.env') -> None: + +def load_env(file_path: str = ".env") -> None: """ Loads environment variables from a specified file into the current process environment. @@ -43,26 +43,26 @@ def load_env(file_path: str = '.env') -> None: IOError: If an I/O error occurs during file reading. """ try: - with open(file_path, 'r') as env_file: + with open(file_path) as env_file: for line in env_file: line = line.strip() # Skip empty lines and comments - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue # Skip lines without '=' - if '=' not in line: + if "=" not in line: continue # Split into key and value at the first '=' - key, value = line.split('=', 1) + key, value = line.split("=", 1) # Set environment variable os.environ[key.strip()] = value.strip() except FileNotFoundError: raise FileNotFoundError(f"Environment file '{file_path}' not found.") - except IOError as e: - raise IOError(f"Error reading environment file '{file_path}': {e}") + except OSError as e: + raise OSError(f"Error reading environment file '{file_path}': {e}") try: load_env() -except Exception as e: +except Exception: pass diff --git a/src/snek/__main__.py b/src/snek/__main__.py index 53f9751..a9af4ed 100644 --- a/src/snek/__main__.py +++ b/src/snek/__main__.py @@ -1,20 +1,20 @@ +import asyncio import pathlib import shutil import sqlite3 -import asyncio + import click from aiohttp import web -from IPython import start_ipython -from snek.shell import Shell + from snek.app import Application - - +from snek.shell import Shell @click.group() def cli(): pass + @cli.command() def export(): @@ -30,6 +30,7 @@ def export(): user = await app.services.user.get(uid=message["user_uid"]) message["user"] = user and user["username"] or None return (message["user"] or "") + ": " + (message["text"] or "") + async def run(): result = [] @@ -49,6 +50,7 @@ def export(): with open("dump.txt", "w") as f: f.write("\n\n".join(result)) print("Dump written to dump.json") + asyncio.run(run()) @@ -57,16 +59,20 @@ def statistics(): async def run(): app = Application(db_path="sqlite:///snek.db") app.services.statistics.database() + asyncio.run(run()) + @cli.command() def maintenance(): async def run(): app = Application(db_path="sqlite:///snek.db") await app.services.container.maintenance() await app.services.channel_message.maintenance() + asyncio.run(run()) + @cli.command() @click.option( "--db_path", default="snek.db", help="Database to initialize if not exists." @@ -123,9 +129,11 @@ def serve(port, host, db_path): def shell(db_path): Shell(db_path).run() + def main(): try: import sentry_sdk + sentry_sdk.init("https://ab6147c2f3354c819768c7e89455557b@gt.molodetz.nl/1") except ImportError: print("Could not import sentry_sdk") diff --git a/src/snek/app.py b/src/snek/app.py index a551e41..c90db06 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -3,13 +3,13 @@ import logging import pathlib import ssl import uuid -import signal -from datetime import datetime from contextlib import asynccontextmanager +from datetime import datetime from snek import snode -from snek.view.threads import ThreadsView from snek.system.ads import AsyncDataSet +from snek.view.threads import ThreadsView + logging.basicConfig(level=logging.DEBUG) from concurrent.futures import ThreadPoolExecutor from ipaddress import ip_address @@ -25,16 +25,21 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage from app.app import Application as BaseApplication from jinja2 import FileSystemLoader +from snek.forum import setup_forum from snek.mapper import get_mappers from snek.service import get_services from snek.sgit import GitApplication from snek.sssh import start_ssh_server from snek.system import http from snek.system.cache import Cache -from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler from snek.system.markdown import MarkdownExtension from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware from snek.system.profiler import profiler_handler +from snek.system.stats import ( + create_stats_structure, + middleware as stats_middleware, + stats_handler, +) from snek.system.template import ( EmojiExtension, LinkifyExtension, @@ -43,10 +48,15 @@ from snek.system.template import ( ) from snek.view.about import AboutHTMLView, AboutMDView from snek.view.avatar import AvatarView -from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView +from snek.view.channel import ( + ChannelAttachmentUploadView, + ChannelAttachmentView, + ChannelDriveApiView, + ChannelView, +) +from snek.view.container import ContainerView from snek.view.docs import DocsHTMLView, DocsMDView from snek.view.drive import DriveApiView, DriveView -from snek.view.channel import ChannelDriveApiView from snek.view.index import IndexView from snek.view.login import LoginView from snek.view.logout import LogoutView @@ -55,7 +65,6 @@ from snek.view.register import RegisterView from snek.view.repository import RepositoryView from snek.view.rpc import RPCView from snek.view.search_user import SearchUserView -from snek.view.container import ContainerView from snek.view.settings.containers import ( ContainersCreateView, ContainersDeleteView, @@ -77,7 +86,6 @@ from snek.view.upload import UploadView from snek.view.user import UserView from snek.view.web import WebView from snek.webdav import WebdavApplication -from snek.forum import setup_forum SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34" from snek.system.template import whitelist_attributes @@ -175,13 +183,12 @@ class Application(BaseApplication): self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_ssh_server) - #self.on_startup.append(self.prepare_database) + # self.on_startup.append(self.prepare_database) async def prepare_stats(self, app): - app['stats'] = create_stats_structure() + app["stats"] = create_stats_structure() print("Stats prepared", flush=True) - @property def uptime_seconds(self): return (datetime.now() - self.time_start).total_seconds() @@ -225,10 +232,10 @@ class Application(BaseApplication): # app.loop = asyncio.get_running_loop() app.executor = ThreadPoolExecutor(max_workers=200) app.loop.set_default_executor(self.executor) - #for sig in (signal.SIGINT, signal.SIGTERM): - #app.loop.add_signal_handler( - # sig, lambda: asyncio.create_task(self.services.container.shutdown()) - #) + # for sig in (signal.SIGINT, signal.SIGTERM): + # app.loop.add_signal_handler( + # sig, lambda: asyncio.create_task(self.services.container.shutdown()) + # ) async def create_task(self, task): await self.tasks.put(task) @@ -244,7 +251,6 @@ class Application(BaseApplication): print(ex) self.db.commit() - async def prepare_database(self, app): await self.db.query_raw("PRAGMA journal_mode=WAL") await self.db.query_raw("PRAGMA syncnorm=off") @@ -291,9 +297,7 @@ class Application(BaseApplication): self.router.add_view( "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView ) - self.router.add_view( - "/channel/{channel_uid}/drive.json", ChannelDriveApiView - ) + self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView) self.router.add_view( "/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView ) @@ -414,9 +418,9 @@ class Application(BaseApplication): ) try: - context["nonce"] = request['csp_nonce'] + context["nonce"] = request["csp_nonce"] except: - context['nonce'] = '?' + context["nonce"] = "?" rendered = await super().render_template(template, request, context) @@ -466,14 +470,14 @@ class Application(BaseApplication): @asynccontextmanager async def no_save(self): - stats = { - 'count': 0 - } + stats = {"count": 0} + async def patched_save(*args, **kwargs): await self.cache.set(args[0]["uid"], args[0]) - stats['count'] = stats['count'] + 1 + stats["count"] = stats["count"] + 1 print(f"save is ignored {stats['count']} times") return args[0] + save_original = self.services.channel_message.mapper.save self.services.channel_message.mapper.save = patched_save raised_exception = None @@ -486,6 +490,7 @@ class Application(BaseApplication): if raised_exception: raise raised_exception + app = Application(db_path="sqlite:///snek.db") diff --git a/src/snek/forum.py b/src/snek/forum.py index a0b0d2d..f1a0e78 100644 --- a/src/snek/forum.py +++ b/src/snek/forum.py @@ -1,5 +1,6 @@ # forum_app.py import aiohttp.web + from snek.view.forum import ForumIndexView, ForumView, ForumWebSocketView @@ -7,13 +8,13 @@ class ForumApplication(aiohttp.web.Application): def __init__(self, parent, *args, **kwargs): super().__init__(*args, **kwargs) self.parent = parent - self.render_template = self.parent.render_template + self.render_template = self.parent.render_template # Set up routes self.setup_routes() - + # Set up notification listeners self.setup_notifications() - + @property def db(self): return self.parent.db @@ -21,86 +22,84 @@ class ForumApplication(aiohttp.web.Application): @property def services(self): return self.parent.services - + def setup_routes(self): """Set up all forum routes""" # API routes self.router.add_view("/index.html", ForumIndexView) self.router.add_route("GET", "/api/forums", ForumView.get_forums) self.router.add_route("GET", "/api/forums/{slug}", ForumView.get_forum) - self.router.add_route("POST", "/api/forums/{slug}/threads", ForumView.create_thread) + self.router.add_route( + "POST", "/api/forums/{slug}/threads", ForumView.create_thread + ) self.router.add_route("GET", "/api/threads/{thread_slug}", ForumView.get_thread) - self.router.add_route("POST", "/api/threads/{thread_uid}/posts", ForumView.create_post) + self.router.add_route( + "POST", "/api/threads/{thread_uid}/posts", ForumView.create_post + ) self.router.add_route("PUT", "/api/posts/{post_uid}", ForumView.edit_post) self.router.add_route("DELETE", "/api/posts/{post_uid}", ForumView.delete_post) - self.router.add_route("POST", "/api/posts/{post_uid}/like", ForumView.toggle_like) - self.router.add_route("POST", "/api/threads/{thread_uid}/pin", ForumView.toggle_pin) - self.router.add_route("POST", "/api/threads/{thread_uid}/lock", ForumView.toggle_lock) - + self.router.add_route( + "POST", "/api/posts/{post_uid}/like", ForumView.toggle_like + ) + self.router.add_route( + "POST", "/api/threads/{thread_uid}/pin", ForumView.toggle_pin + ) + self.router.add_route( + "POST", "/api/threads/{thread_uid}/lock", ForumView.toggle_lock + ) + # WebSocket route self.router.add_view("/ws", ForumWebSocketView) - + # Static HTML route self.router.add_route("GET", "/{path:.*}", self.serve_forum_html) - + def setup_notifications(self): """Set up notification listeners for WebSocket broadcasting""" # Forum notifications - self.services.forum.add_notification_listener("forum_created", self.on_forum_event) - + self.services.forum.add_notification_listener( + "forum_created", self.on_forum_event + ) + # Thread notifications - self.services.thread.add_notification_listener("thread_created", self.on_thread_event) - + self.services.thread.add_notification_listener( + "thread_created", self.on_thread_event + ) + # Post notifications self.services.post.add_notification_listener("post_created", self.on_post_event) self.services.post.add_notification_listener("post_edited", self.on_post_event) self.services.post.add_notification_listener("post_deleted", self.on_post_event) - + # Like notifications - self.services.post_like.add_notification_listener("post_liked", self.on_like_event) - self.services.post_like.add_notification_listener("post_unliked", self.on_like_event) - + self.services.post_like.add_notification_listener( + "post_liked", self.on_like_event + ) + self.services.post_like.add_notification_listener( + "post_unliked", self.on_like_event + ) + async def on_forum_event(self, event_type, data): """Handle forum events""" await ForumWebSocketView.broadcast_update(self, event_type, data) - + async def on_thread_event(self, event_type, data): """Handle thread events""" await ForumWebSocketView.broadcast_update(self, event_type, data) - + async def on_post_event(self, event_type, data): """Handle post events""" await ForumWebSocketView.broadcast_update(self, event_type, data) - + async def on_like_event(self, event_type, data): """Handle like events""" await ForumWebSocketView.broadcast_update(self, event_type, data) - + async def serve_forum_html(self, request): """Serve the forum HTML with the web component""" - html = """ - - - - - Forum - - - - - - -""" return await self.parent.render_template("forum.html", request) - - - #return aiohttp.web.Response(text=html, content_type="text/html") + + # return aiohttp.web.Response(text=html, content_type="text/html") # Integration with main app diff --git a/src/snek/mapper/__init__.py b/src/snek/mapper/__init__.py index be01022..3bce17f 100644 --- a/src/snek/mapper/__init__.py +++ b/src/snek/mapper/__init__.py @@ -7,12 +7,12 @@ from snek.mapper.channel_message import ChannelMessageMapper from snek.mapper.container import ContainerMapper from snek.mapper.drive import DriveMapper from snek.mapper.drive_item import DriveItemMapper +from snek.mapper.forum import ForumMapper, PostLikeMapper, PostMapper, ThreadMapper from snek.mapper.notification import NotificationMapper from snek.mapper.push import PushMapper from snek.mapper.repository import RepositoryMapper from snek.mapper.user import UserMapper from snek.mapper.user_property import UserPropertyMapper -from snek.mapper.forum import ForumMapper, ThreadMapper, PostMapper, PostLikeMapper from snek.system.object import Object @@ -42,4 +42,3 @@ def get_mappers(app=None): def get_mapper(name, app=None): return get_mappers(app=app)[name] - diff --git a/src/snek/mapper/forum.py b/src/snek/mapper/forum.py index ec258a2..37009f6 100644 --- a/src/snek/mapper/forum.py +++ b/src/snek/mapper/forum.py @@ -1,5 +1,5 @@ # mapper/forum.py -from snek.model.forum import ForumModel, ThreadModel, PostModel, PostLikeModel +from snek.model.forum import ForumModel, PostLikeModel, PostModel, ThreadModel from snek.system.mapper import BaseMapper diff --git a/src/snek/model/__init__.py b/src/snek/model/__init__.py index 13ba5b2..1ee16db 100644 --- a/src/snek/model/__init__.py +++ b/src/snek/model/__init__.py @@ -9,12 +9,12 @@ from snek.model.channel_message import ChannelMessageModel from snek.model.container import Container from snek.model.drive import DriveModel from snek.model.drive_item import DriveItemModel +from snek.model.forum import ForumModel, PostLikeModel, PostModel, ThreadModel from snek.model.notification import NotificationModel from snek.model.push_registration import PushRegistrationModel from snek.model.repository import RepositoryModel from snek.model.user import UserModel from snek.model.user_property import UserPropertyModel -from snek.model.forum import ForumModel, ThreadModel, PostModel, PostLikeModel from snek.system.object import Object @@ -44,5 +44,3 @@ def get_models(): def get_model(name): return get_models()[name] - - diff --git a/src/snek/model/channel.py b/src/snek/model/channel.py index 2c17478..b5cb35d 100644 --- a/src/snek/model/channel.py +++ b/src/snek/model/channel.py @@ -12,10 +12,10 @@ class ChannelModel(BaseModel): index = ModelField(name="index", required=True, kind=int, value=1000) last_message_on = ModelField(name="last_message_on", required=False, kind=str) history_start = ModelField(name="history_start", required=False, kind=str) - + @property def is_dm(self): - return 'dm' in self['tag'].lower() + return "dm" in self["tag"].lower() async def get_last_message(self) -> ChannelMessageModel: history_start_filter = "" @@ -23,7 +23,9 @@ class ChannelModel(BaseModel): history_start_filter = f" AND created_at > '{self['history_start']}' " try: async for model in self.app.services.channel_message.query( - "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid" + history_start_filter + " ORDER BY created_at DESC LIMIT 1", + "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid" + + history_start_filter + + " ORDER BY created_at DESC LIMIT 1", {"channel_uid": self["uid"]}, ): diff --git a/src/snek/model/forum.py b/src/snek/model/forum.py index 9c037a1..80e21a3 100644 --- a/src/snek/model/forum.py +++ b/src/snek/model/forum.py @@ -4,9 +4,16 @@ from snek.system.model import BaseModel, ModelField class ForumModel(BaseModel): """Forum categories""" - name = ModelField(name="name", required=True, kind=str, min_length=3, max_length=100) - description = ModelField(name="description", required=False, kind=str, max_length=500) - slug = ModelField(name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$", unique=True) + + name = ModelField( + name="name", required=True, kind=str, min_length=3, max_length=100 + ) + description = ModelField( + name="description", required=False, kind=str, max_length=500 + ) + slug = ModelField( + name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$", unique=True + ) icon = ModelField(name="icon", required=False, kind=str) position = ModelField(name="position", required=True, kind=int, value=0) is_active = ModelField(name="is_active", required=True, kind=bool, value=True) @@ -18,12 +25,12 @@ class ForumModel(BaseModel): async def get_threads(self, limit=50, offset=0): async for thread in self.app.services.thread.find( - forum_uid=self["uid"], + forum_uid=self["uid"], deleted_at=None, _limit=limit, _offset=offset, - order_by="-last_post_at" - #order_by="is_pinned DESC, last_post_at DESC" + order_by="-last_post_at", + # order_by="is_pinned DESC, last_post_at DESC" ): yield thread @@ -39,8 +46,11 @@ class ForumModel(BaseModel): # models/thread.py class ThreadModel(BaseModel): """Forum threads""" + forum_uid = ModelField(name="forum_uid", required=True, kind=str) - title = ModelField(name="title", required=True, kind=str, min_length=5, max_length=200) + title = ModelField( + name="title", required=True, kind=str, min_length=5, max_length=200 + ) slug = ModelField(name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$") created_by_uid = ModelField(name="created_by_uid", required=True, kind=str) is_pinned = ModelField(name="is_pinned", required=True, kind=bool, value=False) @@ -56,7 +66,7 @@ class ThreadModel(BaseModel): deleted_at=None, _limit=limit, _offset=offset, - order_by="created_at" + order_by="created_at", ): yield post @@ -70,30 +80,35 @@ class ThreadModel(BaseModel): await self.save() -# models/post.py +# models/post.py class PostModel(BaseModel): """Forum posts""" + thread_uid = ModelField(name="thread_uid", required=True, kind=str) forum_uid = ModelField(name="forum_uid", required=True, kind=str) - content = ModelField(name="content", required=True, kind=str, min_length=1, max_length=10000) + content = ModelField( + name="content", required=True, kind=str, min_length=1, max_length=10000 + ) created_by_uid = ModelField(name="created_by_uid", required=True, kind=str) edited_at = ModelField(name="edited_at", required=False, kind=str) edited_by_uid = ModelField(name="edited_by_uid", required=False, kind=str) - is_first_post = ModelField(name="is_first_post", required=True, kind=bool, value=False) + is_first_post = ModelField( + name="is_first_post", required=True, kind=bool, value=False + ) like_count = ModelField(name="like_count", required=True, kind=int, value=0) - + async def get_author(self): return await self.app.services.user.get(uid=self["created_by_uid"]) async def is_liked_by(self, user_uid): return await self.app.services.post_like.exists( - post_uid=self["uid"], - user_uid=user_uid + post_uid=self["uid"], user_uid=user_uid ) # models/post_like.py class PostLikeModel(BaseModel): """Post likes""" + post_uid = ModelField(name="post_uid", required=True, kind=str) user_uid = ModelField(name="user_uid", required=True, kind=str) diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py index 6e9e58f..ae44e15 100644 --- a/src/snek/service/__init__.py +++ b/src/snek/service/__init__.py @@ -9,37 +9,43 @@ from snek.service.container import ContainerService from snek.service.db import DBService from snek.service.drive import DriveService from snek.service.drive_item import DriveItemService +from snek.service.forum import ForumService, PostLikeService, PostService, ThreadService from snek.service.notification import NotificationService from snek.service.push import PushService from snek.service.repository import RepositoryService from snek.service.socket import SocketService +from snek.service.statistics import StatisticsService from snek.service.user import UserService from snek.service.user_property import UserPropertyService from snek.service.util import UtilService from snek.system.object import Object -from snek.service.statistics import StatisticsService -from snek.service.forum import ForumService, ThreadService, PostService, PostLikeService + _service_registry = {} + def register_service(name, service_cls): _service_registry[name] = service_cls + register = register_service + @functools.cache def get_services(app): - result = Object( + result = Object( **{ name: service_cls(app=app) for name, service_cls in _service_registry.items() } ) result.register = register_service - return result + return result + def get_service(name, app=None): return get_services(app=app)[name] + # Registering all services register_service("user", UserService) register_service("channel_member", ChannelMemberService) @@ -62,4 +68,3 @@ register_service("forum", ForumService) register_service("thread", ThreadService) register_service("post", PostService) register_service("post_like", PostLikeService) - diff --git a/src/snek/service/channel.py b/src/snek/service/channel.py index 5d15683..5a52fc7 100644 --- a/src/snek/service/channel.py +++ b/src/snek/service/channel.py @@ -115,10 +115,9 @@ class ChannelService(BaseService): async def clear(self, channel_uid): model = await self.get(uid=channel_uid) - model['history_from'] = datetime.now() + model["history_from"] = datetime.now() await self.save(model) - async def ensure_public_channel(self, created_by_uid): model = await self.get(is_listed=True, tag="public") is_moderator = False diff --git a/src/snek/service/channel_message.py b/src/snek/service/channel_message.py index 642f4e0..12ab5f3 100644 --- a/src/snek/service/channel_message.py +++ b/src/snek/service/channel_message.py @@ -1,6 +1,8 @@ +import time + from snek.system.service import BaseService from snek.system.template import sanitize_html -import time + class ChannelMessageService(BaseService): mapper_name = "channel_message" @@ -10,7 +12,6 @@ class ChannelMessageService(BaseService): self._configured_indexes = False async def maintenance(self): - args = {} for message in self.mapper.db["channel_message"].find(): print(message) try: @@ -34,7 +35,6 @@ class ChannelMessageService(BaseService): time.sleep(0.1) print(ex, flush=True) - while True: changed = 0 async for message in self.find(is_final=False): @@ -72,7 +72,7 @@ class ChannelMessageService(BaseService): try: template = self.app.jinja2_env.get_template("message.html") model["html"] = template.render(**context) - model['html'] = sanitize_html(model['html']) + model["html"] = sanitize_html(model["html"]) except Exception as ex: print(ex, flush=True) @@ -99,9 +99,9 @@ class ChannelMessageService(BaseService): if not user: return {} - #if not message["html"].startswith("= datetime('now', '-5 minutes')") + cursor.execute( + "SELECT COUNT(*) FROM user WHERE last_ping >= datetime('now', '-5 minutes')" + ) return cursor.fetchone()[0] # Example usage of new functions - messages_per_day = get_time_series_stats('channel_message', 'created_at') - users_created = get_count_created('user', 'created_at') - channels_created = get_count_created('channel', 'created_at') + messages_per_day = get_time_series_stats("channel_message", "created_at") + users_created = get_count_created("user", "created_at") + channels_created = get_count_created("channel", "created_at") channels_per_user = get_channels_per_user() online_users = get_online_users() diff --git a/src/snek/service/user_property.py b/src/snek/service/user_property.py index acaf0bf..41e8dfa 100644 --- a/src/snek/service/user_property.py +++ b/src/snek/service/user_property.py @@ -8,7 +8,7 @@ class UserPropertyService(BaseService): async def set(self, user_uid, name, value): self.mapper.db.upsert( - "user_property", + "user_property", { "user_uid": user_uid, "name": name, diff --git a/src/snek/shell.py b/src/snek/shell.py index 2a1546b..de7885f 100644 --- a/src/snek/shell.py +++ b/src/snek/shell.py @@ -1,19 +1,16 @@ -from snek.app import Application from IPython import start_ipython +from snek.app import Application + + class Shell: - def __init__(self,db_path): + def __init__(self, db_path): self.app = Application(db_path=f"sqlite:///{db_path}") - + async def maintenance(self): await self.app.services.container.maintenance() await self.app.services.channel_message.maintenance() - - def run(self): - ns = { - "app": self.app, - "maintenance": self.maintenance - } + ns = {"app": self.app, "maintenance": self.maintenance} start_ipython(argv=[], user_ns=ns) diff --git a/src/snek/system/ads.py b/src/snek/system/ads.py index 42b64a0..06b2083 100644 --- a/src/snek/system/ads.py +++ b/src/snek/system/ads.py @@ -1,28 +1,27 @@ -import re +import asyncio +import atexit import json -from uuid import uuid4 +import os +import re +import socket +import unittest +from collections.abc import AsyncGenerator, Iterable from datetime import datetime, timezone +from pathlib import Path from typing import ( Any, Dict, - Iterable, List, Optional, - AsyncGenerator, - Union, - Tuple, Set, + Tuple, + Union, ) -from pathlib import Path -import aiosqlite -import unittest -from types import SimpleNamespace -import asyncio +from uuid import uuid4 + import aiohttp +import aiosqlite from aiohttp import web -import socket -import os -import atexit class AsyncDataSet: @@ -103,7 +102,7 @@ class AsyncDataSet: # No socket file, become server self._is_server = True await self._setup_server() - except Exception as e: + except Exception: # Fallback to client mode self._is_server = False await self._setup_client() @@ -213,7 +212,7 @@ class AsyncDataSet: data = await request.json() try: try: - _limit = data.pop("_limit",60) + _limit = data.pop("_limit", 60) except: pass @@ -751,7 +750,7 @@ class AsyncDataSet: for k, v in where.items() ] ) - return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None) + return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None] async def _server_insert( self, table: str, args: Dict[str, Any], return_id: bool = False @@ -887,7 +886,6 @@ class AsyncDataSet: except: pass - where_clause, where_params = self._build_where(where) order_clause = f" ORDER BY {order_by}" if order_by else "" extra = (f" LIMIT {_limit}" if _limit else "") + ( @@ -1209,4 +1207,3 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): if __name__ == "__main__": unittest.main() - diff --git a/src/snek/system/docker.py b/src/snek/system/docker.py index c72620b..92772b0 100644 --- a/src/snek/system/docker.py +++ b/src/snek/system/docker.py @@ -1,8 +1,9 @@ -import copy -import json -import yaml import asyncio -import subprocess +import copy +import json + +import yaml + try: import pty except Exception as ex: @@ -10,26 +11,28 @@ except Exception as ex: print(ex) import os + class ComposeFileManager: - def __init__(self, compose_path="docker-compose.yml",event_handler=None): + def __init__(self, compose_path="docker-compose.yml", event_handler=None): self.compose_path = compose_path self._load() self.running_instances = {} - self.event_handler = event_handler + self.event_handler = event_handler async def shutdown(self): print("Stopping all sessions") - - tasks = [] + + tasks = [] for name in self.list_instances(): proc = self.running_instances.get(name) if not proc: continue - if proc['proc'].returncode == None: - print("Stopping",name) - tasks.append(asyncio.create_task(proc['proc'].stop())) - print("Stopped",name,"gracefully") - return tasks + if proc["proc"].returncode is None: + print("Stopping", name) + tasks.append(asyncio.create_task(proc["proc"].stop())) + print("Stopped", name, "gracefully") + return tasks + def _load(self): try: with open(self.compose_path) as f: @@ -47,22 +50,21 @@ class ComposeFileManager: async def _create_readers(self, container_name): instance = await self.get_instance(container_name) if not instance: - return False + return False proc = self.running_instances.get(container_name) if not proc: return False - - async def reader(event_handler,stream): + + async def reader(event_handler, stream): loop = asyncio.get_event_loop() while True: - line = await loop.run_in_executor(None,os.read,stream,1024) + line = await loop.run_in_executor(None, os.read, stream, 1024) if not line: break - await event_handler(container_name,"stdout",line) + await event_handler(container_name, "stdout", line) await self.stop(container_name) - asyncio.create_task(reader(self.event_handler,proc['master'])) - + asyncio.create_task(reader(self.event_handler, proc["master"])) def create_instance( self, @@ -79,7 +81,7 @@ class ComposeFileManager: "context": ".", "dockerfile": "DockerfileUbuntu", }, - "user":"root", + "user": "root", "working_dir": "/home/retoor/projects/snek", "environment": [f"SNEK_UID={name}"], } @@ -109,9 +111,9 @@ class ComposeFileManager: instance = self.compose.get("services", {}).get(name) if not instance: return None - instance = json.loads(json.dumps(instance,default=str)) - instance['status'] = await self.get_instance_status(name) - return instance + instance = json.loads(json.dumps(instance, default=str)) + instance["status"] = await self.get_instance_status(name) + return instance def duplicate_instance(self, name, new_name): orig = self.get_instance(name) @@ -135,7 +137,14 @@ class ComposeFileManager: if name not in self.list_instances(): return "error" proc = await asyncio.create_subprocess_exec( - "docker", "compose", "-f", self.compose_path, "ps", "--services", "--filter", f"status=running", + "docker", + "compose", + "-f", + self.compose_path, + "ps", + "--services", + "--filter", + f"status=running", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -148,25 +157,30 @@ class ComposeFileManager: await self.event_handler(name, "stdin", data) proc = self.running_instances.get(name) if not proc: - return False + return False try: - os.write(proc['master'], data.encode()) + os.write(proc["master"], data.encode()) return True except Exception as ex: print(ex) await self.stop(name) - return False + return False async def stop(self, name): """Asynchronously stop a container by doing 'docker compose stop [name]'.""" if name not in self.list_instances(): - return False + return False status = await self.get_instance_status(name) if status != "running": - return True - + return True + proc = await asyncio.create_subprocess_exec( - "docker", "compose", "-f", self.compose_path, "stop", name, + "docker", + "compose", + "-f", + self.compose_path, + "stop", + name, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -177,30 +191,41 @@ class ComposeFileManager: try: stdout, stderr = await proc.communicate() except Exception as ex: - stdout = b'' + stdout = b"" stderr = str(ex).encode() - await self.event_handler(name,"stdout",stdout or stderr) - print("Return code", proc.returncode) + await self.event_handler(name, "stdout", stdout or stderr) + print("Return code", proc.returncode) if stdout: - await self.event_handler(name,"stdout",stdout or b'') + await self.event_handler(name, "stdout", stdout or b"") return stdout and stdout.decode(errors="ignore") or "" - await self.event_handler(name,"stdout",stderr or b'') + await self.event_handler(name, "stdout", stderr or b"") return stderr and stderr.decode(errors="ignore") or "" async def start(self, name): """Asynchronously start a container by doing 'docker compose up -d [name]'.""" if name not in self.list_instances(): - return False - + return False + status = await self.get_instance_status(name) - if name in self.running_instances and status == "running" and self.running_instances.get(name) and self.running_instances.get(name).get('proc').returncode == None: - return True + if ( + name in self.running_instances + and status == "running" + and self.running_instances.get(name) + and self.running_instances.get(name).get("proc").returncode is None + ): + return True elif name in self.running_instances: del self.running_instances[name] - + proc = await asyncio.create_subprocess_exec( - "docker", "compose", "-f", self.compose_path, "up", name, "-d", + "docker", + "compose", + "-f", + self.compose_path, + "up", + name, + "-d", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE, @@ -210,24 +235,29 @@ class ComposeFileManager: try: stdout, stderr = await proc.communicate() except Exception as ex: - stdout = b'' + stdout = b"" stderr = str(ex).encode() if stdout: print(stdout.decode(errors="ignore")) if stderr: print(stderr.decode(errors="ignore")) - await self.event_handler(name,"stdout",stdout or stderr) - print("Return code", proc.returncode) + await self.event_handler(name, "stdout", stdout or stderr) + print("Return code", proc.returncode) master, slave = pty.openpty() proc = await asyncio.create_subprocess_exec( - "docker", "compose", "-f", self.compose_path, "exec", name, "/usr/local/bin/entry", + "docker", + "compose", + "-f", + self.compose_path, + "exec", + name, + "/usr/local/bin/entry", stdin=slave, stdout=slave, stderr=slave, ) - proc = {'proc': proc, 'master': master, 'slave': slave} + proc = {"proc": proc, "master": master, "slave": slave} self.running_instances[name] = proc await self._create_readers(name) return True - diff --git a/src/snek/system/mapper.py b/src/snek/system/mapper.py index c4f765e..6ccc19b 100644 --- a/src/snek/system/mapper.py +++ b/src/snek/system/mapper.py @@ -1,7 +1,7 @@ DEFAULT_LIMIT = 30 import asyncio import typing -import traceback + from snek.system.model import BaseModel @@ -11,7 +11,7 @@ class BaseMapper: default_limit: int = DEFAULT_LIMIT table_name: str = None semaphore = asyncio.Semaphore(1) - + def __init__(self, app): self.app = app @@ -27,17 +27,16 @@ class BaseMapper: async def run_in_executor(self, func, *args, **kwargs): use_semaphore = kwargs.pop("use_semaphore", False) - - def _execute(): + + def _execute(): result = func(*args, **kwargs) if use_semaphore: self.db.commit() - return result - - return _execute() - #async with self.semaphore: - # return await self.loop.run_in_executor(None, _execute) + return result + return _execute() + # async with self.semaphore: + # return await self.loop.run_in_executor(None, _execute) async def new(self): return self.model_class(mapper=self, app=self.app) @@ -51,7 +50,7 @@ class BaseMapper: kwargs["uid"] = uid if not kwargs.get("deleted_at"): kwargs["deleted_at"] = None - #traceback.print_exc() + # traceback.print_exc() record = await self.db.get(self.table_name, kwargs) if not record: @@ -65,20 +64,18 @@ class BaseMapper: async def exists(self, **kwargs): return await self.db.count(self.table_name, kwargs) - - #return await self.run_in_executor(self.table.exists, **kwargs) + # return await self.run_in_executor(self.table.exists, **kwargs) async def count(self, **kwargs) -> int: - return await self.db.count(self.table_name,kwargs) - + return await self.db.count(self.table_name, kwargs) async def save(self, model: BaseModel) -> bool: if not model.record.get("uid"): raise Exception(f"Attempt to save without uid: {model.record}.") - model.updated_at.update() + model.updated_at.update() await self.upsert(model) return model - #return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True) + # return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True) async def find(self, **kwargs) -> typing.AsyncGenerator: if not kwargs.get("_limit"): @@ -100,10 +97,12 @@ class BaseMapper: yield dict(record) async def update(self, model): - if not model["deleted_at"] is None: + if model["deleted_at"] is not None: raise Exception("Can't update deleted record.") model.updated_at.update() - return await self.db.update(self.table_name, model.record, {"uid": model["uid"]}) + return await self.db.update( + self.table_name, model.record, {"uid": model["uid"]} + ) async def upsert(self, model): model.updated_at.update() diff --git a/src/snek/system/markdown.py b/src/snek/system/markdown.py index 3160ab7..0b70367 100644 --- a/src/snek/system/markdown.py +++ b/src/snek/system/markdown.py @@ -2,7 +2,7 @@ import re from types import SimpleNamespace -from app.cache import time_cache_async, time_cache +from app.cache import time_cache, time_cache_async from mistune import HTMLRenderer, Markdown from mistune.plugins.formatting import strikethrough from mistune.plugins.spoiler import spoiler @@ -12,7 +12,6 @@ from pygments.formatters import html from pygments.lexers import get_lexer_by_name - def strip_markdown(md_text): # Remove code blocks ( md_text = re.sub(r"[\s\S]?```", "", md_text) @@ -50,7 +49,7 @@ class MarkdownRenderer(HTMLRenderer): return get_lexer_by_name(lang, stripall=True) except: return get_lexer_by_name(default, stripall=True) - + @time_cache(timeout=60 * 60) def block_code(self, code, lang=None, info=None): if not lang: diff --git a/src/snek/system/middleware.py b/src/snek/system/middleware.py index f3a13d8..fea92ba 100644 --- a/src/snek/system/middleware.py +++ b/src/snek/system/middleware.py @@ -3,31 +3,17 @@ # This code provides middleware functions for an aiohttp server to manage and modify CSP, CORS, and authentication headers. import secrets + from aiohttp import web @web.middleware async def csp_middleware(request, handler): nonce = secrets.token_hex(16) - origin = request.headers.get('Origin') - csp_policy = ( - "default-src 'self'; " - f"script-src 'self' {origin} 'nonce-{nonce}'; " - f"style-src 'self' 'unsafe-inline' {origin} 'nonce-{nonce}'; " - "img-src *; " - "connect-src 'self' https://umami.molodetz.nl; " - "font-src *; " - "object-src 'none'; " - "base-uri 'self'; " - "form-action 'self'; " - "frame-src 'self'; " - "worker-src *; " - "media-src *; " - "manifest-src 'self';" - ) - request['csp_nonce'] = nonce + request.headers.get("Origin") + request["csp_nonce"] = nonce response = await handler(request) - #response.headers['Content-Security-Policy'] = csp_policy + # response.headers['Content-Security-Policy'] = csp_policy return response @@ -37,6 +23,7 @@ async def no_cors_middleware(request, handler): response.headers.pop("Access-Control-Allow-Origin", None) return response + @web.middleware async def cors_allow_middleware(request, handler): response = await handler(request) @@ -48,6 +35,7 @@ async def cors_allow_middleware(request, handler): response.headers["Access-Control-Allow-Credentials"] = "true" return response + @web.middleware async def auth_middleware(request, handler): request["user"] = None @@ -57,6 +45,7 @@ async def auth_middleware(request, handler): ) return await handler(request) + @web.middleware async def cors_middleware(request, handler): if request.headers.get("Allow"): diff --git a/src/snek/system/service.py b/src/snek/system/service.py index 82c5d1c..5934a6e 100644 --- a/src/snek/system/service.py +++ b/src/snek/system/service.py @@ -1,5 +1,4 @@ from snek.mapper import get_mapper -from snek.model.user import UserModel from snek.system.mapper import BaseMapper @@ -40,7 +39,7 @@ class BaseService: yield record async def get(self, *args, **kwargs): - if not "deleted_at" in kwargs: + if "deleted_at" not in kwargs: kwargs["deleted_at"] = None uid = kwargs.get("uid") if args: @@ -50,7 +49,7 @@ class BaseService: if result and result.__class__ == self.mapper.model_class: return result kwargs["uid"] = uid - print(kwargs,"ZZZZZZZ") + print(kwargs, "ZZZZZZZ") result = await self.mapper.get(**kwargs) if result: await self.cache.set(result["uid"], result) diff --git a/src/snek/system/stats.py b/src/snek/system/stats.py index c2148ff..dd5c692 100644 --- a/src/snek/system/stats.py +++ b/src/snek/system/stats.py @@ -1,34 +1,41 @@ -import asyncio -from aiohttp import web, WSMsgType - -from datetime import datetime, timedelta, timezone -from collections import defaultdict import html +from collections import defaultdict +from datetime import datetime, timedelta, timezone + +from aiohttp import WSMsgType, web + def create_stats_structure(): """Creates the nested dictionary structure for storing statistics.""" + def nested_dd(): return defaultdict(lambda: defaultdict(int)) + return defaultdict(nested_dd) + def get_time_keys(dt: datetime): """Generates dictionary keys for different time granularities.""" return { - "hour": dt.strftime('%Y-%m-%d-%H'), - "day": dt.strftime('%Y-%m-%d'), - "week": dt.strftime('%Y-%W'), # Week number, Monday is first day - "month": dt.strftime('%Y-%m'), + "hour": dt.strftime("%Y-%m-%d-%H"), + "day": dt.strftime("%Y-%m-%d"), + "week": dt.strftime("%Y-%W"), # Week number, Monday is first day + "month": dt.strftime("%Y-%m"), } + def update_stats_counters(stats_dict: defaultdict, now: datetime): """Increments the appropriate time-based counters in a stats dictionary.""" keys = get_time_keys(now) - stats_dict['by_hour'][keys['hour']] += 1 - stats_dict['by_day'][keys['day']] += 1 - stats_dict['by_week'][keys['week']] += 1 - stats_dict['by_month'][keys['month']] += 1 + stats_dict["by_hour"][keys["hour"]] += 1 + stats_dict["by_day"][keys["day"]] += 1 + stats_dict["by_week"][keys["week"]] += 1 + stats_dict["by_month"][keys["month"]] += 1 -def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: str) -> str: + +def generate_time_series_svg( + title: str, data: list[tuple[str, int]], y_label: str +) -> str: """Generates a responsive SVG bar chart for time-series data.""" if not data: return f"

{html.escape(title)}

No data yet.

" @@ -36,14 +43,14 @@ def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: s svg_height, svg_width = 250, 600 bar_padding = 5 bar_width = (svg_width - 50) / len(data) - bar_padding - + bars = "" labels = "" for i, (key, val) in enumerate(data): bar_height = (val / max_val) * (svg_height - 50) if max_val > 0 else 0 x = i * (bar_width + bar_padding) + 40 y = svg_height - bar_height - 30 - + bars += f'{html.escape(key)}: {val}' labels += f'{html.escape(key)}' @@ -61,25 +68,32 @@ def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: s """ + @web.middleware async def middleware(request, handler): """Middleware to count all incoming HTTP requests.""" # Avoid counting requests to the stats page itself - if request.path.startswith('/stats.html'): + if request.path.startswith("/stats.html"): return await handler(request) - - update_stats_counters(request.app['stats']['http_requests'], datetime.now(timezone.utc)) + + update_stats_counters( + request.app["stats"]["http_requests"], datetime.now(timezone.utc) + ) return await handler(request) + def update_websocket_stats(app): - update_stats_counters(app['stats']['websocket_requests'], datetime.now(timezone.utc)) + update_stats_counters( + app["stats"]["websocket_requests"], datetime.now(timezone.utc) + ) + async def pipe_and_count_websocket(ws_from, ws_to, stats_dict): """This function proxies WebSocket messages AND counts them.""" async for msg in ws_from: # This is the key part for monitoring WebSockets update_stats_counters(stats_dict, datetime.now(timezone.utc)) - + if msg.type == WSMsgType.TEXT: await ws_to.send_str(msg.data) elif msg.type == WSMsgType.BINARY: @@ -91,27 +105,27 @@ async def pipe_and_count_websocket(ws_from, ws_to, stats_dict): async def stats_handler(request: web.Request): """Handler to display the statistics dashboard.""" - stats = request.app['stats'] + stats = request.app["stats"] now = datetime.now(timezone.utc) - + # Helper to prepare data for charts def get_data(source, period, count): data = [] for i in range(count - 1, -1, -1): - if period == 'hour': + if period == "hour": dt = now - timedelta(hours=i) - key, label = dt.strftime('%Y-%m-%d-%H'), dt.strftime('%H:00') - data.append((label, source['by_hour'].get(key, 0))) - elif period == 'day': + key, label = dt.strftime("%Y-%m-%d-%H"), dt.strftime("%H:00") + data.append((label, source["by_hour"].get(key, 0))) + elif period == "day": dt = now - timedelta(days=i) - key, label = dt.strftime('%Y-%m-%d'), dt.strftime('%a') - data.append((label, source['by_day'].get(key, 0))) + key, label = dt.strftime("%Y-%m-%d"), dt.strftime("%a") + data.append((label, source["by_day"].get(key, 0))) return data - http_hourly = get_data(stats['http_requests'], 'hour', 24) - ws_hourly = get_data(stats['ws_messages'], 'hour', 24) - http_daily = get_data(stats['http_requests'], 'day', 7) - ws_daily = get_data(stats['ws_messages'], 'day', 7) + http_hourly = get_data(stats["http_requests"], "hour", 24) + ws_hourly = get_data(stats["ws_messages"], "hour", 24) + http_daily = get_data(stats["http_requests"], "day", 7) + ws_daily = get_data(stats["ws_messages"], "day", 7) body = f""" App Stats @@ -125,5 +139,4 @@ async def stats_handler(request: web.Request): {generate_time_series_svg("WebSocket Messages", ws_daily, "Msgs/Day")} """ - return web.Response(text=body, content_type='text/html') - + return web.Response(text=body, content_type="text/html") diff --git a/src/snek/system/template.py b/src/snek/system/template.py index 5320af9..ee4d743 100644 --- a/src/snek/system/template.py +++ b/src/snek/system/template.py @@ -81,27 +81,28 @@ emoji.EMOJI_DATA[ ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + ["picture"] + def sanitize_html(value): - soup = BeautifulSoup(value, 'html.parser') + soup = BeautifulSoup(value, "html.parser") - for script in soup.find_all('script'): + for script in soup.find_all("script"): script.decompose() - #for iframe in soup.find_all('iframe'): - #iframe.decompose() + # for iframe in soup.find_all('iframe'): + # iframe.decompose() - for tag in soup.find_all(['object', 'embed']): + for tag in soup.find_all(["object", "embed"]): tag.decompose() for tag in soup.find_all(): - event_attributes = ['onclick', 'onerror', 'onload', 'onmouseover', 'onfocus'] + event_attributes = ["onclick", "onerror", "onload", "onmouseover", "onfocus"] for attr in event_attributes: if attr in tag.attrs: del tag[attr] - for img in soup.find_all('img'): - if 'onerror' in img.attrs: + for img in soup.find_all("img"): + if "onerror" in img.attrs: img.decompose() return soup.prettify() @@ -126,6 +127,7 @@ def set_link_target_blank(text): return str(soup) + def whitelist_attributes(html): return sanitize_html(html) diff --git a/src/snek/view/channel.py b/src/snek/view/channel.py index 2fcdfc1..8512068 100644 --- a/src/snek/view/channel.py +++ b/src/snek/view/channel.py @@ -13,23 +13,27 @@ from snek.system.view import BaseView register_heif_opener() -from snek.view.drive import DriveApiView +from snek.view.drive import DriveApiView + class ChannelDriveApiView(DriveApiView): - + login_required = True async def get_target(self): - target = await self.services.channel.get_home_folder(self.request.match_info.get("channel_uid")) + target = await self.services.channel.get_home_folder( + self.request.match_info.get("channel_uid") + ) target.mkdir(parents=True, exist_ok=True) return target async def get_download_url(self, rel): return f"/channel/{self.request.match_info.get('channel_uid')}/drive/{urllib.parse.quote(rel)}" + class ChannelAttachmentView(BaseView): - login_required=False + login_required = False async def get(self): relative_path = self.request.match_info.get("relative_url") @@ -114,7 +118,9 @@ class ChannelAttachmentView(BaseView): response.headers["Content-Disposition"] = ( f'attachment; filename="{channel_attachment["name"]}"' ) - response.headers["Content-Type"] = original_format or "application/octet-stream" + response.headers["Content-Type"] = ( + original_format or "application/octet-stream" + ) return response async def post(self): @@ -164,11 +170,11 @@ class ChannelAttachmentView(BaseView): class ChannelAttachmentUploadView(BaseView): - + login_required = True async def get(self): - + channel_uid = self.request.match_info.get("channel_uid") user_uid = self.request.session.get("uid") @@ -184,37 +190,44 @@ class ChannelAttachmentUploadView(BaseView): file = None filename = None - - + msg = await ws.receive() if not msg.type == web.WSMsgType.TEXT: return web.HTTPBadRequest() data = msg.json() - if not data.get('type') == 'start': + if not data.get("type") == "start": return web.HTTPBadRequest() - filename = data['filename'] + filename = data["filename"] attachment = await self.services.channel_attachment.create_file( channel_uid=channel_uid, name=filename, user_uid=user_uid ) pathlib.Path(attachment["path"]).parent.mkdir(parents=True, exist_ok=True) - + async with aiofiles.open(attachment["path"], "wb") as file: print("File openend.", filename) async for msg in ws: if msg.type == web.WSMsgType.BINARY: - print("Binary",filename) + print("Binary", filename) if file is not None: await file.write(msg.data) - await ws.send_json({"type": "progress", "filename": filename, "bytes": await file.tell()}) + await ws.send_json( + { + "type": "progress", + "filename": filename, + "bytes": await file.tell(), + } + ) elif msg.type == web.WSMsgType.TEXT: - print("TExt",filename) + print("TExt", filename) print(msg.json()) - data = msg.json() - if data.get('type') == 'end': + data = msg.json() + if data.get("type") == "end": relative_url = urllib.parse.quote(attachment["relative_url"]) - await ws.send_json({"type": "done", "file": relative_url, "filename": filename}) + await ws.send_json( + {"type": "done", "file": relative_url, "filename": filename} + ) elif msg.type == web.WSMsgType.ERROR: break return ws diff --git a/src/snek/view/container.py b/src/snek/view/container.py index 38d3a36..15008fa 100644 --- a/src/snek/view/container.py +++ b/src/snek/view/container.py @@ -1,10 +1,13 @@ -from snek.system.view import BaseView -import functools -from aiohttp import web +import functools + +from aiohttp import web + +from snek.system.view import BaseView + class ContainerView(BaseView): - login_required= True + login_required = True async def stdout_event_handler(self, ws, data): try: @@ -12,8 +15,8 @@ class ContainerView(BaseView): except Exception as ex: print(ex) await ws.close() - return False - return True + return False + return True async def create_stdout_event_handler(self, ws): return functools.partial(self.stdout_event_handler, ws) @@ -26,23 +29,27 @@ class ContainerView(BaseView): return self.HTTPUnauthorized() channel_uid = self.request.match_info.get("channel_uid") - channel_member = await self.services.channel_member.get(channel_uid=channel_uid, user_uid=self.request.session.get("uid")) + channel_member = await self.services.channel_member.get( + channel_uid=channel_uid, user_uid=self.request.session.get("uid") + ) if not channel_member: return web.HTTPUnauthorized() - + container = await self.services.container.get(channel_uid) if not container: return web.HTTPNotFound() - if not container['status'] == 'running': - resp = await self.services.container.start(channel_uid) - await ws.send_bytes(b'Container is starting\n\n') + if not container["status"] == "running": + await self.services.container.start(channel_uid) + await ws.send_bytes(b"Container is starting\n\n") container_name = await self.services.container.get_container_name(channel_uid) - + event_handler = await self.create_stdout_event_handler(ws) - await self.services.container.add_event_listener(container_name, "stdout", event_handler) + await self.services.container.add_event_listener( + container_name, "stdout", event_handler + ) while True: data = await ws.receive() @@ -54,8 +61,8 @@ class ContainerView(BaseView): elif data.type == web.WSMsgType.ERROR: break - await self.services.container.remove_event_listener(container_name, channel_uid, "stdout") - - return ws - + await self.services.container.remove_event_listener( + container_name, channel_uid, "stdout" + ) + return ws diff --git a/src/snek/view/drive.py b/src/snek/view/drive.py index 733c441..def2b8b 100644 --- a/src/snek/view/drive.py +++ b/src/snek/view/drive.py @@ -1,9 +1,6 @@ import mimetypes -import os import urllib.parse -from datetime import datetime from pathlib import Path -from urllib.parse import quote, unquote from aiohttp import web @@ -42,7 +39,7 @@ class DriveApiView(BaseView): async def get(self): target = await self.get_target() - original_target = target + original_target = target rel = self.request.query.get("path", "") offset = int(self.request.query.get("offset", 0)) limit = int(self.request.query.get("limit", 20)) diff --git a/src/snek/view/forum.py b/src/snek/view/forum.py index c72a248..4396784 100644 --- a/src/snek/view/forum.py +++ b/src/snek/view/forum.py @@ -1,19 +1,15 @@ # views/forum.py -from snek.system.view import BaseView -from aiohttp import web import json - - from aiohttp import web from snek.system.view import BaseView class ForumIndexView(BaseView): - + login_required = True - + async def get(self): if self.login_required and not self.session.get("logged_in"): return web.HTTPFound("/") @@ -71,407 +67,421 @@ class ForumIndexView(BaseView): ) - - - class ForumView(BaseView): """REST API endpoints for forum""" login_required = True async def get_forums(self): - request = self + request = self self = request.app """GET /forum/api/forums - Get all active forums""" forums = [] async for forum in self.services.forum.get_active_forums(): - forums.append({ - "uid": forum["uid"], - "name": forum["name"], - "description": forum["description"], - "slug": forum["slug"], - "icon": forum["icon"], - "thread_count": forum["thread_count"], - "post_count": forum["post_count"], - "last_post_at": forum["last_post_at"], - "last_thread_uid": forum["last_thread_uid"] - }) + forums.append( + { + "uid": forum["uid"], + "name": forum["name"], + "description": forum["description"], + "slug": forum["slug"], + "icon": forum["icon"], + "thread_count": forum["thread_count"], + "post_count": forum["post_count"], + "last_post_at": forum["last_post_at"], + "last_thread_uid": forum["last_thread_uid"], + } + ) return web.json_response({"forums": forums}) async def get_forum(self): - request = self + request = self self = request.app setattr(self, "request", request) """GET /forum/api/forums/:slug - Get forum by slug""" slug = self.request.match_info["slug"] forum = await self.services.forum.get(slug=slug, is_active=True) - + if not forum: return web.json_response({"error": "Forum not found"}, status=404) - + # Get threads threads = [] page = int(self.request.query.get("page", 1)) limit = 50 offset = (page - 1) * limit - + async for thread in forum.get_threads(limit=limit, offset=offset): # Get author info author = await self.services.user.get(uid=thread["created_by_uid"]) last_post_author = None if thread["last_post_by_uid"]: - last_post_author = await self.services.user.get(uid=thread["last_post_by_uid"]) - - threads.append({ - "uid": thread["uid"], - "title": thread["title"], - "slug": thread["slug"], - "is_pinned": thread["is_pinned"], - "is_locked": thread["is_locked"], - "view_count": thread["view_count"], - "post_count": thread["post_count"], - "created_at": thread["created_at"], - "last_post_at": thread["last_post_at"], - "author": { - "uid": author["uid"], - "username": author["username"], - "nick": author["nick"], - "color": author["color"] + last_post_author = await self.services.user.get( + uid=thread["last_post_by_uid"] + ) + + threads.append( + { + "uid": thread["uid"], + "title": thread["title"], + "slug": thread["slug"], + "is_pinned": thread["is_pinned"], + "is_locked": thread["is_locked"], + "view_count": thread["view_count"], + "post_count": thread["post_count"], + "created_at": thread["created_at"], + "last_post_at": thread["last_post_at"], + "author": { + "uid": author["uid"], + "username": author["username"], + "nick": author["nick"], + "color": author["color"], + }, + "last_post_author": ( + { + "uid": last_post_author["uid"], + "username": last_post_author["username"], + "nick": last_post_author["nick"], + "color": last_post_author["color"], + } + if last_post_author + else None + ), + } + ) + + return web.json_response( + { + "forum": { + "uid": forum["uid"], + "name": forum["name"], + "description": forum["description"], + "slug": forum["slug"], + "icon": forum["icon"], + "thread_count": forum["thread_count"], + "post_count": forum["post_count"], }, - "last_post_author": { - "uid": last_post_author["uid"], - "username": last_post_author["username"], - "nick": last_post_author["nick"], - "color": last_post_author["color"] - } if last_post_author else None - }) - - return web.json_response({ - "forum": { - "uid": forum["uid"], - "name": forum["name"], - "description": forum["description"], - "slug": forum["slug"], - "icon": forum["icon"], - "thread_count": forum["thread_count"], - "post_count": forum["post_count"] - }, - "threads": threads, - "page": page, - "hasMore": len(threads) == limit - }) + "threads": threads, + "page": page, + "hasMore": len(threads) == limit, + } + ) async def create_thread(self): - request = self + request = self self = request.app - + setattr(self, "request", request) """POST /forum/api/forums/:slug/threads - Create new thread""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + slug = self.request.match_info["slug"] forum = await self.services.forum.get(slug=slug, is_active=True) - + if not forum: return web.json_response({"error": "Forum not found"}, status=404) - + data = await self.request.json() title = data.get("title", "").strip() content = data.get("content", "").strip() - + if not title or not content: - return web.json_response({"error": "Title and content required"}, status=400) - + return web.json_response( + {"error": "Title and content required"}, status=400 + ) + try: thread, post = await self.services.thread.create_thread( forum_uid=forum["uid"], title=title, content=content, - created_by_uid=self.request.session["uid"] + created_by_uid=self.request.session["uid"], ) - - return web.json_response({ - "thread": { - "uid": thread["uid"], - "slug": thread["slug"], - "forum_slug": forum["slug"] + + return web.json_response( + { + "thread": { + "uid": thread["uid"], + "slug": thread["slug"], + "forum_slug": forum["slug"], + } } - }) + ) except Exception as e: return web.json_response({"error": str(e)}, status=400) async def get_thread(self): - request = self + request = self self = request.app setattr(self, "request", request) """GET /forum/api/threads/:thread_slug - Get thread with posts""" thread_slug = self.request.match_info["thread_slug"] thread = await self.services.thread.get(slug=thread_slug) - + if not thread: return web.json_response({"error": "Thread not found"}, status=404) - + # Increment view count await thread.increment_view_count() - + # Get forum forum = await self.services.forum.get(uid=thread["forum_uid"]) - + # Get posts posts = [] page = int(self.request.query.get("page", 1)) limit = 50 offset = (page - 1) * limit - + current_user_uid = self.request.session.get("uid") - + async for post in thread.get_posts(limit=limit, offset=offset): author = await post.get_author() is_liked = False if current_user_uid: is_liked = await post.is_liked_by(current_user_uid) - - posts.append({ - "uid": post["uid"], - "content": post["content"], - "created_at": post["created_at"], - "edited_at": post["edited_at"], - "is_first_post": post["is_first_post"], - "like_count": post["like_count"], - "is_liked": is_liked, - "author": { - "uid": author["uid"], - "username": author["username"], - "nick": author["nick"], - "color": author["color"] + + posts.append( + { + "uid": post["uid"], + "content": post["content"], + "created_at": post["created_at"], + "edited_at": post["edited_at"], + "is_first_post": post["is_first_post"], + "like_count": post["like_count"], + "is_liked": is_liked, + "author": { + "uid": author["uid"], + "username": author["username"], + "nick": author["nick"], + "color": author["color"], + }, } - }) - + ) + # Get thread author thread_author = await self.services.user.get(uid=thread["created_by_uid"]) - - return web.json_response({ - "thread": { - "uid": thread["uid"], - "title": thread["title"], - "slug": thread["slug"], - "is_pinned": thread["is_pinned"], - "is_locked": thread["is_locked"], - "view_count": thread["view_count"], - "post_count": thread["post_count"], - "created_at": thread["created_at"], - "author": { - "uid": thread_author["uid"], - "username": thread_author["username"], - "nick": thread_author["nick"], - "color": thread_author["color"] - } - }, - "forum": { - "uid": forum["uid"], - "name": forum["name"], - "slug": forum["slug"] - }, - "posts": posts, - "page": page, - "hasMore": len(posts) == limit - }) + + return web.json_response( + { + "thread": { + "uid": thread["uid"], + "title": thread["title"], + "slug": thread["slug"], + "is_pinned": thread["is_pinned"], + "is_locked": thread["is_locked"], + "view_count": thread["view_count"], + "post_count": thread["post_count"], + "created_at": thread["created_at"], + "author": { + "uid": thread_author["uid"], + "username": thread_author["username"], + "nick": thread_author["nick"], + "color": thread_author["color"], + }, + }, + "forum": { + "uid": forum["uid"], + "name": forum["name"], + "slug": forum["slug"], + }, + "posts": posts, + "page": page, + "hasMore": len(posts) == limit, + } + ) async def create_post(self): - request = self + request = self self = request.app setattr(self, "request", request) """POST /forum/api/threads/:thread_uid/posts - Create new post""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + thread_uid = self.request.match_info["thread_uid"] thread = await self.services.thread.get(uid=thread_uid) - + if not thread: return web.json_response({"error": "Thread not found"}, status=404) - + data = await self.request.json() content = data.get("content", "").strip() - + if not content: return web.json_response({"error": "Content required"}, status=400) - + try: post = await self.services.post.create_post( thread_uid=thread["uid"], forum_uid=thread["forum_uid"], content=content, - created_by_uid=self.request.session["uid"] + created_by_uid=self.request.session["uid"], ) - + author = await post.get_author() - - return web.json_response({ - "post": { - "uid": post["uid"], - "content": post["content"], - "created_at": post["created_at"], - "like_count": post["like_count"], - "is_liked": False, - "author": { - "uid": author["uid"], - "username": author["username"], - "nick": author["nick"], - "color": author["color"] + + return web.json_response( + { + "post": { + "uid": post["uid"], + "content": post["content"], + "created_at": post["created_at"], + "like_count": post["like_count"], + "is_liked": False, + "author": { + "uid": author["uid"], + "username": author["username"], + "nick": author["nick"], + "color": author["color"], + }, } } - }) + ) except Exception as e: return web.json_response({"error": str(e)}, status=400) async def edit_post(self): - request = self + request = self self = request.app setattr(self, "request", request) """PUT /forum/api/posts/:post_uid - Edit post""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + post_uid = self.request.match_info["post_uid"] data = await self.request.json() content = data.get("content", "").strip() - + if not content: return web.json_response({"error": "Content required"}, status=400) - + post = await self.services.post.edit_post( - post_uid=post_uid, - content=content, - user_uid=self.request.session["uid"] + post_uid=post_uid, content=content, user_uid=self.request.session["uid"] ) - + if not post: return web.json_response({"error": "Cannot edit post"}, status=403) - - return web.json_response({ - "post": { - "uid": post["uid"], - "content": post["content"], - "edited_at": post["edited_at"] + + return web.json_response( + { + "post": { + "uid": post["uid"], + "content": post["content"], + "edited_at": post["edited_at"], + } } - }) + ) async def delete_post(self): - request = self + request = self self = request.app setattr(self, "request", request) """DELETE /forum/api/posts/:post_uid - Delete post""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + post_uid = self.request.match_info["post_uid"] - + success = await self.services.post.delete_post( - post_uid=post_uid, - user_uid=self.request.session["uid"] + post_uid=post_uid, user_uid=self.request.session["uid"] ) - + if not success: return web.json_response({"error": "Cannot delete post"}, status=403) - + return web.json_response({"success": True}) async def toggle_like(self): - request = self + request = self self = request.app setattr(self, "request", request) """POST /forum/api/posts/:post_uid/like - Toggle post like""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + post_uid = self.request.match_info["post_uid"] - + is_liked = await self.services.post_like.toggle_like( - post_uid=post_uid, - user_uid=self.request.session["uid"] + post_uid=post_uid, user_uid=self.request.session["uid"] ) - + if is_liked is None: return web.json_response({"error": "Failed to toggle like"}, status=400) - + # Get updated post post = await self.services.post.get(uid=post_uid) - - return web.json_response({ - "is_liked": is_liked, - "like_count": post["like_count"] - }) + + return web.json_response( + {"is_liked": is_liked, "like_count": post["like_count"]} + ) async def toggle_pin(self): - request = self + request = self self = request.app setattr(self, "request", request) """POST /forum/api/threads/:thread_uid/pin - Toggle thread pin""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + thread_uid = self.request.match_info["thread_uid"] - + thread = await self.services.thread.toggle_pin( - thread_uid=thread_uid, - user_uid=self.request.session["uid"] + thread_uid=thread_uid, user_uid=self.request.session["uid"] ) - + if not thread: return web.json_response({"error": "Cannot toggle pin"}, status=403) - + return web.json_response({"is_pinned": thread["is_pinned"]}) async def toggle_lock(self): - request = self + request = self self = request.app - + setattr(self, "request", request) """POST /forum/api/threads/:thread_uid/lock - Toggle thread lock""" if not self.request.session.get("logged_in"): return web.json_response({"error": "Unauthorized"}, status=401) - + thread_uid = self.request.match_info["thread_uid"] - + thread = await self.services.thread.toggle_lock( - thread_uid=thread_uid, - user_uid=self.request.session["uid"] + thread_uid=thread_uid, user_uid=self.request.session["uid"] ) - + if not thread: return web.json_response({"error": "Cannot toggle lock"}, status=403) - + return web.json_response({"is_locked": thread["is_locked"]}) # views/forum_websocket.py class ForumWebSocketView(BaseView): """WebSocket view for real-time forum updates""" - + async def get(self): ws = web.WebSocketResponse() await ws.prepare(self.request) - + # Store WebSocket connection ws_id = self.services.forum.generate_uid() - if not hasattr(self.app, 'forum_websockets'): + if not hasattr(self.app, "forum_websockets"): self.app.forum_websockets = {} - + user_uid = self.request.session.get("uid") self.app.forum_websockets[ws_id] = { "ws": ws, "user_uid": user_uid, - "subscriptions": set() + "subscriptions": set(), } - + try: async for msg in ws: if msg.type == web.WSMsgType.TEXT: @@ -483,68 +493,76 @@ class ForumWebSocketView(BaseView): # Clean up if ws_id in self.app.forum_websockets: del self.app.forum_websockets[ws_id] - + return ws - + async def handle_ws_message(self, ws_id, data): """Handle incoming WebSocket messages""" action = data.get("action") - + if action == "subscribe": # Subscribe to forum/thread updates target_type = data.get("type") # "forum" or "thread" target_id = data.get("id") - + if target_type and target_id: subscription = f"{target_type}:{target_id}" self.app.forum_websockets[ws_id]["subscriptions"].add(subscription) - + # Send confirmation ws = self.app.forum_websockets[ws_id]["ws"] - await ws.send_str(json.dumps({ - "type": "subscribed", - "subscription": subscription - })) - + await ws.send_str( + json.dumps({"type": "subscribed", "subscription": subscription}) + ) + elif action == "unsubscribe": target_type = data.get("type") target_id = data.get("id") - + if target_type and target_id: subscription = f"{target_type}:{target_id}" self.app.forum_websockets[ws_id]["subscriptions"].discard(subscription) - + # Send confirmation ws = self.app.forum_websockets[ws_id]["ws"] - await ws.send_str(json.dumps({ - "type": "unsubscribed", - "subscription": subscription - })) - + await ws.send_str( + json.dumps({"type": "unsubscribed", "subscription": subscription}) + ) + @staticmethod async def broadcast_update(app, event_type, data): """Broadcast updates to subscribed WebSocket clients""" - if not hasattr(app, 'forum_websockets'): + if not hasattr(app, "forum_websockets"): return - + # Determine subscription targets based on event targets = set() - - if event_type in ["thread_created", "post_created", "post_edited", "post_deleted"]: + + if event_type in [ + "thread_created", + "post_created", + "post_edited", + "post_deleted", + ]: if "forum_uid" in data: targets.add(f"forum:{data['forum_uid']}") - - if event_type in ["post_created", "post_edited", "post_deleted", "post_liked", "post_unliked"]: + + if event_type in [ + "post_created", + "post_edited", + "post_deleted", + "post_liked", + "post_unliked", + ]: if "thread_uid" in data: targets.add(f"thread:{data['thread_uid']}") - + # Send to subscribed clients for ws_id, ws_data in app.forum_websockets.items(): if ws_data["subscriptions"] & targets: try: - await ws_data["ws"].send_str(json.dumps({ - "type": event_type, - "data": data - })) + await ws_data["ws"].send_str( + json.dumps({"type": event_type, "data": data}) + ) except Exception: pass diff --git a/src/snek/view/register_form.py b/src/snek/view/register_form.py index 850e672..60a9624 100644 --- a/src/snek/view/register_form.py +++ b/src/snek/view/register_form.py @@ -30,7 +30,7 @@ from snek.system.view import BaseFormView class RegisterFormView(BaseFormView): - + login_required = False form = RegisterForm diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index 88751ff..71feb86 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -6,18 +6,18 @@ # MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions. -from snek.system.stats import update_websocket_stats import asyncio import json import logging +import time import traceback -import random + from aiohttp import web from snek.system.model import now -from snek.system.profiler import Profiler +from snek.system.stats import update_websocket_stats from snek.system.view import BaseView -import time + logger = logging.getLogger(__name__) @@ -203,7 +203,7 @@ class RPCView(BaseView): } ) return channels - + async def clear_channel(self, channel_uid): self._require_login() user = await self.services.user.get(uid=self.user_uid) @@ -211,41 +211,49 @@ class RPCView(BaseView): raise Exception("Not allowed") return await self.services.channel_message.clear(channel_uid) - async def write_container(self, channel_uid, content,timeout=3): + async def write_container(self, channel_uid, content, timeout=3): self._require_login() channel_member = await self.services.channel_member.get( channel_uid=channel_uid, user_uid=self.user_uid ) if not channel_member: raise Exception("Not allowed") - - container_name = await self.services.container.get_container_name(channel_uid) + + container_name = await self.services.container.get_container_name( + channel_uid + ) class SessionCall: - def __init__(self, app,channel_uid_uid, container_name): - self.app = app - self.channel_uid = channel_uid - self.container_name = container_name + def __init__(self, app, channel_uid_uid, container_name): + self.app = app + self.channel_uid = channel_uid + self.container_name = container_name self.time_last_output = time.time() - self.output = b'' - + self.output = b"" + async def stdout_event_handler(self, data): self.time_last_output = time.time() self.output += data - return True + return True - async def communicate(self,content, timeout=3): - await self.app.services.container.add_event_listener(self.container_name, "stdout", self.stdout_event_handler) - await self.app.services.container.write_stdin(self.channel_uid, content) + async def communicate(self, content, timeout=3): + await self.app.services.container.add_event_listener( + self.container_name, "stdout", self.stdout_event_handler + ) + await self.app.services.container.write_stdin( + self.channel_uid, content + ) while time.time() - self.time_last_output < timeout: await asyncio.sleep(0.1) - await self.app.services.container.remove_event_listener(self.container_name, "stdout", self.stdout_event_handler) + await self.app.services.container.remove_event_listener( + self.container_name, "stdout", self.stdout_event_handler + ) return self.output - - sc = SessionCall(self, channel_uid,container_name) - return (await sc.communicate(content)).decode("utf-8","ignore") + + sc = SessionCall(self, channel_uid, container_name) + return (await sc.communicate(content)).decode("utf-8", "ignore") async def get_container(self, channel_uid): self._require_login() @@ -258,44 +266,44 @@ class RPCView(BaseView): result = None if container: result = { - "name": await self.services.container.get_container_name(channel_uid), + "name": await self.services.container.get_container_name( + channel_uid + ), "cpus": container["deploy"]["resources"]["limits"]["cpus"], "memory": container["deploy"]["resources"]["limits"]["memory"], "image": "ubuntu:latest", "volumes": [], - "status": container["status"] + "status": container["status"], } return result - - async def _finalize_message_task(self,message_uid): + async def _finalize_message_task(self, message_uid): try: for _ in range(7): await asyncio.sleep(1) if not self._finalize_task: - return + return except asyncio.exceptions.CancelledError: print("Finalization cancelled.") - return - - self._finalize_task = None - message = await self.services.channel_message.get(message_uid, is_final=False) + return + + self._finalize_task = None + message = await self.services.channel_message.get( + message_uid, is_final=False + ) if not message: return False if message["user_uid"] != self.user_uid: raise Exception("Not allowed") if message["is_final"]: - return True + return True if not message["is_final"]: await self.services.chat.finalize(message["uid"]) - - return True - async def _queue_finalize_message(self, message_uid): if self._finalize_task: self._finalize_task.cancel() @@ -305,18 +313,22 @@ class RPCView(BaseView): async def send_message(self, channel_uid, message, is_final=True): self._require_login() - + message = message.strip() - + if not is_final: - if_final = False - check_message = await self.services.channel_message.get(channel_uid=channel_uid, user_uid=self.user_uid,is_final=False,deleted_at=None) + check_message = await self.services.channel_message.get( + channel_uid=channel_uid, + user_uid=self.user_uid, + is_final=False, + deleted_at=None, + ) if check_message: await self._queue_finalize_message(check_message["uid"]) return await self.update_message_text(check_message["uid"], message) - + if not message: - return + return message = await self.services.chat.send( self.user_uid, channel_uid, message, is_final @@ -328,7 +340,6 @@ class RPCView(BaseView): return message["uid"] - async def start_container(self, channel_uid): self._require_login() channel_member = await self.services.channel_member.get( @@ -355,8 +366,6 @@ class RPCView(BaseView): if not channel_member: raise Exception("Not allowed") return await self.services.container.get_status(channel_uid) - - async def finalize_message(self, message_uid): self._require_login() @@ -373,8 +382,8 @@ class RPCView(BaseView): return True async def update_message_text(self, message_uid, text): - - #async with self.app.no_save(): + + # async with self.app.no_save(): self._require_login() message = await self.services.channel_message.get(message_uid) @@ -403,7 +412,6 @@ class RPCView(BaseView): "success": False, } - if not text: message["deleted_at"] = now() else: @@ -427,9 +435,6 @@ class RPCView(BaseView): return {"success": True} - - - async def clear_channel(self, channel_uid): self._require_login() user = await self.services.user.get(uid=self.user_uid) @@ -438,11 +443,10 @@ class RPCView(BaseView): channel = await self.services.channel.get(uid=channel_uid) if not channel: raise Exception("Channel not found") - channel['history_start'] = datetime.now() + channel["history_start"] = datetime.now() await self.services.channel.save(channel) return await self.services.channel_message.clear(channel_uid) - async def echo(self, *args): self._require_login() return args @@ -528,10 +532,10 @@ class RPCView(BaseView): async def _send_json(self, obj): try: await self.ws.send_str(json.dumps(obj, default=str)) - except Exception as ex: + except Exception: await self.services.socket.delete(self.ws) await self.ws.close() - + async def get_online_users(self, channel_uid): self._require_login() diff --git a/src/snek/view/web.py b/src/snek/view/web.py index 1152d4c..0773526 100644 --- a/src/snek/view/web.py +++ b/src/snek/view/web.py @@ -38,10 +38,10 @@ class WebView(BaseView): channel = await self.services.channel.get( uid=self.request.match_info.get("channel") ) - print(self.session.get("uid"),"ZZZZZZZZZZ") + print(self.session.get("uid"), "ZZZZZZZZZZ") qq = await self.services.user.get(uid=self.session.get("uid")) - - print("GGGGGGGGGG",qq) + + print("GGGGGGGGGG", qq) if not channel: user = await self.services.user.get( uid=self.request.match_info.get("channel") diff --git a/src/snek/webdav.py b/src/snek/webdav.py index eabb1e5..2b2d58d 100644 --- a/src/snek/webdav.py +++ b/src/snek/webdav.py @@ -1,22 +1,24 @@ -import logging -import pathlib import base64 import datetime +import hashlib import json +import logging import mimetypes import os +import pathlib import shutil import uuid +from urllib.parse import urlparse + import aiofiles import aiohttp import aiohttp.web -import hashlib from app.cache import time_cache_async from lxml import etree -from urllib.parse import urlparse logging.basicConfig(level=logging.DEBUG) + @aiohttp.web.middleware async def debug_middleware(request, handler): print(request.method, request.path, request.headers) @@ -28,6 +30,7 @@ async def debug_middleware(request, handler): pass return result + class WebdavApplication(aiohttp.web.Application): def __init__(self, parent, *args, **kwargs): middlewares = [debug_middleware] @@ -78,7 +81,7 @@ class WebdavApplication(aiohttp.web.Application): return request["user"] async def is_locked(self, abs_path, request): - rel_path = abs_path.relative_to(request["home"]).as_posix() + abs_path.relative_to(request["home"]).as_posix() path = abs_path while path != request["home"].parent: rel = path.relative_to(request["home"]).as_posix() @@ -122,14 +125,20 @@ class WebdavApplication(aiohttp.web.Application): return aiohttp.web.Response(status=403, text="Cannot download a directory") lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") content_type, _ = mimetypes.guess_type(str(abs_path)) content_type = content_type or "application/octet-stream" etag = f'"{hashlib.sha1(str(abs_path.stat().st_mtime).encode()).hexdigest()}"' headers = {"Content-Type": content_type, "ETag": etag} - return aiohttp.web.FileResponse(path=str(abs_path), headers=headers, chunk_size=8192) + return aiohttp.web.FileResponse( + path=str(abs_path), headers=headers, chunk_size=8192 + ) async def handle_head(self, request): if not await self.authenticate(request): @@ -141,10 +150,16 @@ class WebdavApplication(aiohttp.web.Application): if not abs_path.exists(): return aiohttp.web.Response(status=404, text="File not found") if abs_path.is_dir(): - return aiohttp.web.Response(status=403, text="Cannot get metadata for a directory") + return aiohttp.web.Response( + status=403, text="Cannot get metadata for a directory" + ) lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") content_type, _ = mimetypes.guess_type(str(abs_path)) @@ -168,7 +183,11 @@ class WebdavApplication(aiohttp.web.Application): abs_path = request["home"] / requested_path lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") abs_path.parent.mkdir(parents=True, exist_ok=True) @@ -186,7 +205,11 @@ class WebdavApplication(aiohttp.web.Application): abs_path = request["home"] / requested_path lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") if abs_path.is_file(): @@ -199,7 +222,9 @@ class WebdavApplication(aiohttp.web.Application): del self.locks[rel_path] return aiohttp.web.Response(status=204) elif abs_path.is_dir(): - if await self.has_descendant_locks(abs_path.relative_to(request["home"]).as_posix(), True): + if await self.has_descendant_locks( + abs_path.relative_to(request["home"]).as_posix(), True + ): return aiohttp.web.Response(status=423, text="Locked") shutil.rmtree(abs_path) rel_path = abs_path.relative_to(request["home"]).as_posix() @@ -217,7 +242,11 @@ class WebdavApplication(aiohttp.web.Application): abs_path = request["home"] / requested_path lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") if abs_path.exists(): @@ -235,15 +264,25 @@ class WebdavApplication(aiohttp.web.Application): if not destination: return aiohttp.web.Response(status=400, text="Destination header missing") parsed = urlparse(destination) - if parsed.scheme and (parsed.scheme != request.scheme or parsed.netloc != request.host): + if parsed.scheme and ( + parsed.scheme != request.scheme or parsed.netloc != request.host + ): return aiohttp.web.Response(status=502, text="Bad Gateway") - dest_rel = parsed.path[len(self.relative_url)+1:] if parsed.path.startswith(self.relative_url) else "" + dest_rel = ( + parsed.path[len(self.relative_url) + 1 :] + if parsed.path.startswith(self.relative_url) + else "" + ) dest_path = request["home"] / dest_rel if not src_path.exists(): return aiohttp.web.Response(status=404, text="Source not found") src_lock = await self.is_locked(src_path, request) dest_lock = await self.is_locked(dest_path, request) - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if src_lock and src_lock["token"] != submitted: return aiohttp.web.Response(status=423, text="Source Locked") if dest_lock and dest_lock["token"] != submitted: @@ -277,15 +316,25 @@ class WebdavApplication(aiohttp.web.Application): if not destination: return aiohttp.web.Response(status=400, text="Destination header missing") parsed = urlparse(destination) - if parsed.scheme and (parsed.scheme != request.scheme or parsed.netloc != request.host): + if parsed.scheme and ( + parsed.scheme != request.scheme or parsed.netloc != request.host + ): return aiohttp.web.Response(status=502, text="Bad Gateway") - dest_rel = parsed.path[len(self.relative_url)+1:] if parsed.path.startswith(self.relative_url) else "" + dest_rel = ( + parsed.path[len(self.relative_url) + 1 :] + if parsed.path.startswith(self.relative_url) + else "" + ) dest_path = request["home"] / dest_rel if not src_path.exists(): return aiohttp.web.Response(status=404, text="Source not found") src_lock = await self.is_locked(src_path, request) dest_lock = await self.is_locked(dest_path, request) - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if src_lock and src_lock["token"] != submitted: return aiohttp.web.Response(status=423, text="Source Locked") if dest_lock and dest_lock["token"] != submitted: @@ -347,7 +396,16 @@ class WebdavApplication(aiohttp.web.Application): statvfs = await loop.run_in_executor(None, os.statvfs, path) return statvfs.f_bavail * statvfs.f_frsize - async def create_propfind_node(self, request, response_xml, full_path, depth, requested_props, propname_only, all_props): + async def create_propfind_node( + self, + request, + response_xml, + full_path, + depth, + requested_props, + propname_only, + all_props, + ): abs_path = pathlib.Path(full_path) relative_path = abs_path.relative_to(request["home"]).as_posix() href_path = f"{self.relative_url}/{relative_path}".strip(".") @@ -379,7 +437,9 @@ class WebdavApplication(aiohttp.web.Application): prop_notfound = None for prop_tag, prop_nsmap in requested_props: if prop_tag in standard_props: - await self.add_standard_property(prop_ok, prop_tag, full_path, request) + await self.add_standard_property( + prop_ok, prop_tag, full_path, request + ) found_props.append(prop_tag) elif prop_tag in custom_properties: if propname_only: @@ -391,23 +451,31 @@ class WebdavApplication(aiohttp.web.Application): else: if propstat_notfound is None: propstat_notfound = etree.SubElement(response, "{DAV:}propstat") - prop_notfound = etree.SubElement(propstat_notfound, "{DAV:}prop") + prop_notfound = etree.SubElement( + propstat_notfound, "{DAV:}prop" + ) etree.SubElement(prop_notfound, prop_tag, nsmap=prop_nsmap) if propstat_notfound is not None: - etree.SubElement(propstat_notfound, "{DAV:}status").text = "HTTP/1.1 404 Not Found" + etree.SubElement(propstat_notfound, "{DAV:}status").text = ( + "HTTP/1.1 404 Not Found" + ) else: for prop_name in standard_props: if propname_only: etree.SubElement(prop_ok, prop_name) else: - await self.add_standard_property(prop_ok, prop_name, full_path, request) + await self.add_standard_property( + prop_ok, prop_name, full_path, request + ) for prop_name, prop_value in custom_properties.items(): if propname_only: if prop_name.startswith("{"): ns_end = prop_name.find("}") ns = prop_name[1:ns_end] - local_name = prop_name[ns_end+1:] - prop_nsmap = {"D": "DAV:", None: ns} if ns != "DAV:" else {"D": "DAV:"} + local_name = prop_name[ns_end + 1 :] + prop_nsmap = ( + {"D": "DAV:", None: ns} if ns != "DAV:" else {"D": "DAV:"} + ) etree.SubElement(prop_ok, local_name, nsmap=prop_nsmap) else: etree.SubElement(prop_ok, prop_name) @@ -417,9 +485,19 @@ class WebdavApplication(aiohttp.web.Application): etree.SubElement(propstat_ok, "{DAV:}status").text = "HTTP/1.1 200 OK" if abs_path.is_dir() and depth > 0: for item in abs_path.iterdir(): - if item.name.startswith(".") and item.name.endswith(".webdav_props.json"): + if item.name.startswith(".") and item.name.endswith( + ".webdav_props.json" + ): continue - await self.create_propfind_node(request, response_xml, item, depth - 1, requested_props, propname_only, all_props) + await self.create_propfind_node( + request, + response_xml, + item, + depth - 1, + requested_props, + propname_only, + all_props, + ) async def add_standard_property(self, prop_elem, prop_name, full_path, request): if prop_name == "{DAV:}resourcetype": @@ -445,14 +523,22 @@ class WebdavApplication(aiohttp.web.Application): etree.SubElement(locktype, f"{{DAV:}}{lock['type']}") lockscope = etree.SubElement(activelock, "{DAV:}lockscope") etree.SubElement(lockscope, f"{{DAV:}}{lock['scope']}") - etree.SubElement(activelock, "{DAV:}depth").text = lock['depth'].capitalize() - if lock['owner']: - owner = etree.fromstring(lock['owner']) + etree.SubElement(activelock, "{DAV:}depth").text = lock[ + "depth" + ].capitalize() + if lock["owner"]: + owner = etree.fromstring(lock["owner"]) activelock.append(owner) - timeout_str = "Infinite" if lock['timeout'] is None else f"Second-{lock['timeout']}" + timeout_str = ( + "Infinite" + if lock["timeout"] is None + else f"Second-{lock['timeout']}" + ) etree.SubElement(activelock, "{DAV:}timeout").text = timeout_str locktoken = etree.SubElement(activelock, "{DAV:}locktoken") - etree.SubElement(locktoken, "{DAV:}href").text = f"opaquelocktoken:{lock['token']}" + etree.SubElement(locktoken, "{DAV:}href").text = ( + f"opaquelocktoken:{lock['token']}" + ) elif prop_name == "{DAV:}supportedlock": supported_lock = etree.SubElement(prop_elem, "{DAV:}supportedlock") lock_entry_1 = etree.SubElement(supported_lock, "{DAV:}lockentry") @@ -466,11 +552,17 @@ class WebdavApplication(aiohttp.web.Application): lock_type_2 = etree.SubElement(lock_entry_2, "{DAV:}locktype") etree.SubElement(lock_type_2, "{DAV:}write") elif prop_name == "{DAV:}quota-used-bytes": - size = await self.get_file_size(full_path) if full_path.is_file() else await self.get_directory_size(full_path) + size = ( + await self.get_file_size(full_path) + if full_path.is_file() + else await self.get_directory_size(full_path) + ) etree.SubElement(prop_elem, "{DAV:}quota-used-bytes").text = str(size) elif prop_name == "{DAV:}quota-available-bytes": free_space = await self.get_disk_free_space(str(full_path)) - etree.SubElement(prop_elem, "{DAV:}quota-available-bytes").text = str(free_space) + etree.SubElement(prop_elem, "{DAV:}quota-available-bytes").text = str( + free_space + ) elif prop_name == "{DAV:}getcontentlength" and full_path.is_file(): size = await self.get_file_size(full_path) etree.SubElement(prop_elem, "{DAV:}getcontentlength").text = str(size) @@ -479,7 +571,9 @@ class WebdavApplication(aiohttp.web.Application): if mimetype: etree.SubElement(prop_elem, "{DAV:}getcontenttype").text = mimetype elif prop_name == "{DAV:}getetag" and full_path.is_file(): - etag = f'"{hashlib.sha1(str(full_path.stat().st_mtime).encode()).hexdigest()}"' + etag = ( + f'"{hashlib.sha1(str(full_path.stat().st_mtime).encode()).hexdigest()}"' + ) etree.SubElement(prop_elem, "{DAV:}getetag").text = etag async def handle_propfind(self, request): @@ -520,9 +614,21 @@ class WebdavApplication(aiohttp.web.Application): pass nsmap = {"D": "DAV:"} response_xml = etree.Element("{DAV:}multistatus", nsmap=nsmap) - await self.create_propfind_node(request, response_xml, abs_path, depth, requested_props, propname_only, all_props) - xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode() - return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml") + await self.create_propfind_node( + request, + response_xml, + abs_path, + depth, + requested_props, + propname_only, + all_props, + ) + xml_output = etree.tostring( + response_xml, encoding="utf-8", xml_declaration=True + ).decode() + return aiohttp.web.Response( + status=207, text=xml_output, content_type="application/xml" + ) async def handle_proppatch(self, request): if not await self.authenticate(request): @@ -535,7 +641,11 @@ class WebdavApplication(aiohttp.web.Application): return aiohttp.web.Response(status=404, text="Resource not found") lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") body = await request.read() @@ -549,7 +659,11 @@ class WebdavApplication(aiohttp.web.Application): response = etree.SubElement(response_xml, "{DAV:}response") href = etree.SubElement(response, "{DAV:}href") relative_path = abs_path.relative_to(request["home"]).as_posix() - href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/") + href_path = ( + f"{self.relative_url}/{relative_path}".strip(".") + .replace("./", "/") + .replace("//", "/") + ) href.text = href_path set_elem = root.find(".//{DAV:}set") if set_elem is not None: @@ -561,14 +675,14 @@ class WebdavApplication(aiohttp.web.Application): if child.tag.startswith("{"): ns_end = child.tag.find("}") ns = child.tag[1:ns_end] - local_name = child.tag[ns_end+1:] + local_name = child.tag[ns_end + 1 :] else: ns = "DAV:" local_name = child.tag prop_name = f"{{{ns}}}{local_name}" prop_value = child.text or "" properties[prop_name] = prop_value - elem = etree.SubElement(prop, child.tag, nsmap=child.nsmap) + etree.SubElement(prop, child.tag, nsmap=child.nsmap) etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK" remove_elem = root.find(".//{DAV:}remove") if remove_elem is not None: @@ -580,7 +694,7 @@ class WebdavApplication(aiohttp.web.Application): if child.tag.startswith("{"): ns_end = child.tag.find("}") ns = child.tag[1:ns_end] - local_name = child.tag[ns_end+1:] + local_name = child.tag[ns_end + 1 :] else: ns = "DAV:" local_name = child.tag @@ -590,14 +704,18 @@ class WebdavApplication(aiohttp.web.Application): status = "HTTP/1.1 200 OK" else: status = "HTTP/1.1 404 Not Found" - elem = etree.SubElement(prop, child.tag, nsmap=child.nsmap) + etree.SubElement(prop, child.tag, nsmap=child.nsmap) etree.SubElement(propstat, "{DAV:}status").text = status if properties: await self.save_properties(abs_path, properties) else: await self.delete_properties_file(abs_path) - xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode() - return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml") + xml_output = etree.tostring( + response_xml, encoding="utf-8", xml_declaration=True + ).decode() + return aiohttp.web.Response( + status=207, text=xml_output, content_type="application/xml" + ) except Exception as e: logging.error(f"PROPPATCH error: {e}") return aiohttp.web.Response(status=400, text="Bad Request") @@ -639,11 +757,19 @@ class WebdavApplication(aiohttp.web.Application): scope = "shared" locktype = "write" owner_elem = lockinfo.find("{DAV:}owner") - owner_xml = etree.tostring(owner_elem, encoding="unicode") if owner_elem is not None else "" + owner_xml = ( + etree.tostring(owner_elem, encoding="unicode") + if owner_elem is not None + else "" + ) covering_lock = await self.is_locked(abs_path, request) if covering_lock: return aiohttp.web.Response(status=423, text="Locked") - if abs_path.is_dir() and depth == "infinity" and await self.has_descendant_locks(resource, True): + if ( + abs_path.is_dir() + and depth == "infinity" + and await self.has_descendant_locks(resource, True) + ): return aiohttp.web.Response(status=423, text="Locked") token = str(uuid.uuid4()) lock_dict = { @@ -653,12 +779,17 @@ class WebdavApplication(aiohttp.web.Application): "owner": owner_xml, "depth": depth, "timeout": timeout, - "created": datetime.datetime.utcnow() + "created": datetime.datetime.utcnow(), } self.locks[resource] = lock_dict xml = await self.generate_lock_response(lock_dict) headers = {"Lock-Token": f""} - return aiohttp.web.Response(status=200, text=xml, content_type="application/xml", headers=headers) + return aiohttp.web.Response( + status=200, + text=xml, + content_type="application/xml", + headers=headers, + ) except Exception as e: logging.error(f"LOCK error: {e}") return aiohttp.web.Response(status=400, text="Bad Request") @@ -669,14 +800,20 @@ class WebdavApplication(aiohttp.web.Application): if self.is_lock_expired(lock): del self.locks[resource] return aiohttp.web.Response(status=412, text="Precondition Failed") - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=409, text="Conflict") lock["created"] = datetime.datetime.utcnow() lock["timeout"] = timeout xml = await self.generate_lock_response(lock) headers = {"Lock-Token": f""} - return aiohttp.web.Response(status=200, text=xml, content_type="application/xml", headers=headers) + return aiohttp.web.Response( + status=200, text=xml, content_type="application/xml", headers=headers + ) async def handle_unlock(self, request): if not await self.authenticate(request): @@ -690,7 +827,11 @@ class WebdavApplication(aiohttp.web.Application): if self.is_lock_expired(lock): del self.locks[resource] return aiohttp.web.Response(status=409, text="Conflict") - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=400, text="Invalid Lock Token") del self.locks[resource] @@ -705,14 +846,22 @@ class WebdavApplication(aiohttp.web.Application): etree.SubElement(locktype, f"{{DAV:}}{lock_dict['type']}") lockscope = etree.SubElement(activelock, "{DAV:}lockscope") etree.SubElement(lockscope, f"{{DAV:}}{lock_dict['scope']}") - etree.SubElement(activelock, "{DAV:}depth").text = lock_dict['depth'].capitalize() - if lock_dict['owner']: - owner = etree.fromstring(lock_dict['owner']) + etree.SubElement(activelock, "{DAV:}depth").text = lock_dict[ + "depth" + ].capitalize() + if lock_dict["owner"]: + owner = etree.fromstring(lock_dict["owner"]) activelock.append(owner) - timeout_str = "Infinite" if lock_dict['timeout'] is None else f"Second-{lock_dict['timeout']}" + timeout_str = ( + "Infinite" + if lock_dict["timeout"] is None + else f"Second-{lock_dict['timeout']}" + ) etree.SubElement(activelock, "{DAV:}timeout").text = timeout_str locktoken = etree.SubElement(activelock, "{DAV:}locktoken") - etree.SubElement(locktoken, "{DAV:}href").text = f"opaquelocktoken:{lock_dict['token']}" + etree.SubElement(locktoken, "{DAV:}href").text = ( + f"opaquelocktoken:{lock_dict['token']}" + ) return etree.tostring(prop, encoding="utf-8", xml_declaration=True).decode() def get_last_modified(self, path): @@ -731,7 +880,7 @@ class WebdavApplication(aiohttp.web.Application): async def load_properties(self, resource_path): props_file = self.get_props_file_path(resource_path) if props_file.exists(): - async with aiofiles.open(props_file, "r") as f: + async with aiofiles.open(props_file) as f: content = await f.read() return json.loads(content) return {} @@ -773,7 +922,11 @@ class WebdavApplication(aiohttp.web.Application): response = etree.SubElement(response_xml, "{DAV:}response") href = etree.SubElement(response, "{DAV:}href") relative_path = abs_path.relative_to(request["home"]).as_posix() - href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/") + href_path = ( + f"{self.relative_url}/{relative_path}".strip(".") + .replace("./", "/") + .replace("//", "/") + ) href.text = href_path propstat = etree.SubElement(response, "{DAV:}propstat") prop = etree.SubElement(propstat, "{DAV:}prop") @@ -787,8 +940,12 @@ class WebdavApplication(aiohttp.web.Application): elem = etree.SubElement(prop, prop_name) elem.text = str(prop_value) etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK" - xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode() - return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml") + xml_output = etree.tostring( + response_xml, encoding="utf-8", xml_declaration=True + ).decode() + return aiohttp.web.Response( + status=207, text=xml_output, content_type="application/xml" + ) async def handle_propset(self, request): if not await self.authenticate(request): @@ -801,7 +958,11 @@ class WebdavApplication(aiohttp.web.Application): return aiohttp.web.Response(status=404, text="Resource not found") lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") body = await request.read() @@ -824,17 +985,25 @@ class WebdavApplication(aiohttp.web.Application): response = etree.SubElement(response_xml, "{DAV:}response") href = etree.SubElement(response, "{DAV:}href") relative_path = abs_path.relative_to(request["home"]).as_posix() - href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/") + href_path = ( + f"{self.relative_url}/{relative_path}".strip(".") + .replace("./", "/") + .replace("//", "/") + ) href.text = href_path propstat = etree.SubElement(response, "{DAV:}propstat") prop = etree.SubElement(propstat, "{DAV:}prop") for child in prop_elem: ns = child.nsmap.get(None, "DAV:") prop_name = f"{{{ns}}}{child.tag.split('}')[-1]}" - elem = etree.SubElement(prop, prop_name) + etree.SubElement(prop, prop_name) etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK" - xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode() - return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml") + xml_output = etree.tostring( + response_xml, encoding="utf-8", xml_declaration=True + ).decode() + return aiohttp.web.Response( + status=207, text=xml_output, content_type="application/xml" + ) except Exception as e: logging.error(f"PROPSET error: {e}") return aiohttp.web.Response(status=400, text="Bad Request") @@ -850,7 +1019,11 @@ class WebdavApplication(aiohttp.web.Application): return aiohttp.web.Response(status=404, text="Resource not found") lock = await self.is_locked(abs_path, request) if lock: - submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "") + submitted = ( + request.headers.get("Lock-Token", "") + .strip(" <>") + .replace("opaquelocktoken:", "") + ) if submitted != lock["token"]: return aiohttp.web.Response(status=423, text="Locked") body = await request.read() @@ -877,17 +1050,25 @@ class WebdavApplication(aiohttp.web.Application): response = etree.SubElement(response_xml, "{DAV:}response") href = etree.SubElement(response, "{DAV:}href") relative_path = abs_path.relative_to(request["home"]).as_posix() - href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/") + href_path = ( + f"{self.relative_url}/{relative_path}".strip(".") + .replace("./", "/") + .replace("//", "/") + ) href.text = href_path propstat = etree.SubElement(response, "{DAV:}propstat") prop = etree.SubElement(propstat, "{DAV:}prop") for child in prop_elem: ns = child.nsmap.get(None, "DAV:") prop_name = f"{{{ns}}}{child.tag.split('}')[-1]}" - elem = etree.SubElement(prop, prop_name) + etree.SubElement(prop, prop_name) etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK" - xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode() - return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml") + xml_output = etree.tostring( + response_xml, encoding="utf-8", xml_declaration=True + ).decode() + return aiohttp.web.Response( + status=207, text=xml_output, content_type="application/xml" + ) except Exception as e: logging.error(f"PROPDEL error: {e}") return aiohttp.web.Response(status=400, text="Bad Request") diff --git a/terminal/entry b/terminal/entry index da922e0..332a935 100755 --- a/terminal/entry +++ b/terminal/entry @@ -1,10 +1,8 @@ #!/usr/bin/env python3 import os +import pathlib import sys -import pathlib - -import os os.chdir("/root") @@ -16,7 +14,9 @@ if not pathlib.Path(".welcome.txt").exists(): os.system("python3 -m venv --prompt '' .venv") os.system("cp -r /opt/bootstrap/root/.* /root") os.system("cp /opt/bootstrap/.welcome.txt /root/.welcome.txt") - pathlib.Path(".bashrc").write_text(pathlib.Path(".bashrc").read_text() + "\n" + "source .venv/bin/activate") + pathlib.Path(".bashrc").write_text( + pathlib.Path(".bashrc").read_text() + "\n" + "source .venv/bin/activate" + ) os.environ["SNEK"] = "1" if pathlib.Path(".welcome.txt").exists():