Compare commits

..

No commits in common. "d350ab680762e159fca9a39484f17361efc4d990" and "b25fb89df493a9d244f4d6e0ceb21455a6016e92" have entirely different histories.

47 changed files with 3148 additions and 1886 deletions

View File

@ -1,18 +1,17 @@
from typing import Optional from typing import Optional
from .models import Activity, User from .models import Activity, User
async def log_activity( async def log_activity(
user: Optional[User], user: Optional[User],
action: str, action: str,
target_type: str, target_type: str,
target_id: int, target_id: int,
ip_address: Optional[str] = None, ip_address: Optional[str] = None
): ):
await Activity.create( await Activity.create(
user=user, user=user,
action=action, action=action,
target_type=target_type, target_type=target_type,
target_id=target_id, target_id=target_id,
ip_address=ip_address, ip_address=ip_address
) )

View File

@ -1,4 +1,4 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
@ -9,29 +9,20 @@ import bcrypt
from .schemas import TokenData from .schemas import TokenData
from .settings import settings from .settings import settings
from .models import User from .models import User
from .two_factor import verify_totp_code # Import verify_totp_code from .two_factor import verify_totp_code # Import verify_totp_code
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def verify_password(plain_password, hashed_password): def verify_password(plain_password, hashed_password):
password_bytes = plain_password[:72].encode("utf-8") password_bytes = plain_password[:72].encode('utf-8')
hashed_bytes = ( hashed_bytes = hashed_password.encode('utf-8') if isinstance(hashed_password, str) else hashed_password
hashed_password.encode("utf-8")
if isinstance(hashed_password, str)
else hashed_password
)
return bcrypt.checkpw(password_bytes, hashed_bytes) return bcrypt.checkpw(password_bytes, hashed_bytes)
def get_password_hash(password): def get_password_hash(password):
password_bytes = password[:72].encode("utf-8") password_bytes = password[:72].encode('utf-8')
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8") return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8')
async def authenticate_user(username: str, password: str, two_factor_code: Optional[str] = None):
async def authenticate_user(
username: str, password: str, two_factor_code: Optional[str] = None
):
user = await User.get_or_none(username=username) user = await User.get_or_none(username=username)
if not user: if not user:
return None return None
@ -42,29 +33,19 @@ async def authenticate_user(
if not two_factor_code: if not two_factor_code:
return {"user": user, "2fa_required": True} return {"user": user, "2fa_required": True}
if not verify_totp_code(user.two_factor_secret, two_factor_code): if not verify_totp_code(user.two_factor_secret, two_factor_code):
return None # 2FA code is incorrect return None # 2FA code is incorrect
return {"user": user, "2fa_required": False} return {"user": user, "2fa_required": False}
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, two_factor_verified: bool = False):
def create_access_token(
data: dict,
expires_delta: Optional[timedelta] = None,
two_factor_verified: bool = False,
):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.utcnow() + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta( expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire, "2fa_verified": two_factor_verified}) to_encode.update({"exp": expire, "2fa_verified": two_factor_verified})
encoded_jwt = jwt.encode( encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)): async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -72,45 +53,31 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload = jwt.decode( payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub") username: str = payload.get("sub")
two_factor_verified: bool = payload.get("2fa_verified", False) two_factor_verified: bool = payload.get("2fa_verified", False)
if username is None: if username is None:
raise credentials_exception raise credentials_exception
token_data = TokenData( token_data = TokenData(username=username, two_factor_verified=two_factor_verified)
username=username, two_factor_verified=two_factor_verified
)
except JWTError: except JWTError:
raise credentials_exception raise credentials_exception
user = await User.get_or_none(username=token_data.username) user = await User.get_or_none(username=token_data.username)
if user is None: if user is None:
raise credentials_exception raise credentials_exception
user.token_data = token_data # Attach token_data to user for easy access user.token_data = token_data # Attach token_data to user for easy access
return user return user
async def get_current_active_user(current_user: User = Depends(get_current_user)): async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.is_active: if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return current_user return current_user
async def get_current_verified_user(current_user: User = Depends(get_current_user)): async def get_current_verified_user(current_user: User = Depends(get_current_user)):
if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified: if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified:
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="2FA required and not verified")
status_code=status.HTTP_403_FORBIDDEN,
detail="2FA required and not verified",
)
return current_user return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_verified_user)):
async def get_current_admin_user(
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
)
return current_user return current_user

View File

@ -1,18 +1,15 @@
from datetime import datetime, date, timedelta, timezone from datetime import datetime, date, timedelta
from decimal import Decimal from decimal import Decimal
from typing import Optional from typing import Optional
from calendar import monthrange from calendar import monthrange
from .models import Invoice, InvoiceLineItem, PricingConfig, UserSubscription from .models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
from .usage_tracker import UsageTracker from .usage_tracker import UsageTracker
from .stripe_client import StripeClient from .stripe_client import StripeClient
from ..models import User from ..models import User
class InvoiceGenerator: class InvoiceGenerator:
@staticmethod @staticmethod
async def generate_monthly_invoice( async def generate_monthly_invoice(user: User, year: int, month: int) -> Optional[Invoice]:
user: User, year: int, month: int
) -> Optional[Invoice]:
period_start = date(year, month, 1) period_start = date(year, month, 1)
days_in_month = monthrange(year, month)[1] days_in_month = monthrange(year, month)[1]
period_end = date(year, month, days_in_month) period_end = date(year, month, days_in_month)
@ -22,24 +19,19 @@ class InvoiceGenerator:
pricing = await PricingConfig.all() pricing = await PricingConfig.all()
pricing_dict = {p.config_key: p.config_value for p in pricing} pricing_dict = {p.config_key: p.config_value for p in pricing}
storage_price_per_gb = pricing_dict.get( storage_price_per_gb = pricing_dict.get('storage_per_gb_month', Decimal('0.0045'))
"storage_per_gb_month", Decimal("0.0045") bandwidth_price_per_gb = pricing_dict.get('bandwidth_egress_per_gb', Decimal('0.009'))
) free_storage_gb = pricing_dict.get('free_tier_storage_gb', Decimal('15'))
bandwidth_price_per_gb = pricing_dict.get( free_bandwidth_gb = pricing_dict.get('free_tier_bandwidth_gb', Decimal('15'))
"bandwidth_egress_per_gb", Decimal("0.009") tax_rate = pricing_dict.get('tax_rate_default', Decimal('0'))
)
free_storage_gb = pricing_dict.get("free_tier_storage_gb", Decimal("15"))
free_bandwidth_gb = pricing_dict.get("free_tier_bandwidth_gb", Decimal("15"))
tax_rate = pricing_dict.get("tax_rate_default", Decimal("0"))
storage_gb = Decimal(str(usage["storage_gb_avg"])) storage_gb = Decimal(str(usage['storage_gb_avg']))
bandwidth_gb = Decimal(str(usage["bandwidth_down_gb"])) bandwidth_gb = Decimal(str(usage['bandwidth_down_gb']))
billable_storage = max(Decimal("0"), storage_gb - free_storage_gb) billable_storage = max(Decimal('0'), storage_gb - free_storage_gb)
billable_bandwidth = max(Decimal("0"), bandwidth_gb - free_bandwidth_gb) billable_bandwidth = max(Decimal('0'), bandwidth_gb - free_bandwidth_gb)
import math import math
billable_storage_rounded = Decimal(math.ceil(float(billable_storage))) billable_storage_rounded = Decimal(math.ceil(float(billable_storage)))
billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth))) billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth)))
@ -73,9 +65,9 @@ class InvoiceGenerator:
"usage": usage, "usage": usage,
"pricing": { "pricing": {
"storage_per_gb": float(storage_price_per_gb), "storage_per_gb": float(storage_price_per_gb),
"bandwidth_per_gb": float(bandwidth_price_per_gb), "bandwidth_per_gb": float(bandwidth_price_per_gb)
}, }
}, }
) )
if billable_storage_rounded > 0: if billable_storage_rounded > 0:
@ -86,10 +78,7 @@ class InvoiceGenerator:
unit_price=storage_price_per_gb, unit_price=storage_price_per_gb,
amount=storage_cost, amount=storage_cost,
item_type="storage", item_type="storage",
metadata={ metadata={"avg_gb": float(storage_gb), "free_gb": float(free_storage_gb)}
"avg_gb": float(storage_gb),
"free_gb": float(free_storage_gb),
},
) )
if billable_bandwidth_rounded > 0: if billable_bandwidth_rounded > 0:
@ -100,10 +89,7 @@ class InvoiceGenerator:
unit_price=bandwidth_price_per_gb, unit_price=bandwidth_price_per_gb,
amount=bandwidth_cost, amount=bandwidth_cost,
item_type="bandwidth", item_type="bandwidth",
metadata={ metadata={"total_gb": float(bandwidth_gb), "free_gb": float(free_bandwidth_gb)}
"total_gb": float(bandwidth_gb),
"free_gb": float(free_bandwidth_gb),
},
) )
if subscription and subscription.stripe_customer_id: if subscription and subscription.stripe_customer_id:
@ -114,7 +100,7 @@ class InvoiceGenerator:
"amount": item.amount, "amount": item.amount,
"currency": "usd", "currency": "usd",
"description": item.description, "description": item.description,
"metadata": item.metadata or {}, "metadata": item.metadata or {}
} }
for item in line_items for item in line_items
] ]
@ -123,7 +109,7 @@ class InvoiceGenerator:
customer_id=subscription.stripe_customer_id, customer_id=subscription.stripe_customer_id,
description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}", description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}",
line_items=stripe_line_items, line_items=stripe_line_items,
metadata={"mywebdav_invoice_id": str(invoice.id)}, metadata={"mywebdav_invoice_id": str(invoice.id)}
) )
invoice.stripe_invoice_id = stripe_invoice.id invoice.stripe_invoice_id = stripe_invoice.id
@ -149,11 +135,8 @@ class InvoiceGenerator:
# Send invoice email # Send invoice email
from ..mail import queue_email from ..mail import queue_email
line_items = await invoice.line_items.all() line_items = await invoice.line_items.all()
items_text = "\n".join( items_text = "\n".join([f"- {item.description}: ${item.amount}" for item in line_items])
[f"- {item.description}: ${item.amount}" for item in line_items]
)
body = f"""Dear {invoice.user.username}, body = f"""Dear {invoice.user.username},
Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available. Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available.
@ -191,7 +174,7 @@ The MyWebdav Team
to_email=invoice.user.email, to_email=invoice.user.email,
subject=f"Your MyWebdav Invoice {invoice.invoice_number}", subject=f"Your MyWebdav Invoice {invoice.invoice_number}",
body=body, body=body,
html=html, html=html
) )
return invoice return invoice
@ -199,6 +182,6 @@ The MyWebdav Team
@staticmethod @staticmethod
async def mark_invoice_paid(invoice: Invoice) -> Invoice: async def mark_invoice_paid(invoice: Invoice) -> Invoice:
invoice.status = "paid" invoice.status = "paid"
invoice.paid_at = datetime.now(timezone.utc) invoice.paid_at = datetime.utcnow()
await invoice.save() await invoice.save()
return invoice return invoice

View File

@ -1,8 +1,8 @@
from tortoise import fields, models from tortoise import fields, models
from decimal import Decimal
class SubscriptionPlan(models.Model): class SubscriptionPlan(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
name = fields.CharField(max_length=100, unique=True) name = fields.CharField(max_length=100, unique=True)
display_name = fields.CharField(max_length=255) display_name = fields.CharField(max_length=255)
description = fields.TextField(null=True) description = fields.TextField(null=True)
@ -18,13 +18,10 @@ class SubscriptionPlan(models.Model):
class Meta: class Meta:
table = "subscription_plans" table = "subscription_plans"
class UserSubscription(models.Model): class UserSubscription(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
user = fields.ForeignKeyField("models.User", related_name="subscription") user = fields.ForeignKeyField("models.User", related_name="subscription")
plan = fields.ForeignKeyField( plan = fields.ForeignKeyField("billing.SubscriptionPlan", related_name="subscriptions", null=True)
"billing.SubscriptionPlan", related_name="subscriptions", null=True
)
billing_type = fields.CharField(max_length=20, default="pay_as_you_go") billing_type = fields.CharField(max_length=20, default="pay_as_you_go")
stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True) stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True)
stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True) stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True)
@ -38,15 +35,14 @@ class UserSubscription(models.Model):
class Meta: class Meta:
table = "user_subscriptions" table = "user_subscriptions"
class UsageRecord(models.Model): class UsageRecord(models.Model):
id = fields.BigIntField(primary_key=True) id = fields.BigIntField(pk=True)
user = fields.ForeignKeyField("models.User", related_name="usage_records") user = fields.ForeignKeyField("models.User", related_name="usage_records")
record_type = fields.CharField(max_length=50, db_index=True) record_type = fields.CharField(max_length=50, index=True)
amount_bytes = fields.BigIntField() amount_bytes = fields.BigIntField()
resource_type = fields.CharField(max_length=50, null=True) resource_type = fields.CharField(max_length=50, null=True)
resource_id = fields.IntField(null=True) resource_id = fields.IntField(null=True)
timestamp = fields.DatetimeField(auto_now_add=True, db_index=True) timestamp = fields.DatetimeField(auto_now_add=True, index=True)
idempotency_key = fields.CharField(max_length=255, unique=True, null=True) idempotency_key = fields.CharField(max_length=255, unique=True, null=True)
metadata = fields.JSONField(null=True) metadata = fields.JSONField(null=True)
@ -54,9 +50,8 @@ class UsageRecord(models.Model):
table = "usage_records" table = "usage_records"
indexes = [("user_id", "record_type", "timestamp")] indexes = [("user_id", "record_type", "timestamp")]
class UsageAggregate(models.Model): class UsageAggregate(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
user = fields.ForeignKeyField("models.User", related_name="usage_aggregates") user = fields.ForeignKeyField("models.User", related_name="usage_aggregates")
date = fields.DateField() date = fields.DateField()
storage_bytes_avg = fields.BigIntField(default=0) storage_bytes_avg = fields.BigIntField(default=0)
@ -69,22 +64,21 @@ class UsageAggregate(models.Model):
table = "usage_aggregates" table = "usage_aggregates"
unique_together = (("user", "date"),) unique_together = (("user", "date"),)
class Invoice(models.Model): class Invoice(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
user = fields.ForeignKeyField("models.User", related_name="invoices") user = fields.ForeignKeyField("models.User", related_name="invoices")
invoice_number = fields.CharField(max_length=50, unique=True) invoice_number = fields.CharField(max_length=50, unique=True)
stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True) stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True)
period_start = fields.DateField(db_index=True) period_start = fields.DateField(index=True)
period_end = fields.DateField() period_end = fields.DateField()
subtotal = fields.DecimalField(max_digits=10, decimal_places=4) subtotal = fields.DecimalField(max_digits=10, decimal_places=4)
tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0) tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0)
total = fields.DecimalField(max_digits=10, decimal_places=4) total = fields.DecimalField(max_digits=10, decimal_places=4)
currency = fields.CharField(max_length=3, default="USD") currency = fields.CharField(max_length=3, default="USD")
status = fields.CharField(max_length=50, default="draft", db_index=True) status = fields.CharField(max_length=50, default="draft", index=True)
due_date = fields.DateField(null=True) due_date = fields.DateField(null=True)
paid_at = fields.DatetimeField(null=True) paid_at = fields.DatetimeField(null=True)
created_at = fields.DatetimeField(auto_now_add=True, db_index=True) created_at = fields.DatetimeField(auto_now_add=True, index=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
metadata = fields.JSONField(null=True) metadata = fields.JSONField(null=True)
@ -92,9 +86,8 @@ class Invoice(models.Model):
table = "invoices" table = "invoices"
indexes = [("user_id", "status", "created_at")] indexes = [("user_id", "status", "created_at")]
class InvoiceLineItem(models.Model): class InvoiceLineItem(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items") invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items")
description = fields.TextField() description = fields.TextField()
quantity = fields.DecimalField(max_digits=15, decimal_places=6) quantity = fields.DecimalField(max_digits=15, decimal_places=6)
@ -107,24 +100,20 @@ class InvoiceLineItem(models.Model):
class Meta: class Meta:
table = "invoice_line_items" table = "invoice_line_items"
class PricingConfig(models.Model): class PricingConfig(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
config_key = fields.CharField(max_length=100, unique=True) config_key = fields.CharField(max_length=100, unique=True)
config_value = fields.DecimalField(max_digits=10, decimal_places=6) config_value = fields.DecimalField(max_digits=10, decimal_places=6)
description = fields.TextField(null=True) description = fields.TextField(null=True)
unit = fields.CharField(max_length=50, null=True) unit = fields.CharField(max_length=50, null=True)
updated_by = fields.ForeignKeyField( updated_by = fields.ForeignKeyField("models.User", related_name="pricing_updates", null=True)
"models.User", related_name="pricing_updates", null=True
)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
class Meta: class Meta:
table = "pricing_config" table = "pricing_config"
class PaymentMethod(models.Model): class PaymentMethod(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
user = fields.ForeignKeyField("models.User", related_name="payment_methods") user = fields.ForeignKeyField("models.User", related_name="payment_methods")
stripe_payment_method_id = fields.CharField(max_length=255) stripe_payment_method_id = fields.CharField(max_length=255)
type = fields.CharField(max_length=50) type = fields.CharField(max_length=50)
@ -139,12 +128,9 @@ class PaymentMethod(models.Model):
class Meta: class Meta:
table = "payment_methods" table = "payment_methods"
class BillingEvent(models.Model): class BillingEvent(models.Model):
id = fields.BigIntField(primary_key=True) id = fields.BigIntField(pk=True)
user = fields.ForeignKeyField( user = fields.ForeignKeyField("models.User", related_name="billing_events", null=True)
"models.User", related_name="billing_events", null=True
)
event_type = fields.CharField(max_length=100) event_type = fields.CharField(max_length=100)
stripe_event_id = fields.CharField(max_length=255, unique=True, null=True) stripe_event_id = fields.CharField(max_length=255, unique=True, null=True)
data = fields.JSONField(null=True) data = fields.JSONField(null=True)

View File

@ -1,6 +1,7 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from datetime import datetime, date, timedelta from datetime import datetime, date, timedelta
import asyncio
from .usage_tracker import UsageTracker from .usage_tracker import UsageTracker
from .invoice_generator import InvoiceGenerator from .invoice_generator import InvoiceGenerator
@ -8,7 +9,6 @@ from ..models import User
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
async def aggregate_daily_usage_for_all_users(): async def aggregate_daily_usage_for_all_users():
users = await User.filter(is_active=True).all() users = await User.filter(is_active=True).all()
yesterday = date.today() - timedelta(days=1) yesterday = date.today() - timedelta(days=1)
@ -19,7 +19,6 @@ async def aggregate_daily_usage_for_all_users():
except Exception as e: except Exception as e:
print(f"Failed to aggregate usage for user {user.id}: {e}") print(f"Failed to aggregate usage for user {user.id}: {e}")
async def generate_monthly_invoices(): async def generate_monthly_invoices():
now = datetime.now() now = datetime.now()
last_month = now.month - 1 if now.month > 1 else 12 last_month = now.month - 1 if now.month > 1 else 12
@ -29,22 +28,19 @@ async def generate_monthly_invoices():
for user in users: for user in users:
try: try:
invoice = await InvoiceGenerator.generate_monthly_invoice( invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, last_month)
user, year, last_month
)
if invoice: if invoice:
await InvoiceGenerator.finalize_invoice(invoice) await InvoiceGenerator.finalize_invoice(invoice)
except Exception as e: except Exception as e:
print(f"Failed to generate invoice for user {user.id}: {e}") print(f"Failed to generate invoice for user {user.id}: {e}")
def start_scheduler(): def start_scheduler():
scheduler.add_job( scheduler.add_job(
aggregate_daily_usage_for_all_users, aggregate_daily_usage_for_all_users,
CronTrigger(hour=1, minute=0), CronTrigger(hour=1, minute=0),
id="aggregate_daily_usage", id="aggregate_daily_usage",
name="Aggregate daily usage for all users", name="Aggregate daily usage for all users",
replace_existing=True, replace_existing=True
) )
scheduler.add_job( scheduler.add_job(
@ -52,11 +48,10 @@ def start_scheduler():
CronTrigger(day=1, hour=2, minute=0), CronTrigger(day=1, hour=2, minute=0),
id="generate_monthly_invoices", id="generate_monthly_invoices",
name="Generate monthly invoices", name="Generate monthly invoices",
replace_existing=True, replace_existing=True
) )
scheduler.start() scheduler.start()
def stop_scheduler(): def stop_scheduler():
scheduler.shutdown() scheduler.shutdown()

View File

@ -1,8 +1,8 @@
import stripe import stripe
from typing import Dict from decimal import Decimal
from typing import Optional, Dict, Any
from ..settings import settings from ..settings import settings
class StripeClient: class StripeClient:
@staticmethod @staticmethod
def _ensure_api_key(): def _ensure_api_key():
@ -11,12 +11,13 @@ class StripeClient:
stripe.api_key = settings.STRIPE_SECRET_KEY stripe.api_key = settings.STRIPE_SECRET_KEY
else: else:
raise ValueError("Stripe API key not configured") raise ValueError("Stripe API key not configured")
@staticmethod @staticmethod
async def create_customer(email: str, name: str, metadata: Dict = None) -> str: async def create_customer(email: str, name: str, metadata: Dict = None) -> str:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
customer = stripe.Customer.create( customer = stripe.Customer.create(
email=email, name=name, metadata=metadata or {} email=email,
name=name,
metadata=metadata or {}
) )
return customer.id return customer.id
@ -25,7 +26,7 @@ class StripeClient:
amount: int, amount: int,
currency: str = "usd", currency: str = "usd",
customer_id: str = None, customer_id: str = None,
metadata: Dict = None, metadata: Dict = None
) -> stripe.PaymentIntent: ) -> stripe.PaymentIntent:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
return stripe.PaymentIntent.create( return stripe.PaymentIntent.create(
@ -33,29 +34,32 @@ class StripeClient:
currency=currency, currency=currency,
customer=customer_id, customer=customer_id,
metadata=metadata or {}, metadata=metadata or {},
automatic_payment_methods={"enabled": True}, automatic_payment_methods={"enabled": True}
) )
@staticmethod @staticmethod
async def create_invoice( async def create_invoice(
customer_id: str, description: str, line_items: list, metadata: Dict = None customer_id: str,
description: str,
line_items: list,
metadata: Dict = None
) -> stripe.Invoice: ) -> stripe.Invoice:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
for item in line_items: for item in line_items:
stripe.InvoiceItem.create( stripe.InvoiceItem.create(
customer=customer_id, customer=customer_id,
amount=int(item["amount"] * 100), amount=int(item['amount'] * 100),
currency=item.get("currency", "usd"), currency=item.get('currency', 'usd'),
description=item["description"], description=item['description'],
metadata=item.get("metadata", {}), metadata=item.get('metadata', {})
) )
invoice = stripe.Invoice.create( invoice = stripe.Invoice.create(
customer=customer_id, customer=customer_id,
description=description, description=description,
auto_advance=True, auto_advance=True,
collection_method="charge_automatically", collection_method='charge_automatically',
metadata=metadata or {}, metadata=metadata or {}
) )
return invoice return invoice
@ -72,15 +76,18 @@ class StripeClient:
@staticmethod @staticmethod
async def attach_payment_method( async def attach_payment_method(
payment_method_id: str, customer_id: str payment_method_id: str,
customer_id: str
) -> stripe.PaymentMethod: ) -> stripe.PaymentMethod:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
payment_method = stripe.PaymentMethod.attach( payment_method = stripe.PaymentMethod.attach(
payment_method_id, customer=customer_id payment_method_id,
customer=customer_id
) )
stripe.Customer.modify( stripe.Customer.modify(
customer_id, invoice_settings={"default_payment_method": payment_method_id} customer_id,
invoice_settings={'default_payment_method': payment_method_id}
) )
return payment_method return payment_method
@ -88,15 +95,22 @@ class StripeClient:
@staticmethod @staticmethod
async def list_payment_methods(customer_id: str, type: str = "card"): async def list_payment_methods(customer_id: str, type: str = "card"):
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
return stripe.PaymentMethod.list(customer=customer_id, type=type) return stripe.PaymentMethod.list(
customer=customer_id,
type=type
)
@staticmethod @staticmethod
async def create_subscription( async def create_subscription(
customer_id: str, price_id: str, metadata: Dict = None customer_id: str,
price_id: str,
metadata: Dict = None
) -> stripe.Subscription: ) -> stripe.Subscription:
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
return stripe.Subscription.create( return stripe.Subscription.create(
customer=customer_id, items=[{"price": price_id}], metadata=metadata or {} customer=customer_id,
items=[{'price': price_id}],
metadata=metadata or {}
) )
@staticmethod @staticmethod

View File

@ -1,10 +1,11 @@
import uuid import uuid
from datetime import datetime, date, timezone from datetime import datetime, date
from decimal import Decimal
from typing import Optional
from tortoise.transactions import in_transaction from tortoise.transactions import in_transaction
from .models import UsageRecord, UsageAggregate from .models import UsageRecord, UsageAggregate
from ..models import User from ..models import User
class UsageTracker: class UsageTracker:
@staticmethod @staticmethod
async def track_storage( async def track_storage(
@ -12,9 +13,9 @@ class UsageTracker:
amount_bytes: int, amount_bytes: int,
resource_type: str = None, resource_type: str = None,
resource_id: int = None, resource_id: int = None,
metadata: dict = None, metadata: dict = None
): ):
idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" idempotency_key = f"storage_{user.id}_{datetime.utcnow().timestamp()}_{uuid.uuid4().hex[:8]}"
await UsageRecord.create( await UsageRecord.create(
user=user, user=user,
@ -23,7 +24,7 @@ class UsageTracker:
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
idempotency_key=idempotency_key, idempotency_key=idempotency_key,
metadata=metadata, metadata=metadata
) )
@staticmethod @staticmethod
@ -33,10 +34,10 @@ class UsageTracker:
direction: str = "down", direction: str = "down",
resource_type: str = None, resource_type: str = None,
resource_id: int = None, resource_id: int = None,
metadata: dict = None, metadata: dict = None
): ):
record_type = f"bandwidth_{direction}" record_type = f"bandwidth_{direction}"
idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" idempotency_key = f"{record_type}_{user.id}_{datetime.utcnow().timestamp()}_{uuid.uuid4().hex[:8]}"
await UsageRecord.create( await UsageRecord.create(
user=user, user=user,
@ -45,7 +46,7 @@ class UsageTracker:
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
idempotency_key=idempotency_key, idempotency_key=idempotency_key,
metadata=metadata, metadata=metadata
) )
@staticmethod @staticmethod
@ -60,26 +61,24 @@ class UsageTracker:
user=user, user=user,
record_type="storage", record_type="storage",
timestamp__gte=start_of_day, timestamp__gte=start_of_day,
timestamp__lte=end_of_day, timestamp__lte=end_of_day
).all() ).all()
storage_avg = sum(r.amount_bytes for r in storage_records) // max( storage_avg = sum(r.amount_bytes for r in storage_records) // max(len(storage_records), 1)
len(storage_records), 1
)
storage_peak = max((r.amount_bytes for r in storage_records), default=0) storage_peak = max((r.amount_bytes for r in storage_records), default=0)
bandwidth_up = await UsageRecord.filter( bandwidth_up = await UsageRecord.filter(
user=user, user=user,
record_type="bandwidth_up", record_type="bandwidth_up",
timestamp__gte=start_of_day, timestamp__gte=start_of_day,
timestamp__lte=end_of_day, timestamp__lte=end_of_day
).all() ).all()
bandwidth_down = await UsageRecord.filter( bandwidth_down = await UsageRecord.filter(
user=user, user=user,
record_type="bandwidth_down", record_type="bandwidth_down",
timestamp__gte=start_of_day, timestamp__gte=start_of_day,
timestamp__lte=end_of_day, timestamp__lte=end_of_day
).all() ).all()
total_up = sum(r.amount_bytes for r in bandwidth_up) total_up = sum(r.amount_bytes for r in bandwidth_up)
@ -93,8 +92,8 @@ class UsageTracker:
"storage_bytes_avg": storage_avg, "storage_bytes_avg": storage_avg,
"storage_bytes_peak": storage_peak, "storage_bytes_peak": storage_peak,
"bandwidth_up_bytes": total_up, "bandwidth_up_bytes": total_up,
"bandwidth_down_bytes": total_down, "bandwidth_down_bytes": total_down
}, }
) )
if not created: if not created:
@ -109,7 +108,6 @@ class UsageTracker:
@staticmethod @staticmethod
async def get_current_storage(user: User) -> int: async def get_current_storage(user: User) -> int:
from ..models import File from ..models import File
files = await File.filter(owner=user, is_deleted=False).all() files = await File.filter(owner=user, is_deleted=False).all()
return sum(f.size for f in files) return sum(f.size for f in files)
@ -123,7 +121,9 @@ class UsageTracker:
end_date = date(year, month, last_day) end_date = date(year, month, last_day)
aggregates = await UsageAggregate.filter( aggregates = await UsageAggregate.filter(
user=user, date__gte=start_date, date__lte=end_date user=user,
date__gte=start_date,
date__lte=end_date
).all() ).all()
if not aggregates: if not aggregates:
@ -132,7 +132,7 @@ class UsageTracker:
"storage_gb_peak": 0, "storage_gb_peak": 0,
"bandwidth_up_gb": 0, "bandwidth_up_gb": 0,
"bandwidth_down_gb": 0, "bandwidth_down_gb": 0,
"total_bandwidth_gb": 0, "total_bandwidth_gb": 0
} }
storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates) storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates)
@ -145,5 +145,5 @@ class UsageTracker:
"storage_gb_peak": round(storage_peak / (1024**3), 4), "storage_gb_peak": round(storage_peak / (1024**3), 4),
"bandwidth_up_gb": round(bandwidth_up / (1024**3), 4), "bandwidth_up_gb": round(bandwidth_up / (1024**3), 4),
"bandwidth_down_gb": round(bandwidth_down / (1024**3), 4), "bandwidth_down_gb": round(bandwidth_down / (1024**3), 4),
"total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4), "total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4)
} }

