This commit is contained in:
retoor 2025-08-10 00:59:11 +02:00
parent 30382b139d
commit a3ab7b0586
35 changed files with 1108 additions and 785 deletions

View File

@ -26,9 +26,9 @@ Description: Utility to load environment variables from a .env file.
"""
import os
from typing import Optional
def load_env(file_path: str = '.env') -> None:
def load_env(file_path: str = ".env") -> None:
"""
Loads environment variables from a specified file into the current process environment.
@ -43,26 +43,26 @@ def load_env(file_path: str = '.env') -> None:
IOError: If an I/O error occurs during file reading.
"""
try:
with open(file_path, 'r') as env_file:
with open(file_path) as env_file:
for line in env_file:
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
# Skip lines without '='
if '=' not in line:
if "=" not in line:
continue
# Split into key and value at the first '='
key, value = line.split('=', 1)
key, value = line.split("=", 1)
# Set environment variable
os.environ[key.strip()] = value.strip()
except FileNotFoundError:
raise FileNotFoundError(f"Environment file '{file_path}' not found.")
except IOError as e:
raise IOError(f"Error reading environment file '{file_path}': {e}")
except OSError as e:
raise OSError(f"Error reading environment file '{file_path}': {e}")
try:
load_env()
except Exception as e:
except Exception:
pass

View File

@ -1,20 +1,20 @@
import asyncio
import pathlib
import shutil
import sqlite3
import asyncio
import click
from aiohttp import web
from IPython import start_ipython
from snek.shell import Shell
from snek.app import Application
from snek.shell import Shell
@click.group()
def cli():
pass
@cli.command()
def export():
@ -30,6 +30,7 @@ def export():
user = await app.services.user.get(uid=message["user_uid"])
message["user"] = user and user["username"] or None
return (message["user"] or "") + ": " + (message["text"] or "")
async def run():
result = []
@ -49,6 +50,7 @@ def export():
with open("dump.txt", "w") as f:
f.write("\n\n".join(result))
print("Dump written to dump.json")
asyncio.run(run())
@ -57,16 +59,20 @@ def statistics():
async def run():
app = Application(db_path="sqlite:///snek.db")
app.services.statistics.database()
asyncio.run(run())
@cli.command()
def maintenance():
async def run():
app = Application(db_path="sqlite:///snek.db")
await app.services.container.maintenance()
await app.services.channel_message.maintenance()
asyncio.run(run())
@cli.command()
@click.option(
"--db_path", default="snek.db", help="Database to initialize if not exists."
@ -123,9 +129,11 @@ def serve(port, host, db_path):
def shell(db_path):
Shell(db_path).run()
def main():
try:
import sentry_sdk
sentry_sdk.init("https://ab6147c2f3354c819768c7e89455557b@gt.molodetz.nl/1")
except ImportError:
print("Could not import sentry_sdk")

View File

@ -3,13 +3,13 @@ import logging
import pathlib
import ssl
import uuid
import signal
from datetime import datetime
from contextlib import asynccontextmanager
from datetime import datetime
from snek import snode
from snek.view.threads import ThreadsView
from snek.system.ads import AsyncDataSet
from snek.view.threads import ThreadsView
logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
@ -25,16 +25,21 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader
from snek.forum import setup_forum
from snek.mapper import get_mappers
from snek.service import get_services
from snek.sgit import GitApplication
from snek.sssh import start_ssh_server
from snek.system import http
from snek.system.cache import Cache
from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler
from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler
from snek.system.stats import (
create_stats_structure,
middleware as stats_middleware,
stats_handler,
)
from snek.system.template import (
EmojiExtension,
LinkifyExtension,
@ -43,10 +48,15 @@ from snek.system.template import (
)
from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView
from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
from snek.view.channel import (
ChannelAttachmentUploadView,
ChannelAttachmentView,
ChannelDriveApiView,
ChannelView,
)
from snek.view.container import ContainerView
from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveApiView, DriveView
from snek.view.channel import ChannelDriveApiView
from snek.view.index import IndexView
from snek.view.login import LoginView
from snek.view.logout import LogoutView
@ -55,7 +65,6 @@ from snek.view.register import RegisterView
from snek.view.repository import RepositoryView
from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView
from snek.view.container import ContainerView
from snek.view.settings.containers import (
ContainersCreateView,
ContainersDeleteView,
@ -77,7 +86,6 @@ from snek.view.upload import UploadView
from snek.view.user import UserView
from snek.view.web import WebView
from snek.webdav import WebdavApplication
from snek.forum import setup_forum
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
from snek.system.template import whitelist_attributes
@ -175,13 +183,12 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server)
#self.on_startup.append(self.prepare_database)
# self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app):
app['stats'] = create_stats_structure()
app["stats"] = create_stats_structure()
print("Stats prepared", flush=True)
@property
def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds()
@ -225,10 +232,10 @@ class Application(BaseApplication):
# app.loop = asyncio.get_running_loop()
app.executor = ThreadPoolExecutor(max_workers=200)
app.loop.set_default_executor(self.executor)
#for sig in (signal.SIGINT, signal.SIGTERM):
#app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown())
#)
# for sig in (signal.SIGINT, signal.SIGTERM):
# app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown())
# )
async def create_task(self, task):
await self.tasks.put(task)
@ -244,7 +251,6 @@ class Application(BaseApplication):
print(ex)
self.db.commit()
async def prepare_database(self, app):
await self.db.query_raw("PRAGMA journal_mode=WAL")
await self.db.query_raw("PRAGMA syncnorm=off")
@ -291,9 +297,7 @@ class Application(BaseApplication):
self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
)
self.router.add_view(
"/channel/{channel_uid}/drive.json", ChannelDriveApiView
)
self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView)
self.router.add_view(
"/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
)
@ -414,9 +418,9 @@ class Application(BaseApplication):
)
try:
context["nonce"] = request['csp_nonce']
context["nonce"] = request["csp_nonce"]
except:
context['nonce'] = '?'
context["nonce"] = "?"
rendered = await super().render_template(template, request, context)
@ -466,14 +470,14 @@ class Application(BaseApplication):
@asynccontextmanager
async def no_save(self):
stats = {
'count': 0
}
stats = {"count": 0}
async def patched_save(*args, **kwargs):
await self.cache.set(args[0]["uid"], args[0])
stats['count'] = stats['count'] + 1
stats["count"] = stats["count"] + 1
print(f"save is ignored {stats['count']} times")
return args[0]
save_original = self.services.channel_message.mapper.save
self.services.channel_message.mapper.save = patched_save
raised_exception = None
@ -486,6 +490,7 @@ class Application(BaseApplication):
if raised_exception:
raise raised_exception
app = Application(db_path="sqlite:///snek.db")

View File

@ -1,5 +1,6 @@
# forum_app.py
import aiohttp.web
from snek.view.forum import ForumIndexView, ForumView, ForumWebSocketView
@ -7,13 +8,13 @@ class ForumApplication(aiohttp.web.Application):
def __init__(self, parent, *args, **kwargs):
super().__init__(*args, **kwargs)
self.parent = parent
self.render_template = self.parent.render_template
self.render_template = self.parent.render_template
# Set up routes
self.setup_routes()
# Set up notification listeners
self.setup_notifications()
@property
def db(self):
return self.parent.db
@ -21,86 +22,84 @@ class ForumApplication(aiohttp.web.Application):
@property
def services(self):
return self.parent.services
def setup_routes(self):
"""Set up all forum routes"""
# API routes
self.router.add_view("/index.html", ForumIndexView)
self.router.add_route("GET", "/api/forums", ForumView.get_forums)
self.router.add_route("GET", "/api/forums/{slug}", ForumView.get_forum)
self.router.add_route("POST", "/api/forums/{slug}/threads", ForumView.create_thread)
self.router.add_route(
"POST", "/api/forums/{slug}/threads", ForumView.create_thread
)
self.router.add_route("GET", "/api/threads/{thread_slug}", ForumView.get_thread)
self.router.add_route("POST", "/api/threads/{thread_uid}/posts", ForumView.create_post)
self.router.add_route(
"POST", "/api/threads/{thread_uid}/posts", ForumView.create_post
)
self.router.add_route("PUT", "/api/posts/{post_uid}", ForumView.edit_post)
self.router.add_route("DELETE", "/api/posts/{post_uid}", ForumView.delete_post)
self.router.add_route("POST", "/api/posts/{post_uid}/like", ForumView.toggle_like)
self.router.add_route("POST", "/api/threads/{thread_uid}/pin", ForumView.toggle_pin)
self.router.add_route("POST", "/api/threads/{thread_uid}/lock", ForumView.toggle_lock)
self.router.add_route(
"POST", "/api/posts/{post_uid}/like", ForumView.toggle_like
)
self.router.add_route(
"POST", "/api/threads/{thread_uid}/pin", ForumView.toggle_pin
)
self.router.add_route(
"POST", "/api/threads/{thread_uid}/lock", ForumView.toggle_lock
)
# WebSocket route
self.router.add_view("/ws", ForumWebSocketView)
# Static HTML route
self.router.add_route("GET", "/{path:.*}", self.serve_forum_html)
def setup_notifications(self):
"""Set up notification listeners for WebSocket broadcasting"""
# Forum notifications
self.services.forum.add_notification_listener("forum_created", self.on_forum_event)
self.services.forum.add_notification_listener(
"forum_created", self.on_forum_event
)
# Thread notifications
self.services.thread.add_notification_listener("thread_created", self.on_thread_event)
self.services.thread.add_notification_listener(
"thread_created", self.on_thread_event
)
# Post notifications
self.services.post.add_notification_listener("post_created", self.on_post_event)
self.services.post.add_notification_listener("post_edited", self.on_post_event)
self.services.post.add_notification_listener("post_deleted", self.on_post_event)
# Like notifications
self.services.post_like.add_notification_listener("post_liked", self.on_like_event)
self.services.post_like.add_notification_listener("post_unliked", self.on_like_event)
self.services.post_like.add_notification_listener(
"post_liked", self.on_like_event
)
self.services.post_like.add_notification_listener(
"post_unliked", self.on_like_event
)
async def on_forum_event(self, event_type, data):
"""Handle forum events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
async def on_thread_event(self, event_type, data):
"""Handle thread events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
async def on_post_event(self, event_type, data):
"""Handle post events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
async def on_like_event(self, event_type, data):
"""Handle like events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
async def serve_forum_html(self, request):
"""Serve the forum HTML with the web component"""
html = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Forum</title>
<script type="module" src="/forum/static/snek-forum.js"></script>
<style>
body {
margin: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background-color: #f5f5f5;
}
</style>
</head>
<body>
<snek-forum></snek-forum>
</body>
</html>"""
return await self.parent.render_template("forum.html", request)
#return aiohttp.web.Response(text=html, content_type="text/html")
# return aiohttp.web.Response(text=html, content_type="text/html")
# Integration with main app

View File

@ -7,12 +7,12 @@ from snek.mapper.channel_message import ChannelMessageMapper
from snek.mapper.container import ContainerMapper
from snek.mapper.drive import DriveMapper
from snek.mapper.drive_item import DriveItemMapper
from snek.mapper.forum import ForumMapper, PostLikeMapper, PostMapper, ThreadMapper
from snek.mapper.notification import NotificationMapper
from snek.mapper.push import PushMapper
from snek.mapper.repository import RepositoryMapper
from snek.mapper.user import UserMapper
from snek.mapper.user_property import UserPropertyMapper
from snek.mapper.forum import ForumMapper, ThreadMapper, PostMapper, PostLikeMapper
from snek.system.object import Object
@ -42,4 +42,3 @@ def get_mappers(app=None):
def get_mapper(name, app=None):
return get_mappers(app=app)[name]

View File

@ -1,5 +1,5 @@
# mapper/forum.py
from snek.model.forum import ForumModel, ThreadModel, PostModel, PostLikeModel
from snek.model.forum import ForumModel, PostLikeModel, PostModel, ThreadModel
from snek.system.mapper import BaseMapper

View File

@ -9,12 +9,12 @@ from snek.model.channel_message import ChannelMessageModel
from snek.model.container import Container
from snek.model.drive import DriveModel
from snek.model.drive_item import DriveItemModel
from snek.model.forum import ForumModel, PostLikeModel, PostModel, ThreadModel
from snek.model.notification import NotificationModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.repository import RepositoryModel
from snek.model.user import UserModel
from snek.model.user_property import UserPropertyModel
from snek.model.forum import ForumModel, ThreadModel, PostModel, PostLikeModel
from snek.system.object import Object
@ -44,5 +44,3 @@ def get_models():
def get_model(name):
return get_models()[name]

View File

@ -12,10 +12,10 @@ class ChannelModel(BaseModel):
index = ModelField(name="index", required=True, kind=int, value=1000)
last_message_on = ModelField(name="last_message_on", required=False, kind=str)
history_start = ModelField(name="history_start", required=False, kind=str)
@property
def is_dm(self):
return 'dm' in self['tag'].lower()
return "dm" in self["tag"].lower()
async def get_last_message(self) -> ChannelMessageModel:
history_start_filter = ""
@ -23,7 +23,9 @@ class ChannelModel(BaseModel):
history_start_filter = f" AND created_at > '{self['history_start']}' "
try:
async for model in self.app.services.channel_message.query(
"SELECT uid FROM channel_message WHERE channel_uid=:channel_uid" + history_start_filter + " ORDER BY created_at DESC LIMIT 1",
"SELECT uid FROM channel_message WHERE channel_uid=:channel_uid"
+ history_start_filter
+ " ORDER BY created_at DESC LIMIT 1",
{"channel_uid": self["uid"]},
):

View File

