diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py index be356dc..a81b9e7 100644 --- a/src/snek/service/__init__.py +++ b/src/snek/service/__init__.py @@ -13,7 +13,7 @@ from snek.service.user_property import UserPropertyService from snek.service.util import UtilService from snek.service.repository import RepositoryService from snek.system.object import Object - +from snek.service.db import DBService @functools.cache def get_services(app): @@ -31,6 +31,7 @@ def get_services(app): "drive_item": DriveItemService(app=app), "user_property": UserPropertyService(app=app), "repository": RepositoryService(app=app), + "db": DBService(app=app), } ) diff --git a/src/snek/service/db.py b/src/snek/service/db.py new file mode 100644 index 0000000..fe9fb8a --- /dev/null +++ b/src/snek/service/db.py @@ -0,0 +1,71 @@ +from snek.system.service import BaseService +import dataset +import uuid + +from datetime import datetime + +class DBService(BaseService): + + async def get_db(self, user_uid): + + home_folder = await self.app.services.user.get_home_folder(user_uid) + home_folder.mkdir(parents=True, exist_ok=True) + db_path = home_folder.joinpath("snek/user.db") + db_path.parent.mkdir(parents=True, exist_ok=True) + return dataset.connect("sqlite:///" + str(db_path)) + + async def insert(self, user_uid, table_name, values): + db = await self.get_db(user_uid) + return db[table_name].insert(values) + + + async def update(self, user_uid, table_name, values, filters): + db = await self.get_db(user_uid) + if not filters: + filters = {} + if not values: + return False + return db[table_name].update(values, filters) + + async def upsert(self, user_uid, table_name, values, keys): + db = await self.get_db(user_uid) + return db[table_name].upsert(values, keys) + + async def find(self, user_uid, table_name, kwargs): + db = await self.get_db(user_uid) + kwargs['_limit'] = kwargs.get('_limit', 30) + return [dict(row) for row in db[table_name].find(**kwargs)] + + async def get(self, user_uid, table_name, filters): + db = await self.get_db(user_uid) + if not filters: + filters = {} + try: + return dict(db[table_name].find_one(**filters)) + except ValueError: + return None + + + async def delete(self, user_uid, table_name, filters): + db = await self.get_db(user_uid) + if not filters: + filters = {} + return db[table_name].delete(**filters) + + async def query(self, sql,values): + db = await self.app.db + return [dict(row) for row in db.query(sql, values or {})] + + async def exists(self, user_uid, table_name, filters): + db = await self.get_db(user_uid) + if not filters: + filters = {} + return bool(db[table_name].find_one(**filters)) + + + + async def count(self, user_uid, table_name, filters): + db = await self.get_db(user_uid) + if not filters: + filters = {} + return db[table_name].count(**filters) diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index 3161f49..4592f21 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -26,6 +26,31 @@ class RPCView(BaseView): self.services = self.app.services self.ws = ws + async def db_insert(self, table_name, record): + self._require_login() + + return await self.services.db.insert(self.user_uid, table_name, record) + async def db_update(self, table_name, record): + self._require_login() + return await self.services.db.update(self.user_uid, table_name, record) + async def db_delete(self, table_name, record): + self._require_login() + return await self.services.db.delete(self.user_uid, table_name, record) + async def db_get(self, table_name, record): + self._require_login() + return await self.services.db.get(self.user_uid, table_name, record) + async def db_find(self, table_name, record): + self._require_login() + return await self.services.db.find(self.user_uid, table_name, record) + async def db_upsert(self, table_name, record,keys): + self._require_login() + return await self.services.db.upsert(self.user_uid, table_name, record,keys) + + async def db_query(self, table_name, args): + self._require_login() + return await self.services.db.query(self.user_uid, table_name, sql, args) + + @property def user_uid(self): return self.view.session.get("uid")