diff --git a/src/snek/__init__.py b/src/snek/__init__.py
index ba51bb5..9dff4d0 100644
--- a/src/snek/__init__.py
+++ b/src/snek/__init__.py
@@ -26,9 +26,9 @@ Description: Utility to load environment variables from a .env file.
"""
import os
-from typing import Optional
-def load_env(file_path: str = '.env') -> None:
+
+def load_env(file_path: str = ".env") -> None:
"""
Loads environment variables from a specified file into the current process environment.
@@ -43,26 +43,26 @@ def load_env(file_path: str = '.env') -> None:
IOError: If an I/O error occurs during file reading.
"""
try:
- with open(file_path, 'r') as env_file:
+ with open(file_path) as env_file:
for line in env_file:
line = line.strip()
# Skip empty lines and comments
- if not line or line.startswith('#'):
+ if not line or line.startswith("#"):
continue
# Skip lines without '='
- if '=' not in line:
+ if "=" not in line:
continue
# Split into key and value at the first '='
- key, value = line.split('=', 1)
+ key, value = line.split("=", 1)
# Set environment variable
os.environ[key.strip()] = value.strip()
except FileNotFoundError:
raise FileNotFoundError(f"Environment file '{file_path}' not found.")
- except IOError as e:
- raise IOError(f"Error reading environment file '{file_path}': {e}")
+ except OSError as e:
+ raise OSError(f"Error reading environment file '{file_path}': {e}")
try:
load_env()
-except Exception as e:
+except Exception:
pass
diff --git a/src/snek/__main__.py b/src/snek/__main__.py
index 53f9751..a9af4ed 100644
--- a/src/snek/__main__.py
+++ b/src/snek/__main__.py
@@ -1,20 +1,20 @@
+import asyncio
import pathlib
import shutil
import sqlite3
-import asyncio
+
import click
from aiohttp import web
-from IPython import start_ipython
-from snek.shell import Shell
+
from snek.app import Application
-
-
+from snek.shell import Shell
@click.group()
def cli():
pass
+
@cli.command()
def export():
@@ -30,6 +30,7 @@ def export():
user = await app.services.user.get(uid=message["user_uid"])
message["user"] = user and user["username"] or None
return (message["user"] or "") + ": " + (message["text"] or "")
+
async def run():
result = []
@@ -49,6 +50,7 @@ def export():
with open("dump.txt", "w") as f:
f.write("\n\n".join(result))
print("Dump written to dump.json")
+
asyncio.run(run())
@@ -57,16 +59,20 @@ def statistics():
async def run():
app = Application(db_path="sqlite:///snek.db")
app.services.statistics.database()
+
asyncio.run(run())
+
@cli.command()
def maintenance():
async def run():
app = Application(db_path="sqlite:///snek.db")
await app.services.container.maintenance()
await app.services.channel_message.maintenance()
+
asyncio.run(run())
+
@cli.command()
@click.option(
"--db_path", default="snek.db", help="Database to initialize if not exists."
@@ -123,9 +129,11 @@ def serve(port, host, db_path):
def shell(db_path):
Shell(db_path).run()
+
def main():
try:
import sentry_sdk
+
sentry_sdk.init("https://ab6147c2f3354c819768c7e89455557b@gt.molodetz.nl/1")
except ImportError:
print("Could not import sentry_sdk")
diff --git a/src/snek/app.py b/src/snek/app.py
index a551e41..c90db06 100644
--- a/src/snek/app.py
+++ b/src/snek/app.py
@@ -3,13 +3,13 @@ import logging
import pathlib
import ssl
import uuid
-import signal
-from datetime import datetime
from contextlib import asynccontextmanager
+from datetime import datetime
from snek import snode
-from snek.view.threads import ThreadsView
from snek.system.ads import AsyncDataSet
+from snek.view.threads import ThreadsView
+
logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
@@ -25,16 +25,21 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader
+from snek.forum import setup_forum
from snek.mapper import get_mappers
from snek.service import get_services
from snek.sgit import GitApplication
from snek.sssh import start_ssh_server
from snek.system import http
from snek.system.cache import Cache
-from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler
from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler
+from snek.system.stats import (
+ create_stats_structure,
+ middleware as stats_middleware,
+ stats_handler,
+)
from snek.system.template import (
EmojiExtension,
LinkifyExtension,
@@ -43,10 +48,15 @@ from snek.system.template import (
)
from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView
-from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
+from snek.view.channel import (
+ ChannelAttachmentUploadView,
+ ChannelAttachmentView,
+ ChannelDriveApiView,
+ ChannelView,
+)
+from snek.view.container import ContainerView
from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveApiView, DriveView
-from snek.view.channel import ChannelDriveApiView
from snek.view.index import IndexView
from snek.view.login import LoginView
from snek.view.logout import LogoutView
@@ -55,7 +65,6 @@ from snek.view.register import RegisterView
from snek.view.repository import RepositoryView
from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView
-from snek.view.container import ContainerView
from snek.view.settings.containers import (
ContainersCreateView,
ContainersDeleteView,
@@ -77,7 +86,6 @@ from snek.view.upload import UploadView
from snek.view.user import UserView
from snek.view.web import WebView
from snek.webdav import WebdavApplication
-from snek.forum import setup_forum
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
from snek.system.template import whitelist_attributes
@@ -175,13 +183,12 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server)
- #self.on_startup.append(self.prepare_database)
+ # self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app):
- app['stats'] = create_stats_structure()
+ app["stats"] = create_stats_structure()
print("Stats prepared", flush=True)
-
@property
def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds()
@@ -225,10 +232,10 @@ class Application(BaseApplication):
# app.loop = asyncio.get_running_loop()
app.executor = ThreadPoolExecutor(max_workers=200)
app.loop.set_default_executor(self.executor)
- #for sig in (signal.SIGINT, signal.SIGTERM):
- #app.loop.add_signal_handler(
- # sig, lambda: asyncio.create_task(self.services.container.shutdown())
- #)
+ # for sig in (signal.SIGINT, signal.SIGTERM):
+ # app.loop.add_signal_handler(
+ # sig, lambda: asyncio.create_task(self.services.container.shutdown())
+ # )
async def create_task(self, task):
await self.tasks.put(task)
@@ -244,7 +251,6 @@ class Application(BaseApplication):
print(ex)
self.db.commit()
-
async def prepare_database(self, app):
await self.db.query_raw("PRAGMA journal_mode=WAL")
await self.db.query_raw("PRAGMA syncnorm=off")
@@ -291,9 +297,7 @@ class Application(BaseApplication):
self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
)
- self.router.add_view(
- "/channel/{channel_uid}/drive.json", ChannelDriveApiView
- )
+ self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView)
self.router.add_view(
"/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
)
@@ -414,9 +418,9 @@ class Application(BaseApplication):
)
try:
- context["nonce"] = request['csp_nonce']
+ context["nonce"] = request["csp_nonce"]
except:
- context['nonce'] = '?'
+ context["nonce"] = "?"
rendered = await super().render_template(template, request, context)
@@ -466,14 +470,14 @@ class Application(BaseApplication):
@asynccontextmanager
async def no_save(self):
- stats = {
- 'count': 0
- }
+ stats = {"count": 0}
+
async def patched_save(*args, **kwargs):
await self.cache.set(args[0]["uid"], args[0])
- stats['count'] = stats['count'] + 1
+ stats["count"] = stats["count"] + 1
print(f"save is ignored {stats['count']} times")
return args[0]
+
save_original = self.services.channel_message.mapper.save
self.services.channel_message.mapper.save = patched_save
raised_exception = None
@@ -486,6 +490,7 @@ class Application(BaseApplication):
if raised_exception:
raise raised_exception
+
app = Application(db_path="sqlite:///snek.db")
diff --git a/src/snek/forum.py b/src/snek/forum.py
index a0b0d2d..f1a0e78 100644
--- a/src/snek/forum.py
+++ b/src/snek/forum.py
@@ -1,5 +1,6 @@
# forum_app.py
import aiohttp.web
+
from snek.view.forum import ForumIndexView, ForumView, ForumWebSocketView
@@ -7,13 +8,13 @@ class ForumApplication(aiohttp.web.Application):
def __init__(self, parent, *args, **kwargs):
super().__init__(*args, **kwargs)
self.parent = parent
- self.render_template = self.parent.render_template
+ self.render_template = self.parent.render_template
# Set up routes
self.setup_routes()
-
+
# Set up notification listeners
self.setup_notifications()
-
+
@property
def db(self):
return self.parent.db
@@ -21,86 +22,84 @@ class ForumApplication(aiohttp.web.Application):
@property
def services(self):
return self.parent.services
-
+
def setup_routes(self):
"""Set up all forum routes"""
# API routes
self.router.add_view("/index.html", ForumIndexView)
self.router.add_route("GET", "/api/forums", ForumView.get_forums)
self.router.add_route("GET", "/api/forums/{slug}", ForumView.get_forum)
- self.router.add_route("POST", "/api/forums/{slug}/threads", ForumView.create_thread)
+ self.router.add_route(
+ "POST", "/api/forums/{slug}/threads", ForumView.create_thread
+ )
self.router.add_route("GET", "/api/threads/{thread_slug}", ForumView.get_thread)
- self.router.add_route("POST", "/api/threads/{thread_uid}/posts", ForumView.create_post)
+ self.router.add_route(
+ "POST", "/api/threads/{thread_uid}/posts", ForumView.create_post
+ )
self.router.add_route("PUT", "/api/posts/{post_uid}", ForumView.edit_post)
self.router.add_route("DELETE", "/api/posts/{post_uid}", ForumView.delete_post)
- self.router.add_route("POST", "/api/posts/{post_uid}/like", ForumView.toggle_like)
- self.router.add_route("POST", "/api/threads/{thread_uid}/pin", ForumView.toggle_pin)
- self.router.add_route("POST", "/api/threads/{thread_uid}/lock", ForumView.toggle_lock)
-
+ self.router.add_route(
+ "POST", "/api/posts/{post_uid}/like", ForumView.toggle_like
+ )
+ self.router.add_route(
+ "POST", "/api/threads/{thread_uid}/pin", ForumView.toggle_pin
+ )
+ self.router.add_route(
+ "POST", "/api/threads/{thread_uid}/lock", ForumView.toggle_lock
+ )
+
# WebSocket route
self.router.add_view("/ws", ForumWebSocketView)
-
+
# Static HTML route
self.router.add_route("GET", "/{path:.*}", self.serve_forum_html)
-
+
def setup_notifications(self):
"""Set up notification listeners for WebSocket broadcasting"""
# Forum notifications
- self.services.forum.add_notification_listener("forum_created", self.on_forum_event)
-
+ self.services.forum.add_notification_listener(
+ "forum_created", self.on_forum_event
+ )
+
# Thread notifications
- self.services.thread.add_notification_listener("thread_created", self.on_thread_event)
-
+ self.services.thread.add_notification_listener(
+ "thread_created", self.on_thread_event
+ )
+
# Post notifications
self.services.post.add_notification_listener("post_created", self.on_post_event)
self.services.post.add_notification_listener("post_edited", self.on_post_event)
self.services.post.add_notification_listener("post_deleted", self.on_post_event)
-
+
# Like notifications
- self.services.post_like.add_notification_listener("post_liked", self.on_like_event)
- self.services.post_like.add_notification_listener("post_unliked", self.on_like_event)
-
+ self.services.post_like.add_notification_listener(
+ "post_liked", self.on_like_event
+ )
+ self.services.post_like.add_notification_listener(
+ "post_unliked", self.on_like_event
+ )
+
async def on_forum_event(self, event_type, data):
"""Handle forum events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
-
+
async def on_thread_event(self, event_type, data):
"""Handle thread events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
-
+
async def on_post_event(self, event_type, data):
"""Handle post events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
-
+
async def on_like_event(self, event_type, data):
"""Handle like events"""
await ForumWebSocketView.broadcast_update(self, event_type, data)
-
+
async def serve_forum_html(self, request):
"""Serve the forum HTML with the web component"""
- html = """
-
-
-
-
- Forum
-
-
-
-
-
-
-"""
return await self.parent.render_template("forum.html", request)
-
-
- #return aiohttp.web.Response(text=html, content_type="text/html")
+
+ # return aiohttp.web.Response(text=html, content_type="text/html")
# Integration with main app
diff --git a/src/snek/mapper/__init__.py b/src/snek/mapper/__init__.py
index be01022..3bce17f 100644
--- a/src/snek/mapper/__init__.py
+++ b/src/snek/mapper/__init__.py
@@ -7,12 +7,12 @@ from snek.mapper.channel_message import ChannelMessageMapper
from snek.mapper.container import ContainerMapper
from snek.mapper.drive import DriveMapper
from snek.mapper.drive_item import DriveItemMapper
+from snek.mapper.forum import ForumMapper, PostLikeMapper, PostMapper, ThreadMapper
from snek.mapper.notification import NotificationMapper
from snek.mapper.push import PushMapper
from snek.mapper.repository import RepositoryMapper
from snek.mapper.user import UserMapper
from snek.mapper.user_property import UserPropertyMapper
-from snek.mapper.forum import ForumMapper, ThreadMapper, PostMapper, PostLikeMapper
from snek.system.object import Object
@@ -42,4 +42,3 @@ def get_mappers(app=None):
def get_mapper(name, app=None):
return get_mappers(app=app)[name]
-
diff --git a/src/snek/mapper/forum.py b/src/snek/mapper/forum.py
index ec258a2..37009f6 100644
--- a/src/snek/mapper/forum.py
+++ b/src/snek/mapper/forum.py
@@ -1,5 +1,5 @@
# mapper/forum.py
-from snek.model.forum import ForumModel, ThreadModel, PostModel, PostLikeModel
+from snek.model.forum import ForumModel, PostLikeModel, PostModel, ThreadModel
from snek.system.mapper import BaseMapper
diff --git a/src/snek/model/__init__.py b/src/snek/model/__init__.py
index 13ba5b2..1ee16db 100644
--- a/src/snek/model/__init__.py
+++ b/src/snek/model/__init__.py
@@ -9,12 +9,12 @@ from snek.model.channel_message import ChannelMessageModel
from snek.model.container import Container
from snek.model.drive import DriveModel
from snek.model.drive_item import DriveItemModel
+from snek.model.forum import ForumModel, PostLikeModel, PostModel, ThreadModel
from snek.model.notification import NotificationModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.repository import RepositoryModel
from snek.model.user import UserModel
from snek.model.user_property import UserPropertyModel
-from snek.model.forum import ForumModel, ThreadModel, PostModel, PostLikeModel
from snek.system.object import Object
@@ -44,5 +44,3 @@ def get_models():
def get_model(name):
return get_models()[name]
-
-
diff --git a/src/snek/model/channel.py b/src/snek/model/channel.py
index 2c17478..b5cb35d 100644
--- a/src/snek/model/channel.py
+++ b/src/snek/model/channel.py
@@ -12,10 +12,10 @@ class ChannelModel(BaseModel):
index = ModelField(name="index", required=True, kind=int, value=1000)
last_message_on = ModelField(name="last_message_on", required=False, kind=str)
history_start = ModelField(name="history_start", required=False, kind=str)
-
+
@property
def is_dm(self):
- return 'dm' in self['tag'].lower()
+ return "dm" in self["tag"].lower()
async def get_last_message(self) -> ChannelMessageModel:
history_start_filter = ""
@@ -23,7 +23,9 @@ class ChannelModel(BaseModel):
history_start_filter = f" AND created_at > '{self['history_start']}' "
try:
async for model in self.app.services.channel_message.query(
- "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid" + history_start_filter + " ORDER BY created_at DESC LIMIT 1",
+ "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid"
+ + history_start_filter
+ + " ORDER BY created_at DESC LIMIT 1",
{"channel_uid": self["uid"]},
):
diff --git a/src/snek/model/forum.py b/src/snek/model/forum.py
index 9c037a1..80e21a3 100644
--- a/src/snek/model/forum.py
+++ b/src/snek/model/forum.py
@@ -4,9 +4,16 @@ from snek.system.model import BaseModel, ModelField
class ForumModel(BaseModel):
"""Forum categories"""
- name = ModelField(name="name", required=True, kind=str, min_length=3, max_length=100)
- description = ModelField(name="description", required=False, kind=str, max_length=500)
- slug = ModelField(name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$", unique=True)
+
+ name = ModelField(
+ name="name", required=True, kind=str, min_length=3, max_length=100
+ )
+ description = ModelField(
+ name="description", required=False, kind=str, max_length=500
+ )
+ slug = ModelField(
+ name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$", unique=True
+ )
icon = ModelField(name="icon", required=False, kind=str)
position = ModelField(name="position", required=True, kind=int, value=0)
is_active = ModelField(name="is_active", required=True, kind=bool, value=True)
@@ -18,12 +25,12 @@ class ForumModel(BaseModel):
async def get_threads(self, limit=50, offset=0):
async for thread in self.app.services.thread.find(
- forum_uid=self["uid"],
+ forum_uid=self["uid"],
deleted_at=None,
_limit=limit,
_offset=offset,
- order_by="-last_post_at"
- #order_by="is_pinned DESC, last_post_at DESC"
+ order_by="-last_post_at",
+ # order_by="is_pinned DESC, last_post_at DESC"
):
yield thread
@@ -39,8 +46,11 @@ class ForumModel(BaseModel):
# models/thread.py
class ThreadModel(BaseModel):
"""Forum threads"""
+
forum_uid = ModelField(name="forum_uid", required=True, kind=str)
- title = ModelField(name="title", required=True, kind=str, min_length=5, max_length=200)
+ title = ModelField(
+ name="title", required=True, kind=str, min_length=5, max_length=200
+ )
slug = ModelField(name="slug", required=True, kind=str, regex=r"^[a-z0-9-]+$")
created_by_uid = ModelField(name="created_by_uid", required=True, kind=str)
is_pinned = ModelField(name="is_pinned", required=True, kind=bool, value=False)
@@ -56,7 +66,7 @@ class ThreadModel(BaseModel):
deleted_at=None,
_limit=limit,
_offset=offset,
- order_by="created_at"
+ order_by="created_at",
):
yield post
@@ -70,30 +80,35 @@ class ThreadModel(BaseModel):
await self.save()
-# models/post.py
+# models/post.py
class PostModel(BaseModel):
"""Forum posts"""
+
thread_uid = ModelField(name="thread_uid", required=True, kind=str)
forum_uid = ModelField(name="forum_uid", required=True, kind=str)
- content = ModelField(name="content", required=True, kind=str, min_length=1, max_length=10000)
+ content = ModelField(
+ name="content", required=True, kind=str, min_length=1, max_length=10000
+ )
created_by_uid = ModelField(name="created_by_uid", required=True, kind=str)
edited_at = ModelField(name="edited_at", required=False, kind=str)
edited_by_uid = ModelField(name="edited_by_uid", required=False, kind=str)
- is_first_post = ModelField(name="is_first_post", required=True, kind=bool, value=False)
+ is_first_post = ModelField(
+ name="is_first_post", required=True, kind=bool, value=False
+ )
like_count = ModelField(name="like_count", required=True, kind=int, value=0)
-
+
async def get_author(self):
return await self.app.services.user.get(uid=self["created_by_uid"])
async def is_liked_by(self, user_uid):
return await self.app.services.post_like.exists(
- post_uid=self["uid"],
- user_uid=user_uid
+ post_uid=self["uid"], user_uid=user_uid
)
# models/post_like.py
class PostLikeModel(BaseModel):
"""Post likes"""
+
post_uid = ModelField(name="post_uid", required=True, kind=str)
user_uid = ModelField(name="user_uid", required=True, kind=str)
diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py
index 6e9e58f..ae44e15 100644
--- a/src/snek/service/__init__.py
+++ b/src/snek/service/__init__.py
@@ -9,37 +9,43 @@ from snek.service.container import ContainerService
from snek.service.db import DBService
from snek.service.drive import DriveService
from snek.service.drive_item import DriveItemService
+from snek.service.forum import ForumService, PostLikeService, PostService, ThreadService
from snek.service.notification import NotificationService
from snek.service.push import PushService
from snek.service.repository import RepositoryService
from snek.service.socket import SocketService
+from snek.service.statistics import StatisticsService
from snek.service.user import UserService
from snek.service.user_property import UserPropertyService
from snek.service.util import UtilService
from snek.system.object import Object
-from snek.service.statistics import StatisticsService
-from snek.service.forum import ForumService, ThreadService, PostService, PostLikeService
+
_service_registry = {}
+
def register_service(name, service_cls):
_service_registry[name] = service_cls
+
register = register_service
+
@functools.cache
def get_services(app):
- result = Object(
+ result = Object(
**{
name: service_cls(app=app)
for name, service_cls in _service_registry.items()
}
)
result.register = register_service
- return result
+ return result
+
def get_service(name, app=None):
return get_services(app=app)[name]
+
# Registering all services
register_service("user", UserService)
register_service("channel_member", ChannelMemberService)
@@ -62,4 +68,3 @@ register_service("forum", ForumService)
register_service("thread", ThreadService)
register_service("post", PostService)
register_service("post_like", PostLikeService)
-
diff --git a/src/snek/service/channel.py b/src/snek/service/channel.py
index 5d15683..5a52fc7 100644
--- a/src/snek/service/channel.py
+++ b/src/snek/service/channel.py
@@ -115,10 +115,9 @@ class ChannelService(BaseService):
async def clear(self, channel_uid):
model = await self.get(uid=channel_uid)
- model['history_from'] = datetime.now()
+ model["history_from"] = datetime.now()
await self.save(model)
-
async def ensure_public_channel(self, created_by_uid):
model = await self.get(is_listed=True, tag="public")
is_moderator = False
diff --git a/src/snek/service/channel_message.py b/src/snek/service/channel_message.py
index 642f4e0..12ab5f3 100644
--- a/src/snek/service/channel_message.py
+++ b/src/snek/service/channel_message.py
@@ -1,6 +1,8 @@
+import time
+
from snek.system.service import BaseService
from snek.system.template import sanitize_html
-import time
+
class ChannelMessageService(BaseService):
mapper_name = "channel_message"
@@ -10,7 +12,6 @@ class ChannelMessageService(BaseService):
self._configured_indexes = False
async def maintenance(self):
- args = {}
for message in self.mapper.db["channel_message"].find():
print(message)
try:
@@ -34,7 +35,6 @@ class ChannelMessageService(BaseService):
time.sleep(0.1)
print(ex, flush=True)
-
while True:
changed = 0
async for message in self.find(is_final=False):
@@ -72,7 +72,7 @@ class ChannelMessageService(BaseService):
try:
template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context)
- model['html'] = sanitize_html(model['html'])
+ model["html"] = sanitize_html(model["html"])
except Exception as ex:
print(ex, flush=True)
@@ -99,9 +99,9 @@ class ChannelMessageService(BaseService):
if not user:
return {}
- #if not message["html"].startswith("= datetime('now', '-5 minutes')")
+ cursor.execute(
+ "SELECT COUNT(*) FROM user WHERE last_ping >= datetime('now', '-5 minutes')"
+ )
return cursor.fetchone()[0]
# Example usage of new functions
- messages_per_day = get_time_series_stats('channel_message', 'created_at')
- users_created = get_count_created('user', 'created_at')
- channels_created = get_count_created('channel', 'created_at')
+ messages_per_day = get_time_series_stats("channel_message", "created_at")
+ users_created = get_count_created("user", "created_at")
+ channels_created = get_count_created("channel", "created_at")
channels_per_user = get_channels_per_user()
online_users = get_online_users()
diff --git a/src/snek/service/user_property.py b/src/snek/service/user_property.py
index acaf0bf..41e8dfa 100644
--- a/src/snek/service/user_property.py
+++ b/src/snek/service/user_property.py
@@ -8,7 +8,7 @@ class UserPropertyService(BaseService):
async def set(self, user_uid, name, value):
self.mapper.db.upsert(
- "user_property",
+ "user_property",
{
"user_uid": user_uid,
"name": name,
diff --git a/src/snek/shell.py b/src/snek/shell.py
index 2a1546b..de7885f 100644
--- a/src/snek/shell.py
+++ b/src/snek/shell.py
@@ -1,19 +1,16 @@
-from snek.app import Application
from IPython import start_ipython
+from snek.app import Application
+
+
class Shell:
- def __init__(self,db_path):
+ def __init__(self, db_path):
self.app = Application(db_path=f"sqlite:///{db_path}")
-
+
async def maintenance(self):
await self.app.services.container.maintenance()
await self.app.services.channel_message.maintenance()
-
-
def run(self):
- ns = {
- "app": self.app,
- "maintenance": self.maintenance
- }
+ ns = {"app": self.app, "maintenance": self.maintenance}
start_ipython(argv=[], user_ns=ns)
diff --git a/src/snek/system/ads.py b/src/snek/system/ads.py
index 42b64a0..06b2083 100644
--- a/src/snek/system/ads.py
+++ b/src/snek/system/ads.py
@@ -1,28 +1,27 @@
-import re
+import asyncio
+import atexit
import json
-from uuid import uuid4
+import os
+import re
+import socket
+import unittest
+from collections.abc import AsyncGenerator, Iterable
from datetime import datetime, timezone
+from pathlib import Path
from typing import (
Any,
Dict,
- Iterable,
List,
Optional,
- AsyncGenerator,
- Union,
- Tuple,
Set,
+ Tuple,
+ Union,
)
-from pathlib import Path
-import aiosqlite
-import unittest
-from types import SimpleNamespace
-import asyncio
+from uuid import uuid4
+
import aiohttp
+import aiosqlite
from aiohttp import web
-import socket
-import os
-import atexit
class AsyncDataSet:
@@ -103,7 +102,7 @@ class AsyncDataSet:
# No socket file, become server
self._is_server = True
await self._setup_server()
- except Exception as e:
+ except Exception:
# Fallback to client mode
self._is_server = False
await self._setup_client()
@@ -213,7 +212,7 @@ class AsyncDataSet:
data = await request.json()
try:
try:
- _limit = data.pop("_limit",60)
+ _limit = data.pop("_limit", 60)
except:
pass
@@ -751,7 +750,7 @@ class AsyncDataSet:
for k, v in where.items()
]
)
- return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None)
+ return " WHERE " + " AND ".join(clauses), [v for v in vals if v is not None]
async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
@@ -887,7 +886,6 @@ class AsyncDataSet:
except:
pass
-
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {_limit}" if _limit else "") + (
@@ -1209,4 +1207,3 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
if __name__ == "__main__":
unittest.main()
-
diff --git a/src/snek/system/docker.py b/src/snek/system/docker.py
index c72620b..92772b0 100644
--- a/src/snek/system/docker.py
+++ b/src/snek/system/docker.py
@@ -1,8 +1,9 @@
-import copy
-import json
-import yaml
import asyncio
-import subprocess
+import copy
+import json
+
+import yaml
+
try:
import pty
except Exception as ex:
@@ -10,26 +11,28 @@ except Exception as ex:
print(ex)
import os
+
class ComposeFileManager:
- def __init__(self, compose_path="docker-compose.yml",event_handler=None):
+ def __init__(self, compose_path="docker-compose.yml", event_handler=None):
self.compose_path = compose_path
self._load()
self.running_instances = {}
- self.event_handler = event_handler
+ self.event_handler = event_handler
async def shutdown(self):
print("Stopping all sessions")
-
- tasks = []
+
+ tasks = []
for name in self.list_instances():
proc = self.running_instances.get(name)
if not proc:
continue
- if proc['proc'].returncode == None:
- print("Stopping",name)
- tasks.append(asyncio.create_task(proc['proc'].stop()))
- print("Stopped",name,"gracefully")
- return tasks
+ if proc["proc"].returncode is None:
+ print("Stopping", name)
+ tasks.append(asyncio.create_task(proc["proc"].stop()))
+ print("Stopped", name, "gracefully")
+ return tasks
+
def _load(self):
try:
with open(self.compose_path) as f:
@@ -47,22 +50,21 @@ class ComposeFileManager:
async def _create_readers(self, container_name):
instance = await self.get_instance(container_name)
if not instance:
- return False
+ return False
proc = self.running_instances.get(container_name)
if not proc:
return False
-
- async def reader(event_handler,stream):
+
+ async def reader(event_handler, stream):
loop = asyncio.get_event_loop()
while True:
- line = await loop.run_in_executor(None,os.read,stream,1024)
+ line = await loop.run_in_executor(None, os.read, stream, 1024)
if not line:
break
- await event_handler(container_name,"stdout",line)
+ await event_handler(container_name, "stdout", line)
await self.stop(container_name)
- asyncio.create_task(reader(self.event_handler,proc['master']))
-
+ asyncio.create_task(reader(self.event_handler, proc["master"]))
def create_instance(
self,
@@ -79,7 +81,7 @@ class ComposeFileManager:
"context": ".",
"dockerfile": "DockerfileUbuntu",
},
- "user":"root",
+ "user": "root",
"working_dir": "/home/retoor/projects/snek",
"environment": [f"SNEK_UID={name}"],
}
@@ -109,9 +111,9 @@ class ComposeFileManager:
instance = self.compose.get("services", {}).get(name)
if not instance:
return None
- instance = json.loads(json.dumps(instance,default=str))
- instance['status'] = await self.get_instance_status(name)
- return instance
+ instance = json.loads(json.dumps(instance, default=str))
+ instance["status"] = await self.get_instance_status(name)
+ return instance
def duplicate_instance(self, name, new_name):
orig = self.get_instance(name)
@@ -135,7 +137,14 @@ class ComposeFileManager:
if name not in self.list_instances():
return "error"
proc = await asyncio.create_subprocess_exec(
- "docker", "compose", "-f", self.compose_path, "ps", "--services", "--filter", f"status=running",
+ "docker",
+ "compose",
+ "-f",
+ self.compose_path,
+ "ps",
+ "--services",
+ "--filter",
+ f"status=running",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
@@ -148,25 +157,30 @@ class ComposeFileManager:
await self.event_handler(name, "stdin", data)
proc = self.running_instances.get(name)
if not proc:
- return False
+ return False
try:
- os.write(proc['master'], data.encode())
+ os.write(proc["master"], data.encode())
return True
except Exception as ex:
print(ex)
await self.stop(name)
- return False
+ return False
async def stop(self, name):
"""Asynchronously stop a container by doing 'docker compose stop [name]'."""
if name not in self.list_instances():
- return False
+ return False
status = await self.get_instance_status(name)
if status != "running":
- return True
-
+ return True
+
proc = await asyncio.create_subprocess_exec(
- "docker", "compose", "-f", self.compose_path, "stop", name,
+ "docker",
+ "compose",
+ "-f",
+ self.compose_path,
+ "stop",
+ name,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
@@ -177,30 +191,41 @@ class ComposeFileManager:
try:
stdout, stderr = await proc.communicate()
except Exception as ex:
- stdout = b''
+ stdout = b""
stderr = str(ex).encode()
- await self.event_handler(name,"stdout",stdout or stderr)
- print("Return code", proc.returncode)
+ await self.event_handler(name, "stdout", stdout or stderr)
+ print("Return code", proc.returncode)
if stdout:
- await self.event_handler(name,"stdout",stdout or b'')
+ await self.event_handler(name, "stdout", stdout or b"")
return stdout and stdout.decode(errors="ignore") or ""
- await self.event_handler(name,"stdout",stderr or b'')
+ await self.event_handler(name, "stdout", stderr or b"")
return stderr and stderr.decode(errors="ignore") or ""
async def start(self, name):
"""Asynchronously start a container by doing 'docker compose up -d [name]'."""
if name not in self.list_instances():
- return False
-
+ return False
+
status = await self.get_instance_status(name)
- if name in self.running_instances and status == "running" and self.running_instances.get(name) and self.running_instances.get(name).get('proc').returncode == None:
- return True
+ if (
+ name in self.running_instances
+ and status == "running"
+ and self.running_instances.get(name)
+ and self.running_instances.get(name).get("proc").returncode is None
+ ):
+ return True
elif name in self.running_instances:
del self.running_instances[name]
-
+
proc = await asyncio.create_subprocess_exec(
- "docker", "compose", "-f", self.compose_path, "up", name, "-d",
+ "docker",
+ "compose",
+ "-f",
+ self.compose_path,
+ "up",
+ name,
+ "-d",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE,
@@ -210,24 +235,29 @@ class ComposeFileManager:
try:
stdout, stderr = await proc.communicate()
except Exception as ex:
- stdout = b''
+ stdout = b""
stderr = str(ex).encode()
if stdout:
print(stdout.decode(errors="ignore"))
if stderr:
print(stderr.decode(errors="ignore"))
- await self.event_handler(name,"stdout",stdout or stderr)
- print("Return code", proc.returncode)
+ await self.event_handler(name, "stdout", stdout or stderr)
+ print("Return code", proc.returncode)
master, slave = pty.openpty()
proc = await asyncio.create_subprocess_exec(
- "docker", "compose", "-f", self.compose_path, "exec", name, "/usr/local/bin/entry",
+ "docker",
+ "compose",
+ "-f",
+ self.compose_path,
+ "exec",
+ name,
+ "/usr/local/bin/entry",
stdin=slave,
stdout=slave,
stderr=slave,
)
- proc = {'proc': proc, 'master': master, 'slave': slave}
+ proc = {"proc": proc, "master": master, "slave": slave}
self.running_instances[name] = proc
await self._create_readers(name)
return True
-
diff --git a/src/snek/system/mapper.py b/src/snek/system/mapper.py
index c4f765e..6ccc19b 100644
--- a/src/snek/system/mapper.py
+++ b/src/snek/system/mapper.py
@@ -1,7 +1,7 @@
DEFAULT_LIMIT = 30
import asyncio
import typing
-import traceback
+
from snek.system.model import BaseModel
@@ -11,7 +11,7 @@ class BaseMapper:
default_limit: int = DEFAULT_LIMIT
table_name: str = None
semaphore = asyncio.Semaphore(1)
-
+
def __init__(self, app):
self.app = app
@@ -27,17 +27,16 @@ class BaseMapper:
async def run_in_executor(self, func, *args, **kwargs):
use_semaphore = kwargs.pop("use_semaphore", False)
-
- def _execute():
+
+ def _execute():
result = func(*args, **kwargs)
if use_semaphore:
self.db.commit()
- return result
-
- return _execute()
- #async with self.semaphore:
- # return await self.loop.run_in_executor(None, _execute)
+ return result
+ return _execute()
+ # async with self.semaphore:
+ # return await self.loop.run_in_executor(None, _execute)
async def new(self):
return self.model_class(mapper=self, app=self.app)
@@ -51,7 +50,7 @@ class BaseMapper:
kwargs["uid"] = uid
if not kwargs.get("deleted_at"):
kwargs["deleted_at"] = None
- #traceback.print_exc()
+ # traceback.print_exc()
record = await self.db.get(self.table_name, kwargs)
if not record:
@@ -65,20 +64,18 @@ class BaseMapper:
async def exists(self, **kwargs):
return await self.db.count(self.table_name, kwargs)
-
- #return await self.run_in_executor(self.table.exists, **kwargs)
+ # return await self.run_in_executor(self.table.exists, **kwargs)
async def count(self, **kwargs) -> int:
- return await self.db.count(self.table_name,kwargs)
-
+ return await self.db.count(self.table_name, kwargs)
async def save(self, model: BaseModel) -> bool:
if not model.record.get("uid"):
raise Exception(f"Attempt to save without uid: {model.record}.")
- model.updated_at.update()
+ model.updated_at.update()
await self.upsert(model)
return model
- #return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
+ # return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
async def find(self, **kwargs) -> typing.AsyncGenerator:
if not kwargs.get("_limit"):
@@ -100,10 +97,12 @@ class BaseMapper:
yield dict(record)
async def update(self, model):
- if not model["deleted_at"] is None:
+ if model["deleted_at"] is not None:
raise Exception("Can't update deleted record.")
model.updated_at.update()
- return await self.db.update(self.table_name, model.record, {"uid": model["uid"]})
+ return await self.db.update(
+ self.table_name, model.record, {"uid": model["uid"]}
+ )
async def upsert(self, model):
model.updated_at.update()
diff --git a/src/snek/system/markdown.py b/src/snek/system/markdown.py
index 3160ab7..0b70367 100644
--- a/src/snek/system/markdown.py
+++ b/src/snek/system/markdown.py
@@ -2,7 +2,7 @@
import re
from types import SimpleNamespace
-from app.cache import time_cache_async, time_cache
+from app.cache import time_cache, time_cache_async
from mistune import HTMLRenderer, Markdown
from mistune.plugins.formatting import strikethrough
from mistune.plugins.spoiler import spoiler
@@ -12,7 +12,6 @@ from pygments.formatters import html
from pygments.lexers import get_lexer_by_name
-
def strip_markdown(md_text):
# Remove code blocks (
md_text = re.sub(r"[\s\S]?```", "", md_text)
@@ -50,7 +49,7 @@ class MarkdownRenderer(HTMLRenderer):
return get_lexer_by_name(lang, stripall=True)
except:
return get_lexer_by_name(default, stripall=True)
-
+
@time_cache(timeout=60 * 60)
def block_code(self, code, lang=None, info=None):
if not lang:
diff --git a/src/snek/system/middleware.py b/src/snek/system/middleware.py
index f3a13d8..fea92ba 100644
--- a/src/snek/system/middleware.py
+++ b/src/snek/system/middleware.py
@@ -3,31 +3,17 @@
# This code provides middleware functions for an aiohttp server to manage and modify CSP, CORS, and authentication headers.
import secrets
+
from aiohttp import web
@web.middleware
async def csp_middleware(request, handler):
nonce = secrets.token_hex(16)
- origin = request.headers.get('Origin')
- csp_policy = (
- "default-src 'self'; "
- f"script-src 'self' {origin} 'nonce-{nonce}'; "
- f"style-src 'self' 'unsafe-inline' {origin} 'nonce-{nonce}'; "
- "img-src *; "
- "connect-src 'self' https://umami.molodetz.nl; "
- "font-src *; "
- "object-src 'none'; "
- "base-uri 'self'; "
- "form-action 'self'; "
- "frame-src 'self'; "
- "worker-src *; "
- "media-src *; "
- "manifest-src 'self';"
- )
- request['csp_nonce'] = nonce
+ request.headers.get("Origin")
+ request["csp_nonce"] = nonce
response = await handler(request)
- #response.headers['Content-Security-Policy'] = csp_policy
+ # response.headers['Content-Security-Policy'] = csp_policy
return response
@@ -37,6 +23,7 @@ async def no_cors_middleware(request, handler):
response.headers.pop("Access-Control-Allow-Origin", None)
return response
+
@web.middleware
async def cors_allow_middleware(request, handler):
response = await handler(request)
@@ -48,6 +35,7 @@ async def cors_allow_middleware(request, handler):
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
+
@web.middleware
async def auth_middleware(request, handler):
request["user"] = None
@@ -57,6 +45,7 @@ async def auth_middleware(request, handler):
)
return await handler(request)
+
@web.middleware
async def cors_middleware(request, handler):
if request.headers.get("Allow"):
diff --git a/src/snek/system/service.py b/src/snek/system/service.py
index 82c5d1c..5934a6e 100644
--- a/src/snek/system/service.py
+++ b/src/snek/system/service.py
@@ -1,5 +1,4 @@
from snek.mapper import get_mapper
-from snek.model.user import UserModel
from snek.system.mapper import BaseMapper
@@ -40,7 +39,7 @@ class BaseService:
yield record
async def get(self, *args, **kwargs):
- if not "deleted_at" in kwargs:
+ if "deleted_at" not in kwargs:
kwargs["deleted_at"] = None
uid = kwargs.get("uid")
if args:
@@ -50,7 +49,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class:
return result
kwargs["uid"] = uid
- print(kwargs,"ZZZZZZZ")
+ print(kwargs, "ZZZZZZZ")
result = await self.mapper.get(**kwargs)
if result:
await self.cache.set(result["uid"], result)
diff --git a/src/snek/system/stats.py b/src/snek/system/stats.py
index c2148ff..dd5c692 100644
--- a/src/snek/system/stats.py
+++ b/src/snek/system/stats.py
@@ -1,34 +1,41 @@
-import asyncio
-from aiohttp import web, WSMsgType
-
-from datetime import datetime, timedelta, timezone
-from collections import defaultdict
import html
+from collections import defaultdict
+from datetime import datetime, timedelta, timezone
+
+from aiohttp import WSMsgType, web
+
def create_stats_structure():
"""Creates the nested dictionary structure for storing statistics."""
+
def nested_dd():
return defaultdict(lambda: defaultdict(int))
+
return defaultdict(nested_dd)
+
def get_time_keys(dt: datetime):
"""Generates dictionary keys for different time granularities."""
return {
- "hour": dt.strftime('%Y-%m-%d-%H'),
- "day": dt.strftime('%Y-%m-%d'),
- "week": dt.strftime('%Y-%W'), # Week number, Monday is first day
- "month": dt.strftime('%Y-%m'),
+ "hour": dt.strftime("%Y-%m-%d-%H"),
+ "day": dt.strftime("%Y-%m-%d"),
+ "week": dt.strftime("%Y-%W"), # Week number, Monday is first day
+ "month": dt.strftime("%Y-%m"),
}
+
def update_stats_counters(stats_dict: defaultdict, now: datetime):
"""Increments the appropriate time-based counters in a stats dictionary."""
keys = get_time_keys(now)
- stats_dict['by_hour'][keys['hour']] += 1
- stats_dict['by_day'][keys['day']] += 1
- stats_dict['by_week'][keys['week']] += 1
- stats_dict['by_month'][keys['month']] += 1
+ stats_dict["by_hour"][keys["hour"]] += 1
+ stats_dict["by_day"][keys["day"]] += 1
+ stats_dict["by_week"][keys["week"]] += 1
+ stats_dict["by_month"][keys["month"]] += 1
-def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: str) -> str:
+
+def generate_time_series_svg(
+ title: str, data: list[tuple[str, int]], y_label: str
+) -> str:
"""Generates a responsive SVG bar chart for time-series data."""
if not data:
return f"{html.escape(title)}
No data yet.
"
@@ -36,14 +43,14 @@ def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: s
svg_height, svg_width = 250, 600
bar_padding = 5
bar_width = (svg_width - 50) / len(data) - bar_padding
-
+
bars = ""
labels = ""
for i, (key, val) in enumerate(data):
bar_height = (val / max_val) * (svg_height - 50) if max_val > 0 else 0
x = i * (bar_width + bar_padding) + 40
y = svg_height - bar_height - 30
-
+
bars += f'{html.escape(key)}: {val}'
labels += f'{html.escape(key)}'
@@ -61,25 +68,32 @@ def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: s
"""
+
@web.middleware
async def middleware(request, handler):
"""Middleware to count all incoming HTTP requests."""
# Avoid counting requests to the stats page itself
- if request.path.startswith('/stats.html'):
+ if request.path.startswith("/stats.html"):
return await handler(request)
-
- update_stats_counters(request.app['stats']['http_requests'], datetime.now(timezone.utc))
+
+ update_stats_counters(
+ request.app["stats"]["http_requests"], datetime.now(timezone.utc)
+ )
return await handler(request)
+
def update_websocket_stats(app):
- update_stats_counters(app['stats']['websocket_requests'], datetime.now(timezone.utc))
+ update_stats_counters(
+ app["stats"]["websocket_requests"], datetime.now(timezone.utc)
+ )
+
async def pipe_and_count_websocket(ws_from, ws_to, stats_dict):
"""This function proxies WebSocket messages AND counts them."""
async for msg in ws_from:
# This is the key part for monitoring WebSockets
update_stats_counters(stats_dict, datetime.now(timezone.utc))
-
+
if msg.type == WSMsgType.TEXT:
await ws_to.send_str(msg.data)
elif msg.type == WSMsgType.BINARY:
@@ -91,27 +105,27 @@ async def pipe_and_count_websocket(ws_from, ws_to, stats_dict):
async def stats_handler(request: web.Request):
"""Handler to display the statistics dashboard."""
- stats = request.app['stats']
+ stats = request.app["stats"]
now = datetime.now(timezone.utc)
-
+
# Helper to prepare data for charts
def get_data(source, period, count):
data = []
for i in range(count - 1, -1, -1):
- if period == 'hour':
+ if period == "hour":
dt = now - timedelta(hours=i)
- key, label = dt.strftime('%Y-%m-%d-%H'), dt.strftime('%H:00')
- data.append((label, source['by_hour'].get(key, 0)))
- elif period == 'day':
+ key, label = dt.strftime("%Y-%m-%d-%H"), dt.strftime("%H:00")
+ data.append((label, source["by_hour"].get(key, 0)))
+ elif period == "day":
dt = now - timedelta(days=i)
- key, label = dt.strftime('%Y-%m-%d'), dt.strftime('%a')
- data.append((label, source['by_day'].get(key, 0)))
+ key, label = dt.strftime("%Y-%m-%d"), dt.strftime("%a")
+ data.append((label, source["by_day"].get(key, 0)))
return data
- http_hourly = get_data(stats['http_requests'], 'hour', 24)
- ws_hourly = get_data(stats['ws_messages'], 'hour', 24)
- http_daily = get_data(stats['http_requests'], 'day', 7)
- ws_daily = get_data(stats['ws_messages'], 'day', 7)
+ http_hourly = get_data(stats["http_requests"], "hour", 24)
+ ws_hourly = get_data(stats["ws_messages"], "hour", 24)
+ http_daily = get_data(stats["http_requests"], "day", 7)
+ ws_daily = get_data(stats["ws_messages"], "day", 7)
body = f"""
App Stats
@@ -125,5 +139,4 @@ async def stats_handler(request: web.Request):
{generate_time_series_svg("WebSocket Messages", ws_daily, "Msgs/Day")}