diff --git a/src/snek/app.py b/src/snek/app.py index f5c0f27..a551e41 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -107,7 +107,7 @@ async def ip2location_middleware(request, handler): user["city"] if user["city"] != location.city: user["country_long"] = location.country - user["country_short"] = locaion.country_short + user["country_short"] = location.country_short user["city"] = location.city user["region"] = location.region user["latitude"] = location.latitude @@ -165,7 +165,7 @@ class Application(BaseApplication): self.mappers = get_mappers(app=self) self.broadcast_service = None self.user_availability_service_task = None - + self.setup_router() base_path = pathlib.Path(__file__).parent self.ip2location = IP2Location.IP2Location( @@ -176,7 +176,7 @@ class Application(BaseApplication): self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_ssh_server) #self.on_startup.append(self.prepare_database) - + async def prepare_stats(self, app): app['stats'] = create_stats_structure() print("Stats prepared", flush=True) @@ -243,7 +243,7 @@ class Application(BaseApplication): except Exception as ex: print(ex) self.db.commit() - + async def prepare_database(self, app): await self.db.query_raw("PRAGMA journal_mode=WAL") @@ -412,11 +412,11 @@ class Application(BaseApplication): self.jinja2_env.loader = await self.get_user_template_loader( request.session.get("uid") ) - + try: context["nonce"] = request['csp_nonce'] except: - context['nonce'] = '?' + context['nonce'] = '?' rendered = await super().render_template(template, request, context) @@ -451,7 +451,7 @@ class Application(BaseApplication): async def get_user_template_loader(self, uid=None): template_paths = [] - for admin_uid in self.services.user.get_admin_uids(): + for admin_uid in await self.services.user.get_admin_uids(): user_template_path = await self.services.user.get_template_path(admin_uid) if user_template_path: template_paths.append(user_template_path) @@ -463,7 +463,7 @@ class Application(BaseApplication): template_paths.append(self.template_path) return FileSystemLoader(template_paths) - + @asynccontextmanager async def no_save(self): stats = { @@ -478,7 +478,7 @@ class Application(BaseApplication): self.services.channel_message.mapper.save = patched_save raised_exception = None try: - yield + yield except Exception as ex: raised_exception = ex finally: diff --git a/src/snek/mapper/user.py b/src/snek/mapper/user.py index e0df494..dc3387f 100644 --- a/src/snek/mapper/user.py +++ b/src/snek/mapper/user.py @@ -6,11 +6,11 @@ class UserMapper(BaseMapper): table_name = "user" model_class = UserModel - def get_admin_uids(self): + async def get_admin_uids(self): try: return [ user["uid"] - for user in self.db.query( + for user in await self.db.query( "SELECT uid FROM user WHERE is_admin = :is_admin", {"is_admin": True}, ) diff --git a/src/snek/service/channel_message.py b/src/snek/service/channel_message.py index aae68b9..642f4e0 100644 --- a/src/snek/service/channel_message.py +++ b/src/snek/service/channel_message.py @@ -29,12 +29,12 @@ class ChannelMessageService(BaseService): ) if html != message["html"]: print("Reredefined message", message["uid"]) - + except Exception as ex: time.sleep(0.1) print(ex, flush=True) - - + + while True: changed = 0 async for message in self.find(is_final=False): @@ -102,7 +102,7 @@ class ChannelMessageService(BaseService): #if not message["html"].startswith(" 0: async for model in self.query( f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size", - { + *{ "channel_uid": channel_uid, "page_size": page_size, "offset": offset, "timestamp": timestamp, - }, + }.values(), ): results.append(model) else: async for model in self.query( f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset", - { + *{ "channel_uid": channel_uid, "page_size": page_size, "offset": offset, - }, + }.values(), ): results.append(model) diff --git a/src/snek/system/ads.py b/src/snek/system/ads.py index 34ed414..42b64a0 100644 --- a/src/snek/system/ads.py +++ b/src/snek/system/ads.py @@ -745,8 +745,13 @@ class AsyncDataSet: def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]: if not where: return "", [] - clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()]) - return " WHERE " + " AND ".join(clauses), list(vals) + clauses, vals = zip( + *[ + (f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v) + for k, v in where.items() + ] + ) + return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None) async def _server_insert( self, table: str, args: Dict[str, Any], return_id: bool = False @@ -881,7 +886,7 @@ class AsyncDataSet: _limit = where.pop("_limit") except: pass - + where_clause, where_params = self._build_where(where) order_clause = f" ORDER BY {order_by}" if order_by else "" diff --git a/src/snek/system/middleware.py b/src/snek/system/middleware.py index db25327..f3a13d8 100644 --- a/src/snek/system/middleware.py +++ b/src/snek/system/middleware.py @@ -53,7 +53,7 @@ async def auth_middleware(request, handler): request["user"] = None if request.session.get("uid") and request.session.get("logged_in"): request["user"] = await request.app.services.user.get( - uid=request.app.session.get("uid") + uid=request.session.get("uid") ) return await handler(request) @@ -69,5 +69,5 @@ async def cors_middleware(request, handler): response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Allow-Credentials"] = "true" - + return response diff --git a/src/snek/system/service.py b/src/snek/system/service.py index b47dbc4..82c5d1c 100644 --- a/src/snek/system/service.py +++ b/src/snek/system/service.py @@ -36,12 +36,12 @@ class BaseService: return await self.mapper.new() async def query(self, sql, *args): - for record in self.app.db.query(sql, *args): + for record in await self.app.db.query(sql, *args): yield record async def get(self, *args, **kwargs): if not "deleted_at" in kwargs: - kwargs["deleted_at"] = None + kwargs["deleted_at"] = None uid = kwargs.get("uid") if args: uid = args[0] @@ -50,7 +50,7 @@ class BaseService: if result and result.__class__ == self.mapper.model_class: return result kwargs["uid"] = uid - print(kwargs,"ZZZZZZZ") + print(kwargs,"ZZZZZZZ") result = await self.mapper.get(**kwargs) if result: await self.cache.set(result["uid"], result)