import pytest
import pytest_asyncio
from decimal import Decimal
from datetime import date, datetime
from httpx import AsyncClient, ASGITransport
from fastapi import status
from mywebdav.main import app
from mywebdav.models import User
from mywebdav.billing.models import (
PricingConfig,
Invoice,
UsageAggregate,
UserSubscription,
)
from mywebdav.auth import create_access_token
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
is_active=True,
is_superuser=False,
)
yield user
await user.delete()
@pytest_asyncio.fixture
async def admin_user():
user = await User.create(
username="adminuser",
email="admin@example.com",
hashed_password="hashed_password_here",
is_active=True,
is_superuser=True,
)
yield user
await user.delete()
@pytest_asyncio.fixture
async def auth_token(test_user):
token = create_access_token(data={"sub": test_user.username})
return token
@pytest_asyncio.fixture
async def admin_token(admin_user):
token = create_access_token(data={"sub": admin_user.username})
return token
@pytest_asyncio.fixture
async def pricing_config():
configs = []
configs.append(
await PricingConfig.create(
config_key="storage_per_gb_month",
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month",
)
)
configs.append(
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb",
)
)
configs.append(
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb",
)
)
configs.append(
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
)
yield configs
for config in configs:
await config.delete()
@pytest.mark.asyncio
async def test_get_current_usage(test_user, auth_token):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/billing/usage/current",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "storage_gb" in data
assert "bandwidth_down_gb_today" in data
@pytest.mark.asyncio
async def test_get_monthly_usage(test_user, auth_token):
today = date.today()
await UsageAggregate.create(
user=test_user,
date=today,
storage_bytes_avg=1024**3 * 10,
storage_bytes_peak=1024**3 * 12,
bandwidth_up_bytes=1024**3 * 2,
bandwidth_down_bytes=1024**3 * 5,
)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
f"/api/billing/usage/monthly?year={today.year}&month={today.month}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["storage_gb_avg"] == pytest.approx(10.0, rel=0.01)
await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_get_subscription(test_user, auth_token):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/billing/subscription",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["billing_type"] == "pay_as_you_go"
assert data["status"] == "active"
await UserSubscription.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_list_invoices(test_user, auth_token):
invoice = await Invoice.create(
user=test_user,
invoice_number="INV-000001-202311",
period_start=date(2023, 11, 1),
period_end=date(2023, 11, 30),
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="open",
)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) > 0
assert data[0]["invoice_number"] == "INV-000001-202311"
await invoice.delete()
@pytest.mark.asyncio
async def test_get_invoice(test_user, auth_token):
invoice = await Invoice.create(
user=test_user,
invoice_number="INV-000002-202311",
period_start=date(2023, 11, 1),
period_end=date(2023, 11, 30),
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="open",
)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
f"/api/billing/invoices/{invoice.id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["invoice_number"] == "INV-000002-202311"
await invoice.delete()
@pytest.mark.asyncio
async def test_get_pricing():
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/api/billing/pricing")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, dict)
@pytest.mark.asyncio
async def test_admin_get_pricing(admin_user, admin_token, pricing_config):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) > 0
@pytest.mark.asyncio
async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
config_id = pricing_config[0].id
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.put(
f"/api/admin/billing/pricing/{config_id}",
headers={"Authorization": f"Bearer {admin_token}"},
json={"config_key": "storage_per_gb_month", "config_value": 0.005},
)
assert response.status_code == status.HTTP_200_OK
updated = await PricingConfig.get(id=config_id)
assert updated.config_value == Decimal("0.005")
@pytest.mark.asyncio
async def test_admin_get_stats(admin_user, admin_token):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/admin/billing/stats",
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "total_revenue" in data
assert "total_invoices" in data
assert "pending_invoices" in data
@pytest.mark.asyncio
async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN