From d350ab680762e159fca9a39484f17361efc4d990 Mon Sep 17 00:00:00 2001 From: retoor Date: Thu, 13 Nov 2025 23:22:05 +0100 Subject: [PATCH] Update. --- mywebdav/activity.py | 5 +- mywebdav/auth.py | 67 ++-- mywebdav/billing/invoice_generator.py | 57 ++-- mywebdav/billing/models.py | 48 +-- mywebdav/billing/scheduler.py | 13 +- mywebdav/billing/stripe_client.py | 52 ++-- mywebdav/billing/usage_tracker.py | 34 +-- mywebdav/legal.py | 36 ++- mywebdav/mail.py | 62 +++- mywebdav/main.py | 89 ++++-- mywebdav/middleware/usage_tracking.py | 22 +- mywebdav/models.py | 120 +++++--- mywebdav/routers/admin.py | 58 ++-- mywebdav/routers/admin_billing.py | 53 ++-- mywebdav/routers/auth.py | 116 +++++-- mywebdav/routers/billing.py | 191 +++++++----- mywebdav/routers/files.py | 390 ++++++++++++++++++------ mywebdav/routers/folders.py | 143 ++++++--- mywebdav/routers/search.py | 28 +- mywebdav/routers/shares.py | 139 ++++++--- mywebdav/routers/starred.py | 21 +- mywebdav/routers/users.py | 7 +- mywebdav/schemas.py | 24 +- mywebdav/settings.py | 13 +- mywebdav/storage.py | 10 +- mywebdav/thumbnails.py | 53 +++- mywebdav/two_factor.py | 17 +- mywebdav/webdav.py | 362 ++++++++++++++-------- tests/billing/test_api_endpoints.py | 156 ++++++---- tests/billing/test_invoice_generator.py | 138 +++++---- tests/billing/test_models.py | 45 ++- tests/billing/test_simple.py | 75 +++-- tests/billing/test_usage_tracker.py | 52 ++-- tests/conftest.py | 30 +- 34 files changed, 1851 insertions(+), 875 deletions(-) diff --git a/mywebdav/activity.py b/mywebdav/activity.py index 4c497c8..8537e56 100644 --- a/mywebdav/activity.py +++ b/mywebdav/activity.py @@ -1,17 +1,18 @@ from typing import Optional from .models import Activity, User + async def log_activity( user: Optional[User], action: str, target_type: str, target_id: int, - ip_address: Optional[str] = None + ip_address: Optional[str] = None, ): await Activity.create( user=user, action=action, target_type=target_type, target_id=target_id, - ip_address=ip_address + ip_address=ip_address, ) diff --git a/mywebdav/auth.py b/mywebdav/auth.py index fe9bbc3..d6f43cc 100644 --- a/mywebdav/auth.py +++ b/mywebdav/auth.py @@ -9,20 +9,29 @@ import bcrypt from .schemas import TokenData from .settings import settings from .models import User -from .two_factor import verify_totp_code # Import verify_totp_code +from .two_factor import verify_totp_code # Import verify_totp_code oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + def verify_password(plain_password, hashed_password): - password_bytes = plain_password[:72].encode('utf-8') - hashed_bytes = hashed_password.encode('utf-8') if isinstance(hashed_password, str) else hashed_password + password_bytes = plain_password[:72].encode("utf-8") + hashed_bytes = ( + hashed_password.encode("utf-8") + if isinstance(hashed_password, str) + else hashed_password + ) return bcrypt.checkpw(password_bytes, hashed_bytes) -def get_password_hash(password): - password_bytes = password[:72].encode('utf-8') - return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8') -async def authenticate_user(username: str, password: str, two_factor_code: Optional[str] = None): +def get_password_hash(password): + password_bytes = password[:72].encode("utf-8") + return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8") + + +async def authenticate_user( + username: str, password: str, two_factor_code: Optional[str] = None +): user = await User.get_or_none(username=username) if not user: return None @@ -33,19 +42,29 @@ async def authenticate_user(username: str, password: str, two_factor_code: Optio if not two_factor_code: return {"user": user, "2fa_required": True} if not verify_totp_code(user.two_factor_secret, two_factor_code): - return None # 2FA code is incorrect + return None # 2FA code is incorrect return {"user": user, "2fa_required": False} -def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, two_factor_verified: bool = False): + +def create_access_token( + data: dict, + expires_delta: Optional[timedelta] = None, + two_factor_verified: bool = False, +): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) to_encode.update({"exp": expire, "2fa_verified": two_factor_verified}) - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + encoded_jwt = jwt.encode( + to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM + ) return encoded_jwt + async def get_current_user(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -53,31 +72,45 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): headers={"WWW-Authenticate": "Bearer"}, ) try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] + ) username: str = payload.get("sub") two_factor_verified: bool = payload.get("2fa_verified", False) if username is None: raise credentials_exception - token_data = TokenData(username=username, two_factor_verified=two_factor_verified) + token_data = TokenData( + username=username, two_factor_verified=two_factor_verified + ) except JWTError: raise credentials_exception user = await User.get_or_none(username=token_data.username) if user is None: raise credentials_exception - user.token_data = token_data # Attach token_data to user for easy access + user.token_data = token_data # Attach token_data to user for easy access return user + async def get_current_active_user(current_user: User = Depends(get_current_user)): if not current_user.is_active: raise HTTPException(status_code=400, detail="Inactive user") return current_user + async def get_current_verified_user(current_user: User = Depends(get_current_user)): if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="2FA required and not verified") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="2FA required and not verified", + ) return current_user -async def get_current_admin_user(current_user: User = Depends(get_current_verified_user)): + +async def get_current_admin_user( + current_user: User = Depends(get_current_verified_user), +): if not current_user.is_superuser: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions" + ) return current_user diff --git a/mywebdav/billing/invoice_generator.py b/mywebdav/billing/invoice_generator.py index 86a73c0..e23313c 100644 --- a/mywebdav/billing/invoice_generator.py +++ b/mywebdav/billing/invoice_generator.py @@ -2,14 +2,17 @@ from datetime import datetime, date, timedelta, timezone from decimal import Decimal from typing import Optional from calendar import monthrange -from .models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription +from .models import Invoice, InvoiceLineItem, PricingConfig, UserSubscription from .usage_tracker import UsageTracker from .stripe_client import StripeClient from ..models import User + class InvoiceGenerator: @staticmethod - async def generate_monthly_invoice(user: User, year: int, month: int) -> Optional[Invoice]: + async def generate_monthly_invoice( + user: User, year: int, month: int + ) -> Optional[Invoice]: period_start = date(year, month, 1) days_in_month = monthrange(year, month)[1] period_end = date(year, month, days_in_month) @@ -19,19 +22,24 @@ class InvoiceGenerator: pricing = await PricingConfig.all() pricing_dict = {p.config_key: p.config_value for p in pricing} - storage_price_per_gb = pricing_dict.get('storage_per_gb_month', Decimal('0.0045')) - bandwidth_price_per_gb = pricing_dict.get('bandwidth_egress_per_gb', Decimal('0.009')) - free_storage_gb = pricing_dict.get('free_tier_storage_gb', Decimal('15')) - free_bandwidth_gb = pricing_dict.get('free_tier_bandwidth_gb', Decimal('15')) - tax_rate = pricing_dict.get('tax_rate_default', Decimal('0')) + storage_price_per_gb = pricing_dict.get( + "storage_per_gb_month", Decimal("0.0045") + ) + bandwidth_price_per_gb = pricing_dict.get( + "bandwidth_egress_per_gb", Decimal("0.009") + ) + free_storage_gb = pricing_dict.get("free_tier_storage_gb", Decimal("15")) + free_bandwidth_gb = pricing_dict.get("free_tier_bandwidth_gb", Decimal("15")) + tax_rate = pricing_dict.get("tax_rate_default", Decimal("0")) - storage_gb = Decimal(str(usage['storage_gb_avg'])) - bandwidth_gb = Decimal(str(usage['bandwidth_down_gb'])) + storage_gb = Decimal(str(usage["storage_gb_avg"])) + bandwidth_gb = Decimal(str(usage["bandwidth_down_gb"])) - billable_storage = max(Decimal('0'), storage_gb - free_storage_gb) - billable_bandwidth = max(Decimal('0'), bandwidth_gb - free_bandwidth_gb) + billable_storage = max(Decimal("0"), storage_gb - free_storage_gb) + billable_bandwidth = max(Decimal("0"), bandwidth_gb - free_bandwidth_gb) import math + billable_storage_rounded = Decimal(math.ceil(float(billable_storage))) billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth))) @@ -65,9 +73,9 @@ class InvoiceGenerator: "usage": usage, "pricing": { "storage_per_gb": float(storage_price_per_gb), - "bandwidth_per_gb": float(bandwidth_price_per_gb) - } - } + "bandwidth_per_gb": float(bandwidth_price_per_gb), + }, + }, ) if billable_storage_rounded > 0: @@ -78,7 +86,10 @@ class InvoiceGenerator: unit_price=storage_price_per_gb, amount=storage_cost, item_type="storage", - metadata={"avg_gb": float(storage_gb), "free_gb": float(free_storage_gb)} + metadata={ + "avg_gb": float(storage_gb), + "free_gb": float(free_storage_gb), + }, ) if billable_bandwidth_rounded > 0: @@ -89,7 +100,10 @@ class InvoiceGenerator: unit_price=bandwidth_price_per_gb, amount=bandwidth_cost, item_type="bandwidth", - metadata={"total_gb": float(bandwidth_gb), "free_gb": float(free_bandwidth_gb)} + metadata={ + "total_gb": float(bandwidth_gb), + "free_gb": float(free_bandwidth_gb), + }, ) if subscription and subscription.stripe_customer_id: @@ -100,7 +114,7 @@ class InvoiceGenerator: "amount": item.amount, "currency": "usd", "description": item.description, - "metadata": item.metadata or {} + "metadata": item.metadata or {}, } for item in line_items ] @@ -109,7 +123,7 @@ class InvoiceGenerator: customer_id=subscription.stripe_customer_id, description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}", line_items=stripe_line_items, - metadata={"mywebdav_invoice_id": str(invoice.id)} + metadata={"mywebdav_invoice_id": str(invoice.id)}, ) invoice.stripe_invoice_id = stripe_invoice.id @@ -135,8 +149,11 @@ class InvoiceGenerator: # Send invoice email from ..mail import queue_email + line_items = await invoice.line_items.all() - items_text = "\n".join([f"- {item.description}: ${item.amount}" for item in line_items]) + items_text = "\n".join( + [f"- {item.description}: ${item.amount}" for item in line_items] + ) body = f"""Dear {invoice.user.username}, Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available. @@ -174,7 +191,7 @@ The MyWebdav Team to_email=invoice.user.email, subject=f"Your MyWebdav Invoice {invoice.invoice_number}", body=body, - html=html + html=html, ) return invoice diff --git a/mywebdav/billing/models.py b/mywebdav/billing/models.py index 40ab5a0..d298fb0 100644 --- a/mywebdav/billing/models.py +++ b/mywebdav/billing/models.py @@ -1,8 +1,8 @@ from tortoise import fields, models -from decimal import Decimal + class SubscriptionPlan(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) name = fields.CharField(max_length=100, unique=True) display_name = fields.CharField(max_length=255) description = fields.TextField(null=True) @@ -18,10 +18,13 @@ class SubscriptionPlan(models.Model): class Meta: table = "subscription_plans" + class UserSubscription(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) user = fields.ForeignKeyField("models.User", related_name="subscription") - plan = fields.ForeignKeyField("billing.SubscriptionPlan", related_name="subscriptions", null=True) + plan = fields.ForeignKeyField( + "billing.SubscriptionPlan", related_name="subscriptions", null=True + ) billing_type = fields.CharField(max_length=20, default="pay_as_you_go") stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True) stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True) @@ -35,14 +38,15 @@ class UserSubscription(models.Model): class Meta: table = "user_subscriptions" + class UsageRecord(models.Model): - id = fields.BigIntField(pk=True) + id = fields.BigIntField(primary_key=True) user = fields.ForeignKeyField("models.User", related_name="usage_records") - record_type = fields.CharField(max_length=50, index=True) + record_type = fields.CharField(max_length=50, db_index=True) amount_bytes = fields.BigIntField() resource_type = fields.CharField(max_length=50, null=True) resource_id = fields.IntField(null=True) - timestamp = fields.DatetimeField(auto_now_add=True, index=True) + timestamp = fields.DatetimeField(auto_now_add=True, db_index=True) idempotency_key = fields.CharField(max_length=255, unique=True, null=True) metadata = fields.JSONField(null=True) @@ -50,8 +54,9 @@ class UsageRecord(models.Model): table = "usage_records" indexes = [("user_id", "record_type", "timestamp")] + class UsageAggregate(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) user = fields.ForeignKeyField("models.User", related_name="usage_aggregates") date = fields.DateField() storage_bytes_avg = fields.BigIntField(default=0) @@ -64,8 +69,9 @@ class UsageAggregate(models.Model): table = "usage_aggregates" unique_together = (("user", "date"),) + class Invoice(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) user = fields.ForeignKeyField("models.User", related_name="invoices") invoice_number = fields.CharField(max_length=50, unique=True) stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True) @@ -75,10 +81,10 @@ class Invoice(models.Model): tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0) total = fields.DecimalField(max_digits=10, decimal_places=4) currency = fields.CharField(max_length=3, default="USD") - status = fields.CharField(max_length=50, default="draft", index=True) + status = fields.CharField(max_length=50, default="draft", db_index=True) due_date = fields.DateField(null=True) paid_at = fields.DatetimeField(null=True) - created_at = fields.DatetimeField(auto_now_add=True, index=True) + created_at = fields.DatetimeField(auto_now_add=True, db_index=True) updated_at = fields.DatetimeField(auto_now=True) metadata = fields.JSONField(null=True) @@ -86,8 +92,9 @@ class Invoice(models.Model): table = "invoices" indexes = [("user_id", "status", "created_at")] + class InvoiceLineItem(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items") description = fields.TextField() quantity = fields.DecimalField(max_digits=15, decimal_places=6) @@ -100,20 +107,24 @@ class InvoiceLineItem(models.Model): class Meta: table = "invoice_line_items" + class PricingConfig(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) config_key = fields.CharField(max_length=100, unique=True) config_value = fields.DecimalField(max_digits=10, decimal_places=6) description = fields.TextField(null=True) unit = fields.CharField(max_length=50, null=True) - updated_by = fields.ForeignKeyField("models.User", related_name="pricing_updates", null=True) + updated_by = fields.ForeignKeyField( + "models.User", related_name="pricing_updates", null=True + ) updated_at = fields.DatetimeField(auto_now=True) class Meta: table = "pricing_config" + class PaymentMethod(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) user = fields.ForeignKeyField("models.User", related_name="payment_methods") stripe_payment_method_id = fields.CharField(max_length=255) type = fields.CharField(max_length=50) @@ -128,9 +139,12 @@ class PaymentMethod(models.Model): class Meta: table = "payment_methods" + class BillingEvent(models.Model): - id = fields.BigIntField(pk=True) - user = fields.ForeignKeyField("models.User", related_name="billing_events", null=True) + id = fields.BigIntField(primary_key=True) + user = fields.ForeignKeyField( + "models.User", related_name="billing_events", null=True + ) event_type = fields.CharField(max_length=100) stripe_event_id = fields.CharField(max_length=255, unique=True, null=True) data = fields.JSONField(null=True) diff --git a/mywebdav/billing/scheduler.py b/mywebdav/billing/scheduler.py index 7e086e0..cab5f28 100644 --- a/mywebdav/billing/scheduler.py +++ b/mywebdav/billing/scheduler.py @@ -1,7 +1,6 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from datetime import datetime, date, timedelta -import asyncio from .usage_tracker import UsageTracker from .invoice_generator import InvoiceGenerator @@ -9,6 +8,7 @@ from ..models import User scheduler = AsyncIOScheduler() + async def aggregate_daily_usage_for_all_users(): users = await User.filter(is_active=True).all() yesterday = date.today() - timedelta(days=1) @@ -19,6 +19,7 @@ async def aggregate_daily_usage_for_all_users(): except Exception as e: print(f"Failed to aggregate usage for user {user.id}: {e}") + async def generate_monthly_invoices(): now = datetime.now() last_month = now.month - 1 if now.month > 1 else 12 @@ -28,19 +29,22 @@ async def generate_monthly_invoices(): for user in users: try: - invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, last_month) + invoice = await InvoiceGenerator.generate_monthly_invoice( + user, year, last_month + ) if invoice: await InvoiceGenerator.finalize_invoice(invoice) except Exception as e: print(f"Failed to generate invoice for user {user.id}: {e}") + def start_scheduler(): scheduler.add_job( aggregate_daily_usage_for_all_users, CronTrigger(hour=1, minute=0), id="aggregate_daily_usage", name="Aggregate daily usage for all users", - replace_existing=True + replace_existing=True, ) scheduler.add_job( @@ -48,10 +52,11 @@ def start_scheduler(): CronTrigger(day=1, hour=2, minute=0), id="generate_monthly_invoices", name="Generate monthly invoices", - replace_existing=True + replace_existing=True, ) scheduler.start() + def stop_scheduler(): scheduler.shutdown() diff --git a/mywebdav/billing/stripe_client.py b/mywebdav/billing/stripe_client.py index 6c98f86..58b8d53 100644 --- a/mywebdav/billing/stripe_client.py +++ b/mywebdav/billing/stripe_client.py @@ -1,8 +1,8 @@ import stripe -from decimal import Decimal -from typing import Optional, Dict, Any +from typing import Dict from ..settings import settings + class StripeClient: @staticmethod def _ensure_api_key(): @@ -11,13 +11,12 @@ class StripeClient: stripe.api_key = settings.STRIPE_SECRET_KEY else: raise ValueError("Stripe API key not configured") + @staticmethod async def create_customer(email: str, name: str, metadata: Dict = None) -> str: StripeClient._ensure_api_key() customer = stripe.Customer.create( - email=email, - name=name, - metadata=metadata or {} + email=email, name=name, metadata=metadata or {} ) return customer.id @@ -26,7 +25,7 @@ class StripeClient: amount: int, currency: str = "usd", customer_id: str = None, - metadata: Dict = None + metadata: Dict = None, ) -> stripe.PaymentIntent: StripeClient._ensure_api_key() return stripe.PaymentIntent.create( @@ -34,32 +33,29 @@ class StripeClient: currency=currency, customer=customer_id, metadata=metadata or {}, - automatic_payment_methods={"enabled": True} + automatic_payment_methods={"enabled": True}, ) @staticmethod async def create_invoice( - customer_id: str, - description: str, - line_items: list, - metadata: Dict = None + customer_id: str, description: str, line_items: list, metadata: Dict = None ) -> stripe.Invoice: StripeClient._ensure_api_key() for item in line_items: stripe.InvoiceItem.create( customer=customer_id, - amount=int(item['amount'] * 100), - currency=item.get('currency', 'usd'), - description=item['description'], - metadata=item.get('metadata', {}) + amount=int(item["amount"] * 100), + currency=item.get("currency", "usd"), + description=item["description"], + metadata=item.get("metadata", {}), ) invoice = stripe.Invoice.create( customer=customer_id, description=description, auto_advance=True, - collection_method='charge_automatically', - metadata=metadata or {} + collection_method="charge_automatically", + metadata=metadata or {}, ) return invoice @@ -76,18 +72,15 @@ class StripeClient: @staticmethod async def attach_payment_method( - payment_method_id: str, - customer_id: str + payment_method_id: str, customer_id: str ) -> stripe.PaymentMethod: StripeClient._ensure_api_key() payment_method = stripe.PaymentMethod.attach( - payment_method_id, - customer=customer_id + payment_method_id, customer=customer_id ) stripe.Customer.modify( - customer_id, - invoice_settings={'default_payment_method': payment_method_id} + customer_id, invoice_settings={"default_payment_method": payment_method_id} ) return payment_method @@ -95,22 +88,15 @@ class StripeClient: @staticmethod async def list_payment_methods(customer_id: str, type: str = "card"): StripeClient._ensure_api_key() - return stripe.PaymentMethod.list( - customer=customer_id, - type=type - ) + return stripe.PaymentMethod.list(customer=customer_id, type=type) @staticmethod async def create_subscription( - customer_id: str, - price_id: str, - metadata: Dict = None + customer_id: str, price_id: str, metadata: Dict = None ) -> stripe.Subscription: StripeClient._ensure_api_key() return stripe.Subscription.create( - customer=customer_id, - items=[{'price': price_id}], - metadata=metadata or {} + customer=customer_id, items=[{"price": price_id}], metadata=metadata or {} ) @staticmethod diff --git a/mywebdav/billing/usage_tracker.py b/mywebdav/billing/usage_tracker.py index f110d29..22eb2ec 100644 --- a/mywebdav/billing/usage_tracker.py +++ b/mywebdav/billing/usage_tracker.py @@ -1,11 +1,10 @@ import uuid from datetime import datetime, date, timezone -from decimal import Decimal -from typing import Optional from tortoise.transactions import in_transaction from .models import UsageRecord, UsageAggregate from ..models import User + class UsageTracker: @staticmethod async def track_storage( @@ -13,7 +12,7 @@ class UsageTracker: amount_bytes: int, resource_type: str = None, resource_id: int = None, - metadata: dict = None + metadata: dict = None, ): idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" @@ -24,7 +23,7 @@ class UsageTracker: resource_type=resource_type, resource_id=resource_id, idempotency_key=idempotency_key, - metadata=metadata + metadata=metadata, ) @staticmethod @@ -34,7 +33,7 @@ class UsageTracker: direction: str = "down", resource_type: str = None, resource_id: int = None, - metadata: dict = None + metadata: dict = None, ): record_type = f"bandwidth_{direction}" idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" @@ -46,7 +45,7 @@ class UsageTracker: resource_type=resource_type, resource_id=resource_id, idempotency_key=idempotency_key, - metadata=metadata + metadata=metadata, ) @staticmethod @@ -61,24 +60,26 @@ class UsageTracker: user=user, record_type="storage", timestamp__gte=start_of_day, - timestamp__lte=end_of_day + timestamp__lte=end_of_day, ).all() - storage_avg = sum(r.amount_bytes for r in storage_records) // max(len(storage_records), 1) + storage_avg = sum(r.amount_bytes for r in storage_records) // max( + len(storage_records), 1 + ) storage_peak = max((r.amount_bytes for r in storage_records), default=0) bandwidth_up = await UsageRecord.filter( user=user, record_type="bandwidth_up", timestamp__gte=start_of_day, - timestamp__lte=end_of_day + timestamp__lte=end_of_day, ).all() bandwidth_down = await UsageRecord.filter( user=user, record_type="bandwidth_down", timestamp__gte=start_of_day, - timestamp__lte=end_of_day + timestamp__lte=end_of_day, ).all() total_up = sum(r.amount_bytes for r in bandwidth_up) @@ -92,8 +93,8 @@ class UsageTracker: "storage_bytes_avg": storage_avg, "storage_bytes_peak": storage_peak, "bandwidth_up_bytes": total_up, - "bandwidth_down_bytes": total_down - } + "bandwidth_down_bytes": total_down, + }, ) if not created: @@ -108,6 +109,7 @@ class UsageTracker: @staticmethod async def get_current_storage(user: User) -> int: from ..models import File + files = await File.filter(owner=user, is_deleted=False).all() return sum(f.size for f in files) @@ -121,9 +123,7 @@ class UsageTracker: end_date = date(year, month, last_day) aggregates = await UsageAggregate.filter( - user=user, - date__gte=start_date, - date__lte=end_date + user=user, date__gte=start_date, date__lte=end_date ).all() if not aggregates: @@ -132,7 +132,7 @@ class UsageTracker: "storage_gb_peak": 0, "bandwidth_up_gb": 0, "bandwidth_down_gb": 0, - "total_bandwidth_gb": 0 + "total_bandwidth_gb": 0, } storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates) @@ -145,5 +145,5 @@ class UsageTracker: "storage_gb_peak": round(storage_peak / (1024**3), 4), "bandwidth_up_gb": round(bandwidth_up / (1024**3), 4), "bandwidth_down_gb": round(bandwidth_down / (1024**3), 4), - "total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4) + "total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4), } diff --git a/mywebdav/legal.py b/mywebdav/legal.py index 90868a0..ea7afad 100644 --- a/mywebdav/legal.py +++ b/mywebdav/legal.py @@ -7,7 +7,7 @@ for various legal policies required for a European cloud storage provider. from abc import ABC, abstractmethod from datetime import datetime -from typing import Dict, Any +from typing import Dict class LegalDocument(ABC): @@ -17,10 +17,13 @@ class LegalDocument(ABC): Provides common structure and methods for generating legal content. """ - def __init__(self, company_name: str = "MyWebdav Technologies", - last_updated: str = None, - contact_email: str = "legal@mywebdav.eu", - website: str = "https://mywebdav.eu"): + def __init__( + self, + company_name: str = "MyWebdav Technologies", + last_updated: str = None, + contact_email: str = "legal@mywebdav.eu", + website: str = "https://mywebdav.eu", + ): self.company_name = company_name self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y") self.contact_email = contact_email @@ -853,20 +856,21 @@ class ContactComplaintMechanism(LegalDocument): def get_all_legal_documents() -> Dict[str, LegalDocument]: """Return a dictionary of all legal document instances.""" return { - 'privacy_policy': PrivacyPolicy(), - 'terms_of_service': TermsOfService(), - 'security_policy': SecurityPolicy(), - 'cookie_policy': CookiePolicy(), - 'data_processing_agreement': DataProcessingAgreement(), - 'compliance_statement': ComplianceStatement(), - 'data_portability_deletion_policy': DataPortabilityDeletionPolicy(), - 'contact_complaint_mechanism': ContactComplaintMechanism(), + "privacy_policy": PrivacyPolicy(), + "terms_of_service": TermsOfService(), + "security_policy": SecurityPolicy(), + "cookie_policy": CookiePolicy(), + "data_processing_agreement": DataProcessingAgreement(), + "compliance_statement": ComplianceStatement(), + "data_portability_deletion_policy": DataPortabilityDeletionPolicy(), + "contact_complaint_mechanism": ContactComplaintMechanism(), } def generate_legal_documents(output_dir: str = "static/legal"): """Generate all legal documents as Markdown and HTML files.""" import os + os.makedirs(output_dir, exist_ok=True) documents = get_all_legal_documents() @@ -875,17 +879,17 @@ def generate_legal_documents(output_dir: str = "static/legal"): # Generate Markdown md_filename = f"{doc_name}.md" md_path = os.path.join(output_dir, md_filename) - with open(md_path, 'w') as f: + with open(md_path, "w") as f: f.write(doc.to_markdown()) # Generate HTML html_filename = f"{doc_name}.html" html_path = os.path.join(output_dir, html_filename) - with open(html_path, 'w') as f: + with open(html_path, "w") as f: f.write(doc.to_html()) print(f"Generated {md_filename} and {html_filename}") if __name__ == "__main__": - generate_legal_documents() \ No newline at end of file + generate_legal_documents() diff --git a/mywebdav/mail.py b/mywebdav/mail.py index ba6aac3..97ff1e8 100644 --- a/mywebdav/mail.py +++ b/mywebdav/mail.py @@ -2,17 +2,26 @@ import asyncio import aiosmtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart -from typing import Optional, Dict, Any +from typing import Optional from .settings import settings + class EmailTask: - def __init__(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): + def __init__( + self, + to_email: str, + subject: str, + body: str, + html: Optional[str] = None, + **kwargs, + ): self.to_email = to_email self.subject = subject self.body = body self.html = html self.kwargs = kwargs + class EmailService: def __init__(self): self.queue = asyncio.Queue() @@ -38,7 +47,14 @@ class EmailService: except asyncio.CancelledError: pass - async def send_email(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): + async def send_email( + self, + to_email: str, + subject: str, + body: str, + html: Optional[str] = None, + **kwargs, + ): """Queue an email for sending""" task = EmailTask(to_email, subject, body, html, **kwargs) await self.queue.put(task) @@ -60,22 +76,26 @@ class EmailService: async def _send_email_task(self, task: EmailTask): """Send a single email task""" - if not settings.SMTP_HOST or not settings.SMTP_USERNAME or not settings.SMTP_PASSWORD: + if ( + not settings.SMTP_HOST + or not settings.SMTP_USERNAME + or not settings.SMTP_PASSWORD + ): print("SMTP not configured, skipping email send") return - msg = MIMEMultipart('alternative') - msg['From'] = settings.SMTP_SENDER_EMAIL - msg['To'] = task.to_email - msg['Subject'] = task.subject + msg = MIMEMultipart("alternative") + msg["From"] = settings.SMTP_SENDER_EMAIL + msg["To"] = task.to_email + msg["Subject"] = task.subject # Add text part - text_part = MIMEText(task.body, 'plain') + text_part = MIMEText(task.body, "plain") msg.attach(text_part) # Add HTML part if provided if task.html: - html_part = MIMEText(task.html, 'html') + html_part = MIMEText(task.html, "html") msg.attach(html_part) try: @@ -86,7 +106,7 @@ class EmailService: port=settings.SMTP_PORT, username=settings.SMTP_USERNAME, password=settings.SMTP_PASSWORD, - use_tls=True + use_tls=True, ) as smtp: await smtp.send_message(msg) print(f"Email sent to {task.to_email}") @@ -99,7 +119,10 @@ class EmailService: try: await smtp.starttls() except Exception as tls_error: - if "already using" in str(tls_error).lower() or "tls" in str(tls_error).lower(): + if ( + "already using" in str(tls_error).lower() + or "tls" in str(tls_error).lower() + ): # Connection is already using TLS, proceed without starttls pass else: @@ -111,14 +134,23 @@ class EmailService: print(f"Failed to send email to {task.to_email}: {e}") raise # Re-raise to let caller handle + # Global email service instance email_service = EmailService() + # Convenience functions -async def send_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): +async def send_email( + to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs +): """Send an email asynchronously""" await email_service.send_email(to_email, subject, body, html, **kwargs) -def queue_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs): + +def queue_email( + to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs +): """Queue an email for sending (fire and forget)""" - asyncio.create_task(email_service.send_email(to_email, subject, body, html, **kwargs)) \ No newline at end of file + asyncio.create_task( + email_service.send_email(to_email, subject, body, html, **kwargs) + ) diff --git a/mywebdav/main.py b/mywebdav/main.py index 5e7a40d..9f3b211 100644 --- a/mywebdav/main.py +++ b/mywebdav/main.py @@ -2,18 +2,33 @@ import argparse import uvicorn import logging from contextlib import asynccontextmanager -from fastapi import FastAPI, Request, status, HTTPException +from fastapi import FastAPI, Request, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, JSONResponse from tortoise.contrib.fastapi import register_tortoise from .settings import settings -from .routers import auth, users, folders, files, shares, search, admin, starred, billing, admin_billing +from .routers import ( + auth, + users, + folders, + files, + shares, + search, + admin, + starred, + billing, + admin_billing, +) from . import webdav from .schemas import ErrorResponse +from .middleware.usage_tracking import UsageTrackingMiddleware -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) + @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting up...") @@ -21,6 +36,7 @@ async def lifespan(app: FastAPI): from .billing.scheduler import start_scheduler from .billing.models import PricingConfig from .mail import email_service + start_scheduler() logger.info("Billing scheduler started") await email_service.start() @@ -28,28 +44,61 @@ async def lifespan(app: FastAPI): pricing_count = await PricingConfig.all().count() if pricing_count == 0: from decimal import Decimal - await PricingConfig.create(config_key='storage_per_gb_month', config_value=Decimal('0.0045'), description='Storage cost per GB per month', unit='per_gb_month') - await PricingConfig.create(config_key='bandwidth_egress_per_gb', config_value=Decimal('0.009'), description='Bandwidth egress cost per GB', unit='per_gb') - await PricingConfig.create(config_key='bandwidth_ingress_per_gb', config_value=Decimal('0.0'), description='Bandwidth ingress cost per GB (free)', unit='per_gb') - await PricingConfig.create(config_key='free_tier_storage_gb', config_value=Decimal('15'), description='Free tier storage in GB', unit='gb') - await PricingConfig.create(config_key='free_tier_bandwidth_gb', config_value=Decimal('15'), description='Free tier bandwidth in GB per month', unit='gb') - await PricingConfig.create(config_key='tax_rate_default', config_value=Decimal('0.0'), description='Default tax rate (0 = no tax)', unit='percentage') + + await PricingConfig.create( + config_key="storage_per_gb_month", + config_value=Decimal("0.0045"), + description="Storage cost per GB per month", + unit="per_gb_month", + ) + await PricingConfig.create( + config_key="bandwidth_egress_per_gb", + config_value=Decimal("0.009"), + description="Bandwidth egress cost per GB", + unit="per_gb", + ) + await PricingConfig.create( + config_key="bandwidth_ingress_per_gb", + config_value=Decimal("0.0"), + description="Bandwidth ingress cost per GB (free)", + unit="per_gb", + ) + await PricingConfig.create( + config_key="free_tier_storage_gb", + config_value=Decimal("15"), + description="Free tier storage in GB", + unit="gb", + ) + await PricingConfig.create( + config_key="free_tier_bandwidth_gb", + config_value=Decimal("15"), + description="Free tier bandwidth in GB per month", + unit="gb", + ) + await PricingConfig.create( + config_key="tax_rate_default", + config_value=Decimal("0.0"), + description="Default tax rate (0 = no tax)", + unit="percentage", + ) logger.info("Default pricing configuration initialized") yield from .billing.scheduler import stop_scheduler + stop_scheduler() logger.info("Billing scheduler stopped") await email_service.stop() logger.info("Email service stopped") print("Shutting down...") + app = FastAPI( title="MyWebdav Cloud Storage", description="A commercial cloud storage web application", version="0.1.0", - lifespan=lifespan + lifespan=lifespan, ) app.include_router(auth.router) @@ -64,8 +113,6 @@ app.include_router(billing.router) app.include_router(admin_billing.router) app.include_router(webdav.router) -from .middleware.usage_tracking import UsageTrackingMiddleware - app.add_middleware(UsageTrackingMiddleware) app.mount("/static", StaticFiles(directory="static"), name="static") @@ -73,35 +120,39 @@ app.mount("/static", StaticFiles(directory="static"), name="static") register_tortoise( app, db_url=settings.DATABASE_URL, - modules={ - "models": ["mywebdav.models"], - "billing": ["mywebdav.billing.models"] - }, + modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]}, generate_schemas=True, add_exception_handlers=True, ) + @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): - logger.error(f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}") + logger.error( + f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}" + ) return JSONResponse( status_code=exc.status_code, content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(), ) -@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse +@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse async def read_root(): with open("static/index.html", "r") as f: return f.read() + def main(): parser = argparse.ArgumentParser(description="Run the MyWebdav application.") - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address to bind to") + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host address to bind to" + ) parser.add_argument("--port", type=int, default=8000, help="Port to listen on") args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port) + if __name__ == "__main__": main() diff --git a/mywebdav/middleware/usage_tracking.py b/mywebdav/middleware/usage_tracking.py index 854e709..274b3ef 100644 --- a/mywebdav/middleware/usage_tracking.py +++ b/mywebdav/middleware/usage_tracking.py @@ -2,31 +2,35 @@ from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from ..billing.usage_tracker import UsageTracker + class UsageTrackingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): response = await call_next(request) - if hasattr(request.state, 'user') and request.state.user: + if hasattr(request.state, "user") and request.state.user: user = request.state.user - if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path: - content_length = response.headers.get('content-length') + if ( + request.method in ["POST", "PUT"] + and "/files/upload" in request.url.path + ): + content_length = response.headers.get("content-length") if content_length: await UsageTracker.track_bandwidth( user=user, amount_bytes=int(content_length), - direction='up', - metadata={'path': request.url.path} + direction="up", + metadata={"path": request.url.path}, ) - elif request.method == 'GET' and '/files/download' in request.url.path: - content_length = response.headers.get('content-length') + elif request.method == "GET" and "/files/download" in request.url.path: + content_length = response.headers.get("content-length") if content_length: await UsageTracker.track_bandwidth( user=user, amount_bytes=int(content_length), - direction='down', - metadata={'path': request.url.path} + direction="down", + metadata={"path": request.url.path}, ) return response diff --git a/mywebdav/models.py b/mywebdav/models.py index fe6027a..11c9cf0 100644 --- a/mywebdav/models.py +++ b/mywebdav/models.py @@ -1,9 +1,9 @@ from tortoise import fields, models from tortoise.contrib.pydantic import pydantic_model_creator -from datetime import datetime + class User(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) username = fields.CharField(max_length=20, unique=True) email = fields.CharField(max_length=255, unique=True) hashed_password = fields.CharField(max_length=255) @@ -11,7 +11,9 @@ class User(models.Model): is_superuser = fields.BooleanField(default=False) created_at = fields.DatetimeField(auto_now_add=True) updated_at = fields.DatetimeField(auto_now=True) - storage_quota_bytes = fields.BigIntField(default=10 * 1024 * 1024 * 1024) # 10 GB default + storage_quota_bytes = fields.BigIntField( + default=10 * 1024 * 1024 * 1024 + ) # 10 GB default used_storage_bytes = fields.BigIntField(default=0) plan_type = fields.CharField(max_length=50, default="free") two_factor_secret = fields.CharField(max_length=255, null=True) @@ -24,11 +26,16 @@ class User(models.Model): def __str__(self): return self.username + class Folder(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) name = fields.CharField(max_length=255) - parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField("models.Folder", related_name="children", null=True) - owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="folders") + parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField( + "models.Folder", related_name="children", null=True + ) + owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="folders" + ) created_at = fields.DatetimeField(auto_now_add=True) updated_at = fields.DatetimeField(auto_now=True) is_deleted = fields.BooleanField(default=False) @@ -36,21 +43,28 @@ class Folder(models.Model): class Meta: table = "folders" - unique_together = (("name", "parent", "owner"),) # Ensure unique folder names within a parent for an owner + unique_together = ( + ("name", "parent", "owner"), + ) # Ensure unique folder names within a parent for an owner def __str__(self): return self.name + class File(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) name = fields.CharField(max_length=255) - path = fields.CharField(max_length=1024) # Internal storage path + path = fields.CharField(max_length=1024) # Internal storage path size = fields.BigIntField() mime_type = fields.CharField(max_length=255) - file_hash = fields.CharField(max_length=64, null=True) # SHA-256 + file_hash = fields.CharField(max_length=64, null=True) # SHA-256 thumbnail_path = fields.CharField(max_length=1024, null=True) - parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="files", null=True) - owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="files") + parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField( + "models.Folder", related_name="files", null=True + ) + owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="files" + ) created_at = fields.DatetimeField(auto_now_add=True) updated_at = fields.DatetimeField(auto_now=True) is_deleted = fields.BooleanField(default=False) @@ -60,14 +74,19 @@ class File(models.Model): class Meta: table = "files" - unique_together = (("name", "parent", "owner"),) # Ensure unique file names within a parent for an owner + unique_together = ( + ("name", "parent", "owner"), + ) # Ensure unique file names within a parent for an owner def __str__(self): return self.name + class FileVersion(models.Model): - id = fields.IntField(pk=True) - file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="versions") + id = fields.IntField(primary_key=True) + file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField( + "models.File", related_name="versions" + ) version_path = fields.CharField(max_length=1024) size = fields.BigIntField() created_at = fields.DatetimeField(auto_now_add=True) @@ -75,47 +94,69 @@ class FileVersion(models.Model): class Meta: table = "file_versions" + class Share(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) token = fields.CharField(max_length=64, unique=True) - file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="shares", null=True) - folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="shares", null=True) - owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="shares") + file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField( + "models.File", related_name="shares", null=True + ) + folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField( + "models.Folder", related_name="shares", null=True + ) + owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="shares" + ) created_at = fields.DatetimeField(auto_now_add=True) expires_at = fields.DatetimeField(null=True) password_protected = fields.BooleanField(default=False) hashed_password = fields.CharField(max_length=255, null=True) access_count = fields.IntField(default=0) - permission_level = fields.CharField(max_length=50, default="viewer") # viewer, uploader, editor + permission_level = fields.CharField( + max_length=50, default="viewer" + ) # viewer, uploader, editor class Meta: table = "shares" + class Team(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) name = fields.CharField(max_length=255, unique=True) - owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="owned_teams") - members: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User", related_name="teams", through="team_members") + owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="owned_teams" + ) + members: fields.ManyToManyRelation[User] = fields.ManyToManyField( + "models.User", related_name="teams", through="team_members" + ) created_at = fields.DatetimeField(auto_now_add=True) class Meta: table = "teams" + class TeamMember(models.Model): - id = fields.IntField(pk=True) - team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField("models.Team", related_name="team_members") - user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="user_teams") - role = fields.CharField(max_length=50, default="member") # owner, admin, member + id = fields.IntField(primary_key=True) + team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField( + "models.Team", related_name="team_members" + ) + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="user_teams" + ) + role = fields.CharField(max_length=50, default="member") # owner, admin, member class Meta: table = "team_members" unique_together = (("team", "user"),) + class Activity(models.Model): - id = fields.IntField(pk=True) - user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="activities", null=True) + id = fields.IntField(primary_key=True) + user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="activities", null=True + ) action = fields.CharField(max_length=255) - target_type = fields.CharField(max_length=50) # file, folder, share, user, team + target_type = fields.CharField(max_length=50) # file, folder, share, user, team target_id = fields.IntField() ip_address = fields.CharField(max_length=45, null=True) timestamp = fields.DatetimeField(auto_now_add=True) @@ -123,13 +164,18 @@ class Activity(models.Model): class Meta: table = "activities" + class FileRequest(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) title = fields.CharField(max_length=255) description = fields.TextField(null=True) token = fields.CharField(max_length=64, unique=True) - owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="file_requests") - target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="file_requests") + owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( + "models.User", related_name="file_requests" + ) + target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField( + "models.Folder", related_name="file_requests" + ) created_at = fields.DatetimeField(auto_now_add=True) expires_at = fields.DatetimeField(null=True) is_active = fields.BooleanField(default=True) @@ -137,8 +183,9 @@ class FileRequest(models.Model): class Meta: table = "file_requests" + class WebDAVProperty(models.Model): - id = fields.IntField(pk=True) + id = fields.IntField(primary_key=True) resource_type = fields.CharField(max_length=10) resource_id = fields.IntField() namespace = fields.CharField(max_length=255) @@ -151,5 +198,8 @@ class WebDAVProperty(models.Model): table = "webdav_properties" unique_together = (("resource_type", "resource_id", "namespace", "name"),) + User_Pydantic = pydantic_model_creator(User, name="User_Pydantic") -UserIn_Pydantic = pydantic_model_creator(User, name="UserIn_Pydantic", exclude_readonly=True) +UserIn_Pydantic = pydantic_model_creator( + User, name="UserIn_Pydantic", exclude_readonly=True +) diff --git a/mywebdav/routers/admin.py b/mywebdav/routers/admin.py index 320861c..a6a7558 100644 --- a/mywebdav/routers/admin.py +++ b/mywebdav/routers/admin.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List from fastapi import APIRouter, Depends, HTTPException, status @@ -13,18 +13,25 @@ router = APIRouter( responses={403: {"description": "Not enough permissions"}}, ) + @router.get("/users", response_model=List[User_Pydantic]) async def get_all_users(): return await User.all() + @router.get("/users/{user_id}", response_model=User_Pydantic) async def get_user(user_id: int): user = await User.get_or_none(id=user_id) if not user: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) return user -@router.post("/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED) + +@router.post( + "/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED +) async def create_user_by_admin(user_in: UserCreate): user = await User.get_or_none(username=user_in.username) if user: @@ -44,66 +51,81 @@ async def create_user_by_admin(user_in: UserCreate): username=user_in.username, email=user_in.email, hashed_password=hashed_password, - is_superuser=False, # Admin creates regular users by default + is_superuser=False, # Admin creates regular users by default is_active=True, ) return await User_Pydantic.from_tortoise_orm(user) + @router.put("/users/{user_id}", response_model=User_Pydantic) async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate): user = await User.get_or_none(id=user_id) if not user: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) if user_update.username is not None and user_update.username != user.username: if await User.get_or_none(username=user_update.username): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken" + ) user.username = user_update.username - + if user_update.email is not None and user_update.email != user.email: if await User.get_or_none(email=user_update.email): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered", + ) user.email = user_update.email if user_update.password is not None: user.hashed_password = get_password_hash(user_update.password) - + if user_update.is_active is not None: user.is_active = user_update.is_active - + if user_update.is_superuser is not None: user.is_superuser = user_update.is_superuser - + if user_update.storage_quota_bytes is not None: user.storage_quota_bytes = user_update.storage_quota_bytes - + if user_update.plan_type is not None: user.plan_type = user_update.plan_type - + if user_update.is_2fa_enabled is not None: user.is_2fa_enabled = user_update.is_2fa_enabled if not user_update.is_2fa_enabled: user.two_factor_secret = None user.recovery_codes = None - + await user.save() return await User_Pydantic.from_tortoise_orm(user) + @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_user_by_admin(user_id: int): user = await User.get_or_none(id=user_id) if not user: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="User not found" + ) await user.delete() return {"message": "User deleted successfully"} + @router.post("/test-email") -async def send_test_email(to_email: str, subject: str = "Test Email", body: str = "This is a test email"): +async def send_test_email( + to_email: str, subject: str = "Test Email", body: str = "This is a test email" +): from ..mail import queue_email + queue_email( to_email=to_email, subject=subject, body=body, - html=f"

