Fixed db related issues #69

Merged
retoor merged 2 commits from BordedDev/snek:bugfix/retoors-new-db into main 2025-08-10 00:58:13 +02:00
6 changed files with 32 additions and 27 deletions

View File

@ -107,7 +107,7 @@ async def ip2location_middleware(request, handler):
user["city"] user["city"]
if user["city"] != location.city: if user["city"] != location.city:
user["country_long"] = location.country user["country_long"] = location.country
user["country_short"] = locaion.country_short user["country_short"] = location.country_short
user["city"] = location.city user["city"] = location.city
user["region"] = location.region user["region"] = location.region
user["latitude"] = location.latitude user["latitude"] = location.latitude
@ -165,7 +165,7 @@ class Application(BaseApplication):
self.mappers = get_mappers(app=self) self.mappers = get_mappers(app=self)
self.broadcast_service = None self.broadcast_service = None
self.user_availability_service_task = None self.user_availability_service_task = None
self.setup_router() self.setup_router()
base_path = pathlib.Path(__file__).parent base_path = pathlib.Path(__file__).parent
self.ip2location = IP2Location.IP2Location( self.ip2location = IP2Location.IP2Location(
@ -176,7 +176,7 @@ class Application(BaseApplication):
self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server) self.on_startup.append(self.start_ssh_server)
#self.on_startup.append(self.prepare_database) #self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app): async def prepare_stats(self, app):
app['stats'] = create_stats_structure() app['stats'] = create_stats_structure()
print("Stats prepared", flush=True) print("Stats prepared", flush=True)
@ -243,7 +243,7 @@ class Application(BaseApplication):
except Exception as ex: except Exception as ex:
print(ex) print(ex)
self.db.commit() self.db.commit()
async def prepare_database(self, app): async def prepare_database(self, app):
await self.db.query_raw("PRAGMA journal_mode=WAL") await self.db.query_raw("PRAGMA journal_mode=WAL")
@ -412,11 +412,11 @@ class Application(BaseApplication):
self.jinja2_env.loader = await self.get_user_template_loader( self.jinja2_env.loader = await self.get_user_template_loader(
request.session.get("uid") request.session.get("uid")
) )
try: try:
context["nonce"] = request['csp_nonce'] context["nonce"] = request['csp_nonce']
except: except:
context['nonce'] = '?' context['nonce'] = '?'
rendered = await super().render_template(template, request, context) rendered = await super().render_template(template, request, context)
@ -451,7 +451,7 @@ class Application(BaseApplication):
async def get_user_template_loader(self, uid=None): async def get_user_template_loader(self, uid=None):
template_paths = [] template_paths = []
for admin_uid in self.services.user.get_admin_uids(): for admin_uid in await self.services.user.get_admin_uids():
user_template_path = await self.services.user.get_template_path(admin_uid) user_template_path = await self.services.user.get_template_path(admin_uid)
if user_template_path: if user_template_path:
template_paths.append(user_template_path) template_paths.append(user_template_path)
@ -463,7 +463,7 @@ class Application(BaseApplication):
template_paths.append(self.template_path) template_paths.append(self.template_path)
return FileSystemLoader(template_paths) return FileSystemLoader(template_paths)
@asynccontextmanager @asynccontextmanager
async def no_save(self): async def no_save(self):
stats = { stats = {
@ -478,7 +478,7 @@ class Application(BaseApplication):
self.services.channel_message.mapper.save = patched_save self.services.channel_message.mapper.save = patched_save
raised_exception = None raised_exception = None
try: try:
yield yield
except Exception as ex: except Exception as ex:
raised_exception = ex raised_exception = ex
finally: finally:

View File

@ -6,11 +6,11 @@ class UserMapper(BaseMapper):
table_name = "user" table_name = "user"
model_class = UserModel model_class = UserModel
def get_admin_uids(self): async def get_admin_uids(self):
try: try:
return [ return [
user["uid"] user["uid"]
for user in self.db.query( for user in await self.db.query(
"SELECT uid FROM user WHERE is_admin = :is_admin", "SELECT uid FROM user WHERE is_admin = :is_admin",
{"is_admin": True}, {"is_admin": True},
) )

View File

@ -29,12 +29,12 @@ class ChannelMessageService(BaseService):
) )
if html != message["html"]: if html != message["html"]:
print("Reredefined message", message["uid"]) print("Reredefined message", message["uid"])
except Exception as ex: except Exception as ex:
time.sleep(0.1) time.sleep(0.1)
print(ex, flush=True) print(ex, flush=True)
while True: while True:
changed = 0 changed = 0
async for message in self.find(is_final=False): async for message in self.find(is_final=False):
@ -102,7 +102,7 @@ class ChannelMessageService(BaseService):
#if not message["html"].startswith("<chat-message"): #if not message["html"].startswith("<chat-message"):
#message = await self.get(uid=message["uid"]) #message = await self.get(uid=message["uid"])
#await self.save(message) #await self.save(message)
return { return {
"uid": message["uid"], "uid": message["uid"],
"color": user["color"], "color": user["color"],
@ -156,22 +156,22 @@ class ChannelMessageService(BaseService):
elif page > 0: elif page > 0:
async for model in self.query( async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size", f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
{ *{
"channel_uid": channel_uid, "channel_uid": channel_uid,
"page_size": page_size, "page_size": page_size,
"offset": offset, "offset": offset,
"timestamp": timestamp, "timestamp": timestamp,
}, }.values(),
): ):
results.append(model) results.append(model)
else: else:
async for model in self.query( async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset", f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
{ *{
"channel_uid": channel_uid, "channel_uid": channel_uid,
"page_size": page_size, "page_size": page_size,
"offset": offset, "offset": offset,
}, }.values(),
): ):
results.append(model) results.append(model)

View File

@ -745,8 +745,13 @@ class AsyncDataSet:
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
if not where: if not where:
return "", [] return "", []
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) clauses, vals = zip(
return " WHERE " + " AND ".join(clauses), list(vals) *[
(f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v)
for k, v in where.items()
]
)
return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None)
async def _server_insert( async def _server_insert(
self, table: str, args: Dict[str, Any], return_id: bool = False self, table: str, args: Dict[str, Any], return_id: bool = False
@ -881,7 +886,7 @@ class AsyncDataSet:
_limit = where.pop("_limit") _limit = where.pop("_limit")
except: except:
pass pass
where_clause, where_params = self._build_where(where) where_clause, where_params = self._build_where(where)
order_clause = f" ORDER BY {order_by}" if order_by else "" order_clause = f" ORDER BY {order_by}" if order_by else ""

View File

@ -53,7 +53,7 @@ async def auth_middleware(request, handler):
request["user"] = None request["user"] = None
if request.session.get("uid") and request.session.get("logged_in"): if request.session.get("uid") and request.session.get("logged_in"):
request["user"] = await request.app.services.user.get( request["user"] = await request.app.services.user.get(
uid=request.app.session.get("uid") uid=request.session.get("uid")
) )
return await handler(request) return await handler(request)
@ -69,5 +69,5 @@ async def cors_middleware(request, handler):
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Credentials"] = "true"
return response return response

View File

@ -36,12 +36,12 @@ class BaseService:
return await self.mapper.new() return await self.mapper.new()
async def query(self, sql, *args): async def query(self, sql, *args):
for record in self.app.db.query(sql, *args): for record in await self.app.db.query(sql, *args):
yield record yield record
async def get(self, *args, **kwargs): async def get(self, *args, **kwargs):
if not "deleted_at" in kwargs: if not "deleted_at" in kwargs:
kwargs["deleted_at"] = None kwargs["deleted_at"] = None
uid = kwargs.get("uid") uid = kwargs.get("uid")
if args: if args:
uid = args[0] uid = args[0]
@ -50,7 +50,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class: if result and result.__class__ == self.mapper.model_class:
return result return result
kwargs["uid"] = uid kwargs["uid"] = uid
print(kwargs,"ZZZZZZZ") print(kwargs,"ZZZZZZZ")
result = await self.mapper.get(**kwargs) result = await self.mapper.get(**kwargs)
if result: if result:
await self.cache.set(result["uid"], result) await self.cache.set(result["uid"], result)