from datetime import timedelta
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from ..auth import (
authenticate_user,
create_access_token,
get_password_hash,
get_current_user,
get_current_verified_user,
verify_password,
)
from ..models import User
from ..schemas import Token, UserCreate
from ..two_factor import (
generate_totp_secret,
generate_totp_uri,
generate_qr_code_base64,
verify_totp_code,
generate_recovery_codes,
hash_recovery_codes,
)
router = APIRouter(
prefix="/auth",
tags=["auth"],
)
class LoginRequest(BaseModel):
username: str
password: str
class TwoFactorLogin(BaseModel):
username: str
password: str
two_factor_code: Optional[str] = None
class TwoFactorSetupResponse(BaseModel):
secret: str
qr_code_base64: str
recovery_codes: List[str]
class TwoFactorCode(BaseModel):
two_factor_code: str
class TwoFactorDisable(BaseModel):
password: str
two_factor_code: str
@router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered",
)
user = await User.get_or_none(email=user_in.email)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
hashed_password = get_password_hash(user_in.password)
user = await User.create(
username=user_in.username,
email=user_in.email,
hashed_password=hashed_password,
)
# Send welcome email
from ..mail import queue_email
queue_email(
to_email=user.email,
subject="Welcome to MyWebdav!",
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
html=f"
Welcome to MyWebdav!
Hi {user.username},
Welcome to MyWebdav! Your account has been created successfully.
Best regards,
The MyWebdav Team
",
)
access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/token", response_model=Token)
async def login_for_access_token(login_data: LoginRequest):
auth_result = await authenticate_user(
login_data.username, login_data.password, None
)
if not auth_result:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
user = auth_result["user"]
if auth_result["2fa_required"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Two-factor authentication required",
headers={"X-2FA-Required": "true"},
)
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=access_token_expires,
two_factor_verified=True,
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
async def setup_two_factor_authentication(
current_user: User = Depends(get_current_user),
):
if current_user.is_2fa_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if current_user.two_factor_secret:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="2FA setup already initiated. Verify or disable first.",
)
secret = generate_totp_secret()
current_user.two_factor_secret = secret
await current_user.save()
totp_uri = generate_totp_uri(secret, current_user.email, "MyWebdav")
qr_code_base64 = generate_qr_code_base64(totp_uri)
recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes)
current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save()
return TwoFactorSetupResponse(
secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes
)
@router.post("/2fa/verify", response_model=Token)
async def verify_two_factor_authentication(
two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)
):
if current_user.is_2fa_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if not current_user.two_factor_secret:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated."
)
if not verify_totp_code(
current_user.two_factor_secret, two_factor_code_data.two_factor_code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.is_2fa_enabled = True
await current_user.save()
access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
data={"sub": current_user.username},
expires_delta=access_token_expires,
two_factor_verified=True,
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/disable", response_model=dict)
async def disable_two_factor_authentication(
disable_data: TwoFactorDisable,
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
# Verify password
if not verify_password(disable_data.password, current_user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password."
)
# Verify 2FA code
if not verify_totp_code(
current_user.two_factor_secret, disable_data.two_factor_code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.two_factor_secret = None
current_user.is_2fa_enabled = False
current_user.recovery_codes = None
await current_user.save()
return {"message": "2FA disabled successfully."}
@router.get("/2fa/recovery-codes", response_model=List[str])
async def get_new_recovery_codes(
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes)
current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save()
return recovery_codes