@ -4,9 +4,16 @@ from snek.system.model import BaseModel, ModelField
class ForumModel(BaseModel):
"""Forum categories"""
name = ModelField(name="name", required=True, kind=str, min_length=3, max_length=100)
description = ModelField(name="description", required=False, kind=str, max_length=500)
slug = ModelField(name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$", unique=True)
name = ModelField(
name="name", required=True, kind=str, min_length=3, max_length=100
)
description = ModelField(
name="description", required=False, kind=str, max_length=500
)
slug = ModelField(
name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$", unique=True
)
icon = ModelField(name="icon", required=False, kind=str)
position = ModelField(name="position", required=True, kind=int, value=0)
is_active = ModelField(name="is_active", required=True, kind=bool, value=True)
@ -18,12 +25,12 @@ class ForumModel(BaseModel):
async def get_threads(self, limit=50, offset=0):
async for thread in self.app.services.thread.find(
forum_uid=self["uid"],
forum_uid=self["uid"],
deleted_at=None,
_limit=limit,
_offset=offset,
order_by="-last_post_at"
#order_by="is_pinned DESC, last_post_at DESC"
order_by="-last_post_at",
# order_by="is_pinned DESC, last_post_at DESC"
):
yield thread
@ -39,8 +46,11 @@ class ForumModel(BaseModel):
# models/thread.py
class ThreadModel(BaseModel):
"""Forum threads"""
forum_uid = ModelField(name="forum_uid", required=True, kind=str)
title = ModelField(name="title", required=True, kind=str, min_length=5, max_length=200)
title = ModelField(
name="title", required=True, kind=str, min_length=5, max_length=200
)
slug = ModelField(name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$")
created_by_uid = ModelField(name="created_by_uid", required=True, kind=str)
is_pinned = ModelField(name="is_pinned", required=True, kind=bool, value=False)
@ -56,7 +66,7 @@ class ThreadModel(BaseModel):
deleted_at=None,
_limit=limit,
_offset=offset,
order_by="created_at"
order_by="created_at",
):
yield post
@ -70,30 +80,35 @@ class ThreadModel(BaseModel):
await self.save()
# models/post.py
# models/post.py
class PostModel(BaseModel):
"""Forum posts"""
thread_uid = ModelField(name="thread_uid", required=True, kind=str)
forum_uid = ModelField(name="forum_uid", required=True, kind=str)
content = ModelField(name="content", required=True, kind=str, min_length=1, max_length=10000)
content = ModelField(
name="content", required=True, kind=str, min_length=1, max_length=10000
)
created_by_uid = ModelField(name="created_by_uid", required=True, kind=str)
edited_at = ModelField(name="edited_at", required=False, kind=str)
edited_by_uid = ModelField(name="edited_by_uid", required=False, kind=str)
is_first_post = ModelField(name="is_first_post", required=True, kind=bool, value=False)
is_first_post = ModelField(
name="is_first_post", required=True, kind=bool, value=False
)
like_count = ModelField(name="like_count", required=True, kind=int, value=0)
async def get_author(self):
return await self.app.services.user.get(uid=self["created_by_uid"])
async def is_liked_by(self, user_uid):
return await self.app.services.post_like.exists(
post_uid=self["uid"],
user_uid=user_uid
post_uid=self["uid"], user_uid=user_uid
)
# models/post_like.py
class PostLikeModel(BaseModel):
"""Post likes"""
post_uid = ModelField(name="post_uid", required=True, kind=str)
user_uid = ModelField(name="user_uid", required=True, kind=str)

View File

@ -9,37 +9,43 @@ from snek.service.container import ContainerService
from snek.service.db import DBService
from snek.service.drive import DriveService
from snek.service.drive_item import DriveItemService
from snek.service.forum import ForumService, PostLikeService, PostService, ThreadService
from snek.service.notification import NotificationService
from snek.service.push import PushService
from snek.service.repository import RepositoryService
from snek.service.socket import SocketService
from snek.service.statistics import StatisticsService
from snek.service.user import UserService
from snek.service.user_property import UserPropertyService
from snek.service.util import UtilService
from snek.system.object import Object
from snek.service.statistics import StatisticsService
from snek.service.forum import ForumService, ThreadService, PostService, PostLikeService
_service_registry = {}
def register_service(name, service_cls):
_service_registry[name] = service_cls
register = register_service
@functools.cache
def get_services(app):
result = Object(
result = Object(
**{
name: service_cls(app=app)
for name, service_cls in _service_registry.items()
}
)
result.register = register_service
return result
return result
def get_service(name, app=None):
return get_services(app=app)[name]
# Registering all services
register_service("user", UserService)
register_service("channel_member", ChannelMemberService)
@ -62,4 +68,3 @@ register_service("forum", ForumService)
register_service("thread", ThreadService)
register_service("post", PostService)
register_service("post_like", PostLikeService)

View File

@ -115,10 +115,9 @@ class ChannelService(BaseService):
async def clear(self, channel_uid):
model = await self.get(uid=channel_uid)
model['history_from'] = datetime.now()
model["history_from"] = datetime.now()
await self.save(model)
async def ensure_public_channel(self, created_by_uid):
model = await self.get(is_listed=True, tag="public")
is_moderator = False

View File

@ -1,6 +1,8 @@
import time
from snek.system.service import BaseService
from snek.system.template import sanitize_html
import time
class ChannelMessageService(BaseService):
mapper_name = "channel_message"
@ -10,7 +12,6 @@ class ChannelMessageService(BaseService):
self._configured_indexes = False
async def maintenance(self):
args = {}
for message in self.mapper.db["channel_message"].find():
print(message)
try:
@ -34,7 +35,6 @@ class ChannelMessageService(BaseService):
time.sleep(0.1)
print(ex, flush=True)
while True:
changed = 0
async for message in self.find(is_final=False):
@ -72,7 +72,7 @@ class ChannelMessageService(BaseService):
try:
template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context)
model['html'] = sanitize_html(model['html'])
model["html"] = sanitize_html(model["html"])
except Exception as ex:
print(ex, flush=True)
@ -99,9 +99,9 @@ class ChannelMessageService(BaseService):
if not user:
return {}
#if not message["html"].startswith("<chat-message"):
#message = await self.get(uid=message["uid"])
#await self.save(message)
# if not message["html"].startswith("<chat-message"):
# message = await self.get(uid=message["uid"])
# await self.save(message)
return {
"uid": message["uid"],
@ -129,7 +129,7 @@ class ChannelMessageService(BaseService):
)
template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context)
model['html'] = sanitize_html(model['html'])
model["html"] = sanitize_html(model["html"])
return await super().save(model)
async def offset(self, channel_uid, page=0, timestamp=None, page_size=30):

View File

@ -37,7 +37,7 @@ class ChatService(BaseService):
if not channel:
raise Exception("Channel not found.")
channel_message = await self.services.channel_message.get(
channel_uid=channel_uid,user_uid=user_uid, is_final=False
channel_uid=channel_uid, user_uid=user_uid, is_final=False
)
if channel_message:
channel_message["message"] = message

View File

@ -9,16 +9,18 @@ class ContainerService(BaseService):
super().__init__(*args, **kwargs)
self.compose_path = "snek-container-compose.yml"
self.compose = ComposeFileManager(self.compose_path,self.container_event_handler)
self.compose = ComposeFileManager(
self.compose_path, self.container_event_handler
)
self.event_listeners = {}
async def shutdown(self):
return await self.compose.shutdown()
async def add_event_listener(self, name, event,event_handler):
if not name in self.event_listeners:
async def add_event_listener(self, name, event, event_handler):
if name not in self.event_listeners:
self.event_listeners[name] = {}
if not event in self.event_listeners[name]:
if event not in self.event_listeners[name]:
self.event_listeners[name][event] = []
self.event_listeners[name][event].append(event_handler)
@ -27,7 +29,7 @@ class ContainerService(BaseService):
try:
self.event_listeners[name][event].remove(event_handler)
except ValueError:
pass
pass
async def container_event_handler(self, name, event, data):
event_listeners = self.event_listeners.get(name, {})
@ -44,8 +46,10 @@ class ContainerService(BaseService):
return channel_uid
return f"channel-{channel_uid}"
async def get(self,channel_uid):
return await self.compose.get_instance(await self.get_container_name(channel_uid))
async def get(self, channel_uid):
return await self.compose.get_instance(
await self.get_container_name(channel_uid)
)
async def stop(self, channel_uid):
return await self.compose.stop(await self.get_container_name(channel_uid))
@ -60,12 +64,15 @@ class ContainerService(BaseService):
result = await self.create(channel_uid=channel["uid"])
print(result)
async def get_status(self, channel_uid):
return await self.compose.get_instance_status(await self.get_container_name(channel_uid))
return await self.compose.get_instance_status(
await self.get_container_name(channel_uid)
)
async def write_stdin(self, channel_uid, data):
return await self.compose.write_stdin(await self.get_container_name(channel_uid), data)
return await self.compose.write_stdin(
await self.get_container_name(channel_uid), data
)
async def create(
self,
@ -78,7 +85,7 @@ class ContainerService(BaseService):
volumes=None,
):
name = await self.get_container_name(channel_uid)
test = await self.compose.get_instance(name)
if test:
return test
@ -116,4 +123,3 @@ class ContainerService(BaseService):
if await super().save(model):
return model
raise Exception(f"Failed to create container: {model.errors}")

View File

@ -1,18 +1,25 @@
# services/forum.py
from snek.system.service import BaseService
import re
import uuid
from collections import defaultdict
from typing import Any, Awaitable, Callable, Dict, List
from collections.abc import Awaitable
from typing import Any, Callable, Dict, List
from snek.system.model import now
from snek.system.service import BaseService
EventListener = Callable[[str, Any], Awaitable[None]] | Callable[[str, Any], None]
class BaseForumService(BaseService):
"""
Base mix-in that gives a service `add_notification_listener`,
an internal `_dispatch_event` helper, and a public `notify` method.
"""
def get_timestamp(self):
return now()
def generate_uid(self):
return str(uuid.uuid4())
@ -64,15 +71,17 @@ class BaseForumService(BaseService):
class ForumService(BaseForumService):
mapper_name = "forum"
async def create_forum(self, name, description, created_by_uid, slug=None, icon=None):
async def create_forum(
self, name, description, created_by_uid, slug=None, icon=None
):
if not slug:
slug = self.generate_slug(name)
# Check if slug exists
existing = await self.get(slug=slug)
if existing:
slug = f"{slug}-{self.generate_uid()[:8]}"
model = await self.new()
model["name"] = name
model["description"] = description
@ -80,7 +89,7 @@ class ForumService(BaseForumService):
model["created_by_uid"] = created_by_uid
if icon:
model["icon"] = icon
if await self.save(model):
await self.notify("forum_created", model)
return model
@ -89,8 +98,8 @@ class ForumService(BaseForumService):
def generate_slug(self, text):
# Convert to lowercase and replace spaces with hyphens
slug = text.lower().strip()
slug = re.sub(r'[^\w\s-]', '', slug)
slug = re.sub(r'[-\s]+', '-', slug)
slug = re.sub(r"[^\w\s-]", "", slug)
slug = re.sub(r"[-\s]+", "-", slug)
return slug
async def get_active_forums(self):
@ -105,21 +114,19 @@ class ForumService(BaseForumService):
await self.save(forum)
# services/thread.py
# services/thread.py
class ThreadService(BaseForumService):
mapper_name = "thread"
async def create_thread(self, forum_uid, title, content, created_by_uid):
# Generate slug
slug = self.services.forum.generate_slug(title)
# Check if slug exists in this forum
existing = await self.get(forum_uid=forum_uid, slug=slug)
if existing:
slug = f"{slug}-{self.generate_uid()[:8]}"
# Create thread
thread = await self.new()
thread["forum_uid"] = forum_uid
@ -128,7 +135,7 @@ class ThreadService(BaseForumService):
thread["created_by_uid"] = created_by_uid
thread["last_post_at"] = self.get_timestamp()
thread["last_post_by_uid"] = created_by_uid
if await self.save(thread):
# Create first post
post = await self.services.post.create_post(
@ -136,19 +143,18 @@ class ThreadService(BaseForumService):
forum_uid=forum_uid,
content=content,
created_by_uid=created_by_uid,
is_first_post=True
is_first_post=True,
)
# Update forum counters
forum = await self.services.forum.get(uid=forum_uid)
await forum.increment_thread_count()
await self.services.forum.update_last_post(forum_uid, thread["uid"])
await self.notify("thread_created", {
"thread": thread,
"forum_uid": forum_uid
})
await self.notify(
"thread_created", {"thread": thread, "forum_uid": forum_uid}
)
return thread, post
raise Exception(f"Failed to create thread: {thread.errors}")
@ -156,12 +162,12 @@ class ThreadService(BaseForumService):
thread = await self.get(uid=thread_uid)
if not thread:
return None
# Check if user is admin
user = await self.services.user.get(uid=user_uid)
if not user.get("is_admin"):
return None
thread["is_pinned"] = not thread["is_pinned"]
await self.save(thread)
return thread
@ -170,12 +176,12 @@ class ThreadService(BaseForumService):
thread = await self.get(uid=thread_uid)
if not thread:
return None
# Check if user is admin or thread creator
user = await self.services.user.get(uid=user_uid)
if not user.get("is_admin") and thread["created_by_uid"] != user_uid:
return None
thread["is_locked"] = not thread["is_locked"]
await self.save(thread)
return thread
@ -185,19 +191,21 @@ class ThreadService(BaseForumService):
class PostService(BaseForumService):
mapper_name = "post"
async def create_post(self, thread_uid, forum_uid, content, created_by_uid, is_first_post=False):
async def create_post(
self, thread_uid, forum_uid, content, created_by_uid, is_first_post=False
):
# Check if thread is locked
thread = await self.services.thread.get(uid=thread_uid)
if thread["is_locked"] and not is_first_post:
raise Exception("Thread is locked")
post = await self.new()
post["thread_uid"] = thread_uid
post["forum_uid"] = forum_uid
post["content"] = content
post["created_by_uid"] = created_by_uid
post["is_first_post"] = is_first_post
if await self.save(post):
# Update thread counters
if not is_first_post:
@ -205,18 +213,17 @@ class PostService(BaseForumService):
thread["last_post_at"] = self.get_timestamp()
thread["last_post_by_uid"] = created_by_uid
await self.services.thread.save(thread)
# Update forum counters
forum = await self.services.forum.get(uid=forum_uid)
await forum.increment_post_count()
await self.services.forum.update_last_post(forum_uid, thread_uid)
await self.notify("post_created", {
"post": post,
"thread_uid": thread_uid,
"forum_uid": forum_uid
})
await self.notify(
"post_created",
{"post": post, "thread_uid": thread_uid, "forum_uid": forum_uid},
)
return post
raise Exception(f"Failed to create post: {post.errors}")
@ -224,16 +231,16 @@ class PostService(BaseForumService):
post = await self.get(uid=post_uid)
if not post:
return None
# Check permissions
user = await self.services.user.get(uid=user_uid)
if post["created_by_uid"] != user_uid and not user.get("is_admin"):
return None
post["content"] = content
post["edited_at"] = self.get_timestamp()
post["edited_by_uid"] = user_uid
if await self.save(post):
await self.notify("post_edited", post)
return post
@ -243,16 +250,16 @@ class PostService(BaseForumService):
post = await self.get(uid=post_uid)
if not post:
return False
# Check permissions
user = await self.services.user.get(uid=user_uid)
if post["created_by_uid"] != user_uid and not user.get("is_admin"):
return False
# Don't allow deleting first post
if post["is_first_post"]:
return False
post["deleted_at"] = self.get_timestamp()
if await self.save(post):
await self.notify("post_deleted", post)
@ -267,38 +274,36 @@ class PostLikeService(BaseForumService):
async def toggle_like(self, post_uid, user_uid):
# Check if already liked
existing = await self.get(post_uid=post_uid, user_uid=user_uid)
if existing:
# Unlike
await self.delete(uid=existing["uid"])
# Update post like count
post = await self.services.post.get(uid=post_uid)
post["like_count"] = max(0, post["like_count"] - 1)
await self.services.post.save(post)
await self.notify("post_unliked", {
"post_uid": post_uid,
"user_uid": user_uid
})
await self.notify(
"post_unliked", {"post_uid": post_uid, "user_uid": user_uid}
)
return False
else:
# Like
like = await self.new()
like["post_uid"] = post_uid
like["user_uid"] = user_uid
if await self.save(like):
# Update post like count
post = await self.services.post.get(uid=post_uid)
post["like_count"] += 1
await self.services.post.save(post)
await self.notify("post_liked", {
"post_uid": post_uid,
"user_uid": user_uid
})
await self.notify(
"post_liked", {"post_uid": post_uid, "user_uid": user_uid}
)
return True
return None

