|
from fastapi import Request
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from ..billing.usage_tracker import UsageTracker
|
|
|
|
|
|
class UsageTrackingMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
response = await call_next(request)
|
|
|
|
if hasattr(request.state, "user") and request.state.user:
|
|
user = request.state.user
|
|
|
|
if (
|
|
request.method in ["POST", "PUT"]
|
|
and "/files/upload" in request.url.path
|
|
):
|
|
content_length = response.headers.get("content-length")
|
|
if content_length:
|
|
await UsageTracker.track_bandwidth(
|
|
user=user,
|
|
amount_bytes=int(content_length),
|
|
direction="up",
|
|
metadata={"path": request.url.path},
|
|
)
|
|
|
|
elif request.method == "GET" and "/files/download" in request.url.path:
|
|
content_length = response.headers.get("content-length")
|
|
if content_length:
|
|
await UsageTracker.track_bandwidth(
|
|
user=user,
|
|
amount_bytes=int(content_length),
|
|
direction="down",
|
|
metadata={"path": request.url.path},
|
|
)
|
|
|
|
return response
|