This commit is contained in:
retoor 2025-11-13 23:22:05 +01:00
parent aff8dfca08
commit d350ab6807
34 changed files with 1851 additions and 875 deletions

View File

@ -1,17 +1,18 @@
from typing import Optional from typing import Optional
from .models import Activity, User from .models import Activity, User
async def log_activity( async def log_activity(
user: Optional[User], user: Optional[User],
action: str, action: str,
target_type: str, target_type: str,
target_id: int, target_id: int,
ip_address: Optional[str] = None ip_address: Optional[str] = None,
): ):
await Activity.create( await Activity.create(
user=user, user=user,
action=action, action=action,
target_type=target_type, target_type=target_type,
target_id=target_id, target_id=target_id,
ip_address=ip_address ip_address=ip_address,
) )

View File

@ -9,20 +9,29 @@ import bcrypt
from .schemas import TokenData from .schemas import TokenData
from .settings import settings from .settings import settings
from .models import User from .models import User
from .two_factor import verify_totp_code # Import verify_totp_code from .two_factor import verify_totp_code # Import verify_totp_code
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def verify_password(plain_password, hashed_password): def verify_password(plain_password, hashed_password):
password_bytes = plain_password[:72].encode('utf-8') password_bytes = plain_password[:72].encode("utf-8")
hashed_bytes = hashed_password.encode('utf-8') if isinstance(hashed_password, str) else hashed_password hashed_bytes = (
hashed_password.encode("utf-8")
if isinstance(hashed_password, str)
else hashed_password
)
return bcrypt.checkpw(password_bytes, hashed_bytes) return bcrypt.checkpw(password_bytes, hashed_bytes)
def get_password_hash(password):
password_bytes = password[:72].encode('utf-8')
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8')
async def authenticate_user(username: str, password: str, two_factor_code: Optional[str] = None): def get_password_hash(password):
password_bytes = password[:72].encode("utf-8")
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8")
async def authenticate_user(
username: str, password: str, two_factor_code: Optional[str] = None
):
user = await User.get_or_none(username=username) user = await User.get_or_none(username=username)
if not user: if not user:
return None return None
@ -33,19 +42,29 @@ async def authenticate_user(username: str, password: str, two_factor_code: Optio
if not two_factor_code: if not two_factor_code:
return {"user": user, "2fa_required": True} return {"user": user, "2fa_required": True}
if not verify_totp_code(user.two_factor_secret, two_factor_code): if not verify_totp_code(user.two_factor_secret, two_factor_code):
return None # 2FA code is incorrect return None # 2FA code is incorrect
return {"user": user, "2fa_required": False} return {"user": user, "2fa_required": False}
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, two_factor_verified: bool = False):
def create_access_token(
data: dict,
expires_delta: Optional[timedelta] = None,
two_factor_verified: bool = False,
):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(timezone.utc) + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire, "2fa_verified": two_factor_verified}) to_encode.update({"exp": expire, "2fa_verified": two_factor_verified})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)): async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -53,31 +72,45 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub") username: str = payload.get("sub")
two_factor_verified: bool = payload.get("2fa_verified", False) two_factor_verified: bool = payload.get("2fa_verified", False)
if username is None: if username is None:
raise credentials_exception raise credentials_exception
token_data = TokenData(username=username, two_factor_verified=two_factor_verified) token_data = TokenData(
username=username, two_factor_verified=two_factor_verified
)
except JWTError: except JWTError:
raise credentials_exception raise credentials_exception
user = await User.get_or_none(username=token_data.username) user = await User.get_or_none(username=token_data.username)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
user.token_data = token_data # Attach token_data to user for easy access user.token_data = token_data # Attach token_data to user for easy access
return user return user
async def get_current_active_user(current_user: User = Depends(get_current_user)): async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.is_active: if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return current_user return current_user
async def get_current_verified_user(current_user: User = Depends(get_current_user)): async def get_current_verified_user(current_user: User = Depends(get_current_user)):
if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified: if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="2FA required and not verified") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="2FA required and not verified",
)
return current_user return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_verified_user)):
async def get_current_admin_user(
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions") raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
)
return current_user return current_user

View File

@ -2,14 +2,17 @@ from datetime import datetime, date, timedelta, timezone
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional
from calendar import monthrange from calendar import monthrange
from .models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription from .models import Invoice, InvoiceLineItem, PricingConfig, UserSubscription
from .usage_tracker import UsageTracker from .usage_tracker import UsageTracker
from .stripe_client import StripeClient from .stripe_client import StripeClient
from ..models import User from ..models import User
class InvoiceGenerator: class InvoiceGenerator:
@staticmethod @staticmethod
async def generate_monthly_invoice(user: User, year: int, month: int) -> Optional[Invoice]: async def generate_monthly_invoice(
user: User, year: int, month: int
) -> Optional[Invoice]:
period_start = date(year, month, 1) period_start = date(year, month, 1)
days_in_month = monthrange(year, month)[1] days_in_month = monthrange(year, month)[1]
period_end = date(year, month, days_in_month) period_end = date(year, month, days_in_month)
@ -19,19 +22,24 @@ class InvoiceGenerator:
pricing = await PricingConfig.all() pricing = await PricingConfig.all()
pricing_dict = {p.config_key: p.config_value for p in pricing} pricing_dict = {p.config_key: p.config_value for p in pricing}
storage_price_per_gb = pricing_dict.get('storage_per_gb_month', Decimal('0.0045')) storage_price_per_gb = pricing_dict.get(
bandwidth_price_per_gb = pricing_dict.get('bandwidth_egress_per_gb', Decimal('0.009')) "storage_per_gb_month", Decimal("0.0045")
free_storage_gb = pricing_dict.get('free_tier_storage_gb', Decimal('15')) )
free_bandwidth_gb = pricing_dict.get('free_tier_bandwidth_gb', Decimal('15')) bandwidth_price_per_gb = pricing_dict.get(
tax_rate = pricing_dict.get('tax_rate_default', Decimal('0')) "bandwidth_egress_per_gb", Decimal("0.009")
)
free_storage_gb = pricing_dict.get("free_tier_storage_gb", Decimal("15"))
free_bandwidth_gb = pricing_dict.get("free_tier_bandwidth_gb", Decimal("15"))
tax_rate = pricing_dict.get("tax_rate_default", Decimal("0"))
storage_gb = Decimal(str(usage['storage_gb_avg'])) storage_gb = Decimal(str(usage["storage_gb_avg"]))
bandwidth_gb = Decimal(str(usage['bandwidth_down_gb'])) bandwidth_gb = Decimal(str(usage["bandwidth_down_gb"]))
billable_storage = max(Decimal('0'), storage_gb - free_storage_gb) billable_storage = max(Decimal("0"), storage_gb - free_storage_gb)
billable_bandwidth = max(Decimal('0'), bandwidth_gb - free_bandwidth_gb) billable_bandwidth = max(Decimal("0"), bandwidth_gb - free_bandwidth_gb)
import math import math
billable_storage_rounded = Decimal(math.ceil(float(billable_storage))) billable_storage_rounded = Decimal(math.ceil(float(billable_storage)))
billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth))) billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth)))
@ -65,9 +73,9 @@ class InvoiceGenerator:
"usage": usage, "usage": usage,
"pricing": { "pricing": {
"storage_per_gb": float(storage_price_per_gb), "storage_per_gb": float(storage_price_per_gb),
"bandwidth_per_gb": float(bandwidth_price_per_gb) "bandwidth_per_gb": float(bandwidth_price_per_gb),
} },
} },
) )
if billable_storage_rounded > 0: if billable_storage_rounded > 0:
@ -78,7 +86,10 @@ class InvoiceGenerator:
unit_price=storage_price_per_gb, unit_price=storage_price_per_gb,
amount=storage_cost, amount=storage_cost,
item_type="storage", item_type="storage",
metadata={"avg_gb": float(storage_gb), "free_gb": float(free_storage_gb)} metadata={
"avg_gb": float(storage_gb),
"free_gb": float(free_storage_gb),
},
) )
if billable_bandwidth_rounded > 0: if billable_bandwidth_rounded > 0:
@ -89,7 +100,10 @@ class InvoiceGenerator:
unit_price=bandwidth_price_per_gb, unit_price=bandwidth_price_per_gb,
amount=bandwidth_cost, amount=bandwidth_cost,
item_type="bandwidth", item_type="bandwidth",
metadata={"total_gb": float(bandwidth_gb), "free_gb": float(free_bandwidth_gb)} metadata={
"total_gb": float(bandwidth_gb),
"free_gb": float(free_bandwidth_gb),
},
) )
if subscription and subscription.stripe_customer_id: if subscription and subscription.stripe_customer_id:
@ -100,7 +114,7 @@ class InvoiceGenerator:
"amount": item.amount, "amount": item.amount,
"currency": "usd", "currency": "usd",
"description": item.description, "description": item.description,
"metadata": item.metadata or {} "metadata": item.metadata or {},
} }
for item in line_items for item in line_items
] ]
@ -109,7 +123,7 @@ class InvoiceGenerator:
customer_id=subscription.stripe_customer_id, customer_id=subscription.stripe_customer_id,
description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}", description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}",
line_items=stripe_line_items, line_items=stripe_line_items,
metadata={"mywebdav_invoice_id": str(invoice.id)} metadata={"mywebdav_invoice_id": str(invoice.id)},
) )
invoice.stripe_invoice_id = stripe_invoice.id invoice.stripe_invoice_id = stripe_invoice.id
@ -135,8 +149,11 @@ class InvoiceGenerator:
# Send invoice email # Send invoice email
from ..mail import queue_email from ..mail import queue_email
line_items = await invoice.line_items.all() line_items = await invoice.line_items.all()
items_text = "\n".join([f"- {item.description}: ${item.amount}" for item in line_items]) items_text = "\n".join(
[f"- {item.description}: ${item.amount}" for item in line_items]
)
body = f"""Dear {invoice.user.username}, body = f"""Dear {invoice.user.username},
Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available. Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available.
@ -174,7 +191,7 @@ The MyWebdav Team
to_email=invoice.user.email, to_email=invoice.user.email,
subject=f"Your MyWebdav Invoice {invoice.invoice_number}", subject=f"Your MyWebdav Invoice {invoice.invoice_number}",
body=body, body=body,
html=html html=html,
) )
return invoice return invoice

View File

@ -1,8 +1,8 @@
from tortoise import fields, models from tortoise import fields, models
from decimal import Decimal
class SubscriptionPlan(models.Model): class SubscriptionPlan(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=100, unique=True) name = fields.CharField(max_length=100, unique=True)
display_name = fields.CharField(max_length=255) display_name = fields.CharField(max_length=255)
description = fields.TextField(null=True) description = fields.TextField(null=True)
@ -18,10 +18,13 @@ class SubscriptionPlan(models.Model):
class Meta: class Meta:
table = "subscription_plans" table = "subscription_plans"
class UserSubscription(models.Model): class UserSubscription(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="subscription") user = fields.ForeignKeyField("models.User", related_name="subscription")
plan = fields.ForeignKeyField("billing.SubscriptionPlan", related_name="subscriptions", null=True) plan = fields.ForeignKeyField(
"billing.SubscriptionPlan", related_name="subscriptions", null=True
)
billing_type = fields.CharField(max_length=20, default="pay_as_you_go") billing_type = fields.CharField(max_length=20, default="pay_as_you_go")
stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True) stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True)
stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True) stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True)
@ -35,14 +38,15 @@ class UserSubscription(models.Model):
class Meta: class Meta:
table = "user_subscriptions" table = "user_subscriptions"
class UsageRecord(models.Model): class UsageRecord(models.Model):
id = fields.BigIntField(pk=True) id = fields.BigIntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="usage_records") user = fields.ForeignKeyField("models.User", related_name="usage_records")
record_type = fields.CharField(max_length=50, index=True) record_type = fields.CharField(max_length=50, db_index=True)
amount_bytes = fields.BigIntField() amount_bytes = fields.BigIntField()
resource_type = fields.CharField(max_length=50, null=True) resource_type = fields.CharField(max_length=50, null=True)
resource_id = fields.IntField(null=True) resource_id = fields.IntField(null=True)
timestamp = fields.DatetimeField(auto_now_add=True, index=True) timestamp = fields.DatetimeField(auto_now_add=True, db_index=True)
idempotency_key = fields.CharField(max_length=255, unique=True, null=True) idempotency_key = fields.CharField(max_length=255, unique=True, null=True)
metadata = fields.JSONField(null=True) metadata = fields.JSONField(null=True)
@ -50,8 +54,9 @@ class UsageRecord(models.Model):
table = "usage_records" table = "usage_records"
indexes = [("user_id", "record_type", "timestamp")] indexes = [("user_id", "record_type", "timestamp")]
class UsageAggregate(models.Model): class UsageAggregate(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="usage_aggregates") user = fields.ForeignKeyField("models.User", related_name="usage_aggregates")
date = fields.DateField() date = fields.DateField()
storage_bytes_avg = fields.BigIntField(default=0) storage_bytes_avg = fields.BigIntField(default=0)
@ -64,8 +69,9 @@ class UsageAggregate(models.Model):
table = "usage_aggregates" table = "usage_aggregates"
unique_together = (("user", "date"),) unique_together = (("user", "date"),)
class Invoice(models.Model): class Invoice(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="invoices") user = fields.ForeignKeyField("models.User", related_name="invoices")
invoice_number = fields.CharField(max_length=50, unique=True) invoice_number = fields.CharField(max_length=50, unique=True)
stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True) stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True)
@ -75,10 +81,10 @@ class Invoice(models.Model):
tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0) tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0)
total = fields.DecimalField(max_digits=10, decimal_places=4) total = fields.DecimalField(max_digits=10, decimal_places=4)
currency = fields.CharField(max_length=3, default="USD") currency = fields.CharField(max_length=3, default="USD")
status = fields.CharField(max_length=50, default="draft", index=True) status = fields.CharField(max_length=50, default="draft", db_index=True)
due_date = fields.DateField(null=True) due_date = fields.DateField(null=True)
paid_at = fields.DatetimeField(null=True) paid_at = fields.DatetimeField(null=True)
created_at = fields.DatetimeField(auto_now_add=True, index=True) created_at = fields.DatetimeField(auto_now_add=True, db_index=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
metadata = fields.JSONField(null=True) metadata = fields.JSONField(null=True)
@ -86,8 +92,9 @@ class Invoice(models.Model):
table = "invoices" table = "invoices"
indexes = [("user_id", "status", "created_at")] indexes = [("user_id", "status", "created_at")]
class InvoiceLineItem(models.Model): class InvoiceLineItem(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items") invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items")
description = fields.TextField() description = fields.TextField()
quantity = fields.DecimalField(max_digits=15, decimal_places=6) quantity = fields.DecimalField(max_digits=15, decimal_places=6)
@ -100,20 +107,24 @@ class InvoiceLineItem(models.Model):
class Meta: class Meta:
table = "invoice_line_items" table = "invoice_line_items"
class PricingConfig(models.Model): class PricingConfig(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
config_key = fields.CharField(max_length=100, unique=True) config_key = fields.CharField(max_length=100, unique=True)
config_value = fields.DecimalField(max_digits=10, decimal_places=6) config_value = fields.DecimalField(max_digits=10, decimal_places=6)
description = fields.TextField(null=True) description = fields.TextField(null=True)
unit = fields.CharField(max_length=50, null=True) unit = fields.CharField(max_length=50, null=True)
updated_by = fields.ForeignKeyField("models.User", related_name="pricing_updates", null=True) updated_by = fields.ForeignKeyField(
"models.User", related_name="pricing_updates", null=True
)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
class Meta: class Meta:
table = "pricing_config" table = "pricing_config"
class PaymentMethod(models.Model): class PaymentMethod(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="payment_methods") user = fields.ForeignKeyField("models.User", related_name="payment_methods")
stripe_payment_method_id = fields.CharField(max_length=255) stripe_payment_method_id = fields.CharField(max_length=255)
type = fields.CharField(max_length=50) type = fields.CharField(max_length=50)
@ -128,9 +139,12 @@ class PaymentMethod(models.Model):
class Meta: class Meta:
table = "payment_methods" table = "payment_methods"
class BillingEvent(models.Model): class BillingEvent(models.Model):
id = fields.BigIntField(pk=True) id = fields.BigIntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="billing_events", null=True) user = fields.ForeignKeyField(
"models.User", related_name="billing_events", null=True
)
event_type = fields.CharField(max_length=100) event_type = fields.CharField(max_length=100)
stripe_event_id = fields.CharField(max_length=255, unique=True, null=True) stripe_event_id = fields.CharField(max_length=255, unique=True, null=True)
data = fields.JSONField(null=True) data = fields.JSONField(null=True)

View File

@ -1,7 +1,6 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from datetime import datetime, date, timedelta from datetime import datetime, date, timedelta
import asyncio
from .usage_tracker import UsageTracker from .usage_tracker import UsageTracker
from .invoice_generator import InvoiceGenerator from .invoice_generator import InvoiceGenerator
@ -9,6 +8,7 @@ from ..models import User
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
async def aggregate_daily_usage_for_all_users(): async def aggregate_daily_usage_for_all_users():
users = await User.filter(is_active=True).all() users = await User.filter(is_active=True).all()
yesterday = date.today() - timedelta(days=1) yesterday = date.today() - timedelta(days=1)
@ -19,6 +19,7 @@ async def aggregate_daily_usage_for_all_users():
except Exception as e: except Exception as e:
print(f"Failed to aggregate usage for user {user.id}: {e}") print(f"Failed to aggregate usage for user {user.id}: {e}")
async def generate_monthly_invoices(): async def generate_monthly_invoices():
now = datetime.now() now = datetime.now()
last_month = now.month - 1 if now.month > 1 else 12 last_month = now.month - 1 if now.month > 1 else 12
@ -28,19 +29,22 @@ async def generate_monthly_invoices():
for user in users: for user in users:
try: try:
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, last_month) invoice = await InvoiceGenerator.generate_monthly_invoice(
user, year, last_month
)
if invoice: if invoice:
await InvoiceGenerator.finalize_invoice(invoice) await InvoiceGenerator.finalize_invoice(invoice)
except Exception as e: except Exception as e:
print(f"Failed to generate invoice for user {user.id}: {e}") print(f"Failed to generate invoice for user {user.id}: {e}")
def start_scheduler(): def start_scheduler():
scheduler.add_job( scheduler.add_job(
aggregate_daily_usage_for_all_users, aggregate_daily_usage_for_all_users,
CronTrigger(hour=1, minute=0), CronTrigger(hour=1, minute=0),
id="aggregate_daily_usage", id="aggregate_daily_usage",
name="Aggregate daily usage for all users", name="Aggregate daily usage for all users",
replace_existing=True replace_existing=True,
) )
scheduler.add_job( scheduler.add_job(
@ -48,10 +52,11 @@ def start_scheduler():
CronTrigger(day=1, hour=2, minute=0), CronTrigger(day=1, hour=2, minute=0),
id="generate_monthly_invoices", id="generate_monthly_invoices",
name="Generate monthly invoices", name="Generate monthly invoices",
replace_existing=True replace_existing=True,
) )
scheduler.start() scheduler.start()
def stop_scheduler(): def stop_scheduler():
scheduler.shutdown() scheduler.shutdown()

View File

@ -1,8 +1,8 @@
import stripe import stripe
from decimal import Decimal from typing import Dict
from typing import Optional, Dict, Any
from ..settings import settings from ..settings import settings
class StripeClient: class StripeClient:
@staticmethod @staticmethod
def _ensure_api_key(): def _ensure_api_key():
@ -11,13 +11,12 @@ class StripeClient:
stripe.api_key = settings.STRIPE_SECRET_KEY stripe.api_key = settings.STRIPE_SECRET_KEY
else: else:
raise ValueError("Stripe API key not configured") raise ValueError("Stripe API key not configured")
@staticmethod @staticmethod
async def create_customer(email: str, name: str, metadata: Dict = None) -> str: async def create_customer(email: str, name: str, metadata: Dict = None) -> str:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
customer = stripe.Customer.create( customer = stripe.Customer.create(
email=email, email=email, name=name, metadata=metadata or {}
name=name,
metadata=metadata or {}
) )
return customer.id return customer.id
@ -26,7 +25,7 @@ class StripeClient:
amount: int, amount: int,
currency: str = "usd", currency: str = "usd",
customer_id: str = None, customer_id: str = None,
metadata: Dict = None metadata: Dict = None,
) -> stripe.PaymentIntent: ) -> stripe.PaymentIntent:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
return stripe.PaymentIntent.create( return stripe.PaymentIntent.create(
@ -34,32 +33,29 @@ class StripeClient:
currency=currency, currency=currency,
customer=customer_id, customer=customer_id,
metadata=metadata or {}, metadata=metadata or {},
automatic_payment_methods={"enabled": True} automatic_payment_methods={"enabled": True},
) )
@staticmethod @staticmethod
async def create_invoice( async def create_invoice(
customer_id: str, customer_id: str, description: str, line_items: list, metadata: Dict = None
description: str,
line_items: list,
metadata: Dict = None
) -> stripe.Invoice: ) -> stripe.Invoice:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
for item in line_items: for item in line_items:
stripe.InvoiceItem.create( stripe.InvoiceItem.create(
customer=customer_id, customer=customer_id,
amount=int(item['amount'] * 100), amount=int(item["amount"] * 100),
currency=item.get('currency', 'usd'), currency=item.get("currency", "usd"),
description=item['description'], description=item["description"],
metadata=item.get('metadata', {}) metadata=item.get("metadata", {}),
) )
invoice = stripe.Invoice.create( invoice = stripe.Invoice.create(
customer=customer_id, customer=customer_id,
description=description, description=description,
auto_advance=True, auto_advance=True,
collection_method='charge_automatically', collection_method="charge_automatically",
metadata=metadata or {} metadata=metadata or {},
) )
return invoice return invoice
@ -76,18 +72,15 @@ class StripeClient:
@staticmethod @staticmethod
async def attach_payment_method( async def attach_payment_method(
payment_method_id: str, payment_method_id: str, customer_id: str
customer_id: str
) -> stripe.PaymentMethod: ) -> stripe.PaymentMethod:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
payment_method = stripe.PaymentMethod.attach( payment_method = stripe.PaymentMethod.attach(
payment_method_id, payment_method_id, customer=customer_id
customer=customer_id
) )
stripe.Customer.modify( stripe.Customer.modify(
customer_id, customer_id, invoice_settings={"default_payment_method": payment_method_id}
invoice_settings={'default_payment_method': payment_method_id}
) )
return payment_method return payment_method
@ -95,22 +88,15 @@ class StripeClient:
@staticmethod @staticmethod
async def list_payment_methods(customer_id: str, type: str = "card"): async def list_payment_methods(customer_id: str, type: str = "card"):
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
return stripe.PaymentMethod.list( return stripe.PaymentMethod.list(customer=customer_id, type=type)
customer=customer_id,
type=type
)
@staticmethod @staticmethod
async def create_subscription( async def create_subscription(
customer_id: str, customer_id: str, price_id: str, metadata: Dict = None
price_id: str,
metadata: Dict = None
) -> stripe.Subscription: ) -> stripe.Subscription:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
return stripe.Subscription.create( return stripe.Subscription.create(
customer=customer_id, customer=customer_id, items=[{"price": price_id}], metadata=metadata or {}
items=[{'price': price_id}],
metadata=metadata or {}
) )
@staticmethod @staticmethod

