Merge pull request 'Fixed db related issues' (#69) from BordedDev/snek:bugfix/retoors-new-db into main
Reviewed-on: #69
This commit is contained in:
commit
30382b139d
@ -107,7 +107,7 @@ async def ip2location_middleware(request, handler):
|
||||
user["city"]
|
||||
if user["city"] != location.city:
|
||||
user["country_long"] = location.country
|
||||
user["country_short"] = locaion.country_short
|
||||
user["country_short"] = location.country_short
|
||||
user["city"] = location.city
|
||||
user["region"] = location.region
|
||||
user["latitude"] = location.latitude
|
||||
@ -165,7 +165,7 @@ class Application(BaseApplication):
|
||||
self.mappers = get_mappers(app=self)
|
||||
self.broadcast_service = None
|
||||
self.user_availability_service_task = None
|
||||
|
||||
|
||||
self.setup_router()
|
||||
base_path = pathlib.Path(__file__).parent
|
||||
self.ip2location = IP2Location.IP2Location(
|
||||
@ -176,7 +176,7 @@ class Application(BaseApplication):
|
||||
self.on_startup.append(self.start_user_availability_service)
|
||||
self.on_startup.append(self.start_ssh_server)
|
||||
#self.on_startup.append(self.prepare_database)
|
||||
|
||||
|
||||
async def prepare_stats(self, app):
|
||||
app['stats'] = create_stats_structure()
|
||||
print("Stats prepared", flush=True)
|
||||
@ -243,7 +243,7 @@ class Application(BaseApplication):
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
|
||||
async def prepare_database(self, app):
|
||||
await self.db.query_raw("PRAGMA journal_mode=WAL")
|
||||
@ -412,11 +412,11 @@ class Application(BaseApplication):
|
||||
self.jinja2_env.loader = await self.get_user_template_loader(
|
||||
request.session.get("uid")
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
context["nonce"] = request['csp_nonce']
|
||||
except:
|
||||
context['nonce'] = '?'
|
||||
context['nonce'] = '?'
|
||||
|
||||
rendered = await super().render_template(template, request, context)
|
||||
|
||||
@ -451,7 +451,7 @@ class Application(BaseApplication):
|
||||
|
||||
async def get_user_template_loader(self, uid=None):
|
||||
template_paths = []
|
||||
for admin_uid in self.services.user.get_admin_uids():
|
||||
for admin_uid in await self.services.user.get_admin_uids():
|
||||
user_template_path = await self.services.user.get_template_path(admin_uid)
|
||||
if user_template_path:
|
||||
template_paths.append(user_template_path)
|
||||
@ -463,7 +463,7 @@ class Application(BaseApplication):
|
||||
|
||||
template_paths.append(self.template_path)
|
||||
return FileSystemLoader(template_paths)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def no_save(self):
|
||||
stats = {
|
||||
@ -478,7 +478,7 @@ class Application(BaseApplication):
|
||||
self.services.channel_message.mapper.save = patched_save
|
||||
raised_exception = None
|
||||
try:
|
||||
yield
|
||||
yield
|
||||
except Exception as ex:
|
||||
raised_exception = ex
|
||||
finally:
|
||||
|
@ -6,11 +6,11 @@ class UserMapper(BaseMapper):
|
||||
table_name = "user"
|
||||
model_class = UserModel
|
||||
|
||||
def get_admin_uids(self):
|
||||
async def get_admin_uids(self):
|
||||
try:
|
||||
return [
|
||||
user["uid"]
|
||||
for user in self.db.query(
|
||||
for user in await self.db.query(
|
||||
"SELECT uid FROM user WHERE is_admin = :is_admin",
|
||||
{"is_admin": True},
|
||||
)
|
||||
|
@ -29,12 +29,12 @@ class ChannelMessageService(BaseService):
|
||||
)
|
||||
if html != message["html"]:
|
||||
print("Reredefined message", message["uid"])
|
||||
|
||||
|
||||
except Exception as ex:
|
||||
time.sleep(0.1)
|
||||
print(ex, flush=True)
|
||||
|
||||
|
||||
|
||||
|
||||
while True:
|
||||
changed = 0
|
||||
async for message in self.find(is_final=False):
|
||||
@ -102,7 +102,7 @@ class ChannelMessageService(BaseService):
|
||||
#if not message["html"].startswith("<chat-message"):
|
||||
#message = await self.get(uid=message["uid"])
|
||||
#await self.save(message)
|
||||
|
||||
|
||||
return {
|
||||
"uid": message["uid"],
|
||||
"color": user["color"],
|
||||
@ -156,22 +156,22 @@ class ChannelMessageService(BaseService):
|
||||
elif page > 0:
|
||||
async for model in self.query(
|
||||
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
|
||||
{
|
||||
*{
|
||||
"channel_uid": channel_uid,
|
||||
"page_size": page_size,
|
||||
"offset": offset,
|
||||
"timestamp": timestamp,
|
||||
},
|
||||
}.values(),
|
||||
):
|
||||
results.append(model)
|
||||
else:
|
||||
async for model in self.query(
|
||||
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
|
||||
{
|
||||
*{
|
||||
"channel_uid": channel_uid,
|
||||
"page_size": page_size,
|
||||
"offset": offset,
|
||||
},
|
||||
}.values(),
|
||||
):
|
||||
results.append(model)
|
||||
|
||||
|
@ -745,8 +745,13 @@ class AsyncDataSet:
|
||||
def _build_where(where: Optional[Dict[str, Any]]) -> tuple[str, List[Any]]:
|
||||
if not where:
|
||||
return "", []
|
||||
clauses, vals = zip(*[(f"`{k}` = ?", v) for k, v in where.items()])
|
||||
return " WHERE " + " AND ".join(clauses), list(vals)
|
||||
clauses, vals = zip(
|
||||
*[
|
||||
(f"`{k}` = ?" if v is not None else f"`{k}` IS NULL", v)
|
||||
for k, v in where.items()
|
||||
]
|
||||
)
|
||||
return " WHERE " + " AND ".join(clauses), list(v for v in vals if v is not None)
|
||||
|
||||
async def _server_insert(
|
||||
self, table: str, args: Dict[str, Any], return_id: bool = False
|
||||
@ -881,7 +886,7 @@ class AsyncDataSet:
|
||||
_limit = where.pop("_limit")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
where_clause, where_params = self._build_where(where)
|
||||
order_clause = f" ORDER BY {order_by}" if order_by else ""
|
||||
|
@ -53,7 +53,7 @@ async def auth_middleware(request, handler):
|
||||
request["user"] = None
|
||||
if request.session.get("uid") and request.session.get("logged_in"):
|
||||
request["user"] = await request.app.services.user.get(
|
||||
uid=request.app.session.get("uid")
|
||||
uid=request.session.get("uid")
|
||||
)
|
||||
return await handler(request)
|
||||
|
||||
@ -69,5 +69,5 @@ async def cors_middleware(request, handler):
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
|
||||
return response
|
||||
|
@ -36,12 +36,12 @@ class BaseService:
|
||||
return await self.mapper.new()
|
||||
|
||||
async def query(self, sql, *args):
|
||||
for record in self.app.db.query(sql, *args):
|
||||
for record in await self.app.db.query(sql, *args):
|
||||
yield record
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
if not "deleted_at" in kwargs:
|
||||
kwargs["deleted_at"] = None
|
||||
kwargs["deleted_at"] = None
|
||||
uid = kwargs.get("uid")
|
||||
if args:
|
||||
uid = args[0]
|
||||
@ -50,7 +50,7 @@ class BaseService:
|
||||
if result and result.__class__ == self.mapper.model_class:
|
||||
return result
|
||||
kwargs["uid"] = uid
|
||||
print(kwargs,"ZZZZZZZ")
|
||||
print(kwargs,"ZZZZZZZ")
|
||||
result = await self.mapper.get(**kwargs)
|
||||
if result:
|
||||
await self.cache.set(result["uid"], result)
|
||||
|
Loading…
Reference in New Issue
Block a user