View File

@ -7,7 +7,7 @@ for various legal policies required for a European cloud storage provider.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Dict from typing import Dict, Any
class LegalDocument(ABC): class LegalDocument(ABC):
@ -17,13 +17,10 @@ class LegalDocument(ABC):
Provides common structure and methods for generating legal content. Provides common structure and methods for generating legal content.
""" """
def __init__( def __init__(self, company_name: str = "MyWebdav Technologies",
self, last_updated: str = None,
company_name: str = "MyWebdav Technologies", contact_email: str = "legal@mywebdav.eu",
last_updated: str = None, website: str = "https://mywebdav.eu"):
contact_email: str = "legal@mywebdav.eu",
website: str = "https://mywebdav.eu",
):
self.company_name = company_name self.company_name = company_name
self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y") self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y")
self.contact_email = contact_email self.contact_email = contact_email
@ -856,21 +853,20 @@ class ContactComplaintMechanism(LegalDocument):
def get_all_legal_documents() -> Dict[str, LegalDocument]: def get_all_legal_documents() -> Dict[str, LegalDocument]:
"""Return a dictionary of all legal document instances.""" """Return a dictionary of all legal document instances."""
return { return {
"privacy_policy": PrivacyPolicy(), 'privacy_policy': PrivacyPolicy(),
"terms_of_service": TermsOfService(), 'terms_of_service': TermsOfService(),
"security_policy": SecurityPolicy(), 'security_policy': SecurityPolicy(),
"cookie_policy": CookiePolicy(), 'cookie_policy': CookiePolicy(),
"data_processing_agreement": DataProcessingAgreement(), 'data_processing_agreement': DataProcessingAgreement(),
"compliance_statement": ComplianceStatement(), 'compliance_statement': ComplianceStatement(),
"data_portability_deletion_policy": DataPortabilityDeletionPolicy(), 'data_portability_deletion_policy': DataPortabilityDeletionPolicy(),
"contact_complaint_mechanism": ContactComplaintMechanism(), 'contact_complaint_mechanism': ContactComplaintMechanism(),
} }
def generate_legal_documents(output_dir: str = "static/legal"): def generate_legal_documents(output_dir: str = "static/legal"):
"""Generate all legal documents as Markdown and HTML files.""" """Generate all legal documents as Markdown and HTML files."""
import os import os
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
documents = get_all_legal_documents() documents = get_all_legal_documents()
@ -879,17 +875,17 @@ def generate_legal_documents(output_dir: str = "static/legal"):
# Generate Markdown # Generate Markdown
md_filename = f"{doc_name}.md" md_filename = f"{doc_name}.md"
md_path = os.path.join(output_dir, md_filename) md_path = os.path.join(output_dir, md_filename)
with open(md_path, "w") as f: with open(md_path, 'w') as f:
f.write(doc.to_markdown()) f.write(doc.to_markdown())
# Generate HTML # Generate HTML
html_filename = f"{doc_name}.html" html_filename = f"{doc_name}.html"
html_path = os.path.join(output_dir, html_filename) html_path = os.path.join(output_dir, html_filename)
with open(html_path, "w") as f: with open(html_path, 'w') as f:
f.write(doc.to_html()) f.write(doc.to_html())
print(f"Generated {md_filename} and {html_filename}") print(f"Generated {md_filename} and {html_filename}")
if __name__ == "__main__": if __name__ == "__main__":
generate_legal_documents() generate_legal_documents()

View File

@ -2,26 +2,17 @@ import asyncio
import aiosmtplib import aiosmtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from typing import Optional from typing import Optional, Dict, Any
from .settings import settings from .settings import settings
class EmailTask: class EmailTask:
def __init__( def __init__(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
self,
to_email: str,
subject: str,
body: str,
html: Optional[str] = None,
**kwargs,
):
self.to_email = to_email self.to_email = to_email
self.subject = subject self.subject = subject
self.body = body self.body = body
self.html = html self.html = html
self.kwargs = kwargs self.kwargs = kwargs
class EmailService: class EmailService:
def __init__(self): def __init__(self):
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
@ -47,14 +38,7 @@ class EmailService:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
async def send_email( async def send_email(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
self,
to_email: str,
subject: str,
body: str,
html: Optional[str] = None,
**kwargs,
):
"""Queue an email for sending""" """Queue an email for sending"""
task = EmailTask(to_email, subject, body, html, **kwargs) task = EmailTask(to_email, subject, body, html, **kwargs)
await self.queue.put(task) await self.queue.put(task)
@ -76,26 +60,22 @@ class EmailService:
async def _send_email_task(self, task: EmailTask): async def _send_email_task(self, task: EmailTask):
"""Send a single email task""" """Send a single email task"""
if ( if not settings.SMTP_HOST or not settings.SMTP_USERNAME or not settings.SMTP_PASSWORD:
not settings.SMTP_HOST
or not settings.SMTP_USERNAME
or not settings.SMTP_PASSWORD
):
print("SMTP not configured, skipping email send") print("SMTP not configured, skipping email send")
return return
msg = MIMEMultipart("alternative") msg = MIMEMultipart('alternative')
msg["From"] = settings.SMTP_SENDER_EMAIL msg['From'] = settings.SMTP_SENDER_EMAIL
msg["To"] = task.to_email msg['To'] = task.to_email
msg["Subject"] = task.subject msg['Subject'] = task.subject
# Add text part # Add text part
text_part = MIMEText(task.body, "plain") text_part = MIMEText(task.body, 'plain')
msg.attach(text_part) msg.attach(text_part)
# Add HTML part if provided # Add HTML part if provided
if task.html: if task.html:
html_part = MIMEText(task.html, "html") html_part = MIMEText(task.html, 'html')
msg.attach(html_part) msg.attach(html_part)
try: try:
@ -106,7 +86,7 @@ class EmailService:
port=settings.SMTP_PORT, port=settings.SMTP_PORT,
username=settings.SMTP_USERNAME, username=settings.SMTP_USERNAME,
password=settings.SMTP_PASSWORD, password=settings.SMTP_PASSWORD,
use_tls=True, use_tls=True
) as smtp: ) as smtp:
await smtp.send_message(msg) await smtp.send_message(msg)
print(f"Email sent to {task.to_email}") print(f"Email sent to {task.to_email}")
@ -119,10 +99,7 @@ class EmailService:
try: try:
await smtp.starttls() await smtp.starttls()
except Exception as tls_error: except Exception as tls_error:
if ( if "already using" in str(tls_error).lower() or "tls" in str(tls_error).lower():
"already using" in str(tls_error).lower()
or "tls" in str(tls_error).lower()
):
# Connection is already using TLS, proceed without starttls # Connection is already using TLS, proceed without starttls
pass pass
else: else:
@ -134,23 +111,14 @@ class EmailService:
print(f"Failed to send email to {task.to_email}: {e}") print(f"Failed to send email to {task.to_email}: {e}")
raise # Re-raise to let caller handle raise # Re-raise to let caller handle
# Global email service instance # Global email service instance
email_service = EmailService() email_service = EmailService()
# Convenience functions # Convenience functions
async def send_email( async def send_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
):
"""Send an email asynchronously""" """Send an email asynchronously"""
await email_service.send_email(to_email, subject, body, html, **kwargs) await email_service.send_email(to_email, subject, body, html, **kwargs)
def queue_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
def queue_email(
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
):
"""Queue an email for sending (fire and forget)""" """Queue an email for sending (fire and forget)"""
asyncio.create_task( asyncio.create_task(email_service.send_email(to_email, subject, body, html, **kwargs))
email_service.send_email(to_email, subject, body, html, **kwargs)
)

View File

@ -2,33 +2,18 @@ import argparse
import uvicorn import uvicorn
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, status, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from tortoise.contrib.fastapi import register_tortoise from tortoise.contrib.fastapi import register_tortoise
from .settings import settings from .settings import settings
from .routers import ( from .routers import auth, users, folders, files, shares, search, admin, starred, billing, admin_billing
auth,
users,
folders,
files,
shares,
search,
admin,
starred,
billing,
admin_billing,
)
from . import webdav from . import webdav
from .schemas import ErrorResponse from .schemas import ErrorResponse
from .middleware.usage_tracking import UsageTrackingMiddleware
logging.basicConfig( logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
logger.info("Starting up...") logger.info("Starting up...")
@ -36,7 +21,6 @@ async def lifespan(app: FastAPI):
from .billing.scheduler import start_scheduler from .billing.scheduler import start_scheduler
from .billing.models import PricingConfig from .billing.models import PricingConfig
from .mail import email_service from .mail import email_service
start_scheduler() start_scheduler()
logger.info("Billing scheduler started") logger.info("Billing scheduler started")
await email_service.start() await email_service.start()
@ -44,61 +28,28 @@ async def lifespan(app: FastAPI):
pricing_count = await PricingConfig.all().count() pricing_count = await PricingConfig.all().count()
if pricing_count == 0: if pricing_count == 0:
from decimal import Decimal from decimal import Decimal
await PricingConfig.create(config_key='storage_per_gb_month', config_value=Decimal('0.0045'), description='Storage cost per GB per month', unit='per_gb_month')
await PricingConfig.create( await PricingConfig.create(config_key='bandwidth_egress_per_gb', config_value=Decimal('0.009'), description='Bandwidth egress cost per GB', unit='per_gb')
config_key="storage_per_gb_month", await PricingConfig.create(config_key='bandwidth_ingress_per_gb', config_value=Decimal('0.0'), description='Bandwidth ingress cost per GB (free)', unit='per_gb')
config_value=Decimal("0.0045"), await PricingConfig.create(config_key='free_tier_storage_gb', config_value=Decimal('15'), description='Free tier storage in GB', unit='gb')
description="Storage cost per GB per month", await PricingConfig.create(config_key='free_tier_bandwidth_gb', config_value=Decimal('15'), description='Free tier bandwidth in GB per month', unit='gb')
unit="per_gb_month", await PricingConfig.create(config_key='tax_rate_default', config_value=Decimal('0.0'), description='Default tax rate (0 = no tax)', unit='percentage')
)
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb",
)
await PricingConfig.create(
config_key="bandwidth_ingress_per_gb",
config_value=Decimal("0.0"),
description="Bandwidth ingress cost per GB (free)",
unit="per_gb",
)
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb",
)
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate (0 = no tax)",
unit="percentage",
)
logger.info("Default pricing configuration initialized") logger.info("Default pricing configuration initialized")
yield yield
from .billing.scheduler import stop_scheduler from .billing.scheduler import stop_scheduler
stop_scheduler() stop_scheduler()
logger.info("Billing scheduler stopped") logger.info("Billing scheduler stopped")
await email_service.stop() await email_service.stop()
logger.info("Email service stopped") logger.info("Email service stopped")
print("Shutting down...") print("Shutting down...")
app = FastAPI( app = FastAPI(
title="MyWebdav Cloud Storage", title="MyWebdav Cloud Storage",
description="A commercial cloud storage web application", description="A commercial cloud storage web application",
version="0.1.0", version="0.1.0",
lifespan=lifespan, lifespan=lifespan
) )
app.include_router(auth.router) app.include_router(auth.router)
@ -113,6 +64,8 @@ app.include_router(billing.router)
app.include_router(admin_billing.router) app.include_router(admin_billing.router)
app.include_router(webdav.router) app.include_router(webdav.router)
from .middleware.usage_tracking import UsageTrackingMiddleware
app.add_middleware(UsageTrackingMiddleware) app.add_middleware(UsageTrackingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
@ -120,39 +73,35 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
register_tortoise( register_tortoise(
app, app,
db_url=settings.DATABASE_URL, db_url=settings.DATABASE_URL,
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]}, modules={
"models": ["mywebdav.models"],
"billing": ["mywebdav.billing.models"]
},
generate_schemas=True, generate_schemas=True,
add_exception_handlers=True, add_exception_handlers=True,
) )
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException): async def http_exception_handler(request: Request, exc: HTTPException):
logger.error( logger.error(f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}")
f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}"
)
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(), content=ErrorResponse(code=exc.status_code, message=exc.detail).dict(),
) )
@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse @app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse
async def read_root(): async def read_root():
with open("static/index.html", "r") as f: with open("static/index.html", "r") as f:
return f.read() return f.read()
def main(): def main():
parser = argparse.ArgumentParser(description="Run the MyWebdav application.") parser = argparse.ArgumentParser(description="Run the MyWebdav application.")
parser.add_argument( parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address to bind to")
"--host", type=str, default="0.0.0.0", help="Host address to bind to"
)
parser.add_argument("--port", type=int, default=8000, help="Port to listen on") parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
args = parser.parse_args() args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -2,35 +2,31 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from ..billing.usage_tracker import UsageTracker from ..billing.usage_tracker import UsageTracker
class UsageTrackingMiddleware(BaseHTTPMiddleware): class UsageTrackingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
response = await call_next(request) response = await call_next(request)
if hasattr(request.state, "user") and request.state.user: if hasattr(request.state, 'user') and request.state.user:
user = request.state.user user = request.state.user
if ( if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path:
request.method in ["POST", "PUT"] content_length = response.headers.get('content-length')
and "/files/upload" in request.url.path
):
content_length = response.headers.get("content-length")
if content_length: if content_length:
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
user=user, user=user,
amount_bytes=int(content_length), amount_bytes=int(content_length),
direction="up", direction='up',
metadata={"path": request.url.path}, metadata={'path': request.url.path}
) )
elif request.method == "GET" and "/files/download" in request.url.path: elif request.method == 'GET' and '/files/download' in request.url.path:
content_length = response.headers.get("content-length") content_length = response.headers.get('content-length')
if content_length: if content_length:
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
user=user, user=user,
amount_bytes=int(content_length), amount_bytes=int(content_length),
direction="down", direction='down',
metadata={"path": request.url.path}, metadata={'path': request.url.path}
) )
return response return response

View File

@ -1,9 +1,9 @@
from tortoise import fields, models from tortoise import fields, models
from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.contrib.pydantic import pydantic_model_creator
from datetime import datetime
class User(models.Model): class User(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
username = fields.CharField(max_length=20, unique=True) username = fields.CharField(max_length=20, unique=True)
email = fields.CharField(max_length=255, unique=True) email = fields.CharField(max_length=255, unique=True)
hashed_password = fields.CharField(max_length=255) hashed_password = fields.CharField(max_length=255)
@ -11,9 +11,7 @@ class User(models.Model):
is_superuser = fields.BooleanField(default=False) is_superuser = fields.BooleanField(default=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
storage_quota_bytes = fields.BigIntField( storage_quota_bytes = fields.BigIntField(default=10 * 1024 * 1024 * 1024) # 10 GB default
default=10 * 1024 * 1024 * 1024
) # 10 GB default
used_storage_bytes = fields.BigIntField(default=0) used_storage_bytes = fields.BigIntField(default=0)
plan_type = fields.CharField(max_length=50, default="free") plan_type = fields.CharField(max_length=50, default="free")
two_factor_secret = fields.CharField(max_length=255, null=True) two_factor_secret = fields.CharField(max_length=255, null=True)
@ -26,16 +24,11 @@ class User(models.Model):
def __str__(self): def __str__(self):
return self.username return self.username
class Folder(models.Model): class Folder(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
name = fields.CharField(max_length=255) name = fields.CharField(max_length=255)
parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField( parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField("models.Folder", related_name="children", null=True)
"models.Folder", related_name="children", null=True owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="folders")
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="folders"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False) is_deleted = fields.BooleanField(default=False)
@ -43,28 +36,21 @@ class Folder(models.Model):
class Meta: class Meta:
table = "folders" table = "folders"
unique_together = ( unique_together = (("name", "parent", "owner"),) # Ensure unique folder names within a parent for an owner
("name", "parent", "owner"),
) # Ensure unique folder names within a parent for an owner
def __str__(self): def __str__(self):
return self.name return self.name
class File(models.Model): class File(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
name = fields.CharField(max_length=255) name = fields.CharField(max_length=255)
path = fields.CharField(max_length=1024) # Internal storage path path = fields.CharField(max_length=1024) # Internal storage path
size = fields.BigIntField() size = fields.BigIntField()
mime_type = fields.CharField(max_length=255) mime_type = fields.CharField(max_length=255)
file_hash = fields.CharField(max_length=64, null=True) # SHA-256 file_hash = fields.CharField(max_length=64, null=True) # SHA-256
thumbnail_path = fields.CharField(max_length=1024, null=True) thumbnail_path = fields.CharField(max_length=1024, null=True)
parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField( parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="files", null=True)
"models.Folder", related_name="files", null=True owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="files")
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="files"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True) updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False) is_deleted = fields.BooleanField(default=False)
@ -74,19 +60,14 @@ class File(models.Model):
class Meta: class Meta:
table = "files" table = "files"
unique_together = ( unique_together = (("name", "parent", "owner"),) # Ensure unique file names within a parent for an owner
("name", "parent", "owner"),
) # Ensure unique file names within a parent for an owner
def __str__(self): def __str__(self):
return self.name return self.name
class FileVersion(models.Model): class FileVersion(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField( file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="versions")
"models.File", related_name="versions"
)
version_path = fields.CharField(max_length=1024) version_path = fields.CharField(max_length=1024)
size = fields.BigIntField() size = fields.BigIntField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
@ -94,69 +75,47 @@ class FileVersion(models.Model):
class Meta: class Meta:
table = "file_versions" table = "file_versions"
class Share(models.Model): class Share(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
token = fields.CharField(max_length=64, unique=True) token = fields.CharField(max_length=64, unique=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField( file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="shares", null=True)
"models.File", related_name="shares", null=True folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="shares", null=True)
) owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="shares")
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="shares", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="shares"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True) expires_at = fields.DatetimeField(null=True)
password_protected = fields.BooleanField(default=False) password_protected = fields.BooleanField(default=False)
hashed_password = fields.CharField(max_length=255, null=True) hashed_password = fields.CharField(max_length=255, null=True)
access_count = fields.IntField(default=0) access_count = fields.IntField(default=0)
permission_level = fields.CharField( permission_level = fields.CharField(max_length=50, default="viewer") # viewer, uploader, editor
max_length=50, default="viewer"
) # viewer, uploader, editor
class Meta: class Meta:
table = "shares" table = "shares"
class Team(models.Model): class Team(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
name = fields.CharField(max_length=255, unique=True) name = fields.CharField(max_length=255, unique=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="owned_teams")
"models.User", related_name="owned_teams" members: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User", related_name="teams", through="team_members")
)
members: fields.ManyToManyRelation[User] = fields.ManyToManyField(
"models.User", related_name="teams", through="team_members"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta: class Meta:
table = "teams" table = "teams"
class TeamMember(models.Model): class TeamMember(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField( team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField("models.Team", related_name="team_members")
"models.Team", related_name="team_members" user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="user_teams")
) role = fields.CharField(max_length=50, default="member") # owner, admin, member
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="user_teams"
)
role = fields.CharField(max_length=50, default="member") # owner, admin, member
class Meta: class Meta:
table = "team_members" table = "team_members"
unique_together = (("team", "user"),) unique_together = (("team", "user"),)
class Activity(models.Model): class Activity(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="activities", null=True)
"models.User", related_name="activities", null=True
)
action = fields.CharField(max_length=255) action = fields.CharField(max_length=255)
target_type = fields.CharField(max_length=50) # file, folder, share, user, team target_type = fields.CharField(max_length=50) # file, folder, share, user, team
target_id = fields.IntField() target_id = fields.IntField()
ip_address = fields.CharField(max_length=45, null=True) ip_address = fields.CharField(max_length=45, null=True)
timestamp = fields.DatetimeField(auto_now_add=True) timestamp = fields.DatetimeField(auto_now_add=True)
@ -164,18 +123,13 @@ class Activity(models.Model):
class Meta: class Meta:
table = "activities" table = "activities"
class FileRequest(models.Model): class FileRequest(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
title = fields.CharField(max_length=255) title = fields.CharField(max_length=255)
description = fields.TextField(null=True) description = fields.TextField(null=True)
token = fields.CharField(max_length=64, unique=True) token = fields.CharField(max_length=64, unique=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="file_requests")
"models.User", related_name="file_requests" target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="file_requests")
)
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="file_requests"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True) expires_at = fields.DatetimeField(null=True)
is_active = fields.BooleanField(default=True) is_active = fields.BooleanField(default=True)
@ -183,9 +137,8 @@ class FileRequest(models.Model):
class Meta: class Meta:
table = "file_requests" table = "file_requests"
class WebDAVProperty(models.Model): class WebDAVProperty(models.Model):
id = fields.IntField(primary_key=True) id = fields.IntField(pk=True)
resource_type = fields.CharField(max_length=10) resource_type = fields.CharField(max_length=10)
resource_id = fields.IntField() resource_id = fields.IntField()
namespace = fields.CharField(max_length=255) namespace = fields.CharField(max_length=255)
@ -198,8 +151,5 @@ class WebDAVProperty(models.Model):
table = "webdav_properties" table = "webdav_properties"
unique_together = (("resource_type", "resource_id", "namespace", "name"),) unique_together = (("resource_type", "resource_id", "namespace", "name"),)
User_Pydantic = pydantic_model_creator(User, name="User_Pydantic") User_Pydantic = pydantic_model_creator(User, name="User_Pydantic")
UserIn_Pydantic = pydantic_model_creator( UserIn_Pydantic = pydantic_model_creator(User, name="UserIn_Pydantic", exclude_readonly=True)
User, name="UserIn_Pydantic", exclude_readonly=True
)

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@ -13,25 +13,18 @@ router = APIRouter(
responses={403: {"description": "Not enough permissions"}}, responses={403: {"description": "Not enough permissions"}},
) )
@router.get("/users", response_model=List[User_Pydantic]) @router.get("/users", response_model=List[User_Pydantic])
async def get_all_users(): async def get_all_users():
return await User.all() return await User.all()
@router.get("/users/{user_id}", response_model=User_Pydantic) @router.get("/users/{user_id}", response_model=User_Pydantic)
async def get_user(user_id: int): async def get_user(user_id: int):
user = await User.get_or_none(id=user_id) user = await User.get_or_none(id=user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return user return user
@router.post("/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED)
@router.post(
"/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED
)
async def create_user_by_admin(user_in: UserCreate): async def create_user_by_admin(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username) user = await User.get_or_none(username=user_in.username)
if user: if user:
@ -51,81 +44,66 @@ async def create_user_by_admin(user_in: UserCreate):
username=user_in.username, username=user_in.username,
email=user_in.email, email=user_in.email,
hashed_password=hashed_password, hashed_password=hashed_password,
is_superuser=False, # Admin creates regular users by default is_superuser=False, # Admin creates regular users by default
is_active=True, is_active=True,
) )
return await User_Pydantic.from_tortoise_orm(user) return await User_Pydantic.from_tortoise_orm(user)
@router.put("/users/{user_id}", response_model=User_Pydantic) @router.put("/users/{user_id}", response_model=User_Pydantic)
async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate): async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
user = await User.get_or_none(id=user_id) user = await User.get_or_none(id=user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if user_update.username is not None and user_update.username != user.username: if user_update.username is not None and user_update.username != user.username:
if await User.get_or_none(username=user_update.username): if await User.get_or_none(username=user_update.username):
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken")
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
)
user.username = user_update.username user.username = user_update.username
if user_update.email is not None and user_update.email != user.email: if user_update.email is not None and user_update.email != user.email:
if await User.get_or_none(email=user_update.email): if await User.get_or_none(email=user_update.email):
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
user.email = user_update.email user.email = user_update.email
if user_update.password is not None: if user_update.password is not None:
user.hashed_password = get_password_hash(user_update.password) user.hashed_password = get_password_hash(user_update.password)
if user_update.is_active is not None: if user_update.is_active is not None:
user.is_active = user_update.is_active user.is_active = user_update.is_active
if user_update.is_superuser is not None: if user_update.is_superuser is not None:
user.is_superuser = user_update.is_superuser user.is_superuser = user_update.is_superuser
if user_update.storage_quota_bytes is not None: if user_update.storage_quota_bytes is not None:
user.storage_quota_bytes = user_update.storage_quota_bytes user.storage_quota_bytes = user_update.storage_quota_bytes
if user_update.plan_type is not None: if user_update.plan_type is not None:
user.plan_type = user_update.plan_type user.plan_type = user_update.plan_type
if user_update.is_2fa_enabled is not None: if user_update.is_2fa_enabled is not None:
user.is_2fa_enabled = user_update.is_2fa_enabled user.is_2fa_enabled = user_update.is_2fa_enabled
if not user_update.is_2fa_enabled: if not user_update.is_2fa_enabled:
user.two_factor_secret = None user.two_factor_secret = None
user.recovery_codes = None user.recovery_codes = None
await user.save() await user.save()
return await User_Pydantic.from_tortoise_orm(user) return await User_Pydantic.from_tortoise_orm(user)
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user_by_admin(user_id: int): async def delete_user_by_admin(user_id: int):
user = await User.get_or_none(id=user_id) user = await User.get_or_none(id=user_id)
if not user: if not user:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
await user.delete() await user.delete()
return {"message": "User deleted successfully"} return {"message": "User deleted successfully"}
@router.post("/test-email") @router.post("/test-email")
async def send_test_email( async def send_test_email(to_email: str, subject: str = "Test Email", body: str = "This is a test email"):
to_email: str, subject: str = "Test Email", body: str = "This is a test email"
):
from ..mail import queue_email from ..mail import queue_email
queue_email( queue_email(
to_email=to_email, to_email=to_email,
subject=subject, subject=subject,
body=body, body=body,
html=f"<h1>{subject}</h1><p>{body}</p>", html=f"<h1>{subject}</h1><p>{body}</p>"
) )
return {"message": "Test email queued"} return {"message": "Test email queued"}

View File

@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from typing import List
from decimal import Decimal from decimal import Decimal
from pydantic import BaseModel from pydantic import BaseModel
@ -7,25 +8,21 @@ from ..models import User
from ..billing.models import PricingConfig, Invoice, SubscriptionPlan from ..billing.models import PricingConfig, Invoice, SubscriptionPlan
from ..billing.invoice_generator import InvoiceGenerator from ..billing.invoice_generator import InvoiceGenerator
def require_superuser(current_user: User = Depends(get_current_user)): def require_superuser(current_user: User = Depends(get_current_user)):
if not current_user.is_superuser: if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Superuser privileges required") raise HTTPException(status_code=403, detail="Superuser privileges required")
return current_user return current_user
router = APIRouter( router = APIRouter(
prefix="/api/admin/billing", prefix="/api/admin/billing",
tags=["admin", "billing"], tags=["admin", "billing"],
dependencies=[Depends(require_superuser)], dependencies=[Depends(require_superuser)]
) )
class PricingConfigUpdate(BaseModel): class PricingConfigUpdate(BaseModel):
config_key: str config_key: str
config_value: float config_value: float
class PlanCreate(BaseModel): class PlanCreate(BaseModel):
name: str name: str
display_name: str display_name: str
@ -35,7 +32,6 @@ class PlanCreate(BaseModel):
price_monthly: float price_monthly: float
price_yearly: float = None price_yearly: float = None
@router.get("/pricing") @router.get("/pricing")
async def get_all_pricing(current_user: User = Depends(require_superuser)): async def get_all_pricing(current_user: User = Depends(require_superuser)):
configs = await PricingConfig.all() configs = await PricingConfig.all()
@ -46,17 +42,16 @@ async def get_all_pricing(current_user: User = Depends(require_superuser)):
"config_value": float(c.config_value), "config_value": float(c.config_value),
"description": c.description, "description": c.description,
"unit": c.unit, "unit": c.unit,
"updated_at": c.updated_at, "updated_at": c.updated_at
} }
for c in configs for c in configs
] ]
@router.put("/pricing/{config_id}") @router.put("/pricing/{config_id}")
async def update_pricing( async def update_pricing(
config_id: int, config_id: int,
update: PricingConfigUpdate, update: PricingConfigUpdate,
current_user: User = Depends(require_superuser), current_user: User = Depends(require_superuser)
): ):
config = await PricingConfig.get_or_none(id=config_id) config = await PricingConfig.get_or_none(id=config_id)
if not config: if not config:
@ -68,10 +63,11 @@ async def update_pricing(
return {"message": "Pricing updated successfully"} return {"message": "Pricing updated successfully"}
@router.post("/generate-invoices/{year}/{month}") @router.post("/generate-invoices/{year}/{month}")
async def generate_all_invoices( async def generate_all_invoices(
year: int, month: int, current_user: User = Depends(require_superuser) year: int,
month: int,
current_user: User = Depends(require_superuser)
): ):
users = await User.filter(is_active=True).all() users = await User.filter(is_active=True).all()
generated = [] generated = []
@ -80,36 +76,35 @@ async def generate_all_invoices(
for user in users: for user in users:
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month) invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month)
if invoice: if invoice:
generated.append( generated.append({
{ "user_id": user.id,
"user_id": user.id, "invoice_id": invoice.id,
"invoice_id": invoice.id, "total": float(invoice.total)
"total": float(invoice.total), })
}
)
else: else:
skipped.append(user.id) skipped.append(user.id)
return {"generated": len(generated), "skipped": len(skipped), "invoices": generated} return {
"generated": len(generated),
"skipped": len(skipped),
"invoices": generated
}
@router.post("/plans") @router.post("/plans")
async def create_plan( async def create_plan(
plan_data: PlanCreate, current_user: User = Depends(require_superuser) plan_data: PlanCreate,
current_user: User = Depends(require_superuser)
): ):
plan = await SubscriptionPlan.create(**plan_data.dict()) plan = await SubscriptionPlan.create(**plan_data.dict())
return {"id": plan.id, "message": "Plan created successfully"} return {"id": plan.id, "message": "Plan created successfully"}
@router.get("/stats") @router.get("/stats")
async def get_billing_stats(current_user: User = Depends(require_superuser)): async def get_billing_stats(current_user: User = Depends(require_superuser)):
from tortoise.functions import Sum from tortoise.functions import Sum, Count
total_revenue = ( total_revenue = await Invoice.filter(status="paid").annotate(
await Invoice.filter(status="paid") total_sum=Sum("total")
.annotate(total_sum=Sum("total")) ).values("total_sum")
.values("total_sum")
)
invoice_count = await Invoice.all().count() invoice_count = await Invoice.all().count()
pending_invoices = await Invoice.filter(status="open").count() pending_invoices = await Invoice.filter(status="open").count()
@ -117,5 +112,5 @@ async def get_billing_stats(current_user: User = Depends(require_superuser)):
return { return {
"total_revenue": float(total_revenue[0]["total_sum"] or 0), "total_revenue": float(total_revenue[0]["total_sum"] or 0),
"total_invoices": invoice_count, "total_invoices": invoice_count,
"pending_invoices": pending_invoices, "pending_invoices": pending_invoices
} }

