Fixed db related issues

This commit is contained in:
BordedDev 2025-08-08 18:29:52 +02:00
parent 5e99e894e9
commit 40f9646251
6 changed files with 69 additions and 60 deletions

View File

@ -3,13 +3,13 @@ import logging
import pathlib import pathlib
import ssl import ssl
import uuid import uuid
import signal
from datetime import datetime from datetime import datetime
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from snek import snode from snek import snode
from snek.view.threads import ThreadsView from snek.view.threads import ThreadsView
from snek.system.ads import AsyncDataSet from snek.system.ads import AsyncDataSet
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address from ipaddress import ip_address
@ -31,7 +31,11 @@ from snek.sgit import GitApplication
from snek.sssh import start_ssh_server from snek.sssh import start_ssh_server
from snek.system import http from snek.system import http
from snek.system.cache import Cache from snek.system.cache import Cache
from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler from snek.system.stats import (
middleware as stats_middleware,
create_stats_structure,
stats_handler,
)
from snek.system.markdown import MarkdownExtension from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler from snek.system.profiler import profiler_handler
@ -43,7 +47,11 @@ from snek.system.template import (
) )
from snek.view.about import AboutHTMLView, AboutMDView from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView from snek.view.avatar import AvatarView
from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView from snek.view.channel import (
ChannelAttachmentView,
ChannelAttachmentUploadView,
ChannelView,
)
from snek.view.docs import DocsHTMLView, DocsMDView from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveApiView, DriveView from snek.view.drive import DriveApiView, DriveView
from snek.view.channel import ChannelDriveApiView from snek.view.channel import ChannelDriveApiView
@ -107,7 +115,7 @@ async def ip2location_middleware(request, handler):
user["city"] user["city"]
if user["city"] != location.city: if user["city"] != location.city:
user["country_long"] = location.country user["country_long"] = location.country
user["country_short"] = locaion.country_short user["country_short"] = location.country_short
user["city"] = location.city user["city"] = location.city
user["region"] = location.region user["region"] = location.region
user["latitude"] = location.latitude user["latitude"] = location.latitude
@ -165,7 +173,7 @@ class Application(BaseApplication):
self.mappers = get_mappers(app=self) self.mappers = get_mappers(app=self)
self.broadcast_service = None self.broadcast_service = None
self.user_availability_service_task = None self.user_availability_service_task = None
self.setup_router() self.setup_router()
base_path = pathlib.Path(__file__).parent base_path = pathlib.Path(__file__).parent
self.ip2location = IP2Location.IP2Location( self.ip2location = IP2Location.IP2Location(
@ -175,12 +183,11 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server) self.on_startup.append(self.start_ssh_server)
#self.on_startup.append(self.prepare_database) # self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app):
app['stats'] = create_stats_structure()
print("Stats prepared", flush=True)
async def prepare_stats(self, app):
app["stats"] = create_stats_structure()
print("Stats prepared", flush=True)
@property @property
def uptime_seconds(self): def uptime_seconds(self):
@ -225,10 +232,10 @@ class Application(BaseApplication):
# app.loop = asyncio.get_running_loop() # app.loop = asyncio.get_running_loop()
app.executor = ThreadPoolExecutor(max_workers=200) app.executor = ThreadPoolExecutor(max_workers=200)
app.loop.set_default_executor(self.executor) app.loop.set_default_executor(self.executor)
#for sig in (signal.SIGINT, signal.SIGTERM): # for sig in (signal.SIGINT, signal.SIGTERM):
#app.loop.add_signal_handler( # app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown()) # sig, lambda: asyncio.create_task(self.services.container.shutdown())
#) # )
async def create_task(self, task): async def create_task(self, task):
await self.tasks.put(task) await self.tasks.put(task)
@ -243,7 +250,6 @@ class Application(BaseApplication):
except Exception as ex: except Exception as ex:
print(ex) print(ex)
self.db.commit() self.db.commit()
async def prepare_database(self, app): async def prepare_database(self, app):
await self.db.query_raw("PRAGMA journal_mode=WAL") await self.db.query_raw("PRAGMA journal_mode=WAL")
@ -291,9 +297,7 @@ class Application(BaseApplication):
self.router.add_view( self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
) )
self.router.add_view( self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView)
"/channel/{channel_uid}/drive.json", ChannelDriveApiView
)
self.router.add_view( self.router.add_view(
"/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView "/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
) )
@ -412,11 +416,11 @@ class Application(BaseApplication):
self.jinja2_env.loader = await self.get_user_template_loader( self.jinja2_env.loader = await self.get_user_template_loader(
request.session.get("uid") request.session.get("uid")
) )
try: try:
context["nonce"] = request['csp_nonce'] context["nonce"] = request["csp_nonce"]
except: except:
context['nonce'] = '?' context["nonce"] = "?"
rendered = await super().render_template(template, request, context) rendered = await super().render_template(template, request, context)
@ -451,7 +455,7 @@ class Application(BaseApplication):
async def get_user_template_loader(self, uid=None): async def get_user_template_loader(self, uid=None):
template_paths = [] template_paths = []
for admin_uid in self.services.user.get_admin_uids(): for admin_uid in await self.services.user.get_admin_uids():
user_template_path = await self.services.user.get_template_path(admin_uid) user_template_path = await self.services.user.get_template_path(admin_uid)
if user_template_path: if user_template_path:
template_paths.append(user_template_path) template_paths.append(user_template_path)
@ -463,22 +467,22 @@ class Application(BaseApplication):
template_paths.append(self.template_path) template_paths.append(self.template_path)
return FileSystemLoader(template_paths) return FileSystemLoader(template_paths)
@asynccontextmanager @asynccontextmanager
async def no_save(self): async def no_save(self):
stats = { stats = {"count": 0}
'count': 0
}
async def patched_save(*args, **kwargs): async def patched_save(*args, **kwargs):
await self.cache.set(args[0]["uid"], args[0]) await self.cache.set(args[0]["uid"], args[0])
stats['count'] = stats['count'] + 1 stats["count"] = stats["count"] + 1
print(f"save is ignored {stats['count']} times") print(f"save is ignored {stats['count']} times")
return args[0] return args[0]
save_original = self.services.channel_message.mapper.save save_original = self.services.channel_message.mapper.save
self.services.channel_message.mapper.save = patched_save self.services.channel_message.mapper.save = patched_save
raised_exception = None raised_exception = None
try: try:
yield yield
except Exception as ex: except Exception as ex:
raised_exception = ex raised_exception = ex
finally: finally:
@ -486,6 +490,7 @@ class Application(BaseApplication):
if raised_exception: if raised_exception:
raise raised_exception raise raised_exception
app = Application(db_path="sqlite:///snek.db") app = Application(db_path="sqlite:///snek.db")