View File

@ -1,8 +1,10 @@
from snek.system.service import BaseService
import sqlite3
import sqlite3
from snek.system.service import BaseService
class StatisticsService(BaseService):
def database(self):
db_path = self.app.db_path.split("///")[-1]
print(db_path)
@ -15,9 +17,20 @@ class StatisticsService(BaseService):
return cursor.fetchall()
tables = [
'http_access', 'user', 'channel', 'channel_member', 'broadcast',
'channel_message', 'notification', 'repository', 'test', 'drive',
'user_property', 'a', 'channel_attachment', 'push_registration'
"http_access",
"user",
"channel",
"channel_member",
"broadcast",
"channel_message",
"notification",
"repository",
"test",
"drive",
"user_property",
"a",
"channel_attachment",
"push_registration",
]
for table in tables:
@ -36,30 +49,50 @@ class StatisticsService(BaseService):
print(f"\nColumn: {name} ({col_type})")
print(f"Distinct values: {distinct_count}")
if 'INT' in col_type_upper or 'BIGINT' in col_type_upper or 'FLOAT' in col_type_upper:
cursor.execute(f"SELECT MIN('{name}'), MAX('{name}'), AVG('{name}') FROM {table} WHERE '{name}' IS NOT NULL")
if (
"INT" in col_type_upper
or "BIGINT" in col_type_upper
or "FLOAT" in col_type_upper
):
cursor.execute(
f"SELECT MIN('{name}'), MAX('{name}'), AVG('{name}') FROM {table} WHERE '{name}' IS NOT NULL"
)
min_val, max_val, avg_val = cursor.fetchone()
print(f"Min: {min_val}, Max: {max_val}, Avg: {avg_val}")
elif 'TEXT' in col_type_upper and ('date' in name.lower() or 'time' in name.lower() or 'created' in name.lower() or 'updated' in name.lower() or 'on' in name.lower()):
cursor.execute(f"SELECT MIN({name}), MAX({name}) FROM {table} WHERE {name} IS NOT NULL")
elif "TEXT" in col_type_upper and (
"date" in name.lower()
or "time" in name.lower()
or "created" in name.lower()
or "updated" in name.lower()
or "on" in name.lower()
):
cursor.execute(
f"SELECT MIN({name}), MAX({name}) FROM {table} WHERE {name} IS NOT NULL"
)
min_date, max_date = cursor.fetchone()
print(f"Earliest: {min_date}, Latest: {max_date}")
elif 'TEXT' in col_type_upper:
cursor.execute(f"SELECT LENGTH({name}) FROM {table} WHERE {name} IS NOT NULL")
elif "TEXT" in col_type_upper:
cursor.execute(
f"SELECT LENGTH({name}) FROM {table} WHERE {name} IS NOT NULL"
)
lengths = [len_row[0] for len_row in cursor.fetchall()]
if lengths:
avg_length = sum(lengths) / len(lengths)
max_length = max(lengths)
min_length = min(lengths)
print(f"Avg length: {avg_length:.2f}, Max length: {max_length}, Min length: {min_length}")
print(
f"Avg length: {avg_length:.2f}, Max length: {max_length}, Min length: {min_length}"
)
else:
print("No data to compute length statistics.")
# New statistics functions
def get_time_series_stats(table_name, date_column):
cursor.execute(f"SELECT strftime('%Y-%m-%d', {date_column}) AS day, COUNT(*) FROM {table_name} GROUP BY day")
cursor.execute(
f"SELECT strftime('%Y-%m-%d', {date_column}) AS day, COUNT(*) FROM {table_name} GROUP BY day"
)
return cursor.fetchall()
def get_count_created(table_name, date_column):
@ -67,17 +100,21 @@ class StatisticsService(BaseService):
return cursor.fetchone()[0]
def get_channels_per_user():
cursor.execute("SELECT user_uid, COUNT(*) AS channel_count FROM channel_member GROUP BY user_uid ORDER BY channel_count DESC")
cursor.execute(
"SELECT user_uid, COUNT(*) AS channel_count FROM channel_member GROUP BY user_uid ORDER BY channel_count DESC"
)
return cursor.fetchall()
def get_online_users():
cursor.execute("SELECT COUNT(*) FROM user WHERE last_ping >= datetime('now', '-5 minutes')")
cursor.execute(
"SELECT COUNT(*) FROM user WHERE last_ping >= datetime('now', '-5 minutes')"
)
return cursor.fetchone()[0]
# Example usage of new functions
messages_per_day = get_time_series_stats('channel_message', 'created_at')
users_created = get_count_created('user', 'created_at')
channels_created = get_count_created('channel', 'created_at')
messages_per_day = get_time_series_stats("channel_message", "created_at")
users_created = get_count_created("user", "created_at")
channels_created = get_count_created("channel", "created_at")
channels_per_user = get_channels_per_user()
online_users = get_online_users()

View File

@ -8,7 +8,7 @@ class UserPropertyService(BaseService):
async def set(self, user_uid, name, value):
self.mapper.db.upsert(
"user_property",
"user_property",
{
"user_uid": user_uid,
"name": name,

View File

@ -1,19 +1,16 @@
from snek.app import Application
from IPython import start_ipython
from snek.app import Application
class Shell:
def __init__(self,db_path):
def __init__(self, db_path):
self.app = Application(db_path=f"sqlite:///{db_path}")
async def maintenance(self):
await self.app.services.container.maintenance()
await self.app.services.channel_message.maintenance()
def run(self):
ns = {
"app": self.app,
"maintenance": self.maintenance
}
ns = {"app": self.app, "maintenance": self.maintenance}
start_ipython(argv=[], user_ns=ns)

View File

@ -1,28 +1,27 @@
import re
import asyncio
import atexit
import json
from uuid import uuid4
import os
import re
import socket
import unittest
from collections.abc import AsyncGenerator, Iterable
from datetime import datetime, timezone
from pathlib import Path
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
AsyncGenerator,
Union,
Tuple,
Set,
Tuple,
Union,
)
from pathlib import Path
import aiosqlite
import unittest
from types import SimpleNamespace
import asyncio
from uuid import uuid4
import aiohttp
import aiosqlite
from aiohttp import web
import socket
import os
import atexit
class AsyncDataSet:
@ -103,7 +102,7 @@ class AsyncDataSet:
# No socket file, become server
self._is_server = True
await self._setup_server()
except Exception as e:
except Exception:
# Fallback to client mode
self._is_server = False
await self._setup_client()
@ -213,7 +212,7 @@ class AsyncDataSet:
data = await request.json()
try:
try:
_limit = data.pop("_limit",60)
_limit = data.pop("_limit", 60)
except:
pass
@ -751,7 +750,7 @@ class AsyncDataSet:
for k, v in where.items()
]
)
return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None)
return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None]
async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
@ -887,7 +886,6 @@ class AsyncDataSet:
except:
pass
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {_limit}" if _limit else "") + (
@ -1209,4 +1207,3 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
if __name__ == "__main__":
unittest.main()

View File

