diff --git a/src/snek/static/push.js b/src/snek/static/push.js index f1089ae..2e40f2e 100644 --- a/src/snek/static/push.js +++ b/src/snek/static/push.js @@ -31,9 +31,6 @@ export const registerServiceWorker = async () => { /* other app-specific data, such as user identity */ }; console.log(pushSubscription.endpoint, pushSubscription, pushSubscription.toJSON(), subscriptionObject); - // The push subscription details needed by the application - // server are now available, and can be sent to it using, - // for example, the fetch() API. fetch('/push.json', { method: 'POST', diff --git a/src/snek/static/service-worker.js b/src/snek/static/service-worker.js index be89157..038bb5f 100644 --- a/src/snek/static/service-worker.js +++ b/src/snek/static/service-worker.js @@ -33,16 +33,18 @@ async function subscribeUser() { self.addEventListener("push", (event) => { - if (!(self.Notification && self.Notification.permission === "granted")) { + if (!self.Notification || self.Notification.permission !== "granted") { + console.log("Notification permission not granted"); return; } - console.log("Received a push message", event); const data = event.data?.json() ?? {}; + console.log("Received a push message", event, data); + const title = data.title || "Something Has Happened"; const message = data.message || "Here's something you might want to check out."; - const icon = "images/new-notification.png"; + const icon = data.icon || "images/new-notification.png"; console.log("showing message", title, message, icon); const reg = self.registration.showNotification(title, { diff --git a/src/snek/system/notification.py b/src/snek/system/notification.py index c35def3..3a60b7c 100644 --- a/src/snek/system/notification.py +++ b/src/snek/system/notification.py @@ -10,10 +10,14 @@ from cryptography.hazmat.backends import default_backend from urllib.parse import urlparse import os.path +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + +PRIVATE_KEY_FILE = "./notification-private.pem" +PRIVATE_KEY_PKCS8_FILE = "./notification-private.pkcs8.pem" +PUBLIC_KEY_FILE = "./notification-public.pem" -PRIVATE_KEY_FILE = './notification-private.pem' -PRIVATE_KEY_PKCS8_FILE = './notification-private.pkcs8.pem' -PUBLIC_KEY_FILE = './notification-public.pem' def generate_private_key(): if not os.path.isfile(PRIVATE_KEY_FILE): @@ -23,34 +27,40 @@ def generate_private_key(): pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) # Write the private key to a file - with open(PRIVATE_KEY_FILE, 'wb') as pem_out: + with open(PRIVATE_KEY_FILE, "wb") as pem_out: pem_out.write(pem) + def generate_pcks8_private_key(): # openssl pkcs8 -topk8 -nocrypt -in private.pem -out private.pkcs8.pem if not os.path.isfile(PRIVATE_KEY_PKCS8_FILE): - with open(PRIVATE_KEY_FILE, 'rb') as pem_in: - private_key = serialization.load_pem_private_key(pem_in.read(), password=None, backend=default_backend()) + with open(PRIVATE_KEY_FILE, "rb") as pem_in: + private_key = serialization.load_pem_private_key( + pem_in.read(), password=None, backend=default_backend() + ) # Serialize the private key to PKCS8 format pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) # Write the private key to a file - with open(PRIVATE_KEY_PKCS8_FILE, 'wb') as pem_out: + with open(PRIVATE_KEY_PKCS8_FILE, "wb") as pem_out: pem_out.write(pem) + def generate_public_key(): if not os.path.isfile(PUBLIC_KEY_FILE): - with open(PRIVATE_KEY_FILE, 'rb') as pem_in: - private_key = serialization.load_pem_private_key(pem_in.read(), password=None, backend=default_backend()) + with open(PRIVATE_KEY_FILE, "rb") as pem_in: + private_key = serialization.load_pem_private_key( + pem_in.read(), password=None, backend=default_backend() + ) # Get the public key from the private key public_key = private_key.public_key() @@ -58,18 +68,30 @@ def generate_public_key(): # Serialize the public key to PEM format pem = public_key.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) # Write the public key to a file - with open(PUBLIC_KEY_FILE, 'wb') as pem_out: + with open(PUBLIC_KEY_FILE, "wb") as pem_out: pem_out.write(pem) + def ensure_certificates(): generate_private_key() generate_pcks8_private_key() generate_public_key() + +def hkdf(input_key, salt, info, length): + return HKDF( + algorithm=SHA256(), + length=length, + salt=salt, + info=info, + backend=default_backend(), + ).derive(input_key) + + class Notifications: private_key_pem = None public_key = None @@ -77,30 +99,35 @@ class Notifications: def __init__(self): ensure_certificates() - with open(PRIVATE_KEY_FILE, 'rb') as pem_in: - private_key = serialization.load_pem_private_key(pem_in.read(), password=None, backend=default_backend()) + with open(PRIVATE_KEY_FILE, "rb") as pem_in: + private_key = serialization.load_pem_private_key( + pem_in.read(), password=None, backend=default_backend() + ) self.private_key_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) - with open(PUBLIC_KEY_FILE, 'rb') as pem_in: - self.public_key = serialization.load_pem_public_key(pem_in.read(), backend=default_backend()) + with open(PUBLIC_KEY_FILE, "rb") as pem_in: + self.public_key = serialization.load_pem_public_key( + pem_in.read(), backend=default_backend() + ) public_numbers = self.public_key.public_numbers() self.public_key_jwk = { "kty": "EC", "crv": "P-256", - "x": base64.urlsafe_b64encode(public_numbers.x.to_bytes(32, byteorder='big')).decode('utf-8'), - "y": base64.urlsafe_b64encode(public_numbers.y.to_bytes(32, byteorder='big')).decode('utf-8') + "x": base64.urlsafe_b64encode( + public_numbers.x.to_bytes(32, byteorder="big") + ).decode("utf-8"), + "y": base64.urlsafe_b64encode( + public_numbers.y.to_bytes(32, byteorder="big") + ).decode("utf-8"), } print(f"Public key JWK: {self.public_key_jwk}") - - - def create_notification_authorization(self, push_url): target = urlparse(push_url) aud = f"{target.scheme}://{target.netloc}" @@ -108,17 +135,91 @@ class Notifications: identifier = str(uuid.uuid4()) - print(f"Creating notification authorization for {aud} with identifier {identifier}") + print( + f"Creating notification authorization for {aud} with identifier {identifier}" + ) - return jwt.encode({ - "sub": sub, - "aud": aud, - "exp": int(time.time()) + 60 * 60, - "nbf": int(time.time()), - "iat": int(time.time()), - "jti": identifier, - }, self.private_key_pem, algorithm='ES256') + return jwt.encode( + { + "sub": sub, + "aud": aud, + "exp": int(time.time()) + 60 * 60, + "nbf": int(time.time()), + "iat": int(time.time()), + "jti": identifier, + }, + self.private_key_pem, + algorithm="ES256", + ) + def create_encrypted_payload(self, auth: str, p256dh:str, payload:str): + # 1. Generate private key + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) + + # 2. Get public key + public_key = private_key.public_key() + + salt = os.urandom(16) + + subscription_pub_key_bytes = base64.urlsafe_b64decode(p256dh+ '==') + subscription_public_key = ec.EllipticCurvePublicKey.from_encoded_point( + ec.SECP256R1(), subscription_pub_key_bytes + ) + shared_secret = private_key.exchange(ec.ECDH(), subscription_public_key) + + auth_dec = base64.b64decode(auth+ '==') + auth_enc = b"Content-Encoding: auth\x00" + prk = hkdf(shared_secret, auth_dec, auth_enc, 32) + + # 5. Build context + key_label = b"P-256\x00" + subscription_len = len(subscription_pub_key_bytes).to_bytes(2, "big") + local_pub_bytes = public_key.public_bytes( + encoding=serialization.Encoding.X962, format=serialization.PublicFormat.UncompressedPoint + ) + local_len = len(local_pub_bytes).to_bytes(2, "big") + context = ( + key_label + + subscription_len + + subscription_pub_key_bytes + + local_len + + local_pub_bytes + ) + + # 6. Generate nonce and CEK + nonce_enc = b"Content-Encoding: nonce\x00" + nonce_info = nonce_enc + context + cek_enc = b"Content-Encoding: aesgcm\x00" + cek_info = cek_enc + context + + nonce = hkdf(prk, salt, nonce_info, 12) + content_encryption_key = hkdf(prk, salt, cek_info, 16) + + # 7. Encrypt payload with AES-GCM + aesgcm = AESGCM(content_encryption_key) + padding_length = 0 # adjust if needed + padding = padding_length.to_bytes(2, "big") + b"\x00" * padding_length + combined = padding + payload.encode("utf-8") + encrypted = aesgcm.encrypt(nonce, combined, None) + + return { + 'payload': encrypted, + 'salt': salt, + 'public_key': local_pub_bytes + } + + @property + def public_key_base64(self): + return ( + base64.urlsafe_b64encode( + self.public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) + ) + .decode("utf-8") + .rstrip("=") + ) @cache def get_notifications(): diff --git a/src/snek/view/push.py b/src/snek/view/push.py index 0dcae89..00581ba 100644 --- a/src/snek/view/push.py +++ b/src/snek/view/push.py @@ -25,6 +25,7 @@ import base64 +import json import requests from cryptography.hazmat.primitives import serialization @@ -35,22 +36,25 @@ from snek.system.view import BaseFormView class PushView(BaseFormView): async def get(self): - notifications =get_notifications() + notifications = get_notifications() - return await self.json_response({ - "publicKey": base64.b64encode( - notifications.public_key.public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint + return await self.json_response( + { + "publicKey": base64.b64encode( + notifications.public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.UncompressedPoint, + ) ) - ).decode('utf-8').rstrip("="), - }) + .decode("utf-8") + .rstrip("="), + } + ) async def post(self): - memberships = [] user = {} - + user_id = self.session.get("uid") if user_id: user = await self.app.services.user.get(uid=user_id) @@ -59,34 +63,57 @@ class PushView(BaseFormView): body = await self.request.json() - if not ("encoding" in body and "endpoint" in body and "keys" in body and "p256dh" in body["keys"] and "auth" in body["keys"]): - return await self.json_response({"error": "Invalid request"}, status=400) + if not ( + "encoding" in body + and "endpoint" in body + and "keys" in body + and "p256dh" in body["keys"] + and "auth" in body["keys"] + ): + return await self.json_response( + {"error": "Invalid request"}, status=400 + ) print(body) - notifications =get_notifications() + notifications = get_notifications() - cert = base64.urlsafe_b64encode( - notifications.public_key.public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.UncompressedPoint - ) - ).decode('utf-8').rstrip("=") + cert = notifications.public_key_base64 + + test_payload = { + "title": "Hey retoor", + "message": "Guess what? ;P", + "icon": "https://molodetz.online/image/snek192.png", + "url": "https://localhost:8081", + } + + encryped = notifications.create_encrypted_payload( + body["keys"]["auth"], + body["keys"]["p256dh"], + json.dumps(test_payload), + ) + + payload = encryped["payload"] + salt = encryped["salt"] + public_key = encryped["public_key"] headers = { - "TTL": "60", - "Authorization": f"WebPush {notifications.create_notification_authorization(body['endpoint'])}", - "Crypto-Key": f"p256ecdsa={cert}", + "TTL": "60", + "Authorization": f"WebPush {notifications.create_notification_authorization(body['endpoint'])}", + "Crypto-Key": f"dh={base64.urlsafe_b64encode(public_key).decode('utf-8').rstrip('=')}; p256ecdsa={cert}", + "Encryption": f"salt={base64.urlsafe_b64encode(salt).decode('utf-8').rstrip('=')}", + + "Content-Encoding": "aesgcm", + "Content-Length": str(len(payload)), + "Content-Type": "application/octet-stream", } print(headers) - post_notification = requests.post( - body["endpoint"], - headers=headers) + post_notification = requests.post(body["endpoint"], headers=headers, data=payload) print(post_notification.status_code) print(post_notification.text) - + print(post_notification.headers) async for model in self.app.services.channel_member.find( user_uid=user_id, deleted_at=None, is_banned=False