{subject}

{body}

" + html=f"

{subject}

{body}

", ) - return {"message": "Test email queued"} \ No newline at end of file + return {"message": "Test email queued"} diff --git a/mywebdav/routers/admin_billing.py b/mywebdav/routers/admin_billing.py index 54049cf..2f77be5 100644 --- a/mywebdav/routers/admin_billing.py +++ b/mywebdav/routers/admin_billing.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, Depends, HTTPException -from typing import List from decimal import Decimal from pydantic import BaseModel @@ -8,21 +7,25 @@ from ..models import User from ..billing.models import PricingConfig, Invoice, SubscriptionPlan from ..billing.invoice_generator import InvoiceGenerator + def require_superuser(current_user: User = Depends(get_current_user)): if not current_user.is_superuser: raise HTTPException(status_code=403, detail="Superuser privileges required") return current_user + router = APIRouter( prefix="/api/admin/billing", tags=["admin", "billing"], - dependencies=[Depends(require_superuser)] + dependencies=[Depends(require_superuser)], ) + class PricingConfigUpdate(BaseModel): config_key: str config_value: float + class PlanCreate(BaseModel): name: str display_name: str @@ -32,6 +35,7 @@ class PlanCreate(BaseModel): price_monthly: float price_yearly: float = None + @router.get("/pricing") async def get_all_pricing(current_user: User = Depends(require_superuser)): configs = await PricingConfig.all() @@ -42,16 +46,17 @@ async def get_all_pricing(current_user: User = Depends(require_superuser)): "config_value": float(c.config_value), "description": c.description, "unit": c.unit, - "updated_at": c.updated_at + "updated_at": c.updated_at, } for c in configs ] + @router.put("/pricing/{config_id}") async def update_pricing( config_id: int, update: PricingConfigUpdate, - current_user: User = Depends(require_superuser) + current_user: User = Depends(require_superuser), ): config = await PricingConfig.get_or_none(id=config_id) if not config: @@ -63,11 +68,10 @@ async def update_pricing( return {"message": "Pricing updated successfully"} + @router.post("/generate-invoices/{year}/{month}") async def generate_all_invoices( - year: int, - month: int, - current_user: User = Depends(require_superuser) + year: int, month: int, current_user: User = Depends(require_superuser) ): users = await User.filter(is_active=True).all() generated = [] @@ -76,35 +80,36 @@ async def generate_all_invoices( for user in users: invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month) if invoice: - generated.append({ - "user_id": user.id, - "invoice_id": invoice.id, - "total": float(invoice.total) - }) + generated.append( + { + "user_id": user.id, + "invoice_id": invoice.id, + "total": float(invoice.total), + } + ) else: skipped.append(user.id) - return { - "generated": len(generated), - "skipped": len(skipped), - "invoices": generated - } + return {"generated": len(generated), "skipped": len(skipped), "invoices": generated} + @router.post("/plans") async def create_plan( - plan_data: PlanCreate, - current_user: User = Depends(require_superuser) + plan_data: PlanCreate, current_user: User = Depends(require_superuser) ): plan = await SubscriptionPlan.create(**plan_data.dict()) return {"id": plan.id, "message": "Plan created successfully"} + @router.get("/stats") async def get_billing_stats(current_user: User = Depends(require_superuser)): - from tortoise.functions import Sum, Count + from tortoise.functions import Sum - total_revenue = await Invoice.filter(status="paid").annotate( - total_sum=Sum("total") - ).values("total_sum") + total_revenue = ( + await Invoice.filter(status="paid") + .annotate(total_sum=Sum("total")) + .values("total_sum") + ) invoice_count = await Invoice.all().count() pending_invoices = await Invoice.filter(status="open").count() @@ -112,5 +117,5 @@ async def get_billing_stats(current_user: User = Depends(require_superuser)): return { "total_revenue": float(total_revenue[0]["total_sum"] or 0), "total_invoices": invoice_count, - "pending_invoices": pending_invoices + "pending_invoices": pending_invoices, } diff --git a/mywebdav/routers/auth.py b/mywebdav/routers/auth.py index ffd8fd1..500c708 100644 --- a/mywebdav/routers/auth.py +++ b/mywebdav/routers/auth.py @@ -4,13 +4,23 @@ from typing import Optional, List from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel -from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password +from ..auth import ( + authenticate_user, + create_access_token, + get_password_hash, + get_current_user, + get_current_verified_user, + verify_password, +) from ..models import User -from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA +from ..schemas import Token, UserCreate from ..two_factor import ( - generate_totp_secret, generate_totp_uri, generate_qr_code_base64, - verify_totp_code, generate_recovery_codes, hash_recovery_codes, - verify_recovery_codes + generate_totp_secret, + generate_totp_uri, + generate_qr_code_base64, + verify_totp_code, + generate_recovery_codes, + hash_recovery_codes, ) router = APIRouter( @@ -18,27 +28,33 @@ router = APIRouter( tags=["auth"], ) + class LoginRequest(BaseModel): username: str password: str + class TwoFactorLogin(BaseModel): username: str password: str two_factor_code: Optional[str] = None + class TwoFactorSetupResponse(BaseModel): secret: str qr_code_base64: str recovery_codes: List[str] + class TwoFactorCode(BaseModel): two_factor_code: str + class TwoFactorDisable(BaseModel): password: str two_factor_code: str + @router.post("/register", response_model=Token) async def register_user(user_in: UserCreate): user = await User.get_or_none(username=user_in.username) @@ -63,22 +79,26 @@ async def register_user(user_in: UserCreate): # Send welcome email from ..mail import queue_email + queue_email( to_email=user.email, subject="Welcome to MyWebdav!", body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team", - html=f"

