from datetime import timedelta
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password
from ..models import User
from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA
from ..two_factor import (
generate_totp_secret, generate_totp_uri, generate_qr_code_base64,
verify_totp_code, generate_recovery_codes, hash_recovery_codes,
verify_recovery_codes
)
router = APIRouter(
prefix="/auth",
tags=["auth"],
)
class LoginRequest(BaseModel):
username: str
password: str
class TwoFactorLogin(BaseModel):
username: str
password: str
two_factor_code: Optional[str] = None
class TwoFactorSetupResponse(BaseModel):
secret: str
qr_code_base64: str
recovery_codes: List[str]
class TwoFactorCode(BaseModel):
two_factor_code: str
class TwoFactorDisable(BaseModel):
password: str
two_factor_code: str
@router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Username already registered",
)
user = await User.get_or_none(email=user_in.email)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
hashed_password = get_password_hash(user_in.password)
user = await User.create(
username=user_in.username,
email=user_in.email,
hashed_password=hashed_password,
)
# Send welcome email
from ..mail import queue_email
queue_email(
to_email=user.email,
subject="Welcome to MyWebdav!",
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
html=f"
Welcome to MyWebdav!
Hi {user.username},
Welcome to MyWebdav! Your account has been created successfully.
Best regards,
The MyWebdav Team
"
)
access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/token", response_model=Token)
async def login_for_access_token(login_data: LoginRequest):
auth_result = await authenticate_user(login_data.username, login_data.password, None)
if not auth_result:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
user = auth_result["user"]
if auth_result["2fa_required"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Two-factor authentication required",
headers={"X-2FA-Required": "true"},
)
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)):
if current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
if current_user.two_factor_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.")
secret = generate_totp_secret()
current_user.two_factor_secret = secret
await current_user.save()
totp_uri = generate_totp_uri(secret, current_user.email, "MyWebdav")
qr_code_base64 = generate_qr_code_base64(totp_uri)
recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes)
current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save()
return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes)
@router.post("/2fa/verify", response_model=Token)
async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)):
if current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
if not current_user.two_factor_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.")
if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
current_user.is_2fa_enabled = True
await current_user.save()
access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/disable", response_model=dict)
async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)):
if not current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
# Verify password
if not verify_password(disable_data.password, current_user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.")
# Verify 2FA code
if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
current_user.two_factor_secret = None
current_user.is_2fa_enabled = False
current_user.recovery_codes = None
await current_user.save()
return {"message": "2FA disabled successfully."}
@router.get("/2fa/recovery-codes", response_model=List[str])
async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)):
if not current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes)
current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save()
return recovery_codes