@ -1,8 +1,9 @@
import copy
import json
import yaml
import asyncio
import subprocess
import copy
import json
import yaml
try:
import pty
except Exception as ex:
@ -10,26 +11,28 @@ except Exception as ex:
print(ex)
import os
class ComposeFileManager:
def __init__(self, compose_path="docker-compose.yml",event_handler=None):
def __init__(self, compose_path="docker-compose.yml", event_handler=None):
self.compose_path = compose_path
self._load()
self.running_instances = {}
self.event_handler = event_handler
self.event_handler = event_handler
async def shutdown(self):
print("Stopping all sessions")
tasks = []
tasks = []
for name in self.list_instances():
proc = self.running_instances.get(name)
if not proc:
continue
if proc['proc'].returncode == None:
print("Stopping",name)
tasks.append(asyncio.create_task(proc['proc'].stop()))
print("Stopped",name,"gracefully")
return tasks
if proc["proc"].returncode is None:
print("Stopping", name)
tasks.append(asyncio.create_task(proc["proc"].stop()))
print("Stopped", name, "gracefully")
return tasks
def _load(self):
try:
with open(self.compose_path) as f:
@ -47,22 +50,21 @@ class ComposeFileManager:
async def _create_readers(self, container_name):
instance = await self.get_instance(container_name)
if not instance:
return False
return False
proc = self.running_instances.get(container_name)
if not proc:
return False
async def reader(event_handler,stream):
async def reader(event_handler, stream):
loop = asyncio.get_event_loop()
while True:
line = await loop.run_in_executor(None,os.read,stream,1024)
line = await loop.run_in_executor(None, os.read, stream, 1024)
if not line:
break
await event_handler(container_name,"stdout",line)
await event_handler(container_name, "stdout", line)
await self.stop(container_name)
asyncio.create_task(reader(self.event_handler,proc['master']))
asyncio.create_task(reader(self.event_handler, proc["master"]))
def create_instance(
self,
@ -79,7 +81,7 @@ class ComposeFileManager:
"context": ".",
"dockerfile": "DockerfileUbuntu",
},
"user":"root",
"user": "root",
"working_dir": "/home/retoor/projects/snek",
"environment": [f"SNEK_UID={name}"],
}
@ -109,9 +111,9 @@ class ComposeFileManager:
instance = self.compose.get("services", {}).get(name)
if not instance:
return None
instance = json.loads(json.dumps(instance,default=str))
instance['status'] = await self.get_instance_status(name)
return instance
instance = json.loads(json.dumps(instance, default=str))
instance["status"] = await self.get_instance_status(name)
return instance
def duplicate_instance(self, name, new_name):
orig = self.get_instance(name)
@ -135,7 +137,14 @@ class ComposeFileManager:
if name not in self.list_instances():
return "error"
proc = await asyncio.create_subprocess_exec(
"docker", "compose", "-f", self.compose_path, "ps", "--services", "--filter", f"status=running",
"docker",
"compose",
"-f",
self.compose_path,
"ps",
"--services",
"--filter",
f"status=running",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
@ -148,25 +157,30 @@ class ComposeFileManager:
await self.event_handler(name, "stdin", data)
proc = self.running_instances.get(name)
if not proc:
return False
return False
try:
os.write(proc['master'], data.encode())
os.write(proc["master"], data.encode())
return True
except Exception as ex:
print(ex)
await self.stop(name)
return False
return False
async def stop(self, name):
"""Asynchronously stop a container by doing 'docker compose stop [name]'."""
if name not in self.list_instances():
return False
return False
status = await self.get_instance_status(name)
if status != "running":
return True
return True
proc = await asyncio.create_subprocess_exec(
"docker", "compose", "-f", self.compose_path, "stop", name,
"docker",
"compose",
"-f",
self.compose_path,
"stop",
name,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
@ -177,30 +191,41 @@ class ComposeFileManager:
try:
stdout, stderr = await proc.communicate()
except Exception as ex:
stdout = b''
stdout = b""
stderr = str(ex).encode()
await self.event_handler(name,"stdout",stdout or stderr)
print("Return code", proc.returncode)
await self.event_handler(name, "stdout", stdout or stderr)
print("Return code", proc.returncode)
if stdout:
await self.event_handler(name,"stdout",stdout or b'')
await self.event_handler(name, "stdout", stdout or b"")
return stdout and stdout.decode(errors="ignore") or ""
await self.event_handler(name,"stdout",stderr or b'')
await self.event_handler(name, "stdout", stderr or b"")
return stderr and stderr.decode(errors="ignore") or ""
async def start(self, name):
"""Asynchronously start a container by doing 'docker compose up -d [name]'."""
if name not in self.list_instances():
return False
return False
status = await self.get_instance_status(name)
if name in self.running_instances and status == "running" and self.running_instances.get(name) and self.running_instances.get(name).get('proc').returncode == None:
return True
if (
name in self.running_instances
and status == "running"
and self.running_instances.get(name)
and self.running_instances.get(name).get("proc").returncode is None
):
return True
elif name in self.running_instances:
del self.running_instances[name]
proc = await asyncio.create_subprocess_exec(
"docker", "compose", "-f", self.compose_path, "up", name, "-d",
"docker",
"compose",
"-f",
self.compose_path,
"up",
name,
"-d",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
@ -210,24 +235,29 @@ class ComposeFileManager:
try:
stdout, stderr = await proc.communicate()
except Exception as ex:
stdout = b''
stdout = b""
stderr = str(ex).encode()
if stdout:
print(stdout.decode(errors="ignore"))
if stderr:
print(stderr.decode(errors="ignore"))
await self.event_handler(name,"stdout",stdout or stderr)
print("Return code", proc.returncode)
await self.event_handler(name, "stdout", stdout or stderr)
print("Return code", proc.returncode)
master, slave = pty.openpty()
proc = await asyncio.create_subprocess_exec(
"docker", "compose", "-f", self.compose_path, "exec", name, "/usr/local/bin/entry",
"docker",
"compose",
"-f",
self.compose_path,
"exec",
name,
"/usr/local/bin/entry",
stdin=slave,
stdout=slave,
stderr=slave,
)
proc = {'proc': proc, 'master': master, 'slave': slave}
proc = {"proc": proc, "master": master, "slave": slave}
self.running_instances[name] = proc
await self._create_readers(name)
return True

View File

@ -1,7 +1,7 @@
DEFAULT_LIMIT = 30
import asyncio
import typing
import traceback
from snek.system.model import BaseModel
@ -11,7 +11,7 @@ class BaseMapper:
default_limit: int = DEFAULT_LIMIT
table_name: str = None
semaphore = asyncio.Semaphore(1)
def __init__(self, app):
self.app = app
@ -27,17 +27,16 @@ class BaseMapper:
async def run_in_executor(self, func, *args, **kwargs):
use_semaphore = kwargs.pop("use_semaphore", False)
def _execute():
def _execute():
result = func(*args, **kwargs)
if use_semaphore:
self.db.commit()
return result
return _execute()
#async with self.semaphore:
# return await self.loop.run_in_executor(None, _execute)
return result
return _execute()
# async with self.semaphore:
# return await self.loop.run_in_executor(None, _execute)
async def new(self):
return self.model_class(mapper=self, app=self.app)
@ -51,7 +50,7 @@ class BaseMapper:
kwargs["uid"] = uid
if not kwargs.get("deleted_at"):
kwargs["deleted_at"] = None
#traceback.print_exc()
# traceback.print_exc()
record = await self.db.get(self.table_name, kwargs)
if not record:
@ -65,20 +64,18 @@ class BaseMapper:
async def exists(self, **kwargs):
return await self.db.count(self.table_name, kwargs)
#return await self.run_in_executor(self.table.exists, **kwargs)
# return await self.run_in_executor(self.table.exists, **kwargs)
async def count(self, **kwargs) -> int:
return await self.db.count(self.table_name,kwargs)
return await self.db.count(self.table_name, kwargs)
async def save(self, model: BaseModel) -> bool:
if not model.record.get("uid"):
raise Exception(f"Attempt to save without uid: {model.record}.")
model.updated_at.update()
model.updated_at.update()
await self.upsert(model)
return model
#return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
# return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
async def find(self, **kwargs) -> typing.AsyncGenerator:
if not kwargs.get("_limit"):
@ -100,10 +97,12 @@ class BaseMapper:
yield dict(record)
async def update(self, model):
if not model["deleted_at"] is None:
if model["deleted_at"] is not None:
raise Exception("Can't update deleted record.")
model.updated_at.update()
return await self.db.update(self.table_name, model.record, {"uid": model["uid"]})
return await self.db.update(
self.table_name, model.record, {"uid": model["uid"]}
)
async def upsert(self, model):
model.updated_at.update()

View File

@ -2,7 +2,7 @@
import re
from types import SimpleNamespace
from app.cache import time_cache_async, time_cache
from app.cache import time_cache, time_cache_async
from mistune import HTMLRenderer, Markdown
from mistune.plugins.formatting import strikethrough
from mistune.plugins.spoiler import spoiler
@ -12,7 +12,6 @@ from pygments.formatters import html
from pygments.lexers import get_lexer_by_name
def strip_markdown(md_text):
# Remove code blocks (
md_text = re.sub(r"[\s\S]?```", "", md_text)
@ -50,7 +49,7 @@ class MarkdownRenderer(HTMLRenderer):
return get_lexer_by_name(lang, stripall=True)
except:
return get_lexer_by_name(default, stripall=True)
@time_cache(timeout=60 * 60)
def block_code(self, code, lang=None, info=None):
if not lang:

View File

@ -3,31 +3,17 @@
# This code provides middleware functions for an aiohttp server to manage and modify CSP, CORS, and authentication headers.
import secrets
from aiohttp import web
@web.middleware
async def csp_middleware(request, handler):
nonce = secrets.token_hex(16)
origin = request.headers.get('Origin')
csp_policy = (
"default-src 'self'; "
f"script-src 'self' {origin} 'nonce-{nonce}'; "
f"style-src 'self' 'unsafe-inline' {origin} 'nonce-{nonce}'; "
"img-src *; "
"connect-src 'self' https://umami.molodetz.nl; "
"font-src *; "
"object-src 'none'; "
"base-uri 'self'; "
"form-action 'self'; "
"frame-src 'self'; "
"worker-src *; "
"media-src *; "
"manifest-src 'self';"
)
request['csp_nonce'] = nonce
request.headers.get("Origin")
request["csp_nonce"] = nonce
response = await handler(request)
#response.headers['Content-Security-Policy'] = csp_policy
# response.headers['Content-Security-Policy'] = csp_policy
return response
@ -37,6 +23,7 @@ async def no_cors_middleware(request, handler):
response.headers.pop("Access-Control-Allow-Origin", None)
return response
@web.middleware
async def cors_allow_middleware(request, handler):
response = await handler(request)
@ -48,6 +35,7 @@ async def cors_allow_middleware(request, handler):
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
@web.middleware
async def auth_middleware(request, handler):
request["user"] = None
@ -57,6 +45,7 @@ async def auth_middleware(request, handler):
)
return await handler(request)
@web.middleware
async def cors_middleware(request, handler):
if request.headers.get("Allow"):

View File

@ -1,5 +1,4 @@
from snek.mapper import get_mapper
from snek.model.user import UserModel
from snek.system.mapper import BaseMapper
@ -40,7 +39,7 @@ class BaseService:
yield record
async def get(self, *args, **kwargs):
if not "deleted_at" in kwargs:
if "deleted_at" not in kwargs:
kwargs["deleted_at"] = None
uid = kwargs.get("uid")
if args:
@ -50,7 +49,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class:
return result
kwargs["uid"] = uid
print(kwargs,"ZZZZZZZ")
print(kwargs, "ZZZZZZZ")
result = await self.mapper.get(**kwargs)
if result:
await self.cache.set(result["uid"], result)

View File

@ -1,34 +1,41 @@
import asyncio
from aiohttp import web, WSMsgType
from datetime import datetime, timedelta, timezone
from collections import defaultdict
import html
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from aiohttp import WSMsgType, web
def create_stats_structure():
"""Creates the nested dictionary structure for storing statistics."""
def nested_dd():
return defaultdict(lambda: defaultdict(int))
return defaultdict(nested_dd)
def get_time_keys(dt: datetime):
"""Generates dictionary keys for different time granularities."""
return {
"hour": dt.strftime('%Y-%m-%d-%H'),
"day": dt.strftime('%Y-%m-%d'),
"week": dt.strftime('%Y-%W'), # Week number, Monday is first day
"month": dt.strftime('%Y-%m'),
"hour": dt.strftime("%Y-%m-%d-%H"),
"day": dt.strftime("%Y-%m-%d"),
"week": dt.strftime("%Y-%W"), # Week number, Monday is first day
"month": dt.strftime("%Y-%m"),
}
def update_stats_counters(stats_dict: defaultdict, now: datetime):
"""Increments the appropriate time-based counters in a stats dictionary."""
keys = get_time_keys(now)
stats_dict['by_hour'][keys['hour']] += 1
stats_dict['by_day'][keys['day']] += 1
stats_dict['by_week'][keys['week']] += 1
stats_dict['by_month'][keys['month']] += 1
stats_dict["by_hour"][keys["hour"]] += 1
stats_dict["by_day"][keys["day"]] += 1
stats_dict["by_week"][keys["week"]] += 1
stats_dict["by_month"][keys["month"]] += 1
def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: str) -> str:
def generate_time_series_svg(
title: str, data: list[tuple[str, int]], y_label: str
) -> str:
"""Generates a responsive SVG bar chart for time-series data."""
if not data:
return f"<h3>{html.escape(title)}</h3><p>No data yet.</p>"
@ -36,14 +43,14 @@ def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: s
svg_height, svg_width = 250, 600
bar_padding = 5
bar_width = (svg_width - 50) / len(data) - bar_padding
bars = ""
labels = ""
for i, (key, val) in enumerate(data):
bar_height = (val / max_val) * (svg_height - 50) if max_val > 0 else 0
x = i * (bar_width + bar_padding) + 40
y = svg_height - bar_height - 30
bars += f'<rect x="{x}" y="{y}" width="{bar_width}" height="{bar_height}" fill="#007BFF"><title>{html.escape(key)}: {val}</title></rect>'
labels += f'<text x="{x + bar_width / 2}" y="{svg_height - 15}" font-size="11" text-anchor="middle">{html.escape(key)}</text>'
@ -61,25 +68,32 @@ def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: s
</div>
"""
@web.middleware
async def middleware(request, handler):
"""Middleware to count all incoming HTTP requests."""
# Avoid counting requests to the stats page itself
if request.path.startswith('/stats.html'):
if request.path.startswith("/stats.html"):
return await handler(request)
update_stats_counters(request.app['stats']['http_requests'], datetime.now(timezone.utc))
update_stats_counters(
request.app["stats"]["http_requests"], datetime.now(timezone.utc)
)
return await handler(request)
def update_websocket_stats(app):
update_stats_counters(app['stats']['websocket_requests'], datetime.now(timezone.utc))
update_stats_counters(
app["stats"]["websocket_requests"], datetime.now(timezone.utc)
)
async def pipe_and_count_websocket(ws_from, ws_to, stats_dict):
"""This function proxies WebSocket messages AND counts them."""
async for msg in ws_from:
# This is the key part for monitoring WebSockets
update_stats_counters(stats_dict, datetime.now(timezone.utc))
if msg.type == WSMsgType.TEXT:
await ws_to.send_str(msg.data)
elif msg.type == WSMsgType.BINARY:
@ -91,27 +105,27 @@ async def pipe_and_count_websocket(ws_from, ws_to, stats_dict):
async def stats_handler(request: web.Request):
"""Handler to display the statistics dashboard."""
stats = request.app['stats']
stats = request.app["stats"]
now = datetime.now(timezone.utc)
# Helper to prepare data for charts
def get_data(source, period, count):
data = []
for i in range(count - 1, -1, -1):
if period == 'hour':
if period == "hour":
dt = now - timedelta(hours=i)
key, label = dt.strftime('%Y-%m-%d-%H'), dt.strftime('%H:00')
data.append((label, source['by_hour'].get(key, 0)))
elif period == 'day':
key, label = dt.strftime("%Y-%m-%d-%H"), dt.strftime("%H:00")
data.append((label, source["by_hour"].get(key, 0)))
elif period == "day":
dt = now - timedelta(days=i)
key, label = dt.strftime('%Y-%m-%d'), dt.strftime('%a')
data.append((label, source['by_day'].get(key, 0)))
key, label = dt.strftime("%Y-%m-%d"), dt.strftime("%a")
data.append((label, source["by_day"].get(key, 0)))
return data
http_hourly = get_data(stats['http_requests'], 'hour', 24)
ws_hourly = get_data(stats['ws_messages'], 'hour', 24)
http_daily = get_data(stats['http_requests'], 'day', 7)
ws_daily = get_data(stats['ws_messages'], 'day', 7)
http_hourly = get_data(stats["http_requests"], "hour", 24)
ws_hourly = get_data(stats["ws_messages"], "hour", 24)
http_daily = get_data(stats["http_requests"], "day", 7)
ws_daily = get_data(stats["ws_messages"], "day", 7)
body = f"""
<html><head><title>App Stats</title><meta http-equiv="refresh" content="30"></head>
@ -125,5 +139,4 @@ async def stats_handler(request: web.Request):
{generate_time_series_svg("WebSocket Messages", ws_daily, "Msgs/Day")}
</body></html>
"""
return web.Response(text=body, content_type='text/html')
return web.Response(text=body, content_type="text/html")

View File

@ -81,27 +81,28 @@ emoji.EMOJI_DATA[
ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + ["picture"]
def sanitize_html(value):
soup = BeautifulSoup(value, 'html.parser')
soup = BeautifulSoup(value, "html.parser")
for script in soup.find_all('script'):
for script in soup.find_all("script"):
script.decompose()
#for iframe in soup.find_all('iframe'):
#iframe.decompose()
# for iframe in soup.find_all('iframe'):
# iframe.decompose()
for tag in soup.find_all(['object', 'embed']):
for tag in soup.find_all(["object", "embed"]):
tag.decompose()
for tag in soup.find_all():
event_attributes = ['onclick', 'onerror', 'onload', 'onmouseover', 'onfocus']
event_attributes = ["onclick", "onerror", "onload", "onmouseover", "onfocus"]
for attr in event_attributes:
if attr in tag.attrs:
del tag[attr]
for img in soup.find_all('img'):
if 'onerror' in img.attrs:
for img in soup.find_all("img"):
if "onerror" in img.attrs:
img.decompose()
return soup.prettify()
@ -126,6 +127,7 @@ def set_link_target_blank(text):
return str(soup)
def whitelist_attributes(html):
return sanitize_html(html)

View File

@ -13,23 +13,27 @@ from snek.system.view import BaseView
register_heif_opener()
from snek.view.drive import DriveApiView
from snek.view.drive import DriveApiView
class ChannelDriveApiView(DriveApiView):
login_required = True
async def get_target(self):
target = await self.services.channel.get_home_folder(self.request.match_info.get("channel_uid"))
target = await self.services.channel.get_home_folder(
self.request.match_info.get("channel_uid")
)
target.mkdir(parents=True, exist_ok=True)
return target
async def get_download_url(self, rel):
return f"/channel/{self.request.match_info.get('channel_uid')}/drive/{urllib.parse.quote(rel)}"
class ChannelAttachmentView(BaseView):
login_required=False
login_required = False
async def get(self):
relative_path = self.request.match_info.get("relative_url")
@ -114,7 +118,9 @@ class ChannelAttachmentView(BaseView):
response.headers["Content-Disposition"] = (
f'attachment; filename="{channel_attachment["name"]}"'
)
response.headers["Content-Type"] = original_format or "application/octet-stream"
response.headers["Content-Type"] = (
original_format or "application/octet-stream"
)
return response
async def post(self):
@ -164,11 +170,11 @@ class ChannelAttachmentView(BaseView):
class ChannelAttachmentUploadView(BaseView):
login_required = True
async def get(self):
channel_uid = self.request.match_info.get("channel_uid")
user_uid = self.request.session.get("uid")
@ -184,37 +190,44 @@ class ChannelAttachmentUploadView(BaseView):
file = None
filename = None
msg = await ws.receive()
if not msg.type == web.WSMsgType.TEXT:
return web.HTTPBadRequest()
data = msg.json()
if not data.get('type') == 'start':
if not data.get("type") == "start":
return web.HTTPBadRequest()
filename = data['filename']
filename = data["filename"]
attachment = await self.services.channel_attachment.create_file(
channel_uid=channel_uid, name=filename, user_uid=user_uid
)
pathlib.Path(attachment["path"]).parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(attachment["path"], "wb") as file:
print("File openend.", filename)
async for msg in ws:
if msg.type == web.WSMsgType.BINARY:
print("Binary",filename)
print("Binary", filename)
if file is not None:
await file.write(msg.data)
await ws.send_json({"type": "progress", "filename": filename, "bytes": await file.tell()})
await ws.send_json(
{
"type": "progress",
"filename": filename,
"bytes": await file.tell(),
}
)
elif msg.type == web.WSMsgType.TEXT:
print("TExt",filename)
print("TExt", filename)
print(msg.json())
data = msg.json()
if data.get('type') == 'end':
data = msg.json()
if data.get("type") == "end":
relative_url = urllib.parse.quote(attachment["relative_url"])
await ws.send_json({"type": "done", "file": relative_url, "filename": filename})
await ws.send_json(
{"type": "done", "file": relative_url, "filename": filename}
)
elif msg.type == web.WSMsgType.ERROR:
break
return ws

View File

@ -1,10 +1,13 @@
from snek.system.view import BaseView
import functools
from aiohttp import web
import functools
from aiohttp import web
from snek.system.view import BaseView
class ContainerView(BaseView):
login_required= True
login_required = True
async def stdout_event_handler(self, ws, data):
try:
@ -12,8 +15,8 @@ class ContainerView(BaseView):
except Exception as ex:
print(ex)
await ws.close()
return False
return True
return False
return True
async def create_stdout_event_handler(self, ws):
return functools.partial(self.stdout_event_handler, ws)
@ -26,23 +29,27 @@ class ContainerView(BaseView):
return self.HTTPUnauthorized()
channel_uid = self.request.match_info.get("channel_uid")
channel_member = await self.services.channel_member.get(channel_uid=channel_uid, user_uid=self.request.session.get("uid"))
channel_member = await self.services.channel_member.get(
channel_uid=channel_uid, user_uid=self.request.session.get("uid")
)
if not channel_member:
return web.HTTPUnauthorized()
container = await self.services.container.get(channel_uid)
if not container:
return web.HTTPNotFound()
if not container['status'] == 'running':
resp = await self.services.container.start(channel_uid)
await ws.send_bytes(b'Container is starting\n\n')
if not container["status"] == "running":
await self.services.container.start(channel_uid)
await ws.send_bytes(b"Container is starting\n\n")
container_name = await self.services.container.get_container_name(channel_uid)
event_handler = await self.create_stdout_event_handler(ws)
await self.services.container.add_event_listener(container_name, "stdout", event_handler)
await self.services.container.add_event_listener(
container_name, "stdout", event_handler
)
while True:
data = await ws.receive()
@ -54,8 +61,8 @@ class ContainerView(BaseView):
elif data.type == web.WSMsgType.ERROR:
break
await self.services.container.remove_event_listener(container_name, channel_uid, "stdout")
return ws
await self.services.container.remove_event_listener(
container_name, channel_uid, "stdout"
)
return ws

View File

@ -1,9 +1,6 @@
import mimetypes
import os
import urllib.parse
from datetime import datetime
from pathlib import Path
from urllib.parse import quote, unquote
from aiohttp import web
@ -42,7 +39,7 @@ class DriveApiView(BaseView):
async def get(self):
target = await self.get_target()
original_target = target
original_target = target
rel = self.request.query.get("path", "")
offset = int(self.request.query.get("offset", 0))
limit = int(self.request.query.get("limit", 20))

View File

@ -1,19 +1,15 @@
# views/forum.py
from snek.system.view import BaseView
from aiohttp import web
import json
from aiohttp import web
from snek.system.view import BaseView
class ForumIndexView(BaseView):
login_required = True
async def get(self):
if self.login_required and not self.session.get("logged_in"):
return web.HTTPFound("/")
@ -71,407 +67,421 @@ class ForumIndexView(BaseView):
)
class ForumView(BaseView):
"""REST API endpoints for forum"""
login_required = True
async def get_forums(self):
request = self
request = self
self = request.app
"""GET /forum/api/forums - Get all active forums"""
forums = []
async for forum in self.services.forum.get_active_forums():
forums.append({
"uid": forum["uid"],
"name": forum["name"],
"description": forum["description"],
"slug": forum["slug"],
"icon": forum["icon"],
"thread_count": forum["thread_count"],
"post_count": forum["post_count"],
"last_post_at": forum["last_post_at"],
"last_thread_uid": forum["last_thread_uid"]
})
forums.append(
{
"uid": forum["uid"],
"name": forum["name"],
"description": forum["description"],
"slug": forum["slug"],
"icon": forum["icon"],
"thread_count": forum["thread_count"],
"post_count": forum["post_count"],
"last_post_at": forum["last_post_at"],
"last_thread_uid": forum["last_thread_uid"],
}
)
return web.json_response({"forums": forums})
async def get_forum(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""GET /forum/api/forums/:slug - Get forum by slug"""
slug = self.request.match_info["slug"]
forum = await self.services.forum.get(slug=slug, is_active=True)
if not forum:
return web.json_response({"error": "Forum not found"}, status=404)
# Get threads
threads = []
page = int(self.request.query.get("page", 1))
limit = 50
offset = (page - 1) * limit
async for thread in forum.get_threads(limit=limit, offset=offset):
# Get author info
author = await self.services.user.get(uid=thread["created_by_uid"])
last_post_author = None
if thread["last_post_by_uid"]:
last_post_author = await self.services.user.get(uid=thread["last_post_by_uid"])
threads.append({
"uid": thread["uid"],
"title": thread["title"],
"slug": thread["slug"],
"is_pinned": thread["is_pinned"],
"is_locked": thread["is_locked"],
"view_count": thread["view_count"],
"post_count": thread["post_count"],
"created_at": thread["created_at"],
"last_post_at": thread["last_post_at"],
"author": {
"uid": author["uid"],
"username": author["username"],
"nick": author["nick"],
"color": author["color"]
last_post_author = await self.services.user.get(
uid=thread["last_post_by_uid"]
)
threads.append(
{
"uid": thread["uid"],
"title": thread["title"],
"slug": thread["slug"],
"is_pinned": thread["is_pinned"],
"is_locked": thread["is_locked"],
"view_count": thread["view_count"],
"post_count": thread["post_count"],
"created_at": thread["created_at"],
"last_post_at": thread["last_post_at"],
"author": {
"uid": author["uid"],
"username": author["username"],
"nick": author["nick"],
"color": author["color"],
},
"last_post_author": (
{
"uid": last_post_author["uid"],
"username": last_post_author["username"],
"nick": last_post_author["nick"],
"color": last_post_author["color"],
}
if last_post_author
else None
),
}
)
return web.json_response(
{
"forum": {
"uid": forum["uid"],
"name": forum["name"],
"description": forum["description"],
"slug": forum["slug"],
"icon": forum["icon"],
"thread_count": forum["thread_count"],
"post_count": forum["post_count"],
},
"last_post_author": {
"uid": last_post_author["uid"],
"username": last_post_author["username"],
"nick": last_post_author["nick"],
"color": last_post_author["color"]
} if last_post_author else None
})
return web.json_response({
"forum": {
"uid": forum["uid"],
"name": forum["name"],
"description": forum["description"],
"slug": forum["slug"],
"icon": forum["icon"],
"thread_count": forum["thread_count"],
"post_count": forum["post_count"]
},
"threads": threads,
"page": page,
"hasMore": len(threads) == limit
})
"threads": threads,
"page": page,
"hasMore": len(threads) == limit,
}
)
async def create_thread(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""POST /forum/api/forums/:slug/threads - Create new thread"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
slug = self.request.match_info["slug"]
forum = await self.services.forum.get(slug=slug, is_active=True)
if not forum:
return web.json_response({"error": "Forum not found"}, status=404)
data = await self.request.json()
title = data.get("title", "").strip()
content = data.get("content", "").strip()
if not title or not content:
return web.json_response({"error": "Title and content required"}, status=400)
return web.json_response(
{"error": "Title and content required"}, status=400
)
try:
thread, post = await self.services.thread.create_thread(
forum_uid=forum["uid"],
title=title,
content=content,
created_by_uid=self.request.session["uid"]
created_by_uid=self.request.session["uid"],
)
return web.json_response({
"thread": {
"uid": thread["uid"],
"slug": thread["slug"],
"forum_slug": forum["slug"]
return web.json_response(
{
"thread": {
"uid": thread["uid"],
"slug": thread["slug"],
"forum_slug": forum["slug"],
}
}
})
)
except Exception as e:
return web.json_response({"error": str(e)}, status=400)
async def get_thread(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""GET /forum/api/threads/:thread_slug - Get thread with posts"""
thread_slug = self.request.match_info["thread_slug"]
thread = await self.services.thread.get(slug=thread_slug)
if not thread:
return web.json_response({"error": "Thread not found"}, status=404)
# Increment view count
await thread.increment_view_count()
# Get forum
forum = await self.services.forum.get(uid=thread["forum_uid"])
# Get posts
posts = []
page = int(self.request.query.get("page", 1))
limit = 50
offset = (page - 1) * limit
current_user_uid = self.request.session.get("uid")
async for post in thread.get_posts(limit=limit, offset=offset):
author = await post.get_author()
is_liked = False
if current_user_uid:
is_liked = await post.is_liked_by(current_user_uid)
posts.append({
"uid": post["uid"],
"content": post["content"],
"created_at": post["created_at"],
"edited_at": post["edited_at"],
"is_first_post": post["is_first_post"],
"like_count": post["like_count"],
"is_liked": is_liked,
"author": {
"uid": author["uid"],
"username": author["username"],
"nick": author["nick"],
"color": author["color"]
posts.append(
{
"uid": post["uid"],
"content": post["content"],
"created_at": post["created_at"],
"edited_at": post["edited_at"],
"is_first_post": post["is_first_post"],
"like_count": post["like_count"],
"is_liked": is_liked,
"author": {
"uid": author["uid"],
"username": author["username"],
"nick": author["nick"],
"color": author["color"],
},
}
})
)
# Get thread author
thread_author = await self.services.user.get(uid=thread["created_by_uid"])
return web.json_response({
"thread": {
"uid": thread["uid"],
"title": thread["title"],
"slug": thread["slug"],
"is_pinned": thread["is_pinned"],
"is_locked": thread["is_locked"],
"view_count": thread["view_count"],
"post_count": thread["post_count"],
"created_at": thread["created_at"],
"author": {
"uid": thread_author["uid"],
"username": thread_author["username"],
"nick": thread_author["nick"],
"color": thread_author["color"]
}
},
"forum": {
"uid": forum["uid"],
"name": forum["name"],
"slug": forum["slug"]
},
"posts": posts,
"page": page,
"hasMore": len(posts) == limit
})
return web.json_response(
{
"thread": {
"uid": thread["uid"],
"title": thread["title"],
"slug": thread["slug"],
"is_pinned": thread["is_pinned"],
"is_locked": thread["is_locked"],
"view_count": thread["view_count"],
"post_count": thread["post_count"],
"created_at": thread["created_at"],
"author": {
"uid": thread_author["uid"],
"username": thread_author["username"],
"nick": thread_author["nick"],
"color": thread_author["color"],
},
},
"forum": {
"uid": forum["uid"],
"name": forum["name"],
"slug": forum["slug"],
},
"posts": posts,
"page": page,
"hasMore": len(posts) == limit,
}
)
async def create_post(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""POST /forum/api/threads/:thread_uid/posts - Create new post"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
thread_uid = self.request.match_info["thread_uid"]
thread = await self.services.thread.get(uid=thread_uid)
if not thread:
return web.json_response({"error": "Thread not found"}, status=404)
data = await self.request.json()
content = data.get("content", "").strip()
if not content:
return web.json_response({"error": "Content required"}, status=400)
try:
post = await self.services.post.create_post(
thread_uid=thread["uid"],
forum_uid=thread["forum_uid"],
content=content,
created_by_uid=self.request.session["uid"]
created_by_uid=self.request.session["uid"],
)
author = await post.get_author()
return web.json_response({
"post": {
"uid": post["uid"],
"content": post["content"],
"created_at": post["created_at"],
"like_count": post["like_count"],
"is_liked": False,
"author": {
"uid": author["uid"],
"username": author["username"],
"nick": author["nick"],
"color": author["color"]
return web.json_response(
{
"post": {
"uid": post["uid"],
"content": post["content"],
"created_at": post["created_at"],
"like_count": post["like_count"],
"is_liked": False,
"author": {
"uid": author["uid"],
"username": author["username"],
"nick": author["nick"],
"color": author["color"],
},
}
}
})
)
except Exception as e:
return web.json_response({"error": str(e)}, status=400)
async def edit_post(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""PUT /forum/api/posts/:post_uid - Edit post"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
post_uid = self.request.match_info["post_uid"]
data = await self.request.json()
content = data.get("content", "").strip()
if not content:
return web.json_response({"error": "Content required"}, status=400)
post = await self.services.post.edit_post(
post_uid=post_uid,
content=content,
user_uid=self.request.session["uid"]
post_uid=post_uid, content=content, user_uid=self.request.session["uid"]
)
if not post:
return web.json_response({"error": "Cannot edit post"}, status=403)
return web.json_response({
"post": {
"uid": post["uid"],
"content": post["content"],
"edited_at": post["edited_at"]
return web.json_response(
{
"post": {
"uid": post["uid"],
"content": post["content"],
"edited_at": post["edited_at"],
}
}
})
)
async def delete_post(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""DELETE /forum/api/posts/:post_uid - Delete post"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
post_uid = self.request.match_info["post_uid"]
success = await self.services.post.delete_post(
post_uid=post_uid,
user_uid=self.request.session["uid"]
post_uid=post_uid, user_uid=self.request.session["uid"]
)
if not success:
return web.json_response({"error": "Cannot delete post"}, status=403)
return web.json_response({"success": True})
async def toggle_like(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""POST /forum/api/posts/:post_uid/like - Toggle post like"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
post_uid = self.request.match_info["post_uid"]
is_liked = await self.services.post_like.toggle_like(
post_uid=post_uid,
user_uid=self.request.session["uid"]
post_uid=post_uid, user_uid=self.request.session["uid"]
)
if is_liked is None:
return web.json_response({"error": "Failed to toggle like"}, status=400)
# Get updated post
post = await self.services.post.get(uid=post_uid)
return web.json_response({
"is_liked": is_liked,
"like_count": post["like_count"]
})
return web.json_response(
{"is_liked": is_liked, "like_count": post["like_count"]}
)
async def toggle_pin(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""POST /forum/api/threads/:thread_uid/pin - Toggle thread pin"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
thread_uid = self.request.match_info["thread_uid"]
thread = await self.services.thread.toggle_pin(
thread_uid=thread_uid,
user_uid=self.request.session["uid"]
thread_uid=thread_uid, user_uid=self.request.session["uid"]
)
if not thread:
return web.json_response({"error": "Cannot toggle pin"}, status=403)
return web.json_response({"is_pinned": thread["is_pinned"]})
async def toggle_lock(self):
request = self
request = self
self = request.app
setattr(self, "request", request)
"""POST /forum/api/threads/:thread_uid/lock - Toggle thread lock"""
if not self.request.session.get("logged_in"):
return web.json_response({"error": "Unauthorized"}, status=401)
thread_uid = self.request.match_info["thread_uid"]
thread = await self.services.thread.toggle_lock(
thread_uid=thread_uid,
user_uid=self.request.session["uid"]
thread_uid=thread_uid, user_uid=self.request.session["uid"]
)
if not thread:
return web.json_response({"error": "Cannot toggle lock"}, status=403)
return web.json_response({"is_locked": thread["is_locked"]})
# views/forum_websocket.py
class ForumWebSocketView(BaseView):
"""WebSocket view for real-time forum updates"""
async def get(self):
ws = web.WebSocketResponse()
await ws.prepare(self.request)
# Store WebSocket connection
ws_id = self.services.forum.generate_uid()
if not hasattr(self.app, 'forum_websockets'):
if not hasattr(self.app, "forum_websockets"):
self.app.forum_websockets = {}
user_uid = self.request.session.get("uid")
self.app.forum_websockets[ws_id] = {
"ws": ws,
"user_uid": user_uid,
"subscriptions": set()
"subscriptions": set(),
}
try:
async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
@ -483,68 +493,76 @@ class ForumWebSocketView(BaseView):
# Clean up
if ws_id in self.app.forum_websockets:
del self.app.forum_websockets[ws_id]
return ws
async def handle_ws_message(self, ws_id, data):
"""Handle incoming WebSocket messages"""
action = data.get("action")
if action == "subscribe":
# Subscribe to forum/thread updates
target_type = data.get("type") # "forum" or "thread"
target_id = data.get("id")
if target_type and target_id:
subscription = f"{target_type}:{target_id}"
self.app.forum_websockets[ws_id]["subscriptions"].add(subscription)
# Send confirmation
ws = self.app.forum_websockets[ws_id]["ws"]
await ws.send_str(json.dumps({
"type": "subscribed",
"subscription": subscription
}))
await ws.send_str(
json.dumps({"type": "subscribed", "subscription": subscription})
)
elif action == "unsubscribe":
target_type = data.get("type")
target_id = data.get("id")
if target_type and target_id:
subscription = f"{target_type}:{target_id}"
self.app.forum_websockets[ws_id]["subscriptions"].discard(subscription)
# Send confirmation
ws = self.app.forum_websockets[ws_id]["ws"]
await ws.send_str(json.dumps({
"type": "unsubscribed",
"subscription": subscription
}))
await ws.send_str(
json.dumps({"type": "unsubscribed", "subscription": subscription})
)
@staticmethod
async def broadcast_update(app, event_type, data):
"""Broadcast updates to subscribed WebSocket clients"""
if not hasattr(app, 'forum_websockets'):
if not hasattr(app, "forum_websockets"):
return
# Determine subscription targets based on event
targets = set()
if event_type in ["thread_created", "post_created", "post_edited", "post_deleted"]:
if event_type in [
"thread_created",
"post_created",
"post_edited",
"post_deleted",
]:
if "forum_uid" in data:
targets.add(f"forum:{data['forum_uid']}")
if event_type in ["post_created", "post_edited", "post_deleted", "post_liked", "post_unliked"]:
if event_type in [
"post_created",
"post_edited",
"post_deleted",
"post_liked",
"post_unliked",
]:
if "thread_uid" in data:
targets.add(f"thread:{data['thread_uid']}")
# Send to subscribed clients
for ws_id, ws_data in app.forum_websockets.items():
if ws_data["subscriptions"] & targets:
try:
await ws_data["ws"].send_str(json.dumps({
"type": event_type,
"data": data
}))
await ws_data["ws"].send_str(
json.dumps({"type": event_type, "data": data})
)
except Exception:
pass

View File

@ -30,7 +30,7 @@ from snek.system.view import BaseFormView
class RegisterFormView(BaseFormView):
login_required = False
form = RegisterForm

View File

@ -6,18 +6,18 @@
# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
from snek.system.stats import update_websocket_stats
import asyncio
import json
import logging
import time
import traceback
import random
from aiohttp import web
from snek.system.model import now
from snek.system.profiler import Profiler
from snek.system.stats import update_websocket_stats
from snek.system.view import BaseView
import time
logger = logging.getLogger(__name__)
@ -203,7 +203,7 @@ class RPCView(BaseView):
}
)
return channels
async def clear_channel(self, channel_uid):
self._require_login()
user = await self.services.user.get(uid=self.user_uid)
@ -211,41 +211,49 @@ class RPCView(BaseView):
raise Exception("Not allowed")
return await self.services.channel_message.clear(channel_uid)
async def write_container(self, channel_uid, content,timeout=3):
async def write_container(self, channel_uid, content, timeout=3):
self._require_login()
channel_member = await self.services.channel_member.get(
channel_uid=channel_uid, user_uid=self.user_uid
)
if not channel_member:
raise Exception("Not allowed")
container_name = await self.services.container.get_container_name(channel_uid)
container_name = await self.services.container.get_container_name(
channel_uid
)
class SessionCall:
def __init__(self, app,channel_uid_uid, container_name):
self.app = app
self.channel_uid = channel_uid
self.container_name = container_name
def __init__(self, app, channel_uid_uid, container_name):
self.app = app
self.channel_uid = channel_uid
self.container_name = container_name
self.time_last_output = time.time()
self.output = b''
self.output = b""
async def stdout_event_handler(self, data):
self.time_last_output = time.time()
self.output += data
return True
return True
async def communicate(self,content, timeout=3):
await self.app.services.container.add_event_listener(self.container_name, "stdout", self.stdout_event_handler)
await self.app.services.container.write_stdin(self.channel_uid, content)
async def communicate(self, content, timeout=3):
await self.app.services.container.add_event_listener(
self.container_name, "stdout", self.stdout_event_handler
)
await self.app.services.container.write_stdin(
self.channel_uid, content
)
while time.time() - self.time_last_output < timeout:
await asyncio.sleep(0.1)
await self.app.services.container.remove_event_listener(self.container_name, "stdout", self.stdout_event_handler)
await self.app.services.container.remove_event_listener(
self.container_name, "stdout", self.stdout_event_handler
)
return self.output
sc = SessionCall(self, channel_uid,container_name)
return (await sc.communicate(content)).decode("utf-8","ignore")
sc = SessionCall(self, channel_uid, container_name)
return (await sc.communicate(content)).decode("utf-8", "ignore")
async def get_container(self, channel_uid):
self._require_login()
@ -258,44 +266,44 @@ class RPCView(BaseView):
result = None
if container:
result = {
"name": await self.services.container.get_container_name(channel_uid),
"name": await self.services.container.get_container_name(
channel_uid
),
"cpus": container["deploy"]["resources"]["limits"]["cpus"],
"memory": container["deploy"]["resources"]["limits"]["memory"],
"image": "ubuntu:latest",
"volumes": [],
"status": container["status"]
"status": container["status"],
}
return result
async def _finalize_message_task(self,message_uid):
async def _finalize_message_task(self, message_uid):
try:
for _ in range(7):
await asyncio.sleep(1)
if not self._finalize_task:
return
return
except asyncio.exceptions.CancelledError:
print("Finalization cancelled.")
return
self._finalize_task = None
message = await self.services.channel_message.get(message_uid, is_final=False)
return
self._finalize_task = None
message = await self.services.channel_message.get(
message_uid, is_final=False
)
if not message:
return False
if message["user_uid"] != self.user_uid:
raise Exception("Not allowed")
if message["is_final"]:
return True
return True
if not message["is_final"]:
await self.services.chat.finalize(message["uid"])
return True
async def _queue_finalize_message(self, message_uid):
if self._finalize_task:
self._finalize_task.cancel()
@ -305,18 +313,22 @@ class RPCView(BaseView):
async def send_message(self, channel_uid, message, is_final=True):
self._require_login()
message = message.strip()
if not is_final:
if_final = False
check_message = await self.services.channel_message.get(channel_uid=channel_uid, user_uid=self.user_uid,is_final=False,deleted_at=None)
check_message = await self.services.channel_message.get(
channel_uid=channel_uid,
user_uid=self.user_uid,
is_final=False,
deleted_at=None,
)
if check_message:
await self._queue_finalize_message(check_message["uid"])
return await self.update_message_text(check_message["uid"], message)
if not message:
return
return
message = await self.services.chat.send(
self.user_uid, channel_uid, message, is_final
@ -328,7 +340,6 @@ class RPCView(BaseView):
return message["uid"]
async def start_container(self, channel_uid):
self._require_login()
channel_member = await self.services.channel_member.get(
@ -355,8 +366,6 @@ class RPCView(BaseView):
if not channel_member:
raise Exception("Not allowed")
return await self.services.container.get_status(channel_uid)
async def finalize_message(self, message_uid):
self._require_login()
@ -373,8 +382,8 @@ class RPCView(BaseView):
return True
async def update_message_text(self, message_uid, text):
#async with self.app.no_save():
# async with self.app.no_save():
self._require_login()
message = await self.services.channel_message.get(message_uid)
@ -403,7 +412,6 @@ class RPCView(BaseView):
"success": False,
}
if not text:
message["deleted_at"] = now()
else:
@ -427,9 +435,6 @@ class RPCView(BaseView):
return {"success": True}
async def clear_channel(self, channel_uid):
self._require_login()
user = await self.services.user.get(uid=self.user_uid)
@ -438,11 +443,10 @@ class RPCView(BaseView):
channel = await self.services.channel.get(uid=channel_uid)
if not channel:
raise Exception("Channel not found")
channel['history_start'] = datetime.now()
channel["history_start"] = datetime.now()
await self.services.channel.save(channel)
return await self.services.channel_message.clear(channel_uid)
async def echo(self, *args):
self._require_login()
return args
@ -528,10 +532,10 @@ class RPCView(BaseView):
async def _send_json(self, obj):
try:
await self.ws.send_str(json.dumps(obj, default=str))
except Exception as ex:
except Exception:
await self.services.socket.delete(self.ws)
await self.ws.close()
async def get_online_users(self, channel_uid):
self._require_login()

View File

@ -38,10 +38,10 @@ class WebView(BaseView):
channel = await self.services.channel.get(
uid=self.request.match_info.get("channel")
)
print(self.session.get("uid"),"ZZZZZZZZZZ")
print(self.session.get("uid"), "ZZZZZZZZZZ")
qq = await self.services.user.get(uid=self.session.get("uid"))
print("GGGGGGGGGG",qq)
print("GGGGGGGGGG", qq)
if not channel:
user = await self.services.user.get(
uid=self.request.match_info.get("channel")

View File

@ -1,22 +1,24 @@
import logging
import pathlib
import base64
import datetime
import hashlib
import json
import logging
import mimetypes
import os
import pathlib
import shutil
import uuid
from urllib.parse import urlparse
import aiofiles
import aiohttp
import aiohttp.web
import hashlib
from app.cache import time_cache_async
from lxml import etree
from urllib.parse import urlparse
logging.basicConfig(level=logging.DEBUG)
@aiohttp.web.middleware
async def debug_middleware(request, handler):
print(request.method, request.path, request.headers)
@ -28,6 +30,7 @@ async def debug_middleware(request, handler):
pass
return result
class WebdavApplication(aiohttp.web.Application):
def __init__(self, parent, *args, **kwargs):
middlewares = [debug_middleware]
@ -78,7 +81,7 @@ class WebdavApplication(aiohttp.web.Application):
return request["user"]
async def is_locked(self, abs_path, request):
rel_path = abs_path.relative_to(request["home"]).as_posix()
abs_path.relative_to(request["home"]).as_posix()
path = abs_path
while path != request["home"].parent:
rel = path.relative_to(request["home"]).as_posix()
@ -122,14 +125,20 @@ class WebdavApplication(aiohttp.web.Application):
return aiohttp.web.Response(status=403, text="Cannot download a directory")
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
content_type, _ = mimetypes.guess_type(str(abs_path))
content_type = content_type or "application/octet-stream"
etag = f'"{hashlib.sha1(str(abs_path.stat().st_mtime).encode()).hexdigest()}"'
headers = {"Content-Type": content_type, "ETag": etag}
return aiohttp.web.FileResponse(path=str(abs_path), headers=headers, chunk_size=8192)
return aiohttp.web.FileResponse(
path=str(abs_path), headers=headers, chunk_size=8192
)
async def handle_head(self, request):
if not await self.authenticate(request):
@ -141,10 +150,16 @@ class WebdavApplication(aiohttp.web.Application):
if not abs_path.exists():
return aiohttp.web.Response(status=404, text="File not found")
if abs_path.is_dir():
return aiohttp.web.Response(status=403, text="Cannot get metadata for a directory")
return aiohttp.web.Response(
status=403, text="Cannot get metadata for a directory"
)
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
content_type, _ = mimetypes.guess_type(str(abs_path))
@ -168,7 +183,11 @@ class WebdavApplication(aiohttp.web.Application):
abs_path = request["home"] / requested_path
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
abs_path.parent.mkdir(parents=True, exist_ok=True)
@ -186,7 +205,11 @@ class WebdavApplication(aiohttp.web.Application):
abs_path = request["home"] / requested_path
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
if abs_path.is_file():
@ -199,7 +222,9 @@ class WebdavApplication(aiohttp.web.Application):
del self.locks[rel_path]
return aiohttp.web.Response(status=204)
elif abs_path.is_dir():
if await self.has_descendant_locks(abs_path.relative_to(request["home"]).as_posix(), True):
if await self.has_descendant_locks(
abs_path.relative_to(request["home"]).as_posix(), True
):
return aiohttp.web.Response(status=423, text="Locked")
shutil.rmtree(abs_path)
rel_path = abs_path.relative_to(request["home"]).as_posix()
@ -217,7 +242,11 @@ class WebdavApplication(aiohttp.web.Application):
abs_path = request["home"] / requested_path
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
if abs_path.exists():
@ -235,15 +264,25 @@ class WebdavApplication(aiohttp.web.Application):
if not destination:
return aiohttp.web.Response(status=400, text="Destination header missing")
parsed = urlparse(destination)
if parsed.scheme and (parsed.scheme != request.scheme or parsed.netloc != request.host):
if parsed.scheme and (
parsed.scheme != request.scheme or parsed.netloc != request.host
):
return aiohttp.web.Response(status=502, text="Bad Gateway")
dest_rel = parsed.path[len(self.relative_url)+1:] if parsed.path.startswith(self.relative_url) else ""
dest_rel = (
parsed.path[len(self.relative_url) + 1 :]
if parsed.path.startswith(self.relative_url)
else ""
)
dest_path = request["home"] / dest_rel
if not src_path.exists():
return aiohttp.web.Response(status=404, text="Source not found")
src_lock = await self.is_locked(src_path, request)
dest_lock = await self.is_locked(dest_path, request)
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if src_lock and src_lock["token"] != submitted:
return aiohttp.web.Response(status=423, text="Source Locked")
if dest_lock and dest_lock["token"] != submitted:
@ -277,15 +316,25 @@ class WebdavApplication(aiohttp.web.Application):
if not destination:
return aiohttp.web.Response(status=400, text="Destination header missing")
parsed = urlparse(destination)
if parsed.scheme and (parsed.scheme != request.scheme or parsed.netloc != request.host):
if parsed.scheme and (
parsed.scheme != request.scheme or parsed.netloc != request.host
):
return aiohttp.web.Response(status=502, text="Bad Gateway")
dest_rel = parsed.path[len(self.relative_url)+1:] if parsed.path.startswith(self.relative_url) else ""
dest_rel = (
parsed.path[len(self.relative_url) + 1 :]
if parsed.path.startswith(self.relative_url)
else ""
)
dest_path = request["home"] / dest_rel
if not src_path.exists():
return aiohttp.web.Response(status=404, text="Source not found")
src_lock = await self.is_locked(src_path, request)
dest_lock = await self.is_locked(dest_path, request)
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if src_lock and src_lock["token"] != submitted:
return aiohttp.web.Response(status=423, text="Source Locked")
if dest_lock and dest_lock["token"] != submitted:
@ -347,7 +396,16 @@ class WebdavApplication(aiohttp.web.Application):
statvfs = await loop.run_in_executor(None, os.statvfs, path)
return statvfs.f_bavail * statvfs.f_frsize
async def create_propfind_node(self, request, response_xml, full_path, depth, requested_props, propname_only, all_props):
async def create_propfind_node(
self,
request,
response_xml,
full_path,
depth,
requested_props,
propname_only,
all_props,
):
abs_path = pathlib.Path(full_path)
relative_path = abs_path.relative_to(request["home"]).as_posix()
href_path = f"{self.relative_url}/{relative_path}".strip(".")
@ -379,7 +437,9 @@ class WebdavApplication(aiohttp.web.Application):
prop_notfound = None
for prop_tag, prop_nsmap in requested_props:
if prop_tag in standard_props:
await self.add_standard_property(prop_ok, prop_tag, full_path, request)
await self.add_standard_property(
prop_ok, prop_tag, full_path, request
)
found_props.append(prop_tag)
elif prop_tag in custom_properties:
if propname_only:
@ -391,23 +451,31 @@ class WebdavApplication(aiohttp.web.Application):
else:
if propstat_notfound is None:
propstat_notfound = etree.SubElement(response, "{DAV:}propstat")
prop_notfound = etree.SubElement(propstat_notfound, "{DAV:}prop")
prop_notfound = etree.SubElement(
propstat_notfound, "{DAV:}prop"
)
etree.SubElement(prop_notfound, prop_tag, nsmap=prop_nsmap)
if propstat_notfound is not None:
etree.SubElement(propstat_notfound, "{DAV:}status").text = "HTTP/1.1 404 Not Found"
etree.SubElement(propstat_notfound, "{DAV:}status").text = (
"HTTP/1.1 404 Not Found"
)
else:
for prop_name in standard_props:
if propname_only:
etree.SubElement(prop_ok, prop_name)
else:
await self.add_standard_property(prop_ok, prop_name, full_path, request)
await self.add_standard_property(
prop_ok, prop_name, full_path, request
)
for prop_name, prop_value in custom_properties.items():
if propname_only:
if prop_name.startswith("{"):
ns_end = prop_name.find("}")
ns = prop_name[1:ns_end]
local_name = prop_name[ns_end+1:]
prop_nsmap = {"D": "DAV:", None: ns} if ns != "DAV:" else {"D": "DAV:"}
local_name = prop_name[ns_end + 1 :]
prop_nsmap = (
{"D": "DAV:", None: ns} if ns != "DAV:" else {"D": "DAV:"}
)
etree.SubElement(prop_ok, local_name, nsmap=prop_nsmap)
else:
etree.SubElement(prop_ok, prop_name)
@ -417,9 +485,19 @@ class WebdavApplication(aiohttp.web.Application):
etree.SubElement(propstat_ok, "{DAV:}status").text = "HTTP/1.1 200 OK"
if abs_path.is_dir() and depth > 0:
for item in abs_path.iterdir():
if item.name.startswith(".") and item.name.endswith(".webdav_props.json"):
if item.name.startswith(".") and item.name.endswith(
".webdav_props.json"
):
continue
await self.create_propfind_node(request, response_xml, item, depth - 1, requested_props, propname_only, all_props)
await self.create_propfind_node(
request,
response_xml,
item,
depth - 1,
requested_props,
propname_only,
all_props,
)
async def add_standard_property(self, prop_elem, prop_name, full_path, request):
if prop_name == "{DAV:}resourcetype":
@ -445,14 +523,22 @@ class WebdavApplication(aiohttp.web.Application):
etree.SubElement(locktype, f"{{DAV:}}{lock['type']}")
lockscope = etree.SubElement(activelock, "{DAV:}lockscope")
etree.SubElement(lockscope, f"{{DAV:}}{lock['scope']}")
etree.SubElement(activelock, "{DAV:}depth").text = lock['depth'].capitalize()
if lock['owner']:
owner = etree.fromstring(lock['owner'])
etree.SubElement(activelock, "{DAV:}depth").text = lock[
"depth"
].capitalize()
if lock["owner"]:
owner = etree.fromstring(lock["owner"])
activelock.append(owner)
timeout_str = "Infinite" if lock['timeout'] is None else f"Second-{lock['timeout']}"
timeout_str = (
"Infinite"
if lock["timeout"] is None
else f"Second-{lock['timeout']}"
)
etree.SubElement(activelock, "{DAV:}timeout").text = timeout_str
locktoken = etree.SubElement(activelock, "{DAV:}locktoken")
etree.SubElement(locktoken, "{DAV:}href").text = f"opaquelocktoken:{lock['token']}"
etree.SubElement(locktoken, "{DAV:}href").text = (
f"opaquelocktoken:{lock['token']}"
)
elif prop_name == "{DAV:}supportedlock":
supported_lock = etree.SubElement(prop_elem, "{DAV:}supportedlock")
lock_entry_1 = etree.SubElement(supported_lock, "{DAV:}lockentry")
@ -466,11 +552,17 @@ class WebdavApplication(aiohttp.web.Application):
lock_type_2 = etree.SubElement(lock_entry_2, "{DAV:}locktype")
etree.SubElement(lock_type_2, "{DAV:}write")
elif prop_name == "{DAV:}quota-used-bytes":
size = await self.get_file_size(full_path) if full_path.is_file() else await self.get_directory_size(full_path)
size = (
await self.get_file_size(full_path)
if full_path.is_file()
else await self.get_directory_size(full_path)
)
etree.SubElement(prop_elem, "{DAV:}quota-used-bytes").text = str(size)
elif prop_name == "{DAV:}quota-available-bytes":
free_space = await self.get_disk_free_space(str(full_path))
etree.SubElement(prop_elem, "{DAV:}quota-available-bytes").text = str(free_space)
etree.SubElement(prop_elem, "{DAV:}quota-available-bytes").text = str(
free_space
)
elif prop_name == "{DAV:}getcontentlength" and full_path.is_file():
size = await self.get_file_size(full_path)
etree.SubElement(prop_elem, "{DAV:}getcontentlength").text = str(size)
@ -479,7 +571,9 @@ class WebdavApplication(aiohttp.web.Application):
if mimetype:
etree.SubElement(prop_elem, "{DAV:}getcontenttype").text = mimetype
elif prop_name == "{DAV:}getetag" and full_path.is_file():
etag = f'"{hashlib.sha1(str(full_path.stat().st_mtime).encode()).hexdigest()}"'
etag = (
f'"{hashlib.sha1(str(full_path.stat().st_mtime).encode()).hexdigest()}"'
)
etree.SubElement(prop_elem, "{DAV:}getetag").text = etag
async def handle_propfind(self, request):
@ -520,9 +614,21 @@ class WebdavApplication(aiohttp.web.Application):
pass
nsmap = {"D": "DAV:"}
response_xml = etree.Element("{DAV:}multistatus", nsmap=nsmap)
await self.create_propfind_node(request, response_xml, abs_path, depth, requested_props, propname_only, all_props)
xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode()
return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml")
await self.create_propfind_node(
request,
response_xml,
abs_path,
depth,
requested_props,
propname_only,
all_props,
)
xml_output = etree.tostring(
response_xml, encoding="utf-8", xml_declaration=True
).decode()
return aiohttp.web.Response(
status=207, text=xml_output, content_type="application/xml"
)
async def handle_proppatch(self, request):
if not await self.authenticate(request):
@ -535,7 +641,11 @@ class WebdavApplication(aiohttp.web.Application):
return aiohttp.web.Response(status=404, text="Resource not found")
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
body = await request.read()
@ -549,7 +659,11 @@ class WebdavApplication(aiohttp.web.Application):
response = etree.SubElement(response_xml, "{DAV:}response")
href = etree.SubElement(response, "{DAV:}href")
relative_path = abs_path.relative_to(request["home"]).as_posix()
href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/")
href_path = (
f"{self.relative_url}/{relative_path}".strip(".")
.replace("./", "/")
.replace("//", "/")
)
href.text = href_path
set_elem = root.find(".//{DAV:}set")
if set_elem is not None:
@ -561,14 +675,14 @@ class WebdavApplication(aiohttp.web.Application):
if child.tag.startswith("{"):
ns_end = child.tag.find("}")
ns = child.tag[1:ns_end]
local_name = child.tag[ns_end+1:]
local_name = child.tag[ns_end + 1 :]
else:
ns = "DAV:"
local_name = child.tag
prop_name = f"{{{ns}}}{local_name}"
prop_value = child.text or ""
properties[prop_name] = prop_value
elem = etree.SubElement(prop, child.tag, nsmap=child.nsmap)
etree.SubElement(prop, child.tag, nsmap=child.nsmap)
etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK"
remove_elem = root.find(".//{DAV:}remove")
if remove_elem is not None:
@ -580,7 +694,7 @@ class WebdavApplication(aiohttp.web.Application):
if child.tag.startswith("{"):
ns_end = child.tag.find("}")
ns = child.tag[1:ns_end]
local_name = child.tag[ns_end+1:]
local_name = child.tag[ns_end + 1 :]
else:
ns = "DAV:"
local_name = child.tag
@ -590,14 +704,18 @@ class WebdavApplication(aiohttp.web.Application):
status = "HTTP/1.1 200 OK"
else:
status = "HTTP/1.1 404 Not Found"
elem = etree.SubElement(prop, child.tag, nsmap=child.nsmap)
etree.SubElement(prop, child.tag, nsmap=child.nsmap)
etree.SubElement(propstat, "{DAV:}status").text = status
if properties:
await self.save_properties(abs_path, properties)
else:
await self.delete_properties_file(abs_path)
xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode()
return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml")
xml_output = etree.tostring(
response_xml, encoding="utf-8", xml_declaration=True
).decode()
return aiohttp.web.Response(
status=207, text=xml_output, content_type="application/xml"
)
except Exception as e:
logging.error(f"PROPPATCH error: {e}")
return aiohttp.web.Response(status=400, text="Bad Request")
@ -639,11 +757,19 @@ class WebdavApplication(aiohttp.web.Application):
scope = "shared"
locktype = "write"
owner_elem = lockinfo.find("{DAV:}owner")
owner_xml = etree.tostring(owner_elem, encoding="unicode") if owner_elem is not None else ""
owner_xml = (
etree.tostring(owner_elem, encoding="unicode")
if owner_elem is not None
else ""
)
covering_lock = await self.is_locked(abs_path, request)
if covering_lock:
return aiohttp.web.Response(status=423, text="Locked")
if abs_path.is_dir() and depth == "infinity" and await self.has_descendant_locks(resource, True):
if (
abs_path.is_dir()
and depth == "infinity"
and await self.has_descendant_locks(resource, True)
):
return aiohttp.web.Response(status=423, text="Locked")
token = str(uuid.uuid4())
lock_dict = {
@ -653,12 +779,17 @@ class WebdavApplication(aiohttp.web.Application):
"owner": owner_xml,
"depth": depth,
"timeout": timeout,
"created": datetime.datetime.utcnow()
"created": datetime.datetime.utcnow(),
}
self.locks[resource] = lock_dict
xml = await self.generate_lock_response(lock_dict)
headers = {"Lock-Token": f"<opaquelocktoken:{token}>"}
return aiohttp.web.Response(status=200, text=xml, content_type="application/xml", headers=headers)
return aiohttp.web.Response(
status=200,
text=xml,
content_type="application/xml",
headers=headers,
)
except Exception as e:
logging.error(f"LOCK error: {e}")
return aiohttp.web.Response(status=400, text="Bad Request")
@ -669,14 +800,20 @@ class WebdavApplication(aiohttp.web.Application):
if self.is_lock_expired(lock):
del self.locks[resource]
return aiohttp.web.Response(status=412, text="Precondition Failed")
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=409, text="Conflict")
lock["created"] = datetime.datetime.utcnow()
lock["timeout"] = timeout
xml = await self.generate_lock_response(lock)
headers = {"Lock-Token": f"<opaquelocktoken:{lock['token']}>"}
return aiohttp.web.Response(status=200, text=xml, content_type="application/xml", headers=headers)
return aiohttp.web.Response(
status=200, text=xml, content_type="application/xml", headers=headers
)
async def handle_unlock(self, request):
if not await self.authenticate(request):
@ -690,7 +827,11 @@ class WebdavApplication(aiohttp.web.Application):
if self.is_lock_expired(lock):
del self.locks[resource]
return aiohttp.web.Response(status=409, text="Conflict")
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=400, text="Invalid Lock Token")
del self.locks[resource]
@ -705,14 +846,22 @@ class WebdavApplication(aiohttp.web.Application):
etree.SubElement(locktype, f"{{DAV:}}{lock_dict['type']}")
lockscope = etree.SubElement(activelock, "{DAV:}lockscope")
etree.SubElement(lockscope, f"{{DAV:}}{lock_dict['scope']}")
etree.SubElement(activelock, "{DAV:}depth").text = lock_dict['depth'].capitalize()
if lock_dict['owner']:
owner = etree.fromstring(lock_dict['owner'])
etree.SubElement(activelock, "{DAV:}depth").text = lock_dict[
"depth"
].capitalize()
if lock_dict["owner"]:
owner = etree.fromstring(lock_dict["owner"])
activelock.append(owner)
timeout_str = "Infinite" if lock_dict['timeout'] is None else f"Second-{lock_dict['timeout']}"
timeout_str = (
"Infinite"
if lock_dict["timeout"] is None
else f"Second-{lock_dict['timeout']}"
)
etree.SubElement(activelock, "{DAV:}timeout").text = timeout_str
locktoken = etree.SubElement(activelock, "{DAV:}locktoken")
etree.SubElement(locktoken, "{DAV:}href").text = f"opaquelocktoken:{lock_dict['token']}"
etree.SubElement(locktoken, "{DAV:}href").text = (
f"opaquelocktoken:{lock_dict['token']}"
)
return etree.tostring(prop, encoding="utf-8", xml_declaration=True).decode()
def get_last_modified(self, path):
@ -731,7 +880,7 @@ class WebdavApplication(aiohttp.web.Application):
async def load_properties(self, resource_path):
props_file = self.get_props_file_path(resource_path)
if props_file.exists():
async with aiofiles.open(props_file, "r") as f:
async with aiofiles.open(props_file) as f:
content = await f.read()
return json.loads(content)
return {}
@ -773,7 +922,11 @@ class WebdavApplication(aiohttp.web.Application):
response = etree.SubElement(response_xml, "{DAV:}response")
href = etree.SubElement(response, "{DAV:}href")
relative_path = abs_path.relative_to(request["home"]).as_posix()
href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/")
href_path = (
f"{self.relative_url}/{relative_path}".strip(".")
.replace("./", "/")
.replace("//", "/")
)
href.text = href_path
propstat = etree.SubElement(response, "{DAV:}propstat")
prop = etree.SubElement(propstat, "{DAV:}prop")
@ -787,8 +940,12 @@ class WebdavApplication(aiohttp.web.Application):
elem = etree.SubElement(prop, prop_name)
elem.text = str(prop_value)
etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK"
xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode()
return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml")
xml_output = etree.tostring(
response_xml, encoding="utf-8", xml_declaration=True
).decode()
return aiohttp.web.Response(
status=207, text=xml_output, content_type="application/xml"
)
async def handle_propset(self, request):
if not await self.authenticate(request):
@ -801,7 +958,11 @@ class WebdavApplication(aiohttp.web.Application):
return aiohttp.web.Response(status=404, text="Resource not found")
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
body = await request.read()
@ -824,17 +985,25 @@ class WebdavApplication(aiohttp.web.Application):
response = etree.SubElement(response_xml, "{DAV:}response")
href = etree.SubElement(response, "{DAV:}href")
relative_path = abs_path.relative_to(request["home"]).as_posix()
href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/")
href_path = (
f"{self.relative_url}/{relative_path}".strip(".")
.replace("./", "/")
.replace("//", "/")
)
href.text = href_path
propstat = etree.SubElement(response, "{DAV:}propstat")
prop = etree.SubElement(propstat, "{DAV:}prop")
for child in prop_elem:
ns = child.nsmap.get(None, "DAV:")
prop_name = f"{{{ns}}}{child.tag.split('}')[-1]}"
elem = etree.SubElement(prop, prop_name)
etree.SubElement(prop, prop_name)
etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK"
xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode()
return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml")
xml_output = etree.tostring(
response_xml, encoding="utf-8", xml_declaration=True
).decode()
return aiohttp.web.Response(
status=207, text=xml_output, content_type="application/xml"
)
except Exception as e:
logging.error(f"PROPSET error: {e}")
return aiohttp.web.Response(status=400, text="Bad Request")
@ -850,7 +1019,11 @@ class WebdavApplication(aiohttp.web.Application):
return aiohttp.web.Response(status=404, text="Resource not found")
lock = await self.is_locked(abs_path, request)
if lock:
submitted = request.headers.get("Lock-Token", "").strip(" <>").replace("opaquelocktoken:", "")
submitted = (
request.headers.get("Lock-Token", "")
.strip(" <>")
.replace("opaquelocktoken:", "")
)
if submitted != lock["token"]:
return aiohttp.web.Response(status=423, text="Locked")
body = await request.read()
@ -877,17 +1050,25 @@ class WebdavApplication(aiohttp.web.Application):
response = etree.SubElement(response_xml, "{DAV:}response")
href = etree.SubElement(response, "{DAV:}href")
relative_path = abs_path.relative_to(request["home"]).as_posix()
href_path = f"{self.relative_url}/{relative_path}".strip(".").replace("./", "/").replace("//", "/")
href_path = (
f"{self.relative_url}/{relative_path}".strip(".")
.replace("./", "/")
.replace("//", "/")
)
href.text = href_path
propstat = etree.SubElement(response, "{DAV:}propstat")
prop = etree.SubElement(propstat, "{DAV:}prop")
for child in prop_elem:
ns = child.nsmap.get(None, "DAV:")
prop_name = f"{{{ns}}}{child.tag.split('}')[-1]}"
elem = etree.SubElement(prop, prop_name)
etree.SubElement(prop, prop_name)
etree.SubElement(propstat, "{DAV:}status").text = "HTTP/1.1 200 OK"
xml_output = etree.tostring(response_xml, encoding="utf-8", xml_declaration=True).decode()
return aiohttp.web.Response(status=207, text=xml_output, content_type="application/xml")
xml_output = etree.tostring(
response_xml, encoding="utf-8", xml_declaration=True
).decode()
return aiohttp.web.Response(
status=207, text=xml_output, content_type="application/xml"
)
except Exception as e:
logging.error(f"PROPDEL error: {e}")
return aiohttp.web.Response(status=400, text="Bad Request")

View File

@ -1,10 +1,8 @@
#!/usr/bin/env python3
import os
import pathlib
import sys
import pathlib
import os
os.chdir("/root")
@ -16,7 +14,9 @@ if not pathlib.Path(".welcome.txt").exists():
os.system("python3 -m venv --prompt '' .venv")
os.system("cp -r /opt/bootstrap/root/.* /root")
os.system("cp /opt/bootstrap/.welcome.txt /root/.welcome.txt")
pathlib.Path(".bashrc").write_text(pathlib.Path(".bashrc").read_text() + "\n" + "source .venv/bin/activate")
pathlib.Path(".bashrc").write_text(
pathlib.Path(".bashrc").read_text() + "\n" + "source .venv/bin/activate"
)
os.environ["SNEK"] = "1"
if pathlib.Path(".welcome.txt").exists():