Welcome to MyWebdav!

Hi {user.username},

Welcome to MyWebdav! Your account has been created successfully.

Best regards,
The MyWebdav Team

" + html=f"

Welcome to MyWebdav!

Hi {user.username},

Welcome to MyWebdav! Your account has been created successfully.

Best regards,
The MyWebdav Team

", ) - access_token_expires = timedelta(minutes=30) # Use settings + access_token_expires = timedelta(minutes=30) # Use settings access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/token", response_model=Token) async def login_for_access_token(login_data: LoginRequest): - auth_result = await authenticate_user(login_data.username, login_data.password, None) + auth_result = await authenticate_user( + login_data.username, login_data.password, None + ) if not auth_result: raise HTTPException( @@ -97,16 +117,26 @@ async def login_for_access_token(login_data: LoginRequest): access_token_expires = timedelta(minutes=30) access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True + data={"sub": user.username}, + expires_delta=access_token_expires, + two_factor_verified=True, ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/2fa/setup", response_model=TwoFactorSetupResponse) -async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)): +async def setup_two_factor_authentication( + current_user: User = Depends(get_current_user), +): if current_user.is_2fa_enabled: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled." + ) if current_user.two_factor_secret: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="2FA setup already initiated. Verify or disable first.", + ) secret = generate_totp_secret() current_user.two_factor_secret = secret @@ -120,39 +150,66 @@ async def setup_two_factor_authentication(current_user: User = Depends(get_curre current_user.recovery_codes = ",".join(hashed_recovery_codes) await current_user.save() - return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes) + return TwoFactorSetupResponse( + secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes + ) + @router.post("/2fa/verify", response_model=Token) -async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)): +async def verify_two_factor_authentication( + two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user) +): if current_user.is_2fa_enabled: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled." + ) if not current_user.two_factor_secret: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated." + ) - if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.") + if not verify_totp_code( + current_user.two_factor_secret, two_factor_code_data.two_factor_code + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code." + ) current_user.is_2fa_enabled = True await current_user.save() - access_token_expires = timedelta(minutes=30) # Use settings + access_token_expires = timedelta(minutes=30) # Use settings access_token = create_access_token( - data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True + data={"sub": current_user.username}, + expires_delta=access_token_expires, + two_factor_verified=True, ) return {"access_token": access_token, "token_type": "bearer"} + @router.post("/2fa/disable", response_model=dict) -async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)): +async def disable_two_factor_authentication( + disable_data: TwoFactorDisable, + current_user: User = Depends(get_current_verified_user), +): if not current_user.is_2fa_enabled: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled." + ) # Verify password if not verify_password(disable_data.password, current_user.hashed_password): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password." + ) # Verify 2FA code - if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.") + if not verify_totp_code( + current_user.two_factor_secret, disable_data.two_factor_code + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code." + ) current_user.two_factor_secret = None current_user.is_2fa_enabled = False @@ -161,10 +218,15 @@ async def disable_two_factor_authentication(disable_data: TwoFactorDisable, curr return {"message": "2FA disabled successfully."} + @router.get("/2fa/recovery-codes", response_model=List[str]) -async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)): +async def get_new_recovery_codes( + current_user: User = Depends(get_current_verified_user), +): if not current_user.is_2fa_enabled: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled." + ) recovery_codes = generate_recovery_codes() hashed_recovery_codes = hash_recovery_codes(recovery_codes) diff --git a/mywebdav/routers/billing.py b/mywebdav/routers/billing.py index 4eb8ddd..8a69b27 100644 --- a/mywebdav/routers/billing.py +++ b/mywebdav/routers/billing.py @@ -1,25 +1,25 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import JSONResponse from typing import List, Optional from datetime import datetime, date -from decimal import Decimal -import calendar from ..auth import get_current_user from ..models import User from ..billing.models import ( - Invoice, InvoiceLineItem, UserSubscription, PricingConfig, - PaymentMethod, UsageAggregate, SubscriptionPlan + Invoice, + UserSubscription, + PricingConfig, + PaymentMethod, + UsageAggregate, + SubscriptionPlan, ) from ..billing.usage_tracker import UsageTracker from ..billing.invoice_generator import InvoiceGenerator from ..billing.stripe_client import StripeClient from pydantic import BaseModel -router = APIRouter( - prefix="/api/billing", - tags=["billing"] -) +router = APIRouter(prefix="/api/billing", tags=["billing"]) + class UsageResponse(BaseModel): storage_gb_avg: float @@ -29,6 +29,7 @@ class UsageResponse(BaseModel): total_bandwidth_gb: float period: str + class InvoiceResponse(BaseModel): id: int invoice_number: str @@ -42,6 +43,7 @@ class InvoiceResponse(BaseModel): paid_at: Optional[datetime] line_items: List[dict] + class SubscriptionResponse(BaseModel): id: int billing_type: str @@ -50,6 +52,7 @@ class SubscriptionResponse(BaseModel): current_period_start: Optional[datetime] current_period_end: Optional[datetime] + @router.get("/usage/current") async def get_current_usage(current_user: User = Depends(get_current_user)): try: @@ -61,25 +64,32 @@ async def get_current_usage(current_user: User = Depends(get_current_user)): if usage_today: return { "storage_gb": round(storage_bytes / (1024**3), 4), - "bandwidth_down_gb_today": round(usage_today.bandwidth_down_bytes / (1024**3), 4), - "bandwidth_up_gb_today": round(usage_today.bandwidth_up_bytes / (1024**3), 4), - "as_of": today.isoformat() + "bandwidth_down_gb_today": round( + usage_today.bandwidth_down_bytes / (1024**3), 4 + ), + "bandwidth_up_gb_today": round( + usage_today.bandwidth_up_bytes / (1024**3), 4 + ), + "as_of": today.isoformat(), } return { "storage_gb": round(storage_bytes / (1024**3), 4), "bandwidth_down_gb_today": 0, "bandwidth_up_gb_today": 0, - "as_of": today.isoformat() + "as_of": today.isoformat(), } except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to fetch usage data: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to fetch usage data: {str(e)}" + ) + @router.get("/usage/monthly") async def get_monthly_usage( year: Optional[int] = None, month: Optional[int] = None, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ) -> UsageResponse: try: if year is None or month is None: @@ -88,71 +98,85 @@ async def get_monthly_usage( month = now.month if not (1 <= month <= 12): - raise HTTPException(status_code=400, detail="Month must be between 1 and 12") + raise HTTPException( + status_code=400, detail="Month must be between 1 and 12" + ) if not (2020 <= year <= 2100): - raise HTTPException(status_code=400, detail="Year must be between 2020 and 2100") + raise HTTPException( + status_code=400, detail="Year must be between 2020 and 2100" + ) usage = await UsageTracker.get_monthly_usage(current_user, year, month) - return UsageResponse( - **usage, - period=f"{year}-{month:02d}" - ) + return UsageResponse(**usage, period=f"{year}-{month:02d}") except HTTPException: raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}" + ) + @router.get("/invoices") async def list_invoices( - limit: int = 50, - offset: int = 0, - current_user: User = Depends(get_current_user) + limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user) ) -> List[InvoiceResponse]: try: if limit < 1 or limit > 100: - raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + raise HTTPException( + status_code=400, detail="Limit must be between 1 and 100" + ) if offset < 0: raise HTTPException(status_code=400, detail="Offset must be non-negative") - invoices = await Invoice.filter(user=current_user).order_by("-created_at").offset(offset).limit(limit).all() + invoices = ( + await Invoice.filter(user=current_user) + .order_by("-created_at") + .offset(offset) + .limit(limit) + .all() + ) result = [] for invoice in invoices: line_items = await invoice.line_items.all() - result.append(InvoiceResponse( - id=invoice.id, - invoice_number=invoice.invoice_number, - period_start=invoice.period_start, - period_end=invoice.period_end, - subtotal=float(invoice.subtotal), - tax=float(invoice.tax), - total=float(invoice.total), - status=invoice.status, - due_date=invoice.due_date, - paid_at=invoice.paid_at, - line_items=[ - { - "description": item.description, - "quantity": float(item.quantity), - "unit_price": float(item.unit_price), - "amount": float(item.amount), - "type": item.item_type - } - for item in line_items - ] - )) + result.append( + InvoiceResponse( + id=invoice.id, + invoice_number=invoice.invoice_number, + period_start=invoice.period_start, + period_end=invoice.period_end, + subtotal=float(invoice.subtotal), + tax=float(invoice.tax), + total=float(invoice.total), + status=invoice.status, + due_date=invoice.due_date, + paid_at=invoice.paid_at, + line_items=[ + { + "description": item.description, + "quantity": float(item.quantity), + "unit_price": float(item.unit_price), + "amount": float(item.amount), + "type": item.item_type, + } + for item in line_items + ], + ) + ) return result except HTTPException: raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to fetch invoices: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to fetch invoices: {str(e)}" + ) + @router.get("/invoices/{invoice_id}") async def get_invoice( - invoice_id: int, - current_user: User = Depends(get_current_user) + invoice_id: int, current_user: User = Depends(get_current_user) ) -> InvoiceResponse: invoice = await Invoice.get_or_none(id=invoice_id, user=current_user) if not invoice: @@ -177,21 +201,22 @@ async def get_invoice( "quantity": float(item.quantity), "unit_price": float(item.unit_price), "amount": float(item.amount), - "type": item.item_type + "type": item.item_type, } for item in line_items - ] + ], ) + @router.get("/subscription") -async def get_subscription(current_user: User = Depends(get_current_user)) -> SubscriptionResponse: +async def get_subscription( + current_user: User = Depends(get_current_user), +) -> SubscriptionResponse: subscription = await UserSubscription.get_or_none(user=current_user) if not subscription: subscription = await UserSubscription.create( - user=current_user, - billing_type="pay_as_you_go", - status="active" + user=current_user, billing_type="pay_as_you_go", status="active" ) plan_name = None @@ -205,15 +230,19 @@ async def get_subscription(current_user: User = Depends(get_current_user)) -> Su plan_name=plan_name, status=subscription.status, current_period_start=subscription.current_period_start, - current_period_end=subscription.current_period_end + current_period_end=subscription.current_period_end, ) + @router.post("/payment-methods/setup-intent") async def create_setup_intent(current_user: User = Depends(get_current_user)): try: from ..settings import settings + if not settings.STRIPE_SECRET_KEY: - raise HTTPException(status_code=503, detail="Payment processing not configured") + raise HTTPException( + status_code=503, detail="Payment processing not configured" + ) subscription = await UserSubscription.get_or_none(user=current_user) @@ -221,7 +250,7 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)): customer_id = await StripeClient.create_customer( email=current_user.email, name=current_user.username, - metadata={"user_id": str(current_user.id)} + metadata={"user_id": str(current_user.id)}, ) if not subscription: @@ -229,27 +258,30 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)): user=current_user, billing_type="pay_as_you_go", stripe_customer_id=customer_id, - status="active" + status="active", ) else: subscription.stripe_customer_id = customer_id await subscription.save() import stripe + StripeClient._ensure_api_key() setup_intent = stripe.SetupIntent.create( - customer=subscription.stripe_customer_id, - payment_method_types=["card"] + customer=subscription.stripe_customer_id, payment_method_types=["card"] ) return { "client_secret": setup_intent.client_secret, - "customer_id": subscription.stripe_customer_id + "customer_id": subscription.stripe_customer_id, } except HTTPException: raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to create setup intent: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to create setup intent: {str(e)}" + ) + @router.get("/payment-methods") async def list_payment_methods(current_user: User = Depends(get_current_user)): @@ -262,11 +294,12 @@ async def list_payment_methods(current_user: User = Depends(get_current_user)): "brand": m.brand, "exp_month": m.exp_month, "exp_year": m.exp_year, - "is_default": m.is_default + "is_default": m.is_default, } for m in methods ] + @router.post("/webhooks/stripe") async def stripe_webhook(request: Request): import stripe @@ -298,12 +331,14 @@ async def stripe_webhook(request: Request): event_type=event["type"], stripe_event_id=event_id, data=event["data"], - processed=False + processed=False, ) if event["type"] == "invoice.payment_succeeded": invoice_data = event["data"]["object"] - mywebdav_invoice_id = invoice_data.get("metadata", {}).get("mywebdav_invoice_id") + mywebdav_invoice_id = invoice_data.get("metadata", {}).get( + "mywebdav_invoice_id" + ) if mywebdav_invoice_id: invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id)) @@ -317,7 +352,9 @@ async def stripe_webhook(request: Request): payment_method = event["data"]["object"] customer_id = payment_method["customer"] - subscription = await UserSubscription.get_or_none(stripe_customer_id=customer_id) + subscription = await UserSubscription.get_or_none( + stripe_customer_id=customer_id + ) if subscription: await PaymentMethod.create( user=subscription.user, @@ -327,7 +364,7 @@ async def stripe_webhook(request: Request): brand=payment_method.get("card", {}).get("brand"), exp_month=payment_method.get("card", {}).get("exp_month"), exp_year=payment_method.get("card", {}).get("exp_year"), - is_default=True + is_default=True, ) await BillingEvent.filter(stripe_event_id=event_id).update(processed=True) @@ -335,7 +372,10 @@ async def stripe_webhook(request: Request): except HTTPException: raise except Exception as e: - raise HTTPException(status_code=500, detail=f"Webhook processing failed: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Webhook processing failed: {str(e)}" + ) + @router.get("/pricing") async def get_pricing(): @@ -344,11 +384,12 @@ async def get_pricing(): config.config_key: { "value": float(config.config_value), "description": config.description, - "unit": config.unit + "unit": config.unit, } for config in configs } + @router.get("/plans") async def list_plans(): plans = await SubscriptionPlan.filter(is_active=True).all() @@ -361,14 +402,16 @@ async def list_plans(): "storage_gb": plan.storage_gb, "bandwidth_gb": plan.bandwidth_gb, "price_monthly": float(plan.price_monthly), - "price_yearly": float(plan.price_yearly) if plan.price_yearly else None + "price_yearly": float(plan.price_yearly) if plan.price_yearly else None, } for plan in plans ] + @router.get("/stripe-key") async def get_stripe_key(): from ..settings import settings + if not settings.STRIPE_PUBLISHABLE_KEY: raise HTTPException(status_code=503, detail="Payment processing not configured") return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY} diff --git a/mywebdav/routers/files.py b/mywebdav/routers/files.py index 626be85..b0c5348 100644 --- a/mywebdav/routers/files.py +++ b/mywebdav/routers/files.py @@ -1,4 +1,12 @@ -from fastapi import APIRouter, Depends, UploadFile, File as FastAPIFile, HTTPException, status, Response, Form +from fastapi import ( + APIRouter, + Depends, + UploadFile, + File as FastAPIFile, + HTTPException, + status, + Form, +) from fastapi.responses import StreamingResponse from typing import List, Optional import mimetypes @@ -11,7 +19,6 @@ from ..auth import get_current_user from ..models import User, File, Folder from ..schemas import FileOut from ..storage import storage_manager -from ..settings import settings from ..activity import log_activity from ..thumbnails import generate_thumbnail, delete_thumbnail @@ -20,35 +27,46 @@ router = APIRouter( tags=["files"], ) + class FileMove(BaseModel): target_folder_id: Optional[int] = None + class FileRename(BaseModel): new_name: str + class FileCopy(BaseModel): target_folder_id: Optional[int] = None + class BatchFileOperation(BaseModel): file_ids: List[int] - operation: str # e.g., "delete", "star", "unstar", "move", "copy" + operation: str # e.g., "delete", "star", "unstar", "move", "copy" + class BatchMoveCopyPayload(BaseModel): target_folder_id: Optional[int] = None + class FileContentUpdate(BaseModel): content: str + @router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED) async def upload_file( file: UploadFile = FastAPIFile(...), folder_id: Optional[int] = Form(None), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): if folder_id: - parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + parent_folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not parent_folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) else: parent_folder = None @@ -73,7 +91,7 @@ async def upload_file( # Generate a unique path for storage file_extension = os.path.splitext(file.filename)[1] - unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename + unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename storage_path = unique_filename # Save file to storage @@ -105,16 +123,20 @@ async def upload_file( return await FileOut.from_tortoise_orm(db_file) + @router.get("/download/{file_id}") async def download_file(file_id: int, current_user: User = Depends(get_current_user)): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) db_file.last_accessed_at = datetime.now() await db_file.save() try: + async def file_iterator(): async for chunk in storage_manager.get_file(current_user.id, db_file.path): yield chunk @@ -122,16 +144,21 @@ async def download_file(file_id: int, current_user: User = Depends(get_current_u return StreamingResponse( file_iterator(), media_type=db_file.mime_type, - headers={"Content-Disposition": f"attachment; filename=\"{db_file.name}\""} + headers={"Content-Disposition": f'attachment; filename="{db_file.name}"'}, ) except FileNotFoundError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage" + ) + @router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_file(file_id: int, current_user: User = Depends(get_current_user)): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) db_file.is_deleted = True db_file.deleted_at = datetime.now() @@ -141,68 +168,110 @@ async def delete_file(file_id: int, current_user: User = Depends(get_current_use return + @router.post("/{file_id}/move", response_model=FileOut) -async def move_file(file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)): +async def move_file( + file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user) +): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) target_folder = None if move_data.target_folder_id: - target_folder = await Folder.get_or_none(id=move_data.target_folder_id, owner=current_user, is_deleted=False) + target_folder = await Folder.get_or_none( + id=move_data.target_folder_id, owner=current_user, is_deleted=False + ) if not target_folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found" + ) existing_file = await File.get_or_none( name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False ) if existing_file and existing_file.id != file_id: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in target folder") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="File with this name already exists in target folder", + ) db_file.parent = target_folder await db_file.save() - await log_activity(user=current_user, action="file_moved", target_type="file", target_id=file_id) + await log_activity( + user=current_user, action="file_moved", target_type="file", target_id=file_id + ) return await FileOut.from_tortoise_orm(db_file) + @router.post("/{file_id}/rename", response_model=FileOut) -async def rename_file(file_id: int, rename_data: FileRename, current_user: User = Depends(get_current_user)): +async def rename_file( + file_id: int, + rename_data: FileRename, + current_user: User = Depends(get_current_user), +): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) existing_file = await File.get_or_none( - name=rename_data.new_name, parent_id=db_file.parent_id, owner=current_user, is_deleted=False + name=rename_data.new_name, + parent_id=db_file.parent_id, + owner=current_user, + is_deleted=False, ) if existing_file and existing_file.id != file_id: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in the same folder") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="File with this name already exists in the same folder", + ) db_file.name = rename_data.new_name await db_file.save() - await log_activity(user=current_user, action="file_renamed", target_type="file", target_id=file_id) + await log_activity( + user=current_user, action="file_renamed", target_type="file", target_id=file_id + ) return await FileOut.from_tortoise_orm(db_file) -@router.post("/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED) -async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)): + +@router.post( + "/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED +) +async def copy_file( + file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user) +): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) target_folder = None if copy_data.target_folder_id: - target_folder = await Folder.get_or_none(id=copy_data.target_folder_id, owner=current_user, is_deleted=False) + target_folder = await Folder.get_or_none( + id=copy_data.target_folder_id, owner=current_user, is_deleted=False + ) if not target_folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found" + ) base_name = db_file.name name_parts = os.path.splitext(base_name) counter = 1 new_name = base_name - while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False): + while await File.get_or_none( + name=new_name, parent=target_folder, owner=current_user, is_deleted=False + ): new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}" counter += 1 @@ -213,139 +282,205 @@ async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depe mime_type=db_file.mime_type, file_hash=db_file.file_hash, owner=current_user, - parent=target_folder + parent=target_folder, ) - await log_activity(user=current_user, action="file_copied", target_type="file", target_id=new_file.id) + await log_activity( + user=current_user, + action="file_copied", + target_type="file", + target_id=new_file.id, + ) return await FileOut.from_tortoise_orm(new_file) + @router.get("/", response_model=List[FileOut]) -async def list_files(folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)): +async def list_files( + folder_id: Optional[int] = None, current_user: User = Depends(get_current_user) +): if folder_id: - parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + parent_folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not parent_folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") - files = await File.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) + files = await File.filter( + parent=parent_folder, owner=current_user, is_deleted=False + ).order_by("name") else: - files = await File.filter(parent=None, owner=current_user, is_deleted=False).order_by("name") + files = await File.filter( + parent=None, owner=current_user, is_deleted=False + ).order_by("name") return [await FileOut.from_tortoise_orm(f) for f in files] + @router.get("/thumbnail/{file_id}") async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) db_file.last_accessed_at = datetime.now() await db_file.save() - thumbnail_path = getattr(db_file, 'thumbnail_path', None) + thumbnail_path = getattr(db_file, "thumbnail_path", None) if not thumbnail_path: - thumbnail_path = await generate_thumbnail(db_file.path, db_file.mime_type, current_user.id) + thumbnail_path = await generate_thumbnail( + db_file.path, db_file.mime_type, current_user.id + ) if thumbnail_path: db_file.thumbnail_path = thumbnail_path await db_file.save() else: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available" + ) try: + async def thumbnail_iterator(): - async for chunk in storage_manager.get_file(current_user.id, thumbnail_path): + async for chunk in storage_manager.get_file( + current_user.id, thumbnail_path + ): yield chunk - return StreamingResponse( - thumbnail_iterator(), - media_type="image/jpeg" - ) + return StreamingResponse(thumbnail_iterator(), media_type="image/jpeg") except FileNotFoundError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not found in storage") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Thumbnail not found in storage", + ) + @router.get("/photos", response_model=List[FileOut]) async def list_photos(current_user: User = Depends(get_current_user)): files = await File.filter( - owner=current_user, - is_deleted=False, - mime_type__istartswith="image/" + owner=current_user, is_deleted=False, mime_type__istartswith="image/" ).order_by("-created_at") return [await FileOut.from_tortoise_orm(f) for f in files] + @router.get("/recent", response_model=List[FileOut]) -async def list_recent_files(current_user: User = Depends(get_current_user), limit: int = 10): - files = await File.filter( - owner=current_user, - is_deleted=False, - last_accessed_at__isnull=False - ).order_by("-last_accessed_at").limit(limit) +async def list_recent_files( + current_user: User = Depends(get_current_user), limit: int = 10 +): + files = ( + await File.filter( + owner=current_user, is_deleted=False, last_accessed_at__isnull=False + ) + .order_by("-last_accessed_at") + .limit(limit) + ) return [await FileOut.from_tortoise_orm(f) for f in files] + @router.post("/{file_id}/star", response_model=FileOut) async def star_file(file_id: int, current_user: User = Depends(get_current_user)): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) db_file.is_starred = True await db_file.save() - await log_activity(user=current_user, action="file_starred", target_type="file", target_id=file_id) + await log_activity( + user=current_user, action="file_starred", target_type="file", target_id=file_id + ) return await FileOut.from_tortoise_orm(db_file) + @router.post("/{file_id}/unstar", response_model=FileOut) async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) db_file.is_starred = False await db_file.save() - await log_activity(user=current_user, action="file_unstarred", target_type="file", target_id=file_id) + await log_activity( + user=current_user, + action="file_unstarred", + target_type="file", + target_id=file_id, + ) return await FileOut.from_tortoise_orm(db_file) + @router.get("/deleted", response_model=List[FileOut]) async def list_deleted_files(current_user: User = Depends(get_current_user)): - files = await File.filter(owner=current_user, is_deleted=True).order_by("-deleted_at") + files = await File.filter(owner=current_user, is_deleted=True).order_by( + "-deleted_at" + ) return [await FileOut.from_tortoise_orm(f) for f in files] + @router.post("/{file_id}/restore", response_model=FileOut) async def restore_file(file_id: int, current_user: User = Depends(get_current_user)): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found" + ) + # Check if a file with the same name exists in the parent folder existing_file = await File.get_or_none( name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False ) if existing_file: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.") + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.", + ) db_file.is_deleted = False db_file.deleted_at = None await db_file.save() - await log_activity(user=current_user, action="file_restored", target_type="file", target_id=file_id) + await log_activity( + user=current_user, action="file_restored", target_type="file", target_id=file_id + ) return await FileOut.from_tortoise_orm(db_file) + class BatchOperationResult(BaseModel): succeeded: List[FileOut] failed: List[dict] + @router.post("/batch") async def batch_file_operations( batch_operation: BatchFileOperation, payload: Optional[BatchMoveCopyPayload] = None, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid operation: {batch_operation.operation}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid operation: {batch_operation.operation}", + ) updated_files = [] failed_operations = [] for file_id in batch_operation.file_ids: try: - db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) + db_file = await File.get_or_none( + id=file_id, owner=current_user, is_deleted=False + ) if not db_file: - failed_operations.append({"file_id": file_id, "reason": "File not found or not owned by user"}) + failed_operations.append( + { + "file_id": file_id, + "reason": "File not found or not owned by user", + } + ) continue if batch_operation.operation == "delete": @@ -353,47 +488,87 @@ async def batch_file_operations( db_file.deleted_at = datetime.now() await db_file.save() await delete_thumbnail(db_file.id) - await log_activity(user=current_user, action="file_deleted_batch", target_type="file", target_id=file_id) + await log_activity( + user=current_user, + action="file_deleted_batch", + target_type="file", + target_id=file_id, + ) updated_files.append(db_file) elif batch_operation.operation == "star": db_file.is_starred = True await db_file.save() - await log_activity(user=current_user, action="file_starred_batch", target_type="file", target_id=file_id) + await log_activity( + user=current_user, + action="file_starred_batch", + target_type="file", + target_id=file_id, + ) updated_files.append(db_file) elif batch_operation.operation == "unstar": db_file.is_starred = False await db_file.save() - await log_activity(user=current_user, action="file_unstarred_batch", target_type="file", target_id=file_id) + await log_activity( + user=current_user, + action="file_unstarred_batch", + target_type="file", + target_id=file_id, + ) updated_files.append(db_file) elif batch_operation.operation == "move": if not payload or payload.target_folder_id is None: - failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"}) + failed_operations.append( + {"file_id": file_id, "reason": "Target folder not specified"} + ) continue - target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False) + target_folder = await Folder.get_or_none( + id=payload.target_folder_id, owner=current_user, is_deleted=False + ) if not target_folder: - failed_operations.append({"file_id": file_id, "reason": "Target folder not found"}) + failed_operations.append( + {"file_id": file_id, "reason": "Target folder not found"} + ) continue existing_file = await File.get_or_none( - name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False + name=db_file.name, + parent=target_folder, + owner=current_user, + is_deleted=False, ) if existing_file and existing_file.id != file_id: - failed_operations.append({"file_id": file_id, "reason": "File with same name exists in target folder"}) + failed_operations.append( + { + "file_id": file_id, + "reason": "File with same name exists in target folder", + } + ) continue db_file.parent = target_folder await db_file.save() - await log_activity(user=current_user, action="file_moved_batch", target_type="file", target_id=file_id) + await log_activity( + user=current_user, + action="file_moved_batch", + target_type="file", + target_id=file_id, + ) updated_files.append(db_file) elif batch_operation.operation == "copy": if not payload or payload.target_folder_id is None: - failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"}) + failed_operations.append( + {"file_id": file_id, "reason": "Target folder not specified"} + ) continue - target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False) + target_folder = await Folder.get_or_none( + id=payload.target_folder_id, owner=current_user, is_deleted=False + ) if not target_folder: - failed_operations.append({"file_id": file_id, "reason": "Target folder not found"}) + failed_operations.append( + {"file_id": file_id, "reason": "Target folder not found"} + ) continue base_name = db_file.name @@ -401,7 +576,12 @@ async def batch_file_operations( counter = 1 new_name = base_name - while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False): + while await File.get_or_none( + name=new_name, + parent=target_folder, + owner=current_user, + is_deleted=False, + ): new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}" counter += 1 @@ -412,48 +592,70 @@ async def batch_file_operations( mime_type=db_file.mime_type, file_hash=db_file.file_hash, owner=current_user, - parent=target_folder + parent=target_folder, + ) + await log_activity( + user=current_user, + action="file_copied_batch", + target_type="file", + target_id=new_file.id, ) - await log_activity(user=current_user, action="file_copied_batch", target_type="file", target_id=new_file.id) updated_files.append(new_file) except Exception as e: failed_operations.append({"file_id": file_id, "reason": str(e)}) return { "succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files], - "failed": failed_operations + "failed": failed_operations, } + @router.put("/{file_id}/content", response_model=FileOut) async def update_file_content( file_id: int, payload: FileContentUpdate, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) if not db_file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) - if not db_file.mime_type or not db_file.mime_type.startswith('text/'): + if not db_file.mime_type or not db_file.mime_type.startswith("text/"): editableExtensions = [ - 'txt', 'md', 'log', 'json', 'js', 'py', 'html', 'css', - 'xml', 'yaml', 'yml', 'sh', 'bat', 'ini', 'conf', 'cfg' + "txt", + "md", + "log", + "json", + "js", + "py", + "html", + "css", + "xml", + "yaml", + "yml", + "sh", + "bat", + "ini", + "conf", + "cfg", ] file_extension = os.path.splitext(db_file.name)[1][1:].lower() if file_extension not in editableExtensions: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="File type is not editable" + detail="File type is not editable", ) - content_bytes = payload.content.encode('utf-8') + content_bytes = payload.content.encode("utf-8") new_size = len(content_bytes) size_diff = new_size - db_file.size if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes: raise HTTPException( status_code=status.HTTP_507_INSUFFICIENT_STORAGE, - detail="Storage quota exceeded" + detail="Storage quota exceeded", ) new_hash = hashlib.sha256(content_bytes).hexdigest() @@ -465,7 +667,7 @@ async def update_file_content( if new_storage_path != db_file.path: try: await storage_manager.delete_file(current_user.id, db_file.path) - except: + except Exception: pass db_file.path = new_storage_path @@ -477,6 +679,8 @@ async def update_file_content( current_user.used_storage_bytes += size_diff await current_user.save() - await log_activity(user=current_user, action="file_updated", target_type="file", target_id=file_id) + await log_activity( + user=current_user, action="file_updated", target_type="file", target_id=file_id + ) return await FileOut.from_tortoise_orm(db_file) diff --git a/mywebdav/routers/folders.py b/mywebdav/routers/folders.py index a74fce2..b2316cf 100644 --- a/mywebdav/routers/folders.py +++ b/mywebdav/routers/folders.py @@ -3,7 +3,13 @@ from typing import List, Optional from ..auth import get_current_user from ..models import User, Folder -from ..schemas import FolderCreate, FolderOut, FolderUpdate, BatchFolderOperation, BatchMoveCopyPayload +from ..schemas import ( + FolderCreate, + FolderOut, + FolderUpdate, + BatchFolderOperation, + BatchMoveCopyPayload, +) from ..activity import log_activity router = APIRouter( @@ -11,12 +17,17 @@ router = APIRouter( tags=["folders"], ) + @router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED) -async def create_folder(folder_in: FolderCreate, current_user: User = Depends(get_current_user)): +async def create_folder( + folder_in: FolderCreate, current_user: User = Depends(get_current_user) +): # Check if parent folder exists and belongs to the current user parent_folder = None if folder_in.parent_id: - parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user) + parent_folder = await Folder.get_or_none( + id=folder_in.parent_id, owner=current_user + ) if not parent_folder: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -39,50 +50,82 @@ async def create_folder(folder_in: FolderCreate, current_user: User = Depends(ge await log_activity(current_user, "folder_created", "folder", folder.id) return await FolderOut.from_tortoise_orm(folder) + @router.get("/{folder_id}/path", response_model=List[FolderOut]) -async def get_folder_path(folder_id: int, current_user: User = Depends(get_current_user)): - folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) +async def get_folder_path( + folder_id: int, current_user: User = Depends(get_current_user) +): + folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) path = [] current = folder while current: path.insert(0, await FolderOut.from_tortoise_orm(current)) if current.parent_id: - current = await Folder.get_or_none(id=current.parent_id, owner=current_user, is_deleted=False) + current = await Folder.get_or_none( + id=current.parent_id, owner=current_user, is_deleted=False + ) else: current = None return path + @router.get("/{folder_id}", response_model=FolderOut) async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)): - folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) return await FolderOut.from_tortoise_orm(folder) + @router.get("/", response_model=List[FolderOut]) -async def list_folders(parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)): +async def list_folders( + parent_id: Optional[int] = None, current_user: User = Depends(get_current_user) +): if parent_id: - parent_folder = await Folder.get_or_none(id=parent_id, owner=current_user, is_deleted=False) + parent_folder = await Folder.get_or_none( + id=parent_id, owner=current_user, is_deleted=False + ) if not parent_folder: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Parent folder not found or does not belong to the current user", ) - folders = await Folder.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name") + folders = await Folder.filter( + parent=parent_folder, owner=current_user, is_deleted=False + ).order_by("name") else: # List root folders (folders with no parent) - folders = await Folder.filter(parent=None, owner=current_user, is_deleted=False).order_by("name") + folders = await Folder.filter( + parent=None, owner=current_user, is_deleted=False + ).order_by("name") return [await FolderOut.from_tortoise_orm(folder) for folder in folders] + @router.put("/{folder_id}", response_model=FolderOut) -async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: User = Depends(get_current_user)): - folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) +async def update_folder( + folder_id: int, + folder_in: FolderUpdate, + current_user: User = Depends(get_current_user), +): + folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) if folder_in.name: existing_folder = await Folder.get_or_none( @@ -97,11 +140,16 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U if folder_in.parent_id is not None: if folder_in.parent_id == folder_id: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot set folder as its own parent") - + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot set folder as its own parent", + ) + new_parent_folder = None - if folder_in.parent_id != 0: # 0 could represent moving to root - new_parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user, is_deleted=False) + if folder_in.parent_id != 0: # 0 could represent moving to root + new_parent_folder = await Folder.get_or_none( + id=folder_in.parent_id, owner=current_user, is_deleted=False + ) if not new_parent_folder: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -113,48 +161,66 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U await log_activity(current_user, "folder_updated", "folder", folder.id) return await FolderOut.from_tortoise_orm(folder) + @router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)): - folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) + folder.is_deleted = True await folder.save() await log_activity(current_user, "folder_deleted", "folder", folder.id) return + @router.post("/{folder_id}/star", response_model=FolderOut) async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)): - db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + db_folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not db_folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) db_folder.is_starred = True await db_folder.save() await log_activity(current_user, "folder_starred", "folder", folder_id) return await FolderOut.from_tortoise_orm(db_folder) + @router.post("/{folder_id}/unstar", response_model=FolderOut) async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)): - db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + db_folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not db_folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) db_folder.is_starred = False await db_folder.save() await log_activity(current_user, "folder_unstarred", "folder", folder_id) return await FolderOut.from_tortoise_orm(db_folder) + @router.post("/batch", response_model=List[FolderOut]) async def batch_folder_operations( batch_operation: BatchFolderOperation, payload: Optional[BatchMoveCopyPayload] = None, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): updated_folders = [] for folder_id in batch_operation.folder_ids: - db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False) + db_folder = await Folder.get_or_none( + id=folder_id, owner=current_user, is_deleted=False + ) if not db_folder: - continue # Skip if folder not found or not owned by user + continue # Skip if folder not found or not owned by user if batch_operation.operation == "delete": db_folder.is_deleted = True @@ -171,13 +237,22 @@ async def batch_folder_operations( await db_folder.save() await log_activity(current_user, "folder_unstarred", "folder", folder_id) updated_folders.append(db_folder) - elif batch_operation.operation == "move" and payload and payload.target_folder_id is not None: - target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False) + elif ( + batch_operation.operation == "move" + and payload + and payload.target_folder_id is not None + ): + target_folder = await Folder.get_or_none( + id=payload.target_folder_id, owner=current_user, is_deleted=False + ) if not target_folder: continue existing_folder = await Folder.get_or_none( - name=db_folder.name, parent=target_folder, owner=current_user, is_deleted=False + name=db_folder.name, + parent=target_folder, + owner=current_user, + is_deleted=False, ) if existing_folder and existing_folder.id != folder_id: continue @@ -186,5 +261,5 @@ async def batch_folder_operations( await db_folder.save() await log_activity(current_user, "folder_moved", "folder", folder_id) updated_folders.append(db_folder) - + return [await FolderOut.from_tortoise_orm(f) for f in updated_folders] diff --git a/mywebdav/routers/search.py b/mywebdav/routers/search.py index 8c68c62..77c4cc0 100644 --- a/mywebdav/routers/search.py +++ b/mywebdav/routers/search.py @@ -11,15 +11,22 @@ router = APIRouter( tags=["search"], ) + @router.get("/files", response_model=List[FileOut]) async def search_files( q: str = Query(..., min_length=1, description="Search query"), - file_type: Optional[str] = Query(None, description="Filter by MIME type prefix (e.g., 'image', 'video')"), + file_type: Optional[str] = Query( + None, description="Filter by MIME type prefix (e.g., 'image', 'video')" + ), min_size: Optional[int] = Query(None, description="Minimum file size in bytes"), max_size: Optional[int] = Query(None, description="Maximum file size in bytes"), - date_from: Optional[datetime] = Query(None, description="Filter files created after this date"), - date_to: Optional[datetime] = Query(None, description="Filter files created before this date"), - current_user: User = Depends(get_current_user) + date_from: Optional[datetime] = Query( + None, description="Filter files created after this date" + ), + date_to: Optional[datetime] = Query( + None, description="Filter files created before this date" + ), + current_user: User = Depends(get_current_user), ): query = File.filter(owner=current_user, is_deleted=False, name__icontains=q) @@ -41,15 +48,16 @@ async def search_files( files = await query.order_by("-created_at").limit(100) return [await FileOut.from_tortoise_orm(f) for f in files] + @router.get("/folders", response_model=List[FolderOut]) async def search_folders( q: str = Query(..., min_length=1, description="Search query"), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), ): - folders = await Folder.filter( - owner=current_user, - is_deleted=False, - name__icontains=q - ).order_by("-created_at").limit(100) + folders = ( + await Folder.filter(owner=current_user, is_deleted=False, name__icontains=q) + .order_by("-created_at") + .limit(100) + ) return [await FolderOut.from_tortoise_orm(folder) for folder in folders] diff --git a/mywebdav/routers/shares.py b/mywebdav/routers/shares.py index 7c37dff..92c2f51 100644 --- a/mywebdav/routers/shares.py +++ b/mywebdav/routers/shares.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from typing import Optional, List import secrets -from datetime import datetime, timedelta +from datetime import datetime from ..auth import get_current_user from ..models import User, File, Folder, Share @@ -17,25 +17,44 @@ router = APIRouter( tags=["shares"], ) + @router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED) -async def create_share_link(share_in: ShareCreate, current_user: User = Depends(get_current_user)): +async def create_share_link( + share_in: ShareCreate, current_user: User = Depends(get_current_user) +): if not share_in.file_id and not share_in.folder_id: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either file_id or folder_id must be provided") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Either file_id or folder_id must be provided", + ) if share_in.file_id and share_in.folder_id: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot share both a file and a folder simultaneously") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot share both a file and a folder simultaneously", + ) file = None folder = None if share_in.file_id: - file = await File.get_or_none(id=share_in.file_id, owner=current_user, is_deleted=False) + file = await File.get_or_none( + id=share_in.file_id, owner=current_user, is_deleted=False + ) if not file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found or does not belong to you") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="File not found or does not belong to you", + ) + if share_in.folder_id: - folder = await Folder.get_or_none(id=share_in.folder_id, owner=current_user, is_deleted=False) + folder = await Folder.get_or_none( + id=share_in.folder_id, owner=current_user, is_deleted=False + ) if not folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found or does not belong to you") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Folder not found or does not belong to you", + ) token = secrets.token_urlsafe(16) hashed_password = None @@ -60,8 +79,14 @@ async def create_share_link(share_in: ShareCreate, current_user: User = Depends( item_type = "file" if file else "folder" item_name = file.name if file else folder.name - expiry_text = f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" if share_in.expires_at else "" - password_text = f"\n\nPassword: {share_in.password}" if share_in.password else "" + expiry_text = ( + f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" + if share_in.expires_at + else "" + ) + password_text = ( + f"\n\nPassword: {share_in.password}" if share_in.password else "" + ) email_body = f"""Hello, @@ -95,80 +120,103 @@ MyWebdav File Sharing Service""" to_email=share_in.invite_email, subject=f"{current_user.username} shared {item_name} with you", body=email_body, - html=email_html + html=email_html, ) except Exception as e: print(f"Failed to send invitation email: {e}") return await ShareOut.from_tortoise_orm(share) + @router.get("/my", response_model=List[ShareOut]) async def list_my_shares(current_user: User = Depends(get_current_user)): shares = await Share.filter(owner=current_user).order_by("-created_at") return [await ShareOut.from_tortoise_orm(share) for share in shares] + @router.get("/{share_token}", response_model=ShareOut) async def get_share_link_info(share_token: str): share = await Share.get_or_none(token=share_token) if not share: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found" + ) + if share.expires_at and share.expires_at < datetime.utcnow(): - raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired") - + raise HTTPException( + status_code=status.HTTP_410_GONE, detail="Share link has expired" + ) + # Increment access count share.access_count += 1 await share.save() return await ShareOut.from_tortoise_orm(share) + @router.put("/{share_id}", response_model=ShareOut) -async def update_share(share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)): +async def update_share( + share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user) +): share = await Share.get_or_none(id=share_id, owner=current_user) if not share: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Share link not found or does not belong to you", + ) if share_in.expires_at is not None: share.expires_at = share_in.expires_at - + if share_in.password is not None: share.hashed_password = get_password_hash(share_in.password) share.password_protected = True - elif share_in.password == "": # Allow clearing password + elif share_in.password == "": # Allow clearing password share.hashed_password = None share.password_protected = False if share_in.permission_level is not None: share.permission_level = share_in.permission_level - + await share.save() return await ShareOut.from_tortoise_orm(share) + @router.post("/{share_token}/access") async def access_shared_content(share_token: str, password: Optional[str] = None): share = await Share.get_or_none(token=share_token) if not share: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found" + ) if share.expires_at and share.expires_at < datetime.utcnow(): - raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired") + raise HTTPException( + status_code=status.HTTP_410_GONE, detail="Share link has expired" + ) if share.password_protected: if not password or not verify_password(password, share.hashed_password): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password" + ) result = {"message": "Access granted", "permission_level": share.permission_level} if share.file_id: file = await File.get_or_none(id=share.file_id, is_deleted=False) if not file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) result["file"] = await FileOut.from_tortoise_orm(file) result["type"] = "file" elif share.folder_id: folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False) if not folder: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" + ) result["folder"] = await FolderOut.from_tortoise_orm(folder) result["type"] = "folder" @@ -179,29 +227,42 @@ async def access_shared_content(share_token: str, password: Optional[str] = None return result + @router.get("/{share_token}/download") async def download_shared_file(share_token: str, password: Optional[str] = None): share = await Share.get_or_none(token=share_token) if not share: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found" + ) if share.expires_at and share.expires_at < datetime.utcnow(): - raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired") + raise HTTPException( + status_code=status.HTTP_410_GONE, detail="Share link has expired" + ) if share.password_protected: if not password or not verify_password(password, share.hashed_password): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password" + ) if not share.file_id: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This share is not for a file") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This share is not for a file", + ) file = await File.get_or_none(id=share.file_id, is_deleted=False) if not file: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) owner = await User.get(id=file.owner_id) try: + async def file_iterator(): async for chunk in storage_manager.get_file(owner.id, file.path): yield chunk @@ -209,18 +270,22 @@ async def download_shared_file(share_token: str, password: Optional[str] = None) return StreamingResponse( content=file_iterator(), media_type=file.mime_type, - headers={ - "Content-Disposition": f'attachment; filename="{file.name}"' - } + headers={"Content-Disposition": f'attachment; filename="{file.name}"'}, ) except FileNotFoundError: raise HTTPException(status_code=404, detail="File not found in storage") + @router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT) -async def delete_share_link(share_id: int, current_user: User = Depends(get_current_user)): +async def delete_share_link( + share_id: int, current_user: User = Depends(get_current_user) +): share = await Share.get_or_none(id=share_id, owner=current_user) if not share: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Share link not found or does not belong to you", + ) + await share.delete() return diff --git a/mywebdav/routers/starred.py b/mywebdav/routers/starred.py index d47a166..a44e31a 100644 --- a/mywebdav/routers/starred.py +++ b/mywebdav/routers/starred.py @@ -11,18 +11,29 @@ router = APIRouter( tags=["starred"], ) + @router.get("/files", response_model=List[FileOut]) async def list_starred_files(current_user: User = Depends(get_current_user)): - files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name") + files = await File.filter( + owner=current_user, is_starred=True, is_deleted=False + ).order_by("name") return [await FileOut.from_tortoise_orm(f) for f in files] + @router.get("/folders", response_model=List[FolderOut]) async def list_starred_folders(current_user: User = Depends(get_current_user)): - folders = await Folder.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name") + folders = await Folder.filter( + owner=current_user, is_starred=True, is_deleted=False + ).order_by("name") return [await FolderOut.from_tortoise_orm(f) for f in folders] -@router.get("/all", response_model=List[FileOut]) # This will return files and folders as files for now + +@router.get( + "/all", response_model=List[FileOut] +) # This will return files and folders as files for now async def list_all_starred(current_user: User = Depends(get_current_user)): - starred_files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name") + starred_files = await File.filter( + owner=current_user, is_starred=True, is_deleted=False + ).order_by("name") # For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema. - return [await FileOut.from_tortoise_orm(f) for f in starred_files] \ No newline at end of file + return [await FileOut.from_tortoise_orm(f) for f in starred_files] diff --git a/mywebdav/routers/users.py b/mywebdav/routers/users.py index 9bb2010..ef134a0 100644 --- a/mywebdav/routers/users.py +++ b/mywebdav/routers/users.py @@ -1,17 +1,19 @@ from fastapi import APIRouter, Depends from ..auth import get_current_user from ..models import User_Pydantic, User, File, Folder -from typing import List, Dict, Any +from typing import Dict, Any router = APIRouter( prefix="/users", tags=["users"], ) + @router.get("/me", response_model=User_Pydantic) async def read_users_me(current_user: User = Depends(get_current_user)): return await User_Pydantic.from_tortoise_orm(current_user) + @router.get("/me/export", response_model=Dict[str, Any]) async def export_my_data(current_user: User = Depends(get_current_user)): """ @@ -35,6 +37,7 @@ async def export_my_data(current_user: User = Depends(get_current_user)): # share information, etc., would also be included. } + @router.delete("/me", status_code=204) async def delete_my_account(current_user: User = Depends(get_current_user)): """ @@ -48,5 +51,3 @@ async def delete_my_account(current_user: User = Depends(get_current_user)): # Finally, delete the user account await current_user.delete() return {} - - diff --git a/mywebdav/schemas.py b/mywebdav/schemas.py index 9eb52b9..db9e83f 100644 --- a/mywebdav/schemas.py +++ b/mywebdav/schemas.py @@ -2,19 +2,24 @@ from datetime import datetime from pydantic import BaseModel, EmailStr, ConfigDict from typing import Optional, List from tortoise.contrib.pydantic import pydantic_model_creator +from mywebdav.models import Folder, File, Share, FileVersion + class UserCreate(BaseModel): username: str email: EmailStr password: str + class UserLogin(BaseModel): username: str password: str + class UserLoginWith2FA(UserLogin): two_factor_code: Optional[str] = None + class UserAdminUpdate(BaseModel): username: Optional[str] = None email: Optional[EmailStr] = None @@ -25,22 +30,27 @@ class UserAdminUpdate(BaseModel): plan_type: Optional[str] = None is_2fa_enabled: Optional[bool] = None + class Token(BaseModel): access_token: str token_type: str + class TokenData(BaseModel): username: str | None = None two_factor_verified: bool = False + class FolderCreate(BaseModel): name: str parent_id: Optional[int] = None + class FolderUpdate(BaseModel): name: Optional[str] = None parent_id: Optional[int] = None + class ShareCreate(BaseModel): file_id: Optional[int] = None folder_id: Optional[int] = None @@ -49,9 +59,11 @@ class ShareCreate(BaseModel): permission_level: str = "viewer" invite_email: Optional[EmailStr] = None + class TeamCreate(BaseModel): name: str + class TeamOut(BaseModel): id: int name: str @@ -60,6 +72,7 @@ class TeamOut(BaseModel): model_config = ConfigDict(from_attributes=True) + class ActivityOut(BaseModel): id: int user_id: Optional[int] = None @@ -71,12 +84,14 @@ class ActivityOut(BaseModel): model_config = ConfigDict(from_attributes=True) + class FileRequestCreate(BaseModel): title: str description: Optional[str] = None target_folder_id: int expires_at: Optional[datetime] = None + class FileRequestOut(BaseModel): id: int title: str @@ -90,25 +105,28 @@ class FileRequestOut(BaseModel): model_config = ConfigDict(from_attributes=True) -from mywebdav.models import Folder, File, Share, FileVersion FolderOut = pydantic_model_creator(Folder, name="FolderOut") FileOut = pydantic_model_creator(File, name="FileOut") ShareOut = pydantic_model_creator(Share, name="ShareOut") FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut") + class ErrorResponse(BaseModel): code: int message: str details: Optional[str] = None + class BatchFileOperation(BaseModel): file_ids: List[int] - operation: str # e.g., "delete", "move", "copy", "star", "unstar" + operation: str # e.g., "delete", "move", "copy", "star", "unstar" + class BatchFolderOperation(BaseModel): folder_ids: List[int] - operation: str # e.g., "delete", "move", "star", "unstar" + operation: str # e.g., "delete", "move", "star", "unstar" + class BatchMoveCopyPayload(BaseModel): target_folder_id: Optional[int] = None diff --git a/mywebdav/settings.py b/mywebdav/settings.py index 076b4a3..3186ede 100644 --- a/mywebdav/settings.py +++ b/mywebdav/settings.py @@ -2,8 +2,9 @@ import os import sys from pydantic_settings import BaseSettings, SettingsConfigDict + class Settings(BaseSettings): - model_config = SettingsConfigDict(env_file='.env', extra='ignore') + model_config = SettingsConfigDict(env_file=".env", extra="ignore") DATABASE_URL: str = "sqlite:///app/mywebdav.db" REDIS_URL: str = "redis://redis:6379/0" @@ -30,8 +31,14 @@ class Settings(BaseSettings): STRIPE_WEBHOOK_SECRET: str = "" BILLING_ENABLED: bool = False + settings = Settings() -if settings.SECRET_KEY == "super_secret_key" and os.getenv("ENVIRONMENT") == "production": - print("ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable.") +if ( + settings.SECRET_KEY == "super_secret_key" + and os.getenv("ENVIRONMENT") == "production" +): + print( + "ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable." + ) sys.exit(1) diff --git a/mywebdav/storage.py b/mywebdav/storage.py index 94b8cc9..a13f40f 100644 --- a/mywebdav/storage.py +++ b/mywebdav/storage.py @@ -5,6 +5,7 @@ from typing import AsyncGenerator from .settings import settings + class StorageManager: def __init__(self, base_path: str = settings.STORAGE_PATH): self.base_path = Path(base_path) @@ -12,7 +13,11 @@ class StorageManager: async def _get_full_path(self, user_id: int, file_path: str) -> Path: # Ensure file_path is relative and safe - relative_path = Path(file_path).relative_to('/') if str(file_path).startswith('/') else Path(file_path) + relative_path = ( + Path(file_path).relative_to("/") + if str(file_path).startswith("/") + else Path(file_path) + ) full_path = self.base_path / str(user_id) / relative_path full_path.parent.mkdir(parents=True, exist_ok=True) return full_path @@ -27,7 +32,7 @@ class StorageManager: full_path = await self._get_full_path(user_id, file_path) if not full_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - + async with aiofiles.open(full_path, "rb") as f: while chunk := await f.read(8192): yield chunk @@ -52,4 +57,5 @@ class StorageManager: full_path = await self._get_full_path(user_id, file_path) return full_path.exists() + storage_manager = StorageManager() diff --git a/mywebdav/thumbnails.py b/mywebdav/thumbnails.py index 2fd8b8b..f13405c 100644 --- a/mywebdav/thumbnails.py +++ b/mywebdav/thumbnails.py @@ -1,4 +1,3 @@ -import os import asyncio from pathlib import Path from PIL import Image @@ -9,7 +8,10 @@ from .settings import settings THUMBNAIL_SIZE = (300, 300) THUMBNAIL_DIR = "thumbnails" -async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Optional[str]: + +async def generate_thumbnail( + file_path: str, mime_type: str, user_id: int +) -> Optional[str]: try: if mime_type.startswith("image/"): return await generate_image_thumbnail(file_path, user_id) @@ -20,6 +22,7 @@ async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Op print(f"Error generating thumbnail for {file_path}: {e}") return None + async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]: loop = asyncio.get_event_loop() @@ -30,12 +33,16 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str file_name = Path(file_path).name thumbnail_name = f"thumb_{file_name}" - if not thumbnail_name.lower().endswith(('.jpg', '.jpeg', '.png')): + if not thumbnail_name.lower().endswith((".jpg", ".jpeg", ".png")): thumbnail_name += ".jpg" thumbnail_path = thumbnail_dir / thumbnail_name - actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path) + actual_file_path = ( + base_path / str(user_id) / file_path + if not Path(file_path).is_absolute() + else Path(file_path) + ) with Image.open(actual_file_path) as img: img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) @@ -44,7 +51,9 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str background = Image.new("RGB", img.size, (255, 255, 255)) if img.mode == "P": img = img.convert("RGBA") - background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None) + background.paste( + img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None + ) img = background img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True) @@ -53,6 +62,7 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str return await loop.run_in_executor(None, _generate) + async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]: loop = asyncio.get_event_loop() @@ -65,17 +75,29 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str thumbnail_name = f"thumb_{file_name}.jpg" thumbnail_path = thumbnail_dir / thumbnail_name - actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path) + actual_file_path = ( + base_path / str(user_id) / file_path + if not Path(file_path).is_absolute() + else Path(file_path) + ) - subprocess.run([ - "ffmpeg", - "-i", str(actual_file_path), - "-ss", "00:00:01", - "-vframes", "1", - "-vf", f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease", - "-y", - str(thumbnail_path) - ], check=True, capture_output=True) + subprocess.run( + [ + "ffmpeg", + "-i", + str(actual_file_path), + "-ss", + "00:00:01", + "-vframes", + "1", + "-vf", + f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease", + "-y", + str(thumbnail_path), + ], + check=True, + capture_output=True, + ) return str(thumbnail_path.relative_to(base_path / str(user_id))) @@ -84,6 +106,7 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str except subprocess.CalledProcessError: return None + async def delete_thumbnail(thumbnail_path: str, user_id: int): try: base_path = Path(settings.STORAGE_PATH) diff --git a/mywebdav/two_factor.py b/mywebdav/two_factor.py index f9c3c5a..b71b58a 100644 --- a/mywebdav/two_factor.py +++ b/mywebdav/two_factor.py @@ -5,15 +5,20 @@ import base64 import secrets import hashlib -from typing import List, Optional +from typing import List + def generate_totp_secret() -> str: """Generates a random base32 TOTP secret.""" return pyotp.random_base32() + def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str: """Generates a Google Authenticator-compatible TOTP URI.""" - return pyotp.totp.TOTP(secret).provisioning_uri(name=account_name, issuer_name=issuer_name) + return pyotp.totp.TOTP(secret).provisioning_uri( + name=account_name, issuer_name=issuer_name + ) + def generate_qr_code_base64(uri: str) -> str: """Generates a base64 encoded QR code image for a given URI.""" @@ -31,27 +36,33 @@ def generate_qr_code_base64(uri: str) -> str: img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") + def verify_totp_code(secret: str, code: str) -> bool: """Verifies a TOTP code against a secret.""" totp = pyotp.TOTP(secret) return totp.verify(code) + def generate_recovery_codes(num_codes: int = 10) -> List[str]: """Generates a list of random recovery codes.""" return [secrets.token_urlsafe(16) for _ in range(num_codes)] + def hash_recovery_code(code: str) -> str: """Hashes a single recovery code using SHA256.""" - return hashlib.sha256(code.encode('utf-8')).hexdigest() + return hashlib.sha256(code.encode("utf-8")).hexdigest() + def verify_recovery_code(plain_code: str, hashed_code: str) -> bool: """Verifies a plain recovery code against its hashed version.""" return hash_recovery_code(plain_code) == hashed_code + def hash_recovery_codes(codes: List[str]) -> List[str]: """Hashes a list of recovery codes.""" return [hash_recovery_code(code) for code in codes] + def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool: """Verifies if a plain recovery code matches any of the hashed recovery codes.""" for hashed_code in hashed_codes: diff --git a/mywebdav/webdav.py b/mywebdav/webdav.py index b236917..f9ef5a9 100644 --- a/mywebdav/webdav.py +++ b/mywebdav/webdav.py @@ -19,6 +19,7 @@ router = APIRouter( tags=["webdav"], ) + class WebDAVLock: locks = {} @@ -26,10 +27,10 @@ class WebDAVLock: def create_lock(cls, path: str, user_id: int, timeout: int = 3600): lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}" cls.locks[path] = { - 'token': lock_token, - 'user_id': user_id, - 'created_at': datetime.now(), - 'timeout': timeout + "token": lock_token, + "user_id": user_id, + "created_at": datetime.now(), + "timeout": timeout, } return lock_token @@ -42,17 +43,18 @@ class WebDAVLock: if path in cls.locks: del cls.locks[path] + async def basic_auth(authorization: Optional[str] = Header(None)): if not authorization: return None try: scheme, credentials = authorization.split() - if scheme.lower() != 'basic': + if scheme.lower() != "basic": return None - decoded = base64.b64decode(credentials).decode('utf-8') - username, password = decoded.split(':', 1) + decoded = base64.b64decode(credentials).decode("utf-8") + username, password = decoded.split(":", 1) user = await User.get_or_none(username=username) if user and verify_password(password, user.hashed_password): @@ -62,6 +64,7 @@ async def basic_auth(authorization: Optional[str] = Header(None)): return None + async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)): user = await basic_auth(authorization) if user: @@ -75,14 +78,15 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - headers={'WWW-Authenticate': 'Basic realm="MyWebdav WebDAV"'} + headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'}, ) + async def resolve_path(path_str: str, user: User): - if not path_str or path_str == '/': + if not path_str or path_str == "/": return None, None, True - parts = [p for p in path_str.split('/') if p] + parts = [p for p in path_str.split("/") if p] if not parts: return None, None, True @@ -90,10 +94,7 @@ async def resolve_path(path_str: str, user: User): current_folder = None for i, part in enumerate(parts[:-1]): folder = await Folder.get_or_none( - name=part, - parent=current_folder, - owner=user, - is_deleted=False + name=part, parent=current_folder, owner=user, is_deleted=False ) if not folder: return None, None, False @@ -102,39 +103,37 @@ async def resolve_path(path_str: str, user: User): last_part = parts[-1] folder = await Folder.get_or_none( - name=last_part, - parent=current_folder, - owner=user, - is_deleted=False + name=last_part, parent=current_folder, owner=user, is_deleted=False ) if folder: return folder, current_folder, True file = await File.get_or_none( - name=last_part, - parent=current_folder, - owner=user, - is_deleted=False + name=last_part, parent=current_folder, owner=user, is_deleted=False ) if file: return file, current_folder, True return None, current_folder, True + def build_href(base_path: str, name: str, is_collection: bool): path = f"{base_path.rstrip('/')}/{name}" if is_collection: - path += '/' + path += "/" return path + async def get_custom_properties(resource_type: str, resource_id: int): props = await WebDAVProperty.filter( - resource_type=resource_type, - resource_id=resource_id + resource_type=resource_type, resource_id=resource_id ) return {(prop.namespace, prop.name): prop.value for prop in props} -def create_propstat_element(props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"): + +def create_propstat_element( + props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK" +): propstat = ET.Element("D:propstat") prop = ET.SubElement(propstat, "D:prop") @@ -151,10 +150,10 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str elem.text = value elif key == "getlastmodified": elem = ET.SubElement(prop, f"D:{key}") - elem.text = value.strftime('%a, %d %b %Y %H:%M:%S GMT') + elem.text = value.strftime("%a, %d %b %Y %H:%M:%S GMT") elif key == "creationdate": elem = ET.SubElement(prop, f"D:{key}") - elem.text = value.isoformat() + 'Z' + elem.text = value.isoformat() + "Z" elif key == "displayname": elem = ET.SubElement(prop, f"D:{key}") elem.text = value @@ -174,6 +173,7 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str return propstat + def parse_propfind_body(body: bytes): if not body: return None @@ -193,8 +193,8 @@ def parse_propfind_body(body: bytes): if prop is not None: requested_props = [] for child in prop: - ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:" - name = child.tag.split('}')[1] if '}' in child.tag else child.tag + ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:" + name = child.tag.split("}")[1] if "}" in child.tag else child.tag requested_props.append((ns, name)) return requested_props except ET.ParseError: @@ -202,6 +202,7 @@ def parse_propfind_body(body: bytes): return None + @router.api_route("/{full_path:path}", methods=["OPTIONS"]) async def webdav_options(full_path: str): return Response( @@ -209,14 +210,17 @@ async def webdav_options(full_path: str): headers={ "DAV": "1, 2", "Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK", - "MS-Author-Via": "DAV" - } + "MS-Author-Via": "DAV", + }, ) + @router.api_route("/{full_path:path}", methods=["PROPFIND"]) -async def handle_propfind(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): +async def handle_propfind( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): depth = request.headers.get("Depth", "1") - full_path = unquote(full_path).strip('/') + full_path = unquote(full_path).strip("/") body = await request.body() requested_props = parse_propfind_body(body) @@ -232,19 +236,23 @@ async def handle_propfind(request: Request, full_path: str, current_user: User = if resource is None: response = ET.SubElement(multistatus, "D:response") href = ET.SubElement(response, "D:href") - href.text = base_href if base_href.endswith('/') else base_href + '/' + href.text = base_href if base_href.endswith("/") else base_href + "/" props = { "resourcetype": "collection", - "displayname": full_path.split('/')[-1] if full_path else "Root", + "displayname": full_path.split("/")[-1] if full_path else "Root", "creationdate": datetime.now(), - "getlastmodified": datetime.now() + "getlastmodified": datetime.now(), } response.append(create_propstat_element(props)) if depth in ["1", "infinity"]: - folders = await Folder.filter(owner=current_user, parent=parent, is_deleted=False) - files = await File.filter(owner=current_user, parent=parent, is_deleted=False) + folders = await Folder.filter( + owner=current_user, parent=parent, is_deleted=False + ) + files = await File.filter( + owner=current_user, parent=parent, is_deleted=False + ) for folder in folders: response = ET.SubElement(multistatus, "D:response") @@ -255,9 +263,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User = "resourcetype": "collection", "displayname": folder.name, "creationdate": folder.created_at, - "getlastmodified": folder.updated_at + "getlastmodified": folder.updated_at, } - custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None + custom_props = ( + await get_custom_properties("folder", folder.id) + if requested_props == "allprop" + or ( + isinstance(requested_props, list) + and any(ns != "DAV:" for ns, _ in requested_props) + ) + else None + ) response.append(create_propstat_element(props, custom_props)) for file in files: @@ -272,28 +288,48 @@ async def handle_propfind(request: Request, full_path: str, current_user: User = "getcontenttype": file.mime_type, "creationdate": file.created_at, "getlastmodified": file.updated_at, - "getetag": f'"{file.file_hash}"' + "getetag": f'"{file.file_hash}"', } - custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None + custom_props = ( + await get_custom_properties("file", file.id) + if requested_props == "allprop" + or ( + isinstance(requested_props, list) + and any(ns != "DAV:" for ns, _ in requested_props) + ) + else None + ) response.append(create_propstat_element(props, custom_props)) elif isinstance(resource, Folder): response = ET.SubElement(multistatus, "D:response") href = ET.SubElement(response, "D:href") - href.text = base_href if base_href.endswith('/') else base_href + '/' + href.text = base_href if base_href.endswith("/") else base_href + "/" props = { "resourcetype": "collection", "displayname": resource.name, "creationdate": resource.created_at, - "getlastmodified": resource.updated_at + "getlastmodified": resource.updated_at, } - custom_props = await get_custom_properties("folder", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None + custom_props = ( + await get_custom_properties("folder", resource.id) + if requested_props == "allprop" + or ( + isinstance(requested_props, list) + and any(ns != "DAV:" for ns, _ in requested_props) + ) + else None + ) response.append(create_propstat_element(props, custom_props)) if depth in ["1", "infinity"]: - folders = await Folder.filter(owner=current_user, parent=resource, is_deleted=False) - files = await File.filter(owner=current_user, parent=resource, is_deleted=False) + folders = await Folder.filter( + owner=current_user, parent=resource, is_deleted=False + ) + files = await File.filter( + owner=current_user, parent=resource, is_deleted=False + ) for folder in folders: response = ET.SubElement(multistatus, "D:response") @@ -304,9 +340,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User = "resourcetype": "collection", "displayname": folder.name, "creationdate": folder.created_at, - "getlastmodified": folder.updated_at + "getlastmodified": folder.updated_at, } - custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None + custom_props = ( + await get_custom_properties("folder", folder.id) + if requested_props == "allprop" + or ( + isinstance(requested_props, list) + and any(ns != "DAV:" for ns, _ in requested_props) + ) + else None + ) response.append(create_propstat_element(props, custom_props)) for file in files: @@ -321,9 +365,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User = "getcontenttype": file.mime_type, "creationdate": file.created_at, "getlastmodified": file.updated_at, - "getetag": f'"{file.file_hash}"' + "getetag": f'"{file.file_hash}"', } - custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None + custom_props = ( + await get_custom_properties("file", file.id) + if requested_props == "allprop" + or ( + isinstance(requested_props, list) + and any(ns != "DAV:" for ns, _ in requested_props) + ) + else None + ) response.append(create_propstat_element(props, custom_props)) elif isinstance(resource, File): @@ -338,17 +390,32 @@ async def handle_propfind(request: Request, full_path: str, current_user: User = "getcontenttype": resource.mime_type, "creationdate": resource.created_at, "getlastmodified": resource.updated_at, - "getetag": f'"{resource.file_hash}"' + "getetag": f'"{resource.file_hash}"', } - custom_props = await get_custom_properties("file", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None + custom_props = ( + await get_custom_properties("file", resource.id) + if requested_props == "allprop" + or ( + isinstance(requested_props, list) + and any(ns != "DAV:" for ns, _ in requested_props) + ) + else None + ) response.append(create_propstat_element(props, custom_props)) xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True) - return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207) + return Response( + content=xml_content, + media_type="application/xml; charset=utf-8", + status_code=207, + ) + @router.api_route("/{full_path:path}", methods=["GET", "HEAD"]) -async def handle_get(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_get( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") resource, parent, exists = await resolve_path(full_path, current_user) @@ -363,8 +430,10 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe "Content-Length": str(resource.size), "Content-Type": resource.mime_type, "ETag": f'"{resource.file_hash}"', - "Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT') - } + "Last-Modified": resource.updated_at.strftime( + "%a, %d %b %Y %H:%M:%S GMT" + ), + }, ) async def file_iterator(): @@ -377,23 +446,28 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe headers={ "Content-Disposition": f'attachment; filename="{resource.name}"', "ETag": f'"{resource.file_hash}"', - "Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT') - } + "Last-Modified": resource.updated_at.strftime( + "%a, %d %b %Y %H:%M:%S GMT" + ), + }, ) except FileNotFoundError: raise HTTPException(status_code=404, detail="File not found in storage") + @router.api_route("/{full_path:path}", methods=["PUT"]) -async def handle_put(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_put( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") if not full_path: raise HTTPException(status_code=400, detail="Cannot PUT to root") - parts = [p for p in full_path.split('/') if p] + parts = [p for p in full_path.split("/") if p] file_name = parts[-1] - parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else '' + parent_path = "/".join(parts[:-1]) if len(parts) > 1 else "" _, parent_folder, exists = await resolve_path(parent_path, current_user) if not exists: @@ -417,10 +491,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe mime_type = "application/octet-stream" existing_file = await File.get_or_none( - name=file_name, - parent=parent_folder, - owner=current_user, - is_deleted=False + name=file_name, parent=parent_folder, owner=current_user, is_deleted=False ) if existing_file: @@ -432,7 +503,9 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe existing_file.updated_at = datetime.now() await existing_file.save() - current_user.used_storage_bytes = current_user.used_storage_bytes - old_size + file_size + current_user.used_storage_bytes = ( + current_user.used_storage_bytes - old_size + file_size + ) await current_user.save() await log_activity(current_user, "file_updated", "file", existing_file.id) @@ -445,7 +518,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe mime_type=mime_type, file_hash=file_hash, owner=current_user, - parent=parent_folder + parent=parent_folder, ) current_user.used_storage_bytes += file_size @@ -454,9 +527,12 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe await log_activity(current_user, "file_created", "file", db_file.id) return Response(status_code=201) + @router.api_route("/{full_path:path}", methods=["DELETE"]) -async def handle_delete(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_delete( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") if not full_path: raise HTTPException(status_code=400, detail="Cannot DELETE root") @@ -478,44 +554,45 @@ async def handle_delete(request: Request, full_path: str, current_user: User = D return Response(status_code=204) + @router.api_route("/{full_path:path}", methods=["MKCOL"]) -async def handle_mkcol(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_mkcol( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") if not full_path: raise HTTPException(status_code=400, detail="Cannot MKCOL at root") - parts = [p for p in full_path.split('/') if p] + parts = [p for p in full_path.split("/") if p] folder_name = parts[-1] - parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else '' + parent_path = "/".join(parts[:-1]) if len(parts) > 1 else "" _, parent_folder, exists = await resolve_path(parent_path, current_user) if not exists: raise HTTPException(status_code=409, detail="Parent folder does not exist") existing = await Folder.get_or_none( - name=folder_name, - parent=parent_folder, - owner=current_user, - is_deleted=False + name=folder_name, parent=parent_folder, owner=current_user, is_deleted=False ) if existing: raise HTTPException(status_code=405, detail="Folder already exists") folder = await Folder.create( - name=folder_name, - parent=parent_folder, - owner=current_user + name=folder_name, parent=parent_folder, owner=current_user ) await log_activity(current_user, "folder_created", "folder", folder.id) return Response(status_code=201) + @router.api_route("/{full_path:path}", methods=["COPY"]) -async def handle_copy(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_copy( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") destination = request.headers.get("Destination") overwrite = request.headers.get("Overwrite", "T") @@ -523,7 +600,7 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep raise HTTPException(status_code=400, detail="Destination header required") dest_path = unquote(urlparse(destination).path) - dest_path = dest_path.replace('/webdav/', '').strip('/') + dest_path = dest_path.replace("/webdav/", "").strip("/") source_resource, _, exists = await resolve_path(full_path, current_user) @@ -533,9 +610,9 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep if not isinstance(source_resource, File): raise HTTPException(status_code=501, detail="Only file copy is implemented") - dest_parts = [p for p in dest_path.split('/') if p] + dest_parts = [p for p in dest_path.split("/") if p] dest_name = dest_parts[-1] - dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else '' + dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else "" _, dest_parent, exists = await resolve_path(dest_parent_path, current_user) @@ -543,14 +620,13 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep raise HTTPException(status_code=409, detail="Destination parent does not exist") existing_dest = await File.get_or_none( - name=dest_name, - parent=dest_parent, - owner=current_user, - is_deleted=False + name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False ) if existing_dest and overwrite == "F": - raise HTTPException(status_code=412, detail="Destination exists and overwrite is false") + raise HTTPException( + status_code=412, detail="Destination exists and overwrite is false" + ) if existing_dest: await existing_dest.delete() @@ -562,15 +638,18 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep mime_type=source_resource.mime_type, file_hash=source_resource.file_hash, owner=current_user, - parent=dest_parent + parent=dest_parent, ) await log_activity(current_user, "file_copied", "file", new_file.id) return Response(status_code=201 if not existing_dest else 204) + @router.api_route("/{full_path:path}", methods=["MOVE"]) -async def handle_move(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_move( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") destination = request.headers.get("Destination") overwrite = request.headers.get("Overwrite", "T") @@ -578,16 +657,16 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep raise HTTPException(status_code=400, detail="Destination header required") dest_path = unquote(urlparse(destination).path) - dest_path = dest_path.replace('/webdav/', '').strip('/') + dest_path = dest_path.replace("/webdav/", "").strip("/") source_resource, _, exists = await resolve_path(full_path, current_user) if not source_resource: raise HTTPException(status_code=404, detail="Source not found") - dest_parts = [p for p in dest_path.split('/') if p] + dest_parts = [p for p in dest_path.split("/") if p] dest_name = dest_parts[-1] - dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else '' + dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else "" _, dest_parent, exists = await resolve_path(dest_parent_path, current_user) @@ -596,14 +675,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep if isinstance(source_resource, File): existing_dest = await File.get_or_none( - name=dest_name, - parent=dest_parent, - owner=current_user, - is_deleted=False + name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False ) if existing_dest and overwrite == "F": - raise HTTPException(status_code=412, detail="Destination exists and overwrite is false") + raise HTTPException( + status_code=412, detail="Destination exists and overwrite is false" + ) if existing_dest: await existing_dest.delete() @@ -617,14 +695,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep elif isinstance(source_resource, Folder): existing_dest = await Folder.get_or_none( - name=dest_name, - parent=dest_parent, - owner=current_user, - is_deleted=False + name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False ) if existing_dest and overwrite == "F": - raise HTTPException(status_code=412, detail="Destination exists and overwrite is false") + raise HTTPException( + status_code=412, detail="Destination exists and overwrite is false" + ) if existing_dest: existing_dest.is_deleted = True @@ -637,9 +714,12 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep await log_activity(current_user, "folder_moved", "folder", source_resource.id) return Response(status_code=201 if not existing_dest else 204) + @router.api_route("/{full_path:path}", methods=["LOCK"]) -async def handle_lock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_lock( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") timeout_header = request.headers.get("Timeout", "Second-3600") timeout = 3600 @@ -681,32 +761,38 @@ async def handle_lock(request: Request, full_path: str, current_user: User = Dep content=xml_content, media_type="application/xml; charset=utf-8", status_code=200, - headers={"Lock-Token": f"<{lock_token}>"} + headers={"Lock-Token": f"<{lock_token}>"}, ) + @router.api_route("/{full_path:path}", methods=["UNLOCK"]) -async def handle_unlock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_unlock( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") lock_token_header = request.headers.get("Lock-Token") if not lock_token_header: raise HTTPException(status_code=400, detail="Lock-Token header required") - lock_token = lock_token_header.strip('<>') + lock_token = lock_token_header.strip("<>") existing_lock = WebDAVLock.get_lock(full_path) - if not existing_lock or existing_lock['token'] != lock_token: + if not existing_lock or existing_lock["token"] != lock_token: raise HTTPException(status_code=409, detail="Invalid lock token") - if existing_lock['user_id'] != current_user.id: + if existing_lock["user_id"] != current_user.id: raise HTTPException(status_code=403, detail="Not lock owner") WebDAVLock.remove_lock(full_path) return Response(status_code=204) + @router.api_route("/{full_path:path}", methods=["PROPPATCH"]) -async def handle_proppatch(request: Request, full_path: str, current_user: User = Depends(webdav_auth)): - full_path = unquote(full_path).strip('/') +async def handle_proppatch( + request: Request, full_path: str, current_user: User = Depends(webdav_auth) +): + full_path = unquote(full_path).strip("/") resource, parent, exists = await resolve_path(full_path, current_user) @@ -719,7 +805,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User try: root = ET.fromstring(body) - except: + except Exception: raise HTTPException(status_code=400, detail="Invalid XML") resource_type = "file" if isinstance(resource, File) else "folder" @@ -734,13 +820,19 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User prop_element = set_element.find(".//{DAV:}prop") if prop_element is not None: for child in prop_element: - ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:" - name = child.tag.split('}')[1] if '}' in child.tag else child.tag + ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:" + name = child.tag.split("}")[1] if "}" in child.tag else child.tag value = child.text or "" if ns == "DAV:": - live_props = ["creationdate", "getcontentlength", "getcontenttype", - "getetag", "getlastmodified", "resourcetype"] + live_props = [ + "creationdate", + "getcontentlength", + "getcontenttype", + "getetag", + "getlastmodified", + "resourcetype", + ] if name in live_props: failed_props.append((ns, name, "409 Conflict")) continue @@ -750,7 +842,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User resource_type=resource_type, resource_id=resource_id, namespace=ns, - name=name + name=name, ) if existing_prop: @@ -762,10 +854,10 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User resource_id=resource_id, namespace=ns, name=name, - value=value + value=value, ) set_props.append((ns, name)) - except Exception as e: + except Exception: failed_props.append((ns, name, "500 Internal Server Error")) remove_element = root.find(".//{DAV:}remove") @@ -773,8 +865,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User prop_element = remove_element.find(".//{DAV:}prop") if prop_element is not None: for child in prop_element: - ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:" - name = child.tag.split('}')[1] if '}' in child.tag else child.tag + ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:" + name = child.tag.split("}")[1] if "}" in child.tag else child.tag if ns == "DAV:": failed_props.append((ns, name, "409 Conflict")) @@ -785,7 +877,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User resource_type=resource_type, resource_id=resource_id, namespace=ns, - name=name + name=name, ) if existing_prop: @@ -793,7 +885,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User remove_props.append((ns, name)) else: failed_props.append((ns, name, "404 Not Found")) - except Exception as e: + except Exception: failed_props.append((ns, name, "500 Internal Server Error")) multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"}) @@ -837,4 +929,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User await log_activity(current_user, "properties_modified", resource_type, resource_id) xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True) - return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207) + return Response( + content=xml_content, + media_type="application/xml; charset=utf-8", + status_code=207, + ) diff --git a/tests/billing/test_api_endpoints.py b/tests/billing/test_api_endpoints.py index e5e7405..be2c948 100644 --- a/tests/billing/test_api_endpoints.py +++ b/tests/billing/test_api_endpoints.py @@ -6,9 +6,15 @@ from httpx import AsyncClient, ASGITransport from fastapi import status from mywebdav.main import app from mywebdav.models import User -from mywebdav.billing.models import PricingConfig, Invoice, UsageAggregate, UserSubscription +from mywebdav.billing.models import ( + PricingConfig, + Invoice, + UsageAggregate, + UserSubscription, +) from mywebdav.auth import create_access_token + @pytest_asyncio.fixture async def test_user(): user = await User.create( @@ -16,11 +22,12 @@ async def test_user(): email="test@example.com", hashed_password="hashed_password_here", is_active=True, - is_superuser=False + is_superuser=False, ) yield user await user.delete() + @pytest_asyncio.fixture async def admin_user(): user = await User.create( @@ -28,58 +35,72 @@ async def admin_user(): email="admin@example.com", hashed_password="hashed_password_here", is_active=True, - is_superuser=True + is_superuser=True, ) yield user await user.delete() + @pytest_asyncio.fixture async def auth_token(test_user): token = create_access_token(data={"sub": test_user.username}) return token + @pytest_asyncio.fixture async def admin_token(admin_user): token = create_access_token(data={"sub": admin_user.username}) return token + @pytest_asyncio.fixture async def pricing_config(): configs = [] - configs.append(await PricingConfig.create( - config_key="storage_per_gb_month", - config_value=Decimal("0.0045"), - description="Storage cost per GB per month", - unit="per_gb_month" - )) - configs.append(await PricingConfig.create( - config_key="bandwidth_egress_per_gb", - config_value=Decimal("0.009"), - description="Bandwidth egress cost per GB", - unit="per_gb" - )) - configs.append(await PricingConfig.create( - config_key="free_tier_storage_gb", - config_value=Decimal("15"), - description="Free tier storage in GB", - unit="gb" - )) - configs.append(await PricingConfig.create( - config_key="free_tier_bandwidth_gb", - config_value=Decimal("15"), - description="Free tier bandwidth in GB per month", - unit="gb" - )) + configs.append( + await PricingConfig.create( + config_key="storage_per_gb_month", + config_value=Decimal("0.0045"), + description="Storage cost per GB per month", + unit="per_gb_month", + ) + ) + configs.append( + await PricingConfig.create( + config_key="bandwidth_egress_per_gb", + config_value=Decimal("0.009"), + description="Bandwidth egress cost per GB", + unit="per_gb", + ) + ) + configs.append( + await PricingConfig.create( + config_key="free_tier_storage_gb", + config_value=Decimal("15"), + description="Free tier storage in GB", + unit="gb", + ) + ) + configs.append( + await PricingConfig.create( + config_key="free_tier_bandwidth_gb", + config_value=Decimal("15"), + description="Free tier bandwidth in GB per month", + unit="gb", + ) + ) yield configs for config in configs: await config.delete() + @pytest.mark.asyncio async def test_get_current_usage(test_user, auth_token): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( "/api/billing/usage/current", - headers={"Authorization": f"Bearer {auth_token}"} + headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -87,6 +108,7 @@ async def test_get_current_usage(test_user, auth_token): assert "storage_gb" in data assert "bandwidth_down_gb_today" in data + @pytest.mark.asyncio async def test_get_monthly_usage(test_user, auth_token): today = date.today() @@ -94,16 +116,18 @@ async def test_get_monthly_usage(test_user, auth_token): await UsageAggregate.create( user=test_user, date=today, - storage_bytes_avg=1024 ** 3 * 10, - storage_bytes_peak=1024 ** 3 * 12, - bandwidth_up_bytes=1024 ** 3 * 2, - bandwidth_down_bytes=1024 ** 3 * 5 + storage_bytes_avg=1024**3 * 10, + storage_bytes_peak=1024**3 * 12, + bandwidth_up_bytes=1024**3 * 2, + bandwidth_down_bytes=1024**3 * 5, ) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( f"/api/billing/usage/monthly?year={today.year}&month={today.month}", - headers={"Authorization": f"Bearer {auth_token}"} + headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -112,12 +136,15 @@ async def test_get_monthly_usage(test_user, auth_token): await UsageAggregate.filter(user=test_user).delete() + @pytest.mark.asyncio async def test_get_subscription(test_user, auth_token): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( "/api/billing/subscription", - headers={"Authorization": f"Bearer {auth_token}"} + headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -127,6 +154,7 @@ async def test_get_subscription(test_user, auth_token): await UserSubscription.filter(user=test_user).delete() + @pytest.mark.asyncio async def test_list_invoices(test_user, auth_token): invoice = await Invoice.create( @@ -137,13 +165,14 @@ async def test_list_invoices(test_user, auth_token): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="open" + status="open", ) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( - "/api/billing/invoices", - headers={"Authorization": f"Bearer {auth_token}"} + "/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"} ) assert response.status_code == status.HTTP_200_OK @@ -153,6 +182,7 @@ async def test_list_invoices(test_user, auth_token): await invoice.delete() + @pytest.mark.asyncio async def test_get_invoice(test_user, auth_token): invoice = await Invoice.create( @@ -163,13 +193,15 @@ async def test_get_invoice(test_user, auth_token): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="open" + status="open", ) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( f"/api/billing/invoices/{invoice.id}", - headers={"Authorization": f"Bearer {auth_token}"} + headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -178,39 +210,45 @@ async def test_get_invoice(test_user, auth_token): await invoice.delete() + @pytest.mark.asyncio async def test_get_pricing(): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get("/api/billing/pricing") assert response.status_code == status.HTTP_200_OK data = response.json() assert isinstance(data, dict) + @pytest.mark.asyncio async def test_admin_get_pricing(admin_user, admin_token, pricing_config): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( "/api/admin/billing/pricing", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert len(data) > 0 + @pytest.mark.asyncio async def test_admin_update_pricing(admin_user, admin_token, pricing_config): config_id = pricing_config[0].id - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.put( f"/api/admin/billing/pricing/{config_id}", headers={"Authorization": f"Bearer {admin_token}"}, - json={ - "config_key": "storage_per_gb_month", - "config_value": 0.005 - } + json={"config_key": "storage_per_gb_month", "config_value": 0.005}, ) assert response.status_code == status.HTTP_200_OK @@ -218,12 +256,15 @@ async def test_admin_update_pricing(admin_user, admin_token, pricing_config): updated = await PricingConfig.get(id=config_id) assert updated.config_value == Decimal("0.005") + @pytest.mark.asyncio async def test_admin_get_stats(admin_user, admin_token): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( "/api/admin/billing/stats", - headers={"Authorization": f"Bearer {admin_token}"} + headers={"Authorization": f"Bearer {admin_token}"}, ) assert response.status_code == status.HTTP_200_OK @@ -232,12 +273,15 @@ async def test_admin_get_stats(admin_user, admin_token): assert "total_invoices" in data assert "pending_invoices" in data + @pytest.mark.asyncio async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: response = await client.get( "/api/admin/billing/pricing", - headers={"Authorization": f"Bearer {auth_token}"} + headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/tests/billing/test_invoice_generator.py b/tests/billing/test_invoice_generator.py index 1f52745..72c6cc5 100644 --- a/tests/billing/test_invoice_generator.py +++ b/tests/billing/test_invoice_generator.py @@ -3,57 +3,76 @@ import pytest_asyncio from decimal import Decimal from datetime import date, datetime from mywebdav.models import User -from mywebdav.billing.models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription +from mywebdav.billing.models import ( + Invoice, + InvoiceLineItem, + PricingConfig, + UsageAggregate, + UserSubscription, +) from mywebdav.billing.invoice_generator import InvoiceGenerator + @pytest_asyncio.fixture async def test_user(): user = await User.create( username="testuser", email="test@example.com", hashed_password="hashed_password_here", - is_active=True + is_active=True, ) yield user await user.delete() + @pytest_asyncio.fixture async def pricing_config(): configs = [] - configs.append(await PricingConfig.create( - config_key="storage_per_gb_month", - config_value=Decimal("0.0045"), - description="Storage cost per GB per month", - unit="per_gb_month" - )) - configs.append(await PricingConfig.create( - config_key="bandwidth_egress_per_gb", - config_value=Decimal("0.009"), - description="Bandwidth egress cost per GB", - unit="per_gb" - )) - configs.append(await PricingConfig.create( - config_key="free_tier_storage_gb", - config_value=Decimal("15"), - description="Free tier storage in GB", - unit="gb" - )) - configs.append(await PricingConfig.create( - config_key="free_tier_bandwidth_gb", - config_value=Decimal("15"), - description="Free tier bandwidth in GB per month", - unit="gb" - )) - configs.append(await PricingConfig.create( - config_key="tax_rate_default", - config_value=Decimal("0.0"), - description="Default tax rate", - unit="percentage" - )) + configs.append( + await PricingConfig.create( + config_key="storage_per_gb_month", + config_value=Decimal("0.0045"), + description="Storage cost per GB per month", + unit="per_gb_month", + ) + ) + configs.append( + await PricingConfig.create( + config_key="bandwidth_egress_per_gb", + config_value=Decimal("0.009"), + description="Bandwidth egress cost per GB", + unit="per_gb", + ) + ) + configs.append( + await PricingConfig.create( + config_key="free_tier_storage_gb", + config_value=Decimal("15"), + description="Free tier storage in GB", + unit="gb", + ) + ) + configs.append( + await PricingConfig.create( + config_key="free_tier_bandwidth_gb", + config_value=Decimal("15"), + description="Free tier bandwidth in GB per month", + unit="gb", + ) + ) + configs.append( + await PricingConfig.create( + config_key="tax_rate_default", + config_value=Decimal("0.0"), + description="Default tax rate", + unit="percentage", + ) + ) yield configs for config in configs: await config.delete() + @pytest.mark.asyncio async def test_generate_monthly_invoice_with_usage(test_user, pricing_config): today = date.today() @@ -61,13 +80,15 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config): await UsageAggregate.create( user=test_user, date=today, - storage_bytes_avg=1024 ** 3 * 50, - storage_bytes_peak=1024 ** 3 * 55, - bandwidth_up_bytes=1024 ** 3 * 10, - bandwidth_down_bytes=1024 ** 3 * 20 + storage_bytes_avg=1024**3 * 50, + storage_bytes_peak=1024**3 * 55, + bandwidth_up_bytes=1024**3 * 10, + bandwidth_down_bytes=1024**3 * 20, ) - invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month) + invoice = await InvoiceGenerator.generate_monthly_invoice( + test_user, today.year, today.month + ) assert invoice is not None assert invoice.user_id == test_user.id @@ -80,6 +101,7 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config): await invoice.delete() await UsageAggregate.filter(user=test_user).delete() + @pytest.mark.asyncio async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config): today = date.today() @@ -87,18 +109,21 @@ async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_confi await UsageAggregate.create( user=test_user, date=today, - storage_bytes_avg=1024 ** 3 * 10, - storage_bytes_peak=1024 ** 3 * 12, - bandwidth_up_bytes=1024 ** 3 * 5, - bandwidth_down_bytes=1024 ** 3 * 10 + storage_bytes_avg=1024**3 * 10, + storage_bytes_peak=1024**3 * 12, + bandwidth_up_bytes=1024**3 * 5, + bandwidth_down_bytes=1024**3 * 10, ) - invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month) + invoice = await InvoiceGenerator.generate_monthly_invoice( + test_user, today.year, today.month + ) assert invoice is None await UsageAggregate.filter(user=test_user).delete() + @pytest.mark.asyncio async def test_finalize_invoice(test_user, pricing_config): invoice = await Invoice.create( @@ -109,7 +134,7 @@ async def test_finalize_invoice(test_user, pricing_config): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="draft" + status="draft", ) finalized = await InvoiceGenerator.finalize_invoice(invoice) @@ -118,6 +143,7 @@ async def test_finalize_invoice(test_user, pricing_config): await finalized.delete() + @pytest.mark.asyncio async def test_finalize_invoice_already_finalized(test_user, pricing_config): invoice = await Invoice.create( @@ -128,7 +154,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="open" + status="open", ) with pytest.raises(ValueError): @@ -136,6 +162,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config): await invoice.delete() + @pytest.mark.asyncio async def test_mark_invoice_paid(test_user, pricing_config): invoice = await Invoice.create( @@ -146,7 +173,7 @@ async def test_mark_invoice_paid(test_user, pricing_config): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="open" + status="open", ) paid = await InvoiceGenerator.mark_invoice_paid(invoice) @@ -156,10 +183,13 @@ async def test_mark_invoice_paid(test_user, pricing_config): await paid.delete() + @pytest.mark.asyncio async def test_invoice_with_tax(test_user, pricing_config): # Update tax rate - updated = await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.21")) + updated = await PricingConfig.filter(config_key="tax_rate_default").update( + config_value=Decimal("0.21") + ) assert updated == 1 # Should update 1 row # Verify the update worked @@ -171,13 +201,15 @@ async def test_invoice_with_tax(test_user, pricing_config): await UsageAggregate.create( user=test_user, date=today, - storage_bytes_avg=1024 ** 3 * 50, - storage_bytes_peak=1024 ** 3 * 55, - bandwidth_up_bytes=1024 ** 3 * 10, - bandwidth_down_bytes=1024 ** 3 * 20 + storage_bytes_avg=1024**3 * 50, + storage_bytes_peak=1024**3 * 55, + bandwidth_up_bytes=1024**3 * 10, + bandwidth_down_bytes=1024**3 * 20, ) - invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month) + invoice = await InvoiceGenerator.generate_monthly_invoice( + test_user, today.year, today.month + ) assert invoice is not None assert invoice.tax > 0 @@ -185,4 +217,6 @@ async def test_invoice_with_tax(test_user, pricing_config): await invoice.delete() await UsageAggregate.filter(user=test_user).delete() - await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.0")) + await PricingConfig.filter(config_key="tax_rate_default").update( + config_value=Decimal("0.0") + ) diff --git a/tests/billing/test_models.py b/tests/billing/test_models.py index 49bd21a..9ef51d1 100644 --- a/tests/billing/test_models.py +++ b/tests/billing/test_models.py @@ -4,21 +4,30 @@ from decimal import Decimal from datetime import date, datetime from mywebdav.models import User from mywebdav.billing.models import ( - SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate, - Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent + SubscriptionPlan, + UserSubscription, + UsageRecord, + UsageAggregate, + Invoice, + InvoiceLineItem, + PricingConfig, + PaymentMethod, + BillingEvent, ) + @pytest_asyncio.fixture async def test_user(): user = await User.create( username="testuser", email="test@example.com", hashed_password="hashed_password_here", - is_active=True + is_active=True, ) yield user await user.delete() + @pytest.mark.asyncio async def test_subscription_plan_creation(): plan = await SubscriptionPlan.create( @@ -28,7 +37,7 @@ async def test_subscription_plan_creation(): storage_gb=100, bandwidth_gb=100, price_monthly=Decimal("5.00"), - price_yearly=Decimal("50.00") + price_yearly=Decimal("50.00"), ) assert plan.name == "starter" @@ -36,12 +45,11 @@ async def test_subscription_plan_creation(): assert plan.price_monthly == Decimal("5.00") await plan.delete() + @pytest.mark.asyncio async def test_user_subscription_creation(test_user): subscription = await UserSubscription.create( - user=test_user, - billing_type="pay_as_you_go", - status="active" + user=test_user, billing_type="pay_as_you_go", status="active" ) assert subscription.user_id == test_user.id @@ -49,6 +57,7 @@ async def test_user_subscription_creation(test_user): assert subscription.status == "active" await subscription.delete() + @pytest.mark.asyncio async def test_usage_record_creation(test_user): usage = await UsageRecord.create( @@ -57,7 +66,7 @@ async def test_usage_record_creation(test_user): amount_bytes=1024 * 1024 * 100, resource_type="file", resource_id=1, - idempotency_key="test_key_123" + idempotency_key="test_key_123", ) assert usage.user_id == test_user.id @@ -65,6 +74,7 @@ async def test_usage_record_creation(test_user): assert usage.amount_bytes == 1024 * 1024 * 100 await usage.delete() + @pytest.mark.asyncio async def test_usage_aggregate_creation(test_user): aggregate = await UsageAggregate.create( @@ -73,13 +83,14 @@ async def test_usage_aggregate_creation(test_user): storage_bytes_avg=1024 * 1024 * 500, storage_bytes_peak=1024 * 1024 * 600, bandwidth_up_bytes=1024 * 1024 * 50, - bandwidth_down_bytes=1024 * 1024 * 100 + bandwidth_down_bytes=1024 * 1024 * 100, ) assert aggregate.user_id == test_user.id assert aggregate.storage_bytes_avg == 1024 * 1024 * 500 await aggregate.delete() + @pytest.mark.asyncio async def test_invoice_creation(test_user): invoice = await Invoice.create( @@ -90,7 +101,7 @@ async def test_invoice_creation(test_user): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="draft" + status="draft", ) assert invoice.user_id == test_user.id @@ -98,6 +109,7 @@ async def test_invoice_creation(test_user): assert invoice.total == Decimal("10.00") await invoice.delete() + @pytest.mark.asyncio async def test_invoice_line_item_creation(test_user): invoice = await Invoice.create( @@ -108,7 +120,7 @@ async def test_invoice_line_item_creation(test_user): subtotal=Decimal("10.00"), tax=Decimal("0.00"), total=Decimal("10.00"), - status="draft" + status="draft", ) line_item = await InvoiceLineItem.create( @@ -117,7 +129,7 @@ async def test_invoice_line_item_creation(test_user): quantity=Decimal("100.000000"), unit_price=Decimal("0.100000"), amount=Decimal("10.0000"), - item_type="storage" + item_type="storage", ) assert line_item.invoice_id == invoice.id @@ -126,6 +138,7 @@ async def test_invoice_line_item_creation(test_user): await line_item.delete() await invoice.delete() + @pytest.mark.asyncio async def test_pricing_config_creation(test_user): config = await PricingConfig.create( @@ -133,13 +146,14 @@ async def test_pricing_config_creation(test_user): config_value=Decimal("0.0045"), description="Storage cost per GB per month", unit="per_gb_month", - updated_by=test_user + updated_by=test_user, ) assert config.config_key == "storage_per_gb_month" assert config.config_value == Decimal("0.0045") await config.delete() + @pytest.mark.asyncio async def test_payment_method_creation(test_user): payment_method = await PaymentMethod.create( @@ -150,7 +164,7 @@ async def test_payment_method_creation(test_user): last4="4242", brand="visa", exp_month=12, - exp_year=2025 + exp_year=2025, ) assert payment_method.user_id == test_user.id @@ -158,6 +172,7 @@ async def test_payment_method_creation(test_user): assert payment_method.is_default is True await payment_method.delete() + @pytest.mark.asyncio async def test_billing_event_creation(test_user): event = await BillingEvent.create( @@ -165,7 +180,7 @@ async def test_billing_event_creation(test_user): event_type="invoice_created", stripe_event_id="evt_test_123", data={"invoice_id": 1}, - processed=False + processed=False, ) assert event.user_id == test_user.id diff --git a/tests/billing/test_simple.py b/tests/billing/test_simple.py index 657eb4e..f3ebbdc 100644 --- a/tests/billing/test_simple.py +++ b/tests/billing/test_simple.py @@ -1,18 +1,35 @@ import pytest + def test_billing_module_imports(): - from mywebdav.billing import models, stripe_client, usage_tracker, invoice_generator, scheduler + from mywebdav.billing import ( + models, + stripe_client, + usage_tracker, + invoice_generator, + scheduler, + ) + assert models is not None assert stripe_client is not None assert usage_tracker is not None assert invoice_generator is not None assert scheduler is not None + def test_billing_models_exist(): from mywebdav.billing.models import ( - SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate, - Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent + SubscriptionPlan, + UserSubscription, + UsageRecord, + UsageAggregate, + Invoice, + InvoiceLineItem, + PricingConfig, + PaymentMethod, + BillingEvent, ) + assert SubscriptionPlan is not None assert UserSubscription is not None assert UsageRecord is not None @@ -23,55 +40,71 @@ def test_billing_models_exist(): assert PaymentMethod is not None assert BillingEvent is not None + def test_stripe_client_exists(): from mywebdav.billing.stripe_client import StripeClient + assert StripeClient is not None - assert hasattr(StripeClient, 'create_customer') - assert hasattr(StripeClient, 'create_invoice') - assert hasattr(StripeClient, 'finalize_invoice') + assert hasattr(StripeClient, "create_customer") + assert hasattr(StripeClient, "create_invoice") + assert hasattr(StripeClient, "finalize_invoice") + def test_usage_tracker_exists(): from mywebdav.billing.usage_tracker import UsageTracker + assert UsageTracker is not None - assert hasattr(UsageTracker, 'track_storage') - assert hasattr(UsageTracker, 'track_bandwidth') - assert hasattr(UsageTracker, 'aggregate_daily_usage') - assert hasattr(UsageTracker, 'get_current_storage') - assert hasattr(UsageTracker, 'get_monthly_usage') + assert hasattr(UsageTracker, "track_storage") + assert hasattr(UsageTracker, "track_bandwidth") + assert hasattr(UsageTracker, "aggregate_daily_usage") + assert hasattr(UsageTracker, "get_current_storage") + assert hasattr(UsageTracker, "get_monthly_usage") + def test_invoice_generator_exists(): from mywebdav.billing.invoice_generator import InvoiceGenerator + assert InvoiceGenerator is not None - assert hasattr(InvoiceGenerator, 'generate_monthly_invoice') - assert hasattr(InvoiceGenerator, 'finalize_invoice') - assert hasattr(InvoiceGenerator, 'mark_invoice_paid') + assert hasattr(InvoiceGenerator, "generate_monthly_invoice") + assert hasattr(InvoiceGenerator, "finalize_invoice") + assert hasattr(InvoiceGenerator, "mark_invoice_paid") + def test_scheduler_exists(): from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler + assert scheduler is not None assert callable(start_scheduler) assert callable(stop_scheduler) + def test_routers_exist(): from mywebdav.routers import billing, admin_billing + assert billing is not None assert admin_billing is not None - assert hasattr(billing, 'router') - assert hasattr(admin_billing, 'router') + assert hasattr(billing, "router") + assert hasattr(admin_billing, "router") + def test_middleware_exists(): from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware + assert UsageTrackingMiddleware is not None + def test_settings_updated(): from mywebdav.settings import settings - assert hasattr(settings, 'STRIPE_SECRET_KEY') - assert hasattr(settings, 'STRIPE_PUBLISHABLE_KEY') - assert hasattr(settings, 'STRIPE_WEBHOOK_SECRET') - assert hasattr(settings, 'BILLING_ENABLED') + + assert hasattr(settings, "STRIPE_SECRET_KEY") + assert hasattr(settings, "STRIPE_PUBLISHABLE_KEY") + assert hasattr(settings, "STRIPE_WEBHOOK_SECRET") + assert hasattr(settings, "BILLING_ENABLED") + def test_main_includes_billing(): from mywebdav.main import app + routes = [route.path for route in app.routes] - billing_routes = [r for r in routes if '/billing' in r] + billing_routes = [r for r in routes if "/billing" in r] assert len(billing_routes) > 0 diff --git a/tests/billing/test_usage_tracker.py b/tests/billing/test_usage_tracker.py index 6aac8bf..35f3b71 100644 --- a/tests/billing/test_usage_tracker.py +++ b/tests/billing/test_usage_tracker.py @@ -6,24 +6,26 @@ from mywebdav.models import User, File, Folder from mywebdav.billing.models import UsageRecord, UsageAggregate from mywebdav.billing.usage_tracker import UsageTracker + @pytest_asyncio.fixture async def test_user(): user = await User.create( username="testuser", email="test@example.com", hashed_password="hashed_password_here", - is_active=True + is_active=True, ) yield user await user.delete() + @pytest.mark.asyncio async def test_track_storage(test_user): await UsageTracker.track_storage( user=test_user, amount_bytes=1024 * 1024 * 100, resource_type="file", - resource_id=1 + resource_id=1, ) records = await UsageRecord.filter(user=test_user, record_type="storage").all() @@ -32,6 +34,7 @@ async def test_track_storage(test_user): await records[0].delete() + @pytest.mark.asyncio async def test_track_bandwidth_upload(test_user): await UsageTracker.track_bandwidth( @@ -39,7 +42,7 @@ async def test_track_bandwidth_upload(test_user): amount_bytes=1024 * 1024 * 50, direction="up", resource_type="file", - resource_id=1 + resource_id=1, ) records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all() @@ -48,6 +51,7 @@ async def test_track_bandwidth_upload(test_user): await records[0].delete() + @pytest.mark.asyncio async def test_track_bandwidth_download(test_user): await UsageTracker.track_bandwidth( @@ -55,32 +59,28 @@ async def test_track_bandwidth_download(test_user): amount_bytes=1024 * 1024 * 75, direction="down", resource_type="file", - resource_id=1 + resource_id=1, ) - records = await UsageRecord.filter(user=test_user, record_type="bandwidth_down").all() + records = await UsageRecord.filter( + user=test_user, record_type="bandwidth_down" + ).all() assert len(records) == 1 assert records[0].amount_bytes == 1024 * 1024 * 75 await records[0].delete() + @pytest.mark.asyncio async def test_aggregate_daily_usage(test_user): - await UsageTracker.track_storage( - user=test_user, - amount_bytes=1024 * 1024 * 100 + await UsageTracker.track_storage(user=test_user, amount_bytes=1024 * 1024 * 100) + + await UsageTracker.track_bandwidth( + user=test_user, amount_bytes=1024 * 1024 * 50, direction="up" ) await UsageTracker.track_bandwidth( - user=test_user, - amount_bytes=1024 * 1024 * 50, - direction="up" - ) - - await UsageTracker.track_bandwidth( - user=test_user, - amount_bytes=1024 * 1024 * 75, - direction="down" + user=test_user, amount_bytes=1024 * 1024 * 75, direction="down" ) aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today()) @@ -93,12 +93,10 @@ async def test_aggregate_daily_usage(test_user): await aggregate.delete() await UsageRecord.filter(user=test_user).delete() + @pytest.mark.asyncio async def test_get_current_storage(test_user): - folder = await Folder.create( - name="Test Folder", - owner=test_user - ) + folder = await Folder.create(name="Test Folder", owner=test_user) file1 = await File.create( name="test1.txt", @@ -107,7 +105,7 @@ async def test_get_current_storage(test_user): mime_type="text/plain", owner=test_user, parent=folder, - is_deleted=False + is_deleted=False, ) file2 = await File.create( @@ -117,7 +115,7 @@ async def test_get_current_storage(test_user): mime_type="text/plain", owner=test_user, parent=folder, - is_deleted=False + is_deleted=False, ) storage = await UsageTracker.get_current_storage(test_user) @@ -128,6 +126,7 @@ async def test_get_current_storage(test_user): await file2.delete() await folder.delete() + @pytest.mark.asyncio async def test_get_monthly_usage(test_user): today = date.today() @@ -138,7 +137,7 @@ async def test_get_monthly_usage(test_user): storage_bytes_avg=1024 * 1024 * 1024 * 10, storage_bytes_peak=1024 * 1024 * 1024 * 12, bandwidth_up_bytes=1024 * 1024 * 1024 * 2, - bandwidth_down_bytes=1024 * 1024 * 1024 * 5 + bandwidth_down_bytes=1024 * 1024 * 1024 * 5, ) usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month) @@ -150,11 +149,14 @@ async def test_get_monthly_usage(test_user): await aggregate.delete() + @pytest.mark.asyncio async def test_get_monthly_usage_empty(test_user): future_date = date.today() + timedelta(days=365) - usage = await UsageTracker.get_monthly_usage(test_user, future_date.year, future_date.month) + usage = await UsageTracker.get_monthly_usage( + test_user, future_date.year, future_date.month + ) assert usage["storage_gb_avg"] == 0 assert usage["storage_gb_peak"] == 0 diff --git a/tests/conftest.py b/tests/conftest.py index 7d6fdc9..a7a59e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,27 +3,21 @@ import pytest_asyncio import asyncio from tortoise import Tortoise -# Initialize Tortoise at module level -async def _init_tortoise(): + +@pytest_asyncio.fixture(scope="session") +async def init_db(): + """Initialize the database for testing.""" await Tortoise.init( db_url="sqlite://:memory:", - modules={ - "models": ["mywebdav.models"], - "billing": ["mywebdav.billing.models"] - } + modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]}, ) await Tortoise.generate_schemas() - -# Run initialization -asyncio.run(_init_tortoise()) - -@pytest.fixture(scope="session") -def event_loop(): - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - -@pytest_asyncio.fixture(scope="session", autouse=True) -async def cleanup_db(): yield await Tortoise.close_connections() + + +@pytest_asyncio.fixture(autouse=True) +async def setup_db(init_db): + """Ensure database is initialized before each test.""" + # The init_db fixture handles the setup + yield