Compare commits

...

3 Commits

Author SHA1 Message Date
BordedDev
51b4898078 Unformatted code 2025-08-08 18:45:42 +02:00
BordedDev
40f9646251 Fixed db related issues 2025-08-08 18:29:52 +02:00
5e99e894e9 Not working version, issues with get 2025-08-08 17:28:35 +02:00
11 changed files with 1271 additions and 52 deletions

View File

@ -40,7 +40,8 @@ dependencies = [
"pillow-heif",
"IP2Location",
"bleach",
"sentry-sdk"
"sentry-sdk",
"aiosqlite"
]
[tool.setuptools.packages.find]

View File

@ -9,7 +9,7 @@ from contextlib import asynccontextmanager
from snek import snode
from snek.view.threads import ThreadsView
from snek.system.ads import AsyncDataSet
logging.basicConfig(level=logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor
from ipaddress import ip_address
@ -107,7 +107,7 @@ async def ip2location_middleware(request, handler):
user["city"]
if user["city"] != location.city:
user["country_long"] = location.country
user["country_short"] = locaion.country_short
user["country_short"] = location.country_short
user["city"] = location.city
user["region"] = location.region
user["latitude"] = location.latitude
@ -142,6 +142,7 @@ class Application(BaseApplication):
client_max_size=1024 * 1024 * 1024 * 5 * args,
**kwargs,
)
self.db = AsyncDataSet(kwargs["db_path"].replace("sqlite:///", ""))
session_setup(self, EncryptedCookieStorage(SESSION_KEY))
self.tasks = asyncio.Queue()
self._middlewares.append(session_middleware)
@ -164,7 +165,7 @@ class Application(BaseApplication):
self.mappers = get_mappers(app=self)
self.broadcast_service = None
self.user_availability_service_task = None
self.setup_router()
base_path = pathlib.Path(__file__).parent
self.ip2location = IP2Location.IP2Location(
@ -174,8 +175,8 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_asyncio)
self.on_startup.append(self.start_user_availability_service)
self.on_startup.append(self.start_ssh_server)
self.on_startup.append(self.prepare_database)
#self.on_startup.append(self.prepare_database)
async def prepare_stats(self, app):
app['stats'] = create_stats_structure()
print("Stats prepared", flush=True)
@ -242,21 +243,11 @@ class Application(BaseApplication):
except Exception as ex:
print(ex)
self.db.commit()
async def prepare_database(self, app):
self.db.query("PRAGMA journal_mode=WAL")
self.db.query("PRAGMA syncnorm=off")
try:
if not self.db["user"].has_index("username"):
self.db["user"].create_index("username", unique=True)
if not self.db["channel_member"].has_index(["channel_uid", "user_uid"]):
self.db["channel_member"].create_index(["channel_uid", "user_uid"])
if not self.db["channel_message"].has_index(["channel_uid", "user_uid"]):
self.db["channel_message"].create_index(["channel_uid", "user_uid"])
except:
pass
await self.db.query_raw("PRAGMA journal_mode=WAL")
await self.db.query_raw("PRAGMA syncnorm=off")
await self.services.drive.prepare_all()
self.loop.create_task(self.task_runner())
@ -421,11 +412,11 @@ class Application(BaseApplication):
self.jinja2_env.loader = await self.get_user_template_loader(
request.session.get("uid")
)
try:
context["nonce"] = request['csp_nonce']
except:
context['nonce'] = '?'
context['nonce'] = '?'
rendered = await super().render_template(template, request, context)
@ -460,7 +451,7 @@ class Application(BaseApplication):
async def get_user_template_loader(self, uid=None):
template_paths = []
for admin_uid in self.services.user.get_admin_uids():
for admin_uid in await self.services.user.get_admin_uids():
user_template_path = await self.services.user.get_template_path(admin_uid)
if user_template_path:
template_paths.append(user_template_path)
@ -472,7 +463,7 @@ class Application(BaseApplication):
template_paths.append(self.template_path)
return FileSystemLoader(template_paths)
@asynccontextmanager
async def no_save(self):
stats = {
@ -487,7 +478,7 @@ class Application(BaseApplication):
self.services.channel_message.mapper.save = patched_save
raised_exception = None
try:
yield
yield
except Exception as ex:
raised_exception = ex
finally:

View File

@ -6,11 +6,11 @@ class UserMapper(BaseMapper):
table_name = "user"
model_class = UserModel
def get_admin_uids(self):
async def get_admin_uids(self):
try:
return [
user["uid"]
for user in self.db.query(
for user in await self.db.query(
"SELECT uid FROM user WHERE is_admin = :is_admin",
{"is_admin": True},
)

View File

@ -29,12 +29,12 @@ class ChannelMessageService(BaseService):
)
if html != message["html"]:
print("Reredefined message", message["uid"])
except Exception as ex:
time.sleep(0.1)
print(ex, flush=True)
while True:
changed = 0
async for message in self.find(is_final=False):
@ -102,7 +102,7 @@ class ChannelMessageService(BaseService):
#if not message["html"].startswith("<chat-message"):
#message = await self.get(uid=message["uid"])
#await self.save(message)
return {
"uid": message["uid"],
"color": user["color"],
@ -156,22 +156,22 @@ class ChannelMessageService(BaseService):
elif page > 0:
async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
{
*{
"channel_uid": channel_uid,
"page_size": page_size,
"offset": offset,
"timestamp": timestamp,
},
}.values(),
):
results.append(model)
else:
async for model in self.query(
f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
{
*{
"channel_uid": channel_uid,
"page_size": page_size,
"offset": offset,
},
}.values(),
):
results.append(model)

View File

@ -101,6 +101,8 @@ class UserService(BaseService):
model.username.value = username
model.password.value = await security.hash(password)
if await self.save(model):
for x in range(10):
print("Jazeker!!!")
if model:
channel = await self.services.channel.ensure_public_channel(
model["uid"]

View File

@ -7,7 +7,8 @@ class UserPropertyService(BaseService):
mapper_name = "user_property"
async def set(self, user_uid, name, value):
self.mapper.db["user_property"].upsert(
self.mapper.db.upsert(
"user_property",
{
"user_uid": user_uid,
"name": name,

1212
src/snek/system/ads.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,7 @@
DEFAULT_LIMIT = 30
import asyncio
import typing
import traceback
from snek.system.model import BaseModel
@ -51,7 +51,9 @@ class BaseMapper:
kwargs["uid"] = uid
if not kwargs.get("deleted_at"):
kwargs["deleted_at"] = None
record = await self.run_in_executor(self.table.find_one, **kwargs)
#traceback.print_exc()
record = await self.db.get(self.table_name, kwargs)
if not record:
return None
record = dict(record)
@ -61,23 +63,29 @@ class BaseMapper:
return model
async def exists(self, **kwargs):
return await self.run_in_executor(self.table.exists, **kwargs)
return await self.db.count(self.table_name, kwargs)
#return await self.run_in_executor(self.table.exists, **kwargs)
async def count(self, **kwargs) -> int:
return await self.run_in_executor(self.table.count, **kwargs)
return await self.db.count(self.table_name,kwargs)
async def save(self, model: BaseModel) -> bool:
if not model.record.get("uid"):
raise Exception(f"Attempt to save without uid: {model.record}.")
model.updated_at.update()
return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
model.updated_at.update()
await self.upsert(model)
return model
#return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
async def find(self, **kwargs) -> typing.AsyncGenerator:
if not kwargs.get("_limit"):
kwargs["_limit"] = self.default_limit
if not kwargs.get("deleted_at"):
kwargs["deleted_at"] = None
for record in await self.run_in_executor(self.table.find, **kwargs):
for record in await self.db.find(self.table_name, kwargs):
model = await self.new()
for key, value in record.items():
model[key] = value
@ -88,21 +96,21 @@ class BaseMapper:
return "insert" in sql or "update" in sql or "delete" in sql
async def query(self, sql, *args):
for record in await self.run_in_executor(self.db.query, sql, *args, use_semaphore=await self._use_semaphore(sql)):
for record in await self.db.query(sql, *args):
yield dict(record)
async def update(self, model):
if not model["deleted_at"] is None:
raise Exception("Can't update deleted record.")
model.updated_at.update()
return await self.run_in_executor(self.table.update, model.record, ["uid"],use_semaphore=True)
return await self.db.update(self.table_name, model.record, {"uid": model["uid"]})
async def upsert(self, model):
model.updated_at.update()
return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
await self.db.upsert(self.table_name, model.record, {"uid": model["uid"]})
return model
async def delete(self, **kwargs) -> int:
if not kwargs or not isinstance(kwargs, dict):
raise Exception("Can't execute delete with no filter.")
kwargs["use_semaphore"] = True
return await self.run_in_executor(self.table.delete, **kwargs)
return await self.db.delete(self.table_name, kwargs)

View File

@ -53,7 +53,7 @@ async def auth_middleware(request, handler):
request["user"] = None
if request.session.get("uid") and request.session.get("logged_in"):
request["user"] = await request.app.services.user.get(
uid=request.app.session.get("uid")
uid=request.session.get("uid")
)
return await handler(request)
@ -69,5 +69,5 @@ async def cors_middleware(request, handler):
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Credentials"] = "true"
return response

View File

@ -36,12 +36,12 @@ class BaseService:
return await self.mapper.new()
async def query(self, sql, *args):
for record in self.app.db.query(sql, *args):
for record in await self.app.db.query(sql, *args):
yield record
async def get(self, *args, **kwargs):
if not "deleted_at" in kwargs:
kwargs["deleted_at"] = None
kwargs["deleted_at"] = None
uid = kwargs.get("uid")
if args:
uid = args[0]
@ -50,7 +50,7 @@ class BaseService:
if result and result.__class__ == self.mapper.model_class:
return result
kwargs["uid"] = uid
print(kwargs,"ZZZZZZZ")
result = await self.mapper.get(**kwargs)
if result:
await self.cache.set(result["uid"], result)

View File

@ -38,6 +38,10 @@ class WebView(BaseView):
channel = await self.services.channel.get(
uid=self.request.match_info.get("channel")
)
print(self.session.get("uid"),"ZZZZZZZZZZ")
qq = await self.services.user.get(uid=self.session.get("uid"))
print("GGGGGGGGGG",qq)
if not channel:
user = await self.services.user.get(
uid=self.request.match_info.get("channel")