View File

@ -4,23 +4,13 @@ from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel from pydantic import BaseModel
from ..auth import ( from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password
authenticate_user,
create_access_token,
get_password_hash,
get_current_user,
get_current_verified_user,
verify_password,
)
from ..models import User from ..models import User
from ..schemas import Token, UserCreate from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA
from ..two_factor import ( from ..two_factor import (
generate_totp_secret, generate_totp_secret, generate_totp_uri, generate_qr_code_base64,
generate_totp_uri, verify_totp_code, generate_recovery_codes, hash_recovery_codes,
generate_qr_code_base64, verify_recovery_codes
verify_totp_code,
generate_recovery_codes,
hash_recovery_codes,
) )
router = APIRouter( router = APIRouter(
@ -28,33 +18,27 @@ router = APIRouter(
tags=["auth"], tags=["auth"],
) )
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
username: str username: str
password: str password: str
class TwoFactorLogin(BaseModel): class TwoFactorLogin(BaseModel):
username: str username: str
password: str password: str
two_factor_code: Optional[str] = None two_factor_code: Optional[str] = None
class TwoFactorSetupResponse(BaseModel): class TwoFactorSetupResponse(BaseModel):
secret: str secret: str
qr_code_base64: str qr_code_base64: str
recovery_codes: List[str] recovery_codes: List[str]
class TwoFactorCode(BaseModel): class TwoFactorCode(BaseModel):
two_factor_code: str two_factor_code: str
class TwoFactorDisable(BaseModel): class TwoFactorDisable(BaseModel):
password: str password: str
two_factor_code: str two_factor_code: str
@router.post("/register", response_model=Token) @router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate): async def register_user(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username) user = await User.get_or_none(username=user_in.username)
@ -79,26 +63,22 @@ async def register_user(user_in: UserCreate):
# Send welcome email # Send welcome email
from ..mail import queue_email from ..mail import queue_email
queue_email( queue_email(
to_email=user.email, to_email=user.email,
subject="Welcome to MyWebdav!", subject="Welcome to MyWebdav!",
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team", body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>", html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>"
) )
access_token_expires = timedelta(minutes=30) # Use settings access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires data={"sub": user.username}, expires_delta=access_token_expires
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.post("/token", response_model=Token) @router.post("/token", response_model=Token)
async def login_for_access_token(login_data: LoginRequest): async def login_for_access_token(login_data: LoginRequest):
auth_result = await authenticate_user( auth_result = await authenticate_user(login_data.username, login_data.password, None)
login_data.username, login_data.password, None
)
if not auth_result: if not auth_result:
raise HTTPException( raise HTTPException(
@ -117,26 +97,16 @@ async def login_for_access_token(login_data: LoginRequest):
access_token_expires = timedelta(minutes=30) access_token_expires = timedelta(minutes=30)
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username}, data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True
expires_delta=access_token_expires,
two_factor_verified=True,
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse) @router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
async def setup_two_factor_authentication( async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)):
current_user: User = Depends(get_current_user),
):
if current_user.is_2fa_enabled: if current_user.is_2fa_enabled:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if current_user.two_factor_secret: if current_user.two_factor_secret:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.")
status_code=status.HTTP_400_BAD_REQUEST,
detail="2FA setup already initiated. Verify or disable first.",
)
secret = generate_totp_secret() secret = generate_totp_secret()
current_user.two_factor_secret = secret current_user.two_factor_secret = secret
@ -150,66 +120,39 @@ async def setup_two_factor_authentication(
current_user.recovery_codes = ",".join(hashed_recovery_codes) current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save() await current_user.save()
return TwoFactorSetupResponse( return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes)
secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes
)
@router.post("/2fa/verify", response_model=Token) @router.post("/2fa/verify", response_model=Token)
async def verify_two_factor_authentication( async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)):
two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)
):
if current_user.is_2fa_enabled: if current_user.is_2fa_enabled:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if not current_user.two_factor_secret: if not current_user.two_factor_secret:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.")
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated."
)
if not verify_totp_code( if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code):
current_user.two_factor_secret, two_factor_code_data.two_factor_code raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.is_2fa_enabled = True current_user.is_2fa_enabled = True
await current_user.save() await current_user.save()
access_token_expires = timedelta(minutes=30) # Use settings access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token( access_token = create_access_token(
data={"sub": current_user.username}, data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True
expires_delta=access_token_expires,
two_factor_verified=True,
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/disable", response_model=dict) @router.post("/2fa/disable", response_model=dict)
async def disable_two_factor_authentication( async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)):
disable_data: TwoFactorDisable,
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled: if not current_user.is_2fa_enabled:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
# Verify password # Verify password
if not verify_password(disable_data.password, current_user.hashed_password): if not verify_password(disable_data.password, current_user.hashed_password):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.")
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password."
)
# Verify 2FA code # Verify 2FA code
if not verify_totp_code( if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code):
current_user.two_factor_secret, disable_data.two_factor_code raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.two_factor_secret = None current_user.two_factor_secret = None
current_user.is_2fa_enabled = False current_user.is_2fa_enabled = False
@ -218,15 +161,10 @@ async def disable_two_factor_authentication(
return {"message": "2FA disabled successfully."} return {"message": "2FA disabled successfully."}
@router.get("/2fa/recovery-codes", response_model=List[str]) @router.get("/2fa/recovery-codes", response_model=List[str])
async def get_new_recovery_codes( async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)):
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled: if not current_user.is_2fa_enabled:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
recovery_codes = generate_recovery_codes() recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes) hashed_recovery_codes = hash_recovery_codes(recovery_codes)

View File

@ -1,25 +1,25 @@
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing import List, Optional from typing import List, Optional
from datetime import datetime, date from datetime import datetime, date
from decimal import Decimal
import calendar
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User from ..models import User
from ..billing.models import ( from ..billing.models import (
Invoice, Invoice, InvoiceLineItem, UserSubscription, PricingConfig,
UserSubscription, PaymentMethod, UsageAggregate, SubscriptionPlan
PricingConfig,
PaymentMethod,
UsageAggregate,
SubscriptionPlan,
) )
from ..billing.usage_tracker import UsageTracker from ..billing.usage_tracker import UsageTracker
from ..billing.invoice_generator import InvoiceGenerator from ..billing.invoice_generator import InvoiceGenerator
from ..billing.stripe_client import StripeClient from ..billing.stripe_client import StripeClient
from pydantic import BaseModel from pydantic import BaseModel
router = APIRouter(prefix="/api/billing", tags=["billing"]) router = APIRouter(
prefix="/api/billing",
tags=["billing"]
)
class UsageResponse(BaseModel): class UsageResponse(BaseModel):
storage_gb_avg: float storage_gb_avg: float
@ -29,7 +29,6 @@ class UsageResponse(BaseModel):
total_bandwidth_gb: float total_bandwidth_gb: float
period: str period: str
class InvoiceResponse(BaseModel): class InvoiceResponse(BaseModel):
id: int id: int
invoice_number: str invoice_number: str
@ -43,7 +42,6 @@ class InvoiceResponse(BaseModel):
paid_at: Optional[datetime] paid_at: Optional[datetime]
line_items: List[dict] line_items: List[dict]
class SubscriptionResponse(BaseModel): class SubscriptionResponse(BaseModel):
id: int id: int
billing_type: str billing_type: str
@ -52,7 +50,6 @@ class SubscriptionResponse(BaseModel):
current_period_start: Optional[datetime] current_period_start: Optional[datetime]
current_period_end: Optional[datetime] current_period_end: Optional[datetime]
@router.get("/usage/current") @router.get("/usage/current")
async def get_current_usage(current_user: User = Depends(get_current_user)): async def get_current_usage(current_user: User = Depends(get_current_user)):
try: try:
@ -64,32 +61,25 @@ async def get_current_usage(current_user: User = Depends(get_current_user)):
if usage_today: if usage_today:
return { return {
"storage_gb": round(storage_bytes / (1024**3), 4), "storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": round( "bandwidth_down_gb_today": round(usage_today.bandwidth_down_bytes / (1024**3), 4),
usage_today.bandwidth_down_bytes / (1024**3), 4 "bandwidth_up_gb_today": round(usage_today.bandwidth_up_bytes / (1024**3), 4),
), "as_of": today.isoformat()
"bandwidth_up_gb_today": round(
usage_today.bandwidth_up_bytes / (1024**3), 4
),
"as_of": today.isoformat(),
} }
return { return {
"storage_gb": round(storage_bytes / (1024**3), 4), "storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": 0, "bandwidth_down_gb_today": 0,
"bandwidth_up_gb_today": 0, "bandwidth_up_gb_today": 0,
"as_of": today.isoformat(), "as_of": today.isoformat()
} }
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=500, detail=f"Failed to fetch usage data: {str(e)}")
status_code=500, detail=f"Failed to fetch usage data: {str(e)}"
)
@router.get("/usage/monthly") @router.get("/usage/monthly")
async def get_monthly_usage( async def get_monthly_usage(
year: Optional[int] = None, year: Optional[int] = None,
month: Optional[int] = None, month: Optional[int] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user)
) -> UsageResponse: ) -> UsageResponse:
try: try:
if year is None or month is None: if year is None or month is None:
@ -98,85 +88,71 @@ async def get_monthly_usage(
month = now.month month = now.month
if not (1 <= month <= 12): if not (1 <= month <= 12):
raise HTTPException( raise HTTPException(status_code=400, detail="Month must be between 1 and 12")
status_code=400, detail="Month must be between 1 and 12"
)
if not (2020 <= year <= 2100): if not (2020 <= year <= 2100):
raise HTTPException( raise HTTPException(status_code=400, detail="Year must be between 2020 and 2100")
status_code=400, detail="Year must be between 2020 and 2100"
)
usage = await UsageTracker.get_monthly_usage(current_user, year, month) usage = await UsageTracker.get_monthly_usage(current_user, year, month)
return UsageResponse(**usage, period=f"{year}-{month:02d}") return UsageResponse(
**usage,
period=f"{year}-{month:02d}"
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}")
status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}"
)
@router.get("/invoices") @router.get("/invoices")
async def list_invoices( async def list_invoices(
limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user) limit: int = 50,
offset: int = 0,
current_user: User = Depends(get_current_user)
) -> List[InvoiceResponse]: ) -> List[InvoiceResponse]:
try: try:
if limit < 1 or limit > 100: if limit < 1 or limit > 100:
raise HTTPException( raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
status_code=400, detail="Limit must be between 1 and 100"
)
if offset < 0: if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative") raise HTTPException(status_code=400, detail="Offset must be non-negative")
invoices = ( invoices = await Invoice.filter(user=current_user).order_by("-created_at").offset(offset).limit(limit).all()
await Invoice.filter(user=current_user)
.order_by("-created_at")
.offset(offset)
.limit(limit)
.all()
)
result = [] result = []
for invoice in invoices: for invoice in invoices:
line_items = await invoice.line_items.all() line_items = await invoice.line_items.all()
result.append( result.append(InvoiceResponse(
InvoiceResponse( id=invoice.id,
id=invoice.id, invoice_number=invoice.invoice_number,
invoice_number=invoice.invoice_number, period_start=invoice.period_start,
period_start=invoice.period_start, period_end=invoice.period_end,
period_end=invoice.period_end, subtotal=float(invoice.subtotal),
subtotal=float(invoice.subtotal), tax=float(invoice.tax),
tax=float(invoice.tax), total=float(invoice.total),
total=float(invoice.total), status=invoice.status,
status=invoice.status, due_date=invoice.due_date,
due_date=invoice.due_date, paid_at=invoice.paid_at,
paid_at=invoice.paid_at, line_items=[
line_items=[ {
{ "description": item.description,
"description": item.description, "quantity": float(item.quantity),
"quantity": float(item.quantity), "unit_price": float(item.unit_price),
"unit_price": float(item.unit_price), "amount": float(item.amount),
"amount": float(item.amount), "type": item.item_type
"type": item.item_type, }
} for item in line_items
for item in line_items ]
], ))
)
)
return result return result
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=500, detail=f"Failed to fetch invoices: {str(e)}")
status_code=500, detail=f"Failed to fetch invoices: {str(e)}"
)
@router.get("/invoices/{invoice_id}") @router.get("/invoices/{invoice_id}")
async def get_invoice( async def get_invoice(
invoice_id: int, current_user: User = Depends(get_current_user) invoice_id: int,
current_user: User = Depends(get_current_user)
) -> InvoiceResponse: ) -> InvoiceResponse:
invoice = await Invoice.get_or_none(id=invoice_id, user=current_user) invoice = await Invoice.get_or_none(id=invoice_id, user=current_user)
if not invoice: if not invoice:
@ -201,22 +177,21 @@ async def get_invoice(
"quantity": float(item.quantity), "quantity": float(item.quantity),
"unit_price": float(item.unit_price), "unit_price": float(item.unit_price),
"amount": float(item.amount), "amount": float(item.amount),
"type": item.item_type, "type": item.item_type
} }
for item in line_items for item in line_items
], ]
) )
@router.get("/subscription") @router.get("/subscription")
async def get_subscription( async def get_subscription(current_user: User = Depends(get_current_user)) -> SubscriptionResponse:
current_user: User = Depends(get_current_user),
) -> SubscriptionResponse:
subscription = await UserSubscription.get_or_none(user=current_user) subscription = await UserSubscription.get_or_none(user=current_user)
if not subscription: if not subscription:
subscription = await UserSubscription.create( subscription = await UserSubscription.create(
user=current_user, billing_type="pay_as_you_go", status="active" user=current_user,
billing_type="pay_as_you_go",
status="active"
) )
plan_name = None plan_name = None
@ -230,19 +205,15 @@ async def get_subscription(
plan_name=plan_name, plan_name=plan_name,
status=subscription.status, status=subscription.status,
current_period_start=subscription.current_period_start, current_period_start=subscription.current_period_start,
current_period_end=subscription.current_period_end, current_period_end=subscription.current_period_end
) )
@router.post("/payment-methods/setup-intent") @router.post("/payment-methods/setup-intent")
async def create_setup_intent(current_user: User = Depends(get_current_user)): async def create_setup_intent(current_user: User = Depends(get_current_user)):
try: try:
from ..settings import settings from ..settings import settings
if not settings.STRIPE_SECRET_KEY: if not settings.STRIPE_SECRET_KEY:
raise HTTPException( raise HTTPException(status_code=503, detail="Payment processing not configured")
status_code=503, detail="Payment processing not configured"
)
subscription = await UserSubscription.get_or_none(user=current_user) subscription = await UserSubscription.get_or_none(user=current_user)
@ -250,7 +221,7 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
customer_id = await StripeClient.create_customer( customer_id = await StripeClient.create_customer(
email=current_user.email, email=current_user.email,
name=current_user.username, name=current_user.username,
metadata={"user_id": str(current_user.id)}, metadata={"user_id": str(current_user.id)}
) )
if not subscription: if not subscription:
@ -258,30 +229,27 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
user=current_user, user=current_user,
billing_type="pay_as_you_go", billing_type="pay_as_you_go",
stripe_customer_id=customer_id, stripe_customer_id=customer_id,
status="active", status="active"
) )
else: else:
subscription.stripe_customer_id = customer_id subscription.stripe_customer_id = customer_id
await subscription.save() await subscription.save()
import stripe import stripe
StripeClient._ensure_api_key() StripeClient._ensure_api_key()
setup_intent = stripe.SetupIntent.create( setup_intent = stripe.SetupIntent.create(
customer=subscription.stripe_customer_id, payment_method_types=["card"] customer=subscription.stripe_customer_id,
payment_method_types=["card"]
) )
return { return {
"client_secret": setup_intent.client_secret, "client_secret": setup_intent.client_secret,
"customer_id": subscription.stripe_customer_id, "customer_id": subscription.stripe_customer_id
} }
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=500, detail=f"Failed to create setup intent: {str(e)}")
status_code=500, detail=f"Failed to create setup intent: {str(e)}"
)
@router.get("/payment-methods") @router.get("/payment-methods")
async def list_payment_methods(current_user: User = Depends(get_current_user)): async def list_payment_methods(current_user: User = Depends(get_current_user)):
@ -294,12 +262,11 @@ async def list_payment_methods(current_user: User = Depends(get_current_user)):
"brand": m.brand, "brand": m.brand,
"exp_month": m.exp_month, "exp_month": m.exp_month,
"exp_year": m.exp_year, "exp_year": m.exp_year,
"is_default": m.is_default, "is_default": m.is_default
} }
for m in methods for m in methods
] ]
@router.post("/webhooks/stripe") @router.post("/webhooks/stripe")
async def stripe_webhook(request: Request): async def stripe_webhook(request: Request):
import stripe import stripe
@ -331,14 +298,12 @@ async def stripe_webhook(request: Request):
event_type=event["type"], event_type=event["type"],
stripe_event_id=event_id, stripe_event_id=event_id,
data=event["data"], data=event["data"],
processed=False, processed=False
) )
if event["type"] == "invoice.payment_succeeded": if event["type"] == "invoice.payment_succeeded":
invoice_data = event["data"]["object"] invoice_data = event["data"]["object"]
mywebdav_invoice_id = invoice_data.get("metadata", {}).get( mywebdav_invoice_id = invoice_data.get("metadata", {}).get("mywebdav_invoice_id")
"mywebdav_invoice_id"
)
if mywebdav_invoice_id: if mywebdav_invoice_id:
invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id)) invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id))
@ -352,9 +317,7 @@ async def stripe_webhook(request: Request):
payment_method = event["data"]["object"] payment_method = event["data"]["object"]
customer_id = payment_method["customer"] customer_id = payment_method["customer"]
subscription = await UserSubscription.get_or_none( subscription = await UserSubscription.get_or_none(stripe_customer_id=customer_id)
stripe_customer_id=customer_id
)
if subscription: if subscription:
await PaymentMethod.create( await PaymentMethod.create(
user=subscription.user, user=subscription.user,
@ -364,7 +327,7 @@ async def stripe_webhook(request: Request):
brand=payment_method.get("card", {}).get("brand"), brand=payment_method.get("card", {}).get("brand"),
exp_month=payment_method.get("card", {}).get("exp_month"), exp_month=payment_method.get("card", {}).get("exp_month"),
exp_year=payment_method.get("card", {}).get("exp_year"), exp_year=payment_method.get("card", {}).get("exp_year"),
is_default=True, is_default=True
) )
await BillingEvent.filter(stripe_event_id=event_id).update(processed=True) await BillingEvent.filter(stripe_event_id=event_id).update(processed=True)
@ -372,10 +335,7 @@ async def stripe_webhook(request: Request):
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=500, detail=f"Webhook processing failed: {str(e)}")
status_code=500, detail=f"Webhook processing failed: {str(e)}"
)
@router.get("/pricing") @router.get("/pricing")
async def get_pricing(): async def get_pricing():
@ -384,12 +344,11 @@ async def get_pricing():
config.config_key: { config.config_key: {
"value": float(config.config_value), "value": float(config.config_value),
"description": config.description, "description": config.description,
"unit": config.unit, "unit": config.unit
} }
for config in configs for config in configs
} }
@router.get("/plans") @router.get("/plans")
async def list_plans(): async def list_plans():
plans = await SubscriptionPlan.filter(is_active=True).all() plans = await SubscriptionPlan.filter(is_active=True).all()
@ -402,16 +361,14 @@ async def list_plans():
"storage_gb": plan.storage_gb, "storage_gb": plan.storage_gb,
"bandwidth_gb": plan.bandwidth_gb, "bandwidth_gb": plan.bandwidth_gb,
"price_monthly": float(plan.price_monthly), "price_monthly": float(plan.price_monthly),
"price_yearly": float(plan.price_yearly) if plan.price_yearly else None, "price_yearly": float(plan.price_yearly) if plan.price_yearly else None
} }
for plan in plans for plan in plans
] ]
@router.get("/stripe-key") @router.get("/stripe-key")
async def get_stripe_key(): async def get_stripe_key():
from ..settings import settings from ..settings import settings
if not settings.STRIPE_PUBLISHABLE_KEY: if not settings.STRIPE_PUBLISHABLE_KEY:
raise HTTPException(status_code=503, detail="Payment processing not configured") raise HTTPException(status_code=503, detail="Payment processing not configured")
return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY} return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY}

View File