View File

@ -1,11 +1,10 @@
import uuid import uuid
from datetime import datetime, date, timezone from datetime import datetime, date, timezone
from decimal import Decimal
from typing import Optional
from tortoise.transactions import in_transaction from tortoise.transactions import in_transaction
from .models import UsageRecord, UsageAggregate from .models import UsageRecord, UsageAggregate
from ..models import User from ..models import User
class UsageTracker: class UsageTracker:
@staticmethod @staticmethod
async def track_storage( async def track_storage(
@ -13,7 +12,7 @@ class UsageTracker:
amount_bytes: int, amount_bytes: int,
resource_type: str = None, resource_type: str = None,
resource_id: int = None, resource_id: int = None,
metadata: dict = None metadata: dict = None,
): ):
idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
@ -24,7 +23,7 @@ class UsageTracker:
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
idempotency_key=idempotency_key, idempotency_key=idempotency_key,
metadata=metadata metadata=metadata,
) )
@staticmethod @staticmethod
@ -34,7 +33,7 @@ class UsageTracker:
direction: str = "down", direction: str = "down",
resource_type: str = None, resource_type: str = None,
resource_id: int = None, resource_id: int = None,
metadata: dict = None metadata: dict = None,
): ):
record_type = f"bandwidth_{direction}" record_type = f"bandwidth_{direction}"
idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
@ -46,7 +45,7 @@ class UsageTracker:
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
idempotency_key=idempotency_key, idempotency_key=idempotency_key,
metadata=metadata metadata=metadata,
) )
@staticmethod @staticmethod
@ -61,24 +60,26 @@ class UsageTracker:
user=user, user=user,
record_type="storage", record_type="storage",
timestamp__gte=start_of_day, timestamp__gte=start_of_day,
timestamp__lte=end_of_day timestamp__lte=end_of_day,
).all() ).all()
storage_avg = sum(r.amount_bytes for r in storage_records) // max(len(storage_records), 1) storage_avg = sum(r.amount_bytes for r in storage_records) // max(
len(storage_records), 1
)
storage_peak = max((r.amount_bytes for r in storage_records), default=0) storage_peak = max((r.amount_bytes for r in storage_records), default=0)
bandwidth_up = await UsageRecord.filter( bandwidth_up = await UsageRecord.filter(
user=user, user=user,
record_type="bandwidth_up", record_type="bandwidth_up",
timestamp__gte=start_of_day, timestamp__gte=start_of_day,
timestamp__lte=end_of_day timestamp__lte=end_of_day,
).all() ).all()
bandwidth_down = await UsageRecord.filter( bandwidth_down = await UsageRecord.filter(
user=user, user=user,
record_type="bandwidth_down", record_type="bandwidth_down",
timestamp__gte=start_of_day, timestamp__gte=start_of_day,
timestamp__lte=end_of_day timestamp__lte=end_of_day,
).all() ).all()
total_up = sum(r.amount_bytes for r in bandwidth_up) total_up = sum(r.amount_bytes for r in bandwidth_up)
@ -92,8 +93,8 @@ class UsageTracker:
"storage_bytes_avg": storage_avg, "storage_bytes_avg": storage_avg,
"storage_bytes_peak": storage_peak, "storage_bytes_peak": storage_peak,
"bandwidth_up_bytes": total_up, "bandwidth_up_bytes": total_up,
"bandwidth_down_bytes": total_down "bandwidth_down_bytes": total_down,
} },
) )
if not created: if not created:
@ -108,6 +109,7 @@ class UsageTracker:
@staticmethod @staticmethod
async def get_current_storage(user: User) -> int: async def get_current_storage(user: User) -> int:
from ..models import File from ..models import File
files = await File.filter(owner=user, is_deleted=False).all() files = await File.filter(owner=user, is_deleted=False).all()
return sum(f.size for f in files) return sum(f.size for f in files)
@ -121,9 +123,7 @@ class UsageTracker:
end_date = date(year, month, last_day) end_date = date(year, month, last_day)
aggregates = await UsageAggregate.filter( aggregates = await UsageAggregate.filter(
user=user, user=user, date__gte=start_date, date__lte=end_date
date__gte=start_date,
date__lte=end_date
).all() ).all()
if not aggregates: if not aggregates:
@ -132,7 +132,7 @@ class UsageTracker:
"storage_gb_peak": 0, "storage_gb_peak": 0,
"bandwidth_up_gb": 0, "bandwidth_up_gb": 0,
"bandwidth_down_gb": 0, "bandwidth_down_gb": 0,
"total_bandwidth_gb": 0 "total_bandwidth_gb": 0,
} }
storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates) storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates)
@ -145,5 +145,5 @@ class UsageTracker:
"storage_gb_peak": round(storage_peak / (1024**3), 4), "storage_gb_peak": round(storage_peak / (1024**3), 4),
"bandwidth_up_gb": round(bandwidth_up / (1024**3), 4), "bandwidth_up_gb": round(bandwidth_up / (1024**3), 4),
"bandwidth_down_gb": round(bandwidth_down / (1024**3), 4), "bandwidth_down_gb": round(bandwidth_down / (1024**3), 4),
"total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4) "total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4),
} }

View File

@ -7,7 +7,7 @@ for various legal policies required for a European cloud storage provider.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Dict, Any from typing import Dict
class LegalDocument(ABC): class LegalDocument(ABC):
@ -17,10 +17,13 @@ class LegalDocument(ABC):
Provides common structure and methods for generating legal content. Provides common structure and methods for generating legal content.
""" """
def __init__(self, company_name: str = "MyWebdav Technologies", def __init__(
last_updated: str = None, self,
contact_email: str = "legal@mywebdav.eu", company_name: str = "MyWebdav Technologies",
website: str = "https://mywebdav.eu"): last_updated: str = None,
contact_email: str = "legal@mywebdav.eu",
website: str = "https://mywebdav.eu",
):
self.company_name = company_name self.company_name = company_name
self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y") self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y")
self.contact_email = contact_email self.contact_email = contact_email
@ -853,20 +856,21 @@ class ContactComplaintMechanism(LegalDocument):
def get_all_legal_documents() -> Dict[str, LegalDocument]: def get_all_legal_documents() -> Dict[str, LegalDocument]:
"""Return a dictionary of all legal document instances.""" """Return a dictionary of all legal document instances."""
return { return {
'privacy_policy': PrivacyPolicy(), "privacy_policy": PrivacyPolicy(),
'terms_of_service': TermsOfService(), "terms_of_service": TermsOfService(),
'security_policy': SecurityPolicy(), "security_policy": SecurityPolicy(),
'cookie_policy': CookiePolicy(), "cookie_policy": CookiePolicy(),
'data_processing_agreement': DataProcessingAgreement(), "data_processing_agreement": DataProcessingAgreement(),
'compliance_statement': ComplianceStatement(), "compliance_statement": ComplianceStatement(),
'data_portability_deletion_policy': DataPortabilityDeletionPolicy(), "data_portability_deletion_policy": DataPortabilityDeletionPolicy(),
'contact_complaint_mechanism': ContactComplaintMechanism(), "contact_complaint_mechanism": ContactComplaintMechanism(),
} }
def generate_legal_documents(output_dir: str = "static/legal"): def generate_legal_documents(output_dir: str = "static/legal"):
"""Generate all legal documents as Markdown and HTML files.""" """Generate all legal documents as Markdown and HTML files."""
import os import os
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
documents = get_all_legal_documents() documents = get_all_legal_documents()
@ -875,13 +879,13 @@ def generate_legal_documents(output_dir: str = "static/legal"):
# Generate Markdown # Generate Markdown
md_filename = f"{doc_name}.md" md_filename = f"{doc_name}.md"
md_path = os.path.join(output_dir, md_filename) md_path = os.path.join(output_dir, md_filename)
with open(md_path, 'w') as f: with open(md_path, "w") as f:
f.write(doc.to_markdown()) f.write(doc.to_markdown())
# Generate HTML # Generate HTML
html_filename = f"{doc_name}.html" html_filename = f"{doc_name}.html"
html_path = os.path.join(output_dir, html_filename) html_path = os.path.join(output_dir, html_filename)
with open(html_path, 'w') as f: with open(html_path, "w") as f:
f.write(doc.to_html()) f.write(doc.to_html())
print(f"Generated {md_filename} and {html_filename}") print(f"Generated {md_filename} and {html_filename}")

View File

@ -2,17 +2,26 @@ import asyncio
import aiosmtplib import aiosmtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from typing import Optional, Dict, Any from typing import Optional
from .settings import settings from .settings import settings
class EmailTask: class EmailTask:
def __init__(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): def __init__(
self,
to_email: str,
subject: str,
body: str,
html: Optional[str] = None,
**kwargs,
):
self.to_email = to_email self.to_email = to_email
self.subject = subject self.subject = subject
self.body = body self.body = body
self.html = html self.html = html
self.kwargs = kwargs self.kwargs = kwargs
class EmailService: class EmailService:
def __init__(self): def __init__(self):
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
@ -38,7 +47,14 @@ class EmailService:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
async def send_email(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): async def send_email(
self,
to_email: str,
subject: str,
body: str,
html: Optional[str] = None,
**kwargs,
):
"""Queue an email for sending""" """Queue an email for sending"""
task = EmailTask(to_email, subject, body, html, **kwargs) task = EmailTask(to_email, subject, body, html, **kwargs)
await self.queue.put(task) await self.queue.put(task)
@ -60,22 +76,26 @@ class EmailService:
async def _send_email_task(self, task: EmailTask): async def _send_email_task(self, task: EmailTask):
"""Send a single email task""" """Send a single email task"""
if not settings.SMTP_HOST or not settings.SMTP_USERNAME or not settings.SMTP_PASSWORD: if (
not settings.SMTP_HOST
or not settings.SMTP_USERNAME
or not settings.SMTP_PASSWORD
):
print("SMTP not configured, skipping email send") print("SMTP not configured, skipping email send")
return return
msg = MIMEMultipart('alternative') msg = MIMEMultipart("alternative")
msg['From'] = settings.SMTP_SENDER_EMAIL msg["From"] = settings.SMTP_SENDER_EMAIL
msg['To'] = task.to_email msg["To"] = task.to_email
msg['Subject'] = task.subject msg["Subject"] = task.subject
# Add text part # Add text part
text_part = MIMEText(task.body, 'plain') text_part = MIMEText(task.body, "plain")
msg.attach(text_part) msg.attach(text_part)
# Add HTML part if provided # Add HTML part if provided
if task.html: if task.html:
html_part = MIMEText(task.html, 'html') html_part = MIMEText(task.html, "html")
msg.attach(html_part) msg.attach(html_part)
try: try:
@ -86,7 +106,7 @@ class EmailService:
port=settings.SMTP_PORT, port=settings.SMTP_PORT,
username=settings.SMTP_USERNAME, username=settings.SMTP_USERNAME,
password=settings.SMTP_PASSWORD, password=settings.SMTP_PASSWORD,
use_tls=True use_tls=True,
) as smtp: ) as smtp:
await smtp.send_message(msg) await smtp.send_message(msg)
print(f"Email sent to {task.to_email}") print(f"Email sent to {task.to_email}")
@ -99,7 +119,10 @@ class EmailService:
try: try:
await smtp.starttls() await smtp.starttls()
except Exception as tls_error: except Exception as tls_error:
if "already using" in str(tls_error).lower() or "tls" in str(tls_error).lower(): if (
"already using" in str(tls_error).lower()
or "tls" in str(tls_error).lower()
):
# Connection is already using TLS, proceed without starttls # Connection is already using TLS, proceed without starttls
pass pass
else: else:
@ -111,14 +134,23 @@ class EmailService:
print(f"Failed to send email to {task.to_email}: {e}") print(f"Failed to send email to {task.to_email}: {e}")
raise # Re-raise to let caller handle raise # Re-raise to let caller handle
# Global email service instance # Global email service instance
email_service = EmailService() email_service = EmailService()
# Convenience functions # Convenience functions
async def send_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): async def send_email(
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
):
"""Send an email asynchronously""" """Send an email asynchronously"""
await email_service.send_email(to_email, subject, body, html, **kwargs) await email_service.send_email(to_email, subject, body, html, **kwargs)
def queue_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
def queue_email(
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
):
"""Queue an email for sending (fire and forget)""" """Queue an email for sending (fire and forget)"""
asyncio.create_task(email_service.send_email(to_email, subject, body, html, **kwargs)) asyncio.create_task(
email_service.send_email(to_email, subject, body, html, **kwargs)
)

View File

@ -2,18 +2,33 @@ import argparse
import uvicorn import uvicorn
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status, HTTPException from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from tortoise.contrib.fastapi import register_tortoise from tortoise.contrib.fastapi import register_tortoise
from .settings import settings from .settings import settings
from .routers import auth, users, folders, files, shares, search, admin, starred, billing, admin_billing from .routers import (
auth,
users,
folders,
files,
shares,
search,
admin,
starred,
billing,
admin_billing,
)
from . import webdav from . import webdav
from .schemas import ErrorResponse from .schemas import ErrorResponse
from .middleware.usage_tracking import UsageTrackingMiddleware
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up...") logger.info("Starting up...")
@ -21,6 +36,7 @@ async def lifespan(app: FastAPI):
from .billing.scheduler import start_scheduler from .billing.scheduler import start_scheduler
from .billing.models import PricingConfig from .billing.models import PricingConfig
from .mail import email_service from .mail import email_service
start_scheduler() start_scheduler()
logger.info("Billing scheduler started") logger.info("Billing scheduler started")
await email_service.start() await email_service.start()
@ -28,28 +44,61 @@ async def lifespan(app: FastAPI):
pricing_count = await PricingConfig.all().count() pricing_count = await PricingConfig.all().count()
if pricing_count == 0: if pricing_count == 0:
from decimal import Decimal from decimal import Decimal
await PricingConfig.create(config_key='storage_per_gb_month', config_value=Decimal('0.0045'), description='Storage cost per GB per month', unit='per_gb_month')
await PricingConfig.create(config_key='bandwidth_egress_per_gb', config_value=Decimal('0.009'), description='Bandwidth egress cost per GB', unit='per_gb') await PricingConfig.create(
await PricingConfig.create(config_key='bandwidth_ingress_per_gb', config_value=Decimal('0.0'), description='Bandwidth ingress cost per GB (free)', unit='per_gb') config_key="storage_per_gb_month",
await PricingConfig.create(config_key='free_tier_storage_gb', config_value=Decimal('15'), description='Free tier storage in GB', unit='gb') config_value=Decimal("0.0045"),
await PricingConfig.create(config_key='free_tier_bandwidth_gb', config_value=Decimal('15'), description='Free tier bandwidth in GB per month', unit='gb') description="Storage cost per GB per month",
await PricingConfig.create(config_key='tax_rate_default', config_value=Decimal('0.0'), description='Default tax rate (0 = no tax)', unit='percentage') unit="per_gb_month",
)
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb",
)
await PricingConfig.create(
config_key="bandwidth_ingress_per_gb",
config_value=Decimal("0.0"),
description="Bandwidth ingress cost per GB (free)",
unit="per_gb",
)
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb",
)
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate (0 = no tax)",
unit="percentage",
)
logger.info("Default pricing configuration initialized") logger.info("Default pricing configuration initialized")
yield yield
from .billing.scheduler import stop_scheduler from .billing.scheduler import stop_scheduler
stop_scheduler() stop_scheduler()
logger.info("Billing scheduler stopped") logger.info("Billing scheduler stopped")
await email_service.stop() await email_service.stop()
logger.info("Email service stopped") logger.info("Email service stopped")
print("Shutting down...") print("Shutting down...")
app = FastAPI( app = FastAPI(
title="MyWebdav Cloud Storage", title="MyWebdav Cloud Storage",
description="A commercial cloud storage web application", description="A commercial cloud storage web application",
version="0.1.0", version="0.1.0",
lifespan=lifespan lifespan=lifespan,
) )
app.include_router(auth.router) app.include_router(auth.router)
@ -64,8 +113,6 @@ app.include_router(billing.router)
app.include_router(admin_billing.router) app.include_router(admin_billing.router)
app.include_router(webdav.router) app.include_router(webdav.router)
from .middleware.usage_tracking import UsageTrackingMiddleware
app.add_middleware(UsageTrackingMiddleware) app.add_middleware(UsageTrackingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
@ -73,35 +120,39 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
register_tortoise( register_tortoise(
app, app,
db_url=settings.DATABASE_URL, db_url=settings.DATABASE_URL,
modules={ modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
"models": ["mywebdav.models"],
"billing": ["mywebdav.billing.models"]
},
generate_schemas=True, generate_schemas=True,
add_exception_handlers=True, add_exception_handlers=True,
) )
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException): async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}") logger.error(
f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}"
)
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(), content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(),
) )
@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse @app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse
async def read_root(): async def read_root():
with open("static/index.html", "r") as f: with open("static/index.html", "r") as f:
return f.read() return f.read()
def main(): def main():
parser = argparse.ArgumentParser(description="Run the MyWebdav application.") parser = argparse.ArgumentParser(description="Run the MyWebdav application.")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address to bind to") parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host address to bind to"
)
parser.add_argument("--port", type=int, default=8000, help="Port to listen on") parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
args = parser.parse_args() args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -2,31 +2,35 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from ..billing.usage_tracker import UsageTracker from ..billing.usage_tracker import UsageTracker
class UsageTrackingMiddleware(BaseHTTPMiddleware): class UsageTrackingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
response = await call_next(request) response = await call_next(request)
if hasattr(request.state, 'user') and request.state.user: if hasattr(request.state, "user") and request.state.user:
user = request.state.user user = request.state.user
if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path: if (
content_length = response.headers.get('content-length') request.method in ["POST", "PUT"]
and "/files/upload" in request.url.path
):
content_length = response.headers.get("content-length")
if content_length: if content_length:
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
user=user, user=user,
amount_bytes=int(content_length), amount_bytes=int(content_length),
direction='up', direction="up",
metadata={'path': request.url.path} metadata={"path": request.url.path},
) )
elif request.method == 'GET' and '/files/download' in request.url.path: elif request.method == "GET" and "/files/download" in request.url.path:
content_length = response.headers.get('content-length') content_length = response.headers.get("content-length")
if content_length: if content_length:
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
user=user, user=user,
amount_bytes=int(content_length), amount_bytes=int(content_length),
direction='down', direction="down",
metadata={'path': request.url.path} metadata={"path": request.url.path},
) )
return response return response

View File

