33 lines
1.3 KiB
Python
33 lines
1.3 KiB
Python
|
|
from fastapi import Request
|
||
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
|
from ..billing.usage_tracker import UsageTracker
|
||
|
|
|
||
|
|
class UsageTrackingMiddleware(BaseHTTPMiddleware):
|
||
|
|
async def dispatch(self, request: Request, call_next):
|
||
|
|
response = await call_next(request)
|
||
|
|
|
||
|
|
if hasattr(request.state, 'user') and request.state.user:
|
||
|
|
user = request.state.user
|
||
|
|
|
||
|
|
if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path:
|
||
|
|
content_length = response.headers.get('content-length')
|
||
|
|
if content_length:
|
||
|
|
await UsageTracker.track_bandwidth(
|
||
|
|
user=user,
|
||
|
|
amount_bytes=int(content_length),
|
||
|
|
direction='up',
|
||
|
|
metadata={'path': request.url.path}
|
||
|
|
)
|
||
|
|
|
||
|
|
elif request.method == 'GET' and '/files/download' in request.url.path:
|
||
|
|
content_length = response.headers.get('content-length')
|
||
|
|
if content_length:
|
||
|
|
await UsageTracker.track_bandwidth(
|
||
|
|
user=user,
|
||
|
|
amount_bytes=int(content_length),
|
||
|
|
direction='down',
|
||
|
|
metadata={'path': request.url.path}
|
||
|
|
)
|
||
|
|
|
||
|
|
return response
|