diff --git a/mywebdav/activity.py b/mywebdav/activity.py
index 4c497c8..8537e56 100644
--- a/mywebdav/activity.py
+++ b/mywebdav/activity.py
@@ -1,17 +1,18 @@
from typing import Optional
from .models import Activity, User
+
async def log_activity(
user: Optional[User],
action: str,
target_type: str,
target_id: int,
- ip_address: Optional[str] = None
+ ip_address: Optional[str] = None,
):
await Activity.create(
user=user,
action=action,
target_type=target_type,
target_id=target_id,
- ip_address=ip_address
+ ip_address=ip_address,
)
diff --git a/mywebdav/auth.py b/mywebdav/auth.py
index fe9bbc3..d6f43cc 100644
--- a/mywebdav/auth.py
+++ b/mywebdav/auth.py
@@ -9,20 +9,29 @@ import bcrypt
from .schemas import TokenData
from .settings import settings
from .models import User
-from .two_factor import verify_totp_code # Import verify_totp_code
+from .two_factor import verify_totp_code # Import verify_totp_code
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
def verify_password(plain_password, hashed_password):
- password_bytes = plain_password[:72].encode('utf-8')
- hashed_bytes = hashed_password.encode('utf-8') if isinstance(hashed_password, str) else hashed_password
+ password_bytes = plain_password[:72].encode("utf-8")
+ hashed_bytes = (
+ hashed_password.encode("utf-8")
+ if isinstance(hashed_password, str)
+ else hashed_password
+ )
return bcrypt.checkpw(password_bytes, hashed_bytes)
-def get_password_hash(password):
- password_bytes = password[:72].encode('utf-8')
- return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8')
-async def authenticate_user(username: str, password: str, two_factor_code: Optional[str] = None):
+def get_password_hash(password):
+ password_bytes = password[:72].encode("utf-8")
+ return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8")
+
+
+async def authenticate_user(
+ username: str, password: str, two_factor_code: Optional[str] = None
+):
user = await User.get_or_none(username=username)
if not user:
return None
@@ -33,19 +42,29 @@ async def authenticate_user(username: str, password: str, two_factor_code: Optio
if not two_factor_code:
return {"user": user, "2fa_required": True}
if not verify_totp_code(user.two_factor_secret, two_factor_code):
- return None # 2FA code is incorrect
+ return None # 2FA code is incorrect
return {"user": user, "2fa_required": False}
-def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, two_factor_verified: bool = False):
+
+def create_access_token(
+ data: dict,
+ expires_delta: Optional[timedelta] = None,
+ two_factor_verified: bool = False,
+):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
- expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
+ expire = datetime.now(timezone.utc) + timedelta(
+ minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
+ )
to_encode.update({"exp": expire, "2fa_verified": two_factor_verified})
- encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
+ encoded_jwt = jwt.encode(
+ to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
+ )
return encoded_jwt
+
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -53,31 +72,45 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
headers={"WWW-Authenticate": "Bearer"},
)
try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
+ payload = jwt.decode(
+ token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
+ )
username: str = payload.get("sub")
two_factor_verified: bool = payload.get("2fa_verified", False)
if username is None:
raise credentials_exception
- token_data = TokenData(username=username, two_factor_verified=two_factor_verified)
+ token_data = TokenData(
+ username=username, two_factor_verified=two_factor_verified
+ )
except JWTError:
raise credentials_exception
user = await User.get_or_none(username=token_data.username)
if user is None:
raise credentials_exception
- user.token_data = token_data # Attach token_data to user for easy access
+ user.token_data = token_data # Attach token_data to user for easy access
return user
+
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
+
async def get_current_verified_user(current_user: User = Depends(get_current_user)):
if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="2FA required and not verified")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="2FA required and not verified",
+ )
return current_user
-async def get_current_admin_user(current_user: User = Depends(get_current_verified_user)):
+
+async def get_current_admin_user(
+ current_user: User = Depends(get_current_verified_user),
+):
if not current_user.is_superuser:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
+ )
return current_user
diff --git a/mywebdav/billing/invoice_generator.py b/mywebdav/billing/invoice_generator.py
index 86a73c0..e23313c 100644
--- a/mywebdav/billing/invoice_generator.py
+++ b/mywebdav/billing/invoice_generator.py
@@ -2,14 +2,17 @@ from datetime import datetime, date, timedelta, timezone
from decimal import Decimal
from typing import Optional
from calendar import monthrange
-from .models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
+from .models import Invoice, InvoiceLineItem, PricingConfig, UserSubscription
from .usage_tracker import UsageTracker
from .stripe_client import StripeClient
from ..models import User
+
class InvoiceGenerator:
@staticmethod
- async def generate_monthly_invoice(user: User, year: int, month: int) -> Optional[Invoice]:
+ async def generate_monthly_invoice(
+ user: User, year: int, month: int
+ ) -> Optional[Invoice]:
period_start = date(year, month, 1)
days_in_month = monthrange(year, month)[1]
period_end = date(year, month, days_in_month)
@@ -19,19 +22,24 @@ class InvoiceGenerator:
pricing = await PricingConfig.all()
pricing_dict = {p.config_key: p.config_value for p in pricing}
- storage_price_per_gb = pricing_dict.get('storage_per_gb_month', Decimal('0.0045'))
- bandwidth_price_per_gb = pricing_dict.get('bandwidth_egress_per_gb', Decimal('0.009'))
- free_storage_gb = pricing_dict.get('free_tier_storage_gb', Decimal('15'))
- free_bandwidth_gb = pricing_dict.get('free_tier_bandwidth_gb', Decimal('15'))
- tax_rate = pricing_dict.get('tax_rate_default', Decimal('0'))
+ storage_price_per_gb = pricing_dict.get(
+ "storage_per_gb_month", Decimal("0.0045")
+ )
+ bandwidth_price_per_gb = pricing_dict.get(
+ "bandwidth_egress_per_gb", Decimal("0.009")
+ )
+ free_storage_gb = pricing_dict.get("free_tier_storage_gb", Decimal("15"))
+ free_bandwidth_gb = pricing_dict.get("free_tier_bandwidth_gb", Decimal("15"))
+ tax_rate = pricing_dict.get("tax_rate_default", Decimal("0"))
- storage_gb = Decimal(str(usage['storage_gb_avg']))
- bandwidth_gb = Decimal(str(usage['bandwidth_down_gb']))
+ storage_gb = Decimal(str(usage["storage_gb_avg"]))
+ bandwidth_gb = Decimal(str(usage["bandwidth_down_gb"]))
- billable_storage = max(Decimal('0'), storage_gb - free_storage_gb)
- billable_bandwidth = max(Decimal('0'), bandwidth_gb - free_bandwidth_gb)
+ billable_storage = max(Decimal("0"), storage_gb - free_storage_gb)
+ billable_bandwidth = max(Decimal("0"), bandwidth_gb - free_bandwidth_gb)
import math
+
billable_storage_rounded = Decimal(math.ceil(float(billable_storage)))
billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth)))
@@ -65,9 +73,9 @@ class InvoiceGenerator:
"usage": usage,
"pricing": {
"storage_per_gb": float(storage_price_per_gb),
- "bandwidth_per_gb": float(bandwidth_price_per_gb)
- }
- }
+ "bandwidth_per_gb": float(bandwidth_price_per_gb),
+ },
+ },
)
if billable_storage_rounded > 0:
@@ -78,7 +86,10 @@ class InvoiceGenerator:
unit_price=storage_price_per_gb,
amount=storage_cost,
item_type="storage",
- metadata={"avg_gb": float(storage_gb), "free_gb": float(free_storage_gb)}
+ metadata={
+ "avg_gb": float(storage_gb),
+ "free_gb": float(free_storage_gb),
+ },
)
if billable_bandwidth_rounded > 0:
@@ -89,7 +100,10 @@ class InvoiceGenerator:
unit_price=bandwidth_price_per_gb,
amount=bandwidth_cost,
item_type="bandwidth",
- metadata={"total_gb": float(bandwidth_gb), "free_gb": float(free_bandwidth_gb)}
+ metadata={
+ "total_gb": float(bandwidth_gb),
+ "free_gb": float(free_bandwidth_gb),
+ },
)
if subscription and subscription.stripe_customer_id:
@@ -100,7 +114,7 @@ class InvoiceGenerator:
"amount": item.amount,
"currency": "usd",
"description": item.description,
- "metadata": item.metadata or {}
+ "metadata": item.metadata or {},
}
for item in line_items
]
@@ -109,7 +123,7 @@ class InvoiceGenerator:
customer_id=subscription.stripe_customer_id,
description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}",
line_items=stripe_line_items,
- metadata={"mywebdav_invoice_id": str(invoice.id)}
+ metadata={"mywebdav_invoice_id": str(invoice.id)},
)
invoice.stripe_invoice_id = stripe_invoice.id
@@ -135,8 +149,11 @@ class InvoiceGenerator:
# Send invoice email
from ..mail import queue_email
+
line_items = await invoice.line_items.all()
- items_text = "\n".join([f"- {item.description}: ${item.amount}" for item in line_items])
+ items_text = "\n".join(
+ [f"- {item.description}: ${item.amount}" for item in line_items]
+ )
body = f"""Dear {invoice.user.username},
Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available.
@@ -174,7 +191,7 @@ The MyWebdav Team
to_email=invoice.user.email,
subject=f"Your MyWebdav Invoice {invoice.invoice_number}",
body=body,
- html=html
+ html=html,
)
return invoice
diff --git a/mywebdav/billing/models.py b/mywebdav/billing/models.py
index 40ab5a0..d298fb0 100644
--- a/mywebdav/billing/models.py
+++ b/mywebdav/billing/models.py
@@ -1,8 +1,8 @@
from tortoise import fields, models
-from decimal import Decimal
+
class SubscriptionPlan(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=100, unique=True)
display_name = fields.CharField(max_length=255)
description = fields.TextField(null=True)
@@ -18,10 +18,13 @@ class SubscriptionPlan(models.Model):
class Meta:
table = "subscription_plans"
+
class UserSubscription(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="subscription")
- plan = fields.ForeignKeyField("billing.SubscriptionPlan", related_name="subscriptions", null=True)
+ plan = fields.ForeignKeyField(
+ "billing.SubscriptionPlan", related_name="subscriptions", null=True
+ )
billing_type = fields.CharField(max_length=20, default="pay_as_you_go")
stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True)
stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True)
@@ -35,14 +38,15 @@ class UserSubscription(models.Model):
class Meta:
table = "user_subscriptions"
+
class UsageRecord(models.Model):
- id = fields.BigIntField(pk=True)
+ id = fields.BigIntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="usage_records")
- record_type = fields.CharField(max_length=50, index=True)
+ record_type = fields.CharField(max_length=50, db_index=True)
amount_bytes = fields.BigIntField()
resource_type = fields.CharField(max_length=50, null=True)
resource_id = fields.IntField(null=True)
- timestamp = fields.DatetimeField(auto_now_add=True, index=True)
+ timestamp = fields.DatetimeField(auto_now_add=True, db_index=True)
idempotency_key = fields.CharField(max_length=255, unique=True, null=True)
metadata = fields.JSONField(null=True)
@@ -50,8 +54,9 @@ class UsageRecord(models.Model):
table = "usage_records"
indexes = [("user_id", "record_type", "timestamp")]
+
class UsageAggregate(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="usage_aggregates")
date = fields.DateField()
storage_bytes_avg = fields.BigIntField(default=0)
@@ -64,8 +69,9 @@ class UsageAggregate(models.Model):
table = "usage_aggregates"
unique_together = (("user", "date"),)
+
class Invoice(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="invoices")
invoice_number = fields.CharField(max_length=50, unique=True)
stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True)
@@ -75,10 +81,10 @@ class Invoice(models.Model):
tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0)
total = fields.DecimalField(max_digits=10, decimal_places=4)
currency = fields.CharField(max_length=3, default="USD")
- status = fields.CharField(max_length=50, default="draft", index=True)
+ status = fields.CharField(max_length=50, default="draft", db_index=True)
due_date = fields.DateField(null=True)
paid_at = fields.DatetimeField(null=True)
- created_at = fields.DatetimeField(auto_now_add=True, index=True)
+ created_at = fields.DatetimeField(auto_now_add=True, db_index=True)
updated_at = fields.DatetimeField(auto_now=True)
metadata = fields.JSONField(null=True)
@@ -86,8 +92,9 @@ class Invoice(models.Model):
table = "invoices"
indexes = [("user_id", "status", "created_at")]
+
class InvoiceLineItem(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items")
description = fields.TextField()
quantity = fields.DecimalField(max_digits=15, decimal_places=6)
@@ -100,20 +107,24 @@ class InvoiceLineItem(models.Model):
class Meta:
table = "invoice_line_items"
+
class PricingConfig(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
config_key = fields.CharField(max_length=100, unique=True)
config_value = fields.DecimalField(max_digits=10, decimal_places=6)
description = fields.TextField(null=True)
unit = fields.CharField(max_length=50, null=True)
- updated_by = fields.ForeignKeyField("models.User", related_name="pricing_updates", null=True)
+ updated_by = fields.ForeignKeyField(
+ "models.User", related_name="pricing_updates", null=True
+ )
updated_at = fields.DatetimeField(auto_now=True)
class Meta:
table = "pricing_config"
+
class PaymentMethod(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="payment_methods")
stripe_payment_method_id = fields.CharField(max_length=255)
type = fields.CharField(max_length=50)
@@ -128,9 +139,12 @@ class PaymentMethod(models.Model):
class Meta:
table = "payment_methods"
+
class BillingEvent(models.Model):
- id = fields.BigIntField(pk=True)
- user = fields.ForeignKeyField("models.User", related_name="billing_events", null=True)
+ id = fields.BigIntField(primary_key=True)
+ user = fields.ForeignKeyField(
+ "models.User", related_name="billing_events", null=True
+ )
event_type = fields.CharField(max_length=100)
stripe_event_id = fields.CharField(max_length=255, unique=True, null=True)
data = fields.JSONField(null=True)
diff --git a/mywebdav/billing/scheduler.py b/mywebdav/billing/scheduler.py
index 7e086e0..cab5f28 100644
--- a/mywebdav/billing/scheduler.py
+++ b/mywebdav/billing/scheduler.py
@@ -1,7 +1,6 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from datetime import datetime, date, timedelta
-import asyncio
from .usage_tracker import UsageTracker
from .invoice_generator import InvoiceGenerator
@@ -9,6 +8,7 @@ from ..models import User
scheduler = AsyncIOScheduler()
+
async def aggregate_daily_usage_for_all_users():
users = await User.filter(is_active=True).all()
yesterday = date.today() - timedelta(days=1)
@@ -19,6 +19,7 @@ async def aggregate_daily_usage_for_all_users():
except Exception as e:
print(f"Failed to aggregate usage for user {user.id}: {e}")
+
async def generate_monthly_invoices():
now = datetime.now()
last_month = now.month - 1 if now.month > 1 else 12
@@ -28,19 +29,22 @@ async def generate_monthly_invoices():
for user in users:
try:
- invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, last_month)
+ invoice = await InvoiceGenerator.generate_monthly_invoice(
+ user, year, last_month
+ )
if invoice:
await InvoiceGenerator.finalize_invoice(invoice)
except Exception as e:
print(f"Failed to generate invoice for user {user.id}: {e}")
+
def start_scheduler():
scheduler.add_job(
aggregate_daily_usage_for_all_users,
CronTrigger(hour=1, minute=0),
id="aggregate_daily_usage",
name="Aggregate daily usage for all users",
- replace_existing=True
+ replace_existing=True,
)
scheduler.add_job(
@@ -48,10 +52,11 @@ def start_scheduler():
CronTrigger(day=1, hour=2, minute=0),
id="generate_monthly_invoices",
name="Generate monthly invoices",
- replace_existing=True
+ replace_existing=True,
)
scheduler.start()
+
def stop_scheduler():
scheduler.shutdown()
diff --git a/mywebdav/billing/stripe_client.py b/mywebdav/billing/stripe_client.py
index 6c98f86..58b8d53 100644
--- a/mywebdav/billing/stripe_client.py
+++ b/mywebdav/billing/stripe_client.py
@@ -1,8 +1,8 @@
import stripe
-from decimal import Decimal
-from typing import Optional, Dict, Any
+from typing import Dict
from ..settings import settings
+
class StripeClient:
@staticmethod
def _ensure_api_key():
@@ -11,13 +11,12 @@ class StripeClient:
stripe.api_key = settings.STRIPE_SECRET_KEY
else:
raise ValueError("Stripe API key not configured")
+
@staticmethod
async def create_customer(email: str, name: str, metadata: Dict = None) -> str:
StripeClient._ensure_api_key()
customer = stripe.Customer.create(
- email=email,
- name=name,
- metadata=metadata or {}
+ email=email, name=name, metadata=metadata or {}
)
return customer.id
@@ -26,7 +25,7 @@ class StripeClient:
amount: int,
currency: str = "usd",
customer_id: str = None,
- metadata: Dict = None
+ metadata: Dict = None,
) -> stripe.PaymentIntent:
StripeClient._ensure_api_key()
return stripe.PaymentIntent.create(
@@ -34,32 +33,29 @@ class StripeClient:
currency=currency,
customer=customer_id,
metadata=metadata or {},
- automatic_payment_methods={"enabled": True}
+ automatic_payment_methods={"enabled": True},
)
@staticmethod
async def create_invoice(
- customer_id: str,
- description: str,
- line_items: list,
- metadata: Dict = None
+ customer_id: str, description: str, line_items: list, metadata: Dict = None
) -> stripe.Invoice:
StripeClient._ensure_api_key()
for item in line_items:
stripe.InvoiceItem.create(
customer=customer_id,
- amount=int(item['amount'] * 100),
- currency=item.get('currency', 'usd'),
- description=item['description'],
- metadata=item.get('metadata', {})
+ amount=int(item["amount"] * 100),
+ currency=item.get("currency", "usd"),
+ description=item["description"],
+ metadata=item.get("metadata", {}),
)
invoice = stripe.Invoice.create(
customer=customer_id,
description=description,
auto_advance=True,
- collection_method='charge_automatically',
- metadata=metadata or {}
+ collection_method="charge_automatically",
+ metadata=metadata or {},
)
return invoice
@@ -76,18 +72,15 @@ class StripeClient:
@staticmethod
async def attach_payment_method(
- payment_method_id: str,
- customer_id: str
+ payment_method_id: str, customer_id: str
) -> stripe.PaymentMethod:
StripeClient._ensure_api_key()
payment_method = stripe.PaymentMethod.attach(
- payment_method_id,
- customer=customer_id
+ payment_method_id, customer=customer_id
)
stripe.Customer.modify(
- customer_id,
- invoice_settings={'default_payment_method': payment_method_id}
+ customer_id, invoice_settings={"default_payment_method": payment_method_id}
)
return payment_method
@@ -95,22 +88,15 @@ class StripeClient:
@staticmethod
async def list_payment_methods(customer_id: str, type: str = "card"):
StripeClient._ensure_api_key()
- return stripe.PaymentMethod.list(
- customer=customer_id,
- type=type
- )
+ return stripe.PaymentMethod.list(customer=customer_id, type=type)
@staticmethod
async def create_subscription(
- customer_id: str,
- price_id: str,
- metadata: Dict = None
+ customer_id: str, price_id: str, metadata: Dict = None
) -> stripe.Subscription:
StripeClient._ensure_api_key()
return stripe.Subscription.create(
- customer=customer_id,
- items=[{'price': price_id}],
- metadata=metadata or {}
+ customer=customer_id, items=[{"price": price_id}], metadata=metadata or {}
)
@staticmethod
diff --git a/mywebdav/billing/usage_tracker.py b/mywebdav/billing/usage_tracker.py
index f110d29..22eb2ec 100644
--- a/mywebdav/billing/usage_tracker.py
+++ b/mywebdav/billing/usage_tracker.py
@@ -1,11 +1,10 @@
import uuid
from datetime import datetime, date, timezone
-from decimal import Decimal
-from typing import Optional
from tortoise.transactions import in_transaction
from .models import UsageRecord, UsageAggregate
from ..models import User
+
class UsageTracker:
@staticmethod
async def track_storage(
@@ -13,7 +12,7 @@ class UsageTracker:
amount_bytes: int,
resource_type: str = None,
resource_id: int = None,
- metadata: dict = None
+ metadata: dict = None,
):
idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
@@ -24,7 +23,7 @@ class UsageTracker:
resource_type=resource_type,
resource_id=resource_id,
idempotency_key=idempotency_key,
- metadata=metadata
+ metadata=metadata,
)
@staticmethod
@@ -34,7 +33,7 @@ class UsageTracker:
direction: str = "down",
resource_type: str = None,
resource_id: int = None,
- metadata: dict = None
+ metadata: dict = None,
):
record_type = f"bandwidth_{direction}"
idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
@@ -46,7 +45,7 @@ class UsageTracker:
resource_type=resource_type,
resource_id=resource_id,
idempotency_key=idempotency_key,
- metadata=metadata
+ metadata=metadata,
)
@staticmethod
@@ -61,24 +60,26 @@ class UsageTracker:
user=user,
record_type="storage",
timestamp__gte=start_of_day,
- timestamp__lte=end_of_day
+ timestamp__lte=end_of_day,
).all()
- storage_avg = sum(r.amount_bytes for r in storage_records) // max(len(storage_records), 1)
+ storage_avg = sum(r.amount_bytes for r in storage_records) // max(
+ len(storage_records), 1
+ )
storage_peak = max((r.amount_bytes for r in storage_records), default=0)
bandwidth_up = await UsageRecord.filter(
user=user,
record_type="bandwidth_up",
timestamp__gte=start_of_day,
- timestamp__lte=end_of_day
+ timestamp__lte=end_of_day,
).all()
bandwidth_down = await UsageRecord.filter(
user=user,
record_type="bandwidth_down",
timestamp__gte=start_of_day,
- timestamp__lte=end_of_day
+ timestamp__lte=end_of_day,
).all()
total_up = sum(r.amount_bytes for r in bandwidth_up)
@@ -92,8 +93,8 @@ class UsageTracker:
"storage_bytes_avg": storage_avg,
"storage_bytes_peak": storage_peak,
"bandwidth_up_bytes": total_up,
- "bandwidth_down_bytes": total_down
- }
+ "bandwidth_down_bytes": total_down,
+ },
)
if not created:
@@ -108,6 +109,7 @@ class UsageTracker:
@staticmethod
async def get_current_storage(user: User) -> int:
from ..models import File
+
files = await File.filter(owner=user, is_deleted=False).all()
return sum(f.size for f in files)
@@ -121,9 +123,7 @@ class UsageTracker:
end_date = date(year, month, last_day)
aggregates = await UsageAggregate.filter(
- user=user,
- date__gte=start_date,
- date__lte=end_date
+ user=user, date__gte=start_date, date__lte=end_date
).all()
if not aggregates:
@@ -132,7 +132,7 @@ class UsageTracker:
"storage_gb_peak": 0,
"bandwidth_up_gb": 0,
"bandwidth_down_gb": 0,
- "total_bandwidth_gb": 0
+ "total_bandwidth_gb": 0,
}
storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates)
@@ -145,5 +145,5 @@ class UsageTracker:
"storage_gb_peak": round(storage_peak / (1024**3), 4),
"bandwidth_up_gb": round(bandwidth_up / (1024**3), 4),
"bandwidth_down_gb": round(bandwidth_down / (1024**3), 4),
- "total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4)
+ "total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4),
}
diff --git a/mywebdav/legal.py b/mywebdav/legal.py
index 90868a0..ea7afad 100644
--- a/mywebdav/legal.py
+++ b/mywebdav/legal.py
@@ -7,7 +7,7 @@ for various legal policies required for a European cloud storage provider.
from abc import ABC, abstractmethod
from datetime import datetime
-from typing import Dict, Any
+from typing import Dict
class LegalDocument(ABC):
@@ -17,10 +17,13 @@ class LegalDocument(ABC):
Provides common structure and methods for generating legal content.
"""
- def __init__(self, company_name: str = "MyWebdav Technologies",
- last_updated: str = None,
- contact_email: str = "legal@mywebdav.eu",
- website: str = "https://mywebdav.eu"):
+ def __init__(
+ self,
+ company_name: str = "MyWebdav Technologies",
+ last_updated: str = None,
+ contact_email: str = "legal@mywebdav.eu",
+ website: str = "https://mywebdav.eu",
+ ):
self.company_name = company_name
self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y")
self.contact_email = contact_email
@@ -853,20 +856,21 @@ class ContactComplaintMechanism(LegalDocument):
def get_all_legal_documents() -> Dict[str, LegalDocument]:
"""Return a dictionary of all legal document instances."""
return {
- 'privacy_policy': PrivacyPolicy(),
- 'terms_of_service': TermsOfService(),
- 'security_policy': SecurityPolicy(),
- 'cookie_policy': CookiePolicy(),
- 'data_processing_agreement': DataProcessingAgreement(),
- 'compliance_statement': ComplianceStatement(),
- 'data_portability_deletion_policy': DataPortabilityDeletionPolicy(),
- 'contact_complaint_mechanism': ContactComplaintMechanism(),
+ "privacy_policy": PrivacyPolicy(),
+ "terms_of_service": TermsOfService(),
+ "security_policy": SecurityPolicy(),
+ "cookie_policy": CookiePolicy(),
+ "data_processing_agreement": DataProcessingAgreement(),
+ "compliance_statement": ComplianceStatement(),
+ "data_portability_deletion_policy": DataPortabilityDeletionPolicy(),
+ "contact_complaint_mechanism": ContactComplaintMechanism(),
}
def generate_legal_documents(output_dir: str = "static/legal"):
"""Generate all legal documents as Markdown and HTML files."""
import os
+
os.makedirs(output_dir, exist_ok=True)
documents = get_all_legal_documents()
@@ -875,17 +879,17 @@ def generate_legal_documents(output_dir: str = "static/legal"):
# Generate Markdown
md_filename = f"{doc_name}.md"
md_path = os.path.join(output_dir, md_filename)
- with open(md_path, 'w') as f:
+ with open(md_path, "w") as f:
f.write(doc.to_markdown())
# Generate HTML
html_filename = f"{doc_name}.html"
html_path = os.path.join(output_dir, html_filename)
- with open(html_path, 'w') as f:
+ with open(html_path, "w") as f:
f.write(doc.to_html())
print(f"Generated {md_filename} and {html_filename}")
if __name__ == "__main__":
- generate_legal_documents()
\ No newline at end of file
+ generate_legal_documents()
diff --git a/mywebdav/mail.py b/mywebdav/mail.py
index ba6aac3..97ff1e8 100644
--- a/mywebdav/mail.py
+++ b/mywebdav/mail.py
@@ -2,17 +2,26 @@ import asyncio
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
-from typing import Optional, Dict, Any
+from typing import Optional
from .settings import settings
+
class EmailTask:
- def __init__(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
+ def __init__(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html: Optional[str] = None,
+ **kwargs,
+ ):
self.to_email = to_email
self.subject = subject
self.body = body
self.html = html
self.kwargs = kwargs
+
class EmailService:
def __init__(self):
self.queue = asyncio.Queue()
@@ -38,7 +47,14 @@ class EmailService:
except asyncio.CancelledError:
pass
- async def send_email(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
+ async def send_email(
+ self,
+ to_email: str,
+ subject: str,
+ body: str,
+ html: Optional[str] = None,
+ **kwargs,
+ ):
"""Queue an email for sending"""
task = EmailTask(to_email, subject, body, html, **kwargs)
await self.queue.put(task)
@@ -60,22 +76,26 @@ class EmailService:
async def _send_email_task(self, task: EmailTask):
"""Send a single email task"""
- if not settings.SMTP_HOST or not settings.SMTP_USERNAME or not settings.SMTP_PASSWORD:
+ if (
+ not settings.SMTP_HOST
+ or not settings.SMTP_USERNAME
+ or not settings.SMTP_PASSWORD
+ ):
print("SMTP not configured, skipping email send")
return
- msg = MIMEMultipart('alternative')
- msg['From'] = settings.SMTP_SENDER_EMAIL
- msg['To'] = task.to_email
- msg['Subject'] = task.subject
+ msg = MIMEMultipart("alternative")
+ msg["From"] = settings.SMTP_SENDER_EMAIL
+ msg["To"] = task.to_email
+ msg["Subject"] = task.subject
# Add text part
- text_part = MIMEText(task.body, 'plain')
+ text_part = MIMEText(task.body, "plain")
msg.attach(text_part)
# Add HTML part if provided
if task.html:
- html_part = MIMEText(task.html, 'html')
+ html_part = MIMEText(task.html, "html")
msg.attach(html_part)
try:
@@ -86,7 +106,7 @@ class EmailService:
port=settings.SMTP_PORT,
username=settings.SMTP_USERNAME,
password=settings.SMTP_PASSWORD,
- use_tls=True
+ use_tls=True,
) as smtp:
await smtp.send_message(msg)
print(f"Email sent to {task.to_email}")
@@ -99,7 +119,10 @@ class EmailService:
try:
await smtp.starttls()
except Exception as tls_error:
- if "already using" in str(tls_error).lower() or "tls" in str(tls_error).lower():
+ if (
+ "already using" in str(tls_error).lower()
+ or "tls" in str(tls_error).lower()
+ ):
# Connection is already using TLS, proceed without starttls
pass
else:
@@ -111,14 +134,23 @@ class EmailService:
print(f"Failed to send email to {task.to_email}: {e}")
raise # Re-raise to let caller handle
+
# Global email service instance
email_service = EmailService()
+
# Convenience functions
-async def send_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
+async def send_email(
+ to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
+):
"""Send an email asynchronously"""
await email_service.send_email(to_email, subject, body, html, **kwargs)
-def queue_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
+
+def queue_email(
+ to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
+):
"""Queue an email for sending (fire and forget)"""
- asyncio.create_task(email_service.send_email(to_email, subject, body, html, **kwargs))
\ No newline at end of file
+ asyncio.create_task(
+ email_service.send_email(to_email, subject, body, html, **kwargs)
+ )
diff --git a/mywebdav/main.py b/mywebdav/main.py
index 5e7a40d..9f3b211 100644
--- a/mywebdav/main.py
+++ b/mywebdav/main.py
@@ -2,18 +2,33 @@ import argparse
import uvicorn
import logging
from contextlib import asynccontextmanager
-from fastapi import FastAPI, Request, status, HTTPException
+from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from tortoise.contrib.fastapi import register_tortoise
from .settings import settings
-from .routers import auth, users, folders, files, shares, search, admin, starred, billing, admin_billing
+from .routers import (
+ auth,
+ users,
+ folders,
+ files,
+ shares,
+ search,
+ admin,
+ starred,
+ billing,
+ admin_billing,
+)
from . import webdav
from .schemas import ErrorResponse
+from .middleware.usage_tracking import UsageTrackingMiddleware
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
logger = logging.getLogger(__name__)
+
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up...")
@@ -21,6 +36,7 @@ async def lifespan(app: FastAPI):
from .billing.scheduler import start_scheduler
from .billing.models import PricingConfig
from .mail import email_service
+
start_scheduler()
logger.info("Billing scheduler started")
await email_service.start()
@@ -28,28 +44,61 @@ async def lifespan(app: FastAPI):
pricing_count = await PricingConfig.all().count()
if pricing_count == 0:
from decimal import Decimal
- await PricingConfig.create(config_key='storage_per_gb_month', config_value=Decimal('0.0045'), description='Storage cost per GB per month', unit='per_gb_month')
- await PricingConfig.create(config_key='bandwidth_egress_per_gb', config_value=Decimal('0.009'), description='Bandwidth egress cost per GB', unit='per_gb')
- await PricingConfig.create(config_key='bandwidth_ingress_per_gb', config_value=Decimal('0.0'), description='Bandwidth ingress cost per GB (free)', unit='per_gb')
- await PricingConfig.create(config_key='free_tier_storage_gb', config_value=Decimal('15'), description='Free tier storage in GB', unit='gb')
- await PricingConfig.create(config_key='free_tier_bandwidth_gb', config_value=Decimal('15'), description='Free tier bandwidth in GB per month', unit='gb')
- await PricingConfig.create(config_key='tax_rate_default', config_value=Decimal('0.0'), description='Default tax rate (0 = no tax)', unit='percentage')
+
+ await PricingConfig.create(
+ config_key="storage_per_gb_month",
+ config_value=Decimal("0.0045"),
+ description="Storage cost per GB per month",
+ unit="per_gb_month",
+ )
+ await PricingConfig.create(
+ config_key="bandwidth_egress_per_gb",
+ config_value=Decimal("0.009"),
+ description="Bandwidth egress cost per GB",
+ unit="per_gb",
+ )
+ await PricingConfig.create(
+ config_key="bandwidth_ingress_per_gb",
+ config_value=Decimal("0.0"),
+ description="Bandwidth ingress cost per GB (free)",
+ unit="per_gb",
+ )
+ await PricingConfig.create(
+ config_key="free_tier_storage_gb",
+ config_value=Decimal("15"),
+ description="Free tier storage in GB",
+ unit="gb",
+ )
+ await PricingConfig.create(
+ config_key="free_tier_bandwidth_gb",
+ config_value=Decimal("15"),
+ description="Free tier bandwidth in GB per month",
+ unit="gb",
+ )
+ await PricingConfig.create(
+ config_key="tax_rate_default",
+ config_value=Decimal("0.0"),
+ description="Default tax rate (0 = no tax)",
+ unit="percentage",
+ )
logger.info("Default pricing configuration initialized")
yield
from .billing.scheduler import stop_scheduler
+
stop_scheduler()
logger.info("Billing scheduler stopped")
await email_service.stop()
logger.info("Email service stopped")
print("Shutting down...")
+
app = FastAPI(
title="MyWebdav Cloud Storage",
description="A commercial cloud storage web application",
version="0.1.0",
- lifespan=lifespan
+ lifespan=lifespan,
)
app.include_router(auth.router)
@@ -64,8 +113,6 @@ app.include_router(billing.router)
app.include_router(admin_billing.router)
app.include_router(webdav.router)
-from .middleware.usage_tracking import UsageTrackingMiddleware
-
app.add_middleware(UsageTrackingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
@@ -73,35 +120,39 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
register_tortoise(
app,
db_url=settings.DATABASE_URL,
- modules={
- "models": ["mywebdav.models"],
- "billing": ["mywebdav.billing.models"]
- },
+ modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
generate_schemas=True,
add_exception_handlers=True,
)
+
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
- logger.error(f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}")
+ logger.error(
+ f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}"
+ )
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(),
)
-@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse
+@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse
async def read_root():
with open("static/index.html", "r") as f:
return f.read()
+
def main():
parser = argparse.ArgumentParser(description="Run the MyWebdav application.")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address to bind to")
+ parser.add_argument(
+ "--host", type=str, default="0.0.0.0", help="Host address to bind to"
+ )
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)
+
if __name__ == "__main__":
main()
diff --git a/mywebdav/middleware/usage_tracking.py b/mywebdav/middleware/usage_tracking.py
index 854e709..274b3ef 100644
--- a/mywebdav/middleware/usage_tracking.py
+++ b/mywebdav/middleware/usage_tracking.py
@@ -2,31 +2,35 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from ..billing.usage_tracker import UsageTracker
+
class UsageTrackingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
- if hasattr(request.state, 'user') and request.state.user:
+ if hasattr(request.state, "user") and request.state.user:
user = request.state.user
- if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path:
- content_length = response.headers.get('content-length')
+ if (
+ request.method in ["POST", "PUT"]
+ and "/files/upload" in request.url.path
+ ):
+ content_length = response.headers.get("content-length")
if content_length:
await UsageTracker.track_bandwidth(
user=user,
amount_bytes=int(content_length),
- direction='up',
- metadata={'path': request.url.path}
+ direction="up",
+ metadata={"path": request.url.path},
)
- elif request.method == 'GET' and '/files/download' in request.url.path:
- content_length = response.headers.get('content-length')
+ elif request.method == "GET" and "/files/download" in request.url.path:
+ content_length = response.headers.get("content-length")
if content_length:
await UsageTracker.track_bandwidth(
user=user,
amount_bytes=int(content_length),
- direction='down',
- metadata={'path': request.url.path}
+ direction="down",
+ metadata={"path": request.url.path},
)
return response
diff --git a/mywebdav/models.py b/mywebdav/models.py
index fe6027a..11c9cf0 100644
--- a/mywebdav/models.py
+++ b/mywebdav/models.py
@@ -1,9 +1,9 @@
from tortoise import fields, models
from tortoise.contrib.pydantic import pydantic_model_creator
-from datetime import datetime
+
class User(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
username = fields.CharField(max_length=20, unique=True)
email = fields.CharField(max_length=255, unique=True)
hashed_password = fields.CharField(max_length=255)
@@ -11,7 +11,9 @@ class User(models.Model):
is_superuser = fields.BooleanField(default=False)
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
- storage_quota_bytes = fields.BigIntField(default=10 * 1024 * 1024 * 1024) # 10 GB default
+ storage_quota_bytes = fields.BigIntField(
+ default=10 * 1024 * 1024 * 1024
+ ) # 10 GB default
used_storage_bytes = fields.BigIntField(default=0)
plan_type = fields.CharField(max_length=50, default="free")
two_factor_secret = fields.CharField(max_length=255, null=True)
@@ -24,11 +26,16 @@ class User(models.Model):
def __str__(self):
return self.username
+
class Folder(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255)
- parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField("models.Folder", related_name="children", null=True)
- owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="folders")
+ parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField(
+ "models.Folder", related_name="children", null=True
+ )
+ owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="folders"
+ )
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False)
@@ -36,21 +43,28 @@ class Folder(models.Model):
class Meta:
table = "folders"
- unique_together = (("name", "parent", "owner"),) # Ensure unique folder names within a parent for an owner
+ unique_together = (
+ ("name", "parent", "owner"),
+ ) # Ensure unique folder names within a parent for an owner
def __str__(self):
return self.name
+
class File(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255)
- path = fields.CharField(max_length=1024) # Internal storage path
+ path = fields.CharField(max_length=1024) # Internal storage path
size = fields.BigIntField()
mime_type = fields.CharField(max_length=255)
- file_hash = fields.CharField(max_length=64, null=True) # SHA-256
+ file_hash = fields.CharField(max_length=64, null=True) # SHA-256
thumbnail_path = fields.CharField(max_length=1024, null=True)
- parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="files", null=True)
- owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="files")
+ parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
+ "models.Folder", related_name="files", null=True
+ )
+ owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="files"
+ )
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False)
@@ -60,14 +74,19 @@ class File(models.Model):
class Meta:
table = "files"
- unique_together = (("name", "parent", "owner"),) # Ensure unique file names within a parent for an owner
+ unique_together = (
+ ("name", "parent", "owner"),
+ ) # Ensure unique file names within a parent for an owner
def __str__(self):
return self.name
+
class FileVersion(models.Model):
- id = fields.IntField(pk=True)
- file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="versions")
+ id = fields.IntField(primary_key=True)
+ file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
+ "models.File", related_name="versions"
+ )
version_path = fields.CharField(max_length=1024)
size = fields.BigIntField()
created_at = fields.DatetimeField(auto_now_add=True)
@@ -75,47 +94,69 @@ class FileVersion(models.Model):
class Meta:
table = "file_versions"
+
class Share(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
token = fields.CharField(max_length=64, unique=True)
- file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="shares", null=True)
- folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="shares", null=True)
- owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="shares")
+ file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
+ "models.File", related_name="shares", null=True
+ )
+ folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
+ "models.Folder", related_name="shares", null=True
+ )
+ owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="shares"
+ )
created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True)
password_protected = fields.BooleanField(default=False)
hashed_password = fields.CharField(max_length=255, null=True)
access_count = fields.IntField(default=0)
- permission_level = fields.CharField(max_length=50, default="viewer") # viewer, uploader, editor
+ permission_level = fields.CharField(
+ max_length=50, default="viewer"
+ ) # viewer, uploader, editor
class Meta:
table = "shares"
+
class Team(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255, unique=True)
- owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="owned_teams")
- members: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User", related_name="teams", through="team_members")
+ owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="owned_teams"
+ )
+ members: fields.ManyToManyRelation[User] = fields.ManyToManyField(
+ "models.User", related_name="teams", through="team_members"
+ )
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
table = "teams"
+
class TeamMember(models.Model):
- id = fields.IntField(pk=True)
- team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField("models.Team", related_name="team_members")
- user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="user_teams")
- role = fields.CharField(max_length=50, default="member") # owner, admin, member
+ id = fields.IntField(primary_key=True)
+ team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField(
+ "models.Team", related_name="team_members"
+ )
+ user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="user_teams"
+ )
+ role = fields.CharField(max_length=50, default="member") # owner, admin, member
class Meta:
table = "team_members"
unique_together = (("team", "user"),)
+
class Activity(models.Model):
- id = fields.IntField(pk=True)
- user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="activities", null=True)
+ id = fields.IntField(primary_key=True)
+ user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="activities", null=True
+ )
action = fields.CharField(max_length=255)
- target_type = fields.CharField(max_length=50) # file, folder, share, user, team
+ target_type = fields.CharField(max_length=50) # file, folder, share, user, team
target_id = fields.IntField()
ip_address = fields.CharField(max_length=45, null=True)
timestamp = fields.DatetimeField(auto_now_add=True)
@@ -123,13 +164,18 @@ class Activity(models.Model):
class Meta:
table = "activities"
+
class FileRequest(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
title = fields.CharField(max_length=255)
description = fields.TextField(null=True)
token = fields.CharField(max_length=64, unique=True)
- owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="file_requests")
- target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="file_requests")
+ owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
+ "models.User", related_name="file_requests"
+ )
+ target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
+ "models.Folder", related_name="file_requests"
+ )
created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True)
is_active = fields.BooleanField(default=True)
@@ -137,8 +183,9 @@ class FileRequest(models.Model):
class Meta:
table = "file_requests"
+
class WebDAVProperty(models.Model):
- id = fields.IntField(pk=True)
+ id = fields.IntField(primary_key=True)
resource_type = fields.CharField(max_length=10)
resource_id = fields.IntField()
namespace = fields.CharField(max_length=255)
@@ -151,5 +198,8 @@ class WebDAVProperty(models.Model):
table = "webdav_properties"
unique_together = (("resource_type", "resource_id", "namespace", "name"),)
+
User_Pydantic = pydantic_model_creator(User, name="User_Pydantic")
-UserIn_Pydantic = pydantic_model_creator(User, name="UserIn_Pydantic", exclude_readonly=True)
+UserIn_Pydantic = pydantic_model_creator(
+ User, name="UserIn_Pydantic", exclude_readonly=True
+)
diff --git a/mywebdav/routers/admin.py b/mywebdav/routers/admin.py
index 320861c..a6a7558 100644
--- a/mywebdav/routers/admin.py
+++ b/mywebdav/routers/admin.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
@@ -13,18 +13,25 @@ router = APIRouter(
responses={403: {"description": "Not enough permissions"}},
)
+
@router.get("/users", response_model=List[User_Pydantic])
async def get_all_users():
return await User.all()
+
@router.get("/users/{user_id}", response_model=User_Pydantic)
async def get_user(user_id: int):
user = await User.get_or_none(id=user_id)
if not user:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
+ )
return user
-@router.post("/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED)
+
+@router.post(
+ "/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED
+)
async def create_user_by_admin(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username)
if user:
@@ -44,66 +51,81 @@ async def create_user_by_admin(user_in: UserCreate):
username=user_in.username,
email=user_in.email,
hashed_password=hashed_password,
- is_superuser=False, # Admin creates regular users by default
+ is_superuser=False, # Admin creates regular users by default
is_active=True,
)
return await User_Pydantic.from_tortoise_orm(user)
+
@router.put("/users/{user_id}", response_model=User_Pydantic)
async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
user = await User.get_or_none(id=user_id)
if not user:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
+ )
if user_update.username is not None and user_update.username != user.username:
if await User.get_or_none(username=user_update.username):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
+ )
user.username = user_update.username
-
+
if user_update.email is not None and user_update.email != user.email:
if await User.get_or_none(email=user_update.email):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Email already registered",
+ )
user.email = user_update.email
if user_update.password is not None:
user.hashed_password = get_password_hash(user_update.password)
-
+
if user_update.is_active is not None:
user.is_active = user_update.is_active
-
+
if user_update.is_superuser is not None:
user.is_superuser = user_update.is_superuser
-
+
if user_update.storage_quota_bytes is not None:
user.storage_quota_bytes = user_update.storage_quota_bytes
-
+
if user_update.plan_type is not None:
user.plan_type = user_update.plan_type
-
+
if user_update.is_2fa_enabled is not None:
user.is_2fa_enabled = user_update.is_2fa_enabled
if not user_update.is_2fa_enabled:
user.two_factor_secret = None
user.recovery_codes = None
-
+
await user.save()
return await User_Pydantic.from_tortoise_orm(user)
+
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user_by_admin(user_id: int):
user = await User.get_or_none(id=user_id)
if not user:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
+ )
await user.delete()
return {"message": "User deleted successfully"}
+
@router.post("/test-email")
-async def send_test_email(to_email: str, subject: str = "Test Email", body: str = "This is a test email"):
+async def send_test_email(
+ to_email: str, subject: str = "Test Email", body: str = "This is a test email"
+):
from ..mail import queue_email
+
queue_email(
to_email=to_email,
subject=subject,
body=body,
- html=f"
{subject}
{body}
"
+ html=f"{subject}
{body}
",
)
- return {"message": "Test email queued"}
\ No newline at end of file
+ return {"message": "Test email queued"}
diff --git a/mywebdav/routers/admin_billing.py b/mywebdav/routers/admin_billing.py
index 54049cf..2f77be5 100644
--- a/mywebdav/routers/admin_billing.py
+++ b/mywebdav/routers/admin_billing.py
@@ -1,5 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException
-from typing import List
from decimal import Decimal
from pydantic import BaseModel
@@ -8,21 +7,25 @@ from ..models import User
from ..billing.models import PricingConfig, Invoice, SubscriptionPlan
from ..billing.invoice_generator import InvoiceGenerator
+
def require_superuser(current_user: User = Depends(get_current_user)):
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Superuser privileges required")
return current_user
+
router = APIRouter(
prefix="/api/admin/billing",
tags=["admin", "billing"],
- dependencies=[Depends(require_superuser)]
+ dependencies=[Depends(require_superuser)],
)
+
class PricingConfigUpdate(BaseModel):
config_key: str
config_value: float
+
class PlanCreate(BaseModel):
name: str
display_name: str
@@ -32,6 +35,7 @@ class PlanCreate(BaseModel):
price_monthly: float
price_yearly: float = None
+
@router.get("/pricing")
async def get_all_pricing(current_user: User = Depends(require_superuser)):
configs = await PricingConfig.all()
@@ -42,16 +46,17 @@ async def get_all_pricing(current_user: User = Depends(require_superuser)):
"config_value": float(c.config_value),
"description": c.description,
"unit": c.unit,
- "updated_at": c.updated_at
+ "updated_at": c.updated_at,
}
for c in configs
]
+
@router.put("/pricing/{config_id}")
async def update_pricing(
config_id: int,
update: PricingConfigUpdate,
- current_user: User = Depends(require_superuser)
+ current_user: User = Depends(require_superuser),
):
config = await PricingConfig.get_or_none(id=config_id)
if not config:
@@ -63,11 +68,10 @@ async def update_pricing(
return {"message": "Pricing updated successfully"}
+
@router.post("/generate-invoices/{year}/{month}")
async def generate_all_invoices(
- year: int,
- month: int,
- current_user: User = Depends(require_superuser)
+ year: int, month: int, current_user: User = Depends(require_superuser)
):
users = await User.filter(is_active=True).all()
generated = []
@@ -76,35 +80,36 @@ async def generate_all_invoices(
for user in users:
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month)
if invoice:
- generated.append({
- "user_id": user.id,
- "invoice_id": invoice.id,
- "total": float(invoice.total)
- })
+ generated.append(
+ {
+ "user_id": user.id,
+ "invoice_id": invoice.id,
+ "total": float(invoice.total),
+ }
+ )
else:
skipped.append(user.id)
- return {
- "generated": len(generated),
- "skipped": len(skipped),
- "invoices": generated
- }
+ return {"generated": len(generated), "skipped": len(skipped), "invoices": generated}
+
@router.post("/plans")
async def create_plan(
- plan_data: PlanCreate,
- current_user: User = Depends(require_superuser)
+ plan_data: PlanCreate, current_user: User = Depends(require_superuser)
):
plan = await SubscriptionPlan.create(**plan_data.dict())
return {"id": plan.id, "message": "Plan created successfully"}
+
@router.get("/stats")
async def get_billing_stats(current_user: User = Depends(require_superuser)):
- from tortoise.functions import Sum, Count
+ from tortoise.functions import Sum
- total_revenue = await Invoice.filter(status="paid").annotate(
- total_sum=Sum("total")
- ).values("total_sum")
+ total_revenue = (
+ await Invoice.filter(status="paid")
+ .annotate(total_sum=Sum("total"))
+ .values("total_sum")
+ )
invoice_count = await Invoice.all().count()
pending_invoices = await Invoice.filter(status="open").count()
@@ -112,5 +117,5 @@ async def get_billing_stats(current_user: User = Depends(require_superuser)):
return {
"total_revenue": float(total_revenue[0]["total_sum"] or 0),
"total_invoices": invoice_count,
- "pending_invoices": pending_invoices
+ "pending_invoices": pending_invoices,
}
diff --git a/mywebdav/routers/auth.py b/mywebdav/routers/auth.py
index ffd8fd1..500c708 100644
--- a/mywebdav/routers/auth.py
+++ b/mywebdav/routers/auth.py
@@ -4,13 +4,23 @@ from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
-from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password
+from ..auth import (
+ authenticate_user,
+ create_access_token,
+ get_password_hash,
+ get_current_user,
+ get_current_verified_user,
+ verify_password,
+)
from ..models import User
-from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA
+from ..schemas import Token, UserCreate
from ..two_factor import (
- generate_totp_secret, generate_totp_uri, generate_qr_code_base64,
- verify_totp_code, generate_recovery_codes, hash_recovery_codes,
- verify_recovery_codes
+ generate_totp_secret,
+ generate_totp_uri,
+ generate_qr_code_base64,
+ verify_totp_code,
+ generate_recovery_codes,
+ hash_recovery_codes,
)
router = APIRouter(
@@ -18,27 +28,33 @@ router = APIRouter(
tags=["auth"],
)
+
class LoginRequest(BaseModel):
username: str
password: str
+
class TwoFactorLogin(BaseModel):
username: str
password: str
two_factor_code: Optional[str] = None
+
class TwoFactorSetupResponse(BaseModel):
secret: str
qr_code_base64: str
recovery_codes: List[str]
+
class TwoFactorCode(BaseModel):
two_factor_code: str
+
class TwoFactorDisable(BaseModel):
password: str
two_factor_code: str
+
@router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username)
@@ -63,22 +79,26 @@ async def register_user(user_in: UserCreate):
# Send welcome email
from ..mail import queue_email
+
queue_email(
to_email=user.email,
subject="Welcome to MyWebdav!",
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
- html=f"Welcome to MyWebdav!
Hi {user.username},
Welcome to MyWebdav! Your account has been created successfully.
Best regards,
The MyWebdav Team
"
+ html=f"Welcome to MyWebdav!
Hi {user.username},
Welcome to MyWebdav! Your account has been created successfully.
Best regards,
The MyWebdav Team
",
)
- access_token_expires = timedelta(minutes=30) # Use settings
+ access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
+
@router.post("/token", response_model=Token)
async def login_for_access_token(login_data: LoginRequest):
- auth_result = await authenticate_user(login_data.username, login_data.password, None)
+ auth_result = await authenticate_user(
+ login_data.username, login_data.password, None
+ )
if not auth_result:
raise HTTPException(
@@ -97,16 +117,26 @@ async def login_for_access_token(login_data: LoginRequest):
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True
+ data={"sub": user.username},
+ expires_delta=access_token_expires,
+ two_factor_verified=True,
)
return {"access_token": access_token, "token_type": "bearer"}
+
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
-async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)):
+async def setup_two_factor_authentication(
+ current_user: User = Depends(get_current_user),
+):
if current_user.is_2fa_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
+ )
if current_user.two_factor_secret:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="2FA setup already initiated. Verify or disable first.",
+ )
secret = generate_totp_secret()
current_user.two_factor_secret = secret
@@ -120,39 +150,66 @@ async def setup_two_factor_authentication(current_user: User = Depends(get_curre
current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save()
- return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes)
+ return TwoFactorSetupResponse(
+ secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes
+ )
+
@router.post("/2fa/verify", response_model=Token)
-async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)):
+async def verify_two_factor_authentication(
+ two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)
+):
if current_user.is_2fa_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
+ )
if not current_user.two_factor_secret:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated."
+ )
- if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
+ if not verify_totp_code(
+ current_user.two_factor_secret, two_factor_code_data.two_factor_code
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
+ )
current_user.is_2fa_enabled = True
await current_user.save()
- access_token_expires = timedelta(minutes=30) # Use settings
+ access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
- data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True
+ data={"sub": current_user.username},
+ expires_delta=access_token_expires,
+ two_factor_verified=True,
)
return {"access_token": access_token, "token_type": "bearer"}
+
@router.post("/2fa/disable", response_model=dict)
-async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)):
+async def disable_two_factor_authentication(
+ disable_data: TwoFactorDisable,
+ current_user: User = Depends(get_current_verified_user),
+):
if not current_user.is_2fa_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
+ )
# Verify password
if not verify_password(disable_data.password, current_user.hashed_password):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.")
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password."
+ )
# Verify 2FA code
- if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
+ if not verify_totp_code(
+ current_user.two_factor_secret, disable_data.two_factor_code
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
+ )
current_user.two_factor_secret = None
current_user.is_2fa_enabled = False
@@ -161,10 +218,15 @@ async def disable_two_factor_authentication(disable_data: TwoFactorDisable, curr
return {"message": "2FA disabled successfully."}
+
@router.get("/2fa/recovery-codes", response_model=List[str])
-async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)):
+async def get_new_recovery_codes(
+ current_user: User = Depends(get_current_verified_user),
+):
if not current_user.is_2fa_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
+ )
recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes)
diff --git a/mywebdav/routers/billing.py b/mywebdav/routers/billing.py
index 4eb8ddd..8a69b27 100644
--- a/mywebdav/routers/billing.py
+++ b/mywebdav/routers/billing.py
@@ -1,25 +1,25 @@
-from fastapi import APIRouter, Depends, HTTPException, status, Request
+from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from typing import List, Optional
from datetime import datetime, date
-from decimal import Decimal
-import calendar
from ..auth import get_current_user
from ..models import User
from ..billing.models import (
- Invoice, InvoiceLineItem, UserSubscription, PricingConfig,
- PaymentMethod, UsageAggregate, SubscriptionPlan
+ Invoice,
+ UserSubscription,
+ PricingConfig,
+ PaymentMethod,
+ UsageAggregate,
+ SubscriptionPlan,
)
from ..billing.usage_tracker import UsageTracker
from ..billing.invoice_generator import InvoiceGenerator
from ..billing.stripe_client import StripeClient
from pydantic import BaseModel
-router = APIRouter(
- prefix="/api/billing",
- tags=["billing"]
-)
+router = APIRouter(prefix="/api/billing", tags=["billing"])
+
class UsageResponse(BaseModel):
storage_gb_avg: float
@@ -29,6 +29,7 @@ class UsageResponse(BaseModel):
total_bandwidth_gb: float
period: str
+
class InvoiceResponse(BaseModel):
id: int
invoice_number: str
@@ -42,6 +43,7 @@ class InvoiceResponse(BaseModel):
paid_at: Optional[datetime]
line_items: List[dict]
+
class SubscriptionResponse(BaseModel):
id: int
billing_type: str
@@ -50,6 +52,7 @@ class SubscriptionResponse(BaseModel):
current_period_start: Optional[datetime]
current_period_end: Optional[datetime]
+
@router.get("/usage/current")
async def get_current_usage(current_user: User = Depends(get_current_user)):
try:
@@ -61,25 +64,32 @@ async def get_current_usage(current_user: User = Depends(get_current_user)):
if usage_today:
return {
"storage_gb": round(storage_bytes / (1024**3), 4),
- "bandwidth_down_gb_today": round(usage_today.bandwidth_down_bytes / (1024**3), 4),
- "bandwidth_up_gb_today": round(usage_today.bandwidth_up_bytes / (1024**3), 4),
- "as_of": today.isoformat()
+ "bandwidth_down_gb_today": round(
+ usage_today.bandwidth_down_bytes / (1024**3), 4
+ ),
+ "bandwidth_up_gb_today": round(
+ usage_today.bandwidth_up_bytes / (1024**3), 4
+ ),
+ "as_of": today.isoformat(),
}
return {
"storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": 0,
"bandwidth_up_gb_today": 0,
- "as_of": today.isoformat()
+ "as_of": today.isoformat(),
}
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to fetch usage data: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to fetch usage data: {str(e)}"
+ )
+
@router.get("/usage/monthly")
async def get_monthly_usage(
year: Optional[int] = None,
month: Optional[int] = None,
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
) -> UsageResponse:
try:
if year is None or month is None:
@@ -88,71 +98,85 @@ async def get_monthly_usage(
month = now.month
if not (1 <= month <= 12):
- raise HTTPException(status_code=400, detail="Month must be between 1 and 12")
+ raise HTTPException(
+ status_code=400, detail="Month must be between 1 and 12"
+ )
if not (2020 <= year <= 2100):
- raise HTTPException(status_code=400, detail="Year must be between 2020 and 2100")
+ raise HTTPException(
+ status_code=400, detail="Year must be between 2020 and 2100"
+ )
usage = await UsageTracker.get_monthly_usage(current_user, year, month)
- return UsageResponse(
- **usage,
- period=f"{year}-{month:02d}"
- )
+ return UsageResponse(**usage, period=f"{year}-{month:02d}")
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}"
+ )
+
@router.get("/invoices")
async def list_invoices(
- limit: int = 50,
- offset: int = 0,
- current_user: User = Depends(get_current_user)
+ limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user)
) -> List[InvoiceResponse]:
try:
if limit < 1 or limit > 100:
- raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
+ raise HTTPException(
+ status_code=400, detail="Limit must be between 1 and 100"
+ )
if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative")
- invoices = await Invoice.filter(user=current_user).order_by("-created_at").offset(offset).limit(limit).all()
+ invoices = (
+ await Invoice.filter(user=current_user)
+ .order_by("-created_at")
+ .offset(offset)
+ .limit(limit)
+ .all()
+ )
result = []
for invoice in invoices:
line_items = await invoice.line_items.all()
- result.append(InvoiceResponse(
- id=invoice.id,
- invoice_number=invoice.invoice_number,
- period_start=invoice.period_start,
- period_end=invoice.period_end,
- subtotal=float(invoice.subtotal),
- tax=float(invoice.tax),
- total=float(invoice.total),
- status=invoice.status,
- due_date=invoice.due_date,
- paid_at=invoice.paid_at,
- line_items=[
- {
- "description": item.description,
- "quantity": float(item.quantity),
- "unit_price": float(item.unit_price),
- "amount": float(item.amount),
- "type": item.item_type
- }
- for item in line_items
- ]
- ))
+ result.append(
+ InvoiceResponse(
+ id=invoice.id,
+ invoice_number=invoice.invoice_number,
+ period_start=invoice.period_start,
+ period_end=invoice.period_end,
+ subtotal=float(invoice.subtotal),
+ tax=float(invoice.tax),
+ total=float(invoice.total),
+ status=invoice.status,
+ due_date=invoice.due_date,
+ paid_at=invoice.paid_at,
+ line_items=[
+ {
+ "description": item.description,
+ "quantity": float(item.quantity),
+ "unit_price": float(item.unit_price),
+ "amount": float(item.amount),
+ "type": item.item_type,
+ }
+ for item in line_items
+ ],
+ )
+ )
return result
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to fetch invoices: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to fetch invoices: {str(e)}"
+ )
+
@router.get("/invoices/{invoice_id}")
async def get_invoice(
- invoice_id: int,
- current_user: User = Depends(get_current_user)
+ invoice_id: int, current_user: User = Depends(get_current_user)
) -> InvoiceResponse:
invoice = await Invoice.get_or_none(id=invoice_id, user=current_user)
if not invoice:
@@ -177,21 +201,22 @@ async def get_invoice(
"quantity": float(item.quantity),
"unit_price": float(item.unit_price),
"amount": float(item.amount),
- "type": item.item_type
+ "type": item.item_type,
}
for item in line_items
- ]
+ ],
)
+
@router.get("/subscription")
-async def get_subscription(current_user: User = Depends(get_current_user)) -> SubscriptionResponse:
+async def get_subscription(
+ current_user: User = Depends(get_current_user),
+) -> SubscriptionResponse:
subscription = await UserSubscription.get_or_none(user=current_user)
if not subscription:
subscription = await UserSubscription.create(
- user=current_user,
- billing_type="pay_as_you_go",
- status="active"
+ user=current_user, billing_type="pay_as_you_go", status="active"
)
plan_name = None
@@ -205,15 +230,19 @@ async def get_subscription(current_user: User = Depends(get_current_user)) -> Su
plan_name=plan_name,
status=subscription.status,
current_period_start=subscription.current_period_start,
- current_period_end=subscription.current_period_end
+ current_period_end=subscription.current_period_end,
)
+
@router.post("/payment-methods/setup-intent")
async def create_setup_intent(current_user: User = Depends(get_current_user)):
try:
from ..settings import settings
+
if not settings.STRIPE_SECRET_KEY:
- raise HTTPException(status_code=503, detail="Payment processing not configured")
+ raise HTTPException(
+ status_code=503, detail="Payment processing not configured"
+ )
subscription = await UserSubscription.get_or_none(user=current_user)
@@ -221,7 +250,7 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
customer_id = await StripeClient.create_customer(
email=current_user.email,
name=current_user.username,
- metadata={"user_id": str(current_user.id)}
+ metadata={"user_id": str(current_user.id)},
)
if not subscription:
@@ -229,27 +258,30 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
user=current_user,
billing_type="pay_as_you_go",
stripe_customer_id=customer_id,
- status="active"
+ status="active",
)
else:
subscription.stripe_customer_id = customer_id
await subscription.save()
import stripe
+
StripeClient._ensure_api_key()
setup_intent = stripe.SetupIntent.create(
- customer=subscription.stripe_customer_id,
- payment_method_types=["card"]
+ customer=subscription.stripe_customer_id, payment_method_types=["card"]
)
return {
"client_secret": setup_intent.client_secret,
- "customer_id": subscription.stripe_customer_id
+ "customer_id": subscription.stripe_customer_id,
}
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to create setup intent: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to create setup intent: {str(e)}"
+ )
+
@router.get("/payment-methods")
async def list_payment_methods(current_user: User = Depends(get_current_user)):
@@ -262,11 +294,12 @@ async def list_payment_methods(current_user: User = Depends(get_current_user)):
"brand": m.brand,
"exp_month": m.exp_month,
"exp_year": m.exp_year,
- "is_default": m.is_default
+ "is_default": m.is_default,
}
for m in methods
]
+
@router.post("/webhooks/stripe")
async def stripe_webhook(request: Request):
import stripe
@@ -298,12 +331,14 @@ async def stripe_webhook(request: Request):
event_type=event["type"],
stripe_event_id=event_id,
data=event["data"],
- processed=False
+ processed=False,
)
if event["type"] == "invoice.payment_succeeded":
invoice_data = event["data"]["object"]
- mywebdav_invoice_id = invoice_data.get("metadata", {}).get("mywebdav_invoice_id")
+ mywebdav_invoice_id = invoice_data.get("metadata", {}).get(
+ "mywebdav_invoice_id"
+ )
if mywebdav_invoice_id:
invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id))
@@ -317,7 +352,9 @@ async def stripe_webhook(request: Request):
payment_method = event["data"]["object"]
customer_id = payment_method["customer"]
- subscription = await UserSubscription.get_or_none(stripe_customer_id=customer_id)
+ subscription = await UserSubscription.get_or_none(
+ stripe_customer_id=customer_id
+ )
if subscription:
await PaymentMethod.create(
user=subscription.user,
@@ -327,7 +364,7 @@ async def stripe_webhook(request: Request):
brand=payment_method.get("card", {}).get("brand"),
exp_month=payment_method.get("card", {}).get("exp_month"),
exp_year=payment_method.get("card", {}).get("exp_year"),
- is_default=True
+ is_default=True,
)
await BillingEvent.filter(stripe_event_id=event_id).update(processed=True)
@@ -335,7 +372,10 @@ async def stripe_webhook(request: Request):
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=500, detail=f"Webhook processing failed: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Webhook processing failed: {str(e)}"
+ )
+
@router.get("/pricing")
async def get_pricing():
@@ -344,11 +384,12 @@ async def get_pricing():
config.config_key: {
"value": float(config.config_value),
"description": config.description,
- "unit": config.unit
+ "unit": config.unit,
}
for config in configs
}
+
@router.get("/plans")
async def list_plans():
plans = await SubscriptionPlan.filter(is_active=True).all()
@@ -361,14 +402,16 @@ async def list_plans():
"storage_gb": plan.storage_gb,
"bandwidth_gb": plan.bandwidth_gb,
"price_monthly": float(plan.price_monthly),
- "price_yearly": float(plan.price_yearly) if plan.price_yearly else None
+ "price_yearly": float(plan.price_yearly) if plan.price_yearly else None,
}
for plan in plans
]
+
@router.get("/stripe-key")
async def get_stripe_key():
from ..settings import settings
+
if not settings.STRIPE_PUBLISHABLE_KEY:
raise HTTPException(status_code=503, detail="Payment processing not configured")
return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY}
diff --git a/mywebdav/routers/files.py b/mywebdav/routers/files.py
index 626be85..b0c5348 100644
--- a/mywebdav/routers/files.py
+++ b/mywebdav/routers/files.py
@@ -1,4 +1,12 @@
-from fastapi import APIRouter, Depends, UploadFile, File as FastAPIFile, HTTPException, status, Response, Form
+from fastapi import (
+ APIRouter,
+ Depends,
+ UploadFile,
+ File as FastAPIFile,
+ HTTPException,
+ status,
+ Form,
+)
from fastapi.responses import StreamingResponse
from typing import List, Optional
import mimetypes
@@ -11,7 +19,6 @@ from ..auth import get_current_user
from ..models import User, File, Folder
from ..schemas import FileOut
from ..storage import storage_manager
-from ..settings import settings
from ..activity import log_activity
from ..thumbnails import generate_thumbnail, delete_thumbnail
@@ -20,35 +27,46 @@ router = APIRouter(
tags=["files"],
)
+
class FileMove(BaseModel):
target_folder_id: Optional[int] = None
+
class FileRename(BaseModel):
new_name: str
+
class FileCopy(BaseModel):
target_folder_id: Optional[int] = None
+
class BatchFileOperation(BaseModel):
file_ids: List[int]
- operation: str # e.g., "delete", "star", "unstar", "move", "copy"
+ operation: str # e.g., "delete", "star", "unstar", "move", "copy"
+
class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None
+
class FileContentUpdate(BaseModel):
content: str
+
@router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED)
async def upload_file(
file: UploadFile = FastAPIFile(...),
folder_id: Optional[int] = Form(None),
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
if folder_id:
- parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ parent_folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not parent_folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
else:
parent_folder = None
@@ -73,7 +91,7 @@ async def upload_file(
# Generate a unique path for storage
file_extension = os.path.splitext(file.filename)[1]
- unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename
+ unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename
storage_path = unique_filename
# Save file to storage
@@ -105,16 +123,20 @@ async def upload_file(
return await FileOut.from_tortoise_orm(db_file)
+
@router.get("/download/{file_id}")
async def download_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
db_file.last_accessed_at = datetime.now()
await db_file.save()
try:
+
async def file_iterator():
async for chunk in storage_manager.get_file(current_user.id, db_file.path):
yield chunk
@@ -122,16 +144,21 @@ async def download_file(file_id: int, current_user: User = Depends(get_current_u
return StreamingResponse(
file_iterator(),
media_type=db_file.mime_type,
- headers={"Content-Disposition": f"attachment; filename=\"{db_file.name}\""}
+ headers={"Content-Disposition": f'attachment; filename="{db_file.name}"'},
)
except FileNotFoundError:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage"
+ )
+
@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
db_file.is_deleted = True
db_file.deleted_at = datetime.now()
@@ -141,68 +168,110 @@ async def delete_file(file_id: int, current_user: User = Depends(get_current_use
return
+
@router.post("/{file_id}/move", response_model=FileOut)
-async def move_file(file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)):
+async def move_file(
+ file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)
+):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
target_folder = None
if move_data.target_folder_id:
- target_folder = await Folder.get_or_none(id=move_data.target_folder_id, owner=current_user, is_deleted=False)
+ target_folder = await Folder.get_or_none(
+ id=move_data.target_folder_id, owner=current_user, is_deleted=False
+ )
if not target_folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
+ )
existing_file = await File.get_or_none(
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
)
if existing_file and existing_file.id != file_id:
- raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in target folder")
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail="File with this name already exists in target folder",
+ )
db_file.parent = target_folder
await db_file.save()
- await log_activity(user=current_user, action="file_moved", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user, action="file_moved", target_type="file", target_id=file_id
+ )
return await FileOut.from_tortoise_orm(db_file)
+
@router.post("/{file_id}/rename", response_model=FileOut)
-async def rename_file(file_id: int, rename_data: FileRename, current_user: User = Depends(get_current_user)):
+async def rename_file(
+ file_id: int,
+ rename_data: FileRename,
+ current_user: User = Depends(get_current_user),
+):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
existing_file = await File.get_or_none(
- name=rename_data.new_name, parent_id=db_file.parent_id, owner=current_user, is_deleted=False
+ name=rename_data.new_name,
+ parent_id=db_file.parent_id,
+ owner=current_user,
+ is_deleted=False,
)
if existing_file and existing_file.id != file_id:
- raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in the same folder")
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail="File with this name already exists in the same folder",
+ )
db_file.name = rename_data.new_name
await db_file.save()
- await log_activity(user=current_user, action="file_renamed", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user, action="file_renamed", target_type="file", target_id=file_id
+ )
return await FileOut.from_tortoise_orm(db_file)
-@router.post("/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED)
-async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)):
+
+@router.post(
+ "/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED
+)
+async def copy_file(
+ file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)
+):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
target_folder = None
if copy_data.target_folder_id:
- target_folder = await Folder.get_or_none(id=copy_data.target_folder_id, owner=current_user, is_deleted=False)
+ target_folder = await Folder.get_or_none(
+ id=copy_data.target_folder_id, owner=current_user, is_deleted=False
+ )
if not target_folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
+ )
base_name = db_file.name
name_parts = os.path.splitext(base_name)
counter = 1
new_name = base_name
- while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
+ while await File.get_or_none(
+ name=new_name, parent=target_folder, owner=current_user, is_deleted=False
+ ):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1
@@ -213,139 +282,205 @@ async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depe
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
owner=current_user,
- parent=target_folder
+ parent=target_folder,
)
- await log_activity(user=current_user, action="file_copied", target_type="file", target_id=new_file.id)
+ await log_activity(
+ user=current_user,
+ action="file_copied",
+ target_type="file",
+ target_id=new_file.id,
+ )
return await FileOut.from_tortoise_orm(new_file)
+
@router.get("/", response_model=List[FileOut])
-async def list_files(folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
+async def list_files(
+ folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)
+):
if folder_id:
- parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ parent_folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not parent_folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
- files = await File.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
+ files = await File.filter(
+ parent=parent_folder, owner=current_user, is_deleted=False
+ ).order_by("name")
else:
- files = await File.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
+ files = await File.filter(
+ parent=None, owner=current_user, is_deleted=False
+ ).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files]
+
@router.get("/thumbnail/{file_id}")
async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
db_file.last_accessed_at = datetime.now()
await db_file.save()
- thumbnail_path = getattr(db_file, 'thumbnail_path', None)
+ thumbnail_path = getattr(db_file, "thumbnail_path", None)
if not thumbnail_path:
- thumbnail_path = await generate_thumbnail(db_file.path, db_file.mime_type, current_user.id)
+ thumbnail_path = await generate_thumbnail(
+ db_file.path, db_file.mime_type, current_user.id
+ )
if thumbnail_path:
db_file.thumbnail_path = thumbnail_path
await db_file.save()
else:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available"
+ )
try:
+
async def thumbnail_iterator():
- async for chunk in storage_manager.get_file(current_user.id, thumbnail_path):
+ async for chunk in storage_manager.get_file(
+ current_user.id, thumbnail_path
+ ):
yield chunk
- return StreamingResponse(
- thumbnail_iterator(),
- media_type="image/jpeg"
- )
+ return StreamingResponse(thumbnail_iterator(), media_type="image/jpeg")
except FileNotFoundError:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not found in storage")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Thumbnail not found in storage",
+ )
+
@router.get("/photos", response_model=List[FileOut])
async def list_photos(current_user: User = Depends(get_current_user)):
files = await File.filter(
- owner=current_user,
- is_deleted=False,
- mime_type__istartswith="image/"
+ owner=current_user, is_deleted=False, mime_type__istartswith="image/"
).order_by("-created_at")
return [await FileOut.from_tortoise_orm(f) for f in files]
+
@router.get("/recent", response_model=List[FileOut])
-async def list_recent_files(current_user: User = Depends(get_current_user), limit: int = 10):
- files = await File.filter(
- owner=current_user,
- is_deleted=False,
- last_accessed_at__isnull=False
- ).order_by("-last_accessed_at").limit(limit)
+async def list_recent_files(
+ current_user: User = Depends(get_current_user), limit: int = 10
+):
+ files = (
+ await File.filter(
+ owner=current_user, is_deleted=False, last_accessed_at__isnull=False
+ )
+ .order_by("-last_accessed_at")
+ .limit(limit)
+ )
return [await FileOut.from_tortoise_orm(f) for f in files]
+
@router.post("/{file_id}/star", response_model=FileOut)
async def star_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
db_file.is_starred = True
await db_file.save()
- await log_activity(user=current_user, action="file_starred", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user, action="file_starred", target_type="file", target_id=file_id
+ )
return await FileOut.from_tortoise_orm(db_file)
+
@router.post("/{file_id}/unstar", response_model=FileOut)
async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
db_file.is_starred = False
await db_file.save()
- await log_activity(user=current_user, action="file_unstarred", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user,
+ action="file_unstarred",
+ target_type="file",
+ target_id=file_id,
+ )
return await FileOut.from_tortoise_orm(db_file)
+
@router.get("/deleted", response_model=List[FileOut])
async def list_deleted_files(current_user: User = Depends(get_current_user)):
- files = await File.filter(owner=current_user, is_deleted=True).order_by("-deleted_at")
+ files = await File.filter(owner=current_user, is_deleted=True).order_by(
+ "-deleted_at"
+ )
return [await FileOut.from_tortoise_orm(f) for f in files]
+
@router.post("/{file_id}/restore", response_model=FileOut)
async def restore_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found"
+ )
+
# Check if a file with the same name exists in the parent folder
existing_file = await File.get_or_none(
name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False
)
if existing_file:
- raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.")
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.",
+ )
db_file.is_deleted = False
db_file.deleted_at = None
await db_file.save()
- await log_activity(user=current_user, action="file_restored", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user, action="file_restored", target_type="file", target_id=file_id
+ )
return await FileOut.from_tortoise_orm(db_file)
+
class BatchOperationResult(BaseModel):
succeeded: List[FileOut]
failed: List[dict]
+
@router.post("/batch")
async def batch_file_operations(
batch_operation: BatchFileOperation,
payload: Optional[BatchMoveCopyPayload] = None,
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid operation: {batch_operation.operation}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Invalid operation: {batch_operation.operation}",
+ )
updated_files = []
failed_operations = []
for file_id in batch_operation.file_ids:
try:
- db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
+ db_file = await File.get_or_none(
+ id=file_id, owner=current_user, is_deleted=False
+ )
if not db_file:
- failed_operations.append({"file_id": file_id, "reason": "File not found or not owned by user"})
+ failed_operations.append(
+ {
+ "file_id": file_id,
+ "reason": "File not found or not owned by user",
+ }
+ )
continue
if batch_operation.operation == "delete":
@@ -353,47 +488,87 @@ async def batch_file_operations(
db_file.deleted_at = datetime.now()
await db_file.save()
await delete_thumbnail(db_file.id)
- await log_activity(user=current_user, action="file_deleted_batch", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user,
+ action="file_deleted_batch",
+ target_type="file",
+ target_id=file_id,
+ )
updated_files.append(db_file)
elif batch_operation.operation == "star":
db_file.is_starred = True
await db_file.save()
- await log_activity(user=current_user, action="file_starred_batch", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user,
+ action="file_starred_batch",
+ target_type="file",
+ target_id=file_id,
+ )
updated_files.append(db_file)
elif batch_operation.operation == "unstar":
db_file.is_starred = False
await db_file.save()
- await log_activity(user=current_user, action="file_unstarred_batch", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user,
+ action="file_unstarred_batch",
+ target_type="file",
+ target_id=file_id,
+ )
updated_files.append(db_file)
elif batch_operation.operation == "move":
if not payload or payload.target_folder_id is None:
- failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
+ failed_operations.append(
+ {"file_id": file_id, "reason": "Target folder not specified"}
+ )
continue
- target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
+ target_folder = await Folder.get_or_none(
+ id=payload.target_folder_id, owner=current_user, is_deleted=False
+ )
if not target_folder:
- failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
+ failed_operations.append(
+ {"file_id": file_id, "reason": "Target folder not found"}
+ )
continue
existing_file = await File.get_or_none(
- name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
+ name=db_file.name,
+ parent=target_folder,
+ owner=current_user,
+ is_deleted=False,
)
if existing_file and existing_file.id != file_id:
- failed_operations.append({"file_id": file_id, "reason": "File with same name exists in target folder"})
+ failed_operations.append(
+ {
+ "file_id": file_id,
+ "reason": "File with same name exists in target folder",
+ }
+ )
continue
db_file.parent = target_folder
await db_file.save()
- await log_activity(user=current_user, action="file_moved_batch", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user,
+ action="file_moved_batch",
+ target_type="file",
+ target_id=file_id,
+ )
updated_files.append(db_file)
elif batch_operation.operation == "copy":
if not payload or payload.target_folder_id is None:
- failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
+ failed_operations.append(
+ {"file_id": file_id, "reason": "Target folder not specified"}
+ )
continue
- target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
+ target_folder = await Folder.get_or_none(
+ id=payload.target_folder_id, owner=current_user, is_deleted=False
+ )
if not target_folder:
- failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
+ failed_operations.append(
+ {"file_id": file_id, "reason": "Target folder not found"}
+ )
continue
base_name = db_file.name
@@ -401,7 +576,12 @@ async def batch_file_operations(
counter = 1
new_name = base_name
- while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
+ while await File.get_or_none(
+ name=new_name,
+ parent=target_folder,
+ owner=current_user,
+ is_deleted=False,
+ ):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1
@@ -412,48 +592,70 @@ async def batch_file_operations(
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
owner=current_user,
- parent=target_folder
+ parent=target_folder,
+ )
+ await log_activity(
+ user=current_user,
+ action="file_copied_batch",
+ target_type="file",
+ target_id=new_file.id,
)
- await log_activity(user=current_user, action="file_copied_batch", target_type="file", target_id=new_file.id)
updated_files.append(new_file)
except Exception as e:
failed_operations.append({"file_id": file_id, "reason": str(e)})
return {
"succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files],
- "failed": failed_operations
+ "failed": failed_operations,
}
+
@router.put("/{file_id}/content", response_model=FileOut)
async def update_file_content(
file_id: int,
payload: FileContentUpdate,
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
- if not db_file.mime_type or not db_file.mime_type.startswith('text/'):
+ if not db_file.mime_type or not db_file.mime_type.startswith("text/"):
editableExtensions = [
- 'txt', 'md', 'log', 'json', 'js', 'py', 'html', 'css',
- 'xml', 'yaml', 'yml', 'sh', 'bat', 'ini', 'conf', 'cfg'
+ "txt",
+ "md",
+ "log",
+ "json",
+ "js",
+ "py",
+ "html",
+ "css",
+ "xml",
+ "yaml",
+ "yml",
+ "sh",
+ "bat",
+ "ini",
+ "conf",
+ "cfg",
]
file_extension = os.path.splitext(db_file.name)[1][1:].lower()
if file_extension not in editableExtensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
- detail="File type is not editable"
+ detail="File type is not editable",
)
- content_bytes = payload.content.encode('utf-8')
+ content_bytes = payload.content.encode("utf-8")
new_size = len(content_bytes)
size_diff = new_size - db_file.size
if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes:
raise HTTPException(
status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
- detail="Storage quota exceeded"
+ detail="Storage quota exceeded",
)
new_hash = hashlib.sha256(content_bytes).hexdigest()
@@ -465,7 +667,7 @@ async def update_file_content(
if new_storage_path != db_file.path:
try:
await storage_manager.delete_file(current_user.id, db_file.path)
- except:
+ except Exception:
pass
db_file.path = new_storage_path
@@ -477,6 +679,8 @@ async def update_file_content(
current_user.used_storage_bytes += size_diff
await current_user.save()
- await log_activity(user=current_user, action="file_updated", target_type="file", target_id=file_id)
+ await log_activity(
+ user=current_user, action="file_updated", target_type="file", target_id=file_id
+ )
return await FileOut.from_tortoise_orm(db_file)
diff --git a/mywebdav/routers/folders.py b/mywebdav/routers/folders.py
index a74fce2..b2316cf 100644
--- a/mywebdav/routers/folders.py
+++ b/mywebdav/routers/folders.py
@@ -3,7 +3,13 @@ from typing import List, Optional
from ..auth import get_current_user
from ..models import User, Folder
-from ..schemas import FolderCreate, FolderOut, FolderUpdate, BatchFolderOperation, BatchMoveCopyPayload
+from ..schemas import (
+ FolderCreate,
+ FolderOut,
+ FolderUpdate,
+ BatchFolderOperation,
+ BatchMoveCopyPayload,
+)
from ..activity import log_activity
router = APIRouter(
@@ -11,12 +17,17 @@ router = APIRouter(
tags=["folders"],
)
+
@router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
-async def create_folder(folder_in: FolderCreate, current_user: User = Depends(get_current_user)):
+async def create_folder(
+ folder_in: FolderCreate, current_user: User = Depends(get_current_user)
+):
# Check if parent folder exists and belongs to the current user
parent_folder = None
if folder_in.parent_id:
- parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user)
+ parent_folder = await Folder.get_or_none(
+ id=folder_in.parent_id, owner=current_user
+ )
if not parent_folder:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -39,50 +50,82 @@ async def create_folder(folder_in: FolderCreate, current_user: User = Depends(ge
await log_activity(current_user, "folder_created", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder)
+
@router.get("/{folder_id}/path", response_model=List[FolderOut])
-async def get_folder_path(folder_id: int, current_user: User = Depends(get_current_user)):
- folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+async def get_folder_path(
+ folder_id: int, current_user: User = Depends(get_current_user)
+):
+ folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
path = []
current = folder
while current:
path.insert(0, await FolderOut.from_tortoise_orm(current))
if current.parent_id:
- current = await Folder.get_or_none(id=current.parent_id, owner=current_user, is_deleted=False)
+ current = await Folder.get_or_none(
+ id=current.parent_id, owner=current_user, is_deleted=False
+ )
else:
current = None
return path
+
@router.get("/{folder_id}", response_model=FolderOut)
async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)):
- folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
return await FolderOut.from_tortoise_orm(folder)
+
@router.get("/", response_model=List[FolderOut])
-async def list_folders(parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
+async def list_folders(
+ parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)
+):
if parent_id:
- parent_folder = await Folder.get_or_none(id=parent_id, owner=current_user, is_deleted=False)
+ parent_folder = await Folder.get_or_none(
+ id=parent_id, owner=current_user, is_deleted=False
+ )
if not parent_folder:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Parent folder not found or does not belong to the current user",
)
- folders = await Folder.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
+ folders = await Folder.filter(
+ parent=parent_folder, owner=current_user, is_deleted=False
+ ).order_by("name")
else:
# List root folders (folders with no parent)
- folders = await Folder.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
+ folders = await Folder.filter(
+ parent=None, owner=current_user, is_deleted=False
+ ).order_by("name")
return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
+
@router.put("/{folder_id}", response_model=FolderOut)
-async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: User = Depends(get_current_user)):
- folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+async def update_folder(
+ folder_id: int,
+ folder_in: FolderUpdate,
+ current_user: User = Depends(get_current_user),
+):
+ folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
if folder_in.name:
existing_folder = await Folder.get_or_none(
@@ -97,11 +140,16 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
if folder_in.parent_id is not None:
if folder_in.parent_id == folder_id:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot set folder as its own parent")
-
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Cannot set folder as its own parent",
+ )
+
new_parent_folder = None
- if folder_in.parent_id != 0: # 0 could represent moving to root
- new_parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user, is_deleted=False)
+ if folder_in.parent_id != 0: # 0 could represent moving to root
+ new_parent_folder = await Folder.get_or_none(
+ id=folder_in.parent_id, owner=current_user, is_deleted=False
+ )
if not new_parent_folder:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -113,48 +161,66 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
await log_activity(current_user, "folder_updated", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder)
+
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)):
- folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
+
folder.is_deleted = True
await folder.save()
await log_activity(current_user, "folder_deleted", "folder", folder.id)
return
+
@router.post("/{folder_id}/star", response_model=FolderOut)
async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)):
- db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ db_folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not db_folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
db_folder.is_starred = True
await db_folder.save()
await log_activity(current_user, "folder_starred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder)
+
@router.post("/{folder_id}/unstar", response_model=FolderOut)
async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)):
- db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ db_folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not db_folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
db_folder.is_starred = False
await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder)
+
@router.post("/batch", response_model=List[FolderOut])
async def batch_folder_operations(
batch_operation: BatchFolderOperation,
payload: Optional[BatchMoveCopyPayload] = None,
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
updated_folders = []
for folder_id in batch_operation.folder_ids:
- db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
+ db_folder = await Folder.get_or_none(
+ id=folder_id, owner=current_user, is_deleted=False
+ )
if not db_folder:
- continue # Skip if folder not found or not owned by user
+ continue # Skip if folder not found or not owned by user
if batch_operation.operation == "delete":
db_folder.is_deleted = True
@@ -171,13 +237,22 @@ async def batch_folder_operations(
await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id)
updated_folders.append(db_folder)
- elif batch_operation.operation == "move" and payload and payload.target_folder_id is not None:
- target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
+ elif (
+ batch_operation.operation == "move"
+ and payload
+ and payload.target_folder_id is not None
+ ):
+ target_folder = await Folder.get_or_none(
+ id=payload.target_folder_id, owner=current_user, is_deleted=False
+ )
if not target_folder:
continue
existing_folder = await Folder.get_or_none(
- name=db_folder.name, parent=target_folder, owner=current_user, is_deleted=False
+ name=db_folder.name,
+ parent=target_folder,
+ owner=current_user,
+ is_deleted=False,
)
if existing_folder and existing_folder.id != folder_id:
continue
@@ -186,5 +261,5 @@ async def batch_folder_operations(
await db_folder.save()
await log_activity(current_user, "folder_moved", "folder", folder_id)
updated_folders.append(db_folder)
-
+
return [await FolderOut.from_tortoise_orm(f) for f in updated_folders]
diff --git a/mywebdav/routers/search.py b/mywebdav/routers/search.py
index 8c68c62..77c4cc0 100644
--- a/mywebdav/routers/search.py
+++ b/mywebdav/routers/search.py
@@ -11,15 +11,22 @@ router = APIRouter(
tags=["search"],
)
+
@router.get("/files", response_model=List[FileOut])
async def search_files(
q: str = Query(..., min_length=1, description="Search query"),
- file_type: Optional[str] = Query(None, description="Filter by MIME type prefix (e.g., 'image', 'video')"),
+ file_type: Optional[str] = Query(
+ None, description="Filter by MIME type prefix (e.g., 'image', 'video')"
+ ),
min_size: Optional[int] = Query(None, description="Minimum file size in bytes"),
max_size: Optional[int] = Query(None, description="Maximum file size in bytes"),
- date_from: Optional[datetime] = Query(None, description="Filter files created after this date"),
- date_to: Optional[datetime] = Query(None, description="Filter files created before this date"),
- current_user: User = Depends(get_current_user)
+ date_from: Optional[datetime] = Query(
+ None, description="Filter files created after this date"
+ ),
+ date_to: Optional[datetime] = Query(
+ None, description="Filter files created before this date"
+ ),
+ current_user: User = Depends(get_current_user),
):
query = File.filter(owner=current_user, is_deleted=False, name__icontains=q)
@@ -41,15 +48,16 @@ async def search_files(
files = await query.order_by("-created_at").limit(100)
return [await FileOut.from_tortoise_orm(f) for f in files]
+
@router.get("/folders", response_model=List[FolderOut])
async def search_folders(
q: str = Query(..., min_length=1, description="Search query"),
- current_user: User = Depends(get_current_user)
+ current_user: User = Depends(get_current_user),
):
- folders = await Folder.filter(
- owner=current_user,
- is_deleted=False,
- name__icontains=q
- ).order_by("-created_at").limit(100)
+ folders = (
+ await Folder.filter(owner=current_user, is_deleted=False, name__icontains=q)
+ .order_by("-created_at")
+ .limit(100)
+ )
return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
diff --git a/mywebdav/routers/shares.py b/mywebdav/routers/shares.py
index 7c37dff..92c2f51 100644
--- a/mywebdav/routers/shares.py
+++ b/mywebdav/routers/shares.py
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from typing import Optional, List
import secrets
-from datetime import datetime, timedelta
+from datetime import datetime
from ..auth import get_current_user
from ..models import User, File, Folder, Share
@@ -17,25 +17,44 @@ router = APIRouter(
tags=["shares"],
)
+
@router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED)
-async def create_share_link(share_in: ShareCreate, current_user: User = Depends(get_current_user)):
+async def create_share_link(
+ share_in: ShareCreate, current_user: User = Depends(get_current_user)
+):
if not share_in.file_id and not share_in.folder_id:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either file_id or folder_id must be provided")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Either file_id or folder_id must be provided",
+ )
if share_in.file_id and share_in.folder_id:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot share both a file and a folder simultaneously")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Cannot share both a file and a folder simultaneously",
+ )
file = None
folder = None
if share_in.file_id:
- file = await File.get_or_none(id=share_in.file_id, owner=current_user, is_deleted=False)
+ file = await File.get_or_none(
+ id=share_in.file_id, owner=current_user, is_deleted=False
+ )
if not file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found or does not belong to you")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="File not found or does not belong to you",
+ )
+
if share_in.folder_id:
- folder = await Folder.get_or_none(id=share_in.folder_id, owner=current_user, is_deleted=False)
+ folder = await Folder.get_or_none(
+ id=share_in.folder_id, owner=current_user, is_deleted=False
+ )
if not folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found or does not belong to you")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Folder not found or does not belong to you",
+ )
token = secrets.token_urlsafe(16)
hashed_password = None
@@ -60,8 +79,14 @@ async def create_share_link(share_in: ShareCreate, current_user: User = Depends(
item_type = "file" if file else "folder"
item_name = file.name if file else folder.name
- expiry_text = f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" if share_in.expires_at else ""
- password_text = f"\n\nPassword: {share_in.password}" if share_in.password else ""
+ expiry_text = (
+ f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}"
+ if share_in.expires_at
+ else ""
+ )
+ password_text = (
+ f"\n\nPassword: {share_in.password}" if share_in.password else ""
+ )
email_body = f"""Hello,
@@ -95,80 +120,103 @@ MyWebdav File Sharing Service"""
to_email=share_in.invite_email,
subject=f"{current_user.username} shared {item_name} with you",
body=email_body,
- html=email_html
+ html=email_html,
)
except Exception as e:
print(f"Failed to send invitation email: {e}")
return await ShareOut.from_tortoise_orm(share)
+
@router.get("/my", response_model=List[ShareOut])
async def list_my_shares(current_user: User = Depends(get_current_user)):
shares = await Share.filter(owner=current_user).order_by("-created_at")
return [await ShareOut.from_tortoise_orm(share) for share in shares]
+
@router.get("/{share_token}", response_model=ShareOut)
async def get_share_link_info(share_token: str):
share = await Share.get_or_none(token=share_token)
if not share:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
+ )
+
if share.expires_at and share.expires_at < datetime.utcnow():
- raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
-
+ raise HTTPException(
+ status_code=status.HTTP_410_GONE, detail="Share link has expired"
+ )
+
# Increment access count
share.access_count += 1
await share.save()
return await ShareOut.from_tortoise_orm(share)
+
@router.put("/{share_id}", response_model=ShareOut)
-async def update_share(share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)):
+async def update_share(
+ share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)
+):
share = await Share.get_or_none(id=share_id, owner=current_user)
if not share:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Share link not found or does not belong to you",
+ )
if share_in.expires_at is not None:
share.expires_at = share_in.expires_at
-
+
if share_in.password is not None:
share.hashed_password = get_password_hash(share_in.password)
share.password_protected = True
- elif share_in.password == "": # Allow clearing password
+ elif share_in.password == "": # Allow clearing password
share.hashed_password = None
share.password_protected = False
if share_in.permission_level is not None:
share.permission_level = share_in.permission_level
-
+
await share.save()
return await ShareOut.from_tortoise_orm(share)
+
@router.post("/{share_token}/access")
async def access_shared_content(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token)
if not share:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
+ )
if share.expires_at and share.expires_at < datetime.utcnow():
- raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
+ raise HTTPException(
+ status_code=status.HTTP_410_GONE, detail="Share link has expired"
+ )
if share.password_protected:
if not password or not verify_password(password, share.hashed_password):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
+ )
result = {"message": "Access granted", "permission_level": share.permission_level}
if share.file_id:
file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
result["file"] = await FileOut.from_tortoise_orm(file)
result["type"] = "file"
elif share.folder_id:
folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False)
if not folder:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
+ )
result["folder"] = await FolderOut.from_tortoise_orm(folder)
result["type"] = "folder"
@@ -179,29 +227,42 @@ async def access_shared_content(share_token: str, password: Optional[str] = None
return result
+
@router.get("/{share_token}/download")
async def download_shared_file(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token)
if not share:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
+ )
if share.expires_at and share.expires_at < datetime.utcnow():
- raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
+ raise HTTPException(
+ status_code=status.HTTP_410_GONE, detail="Share link has expired"
+ )
if share.password_protected:
if not password or not verify_password(password, share.hashed_password):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
+ )
if not share.file_id:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This share is not for a file")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="This share is not for a file",
+ )
file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
owner = await User.get(id=file.owner_id)
try:
+
async def file_iterator():
async for chunk in storage_manager.get_file(owner.id, file.path):
yield chunk
@@ -209,18 +270,22 @@ async def download_shared_file(share_token: str, password: Optional[str] = None)
return StreamingResponse(
content=file_iterator(),
media_type=file.mime_type,
- headers={
- "Content-Disposition": f'attachment; filename="{file.name}"'
- }
+ headers={"Content-Disposition": f'attachment; filename="{file.name}"'},
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage")
+
@router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
-async def delete_share_link(share_id: int, current_user: User = Depends(get_current_user)):
+async def delete_share_link(
+ share_id: int, current_user: User = Depends(get_current_user)
+):
share = await Share.get_or_none(id=share_id, owner=current_user)
if not share:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Share link not found or does not belong to you",
+ )
+
await share.delete()
return
diff --git a/mywebdav/routers/starred.py b/mywebdav/routers/starred.py
index d47a166..a44e31a 100644
--- a/mywebdav/routers/starred.py
+++ b/mywebdav/routers/starred.py
@@ -11,18 +11,29 @@ router = APIRouter(
tags=["starred"],
)
+
@router.get("/files", response_model=List[FileOut])
async def list_starred_files(current_user: User = Depends(get_current_user)):
- files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
+ files = await File.filter(
+ owner=current_user, is_starred=True, is_deleted=False
+ ).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files]
+
@router.get("/folders", response_model=List[FolderOut])
async def list_starred_folders(current_user: User = Depends(get_current_user)):
- folders = await Folder.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
+ folders = await Folder.filter(
+ owner=current_user, is_starred=True, is_deleted=False
+ ).order_by("name")
return [await FolderOut.from_tortoise_orm(f) for f in folders]
-@router.get("/all", response_model=List[FileOut]) # This will return files and folders as files for now
+
+@router.get(
+ "/all", response_model=List[FileOut]
+) # This will return files and folders as files for now
async def list_all_starred(current_user: User = Depends(get_current_user)):
- starred_files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
+ starred_files = await File.filter(
+ owner=current_user, is_starred=True, is_deleted=False
+ ).order_by("name")
# For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema.
- return [await FileOut.from_tortoise_orm(f) for f in starred_files]
\ No newline at end of file
+ return [await FileOut.from_tortoise_orm(f) for f in starred_files]
diff --git a/mywebdav/routers/users.py b/mywebdav/routers/users.py
index 9bb2010..ef134a0 100644
--- a/mywebdav/routers/users.py
+++ b/mywebdav/routers/users.py
@@ -1,17 +1,19 @@
from fastapi import APIRouter, Depends
from ..auth import get_current_user
from ..models import User_Pydantic, User, File, Folder
-from typing import List, Dict, Any
+from typing import Dict, Any
router = APIRouter(
prefix="/users",
tags=["users"],
)
+
@router.get("/me", response_model=User_Pydantic)
async def read_users_me(current_user: User = Depends(get_current_user)):
return await User_Pydantic.from_tortoise_orm(current_user)
+
@router.get("/me/export", response_model=Dict[str, Any])
async def export_my_data(current_user: User = Depends(get_current_user)):
"""
@@ -35,6 +37,7 @@ async def export_my_data(current_user: User = Depends(get_current_user)):
# share information, etc., would also be included.
}
+
@router.delete("/me", status_code=204)
async def delete_my_account(current_user: User = Depends(get_current_user)):
"""
@@ -48,5 +51,3 @@ async def delete_my_account(current_user: User = Depends(get_current_user)):
# Finally, delete the user account
await current_user.delete()
return {}
-
-
diff --git a/mywebdav/schemas.py b/mywebdav/schemas.py
index 9eb52b9..db9e83f 100644
--- a/mywebdav/schemas.py
+++ b/mywebdav/schemas.py
@@ -2,19 +2,24 @@ from datetime import datetime
from pydantic import BaseModel, EmailStr, ConfigDict
from typing import Optional, List
from tortoise.contrib.pydantic import pydantic_model_creator
+from mywebdav.models import Folder, File, Share, FileVersion
+
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
+
class UserLogin(BaseModel):
username: str
password: str
+
class UserLoginWith2FA(UserLogin):
two_factor_code: Optional[str] = None
+
class UserAdminUpdate(BaseModel):
username: Optional[str] = None
email: Optional[EmailStr] = None
@@ -25,22 +30,27 @@ class UserAdminUpdate(BaseModel):
plan_type: Optional[str] = None
is_2fa_enabled: Optional[bool] = None
+
class Token(BaseModel):
access_token: str
token_type: str
+
class TokenData(BaseModel):
username: str | None = None
two_factor_verified: bool = False
+
class FolderCreate(BaseModel):
name: str
parent_id: Optional[int] = None
+
class FolderUpdate(BaseModel):
name: Optional[str] = None
parent_id: Optional[int] = None
+
class ShareCreate(BaseModel):
file_id: Optional[int] = None
folder_id: Optional[int] = None
@@ -49,9 +59,11 @@ class ShareCreate(BaseModel):
permission_level: str = "viewer"
invite_email: Optional[EmailStr] = None
+
class TeamCreate(BaseModel):
name: str
+
class TeamOut(BaseModel):
id: int
name: str
@@ -60,6 +72,7 @@ class TeamOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
+
class ActivityOut(BaseModel):
id: int
user_id: Optional[int] = None
@@ -71,12 +84,14 @@ class ActivityOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
+
class FileRequestCreate(BaseModel):
title: str
description: Optional[str] = None
target_folder_id: int
expires_at: Optional[datetime] = None
+
class FileRequestOut(BaseModel):
id: int
title: str
@@ -90,25 +105,28 @@ class FileRequestOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
-from mywebdav.models import Folder, File, Share, FileVersion
FolderOut = pydantic_model_creator(Folder, name="FolderOut")
FileOut = pydantic_model_creator(File, name="FileOut")
ShareOut = pydantic_model_creator(Share, name="ShareOut")
FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut")
+
class ErrorResponse(BaseModel):
code: int
message: str
details: Optional[str] = None
+
class BatchFileOperation(BaseModel):
file_ids: List[int]
- operation: str # e.g., "delete", "move", "copy", "star", "unstar"
+ operation: str # e.g., "delete", "move", "copy", "star", "unstar"
+
class BatchFolderOperation(BaseModel):
folder_ids: List[int]
- operation: str # e.g., "delete", "move", "star", "unstar"
+ operation: str # e.g., "delete", "move", "star", "unstar"
+
class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None
diff --git a/mywebdav/settings.py b/mywebdav/settings.py
index 076b4a3..3186ede 100644
--- a/mywebdav/settings.py
+++ b/mywebdav/settings.py
@@ -2,8 +2,9 @@ import os
import sys
from pydantic_settings import BaseSettings, SettingsConfigDict
+
class Settings(BaseSettings):
- model_config = SettingsConfigDict(env_file='.env', extra='ignore')
+ model_config = SettingsConfigDict(env_file=".env", extra="ignore")
DATABASE_URL: str = "sqlite:///app/mywebdav.db"
REDIS_URL: str = "redis://redis:6379/0"
@@ -30,8 +31,14 @@ class Settings(BaseSettings):
STRIPE_WEBHOOK_SECRET: str = ""
BILLING_ENABLED: bool = False
+
settings = Settings()
-if settings.SECRET_KEY == "super_secret_key" and os.getenv("ENVIRONMENT") == "production":
- print("ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable.")
+if (
+ settings.SECRET_KEY == "super_secret_key"
+ and os.getenv("ENVIRONMENT") == "production"
+):
+ print(
+ "ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable."
+ )
sys.exit(1)
diff --git a/mywebdav/storage.py b/mywebdav/storage.py
index 94b8cc9..a13f40f 100644
--- a/mywebdav/storage.py
+++ b/mywebdav/storage.py
@@ -5,6 +5,7 @@ from typing import AsyncGenerator
from .settings import settings
+
class StorageManager:
def __init__(self, base_path: str = settings.STORAGE_PATH):
self.base_path = Path(base_path)
@@ -12,7 +13,11 @@ class StorageManager:
async def _get_full_path(self, user_id: int, file_path: str) -> Path:
# Ensure file_path is relative and safe
- relative_path = Path(file_path).relative_to('/') if str(file_path).startswith('/') else Path(file_path)
+ relative_path = (
+ Path(file_path).relative_to("/")
+ if str(file_path).startswith("/")
+ else Path(file_path)
+ )
full_path = self.base_path / str(user_id) / relative_path
full_path.parent.mkdir(parents=True, exist_ok=True)
return full_path
@@ -27,7 +32,7 @@ class StorageManager:
full_path = await self._get_full_path(user_id, file_path)
if not full_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
-
+
async with aiofiles.open(full_path, "rb") as f:
while chunk := await f.read(8192):
yield chunk
@@ -52,4 +57,5 @@ class StorageManager:
full_path = await self._get_full_path(user_id, file_path)
return full_path.exists()
+
storage_manager = StorageManager()
diff --git a/mywebdav/thumbnails.py b/mywebdav/thumbnails.py
index 2fd8b8b..f13405c 100644
--- a/mywebdav/thumbnails.py
+++ b/mywebdav/thumbnails.py
@@ -1,4 +1,3 @@
-import os
import asyncio
from pathlib import Path
from PIL import Image
@@ -9,7 +8,10 @@ from .settings import settings
THUMBNAIL_SIZE = (300, 300)
THUMBNAIL_DIR = "thumbnails"
-async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Optional[str]:
+
+async def generate_thumbnail(
+ file_path: str, mime_type: str, user_id: int
+) -> Optional[str]:
try:
if mime_type.startswith("image/"):
return await generate_image_thumbnail(file_path, user_id)
@@ -20,6 +22,7 @@ async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Op
print(f"Error generating thumbnail for {file_path}: {e}")
return None
+
async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop()
@@ -30,12 +33,16 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
file_name = Path(file_path).name
thumbnail_name = f"thumb_{file_name}"
- if not thumbnail_name.lower().endswith(('.jpg', '.jpeg', '.png')):
+ if not thumbnail_name.lower().endswith((".jpg", ".jpeg", ".png")):
thumbnail_name += ".jpg"
thumbnail_path = thumbnail_dir / thumbnail_name
- actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
+ actual_file_path = (
+ base_path / str(user_id) / file_path
+ if not Path(file_path).is_absolute()
+ else Path(file_path)
+ )
with Image.open(actual_file_path) as img:
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
@@ -44,7 +51,9 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == "P":
img = img.convert("RGBA")
- background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
+ background.paste(
+ img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None
+ )
img = background
img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True)
@@ -53,6 +62,7 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
return await loop.run_in_executor(None, _generate)
+
async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop()
@@ -65,17 +75,29 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
thumbnail_name = f"thumb_{file_name}.jpg"
thumbnail_path = thumbnail_dir / thumbnail_name
- actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
+ actual_file_path = (
+ base_path / str(user_id) / file_path
+ if not Path(file_path).is_absolute()
+ else Path(file_path)
+ )
- subprocess.run([
- "ffmpeg",
- "-i", str(actual_file_path),
- "-ss", "00:00:01",
- "-vframes", "1",
- "-vf", f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
- "-y",
- str(thumbnail_path)
- ], check=True, capture_output=True)
+ subprocess.run(
+ [
+ "ffmpeg",
+ "-i",
+ str(actual_file_path),
+ "-ss",
+ "00:00:01",
+ "-vframes",
+ "1",
+ "-vf",
+ f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
+ "-y",
+ str(thumbnail_path),
+ ],
+ check=True,
+ capture_output=True,
+ )
return str(thumbnail_path.relative_to(base_path / str(user_id)))
@@ -84,6 +106,7 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
except subprocess.CalledProcessError:
return None
+
async def delete_thumbnail(thumbnail_path: str, user_id: int):
try:
base_path = Path(settings.STORAGE_PATH)
diff --git a/mywebdav/two_factor.py b/mywebdav/two_factor.py
index f9c3c5a..b71b58a 100644
--- a/mywebdav/two_factor.py
+++ b/mywebdav/two_factor.py
@@ -5,15 +5,20 @@ import base64
import secrets
import hashlib
-from typing import List, Optional
+from typing import List
+
def generate_totp_secret() -> str:
"""Generates a random base32 TOTP secret."""
return pyotp.random_base32()
+
def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str:
"""Generates a Google Authenticator-compatible TOTP URI."""
- return pyotp.totp.TOTP(secret).provisioning_uri(name=account_name, issuer_name=issuer_name)
+ return pyotp.totp.TOTP(secret).provisioning_uri(
+ name=account_name, issuer_name=issuer_name
+ )
+
def generate_qr_code_base64(uri: str) -> str:
"""Generates a base64 encoded QR code image for a given URI."""
@@ -31,27 +36,33 @@ def generate_qr_code_base64(uri: str) -> str:
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
def verify_totp_code(secret: str, code: str) -> bool:
"""Verifies a TOTP code against a secret."""
totp = pyotp.TOTP(secret)
return totp.verify(code)
+
def generate_recovery_codes(num_codes: int = 10) -> List[str]:
"""Generates a list of random recovery codes."""
return [secrets.token_urlsafe(16) for _ in range(num_codes)]
+
def hash_recovery_code(code: str) -> str:
"""Hashes a single recovery code using SHA256."""
- return hashlib.sha256(code.encode('utf-8')).hexdigest()
+ return hashlib.sha256(code.encode("utf-8")).hexdigest()
+
def verify_recovery_code(plain_code: str, hashed_code: str) -> bool:
"""Verifies a plain recovery code against its hashed version."""
return hash_recovery_code(plain_code) == hashed_code
+
def hash_recovery_codes(codes: List[str]) -> List[str]:
"""Hashes a list of recovery codes."""
return [hash_recovery_code(code) for code in codes]
+
def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool:
"""Verifies if a plain recovery code matches any of the hashed recovery codes."""
for hashed_code in hashed_codes:
diff --git a/mywebdav/webdav.py b/mywebdav/webdav.py
index b236917..f9ef5a9 100644
--- a/mywebdav/webdav.py
+++ b/mywebdav/webdav.py
@@ -19,6 +19,7 @@ router = APIRouter(
tags=["webdav"],
)
+
class WebDAVLock:
locks = {}
@@ -26,10 +27,10 @@ class WebDAVLock:
def create_lock(cls, path: str, user_id: int, timeout: int = 3600):
lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}"
cls.locks[path] = {
- 'token': lock_token,
- 'user_id': user_id,
- 'created_at': datetime.now(),
- 'timeout': timeout
+ "token": lock_token,
+ "user_id": user_id,
+ "created_at": datetime.now(),
+ "timeout": timeout,
}
return lock_token
@@ -42,17 +43,18 @@ class WebDAVLock:
if path in cls.locks:
del cls.locks[path]
+
async def basic_auth(authorization: Optional[str] = Header(None)):
if not authorization:
return None
try:
scheme, credentials = authorization.split()
- if scheme.lower() != 'basic':
+ if scheme.lower() != "basic":
return None
- decoded = base64.b64decode(credentials).decode('utf-8')
- username, password = decoded.split(':', 1)
+ decoded = base64.b64decode(credentials).decode("utf-8")
+ username, password = decoded.split(":", 1)
user = await User.get_or_none(username=username)
if user and verify_password(password, user.hashed_password):
@@ -62,6 +64,7 @@ async def basic_auth(authorization: Optional[str] = Header(None)):
return None
+
async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)):
user = await basic_auth(authorization)
if user:
@@ -75,14 +78,15 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
- headers={'WWW-Authenticate': 'Basic realm="MyWebdav WebDAV"'}
+ headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'},
)
+
async def resolve_path(path_str: str, user: User):
- if not path_str or path_str == '/':
+ if not path_str or path_str == "/":
return None, None, True
- parts = [p for p in path_str.split('/') if p]
+ parts = [p for p in path_str.split("/") if p]
if not parts:
return None, None, True
@@ -90,10 +94,7 @@ async def resolve_path(path_str: str, user: User):
current_folder = None
for i, part in enumerate(parts[:-1]):
folder = await Folder.get_or_none(
- name=part,
- parent=current_folder,
- owner=user,
- is_deleted=False
+ name=part, parent=current_folder, owner=user, is_deleted=False
)
if not folder:
return None, None, False
@@ -102,39 +103,37 @@ async def resolve_path(path_str: str, user: User):
last_part = parts[-1]
folder = await Folder.get_or_none(
- name=last_part,
- parent=current_folder,
- owner=user,
- is_deleted=False
+ name=last_part, parent=current_folder, owner=user, is_deleted=False
)
if folder:
return folder, current_folder, True
file = await File.get_or_none(
- name=last_part,
- parent=current_folder,
- owner=user,
- is_deleted=False
+ name=last_part, parent=current_folder, owner=user, is_deleted=False
)
if file:
return file, current_folder, True
return None, current_folder, True
+
def build_href(base_path: str, name: str, is_collection: bool):
path = f"{base_path.rstrip('/')}/{name}"
if is_collection:
- path += '/'
+ path += "/"
return path
+
async def get_custom_properties(resource_type: str, resource_id: int):
props = await WebDAVProperty.filter(
- resource_type=resource_type,
- resource_id=resource_id
+ resource_type=resource_type, resource_id=resource_id
)
return {(prop.namespace, prop.name): prop.value for prop in props}
-def create_propstat_element(props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"):
+
+def create_propstat_element(
+ props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"
+):
propstat = ET.Element("D:propstat")
prop = ET.SubElement(propstat, "D:prop")
@@ -151,10 +150,10 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
elem.text = value
elif key == "getlastmodified":
elem = ET.SubElement(prop, f"D:{key}")
- elem.text = value.strftime('%a, %d %b %Y %H:%M:%S GMT')
+ elem.text = value.strftime("%a, %d %b %Y %H:%M:%S GMT")
elif key == "creationdate":
elem = ET.SubElement(prop, f"D:{key}")
- elem.text = value.isoformat() + 'Z'
+ elem.text = value.isoformat() + "Z"
elif key == "displayname":
elem = ET.SubElement(prop, f"D:{key}")
elem.text = value
@@ -174,6 +173,7 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
return propstat
+
def parse_propfind_body(body: bytes):
if not body:
return None
@@ -193,8 +193,8 @@ def parse_propfind_body(body: bytes):
if prop is not None:
requested_props = []
for child in prop:
- ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
- name = child.tag.split('}')[1] if '}' in child.tag else child.tag
+ ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
+ name = child.tag.split("}")[1] if "}" in child.tag else child.tag
requested_props.append((ns, name))
return requested_props
except ET.ParseError:
@@ -202,6 +202,7 @@ def parse_propfind_body(body: bytes):
return None
+
@router.api_route("/{full_path:path}", methods=["OPTIONS"])
async def webdav_options(full_path: str):
return Response(
@@ -209,14 +210,17 @@ async def webdav_options(full_path: str):
headers={
"DAV": "1, 2",
"Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK",
- "MS-Author-Via": "DAV"
- }
+ "MS-Author-Via": "DAV",
+ },
)
+
@router.api_route("/{full_path:path}", methods=["PROPFIND"])
-async def handle_propfind(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
+async def handle_propfind(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
depth = request.headers.get("Depth", "1")
- full_path = unquote(full_path).strip('/')
+ full_path = unquote(full_path).strip("/")
body = await request.body()
requested_props = parse_propfind_body(body)
@@ -232,19 +236,23 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
if resource is None:
response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href")
- href.text = base_href if base_href.endswith('/') else base_href + '/'
+ href.text = base_href if base_href.endswith("/") else base_href + "/"
props = {
"resourcetype": "collection",
- "displayname": full_path.split('/')[-1] if full_path else "Root",
+ "displayname": full_path.split("/")[-1] if full_path else "Root",
"creationdate": datetime.now(),
- "getlastmodified": datetime.now()
+ "getlastmodified": datetime.now(),
}
response.append(create_propstat_element(props))
if depth in ["1", "infinity"]:
- folders = await Folder.filter(owner=current_user, parent=parent, is_deleted=False)
- files = await File.filter(owner=current_user, parent=parent, is_deleted=False)
+ folders = await Folder.filter(
+ owner=current_user, parent=parent, is_deleted=False
+ )
+ files = await File.filter(
+ owner=current_user, parent=parent, is_deleted=False
+ )
for folder in folders:
response = ET.SubElement(multistatus, "D:response")
@@ -255,9 +263,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"resourcetype": "collection",
"displayname": folder.name,
"creationdate": folder.created_at,
- "getlastmodified": folder.updated_at
+ "getlastmodified": folder.updated_at,
}
- custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
+ custom_props = (
+ await get_custom_properties("folder", folder.id)
+ if requested_props == "allprop"
+ or (
+ isinstance(requested_props, list)
+ and any(ns != "DAV:" for ns, _ in requested_props)
+ )
+ else None
+ )
response.append(create_propstat_element(props, custom_props))
for file in files:
@@ -272,28 +288,48 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": file.mime_type,
"creationdate": file.created_at,
"getlastmodified": file.updated_at,
- "getetag": f'"{file.file_hash}"'
+ "getetag": f'"{file.file_hash}"',
}
- custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
+ custom_props = (
+ await get_custom_properties("file", file.id)
+ if requested_props == "allprop"
+ or (
+ isinstance(requested_props, list)
+ and any(ns != "DAV:" for ns, _ in requested_props)
+ )
+ else None
+ )
response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, Folder):
response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href")
- href.text = base_href if base_href.endswith('/') else base_href + '/'
+ href.text = base_href if base_href.endswith("/") else base_href + "/"
props = {
"resourcetype": "collection",
"displayname": resource.name,
"creationdate": resource.created_at,
- "getlastmodified": resource.updated_at
+ "getlastmodified": resource.updated_at,
}
- custom_props = await get_custom_properties("folder", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
+ custom_props = (
+ await get_custom_properties("folder", resource.id)
+ if requested_props == "allprop"
+ or (
+ isinstance(requested_props, list)
+ and any(ns != "DAV:" for ns, _ in requested_props)
+ )
+ else None
+ )
response.append(create_propstat_element(props, custom_props))
if depth in ["1", "infinity"]:
- folders = await Folder.filter(owner=current_user, parent=resource, is_deleted=False)
- files = await File.filter(owner=current_user, parent=resource, is_deleted=False)
+ folders = await Folder.filter(
+ owner=current_user, parent=resource, is_deleted=False
+ )
+ files = await File.filter(
+ owner=current_user, parent=resource, is_deleted=False
+ )
for folder in folders:
response = ET.SubElement(multistatus, "D:response")
@@ -304,9 +340,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"resourcetype": "collection",
"displayname": folder.name,
"creationdate": folder.created_at,
- "getlastmodified": folder.updated_at
+ "getlastmodified": folder.updated_at,
}
- custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
+ custom_props = (
+ await get_custom_properties("folder", folder.id)
+ if requested_props == "allprop"
+ or (
+ isinstance(requested_props, list)
+ and any(ns != "DAV:" for ns, _ in requested_props)
+ )
+ else None
+ )
response.append(create_propstat_element(props, custom_props))
for file in files:
@@ -321,9 +365,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": file.mime_type,
"creationdate": file.created_at,
"getlastmodified": file.updated_at,
- "getetag": f'"{file.file_hash}"'
+ "getetag": f'"{file.file_hash}"',
}
- custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
+ custom_props = (
+ await get_custom_properties("file", file.id)
+ if requested_props == "allprop"
+ or (
+ isinstance(requested_props, list)
+ and any(ns != "DAV:" for ns, _ in requested_props)
+ )
+ else None
+ )
response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, File):
@@ -338,17 +390,32 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": resource.mime_type,
"creationdate": resource.created_at,
"getlastmodified": resource.updated_at,
- "getetag": f'"{resource.file_hash}"'
+ "getetag": f'"{resource.file_hash}"',
}
- custom_props = await get_custom_properties("file", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
+ custom_props = (
+ await get_custom_properties("file", resource.id)
+ if requested_props == "allprop"
+ or (
+ isinstance(requested_props, list)
+ and any(ns != "DAV:" for ns, _ in requested_props)
+ )
+ else None
+ )
response.append(create_propstat_element(props, custom_props))
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
- return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
+ return Response(
+ content=xml_content,
+ media_type="application/xml; charset=utf-8",
+ status_code=207,
+ )
+
@router.api_route("/{full_path:path}", methods=["GET", "HEAD"])
-async def handle_get(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_get(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user)
@@ -363,8 +430,10 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
"Content-Length": str(resource.size),
"Content-Type": resource.mime_type,
"ETag": f'"{resource.file_hash}"',
- "Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
- }
+ "Last-Modified": resource.updated_at.strftime(
+ "%a, %d %b %Y %H:%M:%S GMT"
+ ),
+ },
)
async def file_iterator():
@@ -377,23 +446,28 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
headers={
"Content-Disposition": f'attachment; filename="{resource.name}"',
"ETag": f'"{resource.file_hash}"',
- "Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
- }
+ "Last-Modified": resource.updated_at.strftime(
+ "%a, %d %b %Y %H:%M:%S GMT"
+ ),
+ },
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage")
+
@router.api_route("/{full_path:path}", methods=["PUT"])
-async def handle_put(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_put(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
if not full_path:
raise HTTPException(status_code=400, detail="Cannot PUT to root")
- parts = [p for p in full_path.split('/') if p]
+ parts = [p for p in full_path.split("/") if p]
file_name = parts[-1]
- parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
+ parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
_, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists:
@@ -417,10 +491,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
mime_type = "application/octet-stream"
existing_file = await File.get_or_none(
- name=file_name,
- parent=parent_folder,
- owner=current_user,
- is_deleted=False
+ name=file_name, parent=parent_folder, owner=current_user, is_deleted=False
)
if existing_file:
@@ -432,7 +503,9 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
existing_file.updated_at = datetime.now()
await existing_file.save()
- current_user.used_storage_bytes = current_user.used_storage_bytes - old_size + file_size
+ current_user.used_storage_bytes = (
+ current_user.used_storage_bytes - old_size + file_size
+ )
await current_user.save()
await log_activity(current_user, "file_updated", "file", existing_file.id)
@@ -445,7 +518,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
mime_type=mime_type,
file_hash=file_hash,
owner=current_user,
- parent=parent_folder
+ parent=parent_folder,
)
current_user.used_storage_bytes += file_size
@@ -454,9 +527,12 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
await log_activity(current_user, "file_created", "file", db_file.id)
return Response(status_code=201)
+
@router.api_route("/{full_path:path}", methods=["DELETE"])
-async def handle_delete(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_delete(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
if not full_path:
raise HTTPException(status_code=400, detail="Cannot DELETE root")
@@ -478,44 +554,45 @@ async def handle_delete(request: Request, full_path: str, current_user: User = D
return Response(status_code=204)
+
@router.api_route("/{full_path:path}", methods=["MKCOL"])
-async def handle_mkcol(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_mkcol(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
if not full_path:
raise HTTPException(status_code=400, detail="Cannot MKCOL at root")
- parts = [p for p in full_path.split('/') if p]
+ parts = [p for p in full_path.split("/") if p]
folder_name = parts[-1]
- parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
+ parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
_, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists:
raise HTTPException(status_code=409, detail="Parent folder does not exist")
existing = await Folder.get_or_none(
- name=folder_name,
- parent=parent_folder,
- owner=current_user,
- is_deleted=False
+ name=folder_name, parent=parent_folder, owner=current_user, is_deleted=False
)
if existing:
raise HTTPException(status_code=405, detail="Folder already exists")
folder = await Folder.create(
- name=folder_name,
- parent=parent_folder,
- owner=current_user
+ name=folder_name, parent=parent_folder, owner=current_user
)
await log_activity(current_user, "folder_created", "folder", folder.id)
return Response(status_code=201)
+
@router.api_route("/{full_path:path}", methods=["COPY"])
-async def handle_copy(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_copy(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T")
@@ -523,7 +600,7 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path)
- dest_path = dest_path.replace('/webdav/', '').strip('/')
+ dest_path = dest_path.replace("/webdav/", "").strip("/")
source_resource, _, exists = await resolve_path(full_path, current_user)
@@ -533,9 +610,9 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
if not isinstance(source_resource, File):
raise HTTPException(status_code=501, detail="Only file copy is implemented")
- dest_parts = [p for p in dest_path.split('/') if p]
+ dest_parts = [p for p in dest_path.split("/") if p]
dest_name = dest_parts[-1]
- dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
+ dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@@ -543,14 +620,13 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=409, detail="Destination parent does not exist")
existing_dest = await File.get_or_none(
- name=dest_name,
- parent=dest_parent,
- owner=current_user,
- is_deleted=False
+ name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
)
if existing_dest and overwrite == "F":
- raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
+ raise HTTPException(
+ status_code=412, detail="Destination exists and overwrite is false"
+ )
if existing_dest:
await existing_dest.delete()
@@ -562,15 +638,18 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
mime_type=source_resource.mime_type,
file_hash=source_resource.file_hash,
owner=current_user,
- parent=dest_parent
+ parent=dest_parent,
)
await log_activity(current_user, "file_copied", "file", new_file.id)
return Response(status_code=201 if not existing_dest else 204)
+
@router.api_route("/{full_path:path}", methods=["MOVE"])
-async def handle_move(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_move(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T")
@@ -578,16 +657,16 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path)
- dest_path = dest_path.replace('/webdav/', '').strip('/')
+ dest_path = dest_path.replace("/webdav/", "").strip("/")
source_resource, _, exists = await resolve_path(full_path, current_user)
if not source_resource:
raise HTTPException(status_code=404, detail="Source not found")
- dest_parts = [p for p in dest_path.split('/') if p]
+ dest_parts = [p for p in dest_path.split("/") if p]
dest_name = dest_parts[-1]
- dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
+ dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@@ -596,14 +675,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
if isinstance(source_resource, File):
existing_dest = await File.get_or_none(
- name=dest_name,
- parent=dest_parent,
- owner=current_user,
- is_deleted=False
+ name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
)
if existing_dest and overwrite == "F":
- raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
+ raise HTTPException(
+ status_code=412, detail="Destination exists and overwrite is false"
+ )
if existing_dest:
await existing_dest.delete()
@@ -617,14 +695,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
elif isinstance(source_resource, Folder):
existing_dest = await Folder.get_or_none(
- name=dest_name,
- parent=dest_parent,
- owner=current_user,
- is_deleted=False
+ name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
)
if existing_dest and overwrite == "F":
- raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
+ raise HTTPException(
+ status_code=412, detail="Destination exists and overwrite is false"
+ )
if existing_dest:
existing_dest.is_deleted = True
@@ -637,9 +714,12 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
await log_activity(current_user, "folder_moved", "folder", source_resource.id)
return Response(status_code=201 if not existing_dest else 204)
+
@router.api_route("/{full_path:path}", methods=["LOCK"])
-async def handle_lock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_lock(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
timeout_header = request.headers.get("Timeout", "Second-3600")
timeout = 3600
@@ -681,32 +761,38 @@ async def handle_lock(request: Request, full_path: str, current_user: User = Dep
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=200,
- headers={"Lock-Token": f"<{lock_token}>"}
+ headers={"Lock-Token": f"<{lock_token}>"},
)
+
@router.api_route("/{full_path:path}", methods=["UNLOCK"])
-async def handle_unlock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_unlock(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
lock_token_header = request.headers.get("Lock-Token")
if not lock_token_header:
raise HTTPException(status_code=400, detail="Lock-Token header required")
- lock_token = lock_token_header.strip('<>')
+ lock_token = lock_token_header.strip("<>")
existing_lock = WebDAVLock.get_lock(full_path)
- if not existing_lock or existing_lock['token'] != lock_token:
+ if not existing_lock or existing_lock["token"] != lock_token:
raise HTTPException(status_code=409, detail="Invalid lock token")
- if existing_lock['user_id'] != current_user.id:
+ if existing_lock["user_id"] != current_user.id:
raise HTTPException(status_code=403, detail="Not lock owner")
WebDAVLock.remove_lock(full_path)
return Response(status_code=204)
+
@router.api_route("/{full_path:path}", methods=["PROPPATCH"])
-async def handle_proppatch(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
- full_path = unquote(full_path).strip('/')
+async def handle_proppatch(
+ request: Request, full_path: str, current_user: User = Depends(webdav_auth)
+):
+ full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user)
@@ -719,7 +805,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
try:
root = ET.fromstring(body)
- except:
+ except Exception:
raise HTTPException(status_code=400, detail="Invalid XML")
resource_type = "file" if isinstance(resource, File) else "folder"
@@ -734,13 +820,19 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
prop_element = set_element.find(".//{DAV:}prop")
if prop_element is not None:
for child in prop_element:
- ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
- name = child.tag.split('}')[1] if '}' in child.tag else child.tag
+ ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
+ name = child.tag.split("}")[1] if "}" in child.tag else child.tag
value = child.text or ""
if ns == "DAV:":
- live_props = ["creationdate", "getcontentlength", "getcontenttype",
- "getetag", "getlastmodified", "resourcetype"]
+ live_props = [
+ "creationdate",
+ "getcontentlength",
+ "getcontenttype",
+ "getetag",
+ "getlastmodified",
+ "resourcetype",
+ ]
if name in live_props:
failed_props.append((ns, name, "409 Conflict"))
continue
@@ -750,7 +842,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_type=resource_type,
resource_id=resource_id,
namespace=ns,
- name=name
+ name=name,
)
if existing_prop:
@@ -762,10 +854,10 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_id=resource_id,
namespace=ns,
name=name,
- value=value
+ value=value,
)
set_props.append((ns, name))
- except Exception as e:
+ except Exception:
failed_props.append((ns, name, "500 Internal Server Error"))
remove_element = root.find(".//{DAV:}remove")
@@ -773,8 +865,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
prop_element = remove_element.find(".//{DAV:}prop")
if prop_element is not None:
for child in prop_element:
- ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
- name = child.tag.split('}')[1] if '}' in child.tag else child.tag
+ ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
+ name = child.tag.split("}")[1] if "}" in child.tag else child.tag
if ns == "DAV:":
failed_props.append((ns, name, "409 Conflict"))
@@ -785,7 +877,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_type=resource_type,
resource_id=resource_id,
namespace=ns,
- name=name
+ name=name,
)
if existing_prop:
@@ -793,7 +885,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
remove_props.append((ns, name))
else:
failed_props.append((ns, name, "404 Not Found"))
- except Exception as e:
+ except Exception:
failed_props.append((ns, name, "500 Internal Server Error"))
multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"})
@@ -837,4 +929,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
await log_activity(current_user, "properties_modified", resource_type, resource_id)
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
- return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
+ return Response(
+ content=xml_content,
+ media_type="application/xml; charset=utf-8",
+ status_code=207,
+ )
diff --git a/tests/billing/test_api_endpoints.py b/tests/billing/test_api_endpoints.py
index e5e7405..be2c948 100644
--- a/tests/billing/test_api_endpoints.py
+++ b/tests/billing/test_api_endpoints.py
@@ -6,9 +6,15 @@ from httpx import AsyncClient, ASGITransport
from fastapi import status
from mywebdav.main import app
from mywebdav.models import User
-from mywebdav.billing.models import PricingConfig, Invoice, UsageAggregate, UserSubscription
+from mywebdav.billing.models import (
+ PricingConfig,
+ Invoice,
+ UsageAggregate,
+ UserSubscription,
+)
from mywebdav.auth import create_access_token
+
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
@@ -16,11 +22,12 @@ async def test_user():
email="test@example.com",
hashed_password="hashed_password_here",
is_active=True,
- is_superuser=False
+ is_superuser=False,
)
yield user
await user.delete()
+
@pytest_asyncio.fixture
async def admin_user():
user = await User.create(
@@ -28,58 +35,72 @@ async def admin_user():
email="admin@example.com",
hashed_password="hashed_password_here",
is_active=True,
- is_superuser=True
+ is_superuser=True,
)
yield user
await user.delete()
+
@pytest_asyncio.fixture
async def auth_token(test_user):
token = create_access_token(data={"sub": test_user.username})
return token
+
@pytest_asyncio.fixture
async def admin_token(admin_user):
token = create_access_token(data={"sub": admin_user.username})
return token
+
@pytest_asyncio.fixture
async def pricing_config():
configs = []
- configs.append(await PricingConfig.create(
- config_key="storage_per_gb_month",
- config_value=Decimal("0.0045"),
- description="Storage cost per GB per month",
- unit="per_gb_month"
- ))
- configs.append(await PricingConfig.create(
- config_key="bandwidth_egress_per_gb",
- config_value=Decimal("0.009"),
- description="Bandwidth egress cost per GB",
- unit="per_gb"
- ))
- configs.append(await PricingConfig.create(
- config_key="free_tier_storage_gb",
- config_value=Decimal("15"),
- description="Free tier storage in GB",
- unit="gb"
- ))
- configs.append(await PricingConfig.create(
- config_key="free_tier_bandwidth_gb",
- config_value=Decimal("15"),
- description="Free tier bandwidth in GB per month",
- unit="gb"
- ))
+ configs.append(
+ await PricingConfig.create(
+ config_key="storage_per_gb_month",
+ config_value=Decimal("0.0045"),
+ description="Storage cost per GB per month",
+ unit="per_gb_month",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="bandwidth_egress_per_gb",
+ config_value=Decimal("0.009"),
+ description="Bandwidth egress cost per GB",
+ unit="per_gb",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="free_tier_storage_gb",
+ config_value=Decimal("15"),
+ description="Free tier storage in GB",
+ unit="gb",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="free_tier_bandwidth_gb",
+ config_value=Decimal("15"),
+ description="Free tier bandwidth in GB per month",
+ unit="gb",
+ )
+ )
yield configs
for config in configs:
await config.delete()
+
@pytest.mark.asyncio
async def test_get_current_usage(test_user, auth_token):
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
"/api/billing/usage/current",
- headers={"Authorization": f"Bearer {auth_token}"}
+ headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -87,6 +108,7 @@ async def test_get_current_usage(test_user, auth_token):
assert "storage_gb" in data
assert "bandwidth_down_gb_today" in data
+
@pytest.mark.asyncio
async def test_get_monthly_usage(test_user, auth_token):
today = date.today()
@@ -94,16 +116,18 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.create(
user=test_user,
date=today,
- storage_bytes_avg=1024 ** 3 * 10,
- storage_bytes_peak=1024 ** 3 * 12,
- bandwidth_up_bytes=1024 ** 3 * 2,
- bandwidth_down_bytes=1024 ** 3 * 5
+ storage_bytes_avg=1024**3 * 10,
+ storage_bytes_peak=1024**3 * 12,
+ bandwidth_up_bytes=1024**3 * 2,
+ bandwidth_down_bytes=1024**3 * 5,
)
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
f"/api/billing/usage/monthly?year={today.year}&month={today.month}",
- headers={"Authorization": f"Bearer {auth_token}"}
+ headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -112,12 +136,15 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.filter(user=test_user).delete()
+
@pytest.mark.asyncio
async def test_get_subscription(test_user, auth_token):
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
"/api/billing/subscription",
- headers={"Authorization": f"Bearer {auth_token}"}
+ headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -127,6 +154,7 @@ async def test_get_subscription(test_user, auth_token):
await UserSubscription.filter(user=test_user).delete()
+
@pytest.mark.asyncio
async def test_list_invoices(test_user, auth_token):
invoice = await Invoice.create(
@@ -137,13 +165,14 @@ async def test_list_invoices(test_user, auth_token):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="open"
+ status="open",
)
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
- "/api/billing/invoices",
- headers={"Authorization": f"Bearer {auth_token}"}
+ "/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"}
)
assert response.status_code == status.HTTP_200_OK
@@ -153,6 +182,7 @@ async def test_list_invoices(test_user, auth_token):
await invoice.delete()
+
@pytest.mark.asyncio
async def test_get_invoice(test_user, auth_token):
invoice = await Invoice.create(
@@ -163,13 +193,15 @@ async def test_get_invoice(test_user, auth_token):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="open"
+ status="open",
)
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
f"/api/billing/invoices/{invoice.id}",
- headers={"Authorization": f"Bearer {auth_token}"}
+ headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -178,39 +210,45 @@ async def test_get_invoice(test_user, auth_token):
await invoice.delete()
+
@pytest.mark.asyncio
async def test_get_pricing():
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get("/api/billing/pricing")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, dict)
+
@pytest.mark.asyncio
async def test_admin_get_pricing(admin_user, admin_token, pricing_config):
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
"/api/admin/billing/pricing",
- headers={"Authorization": f"Bearer {admin_token}"}
+ headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) > 0
+
@pytest.mark.asyncio
async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
config_id = pricing_config[0].id
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.put(
f"/api/admin/billing/pricing/{config_id}",
headers={"Authorization": f"Bearer {admin_token}"},
- json={
- "config_key": "storage_per_gb_month",
- "config_value": 0.005
- }
+ json={"config_key": "storage_per_gb_month", "config_value": 0.005},
)
assert response.status_code == status.HTTP_200_OK
@@ -218,12 +256,15 @@ async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
updated = await PricingConfig.get(id=config_id)
assert updated.config_value == Decimal("0.005")
+
@pytest.mark.asyncio
async def test_admin_get_stats(admin_user, admin_token):
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
"/api/admin/billing/stats",
- headers={"Authorization": f"Bearer {admin_token}"}
+ headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
@@ -232,12 +273,15 @@ async def test_admin_get_stats(admin_user, admin_token):
assert "total_invoices" in data
assert "pending_invoices" in data
+
@pytest.mark.asyncio
async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token):
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
+ async with AsyncClient(
+ transport=ASGITransport(app=app), base_url="http://test"
+ ) as client:
response = await client.get(
"/api/admin/billing/pricing",
- headers={"Authorization": f"Bearer {auth_token}"}
+ headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
diff --git a/tests/billing/test_invoice_generator.py b/tests/billing/test_invoice_generator.py
index 1f52745..72c6cc5 100644
--- a/tests/billing/test_invoice_generator.py
+++ b/tests/billing/test_invoice_generator.py
@@ -3,57 +3,76 @@ import pytest_asyncio
from decimal import Decimal
from datetime import date, datetime
from mywebdav.models import User
-from mywebdav.billing.models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
+from mywebdav.billing.models import (
+ Invoice,
+ InvoiceLineItem,
+ PricingConfig,
+ UsageAggregate,
+ UserSubscription,
+)
from mywebdav.billing.invoice_generator import InvoiceGenerator
+
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
- is_active=True
+ is_active=True,
)
yield user
await user.delete()
+
@pytest_asyncio.fixture
async def pricing_config():
configs = []
- configs.append(await PricingConfig.create(
- config_key="storage_per_gb_month",
- config_value=Decimal("0.0045"),
- description="Storage cost per GB per month",
- unit="per_gb_month"
- ))
- configs.append(await PricingConfig.create(
- config_key="bandwidth_egress_per_gb",
- config_value=Decimal("0.009"),
- description="Bandwidth egress cost per GB",
- unit="per_gb"
- ))
- configs.append(await PricingConfig.create(
- config_key="free_tier_storage_gb",
- config_value=Decimal("15"),
- description="Free tier storage in GB",
- unit="gb"
- ))
- configs.append(await PricingConfig.create(
- config_key="free_tier_bandwidth_gb",
- config_value=Decimal("15"),
- description="Free tier bandwidth in GB per month",
- unit="gb"
- ))
- configs.append(await PricingConfig.create(
- config_key="tax_rate_default",
- config_value=Decimal("0.0"),
- description="Default tax rate",
- unit="percentage"
- ))
+ configs.append(
+ await PricingConfig.create(
+ config_key="storage_per_gb_month",
+ config_value=Decimal("0.0045"),
+ description="Storage cost per GB per month",
+ unit="per_gb_month",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="bandwidth_egress_per_gb",
+ config_value=Decimal("0.009"),
+ description="Bandwidth egress cost per GB",
+ unit="per_gb",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="free_tier_storage_gb",
+ config_value=Decimal("15"),
+ description="Free tier storage in GB",
+ unit="gb",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="free_tier_bandwidth_gb",
+ config_value=Decimal("15"),
+ description="Free tier bandwidth in GB per month",
+ unit="gb",
+ )
+ )
+ configs.append(
+ await PricingConfig.create(
+ config_key="tax_rate_default",
+ config_value=Decimal("0.0"),
+ description="Default tax rate",
+ unit="percentage",
+ )
+ )
yield configs
for config in configs:
await config.delete()
+
@pytest.mark.asyncio
async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
today = date.today()
@@ -61,13 +80,15 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await UsageAggregate.create(
user=test_user,
date=today,
- storage_bytes_avg=1024 ** 3 * 50,
- storage_bytes_peak=1024 ** 3 * 55,
- bandwidth_up_bytes=1024 ** 3 * 10,
- bandwidth_down_bytes=1024 ** 3 * 20
+ storage_bytes_avg=1024**3 * 50,
+ storage_bytes_peak=1024**3 * 55,
+ bandwidth_up_bytes=1024**3 * 10,
+ bandwidth_down_bytes=1024**3 * 20,
)
- invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
+ invoice = await InvoiceGenerator.generate_monthly_invoice(
+ test_user, today.year, today.month
+ )
assert invoice is not None
assert invoice.user_id == test_user.id
@@ -80,6 +101,7 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await invoice.delete()
await UsageAggregate.filter(user=test_user).delete()
+
@pytest.mark.asyncio
async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config):
today = date.today()
@@ -87,18 +109,21 @@ async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_confi
await UsageAggregate.create(
user=test_user,
date=today,
- storage_bytes_avg=1024 ** 3 * 10,
- storage_bytes_peak=1024 ** 3 * 12,
- bandwidth_up_bytes=1024 ** 3 * 5,
- bandwidth_down_bytes=1024 ** 3 * 10
+ storage_bytes_avg=1024**3 * 10,
+ storage_bytes_peak=1024**3 * 12,
+ bandwidth_up_bytes=1024**3 * 5,
+ bandwidth_down_bytes=1024**3 * 10,
)
- invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
+ invoice = await InvoiceGenerator.generate_monthly_invoice(
+ test_user, today.year, today.month
+ )
assert invoice is None
await UsageAggregate.filter(user=test_user).delete()
+
@pytest.mark.asyncio
async def test_finalize_invoice(test_user, pricing_config):
invoice = await Invoice.create(
@@ -109,7 +134,7 @@ async def test_finalize_invoice(test_user, pricing_config):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="draft"
+ status="draft",
)
finalized = await InvoiceGenerator.finalize_invoice(invoice)
@@ -118,6 +143,7 @@ async def test_finalize_invoice(test_user, pricing_config):
await finalized.delete()
+
@pytest.mark.asyncio
async def test_finalize_invoice_already_finalized(test_user, pricing_config):
invoice = await Invoice.create(
@@ -128,7 +154,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="open"
+ status="open",
)
with pytest.raises(ValueError):
@@ -136,6 +162,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
await invoice.delete()
+
@pytest.mark.asyncio
async def test_mark_invoice_paid(test_user, pricing_config):
invoice = await Invoice.create(
@@ -146,7 +173,7 @@ async def test_mark_invoice_paid(test_user, pricing_config):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="open"
+ status="open",
)
paid = await InvoiceGenerator.mark_invoice_paid(invoice)
@@ -156,10 +183,13 @@ async def test_mark_invoice_paid(test_user, pricing_config):
await paid.delete()
+
@pytest.mark.asyncio
async def test_invoice_with_tax(test_user, pricing_config):
# Update tax rate
- updated = await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.21"))
+ updated = await PricingConfig.filter(config_key="tax_rate_default").update(
+ config_value=Decimal("0.21")
+ )
assert updated == 1 # Should update 1 row
# Verify the update worked
@@ -171,13 +201,15 @@ async def test_invoice_with_tax(test_user, pricing_config):
await UsageAggregate.create(
user=test_user,
date=today,
- storage_bytes_avg=1024 ** 3 * 50,
- storage_bytes_peak=1024 ** 3 * 55,
- bandwidth_up_bytes=1024 ** 3 * 10,
- bandwidth_down_bytes=1024 ** 3 * 20
+ storage_bytes_avg=1024**3 * 50,
+ storage_bytes_peak=1024**3 * 55,
+ bandwidth_up_bytes=1024**3 * 10,
+ bandwidth_down_bytes=1024**3 * 20,
)
- invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
+ invoice = await InvoiceGenerator.generate_monthly_invoice(
+ test_user, today.year, today.month
+ )
assert invoice is not None
assert invoice.tax > 0
@@ -185,4 +217,6 @@ async def test_invoice_with_tax(test_user, pricing_config):
await invoice.delete()
await UsageAggregate.filter(user=test_user).delete()
- await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.0"))
+ await PricingConfig.filter(config_key="tax_rate_default").update(
+ config_value=Decimal("0.0")
+ )
diff --git a/tests/billing/test_models.py b/tests/billing/test_models.py
index 49bd21a..9ef51d1 100644
--- a/tests/billing/test_models.py
+++ b/tests/billing/test_models.py
@@ -4,21 +4,30 @@ from decimal import Decimal
from datetime import date, datetime
from mywebdav.models import User
from mywebdav.billing.models import (
- SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
- Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
+ SubscriptionPlan,
+ UserSubscription,
+ UsageRecord,
+ UsageAggregate,
+ Invoice,
+ InvoiceLineItem,
+ PricingConfig,
+ PaymentMethod,
+ BillingEvent,
)
+
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
- is_active=True
+ is_active=True,
)
yield user
await user.delete()
+
@pytest.mark.asyncio
async def test_subscription_plan_creation():
plan = await SubscriptionPlan.create(
@@ -28,7 +37,7 @@ async def test_subscription_plan_creation():
storage_gb=100,
bandwidth_gb=100,
price_monthly=Decimal("5.00"),
- price_yearly=Decimal("50.00")
+ price_yearly=Decimal("50.00"),
)
assert plan.name == "starter"
@@ -36,12 +45,11 @@ async def test_subscription_plan_creation():
assert plan.price_monthly == Decimal("5.00")
await plan.delete()
+
@pytest.mark.asyncio
async def test_user_subscription_creation(test_user):
subscription = await UserSubscription.create(
- user=test_user,
- billing_type="pay_as_you_go",
- status="active"
+ user=test_user, billing_type="pay_as_you_go", status="active"
)
assert subscription.user_id == test_user.id
@@ -49,6 +57,7 @@ async def test_user_subscription_creation(test_user):
assert subscription.status == "active"
await subscription.delete()
+
@pytest.mark.asyncio
async def test_usage_record_creation(test_user):
usage = await UsageRecord.create(
@@ -57,7 +66,7 @@ async def test_usage_record_creation(test_user):
amount_bytes=1024 * 1024 * 100,
resource_type="file",
resource_id=1,
- idempotency_key="test_key_123"
+ idempotency_key="test_key_123",
)
assert usage.user_id == test_user.id
@@ -65,6 +74,7 @@ async def test_usage_record_creation(test_user):
assert usage.amount_bytes == 1024 * 1024 * 100
await usage.delete()
+
@pytest.mark.asyncio
async def test_usage_aggregate_creation(test_user):
aggregate = await UsageAggregate.create(
@@ -73,13 +83,14 @@ async def test_usage_aggregate_creation(test_user):
storage_bytes_avg=1024 * 1024 * 500,
storage_bytes_peak=1024 * 1024 * 600,
bandwidth_up_bytes=1024 * 1024 * 50,
- bandwidth_down_bytes=1024 * 1024 * 100
+ bandwidth_down_bytes=1024 * 1024 * 100,
)
assert aggregate.user_id == test_user.id
assert aggregate.storage_bytes_avg == 1024 * 1024 * 500
await aggregate.delete()
+
@pytest.mark.asyncio
async def test_invoice_creation(test_user):
invoice = await Invoice.create(
@@ -90,7 +101,7 @@ async def test_invoice_creation(test_user):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="draft"
+ status="draft",
)
assert invoice.user_id == test_user.id
@@ -98,6 +109,7 @@ async def test_invoice_creation(test_user):
assert invoice.total == Decimal("10.00")
await invoice.delete()
+
@pytest.mark.asyncio
async def test_invoice_line_item_creation(test_user):
invoice = await Invoice.create(
@@ -108,7 +120,7 @@ async def test_invoice_line_item_creation(test_user):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
- status="draft"
+ status="draft",
)
line_item = await InvoiceLineItem.create(
@@ -117,7 +129,7 @@ async def test_invoice_line_item_creation(test_user):
quantity=Decimal("100.000000"),
unit_price=Decimal("0.100000"),
amount=Decimal("10.0000"),
- item_type="storage"
+ item_type="storage",
)
assert line_item.invoice_id == invoice.id
@@ -126,6 +138,7 @@ async def test_invoice_line_item_creation(test_user):
await line_item.delete()
await invoice.delete()
+
@pytest.mark.asyncio
async def test_pricing_config_creation(test_user):
config = await PricingConfig.create(
@@ -133,13 +146,14 @@ async def test_pricing_config_creation(test_user):
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month",
- updated_by=test_user
+ updated_by=test_user,
)
assert config.config_key == "storage_per_gb_month"
assert config.config_value == Decimal("0.0045")
await config.delete()
+
@pytest.mark.asyncio
async def test_payment_method_creation(test_user):
payment_method = await PaymentMethod.create(
@@ -150,7 +164,7 @@ async def test_payment_method_creation(test_user):
last4="4242",
brand="visa",
exp_month=12,
- exp_year=2025
+ exp_year=2025,
)
assert payment_method.user_id == test_user.id
@@ -158,6 +172,7 @@ async def test_payment_method_creation(test_user):
assert payment_method.is_default is True
await payment_method.delete()
+
@pytest.mark.asyncio
async def test_billing_event_creation(test_user):
event = await BillingEvent.create(
@@ -165,7 +180,7 @@ async def test_billing_event_creation(test_user):
event_type="invoice_created",
stripe_event_id="evt_test_123",
data={"invoice_id": 1},
- processed=False
+ processed=False,
)
assert event.user_id == test_user.id
diff --git a/tests/billing/test_simple.py b/tests/billing/test_simple.py
index 657eb4e..f3ebbdc 100644
--- a/tests/billing/test_simple.py
+++ b/tests/billing/test_simple.py
@@ -1,18 +1,35 @@
import pytest
+
def test_billing_module_imports():
- from mywebdav.billing import models, stripe_client, usage_tracker, invoice_generator, scheduler
+ from mywebdav.billing import (
+ models,
+ stripe_client,
+ usage_tracker,
+ invoice_generator,
+ scheduler,
+ )
+
assert models is not None
assert stripe_client is not None
assert usage_tracker is not None
assert invoice_generator is not None
assert scheduler is not None
+
def test_billing_models_exist():
from mywebdav.billing.models import (
- SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
- Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
+ SubscriptionPlan,
+ UserSubscription,
+ UsageRecord,
+ UsageAggregate,
+ Invoice,
+ InvoiceLineItem,
+ PricingConfig,
+ PaymentMethod,
+ BillingEvent,
)
+
assert SubscriptionPlan is not None
assert UserSubscription is not None
assert UsageRecord is not None
@@ -23,55 +40,71 @@ def test_billing_models_exist():
assert PaymentMethod is not None
assert BillingEvent is not None
+
def test_stripe_client_exists():
from mywebdav.billing.stripe_client import StripeClient
+
assert StripeClient is not None
- assert hasattr(StripeClient, 'create_customer')
- assert hasattr(StripeClient, 'create_invoice')
- assert hasattr(StripeClient, 'finalize_invoice')
+ assert hasattr(StripeClient, "create_customer")
+ assert hasattr(StripeClient, "create_invoice")
+ assert hasattr(StripeClient, "finalize_invoice")
+
def test_usage_tracker_exists():
from mywebdav.billing.usage_tracker import UsageTracker
+
assert UsageTracker is not None
- assert hasattr(UsageTracker, 'track_storage')
- assert hasattr(UsageTracker, 'track_bandwidth')
- assert hasattr(UsageTracker, 'aggregate_daily_usage')
- assert hasattr(UsageTracker, 'get_current_storage')
- assert hasattr(UsageTracker, 'get_monthly_usage')
+ assert hasattr(UsageTracker, "track_storage")
+ assert hasattr(UsageTracker, "track_bandwidth")
+ assert hasattr(UsageTracker, "aggregate_daily_usage")
+ assert hasattr(UsageTracker, "get_current_storage")
+ assert hasattr(UsageTracker, "get_monthly_usage")
+
def test_invoice_generator_exists():
from mywebdav.billing.invoice_generator import InvoiceGenerator
+
assert InvoiceGenerator is not None
- assert hasattr(InvoiceGenerator, 'generate_monthly_invoice')
- assert hasattr(InvoiceGenerator, 'finalize_invoice')
- assert hasattr(InvoiceGenerator, 'mark_invoice_paid')
+ assert hasattr(InvoiceGenerator, "generate_monthly_invoice")
+ assert hasattr(InvoiceGenerator, "finalize_invoice")
+ assert hasattr(InvoiceGenerator, "mark_invoice_paid")
+
def test_scheduler_exists():
from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler
+
assert scheduler is not None
assert callable(start_scheduler)
assert callable(stop_scheduler)
+
def test_routers_exist():
from mywebdav.routers import billing, admin_billing
+
assert billing is not None
assert admin_billing is not None
- assert hasattr(billing, 'router')
- assert hasattr(admin_billing, 'router')
+ assert hasattr(billing, "router")
+ assert hasattr(admin_billing, "router")
+
def test_middleware_exists():
from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware
+
assert UsageTrackingMiddleware is not None
+
def test_settings_updated():
from mywebdav.settings import settings
- assert hasattr(settings, 'STRIPE_SECRET_KEY')
- assert hasattr(settings, 'STRIPE_PUBLISHABLE_KEY')
- assert hasattr(settings, 'STRIPE_WEBHOOK_SECRET')
- assert hasattr(settings, 'BILLING_ENABLED')
+
+ assert hasattr(settings, "STRIPE_SECRET_KEY")
+ assert hasattr(settings, "STRIPE_PUBLISHABLE_KEY")
+ assert hasattr(settings, "STRIPE_WEBHOOK_SECRET")
+ assert hasattr(settings, "BILLING_ENABLED")
+
def test_main_includes_billing():
from mywebdav.main import app
+
routes = [route.path for route in app.routes]
- billing_routes = [r for r in routes if '/billing' in r]
+ billing_routes = [r for r in routes if "/billing" in r]
assert len(billing_routes) > 0
diff --git a/tests/billing/test_usage_tracker.py b/tests/billing/test_usage_tracker.py
index 6aac8bf..35f3b71 100644
--- a/tests/billing/test_usage_tracker.py
+++ b/tests/billing/test_usage_tracker.py
@@ -6,24 +6,26 @@ from mywebdav.models import User, File, Folder
from mywebdav.billing.models import UsageRecord, UsageAggregate
from mywebdav.billing.usage_tracker import UsageTracker
+
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
- is_active=True
+ is_active=True,
)
yield user
await user.delete()
+
@pytest.mark.asyncio
async def test_track_storage(test_user):
await UsageTracker.track_storage(
user=test_user,
amount_bytes=1024 * 1024 * 100,
resource_type="file",
- resource_id=1
+ resource_id=1,
)
records = await UsageRecord.filter(user=test_user, record_type="storage").all()
@@ -32,6 +34,7 @@ async def test_track_storage(test_user):
await records[0].delete()
+
@pytest.mark.asyncio
async def test_track_bandwidth_upload(test_user):
await UsageTracker.track_bandwidth(
@@ -39,7 +42,7 @@ async def test_track_bandwidth_upload(test_user):
amount_bytes=1024 * 1024 * 50,
direction="up",
resource_type="file",
- resource_id=1
+ resource_id=1,
)
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all()
@@ -48,6 +51,7 @@ async def test_track_bandwidth_upload(test_user):
await records[0].delete()
+
@pytest.mark.asyncio
async def test_track_bandwidth_download(test_user):
await UsageTracker.track_bandwidth(
@@ -55,32 +59,28 @@ async def test_track_bandwidth_download(test_user):
amount_bytes=1024 * 1024 * 75,
direction="down",
resource_type="file",
- resource_id=1
+ resource_id=1,
)
- records = await UsageRecord.filter(user=test_user, record_type="bandwidth_down").all()
+ records = await UsageRecord.filter(
+ user=test_user, record_type="bandwidth_down"
+ ).all()
assert len(records) == 1
assert records[0].amount_bytes == 1024 * 1024 * 75
await records[0].delete()
+
@pytest.mark.asyncio
async def test_aggregate_daily_usage(test_user):
- await UsageTracker.track_storage(
- user=test_user,
- amount_bytes=1024 * 1024 * 100
+ await UsageTracker.track_storage(user=test_user, amount_bytes=1024 * 1024 * 100)
+
+ await UsageTracker.track_bandwidth(
+ user=test_user, amount_bytes=1024 * 1024 * 50, direction="up"
)
await UsageTracker.track_bandwidth(
- user=test_user,
- amount_bytes=1024 * 1024 * 50,
- direction="up"
- )
-
- await UsageTracker.track_bandwidth(
- user=test_user,
- amount_bytes=1024 * 1024 * 75,
- direction="down"
+ user=test_user, amount_bytes=1024 * 1024 * 75, direction="down"
)
aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today())
@@ -93,12 +93,10 @@ async def test_aggregate_daily_usage(test_user):
await aggregate.delete()
await UsageRecord.filter(user=test_user).delete()
+
@pytest.mark.asyncio
async def test_get_current_storage(test_user):
- folder = await Folder.create(
- name="Test Folder",
- owner=test_user
- )
+ folder = await Folder.create(name="Test Folder", owner=test_user)
file1 = await File.create(
name="test1.txt",
@@ -107,7 +105,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain",
owner=test_user,
parent=folder,
- is_deleted=False
+ is_deleted=False,
)
file2 = await File.create(
@@ -117,7 +115,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain",
owner=test_user,
parent=folder,
- is_deleted=False
+ is_deleted=False,
)
storage = await UsageTracker.get_current_storage(test_user)
@@ -128,6 +126,7 @@ async def test_get_current_storage(test_user):
await file2.delete()
await folder.delete()
+
@pytest.mark.asyncio
async def test_get_monthly_usage(test_user):
today = date.today()
@@ -138,7 +137,7 @@ async def test_get_monthly_usage(test_user):
storage_bytes_avg=1024 * 1024 * 1024 * 10,
storage_bytes_peak=1024 * 1024 * 1024 * 12,
bandwidth_up_bytes=1024 * 1024 * 1024 * 2,
- bandwidth_down_bytes=1024 * 1024 * 1024 * 5
+ bandwidth_down_bytes=1024 * 1024 * 1024 * 5,
)
usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month)
@@ -150,11 +149,14 @@ async def test_get_monthly_usage(test_user):
await aggregate.delete()
+
@pytest.mark.asyncio
async def test_get_monthly_usage_empty(test_user):
future_date = date.today() + timedelta(days=365)
- usage = await UsageTracker.get_monthly_usage(test_user, future_date.year, future_date.month)
+ usage = await UsageTracker.get_monthly_usage(
+ test_user, future_date.year, future_date.month
+ )
assert usage["storage_gb_avg"] == 0
assert usage["storage_gb_peak"] == 0
diff --git a/tests/conftest.py b/tests/conftest.py
index 7d6fdc9..a7a59e2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,27 +3,21 @@ import pytest_asyncio
import asyncio
from tortoise import Tortoise
-# Initialize Tortoise at module level
-async def _init_tortoise():
+
+@pytest_asyncio.fixture(scope="session")
+async def init_db():
+ """Initialize the database for testing."""
await Tortoise.init(
db_url="sqlite://:memory:",
- modules={
- "models": ["mywebdav.models"],
- "billing": ["mywebdav.billing.models"]
- }
+ modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
)
await Tortoise.generate_schemas()
-
-# Run initialization
-asyncio.run(_init_tortoise())
-
-@pytest.fixture(scope="session")
-def event_loop():
- loop = asyncio.get_event_loop_policy().new_event_loop()
- yield loop
- loop.close()
-
-@pytest_asyncio.fixture(scope="session", autouse=True)
-async def cleanup_db():
yield
await Tortoise.close_connections()
+
+
+@pytest_asyncio.fixture(autouse=True)
+async def setup_db(init_db):
+ """Ensure database is initialized before each test."""
+ # The init_db fixture handles the setup
+ yield