@ -1,9 +1,9 @@
from tortoise import fields, models from tortoise import fields, models
from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.contrib.pydantic import pydantic_model_creator
from datetime import datetime
class User(models.Model): class User(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
username = fields.CharField(max_length=20, unique=True) username = fields.CharField(max_length=20, unique=True)
email = fields.CharField(max_length=255, unique=True) email = fields.CharField(max_length=255, unique=True)
hashed_password = fields.CharField(max_length=255) hashed_password = fields.CharField(max_length=255)
@ -11,7 +11,9 @@ class User(models.Model):
is_superuser = fields.BooleanField(default=False) is_superuser = fields.BooleanField(default=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
storage_quota_bytes = fields.BigIntField(default=10 * 1024 * 1024 * 1024) # 10 GB default storage_quota_bytes = fields.BigIntField(
default=10 * 1024 * 1024 * 1024
) # 10 GB default
used_storage_bytes = fields.BigIntField(default=0) used_storage_bytes = fields.BigIntField(default=0)
plan_type = fields.CharField(max_length=50, default="free") plan_type = fields.CharField(max_length=50, default="free")
two_factor_secret = fields.CharField(max_length=255, null=True) two_factor_secret = fields.CharField(max_length=255, null=True)
@ -24,11 +26,16 @@ class User(models.Model):
def __str__(self): def __str__(self):
return self.username return self.username
class Folder(models.Model): class Folder(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255) name = fields.CharField(max_length=255)
parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField("models.Folder", related_name="children", null=True) parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField(
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="folders") "models.Folder", related_name="children", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="folders"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False) is_deleted = fields.BooleanField(default=False)
@ -36,21 +43,28 @@ class Folder(models.Model):
class Meta: class Meta:
table = "folders" table = "folders"
unique_together = (("name", "parent", "owner"),) # Ensure unique folder names within a parent for an owner unique_together = (
("name", "parent", "owner"),
) # Ensure unique folder names within a parent for an owner
def __str__(self): def __str__(self):
return self.name return self.name
class File(models.Model): class File(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255) name = fields.CharField(max_length=255)
path = fields.CharField(max_length=1024) # Internal storage path path = fields.CharField(max_length=1024) # Internal storage path
size = fields.BigIntField() size = fields.BigIntField()
mime_type = fields.CharField(max_length=255) mime_type = fields.CharField(max_length=255)
file_hash = fields.CharField(max_length=64, null=True) # SHA-256 file_hash = fields.CharField(max_length=64, null=True) # SHA-256
thumbnail_path = fields.CharField(max_length=1024, null=True) thumbnail_path = fields.CharField(max_length=1024, null=True)
parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="files", null=True) parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="files") "models.Folder", related_name="files", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="files"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False) is_deleted = fields.BooleanField(default=False)
@ -60,14 +74,19 @@ class File(models.Model):
class Meta: class Meta:
table = "files" table = "files"
unique_together = (("name", "parent", "owner"),) # Ensure unique file names within a parent for an owner unique_together = (
("name", "parent", "owner"),
) # Ensure unique file names within a parent for an owner
def __str__(self): def __str__(self):
return self.name return self.name
class FileVersion(models.Model): class FileVersion(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="versions") file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
"models.File", related_name="versions"
)
version_path = fields.CharField(max_length=1024) version_path = fields.CharField(max_length=1024)
size = fields.BigIntField() size = fields.BigIntField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
@ -75,47 +94,69 @@ class FileVersion(models.Model):
class Meta: class Meta:
table = "file_versions" table = "file_versions"
class Share(models.Model): class Share(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
token = fields.CharField(max_length=64, unique=True) token = fields.CharField(max_length=64, unique=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="shares", null=True) file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="shares", null=True) "models.File", related_name="shares", null=True
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="shares") )
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="shares", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="shares"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True) expires_at = fields.DatetimeField(null=True)
password_protected = fields.BooleanField(default=False) password_protected = fields.BooleanField(default=False)
hashed_password = fields.CharField(max_length=255, null=True) hashed_password = fields.CharField(max_length=255, null=True)
access_count = fields.IntField(default=0) access_count = fields.IntField(default=0)
permission_level = fields.CharField(max_length=50, default="viewer") # viewer, uploader, editor permission_level = fields.CharField(
max_length=50, default="viewer"
) # viewer, uploader, editor
class Meta: class Meta:
table = "shares" table = "shares"
class Team(models.Model): class Team(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255, unique=True) name = fields.CharField(max_length=255, unique=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="owned_teams") owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
members: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User", related_name="teams", through="team_members") "models.User", related_name="owned_teams"
)
members: fields.ManyToManyRelation[User] = fields.ManyToManyField(
"models.User", related_name="teams", through="team_members"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta: class Meta:
table = "teams" table = "teams"
class TeamMember(models.Model): class TeamMember(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField("models.Team", related_name="team_members") team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField(
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="user_teams") "models.Team", related_name="team_members"
role = fields.CharField(max_length=50, default="member") # owner, admin, member )
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="user_teams"
)
role = fields.CharField(max_length=50, default="member") # owner, admin, member
class Meta: class Meta:
table = "team_members" table = "team_members"
unique_together = (("team", "user"),) unique_together = (("team", "user"),)
class Activity(models.Model): class Activity(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="activities", null=True) user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="activities", null=True
)
action = fields.CharField(max_length=255) action = fields.CharField(max_length=255)
target_type = fields.CharField(max_length=50) # file, folder, share, user, team target_type = fields.CharField(max_length=50) # file, folder, share, user, team
target_id = fields.IntField() target_id = fields.IntField()
ip_address = fields.CharField(max_length=45, null=True) ip_address = fields.CharField(max_length=45, null=True)
timestamp = fields.DatetimeField(auto_now_add=True) timestamp = fields.DatetimeField(auto_now_add=True)
@ -123,13 +164,18 @@ class Activity(models.Model):
class Meta: class Meta:
table = "activities" table = "activities"
class FileRequest(models.Model): class FileRequest(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
title = fields.CharField(max_length=255) title = fields.CharField(max_length=255)
description = fields.TextField(null=True) description = fields.TextField(null=True)
token = fields.CharField(max_length=64, unique=True) token = fields.CharField(max_length=64, unique=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="file_requests") owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="file_requests") "models.User", related_name="file_requests"
)
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="file_requests"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True) expires_at = fields.DatetimeField(null=True)
is_active = fields.BooleanField(default=True) is_active = fields.BooleanField(default=True)
@ -137,8 +183,9 @@ class FileRequest(models.Model):
class Meta: class Meta:
table = "file_requests" table = "file_requests"
class WebDAVProperty(models.Model): class WebDAVProperty(models.Model):
id = fields.IntField(pk=True) id = fields.IntField(primary_key=True)
resource_type = fields.CharField(max_length=10) resource_type = fields.CharField(max_length=10)
resource_id = fields.IntField() resource_id = fields.IntField()
namespace = fields.CharField(max_length=255) namespace = fields.CharField(max_length=255)
@ -151,5 +198,8 @@ class WebDAVProperty(models.Model):
table = "webdav_properties" table = "webdav_properties"
unique_together = (("resource_type", "resource_id", "namespace", "name"),) unique_together = (("resource_type", "resource_id", "namespace", "name"),)
User_Pydantic = pydantic_model_creator(User, name="User_Pydantic") User_Pydantic = pydantic_model_creator(User, name="User_Pydantic")
UserIn_Pydantic = pydantic_model_creator(User, name="UserIn_Pydantic", exclude_readonly=True) UserIn_Pydantic = pydantic_model_creator(
User, name="UserIn_Pydantic", exclude_readonly=True
)

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import List
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@ -13,18 +13,25 @@ router = APIRouter(
responses={403: {"description": "Not enough permissions"}}, responses={403: {"description": "Not enough permissions"}},
) )
@router.get("/users", response_model=List[User_Pydantic]) @router.get("/users", response_model=List[User_Pydantic])
async def get_all_users(): async def get_all_users():
return await User.all() return await User.all()
@router.get("/users/{user_id}", response_model=User_Pydantic) @router.get("/users/{user_id}", response_model=User_Pydantic)
async def get_user(user_id: int): async def get_user(user_id: int):
user = await User.get_or_none(id=user_id) user = await User.get_or_none(id=user_id)
if not user: if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return user return user
@router.post("/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED)
@router.post(
"/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED
)
async def create_user_by_admin(user_in: UserCreate): async def create_user_by_admin(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username) user = await User.get_or_none(username=user_in.username)
if user: if user:
@ -44,25 +51,33 @@ async def create_user_by_admin(user_in: UserCreate):
username=user_in.username, username=user_in.username,
email=user_in.email, email=user_in.email,
hashed_password=hashed_password, hashed_password=hashed_password,
is_superuser=False, # Admin creates regular users by default is_superuser=False, # Admin creates regular users by default
is_active=True, is_active=True,
) )
return await User_Pydantic.from_tortoise_orm(user) return await User_Pydantic.from_tortoise_orm(user)
@router.put("/users/{user_id}", response_model=User_Pydantic) @router.put("/users/{user_id}", response_model=User_Pydantic)
async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate): async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
user = await User.get_or_none(id=user_id) user = await User.get_or_none(id=user_id)
if not user: if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if user_update.username is not None and user_update.username != user.username: if user_update.username is not None and user_update.username != user.username:
if await User.get_or_none(username=user_update.username): if await User.get_or_none(username=user_update.username):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
)
user.username = user_update.username user.username = user_update.username
if user_update.email is not None and user_update.email != user.email: if user_update.email is not None and user_update.email != user.email:
if await User.get_or_none(email=user_update.email): if await User.get_or_none(email=user_update.email):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
user.email = user_update.email user.email = user_update.email
if user_update.password is not None: if user_update.password is not None:
@ -89,21 +104,28 @@ async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
await user.save() await user.save()
return await User_Pydantic.from_tortoise_orm(user) return await User_Pydantic.from_tortoise_orm(user)
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user_by_admin(user_id: int): async def delete_user_by_admin(user_id: int):
user = await User.get_or_none(id=user_id) user = await User.get_or_none(id=user_id)
if not user: if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
await user.delete() await user.delete()
return {"message": "User deleted successfully"} return {"message": "User deleted successfully"}
@router.post("/test-email") @router.post("/test-email")
async def send_test_email(to_email: str, subject: str = "Test Email", body: str = "This is a test email"): async def send_test_email(
to_email: str, subject: str = "Test Email", body: str = "This is a test email"
):
from ..mail import queue_email from ..mail import queue_email
queue_email( queue_email(
to_email=to_email, to_email=to_email,
subject=subject, subject=subject,
body=body, body=body,
html=f"<h1>{subject}</h1><p>{body}</p>" html=f"<h1>{subject}</h1><p>{body}</p>",
) )
return {"message": "Test email queued"} return {"message": "Test email queued"}

View File

@ -1,5 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from typing import List
from decimal import Decimal from decimal import Decimal
from pydantic import BaseModel from pydantic import BaseModel
@ -8,21 +7,25 @@ from ..models import User
from ..billing.models import PricingConfig, Invoice, SubscriptionPlan from ..billing.models import PricingConfig, Invoice, SubscriptionPlan
from ..billing.invoice_generator import InvoiceGenerator from ..billing.invoice_generator import InvoiceGenerator
def require_superuser(current_user: User = Depends(get_current_user)): def require_superuser(current_user: User = Depends(get_current_user)):
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Superuser privileges required") raise HTTPException(status_code=403, detail="Superuser privileges required")
return current_user return current_user
router = APIRouter( router = APIRouter(
prefix="/api/admin/billing", prefix="/api/admin/billing",
tags=["admin", "billing"], tags=["admin", "billing"],
dependencies=[Depends(require_superuser)] dependencies=[Depends(require_superuser)],
) )
class PricingConfigUpdate(BaseModel): class PricingConfigUpdate(BaseModel):
config_key: str config_key: str
config_value: float config_value: float
class PlanCreate(BaseModel): class PlanCreate(BaseModel):
name: str name: str
display_name: str display_name: str
@ -32,6 +35,7 @@ class PlanCreate(BaseModel):
price_monthly: float price_monthly: float
price_yearly: float = None price_yearly: float = None
@router.get("/pricing") @router.get("/pricing")
async def get_all_pricing(current_user: User = Depends(require_superuser)): async def get_all_pricing(current_user: User = Depends(require_superuser)):
configs = await PricingConfig.all() configs = await PricingConfig.all()
@ -42,16 +46,17 @@ async def get_all_pricing(current_user: User = Depends(require_superuser)):
"config_value": float(c.config_value), "config_value": float(c.config_value),
"description": c.description, "description": c.description,
"unit": c.unit, "unit": c.unit,
"updated_at": c.updated_at "updated_at": c.updated_at,
} }
for c in configs for c in configs
] ]
@router.put("/pricing/{config_id}") @router.put("/pricing/{config_id}")
async def update_pricing( async def update_pricing(
config_id: int, config_id: int,
update: PricingConfigUpdate, update: PricingConfigUpdate,
current_user: User = Depends(require_superuser) current_user: User = Depends(require_superuser),
): ):
config = await PricingConfig.get_or_none(id=config_id) config = await PricingConfig.get_or_none(id=config_id)
if not config: if not config:
@ -63,11 +68,10 @@ async def update_pricing(
return {"message": "Pricing updated successfully"} return {"message": "Pricing updated successfully"}
@router.post("/generate-invoices/{year}/{month}") @router.post("/generate-invoices/{year}/{month}")
async def generate_all_invoices( async def generate_all_invoices(
year: int, year: int, month: int, current_user: User = Depends(require_superuser)
month: int,
current_user: User = Depends(require_superuser)
): ):
users = await User.filter(is_active=True).all() users = await User.filter(is_active=True).all()
generated = [] generated = []
@ -76,35 +80,36 @@ async def generate_all_invoices(
for user in users: for user in users:
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month) invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month)
if invoice: if invoice:
generated.append({ generated.append(
"user_id": user.id, {
"invoice_id": invoice.id, "user_id": user.id,
"total": float(invoice.total) "invoice_id": invoice.id,
}) "total": float(invoice.total),
}
)
else: else:
skipped.append(user.id) skipped.append(user.id)
return { return {"generated": len(generated), "skipped": len(skipped), "invoices": generated}
"generated": len(generated),
"skipped": len(skipped),
"invoices": generated
}
@router.post("/plans") @router.post("/plans")
async def create_plan( async def create_plan(
plan_data: PlanCreate, plan_data: PlanCreate, current_user: User = Depends(require_superuser)
current_user: User = Depends(require_superuser)
): ):
plan = await SubscriptionPlan.create(**plan_data.dict()) plan = await SubscriptionPlan.create(**plan_data.dict())
return {"id": plan.id, "message": "Plan created successfully"} return {"id": plan.id, "message": "Plan created successfully"}
@router.get("/stats") @router.get("/stats")
async def get_billing_stats(current_user: User = Depends(require_superuser)): async def get_billing_stats(current_user: User = Depends(require_superuser)):
from tortoise.functions import Sum, Count from tortoise.functions import Sum
total_revenue = await Invoice.filter(status="paid").annotate( total_revenue = (
total_sum=Sum("total") await Invoice.filter(status="paid")
).values("total_sum") .annotate(total_sum=Sum("total"))
.values("total_sum")
)
invoice_count = await Invoice.all().count() invoice_count = await Invoice.all().count()
pending_invoices = await Invoice.filter(status="open").count() pending_invoices = await Invoice.filter(status="open").count()
@ -112,5 +117,5 @@ async def get_billing_stats(current_user: User = Depends(require_superuser)):
return { return {
"total_revenue": float(total_revenue[0]["total_sum"] or 0), "total_revenue": float(total_revenue[0]["total_sum"] or 0),
"total_invoices": invoice_count, "total_invoices": invoice_count,
"pending_invoices": pending_invoices "pending_invoices": pending_invoices,
} }

View File

@ -4,13 +4,23 @@ from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel from pydantic import BaseModel
from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password from ..auth import (
authenticate_user,
create_access_token,
get_password_hash,
get_current_user,
get_current_verified_user,
verify_password,
)
from ..models import User from ..models import User
from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA from ..schemas import Token, UserCreate
from ..two_factor import ( from ..two_factor import (
generate_totp_secret, generate_totp_uri, generate_qr_code_base64, generate_totp_secret,
verify_totp_code, generate_recovery_codes, hash_recovery_codes, generate_totp_uri,
verify_recovery_codes generate_qr_code_base64,
verify_totp_code,
generate_recovery_codes,
hash_recovery_codes,
) )
router = APIRouter( router = APIRouter(
@ -18,27 +28,33 @@ router = APIRouter(
tags=["auth"], tags=["auth"],
) )
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
username: str username: str
password: str password: str
class TwoFactorLogin(BaseModel): class TwoFactorLogin(BaseModel):
username: str username: str
password: str password: str
two_factor_code: Optional[str] = None two_factor_code: Optional[str] = None
class TwoFactorSetupResponse(BaseModel): class TwoFactorSetupResponse(BaseModel):
secret: str secret: str
qr_code_base64: str qr_code_base64: str
recovery_codes: List[str] recovery_codes: List[str]
class TwoFactorCode(BaseModel): class TwoFactorCode(BaseModel):
two_factor_code: str two_factor_code: str
class TwoFactorDisable(BaseModel): class TwoFactorDisable(BaseModel):
password: str password: str
two_factor_code: str two_factor_code: str
@router.post("/register", response_model=Token) @router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate): async def register_user(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username) user = await User.get_or_none(username=user_in.username)
@ -63,22 +79,26 @@ async def register_user(user_in: UserCreate):
# Send welcome email # Send welcome email
from ..mail import queue_email from ..mail import queue_email
queue_email( queue_email(
to_email=user.email, to_email=user.email,
subject="Welcome to MyWebdav!", subject="Welcome to MyWebdav!",
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team", body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>" html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>",
) )
access_token_expires = timedelta(minutes=30) # Use settings access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires data={"sub": user.username}, expires_delta=access_token_expires
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.post("/token", response_model=Token) @router.post("/token", response_model=Token)
async def login_for_access_token(login_data: LoginRequest): async def login_for_access_token(login_data: LoginRequest):
auth_result = await authenticate_user(login_data.username, login_data.password, None) auth_result = await authenticate_user(
login_data.username, login_data.password, None
)
if not auth_result: if not auth_result:
raise HTTPException( raise HTTPException(
@ -97,16 +117,26 @@ async def login_for_access_token(login_data: LoginRequest):
access_token_expires = timedelta(minutes=30) access_token_expires = timedelta(minutes=30)
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True data={"sub": user.username},
expires_delta=access_token_expires,
two_factor_verified=True,
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse) @router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)): async def setup_two_factor_authentication(
current_user: User = Depends(get_current_user),
):
if current_user.is_2fa_enabled: if current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if current_user.two_factor_secret: if current_user.two_factor_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="2FA setup already initiated. Verify or disable first.",
)
secret = generate_totp_secret() secret = generate_totp_secret()
current_user.two_factor_secret = secret current_user.two_factor_secret = secret
@ -120,39 +150,66 @@ async def setup_two_factor_authentication(current_user: User = Depends(get_curre
current_user.recovery_codes = ",".join(hashed_recovery_codes) current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save() await current_user.save()
return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes) return TwoFactorSetupResponse(
secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes
)
@router.post("/2fa/verify", response_model=Token) @router.post("/2fa/verify", response_model=Token)
async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)): async def verify_two_factor_authentication(
two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)
):
if current_user.is_2fa_enabled: if current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if not current_user.two_factor_secret: if not current_user.two_factor_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated."
)
if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code): if not verify_totp_code(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.") current_user.two_factor_secret, two_factor_code_data.two_factor_code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.is_2fa_enabled = True current_user.is_2fa_enabled = True
await current_user.save() await current_user.save()
access_token_expires = timedelta(minutes=30) # Use settings access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token( access_token = create_access_token(
data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True data={"sub": current_user.username},
expires_delta=access_token_expires,
two_factor_verified=True,
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/disable", response_model=dict) @router.post("/2fa/disable", response_model=dict)
async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)): async def disable_two_factor_authentication(
disable_data: TwoFactorDisable,
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled: if not current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
# Verify password # Verify password
if not verify_password(disable_data.password, current_user.hashed_password): if not verify_password(disable_data.password, current_user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password."
)
# Verify 2FA code # Verify 2FA code
if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code): if not verify_totp_code(
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.") current_user.two_factor_secret, disable_data.two_factor_code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.two_factor_secret = None current_user.two_factor_secret = None
current_user.is_2fa_enabled = False current_user.is_2fa_enabled = False
@ -161,10 +218,15 @@ async def disable_two_factor_authentication(disable_data: TwoFactorDisable, curr
return {"message": "2FA disabled successfully."} return {"message": "2FA disabled successfully."}
@router.get("/2fa/recovery-codes", response_model=List[str]) @router.get("/2fa/recovery-codes", response_model=List[str])
async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)): async def get_new_recovery_codes(
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled: if not current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
recovery_codes = generate_recovery_codes() recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes) hashed_recovery_codes = hash_recovery_codes(recovery_codes)

View File

