from datetime import timedelta from typing import Optional, List from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password from ..models import User from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA from ..two_factor import ( generate_totp_secret, generate_totp_uri, generate_qr_code_base64, verify_totp_code, generate_recovery_codes, hash_recovery_codes, verify_recovery_codes ) router = APIRouter( prefix="/auth", tags=["auth"], ) class LoginRequest(BaseModel): username: str password: str class TwoFactorLogin(BaseModel): username: str password: str two_factor_code: Optional[str] = None class TwoFactorSetupResponse(BaseModel): secret: str qr_code_base64: str recovery_codes: List[str] class TwoFactorCode(BaseModel): two_factor_code: str class TwoFactorDisable(BaseModel): password: str two_factor_code: str @router.post("/register", response_model=Token) async def register_user(user_in: UserCreate): user = await User.get_or_none(username=user_in.username) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered", ) user = await User.get_or_none(email=user_in.email) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered", ) hashed_password = get_password_hash(user_in.password) user = await User.create( username=user_in.username, email=user_in.email, hashed_password=hashed_password, ) # Send welcome email from ..mail import queue_email queue_email( to_email=user.email, subject="Welcome to MyWebdav!", body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team", html=f"

Welcome to MyWebdav!

Hi {user.username},

Welcome to MyWebdav! Your account has been created successfully.

Best regards,
The MyWebdav Team

" ) access_token_expires = timedelta(minutes=30) # Use settings access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} @router.post("/token", response_model=Token) async def login_for_access_token(login_data: LoginRequest): auth_result = await authenticate_user(login_data.username, login_data.password, None) if not auth_result: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) user = auth_result["user"] if auth_result["2fa_required"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Two-factor authentication required", headers={"X-2FA-Required": "true"}, ) access_token_expires = timedelta(minutes=30) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True ) return {"access_token": access_token, "token_type": "bearer"} @router.post("/2fa/setup", response_model=TwoFactorSetupResponse) async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)): if current_user.is_2fa_enabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.") if current_user.two_factor_secret: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.") secret = generate_totp_secret() current_user.two_factor_secret = secret await current_user.save() totp_uri = generate_totp_uri(secret, current_user.email, "RBox") qr_code_base64 = generate_qr_code_base64(totp_uri) recovery_codes = generate_recovery_codes() hashed_recovery_codes = hash_recovery_codes(recovery_codes) current_user.recovery_codes = ",".join(hashed_recovery_codes) await current_user.save() return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes) @router.post("/2fa/verify", response_model=Token) async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)): if current_user.is_2fa_enabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.") if not current_user.two_factor_secret: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.") if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.") current_user.is_2fa_enabled = True await current_user.save() access_token_expires = timedelta(minutes=30) # Use settings access_token = create_access_token( data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True ) return {"access_token": access_token, "token_type": "bearer"} @router.post("/2fa/disable", response_model=dict) async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)): if not current_user.is_2fa_enabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.") # Verify password if not verify_password(disable_data.password, current_user.hashed_password): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.") # Verify 2FA code if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.") current_user.two_factor_secret = None current_user.is_2fa_enabled = False current_user.recovery_codes = None await current_user.save() return {"message": "2FA disabled successfully."} @router.get("/2fa/recovery-codes", response_model=List[str]) async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)): if not current_user.is_2fa_enabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.") recovery_codes = generate_recovery_codes() hashed_recovery_codes = hash_recovery_codes(recovery_codes) current_user.recovery_codes = ",".join(hashed_recovery_codes) await current_user.save() return recovery_codes