View File

@ -6,11 +6,11 @@ class UserMapper(BaseMapper):
table_name = "user" table_name = "user"
model_class = UserModel model_class = UserModel
def get_admin_uids(self): async def get_admin_uids(self):
try: try:
return [ return [
user["uid"] user["uid"]
for user in self.db.query( for user in await self.db.query(
"SELECT uid FROM user WHERE is_admin = :is_admin", "SELECT uid FROM user WHERE is_admin = :is_admin",
{"is_admin": True}, {"is_admin": True},
) )

View File

@ -2,6 +2,7 @@ from snek.system.service import BaseService
from snek.system.template import sanitize_html from snek.system.template import sanitize_html
import time import time
class ChannelMessageService(BaseService): class ChannelMessageService(BaseService):
mapper_name = "channel_message" mapper_name = "channel_message"
@ -29,12 +30,11 @@ class ChannelMessageService(BaseService):
) )
if html != message["html"]: if html != message["html"]:
print("Reredefined message", message["uid"]) print("Reredefined message", message["uid"])
except Exception as ex: except Exception as ex:
time.sleep(0.1) time.sleep(0.1)
print(ex, flush=True) print(ex, flush=True)
while True: while True:
changed = 0 changed = 0
async for message in self.find(is_final=False): async for message in self.find(is_final=False):
@ -72,7 +72,7 @@ class ChannelMessageService(BaseService):
try: try:
template = self.app.jinja2_env.get_template("message.html") template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context) model["html"] = template.render(**context)
model['html'] = sanitize_html(model['html']) model["html"] = sanitize_html(model["html"])
except Exception as ex: except Exception as ex:
print(ex, flush=True) print(ex, flush=True)
@ -99,10 +99,10 @@ class ChannelMessageService(BaseService):
if not user: if not user:
return {} return {}
#if not message["html"].startswith("<chat-message"): # if not message["html"].startswith("<chat-message"):
#message = await self.get(uid=message["uid"]) # message = await self.get(uid=message["uid"])
#await self.save(message) # await self.save(message)
return { return {
"uid": message["uid"], "uid": message["uid"],
"color": user["color"], "color": user["color"],
@ -129,7 +129,7 @@ class ChannelMessageService(BaseService):
) )
template = self.app.jinja2_env.get_template("message.html") template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context) model["html"] = template.render(**context)
model['html'] = sanitize_html(model['html']) model["html"] = sanitize_html(model["html"])
return await super().save(model) return await super().save(model)
async def offset(self, channel_uid, page=0, timestamp=None, page_size=30): async def offset(self, channel_uid, page=0, timestamp=None, page_size=30):
@ -156,22 +156,22 @@ class ChannelMessageService(BaseService):
elif page > 0: elif page > 0:
async for model in self.query( async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size", f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
{ *{
"channel_uid": channel_uid, "channel_uid": channel_uid,
"page_size": page_size, "page_size": page_size,
"offset": offset, "offset": offset,
"timestamp": timestamp, "timestamp": timestamp,
}, }.values(),
): ):
results.append(model) results.append(model)
else: else:
async for model in self.query( async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset", f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
{ *{
"channel_uid": channel_uid, "channel_uid": channel_uid,
"page_size": page_size, "page_size": page_size,
"offset": offset, "offset": offset,
}, }.values(),
): ):
results.append(model) results.append(model)