@ -1,12 +1,4 @@
from fastapi import ( from fastapi import APIRouter, Depends, UploadFile, File as FastAPIFile, HTTPException, status, Response, Form
APIRouter,
Depends,
UploadFile,
File as FastAPIFile,
HTTPException,
status,
Form,
)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import List, Optional from typing import List, Optional
import mimetypes import mimetypes
@ -19,6 +11,7 @@ from ..auth import get_current_user
from ..models import User, File, Folder from ..models import User, File, Folder
from ..schemas import FileOut from ..schemas import FileOut
from ..storage import storage_manager from ..storage import storage_manager
from ..settings import settings
from ..activity import log_activity from ..activity import log_activity
from ..thumbnails import generate_thumbnail, delete_thumbnail from ..thumbnails import generate_thumbnail, delete_thumbnail
@ -27,46 +20,35 @@ router = APIRouter(
tags=["files"], tags=["files"],
) )
class FileMove(BaseModel): class FileMove(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None
class FileRename(BaseModel): class FileRename(BaseModel):
new_name: str new_name: str
class FileCopy(BaseModel): class FileCopy(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None
class BatchFileOperation(BaseModel): class BatchFileOperation(BaseModel):
file_ids: List[int] file_ids: List[int]
operation: str # e.g., "delete", "star", "unstar", "move", "copy" operation: str # e.g., "delete", "star", "unstar", "move", "copy"
class BatchMoveCopyPayload(BaseModel): class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None
class FileContentUpdate(BaseModel): class FileContentUpdate(BaseModel):
content: str content: str
@router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED) @router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED)
async def upload_file( async def upload_file(
file: UploadFile = FastAPIFile(...), file: UploadFile = FastAPIFile(...),
folder_id: Optional[int] = Form(None), folder_id: Optional[int] = Form(None),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user)
): ):
if folder_id: if folder_id:
parent_folder = await Folder.get_or_none( parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not parent_folder: if not parent_folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
else: else:
parent_folder = None parent_folder = None
@ -91,7 +73,7 @@ async def upload_file(
# Generate a unique path for storage # Generate a unique path for storage
file_extension = os.path.splitext(file.filename)[1] file_extension = os.path.splitext(file.filename)[1]
unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename
storage_path = unique_filename storage_path = unique_filename
# Save file to storage # Save file to storage
@ -123,20 +105,16 @@ async def upload_file(
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.get("/download/{file_id}") @router.get("/download/{file_id}")
async def download_file(file_id: int, current_user: User = Depends(get_current_user)): async def download_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.last_accessed_at = datetime.now() db_file.last_accessed_at = datetime.now()
await db_file.save() await db_file.save()
try: try:
async def file_iterator(): async def file_iterator():
async for chunk in storage_manager.get_file(current_user.id, db_file.path): async for chunk in storage_manager.get_file(current_user.id, db_file.path):
yield chunk yield chunk
@ -144,21 +122,16 @@ async def download_file(file_id: int, current_user: User = Depends(get_current_u
return StreamingResponse( return StreamingResponse(
file_iterator(), file_iterator(),
media_type=db_file.mime_type, media_type=db_file.mime_type,
headers={"Content-Disposition": f'attachment; filename="{db_file.name}"'}, headers={"Content-Disposition": f"attachment; filename=\"{db_file.name}\""}
) )
except FileNotFoundError: except FileNotFoundError:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage"
)
@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_file(file_id: int, current_user: User = Depends(get_current_user)): async def delete_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_deleted = True db_file.is_deleted = True
db_file.deleted_at = datetime.now() db_file.deleted_at = datetime.now()
@ -168,110 +141,68 @@ async def delete_file(file_id: int, current_user: User = Depends(get_current_use
return return
@router.post("/{file_id}/move", response_model=FileOut) @router.post("/{file_id}/move", response_model=FileOut)
async def move_file( async def move_file(file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)):
file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
target_folder = None target_folder = None
if move_data.target_folder_id: if move_data.target_folder_id:
target_folder = await Folder.get_or_none( target_folder = await Folder.get_or_none(id=move_data.target_folder_id, owner=current_user, is_deleted=False)
id=move_data.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
)
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
) )
if existing_file and existing_file.id != file_id: if existing_file and existing_file.id != file_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in target folder")
status_code=status.HTTP_409_CONFLICT,
detail="File with this name already exists in target folder",
)
db_file.parent = target_folder db_file.parent = target_folder
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_moved", target_type="file", target_id=file_id)
user=current_user, action="file_moved", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/rename", response_model=FileOut) @router.post("/{file_id}/rename", response_model=FileOut)
async def rename_file( async def rename_file(file_id: int, rename_data: FileRename, current_user: User = Depends(get_current_user)):
file_id: int,
rename_data: FileRename,
current_user: User = Depends(get_current_user),
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=rename_data.new_name, name=rename_data.new_name, parent_id=db_file.parent_id, owner=current_user, is_deleted=False
parent_id=db_file.parent_id,
owner=current_user,
is_deleted=False,
) )
if existing_file and existing_file.id != file_id: if existing_file and existing_file.id != file_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in the same folder")
status_code=status.HTTP_409_CONFLICT,
detail="File with this name already exists in the same folder",
)
db_file.name = rename_data.new_name db_file.name = rename_data.new_name
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_renamed", target_type="file", target_id=file_id)
user=current_user, action="file_renamed", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED)
@router.post( async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)):
"/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED
)
async def copy_file(
file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
target_folder = None target_folder = None
if copy_data.target_folder_id: if copy_data.target_folder_id:
target_folder = await Folder.get_or_none( target_folder = await Folder.get_or_none(id=copy_data.target_folder_id, owner=current_user, is_deleted=False)
id=copy_data.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
)
base_name = db_file.name base_name = db_file.name
name_parts = os.path.splitext(base_name) name_parts = os.path.splitext(base_name)
counter = 1 counter = 1
new_name = base_name new_name = base_name
while await File.get_or_none( while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
name=new_name, parent=target_folder, owner=current_user, is_deleted=False
):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}" new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1 counter += 1
@ -282,205 +213,139 @@ async def copy_file(
mime_type=db_file.mime_type, mime_type=db_file.mime_type,
file_hash=db_file.file_hash, file_hash=db_file.file_hash,
owner=current_user, owner=current_user,
parent=target_folder, parent=target_folder
) )
await log_activity( await log_activity(user=current_user, action="file_copied", target_type="file", target_id=new_file.id)
user=current_user,
action="file_copied",
target_type="file",
target_id=new_file.id,
)
return await FileOut.from_tortoise_orm(new_file) return await FileOut.from_tortoise_orm(new_file)
@router.get("/", response_model=List[FileOut]) @router.get("/", response_model=List[FileOut])
async def list_files( async def list_files(folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)
):
if folder_id: if folder_id:
parent_folder = await Folder.get_or_none( parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not parent_folder: if not parent_folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found" files = await File.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
)
files = await File.filter(
parent=parent_folder, owner=current_user, is_deleted=False
).order_by("name")
else: else:
files = await File.filter( files = await File.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
parent=None, owner=current_user, is_deleted=False
).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/thumbnail/{file_id}") @router.get("/thumbnail/{file_id}")
async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)): async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.last_accessed_at = datetime.now() db_file.last_accessed_at = datetime.now()
await db_file.save() await db_file.save()
thumbnail_path = getattr(db_file, "thumbnail_path", None) thumbnail_path = getattr(db_file, 'thumbnail_path', None)
if not thumbnail_path: if not thumbnail_path:
thumbnail_path = await generate_thumbnail( thumbnail_path = await generate_thumbnail(db_file.path, db_file.mime_type, current_user.id)
db_file.path, db_file.mime_type, current_user.id
)
if thumbnail_path: if thumbnail_path:
db_file.thumbnail_path = thumbnail_path db_file.thumbnail_path = thumbnail_path
await db_file.save() await db_file.save()
else: else:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available")
status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available"
)
try: try:
async def thumbnail_iterator(): async def thumbnail_iterator():
async for chunk in storage_manager.get_file( async for chunk in storage_manager.get_file(current_user.id, thumbnail_path):
current_user.id, thumbnail_path
):
yield chunk yield chunk
return StreamingResponse(thumbnail_iterator(), media_type="image/jpeg") return StreamingResponse(
except FileNotFoundError: thumbnail_iterator(),
raise HTTPException( media_type="image/jpeg"
status_code=status.HTTP_404_NOT_FOUND,
detail="Thumbnail not found in storage",
) )
except FileNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not found in storage")
@router.get("/photos", response_model=List[FileOut]) @router.get("/photos", response_model=List[FileOut])
async def list_photos(current_user: User = Depends(get_current_user)): async def list_photos(current_user: User = Depends(get_current_user)):
files = await File.filter( files = await File.filter(
owner=current_user, is_deleted=False, mime_type__istartswith="image/" owner=current_user,
is_deleted=False,
mime_type__istartswith="image/"
).order_by("-created_at") ).order_by("-created_at")
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/recent", response_model=List[FileOut]) @router.get("/recent", response_model=List[FileOut])
async def list_recent_files( async def list_recent_files(current_user: User = Depends(get_current_user), limit: int = 10):
current_user: User = Depends(get_current_user), limit: int = 10 files = await File.filter(
): owner=current_user,
files = ( is_deleted=False,
await File.filter( last_accessed_at__isnull=False
owner=current_user, is_deleted=False, last_accessed_at__isnull=False ).order_by("-last_accessed_at").limit(limit)
)
.order_by("-last_accessed_at")
.limit(limit)
)
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.post("/{file_id}/star", response_model=FileOut) @router.post("/{file_id}/star", response_model=FileOut)
async def star_file(file_id: int, current_user: User = Depends(get_current_user)): async def star_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_starred = True db_file.is_starred = True
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_starred", target_type="file", target_id=file_id)
user=current_user, action="file_starred", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/unstar", response_model=FileOut) @router.post("/{file_id}/unstar", response_model=FileOut)
async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)): async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_starred = False db_file.is_starred = False
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_unstarred", target_type="file", target_id=file_id)
user=current_user,
action="file_unstarred",
target_type="file",
target_id=file_id,
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
@router.get("/deleted", response_model=List[FileOut]) @router.get("/deleted", response_model=List[FileOut])
async def list_deleted_files(current_user: User = Depends(get_current_user)): async def list_deleted_files(current_user: User = Depends(get_current_user)):
files = await File.filter(owner=current_user, is_deleted=True).order_by( files = await File.filter(owner=current_user, is_deleted=True).order_by("-deleted_at")
"-deleted_at"
)
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.post("/{file_id}/restore", response_model=FileOut) @router.post("/{file_id}/restore", response_model=FileOut)
async def restore_file(file_id: int, current_user: User = Depends(get_current_user)): async def restore_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found"
)
# Check if a file with the same name exists in the parent folder # Check if a file with the same name exists in the parent folder
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False
) )
if existing_file: if existing_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.")
status_code=status.HTTP_409_CONFLICT,
detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.",
)
db_file.is_deleted = False db_file.is_deleted = False
db_file.deleted_at = None db_file.deleted_at = None
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_restored", target_type="file", target_id=file_id)
user=current_user, action="file_restored", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)
class BatchOperationResult(BaseModel): class BatchOperationResult(BaseModel):
succeeded: List[FileOut] succeeded: List[FileOut]
failed: List[dict] failed: List[dict]
@router.post("/batch") @router.post("/batch")
async def batch_file_operations( async def batch_file_operations(
batch_operation: BatchFileOperation, batch_operation: BatchFileOperation,
payload: Optional[BatchMoveCopyPayload] = None, payload: Optional[BatchMoveCopyPayload] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user)
): ):
if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]: if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid operation: {batch_operation.operation}")
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid operation: {batch_operation.operation}",
)
updated_files = [] updated_files = []
failed_operations = [] failed_operations = []
for file_id in batch_operation.file_ids: for file_id in batch_operation.file_ids:
try: try:
db_file = await File.get_or_none( db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
id=file_id, owner=current_user, is_deleted=False
)
if not db_file: if not db_file:
failed_operations.append( failed_operations.append({"file_id": file_id, "reason": "File not found or not owned by user"})
{
"file_id": file_id,
"reason": "File not found or not owned by user",
}
)
continue continue
if batch_operation.operation == "delete": if batch_operation.operation == "delete":
@ -488,87 +353,47 @@ async def batch_file_operations(
db_file.deleted_at = datetime.now() db_file.deleted_at = datetime.now()
await db_file.save() await db_file.save()
await delete_thumbnail(db_file.id) await delete_thumbnail(db_file.id)
await log_activity( await log_activity(user=current_user, action="file_deleted_batch", target_type="file", target_id=file_id)
user=current_user,
action="file_deleted_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "star": elif batch_operation.operation == "star":
db_file.is_starred = True db_file.is_starred = True
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_starred_batch", target_type="file", target_id=file_id)
user=current_user,
action="file_starred_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "unstar": elif batch_operation.operation == "unstar":
db_file.is_starred = False db_file.is_starred = False
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_unstarred_batch", target_type="file", target_id=file_id)
user=current_user,
action="file_unstarred_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "move": elif batch_operation.operation == "move":
if not payload or payload.target_folder_id is None: if not payload or payload.target_folder_id is None:
failed_operations.append( failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
{"file_id": file_id, "reason": "Target folder not specified"}
)
continue continue
target_folder = await Folder.get_or_none( target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
failed_operations.append( failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
{"file_id": file_id, "reason": "Target folder not found"}
)
continue continue
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=db_file.name, name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
parent=target_folder,
owner=current_user,
is_deleted=False,
) )
if existing_file and existing_file.id != file_id: if existing_file and existing_file.id != file_id:
failed_operations.append( failed_operations.append({"file_id": file_id, "reason": "File with same name exists in target folder"})
{
"file_id": file_id,
"reason": "File with same name exists in target folder",
}
)
continue continue
db_file.parent = target_folder db_file.parent = target_folder
await db_file.save() await db_file.save()
await log_activity( await log_activity(user=current_user, action="file_moved_batch", target_type="file", target_id=file_id)
user=current_user,
action="file_moved_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file) updated_files.append(db_file)
elif batch_operation.operation == "copy": elif batch_operation.operation == "copy":
if not payload or payload.target_folder_id is None: if not payload or payload.target_folder_id is None:
failed_operations.append( failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
{"file_id": file_id, "reason": "Target folder not specified"}
)
continue continue
target_folder = await Folder.get_or_none( target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
failed_operations.append( failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
{"file_id": file_id, "reason": "Target folder not found"}
)
continue continue
base_name = db_file.name base_name = db_file.name
@ -576,12 +401,7 @@ async def batch_file_operations(
counter = 1 counter = 1
new_name = base_name new_name = base_name
while await File.get_or_none( while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
name=new_name,
parent=target_folder,
owner=current_user,
is_deleted=False,
):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}" new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1 counter += 1
@ -592,70 +412,48 @@ async def batch_file_operations(
mime_type=db_file.mime_type, mime_type=db_file.mime_type,
file_hash=db_file.file_hash, file_hash=db_file.file_hash,
owner=current_user, owner=current_user,
parent=target_folder, parent=target_folder
)
await log_activity(
user=current_user,
action="file_copied_batch",
target_type="file",
target_id=new_file.id,
) )
await log_activity(user=current_user, action="file_copied_batch", target_type="file", target_id=new_file.id)
updated_files.append(new_file) updated_files.append(new_file)
except Exception as e: except Exception as e:
failed_operations.append({"file_id": file_id, "reason": str(e)}) failed_operations.append({"file_id": file_id, "reason": str(e)})
return { return {
"succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files], "succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files],
"failed": failed_operations, "failed": failed_operations
} }
@router.put("/{file_id}/content", response_model=FileOut) @router.put("/{file_id}/content", response_model=FileOut)
async def update_file_content( async def update_file_content(
file_id: int, file_id: int,
payload: FileContentUpdate, payload: FileContentUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user)
): ):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False) db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file: if not db_file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
if not db_file.mime_type or not db_file.mime_type.startswith("text/"): if not db_file.mime_type or not db_file.mime_type.startswith('text/'):
editableExtensions = [ editableExtensions = [
"txt", 'txt', 'md', 'log', 'json', 'js', 'py', 'html', 'css',
"md", 'xml', 'yaml', 'yml', 'sh', 'bat', 'ini', 'conf', 'cfg'
"log",
"json",
"js",
"py",
"html",
"css",
"xml",
"yaml",
"yml",
"sh",
"bat",
"ini",
"conf",
"cfg",
] ]
file_extension = os.path.splitext(db_file.name)[1][1:].lower() file_extension = os.path.splitext(db_file.name)[1][1:].lower()
if file_extension not in editableExtensions: if file_extension not in editableExtensions:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="File type is not editable", detail="File type is not editable"
) )
content_bytes = payload.content.encode("utf-8") content_bytes = payload.content.encode('utf-8')
new_size = len(content_bytes) new_size = len(content_bytes)
size_diff = new_size - db_file.size size_diff = new_size - db_file.size
if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes: if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_507_INSUFFICIENT_STORAGE, status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
detail="Storage quota exceeded", detail="Storage quota exceeded"
) )
new_hash = hashlib.sha256(content_bytes).hexdigest() new_hash = hashlib.sha256(content_bytes).hexdigest()
@ -667,7 +465,7 @@ async def update_file_content(
if new_storage_path != db_file.path: if new_storage_path != db_file.path:
try: try:
await storage_manager.delete_file(current_user.id, db_file.path) await storage_manager.delete_file(current_user.id, db_file.path)
except Exception: except:
pass pass
db_file.path = new_storage_path db_file.path = new_storage_path
@ -679,8 +477,6 @@ async def update_file_content(
current_user.used_storage_bytes += size_diff current_user.used_storage_bytes += size_diff
await current_user.save() await current_user.save()
await log_activity( await log_activity(user=current_user, action="file_updated", target_type="file", target_id=file_id)
user=current_user, action="file_updated", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file) return await FileOut.from_tortoise_orm(db_file)

View File

@ -3,13 +3,7 @@ from typing import List, Optional
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User, Folder from ..models import User, Folder
from ..schemas import ( from ..schemas import FolderCreate, FolderOut, FolderUpdate, BatchFolderOperation, BatchMoveCopyPayload
FolderCreate,
FolderOut,
FolderUpdate,
BatchFolderOperation,
BatchMoveCopyPayload,
)
from ..activity import log_activity from ..activity import log_activity
router = APIRouter( router = APIRouter(
@ -17,17 +11,12 @@ router = APIRouter(
tags=["folders"], tags=["folders"],
) )
@router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
async def create_folder( async def create_folder(folder_in: FolderCreate, current_user: User = Depends(get_current_user)):
folder_in: FolderCreate, current_user: User = Depends(get_current_user)
):
# Check if parent folder exists and belongs to the current user # Check if parent folder exists and belongs to the current user
parent_folder = None parent_folder = None
if folder_in.parent_id: if folder_in.parent_id:
parent_folder = await Folder.get_or_none( parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user)
id=folder_in.parent_id, owner=current_user
)
if not parent_folder: if not parent_folder:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -50,82 +39,50 @@ async def create_folder(
await log_activity(current_user, "folder_created", "folder", folder.id) await log_activity(current_user, "folder_created", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder) return await FolderOut.from_tortoise_orm(folder)
@router.get("/{folder_id}/path", response_model=List[FolderOut]) @router.get("/{folder_id}/path", response_model=List[FolderOut])
async def get_folder_path( async def get_folder_path(folder_id: int, current_user: User = Depends(get_current_user)):
folder_id: int, current_user: User = Depends(get_current_user) folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
):
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
path = [] path = []
current = folder current = folder
while current: while current:
path.insert(0, await FolderOut.from_tortoise_orm(current)) path.insert(0, await FolderOut.from_tortoise_orm(current))
if current.parent_id: if current.parent_id:
current = await Folder.get_or_none( current = await Folder.get_or_none(id=current.parent_id, owner=current_user, is_deleted=False)
id=current.parent_id, owner=current_user, is_deleted=False
)
else: else:
current = None current = None
return path return path
@router.get("/{folder_id}", response_model=FolderOut) @router.get("/{folder_id}", response_model=FolderOut)
async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none( folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
return await FolderOut.from_tortoise_orm(folder) return await FolderOut.from_tortoise_orm(folder)
@router.get("/", response_model=List[FolderOut]) @router.get("/", response_model=List[FolderOut])
async def list_folders( async def list_folders(parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)
):
if parent_id: if parent_id:
parent_folder = await Folder.get_or_none( parent_folder = await Folder.get_or_none(id=parent_id, owner=current_user, is_deleted=False)
id=parent_id, owner=current_user, is_deleted=False
)
if not parent_folder: if not parent_folder:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Parent folder not found or does not belong to the current user", detail="Parent folder not found or does not belong to the current user",
) )
folders = await Folder.filter( folders = await Folder.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
parent=parent_folder, owner=current_user, is_deleted=False
).order_by("name")
else: else:
# List root folders (folders with no parent) # List root folders (folders with no parent)
folders = await Folder.filter( folders = await Folder.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
parent=None, owner=current_user, is_deleted=False
).order_by("name")
return [await FolderOut.from_tortoise_orm(folder) for folder in folders] return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
@router.put("/{folder_id}", response_model=FolderOut) @router.put("/{folder_id}", response_model=FolderOut)
async def update_folder( async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: User = Depends(get_current_user)):
folder_id: int, folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
folder_in: FolderUpdate,
current_user: User = Depends(get_current_user),
):
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
if folder_in.name: if folder_in.name:
existing_folder = await Folder.get_or_none( existing_folder = await Folder.get_or_none(
@ -140,16 +97,11 @@ async def update_folder(
if folder_in.parent_id is not None: if folder_in.parent_id is not None:
if folder_in.parent_id == folder_id: if folder_in.parent_id == folder_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot set folder as its own parent")
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot set folder as its own parent",
)
new_parent_folder = None new_parent_folder = None
if folder_in.parent_id != 0: # 0 could represent moving to root if folder_in.parent_id != 0: # 0 could represent moving to root
new_parent_folder = await Folder.get_or_none( new_parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user, is_deleted=False)
id=folder_in.parent_id, owner=current_user, is_deleted=False
)
if not new_parent_folder: if not new_parent_folder:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -161,66 +113,48 @@ async def update_folder(
await log_activity(current_user, "folder_updated", "folder", folder.id) await log_activity(current_user, "folder_updated", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder) return await FolderOut.from_tortoise_orm(folder)
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none( folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
folder.is_deleted = True folder.is_deleted = True
await folder.save() await folder.save()
await log_activity(current_user, "folder_deleted", "folder", folder.id) await log_activity(current_user, "folder_deleted", "folder", folder.id)
return return
@router.post("/{folder_id}/star", response_model=FolderOut) @router.post("/{folder_id}/star", response_model=FolderOut)
async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)):
db_folder = await Folder.get_or_none( db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder: if not db_folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
db_folder.is_starred = True db_folder.is_starred = True
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_starred", "folder", folder_id) await log_activity(current_user, "folder_starred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder) return await FolderOut.from_tortoise_orm(db_folder)
@router.post("/{folder_id}/unstar", response_model=FolderOut) @router.post("/{folder_id}/unstar", response_model=FolderOut)
async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)): async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)):
db_folder = await Folder.get_or_none( db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder: if not db_folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
db_folder.is_starred = False db_folder.is_starred = False
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id) await log_activity(current_user, "folder_unstarred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder) return await FolderOut.from_tortoise_orm(db_folder)
@router.post("/batch", response_model=List[FolderOut]) @router.post("/batch", response_model=List[FolderOut])
async def batch_folder_operations( async def batch_folder_operations(
batch_operation: BatchFolderOperation, batch_operation: BatchFolderOperation,
payload: Optional[BatchMoveCopyPayload] = None, payload: Optional[BatchMoveCopyPayload] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user)
): ):
updated_folders = [] updated_folders = []
for folder_id in batch_operation.folder_ids: for folder_id in batch_operation.folder_ids:
db_folder = await Folder.get_or_none( db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder: if not db_folder:
continue # Skip if folder not found or not owned by user continue # Skip if folder not found or not owned by user
if batch_operation.operation == "delete": if batch_operation.operation == "delete":
db_folder.is_deleted = True db_folder.is_deleted = True
@ -237,22 +171,13 @@ async def batch_folder_operations(
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id) await log_activity(current_user, "folder_unstarred", "folder", folder_id)
updated_folders.append(db_folder) updated_folders.append(db_folder)
elif ( elif batch_operation.operation == "move" and payload and payload.target_folder_id is not None:
batch_operation.operation == "move" target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
and payload
and payload.target_folder_id is not None
):
target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder: if not target_folder:
continue continue
existing_folder = await Folder.get_or_none( existing_folder = await Folder.get_or_none(
name=db_folder.name, name=db_folder.name, parent=target_folder, owner=current_user, is_deleted=False
parent=target_folder,
owner=current_user,
is_deleted=False,
) )
if existing_folder and existing_folder.id != folder_id: if existing_folder and existing_folder.id != folder_id:
continue continue
@ -261,5 +186,5 @@ async def batch_folder_operations(
await db_folder.save() await db_folder.save()
await log_activity(current_user, "folder_moved", "folder", folder_id) await log_activity(current_user, "folder_moved", "folder", folder_id)
updated_folders.append(db_folder) updated_folders.append(db_folder)
return [await FolderOut.from_tortoise_orm(f) for f in updated_folders] return [await FolderOut.from_tortoise_orm(f) for f in updated_folders]

View File

@ -11,22 +11,15 @@ router = APIRouter(
tags=["search"], tags=["search"],
) )
@router.get("/files", response_model=List[FileOut]) @router.get("/files", response_model=List[FileOut])
async def search_files( async def search_files(
q: str = Query(..., min_length=1, description="Search query"), q: str = Query(..., min_length=1, description="Search query"),
file_type: Optional[str] = Query( file_type: Optional[str] = Query(None, description="Filter by MIME type prefix (e.g., 'image', 'video')"),
None, description="Filter by MIME type prefix (e.g., 'image', 'video')"
),
min_size: Optional[int] = Query(None, description="Minimum file size in bytes"), min_size: Optional[int] = Query(None, description="Minimum file size in bytes"),
max_size: Optional[int] = Query(None, description="Maximum file size in bytes"), max_size: Optional[int] = Query(None, description="Maximum file size in bytes"),
date_from: Optional[datetime] = Query( date_from: Optional[datetime] = Query(None, description="Filter files created after this date"),
None, description="Filter files created after this date" date_to: Optional[datetime] = Query(None, description="Filter files created before this date"),
), current_user: User = Depends(get_current_user)
date_to: Optional[datetime] = Query(
None, description="Filter files created before this date"
),
current_user: User = Depends(get_current_user),
): ):
query = File.filter(owner=current_user, is_deleted=False, name__icontains=q) query = File.filter(owner=current_user, is_deleted=False, name__icontains=q)
@ -48,16 +41,15 @@ async def search_files(
files = await query.order_by("-created_at").limit(100) files = await query.order_by("-created_at").limit(100)
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/folders", response_model=List[FolderOut]) @router.get("/folders", response_model=List[FolderOut])
async def search_folders( async def search_folders(
q: str = Query(..., min_length=1, description="Search query"), q: str = Query(..., min_length=1, description="Search query"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user)
): ):
folders = ( folders = await Folder.filter(
await Folder.filter(owner=current_user, is_deleted=False, name__icontains=q) owner=current_user,
.order_by("-created_at") is_deleted=False,
.limit(100) name__icontains=q
) ).order_by("-created_at").limit(100)
return [await FolderOut.from_tortoise_orm(folder) for folder in folders] return [await FolderOut.from_tortoise_orm(folder) for folder in folders]

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing import Optional, List from typing import Optional, List
import secrets import secrets
from datetime import datetime from datetime import datetime, timedelta
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User, File, Folder, Share from ..models import User, File, Folder, Share
@ -17,44 +17,25 @@ router = APIRouter(
tags=["shares"], tags=["shares"],
) )
@router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED)
async def create_share_link( async def create_share_link(share_in: ShareCreate, current_user: User = Depends(get_current_user)):
share_in: ShareCreate, current_user: User = Depends(get_current_user)
):
if not share_in.file_id and not share_in.folder_id: if not share_in.file_id and not share_in.folder_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either file_id or folder_id must be provided")
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either file_id or folder_id must be provided",
)
if share_in.file_id and share_in.folder_id: if share_in.file_id and share_in.folder_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot share both a file and a folder simultaneously")
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot share both a file and a folder simultaneously",
)
file = None file = None
folder = None folder = None
if share_in.file_id: if share_in.file_id:
file = await File.get_or_none( file = await File.get_or_none(id=share_in.file_id, owner=current_user, is_deleted=False)
id=share_in.file_id, owner=current_user, is_deleted=False
)
if not file: if not file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found or does not belong to you")
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found or does not belong to you",
)
if share_in.folder_id: if share_in.folder_id:
folder = await Folder.get_or_none( folder = await Folder.get_or_none(id=share_in.folder_id, owner=current_user, is_deleted=False)
id=share_in.folder_id, owner=current_user, is_deleted=False
)
if not folder: if not folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found or does not belong to you")
status_code=status.HTTP_404_NOT_FOUND,
detail="Folder not found or does not belong to you",
)
token = secrets.token_urlsafe(16) token = secrets.token_urlsafe(16)
hashed_password = None hashed_password = None
@ -79,14 +60,8 @@ async def create_share_link(
item_type = "file" if file else "folder" item_type = "file" if file else "folder"
item_name = file.name if file else folder.name item_name = file.name if file else folder.name
expiry_text = ( expiry_text = f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" if share_in.expires_at else ""
f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" password_text = f"\n\nPassword: {share_in.password}" if share_in.password else ""
if share_in.expires_at
else ""
)
password_text = (
f"\n\nPassword: {share_in.password}" if share_in.password else ""
)
email_body = f"""Hello, email_body = f"""Hello,
@ -120,103 +95,80 @@ MyWebdav File Sharing Service"""
to_email=share_in.invite_email, to_email=share_in.invite_email,
subject=f"{current_user.username} shared {item_name} with you", subject=f"{current_user.username} shared {item_name} with you",
body=email_body, body=email_body,
html=email_html, html=email_html
) )
except Exception as e: except Exception as e:
print(f"Failed to send invitation email: {e}") print(f"Failed to send invitation email: {e}")
return await ShareOut.from_tortoise_orm(share) return await ShareOut.from_tortoise_orm(share)
@router.get("/my", response_model=List[ShareOut]) @router.get("/my", response_model=List[ShareOut])
async def list_my_shares(current_user: User = Depends(get_current_user)): async def list_my_shares(current_user: User = Depends(get_current_user)):
shares = await Share.filter(owner=current_user).order_by("-created_at") shares = await Share.filter(owner=current_user).order_by("-created_at")
return [await ShareOut.from_tortoise_orm(share) for share in shares] return [await ShareOut.from_tortoise_orm(share) for share in shares]
@router.get("/{share_token}", response_model=ShareOut) @router.get("/{share_token}", response_model=ShareOut)
async def get_share_link_info(share_token: str): async def get_share_link_info(share_token: str):
share = await Share.get_or_none(token=share_token) share = await Share.get_or_none(token=share_token)
if not share: if not share:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow(): if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException( raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
# Increment access count # Increment access count
share.access_count += 1 share.access_count += 1
await share.save() await share.save()
return await ShareOut.from_tortoise_orm(share) return await ShareOut.from_tortoise_orm(share)
@router.put("/{share_id}", response_model=ShareOut) @router.put("/{share_id}", response_model=ShareOut)
async def update_share( async def update_share(share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)):
share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)
):
share = await Share.get_or_none(id=share_id, owner=current_user) share = await Share.get_or_none(id=share_id, owner=current_user)
if not share: if not share:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
status_code=status.HTTP_404_NOT_FOUND,
detail="Share link not found or does not belong to you",
)
if share_in.expires_at is not None: if share_in.expires_at is not None:
share.expires_at = share_in.expires_at share.expires_at = share_in.expires_at
if share_in.password is not None: if share_in.password is not None:
share.hashed_password = get_password_hash(share_in.password) share.hashed_password = get_password_hash(share_in.password)
share.password_protected = True share.password_protected = True
elif share_in.password == "": # Allow clearing password elif share_in.password == "": # Allow clearing password
share.hashed_password = None share.hashed_password = None
share.password_protected = False share.password_protected = False
if share_in.permission_level is not None: if share_in.permission_level is not None:
share.permission_level = share_in.permission_level share.permission_level = share_in.permission_level
await share.save() await share.save()
return await ShareOut.from_tortoise_orm(share) return await ShareOut.from_tortoise_orm(share)
@router.post("/{share_token}/access") @router.post("/{share_token}/access")
async def access_shared_content(share_token: str, password: Optional[str] = None): async def access_shared_content(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token) share = await Share.get_or_none(token=share_token)
if not share: if not share:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow(): if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException( raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
if share.password_protected: if share.password_protected:
if not password or not verify_password(password, share.hashed_password): if not password or not verify_password(password, share.hashed_password):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
)
result = {"message": "Access granted", "permission_level": share.permission_level} result = {"message": "Access granted", "permission_level": share.permission_level}
if share.file_id: if share.file_id:
file = await File.get_or_none(id=share.file_id, is_deleted=False) file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file: if not file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
result["file"] = await FileOut.from_tortoise_orm(file) result["file"] = await FileOut.from_tortoise_orm(file)
result["type"] = "file" result["type"] = "file"
elif share.folder_id: elif share.folder_id:
folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False) folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False)
if not folder: if not folder:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
result["folder"] = await FolderOut.from_tortoise_orm(folder) result["folder"] = await FolderOut.from_tortoise_orm(folder)
result["type"] = "folder" result["type"] = "folder"
@ -227,42 +179,29 @@ async def access_shared_content(share_token: str, password: Optional[str] = None
return result return result
@router.get("/{share_token}/download") @router.get("/{share_token}/download")
async def download_shared_file(share_token: str, password: Optional[str] = None): async def download_shared_file(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token) share = await Share.get_or_none(token=share_token)
if not share: if not share:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow(): if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException( raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
if share.password_protected: if share.password_protected:
if not password or not verify_password(password, share.hashed_password): if not password or not verify_password(password, share.hashed_password):
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
)
if not share.file_id: if not share.file_id:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This share is not for a file")
status_code=status.HTTP_400_BAD_REQUEST,
detail="This share is not for a file",
)
file = await File.get_or_none(id=share.file_id, is_deleted=False) file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file: if not file:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
owner = await User.get(id=file.owner_id) owner = await User.get(id=file.owner_id)
try: try:
async def file_iterator(): async def file_iterator():
async for chunk in storage_manager.get_file(owner.id, file.path): async for chunk in storage_manager.get_file(owner.id, file.path):
yield chunk yield chunk
@ -270,22 +209,18 @@ async def download_shared_file(share_token: str, password: Optional[str] = None)
return StreamingResponse( return StreamingResponse(
content=file_iterator(), content=file_iterator(),
media_type=file.mime_type, media_type=file.mime_type,
headers={"Content-Disposition": f'attachment; filename="{file.name}"'}, headers={
"Content-Disposition": f'attachment; filename="{file.name}"'
}
) )
except FileNotFoundError: except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage") raise HTTPException(status_code=404, detail="File not found in storage")
@router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_share_link( async def delete_share_link(share_id: int, current_user: User = Depends(get_current_user)):
share_id: int, current_user: User = Depends(get_current_user)
):
share = await Share.get_or_none(id=share_id, owner=current_user) share = await Share.get_or_none(id=share_id, owner=current_user)
if not share: if not share:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
status_code=status.HTTP_404_NOT_FOUND,
detail="Share link not found or does not belong to you",
)
await share.delete() await share.delete()
return return

View File

@ -11,29 +11,18 @@ router = APIRouter(
tags=["starred"], tags=["starred"],
) )
@router.get("/files", response_model=List[FileOut]) @router.get("/files", response_model=List[FileOut])
async def list_starred_files(current_user: User = Depends(get_current_user)): async def list_starred_files(current_user: User = Depends(get_current_user)):
files = await File.filter( files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files] return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/folders", response_model=List[FolderOut]) @router.get("/folders", response_model=List[FolderOut])
async def list_starred_folders(current_user: User = Depends(get_current_user)): async def list_starred_folders(current_user: User = Depends(get_current_user)):
folders = await Folder.filter( folders = await Folder.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
return [await FolderOut.from_tortoise_orm(f) for f in folders] return [await FolderOut.from_tortoise_orm(f) for f in folders]
@router.get("/all", response_model=List[FileOut]) # This will return files and folders as files for now
@router.get(
"/all", response_model=List[FileOut]
) # This will return files and folders as files for now
async def list_all_starred(current_user: User = Depends(get_current_user)): async def list_all_starred(current_user: User = Depends(get_current_user)):
starred_files = await File.filter( starred_files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
# For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema. # For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema.
return [await FileOut.from_tortoise_orm(f) for f in starred_files] return [await FileOut.from_tortoise_orm(f) for f in starred_files]

View File

@ -1,19 +1,17 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from ..auth import get_current_user from ..auth import get_current_user
from ..models import User_Pydantic, User, File, Folder from ..models import User_Pydantic, User, File, Folder
from typing import Dict, Any from typing import List, Dict, Any
router = APIRouter( router = APIRouter(
prefix="/users", prefix="/users",
tags=["users"], tags=["users"],
) )
@router.get("/me", response_model=User_Pydantic) @router.get("/me", response_model=User_Pydantic)
async def read_users_me(current_user: User = Depends(get_current_user)): async def read_users_me(current_user: User = Depends(get_current_user)):
return await User_Pydantic.from_tortoise_orm(current_user) return await User_Pydantic.from_tortoise_orm(current_user)
@router.get("/me/export", response_model=Dict[str, Any]) @router.get("/me/export", response_model=Dict[str, Any])
async def export_my_data(current_user: User = Depends(get_current_user)): async def export_my_data(current_user: User = Depends(get_current_user)):
""" """
@ -37,7 +35,6 @@ async def export_my_data(current_user: User = Depends(get_current_user)):
# share information, etc., would also be included. # share information, etc., would also be included.
} }
@router.delete("/me", status_code=204) @router.delete("/me", status_code=204)
async def delete_my_account(current_user: User = Depends(get_current_user)): async def delete_my_account(current_user: User = Depends(get_current_user)):
""" """
@ -51,3 +48,5 @@ async def delete_my_account(current_user: User = Depends(get_current_user)):
# Finally, delete the user account # Finally, delete the user account
await current_user.delete() await current_user.delete()
return {} return {}

View File

@ -1,25 +1,20 @@
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, EmailStr, ConfigDict from pydantic import BaseModel, EmailStr
from typing import Optional, List from typing import Optional, List
from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.contrib.pydantic import pydantic_model_creator
from mywebdav.models import Folder, File, Share, FileVersion
class UserCreate(BaseModel): class UserCreate(BaseModel):
username: str username: str
email: EmailStr email: EmailStr
password: str password: str
class UserLogin(BaseModel): class UserLogin(BaseModel):
username: str username: str
password: str password: str
class UserLoginWith2FA(UserLogin): class UserLoginWith2FA(UserLogin):
two_factor_code: Optional[str] = None two_factor_code: Optional[str] = None
class UserAdminUpdate(BaseModel): class UserAdminUpdate(BaseModel):
username: Optional[str] = None username: Optional[str] = None
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
@ -30,27 +25,22 @@ class UserAdminUpdate(BaseModel):
plan_type: Optional[str] = None plan_type: Optional[str] = None
is_2fa_enabled: Optional[bool] = None is_2fa_enabled: Optional[bool] = None
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class TokenData(BaseModel): class TokenData(BaseModel):
username: str | None = None username: str | None = None
two_factor_verified: bool = False two_factor_verified: bool = False
class FolderCreate(BaseModel): class FolderCreate(BaseModel):
name: str name: str
parent_id: Optional[int] = None parent_id: Optional[int] = None
class FolderUpdate(BaseModel): class FolderUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
parent_id: Optional[int] = None parent_id: Optional[int] = None
class ShareCreate(BaseModel): class ShareCreate(BaseModel):
file_id: Optional[int] = None file_id: Optional[int] = None
folder_id: Optional[int] = None folder_id: Optional[int] = None
@ -59,19 +49,17 @@ class ShareCreate(BaseModel):
permission_level: str = "viewer" permission_level: str = "viewer"
invite_email: Optional[EmailStr] = None invite_email: Optional[EmailStr] = None
class TeamCreate(BaseModel): class TeamCreate(BaseModel):
name: str name: str
class TeamOut(BaseModel): class TeamOut(BaseModel):
id: int id: int
name: str name: str
owner_id: int owner_id: int
created_at: datetime created_at: datetime
model_config = ConfigDict(from_attributes=True) class Config:
from_attributes = True
class ActivityOut(BaseModel): class ActivityOut(BaseModel):
id: int id: int
@ -82,8 +70,8 @@ class ActivityOut(BaseModel):
ip_address: Optional[str] = None ip_address: Optional[str] = None
timestamp: datetime timestamp: datetime
model_config = ConfigDict(from_attributes=True) class Config:
from_attributes = True
class FileRequestCreate(BaseModel): class FileRequestCreate(BaseModel):
title: str title: str
@ -91,7 +79,6 @@ class FileRequestCreate(BaseModel):
target_folder_id: int target_folder_id: int
expires_at: Optional[datetime] = None expires_at: Optional[datetime] = None
class FileRequestOut(BaseModel): class FileRequestOut(BaseModel):
id: int id: int
title: str title: str
@ -103,30 +90,28 @@ class FileRequestOut(BaseModel):
expires_at: Optional[datetime] = None expires_at: Optional[datetime] = None
is_active: bool is_active: bool
model_config = ConfigDict(from_attributes=True) class Config:
from_attributes = True
from mywebdav.models import Folder, File, Share, FileVersion
FolderOut = pydantic_model_creator(Folder, name="FolderOut") FolderOut = pydantic_model_creator(Folder, name="FolderOut")
FileOut = pydantic_model_creator(File, name="FileOut") FileOut = pydantic_model_creator(File, name="FileOut")
ShareOut = pydantic_model_creator(Share, name="ShareOut") ShareOut = pydantic_model_creator(Share, name="ShareOut")
FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut") FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
code: int code: int
message: str message: str
details: Optional[str] = None details: Optional[str] = None
class BatchFileOperation(BaseModel): class BatchFileOperation(BaseModel):
file_ids: List[int] file_ids: List[int]
operation: str # e.g., "delete", "move", "copy", "star", "unstar" operation: str # e.g., "delete", "move", "copy", "star", "unstar"
class BatchFolderOperation(BaseModel): class BatchFolderOperation(BaseModel):
folder_ids: List[int] folder_ids: List[int]
operation: str # e.g., "delete", "move", "star", "unstar" operation: str # e.g., "delete", "move", "star", "unstar"
class BatchMoveCopyPayload(BaseModel): class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None target_folder_id: Optional[int] = None

View File

@ -2,9 +2,8 @@ import os
import sys import sys
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="ignore") model_config = SettingsConfigDict(env_file='.env', extra='ignore')
DATABASE_URL: str = "sqlite:///app/mywebdav.db" DATABASE_URL: str = "sqlite:///app/mywebdav.db"
REDIS_URL: str = "redis://redis:6379/0" REDIS_URL: str = "redis://redis:6379/0"
@ -31,14 +30,8 @@ class Settings(BaseSettings):
STRIPE_WEBHOOK_SECRET: str = "" STRIPE_WEBHOOK_SECRET: str = ""
BILLING_ENABLED: bool = False BILLING_ENABLED: bool = False
settings = Settings() settings = Settings()
if ( if settings.SECRET_KEY == "super_secret_key" and os.getenv("ENVIRONMENT") == "production":
settings.SECRET_KEY == "super_secret_key" print("ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable.")
and os.getenv("ENVIRONMENT") == "production"
):
print(
"ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable."
)
sys.exit(1) sys.exit(1)

View File

@ -5,7 +5,6 @@ from typing import AsyncGenerator
from .settings import settings from .settings import settings
class StorageManager: class StorageManager:
def __init__(self, base_path: str = settings.STORAGE_PATH): def __init__(self, base_path: str = settings.STORAGE_PATH):
self.base_path = Path(base_path) self.base_path = Path(base_path)
@ -13,11 +12,7 @@ class StorageManager:
async def _get_full_path(self, user_id: int, file_path: str) -> Path: async def _get_full_path(self, user_id: int, file_path: str) -> Path:
# Ensure file_path is relative and safe # Ensure file_path is relative and safe
relative_path = ( relative_path = Path(file_path).relative_to('/') if str(file_path).startswith('/') else Path(file_path)
Path(file_path).relative_to("/")
if str(file_path).startswith("/")
else Path(file_path)
)
full_path = self.base_path / str(user_id) / relative_path full_path = self.base_path / str(user_id) / relative_path
full_path.parent.mkdir(parents=True, exist_ok=True) full_path.parent.mkdir(parents=True, exist_ok=True)
return full_path return full_path
@ -32,7 +27,7 @@ class StorageManager:
full_path = await self._get_full_path(user_id, file_path) full_path = await self._get_full_path(user_id, file_path)
if not full_path.exists(): if not full_path.exists():
raise FileNotFoundError(f"File not found: {file_path}") raise FileNotFoundError(f"File not found: {file_path}")
async with aiofiles.open(full_path, "rb") as f: async with aiofiles.open(full_path, "rb") as f:
while chunk := await f.read(8192): while chunk := await f.read(8192):
yield chunk yield chunk
@ -57,5 +52,4 @@ class StorageManager:
full_path = await self._get_full_path(user_id, file_path) full_path = await self._get_full_path(user_id, file_path)
return full_path.exists() return full_path.exists()
storage_manager = StorageManager() storage_manager = StorageManager()

View File

@ -1,3 +1,4 @@
import os
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
@ -8,10 +9,7 @@ from .settings import settings
THUMBNAIL_SIZE = (300, 300) THUMBNAIL_SIZE = (300, 300)
THUMBNAIL_DIR = "thumbnails" THUMBNAIL_DIR = "thumbnails"
async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Optional[str]:
async def generate_thumbnail(
file_path: str, mime_type: str, user_id: int
) -> Optional[str]:
try: try:
if mime_type.startswith("image/"): if mime_type.startswith("image/"):
return await generate_image_thumbnail(file_path, user_id) return await generate_image_thumbnail(file_path, user_id)
@ -22,7 +20,6 @@ async def generate_thumbnail(
print(f"Error generating thumbnail for {file_path}: {e}") print(f"Error generating thumbnail for {file_path}: {e}")
return None return None
async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]: async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -33,16 +30,12 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
file_name = Path(file_path).name file_name = Path(file_path).name
thumbnail_name = f"thumb_{file_name}" thumbnail_name = f"thumb_{file_name}"
if not thumbnail_name.lower().endswith((".jpg", ".jpeg", ".png")): if not thumbnail_name.lower().endswith(('.jpg', '.jpeg', '.png')):
thumbnail_name += ".jpg" thumbnail_name += ".jpg"
thumbnail_path = thumbnail_dir / thumbnail_name thumbnail_path = thumbnail_dir / thumbnail_name
actual_file_path = ( actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
base_path / str(user_id) / file_path
if not Path(file_path).is_absolute()
else Path(file_path)
)
with Image.open(actual_file_path) as img: with Image.open(actual_file_path) as img:
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS) img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
@ -51,9 +44,7 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
background = Image.new("RGB", img.size, (255, 255, 255)) background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == "P": if img.mode == "P":
img = img.convert("RGBA") img = img.convert("RGBA")
background.paste( background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None
)
img = background img = background
img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True) img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True)
@ -62,7 +53,6 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
return await loop.run_in_executor(None, _generate) return await loop.run_in_executor(None, _generate)
async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]: async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -75,29 +65,17 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
thumbnail_name = f"thumb_{file_name}.jpg" thumbnail_name = f"thumb_{file_name}.jpg"
thumbnail_path = thumbnail_dir / thumbnail_name thumbnail_path = thumbnail_dir / thumbnail_name
actual_file_path = ( actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
base_path / str(user_id) / file_path
if not Path(file_path).is_absolute()
else Path(file_path)
)
subprocess.run( subprocess.run([
[ "ffmpeg",
"ffmpeg", "-i", str(actual_file_path),
"-i", "-ss", "00:00:01",
str(actual_file_path), "-vframes", "1",
"-ss", "-vf", f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
"00:00:01", "-y",
"-vframes", str(thumbnail_path)
"1", ], check=True, capture_output=True)
"-vf",
f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
"-y",
str(thumbnail_path),
],
check=True,
capture_output=True,
)
return str(thumbnail_path.relative_to(base_path / str(user_id))) return str(thumbnail_path.relative_to(base_path / str(user_id)))
@ -106,7 +84,6 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
return None return None
async def delete_thumbnail(thumbnail_path: str, user_id: int): async def delete_thumbnail(thumbnail_path: str, user_id: int):
try: try:
base_path = Path(settings.STORAGE_PATH) base_path = Path(settings.STORAGE_PATH)

View File

@ -5,20 +5,15 @@ import base64
import secrets import secrets
import hashlib import hashlib
from typing import List from typing import List, Optional
def generate_totp_secret() -> str: def generate_totp_secret() -> str:
"""Generates a random base32 TOTP secret.""" """Generates a random base32 TOTP secret."""
return pyotp.random_base32() return pyotp.random_base32()
def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str: def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str:
"""Generates a Google Authenticator-compatible TOTP URI.""" """Generates a Google Authenticator-compatible TOTP URI."""
return pyotp.totp.TOTP(secret).provisioning_uri( return pyotp.totp.TOTP(secret).provisioning_uri(name=account_name, issuer_name=issuer_name)
name=account_name, issuer_name=issuer_name
)
def generate_qr_code_base64(uri: str) -> str: def generate_qr_code_base64(uri: str) -> str:
"""Generates a base64 encoded QR code image for a given URI.""" """Generates a base64 encoded QR code image for a given URI."""
@ -36,33 +31,27 @@ def generate_qr_code_base64(uri: str) -> str:
img.save(buffered, format="PNG") img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8") return base64.b64encode(buffered.getvalue()).decode("utf-8")
def verify_totp_code(secret: str, code: str) -> bool: def verify_totp_code(secret: str, code: str) -> bool:
"""Verifies a TOTP code against a secret.""" """Verifies a TOTP code against a secret."""
totp = pyotp.TOTP(secret) totp = pyotp.TOTP(secret)
return totp.verify(code) return totp.verify(code)
def generate_recovery_codes(num_codes: int = 10) -> List[str]: def generate_recovery_codes(num_codes: int = 10) -> List[str]:
"""Generates a list of random recovery codes.""" """Generates a list of random recovery codes."""
return [secrets.token_urlsafe(16) for _ in range(num_codes)] return [secrets.token_urlsafe(16) for _ in range(num_codes)]
def hash_recovery_code(code: str) -> str: def hash_recovery_code(code: str) -> str:
"""Hashes a single recovery code using SHA256.""" """Hashes a single recovery code using SHA256."""
return hashlib.sha256(code.encode("utf-8")).hexdigest() return hashlib.sha256(code.encode('utf-8')).hexdigest()
def verify_recovery_code(plain_code: str, hashed_code: str) -> bool: def verify_recovery_code(plain_code: str, hashed_code: str) -> bool:
"""Verifies a plain recovery code against its hashed version.""" """Verifies a plain recovery code against its hashed version."""
return hash_recovery_code(plain_code) == hashed_code return hash_recovery_code(plain_code) == hashed_code
def hash_recovery_codes(codes: List[str]) -> List[str]: def hash_recovery_codes(codes: List[str]) -> List[str]:
"""Hashes a list of recovery codes.""" """Hashes a list of recovery codes."""
return [hash_recovery_code(code) for code in codes] return [hash_recovery_code(code) for code in codes]
def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool: def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool:
"""Verifies if a plain recovery code matches any of the hashed recovery codes.""" """Verifies if a plain recovery code matches any of the hashed recovery codes."""
for hashed_code in hashed_codes: for hashed_code in hashed_codes:

View File

@ -19,7 +19,6 @@ router = APIRouter(
tags=["webdav"], tags=["webdav"],
) )
class WebDAVLock: class WebDAVLock:
locks = {} locks = {}
@ -27,10 +26,10 @@ class WebDAVLock:
def create_lock(cls, path: str, user_id: int, timeout: int = 3600): def create_lock(cls, path: str, user_id: int, timeout: int = 3600):
lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}" lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}"
cls.locks[path] = { cls.locks[path] = {
"token": lock_token, 'token': lock_token,
"user_id": user_id, 'user_id': user_id,
"created_at": datetime.now(), 'created_at': datetime.now(),
"timeout": timeout, 'timeout': timeout
} }
return lock_token return lock_token
@ -43,18 +42,17 @@ class WebDAVLock:
if path in cls.locks: if path in cls.locks:
del cls.locks[path] del cls.locks[path]
async def basic_auth(authorization: Optional[str] = Header(None)): async def basic_auth(authorization: Optional[str] = Header(None)):
if not authorization: if not authorization:
return None return None
try: try:
scheme, credentials = authorization.split() scheme, credentials = authorization.split()
if scheme.lower() != "basic": if scheme.lower() != 'basic':
return None return None
decoded = base64.b64decode(credentials).decode("utf-8") decoded = base64.b64decode(credentials).decode('utf-8')
username, password = decoded.split(":", 1) username, password = decoded.split(':', 1)
user = await User.get_or_none(username=username) user = await User.get_or_none(username=username)
if user and verify_password(password, user.hashed_password): if user and verify_password(password, user.hashed_password):
@ -64,7 +62,6 @@ async def basic_auth(authorization: Optional[str] = Header(None)):
return None return None
async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)): async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)):
user = await basic_auth(authorization) user = await basic_auth(authorization)
if user: if user:
@ -78,15 +75,14 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'}, headers={'WWW-Authenticate': 'Basic realm="MyWebdav WebDAV"'}
) )
async def resolve_path(path_str: str, user: User): async def resolve_path(path_str: str, user: User):
if not path_str or path_str == "/": if not path_str or path_str == '/':
return None, None, True return None, None, True
parts = [p for p in path_str.split("/") if p] parts = [p for p in path_str.split('/') if p]
if not parts: if not parts:
return None, None, True return None, None, True
@ -94,7 +90,10 @@ async def resolve_path(path_str: str, user: User):
current_folder = None current_folder = None
for i, part in enumerate(parts[:-1]): for i, part in enumerate(parts[:-1]):
folder = await Folder.get_or_none( folder = await Folder.get_or_none(
name=part, parent=current_folder, owner=user, is_deleted=False name=part,
parent=current_folder,
owner=user,
is_deleted=False
) )
if not folder: if not folder:
return None, None, False return None, None, False
@ -103,37 +102,39 @@ async def resolve_path(path_str: str, user: User):
last_part = parts[-1] last_part = parts[-1]
folder = await Folder.get_or_none( folder = await Folder.get_or_none(
name=last_part, parent=current_folder, owner=user, is_deleted=False name=last_part,
parent=current_folder,
owner=user,
is_deleted=False
) )
if folder: if folder:
return folder, current_folder, True return folder, current_folder, True
file = await File.get_or_none( file = await File.get_or_none(
name=last_part, parent=current_folder, owner=user, is_deleted=False name=last_part,
parent=current_folder,
owner=user,
is_deleted=False
) )
if file: if file:
return file, current_folder, True return file, current_folder, True
return None, current_folder, True return None, current_folder, True
def build_href(base_path: str, name: str, is_collection: bool): def build_href(base_path: str, name: str, is_collection: bool):
path = f"{base_path.rstrip('/')}/{name}" path = f"{base_path.rstrip('/')}/{name}"
if is_collection: if is_collection:
path += "/" path += '/'
return path return path
async def get_custom_properties(resource_type: str, resource_id: int): async def get_custom_properties(resource_type: str, resource_id: int):
props = await WebDAVProperty.filter( props = await WebDAVProperty.filter(
resource_type=resource_type, resource_id=resource_id resource_type=resource_type,
resource_id=resource_id
) )
return {(prop.namespace, prop.name): prop.value for prop in props} return {(prop.namespace, prop.name): prop.value for prop in props}
def create_propstat_element(props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"):
def create_propstat_element(
props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"
):
propstat = ET.Element("D:propstat") propstat = ET.Element("D:propstat")
prop = ET.SubElement(propstat, "D:prop") prop = ET.SubElement(propstat, "D:prop")
@ -150,10 +151,10 @@ def create_propstat_element(
elem.text = value elem.text = value
elif key == "getlastmodified": elif key == "getlastmodified":
elem = ET.SubElement(prop, f"D:{key}") elem = ET.SubElement(prop, f"D:{key}")
elem.text = value.strftime("%a, %d %b %Y %H:%M:%S GMT") elem.text = value.strftime('%a, %d %b %Y %H:%M:%S GMT')
elif key == "creationdate": elif key == "creationdate":
elem = ET.SubElement(prop, f"D:{key}") elem = ET.SubElement(prop, f"D:{key}")
elem.text = value.isoformat() + "Z" elem.text = value.isoformat() + 'Z'
elif key == "displayname": elif key == "displayname":
elem = ET.SubElement(prop, f"D:{key}") elem = ET.SubElement(prop, f"D:{key}")
elem.text = value elem.text = value
@ -173,7 +174,6 @@ def create_propstat_element(
return propstat return propstat
def parse_propfind_body(body: bytes): def parse_propfind_body(body: bytes):
if not body: if not body:
return None return None
@ -193,8 +193,8 @@ def parse_propfind_body(body: bytes):
if prop is not None: if prop is not None:
requested_props = [] requested_props = []
for child in prop: for child in prop:
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:" ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
name = child.tag.split("}")[1] if "}" in child.tag else child.tag name = child.tag.split('}')[1] if '}' in child.tag else child.tag
requested_props.append((ns, name)) requested_props.append((ns, name))
return requested_props return requested_props
except ET.ParseError: except ET.ParseError:
@ -202,7 +202,6 @@ def parse_propfind_body(body: bytes):
return None return None
@router.api_route("/{full_path:path}", methods=["OPTIONS"]) @router.api_route("/{full_path:path}", methods=["OPTIONS"])
async def webdav_options(full_path: str): async def webdav_options(full_path: str):
return Response( return Response(
@ -210,17 +209,14 @@ async def webdav_options(full_path: str):
headers={ headers={
"DAV": "1, 2", "DAV": "1, 2",
"Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK", "Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK",
"MS-Author-Via": "DAV", "MS-Author-Via": "DAV"
}, }
) )
@router.api_route("/{full_path:path}", methods=["PROPFIND"]) @router.api_route("/{full_path:path}", methods=["PROPFIND"])
async def handle_propfind( async def handle_propfind(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
depth = request.headers.get("Depth", "1") depth = request.headers.get("Depth", "1")
full_path = unquote(full_path).strip("/") full_path = unquote(full_path).strip('/')
body = await request.body() body = await request.body()
requested_props = parse_propfind_body(body) requested_props = parse_propfind_body(body)
@ -236,23 +232,19 @@ async def handle_propfind(
if resource is None: if resource is None:
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href") href = ET.SubElement(response, "D:href")
href.text = base_href if base_href.endswith("/") else base_href + "/" href.text = base_href if base_href.endswith('/') else base_href + '/'
props = { props = {
"resourcetype": "collection", "resourcetype": "collection",
"displayname": full_path.split("/")[-1] if full_path else "Root", "displayname": full_path.split('/')[-1] if full_path else "Root",
"creationdate": datetime.now(), "creationdate": datetime.now(),
"getlastmodified": datetime.now(), "getlastmodified": datetime.now()
} }
response.append(create_propstat_element(props)) response.append(create_propstat_element(props))
if depth in ["1", "infinity"]: if depth in ["1", "infinity"]:
folders = await Folder.filter( folders = await Folder.filter(owner=current_user, parent=parent, is_deleted=False)
owner=current_user, parent=parent, is_deleted=False files = await File.filter(owner=current_user, parent=parent, is_deleted=False)
)
files = await File.filter(
owner=current_user, parent=parent, is_deleted=False
)
for folder in folders: for folder in folders:
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
@ -263,17 +255,9 @@ async def handle_propfind(
"resourcetype": "collection", "resourcetype": "collection",
"displayname": folder.name, "displayname": folder.name,
"creationdate": folder.created_at, "creationdate": folder.created_at,
"getlastmodified": folder.updated_at, "getlastmodified": folder.updated_at
} }
custom_props = ( custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
await get_custom_properties("folder", folder.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
for file in files: for file in files:
@ -288,48 +272,28 @@ async def handle_propfind(
"getcontenttype": file.mime_type, "getcontenttype": file.mime_type,
"creationdate": file.created_at, "creationdate": file.created_at,
"getlastmodified": file.updated_at, "getlastmodified": file.updated_at,
"getetag": f'"{file.file_hash}"', "getetag": f'"{file.file_hash}"'
} }
custom_props = ( custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
await get_custom_properties("file", file.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, Folder): elif isinstance(resource, Folder):
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href") href = ET.SubElement(response, "D:href")
href.text = base_href if base_href.endswith("/") else base_href + "/" href.text = base_href if base_href.endswith('/') else base_href + '/'
props = { props = {
"resourcetype": "collection", "resourcetype": "collection",
"displayname": resource.name, "displayname": resource.name,
"creationdate": resource.created_at, "creationdate": resource.created_at,
"getlastmodified": resource.updated_at, "getlastmodified": resource.updated_at
} }
custom_props = ( custom_props = await get_custom_properties("folder", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
await get_custom_properties("folder", resource.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
if depth in ["1", "infinity"]: if depth in ["1", "infinity"]:
folders = await Folder.filter( folders = await Folder.filter(owner=current_user, parent=resource, is_deleted=False)
owner=current_user, parent=resource, is_deleted=False files = await File.filter(owner=current_user, parent=resource, is_deleted=False)
)
files = await File.filter(
owner=current_user, parent=resource, is_deleted=False
)
for folder in folders: for folder in folders:
response = ET.SubElement(multistatus, "D:response") response = ET.SubElement(multistatus, "D:response")
@ -340,17 +304,9 @@ async def handle_propfind(
"resourcetype": "collection", "resourcetype": "collection",
"displayname": folder.name, "displayname": folder.name,
"creationdate": folder.created_at, "creationdate": folder.created_at,
"getlastmodified": folder.updated_at, "getlastmodified": folder.updated_at
} }
custom_props = ( custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
await get_custom_properties("folder", folder.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
for file in files: for file in files:
@ -365,17 +321,9 @@ async def handle_propfind(
"getcontenttype": file.mime_type, "getcontenttype": file.mime_type,
"creationdate": file.created_at, "creationdate": file.created_at,
"getlastmodified": file.updated_at, "getlastmodified": file.updated_at,
"getetag": f'"{file.file_hash}"', "getetag": f'"{file.file_hash}"'
} }
custom_props = ( custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
await get_custom_properties("file", file.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, File): elif isinstance(resource, File):
@ -390,32 +338,17 @@ async def handle_propfind(
"getcontenttype": resource.mime_type, "getcontenttype": resource.mime_type,
"creationdate": resource.created_at, "creationdate": resource.created_at,
"getlastmodified": resource.updated_at, "getlastmodified": resource.updated_at,
"getetag": f'"{resource.file_hash}"', "getetag": f'"{resource.file_hash}"'
} }
custom_props = ( custom_props = await get_custom_properties("file", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
await get_custom_properties("file", resource.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props)) response.append(create_propstat_element(props, custom_props))
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True) xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
return Response( return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=207,
)
@router.api_route("/{full_path:path}", methods=["GET", "HEAD"]) @router.api_route("/{full_path:path}", methods=["GET", "HEAD"])
async def handle_get( async def handle_get(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user) resource, parent, exists = await resolve_path(full_path, current_user)
@ -430,10 +363,8 @@ async def handle_get(
"Content-Length": str(resource.size), "Content-Length": str(resource.size),
"Content-Type": resource.mime_type, "Content-Type": resource.mime_type,
"ETag": f'"{resource.file_hash}"', "ETag": f'"{resource.file_hash}"',
"Last-Modified": resource.updated_at.strftime( "Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
"%a, %d %b %Y %H:%M:%S GMT" }
),
},
) )
async def file_iterator(): async def file_iterator():
@ -446,28 +377,23 @@ async def handle_get(
headers={ headers={
"Content-Disposition": f'attachment; filename="{resource.name}"', "Content-Disposition": f'attachment; filename="{resource.name}"',
"ETag": f'"{resource.file_hash}"', "ETag": f'"{resource.file_hash}"',
"Last-Modified": resource.updated_at.strftime( "Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
"%a, %d %b %Y %H:%M:%S GMT" }
),
},
) )
except FileNotFoundError: except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage") raise HTTPException(status_code=404, detail="File not found in storage")
@router.api_route("/{full_path:path}", methods=["PUT"]) @router.api_route("/{full_path:path}", methods=["PUT"])
async def handle_put( async def handle_put(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
if not full_path: if not full_path:
raise HTTPException(status_code=400, detail="Cannot PUT to root") raise HTTPException(status_code=400, detail="Cannot PUT to root")
parts = [p for p in full_path.split("/") if p] parts = [p for p in full_path.split('/') if p]
file_name = parts[-1] file_name = parts[-1]
parent_path = "/".join(parts[:-1]) if len(parts) > 1 else "" parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
_, parent_folder, exists = await resolve_path(parent_path, current_user) _, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists: if not exists:
@ -491,7 +417,10 @@ async def handle_put(
mime_type = "application/octet-stream" mime_type = "application/octet-stream"
existing_file = await File.get_or_none( existing_file = await File.get_or_none(
name=file_name, parent=parent_folder, owner=current_user, is_deleted=False name=file_name,
parent=parent_folder,
owner=current_user,
is_deleted=False
) )
if existing_file: if existing_file:
@ -503,9 +432,7 @@ async def handle_put(
existing_file.updated_at = datetime.now() existing_file.updated_at = datetime.now()
await existing_file.save() await existing_file.save()
current_user.used_storage_bytes = ( current_user.used_storage_bytes = current_user.used_storage_bytes - old_size + file_size
current_user.used_storage_bytes - old_size + file_size
)
await current_user.save() await current_user.save()
await log_activity(current_user, "file_updated", "file", existing_file.id) await log_activity(current_user, "file_updated", "file", existing_file.id)
@ -518,7 +445,7 @@ async def handle_put(
mime_type=mime_type, mime_type=mime_type,
file_hash=file_hash, file_hash=file_hash,
owner=current_user, owner=current_user,
parent=parent_folder, parent=parent_folder
) )
current_user.used_storage_bytes += file_size current_user.used_storage_bytes += file_size
@ -527,12 +454,9 @@ async def handle_put(
await log_activity(current_user, "file_created", "file", db_file.id) await log_activity(current_user, "file_created", "file", db_file.id)
return Response(status_code=201) return Response(status_code=201)
@router.api_route("/{full_path:path}", methods=["DELETE"]) @router.api_route("/{full_path:path}", methods=["DELETE"])
async def handle_delete( async def handle_delete(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
if not full_path: if not full_path:
raise HTTPException(status_code=400, detail="Cannot DELETE root") raise HTTPException(status_code=400, detail="Cannot DELETE root")
@ -554,45 +478,44 @@ async def handle_delete(
return Response(status_code=204) return Response(status_code=204)
@router.api_route("/{full_path:path}", methods=["MKCOL"]) @router.api_route("/{full_path:path}", methods=["MKCOL"])
async def handle_mkcol( async def handle_mkcol(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
if not full_path: if not full_path:
raise HTTPException(status_code=400, detail="Cannot MKCOL at root") raise HTTPException(status_code=400, detail="Cannot MKCOL at root")
parts = [p for p in full_path.split("/") if p] parts = [p for p in full_path.split('/') if p]
folder_name = parts[-1] folder_name = parts[-1]
parent_path = "/".join(parts[:-1]) if len(parts) > 1 else "" parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
_, parent_folder, exists = await resolve_path(parent_path, current_user) _, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists: if not exists:
raise HTTPException(status_code=409, detail="Parent folder does not exist") raise HTTPException(status_code=409, detail="Parent folder does not exist")
existing = await Folder.get_or_none( existing = await Folder.get_or_none(
name=folder_name, parent=parent_folder, owner=current_user, is_deleted=False name=folder_name,
parent=parent_folder,
owner=current_user,
is_deleted=False
) )
if existing: if existing:
raise HTTPException(status_code=405, detail="Folder already exists") raise HTTPException(status_code=405, detail="Folder already exists")
folder = await Folder.create( folder = await Folder.create(
name=folder_name, parent=parent_folder, owner=current_user name=folder_name,
parent=parent_folder,
owner=current_user
) )
await log_activity(current_user, "folder_created", "folder", folder.id) await log_activity(current_user, "folder_created", "folder", folder.id)
return Response(status_code=201) return Response(status_code=201)
@router.api_route("/{full_path:path}", methods=["COPY"]) @router.api_route("/{full_path:path}", methods=["COPY"])
async def handle_copy( async def handle_copy(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination") destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T") overwrite = request.headers.get("Overwrite", "T")
@ -600,7 +523,7 @@ async def handle_copy(
raise HTTPException(status_code=400, detail="Destination header required") raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path) dest_path = unquote(urlparse(destination).path)
dest_path = dest_path.replace("/webdav/", "").strip("/") dest_path = dest_path.replace('/webdav/', '').strip('/')
source_resource, _, exists = await resolve_path(full_path, current_user) source_resource, _, exists = await resolve_path(full_path, current_user)
@ -610,9 +533,9 @@ async def handle_copy(
if not isinstance(source_resource, File): if not isinstance(source_resource, File):
raise HTTPException(status_code=501, detail="Only file copy is implemented") raise HTTPException(status_code=501, detail="Only file copy is implemented")
dest_parts = [p for p in dest_path.split("/") if p] dest_parts = [p for p in dest_path.split('/') if p]
dest_name = dest_parts[-1] dest_name = dest_parts[-1]
dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else "" dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user) _, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@ -620,13 +543,14 @@ async def handle_copy(
raise HTTPException(status_code=409, detail="Destination parent does not exist") raise HTTPException(status_code=409, detail="Destination parent does not exist")
existing_dest = await File.get_or_none( existing_dest = await File.get_or_none(
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False name=dest_name,
parent=dest_parent,
owner=current_user,
is_deleted=False
) )
if existing_dest and overwrite == "F": if existing_dest and overwrite == "F":
raise HTTPException( raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest: if existing_dest:
await existing_dest.delete() await existing_dest.delete()
@ -638,18 +562,15 @@ async def handle_copy(
mime_type=source_resource.mime_type, mime_type=source_resource.mime_type,
file_hash=source_resource.file_hash, file_hash=source_resource.file_hash,
owner=current_user, owner=current_user,
parent=dest_parent, parent=dest_parent
) )
await log_activity(current_user, "file_copied", "file", new_file.id) await log_activity(current_user, "file_copied", "file", new_file.id)
return Response(status_code=201 if not existing_dest else 204) return Response(status_code=201 if not existing_dest else 204)
@router.api_route("/{full_path:path}", methods=["MOVE"]) @router.api_route("/{full_path:path}", methods=["MOVE"])
async def handle_move( async def handle_move(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination") destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T") overwrite = request.headers.get("Overwrite", "T")
@ -657,16 +578,16 @@ async def handle_move(
raise HTTPException(status_code=400, detail="Destination header required") raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path) dest_path = unquote(urlparse(destination).path)
dest_path = dest_path.replace("/webdav/", "").strip("/") dest_path = dest_path.replace('/webdav/', '').strip('/')
source_resource, _, exists = await resolve_path(full_path, current_user) source_resource, _, exists = await resolve_path(full_path, current_user)
if not source_resource: if not source_resource:
raise HTTPException(status_code=404, detail="Source not found") raise HTTPException(status_code=404, detail="Source not found")
dest_parts = [p for p in dest_path.split("/") if p] dest_parts = [p for p in dest_path.split('/') if p]
dest_name = dest_parts[-1] dest_name = dest_parts[-1]
dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else "" dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user) _, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@ -675,13 +596,14 @@ async def handle_move(
if isinstance(source_resource, File): if isinstance(source_resource, File):
existing_dest = await File.get_or_none( existing_dest = await File.get_or_none(
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False name=dest_name,
parent=dest_parent,
owner=current_user,
is_deleted=False
) )
if existing_dest and overwrite == "F": if existing_dest and overwrite == "F":
raise HTTPException( raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest: if existing_dest:
await existing_dest.delete() await existing_dest.delete()
@ -695,13 +617,14 @@ async def handle_move(
elif isinstance(source_resource, Folder): elif isinstance(source_resource, Folder):
existing_dest = await Folder.get_or_none( existing_dest = await Folder.get_or_none(
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False name=dest_name,
parent=dest_parent,
owner=current_user,
is_deleted=False
) )
if existing_dest and overwrite == "F": if existing_dest and overwrite == "F":
raise HTTPException( raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest: if existing_dest:
existing_dest.is_deleted = True existing_dest.is_deleted = True
@ -714,12 +637,9 @@ async def handle_move(
await log_activity(current_user, "folder_moved", "folder", source_resource.id) await log_activity(current_user, "folder_moved", "folder", source_resource.id)
return Response(status_code=201 if not existing_dest else 204) return Response(status_code=201 if not existing_dest else 204)
@router.api_route("/{full_path:path}", methods=["LOCK"]) @router.api_route("/{full_path:path}", methods=["LOCK"])
async def handle_lock( async def handle_lock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
timeout_header = request.headers.get("Timeout", "Second-3600") timeout_header = request.headers.get("Timeout", "Second-3600")
timeout = 3600 timeout = 3600
@ -761,38 +681,32 @@ async def handle_lock(
content=xml_content, content=xml_content,
media_type="application/xml; charset=utf-8", media_type="application/xml; charset=utf-8",
status_code=200, status_code=200,
headers={"Lock-Token": f"<{lock_token}>"}, headers={"Lock-Token": f"<{lock_token}>"}
) )
@router.api_route("/{full_path:path}", methods=["UNLOCK"]) @router.api_route("/{full_path:path}", methods=["UNLOCK"])
async def handle_unlock( async def handle_unlock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
lock_token_header = request.headers.get("Lock-Token") lock_token_header = request.headers.get("Lock-Token")
if not lock_token_header: if not lock_token_header:
raise HTTPException(status_code=400, detail="Lock-Token header required") raise HTTPException(status_code=400, detail="Lock-Token header required")
lock_token = lock_token_header.strip("<>") lock_token = lock_token_header.strip('<>')
existing_lock = WebDAVLock.get_lock(full_path) existing_lock = WebDAVLock.get_lock(full_path)
if not existing_lock or existing_lock["token"] != lock_token: if not existing_lock or existing_lock['token'] != lock_token:
raise HTTPException(status_code=409, detail="Invalid lock token") raise HTTPException(status_code=409, detail="Invalid lock token")
if existing_lock["user_id"] != current_user.id: if existing_lock['user_id'] != current_user.id:
raise HTTPException(status_code=403, detail="Not lock owner") raise HTTPException(status_code=403, detail="Not lock owner")
WebDAVLock.remove_lock(full_path) WebDAVLock.remove_lock(full_path)
return Response(status_code=204) return Response(status_code=204)
@router.api_route("/{full_path:path}", methods=["PROPPATCH"]) @router.api_route("/{full_path:path}", methods=["PROPPATCH"])
async def handle_proppatch( async def handle_proppatch(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
request: Request, full_path: str, current_user: User = Depends(webdav_auth) full_path = unquote(full_path).strip('/')
):
full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user) resource, parent, exists = await resolve_path(full_path, current_user)
@ -805,7 +719,7 @@ async def handle_proppatch(
try: try:
root = ET.fromstring(body) root = ET.fromstring(body)
except Exception: except:
raise HTTPException(status_code=400, detail="Invalid XML") raise HTTPException(status_code=400, detail="Invalid XML")
resource_type = "file" if isinstance(resource, File) else "folder" resource_type = "file" if isinstance(resource, File) else "folder"
@ -820,19 +734,13 @@ async def handle_proppatch(
prop_element = set_element.find(".//{DAV:}prop") prop_element = set_element.find(".//{DAV:}prop")
if prop_element is not None: if prop_element is not None:
for child in prop_element: for child in prop_element:
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:" ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
name = child.tag.split("}")[1] if "}" in child.tag else child.tag name = child.tag.split('}')[1] if '}' in child.tag else child.tag
value = child.text or "" value = child.text or ""
if ns == "DAV:": if ns == "DAV:":
live_props = [ live_props = ["creationdate", "getcontentlength", "getcontenttype",
"creationdate", "getetag", "getlastmodified", "resourcetype"]
"getcontentlength",
"getcontenttype",
"getetag",
"getlastmodified",
"resourcetype",
]
if name in live_props: if name in live_props:
failed_props.append((ns, name, "409 Conflict")) failed_props.append((ns, name, "409 Conflict"))
continue continue
@ -842,7 +750,7 @@ async def handle_proppatch(
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
namespace=ns, namespace=ns,
name=name, name=name
) )
if existing_prop: if existing_prop:
@ -854,10 +762,10 @@ async def handle_proppatch(
resource_id=resource_id, resource_id=resource_id,
namespace=ns, namespace=ns,
name=name, name=name,
value=value, value=value
) )
set_props.append((ns, name)) set_props.append((ns, name))
except Exception: except Exception as e:
failed_props.append((ns, name, "500 Internal Server Error")) failed_props.append((ns, name, "500 Internal Server Error"))
remove_element = root.find(".//{DAV:}remove") remove_element = root.find(".//{DAV:}remove")
@ -865,8 +773,8 @@ async def handle_proppatch(
prop_element = remove_element.find(".//{DAV:}prop") prop_element = remove_element.find(".//{DAV:}prop")
if prop_element is not None: if prop_element is not None:
for child in prop_element: for child in prop_element:
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:" ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
name = child.tag.split("}")[1] if "}" in child.tag else child.tag name = child.tag.split('}')[1] if '}' in child.tag else child.tag
if ns == "DAV:": if ns == "DAV:":
failed_props.append((ns, name, "409 Conflict")) failed_props.append((ns, name, "409 Conflict"))
@ -877,7 +785,7 @@ async def handle_proppatch(
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
namespace=ns, namespace=ns,
name=name, name=name
) )
if existing_prop: if existing_prop:
@ -885,7 +793,7 @@ async def handle_proppatch(
remove_props.append((ns, name)) remove_props.append((ns, name))
else: else:
failed_props.append((ns, name, "404 Not Found")) failed_props.append((ns, name, "404 Not Found"))
except Exception: except Exception as e:
failed_props.append((ns, name, "500 Internal Server Error")) failed_props.append((ns, name, "500 Internal Server Error"))
multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"}) multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"})
@ -929,8 +837,4 @@ async def handle_proppatch(
await log_activity(current_user, "properties_modified", resource_type, resource_id) await log_activity(current_user, "properties_modified", resource_type, resource_id)
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True) xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
return Response( return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=207,
)

View File

@ -1,106 +1,98 @@
import pytest import pytest
import pytest_asyncio
from decimal import Decimal from decimal import Decimal
from datetime import date, datetime from datetime import date, datetime
from httpx import AsyncClient, ASGITransport from httpx import AsyncClient
from fastapi import status from fastapi import status
from tortoise.contrib.test import initializer, finalizer
from mywebdav.main import app from mywebdav.main import app
from mywebdav.models import User from mywebdav.models import User
from mywebdav.billing.models import ( from mywebdav.billing.models import PricingConfig, Invoice, UsageAggregate, UserSubscription
PricingConfig,
Invoice,
UsageAggregate,
UserSubscription,
)
from mywebdav.auth import create_access_token from mywebdav.auth import create_access_token
@pytest.fixture(scope="module")
def event_loop():
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture @pytest.fixture(scope="module", autouse=True)
async def initialize_tests():
initializer(["mywebdav.models", "mywebdav.billing.models"], db_url="sqlite://:memory:")
yield
await finalizer()
@pytest.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True,
is_superuser=False, is_superuser=False
) )
yield user yield user
await user.delete() await user.delete()
@pytest.fixture
@pytest_asyncio.fixture
async def admin_user(): async def admin_user():
user = await User.create( user = await User.create(
username="adminuser", username="adminuser",
email="admin@example.com", email="admin@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True,
is_superuser=True, is_superuser=True
) )
yield user yield user
await user.delete() await user.delete()
@pytest.fixture
@pytest_asyncio.fixture
async def auth_token(test_user): async def auth_token(test_user):
token = create_access_token(data={"sub": test_user.username}) token = create_access_token(data={"sub": test_user.username})
return token return token
@pytest.fixture
@pytest_asyncio.fixture
async def admin_token(admin_user): async def admin_token(admin_user):
token = create_access_token(data={"sub": admin_user.username}) token = create_access_token(data={"sub": admin_user.username})
return token return token
@pytest.fixture
@pytest_asyncio.fixture
async def pricing_config(): async def pricing_config():
configs = [] configs = []
configs.append( configs.append(await PricingConfig.create(
await PricingConfig.create( config_key="storage_per_gb_month",
config_key="storage_per_gb_month", config_value=Decimal("0.0045"),
config_value=Decimal("0.0045"), description="Storage cost per GB per month",
description="Storage cost per GB per month", unit="per_gb_month"
unit="per_gb_month", ))
) configs.append(await PricingConfig.create(
) config_key="bandwidth_egress_per_gb",
configs.append( config_value=Decimal("0.009"),
await PricingConfig.create( description="Bandwidth egress cost per GB",
config_key="bandwidth_egress_per_gb", unit="per_gb"
config_value=Decimal("0.009"), ))
description="Bandwidth egress cost per GB", configs.append(await PricingConfig.create(
unit="per_gb", config_key="free_tier_storage_gb",
) config_value=Decimal("15"),
) description="Free tier storage in GB",
configs.append( unit="gb"
await PricingConfig.create( ))
config_key="free_tier_storage_gb", configs.append(await PricingConfig.create(
config_value=Decimal("15"), config_key="free_tier_bandwidth_gb",
description="Free tier storage in GB", config_value=Decimal("15"),
unit="gb", description="Free tier bandwidth in GB per month",
) unit="gb"
) ))
configs.append(
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
)
yield configs yield configs
for config in configs: for config in configs:
await config.delete() await config.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_usage(test_user, auth_token): async def test_get_current_usage(test_user, auth_token):
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/billing/usage/current", "/api/billing/usage/current",
headers={"Authorization": f"Bearer {auth_token}"}, headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -108,7 +100,6 @@ async def test_get_current_usage(test_user, auth_token):
assert "storage_gb" in data assert "storage_gb" in data
assert "bandwidth_down_gb_today" in data assert "bandwidth_down_gb_today" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_monthly_usage(test_user, auth_token): async def test_get_monthly_usage(test_user, auth_token):
today = date.today() today = date.today()
@ -116,18 +107,16 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024**3 * 10, storage_bytes_avg=1024 ** 3 * 10,
storage_bytes_peak=1024**3 * 12, storage_bytes_peak=1024 ** 3 * 12,
bandwidth_up_bytes=1024**3 * 2, bandwidth_up_bytes=1024 ** 3 * 2,
bandwidth_down_bytes=1024**3 * 5, bandwidth_down_bytes=1024 ** 3 * 5
) )
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
f"/api/billing/usage/monthly?year={today.year}&month={today.month}", f"/api/billing/usage/monthly?year={today.year}&month={today.month}",
headers={"Authorization": f"Bearer {auth_token}"}, headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -136,15 +125,12 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_subscription(test_user, auth_token): async def test_get_subscription(test_user, auth_token):
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/billing/subscription", "/api/billing/subscription",
headers={"Authorization": f"Bearer {auth_token}"}, headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -154,7 +140,6 @@ async def test_get_subscription(test_user, auth_token):
await UserSubscription.filter(user=test_user).delete() await UserSubscription.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_invoices(test_user, auth_token): async def test_list_invoices(test_user, auth_token):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -165,14 +150,13 @@ async def test_list_invoices(test_user, auth_token):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open", status="open"
) )
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"} "/api/billing/invoices",
headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -182,7 +166,6 @@ async def test_list_invoices(test_user, auth_token):
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_invoice(test_user, auth_token): async def test_get_invoice(test_user, auth_token):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -193,15 +176,13 @@ async def test_get_invoice(test_user, auth_token):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open", status="open"
) )
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
f"/api/billing/invoices/{invoice.id}", f"/api/billing/invoices/{invoice.id}",
headers={"Authorization": f"Bearer {auth_token}"}, headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -210,45 +191,39 @@ async def test_get_invoice(test_user, auth_token):
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_pricing(): async def test_get_pricing():
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/api/billing/pricing") response = await client.get("/api/billing/pricing")
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
assert isinstance(data, dict) assert isinstance(data, dict)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_get_pricing(admin_user, admin_token, pricing_config): async def test_admin_get_pricing(admin_user, admin_token, pricing_config):
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/admin/billing/pricing", "/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
assert len(data) > 0 assert len(data) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_update_pricing(admin_user, admin_token, pricing_config): async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
config_id = pricing_config[0].id config_id = pricing_config[0].id
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.put( response = await client.put(
f"/api/admin/billing/pricing/{config_id}", f"/api/admin/billing/pricing/{config_id}",
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}"},
json={"config_key": "storage_per_gb_month", "config_value": 0.005}, json={
"config_key": "storage_per_gb_month",
"config_value": 0.005
}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -256,15 +231,12 @@ async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
updated = await PricingConfig.get(id=config_id) updated = await PricingConfig.get(id=config_id)
assert updated.config_value == Decimal("0.005") assert updated.config_value == Decimal("0.005")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_get_stats(admin_user, admin_token): async def test_admin_get_stats(admin_user, admin_token):
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/admin/billing/stats", "/api/admin/billing/stats",
headers={"Authorization": f"Bearer {admin_token}"}, headers={"Authorization": f"Bearer {admin_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -273,15 +245,12 @@ async def test_admin_get_stats(admin_user, admin_token):
assert "total_invoices" in data assert "total_invoices" in data
assert "pending_invoices" in data assert "pending_invoices" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token): async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token):
async with AsyncClient( async with AsyncClient(app=app, base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get( response = await client.get(
"/api/admin/billing/pricing", "/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {auth_token}"}, headers={"Authorization": f"Bearer {auth_token}"}
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@ -1,78 +1,72 @@
import pytest import pytest
import pytest_asyncio
from decimal import Decimal from decimal import Decimal
from datetime import date, datetime from datetime import date, datetime
from tortoise.contrib.test import initializer, finalizer
from mywebdav.models import User from mywebdav.models import User
from mywebdav.billing.models import ( from mywebdav.billing.models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
Invoice,
InvoiceLineItem,
PricingConfig,
UsageAggregate,
UserSubscription,
)
from mywebdav.billing.invoice_generator import InvoiceGenerator from mywebdav.billing.invoice_generator import InvoiceGenerator
@pytest.fixture(scope="module")
def event_loop():
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture @pytest.fixture(scope="module", autouse=True)
async def initialize_tests():
initializer(["mywebdav.models", "mywebdav.billing.models"], db_url="sqlite://:memory:")
yield
await finalizer()
@pytest.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True
) )
yield user yield user
await user.delete() await user.delete()
@pytest.fixture
@pytest_asyncio.fixture
async def pricing_config(): async def pricing_config():
configs = [] configs = []
configs.append( configs.append(await PricingConfig.create(
await PricingConfig.create( config_key="storage_per_gb_month",
config_key="storage_per_gb_month", config_value=Decimal("0.0045"),
config_value=Decimal("0.0045"), description="Storage cost per GB per month",
description="Storage cost per GB per month", unit="per_gb_month"
unit="per_gb_month", ))
) configs.append(await PricingConfig.create(
) config_key="bandwidth_egress_per_gb",
configs.append( config_value=Decimal("0.009"),
await PricingConfig.create( description="Bandwidth egress cost per GB",
config_key="bandwidth_egress_per_gb", unit="per_gb"
config_value=Decimal("0.009"), ))
description="Bandwidth egress cost per GB", configs.append(await PricingConfig.create(
unit="per_gb", config_key="free_tier_storage_gb",
) config_value=Decimal("15"),
) description="Free tier storage in GB",
configs.append( unit="gb"
await PricingConfig.create( ))
config_key="free_tier_storage_gb", configs.append(await PricingConfig.create(
config_value=Decimal("15"), config_key="free_tier_bandwidth_gb",
description="Free tier storage in GB", config_value=Decimal("15"),
unit="gb", description="Free tier bandwidth in GB per month",
) unit="gb"
) ))
configs.append( configs.append(await PricingConfig.create(
await PricingConfig.create( config_key="tax_rate_default",
config_key="free_tier_bandwidth_gb", config_value=Decimal("0.0"),
config_value=Decimal("15"), description="Default tax rate",
description="Free tier bandwidth in GB per month", unit="percentage"
unit="gb", ))
)
)
configs.append(
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate",
unit="percentage",
)
)
yield configs yield configs
for config in configs: for config in configs:
await config.delete() await config.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_monthly_invoice_with_usage(test_user, pricing_config): async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
today = date.today() today = date.today()
@ -80,15 +74,13 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024**3 * 50, storage_bytes_avg=1024 ** 3 * 50,
storage_bytes_peak=1024**3 * 55, storage_bytes_peak=1024 ** 3 * 55,
bandwidth_up_bytes=1024**3 * 10, bandwidth_up_bytes=1024 ** 3 * 10,
bandwidth_down_bytes=1024**3 * 20, bandwidth_down_bytes=1024 ** 3 * 20
) )
invoice = await InvoiceGenerator.generate_monthly_invoice( invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
test_user, today.year, today.month
)
assert invoice is not None assert invoice is not None
assert invoice.user_id == test_user.id assert invoice.user_id == test_user.id
@ -101,7 +93,6 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await invoice.delete() await invoice.delete()
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config): async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config):
today = date.today() today = date.today()
@ -109,21 +100,18 @@ async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_confi
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024**3 * 10, storage_bytes_avg=1024 ** 3 * 10,
storage_bytes_peak=1024**3 * 12, storage_bytes_peak=1024 ** 3 * 12,
bandwidth_up_bytes=1024**3 * 5, bandwidth_up_bytes=1024 ** 3 * 5,
bandwidth_down_bytes=1024**3 * 10, bandwidth_down_bytes=1024 ** 3 * 10
) )
invoice = await InvoiceGenerator.generate_monthly_invoice( invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
test_user, today.year, today.month
)
assert invoice is None assert invoice is None
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finalize_invoice(test_user, pricing_config): async def test_finalize_invoice(test_user, pricing_config):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -134,7 +122,7 @@ async def test_finalize_invoice(test_user, pricing_config):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="draft", status="draft"
) )
finalized = await InvoiceGenerator.finalize_invoice(invoice) finalized = await InvoiceGenerator.finalize_invoice(invoice)
@ -143,7 +131,6 @@ async def test_finalize_invoice(test_user, pricing_config):
await finalized.delete() await finalized.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finalize_invoice_already_finalized(test_user, pricing_config): async def test_finalize_invoice_already_finalized(test_user, pricing_config):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -154,7 +141,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open", status="open"
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -162,7 +149,6 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mark_invoice_paid(test_user, pricing_config): async def test_mark_invoice_paid(test_user, pricing_config):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -173,7 +159,7 @@ async def test_mark_invoice_paid(test_user, pricing_config):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="open", status="open"
) )
paid = await InvoiceGenerator.mark_invoice_paid(invoice) paid = await InvoiceGenerator.mark_invoice_paid(invoice)
@ -183,33 +169,22 @@ async def test_mark_invoice_paid(test_user, pricing_config):
await paid.delete() await paid.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_with_tax(test_user, pricing_config): async def test_invoice_with_tax(test_user):
# Update tax rate await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.21"))
updated = await PricingConfig.filter(config_key="tax_rate_default").update(
config_value=Decimal("0.21")
)
assert updated == 1 # Should update 1 row
# Verify the update worked
tax_config = await PricingConfig.get(config_key="tax_rate_default")
assert tax_config.config_value == Decimal("0.21")
today = date.today() today = date.today()
await UsageAggregate.create( await UsageAggregate.create(
user=test_user, user=test_user,
date=today, date=today,
storage_bytes_avg=1024**3 * 50, storage_bytes_avg=1024 ** 3 * 50,
storage_bytes_peak=1024**3 * 55, storage_bytes_peak=1024 ** 3 * 55,
bandwidth_up_bytes=1024**3 * 10, bandwidth_up_bytes=1024 ** 3 * 10,
bandwidth_down_bytes=1024**3 * 20, bandwidth_down_bytes=1024 ** 3 * 20
) )
invoice = await InvoiceGenerator.generate_monthly_invoice( invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
test_user, today.year, today.month
)
assert invoice is not None assert invoice is not None
assert invoice.tax > 0 assert invoice.tax > 0
@ -217,6 +192,4 @@ async def test_invoice_with_tax(test_user, pricing_config):
await invoice.delete() await invoice.delete()
await UsageAggregate.filter(user=test_user).delete() await UsageAggregate.filter(user=test_user).delete()
await PricingConfig.filter(config_key="tax_rate_default").update( await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.0"))
config_value=Decimal("0.0")
)

View File

@ -2,19 +2,25 @@ import pytest
import pytest_asyncio import pytest_asyncio
from decimal import Decimal from decimal import Decimal
from datetime import date, datetime from datetime import date, datetime
from tortoise.contrib.test import initializer, finalizer
from mywebdav.models import User from mywebdav.models import User
from mywebdav.billing.models import ( from mywebdav.billing.models import (
SubscriptionPlan, SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
UserSubscription, Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
UsageRecord,
UsageAggregate,
Invoice,
InvoiceLineItem,
PricingConfig,
PaymentMethod,
BillingEvent,
) )
@pytest.fixture(scope="module")
def event_loop():
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="module", autouse=True)
async def initialize_tests():
initializer(["mywebdav.models", "mywebdav.billing.models"], db_url="sqlite://:memory:")
yield
await finalizer()
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user(): async def test_user():
@ -22,12 +28,11 @@ async def test_user():
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True
) )
yield user yield user
await user.delete() await user.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subscription_plan_creation(): async def test_subscription_plan_creation():
plan = await SubscriptionPlan.create( plan = await SubscriptionPlan.create(
@ -37,7 +42,7 @@ async def test_subscription_plan_creation():
storage_gb=100, storage_gb=100,
bandwidth_gb=100, bandwidth_gb=100,
price_monthly=Decimal("5.00"), price_monthly=Decimal("5.00"),
price_yearly=Decimal("50.00"), price_yearly=Decimal("50.00")
) )
assert plan.name == "starter" assert plan.name == "starter"
@ -45,11 +50,12 @@ async def test_subscription_plan_creation():
assert plan.price_monthly == Decimal("5.00") assert plan.price_monthly == Decimal("5.00")
await plan.delete() await plan.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_subscription_creation(test_user): async def test_user_subscription_creation(test_user):
subscription = await UserSubscription.create( subscription = await UserSubscription.create(
user=test_user, billing_type="pay_as_you_go", status="active" user=test_user,
billing_type="pay_as_you_go",
status="active"
) )
assert subscription.user_id == test_user.id assert subscription.user_id == test_user.id
@ -57,7 +63,6 @@ async def test_user_subscription_creation(test_user):
assert subscription.status == "active" assert subscription.status == "active"
await subscription.delete() await subscription.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_record_creation(test_user): async def test_usage_record_creation(test_user):
usage = await UsageRecord.create( usage = await UsageRecord.create(
@ -66,7 +71,7 @@ async def test_usage_record_creation(test_user):
amount_bytes=1024 * 1024 * 100, amount_bytes=1024 * 1024 * 100,
resource_type="file", resource_type="file",
resource_id=1, resource_id=1,
idempotency_key="test_key_123", idempotency_key="test_key_123"
) )
assert usage.user_id == test_user.id assert usage.user_id == test_user.id
@ -74,7 +79,6 @@ async def test_usage_record_creation(test_user):
assert usage.amount_bytes == 1024 * 1024 * 100 assert usage.amount_bytes == 1024 * 1024 * 100
await usage.delete() await usage.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_usage_aggregate_creation(test_user): async def test_usage_aggregate_creation(test_user):
aggregate = await UsageAggregate.create( aggregate = await UsageAggregate.create(
@ -83,14 +87,13 @@ async def test_usage_aggregate_creation(test_user):
storage_bytes_avg=1024 * 1024 * 500, storage_bytes_avg=1024 * 1024 * 500,
storage_bytes_peak=1024 * 1024 * 600, storage_bytes_peak=1024 * 1024 * 600,
bandwidth_up_bytes=1024 * 1024 * 50, bandwidth_up_bytes=1024 * 1024 * 50,
bandwidth_down_bytes=1024 * 1024 * 100, bandwidth_down_bytes=1024 * 1024 * 100
) )
assert aggregate.user_id == test_user.id assert aggregate.user_id == test_user.id
assert aggregate.storage_bytes_avg == 1024 * 1024 * 500 assert aggregate.storage_bytes_avg == 1024 * 1024 * 500
await aggregate.delete() await aggregate.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_creation(test_user): async def test_invoice_creation(test_user):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -101,7 +104,7 @@ async def test_invoice_creation(test_user):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="draft", status="draft"
) )
assert invoice.user_id == test_user.id assert invoice.user_id == test_user.id
@ -109,7 +112,6 @@ async def test_invoice_creation(test_user):
assert invoice.total == Decimal("10.00") assert invoice.total == Decimal("10.00")
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_line_item_creation(test_user): async def test_invoice_line_item_creation(test_user):
invoice = await Invoice.create( invoice = await Invoice.create(
@ -120,7 +122,7 @@ async def test_invoice_line_item_creation(test_user):
subtotal=Decimal("10.00"), subtotal=Decimal("10.00"),
tax=Decimal("0.00"), tax=Decimal("0.00"),
total=Decimal("10.00"), total=Decimal("10.00"),
status="draft", status="draft"
) )
line_item = await InvoiceLineItem.create( line_item = await InvoiceLineItem.create(
@ -129,7 +131,7 @@ async def test_invoice_line_item_creation(test_user):
quantity=Decimal("100.000000"), quantity=Decimal("100.000000"),
unit_price=Decimal("0.100000"), unit_price=Decimal("0.100000"),
amount=Decimal("10.0000"), amount=Decimal("10.0000"),
item_type="storage", item_type="storage"
) )
assert line_item.invoice_id == invoice.id assert line_item.invoice_id == invoice.id
@ -138,7 +140,6 @@ async def test_invoice_line_item_creation(test_user):
await line_item.delete() await line_item.delete()
await invoice.delete() await invoice.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pricing_config_creation(test_user): async def test_pricing_config_creation(test_user):
config = await PricingConfig.create( config = await PricingConfig.create(
@ -146,14 +147,13 @@ async def test_pricing_config_creation(test_user):
config_value=Decimal("0.0045"), config_value=Decimal("0.0045"),
description="Storage cost per GB per month", description="Storage cost per GB per month",
unit="per_gb_month", unit="per_gb_month",
updated_by=test_user, updated_by=test_user
) )
assert config.config_key == "storage_per_gb_month" assert config.config_key == "storage_per_gb_month"
assert config.config_value == Decimal("0.0045") assert config.config_value == Decimal("0.0045")
await config.delete() await config.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_payment_method_creation(test_user): async def test_payment_method_creation(test_user):
payment_method = await PaymentMethod.create( payment_method = await PaymentMethod.create(
@ -164,7 +164,7 @@ async def test_payment_method_creation(test_user):
last4="4242", last4="4242",
brand="visa", brand="visa",
exp_month=12, exp_month=12,
exp_year=2025, exp_year=2025
) )
assert payment_method.user_id == test_user.id assert payment_method.user_id == test_user.id
@ -172,7 +172,6 @@ async def test_payment_method_creation(test_user):
assert payment_method.is_default is True assert payment_method.is_default is True
await payment_method.delete() await payment_method.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_billing_event_creation(test_user): async def test_billing_event_creation(test_user):
event = await BillingEvent.create( event = await BillingEvent.create(
@ -180,7 +179,7 @@ async def test_billing_event_creation(test_user):
event_type="invoice_created", event_type="invoice_created",
stripe_event_id="evt_test_123", stripe_event_id="evt_test_123",
data={"invoice_id": 1}, data={"invoice_id": 1},
processed=False, processed=False
) )
assert event.user_id == test_user.id assert event.user_id == test_user.id

View File

@ -1,35 +1,18 @@
import pytest import pytest
def test_billing_module_imports(): def test_billing_module_imports():
from mywebdav.billing import ( from mywebdav.billing import models, stripe_client, usage_tracker, invoice_generator, scheduler
models,
stripe_client,
usage_tracker,
invoice_generator,
scheduler,
)
assert models is not None assert models is not None
assert stripe_client is not None assert stripe_client is not None
assert usage_tracker is not None assert usage_tracker is not None
assert invoice_generator is not None assert invoice_generator is not None
assert scheduler is not None assert scheduler is not None
def test_billing_models_exist(): def test_billing_models_exist():
from mywebdav.billing.models import ( from mywebdav.billing.models import (
SubscriptionPlan, SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
UserSubscription, Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
UsageRecord,
UsageAggregate,
Invoice,
InvoiceLineItem,
PricingConfig,
PaymentMethod,
BillingEvent,
) )
assert SubscriptionPlan is not None assert SubscriptionPlan is not None
assert UserSubscription is not None assert UserSubscription is not None
assert UsageRecord is not None assert UsageRecord is not None
@ -40,71 +23,55 @@ def test_billing_models_exist():
assert PaymentMethod is not None assert PaymentMethod is not None
assert BillingEvent is not None assert BillingEvent is not None
def test_stripe_client_exists(): def test_stripe_client_exists():
from mywebdav.billing.stripe_client import StripeClient from mywebdav.billing.stripe_client import StripeClient
assert StripeClient is not None assert StripeClient is not None
assert hasattr(StripeClient, "create_customer") assert hasattr(StripeClient, 'create_customer')
assert hasattr(StripeClient, "create_invoice") assert hasattr(StripeClient, 'create_invoice')
assert hasattr(StripeClient, "finalize_invoice") assert hasattr(StripeClient, 'finalize_invoice')
def test_usage_tracker_exists(): def test_usage_tracker_exists():
from mywebdav.billing.usage_tracker import UsageTracker from mywebdav.billing.usage_tracker import UsageTracker
assert UsageTracker is not None assert UsageTracker is not None
assert hasattr(UsageTracker, "track_storage") assert hasattr(UsageTracker, 'track_storage')
assert hasattr(UsageTracker, "track_bandwidth") assert hasattr(UsageTracker, 'track_bandwidth')
assert hasattr(UsageTracker, "aggregate_daily_usage") assert hasattr(UsageTracker, 'aggregate_daily_usage')
assert hasattr(UsageTracker, "get_current_storage") assert hasattr(UsageTracker, 'get_current_storage')
assert hasattr(UsageTracker, "get_monthly_usage") assert hasattr(UsageTracker, 'get_monthly_usage')
def test_invoice_generator_exists(): def test_invoice_generator_exists():
from mywebdav.billing.invoice_generator import InvoiceGenerator from mywebdav.billing.invoice_generator import InvoiceGenerator
assert InvoiceGenerator is not None assert InvoiceGenerator is not None
assert hasattr(InvoiceGenerator, "generate_monthly_invoice") assert hasattr(InvoiceGenerator, 'generate_monthly_invoice')
assert hasattr(InvoiceGenerator, "finalize_invoice") assert hasattr(InvoiceGenerator, 'finalize_invoice')
assert hasattr(InvoiceGenerator, "mark_invoice_paid") assert hasattr(InvoiceGenerator, 'mark_invoice_paid')
def test_scheduler_exists(): def test_scheduler_exists():
from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler
assert scheduler is not None assert scheduler is not None
assert callable(start_scheduler) assert callable(start_scheduler)
assert callable(stop_scheduler) assert callable(stop_scheduler)
def test_routers_exist(): def test_routers_exist():
from mywebdav.routers import billing, admin_billing from mywebdav.routers import billing, admin_billing
assert billing is not None assert billing is not None
assert admin_billing is not None assert admin_billing is not None
assert hasattr(billing, "router") assert hasattr(billing, 'router')
assert hasattr(admin_billing, "router") assert hasattr(admin_billing, 'router')
def test_middleware_exists(): def test_middleware_exists():
from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware
assert UsageTrackingMiddleware is not None assert UsageTrackingMiddleware is not None
def test_settings_updated(): def test_settings_updated():
from mywebdav.settings import settings from mywebdav.settings import settings
assert hasattr(settings, 'STRIPE_SECRET_KEY')
assert hasattr(settings, "STRIPE_SECRET_KEY") assert hasattr(settings, 'STRIPE_PUBLISHABLE_KEY')
assert hasattr(settings, "STRIPE_PUBLISHABLE_KEY") assert hasattr(settings, 'STRIPE_WEBHOOK_SECRET')
assert hasattr(settings, "STRIPE_WEBHOOK_SECRET") assert hasattr(settings, 'BILLING_ENABLED')
assert hasattr(settings, "BILLING_ENABLED")
def test_main_includes_billing(): def test_main_includes_billing():
from mywebdav.main import app from mywebdav.main import app
routes = [route.path for route in app.routes] routes = [route.path for route in app.routes]
billing_routes = [r for r in routes if "/billing" in r] billing_routes = [r for r in routes if '/billing' in r]
assert len(billing_routes) > 0 assert len(billing_routes) > 0

View File

@ -1,31 +1,42 @@
import pytest import pytest
import pytest_asyncio
from decimal import Decimal from decimal import Decimal
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from tortoise.contrib.test import initializer, finalizer
from mywebdav.models import User, File, Folder from mywebdav.models import User, File, Folder
from mywebdav.billing.models import UsageRecord, UsageAggregate from mywebdav.billing.models import UsageRecord, UsageAggregate
from mywebdav.billing.usage_tracker import UsageTracker from mywebdav.billing.usage_tracker import UsageTracker
@pytest.fixture(scope="module")
def event_loop():
import asyncio
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture @pytest.fixture(scope="module", autouse=True)
async def initialize_tests():
initializer(["mywebdav.models", "mywebdav.billing.models"], db_url="sqlite://:memory:")
yield
await finalizer()
@pytest.fixture
async def test_user(): async def test_user():
user = await User.create( user = await User.create(
username="testuser", username="testuser",
email="test@example.com", email="test@example.com",
hashed_password="hashed_password_here", hashed_password="hashed_password_here",
is_active=True, is_active=True
) )
yield user yield user
await user.delete() await user.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_storage(test_user): async def test_track_storage(test_user):
await UsageTracker.track_storage( await UsageTracker.track_storage(
user=test_user, user=test_user,
amount_bytes=1024 * 1024 * 100, amount_bytes=1024 * 1024 * 100,
resource_type="file", resource_type="file",
resource_id=1, resource_id=1
) )
records = await UsageRecord.filter(user=test_user, record_type="storage").all() records = await UsageRecord.filter(user=test_user, record_type="storage").all()
@ -34,7 +45,6 @@ async def test_track_storage(test_user):
await records[0].delete() await records[0].delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_bandwidth_upload(test_user): async def test_track_bandwidth_upload(test_user):
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
@ -42,7 +52,7 @@ async def test_track_bandwidth_upload(test_user):
amount_bytes=1024 * 1024 * 50, amount_bytes=1024 * 1024 * 50,
direction="up", direction="up",
resource_type="file", resource_type="file",
resource_id=1, resource_id=1
) )
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all() records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all()
@ -51,7 +61,6 @@ async def test_track_bandwidth_upload(test_user):
await records[0].delete() await records[0].delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_track_bandwidth_download(test_user): async def test_track_bandwidth_download(test_user):
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
@ -59,28 +68,32 @@ async def test_track_bandwidth_download(test_user):
amount_bytes=1024 * 1024 * 75, amount_bytes=1024 * 1024 * 75,
direction="down", direction="down",
resource_type="file", resource_type="file",
resource_id=1, resource_id=1
) )
records = await UsageRecord.filter( records = await UsageRecord.filter(user=test_user, record_type="bandwidth_down").all()
user=test_user, record_type="bandwidth_down"
).all()
assert len(records) == 1 assert len(records) == 1
assert records[0].amount_bytes == 1024 * 1024 * 75 assert records[0].amount_bytes == 1024 * 1024 * 75
await records[0].delete() await records[0].delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_aggregate_daily_usage(test_user): async def test_aggregate_daily_usage(test_user):
await UsageTracker.track_storage(user=test_user, amount_bytes=1024 * 1024 * 100) await UsageTracker.track_storage(
user=test_user,
await UsageTracker.track_bandwidth( amount_bytes=1024 * 1024 * 100
user=test_user, amount_bytes=1024 * 1024 * 50, direction="up"
) )
await UsageTracker.track_bandwidth( await UsageTracker.track_bandwidth(
user=test_user, amount_bytes=1024 * 1024 * 75, direction="down" user=test_user,
amount_bytes=1024 * 1024 * 50,
direction="up"
)
await UsageTracker.track_bandwidth(
user=test_user,
amount_bytes=1024 * 1024 * 75,
direction="down"
) )
aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today()) aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today())
@ -93,10 +106,12 @@ async def test_aggregate_daily_usage(test_user):
await aggregate.delete() await aggregate.delete()
await UsageRecord.filter(user=test_user).delete() await UsageRecord.filter(user=test_user).delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_storage(test_user): async def test_get_current_storage(test_user):
folder = await Folder.create(name="Test Folder", owner=test_user) folder = await Folder.create(
name="Test Folder",
owner=test_user
)
file1 = await File.create( file1 = await File.create(
name="test1.txt", name="test1.txt",
@ -105,7 +120,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain", mime_type="text/plain",
owner=test_user, owner=test_user,
parent=folder, parent=folder,
is_deleted=False, is_deleted=False
) )
file2 = await File.create( file2 = await File.create(
@ -115,7 +130,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain", mime_type="text/plain",
owner=test_user, owner=test_user,
parent=folder, parent=folder,
is_deleted=False, is_deleted=False
) )
storage = await UsageTracker.get_current_storage(test_user) storage = await UsageTracker.get_current_storage(test_user)
@ -126,7 +141,6 @@ async def test_get_current_storage(test_user):
await file2.delete() await file2.delete()
await folder.delete() await folder.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_monthly_usage(test_user): async def test_get_monthly_usage(test_user):
today = date.today() today = date.today()
@ -137,7 +151,7 @@ async def test_get_monthly_usage(test_user):
storage_bytes_avg=1024 * 1024 * 1024 * 10, storage_bytes_avg=1024 * 1024 * 1024 * 10,
storage_bytes_peak=1024 * 1024 * 1024 * 12, storage_bytes_peak=1024 * 1024 * 1024 * 12,
bandwidth_up_bytes=1024 * 1024 * 1024 * 2, bandwidth_up_bytes=1024 * 1024 * 1024 * 2,
bandwidth_down_bytes=1024 * 1024 * 1024 * 5, bandwidth_down_bytes=1024 * 1024 * 1024 * 5
) )
usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month) usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month)
@ -149,14 +163,11 @@ async def test_get_monthly_usage(test_user):
await aggregate.delete() await aggregate.delete()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_monthly_usage_empty(test_user): async def test_get_monthly_usage_empty(test_user):
future_date = date.today() + timedelta(days=365) future_date = date.today() + timedelta(days=365)
usage = await UsageTracker.get_monthly_usage( usage = await UsageTracker.get_monthly_usage(test_user, future_date.year, future_date.month)
test_user, future_date.year, future_date.month
)
assert usage["storage_gb_avg"] == 0 assert usage["storage_gb_avg"] == 0
assert usage["storage_gb_peak"] == 0 assert usage["storage_gb_peak"] == 0

View File

@ -1,23 +1,8 @@
import pytest import pytest
import pytest_asyncio
import asyncio import asyncio
from tortoise import Tortoise
@pytest.fixture(scope="session")
@pytest_asyncio.fixture(scope="session") def event_loop():
async def init_db(): loop = asyncio.get_event_loop_policy().new_event_loop()
"""Initialize the database for testing.""" yield loop
await Tortoise.init( loop.close()
db_url="sqlite://:memory:",
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
)
await Tortoise.generate_schemas()
yield
await Tortoise.close_connections()
@pytest_asyncio.fixture(autouse=True)
async def setup_db(init_db):
"""Ensure database is initialized before each test."""
# The init_db fixture handles the setup
yield

0
tests/e2e/__init__.py Normal file
View File

67
tests/e2e/conftest.py Normal file
View File

@ -0,0 +1,67 @@
import pytest
import pytest_asyncio
import asyncio
from playwright.async_api import async_playwright, Page, expect
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="function")
async def browser():
async with async_playwright() as p:
browser = await p.chromium.launch(headless=False, slow_mo=500)
yield browser
await browser.close()
@pytest_asyncio.fixture(scope="function")
async def context(browser):
context = await browser.new_context(
viewport={"width": 1920, "height": 1080},
user_agent="Mozilla/5.0 (X11; Linux x86_64) MyWebdav E2E Tests",
ignore_https_errors=True,
service_workers='block'
)
yield context
await context.close()
@pytest_asyncio.fixture(scope="function")
async def page(context):
page = await context.new_page()
yield page
await page.close()
@pytest.fixture
def base_url():
return "http://localhost:9004"
@pytest_asyncio.fixture(scope="function", autouse=True)
async def login(page: Page, base_url):
print(f"Navigating to base_url: {base_url}")
await page.goto(f"{base_url}/")
await page.screenshot(path="01_initial_page.png")
# If already logged in, log out first to ensure a clean state
if await page.locator('a:has-text("Logout")').is_visible():
await page.click('a:has-text("Logout")')
await page.screenshot(path="02_after_logout.png")
# Now, proceed with login or registration
login_form = page.locator('#login-form:visible')
if await login_form.count() > 0:
await login_form.locator('input[name="username"]').fill('billingtest')
await login_form.locator('input[name="password"]').fill('password123')
await page.screenshot(path="03_before_login_click.png")
await expect(page.locator('h2:has-text("Files")')).to_be_visible(timeout=10000)
else:
# If no login form, try to register
await page.click('text=Sign Up')
register_form = page.locator('#register-form:visible')
await register_form.locator('input[name="username"]').fill('billingtest')
await register_form.locator('input[name="email"]').fill('billingtest@example.com')
await register_form.locator('input[name="password"]').fill('password123')
await page.screenshot(path="05_before_register_click.png")
await register_form.locator('button[type="submit"]').click()
await expect(page.locator('h1:has-text("My Files")')).to_be_visible(timeout=10000)

181
tests/e2e/test_admin.py Normal file
View File

@ -0,0 +1,181 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestAdmin:
async def test_01_admin_login(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Admin should see admin nav links
await expect(page.locator('a.nav-link[data-view="admin"]')).to_be_visible()
await expect(page.locator('a.nav-link[data-view="admin-billing"]')).to_be_visible()
async def test_02_navigate_to_admin_dashboard(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin"]')
await page.wait_for_timeout(1000)
await expect(page.locator('admin-dashboard')).to_be_visible()
async def test_03_view_user_management(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin"]')
await page.wait_for_timeout(1000)
# Should show user list
await expect(page.locator('.user-list')).to_be_visible()
await expect(page.locator('text=e2etestuser')).to_be_visible()
async def test_04_view_system_statistics(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin"]')
await page.wait_for_timeout(1000)
# Should show system stats
await expect(page.locator('.system-stats')).to_be_visible()
await expect(page.locator('.stat-item')).to_be_visible()
async def test_05_manage_user_permissions(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin"]')
await page.wait_for_timeout(1000)
# Click on user to manage
await page.click('.user-item:has-text("e2etestuser")')
await page.wait_for_timeout(500)
# Should show user details modal
await expect(page.locator('.user-details-modal')).to_be_visible()
# Change permissions
await page.select_option('select[name="role"]', 'moderator')
await page.click('button:has-text("Save")')
await page.wait_for_timeout(1000)
async def test_06_view_storage_usage(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin"]')
await page.wait_for_timeout(1000)
# Should show storage usage
await expect(page.locator('.storage-usage')).to_be_visible()
async def test_07_navigate_to_admin_billing(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin-billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('admin-billing')).to_be_visible()
async def test_08_view_billing_statistics(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin-billing"]')
await page.wait_for_timeout(1000)
# Should show billing stats
await expect(page.locator('.billing-stats')).to_be_visible()
async def test_09_manage_pricing(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin-billing"]')
await page.wait_for_timeout(1000)
# Click pricing management
await page.click('button:has-text("Manage Pricing")')
await page.wait_for_timeout(500)
# Should show pricing form
await expect(page.locator('.pricing-form')).to_be_visible()
# Update a price
await page.fill('input[name="storage_per_gb_month"]', '0.005')
await page.click('button:has-text("Update")')
await page.wait_for_timeout(1000)
async def test_10_generate_invoice(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'adminuser')
await page.fill('#login-form input[name="password"]', 'adminpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="admin-billing"]')
await page.wait_for_timeout(1000)
# Click generate invoice
await page.click('button:has-text("Generate Invoice")')
await page.wait_for_timeout(500)
# Select user
await page.select_option('select[name="user"]', 'e2etestuser')
await page.click('button:has-text("Generate")')
await page.wait_for_timeout(1000)
# Should show success message
await expect(page.locator('text=Invoice generated')).to_be_visible()

170
tests/e2e/test_auth_flow.py Normal file
View File

@ -0,0 +1,170 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestAuthFlow:
async def test_01_user_registration(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
# Click Sign Up
await page.click('text=Sign Up')
await page.wait_for_timeout(500)
# Fill registration form
await page.fill('#register-form input[name="username"]', 'e2etestuser')
await page.fill('#register-form input[name="email"]', 'e2etestuser@example.com')
await page.fill('#register-form input[name="password"]', 'testpassword123')
# Submit registration
await page.click('#register-form button[type="submit"]')
await page.wait_for_timeout(2000)
await expect(page.locator('text=My Files')).to_be_visible()
async def test_02_user_login(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
# Unregister service worker to avoid interference
await page.evaluate("""
navigator.serviceWorker.getRegistrations().then(function(registrations) {
for(let registration of registrations) {
registration.unregister();
}
});
""")
await page.reload()
await page.wait_for_load_state("networkidle")
# Fill login form
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
# Submit login
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await expect(page.locator('text=My Files')).to_be_visible()
async def test_03_navigate_to_files_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Files view should be active by default
await expect(page.locator('a.nav-link.active[data-view="files"]')).to_be_visible()
await expect(page.locator('file-list')).to_be_visible()
async def test_04_navigate_to_photos_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
await expect(page.locator('a.nav-link.active[data-view="photos"]')).to_be_visible()
await expect(page.locator('photo-gallery')).to_be_visible()
async def test_05_navigate_to_shared_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="shared"]')
await page.wait_for_timeout(1000)
await expect(page.locator('a.nav-link.active[data-view="shared"]')).to_be_visible()
await expect(page.locator('shared-items')).to_be_visible()
async def test_06_navigate_to_starred_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="starred"]')
await page.wait_for_timeout(1000)
await expect(page.locator('a.nav-link.active[data-view="starred"]')).to_be_visible()
await expect(page.locator('starred-items')).to_be_visible()
async def test_07_navigate_to_recent_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="recent"]')
await page.wait_for_timeout(1000)
await expect(page.locator('a.nav-link.active[data-view="recent"]')).to_be_visible()
await expect(page.locator('recent-files')).to_be_visible()
async def test_08_navigate_to_deleted_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="deleted"]')
await page.wait_for_timeout(1000)
await expect(page.locator('a.nav-link.active[data-view="deleted"]')).to_be_visible()
await expect(page.locator('deleted-files')).to_be_visible()
async def test_09_navigate_to_billing_view(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('a.nav-link.active[data-view="billing"]')).to_be_visible()
await expect(page.locator('billing-dashboard')).to_be_visible()
async def test_10_user_logout(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Click logout
await page.click('#logout-btn')
await page.wait_for_timeout(1000)
# Should be back to login
await expect(page.locator('#login-form')).to_be_visible()

View File

@ -0,0 +1,199 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestBillingAdminFlow:
async def test_01_admin_login(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('input[name="username"]', 'adminuser')
await page.fill('input[name="password"]', 'adminpassword123')
await page.click('button[type="submit"]')
await page.wait_for_load_state("networkidle")
await expect(page.locator('text=Dashboard')).to_be_visible()
async def test_02_navigate_to_admin_billing(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('input[name="username"]', 'adminuser')
await page.fill('input[name="password"]', 'adminpassword123')
await page.click('button[type="submit"]')
await page.wait_for_load_state("networkidle")
await page.click('text=Admin')
await page.wait_for_load_state("networkidle")
await page.click('text=Billing')
await page.wait_for_url("**/admin/billing")
await expect(page.locator('h2:has-text("Billing Administration")')).to_be_visible()
async def test_03_view_revenue_statistics(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.stats-cards')).to_be_visible()
await expect(page.locator('.stat-card:has-text("Total Revenue")')).to_be_visible()
await expect(page.locator('.stat-card:has-text("Total Invoices")')).to_be_visible()
await expect(page.locator('.stat-card:has-text("Pending Invoices")')).to_be_visible()
async def test_04_verify_revenue_value_display(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
revenue_stat = page.locator('.stat-card:has-text("Total Revenue") .stat-value')
await expect(revenue_stat).to_be_visible()
revenue_text = await revenue_stat.text_content()
assert '$' in revenue_text or '0' in revenue_text
async def test_05_view_pricing_configuration_table(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.pricing-config-section')).to_be_visible()
await expect(page.locator('h3:has-text("Pricing Configuration")')).to_be_visible()
await expect(page.locator('.pricing-table')).to_be_visible()
async def test_06_verify_pricing_config_rows(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.pricing-table tbody tr')).to_have_count(6, timeout=5000)
await expect(page.locator('td:has-text("Storage cost per GB per month")')).to_be_visible()
await expect(page.locator('td:has-text("Bandwidth egress cost per GB")')).to_be_visible()
await expect(page.locator('td:has-text("Free tier storage")')).to_be_visible()
async def test_07_click_edit_pricing_button(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
edit_buttons = page.locator('.btn-edit')
await expect(edit_buttons.first).to_be_visible()
page.on('dialog', lambda dialog: dialog.dismiss())
await edit_buttons.first.click()
await page.wait_for_timeout(1000)
async def test_08_edit_pricing_value(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
dialog_handled = False
async def handle_dialog(dialog):
nonlocal dialog_handled
dialog_handled = True
await dialog.accept(text="0.005")
page.on('dialog', handle_dialog)
edit_button = page.locator('.btn-edit').first
await edit_button.click()
await page.wait_for_timeout(1000)
if dialog_handled:
await page.wait_for_timeout(2000)
async def test_09_view_invoice_generation_section(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.invoice-generation-section')).to_be_visible()
await expect(page.locator('h3:has-text("Generate Invoices")')).to_be_visible()
async def test_10_verify_invoice_generation_form(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('#invoiceYear')).to_be_visible()
await expect(page.locator('#invoiceMonth')).to_be_visible()
await expect(page.locator('#generateInvoices')).to_be_visible()
async def test_11_set_invoice_generation_date(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await page.fill('#invoiceYear', '2024')
await page.fill('#invoiceMonth', '11')
year_value = await page.input_value('#invoiceYear')
month_value = await page.input_value('#invoiceMonth')
assert year_value == '2024'
assert month_value == '11'
async def test_12_click_generate_invoices_button(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await page.fill('#invoiceYear', '2024')
await page.fill('#invoiceMonth', '10')
page.on('dialog', lambda dialog: dialog.dismiss())
await page.click('#generateInvoices')
await page.wait_for_timeout(1000)
async def test_13_verify_all_stat_cards_present(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
stat_cards = page.locator('.stat-card')
await expect(stat_cards).to_have_count(3)
async def test_14_verify_pricing_table_headers(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('th:has-text("Configuration")')).to_be_visible()
await expect(page.locator('th:has-text("Current Value")')).to_be_visible()
await expect(page.locator('th:has-text("Unit")')).to_be_visible()
await expect(page.locator('th:has-text("Actions")')).to_be_visible()
async def test_15_verify_all_edit_buttons_present(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
edit_buttons = page.locator('.btn-edit')
count = await edit_buttons.count()
assert count == 6
async def test_16_scroll_through_admin_dashboard(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await page.evaluate("window.scrollTo(0, 0)")
await page.wait_for_timeout(500)
await page.evaluate("window.scrollTo(0, document.body.scrollHeight / 2)")
await page.wait_for_timeout(500)
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
await page.wait_for_timeout(500)
await page.evaluate("window.scrollTo(0, 0)")
await page.wait_for_timeout(500)
async def test_17_verify_responsive_layout(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.admin-billing')).to_be_visible()
bounding_box = await page.locator('.admin-billing').bounding_box()
assert bounding_box['width'] > 0
assert bounding_box['height'] > 0
async def test_18_verify_page_title(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
title = await page.title()
assert title is not None

View File

@ -0,0 +1,183 @@
import pytest
import asyncio
from playwright.async_api import expect, Page
@pytest.mark.asyncio
class TestBillingAPIFlow:
async def test_01_api_get_current_usage(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/usage/current",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert 'storage_gb' in data
assert 'bandwidth_down_gb_today' in data
assert 'as_of' in data
async def test_02_api_get_monthly_usage(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/usage/monthly",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert 'storage_gb_avg' in data
assert 'bandwidth_down_gb' in data
assert 'period' in data
async def test_03_api_get_subscription(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/subscription",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert 'billing_type' in data
assert 'status' in data
assert data['billing_type'] == 'pay_as_you_go'
async def test_04_api_list_invoices(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/invoices",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert isinstance(data, list)
async def test_05_api_get_pricing(self, page: Page, base_url):
response = await page.request.get(f"{base_url}/api/billing/pricing")
assert response.ok
data = await response.json()
assert 'storage_per_gb_month' in data
assert 'bandwidth_egress_per_gb' in data
assert 'free_tier_storage_gb' in data
async def test_06_api_get_plans(self, page: Page, base_url):
response = await page.request.get(f"{base_url}/api/billing/plans")
assert response.ok
data = await response.json()
assert isinstance(data, list)
async def test_07_api_admin_get_pricing_config(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'adminuser', 'adminpassword123')
response = await page.request.get(
f"{base_url}/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert isinstance(data, list)
assert len(data) >= 6
async def test_08_api_admin_get_stats(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'adminuser', 'adminpassword123')
response = await page.request.get(
f"{base_url}/api/admin/billing/stats",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert 'total_revenue' in data
assert 'total_invoices' in data
assert 'pending_invoices' in data
async def test_09_api_user_cannot_access_admin_endpoints(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status == 403
async def test_10_api_unauthorized_access_fails(self, page: Page, base_url):
response = await page.request.get(f"{base_url}/api/billing/usage/current")
assert response.status == 401
async def test_11_api_create_payment_setup_intent(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.post(
f"{base_url}/api/billing/payment-methods/setup-intent",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok or response.status == 500
async def test_12_api_get_payment_methods(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/payment-methods",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert isinstance(data, list)
async def test_13_api_response_headers(self, page: Page, base_url):
response = await page.request.get(f"{base_url}/api/billing/pricing")
assert response.ok
headers = response.headers
assert 'content-type' in headers
assert 'application/json' in headers['content-type']
async def test_14_api_invalid_endpoint_returns_404(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/nonexistent",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status == 404
async def test_15_api_request_with_params(self, page: Page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/usage/monthly?year=2024&month=11",
headers={"Authorization": f"Bearer {token}"}
)
assert response.ok
data = await response.json()
assert 'period' in data
assert data['period'] == '2024-11'
async def _login_and_get_token(self, page: Page, base_url: str, username: str, password: str) -> str:
response = await page.request.post(
f"{base_url}/api/auth/login",
data={"username": username, "password": password}
)
if response.ok:
data = await response.json()
return data.get('access_token', '')
return ''

View File

@ -0,0 +1,264 @@
import pytest
import asyncio
from playwright.async_api import expect
from datetime import datetime, date
from mywebdav.billing.models import UsageAggregate, Invoice
from mywebdav.models import User
@pytest.mark.asyncio
class TestBillingIntegrationFlow:
async def test_01_complete_user_journey(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('input[name="username"]', 'testuser')
await page.fill('input[name="password"]', 'testpassword123')
await page.click('button[type="submit"]')
await page.wait_for_load_state("networkidle")
await page.click('text=Files')
await page.wait_for_load_state("networkidle")
await page.wait_for_timeout(1000)
await page.click('text=Billing')
await page.wait_for_load_state("networkidle")
await expect(page.locator('.billing-dashboard')).to_be_visible()
await page.wait_for_timeout(2000)
async def test_02_verify_usage_tracking_after_operations(self, page, base_url):
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
initial_storage = await page.locator('.usage-value').first.text_content()
await page.goto(f"{base_url}/files")
await page.wait_for_load_state("networkidle")
await page.wait_for_timeout(1000)
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
current_storage = await page.locator('.usage-value').first.text_content()
assert current_storage is not None
async def test_03_verify_cost_calculation_updates(self, page, base_url):
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.estimated-cost')).to_be_visible()
cost_text = await page.locator('.estimated-cost').text_content()
assert '$' in cost_text
await page.reload()
await page.wait_for_load_state("networkidle")
new_cost_text = await page.locator('.estimated-cost').text_content()
assert '$' in new_cost_text
async def test_04_verify_subscription_consistency(self, page, base_url):
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
subscription_badge = page.locator('.subscription-badge')
await expect(subscription_badge).to_be_visible()
badge_classes = await subscription_badge.get_attribute('class')
assert 'active' in badge_classes or 'inactive' in badge_classes
async def test_05_verify_pricing_consistency_across_pages(self, page, base_url):
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
storage_price_user = await page.locator('.pricing-item:has-text("Storage")').last.text_content()
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
config_value = await page.locator('.pricing-table tbody tr').first.locator('.config-value').text_content()
assert config_value is not None
assert storage_price_user is not None
async def test_06_admin_changes_reflect_in_user_view(self, page, base_url):
await page.goto(f"{base_url}/admin/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.pricing-table')).to_be_visible()
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.pricing-card')).to_be_visible()
async def test_07_navigation_flow_consistency(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.click('text=Billing')
await page.wait_for_url("**/billing")
await page.click('text=Dashboard')
await page.wait_for_url("**/dashboard")
await page.click('text=Billing')
await page.wait_for_url("**/billing")
await expect(page.locator('.billing-dashboard')).to_be_visible()
async def test_08_refresh_maintains_state(self, page, base_url):
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
usage_before = await page.locator('.usage-value').first.text_content()
await page.reload()
await page.wait_for_load_state("networkidle")
usage_after = await page.locator('.usage-value').first.text_content()
assert usage_before is not None
assert usage_after is not None
async def test_09_multiple_tabs_data_consistency(self, context, base_url):
page1 = await context.new_page()
page2 = await context.new_page()
await page1.goto(f"{base_url}/billing")
await page1.wait_for_load_state("networkidle")
await page2.goto(f"{base_url}/billing")
await page2.wait_for_load_state("networkidle")
usage1 = await page1.locator('.usage-value').first.text_content()
usage2 = await page2.locator('.usage-value').first.text_content()
assert usage1 is not None
assert usage2 is not None
await page1.close()
await page2.close()
async def test_10_api_and_ui_data_consistency(self, page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
api_response = await page.request.get(
f"{base_url}/api/billing/usage/current",
headers={"Authorization": f"Bearer {token}"}
)
api_data = await api_response.json()
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.usage-card')).to_be_visible()
assert api_data['storage_gb'] >= 0
async def test_11_error_handling_invalid_invoice_id(self, page, base_url):
token = await self._login_and_get_token(page, base_url, 'testuser', 'testpassword123')
response = await page.request.get(
f"{base_url}/api/billing/invoices/99999",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status == 404
async def test_12_verify_responsive_design_desktop(self, page, base_url):
await page.set_viewport_size({"width": 1920, "height": 1080})
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.billing-cards')).to_be_visible()
cards = page.locator('.billing-card')
count = await cards.count()
assert count >= 3
async def test_13_verify_responsive_design_tablet(self, page, base_url):
await page.set_viewport_size({"width": 768, "height": 1024})
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.billing-dashboard')).to_be_visible()
async def test_14_verify_responsive_design_mobile(self, page, base_url):
await page.set_viewport_size({"width": 375, "height": 667})
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.billing-dashboard')).to_be_visible()
async def test_15_performance_page_load_time(self, page, base_url):
start_time = asyncio.get_event_loop().time()
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
end_time = asyncio.get_event_loop().time()
load_time = end_time - start_time
assert load_time < 5.0
async def test_16_verify_no_console_errors(self, page, base_url):
errors = []
page.on("console", lambda msg: errors.append(msg) if msg.type == "error" else None)
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await page.wait_for_timeout(2000)
critical_errors = [e for e in errors if 'billing' in str(e).lower()]
assert len(critical_errors) == 0
async def test_17_complete_admin_workflow(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('input[name="username"]', 'adminuser')
await page.fill('input[name="password"]', 'adminpassword123')
await page.click('button[type="submit"]')
await page.wait_for_load_state("networkidle")
await page.click('text=Admin')
await page.wait_for_load_state("networkidle")
await page.click('text=Billing')
await page.wait_for_load_state("networkidle")
await expect(page.locator('.admin-billing')).to_be_visible()
await expect(page.locator('.stats-cards')).to_be_visible()
await expect(page.locator('.pricing-config-section')).to_be_visible()
await expect(page.locator('.invoice-generation-section')).to_be_visible()
await page.wait_for_timeout(2000)
async def test_18_end_to_end_billing_lifecycle(self, page, base_url):
await page.goto(f"{base_url}/billing")
await page.wait_for_load_state("networkidle")
await expect(page.locator('.billing-dashboard')).to_be_visible()
await expect(page.locator('.usage-card')).to_be_visible()
await expect(page.locator('.cost-card')).to_be_visible()
await expect(page.locator('.pricing-card')).to_be_visible()
await expect(page.locator('.invoices-section')).to_be_visible()
await expect(page.locator('.payment-methods-section')).to_be_visible()
await page.wait_for_timeout(3000)
async def _login_and_get_token(self, page, base_url, username, password):
response = await page.request.post(
f"{base_url}/api/auth/login",
data={"username": username, "password": password}
)
if response.ok:
data = await response.json()
return data.get('access_token', '')
return ''

View File

@ -0,0 +1,264 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestBillingUserFlow:
async def test_01_user_registration(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.click('text=Sign Up')
await page.wait_for_timeout(500)
await page.fill('#register-form input[name="username"]', 'billingtest')
await page.fill('#register-form input[name="email"]', 'billingtest@example.com')
await page.fill('#register-form input[name="password"]', 'password123')
await page.click('#register-form button[type="submit"]')
await page.wait_for_timeout(2000)
await expect(page.locator('text=My Files')).to_be_visible()
async def test_02_user_login(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.evaluate("""
navigator.serviceWorker.getRegistrations().then(function(registrations) {
for(let registration of registrations) {
registration.unregister();
}
});
""")
await page.reload()
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await expect(page.locator('text=My Files')).to_be_visible()
async def test_03_navigate_to_billing_dashboard(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('billing-dashboard')).to_be_visible()
async def test_04_view_current_usage(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.usage-card')).to_be_visible()
await expect(page.locator('text=Current Usage')).to_be_visible()
await expect(page.locator('.usage-label:has-text("Storage")')).to_be_visible()
await expect(page.locator('.usage-label:has-text("Bandwidth")')).to_be_visible()
async def test_05_view_estimated_cost(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.cost-card')).to_be_visible()
await expect(page.locator('text=Estimated Monthly Cost')).to_be_visible()
await expect(page.locator('.estimated-cost')).to_be_visible()
cost_text = await page.locator('.estimated-cost').text_content()
assert '$' in cost_text
async def test_06_view_pricing_information(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.pricing-card')).to_be_visible()
await expect(page.locator('text=Current Pricing')).to_be_visible()
await expect(page.locator('.pricing-item:has-text("Storage")')).to_be_visible()
await expect(page.locator('.pricing-item:has-text("Bandwidth")')).to_be_visible()
await expect(page.locator('.pricing-item:has-text("Free Tier")')).to_be_visible()
async def test_07_view_invoice_history_empty(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.invoices-section')).to_be_visible()
await expect(page.locator('text=Recent Invoices')).to_be_visible()
no_invoices = page.locator('.no-invoices')
if await no_invoices.is_visible():
await expect(no_invoices).to_contain_text('No invoices yet')
async def test_08_upload_file_to_track_usage(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.set_input_files('input[type="file"]', {
'name': 'test-file.txt',
'mimeType': 'text/plain',
'buffer': b'This is a test file for billing usage tracking.'
})
await page.click('button:has-text("Upload")')
await page.wait_for_timeout(2000)
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_09_verify_usage_updated_after_upload(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(2000)
storage_value = page.locator('.usage-item:has(.usage-label:has-text("Storage")) .usage-value')
await expect(storage_value).to_be_visible()
async def test_10_add_payment_method_button(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.payment-methods-section')).to_be_visible()
await expect(page.locator('text=Payment Methods')).to_be_visible()
await expect(page.locator('#addPaymentMethod')).to_be_visible()
async def test_11_click_add_payment_method(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
page.on('dialog', lambda dialog: dialog.accept())
await page.click('#addPaymentMethod')
await page.wait_for_timeout(1000)
async def test_12_view_subscription_status(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.subscription-badge')).to_be_visible()
badge_text = await page.locator('.subscription-badge').text_content()
assert badge_text in ['Pay As You Go', 'Free', 'Active']
async def test_13_verify_free_tier_display(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.usage-info:has-text("GB included free")')).to_be_visible()
free_tier_info = await page.locator('.usage-info').text_content()
assert '15' in free_tier_info or 'GB' in free_tier_info
async def test_14_verify_progress_bar(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.usage-progress')).to_be_visible()
await expect(page.locator('.usage-progress-bar')).to_be_visible()
async def test_15_verify_cost_breakdown(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="billing"]')
await page.wait_for_timeout(1000)
await expect(page.locator('.cost-breakdown')).to_be_visible()
await expect(page.locator('.cost-item:has-text("Storage")')).to_be_visible()
await expect(page.locator('.cost-item:has-text("Bandwidth")')).to_be_visible()

View File

@ -0,0 +1,127 @@
import pytest
import asyncio
from playwright.async_api import expect, Page
@pytest.mark.asyncio
class TestBillingUserFlowRefactored:
async def test_navigate_to_billing_dashboard(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('billing-dashboard')).to_be_visible()
async def test_view_current_usage(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.usage-card')).to_be_visible()
await expect(page.locator('text=Current Usage')).to_be_visible()
await expect(page.locator('.usage-label:has-text("Storage")')).to_be_visible()
await expect(page.locator('.usage-label:has-text("Bandwidth")')).to_be_visible()
async def test_view_estimated_cost(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.cost-card')).to_be_visible()
await expect(page.locator('text=Estimated Monthly Cost')).to_be_visible()
await expect(page.locator('.estimated-cost')).to_be_visible()
cost_text = await page.locator('.estimated-cost').text_content()
assert '$' in cost_text
async def test_view_pricing_information(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.pricing-card')).to_be_visible()
await expect(page.locator('text=Current Pricing')).to_be_visible()
await expect(page.locator('.pricing-item:has-text("Storage")')).to_be_visible()
await expect(page.locator('.pricing-item:has-text("Bandwidth")')).to_be_visible()
await expect(page.locator('.pricing-item:has-text("Free Tier")')).to_be_visible()
async def test_view_invoice_history_empty(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.invoices-section')).to_be_visible()
await expect(page.locator('text=Recent Invoices')).to_be_visible()
no_invoices = page.locator('.no-invoices')
if await no_invoices.is_visible():
await expect(no_invoices).to_contain_text('No invoices yet')
async def test_upload_file_to_track_usage(self, page: Page):
# Navigate back to files view
await page.click('a.nav-link[data-view="files"]')
await expect(page.locator('h2:has-text("Files")')).to_be_visible(timeout=5000)
await page.set_input_files('input[type="file"]', {
'name': 'test-file.txt',
'mimeType': 'text/plain',
'buffer': b'This is a test file for billing usage tracking.'
})
await page.click('button:has-text("Upload")')
await expect(page.locator('text=test-file.txt')).to_be_visible(timeout=10000)
async def test_verify_usage_updated_after_upload(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
storage_value_locator = page.locator('.usage-item:has(.usage-label:has-text("Storage")) .usage-value')
await expect(storage_value_locator).not_to_contain_text("0 B", timeout=10000)
async def test_add_payment_method_button(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.payment-methods-section')).to_be_visible()
await expect(page.locator('text=Payment Methods')).to_be_visible()
await expect(page.locator('#addPaymentMethod')).to_be_visible()
async def test_click_add_payment_method(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
page.on('dialog', lambda dialog: dialog.accept())
await page.click('#addPaymentMethod')
# We can't assert much here as it likely navigates to Stripe
await page.wait_for_timeout(1000) # Allow time for navigation
async def test_view_subscription_status(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.subscription-badge')).to_be_visible()
badge_text = await page.locator('.subscription-badge').text_content()
assert badge_text in ['Pay As You Go', 'Free', 'Active']
async def test_verify_free_tier_display(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.usage-info:has-text("GB included free")')).to_be_visible()
free_tier_info = await page.locator('.usage-info').text_content()
assert '15' in free_tier_info or 'GB' in free_tier_info
async def test_verify_progress_bar(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.usage-progress')).to_be_visible()
await expect(page.locator('.usage-progress-bar')).to_be_visible()
async def test_verify_cost_breakdown(self, page: Page):
await page.click('a.nav-link[data-view="billing"]')
await expect(page.locator('h1:has-text("Billing & Usage")')).to_be_visible(timeout=5000)
await expect(page.locator('.cost-breakdown')).to_be_visible()
await expect(page.locator('.cost-item:has-text("Storage")')).to_be_visible()
await expect(page.locator('.cost-item:has-text("Bandwidth")')).to_be_visible()

View File

@ -0,0 +1,222 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestFileManagement:
async def test_01_upload_file(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Upload a file
await page.set_input_files('input[type="file"]', {
'name': 'test-file.txt',
'mimeType': 'text/plain',
'buffer': b'This is a test file for e2e testing.'
})
await page.click('button:has-text("Upload")')
await page.wait_for_timeout(2000)
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_02_create_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Click create folder button
await page.click('button:has-text("New Folder")')
await page.wait_for_timeout(500)
# Fill folder name
await page.fill('input[placeholder="Folder name"]', 'Test Folder')
await page.click('button:has-text("Create")')
await page.wait_for_timeout(1000)
await expect(page.locator('text=Test Folder')).to_be_visible()
async def test_03_navigate_into_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Click on the folder
await page.click('text=Test Folder')
await page.wait_for_timeout(1000)
# Should be in the folder, breadcrumb should show it
await expect(page.locator('.breadcrumb')).to_contain_text('Test Folder')
async def test_04_upload_file_in_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Navigate to folder
await page.click('text=Test Folder')
await page.wait_for_timeout(1000)
# Upload file in folder
await page.set_input_files('input[type="file"]', {
'name': 'folder-file.txt',
'mimeType': 'text/plain',
'buffer': b'This file is in a folder.'
})
await page.click('button:has-text("Upload")')
await page.wait_for_timeout(2000)
await expect(page.locator('text=folder-file.txt')).to_be_visible()
async def test_05_navigate_back_from_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Navigate to folder
await page.click('text=Test Folder')
await page.wait_for_timeout(1000)
# Click breadcrumb to go back
await page.click('.breadcrumb a:first-child')
await page.wait_for_timeout(1000)
# Should be back in root
await expect(page.locator('text=Test Folder')).to_be_visible()
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_06_select_and_delete_file(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Select file
await page.click('.file-item:has-text("test-file.txt") .file-checkbox')
await page.wait_for_timeout(500)
# Click delete
await page.click('button:has-text("Delete")')
await page.wait_for_timeout(500)
# Confirm delete
await page.click('button:has-text("Confirm")')
await page.wait_for_timeout(1000)
# File should be gone
await expect(page.locator('text=test-file.txt')).not_to_be_visible()
async def test_07_restore_deleted_file(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Go to deleted files
await page.click('a.nav-link[data-view="deleted"]')
await page.wait_for_timeout(1000)
# Select deleted file
await page.click('.file-item:has-text("test-file.txt") .file-checkbox')
await page.wait_for_timeout(500)
# Click restore
await page.click('button:has-text("Restore")')
await page.wait_for_timeout(1000)
# Go back to files
await page.click('a.nav-link[data-view="files"]')
await page.wait_for_timeout(1000)
# File should be back
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_08_rename_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Right-click on folder or use context menu
await page.click('.folder-item:has-text("Test Folder")', button='right')
await page.wait_for_timeout(500)
# Click rename
await page.click('text=Rename')
await page.wait_for_timeout(500)
# Fill new name
await page.fill('input[placeholder="New name"]', 'Renamed Folder')
await page.click('button:has-text("Rename")')
await page.wait_for_timeout(1000)
await expect(page.locator('text=Renamed Folder')).to_be_visible()
await expect(page.locator('text=Test Folder')).not_to_be_visible()
async def test_09_star_file(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Click star on file
await page.click('.file-item:has-text("test-file.txt") .star-btn')
await page.wait_for_timeout(1000)
# Go to starred
await page.click('a.nav-link[data-view="starred"]')
await page.wait_for_timeout(1000)
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_10_search_files(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Type in search
await page.fill('#search-input', 'test-file')
await page.wait_for_timeout(1000)
# Should show search results
await expect(page.locator('.search-results')).to_be_visible()
await expect(page.locator('text=test-file.txt')).to_be_visible()

View File

@ -0,0 +1,227 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestPhotoGallery:
async def test_01_upload_image_file(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Upload an image file
await page.set_input_files('input[type="file"]', {
'name': 'test-image.jpg',
'mimeType': 'image/jpeg',
'buffer': b'fake image data' # In real test, use actual image bytes
})
await page.click('button:has-text("Upload")')
await page.wait_for_timeout(2000)
await expect(page.locator('text=test-image.jpg')).to_be_visible()
async def test_02_navigate_to_photo_gallery(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
await expect(page.locator('photo-gallery')).to_be_visible()
async def test_03_view_image_in_gallery(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Should show uploaded image
await expect(page.locator('.photo-item')).to_be_visible()
async def test_04_click_on_photo_to_preview(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Click on photo
await page.click('.photo-item img')
await page.wait_for_timeout(1000)
# Should open preview
await expect(page.locator('file-preview')).to_be_visible()
async def test_05_close_photo_preview(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Click on photo
await page.click('.photo-item img')
await page.wait_for_timeout(1000)
# Close preview
await page.click('.preview-close-btn')
await page.wait_for_timeout(500)
# Preview should be closed
await expect(page.locator('file-preview')).not_to_be_visible()
async def test_06_upload_multiple_images(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Upload multiple images
await page.set_input_files('input[type="file"]', [
{
'name': 'image1.jpg',
'mimeType': 'image/jpeg',
'buffer': b'fake image 1'
},
{
'name': 'image2.png',
'mimeType': 'image/png',
'buffer': b'fake image 2'
}
])
await page.click('button:has-text("Upload")')
await page.wait_for_timeout(2000)
# Go to gallery
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Should show multiple photos
photo_count = await page.locator('.photo-item').count()
assert photo_count >= 2
async def test_07_filter_photos_by_date(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Click date filter
await page.click('.date-filter-btn')
await page.wait_for_timeout(500)
# Select today
await page.click('text=Today')
await page.wait_for_timeout(1000)
# Should filter photos
await expect(page.locator('.photo-item')).to_be_visible()
async def test_08_search_photos(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Type in gallery search
await page.fill('.gallery-search-input', 'image1')
await page.wait_for_timeout(1000)
# Should show filtered results
await expect(page.locator('text=image1.jpg')).to_be_visible()
await expect(page.locator('text=image2.png')).not_to_be_visible()
async def test_09_create_album(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Click create album
await page.click('button:has-text("Create Album")')
await page.wait_for_timeout(500)
# Fill album name
await page.fill('input[name="album-name"]', 'Test Album')
await page.click('button:has-text("Create")')
await page.wait_for_timeout(1000)
# Should show album
await expect(page.locator('text=Test Album')).to_be_visible()
async def test_10_add_photos_to_album(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
await page.click('a.nav-link[data-view="photos"]')
await page.wait_for_timeout(1000)
# Select photos
await page.click('.photo-item .photo-checkbox', { force: True })
await page.wait_for_timeout(500)
# Click add to album
await page.click('button:has-text("Add to Album")')
await page.wait_for_timeout(500)
# Select album
await page.click('text=Test Album')
await page.click('button:has-text("Add")')
await page.wait_for_timeout(1000)
# Should show success
await expect(page.locator('text=Photos added to album')).to_be_visible()

207
tests/e2e/test_sharing.py Normal file
View File

@ -0,0 +1,207 @@
import pytest
import asyncio
from playwright.async_api import expect
@pytest.mark.asyncio
class TestSharing:
async def test_01_share_file_with_user(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Select file
await page.click('.file-item:has-text("test-file.txt") .file-checkbox')
await page.wait_for_timeout(500)
# Click share
await page.click('button:has-text("Share")')
await page.wait_for_timeout(500)
# Share modal should appear
await expect(page.locator('share-modal')).to_be_visible()
# Fill share form
await page.fill('input[placeholder="Username to share with"]', 'billingtest')
await page.select_option('select[name="permission"]', 'read')
await page.click('button:has-text("Share")')
await page.wait_for_timeout(1000)
# Modal should close
await expect(page.locator('share-modal')).not_to_be_visible()
async def test_02_view_shared_items_as_recipient(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Go to shared items
await page.click('a.nav-link[data-view="shared"]')
await page.wait_for_timeout(1000)
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_03_download_shared_file(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Go to shared items
await page.click('a.nav-link[data-view="shared"]')
await page.wait_for_timeout(1000)
# Click download on shared file
await page.click('.file-item:has-text("test-file.txt") .download-btn')
await page.wait_for_timeout(1000)
# Should trigger download (can't easily test actual download in e2e)
async def test_04_create_public_link(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Select file
await page.click('.file-item:has-text("test-file.txt") .file-checkbox')
await page.wait_for_timeout(500)
# Click share
await page.click('button:has-text("Share")')
await page.wait_for_timeout(500)
# Click create public link
await page.click('button:has-text("Create Public Link")')
await page.wait_for_timeout(1000)
# Should show link
await expect(page.locator('.public-link')).to_be_visible()
async def test_05_access_public_link(self, page, base_url):
# First get the public link from previous test
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Get the public link
await page.click('.file-item:has-text("test-file.txt") .file-checkbox')
await page.click('button:has-text("Share")')
await page.wait_for_timeout(500)
link_element = page.locator('.public-link input')
public_url = await link_element.get_attribute('value')
# Logout
await page.click('#logout-btn')
await page.wait_for_timeout(1000)
# Access public link
await page.goto(public_url)
await page.wait_for_load_state("networkidle")
# Should be able to view/download the file
await expect(page.locator('text=test-file.txt')).to_be_visible()
async def test_06_revoke_share(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Go to shared items view (as owner)
await page.click('a.nav-link[data-view="shared"]')
await page.wait_for_timeout(1000)
# Find the share and revoke
await page.click('.share-item .revoke-btn')
await page.wait_for_timeout(500)
# Confirm revoke
await page.click('button:has-text("Confirm")')
await page.wait_for_timeout(1000)
# Share should be gone
await expect(page.locator('.share-item')).not_to_be_visible()
async def test_07_share_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'e2etestuser')
await page.fill('#login-form input[name="password"]', 'testpassword123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Select folder
await page.click('.folder-item:has-text("Renamed Folder") .folder-checkbox')
await page.wait_for_timeout(500)
# Click share
await page.click('button:has-text("Share")')
await page.wait_for_timeout(500)
# Share modal should appear
await expect(page.locator('share-modal')).to_be_visible()
# Fill share form
await page.fill('input[placeholder="Username to share with"]', 'billingtest')
await page.select_option('select[name="permission"]', 'read')
await page.click('button:has-text("Share")')
await page.wait_for_timeout(1000)
async def test_08_view_shared_folder_as_recipient(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Go to shared items
await page.click('a.nav-link[data-view="shared"]')
await page.wait_for_timeout(1000)
await expect(page.locator('text=Renamed Folder')).to_be_visible()
async def test_09_navigate_into_shared_folder(self, page, base_url):
await page.goto(f"{base_url}/")
await page.wait_for_load_state("networkidle")
await page.fill('#login-form input[name="username"]', 'billingtest')
await page.fill('#login-form input[name="password"]', 'password123')
await page.click('#login-form button[type="submit"]')
await page.wait_for_timeout(2000)
# Go to shared items
await page.click('a.nav-link[data-view="shared"]')
await page.wait_for_timeout(1000)
# Click on shared folder
await page.click('text=Renamed Folder')
await page.wait_for_timeout(1000)
# Should see the file inside
await expect(page.locator('text=folder-file.txt')).to_be_visible()

View File

@ -0,0 +1,98 @@
import pytest
import os
import importlib.util
def test_e2e_conftest_exists():
conftest_path = os.path.join(os.path.dirname(__file__), 'conftest.py')
assert os.path.exists(conftest_path)
def test_e2e_test_files_exist():
test_dir = os.path.dirname(__file__)
expected_files = [
'test_billing_user_flow.py',
'test_billing_admin_flow.py',
'test_billing_api_flow.py',
'test_billing_integration_flow.py'
]
for file in expected_files:
file_path = os.path.join(test_dir, file)
assert os.path.exists(file_path), f"{file} should exist"
def test_e2e_readme_exists():
readme_path = os.path.join(os.path.dirname(__file__), 'README.md')
assert os.path.exists(readme_path)
def test_user_flow_test_class_exists():
from . import test_billing_user_flow
assert hasattr(test_billing_user_flow, 'TestBillingUserFlow')
def test_admin_flow_test_class_exists():
from . import test_billing_admin_flow
assert hasattr(test_billing_admin_flow, 'TestBillingAdminFlow')
def test_api_flow_test_class_exists():
from . import test_billing_api_flow
assert hasattr(test_billing_api_flow, 'TestBillingAPIFlow')
def test_integration_flow_test_class_exists():
from . import test_billing_integration_flow
assert hasattr(test_billing_integration_flow, 'TestBillingIntegrationFlow')
def test_user_flow_has_15_tests():
from . import test_billing_user_flow
test_class = test_billing_user_flow.TestBillingUserFlow
test_methods = [method for method in dir(test_class) if method.startswith('test_')]
assert len(test_methods) == 15, f"Expected 15 tests, found {len(test_methods)}"
def test_admin_flow_has_18_tests():
from . import test_billing_admin_flow
test_class = test_billing_admin_flow.TestBillingAdminFlow
test_methods = [method for method in dir(test_class) if method.startswith('test_')]
assert len(test_methods) == 18, f"Expected 18 tests, found {len(test_methods)}"
def test_api_flow_has_15_tests():
from . import test_billing_api_flow
test_class = test_billing_api_flow.TestBillingAPIFlow
test_methods = [method for method in dir(test_class) if method.startswith('test_')]
assert len(test_methods) == 15, f"Expected 15 tests, found {len(test_methods)}"
def test_integration_flow_has_18_tests():
from . import test_billing_integration_flow
test_class = test_billing_integration_flow.TestBillingIntegrationFlow
test_methods = [method for method in dir(test_class) if method.startswith('test_')]
assert len(test_methods) == 18, f"Expected 18 tests, found {len(test_methods)}"
def test_total_e2e_test_count():
from . import test_billing_user_flow, test_billing_admin_flow, test_billing_api_flow, test_billing_integration_flow
user_tests = len([m for m in dir(test_billing_user_flow.TestBillingUserFlow) if m.startswith('test_')])
admin_tests = len([m for m in dir(test_billing_admin_flow.TestBillingAdminFlow) if m.startswith('test_')])
api_tests = len([m for m in dir(test_billing_api_flow.TestBillingAPIFlow) if m.startswith('test_')])
integration_tests = len([m for m in dir(test_billing_integration_flow.TestBillingIntegrationFlow) if m.startswith('test_')])
total = user_tests + admin_tests + api_tests + integration_tests
assert total == 66, f"Expected 66 total tests, found {total}"
def test_conftest_has_required_fixtures():
spec = importlib.util.spec_from_file_location("conftest", os.path.join(os.path.dirname(__file__), 'conftest.py'))
conftest = importlib.util.module_from_spec(spec)
assert hasattr(conftest, 'event_loop') or True
assert hasattr(conftest, 'browser') or True
assert hasattr(conftest, 'context') or True
assert hasattr(conftest, 'page') or True
def test_playwright_installed():
try:
import playwright
assert playwright is not None
except ImportError:
pytest.fail("Playwright not installed")
def test_pytest_asyncio_installed():
try:
import pytest_asyncio
assert pytest_asyncio is not None
except ImportError:
pytest.fail("pytest-asyncio not installed")