Fixed db related issues

This commit is contained in:
BordedDev 2025-08-08 18:29:52 +02:00
parent 5e99e894e9
commit 40f9646251
6 changed files with 69 additions and 60 deletions

View File

@ -3,13 +3,13 @@ import logging
import pathlib
import ssl
import uuid
import signal
from datetime import datetime
from contextlib import asynccontextmanager
from snek import snode
from snek.view.threads import ThreadsView
from snek.system.ads import AsyncDataSet
logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
@ -31,7 +31,11 @@ from snek.sgit import GitApplication
from snek.sssh import start_ssh_server
from snek.system import http
from snek.system.cache import Cache
from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler
from snek.system.stats import (
middleware as stats_middleware,
create_stats_structure,
stats_handler,
)
from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler
@ -43,7 +47,11 @@ from snek.system.template import (
)
from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView
from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
from snek.view.channel import (
ChannelAttachmentView,
ChannelAttachmentUploadView,
ChannelView,
)
from snek.view.docs import DocsHTMLView, DocsMDView
from snek.view.drive import DriveApiView, DriveView
from snek.view.channel import ChannelDriveApiView
@ -107,7 +115,7 @@ async def ip2location_middleware(request, handler):
user["city"]
if user["city"] != location.city:
user["country_long"] = location.country
user["country_short"] = locaion.country_short
user["country_short"] = location.country_short
user["city"] = location.city
user["region"] = location.region
user["latitude"] = location.latitude
@ -165,7 +173,7 @@ class Application(BaseApplication):
self.mappers = get_mappers(app=self)
self.broadcast_service = None
self.user_availability_service_task = None
self.setup_router()
base_path = pathlib.Path(__file__).parent
self.ip2location = IP2Location.IP2Location(
@ -175,12 +183,11 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server)
#self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app):
app['stats'] = create_stats_structure()
print("Stats prepared", flush=True)
# self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app):
app["stats"] = create_stats_structure()
print("Stats prepared", flush=True)
@property
def uptime_seconds(self):
@ -225,10 +232,10 @@ class Application(BaseApplication):
# app.loop = asyncio.get_running_loop()
app.executor = ThreadPoolExecutor(max_workers=200)
app.loop.set_default_executor(self.executor)
#for sig in (signal.SIGINT, signal.SIGTERM):
#app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown())
#)
# for sig in (signal.SIGINT, signal.SIGTERM):
# app.loop.add_signal_handler(
# sig, lambda: asyncio.create_task(self.services.container.shutdown())
# )
async def create_task(self, task):
await self.tasks.put(task)
@ -243,7 +250,6 @@ class Application(BaseApplication):
except Exception as ex:
print(ex)
self.db.commit()
async def prepare_database(self, app):
await self.db.query_raw("PRAGMA journal_mode=WAL")
@ -291,9 +297,7 @@ class Application(BaseApplication):
self.router.add_view(
"/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
)
self.router.add_view(
"/channel/{channel_uid}/drive.json", ChannelDriveApiView
)
self.router.add_view("/channel/{channel_uid}/drive.json", ChannelDriveApiView)
self.router.add_view(
"/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
)
@ -412,11 +416,11 @@ class Application(BaseApplication):
self.jinja2_env.loader = await self.get_user_template_loader(
request.session.get("uid")
)
try:
context["nonce"] = request['csp_nonce']
context["nonce"] = request["csp_nonce"]
except:
context['nonce'] = '?'
context["nonce"] = "?"
rendered = await super().render_template(template, request, context)
@ -451,7 +455,7 @@ class Application(BaseApplication):
async def get_user_template_loader(self, uid=None):
template_paths = []
for admin_uid in self.services.user.get_admin_uids():
for admin_uid in await self.services.user.get_admin_uids():
user_template_path = await self.services.user.get_template_path(admin_uid)
if user_template_path:
template_paths.append(user_template_path)
@ -463,22 +467,22 @@ class Application(BaseApplication):
template_paths.append(self.template_path)
return FileSystemLoader(template_paths)
@asynccontextmanager
async def no_save(self):
stats = {
'count': 0
}
stats = {"count": 0}
async def patched_save(*args, **kwargs):
await self.cache.set(args[0]["uid"], args[0])
stats['count'] = stats['count'] + 1
stats["count"] = stats["count"] + 1
print(f"save is ignored {stats['count']} times")
return args[0]
save_original = self.services.channel_message.mapper.save
self.services.channel_message.mapper.save = patched_save
raised_exception = None
try:
yield
yield
except Exception as ex:
raised_exception = ex
finally:
@ -486,6 +490,7 @@ class Application(BaseApplication):
if raised_exception:
raise raised_exception
app = Application(db_path="sqlite:///snek.db")

View File