View File

@ -16,7 +16,6 @@ from typing import (
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
import unittest import unittest
from types import SimpleNamespace
import asyncio import asyncio
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
@ -103,7 +102,7 @@ class AsyncDataSet:
# No socket file, become server # No socket file, become server
self._is_server = True self._is_server = True
await self._setup_server() await self._setup_server()
except Exception as e: except Exception:
# Fallback to client mode # Fallback to client mode
self._is_server = False self._is_server = False
await self._setup_client() await self._setup_client()
@ -213,7 +212,7 @@ class AsyncDataSet:
data = await request.json() data = await request.json()
try: try:
try: try:
_limit = data.pop("_limit",60) _limit = data.pop("_limit", 60)
except: except:
pass pass
@ -745,8 +744,13 @@ class AsyncDataSet:
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where: if not where:
return "", [] return "", []
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) clauses, vals = zip(
return " WHERE " + " AND ".join(clauses), list(vals) *[
(f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v)
for k, v in where.items()
]
)
return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None)
async def _server_insert( async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False self, table: str, args: Dict[str, Any], return_id: bool = False
@ -881,7 +885,6 @@ class AsyncDataSet:
_limit = where.pop("_limit") _limit = where.pop("_limit")
except: except:
pass pass
where_clause, where_params = self._build_where(where) where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else "" order_clause = f" ORDER BY {order_by}" if order_by else ""
@ -1204,4 +1207,3 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -9,7 +9,7 @@ from aiohttp import web
@web.middleware @web.middleware
async def csp_middleware(request, handler): async def csp_middleware(request, handler):
nonce = secrets.token_hex(16) nonce = secrets.token_hex(16)
origin = request.headers.get('Origin') origin = request.headers.get("Origin")
csp_policy = ( csp_policy = (
"default-src 'self'; " "default-src 'self'; "
f"script-src 'self' {origin} 'nonce-{nonce}'; " f"script-src 'self' {origin} 'nonce-{nonce}'; "
@ -25,9 +25,9 @@ async def csp_middleware(request, handler):
"media-src *; " "media-src *; "
"manifest-src 'self';" "manifest-src 'self';"
) )
request['csp_nonce'] = nonce request["csp_nonce"] = nonce
response = await handler(request) response = await handler(request)
#response.headers['Content-Security-Policy'] = csp_policy # response.headers['Content-Security-Policy'] = csp_policy
return response return response
@ -37,6 +37,7 @@ async def no_cors_middleware(request, handler):
response.headers.pop("Access-Control-Allow-Origin", None) response.headers.pop("Access-Control-Allow-Origin", None)
return response return response
@web.middleware @web.middleware
async def cors_allow_middleware(request, handler): async def cors_allow_middleware(request, handler):
response = await handler(request) response = await handler(request)
@ -48,15 +49,17 @@ async def cors_allow_middleware(request, handler):
response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Credentials"] = "true"
return response return response
@web.middleware @web.middleware
async def auth_middleware(request, handler): async def auth_middleware(request, handler):
request["user"] = None request["user"] = None
if request.session.get("uid") and request.session.get("logged_in"): if request.session.get("uid") and request.session.get("logged_in"):
request["user"] = await request.app.services.user.get( request["user"] = await request.app.services.user.get(
uid=request.app.session.get("uid") uid=request.session.get("uid")
) )
return await handler(request) return await handler(request)
@web.middleware @web.middleware
async def cors_middleware(request, handler): async def cors_middleware(request, handler):
if request.headers.get("Allow"): if request.headers.get("Allow"):
@ -69,5 +72,5 @@ async def cors_middleware(request, handler):
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Credentials"] = "true"
return response return response

View File

@ -1,5 +1,4 @@
from snek.mapper import get_mapper from snek.mapper import get_mapper
from snek.model.user import UserModel
from snek.system.mapper import BaseMapper from snek.system.mapper import BaseMapper
@ -36,12 +35,12 @@ class BaseService:
return await self.mapper.new() return await self.mapper.new()
async def query(self, sql, *args): async def query(self, sql, *args):
for record in self.app.db.query(sql, *args): for record in await self.app.db.query(sql, *args):
yield record yield record
async def get(self, *args, **kwargs): async def get(self, *args, **kwargs):
if not "deleted_at" in kwargs: if "deleted_at" not in kwargs:
kwargs["deleted_at"] = None kwargs["deleted_at"] = None
uid = kwargs.get("uid") uid = kwargs.get("uid")
if args: if args:
uid = args[0] uid = args[0]
@ -50,7 +49,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class: if result and result.__class__ == self.mapper.model_class:
return result return result
kwargs["uid"] = uid kwargs["uid"] = uid
print(kwargs,"ZZZZZZZ") print(kwargs, "ZZZZZZZ")
result = await self.mapper.get(**kwargs) result = await self.mapper.get(**kwargs)
if result: if result:
await self.cache.set(result["uid"], result) await self.cache.set(result["uid"], result)