from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from ..billing.usage_tracker import UsageTracker class UsageTrackingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): response = await call_next(request) if hasattr(request.state, "user") and request.state.user: user = request.state.user if ( request.method in ["POST", "PUT"] and "/files/upload" in request.url.path ): content_length = response.headers.get("content-length") if content_length: await UsageTracker.track_bandwidth( user=user, amount_bytes=int(content_length), direction="up", metadata={"path": request.url.path}, ) elif request.method == "GET" and "/files/download" in request.url.path: content_length = response.headers.get("content-length") if content_length: await UsageTracker.track_bandwidth( user=user, amount_bytes=int(content_length), direction="down", metadata={"path": request.url.path}, ) return response