Implement push notification service and registration

This commit is contained in:
BordedDev 2025-05-31 19:08:55 +02:00
parent 272998f757
commit aec2da11f2
No known key found for this signature in database
GPG Key ID: C5F495EAE56673BF
10 changed files with 153 additions and 70 deletions

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import logging import logging
import pathlib import pathlib
import ssl
import time import time
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -23,7 +24,6 @@ from jinja2 import FileSystemLoader
from snek.sssh import start_ssh_server from snek.sssh import start_ssh_server
from snek.system.notification import get_notifications
from snek.docs.app import Application as DocsApplication from snek.docs.app import Application as DocsApplication
from snek.mapper import get_mappers from snek.mapper import get_mappers
from snek.service import get_services from snek.service import get_services
@ -105,7 +105,7 @@ class Application(BaseApplication):
self.ssh_host = "0.0.0.0" self.ssh_host = "0.0.0.0"
self.ssh_port = 2242 self.ssh_port = 2242
get_notifications()
self.setup_router() self.setup_router()
self.ssh_server = None self.ssh_server = None
self.sync_service = None self.sync_service = None
@ -380,7 +380,9 @@ app = Application(db_path="sqlite:///snek.db")
async def main(): async def main():
await web._run_app(app, port=8081, host="0.0.0.0") ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain('cert.pem', 'key.pem')
await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -9,6 +9,7 @@ from snek.mapper.notification import NotificationMapper
from snek.mapper.user import UserMapper from snek.mapper.user import UserMapper
from snek.mapper.user_property import UserPropertyMapper from snek.mapper.user_property import UserPropertyMapper
from snek.mapper.repository import RepositoryMapper from snek.mapper.repository import RepositoryMapper
from snek.mapper.push import PushMapper
from snek.mapper.channel_attachment import ChannelAttachmentMapper from snek.mapper.channel_attachment import ChannelAttachmentMapper
from snek.mapper.container import ContainerMapper from snek.mapper.container import ContainerMapper
from snek.system.object import Object from snek.system.object import Object
@ -30,6 +31,7 @@ def get_mappers(app=None):
"repository": RepositoryMapper(app=app), "repository": RepositoryMapper(app=app),
"channel_attachment": ChannelAttachmentMapper(app=app), "channel_attachment": ChannelAttachmentMapper(app=app),
"container": ContainerMapper(app=app), "container": ContainerMapper(app=app),
"push": PushMapper(app=app),
} }
) )

7
src/snek/mapper/push.py Normal file
View File

@ -0,0 +1,7 @@
from snek.model.push_registration import PushRegistrationModel
from snek.system.mapper import BaseMapper
class PushMapper(BaseMapper):
model_class = PushRegistrationModel
table_name = "push_registration"

View File

@ -12,6 +12,7 @@ from snek.model.user import UserModel
from snek.model.user_property import UserPropertyModel from snek.model.user_property import UserPropertyModel
from snek.model.repository import RepositoryModel from snek.model.repository import RepositoryModel
from snek.model.channel_attachment import ChannelAttachmentModel from snek.model.channel_attachment import ChannelAttachmentModel
from snek.model.push_registration import PushRegistrationModel
from snek.model.container import Container from snek.model.container import Container
from snek.system.object import Object from snek.system.object import Object
@ -31,6 +32,7 @@ def get_models():
"repository": RepositoryModel, "repository": RepositoryModel,
"channel_attachment": ChannelAttachmentModel, "channel_attachment": ChannelAttachmentModel,
"container": Container, "container": Container,
"push_registration": PushRegistrationModel,
} }
) )

View File

@ -0,0 +1,8 @@
from snek.system.model import BaseModel, ModelField
class PushRegistrationModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True)
endpoint = ModelField(name="endpoint", required=True)
key_auth = ModelField(name="key_auth", required=True)
key_p256dh = ModelField(name="key_p256dh", required=True)

View File