@ -1,25 +1,25 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List, Optional from typing import List, Optional
from datetime import datetime, date from datetime import datetime, date
from decimal import Decimal
import calendar
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User from ..models import User
from ..billing.models import ( from ..billing.models import (
Invoice, InvoiceLineItem, UserSubscription, PricingConfig, Invoice,
PaymentMethod, UsageAggregate, SubscriptionPlan UserSubscription,
PricingConfig,
PaymentMethod,
UsageAggregate,
SubscriptionPlan,
) )
from ..billing.usage_tracker import UsageTracker from ..billing.usage_tracker import UsageTracker
from ..billing.invoice_generator import InvoiceGenerator from ..billing.invoice_generator import InvoiceGenerator
from ..billing.stripe_client import StripeClient from ..billing.stripe_client import StripeClient
from pydantic import BaseModel from pydantic import BaseModel
router = APIRouter( router = APIRouter(prefix="/api/billing", tags=["billing"])
prefix="/api/billing",
tags=["billing"]
)
class UsageResponse(BaseModel): class UsageResponse(BaseModel):
storage_gb_avg: float storage_gb_avg: float
@ -29,6 +29,7 @@ class UsageResponse(BaseModel):
total_bandwidth_gb: float total_bandwidth_gb: float
period: str period: str
class InvoiceResponse(BaseModel): class InvoiceResponse(BaseModel):
id: int id: int
invoice_number: str invoice_number: str
@ -42,6 +43,7 @@ class InvoiceResponse(BaseModel):
paid_at: Optional[datetime] paid_at: Optional[datetime]
line_items: List[dict] line_items: List[dict]
class SubscriptionResponse(BaseModel): class SubscriptionResponse(BaseModel):
id: int id: int
billing_type: str billing_type: str
@ -50,6 +52,7 @@ class SubscriptionResponse(BaseModel):
current_period_start: Optional[datetime] current_period_start: Optional[datetime]
current_period_end: Optional[datetime] current_period_end: Optional[datetime]
@router.get("/usage/current") @router.get("/usage/current")
async def get_current_usage(current_user: User = Depends(get_current_user)): async def get_current_usage(current_user: User = Depends(get_current_user)):
try: try:
@ -61,25 +64,32 @@ async def get_current_usage(current_user: User = Depends(get_current_user)):
if usage_today: if usage_today:
return { return {
"storage_gb": round(storage_bytes / (1024**3), 4), "storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": round(usage_today.bandwidth_down_bytes / (1024**3), 4), "bandwidth_down_gb_today": round(
"bandwidth_up_gb_today": round(usage_today.bandwidth_up_bytes / (1024**3), 4), usage_today.bandwidth_down_bytes / (1024**3), 4
"as_of": today.isoformat() ),
"bandwidth_up_gb_today": round(
usage_today.bandwidth_up_bytes / (1024**3), 4
),
"as_of": today.isoformat(),
} }
return { return {
"storage_gb": round(storage_bytes / (1024**3), 4), "storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": 0, "bandwidth_down_gb_today": 0,
"bandwidth_up_gb_today": 0, "bandwidth_up_gb_today": 0,
"as_of": today.isoformat() "as_of": today.isoformat(),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch usage data: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to fetch usage data: {str(e)}"
)
@router.get("/usage/monthly") @router.get("/usage/monthly")
async def get_monthly_usage( async def get_monthly_usage(
year: Optional[int] = None, year: Optional[int] = None,
month: Optional[int] = None, month: Optional[int] = None,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
) -> UsageResponse: ) -> UsageResponse:
try: try:
if year is None or month is None: if year is None or month is None:
@ -88,71 +98,85 @@ async def get_monthly_usage(
month = now.month month = now.month
if not (1 <= month <= 12): if not (1 <= month <= 12):
raise HTTPException(status_code=400, detail="Month must be between 1 and 12") raise HTTPException(
status_code=400, detail="Month must be between 1 and 12"
)
if not (2020 <= year <= 2100): if not (2020 <= year <= 2100):
raise HTTPException(status_code=400, detail="Year must be between 2020 and 2100") raise HTTPException(
status_code=400, detail="Year must be between 2020 and 2100"
)
usage = await UsageTracker.get_monthly_usage(current_user, year, month) usage = await UsageTracker.get_monthly_usage(current_user, year, month)
return UsageResponse( return UsageResponse(**usage, period=f"{year}-{month:02d}")
**usage,
period=f"{year}-{month:02d}"
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}"
)
@router.get("/invoices") @router.get("/invoices")
async def list_invoices( async def list_invoices(
limit: int = 50, limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user)
offset: int = 0,
current_user: User = Depends(get_current_user)
) -> List[InvoiceResponse]: ) -> List[InvoiceResponse]:
try: try:
if limit < 1 or limit > 100: if limit < 1 or limit > 100:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") raise HTTPException(
status_code=400, detail="Limit must be between 1 and 100"
)
if offset < 0: if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative") raise HTTPException(status_code=400, detail="Offset must be non-negative")
invoices = await Invoice.filter(user=current_user).order_by("-created_at").offset(offset).limit(limit).all() invoices = (
await Invoice.filter(user=current_user)
.order_by("-created_at")
.offset(offset)
.limit(limit)
.all()
)
result = [] result = []
for invoice in invoices: for invoice in invoices:
line_items = await invoice.line_items.all() line_items = await invoice.line_items.all()
result.append(InvoiceResponse( result.append(
id=invoice.id, InvoiceResponse(
invoice_number=invoice.invoice_number, id=invoice.id,
period_start=invoice.period_start, invoice_number=invoice.invoice_number,
period_end=invoice.period_end, period_start=invoice.period_start,
subtotal=float(invoice.subtotal), period_end=invoice.period_end,
tax=float(invoice.tax), subtotal=float(invoice.subtotal),
total=float(invoice.total), tax=float(invoice.tax),
status=invoice.status, total=float(invoice.total),
due_date=invoice.due_date, status=invoice.status,
paid_at=invoice.paid_at, due_date=invoice.due_date,
line_items=[ paid_at=invoice.paid_at,
{ line_items=[
"description": item.description, {
"quantity": float(item.quantity), "description": item.description,
"unit_price": float(item.unit_price), "quantity": float(item.quantity),
"amount": float(item.amount), "unit_price": float(item.unit_price),
"type": item.item_type "amount": float(item.amount),
} "type": item.item_type,
for item in line_items }
] for item in line_items
)) ],
)
)
return result return result
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch invoices: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to fetch invoices: {str(e)}"
)
@router.get("/invoices/{invoice_id}") @router.get("/invoices/{invoice_id}")
async def get_invoice( async def get_invoice(
invoice_id: int, invoice_id: int, current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
) -> InvoiceResponse: ) -> InvoiceResponse:
invoice = await Invoice.get_or_none(id=invoice_id, user=current_user) invoice = await Invoice.get_or_none(id=invoice_id, user=current_user)
if not invoice: if not invoice:
@ -177,21 +201,22 @@ async def get_invoice(
"quantity": float(item.quantity), "quantity": float(item.quantity),
"unit_price": float(item.unit_price), "unit_price": float(item.unit_price),
"amount": float(item.amount), "amount": float(item.amount),
"type": item.item_type "type": item.item_type,
} }
for item in line_items for item in line_items
] ],
) )
@router.get("/subscription") @router.get("/subscription")
async def get_subscription(current_user: User = Depends(get_current_user)) -> SubscriptionResponse: async def get_subscription(
current_user: User = Depends(get_current_user),
) -> SubscriptionResponse:
subscription = await UserSubscription.get_or_none(user=current_user) subscription = await UserSubscription.get_or_none(user=current_user)
if not subscription: if not subscription:
subscription = await UserSubscription.create( subscription = await UserSubscription.create(
user=current_user, user=current_user, billing_type="pay_as_you_go", status="active"
billing_type="pay_as_you_go",
status="active"
) )
plan_name = None plan_name = None
@ -205,15 +230,19 @@ async def get_subscription(current_user: User = Depends(get_current_user)) -> Su
plan_name=plan_name, plan_name=plan_name,
status=subscription.status, status=subscription.status,
current_period_start=subscription.current_period_start, current_period_start=subscription.current_period_start,
current_period_end=subscription.current_period_end current_period_end=subscription.current_period_end,
) )
@router.post("/payment-methods/setup-intent") @router.post("/payment-methods/setup-intent")
async def create_setup_intent(current_user: User = Depends(get_current_user)): async def create_setup_intent(current_user: User = Depends(get_current_user)):
try: try:
from ..settings import settings from ..settings import settings
if not settings.STRIPE_SECRET_KEY: if not settings.STRIPE_SECRET_KEY:
raise HTTPException(status_code=503, detail="Payment processing not configured") raise HTTPException(
status_code=503, detail="Payment processing not configured"
)
subscription = await UserSubscription.get_or_none(user=current_user) subscription = await UserSubscription.get_or_none(user=current_user)
@ -221,7 +250,7 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
customer_id = await StripeClient.create_customer( customer_id = await StripeClient.create_customer(
email=current_user.email, email=current_user.email,
name=current_user.username, name=current_user.username,
metadata={"user_id": str(current_user.id)} metadata={"user_id": str(current_user.id)},
) )
if not subscription: if not subscription:
@ -229,27 +258,30 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
user=current_user, user=current_user,
billing_type="pay_as_you_go", billing_type="pay_as_you_go",
stripe_customer_id=customer_id, stripe_customer_id=customer_id,
status="active" status="active",
) )
else: else:
subscription.stripe_customer_id = customer_id subscription.stripe_customer_id = customer_id
await subscription.save() await subscription.save()
import stripe import stripe
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
setup_intent = stripe.SetupIntent.create( setup_intent = stripe.SetupIntent.create(
customer=subscription.stripe_customer_id, customer=subscription.stripe_customer_id, payment_method_types=["card"]
payment_method_types=["card"]
) )
return { return {
"client_secret": setup_intent.client_secret, "client_secret": setup_intent.client_secret,
"customer_id": subscription.stripe_customer_id "customer_id": subscription.stripe_customer_id,
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create setup intent: {str(e)}") raise HTTPException(
status_code=500, detail=f"Failed to create setup intent: {str(e)}"
)
@router.get("/payment-methods") @router.get("/payment-methods")
async def list_payment_methods(current_user: User = Depends(get_current_user)): async def list_payment_methods(current_user: User = Depends(get_current_user)):
@ -262,11 +294,12 @@ async def list_payment_methods(current_user: User = Depends(get_current_user)):
"brand": m.brand, "brand": m.brand,
"exp_month": m.exp_month, "exp_month": m.exp_month,
"exp_year": m.exp_year, "exp_year": m.exp_year,
"is_default": m.is_default "is_default": m.is_default,
} }
for m in methods for m in methods
] ]
@router.post("/webhooks/stripe") @router.post("/webhooks/stripe")
async def stripe_webhook(request: Request): async def stripe_webhook(request: Request):
import stripe import stripe
@ -298,12 +331,14 @@ async def stripe_webhook(request: Request):
event_type=event["type"], event_type=event["type"],
stripe_event_id=event_id, stripe_event_id=event_id,
data=event["data"], data=event["data"],
processed=False processed=False,
) )
if event["type"] == "invoice.payment_succeeded": if event["type"] == "invoice.payment_succeeded":
invoice_data = event["data"]["object"] invoice_data = event["data"]["object"]
mywebdav_invoice_id = invoice_data.get("metadata", {}).get("mywebdav_invoice_id") mywebdav_invoice_id = invoice_data.get("metadata", {}).get(
"mywebdav_invoice_id"
)
if mywebdav_invoice_id: if mywebdav_invoice_id:
invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id)) invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id))
@ -317,7 +352,9 @@ async def stripe_webhook(request: Request):
payment_method = event["data"]["object"] payment_method = event["data"]["object"]
customer_id = payment_method["customer"] customer_id = payment_method["customer"]
subscription = await UserSubscription.get_or_none(stripe_customer_id=customer_id) subscription = await UserSubscription.get_or_none(
stripe_customer_id=customer_id
)
if subscription: if subscription:
await PaymentMethod.create( await PaymentMethod.create(
user=subscription.user, user=subscription.user,
@ -327,7 +364,7 @@ async def stripe_webhook(request: Request):
brand=payment_method.get("card", {}).get("brand"), brand=payment_method.get("card", {}).get("brand"),
exp_month=payment_method.get("card", {}).get("exp_month"), exp_month=payment_method.get("card", {}).get("exp_month"),
exp_year=payment_method.get("card", {}).get("exp_year"), exp_year=payment_method.get("card", {}).get("exp_year"),
is_default=True is_default=True,
) )
await BillingEvent.filter(stripe_event_id=event_id).update(processed=True) await BillingEvent.filter(stripe_event_id=event_id).update(processed=True)
@ -335,7 +372,10 @@ async def stripe_webhook(request: Request):
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Webhook processing failed: {str(e)}") raise HTTPException(
status_code=500, detail=f"Webhook processing failed: {str(e)}"
)
@router.get("/pricing") @router.get("/pricing")
async def get_pricing(): async def get_pricing():
@ -344,11 +384,12 @@ async def get_pricing():
config.config_key: { config.config_key: {
"value": float(config.config_value), "value": float(config.config_value),
"description": config.description, "description": config.description,
"unit": config.unit "unit": config.unit,
} }
for config in configs for config in configs
} }
@router.get("/plans") @router.get("/plans")
async def list_plans(): async def list_plans():
plans = await SubscriptionPlan.filter(is_active=True).all() plans = await SubscriptionPlan.filter(is_active=True).all()
@ -361,14 +402,16 @@ async def list_plans():
"storage_gb": plan.storage_gb, "storage_gb": plan.storage_gb,
"bandwidth_gb": plan.bandwidth_gb, "bandwidth_gb": plan.bandwidth_gb,
"price_monthly": float(plan.price_monthly), "price_monthly": float(plan.price_monthly),
"price_yearly": float(plan.price_yearly) if plan.price_yearly else None "price_yearly": float(plan.price_yearly) if plan.price_yearly else None,
} }
for plan in plans for plan in plans
] ]
@router.get("/stripe-key") @router.get("/stripe-key")
async def get_stripe_key(): async def get_stripe_key():
from ..settings import settings from ..settings import settings
if not settings.STRIPE_PUBLISHABLE_KEY: if not settings.STRIPE_PUBLISHABLE_KEY:
raise HTTPException(status_code=503, detail="Payment processing not configured") raise HTTPException(status_code=503, detail="Payment processing not configured")
return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY} return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY}

View File

