from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from ..billing.usage_tracker import UsageTracker class UsageTrackingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): response = await call_next(request) if hasattr(request.state, 'user') and request.state.user: user = request.state.user if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path: content_length = response.headers.get('content-length') if content_length: await UsageTracker.track_bandwidth( user=user, amount_bytes=int(content_length), direction='up', metadata={'path': request.url.path} ) elif request.method == 'GET' and '/files/download' in request.url.path: content_length = response.headers.get('content-length') if content_length: await UsageTracker.track_bandwidth( user=user, amount_bytes=int(content_length), direction='down', metadata={'path': request.url.path} ) return response