diff --git a/src/snek/app.py b/src/snek/app.py index f5c0f27..e14aa5f 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -3,13 +3,13 @@ import logging import pathlib import ssl import uuid -import signal from datetime import datetime from contextlib import asynccontextmanager from snek import snode from snek.view.threads import ThreadsView from snek.system.ads import AsyncDataSet + logging.basicConfig(level=logging.DEBUG) from concurrent.futures import ThreadPoolExecutor from ipaddress import ip_address @@ -31,7 +31,11 @@ from snek.sgit import GitApplication from snek.sssh import start_ssh_server from snek.system import http from snek.system.cache import Cache -from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler +from snek.system.stats import ( + middleware as stats_middleware, + create_stats_structure, + stats_handler, +) from snek.system.markdown import MarkdownExtension from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware from snek.system.profiler import profiler_handler @@ -43,7 +47,11 @@ from snek.system.template import ( ) from snek.view.about import AboutHTMLView, AboutMDView from snek.view.avatar import AvatarView -from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView +from snek.view.channel import ( + ChannelAttachmentView, + ChannelAttachmentUploadView, + ChannelView, +) from snek.view.docs import DocsHTMLView, DocsMDView from snek.view.drive import DriveApiView, DriveView from snek.view.channel import ChannelDriveApiView @@ -107,7 +115,7 @@ async def ip2location_middleware(request, handler): user["city"] if user["city"] != location.city: user["country_long"] = location.country - user["country_short"] = locaion.country_short + user["country_short"] = location.country_short user["city"] = location.city user["region"] = location.region user["latitude"] = location.latitude @@ -165,7 +173,7 @@ class Application(BaseApplication): self.mappers = get_mappers(app=self) self.broadcast_service = None self.user_availability_service_task = None - + self.setup_router() base_path = pathlib.Path(__file__).parent self.ip2location = IP2Location.IP2Location( @@ -175,12 +183,11 @@ class Application(BaseApplication): self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_ssh_server) - #self.on_startup.append(self.prepare_database) - - async def prepare_stats(self, app): - app['stats'] = create_stats_structure() - print("Stats prepared", flush=True) + # self.on_startup.append(self.prepare_database) + async def prepare_stats(self, app): + app["stats"] = create_stats_structure() + print("Stats prepared", flush=True) @property def uptime_seconds(self): @@ -225,10 +232,10 @@ class Application(BaseApplication): # app.loop = asyncio.get_running_loop() app.executor = ThreadPoolExecutor(max_workers=200) app.loop.set_default_executor(self.executor) - #for sig in (signal.SIGINT, signal.SIGTERM): - #app.loop.add_signal_handler( - # sig, lambda: asyncio.create_task(self.services.container.shutdown()) - #) + # for sig in (signal.SIGINT, signal.SIGTERM): + # app.loop.add_signal_handler( + # sig, lambda: asyncio.create_task(self.services.container.shutdown()) + # ) async def create_task(self, task): await self.tasks.put(task) @@ -243,7 +250,6 @@ class Application(BaseApplication): except Exception as ex: print(ex) self.db.commit() - async def prepare_database(self, app): await self.db.query_raw("PRAGMA journal_mode=WAL") @@ -291,9 +297,7 @@ class Application(BaseApplication): self.router.add_view( "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView ) - self.router.add_view( - "/channel/{channel_uid}/drive.json", ChannelDriveApiView - ) + self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView) self.router.add_view( "/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView ) @@ -412,11 +416,11 @@ class Application(BaseApplication): self.jinja2_env.loader = await self.get_user_template_loader( request.session.get("uid") ) - + try: - context["nonce"] = request['csp_nonce'] + context["nonce"] = request["csp_nonce"] except: - context['nonce'] = '?' + context["nonce"] = "?" rendered = await super().render_template(template, request, context) @@ -451,7 +455,7 @@ class Application(BaseApplication): async def get_user_template_loader(self, uid=None): template_paths = [] - for admin_uid in self.services.user.get_admin_uids(): + for admin_uid in await self.services.user.get_admin_uids(): user_template_path = await self.services.user.get_template_path(admin_uid) if user_template_path: template_paths.append(user_template_path) @@ -463,22 +467,22 @@ class Application(BaseApplication): template_paths.append(self.template_path) return FileSystemLoader(template_paths) - + @asynccontextmanager async def no_save(self): - stats = { - 'count': 0 - } + stats = {"count": 0} + async def patched_save(*args, **kwargs): await self.cache.set(args[0]["uid"], args[0]) - stats['count'] = stats['count'] + 1 + stats["count"] = stats["count"] + 1 print(f"save is ignored {stats['count']} times") return args[0] + save_original = self.services.channel_message.mapper.save self.services.channel_message.mapper.save = patched_save raised_exception = None try: - yield + yield except Exception as ex: raised_exception = ex finally: @@ -486,6 +490,7 @@ class Application(BaseApplication): if raised_exception: raise raised_exception + app = Application(db_path="sqlite:///snek.db") diff --git a/src/snek/mapper/user.py b/src/snek/mapper/user.py index e0df494..dc3387f 100644 --- a/src/snek/mapper/user.py +++ b/src/snek/mapper/user.py @@ -6,11 +6,11 @@ class UserMapper(BaseMapper): table_name = "user" model_class = UserModel - def get_admin_uids(self): + async def get_admin_uids(self): try: return [ user["uid"] - for user in self.db.query( + for user in await self.db.query( "SELECT uid FROM user WHERE is_admin = :is_admin", {"is_admin": True}, ) diff --git a/src/snek/service/channel_message.py b/src/snek/service/channel_message.py index aae68b9..bbb870c 100644 --- a/src/snek/service/channel_message.py +++ b/src/snek/service/channel_message.py @@ -2,6 +2,7 @@ from snek.system.service import BaseService from snek.system.template import sanitize_html import time + class ChannelMessageService(BaseService): mapper_name = "channel_message" @@ -29,12 +30,11 @@ class ChannelMessageService(BaseService): ) if html != message["html"]: print("Reredefined message", message["uid"]) - + except Exception as ex: time.sleep(0.1) print(ex, flush=True) - - + while True: changed = 0 async for message in self.find(is_final=False): @@ -72,7 +72,7 @@ class ChannelMessageService(BaseService): try: template = self.app.jinja2_env.get_template("message.html") model["html"] = template.render(**context) - model['html'] = sanitize_html(model['html']) + model["html"] = sanitize_html(model["html"]) except Exception as ex: print(ex, flush=True) @@ -99,10 +99,10 @@ class ChannelMessageService(BaseService): if not user: return {} - #if not message["html"].startswith(" 0: async for model in self.query( f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size", - { + *{ "channel_uid": channel_uid, "page_size": page_size, "offset": offset, "timestamp": timestamp, - }, + }.values(), ): results.append(model) else: async for model in self.query( f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset", - { + *{ "channel_uid": channel_uid, "page_size": page_size, "offset": offset, - }, + }.values(), ): results.append(model) diff --git a/src/snek/system/ads.py b/src/snek/system/ads.py index 34ed414..80995d3 100644 --- a/src/snek/system/ads.py +++ b/src/snek/system/ads.py @@ -16,7 +16,6 @@ from typing import ( from pathlib import Path import aiosqlite import unittest -from types import SimpleNamespace import asyncio import aiohttp from aiohttp import web @@ -103,7 +102,7 @@ class AsyncDataSet: # No socket file, become server self._is_server = True await self._setup_server() - except Exception as e: + except Exception: # Fallback to client mode self._is_server = False await self._setup_client() @@ -213,7 +212,7 @@ class AsyncDataSet: data = await request.json() try: try: - _limit = data.pop("_limit",60) + _limit = data.pop("_limit", 60) except: pass @@ -745,8 +744,13 @@ class AsyncDataSet: def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: if not where: return "", [] - clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) - return " WHERE " + " AND ".join(clauses), list(vals) + clauses, vals = zip( + *[ + (f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v) + for k, v in where.items() + ] + ) + return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None) async def _server_insert( self, table: str, args: Dict[str, Any], return_id: bool = False @@ -881,7 +885,6 @@ class AsyncDataSet: _limit = where.pop("_limit") except: pass - where_clause, where_params = self._build_where(where) order_clause = f" ORDER BY {order_by}" if order_by else "" @@ -1204,4 +1207,3 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase): if __name__ == "__main__": unittest.main() - diff --git a/src/snek/system/middleware.py b/src/snek/system/middleware.py index db25327..558cd3d 100644 --- a/src/snek/system/middleware.py +++ b/src/snek/system/middleware.py @@ -9,7 +9,7 @@ from aiohttp import web @web.middleware async def csp_middleware(request, handler): nonce = secrets.token_hex(16) - origin = request.headers.get('Origin') + origin = request.headers.get("Origin") csp_policy = ( "default-src 'self'; " f"script-src 'self' {origin} 'nonce-{nonce}'; " @@ -25,9 +25,9 @@ async def csp_middleware(request, handler): "media-src *; " "manifest-src 'self';" ) - request['csp_nonce'] = nonce + request["csp_nonce"] = nonce response = await handler(request) - #response.headers['Content-Security-Policy'] = csp_policy + # response.headers['Content-Security-Policy'] = csp_policy return response @@ -37,6 +37,7 @@ async def no_cors_middleware(request, handler): response.headers.pop("Access-Control-Allow-Origin", None) return response + @web.middleware async def cors_allow_middleware(request, handler): response = await handler(request) @@ -48,15 +49,17 @@ async def cors_allow_middleware(request, handler): response.headers["Access-Control-Allow-Credentials"] = "true" return response + @web.middleware async def auth_middleware(request, handler): request["user"] = None if request.session.get("uid") and request.session.get("logged_in"): request["user"] = await request.app.services.user.get( - uid=request.app.session.get("uid") + uid=request.session.get("uid") ) return await handler(request) + @web.middleware async def cors_middleware(request, handler): if request.headers.get("Allow"): @@ -69,5 +72,5 @@ async def cors_middleware(request, handler): response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Allow-Credentials"] = "true" - + return response diff --git a/src/snek/system/service.py b/src/snek/system/service.py index b47dbc4..5934a6e 100644 --- a/src/snek/system/service.py +++ b/src/snek/system/service.py @@ -1,5 +1,4 @@ from snek.mapper import get_mapper -from snek.model.user import UserModel from snek.system.mapper import BaseMapper @@ -36,12 +35,12 @@ class BaseService: return await self.mapper.new() async def query(self, sql, *args): - for record in self.app.db.query(sql, *args): + for record in await self.app.db.query(sql, *args): yield record async def get(self, *args, **kwargs): - if not "deleted_at" in kwargs: - kwargs["deleted_at"] = None + if "deleted_at" not in kwargs: + kwargs["deleted_at"] = None uid = kwargs.get("uid") if args: uid = args[0] @@ -50,7 +49,7 @@ class BaseService: if result and result.__class__ == self.mapper.model_class: return result kwargs["uid"] = uid - print(kwargs,"ZZZZZZZ") + print(kwargs, "ZZZZZZZ") result = await self.mapper.get(**kwargs) if result: await self.cache.set(result["uid"], result)