From aec2da11f27a0e045ac03a30fb70298d400d30ad Mon Sep 17 00:00:00 2001 From: BordedDev <> Date: Sat, 31 May 2025 19:08:55 +0200 Subject: [PATCH] Implement push notification service and registration --- src/snek/app.py | 8 +- src/snek/mapper/__init__.py | 2 + src/snek/mapper/push.py | 7 ++ src/snek/model/__init__.py | 2 + src/snek/model/push_registration.py | 8 ++ src/snek/service/__init__.py | 3 + src/snek/service/notification.py | 13 +++ .../notification.py => service/push.py} | 81 ++++++++++++++-- src/snek/static/push.js | 2 + src/snek/view/push.py | 97 +++++++------------ 10 files changed, 153 insertions(+), 70 deletions(-) create mode 100644 src/snek/mapper/push.py create mode 100644 src/snek/model/push_registration.py rename src/snek/{system/notification.py => service/push.py} (70%) diff --git a/src/snek/app.py b/src/snek/app.py index cb6ca45..ee82144 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -1,6 +1,7 @@ import asyncio import logging import pathlib +import ssl import time import uuid from datetime import datetime @@ -23,7 +24,6 @@ from jinja2 import FileSystemLoader from snek.sssh import start_ssh_server -from snek.system.notification import get_notifications from snek.docs.app import Application as DocsApplication from snek.mapper import get_mappers from snek.service import get_services @@ -105,7 +105,7 @@ class Application(BaseApplication): self.ssh_host = "0.0.0.0" self.ssh_port = 2242 - get_notifications() + self.setup_router() self.ssh_server = None self.sync_service = None @@ -380,7 +380,9 @@ app = Application(db_path="sqlite:///snek.db") async def main(): - await web._run_app(app, port=8081, host="0.0.0.0") + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain('cert.pem', 'key.pem') + await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context) if __name__ == "__main__": diff --git a/src/snek/mapper/__init__.py b/src/snek/mapper/__init__.py index 69901ec..48a8a0e 100644 --- a/src/snek/mapper/__init__.py +++ b/src/snek/mapper/__init__.py @@ -9,6 +9,7 @@ from snek.mapper.notification import NotificationMapper from snek.mapper.user import UserMapper from snek.mapper.user_property import UserPropertyMapper from snek.mapper.repository import RepositoryMapper +from snek.mapper.push import PushMapper from snek.mapper.channel_attachment import ChannelAttachmentMapper from snek.mapper.container import ContainerMapper from snek.system.object import Object @@ -30,6 +31,7 @@ def get_mappers(app=None): "repository": RepositoryMapper(app=app), "channel_attachment": ChannelAttachmentMapper(app=app), "container": ContainerMapper(app=app), + "push": PushMapper(app=app), } ) diff --git a/src/snek/mapper/push.py b/src/snek/mapper/push.py new file mode 100644 index 0000000..2fe7c29 --- /dev/null +++ b/src/snek/mapper/push.py @@ -0,0 +1,7 @@ +from snek.model.push_registration import PushRegistrationModel +from snek.system.mapper import BaseMapper + + +class PushMapper(BaseMapper): + model_class = PushRegistrationModel + table_name = "push_registration" diff --git a/src/snek/model/__init__.py b/src/snek/model/__init__.py index dec05e8..e5c6054 100644 --- a/src/snek/model/__init__.py +++ b/src/snek/model/__init__.py @@ -12,6 +12,7 @@ from snek.model.user import UserModel from snek.model.user_property import UserPropertyModel from snek.model.repository import RepositoryModel from snek.model.channel_attachment import ChannelAttachmentModel +from snek.model.push_registration import PushRegistrationModel from snek.model.container import Container from snek.system.object import Object @@ -31,6 +32,7 @@ def get_models(): "repository": RepositoryModel, "channel_attachment": ChannelAttachmentModel, "container": Container, + "push_registration": PushRegistrationModel, } ) diff --git a/src/snek/model/push_registration.py b/src/snek/model/push_registration.py new file mode 100644 index 0000000..bb13be0 --- /dev/null +++ b/src/snek/model/push_registration.py @@ -0,0 +1,8 @@ +from snek.system.model import BaseModel, ModelField + + +class PushRegistrationModel(BaseModel): + user_uid = ModelField(name="user_uid", required=True) + endpoint = ModelField(name="endpoint", required=True) + key_auth = ModelField(name="key_auth", required=True) + key_p256dh = ModelField(name="key_p256dh", required=True) diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py index 5669bdb..f1654fb 100644 --- a/src/snek/service/__init__.py +++ b/src/snek/service/__init__.py @@ -9,6 +9,7 @@ from snek.service.drive_item import DriveItemService from snek.service.notification import NotificationService from snek.service.socket import SocketService from snek.service.user import UserService +from snek.service.push import PushService from snek.service.user_property import UserPropertyService from snek.service.util import UtilService from snek.service.repository import RepositoryService @@ -17,6 +18,7 @@ from snek.service.container import ContainerService from snek.system.object import Object from snek.service.db import DBService + @functools.cache def get_services(app): return Object( @@ -36,6 +38,7 @@ def get_services(app): "db": DBService(app=app), "channel_attachment": ChannelAttachmentService(app=app), "container": ContainerService(app=app), + "push": PushService(app=app), } ) diff --git a/src/snek/service/notification.py b/src/snek/service/notification.py index a22e8ae..65d348f 100644 --- a/src/snek/service/notification.py +++ b/src/snek/service/notification.py @@ -62,4 +62,17 @@ class NotificationService(BaseService): except Exception: raise Exception(f"Failed to create notification: {model.errors}.") + try: + await self.app.services.push.notify_user( + user_uid=channel_member["user_uid"], + payload={ + "title": f"New message in {channel_member['label']}", + "message": f"{user['nick']}: {channel_message['message']}", + "icon": "/image/snek192.png", + "url": f"/channel/{channel_message['channel_uid']}.html", + }, + ) + except Exception as e: + print(f"Failed to send push notification:", e) + self.app.db.commit() diff --git a/src/snek/system/notification.py b/src/snek/service/push.py similarity index 70% rename from src/snek/system/notification.py rename to src/snek/service/push.py index ad578d6..ccba4af 100644 --- a/src/snek/system/notification.py +++ b/src/snek/service/push.py @@ -1,3 +1,7 @@ +import json + +import aiohttp +from snek.system.service import BaseService import random import time import base64 @@ -87,12 +91,15 @@ def _browser_base64(data): return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") -class Notifications: +class PushService(BaseService): + mapper_name = "push" + private_key_pem = None public_key = None public_key_base64 = None - def __init__(self): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) ensure_certificates() private_key = serialization.load_pem_private_key( @@ -139,7 +146,9 @@ class Notifications: algorithm="ES256", ) - def create_notification_info_with_payload(self, endpoint: str, auth: str, p256dh: str, payload: str): + def create_notification_info_with_payload( + self, endpoint: str, auth: str, p256dh: str, payload: str + ): message_private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) message_public_key_bytes = message_private_key.public_key().public_bytes( @@ -196,7 +205,67 @@ class Notifications: "data": data, } + async def notify_user(self, user_uid: str, payload: dict): + async with aiohttp.ClientSession() as session: + async for subscription in self.find(user_uid=user_uid): + endpoint = subscription["endpoint"] + key_auth = subscription["key_auth"] + key_p256dh = subscription["key_p256dh"] -@cache -def get_notifications(): - return Notifications() + notification_info = self.create_notification_info_with_payload( + endpoint, key_auth, key_p256dh, json.dumps(payload) + ) + + headers = { + **notification_info["headers"], + "TTL": "60", + } + data = notification_info["data"] + + async with session.post( + endpoint, + headers=headers, + data=data, + ) as response: + if response.status == 201 or response.status == 200: + print( + f"Notification sent to user {user_uid} via endpoint {endpoint}" + ) + else: + print( + f"Failed to send notification to user {user_uid} via endpoint {endpoint}: {response.status}" + ) + else: + print(f"No push subscriptions found for user {user_uid}") + + async def register( + self, user_uid: str, endpoint: str, key_auth: str, key_p256dh: str + ): + if await self.exists( + user_uid=user_uid, + endpoint=endpoint, + key_auth=key_auth, + key_p256dh=key_p256dh, + ): + return + + model = await self.new() + model["user_uid"] = user_uid + model["endpoint"] = endpoint + model["key_auth"] = key_auth + model["key_p256dh"] = key_p256dh + + print( + f"Registering push subscription for user {user_uid} with endpoint {endpoint}" + ) + + if await self.save(model=model) and model: + print( + f"Push subscription registered for user {user_uid} with endpoint {endpoint}" + ) + + return model + + raise Exception( + f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}" + ) \ No newline at end of file diff --git a/src/snek/static/push.js b/src/snek/static/push.js index 89aacb6..1652817 100644 --- a/src/snek/static/push.js +++ b/src/snek/static/push.js @@ -5,6 +5,8 @@ export const registerServiceWorker = async (silent = false) => { await serviceWorkerRegistration.update() + await navigator.serviceWorker.ready + const keyResponse = await fetch('/push.json') const keyData = await keyResponse.json() diff --git a/src/snek/view/push.py b/src/snek/view/push.py index 15b082d..9129ba6 100644 --- a/src/snek/view/push.py +++ b/src/snek/view/push.py @@ -30,18 +30,15 @@ import json import requests from cryptography.hazmat.primitives import serialization -from snek.system.notification import get_notifications from snek.system.view import BaseFormView class PushView(BaseFormView): async def get(self): - notifications = get_notifications() - return await self.json_response( { "publicKey": base64.b64encode( - notifications.public_key.public_bytes( + self.app.services.push.public_key.public_bytes( encoding=serialization.Encoding.X962, format=serialization.PublicFormat.UncompressedPoint, ) @@ -76,69 +73,47 @@ class PushView(BaseFormView): {"error": "Invalid request"}, status=400 ) - print(body) - notifications = get_notifications() - - test_payload = { - "title": "Hey retoor", - "message": "Guess what? ;P", - "icon": "/image/snek192.png", - "url": "/web.html", - } - - notification_info = notifications.create_notification_info_with_payload( - body["endpoint"], - body["keys"]["auth"], - body["keys"]["p256dh"], - json.dumps(test_payload), + regist = await self.app.services.push.register( + user_uid=user_id, + endpoint=body["endpoint"], + key_auth=body["keys"]["auth"], + key_p256dh=body["keys"]["p256dh"], ) - headers = { - **notification_info["headers"], - "TTL": "60", - } + if regist: + test_payload = { + "title": f"Welcome {user['nick']}!", + "message": "You'll now receive notifications from Snek :D", + "icon": "/image/snek192.png", + "url": "/web.html", + } - print(headers) - - post_notification = requests.post( - body["endpoint"], headers=headers, data=notification_info["data"] - ) - - print(post_notification.status_code) - print(post_notification.text) - print(post_notification.headers) - - async for model in self.app.services.channel_member.find( - user_uid=user_id, deleted_at=None, is_banned=False - ): - channel = await self.app.services.channel.get(uid=model["channel_uid"]) - memberships.append( - { - "name": channel["label"], - "description": model["description"], - "user_uid": model["user_uid"], - "is_moderator": model["is_moderator"], - "is_read_only": model["is_read_only"], - "is_muted": model["is_muted"], - "is_banned": model["is_banned"], - "channel_uid": model["channel_uid"], - "uid": model["uid"], - } + notification_info = ( + self.app.services.push.create_notification_info_with_payload( + body["endpoint"], + body["keys"]["auth"], + body["keys"]["p256dh"], + json.dumps(test_payload), + ) ) - user = { - "username": user["username"], - "email": user["email"], - "nick": user["nick"], - "uid": user["uid"], - "color": user["color"], - "memberships": memberships, - } + + headers = { + **notification_info["headers"], + "TTL": "60", + } + + print(headers) + + post_notification = requests.post( + body["endpoint"], headers=headers, data=notification_info["data"] + ) + + print(post_notification.status_code) + print(post_notification.text) + print(post_notification.headers) return await self.json_response( { - "user": user, - "cache": await self.app.cache.create_cache_key( - self.app.cache.cache, None - ), + "registered": True, } ) \ No newline at end of file