@ -1,4 +1,12 @@
from fastapi import APIRouter, Depends, UploadFile, File as FastAPIFile, HTTPException, status, Response, Form from fastapi import (
APIRouter,
Depends,
UploadFile,
File as FastAPIFile,
HTTPException,
status,
Form,
)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import List, Optional from typing import List, Optional
import mimetypes import mimetypes
@ -11,7 +19,6 @@ from ..auth import get_current_user
from ..models import User, File, Folder from ..models import User, File, Folder
from ..schemas import FileOut from ..schemas import FileOut
from ..storage import storage_manager from ..storage import storage_manager
from ..settings import settings
from ..activity import log_activity from ..activity import log_activity
from ..thumbnails import generate_thumbnail, delete_thumbnail from ..thumbnails import generate_thumbnail, delete_thumbnail
@ -20,35 +27,46 @@ router = APIRouter(
tags=["files"], tags=["files"],
) )
class FileMove(BaseModel): class FileMove(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None
class FileRename(BaseModel): class FileRename(BaseModel):
new_name: str new_name: str
class FileCopy(BaseModel): class FileCopy(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None
class BatchFileOperation(BaseModel): class BatchFileOperation(BaseModel):
file_ids: List[int] file_ids: List[int]
operation: str # e.g., "delete", "star", "unstar", "move", "copy" operation: str # e.g., "delete", "star", "unstar", "move", "copy"
class BatchMoveCopyPayload(BaseModel): class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None
class FileContentUpdate(BaseModel): class FileContentUpdate(BaseModel):
content: str content: str
@router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED) @router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED)
async def upload_file( async def upload_file(
file: UploadFile = FastAPIFile(...), file: UploadFile = FastAPIFile(...),
folder_id: Optional[int] = Form(None), folder_id: Optional[int] = Form(None),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
if folder_id: if folder_id:
parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) parent_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not parent_folder: if not parent_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
else: else:
parent_folder = None parent_folder = None
@ -73,7 +91,7 @@ async def upload_file(
# Generate a unique path for storage # Generate a unique path for storage
file_extension = os.path.splitext(file.filename)[1] file_extension = os.path.splitext(file.filename)[1]
unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename
storage_path = unique_filename storage_path = unique_filename
# Save file to storage # Save file to storage
@ -105,16 +123,20 @@ async def upload_file(
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.get("/download/{file_id}") @router.get("/download/{file_id}")
async def download_file(file_id: int, current_user: User = Depends(get_current_user)): async def download_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.last_accessed_at = datetime.now() db_file.last_accessed_at = datetime.now()
await db_file.save() await db_file.save()
try: try:
async def file_iterator(): async def file_iterator():
async for chunk in storage_manager.get_file(current_user.id, db_file.path): async for chunk in storage_manager.get_file(current_user.id, db_file.path):
yield chunk yield chunk
@ -122,16 +144,21 @@ async def download_file(file_id: int, current_user: User = Depends(get_current_u
return StreamingResponse( return StreamingResponse(
file_iterator(), file_iterator(),
media_type=db_file.mime_type, media_type=db_file.mime_type,
headers={"Content-Disposition": f"attachment; filename=\"{db_file.name}\""} headers={"Content-Disposition": f'attachment; filename="{db_file.name}"'},
) )
except FileNotFoundError: except FileNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage"
)
@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_file(file_id: int, current_user: User = Depends(get_current_user)): async def delete_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_deleted = True db_file.is_deleted = True
db_file.deleted_at = datetime.now() db_file.deleted_at = datetime.now()
@ -141,68 +168,110 @@ async def delete_file(file_id: int, current_user: User = Depends(get_current_use
return return
@router.post("/{file_id}/move", response_model=FileOut) @router.post("/{file_id}/move", response_model=FileOut)
async def move_file(file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)): async def move_file(
file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
target_folder = None target_folder = None
if move_data.target_folder_id: if move_data.target_folder_id:
target_folder = await Folder.get_or_none(id=move_data.target_folder_id, owner=current_user, is_deleted=False) target_folder = await Folder.get_or_none(
id=move_data.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
)
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
) )
if existing_file and existing_file.id != file_id: if existing_file and existing_file.id != file_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in target folder") raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="File with this name already exists in target folder",
)
db_file.parent = target_folder db_file.parent = target_folder
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_moved", target_type="file", target_id=file_id) await log_activity(
user=current_user, action="file_moved", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/rename", response_model=FileOut) @router.post("/{file_id}/rename", response_model=FileOut)
async def rename_file(file_id: int, rename_data: FileRename, current_user: User = Depends(get_current_user)): async def rename_file(
file_id: int,
rename_data: FileRename,
current_user: User = Depends(get_current_user),
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=rename_data.new_name, parent_id=db_file.parent_id, owner=current_user, is_deleted=False name=rename_data.new_name,
parent_id=db_file.parent_id,
owner=current_user,
is_deleted=False,
) )
if existing_file and existing_file.id != file_id: if existing_file and existing_file.id != file_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in the same folder") raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="File with this name already exists in the same folder",
)
db_file.name = rename_data.new_name db_file.name = rename_data.new_name
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_renamed", target_type="file", target_id=file_id) await log_activity(
user=current_user, action="file_renamed", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED)
async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)): @router.post(
"/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED
)
async def copy_file(
file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
target_folder = None target_folder = None
if copy_data.target_folder_id: if copy_data.target_folder_id:
target_folder = await Folder.get_or_none(id=copy_data.target_folder_id, owner=current_user, is_deleted=False) target_folder = await Folder.get_or_none(
id=copy_data.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
)
base_name = db_file.name base_name = db_file.name
name_parts = os.path.splitext(base_name) name_parts = os.path.splitext(base_name)
counter = 1 counter = 1
new_name = base_name new_name = base_name
while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False): while await File.get_or_none(
name=new_name, parent=target_folder, owner=current_user, is_deleted=False
):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}" new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1 counter += 1
@ -213,139 +282,205 @@ async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depe
mime_type=db_file.mime_type, mime_type=db_file.mime_type,
file_hash=db_file.file_hash, file_hash=db_file.file_hash,
owner=current_user, owner=current_user,
parent=target_folder parent=target_folder,
) )
await log_activity(user=current_user, action="file_copied", target_type="file", target_id=new_file.id) await log_activity(
user=current_user,
action="file_copied",
target_type="file",
target_id=new_file.id,
)
return await FileOut.from_tortoise_orm(new_file) return await FileOut.from_tortoise_orm(new_file)
@router.get("/", response_model=List[FileOut]) @router.get("/", response_model=List[FileOut])
async def list_files(folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)): async def list_files(
folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)
):
if folder_id: if folder_id:
parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) parent_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not parent_folder: if not parent_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
files = await File.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name") status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
files = await File.filter(
parent=parent_folder, owner=current_user, is_deleted=False
).order_by("name")
else: else:
files = await File.filter(parent=None, owner=current_user, is_deleted=False).order_by("name") files = await File.filter(
parent=None, owner=current_user, is_deleted=False
).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/thumbnail/{file_id}") @router.get("/thumbnail/{file_id}")
async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)): async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.last_accessed_at = datetime.now() db_file.last_accessed_at = datetime.now()
await db_file.save() await db_file.save()
thumbnail_path = getattr(db_file, 'thumbnail_path', None) thumbnail_path = getattr(db_file, "thumbnail_path", None)
if not thumbnail_path: if not thumbnail_path:
thumbnail_path = await generate_thumbnail(db_file.path, db_file.mime_type, current_user.id) thumbnail_path = await generate_thumbnail(
db_file.path, db_file.mime_type, current_user.id
)
if thumbnail_path: if thumbnail_path:
db_file.thumbnail_path = thumbnail_path db_file.thumbnail_path = thumbnail_path
await db_file.save() await db_file.save()
else: else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available"
)
try: try:
async def thumbnail_iterator(): async def thumbnail_iterator():
async for chunk in storage_manager.get_file(current_user.id, thumbnail_path): async for chunk in storage_manager.get_file(
current_user.id, thumbnail_path
):
yield chunk yield chunk
return StreamingResponse( return StreamingResponse(thumbnail_iterator(), media_type="image/jpeg")
thumbnail_iterator(),
media_type="image/jpeg"
)
except FileNotFoundError: except FileNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not found in storage") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Thumbnail not found in storage",
)
@router.get("/photos", response_model=List[FileOut]) @router.get("/photos", response_model=List[FileOut])
async def list_photos(current_user: User = Depends(get_current_user)): async def list_photos(current_user: User = Depends(get_current_user)):
files = await File.filter( files = await File.filter(
owner=current_user, owner=current_user, is_deleted=False, mime_type__istartswith="image/"
is_deleted=False,
mime_type__istartswith="image/"
).order_by("-created_at") ).order_by("-created_at")
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/recent", response_model=List[FileOut]) @router.get("/recent", response_model=List[FileOut])
async def list_recent_files(current_user: User = Depends(get_current_user), limit: int = 10): async def list_recent_files(
files = await File.filter( current_user: User = Depends(get_current_user), limit: int = 10
owner=current_user, ):
is_deleted=False, files = (
last_accessed_at__isnull=False await File.filter(
).order_by("-last_accessed_at").limit(limit) owner=current_user, is_deleted=False, last_accessed_at__isnull=False
)
.order_by("-last_accessed_at")
.limit(limit)
)
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.post("/{file_id}/star", response_model=FileOut) @router.post("/{file_id}/star", response_model=FileOut)
async def star_file(file_id: int, current_user: User = Depends(get_current_user)): async def star_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_starred = True db_file.is_starred = True
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_starred", target_type="file", target_id=file_id) await log_activity(
user=current_user, action="file_starred", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/unstar", response_model=FileOut) @router.post("/{file_id}/unstar", response_model=FileOut)
async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)): async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_starred = False db_file.is_starred = False
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_unstarred", target_type="file", target_id=file_id) await log_activity(
user=current_user,
action="file_unstarred",
target_type="file",
target_id=file_id,
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.get("/deleted", response_model=List[FileOut]) @router.get("/deleted", response_model=List[FileOut])
async def list_deleted_files(current_user: User = Depends(get_current_user)): async def list_deleted_files(current_user: User = Depends(get_current_user)):
files = await File.filter(owner=current_user, is_deleted=True).order_by("-deleted_at") files = await File.filter(owner=current_user, is_deleted=True).order_by(
"-deleted_at"
)
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.post("/{file_id}/restore", response_model=FileOut) @router.post("/{file_id}/restore", response_model=FileOut)
async def restore_file(file_id: int, current_user: User = Depends(get_current_user)): async def restore_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found"
)
# Check if a file with the same name exists in the parent folder # Check if a file with the same name exists in the parent folder
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False
) )
if existing_file: if existing_file:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.") raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.",
)
db_file.is_deleted = False db_file.is_deleted = False
db_file.deleted_at = None db_file.deleted_at = None
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_restored", target_type="file", target_id=file_id) await log_activity(
user=current_user, action="file_restored", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
class BatchOperationResult(BaseModel): class BatchOperationResult(BaseModel):
succeeded: List[FileOut] succeeded: List[FileOut]
failed: List[dict] failed: List[dict]
@router.post("/batch") @router.post("/batch")
async def batch_file_operations( async def batch_file_operations(
batch_operation: BatchFileOperation, batch_operation: BatchFileOperation,
payload: Optional[BatchMoveCopyPayload] = None, payload: Optional[BatchMoveCopyPayload] = None,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]: if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid operation: {batch_operation.operation}") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid operation: {batch_operation.operation}",
)
updated_files = [] updated_files = []
failed_operations = [] failed_operations = []
for file_id in batch_operation.file_ids: for file_id in batch_operation.file_ids:
try: try:
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(
id=file_id, owner=current_user, is_deleted=False
)
if not db_file: if not db_file:
failed_operations.append({"file_id": file_id, "reason": "File not found or not owned by user"}) failed_operations.append(
{
"file_id": file_id,
"reason": "File not found or not owned by user",
}
)
continue continue
if batch_operation.operation == "delete": if batch_operation.operation == "delete":
@ -353,47 +488,87 @@ async def batch_file_operations(
db_file.deleted_at = datetime.now() db_file.deleted_at = datetime.now()
await db_file.save() await db_file.save()
await delete_thumbnail(db_file.id) await delete_thumbnail(db_file.id)
await log_activity(user=current_user, action="file_deleted_batch", target_type="file", target_id=file_id) await log_activity(
user=current_user,
action="file_deleted_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "star": elif batch_operation.operation == "star":
db_file.is_starred = True db_file.is_starred = True
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_starred_batch", target_type="file", target_id=file_id) await log_activity(
user=current_user,
action="file_starred_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "unstar": elif batch_operation.operation == "unstar":
db_file.is_starred = False db_file.is_starred = False
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_unstarred_batch", target_type="file", target_id=file_id) await log_activity(
user=current_user,
action="file_unstarred_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "move": elif batch_operation.operation == "move":
if not payload or payload.target_folder_id is None: if not payload or payload.target_folder_id is None:
failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"}) failed_operations.append(
{"file_id": file_id, "reason": "Target folder not specified"}
)
continue continue
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False) target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
failed_operations.append({"file_id": file_id, "reason": "Target folder not found"}) failed_operations.append(
{"file_id": file_id, "reason": "Target folder not found"}
)
continue continue
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False name=db_file.name,
parent=target_folder,
owner=current_user,
is_deleted=False,
) )
if existing_file and existing_file.id != file_id: if existing_file and existing_file.id != file_id:
failed_operations.append({"file_id": file_id, "reason": "File with same name exists in target folder"}) failed_operations.append(
{
"file_id": file_id,
"reason": "File with same name exists in target folder",
}
)
continue continue
db_file.parent = target_folder db_file.parent = target_folder
await db_file.save() await db_file.save()
await log_activity(user=current_user, action="file_moved_batch", target_type="file", target_id=file_id) await log_activity(
user=current_user,
action="file_moved_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "copy": elif batch_operation.operation == "copy":
if not payload or payload.target_folder_id is None: if not payload or payload.target_folder_id is None:
failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"}) failed_operations.append(
{"file_id": file_id, "reason": "Target folder not specified"}
)
continue continue
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False) target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
failed_operations.append({"file_id": file_id, "reason": "Target folder not found"}) failed_operations.append(
{"file_id": file_id, "reason": "Target folder not found"}
)
continue continue
base_name = db_file.name base_name = db_file.name
@ -401,7 +576,12 @@ async def batch_file_operations(
counter = 1 counter = 1
new_name = base_name new_name = base_name
while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False): while await File.get_or_none(
name=new_name,
parent=target_folder,
owner=current_user,
is_deleted=False,
):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}" new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1 counter += 1
@ -412,48 +592,70 @@ async def batch_file_operations(
mime_type=db_file.mime_type, mime_type=db_file.mime_type,
file_hash=db_file.file_hash, file_hash=db_file.file_hash,
owner=current_user, owner=current_user,
parent=target_folder parent=target_folder,
)
await log_activity(
user=current_user,
action="file_copied_batch",
target_type="file",
target_id=new_file.id,
) )
await log_activity(user=current_user, action="file_copied_batch", target_type="file", target_id=new_file.id)
updated_files.append(new_file) updated_files.append(new_file)
except Exception as e: except Exception as e:
failed_operations.append({"file_id": file_id, "reason": str(e)}) failed_operations.append({"file_id": file_id, "reason": str(e)})
return { return {
"succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files], "succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files],
"failed": failed_operations "failed": failed_operations,
} }
@router.put("/{file_id}/content", response_model=FileOut) @router.put("/{file_id}/content", response_model=FileOut)
async def update_file_content( async def update_file_content(
file_id: int, file_id: int,
payload: FileContentUpdate, payload: FileContentUpdate,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
if not db_file.mime_type or not db_file.mime_type.startswith('text/'): if not db_file.mime_type or not db_file.mime_type.startswith("text/"):
editableExtensions = [ editableExtensions = [
'txt', 'md', 'log', 'json', 'js', 'py', 'html', 'css', "txt",
'xml', 'yaml', 'yml', 'sh', 'bat', 'ini', 'conf', 'cfg' "md",
"log",
"json",
"js",
"py",
"html",
"css",
"xml",
"yaml",
"yml",
"sh",
"bat",
"ini",
"conf",
"cfg",
] ]
file_extension = os.path.splitext(db_file.name)[1][1:].lower() file_extension = os.path.splitext(db_file.name)[1][1:].lower()
if file_extension not in editableExtensions: if file_extension not in editableExtensions:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="File type is not editable" detail="File type is not editable",
) )
content_bytes = payload.content.encode('utf-8') content_bytes = payload.content.encode("utf-8")
new_size = len(content_bytes) new_size = len(content_bytes)
size_diff = new_size - db_file.size size_diff = new_size - db_file.size
if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes: if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_507_INSUFFICIENT_STORAGE, status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
detail="Storage quota exceeded" detail="Storage quota exceeded",
) )
new_hash = hashlib.sha256(content_bytes).hexdigest() new_hash = hashlib.sha256(content_bytes).hexdigest()
@ -465,7 +667,7 @@ async def update_file_content(
if new_storage_path != db_file.path: if new_storage_path != db_file.path:
try: try:
await storage_manager.delete_file(current_user.id, db_file.path) await storage_manager.delete_file(current_user.id, db_file.path)
except: except Exception:
pass pass
db_file.path = new_storage_path db_file.path = new_storage_path
@ -477,6 +679,8 @@ async def update_file_content(
current_user.used_storage_bytes += size_diff current_user.used_storage_bytes += size_diff
await current_user.save() await current_user.save()
await log_activity(user=current_user, action="file_updated", target_type="file", target_id=file_id) await log_activity(
user=current_user, action="file_updated", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)

View File

@ -3,7 +3,13 @@ from typing import List, Optional
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User, Folder from ..models import User, Folder
from ..schemas import FolderCreate, FolderOut, FolderUpdate, BatchFolderOperation, BatchMoveCopyPayload from ..schemas import (
FolderCreate,
FolderOut,
FolderUpdate,
BatchFolderOperation,
BatchMoveCopyPayload,
)
from ..activity import log_activity from ..activity import log_activity
router = APIRouter( router = APIRouter(
@ -11,12 +17,17 @@ router = APIRouter(
tags=["folders"], tags=["folders"],
) )
@router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
async def create_folder(folder_in: FolderCreate, current_user: User = Depends(get_current_user)): async def create_folder(
folder_in: FolderCreate, current_user: User = Depends(get_current_user)
):
# Check if parent folder exists and belongs to the current user # Check if parent folder exists and belongs to the current user
parent_folder = None parent_folder = None
if folder_in.parent_id: if folder_in.parent_id:
parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user) parent_folder = await Folder.get_or_none(
id=folder_in.parent_id, owner=current_user
)
if not parent_folder: if not parent_folder:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -39,50 +50,82 @@ async def create_folder(folder_in: FolderCreate, current_user: User = Depends(ge
await log_activity(current_user, "folder_created", "folder", folder.id) await log_activity(current_user, "folder_created", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder) return await FolderOut.from_tortoise_orm(folder)
@router.get("/{folder_id}/path", response_model=List[FolderOut]) @router.get("/{folder_id}/path", response_model=List[FolderOut])
async def get_folder_path(folder_id: int, current_user: User = Depends(get_current_user)): async def get_folder_path(
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) folder_id: int, current_user: User = Depends(get_current_user)
):
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
path = [] path = []
current = folder current = folder
while current: while current:
path.insert(0, await FolderOut.from_tortoise_orm(current)) path.insert(0, await FolderOut.from_tortoise_orm(current))
if current.parent_id: if current.parent_id:
current = await Folder.get_or_none(id=current.parent_id, owner=current_user, is_deleted=False) current = await Folder.get_or_none(
id=current.parent_id, owner=current_user, is_deleted=False
)
else: else:
current = None current = None
return path return path
@router.get("/{folder_id}", response_model=FolderOut) @router.get("/{folder_id}", response_model=FolderOut)
async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
return await FolderOut.from_tortoise_orm(folder) return await FolderOut.from_tortoise_orm(folder)
@router.get("/", response_model=List[FolderOut]) @router.get("/", response_model=List[FolderOut])
async def list_folders(parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)): async def list_folders(
parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)
):
if parent_id: if parent_id:
parent_folder = await Folder.get_or_none(id=parent_id, owner=current_user, is_deleted=False) parent_folder = await Folder.get_or_none(
id=parent_id, owner=current_user, is_deleted=False
)
if not parent_folder: if not parent_folder:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Parent folder not found or does not belong to the current user", detail="Parent folder not found or does not belong to the current user",
) )
folders = await Folder.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name") folders = await Folder.filter(
parent=parent_folder, owner=current_user, is_deleted=False
).order_by("name")
else: else:
# List root folders (folders with no parent) # List root folders (folders with no parent)
folders = await Folder.filter(parent=None, owner=current_user, is_deleted=False).order_by("name") folders = await Folder.filter(
parent=None, owner=current_user, is_deleted=False
).order_by("name")
return [await FolderOut.from_tortoise_orm(folder) for folder in folders] return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
@router.put("/{folder_id}", response_model=FolderOut) @router.put("/{folder_id}", response_model=FolderOut)
async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: User = Depends(get_current_user)): async def update_folder(
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) folder_id: int,
folder_in: FolderUpdate,
current_user: User = Depends(get_current_user),
):
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
if folder_in.name: if folder_in.name:
existing_folder = await Folder.get_or_none( existing_folder = await Folder.get_or_none(
@ -97,11 +140,16 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
if folder_in.parent_id is not None: if folder_in.parent_id is not None:
if folder_in.parent_id == folder_id: if folder_in.parent_id == folder_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot set folder as its own parent") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot set folder as its own parent",
)
new_parent_folder = None new_parent_folder = None
if folder_in.parent_id != 0: # 0 could represent moving to root if folder_in.parent_id != 0: # 0 could represent moving to root
new_parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user, is_deleted=False) new_parent_folder = await Folder.get_or_none(
id=folder_in.parent_id, owner=current_user, is_deleted=False
)
if not new_parent_folder: if not new_parent_folder:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -113,48 +161,66 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
await log_activity(current_user, "folder_updated", "folder", folder.id) await log_activity(current_user, "folder_updated", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder) return await FolderOut.from_tortoise_orm(folder)
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
folder.is_deleted = True folder.is_deleted = True
await folder.save() await folder.save()
await log_activity(current_user, "folder_deleted", "folder", folder.id) await log_activity(current_user, "folder_deleted", "folder", folder.id)
return return
@router.post("/{folder_id}/star", response_model=FolderOut) @router.post("/{folder_id}/star", response_model=FolderOut)
async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)):
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) db_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder: if not db_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
db_folder.is_starred = True db_folder.is_starred = True
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_starred", "folder", folder_id) await log_activity(current_user, "folder_starred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder) return await FolderOut.from_tortoise_orm(db_folder)
@router.post("/{folder_id}/unstar", response_model=FolderOut) @router.post("/{folder_id}/unstar", response_model=FolderOut)
async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)):
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) db_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder: if not db_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
db_folder.is_starred = False db_folder.is_starred = False
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id) await log_activity(current_user, "folder_unstarred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder) return await FolderOut.from_tortoise_orm(db_folder)
@router.post("/batch", response_model=List[FolderOut]) @router.post("/batch", response_model=List[FolderOut])
async def batch_folder_operations( async def batch_folder_operations(
batch_operation: BatchFolderOperation, batch_operation: BatchFolderOperation,
payload: Optional[BatchMoveCopyPayload] = None, payload: Optional[BatchMoveCopyPayload] = None,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
updated_folders = [] updated_folders = []
for folder_id in batch_operation.folder_ids: for folder_id in batch_operation.folder_ids:
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) db_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder: if not db_folder:
continue # Skip if folder not found or not owned by user continue # Skip if folder not found or not owned by user
if batch_operation.operation == "delete": if batch_operation.operation == "delete":
db_folder.is_deleted = True db_folder.is_deleted = True
@ -171,13 +237,22 @@ async def batch_folder_operations(
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id) await log_activity(current_user, "folder_unstarred", "folder", folder_id)
updated_folders.append(db_folder) updated_folders.append(db_folder)
elif batch_operation.operation == "move" and payload and payload.target_folder_id is not None: elif (
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False) batch_operation.operation == "move"
and payload
and payload.target_folder_id is not None
):
target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
continue continue
existing_folder = await Folder.get_or_none( existing_folder = await Folder.get_or_none(
name=db_folder.name, parent=target_folder, owner=current_user, is_deleted=False name=db_folder.name,
parent=target_folder,
owner=current_user,
is_deleted=False,
) )
if existing_folder and existing_folder.id != folder_id: if existing_folder and existing_folder.id != folder_id:
continue continue

View File

@ -11,15 +11,22 @@ router = APIRouter(
tags=["search"], tags=["search"],
) )
@router.get("/files", response_model=List[FileOut]) @router.get("/files", response_model=List[FileOut])
async def search_files( async def search_files(
q: str = Query(..., min_length=1, description="Search query"), q: str = Query(..., min_length=1, description="Search query"),
file_type: Optional[str] = Query(None, description="Filter by MIME type prefix (e.g., 'image', 'video')"), file_type: Optional[str] = Query(
None, description="Filter by MIME type prefix (e.g., 'image', 'video')"
),
min_size: Optional[int] = Query(None, description="Minimum file size in bytes"), min_size: Optional[int] = Query(None, description="Minimum file size in bytes"),
max_size: Optional[int] = Query(None, description="Maximum file size in bytes"), max_size: Optional[int] = Query(None, description="Maximum file size in bytes"),
date_from: Optional[datetime] = Query(None, description="Filter files created after this date"), date_from: Optional[datetime] = Query(
date_to: Optional[datetime] = Query(None, description="Filter files created before this date"), None, description="Filter files created after this date"
current_user: User = Depends(get_current_user) ),
date_to: Optional[datetime] = Query(
None, description="Filter files created before this date"
),
current_user: User = Depends(get_current_user),
): ):
query = File.filter(owner=current_user, is_deleted=False, name__icontains=q) query = File.filter(owner=current_user, is_deleted=False, name__icontains=q)
@ -41,15 +48,16 @@ async def search_files(
files = await query.order_by("-created_at").limit(100) files = await query.order_by("-created_at").limit(100)
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/folders", response_model=List[FolderOut]) @router.get("/folders", response_model=List[FolderOut])
async def search_folders( async def search_folders(
q: str = Query(..., min_length=1, description="Search query"), q: str = Query(..., min_length=1, description="Search query"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
folders = await Folder.filter( folders = (
owner=current_user, await Folder.filter(owner=current_user, is_deleted=False, name__icontains=q)
is_deleted=False, .order_by("-created_at")
name__icontains=q .limit(100)
).order_by("-created_at").limit(100) )
return [await FolderOut.from_tortoise_orm(folder) for folder in folders] return [await FolderOut.from_tortoise_orm(folder) for folder in folders]

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import Optional, List from typing import Optional, List
import secrets import secrets
from datetime import datetime, timedelta from datetime import datetime
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User, File, Folder, Share from ..models import User, File, Folder, Share
@ -17,25 +17,44 @@ router = APIRouter(
tags=["shares"], tags=["shares"],
) )
@router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED)
async def create_share_link(share_in: ShareCreate, current_user: User = Depends(get_current_user)): async def create_share_link(
share_in: ShareCreate, current_user: User = Depends(get_current_user)
):
if not share_in.file_id and not share_in.folder_id: if not share_in.file_id and not share_in.folder_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either file_id or folder_id must be provided") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either file_id or folder_id must be provided",
)
if share_in.file_id and share_in.folder_id: if share_in.file_id and share_in.folder_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot share both a file and a folder simultaneously") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot share both a file and a folder simultaneously",
)
file = None file = None
folder = None folder = None
if share_in.file_id: if share_in.file_id:
file = await File.get_or_none(id=share_in.file_id, owner=current_user, is_deleted=False) file = await File.get_or_none(
id=share_in.file_id, owner=current_user, is_deleted=False
)
if not file: if not file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found or does not belong to you") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found or does not belong to you",
)
if share_in.folder_id: if share_in.folder_id:
folder = await Folder.get_or_none(id=share_in.folder_id, owner=current_user, is_deleted=False) folder = await Folder.get_or_none(
id=share_in.folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found or does not belong to you") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Folder not found or does not belong to you",
)
token = secrets.token_urlsafe(16) token = secrets.token_urlsafe(16)
hashed_password = None hashed_password = None
@ -60,8 +79,14 @@ async def create_share_link(share_in: ShareCreate, current_user: User = Depends(
item_type = "file" if file else "folder" item_type = "file" if file else "folder"
item_name = file.name if file else folder.name item_name = file.name if file else folder.name
expiry_text = f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" if share_in.expires_at else "" expiry_text = (
password_text = f"\n\nPassword: {share_in.password}" if share_in.password else "" f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}"
if share_in.expires_at
else ""
)
password_text = (
f"\n\nPassword: {share_in.password}" if share_in.password else ""
)
email_body = f"""Hello, email_body = f"""Hello,
@ -95,26 +120,32 @@ MyWebdav File Sharing Service"""
to_email=share_in.invite_email, to_email=share_in.invite_email,
subject=f"{current_user.username} shared {item_name} with you", subject=f"{current_user.username} shared {item_name} with you",
body=email_body, body=email_body,
html=email_html html=email_html,
) )
except Exception as e: except Exception as e:
print(f"Failed to send invitation email: {e}") print(f"Failed to send invitation email: {e}")
return await ShareOut.from_tortoise_orm(share) return await ShareOut.from_tortoise_orm(share)
@router.get("/my", response_model=List[ShareOut]) @router.get("/my", response_model=List[ShareOut])
async def list_my_shares(current_user: User = Depends(get_current_user)): async def list_my_shares(current_user: User = Depends(get_current_user)):
shares = await Share.filter(owner=current_user).order_by("-created_at") shares = await Share.filter(owner=current_user).order_by("-created_at")
return [await ShareOut.from_tortoise_orm(share) for share in shares] return [await ShareOut.from_tortoise_orm(share) for share in shares]
@router.get("/{share_token}", response_model=ShareOut) @router.get("/{share_token}", response_model=ShareOut)
async def get_share_link_info(share_token: str): async def get_share_link_info(share_token: str):
share = await Share.get_or_none(token=share_token) share = await Share.get_or_none(token=share_token)
if not share: if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow(): if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired") raise HTTPException(
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
# Increment access count # Increment access count
share.access_count += 1 share.access_count += 1
@ -122,11 +153,17 @@ async def get_share_link_info(share_token: str):
return await ShareOut.from_tortoise_orm(share) return await ShareOut.from_tortoise_orm(share)
@router.put("/{share_id}", response_model=ShareOut) @router.put("/{share_id}", response_model=ShareOut)
async def update_share(share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)): async def update_share(
share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)
):
share = await Share.get_or_none(id=share_id, owner=current_user) share = await Share.get_or_none(id=share_id, owner=current_user)
if not share: if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share link not found or does not belong to you",
)
if share_in.expires_at is not None: if share_in.expires_at is not None:
share.expires_at = share_in.expires_at share.expires_at = share_in.expires_at
@ -134,7 +171,7 @@ async def update_share(share_id: int, share_in: ShareCreate, current_user: User
if share_in.password is not None: if share_in.password is not None:
share.hashed_password = get_password_hash(share_in.password) share.hashed_password = get_password_hash(share_in.password)
share.password_protected = True share.password_protected = True
elif share_in.password == "": # Allow clearing password elif share_in.password == "": # Allow clearing password
share.hashed_password = None share.hashed_password = None
share.password_protected = False share.password_protected = False
@ -144,31 +181,42 @@ async def update_share(share_id: int, share_in: ShareCreate, current_user: User
await share.save() await share.save()
return await ShareOut.from_tortoise_orm(share) return await ShareOut.from_tortoise_orm(share)
@router.post("/{share_token}/access") @router.post("/{share_token}/access")
async def access_shared_content(share_token: str, password: Optional[str] = None): async def access_shared_content(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token) share = await Share.get_or_none(token=share_token)
if not share: if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow(): if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired") raise HTTPException(
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
if share.password_protected: if share.password_protected:
if not password or not verify_password(password, share.hashed_password): if not password or not verify_password(password, share.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
)
result = {"message": "Access granted", "permission_level": share.permission_level} result = {"message": "Access granted", "permission_level": share.permission_level}
if share.file_id: if share.file_id:
file = await File.get_or_none(id=share.file_id, is_deleted=False) file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file: if not file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
result["file"] = await FileOut.from_tortoise_orm(file) result["file"] = await FileOut.from_tortoise_orm(file)
result["type"] = "file" result["type"] = "file"
elif share.folder_id: elif share.folder_id:
folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False) folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False)
if not folder: if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
result["folder"] = await FolderOut.from_tortoise_orm(folder) result["folder"] = await FolderOut.from_tortoise_orm(folder)
result["type"] = "folder" result["type"] = "folder"
@ -179,29 +227,42 @@ async def access_shared_content(share_token: str, password: Optional[str] = None
return result return result
@router.get("/{share_token}/download") @router.get("/{share_token}/download")
async def download_shared_file(share_token: str, password: Optional[str] = None): async def download_shared_file(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token) share = await Share.get_or_none(token=share_token)
if not share: if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow(): if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired") raise HTTPException(
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
if share.password_protected: if share.password_protected:
if not password or not verify_password(password, share.hashed_password): if not password or not verify_password(password, share.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
)
if not share.file_id: if not share.file_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This share is not for a file") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This share is not for a file",
)
file = await File.get_or_none(id=share.file_id, is_deleted=False) file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file: if not file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
owner = await User.get(id=file.owner_id) owner = await User.get(id=file.owner_id)
try: try:
async def file_iterator(): async def file_iterator():
async for chunk in storage_manager.get_file(owner.id, file.path): async for chunk in storage_manager.get_file(owner.id, file.path):
yield chunk yield chunk
@ -209,18 +270,22 @@ async def download_shared_file(share_token: str, password: Optional[str] = None)
return StreamingResponse( return StreamingResponse(
content=file_iterator(), content=file_iterator(),
media_type=file.mime_type, media_type=file.mime_type,
headers={ headers={"Content-Disposition": f'attachment; filename="{file.name}"'},
"Content-Disposition": f'attachment; filename="{file.name}"'
}
) )
except FileNotFoundError: except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage") raise HTTPException(status_code=404, detail="File not found in storage")
@router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_share_link(share_id: int, current_user: User = Depends(get_current_user)): async def delete_share_link(
share_id: int, current_user: User = Depends(get_current_user)
):
share = await Share.get_or_none(id=share_id, owner=current_user) share = await Share.get_or_none(id=share_id, owner=current_user)
if not share: if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share link not found or does not belong to you",
)
await share.delete() await share.delete()
return return

View File

@ -11,18 +11,29 @@ router = APIRouter(
tags=["starred"], tags=["starred"],
) )
@router.get("/files", response_model=List[FileOut]) @router.get("/files", response_model=List[FileOut])
async def list_starred_files(current_user: User = Depends(get_current_user)): async def list_starred_files(current_user: User = Depends(get_current_user)):
files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name") files = await File.filter(
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/folders", response_model=List[FolderOut]) @router.get("/folders", response_model=List[FolderOut])
async def list_starred_folders(current_user: User = Depends(get_current_user)): async def list_starred_folders(current_user: User = Depends(get_current_user)):
folders = await Folder.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name") folders = await Folder.filter(
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
return [await FolderOut.from_tortoise_orm(f) for f in folders] return [await FolderOut.from_tortoise_orm(f) for f in folders]
@router.get("/all", response_model=List[FileOut]) # This will return files and folders as files for now
@router.get(
"/all", response_model=List[FileOut]
) # This will return files and folders as files for now
async def list_all_starred(current_user: User = Depends(get_current_user)): async def list_all_starred(current_user: User = Depends(get_current_user)):
starred_files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name") starred_files = await File.filter(
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
# For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema. # For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema.
return [await FileOut.from_tortoise_orm(f) for f in starred_files] return [await FileOut.from_tortoise_orm(f) for f in starred_files]

View File

@ -1,17 +1,19 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User_Pydantic, User, File, Folder from ..models import User_Pydantic, User, File, Folder
from typing import List, Dict, Any from typing import Dict, Any
router = APIRouter( router = APIRouter(
prefix="/users", prefix="/users",
tags=["users"], tags=["users"],
) )
@router.get("/me", response_model=User_Pydantic) @router.get("/me", response_model=User_Pydantic)
async def read_users_me(current_user: User = Depends(get_current_user)): async def read_users_me(current_user: User = Depends(get_current_user)):
return await User_Pydantic.from_tortoise_orm(current_user) return await User_Pydantic.from_tortoise_orm(current_user)
@router.get("/me/export", response_model=Dict[str, Any]) @router.get("/me/export", response_model=Dict[str, Any])
async def export_my_data(current_user: User = Depends(get_current_user)): async def export_my_data(current_user: User = Depends(get_current_user)):
""" """
@ -35,6 +37,7 @@ async def export_my_data(current_user: User = Depends(get_current_user)):
# share information, etc., would also be included. # share information, etc., would also be included.
} }
@router.delete("/me", status_code=204) @router.delete("/me", status_code=204)
async def delete_my_account(current_user: User = Depends(get_current_user)): async def delete_my_account(current_user: User = Depends(get_current_user)):
""" """
@ -48,5 +51,3 @@ async def delete_my_account(current_user: User = Depends(get_current_user)):
# Finally, delete the user account # Finally, delete the user account
await current_user.delete() await current_user.delete()
return {} return {}

View File

@ -2,19 +2,24 @@ from datetime import datetime
from pydantic import BaseModel, EmailStr, ConfigDict from pydantic import BaseModel, EmailStr, ConfigDict
from typing import Optional, List from typing import Optional, List
from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.contrib.pydantic import pydantic_model_creator
from mywebdav.models import Folder, File, Share, FileVersion
class UserCreate(BaseModel): class UserCreate(BaseModel):
username: str username: str
email: EmailStr email: EmailStr
password: str password: str
class UserLogin(BaseModel): class UserLogin(BaseModel):
username: str username: str
password: str password: str
class UserLoginWith2FA(UserLogin): class UserLoginWith2FA(UserLogin):
two_factor_code: Optional[str] = None two_factor_code: Optional[str] = None
class UserAdminUpdate(BaseModel): class UserAdminUpdate(BaseModel):
username: Optional[str] = None username: Optional[str] = None
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
@ -25,22 +30,27 @@ class UserAdminUpdate(BaseModel):
plan_type: Optional[str] = None plan_type: Optional[str] = None
is_2fa_enabled: Optional[bool] = None is_2fa_enabled: Optional[bool] = None
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class TokenData(BaseModel): class TokenData(BaseModel):
username: str | None = None username: str | None = None
two_factor_verified: bool = False two_factor_verified: bool = False
class FolderCreate(BaseModel): class FolderCreate(BaseModel):
name: str name: str
parent_id: Optional[int] = None parent_id: Optional[int] = None
class FolderUpdate(BaseModel): class FolderUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
parent_id: Optional[int] = None parent_id: Optional[int] = None
class ShareCreate(BaseModel): class ShareCreate(BaseModel):
file_id: Optional[int] = None file_id: Optional[int] = None
folder_id: Optional[int] = None folder_id: Optional[int] = None
@ -49,9 +59,11 @@ class ShareCreate(BaseModel):
permission_level: str = "viewer" permission_level: str = "viewer"
invite_email: Optional[EmailStr] = None invite_email: Optional[EmailStr] = None
class TeamCreate(BaseModel): class TeamCreate(BaseModel):
name: str name: str
class TeamOut(BaseModel): class TeamOut(BaseModel):
id: int id: int
name: str name: str
@ -60,6 +72,7 @@ class TeamOut(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class ActivityOut(BaseModel): class ActivityOut(BaseModel):
id: int id: int
user_id: Optional[int] = None user_id: Optional[int] = None
@ -71,12 +84,14 @@ class ActivityOut(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class FileRequestCreate(BaseModel): class FileRequestCreate(BaseModel):
title: str title: str
description: Optional[str] = None description: Optional[str] = None
target_folder_id: int target_folder_id: int
expires_at: Optional[datetime] = None expires_at: Optional[datetime] = None
class FileRequestOut(BaseModel): class FileRequestOut(BaseModel):
id: int id: int
title: str title: str
@ -90,25 +105,28 @@ class FileRequestOut(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
from mywebdav.models import Folder, File, Share, FileVersion
FolderOut = pydantic_model_creator(Folder, name="FolderOut") FolderOut = pydantic_model_creator(Folder, name="FolderOut")
FileOut = pydantic_model_creator(File, name="FileOut") FileOut = pydantic_model_creator(File, name="FileOut")
ShareOut = pydantic_model_creator(Share, name="ShareOut") ShareOut = pydantic_model_creator(Share, name="ShareOut")
FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut") FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
code: int code: int
message: str message: str
details: Optional[str] = None details: Optional[str] = None
class BatchFileOperation(BaseModel): class BatchFileOperation(BaseModel):
file_ids: List[int] file_ids: List[int]
operation: str # e.g., "delete", "move", "copy", "star", "unstar" operation: str # e.g., "delete", "move", "copy", "star", "unstar"
class BatchFolderOperation(BaseModel): class BatchFolderOperation(BaseModel):
folder_ids: List[int] folder_ids: List[int]
operation: str # e.g., "delete", "move", "star", "unstar" operation: str # e.g., "delete", "move", "star", "unstar"
class BatchMoveCopyPayload(BaseModel): class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None

View File

@ -2,8 +2,9 @@ import os
import sys import sys
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', extra='ignore') model_config = SettingsConfigDict(env_file=".env", extra="ignore")
DATABASE_URL: str = "sqlite:///app/mywebdav.db" DATABASE_URL: str = "sqlite:///app/mywebdav.db"
REDIS_URL: str = "redis://redis:6379/0" REDIS_URL: str = "redis://redis:6379/0"
@ -30,8 +31,14 @@ class Settings(BaseSettings):
STRIPE_WEBHOOK_SECRET: str = "" STRIPE_WEBHOOK_SECRET: str = ""
BILLING_ENABLED: bool = False BILLING_ENABLED: bool = False
settings = Settings() settings = Settings()
if settings.SECRET_KEY == "super_secret_key" and os.getenv("ENVIRONMENT") == "production": if (
print("ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable.") settings.SECRET_KEY == "super_secret_key"
and os.getenv("ENVIRONMENT") == "production"
):
print(
"ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable."
)
sys.exit(1) sys.exit(1)

View File

@ -5,6 +5,7 @@ from typing import AsyncGenerator
from .settings import settings from .settings import settings
class StorageManager: class StorageManager:
def __init__(self, base_path: str = settings.STORAGE_PATH): def __init__(self, base_path: str = settings.STORAGE_PATH):
self.base_path = Path(base_path) self.base_path = Path(base_path)
@ -12,7 +13,11 @@ class StorageManager:
async def _get_full_path(self, user_id: int, file_path: str) -> Path: async def _get_full_path(self, user_id: int, file_path: str) -> Path:
# Ensure file_path is relative and safe # Ensure file_path is relative and safe
relative_path = Path(file_path).relative_to('/') if str(file_path).startswith('/') else Path(file_path) relative_path = (
Path(file_path).relative_to("/")
if str(file_path).startswith("/")
else Path(file_path)
)
full_path = self.base_path / str(user_id) / relative_path full_path = self.base_path / str(user_id) / relative_path
full_path.parent.mkdir(parents=True, exist_ok=True) full_path.parent.mkdir(parents=True, exist_ok=True)
return full_path return full_path
@ -52,4 +57,5 @@ class StorageManager:
full_path = await self._get_full_path(user_id, file_path) full_path = await self._get_full_path(user_id, file_path)
return full_path.exists() return full_path.exists()
storage_manager = StorageManager() storage_manager = StorageManager()

View File

@ -1,4 +1,3 @@
import os
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
@ -9,7 +8,10 @@ from .settings import settings
THUMBNAIL_SIZE = (300, 300) THUMBNAIL_SIZE = (300, 300)
THUMBNAIL_DIR = "thumbnails" THUMBNAIL_DIR = "thumbnails"
async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Optional[str]:
async def generate_thumbnail(
file_path: str, mime_type: str, user_id: int
) -> Optional[str]:
try: try:
if mime_type.startswith("image/"): if mime_type.startswith("image/"):
return await generate_image_thumbnail(file_path, user_id) return await generate_image_thumbnail(file_path, user_id)
@ -20,6 +22,7 @@ async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Op
print(f"Error generating thumbnail for {file_path}: {e}") print(f"Error generating thumbnail for {file_path}: {e}")
return None return None
async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]: async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -30,12 +33,16 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
file_name = Path(file_path).name file_name = Path(file_path).name
thumbnail_name = f"thumb_{file_name}" thumbnail_name = f"thumb_{file_name}"
if not thumbnail_name.lower().endswith(('.jpg', '.jpeg', '.png')): if not thumbnail_name.lower().endswith((".jpg", ".jpeg", ".png")):
thumbnail_name += ".jpg" thumbnail_name += ".jpg"
thumbnail_path = thumbnail_dir / thumbnail_name thumbnail_path = thumbnail_dir / thumbnail_name
actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path) actual_file_path = (
base_path / str(user_id) / file_path
if not Path(file_path).is_absolute()
else Path(file_path)
)
with Image.open(actual_file_path) as img: with Image.open(actual_file_path) as img:
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
@ -44,7 +51,9 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
background = Image.new("RGB", img.size, (255, 255, 255)) background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == "P": if img.mode == "P":
img = img.convert("RGBA") img = img.convert("RGBA")
background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None) background.paste(
img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None
)
img = background img = background
img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True) img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True)
@ -53,6 +62,7 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
return await loop.run_in_executor(None, _generate) return await loop.run_in_executor(None, _generate)
async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]: async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -65,17 +75,29 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
thumbnail_name = f"thumb_{file_name}.jpg" thumbnail_name = f"thumb_{file_name}.jpg"
thumbnail_path = thumbnail_dir / thumbnail_name thumbnail_path = thumbnail_dir / thumbnail_name
actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path) actual_file_path = (
base_path / str(user_id) / file_path
if not Path(file_path).is_absolute()
else Path(file_path)
)
subprocess.run([ subprocess.run(
"ffmpeg", [
"-i", str(actual_file_path), "ffmpeg",
"-ss", "00:00:01", "-i",
"-vframes", "1", str(actual_file_path),
"-vf", f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease", "-ss",
"-y", "00:00:01",
str(thumbnail_path) "-vframes",
], check=True, capture_output=True) "1",
"-vf",
f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
"-y",
str(thumbnail_path),
],
check=True,
capture_output=True,
)
return str(thumbnail_path.relative_to(base_path / str(user_id))) return str(thumbnail_path.relative_to(base_path / str(user_id)))
@ -84,6 +106,7 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
return None return None
async def delete_thumbnail(thumbnail_path: str, user_id: int): async def delete_thumbnail(thumbnail_path: str, user_id: int):
try: try:
base_path = Path(settings.STORAGE_PATH) base_path = Path(settings.STORAGE_PATH)

View File

@ -5,15 +5,20 @@ import base64
import secrets import secrets
import hashlib import hashlib
from typing import List, Optional from typing import List
def generate_totp_secret() -> str: def generate_totp_secret() -> str:
"""Generates a random base32 TOTP secret.""" """Generates a random base32 TOTP secret."""
return pyotp.random_base32() return pyotp.random_base32()
def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str: def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str:
"""Generates a Google Authenticator-compatible TOTP URI.""" """Generates a Google Authenticator-compatible TOTP URI."""
return pyotp.totp.TOTP(secret).provisioning_uri(name=account_name, issuer_name=issuer_name) return pyotp.totp.TOTP(secret).provisioning_uri(
name=account_name, issuer_name=issuer_name
)
def generate_qr_code_base64(uri: str) -> str: def generate_qr_code_base64(uri: str) -> str:
"""Generates a base64 encoded QR code image for a given URI.""" """Generates a base64 encoded QR code image for a given URI."""
@ -31,27 +36,33 @@ def generate_qr_code_base64(uri: str) -> str:
img.save(buffered, format="PNG") img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8") return base64.b64encode(buffered.getvalue()).decode("utf-8")
def verify_totp_code(secret: str, code: str) -> bool: def verify_totp_code(secret: str, code: str) -> bool:
"""Verifies a TOTP code against a secret.""" """Verifies a TOTP code against a secret."""
totp = pyotp.TOTP(secret) totp = pyotp.TOTP(secret)
return totp.verify(code) return totp.verify(code)
def generate_recovery_codes(num_codes: int = 10) -> List[str]: def generate_recovery_codes(num_codes: int = 10) -> List[str]:
"""Generates a list of random recovery codes.""" """Generates a list of random recovery codes."""
return [secrets.token_urlsafe(16) for _ in range(num_codes)] return [secrets.token_urlsafe(16) for _ in range(num_codes)]
def hash_recovery_code(code: str) -> str: def hash_recovery_code(code: str) -> str:
"""Hashes a single recovery code using SHA256.""" """Hashes a single recovery code using SHA256."""
return hashlib.sha256(code.encode('utf-8')).hexdigest() return hashlib.sha256(code.encode("utf-8")).hexdigest()
def verify_recovery_code(plain_code: str, hashed_code: str) -> bool: def verify_recovery_code(plain_code: str, hashed_code: str) -> bool:
"""Verifies a plain recovery code against its hashed version.""" """Verifies a plain recovery code against its hashed version."""
return hash_recovery_code(plain_code) == hashed_code return hash_recovery_code(plain_code) == hashed_code
def hash_recovery_codes(codes: List[str]) -> List[str]: def hash_recovery_codes(codes: List[str]) -> List[str]:
"""Hashes a list of recovery codes.""" """Hashes a list of recovery codes."""
return [hash_recovery_code(code) for code in codes] return [hash_recovery_code(code) for code in codes]
def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool: def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool:
"""Verifies if a plain recovery code matches any of the hashed recovery codes.""" """Verifies if a plain recovery code matches any of the hashed recovery codes."""
for hashed_code in hashed_codes: for hashed_code in hashed_codes:

View File

@ -19,6 +19,7 @@ router = APIRouter(
tags=["webdav"], tags=["webdav"],
) )
class WebDAVLock: class WebDAVLock:
locks = {} locks = {}
@ -26,10 +27,10 @@ class WebDAVLock:
def create_lock(cls, path: str, user_id: int, timeout: int = 3600): def create_lock(cls, path: str, user_id: int, timeout: int = 3600):
lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}" lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}"
cls.locks[path] = { cls.locks[path] = {
'token': lock_token, "token": lock_token,
'user_id': user_id, "user_id": user_id,
'created_at': datetime.now(), "created_at": datetime.now(),
'timeout': timeout "timeout": timeout,
} }
return lock_token return lock_token
@ -42,17 +43,18 @@ class WebDAVLock:
if path in cls.locks: if path in cls.locks:
del cls.locks[path] del cls.locks[path]
async def basic_auth(authorization: Optional[str] = Header(None)): async def basic_auth(authorization: Optional[str] = Header(None)):
if not authorization: if not authorization:
return None return None
try: try:
scheme, credentials = authorization.split() scheme, credentials = authorization.split()
if scheme.lower() != 'basic': if scheme.lower() != "basic":
return None return None
decoded = base64.b64decode(credentials).decode('utf-8') decoded = base64.b64decode(credentials).decode("utf-8")
username, password = decoded.split(':', 1) username, password = decoded.split(":", 1)
user = await User.get_or_none(username=username) user = await User.get_or_none(username=username)
if user and verify_password(password, user.hashed_password): if user and verify_password(password, user.hashed_password):
@ -62,6 +64,7 @@ async def basic_auth(authorization: Optional[str] = Header(None)):
return None return None
async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)): async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)):
user = await basic_auth(authorization) user = await basic_auth(authorization)
if user: if user:
@ -75,14 +78,15 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
headers={'WWW-Authenticate': 'Basic realm="MyWebdav WebDAV"'} headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'},
) )
async def resolve_path(path_str: str, user: User): async def resolve_path(path_str: str, user: User):
if not path_str or path_str == '/': if not path_str or path_str == "/":
return None, None, True return None, None, True
parts = [p for p in path_str.split('/') if p] parts = [p for p in path_str.split("/") if p]
if not parts: if not parts:
return None, None, True return None, None, True
@ -90,10 +94,7 @@ async def resolve_path(path_str: str, user: User):
current_folder = None current_folder = None
for i, part in enumerate(parts[:-1]): for i, part in enumerate(parts[:-1]):
folder = await Folder.get_or_none( folder = await Folder.get_or_none(
name=part, name=part, parent=current_folder, owner=user, is_deleted=False
parent=current_folder,
owner=user,
is_deleted=False
) )
if not folder: if not folder:
return None, None, False return None, None, False
@ -102,39 +103,37 @@ async def resolve_path(path_str: str, user: User):
last_part = parts[-1] last_part = parts[-1]
folder = await Folder.get_or_none( folder = await Folder.get_or_none(
name=last_part, name=last_part, parent=current_folder, owner=user, is_deleted=False
parent=current_folder,
owner=user,
is_deleted=False
) )
if folder: if folder:
return folder, current_folder, True return folder, current_folder, True
file = await File.get_or_none( file = await File.get_or_none(
name=last_part, name=last_part, parent=current_folder, owner=user, is_deleted=False
parent=current_folder,
owner=user,
is_deleted=False
) )
if file: if file:
return file, current_folder, True return file, current_folder, True
return None, current_folder, True return None, current_folder, True
def build_href(base_path: str, name: str, is_collection: bool): def build_href(base_path: str, name: str, is_collection: bool):
path = f"{base_path.rstrip('/')}/{name}" path = f"{base_path.rstrip('/')}/{name}"
if is_collection: if is_collection:
path += '/' path += "/"
return path return path
async def get_custom_properties(resource_type: str, resource_id: int): async def get_custom_properties(resource_type: str, resource_id: int):
props = await WebDAVProperty.filter( props = await WebDAVProperty.filter(
resource_type=resource_type, resource_type=resource_type, resource_id=resource_id
resource_id=resource_id
) )
return {(prop.namespace, prop.name): prop.value for prop in props} return {(prop.namespace, prop.name): prop.value for prop in props}
def create_propstat_element(props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"):
def create_propstat_element(
props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"
):
propstat = ET.Element("D:propstat") propstat = ET.Element("D:propstat")
prop = ET.SubElement(propstat, "D:prop") prop = ET.SubElement(propstat, "D:prop")
@ -151,10 +150,10 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
elem.text = value elem.text = value
elif key == "getlastmodified": elif key == "getlastmodified":
elem = ET.SubElement(prop, f"D:{key}") elem = ET.SubElement(prop, f"D:{key}")
elem.text = value.strftime('%a, %d %b %Y %H:%M:%S GMT') elem.text = value.strftime("%a, %d %b %Y %H:%M:%S GMT")
elif key == "creationdate": elif key == "creationdate":
elem = ET.SubElement(prop, f"D:{key}") elem = ET.SubElement(prop, f"D:{key}")
elem.text = value.isoformat() + 'Z' elem.text = value.isoformat() + "Z"
elif key == "displayname": elif key == "displayname":
elem = ET.SubElement(prop, f"D:{key}") elem = ET.SubElement(prop, f"D:{key}")
elem.text = value elem.text = value
@ -174,6 +173,7 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
return propstat return propstat
def parse_propfind_body(body: bytes): def parse_propfind_body(body: bytes):
if not body: if not body:
return None return None
@ -193,8 +193,8 @@ def parse_propfind_body(body: bytes):
if prop is not None: if prop is not None:
requested_props = [] requested_props = []
for child in prop: for child in prop:
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:" ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
name = child.tag.split('}')[1] if '}' in child.tag else child.tag name = child.tag.split("}")[1] if "}" in child.tag else child.tag
requested_props.append((ns, name)) requested_props.append((ns, name))
return requested_props return requested_props
except ET.ParseError: except ET.ParseError:
@ -202,6 +202,7 @@ def parse_propfind_body(body: bytes):
return None return None
@router.api_route("/{full_path:path}", methods=["OPTIONS"]) @router.api_route("/{full_path:path}", methods=["OPTIONS"])
async def webdav_options(full_path: str): async def webdav_options(full_path: str):
return Response( return Response(
@ -209,14 +210,17 @@ async def webdav_options(full_path: str):
headers={ headers={
"DAV": "1, 2", "DAV": "1, 2",
"Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK", "Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK",
"MS-Author-Via": "DAV" "MS-Author-Via": "DAV",
} },
) )
@router.api_route("/{full_path:path}", methods=["PROPFIND"]) @router.api_route("/{full_path:path}", methods=["PROPFIND"])
async def handle_propfind(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_propfind(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
depth = request.headers.get("Depth", "1") depth = request.headers.get("Depth", "1")
full_path = unquote(full_path).strip('/') full_path = unquote(full_path).strip("/")
body = await request.body() body = await request.body()
requested_props = parse_propfind_body(body) requested_props = parse_propfind_body(body)
@ -232,19 +236,23 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
if resource is None: if resource is None:
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href") href = ET.SubElement(response, "D:href")
href.text = base_href if base_href.endswith('/') else base_href + '/' href.text = base_href if base_href.endswith("/") else base_href + "/"
props = { props = {
"resourcetype": "collection", "resourcetype": "collection",
"displayname": full_path.split('/')[-1] if full_path else "Root", "displayname": full_path.split("/")[-1] if full_path else "Root",
"creationdate": datetime.now(), "creationdate": datetime.now(),
"getlastmodified": datetime.now() "getlastmodified": datetime.now(),
} }
response.append(create_propstat_element(props)) response.append(create_propstat_element(props))
if depth in ["1", "infinity"]: if depth in ["1", "infinity"]:
folders = await Folder.filter(owner=current_user, parent=parent, is_deleted=False) folders = await Folder.filter(
files = await File.filter(owner=current_user, parent=parent, is_deleted=False) owner=current_user, parent=parent, is_deleted=False
)
files = await File.filter(
owner=current_user, parent=parent, is_deleted=False
)
for folder in folders: for folder in folders:
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
@ -255,9 +263,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"resourcetype": "collection", "resourcetype": "collection",
"displayname": folder.name, "displayname": folder.name,
"creationdate": folder.created_at, "creationdate": folder.created_at,
"getlastmodified": folder.updated_at "getlastmodified": folder.updated_at,
} }
custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None custom_props = (
await get_custom_properties("folder", folder.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
for file in files: for file in files:
@ -272,28 +288,48 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": file.mime_type, "getcontenttype": file.mime_type,
"creationdate": file.created_at, "creationdate": file.created_at,
"getlastmodified": file.updated_at, "getlastmodified": file.updated_at,
"getetag": f'"{file.file_hash}"' "getetag": f'"{file.file_hash}"',
} }
custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None custom_props = (
await get_custom_properties("file", file.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, Folder): elif isinstance(resource, Folder):
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href") href = ET.SubElement(response, "D:href")
href.text = base_href if base_href.endswith('/') else base_href + '/' href.text = base_href if base_href.endswith("/") else base_href + "/"
props = { props = {
"resourcetype": "collection", "resourcetype": "collection",
"displayname": resource.name, "displayname": resource.name,
"creationdate": resource.created_at, "creationdate": resource.created_at,
"getlastmodified": resource.updated_at "getlastmodified": resource.updated_at,
} }
custom_props = await get_custom_properties("folder", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None custom_props = (
await get_custom_properties("folder", resource.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
if depth in ["1", "infinity"]: if depth in ["1", "infinity"]:
folders = await Folder.filter(owner=current_user, parent=resource, is_deleted=False) folders = await Folder.filter(
files = await File.filter(owner=current_user, parent=resource, is_deleted=False) owner=current_user, parent=resource, is_deleted=False
)
files = await File.filter(
owner=current_user, parent=resource, is_deleted=False
)
for folder in folders: for folder in folders:
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
@ -304,9 +340,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"resourcetype": "collection", "resourcetype": "collection",
"displayname": folder.name, "displayname": folder.name,
"creationdate": folder.created_at, "creationdate": folder.created_at,
"getlastmodified": folder.updated_at "getlastmodified": folder.updated_at,
} }
custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None custom_props = (
await get_custom_properties("folder", folder.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
for file in files: for file in files:
@ -321,9 +365,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": file.mime_type, "getcontenttype": file.mime_type,
"creationdate": file.created_at, "creationdate": file.created_at,
"getlastmodified": file.updated_at, "getlastmodified": file.updated_at,
"getetag": f'"{file.file_hash}"' "getetag": f'"{file.file_hash}"',
} }
custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None custom_props = (
await get_custom_properties("file", file.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, File): elif isinstance(resource, File):
@ -338,17 +390,32 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": resource.mime_type, "getcontenttype": resource.mime_type,
"creationdate": resource.created_at, "creationdate": resource.created_at,
"getlastmodified": resource.updated_at, "getlastmodified": resource.updated_at,
"getetag": f'"{resource.file_hash}"' "getetag": f'"{resource.file_hash}"',
} }
custom_props = await get_custom_properties("file", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None custom_props = (
await get_custom_properties("file", resource.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True) xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207) return Response(
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=207,
)
@router.api_route("/{full_path:path}", methods=["GET", "HEAD"]) @router.api_route("/{full_path:path}", methods=["GET", "HEAD"])
async def handle_get(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_get(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user) resource, parent, exists = await resolve_path(full_path, current_user)
@ -363,8 +430,10 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
"Content-Length": str(resource.size), "Content-Length": str(resource.size),
"Content-Type": resource.mime_type, "Content-Type": resource.mime_type,
"ETag": f'"{resource.file_hash}"', "ETag": f'"{resource.file_hash}"',
"Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT') "Last-Modified": resource.updated_at.strftime(
} "%a, %d %b %Y %H:%M:%S GMT"
),
},
) )
async def file_iterator(): async def file_iterator():
@ -377,23 +446,28 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
headers={ headers={
"Content-Disposition": f'attachment; filename="{resource.name}"', "Content-Disposition": f'attachment; filename="{resource.name}"',
"ETag": f'"{resource.file_hash}"', "ETag": f'"{resource.file_hash}"',
"Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT') "Last-Modified": resource.updated_at.strftime(
} "%a, %d %b %Y %H:%M:%S GMT"
),
},
) )
except FileNotFoundError: except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage") raise HTTPException(status_code=404, detail="File not found in storage")
@router.api_route("/{full_path:path}", methods=["PUT"]) @router.api_route("/{full_path:path}", methods=["PUT"])
async def handle_put(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_put(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
if not full_path: if not full_path:
raise HTTPException(status_code=400, detail="Cannot PUT to root") raise HTTPException(status_code=400, detail="Cannot PUT to root")
parts = [p for p in full_path.split('/') if p] parts = [p for p in full_path.split("/") if p]
file_name = parts[-1] file_name = parts[-1]
parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else '' parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
_, parent_folder, exists = await resolve_path(parent_path, current_user) _, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists: if not exists:
@ -417,10 +491,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
mime_type = "application/octet-stream" mime_type = "application/octet-stream"
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=file_name, name=file_name, parent=parent_folder, owner=current_user, is_deleted=False
parent=parent_folder,
owner=current_user,
is_deleted=False
) )
if existing_file: if existing_file:
@ -432,7 +503,9 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
existing_file.updated_at = datetime.now() existing_file.updated_at = datetime.now()
await existing_file.save() await existing_file.save()
current_user.used_storage_bytes = current_user.used_storage_bytes - old_size + file_size current_user.used_storage_bytes = (
current_user.used_storage_bytes - old_size + file_size
)
await current_user.save() await current_user.save()
await log_activity(current_user, "file_updated", "file", existing_file.id) await log_activity(current_user, "file_updated", "file", existing_file.id)
@ -445,7 +518,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
mime_type=mime_type, mime_type=mime_type,
file_hash=file_hash, file_hash=file_hash,
owner=current_user, owner=current_user,
parent=parent_folder parent=parent_folder,
) )
current_user.used_storage_bytes += file_size current_user.used_storage_bytes += file_size
@ -454,9 +527,12 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
await log_activity(current_user, "file_created", "file", db_file.id) await log_activity(current_user, "file_created", "file", db_file.id)
return Response(status_code=201) return Response(status_code=201)
@router.api_route("/{full_path:path}", methods=["DELETE"]) @router.api_route("/{full_path:path}", methods=["DELETE"])
async def handle_delete(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_delete(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
if not full_path: if not full_path:
raise HTTPException(status_code=400, detail="Cannot DELETE root") raise HTTPException(status_code=400, detail="Cannot DELETE root")
@ -478,44 +554,45 @@ async def handle_delete(request: Request, full_path: str, current_user: User = D
return Response(status_code=204) return Response(status_code=204)
@router.api_route("/{full_path:path}", methods=["MKCOL"]) @router.api_route("/{full_path:path}", methods=["MKCOL"])
async def handle_mkcol(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_mkcol(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
if not full_path: if not full_path:
raise HTTPException(status_code=400, detail="Cannot MKCOL at root") raise HTTPException(status_code=400, detail="Cannot MKCOL at root")
parts = [p for p in full_path.split('/') if p] parts = [p for p in full_path.split("/") if p]
folder_name = parts[-1] folder_name = parts[-1]
parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else '' parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
_, parent_folder, exists = await resolve_path(parent_path, current_user) _, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists: if not exists:
raise HTTPException(status_code=409, detail="Parent folder does not exist") raise HTTPException(status_code=409, detail="Parent folder does not exist")
existing = await Folder.get_or_none( existing = await Folder.get_or_none(
name=folder_name, name=folder_name, parent=parent_folder, owner=current_user, is_deleted=False
parent=parent_folder,
owner=current_user,
is_deleted=False
) )
if existing: if existing:
raise HTTPException(status_code=405, detail="Folder already exists") raise HTTPException(status_code=405, detail="Folder already exists")
folder = await Folder.create( folder = await Folder.create(
name=folder_name, name=folder_name, parent=parent_folder, owner=current_user
parent=parent_folder,
owner=current_user
) )
await log_activity(current_user, "folder_created", "folder", folder.id) await log_activity(current_user, "folder_created", "folder", folder.id)
return Response(status_code=201) return Response(status_code=201)
@router.api_route("/{full_path:path}", methods=["COPY"]) @router.api_route("/{full_path:path}", methods=["COPY"])
async def handle_copy(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_copy(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination") destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T") overwrite = request.headers.get("Overwrite", "T")
@ -523,7 +600,7 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=400, detail="Destination header required") raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path) dest_path = unquote(urlparse(destination).path)
dest_path = dest_path.replace('/webdav/', '').strip('/') dest_path = dest_path.replace("/webdav/", "").strip("/")
source_resource, _, exists = await resolve_path(full_path, current_user) source_resource, _, exists = await resolve_path(full_path, current_user)
@ -533,9 +610,9 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
if not isinstance(source_resource, File): if not isinstance(source_resource, File):
raise HTTPException(status_code=501, detail="Only file copy is implemented") raise HTTPException(status_code=501, detail="Only file copy is implemented")
dest_parts = [p for p in dest_path.split('/') if p] dest_parts = [p for p in dest_path.split("/") if p]
dest_name = dest_parts[-1] dest_name = dest_parts[-1]
dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else '' dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user) _, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@ -543,14 +620,13 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=409, detail="Destination parent does not exist") raise HTTPException(status_code=409, detail="Destination parent does not exist")
existing_dest = await File.get_or_none( existing_dest = await File.get_or_none(
name=dest_name, name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
parent=dest_parent,
owner=current_user,
is_deleted=False
) )
if existing_dest and overwrite == "F": if existing_dest and overwrite == "F":
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false") raise HTTPException(
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest: if existing_dest:
await existing_dest.delete() await existing_dest.delete()
@ -562,15 +638,18 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
mime_type=source_resource.mime_type, mime_type=source_resource.mime_type,
file_hash=source_resource.file_hash, file_hash=source_resource.file_hash,
owner=current_user, owner=current_user,
parent=dest_parent parent=dest_parent,
) )
await log_activity(current_user, "file_copied", "file", new_file.id) await log_activity(current_user, "file_copied", "file", new_file.id)
return Response(status_code=201 if not existing_dest else 204) return Response(status_code=201 if not existing_dest else 204)
@router.api_route("/{full_path:path}", methods=["MOVE"]) @router.api_route("/{full_path:path}", methods=["MOVE"])
async def handle_move(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_move(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination") destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T") overwrite = request.headers.get("Overwrite", "T")
@ -578,16 +657,16 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=400, detail="Destination header required") raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path) dest_path = unquote(urlparse(destination).path)
dest_path = dest_path.replace('/webdav/', '').strip('/') dest_path = dest_path.replace("/webdav/", "").strip("/")
source_resource, _, exists = await resolve_path(full_path, current_user) source_resource, _, exists = await resolve_path(full_path, current_user)
if not source_resource: if not source_resource:
raise HTTPException(status_code=404, detail="Source not found") raise HTTPException(status_code=404, detail="Source not found")
dest_parts = [p for p in dest_path.split('/') if p] dest_parts = [p for p in dest_path.split("/") if p]
dest_name = dest_parts[-1] dest_name = dest_parts[-1]
dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else '' dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user) _, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@ -596,14 +675,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
if isinstance(source_resource, File): if isinstance(source_resource, File):
existing_dest = await File.get_or_none( existing_dest = await File.get_or_none(
name=dest_name, name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
parent=dest_parent,
owner=current_user,
is_deleted=False
) )
if existing_dest and overwrite == "F": if existing_dest and overwrite == "F":
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false") raise HTTPException(
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest: if existing_dest:
await existing_dest.delete() await existing_dest.delete()
@ -617,14 +695,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
elif isinstance(source_resource, Folder): elif isinstance(source_resource, Folder):
existing_dest = await Folder.get_or_none( existing_dest = await Folder.get_or_none(
name=dest_name, name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
parent=dest_parent,
owner=current_user,
is_deleted=False
) )
if existing_dest and overwrite == "F": if existing_dest and overwrite == "F":
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false") raise HTTPException(
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest: if existing_dest:
existing_dest.is_deleted = True existing_dest.is_deleted = True
@ -637,9 +714,12 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
await log_activity(current_user, "folder_moved", "folder", source_resource.id) await log_activity(current_user, "folder_moved", "folder", source_resource.id)
return Response(status_code=201 if not existing_dest else 204) return Response(status_code=201 if not existing_dest else 204)
@router.api_route("/{full_path:path}", methods=["LOCK"]) @router.api_route("/{full_path:path}", methods=["LOCK"])
async def handle_lock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_lock(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
timeout_header = request.headers.get("Timeout", "Second-3600") timeout_header = request.headers.get("Timeout", "Second-3600")
timeout = 3600 timeout = 3600
@ -681,32 +761,38 @@ async def handle_lock(request: Request, full_path: str, current_user: User = Dep
content=xml_content, content=xml_content,
media_type="application/xml; charset=utf-8", media_type="application/xml; charset=utf-8",
status_code=200, status_code=200,
headers={"Lock-Token": f"<{lock_token}>"} headers={"Lock-Token": f"<{lock_token}>"},
) )
@router.api_route("/{full_path:path}", methods=["UNLOCK"]) @router.api_route("/{full_path:path}", methods=["UNLOCK"])
async def handle_unlock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_unlock(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
lock_token_header = request.headers.get("Lock-Token") lock_token_header = request.headers.get("Lock-Token")
if not lock_token_header: if not lock_token_header:
raise HTTPException(status_code=400, detail="Lock-Token header required") raise HTTPException(status_code=400, detail="Lock-Token header required")
lock_token = lock_token_header.strip('<>') lock_token = lock_token_header.strip("<>")
existing_lock = WebDAVLock.get_lock(full_path) existing_lock = WebDAVLock.get_lock(full_path)
if not existing_lock or existing_lock['token'] != lock_token: if not existing_lock or existing_lock["token"] != lock_token:
raise HTTPException(status_code=409, detail="Invalid lock token") raise HTTPException(status_code=409, detail="Invalid lock token")
if existing_lock['user_id'] != current_user.id: if existing_lock["user_id"] != current_user.id:
raise HTTPException(status_code=403, detail="Not lock owner") raise HTTPException(status_code=403, detail="Not lock owner")
WebDAVLock.remove_lock(full_path) WebDAVLock.remove_lock(full_path)
return Response(status_code=204) return Response(status_code=204)
@router.api_route("/{full_path:path}", methods=["PROPPATCH"]) @router.api_route("/{full_path:path}", methods=["PROPPATCH"])
async def handle_proppatch(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): async def handle_proppatch(
full_path = unquote(full_path).strip('/') request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user) resource, parent, exists = await resolve_path(full_path, current_user)
@ -719,7 +805,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
try: try:
root = ET.fromstring(body) root = ET.fromstring(body)
except: except Exception:
raise HTTPException(status_code=400, detail="Invalid XML") raise HTTPException(status_code=400, detail="Invalid XML")
resource_type = "file" if isinstance(resource, File) else "folder" resource_type = "file" if isinstance(resource, File) else "folder"
@ -734,13 +820,19 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
prop_element = set_element.find(".//{DAV:}prop") prop_element = set_element.find(".//{DAV:}prop")
if prop_element is not None: if prop_element is not None:
for child in prop_element: for child in prop_element:
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:" ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
name = child.tag.split('}')[1] if '}' in child.tag else child.tag name = child.tag.split("}")[1] if "}" in child.tag else child.tag
value = child.text or "" value = child.text or ""
if ns == "DAV:": if ns == "DAV:":
live_props = ["creationdate", "getcontentlength", "getcontenttype", live_props = [
"getetag", "getlastmodified", "resourcetype"] "creationdate",
"getcontentlength",
"getcontenttype",
"getetag",
"getlastmodified",
"resourcetype",
]
if name in live_props: if name in live_props:
failed_props.append((ns, name, "409 Conflict")) failed_props.append((ns, name, "409 Conflict"))
continue continue
@ -750,7 +842,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
namespace=ns, namespace=ns,
name=name name=name,
) )
if existing_prop: if existing_prop:
@ -762,10 +854,10 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_id=resource_id, resource_id=resource_id,
namespace=ns, namespace=ns,
name=name, name=name,
value=value value=value,
) )
set_props.append((ns, name)) set_props.append((ns, name))
except Exception as e: except Exception:
failed_props.append((ns, name, "500 Internal Server Error")) failed_props.append((ns, name, "500 Internal Server Error"))
remove_element = root.find(".//{DAV:}remove") remove_element = root.find(".//{DAV:}remove")
@ -773,8 +865,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
prop_element = remove_element.find(".//{DAV:}prop") prop_element = remove_element.find(".//{DAV:}prop")
if prop_element is not None: if prop_element is not None:
for child in prop_element: for child in prop_element:
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:" ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
name = child.tag.split('}')[1] if '}' in child.tag else child.tag name = child.tag.split("}")[1] if "}" in child.tag else child.tag
if ns == "DAV:": if ns == "DAV:":
failed_props.append((ns, name, "409 Conflict")) failed_props.append((ns, name, "409 Conflict"))
@ -785,7 +877,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
namespace=ns, namespace=ns,
name=name name=name,
) )
if existing_prop: if existing_prop:
@ -793,7 +885,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
remove_props.append((ns, name)) remove_props.append((ns, name))
else: else:
failed_props.append((ns, name, "404 Not Found")) failed_props.append((ns, name, "404 Not Found"))
except Exception as e: except Exception:
failed_props.append((ns, name, "500 Internal Server Error")) failed_props.append((ns, name, "500 Internal Server Error"))
multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"}) multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"})
@ -837,4 +929,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
await log_activity(current_user, "properties_modified", resource_type, resource_id) await log_activity(current_user, "properties_modified", resource_type, resource_id)
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True) xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207) return Response(
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=207,
)

View File

@ -6,9 +6,15 @@ from httpx import AsyncClient, ASGITransport
from fastapi import status from fastapi import status
from mywebdav.main import app from mywebdav.main import app
from mywebdav.models import User from mywebdav.models import User
from mywebdav.billing.models import PricingConfig, Invoice, UsageAggregate, UserSubscription from mywebdav.billing.models import (
PricingConfig,
Invoice,
UsageAggregate,
UserSubscription,
)
from mywebdav.auth import create_access_token from mywebdav.auth import create_access_token
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
@ -16,11 +22,12 @@ async def test_user():
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True,
is_superuser=False is_superuser=False,
) )
yield user yield user
await user.delete() await user.delete()
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_user(): async def admin_user():
user = await User.create( user = await User.create(
@ -28,58 +35,72 @@ async def admin_user():
email="admin@example.com", email="admin@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True,
is_superuser=True is_superuser=True,
) )
yield user yield user
await user.delete() await user.delete()
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def auth_token(test_user): async def auth_token(test_user):
token = create_access_token(data={"sub": test_user.username}) token = create_access_token(data={"sub": test_user.username})
return token return token
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_token(admin_user): async def admin_token(admin_user):
token = create_access_token(data={"sub": admin_user.username}) token = create_access_token(data={"sub": admin_user.username})
return token return token
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def pricing_config(): async def pricing_config():
configs = [] configs = []
configs.append(await PricingConfig.create( configs.append(
config_key="storage_per_gb_month", await PricingConfig.create(
config_value=Decimal("0.0045"), config_key="storage_per_gb_month",
description="Storage cost per GB per month", config_value=Decimal("0.0045"),
unit="per_gb_month" description="Storage cost per GB per month",
)) unit="per_gb_month",
configs.append(await PricingConfig.create( )
config_key="bandwidth_egress_per_gb", )
config_value=Decimal("0.009"), configs.append(
description="Bandwidth egress cost per GB", await PricingConfig.create(
unit="per_gb" config_key="bandwidth_egress_per_gb",
)) config_value=Decimal("0.009"),
configs.append(await PricingConfig.create( description="Bandwidth egress cost per GB",
config_key="free_tier_storage_gb", unit="per_gb",
config_value=Decimal("15"), )
description="Free tier storage in GB", )
unit="gb" configs.append(
)) await PricingConfig.create(
configs.append(await PricingConfig.create( config_key="free_tier_storage_gb",
config_key="free_tier_bandwidth_gb", config_value=Decimal("15"),
config_value=Decimal("15"), description="Free tier storage in GB",
description="Free tier bandwidth in GB per month", unit="gb",
unit="gb" )
)) )
configs.append(
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
)
yield configs yield configs
for config in configs: for config in configs:
await config.delete() await config.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_usage(test_user, auth_token): async def test_get_current_usage(test_user, auth_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/billing/usage/current", "/api/billing/usage/current",
headers={"Authorization": f"Bearer {auth_token}"} headers={"Authorization": f"Bearer {auth_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -87,6 +108,7 @@ async def test_get_current_usage(test_user, auth_token):
assert "storage_gb" in data assert "storage_gb" in data
assert "bandwidth_down_gb_today" in data assert "bandwidth_down_gb_today" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_monthly_usage(test_user, auth_token): async def test_get_monthly_usage(test_user, auth_token):
today = date.today() today = date.today()
@ -94,16 +116,18 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024 ** 3 * 10, storage_bytes_avg=1024**3 * 10,
storage_bytes_peak=1024 ** 3 * 12, storage_bytes_peak=1024**3 * 12,
bandwidth_up_bytes=1024 ** 3 * 2, bandwidth_up_bytes=1024**3 * 2,
bandwidth_down_bytes=1024 ** 3 * 5 bandwidth_down_bytes=1024**3 * 5,
) )
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
f"/api/billing/usage/monthly?year={today.year}&month={today.month}", f"/api/billing/usage/monthly?year={today.year}&month={today.month}",
headers={"Authorization": f"Bearer {auth_token}"} headers={"Authorization": f"Bearer {auth_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -112,12 +136,15 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_subscription(test_user, auth_token): async def test_get_subscription(test_user, auth_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/billing/subscription", "/api/billing/subscription",
headers={"Authorization": f"Bearer {auth_token}"} headers={"Authorization": f"Bearer {auth_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -127,6 +154,7 @@ async def test_get_subscription(test_user, auth_token):
await UserSubscription.filter(user=test_user).delete() await UserSubscription.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_invoices(test_user, auth_token): async def test_list_invoices(test_user, auth_token):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -137,13 +165,14 @@ async def test_list_invoices(test_user, auth_token):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open" status="open",
) )
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/billing/invoices", "/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"}
headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -153,6 +182,7 @@ async def test_list_invoices(test_user, auth_token):
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_invoice(test_user, auth_token): async def test_get_invoice(test_user, auth_token):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -163,13 +193,15 @@ async def test_get_invoice(test_user, auth_token):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open" status="open",
) )
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
f"/api/billing/invoices/{invoice.id}", f"/api/billing/invoices/{invoice.id}",
headers={"Authorization": f"Bearer {auth_token}"} headers={"Authorization": f"Bearer {auth_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -178,39 +210,45 @@ async def test_get_invoice(test_user, auth_token):
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_pricing(): async def test_get_pricing():
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/api/billing/pricing") response = await client.get("/api/billing/pricing")
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
assert isinstance(data, dict) assert isinstance(data, dict)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_get_pricing(admin_user, admin_token, pricing_config): async def test_admin_get_pricing(admin_user, admin_token, pricing_config):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/admin/billing/pricing", "/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
assert len(data) > 0 assert len(data) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_update_pricing(admin_user, admin_token, pricing_config): async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
config_id = pricing_config[0].id config_id = pricing_config[0].id
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.put( response = await client.put(
f"/api/admin/billing/pricing/{config_id}", f"/api/admin/billing/pricing/{config_id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}"},
json={ json={"config_key": "storage_per_gb_month", "config_value": 0.005},
"config_key": "storage_per_gb_month",
"config_value": 0.005
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -218,12 +256,15 @@ async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
updated = await PricingConfig.get(id=config_id) updated = await PricingConfig.get(id=config_id)
assert updated.config_value == Decimal("0.005") assert updated.config_value == Decimal("0.005")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_get_stats(admin_user, admin_token): async def test_admin_get_stats(admin_user, admin_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/admin/billing/stats", "/api/admin/billing/stats",
headers={"Authorization": f"Bearer {admin_token}"} headers={"Authorization": f"Bearer {admin_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -232,12 +273,15 @@ async def test_admin_get_stats(admin_user, admin_token):
assert "total_invoices" in data assert "total_invoices" in data
assert "pending_invoices" in data assert "pending_invoices" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token): async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/admin/billing/pricing", "/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {auth_token}"} headers={"Authorization": f"Bearer {auth_token}"},
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@ -3,57 +3,76 @@ import pytest_asyncio
from decimal import Decimal from decimal import Decimal
from datetime import date, datetime from datetime import date, datetime
from mywebdav.models import User from mywebdav.models import User
from mywebdav.billing.models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription from mywebdav.billing.models import (
Invoice,
InvoiceLineItem,
PricingConfig,
UsageAggregate,
UserSubscription,
)
from mywebdav.billing.invoice_generator import InvoiceGenerator from mywebdav.billing.invoice_generator import InvoiceGenerator
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True is_active=True,
) )
yield user yield user
await user.delete() await user.delete()
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def pricing_config(): async def pricing_config():
configs = [] configs = []
configs.append(await PricingConfig.create( configs.append(
config_key="storage_per_gb_month", await PricingConfig.create(
config_value=Decimal("0.0045"), config_key="storage_per_gb_month",
description="Storage cost per GB per month", config_value=Decimal("0.0045"),
unit="per_gb_month" description="Storage cost per GB per month",
)) unit="per_gb_month",
configs.append(await PricingConfig.create( )
config_key="bandwidth_egress_per_gb", )
config_value=Decimal("0.009"), configs.append(
description="Bandwidth egress cost per GB", await PricingConfig.create(
unit="per_gb" config_key="bandwidth_egress_per_gb",
)) config_value=Decimal("0.009"),
configs.append(await PricingConfig.create( description="Bandwidth egress cost per GB",
config_key="free_tier_storage_gb", unit="per_gb",
config_value=Decimal("15"), )
description="Free tier storage in GB", )
unit="gb" configs.append(
)) await PricingConfig.create(
configs.append(await PricingConfig.create( config_key="free_tier_storage_gb",
config_key="free_tier_bandwidth_gb", config_value=Decimal("15"),
config_value=Decimal("15"), description="Free tier storage in GB",
description="Free tier bandwidth in GB per month", unit="gb",
unit="gb" )
)) )
configs.append(await PricingConfig.create( configs.append(
config_key="tax_rate_default", await PricingConfig.create(
config_value=Decimal("0.0"), config_key="free_tier_bandwidth_gb",
description="Default tax rate", config_value=Decimal("15"),
unit="percentage" description="Free tier bandwidth in GB per month",
)) unit="gb",
)
)
configs.append(
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate",
unit="percentage",
)
)
yield configs yield configs
for config in configs: for config in configs:
await config.delete() await config.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_monthly_invoice_with_usage(test_user, pricing_config): async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
today = date.today() today = date.today()
@ -61,13 +80,15 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024 ** 3 * 50, storage_bytes_avg=1024**3 * 50,
storage_bytes_peak=1024 ** 3 * 55, storage_bytes_peak=1024**3 * 55,
bandwidth_up_bytes=1024 ** 3 * 10, bandwidth_up_bytes=1024**3 * 10,
bandwidth_down_bytes=1024 ** 3 * 20 bandwidth_down_bytes=1024**3 * 20,
) )
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month) invoice = await InvoiceGenerator.generate_monthly_invoice(
test_user, today.year, today.month
)
assert invoice is not None assert invoice is not None
assert invoice.user_id == test_user.id assert invoice.user_id == test_user.id
@ -80,6 +101,7 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await invoice.delete() await invoice.delete()
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config): async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config):
today = date.today() today = date.today()
@ -87,18 +109,21 @@ async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_confi
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024 ** 3 * 10, storage_bytes_avg=1024**3 * 10,
storage_bytes_peak=1024 ** 3 * 12, storage_bytes_peak=1024**3 * 12,
bandwidth_up_bytes=1024 ** 3 * 5, bandwidth_up_bytes=1024**3 * 5,
bandwidth_down_bytes=1024 ** 3 * 10 bandwidth_down_bytes=1024**3 * 10,
) )
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month) invoice = await InvoiceGenerator.generate_monthly_invoice(
test_user, today.year, today.month
)
assert invoice is None assert invoice is None
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finalize_invoice(test_user, pricing_config): async def test_finalize_invoice(test_user, pricing_config):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -109,7 +134,7 @@ async def test_finalize_invoice(test_user, pricing_config):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="draft" status="draft",
) )
finalized = await InvoiceGenerator.finalize_invoice(invoice) finalized = await InvoiceGenerator.finalize_invoice(invoice)
@ -118,6 +143,7 @@ async def test_finalize_invoice(test_user, pricing_config):
await finalized.delete() await finalized.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finalize_invoice_already_finalized(test_user, pricing_config): async def test_finalize_invoice_already_finalized(test_user, pricing_config):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -128,7 +154,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open" status="open",
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -136,6 +162,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mark_invoice_paid(test_user, pricing_config): async def test_mark_invoice_paid(test_user, pricing_config):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -146,7 +173,7 @@ async def test_mark_invoice_paid(test_user, pricing_config):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open" status="open",
) )
paid = await InvoiceGenerator.mark_invoice_paid(invoice) paid = await InvoiceGenerator.mark_invoice_paid(invoice)
@ -156,10 +183,13 @@ async def test_mark_invoice_paid(test_user, pricing_config):
await paid.delete() await paid.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_with_tax(test_user, pricing_config): async def test_invoice_with_tax(test_user, pricing_config):
# Update tax rate # Update tax rate
updated = await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.21")) updated = await PricingConfig.filter(config_key="tax_rate_default").update(
config_value=Decimal("0.21")
)
assert updated == 1 # Should update 1 row assert updated == 1 # Should update 1 row
# Verify the update worked # Verify the update worked
@ -171,13 +201,15 @@ async def test_invoice_with_tax(test_user, pricing_config):
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024 ** 3 * 50, storage_bytes_avg=1024**3 * 50,
storage_bytes_peak=1024 ** 3 * 55, storage_bytes_peak=1024**3 * 55,
bandwidth_up_bytes=1024 ** 3 * 10, bandwidth_up_bytes=1024**3 * 10,
bandwidth_down_bytes=1024 ** 3 * 20 bandwidth_down_bytes=1024**3 * 20,
) )
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month) invoice = await InvoiceGenerator.generate_monthly_invoice(
test_user, today.year, today.month
)
assert invoice is not None assert invoice is not None
assert invoice.tax > 0 assert invoice.tax > 0
@ -185,4 +217,6 @@ async def test_invoice_with_tax(test_user, pricing_config):
await invoice.delete() await invoice.delete()
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.0")) await PricingConfig.filter(config_key="tax_rate_default").update(
config_value=Decimal("0.0")
)

View File

@ -4,21 +4,30 @@ from decimal import Decimal
from datetime import date, datetime from datetime import date, datetime
from mywebdav.models import User from mywebdav.models import User
from mywebdav.billing.models import ( from mywebdav.billing.models import (
SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate, SubscriptionPlan,
Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent UserSubscription,
UsageRecord,
UsageAggregate,
Invoice,
InvoiceLineItem,
PricingConfig,
PaymentMethod,
BillingEvent,
) )
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True is_active=True,
) )
yield user yield user
await user.delete() await user.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscription_plan_creation(): async def test_subscription_plan_creation():
plan = await SubscriptionPlan.create( plan = await SubscriptionPlan.create(
@ -28,7 +37,7 @@ async def test_subscription_plan_creation():
storage_gb=100, storage_gb=100,
bandwidth_gb=100, bandwidth_gb=100,
price_monthly=Decimal("5.00"), price_monthly=Decimal("5.00"),
price_yearly=Decimal("50.00") price_yearly=Decimal("50.00"),
) )
assert plan.name == "starter" assert plan.name == "starter"
@ -36,12 +45,11 @@ async def test_subscription_plan_creation():
assert plan.price_monthly == Decimal("5.00") assert plan.price_monthly == Decimal("5.00")
await plan.delete() await plan.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_subscription_creation(test_user): async def test_user_subscription_creation(test_user):
subscription = await UserSubscription.create( subscription = await UserSubscription.create(
user=test_user, user=test_user, billing_type="pay_as_you_go", status="active"
billing_type="pay_as_you_go",
status="active"
) )
assert subscription.user_id == test_user.id assert subscription.user_id == test_user.id
@ -49,6 +57,7 @@ async def test_user_subscription_creation(test_user):
assert subscription.status == "active" assert subscription.status == "active"
await subscription.delete() await subscription.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_creation(test_user): async def test_usage_record_creation(test_user):
usage = await UsageRecord.create( usage = await UsageRecord.create(
@ -57,7 +66,7 @@ async def test_usage_record_creation(test_user):
amount_bytes=1024 * 1024 * 100, amount_bytes=1024 * 1024 * 100,
resource_type="file", resource_type="file",
resource_id=1, resource_id=1,
idempotency_key="test_key_123" idempotency_key="test_key_123",
) )
assert usage.user_id == test_user.id assert usage.user_id == test_user.id
@ -65,6 +74,7 @@ async def test_usage_record_creation(test_user):
assert usage.amount_bytes == 1024 * 1024 * 100 assert usage.amount_bytes == 1024 * 1024 * 100
await usage.delete() await usage.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_aggregate_creation(test_user): async def test_usage_aggregate_creation(test_user):
aggregate = await UsageAggregate.create( aggregate = await UsageAggregate.create(
@ -73,13 +83,14 @@ async def test_usage_aggregate_creation(test_user):
storage_bytes_avg=1024 * 1024 * 500, storage_bytes_avg=1024 * 1024 * 500,
storage_bytes_peak=1024 * 1024 * 600, storage_bytes_peak=1024 * 1024 * 600,
bandwidth_up_bytes=1024 * 1024 * 50, bandwidth_up_bytes=1024 * 1024 * 50,
bandwidth_down_bytes=1024 * 1024 * 100 bandwidth_down_bytes=1024 * 1024 * 100,
) )
assert aggregate.user_id == test_user.id assert aggregate.user_id == test_user.id
assert aggregate.storage_bytes_avg == 1024 * 1024 * 500 assert aggregate.storage_bytes_avg == 1024 * 1024 * 500
await aggregate.delete() await aggregate.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_creation(test_user): async def test_invoice_creation(test_user):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -90,7 +101,7 @@ async def test_invoice_creation(test_user):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="draft" status="draft",
) )
assert invoice.user_id == test_user.id assert invoice.user_id == test_user.id
@ -98,6 +109,7 @@ async def test_invoice_creation(test_user):
assert invoice.total == Decimal("10.00") assert invoice.total == Decimal("10.00")
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_line_item_creation(test_user): async def test_invoice_line_item_creation(test_user):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -108,7 +120,7 @@ async def test_invoice_line_item_creation(test_user):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="draft" status="draft",
) )
line_item = await InvoiceLineItem.create( line_item = await InvoiceLineItem.create(
@ -117,7 +129,7 @@ async def test_invoice_line_item_creation(test_user):
quantity=Decimal("100.000000"), quantity=Decimal("100.000000"),
unit_price=Decimal("0.100000"), unit_price=Decimal("0.100000"),
amount=Decimal("10.0000"), amount=Decimal("10.0000"),
item_type="storage" item_type="storage",
) )
assert line_item.invoice_id == invoice.id assert line_item.invoice_id == invoice.id
@ -126,6 +138,7 @@ async def test_invoice_line_item_creation(test_user):
await line_item.delete() await line_item.delete()
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pricing_config_creation(test_user): async def test_pricing_config_creation(test_user):
config = await PricingConfig.create( config = await PricingConfig.create(
@ -133,13 +146,14 @@ async def test_pricing_config_creation(test_user):
config_value=Decimal("0.0045"), config_value=Decimal("0.0045"),
description="Storage cost per GB per month", description="Storage cost per GB per month",
unit="per_gb_month", unit="per_gb_month",
updated_by=test_user updated_by=test_user,
) )
assert config.config_key == "storage_per_gb_month" assert config.config_key == "storage_per_gb_month"
assert config.config_value == Decimal("0.0045") assert config.config_value == Decimal("0.0045")
await config.delete() await config.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_payment_method_creation(test_user): async def test_payment_method_creation(test_user):
payment_method = await PaymentMethod.create( payment_method = await PaymentMethod.create(
@ -150,7 +164,7 @@ async def test_payment_method_creation(test_user):
last4="4242", last4="4242",
brand="visa", brand="visa",
exp_month=12, exp_month=12,
exp_year=2025 exp_year=2025,
) )
assert payment_method.user_id == test_user.id assert payment_method.user_id == test_user.id
@ -158,6 +172,7 @@ async def test_payment_method_creation(test_user):
assert payment_method.is_default is True assert payment_method.is_default is True
await payment_method.delete() await payment_method.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_billing_event_creation(test_user): async def test_billing_event_creation(test_user):
event = await BillingEvent.create( event = await BillingEvent.create(
@ -165,7 +180,7 @@ async def test_billing_event_creation(test_user):
event_type="invoice_created", event_type="invoice_created",
stripe_event_id="evt_test_123", stripe_event_id="evt_test_123",
data={"invoice_id": 1}, data={"invoice_id": 1},
processed=False processed=False,
) )
assert event.user_id == test_user.id assert event.user_id == test_user.id

View File

@ -1,18 +1,35 @@
import pytest import pytest
def test_billing_module_imports(): def test_billing_module_imports():
from mywebdav.billing import models, stripe_client, usage_tracker, invoice_generator, scheduler from mywebdav.billing import (
models,
stripe_client,
usage_tracker,
invoice_generator,
scheduler,
)
assert models is not None assert models is not None
assert stripe_client is not None assert stripe_client is not None
assert usage_tracker is not None assert usage_tracker is not None
assert invoice_generator is not None assert invoice_generator is not None
assert scheduler is not None assert scheduler is not None
def test_billing_models_exist(): def test_billing_models_exist():
from mywebdav.billing.models import ( from mywebdav.billing.models import (
SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate, SubscriptionPlan,
Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent UserSubscription,
UsageRecord,
UsageAggregate,
Invoice,
InvoiceLineItem,
PricingConfig,
PaymentMethod,
BillingEvent,
) )
assert SubscriptionPlan is not None assert SubscriptionPlan is not None
assert UserSubscription is not None assert UserSubscription is not None
assert UsageRecord is not None assert UsageRecord is not None
@ -23,55 +40,71 @@ def test_billing_models_exist():
assert PaymentMethod is not None assert PaymentMethod is not None
assert BillingEvent is not None assert BillingEvent is not None
def test_stripe_client_exists(): def test_stripe_client_exists():
from mywebdav.billing.stripe_client import StripeClient from mywebdav.billing.stripe_client import StripeClient
assert StripeClient is not None assert StripeClient is not None
assert hasattr(StripeClient, 'create_customer') assert hasattr(StripeClient, "create_customer")
assert hasattr(StripeClient, 'create_invoice') assert hasattr(StripeClient, "create_invoice")
assert hasattr(StripeClient, 'finalize_invoice') assert hasattr(StripeClient, "finalize_invoice")
def test_usage_tracker_exists(): def test_usage_tracker_exists():
from mywebdav.billing.usage_tracker import UsageTracker from mywebdav.billing.usage_tracker import UsageTracker
assert UsageTracker is not None assert UsageTracker is not None
assert hasattr(UsageTracker, 'track_storage') assert hasattr(UsageTracker, "track_storage")
assert hasattr(UsageTracker, 'track_bandwidth') assert hasattr(UsageTracker, "track_bandwidth")
assert hasattr(UsageTracker, 'aggregate_daily_usage') assert hasattr(UsageTracker, "aggregate_daily_usage")
assert hasattr(UsageTracker, 'get_current_storage') assert hasattr(UsageTracker, "get_current_storage")
assert hasattr(UsageTracker, 'get_monthly_usage') assert hasattr(UsageTracker, "get_monthly_usage")
def test_invoice_generator_exists(): def test_invoice_generator_exists():
from mywebdav.billing.invoice_generator import InvoiceGenerator from mywebdav.billing.invoice_generator import InvoiceGenerator
assert InvoiceGenerator is not None assert InvoiceGenerator is not None
assert hasattr(InvoiceGenerator, 'generate_monthly_invoice') assert hasattr(InvoiceGenerator, "generate_monthly_invoice")
assert hasattr(InvoiceGenerator, 'finalize_invoice') assert hasattr(InvoiceGenerator, "finalize_invoice")
assert hasattr(InvoiceGenerator, 'mark_invoice_paid') assert hasattr(InvoiceGenerator, "mark_invoice_paid")
def test_scheduler_exists(): def test_scheduler_exists():
from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler
assert scheduler is not None assert scheduler is not None
assert callable(start_scheduler) assert callable(start_scheduler)
assert callable(stop_scheduler) assert callable(stop_scheduler)
def test_routers_exist(): def test_routers_exist():
from mywebdav.routers import billing, admin_billing from mywebdav.routers import billing, admin_billing
assert billing is not None assert billing is not None
assert admin_billing is not None assert admin_billing is not None
assert hasattr(billing, 'router') assert hasattr(billing, "router")
assert hasattr(admin_billing, 'router') assert hasattr(admin_billing, "router")
def test_middleware_exists(): def test_middleware_exists():
from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware
assert UsageTrackingMiddleware is not None assert UsageTrackingMiddleware is not None
def test_settings_updated(): def test_settings_updated():
from mywebdav.settings import settings from mywebdav.settings import settings
assert hasattr(settings, 'STRIPE_SECRET_KEY')
assert hasattr(settings, 'STRIPE_PUBLISHABLE_KEY') assert hasattr(settings, "STRIPE_SECRET_KEY")
assert hasattr(settings, 'STRIPE_WEBHOOK_SECRET') assert hasattr(settings, "STRIPE_PUBLISHABLE_KEY")
assert hasattr(settings, 'BILLING_ENABLED') assert hasattr(settings, "STRIPE_WEBHOOK_SECRET")
assert hasattr(settings, "BILLING_ENABLED")
def test_main_includes_billing(): def test_main_includes_billing():
from mywebdav.main import app from mywebdav.main import app
routes = [route.path for route in app.routes] routes = [route.path for route in app.routes]
billing_routes = [r for r in routes if '/billing' in r] billing_routes = [r for r in routes if "/billing" in r]
assert len(billing_routes) > 0 assert len(billing_routes) > 0

View File

@ -6,24 +6,26 @@ from mywebdav.models import User, File, Folder
from mywebdav.billing.models import UsageRecord, UsageAggregate from mywebdav.billing.models import UsageRecord, UsageAggregate
from mywebdav.billing.usage_tracker import UsageTracker from mywebdav.billing.usage_tracker import UsageTracker
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True is_active=True,
) )
yield user yield user
await user.delete() await user.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_storage(test_user): async def test_track_storage(test_user):
await UsageTracker.track_storage( await UsageTracker.track_storage(
user=test_user, user=test_user,
amount_bytes=1024 * 1024 * 100, amount_bytes=1024 * 1024 * 100,
resource_type="file", resource_type="file",
resource_id=1 resource_id=1,
) )
records = await UsageRecord.filter(user=test_user, record_type="storage").all() records = await UsageRecord.filter(user=test_user, record_type="storage").all()
@ -32,6 +34,7 @@ async def test_track_storage(test_user):
await records[0].delete() await records[0].delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_bandwidth_upload(test_user): async def test_track_bandwidth_upload(test_user):
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
@ -39,7 +42,7 @@ async def test_track_bandwidth_upload(test_user):
amount_bytes=1024 * 1024 * 50, amount_bytes=1024 * 1024 * 50,
direction="up", direction="up",
resource_type="file", resource_type="file",
resource_id=1 resource_id=1,
) )
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all() records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all()
@ -48,6 +51,7 @@ async def test_track_bandwidth_upload(test_user):
await records[0].delete() await records[0].delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_bandwidth_download(test_user): async def test_track_bandwidth_download(test_user):
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
@ -55,32 +59,28 @@ async def test_track_bandwidth_download(test_user):
amount_bytes=1024 * 1024 * 75, amount_bytes=1024 * 1024 * 75,
direction="down", direction="down",
resource_type="file", resource_type="file",
resource_id=1 resource_id=1,
) )
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_down").all() records = await UsageRecord.filter(
user=test_user, record_type="bandwidth_down"
).all()
assert len(records) == 1 assert len(records) == 1
assert records[0].amount_bytes == 1024 * 1024 * 75 assert records[0].amount_bytes == 1024 * 1024 * 75
await records[0].delete() await records[0].delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_aggregate_daily_usage(test_user): async def test_aggregate_daily_usage(test_user):
await UsageTracker.track_storage( await UsageTracker.track_storage(user=test_user, amount_bytes=1024 * 1024 * 100)
user=test_user,
amount_bytes=1024 * 1024 * 100 await UsageTracker.track_bandwidth(
user=test_user, amount_bytes=1024 * 1024 * 50, direction="up"
) )
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
user=test_user, user=test_user, amount_bytes=1024 * 1024 * 75, direction="down"
amount_bytes=1024 * 1024 * 50,
direction="up"
)
await UsageTracker.track_bandwidth(
user=test_user,
amount_bytes=1024 * 1024 * 75,
direction="down"
) )
aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today()) aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today())
@ -93,12 +93,10 @@ async def test_aggregate_daily_usage(test_user):
await aggregate.delete() await aggregate.delete()
await UsageRecord.filter(user=test_user).delete() await UsageRecord.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_storage(test_user): async def test_get_current_storage(test_user):
folder = await Folder.create( folder = await Folder.create(name="Test Folder", owner=test_user)
name="Test Folder",
owner=test_user
)
file1 = await File.create( file1 = await File.create(
name="test1.txt", name="test1.txt",
@ -107,7 +105,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain", mime_type="text/plain",
owner=test_user, owner=test_user,
parent=folder, parent=folder,
is_deleted=False is_deleted=False,
) )
file2 = await File.create( file2 = await File.create(
@ -117,7 +115,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain", mime_type="text/plain",
owner=test_user, owner=test_user,
parent=folder, parent=folder,
is_deleted=False is_deleted=False,
) )
storage = await UsageTracker.get_current_storage(test_user) storage = await UsageTracker.get_current_storage(test_user)
@ -128,6 +126,7 @@ async def test_get_current_storage(test_user):
await file2.delete() await file2.delete()
await folder.delete() await folder.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_monthly_usage(test_user): async def test_get_monthly_usage(test_user):
today = date.today() today = date.today()
@ -138,7 +137,7 @@ async def test_get_monthly_usage(test_user):
storage_bytes_avg=1024 * 1024 * 1024 * 10, storage_bytes_avg=1024 * 1024 * 1024 * 10,
storage_bytes_peak=1024 * 1024 * 1024 * 12, storage_bytes_peak=1024 * 1024 * 1024 * 12,
bandwidth_up_bytes=1024 * 1024 * 1024 * 2, bandwidth_up_bytes=1024 * 1024 * 1024 * 2,
bandwidth_down_bytes=1024 * 1024 * 1024 * 5 bandwidth_down_bytes=1024 * 1024 * 1024 * 5,
) )
usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month) usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month)
@ -150,11 +149,14 @@ async def test_get_monthly_usage(test_user):
await aggregate.delete() await aggregate.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_monthly_usage_empty(test_user): async def test_get_monthly_usage_empty(test_user):
future_date = date.today() + timedelta(days=365) future_date = date.today() + timedelta(days=365)
usage = await UsageTracker.get_monthly_usage(test_user, future_date.year, future_date.month) usage = await UsageTracker.get_monthly_usage(
test_user, future_date.year, future_date.month
)
assert usage["storage_gb_avg"] == 0 assert usage["storage_gb_avg"] == 0
assert usage["storage_gb_peak"] == 0 assert usage["storage_gb_peak"] == 0

View File

@ -3,27 +3,21 @@ import pytest_asyncio
import asyncio import asyncio
from tortoise import Tortoise from tortoise import Tortoise
# Initialize Tortoise at module level
async def _init_tortoise(): @pytest_asyncio.fixture(scope="session")
async def init_db():
"""Initialize the database for testing."""
await Tortoise.init( await Tortoise.init(
db_url="sqlite://:memory:", db_url="sqlite://:memory:",
modules={ modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
"models": ["mywebdav.models"],
"billing": ["mywebdav.billing.models"]
}
) )
await Tortoise.generate_schemas() await Tortoise.generate_schemas()
# Run initialization
asyncio.run(_init_tortoise())
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def cleanup_db():
yield yield
await Tortoise.close_connections() await Tortoise.close_connections()
@pytest_asyncio.fixture(autouse=True)
async def setup_db(init_db):
"""Ensure database is initialized before each test."""
# The init_db fixture handles the setup
yield