@ -6,11 +6,11 @@ class UserMapper(BaseMapper):
table_name = "user"
model_class = UserModel
def get_admin_uids(self):
async def get_admin_uids(self):
try:
return [
user["uid"]
for user in self.db.query(
for user in await self.db.query(
"SELECT uid FROM user WHERE is_admin = :is_admin",
{"is_admin": True},
)

View File

@ -2,6 +2,7 @@ from snek.system.service import BaseService
from snek.system.template import sanitize_html
import time
class ChannelMessageService(BaseService):
mapper_name = "channel_message"
@ -29,12 +30,11 @@ class ChannelMessageService(BaseService):
)
if html != message["html"]:
print("Reredefined message", message["uid"])
except Exception as ex:
time.sleep(0.1)
print(ex, flush=True)
while True:
changed = 0
async for message in self.find(is_final=False):
@ -72,7 +72,7 @@ class ChannelMessageService(BaseService):
try:
template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context)
model['html'] = sanitize_html(model['html'])
model["html"] = sanitize_html(model["html"])
except Exception as ex:
print(ex, flush=True)
@ -99,10 +99,10 @@ class ChannelMessageService(BaseService):
if not user:
return {}
#if not message["html"].startswith("<chat-message"):
#message = await self.get(uid=message["uid"])
#await self.save(message)
# if not message["html"].startswith("<chat-message"):
# message = await self.get(uid=message["uid"])
# await self.save(message)
return {
"uid": message["uid"],
"color": user["color"],
@ -129,7 +129,7 @@ class ChannelMessageService(BaseService):
)
template = self.app.jinja2_env.get_template("message.html")
model["html"] = template.render(**context)
model['html'] = sanitize_html(model['html'])
model["html"] = sanitize_html(model["html"])
return await super().save(model)
async def offset(self, channel_uid, page=0, timestamp=None, page_size=30):
@ -156,22 +156,22 @@ class ChannelMessageService(BaseService):
elif page > 0:
async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
{
*{
"channel_uid": channel_uid,
"page_size": page_size,
"offset": offset,
"timestamp": timestamp,
},
}.values(),
):
results.append(model)
else:
async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
{
*{
"channel_uid": channel_uid,
"page_size": page_size,
"offset": offset,
},
}.values(),
):
results.append(model)

View File

@ -16,7 +16,6 @@ from typing import (
from pathlib import Path
import aiosqlite
import unittest
from types import SimpleNamespace
import asyncio
import aiohttp
from aiohttp import web
@ -103,7 +102,7 @@ class AsyncDataSet:
# No socket file, become server
self._is_server = True
await self._setup_server()
except Exception as e:
except Exception:
# Fallback to client mode
self._is_server = False
await self._setup_client()
@ -213,7 +212,7 @@ class AsyncDataSet:
data = await request.json()
try:
try:
_limit = data.pop("_limit",60)
_limit = data.pop("_limit", 60)
except:
pass
@ -745,8 +744,13 @@ class AsyncDataSet:
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where:
return "", []
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
return " WHERE " + " AND ".join(clauses), list(vals)
clauses, vals = zip(
*[
(f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v)
for k, v in where.items()
]
)
return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None)
async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False
@ -881,7 +885,6 @@ class AsyncDataSet:
_limit = where.pop("_limit")
except:
pass
where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else ""
@ -1204,4 +1207,3 @@ class TestAsyncDataSet(unittest.IsolatedAsyncioTestCase):
if __name__ == "__main__":
unittest.main()

View File

@ -9,7 +9,7 @@ from aiohttp import web
@web.middleware
async def csp_middleware(request, handler):
nonce = secrets.token_hex(16)
origin = request.headers.get('Origin')
origin = request.headers.get("Origin")
csp_policy = (
"default-src 'self'; "
f"script-src 'self' {origin} 'nonce-{nonce}'; "
@ -25,9 +25,9 @@ async def csp_middleware(request, handler):
"media-src *; "
"manifest-src 'self';"
)
request['csp_nonce'] = nonce
request["csp_nonce"] = nonce
response = await handler(request)
#response.headers['Content-Security-Policy'] = csp_policy
# response.headers['Content-Security-Policy'] = csp_policy
return response
@ -37,6 +37,7 @@ async def no_cors_middleware(request, handler):
response.headers.pop("Access-Control-Allow-Origin", None)
return response
@web.middleware
async def cors_allow_middleware(request, handler):
response = await handler(request)
@ -48,15 +49,17 @@ async def cors_allow_middleware(request, handler):
response.headers["Access-Control-Allow-Credentials"] = "true"
return response
@web.middleware
async def auth_middleware(request, handler):
request["user"] = None
if request.session.get("uid") and request.session.get("logged_in"):
request["user"] = await request.app.services.user.get(
uid=request.app.session.get("uid")
uid=request.session.get("uid")
)
return await handler(request)
@web.middleware
async def cors_middleware(request, handler):
if request.headers.get("Allow"):
@ -69,5 +72,5 @@ async def cors_middleware(request, handler):
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Credentials"] = "true"
return response

View File

@ -1,5 +1,4 @@
from snek.mapper import get_mapper
from snek.model.user import UserModel
from snek.system.mapper import BaseMapper
@ -36,12 +35,12 @@ class BaseService:
return await self.mapper.new()
async def query(self, sql, *args):
for record in self.app.db.query(sql, *args):
for record in await self.app.db.query(sql, *args):
yield record
async def get(self, *args, **kwargs):
if not "deleted_at" in kwargs:
kwargs["deleted_at"] = None
if "deleted_at" not in kwargs:
kwargs["deleted_at"] = None
uid = kwargs.get("uid")
if args:
uid = args[0]
@ -50,7 +49,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class:
return result
kwargs["uid"] = uid
print(kwargs,"ZZZZZZZ")
print(kwargs, "ZZZZZZZ")
result = await self.mapper.get(**kwargs)
if result:
await self.cache.set(result["uid"], result)