diff --git a/cert.pem b/cert.pem new file mode 100644 index 0000000..677cf5f --- /dev/null +++ b/cert.pem @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUB7PQvHZD6v8hfxeaDbU3hC0nGQQwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCTkwxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTA0MDYxOTUzMDhaFw0yNjA0 +MDYxOTUzMDhaMEUxCzAJBgNVBAYTAk5MMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCtYf8PP7QjRJOfK6zmfAZhSKwMowCSYijKeChxsgyn +hDDE8A/OuOuluJh6M/X+ZH0Q4HWTAaTwrXesBBPhie+4KmtsykiI7QEHXVVrWHba +6t5ymKiFiu+rWMwJVznS7T8K+DPGLRO2bF71Fme4ofJ2Plb7PnF53R4Tc3aTMdIW +HrUsU1JMNmbCibSVlkfPXSg/HY3XLysCrtrldPHYbTGvBcDUil7qZ8hZ8ZxLMzu3 +GPo6awPc0RBqw3tZu6SCECwQJEM0gX2n5nSyVz+fVgvLozNL9kV89hbZo7H/M37O +zmxVNwsAwoHpAGmnYs3ZYt4Q8duYjF1AtgZyXgXgdMdLAgMBAAGjUzBRMB0GA1Ud +DgQWBBQtGeiVTYjzWb2hTqJwipRVXU1LnzAfBgNVHSMEGDAWgBQtGeiVTYjzWb2h +TqJwipRVXU1LnzAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAc +1BacrGMlCd5nfYuvQfv0DdTVGc2FSqxPMRGrZKfjvjemgPMs0+DqUwCJiR6oEOGb +atOYoIBX9KGXSUKRYYc/N75bslwfV1CclNqd2mPxULfks/D8cAzf2mgw4kYSaDHs +tJkywBe9L6eIK4cQ5YJvutVNVKMYPi+9w+wKog/FafkamFfX/3SLCkGmV0Vv4g0q +Ro9KmTTQpJUvd63X8bONLs1t8p+HQfWmKlhuVn5+mncNdGREe8dbciXE5FKu8luN +dr/twoTZTPhmIHPmVEeNxS8hFSiu0iUPTO0HcCAODILGbtXbClA+1Z0ukiRfUya6 +tgVuEk0c64L86qGP7Ply +-----END CERTIFICATE----- diff --git a/key.pem b/key.pem new file mode 100644 index 0000000..01682ba --- /dev/null +++ b/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCtYf8PP7QjRJOf +K6zmfAZhSKwMowCSYijKeChxsgynhDDE8A/OuOuluJh6M/X+ZH0Q4HWTAaTwrXes +BBPhie+4KmtsykiI7QEHXVVrWHba6t5ymKiFiu+rWMwJVznS7T8K+DPGLRO2bF71 +Fme4ofJ2Plb7PnF53R4Tc3aTMdIWHrUsU1JMNmbCibSVlkfPXSg/HY3XLysCrtrl +dPHYbTGvBcDUil7qZ8hZ8ZxLMzu3GPo6awPc0RBqw3tZu6SCECwQJEM0gX2n5nSy +Vz+fVgvLozNL9kV89hbZo7H/M37OzmxVNwsAwoHpAGmnYs3ZYt4Q8duYjF1AtgZy +XgXgdMdLAgMBAAECggEAFnbkqz8fweoNY8mEOiDGWth695rZuh20bKIA63+cRXV1 +NC8T0pRXGT5qUyW5sQpSwgWzINGiY09hJWJ/M5vBpDpVd4pbYj0DAxyZXV01mSER +TVvGNKH5x65WUWeB0Hh40J0JaEXy5edIrmIGx6oEAO9hfxAUzStUeES05QFxgk1Q +RI4rKgvVt4W4wEGqSqX7OMwSU1EHJkX+IKYUdXvFA4Gi192mHHhX9MMDK/RSaDOC +1ZzHzHeKoTlf4jaUcwATlibo8ExGu4wsY+y3+NKE15o6D36AZD7ObqDOF1RsyfGG +eyljXzcglZAJN9Ctrz0xj5Xt22HqwsPO0o0mJ7URYQKBgQDcUWiu2acJJyjEJ89F +aiw3z5RvyO9LksHXwkf6gAV+dro/JeUf7u9Qgz3bwnoqwL16u+vjZxrtcpzkjc2C ++DIr6spCf8XkneJ2FovrFDe6oJSFxbgeexkQEBgw0TskRKILN8PGS6FAOfe8Zkwz +OHAJOYjxoVVoSeDPnxdu6uwJSQKBgQDJdpwZrtjGKSxkcMJzUlmp3XAPdlI1hZkl +v56Sdj6+Wz9bNTFlgiPHS+4Z7M+LyotShOEqwMfe+MDqVxTIB9TWfnmvnFDxI1VB +orHogWVWMHOqPJAzGrrWgbG2CSIiwQ3WFxU1nXqAeNk9aIFidGco87l3lVb4XEZs +eoUOUic/8wKBgQCK6r3x+gULjWhz/pH/t8l3y2hR78WKxld5XuQZvB06t0wKQy+s +qfC1uHsJlR+I04zl1ZYQBdQBwlHQ/uSFX0/rRxkPQxeZZkADq4W/zTiycUwU6S2F +8qJD8ZH/Pf5niOsP3bKQ1uEu6R4e6fXEGiLyfheuG8cJggPBhhO1eWUpGQKBgQDC +L+OzFce46gLyJYYopl3qz5iuLrx6/nVp31O3lOZRkZ52CcW9ND3MYjH1Jz++XNMC +DTcEgKGnGFrLBnjvfiz3Ox2L2b5jUE1jYLDfjanh8/3pP0s3FzK0hHqJHjCbEz6E +9+bnsQ1dPB8Zg9wCzHSLErHYxEf6SOdQtJ//98wBZQKBgQDLON5QPUAJ21uZRvwv +9LsjKMpd5f/L6/q5j6YYXNpys5MREUgryDpR/uqcmyBuxCU3vBeK8tpYJzfXqO45 +5jFoiKhtEFXjb1+d18ACKg1gXQF0Ljry59HGiZOw7IubRPHh9CDdT5tzynylipr3 +xhhX7RsDOYMFKmn59DS1CQCZAA== +-----END PRIVATE KEY----- diff --git a/src/snek/app.py b/src/snek/app.py index 6a79592..03ef491 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -1,14 +1,14 @@ import asyncio import logging import pathlib -import time +import ssl import uuid from datetime import datetime -from snek import snode +from snek import snode from snek.view.threads import ThreadsView -import json + logging.basicConfig(level=logging.DEBUG) -from ipaddress import ip_address +from ipaddress import ip_address from concurrent.futures import ThreadPoolExecutor from aiohttp import web @@ -22,6 +22,7 @@ from app.app import Application as BaseApplication from jinja2 import FileSystemLoader import IP2Location from snek.sssh import start_ssh_server + from snek.docs.app import Application as DocsApplication from snek.mapper import get_mappers from snek.service import get_services @@ -38,6 +39,7 @@ from snek.view.drive import DriveView from snek.view.drive import DriveApiView from snek.view.index import IndexView from snek.view.login import LoginView +from snek.view.push import PushView from snek.view.logout import LogoutView from snek.view.register import RegisterView from snek.view.rpc import RPCView @@ -57,10 +59,16 @@ from snek.view.user import UserView from snek.view.web import WebView from snek.view.channel import ChannelAttachmentView from snek.view.channel import ChannelView -from snek.view.settings.containers import ContainersIndexView, ContainersCreateView, ContainersUpdateView, ContainersDeleteView +from snek.view.settings.containers import ( + ContainersIndexView, + ContainersCreateView, + ContainersUpdateView, + ContainersDeleteView, +) from snek.webdav import WebdavApplication from snek.system.template import sanitize_html from snek.sgit import GitApplication + SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34" from snek.system.template import whitelist_attributes @@ -71,32 +79,34 @@ async def session_middleware(request, handler): response = await handler(request) return response -@web.middleware + +@web.middleware async def ip2location_middleware(request, handler): response = await handler(request) return response ip = request.headers.get("X-Forwarded-For", request.remote) ipaddress = ip_address(ip) if ipaddress.is_private: - return response + return response if not request.app.session.get("uid"): - return response + return response user = await request.app.services.user.get(uid=request.app.session.get("uid")) if not user: return response location = request.app.ip2location.get(ip) - original_city = user['city'] - if user['city'] != location.city: - user['country_long'] = location.country - user['country_short'] = locaion.country_short - user['city'] = location.city - user['region'] = location.region - user['latitude'] = location.latitude - user['longitude'] = location.longitude - user['ip'] = ip + original_city = user["city"] + if user["city"] != location.city: + user["country_long"] = location.country + user["country_short"] = locaion.country_short + user["city"] = location.city + user["region"] = location.region + user["latitude"] = location.latitude + user["longitude"] = location.longitude + user["ip"] = ip await request.app.services.user.update(user) return response + @web.middleware async def trailing_slash_middleware(request, handler): if request.path and not request.path.endswith("/"): @@ -106,7 +116,6 @@ async def trailing_slash_middleware(request, handler): class Application(BaseApplication): - def __init__(self, *args, **kwargs): middlewares = [ cors_middleware, @@ -117,7 +126,10 @@ class Application(BaseApplication): self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.static_path = pathlib.Path(__file__).parent.joinpath("static") super().__init__( - middlewares=middlewares, template_path=self.template_path, client_max_size=1024*1024*1024*5 *args, **kwargs + middlewares=middlewares, + template_path=self.template_path, + client_max_size=1024 * 1024 * 1024 * 5 * args, + **kwargs, ) session_setup(self, EncryptedCookieStorage(SESSION_KEY)) self.tasks = asyncio.Queue() @@ -131,6 +143,7 @@ class Application(BaseApplication): self.time_start = datetime.now() self.ssh_host = "0.0.0.0" self.ssh_port = 2242 + self.setup_router() self.ssh_server = None self.sync_service = None @@ -140,25 +153,24 @@ class Application(BaseApplication): self.mappers = get_mappers(app=self) self.broadcast_service = None self.user_availability_service_task = None - - base_path = pathlib.Path(__file__).parent - self.ip2location = IP2Location.IP2Location(base_path.joinpath("IP2LOCATION-LITE-DB11.BIN")) - + base_path = pathlib.Path(__file__).parent + self.ip2location = IP2Location.IP2Location( + base_path.joinpath("IP2LOCATION-LITE-DB11.BIN") + ) self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.start_user_availability_service) self.on_startup.append(self.start_ssh_server) self.on_startup.append(self.prepare_database) - @property def uptime_seconds(self): return (datetime.now() - self.time_start).total_seconds() - @property + @property def uptime(self): return self._format_uptime(self.uptime_seconds) - - def _format_uptime(self,seconds): + + def _format_uptime(self, seconds): seconds = int(seconds) days, seconds = divmod(seconds, 86400) hours, seconds = divmod(seconds, 3600) @@ -176,14 +188,16 @@ class Application(BaseApplication): return ", ".join(parts) - async def start_user_availability_service(self, app): - app.user_availability_service_task = asyncio.create_task(app.services.socket.user_availability_service()) + app.user_availability_service_task = asyncio.create_task( + app.services.socket.user_availability_service() + ) + async def snode_sync(self, app): self.sync_service = asyncio.create_task(snode.sync_service(app)) - + async def start_ssh_server(self, app): - app.ssh_server = await start_ssh_server(app,app.ssh_host,app.ssh_port) + app.ssh_server = await start_ssh_server(app, app.ssh_host, app.ssh_port) if app.ssh_server: asyncio.create_task(app.ssh_server.wait_closed()) @@ -242,6 +256,7 @@ class Application(BaseApplication): self.router.add_view("/settings/index.html", SettingsIndexView) self.router.add_view("/settings/profile.html", SettingsProfileView) self.router.add_view("/settings/profile.json", SettingsProfileView) + self.router.add_view("/push.json", PushView) self.router.add_view("/web.html", WebView) self.router.add_view("/login.html", LoginView) self.router.add_view("/login.json", LoginView) @@ -281,10 +296,10 @@ class Application(BaseApplication): self.webdav = WebdavApplication(self) self.git = GitApplication(self) self.add_subapp("/webdav", self.webdav) - self.add_subapp("/git",self.git) - - #self.router.add_get("/{file_path:.*}", self.static_handler) - + self.add_subapp("/git", self.git) + + # self.router.add_get("/{file_path:.*}", self.static_handler) + async def handle_test(self, request): return await whitelist_attributes(self.render_template( @@ -313,9 +328,8 @@ class Application(BaseApplication): async for subscribed_channel in self.services.channel_member.find( user_uid=request.session.get("uid"), deleted_at=None, is_banned=False ): - parent_object = await subscribed_channel.get_channel() - + item = {} other_user = await self.services.channel_member.get_other_dm_user( subscribed_channel["channel_uid"], request.session.get("uid") @@ -360,14 +374,13 @@ class Application(BaseApplication): rendered = await super().render_template(template, request, context) self.jinja2_env.loader = self.original_loader - + #rendered.text = whitelist_attributes(rendered.text) #rendered.headers['Content-Lenght'] = len(rendered.text) return rendered - async def static_handler(self, request): - file_name = request.match_info.get('filename', '') + file_name = request.match_info.get("filename", "") paths = [] @@ -376,12 +389,12 @@ class Application(BaseApplication): user_static_path = await self.services.user.get_static_path(uid) if user_static_path: paths.append(user_static_path) - + for admin_uid in self.services.user.get_admin_uids(): user_static_path = await self.services.user.get_static_path(admin_uid) if user_static_path: paths.append(user_static_path) - + paths.append(self.static_path) for path in paths: @@ -401,7 +414,6 @@ class Application(BaseApplication): if user_template_path: template_paths.append(user_template_path) - template_paths.append(self.template_path) return FileSystemLoader(template_paths) @@ -410,7 +422,9 @@ app = Application(db_path="sqlite:///snek.db") async def main(): - await web._run_app(app, port=8081, host="0.0.0.0") + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain('cert.pem', 'key.pem') + await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context) if __name__ == "__main__": diff --git a/src/snek/mapper/__init__.py b/src/snek/mapper/__init__.py index 69901ec..48a8a0e 100644 --- a/src/snek/mapper/__init__.py +++ b/src/snek/mapper/__init__.py @@ -9,6 +9,7 @@ from snek.mapper.notification import NotificationMapper from snek.mapper.user import UserMapper from snek.mapper.user_property import UserPropertyMapper from snek.mapper.repository import RepositoryMapper +from snek.mapper.push import PushMapper from snek.mapper.channel_attachment import ChannelAttachmentMapper from snek.mapper.container import ContainerMapper from snek.system.object import Object @@ -30,6 +31,7 @@ def get_mappers(app=None): "repository": RepositoryMapper(app=app), "channel_attachment": ChannelAttachmentMapper(app=app), "container": ContainerMapper(app=app), + "push": PushMapper(app=app), } ) diff --git a/src/snek/mapper/push.py b/src/snek/mapper/push.py new file mode 100644 index 0000000..2fe7c29 --- /dev/null +++ b/src/snek/mapper/push.py @@ -0,0 +1,7 @@ +from snek.model.push_registration import PushRegistrationModel +from snek.system.mapper import BaseMapper + + +class PushMapper(BaseMapper): + model_class = PushRegistrationModel + table_name = "push_registration" diff --git a/src/snek/model/__init__.py b/src/snek/model/__init__.py index dec05e8..e5c6054 100644 --- a/src/snek/model/__init__.py +++ b/src/snek/model/__init__.py @@ -12,6 +12,7 @@ from snek.model.user import UserModel from snek.model.user_property import UserPropertyModel from snek.model.repository import RepositoryModel from snek.model.channel_attachment import ChannelAttachmentModel +from snek.model.push_registration import PushRegistrationModel from snek.model.container import Container from snek.system.object import Object @@ -31,6 +32,7 @@ def get_models(): "repository": RepositoryModel, "channel_attachment": ChannelAttachmentModel, "container": Container, + "push_registration": PushRegistrationModel, } ) diff --git a/src/snek/model/push_registration.py b/src/snek/model/push_registration.py new file mode 100644 index 0000000..bb13be0 --- /dev/null +++ b/src/snek/model/push_registration.py @@ -0,0 +1,8 @@ +from snek.system.model import BaseModel, ModelField + + +class PushRegistrationModel(BaseModel): + user_uid = ModelField(name="user_uid", required=True) + endpoint = ModelField(name="endpoint", required=True) + key_auth = ModelField(name="key_auth", required=True) + key_p256dh = ModelField(name="key_p256dh", required=True) diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py index 5669bdb..f1654fb 100644 --- a/src/snek/service/__init__.py +++ b/src/snek/service/__init__.py @@ -9,6 +9,7 @@ from snek.service.drive_item import DriveItemService from snek.service.notification import NotificationService from snek.service.socket import SocketService from snek.service.user import UserService +from snek.service.push import PushService from snek.service.user_property import UserPropertyService from snek.service.util import UtilService from snek.service.repository import RepositoryService @@ -17,6 +18,7 @@ from snek.service.container import ContainerService from snek.system.object import Object from snek.service.db import DBService + @functools.cache def get_services(app): return Object( @@ -36,6 +38,7 @@ def get_services(app): "db": DBService(app=app), "channel_attachment": ChannelAttachmentService(app=app), "container": ContainerService(app=app), + "push": PushService(app=app), } ) diff --git a/src/snek/service/notification.py b/src/snek/service/notification.py index a22e8ae..1e34bae 100644 --- a/src/snek/service/notification.py +++ b/src/snek/service/notification.py @@ -62,4 +62,18 @@ class NotificationService(BaseService): except Exception: raise Exception(f"Failed to create notification: {model.errors}.") + if channel_member["user_uid"] != user["uid"]: + try: + await self.app.services.push.notify_user( + user_uid=channel_member["user_uid"], + payload={ + "title": f"New message in {channel_member['label']}", + "message": f"{user['nick']}: {channel_message['message']}", + "icon": "/image/snek192.png", + "url": f"/channel/{channel_message['channel_uid']}.html", + }, + ) + except Exception as e: + print(f"Failed to send push notification:", e) + self.app.db.commit() diff --git a/src/snek/service/push.py b/src/snek/service/push.py new file mode 100644 index 0000000..ccba4af --- /dev/null +++ b/src/snek/service/push.py @@ -0,0 +1,271 @@ +import json + +import aiohttp +from snek.system.service import BaseService +import random +import time +import base64 +import uuid +from functools import cache +from pathlib import Path + +import jwt +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.backends import default_backend +from urllib.parse import urlparse +import os.path + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + +# The only reason to persist the keys is to be able to use them in the web push + +PRIVATE_KEY_FILE = Path("./notification-private.pem") +PRIVATE_KEY_PKCS8_FILE = Path("./notification-private.pkcs8.pem") +PUBLIC_KEY_FILE = Path("./notification-public.pem") + + +def generate_private_key(): + if not PRIVATE_KEY_FILE.exists(): + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + PRIVATE_KEY_FILE.write_bytes(pem) + + +def generate_pcks8_private_key(): + if not PRIVATE_KEY_PKCS8_FILE.exists(): + private_key = serialization.load_pem_private_key( + PRIVATE_KEY_FILE.read_bytes(), password=None, backend=default_backend() + ) + + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + PRIVATE_KEY_PKCS8_FILE.write_bytes(pem) + + +def generate_public_key(): + if not PUBLIC_KEY_FILE.exists(): + private_key = serialization.load_pem_private_key( + PRIVATE_KEY_FILE.read_bytes(), password=None, backend=default_backend() + ) + + public_key = private_key.public_key() + + pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + PUBLIC_KEY_FILE.write_bytes(pem) + + +def ensure_certificates(): + generate_private_key() + generate_pcks8_private_key() + generate_public_key() + + +def hkdf(input_key, salt, info, length): + return HKDF( + algorithm=SHA256(), + length=length, + salt=salt, + info=info, + backend=default_backend(), + ).derive(input_key) + + +def _browser_base64(data): + return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") + + +class PushService(BaseService): + mapper_name = "push" + + private_key_pem = None + public_key = None + public_key_base64 = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + ensure_certificates() + + private_key = serialization.load_pem_private_key( + PRIVATE_KEY_FILE.read_bytes(), password=None, backend=default_backend() + ) + self.private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + self.public_key = serialization.load_pem_public_key( + PUBLIC_KEY_FILE.read_bytes(), backend=default_backend() + ) + + self.public_key_base64 = _browser_base64( + self.public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + ) + + def create_notification_authorization(self, push_url): + target = urlparse(push_url) + aud = f"{target.scheme}://{target.netloc}" + sub = "mailto:admin@molodetz.nl" + + identifier = str(uuid.uuid4()) + + print( + f"Creating notification authorization for {aud} with identifier {identifier}" + ) + + return jwt.encode( + { + "sub": sub, + "aud": aud, + "exp": int(time.time()) + 60 * 60, + "nbf": int(time.time()), + "iat": int(time.time()), + "jti": identifier, + }, + self.private_key_pem, + algorithm="ES256", + ) + + def create_notification_info_with_payload( + self, endpoint: str, auth: str, p256dh: str, payload: str + ): + message_private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + + message_public_key_bytes = message_private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + + salt = os.urandom(16) + + user_key_bytes = base64.urlsafe_b64decode(p256dh + "==") + shared_secret = message_private_key.exchange( + ec.ECDH(), + ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP256R1(), user_key_bytes + ), + ) + + encryption_key = hkdf( + shared_secret, + base64.urlsafe_b64decode(auth + "=="), + b"Content-Encoding: auth\x00", + 32, + ) + + context = ( + b"P-256\x00" + + len(user_key_bytes).to_bytes(2, "big") + + user_key_bytes + + len(message_public_key_bytes).to_bytes(2, "big") + + message_public_key_bytes + ) + + nonce = hkdf(encryption_key, salt, b"Content-Encoding: nonce\x00" + context, 12) + content_encryption_key = hkdf( + encryption_key, salt, b"Content-Encoding: aesgcm\x00" + context, 16 + ) + + padding_length = random.randint(0, 16) + padding = padding_length.to_bytes(2, "big") + b"\x00" * padding_length + + data = AESGCM(content_encryption_key).encrypt( + nonce, padding + payload.encode("utf-8"), None + ) + + return { + "headers": { + "Authorization": f"WebPush {self.create_notification_authorization(endpoint)}", + "Crypto-Key": f"dh={_browser_base64(message_public_key_bytes)}; p256ecdsa={self.public_key_base64}", + "Encryption": f"salt={_browser_base64(salt)}", + "Content-Encoding": "aesgcm", + "Content-Length": str(len(data)), + "Content-Type": "application/octet-stream", + }, + "data": data, + } + + async def notify_user(self, user_uid: str, payload: dict): + async with aiohttp.ClientSession() as session: + async for subscription in self.find(user_uid=user_uid): + endpoint = subscription["endpoint"] + key_auth = subscription["key_auth"] + key_p256dh = subscription["key_p256dh"] + + notification_info = self.create_notification_info_with_payload( + endpoint, key_auth, key_p256dh, json.dumps(payload) + ) + + headers = { + **notification_info["headers"], + "TTL": "60", + } + data = notification_info["data"] + + async with session.post( + endpoint, + headers=headers, + data=data, + ) as response: + if response.status == 201 or response.status == 200: + print( + f"Notification sent to user {user_uid} via endpoint {endpoint}" + ) + else: + print( + f"Failed to send notification to user {user_uid} via endpoint {endpoint}: {response.status}" + ) + else: + print(f"No push subscriptions found for user {user_uid}") + + async def register( + self, user_uid: str, endpoint: str, key_auth: str, key_p256dh: str + ): + if await self.exists( + user_uid=user_uid, + endpoint=endpoint, + key_auth=key_auth, + key_p256dh=key_p256dh, + ): + return + + model = await self.new() + model["user_uid"] = user_uid + model["endpoint"] = endpoint + model["key_auth"] = key_auth + model["key_p256dh"] = key_p256dh + + print( + f"Registering push subscription for user {user_uid} with endpoint {endpoint}" + ) + + if await self.save(model=model) and model: + print( + f"Push subscription registered for user {user_uid} with endpoint {endpoint}" + ) + + return model + + raise Exception( + f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}" + ) \ No newline at end of file diff --git a/src/snek/static/push.js b/src/snek/static/push.js index f9917f4..1652817 100644 --- a/src/snek/static/push.js +++ b/src/snek/static/push.js @@ -1,34 +1,57 @@ -this.onpush = (event) => { - console.log(event.data); - // From here we can write the data to IndexedDB, send it to any open - // windows, display a notification, etc. -}; +export const registerServiceWorker = async (silent = false) => { + try { + const serviceWorkerRegistration = await navigator.serviceWorker + .register("/service-worker.js") + + await serviceWorkerRegistration.update() + + await navigator.serviceWorker.ready + + const keyResponse = await fetch('/push.json') + const keyData = await keyResponse.json() + + const publicKey = Uint8Array.from(atob(keyData.publicKey), c => c.charCodeAt(0)) + + const pushSubscription = await serviceWorkerRegistration.pushManager.subscribe({ + userVisibleOnly: true, applicationServerKey: publicKey, + }) -navigator.serviceWorker - .register("/service-worker.js") - .then((serviceWorkerRegistration) => { - serviceWorkerRegistration.pushManager.subscribe().then( - (pushSubscription) => { const subscriptionObject = { - endpoint: pushSubscription.endpoint, - keys: { - p256dh: pushSubscription.getKey("p256dh"), - auth: pushSubscription.getKey("auth"), - }, - encoding: PushManager.supportedContentEncodings, - /* other app-specific data, such as user identity */ + ...pushSubscription.toJSON(), encoding: PushManager.supportedContentEncodings, }; - console.log( - pushSubscription.endpoint, - pushSubscription, - subscriptionObject, - ); - // The push subscription details needed by the application - // server are now available, and can be sent to it using, - // for example, the fetch() API. - }, - (error) => { - console.error(error); - }, - ); - }); + console.log(pushSubscription.endpoint, pushSubscription, pushSubscription.toJSON(), subscriptionObject); + + const response = await fetch('/push.json', { + method: 'POST', headers: { + 'Content-Type': 'application/json', + }, body: JSON.stringify(subscriptionObject), + }) + + if (!response.ok) { + throw new Error('Bad status code from server.'); + } + + const responseData = await response.json(); + console.log('Registration response', responseData); + } catch (error) { + console.error("Error registering service worker:", error); + if (!silent) { + alert("Registering push notifications failed. Please check your browser settings and try again.\n\n" + error); + } + } +} + + +window.registerNotificationsServiceWorker = () => { + return Notification.requestPermission().then((permission) => { + if (permission === "granted") { + console.log("Permission was granted"); + return registerServiceWorker(); + } else if (permission === "denied") { + console.log("Permission was denied"); + } else { + console.log("Permission was dismissed"); + } + }); +}; +registerServiceWorker(true).catch(console.error); diff --git a/src/snek/static/service-worker.js b/src/snek/static/service-worker.js index c89b668..310fa96 100644 --- a/src/snek/static/service-worker.js +++ b/src/snek/static/service-worker.js @@ -1,65 +1,66 @@ -async function requestNotificationPermission() { - const permission = await Notification.requestPermission(); - return permission === "granted"; -} - -// Subscribe to Push Notifications -async function subscribeUser() { - const registration = - await navigator.serviceWorker.register("/service-worker.js"); - - const subscription = await registration.pushManager.subscribe({ - userVisibleOnly: true, - applicationServerKey: urlBase64ToUint8Array(PUBLIC_VAPID_KEY), - }); - - // Send subscription to your backend - await fetch("/subscribe", { - method: "POST", - body: JSON.stringify(subscription), - headers: { - "Content-Type": "application/json", - }, - }); -} - -// Service Worker (service-worker.js) -self.addEventListener("push", (event) => { - const data = event.data.json(); - self.registration.showNotification(data.title, { - body: data.message, - icon: data.icon, - }); -}); - -/* self.addEventListener("install", (event) => { - console.log("Service worker installed"); + console.log("Service worker installing..."); + event.waitUntil( + caches.open("snek-cache").then((cache) => { + return cache.addAll([]); + }) + ); +}) + +self.addEventListener("activate", (event) => { + event.waitUntil(self.registration?.navigationPreload.enable()); }); -self.addEventListener("push", (event) => { - if (!(self.Notification && self.Notification.permission === "granted")) { - return; - } - console.log("Received a push message", event); +self.addEventListener("push", (event) => { + if (!self.Notification || self.Notification.permission !== "granted") { + console.log("Notification permission not granted"); + return; + } - const data = event.data?.json() ?? {}; - const title = data.title || "Something Has Happened"; - const message = - data.message || "Here's something you might want to check out."; - const icon = "images/new-notification.png"; + const data = event.data?.json() ?? {}; + console.log("Received a push message", event, data); - event.waitUntil(self.registration.showNotification(title, { - body: message, - tag: "simple-push-demo-notification", - icon, - })); + const title = data.title || "Something Has Happened"; + const message = + data.message || "Here's something you might want to check out."; + const icon = data.icon || "/image/snek512.png"; + const notificationSettings = data.notificationSettings || {}; + + console.log("Showing message", title, message, icon); + + const reg = self.registration.showNotification(title, { + body: message, + tag: "message-received", + icon, + badge: icon, + ...notificationSettings, + data, + }).then(e => console.log("Showing notification", e)).catch(console.error); + + event.waitUntil(reg); }); self.addEventListener("notificationclick", (event) => { console.log("Notification click Received.", event); event.notification.close(); - event.waitUntil(clients.openWindow( - "https://snek.molodetz.nl",)); -});*/ + event.waitUntil(clients.openWindow(`${event.notification.data.url || event.notification.data.link || `/web.html`}`)); +}); + +self.addEventListener("notificationclose", (event) => { + console.log("Notification closed", event); +}) + +self.addEventListener("fetch", (event) => { + // console.log("Fetch event for ", event.request.url); + event.respondWith( + caches.match(event.request).then((response) => { + if (response) { + // console.log("Found response in cache: ", response); + return response; + } + // console.log("No response found in cache. About to fetch from network..."); + return fetch(event.request); + }) + ); +}) \ No newline at end of file diff --git a/src/snek/templates/app.html b/src/snek/templates/app.html index c92ae76..0119606 100644 --- a/src/snek/templates/app.html +++ b/src/snek/templates/app.html @@ -7,9 +7,7 @@