Unformatted code

This commit is contained in:
BordedDev 2025-08-08 18:45:42 +02:00
parent 40f9646251
commit 51b4898078
5 changed files with 37 additions and 41 deletions

View File

@ -3,13 +3,13 @@ import logging
import pathlib import pathlib
import ssl import ssl
import uuid import uuid
import signal
from datetime import datetime from datetime import datetime
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from snek import snode from snek import snode
from snek.view.threads import ThreadsView from snek.view.threads import ThreadsView
from snek.system.ads import AsyncDataSet from snek.system.ads import AsyncDataSet
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address from ipaddress import ip_address
@ -31,11 +31,7 @@ from snek.sgit import GitApplication
from snek.sssh import start_ssh_server from snek.sssh import start_ssh_server
from snek.system import http from snek.system import http
from snek.system.cache import Cache from snek.system.cache import Cache
from snek.system.stats import ( from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler
middleware as stats_middleware,
create_stats_structure,
stats_handler,
)
from snek.system.markdown import MarkdownExtension from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler from snek.system.profiler import profiler_handler
@ -47,11 +43,7 @@ from snek.system.template import (
) )
from snek.view.about import AboutHTMLView, AboutMDView from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView from snek.view.avatar import AvatarView
from snek.view.channel import ( from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
ChannelAttachmentView,
ChannelAttachmentUploadView,
ChannelView,
)
from snek.view.docs import DocsHTMLView, DocsMDView from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveApiView, DriveView from snek.view.drive import DriveApiView, DriveView
from snek.view.channel import ChannelDriveApiView from snek.view.channel import ChannelDriveApiView
@ -183,12 +175,13 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server) self.on_startup.append(self.start_ssh_server)
# self.on_startup.append(self.prepare_database) #self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app): async def prepare_stats(self, app):
app["stats"] = create_stats_structure() app['stats'] = create_stats_structure()
print("Stats prepared", flush=True) print("Stats prepared", flush=True)
@property @property
def uptime_seconds(self): def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds() return (datetime.now() - self.time_start).total_seconds()
@ -232,10 +225,10 @@ class Application(BaseApplication):
# app.loop = asyncio.get_running_loop() # app.loop = asyncio.get_running_loop()
app.executor = ThreadPoolExecutor(max_workers=200) app.executor = ThreadPoolExecutor(max_workers=200)
app.loop.set_default_executor(self.executor) app.loop.set_default_executor(self.executor)
# for sig in (signal.SIGINT, signal.SIGTERM): #for sig in (signal.SIGINT, signal.SIGTERM):
# app.loop.add_signal_handler( #app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown()) # sig, lambda: asyncio.create_task(self.services.container.shutdown())
# ) #)
async def create_task(self, task): async def create_task(self, task):
await self.tasks.put(task) await self.tasks.put(task)
@ -251,6 +244,7 @@ class Application(BaseApplication):
print(ex) print(ex)
self.db.commit() self.db.commit()
async def prepare_database(self, app): async def prepare_database(self, app):
await self.db.query_raw("PRAGMA journal_mode=WAL") await self.db.query_raw("PRAGMA journal_mode=WAL")
await self.db.query_raw("PRAGMA syncnorm=off") await self.db.query_raw("PRAGMA syncnorm=off")
@ -297,7 +291,9 @@ class Application(BaseApplication):
self.router.add_view( self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
) )
self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView) self.router.add_view(
"/channel/{channel_uid}/drive.json", ChannelDriveApiView
)
self.router.add_view( self.router.add_view(
"/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView "/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
) )
@ -418,9 +414,9 @@ class Application(BaseApplication):
) )
try: try:
context["nonce"] = request["csp_nonce"] context["nonce"] = request['csp_nonce']
except: except:
context["nonce"] = "?" context['nonce'] = '?'
rendered = await super().render_template(template, request, context) rendered = await super().render_template(template, request, context)
@ -470,14 +466,14 @@ class Application(BaseApplication):
@asynccontextmanager @asynccontextmanager
async def no_save(self): async def no_save(self):
stats = {"count": 0} stats = {
'count': 0
}
async def patched_save(*args, **kwargs): async def patched_save(*args, **kwargs):
await self.cache.set(args[0]["uid"], args[0]) await self.cache.set(args[0]["uid"], args[0])
stats["count"] = stats["count"] + 1 stats['count'] = stats['count'] + 1
print(f"save is ignored {stats['count']} times") print(f"save is ignored {stats['count']} times")
return args[0] return args[0]
save_original = self.services.channel_message.mapper.save save_original = self.services.channel_message.mapper.save
self.services.channel_message.mapper.save = patched_save self.services.channel_message.mapper.save = patched_save
raised_exception = None raised_exception = None
@ -490,7 +486,6 @@ class Application(BaseApplication):
if raised_exception: if raised_exception:
raise raised_exception raise raised_exception
app = Application(db_path="sqlite:///snek.db") app = Application(db_path="sqlite:///snek.db")

View File

