import argparse
import uvicorn
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse, Response, RedirectResponse
from fastapi.templating import Jinja2Templates
from tortoise.contrib.fastapi import register_tortoise
from .settings import settings
from .routers import (
auth,
users,
folders,
files,
shares,
search,
admin,
starred,
billing,
admin_billing,
)
from . import webdav
from .schemas import ErrorResponse
from .middleware.usage_tracking import UsageTrackingMiddleware
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up...")
logger.info("Database connected.")
from .billing.scheduler import start_scheduler
from .billing.models import PricingConfig
from .mail import email_service
start_scheduler()
logger.info("Billing scheduler started")
await email_service.start()
logger.info("Email service started")
pricing_count = await PricingConfig.all().count()
if pricing_count == 0:
from decimal import Decimal
await PricingConfig.create(
config_key="storage_per_gb_month",
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month",
)
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb",
)
await PricingConfig.create(
config_key="bandwidth_ingress_per_gb",
config_value=Decimal("0.0"),
description="Bandwidth ingress cost per GB (free)",
unit="per_gb",
)
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb",
)
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate (0 = no tax)",
unit="percentage",
)
logger.info("Default pricing configuration initialized")
yield
from .billing.scheduler import stop_scheduler
stop_scheduler()
logger.info("Billing scheduler stopped")
await email_service.stop()
logger.info("Email service stopped")
print("Shutting down...")
app = FastAPI(
title="MyWebdav Cloud Storage",
description="A commercial cloud storage web application",
version="0.1.0",
lifespan=lifespan,
)
templates = Jinja2Templates(directory="mywebdav/templates")
app.include_router(auth.router)
app.include_router(users.router)
app.include_router(folders.router)
app.include_router(files.router)
app.include_router(shares.router)
app.include_router(search.router)
app.include_router(admin.router)
app.include_router(starred.router)
app.include_router(billing.router)
app.include_router(admin_billing.router)
app.include_router(webdav.router)
app.add_middleware(UsageTrackingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
register_tortoise(
app,
db_url=settings.DATABASE_URL,
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
generate_schemas=True,
add_exception_handlers=True,
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(
f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}"
)
headers = exc.headers
# For WebDAV authentication challenges, we must return the headers
# from the exception and an empty body. A JSON body will confuse WebDAV clients.
if request.url.path.startswith("/webdav") and exc.status_code == status.HTTP_401_UNAUTHORIZED:
return Response(status_code=exc.status_code, headers=headers)
# For other WebDAV errors, it's better to return a text body than JSON
if request.url.path.startswith("/webdav"):
return Response(content=exc.detail, status_code=exc.status_code, headers=headers)
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(),
headers=headers,
)
@app.get("/", response_class=HTMLResponse)
async def splash_page(request: Request):
return templates.TemplateResponse("splash.html", {"request": request})
@app.get("/features", response_class=HTMLResponse)
async def features_page(request: Request):
return templates.TemplateResponse("features.html", {"request": request})
@app.get("/pricing", response_class=HTMLResponse)
async def pricing_page(request: Request):
return templates.TemplateResponse("pricing.html", {"request": request})
@app.get("/support", response_class=HTMLResponse)
async def support_page(request: Request):
return templates.TemplateResponse("support.html", {"request": request})
@app.get("/legal/{document_name}", response_class=HTMLResponse)
async def legal_document(request: Request, document_name: str):
try:
return templates.TemplateResponse(f"legal/{document_name}.html", {"request": request})
except Exception:
raise HTTPException(status_code=404, detail="Legal document not found")
@app.get("/login")
async def login_redirect():
return RedirectResponse(url="/app", status_code=302)
@app.get("/app", response_class=HTMLResponse)
async def web_app():
with open("static/index.html", "r") as f:
return f.read()
def main():
parser = argparse.ArgumentParser(description="Run the MyWebdav application.")
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host address to bind to"
)
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()