@ -9,6 +9,7 @@ from snek.service.drive_item import DriveItemService
from snek.service.notification import NotificationService from snek.service.notification import NotificationService
from snek.service.socket import SocketService from snek.service.socket import SocketService
from snek.service.user import UserService from snek.service.user import UserService
from snek.service.push import PushService
from snek.service.user_property import UserPropertyService from snek.service.user_property import UserPropertyService
from snek.service.util import UtilService from snek.service.util import UtilService
from snek.service.repository import RepositoryService from snek.service.repository import RepositoryService
@ -17,6 +18,7 @@ from snek.service.container import ContainerService
from snek.system.object import Object from snek.system.object import Object
from snek.service.db import DBService from snek.service.db import DBService
@functools.cache @functools.cache
def get_services(app): def get_services(app):
return Object( return Object(
@ -36,6 +38,7 @@ def get_services(app):
"db": DBService(app=app), "db": DBService(app=app),
"channel_attachment": ChannelAttachmentService(app=app), "channel_attachment": ChannelAttachmentService(app=app),
"container": ContainerService(app=app), "container": ContainerService(app=app),
"push": PushService(app=app),
} }
) )

View File

@ -62,4 +62,17 @@ class NotificationService(BaseService):
except Exception: except Exception:
raise Exception(f"Failed to create notification: {model.errors}.") raise Exception(f"Failed to create notification: {model.errors}.")
try:
await self.app.services.push.notify_user(
user_uid=channel_member["user_uid"],
payload={
"title": f"New message in {channel_member['label']}",
"message": f"{user['nick']}: {channel_message['message']}",
"icon": "/image/snek192.png",
"url": f"/channel/{channel_message['channel_uid']}.html",
},
)
except Exception as e:
print(f"Failed to send push notification:", e)
self.app.db.commit() self.app.db.commit()

View File

@ -1,3 +1,7 @@
import json
import aiohttp
from snek.system.service import BaseService
import random import random
import time import time
import base64 import base64
@ -87,12 +91,15 @@ def _browser_base64(data):
return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
class Notifications: class PushService(BaseService):
mapper_name = "push"
private_key_pem = None private_key_pem = None
public_key = None public_key = None
public_key_base64 = None public_key_base64 = None
def __init__(self): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
ensure_certificates() ensure_certificates()
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(
@ -139,7 +146,9 @@ class Notifications:
algorithm="ES256", algorithm="ES256",
) )
def create_notification_info_with_payload(self, endpoint: str, auth: str, p256dh: str, payload: str): def create_notification_info_with_payload(
self, endpoint: str, auth: str, p256dh: str, payload: str
):
message_private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) message_private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
message_public_key_bytes = message_private_key.public_key().public_bytes( message_public_key_bytes = message_private_key.public_key().public_bytes(
@ -196,7 +205,67 @@ class Notifications:
"data": data, "data": data,
} }
async def notify_user(self, user_uid: str, payload: dict):
async with aiohttp.ClientSession() as session:
async for subscription in self.find(user_uid=user_uid):
endpoint = subscription["endpoint"]
key_auth = subscription["key_auth"]
key_p256dh = subscription["key_p256dh"]
@cache notification_info = self.create_notification_info_with_payload(
def get_notifications(): endpoint, key_auth, key_p256dh, json.dumps(payload)
return Notifications() )
headers = {
**notification_info["headers"],
"TTL": "60",
}
data = notification_info["data"]
async with session.post(
endpoint,
headers=headers,
data=data,
) as response:
if response.status == 201 or response.status == 200:
print(
f"Notification sent to user {user_uid} via endpoint {endpoint}"
)
else:
print(
f"Failed to send notification to user {user_uid} via endpoint {endpoint}: {response.status}"
)
else:
print(f"No push subscriptions found for user {user_uid}")
async def register(
self, user_uid: str, endpoint: str, key_auth: str, key_p256dh: str
):
if await self.exists(
user_uid=user_uid,
endpoint=endpoint,
key_auth=key_auth,
key_p256dh=key_p256dh,
):
return
model = await self.new()
model["user_uid"] = user_uid
model["endpoint"] = endpoint
model["key_auth"] = key_auth
model["key_p256dh"] = key_p256dh
print(
f"Registering push subscription for user {user_uid} with endpoint {endpoint}"
)
if await self.save(model=model) and model:
print(
f"Push subscription registered for user {user_uid} with endpoint {endpoint}"
)
return model
raise Exception(
f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}"
)

View File

@ -5,6 +5,8 @@ export const registerServiceWorker = async (silent = false) => {
await serviceWorkerRegistration.update() await serviceWorkerRegistration.update()
await navigator.serviceWorker.ready
const keyResponse = await fetch('/push.json') const keyResponse = await fetch('/push.json')
const keyData = await keyResponse.json() const keyData = await keyResponse.json()

View File

@ -30,18 +30,15 @@ import json
import requests import requests
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from snek.system.notification import get_notifications
from snek.system.view import BaseFormView from snek.system.view import BaseFormView
class PushView(BaseFormView): class PushView(BaseFormView):
async def get(self): async def get(self):
notifications = get_notifications()
return await self.json_response( return await self.json_response(
{ {
"publicKey": base64.b64encode( "publicKey": base64.b64encode(
notifications.public_key.public_bytes( self.app.services.push.public_key.public_bytes(
encoding=serialization.Encoding.X962, encoding=serialization.Encoding.X962,
format=serialization.PublicFormat.UncompressedPoint, format=serialization.PublicFormat.UncompressedPoint,
) )
@ -76,69 +73,47 @@ class PushView(BaseFormView):
{"error": "Invalid request"}, status=400 {"error": "Invalid request"}, status=400
) )
print(body) regist = await self.app.services.push.register(
notifications = get_notifications() user_uid=user_id,
endpoint=body["endpoint"],
test_payload = { key_auth=body["keys"]["auth"],
"title": "Hey retoor", key_p256dh=body["keys"]["p256dh"],
"message": "Guess what? ;P",
"icon": "/image/snek192.png",
"url": "/web.html",
}
notification_info = notifications.create_notification_info_with_payload(
body["endpoint"],
body["keys"]["auth"],
body["keys"]["p256dh"],
json.dumps(test_payload),
) )
headers = { if regist:
**notification_info["headers"], test_payload = {
"TTL": "60", "title": f"Welcome {user['nick']}!",
} "message": "You'll now receive notifications from Snek :D",
"icon": "/image/snek192.png",
"url": "/web.html",
}
print(headers) notification_info = (
self.app.services.push.create_notification_info_with_payload(
post_notification = requests.post( body["endpoint"],
body["endpoint"], headers=headers, data=notification_info["data"] body["keys"]["auth"],
) body["keys"]["p256dh"],
json.dumps(test_payload),
print(post_notification.status_code) )
print(post_notification.text)
print(post_notification.headers)
async for model in self.app.services.channel_member.find(
user_uid=user_id, deleted_at=None, is_banned=False
):
channel = await self.app.services.channel.get(uid=model["channel_uid"])
memberships.append(
{
"name": channel["label"],
"description": model["description"],
"user_uid": model["user_uid"],
"is_moderator": model["is_moderator"],
"is_read_only": model["is_read_only"],
"is_muted": model["is_muted"],
"is_banned": model["is_banned"],
"channel_uid": model["channel_uid"],
"uid": model["uid"],
}
) )
user = {
"username": user["username"], headers = {
"email": user["email"], **notification_info["headers"],
"nick": user["nick"], "TTL": "60",
"uid": user["uid"], }
"color": user["color"],
"memberships": memberships, print(headers)
}
post_notification = requests.post(
body["endpoint"], headers=headers, data=notification_info["data"]
)
print(post_notification.status_code)
print(post_notification.text)
print(post_notification.headers)
return await self.json_response( return await self.json_response(
{ {
"user": user, "registered": True,
"cache": await self.app.cache.create_cache_key(
self.app.cache.cache, None
),
} }
) )