@ -2,7 +2,6 @@ from snek.system.service import BaseService
from snek.system.template import sanitize_html from snek.system.template import sanitize_html
import time import time
class ChannelMessageService(BaseService): class ChannelMessageService(BaseService):
mapper_name = "channel_message" mapper_name = "channel_message"
@ -35,6 +34,7 @@ class ChannelMessageService(BaseService):
time.sleep(0.1) time.sleep(0.1)
print(ex, flush=True) print(ex, flush=True)
while True: while True:
changed = 0 changed = 0
async for message in self.find(is_final=False): async for message in self.find(is_final=False):
@ -72,7 +72,7 @@ class ChannelMessageService(BaseService):
try: try:
template = self.app.jinja2_env.get_template("message.html") template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context) model["html"] = template.render(**context)
model["html"] = sanitize_html(model["html"]) model['html'] = sanitize_html(model['html'])
except Exception as ex: except Exception as ex:
print(ex, flush=True) print(ex, flush=True)
@ -99,9 +99,9 @@ class ChannelMessageService(BaseService):
if not user: if not user:
return {} return {}
# if not message["html"].startswith("<chat-message"): #if not message["html"].startswith("<chat-message"):
# message = await self.get(uid=message["uid"]) #message = await self.get(uid=message["uid"])
# await self.save(message) #await self.save(message)
return { return {
"uid": message["uid"], "uid": message["uid"],
@ -129,7 +129,7 @@ class ChannelMessageService(BaseService):
) )
template = self.app.jinja2_env.get_template("message.html") template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context) model["html"] = template.render(**context)
model["html"] = sanitize_html(model["html"]) model['html'] = sanitize_html(model['html'])
return await super().save(model) return await super().save(model)
async def offset(self, channel_uid, page=0, timestamp=None, page_size=30): async def offset(self, channel_uid, page=0, timestamp=None, page_size=30):

View File

@ -16,6 +16,7 @@ from typing import (
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
import unittest import unittest
from types import SimpleNamespace
import asyncio import asyncio
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
@ -102,7 +103,7 @@ class AsyncDataSet:
# No socket file, become server # No socket file, become server
self._is_server = True self._is_server = True
await self._setup_server() await self._setup_server()
except Exception: except Exception as e:
# Fallback to client mode # Fallback to client mode
self._is_server = False self._is_server = False
await self._setup_client() await self._setup_client()
@ -212,7 +213,7 @@ class AsyncDataSet:
data = await request.json() data = await request.json()
try: try:
try: try:
_limit = data.pop("_limit", 60) _limit = data.pop("_limit",60)
except: except:
pass pass
@ -886,6 +887,7 @@ class AsyncDataSet:
except: except:
pass pass
where_clause, where_params = self._build_where(where) where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else "" order_clause = f" ORDER BY {order_by}" if order_by else ""
extra = (f" LIMIT {_limit}" if _limit else "") + ( extra = (f" LIMIT {_limit}" if _limit else "") + (
@ -1207,3 +1209,4 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -9,7 +9,7 @@ from aiohttp import web
@web.middleware @web.middleware
async def csp_middleware(request, handler): async def csp_middleware(request, handler):
nonce = secrets.token_hex(16) nonce = secrets.token_hex(16)
origin = request.headers.get("Origin") origin = request.headers.get('Origin')
csp_policy = ( csp_policy = (
"default-src 'self'; " "default-src 'self'; "
f"script-src 'self' {origin} 'nonce-{nonce}'; " f"script-src 'self' {origin} 'nonce-{nonce}'; "
@ -25,9 +25,9 @@ async def csp_middleware(request, handler):
"media-src *; " "media-src *; "
"manifest-src 'self';" "manifest-src 'self';"
) )
request["csp_nonce"] = nonce request['csp_nonce'] = nonce
response = await handler(request) response = await handler(request)
# response.headers['Content-Security-Policy'] = csp_policy #response.headers['Content-Security-Policy'] = csp_policy
return response return response
@ -37,7 +37,6 @@ async def no_cors_middleware(request, handler):
response.headers.pop("Access-Control-Allow-Origin", None) response.headers.pop("Access-Control-Allow-Origin", None)
return response return response
@web.middleware @web.middleware
async def cors_allow_middleware(request, handler): async def cors_allow_middleware(request, handler):
response = await handler(request) response = await handler(request)
@ -49,7 +48,6 @@ async def cors_allow_middleware(request, handler):
response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Credentials"] = "true"
return response return response
@web.middleware @web.middleware
async def auth_middleware(request, handler): async def auth_middleware(request, handler):
request["user"] = None request["user"] = None
@ -59,7 +57,6 @@ async def auth_middleware(request, handler):
) )
return await handler(request) return await handler(request)
@web.middleware @web.middleware
async def cors_middleware(request, handler): async def cors_middleware(request, handler):
if request.headers.get("Allow"): if request.headers.get("Allow"):

View File

@ -1,4 +1,5 @@
from snek.mapper import get_mapper from snek.mapper import get_mapper
from snek.model.user import UserModel
from snek.system.mapper import BaseMapper from snek.system.mapper import BaseMapper
@ -39,7 +40,7 @@ class BaseService:
yield record yield record
async def get(self, *args, **kwargs): async def get(self, *args, **kwargs):
if "deleted_at" not in kwargs: if not "deleted_at" in kwargs:
kwargs["deleted_at"] = None kwargs["deleted_at"] = None
uid = kwargs.get("uid") uid = kwargs.get("uid")
if args: if args:
@ -49,7 +50,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class: if result and result.__class__ == self.mapper.model_class:
return result return result
kwargs["uid"] = uid kwargs["uid"] = uid
print(kwargs, "ZZZZZZZ") print(kwargs,"ZZZZZZZ")
result = await self.mapper.get(**kwargs) result = await self.mapper.get(**kwargs)
if result: if result:
await self.cache.set(result["uid"], result) await self.cache.set(result["uid"], result)