This commit is contained in:
retoor 2025-11-13 23:22:05 +01:00
parent aff8dfca08
commit d350ab6807
34 changed files with 1851 additions and 875 deletions

View File

@ -1,17 +1,18 @@
from typing import Optional
from .models import Activity, User
async def log_activity(
user: Optional[User],
action: str,
target_type: str,
target_id: int,
ip_address: Optional[str] = None
ip_address: Optional[str] = None,
):
await Activity.create(
user=user,
action=action,
target_type=target_type,
target_id=target_id,
ip_address=ip_address
ip_address=ip_address,
)

View File

@ -13,16 +13,25 @@ from .two_factor import verify_totp_code # Import verify_totp_code
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def verify_password(plain_password, hashed_password):
password_bytes = plain_password[:72].encode('utf-8')
hashed_bytes = hashed_password.encode('utf-8') if isinstance(hashed_password, str) else hashed_password
password_bytes = plain_password[:72].encode("utf-8")
hashed_bytes = (
hashed_password.encode("utf-8")
if isinstance(hashed_password, str)
else hashed_password
)
return bcrypt.checkpw(password_bytes, hashed_bytes)
def get_password_hash(password):
password_bytes = password[:72].encode('utf-8')
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8')
async def authenticate_user(username: str, password: str, two_factor_code: Optional[str] = None):
def get_password_hash(password):
password_bytes = password[:72].encode("utf-8")
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8")
async def authenticate_user(
username: str, password: str, two_factor_code: Optional[str] = None
):
user = await User.get_or_none(username=username)
if not user:
return None
@ -36,16 +45,26 @@ async def authenticate_user(username: str, password: str, two_factor_code: Optio
return None # 2FA code is incorrect
return {"user": user, "2fa_required": False}
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, two_factor_verified: bool = False):
def create_access_token(
data: dict,
expires_delta: Optional[timedelta] = None,
two_factor_verified: bool = False,
):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire, "2fa_verified": two_factor_verified})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -53,12 +72,16 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
username: str = payload.get("sub")
two_factor_verified: bool = payload.get("2fa_verified", False)
if username is None:
raise credentials_exception
token_data = TokenData(username=username, two_factor_verified=two_factor_verified)
token_data = TokenData(
username=username, two_factor_verified=two_factor_verified
)
except JWTError:
raise credentials_exception
user = await User.get_or_none(username=token_data.username)
@ -67,17 +90,27 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
user.token_data = token_data # Attach token_data to user for easy access
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def get_current_verified_user(current_user: User = Depends(get_current_user)):
if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="2FA required and not verified")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="2FA required and not verified",
)
return current_user
async def get_current_admin_user(current_user: User = Depends(get_current_verified_user)):
async def get_current_admin_user(
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_superuser:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
)
return current_user

View File

@ -2,14 +2,17 @@ from datetime import datetime, date, timedelta, timezone
from decimal import Decimal
from typing import Optional
from calendar import monthrange
from .models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
from .models import Invoice, InvoiceLineItem, PricingConfig, UserSubscription
from .usage_tracker import UsageTracker
from .stripe_client import StripeClient
from ..models import User
class InvoiceGenerator:
@staticmethod
async def generate_monthly_invoice(user: User, year: int, month: int) -> Optional[Invoice]:
async def generate_monthly_invoice(
user: User, year: int, month: int
) -> Optional[Invoice]:
period_start = date(year, month, 1)
days_in_month = monthrange(year, month)[1]
period_end = date(year, month, days_in_month)
@ -19,19 +22,24 @@ class InvoiceGenerator:
pricing = await PricingConfig.all()
pricing_dict = {p.config_key: p.config_value for p in pricing}
storage_price_per_gb = pricing_dict.get('storage_per_gb_month', Decimal('0.0045'))
bandwidth_price_per_gb = pricing_dict.get('bandwidth_egress_per_gb', Decimal('0.009'))
free_storage_gb = pricing_dict.get('free_tier_storage_gb', Decimal('15'))
free_bandwidth_gb = pricing_dict.get('free_tier_bandwidth_gb', Decimal('15'))
tax_rate = pricing_dict.get('tax_rate_default', Decimal('0'))
storage_price_per_gb = pricing_dict.get(
"storage_per_gb_month", Decimal("0.0045")
)
bandwidth_price_per_gb = pricing_dict.get(
"bandwidth_egress_per_gb", Decimal("0.009")
)
free_storage_gb = pricing_dict.get("free_tier_storage_gb", Decimal("15"))
free_bandwidth_gb = pricing_dict.get("free_tier_bandwidth_gb", Decimal("15"))
tax_rate = pricing_dict.get("tax_rate_default", Decimal("0"))
storage_gb = Decimal(str(usage['storage_gb_avg']))
bandwidth_gb = Decimal(str(usage['bandwidth_down_gb']))
storage_gb = Decimal(str(usage["storage_gb_avg"]))
bandwidth_gb = Decimal(str(usage["bandwidth_down_gb"]))
billable_storage = max(Decimal('0'), storage_gb - free_storage_gb)
billable_bandwidth = max(Decimal('0'), bandwidth_gb - free_bandwidth_gb)
billable_storage = max(Decimal("0"), storage_gb - free_storage_gb)
billable_bandwidth = max(Decimal("0"), bandwidth_gb - free_bandwidth_gb)
import math
billable_storage_rounded = Decimal(math.ceil(float(billable_storage)))
billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth)))
@ -65,9 +73,9 @@ class InvoiceGenerator:
"usage": usage,
"pricing": {
"storage_per_gb": float(storage_price_per_gb),
"bandwidth_per_gb": float(bandwidth_price_per_gb)
}
}
"bandwidth_per_gb": float(bandwidth_price_per_gb),
},
},
)
if billable_storage_rounded > 0:
@ -78,7 +86,10 @@ class InvoiceGenerator:
unit_price=storage_price_per_gb,
amount=storage_cost,
item_type="storage",
metadata={"avg_gb": float(storage_gb), "free_gb": float(free_storage_gb)}
metadata={
"avg_gb": float(storage_gb),
"free_gb": float(free_storage_gb),
},
)
if billable_bandwidth_rounded > 0:
@ -89,7 +100,10 @@ class InvoiceGenerator:
unit_price=bandwidth_price_per_gb,
amount=bandwidth_cost,
item_type="bandwidth",
metadata={"total_gb": float(bandwidth_gb), "free_gb": float(free_bandwidth_gb)}
metadata={
"total_gb": float(bandwidth_gb),
"free_gb": float(free_bandwidth_gb),
},
)
if subscription and subscription.stripe_customer_id:
@ -100,7 +114,7 @@ class InvoiceGenerator:
"amount": item.amount,
"currency": "usd",
"description": item.description,
"metadata": item.metadata or {}
"metadata": item.metadata or {},
}
for item in line_items
]
@ -109,7 +123,7 @@ class InvoiceGenerator:
customer_id=subscription.stripe_customer_id,
description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}",
line_items=stripe_line_items,
metadata={"mywebdav_invoice_id": str(invoice.id)}
metadata={"mywebdav_invoice_id": str(invoice.id)},
)
invoice.stripe_invoice_id = stripe_invoice.id
@ -135,8 +149,11 @@ class InvoiceGenerator:
# Send invoice email
from ..mail import queue_email
line_items = await invoice.line_items.all()
items_text = "\n".join([f"- {item.description}: ${item.amount}" for item in line_items])
items_text = "\n".join(
[f"- {item.description}: ${item.amount}" for item in line_items]
)
body = f"""Dear {invoice.user.username},
Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available.
@ -174,7 +191,7 @@ The MyWebdav Team
to_email=invoice.user.email,
subject=f"Your MyWebdav Invoice {invoice.invoice_number}",
body=body,
html=html
html=html,
)
return invoice

View File

@ -1,8 +1,8 @@
from tortoise import fields, models
from decimal import Decimal
class SubscriptionPlan(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=100, unique=True)
display_name = fields.CharField(max_length=255)
description = fields.TextField(null=True)
@ -18,10 +18,13 @@ class SubscriptionPlan(models.Model):
class Meta:
table = "subscription_plans"
class UserSubscription(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="subscription")
plan = fields.ForeignKeyField("billing.SubscriptionPlan", related_name="subscriptions", null=True)
plan = fields.ForeignKeyField(
"billing.SubscriptionPlan", related_name="subscriptions", null=True
)
billing_type = fields.CharField(max_length=20, default="pay_as_you_go")
stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True)
stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True)
@ -35,14 +38,15 @@ class UserSubscription(models.Model):
class Meta:
table = "user_subscriptions"
class UsageRecord(models.Model):
id = fields.BigIntField(pk=True)
id = fields.BigIntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="usage_records")
record_type = fields.CharField(max_length=50, index=True)
record_type = fields.CharField(max_length=50, db_index=True)
amount_bytes = fields.BigIntField()
resource_type = fields.CharField(max_length=50, null=True)
resource_id = fields.IntField(null=True)
timestamp = fields.DatetimeField(auto_now_add=True, index=True)
timestamp = fields.DatetimeField(auto_now_add=True, db_index=True)
idempotency_key = fields.CharField(max_length=255, unique=True, null=True)
metadata = fields.JSONField(null=True)
@ -50,8 +54,9 @@ class UsageRecord(models.Model):
table = "usage_records"
indexes = [("user_id", "record_type", "timestamp")]
class UsageAggregate(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="usage_aggregates")
date = fields.DateField()
storage_bytes_avg = fields.BigIntField(default=0)
@ -64,8 +69,9 @@ class UsageAggregate(models.Model):
table = "usage_aggregates"
unique_together = (("user", "date"),)
class Invoice(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="invoices")
invoice_number = fields.CharField(max_length=50, unique=True)
stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True)
@ -75,10 +81,10 @@ class Invoice(models.Model):
tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0)
total = fields.DecimalField(max_digits=10, decimal_places=4)
currency = fields.CharField(max_length=3, default="USD")
status = fields.CharField(max_length=50, default="draft", index=True)
status = fields.CharField(max_length=50, default="draft", db_index=True)
due_date = fields.DateField(null=True)
paid_at = fields.DatetimeField(null=True)
created_at = fields.DatetimeField(auto_now_add=True, index=True)
created_at = fields.DatetimeField(auto_now_add=True, db_index=True)
updated_at = fields.DatetimeField(auto_now=True)
metadata = fields.JSONField(null=True)
@ -86,8 +92,9 @@ class Invoice(models.Model):
table = "invoices"
indexes = [("user_id", "status", "created_at")]
class InvoiceLineItem(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items")
description = fields.TextField()
quantity = fields.DecimalField(max_digits=15, decimal_places=6)
@ -100,20 +107,24 @@ class InvoiceLineItem(models.Model):
class Meta:
table = "invoice_line_items"
class PricingConfig(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
config_key = fields.CharField(max_length=100, unique=True)
config_value = fields.DecimalField(max_digits=10, decimal_places=6)
description = fields.TextField(null=True)
unit = fields.CharField(max_length=50, null=True)
updated_by = fields.ForeignKeyField("models.User", related_name="pricing_updates", null=True)
updated_by = fields.ForeignKeyField(
"models.User", related_name="pricing_updates", null=True
)
updated_at = fields.DatetimeField(auto_now=True)
class Meta:
table = "pricing_config"
class PaymentMethod(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
user = fields.ForeignKeyField("models.User", related_name="payment_methods")
stripe_payment_method_id = fields.CharField(max_length=255)
type = fields.CharField(max_length=50)
@ -128,9 +139,12 @@ class PaymentMethod(models.Model):
class Meta:
table = "payment_methods"
class BillingEvent(models.Model):
id = fields.BigIntField(pk=True)
user = fields.ForeignKeyField("models.User", related_name="billing_events", null=True)
id = fields.BigIntField(primary_key=True)
user = fields.ForeignKeyField(
"models.User", related_name="billing_events", null=True
)
event_type = fields.CharField(max_length=100)
stripe_event_id = fields.CharField(max_length=255, unique=True, null=True)
data = fields.JSONField(null=True)

View File

@ -1,7 +1,6 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from datetime import datetime, date, timedelta
import asyncio
from .usage_tracker import UsageTracker
from .invoice_generator import InvoiceGenerator
@ -9,6 +8,7 @@ from ..models import User
scheduler = AsyncIOScheduler()
async def aggregate_daily_usage_for_all_users():
users = await User.filter(is_active=True).all()
yesterday = date.today() - timedelta(days=1)
@ -19,6 +19,7 @@ async def aggregate_daily_usage_for_all_users():
except Exception as e:
print(f"Failed to aggregate usage for user {user.id}: {e}")
async def generate_monthly_invoices():
now = datetime.now()
last_month = now.month - 1 if now.month > 1 else 12
@ -28,19 +29,22 @@ async def generate_monthly_invoices():
for user in users:
try:
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, last_month)
invoice = await InvoiceGenerator.generate_monthly_invoice(
user, year, last_month
)
if invoice:
await InvoiceGenerator.finalize_invoice(invoice)
except Exception as e:
print(f"Failed to generate invoice for user {user.id}: {e}")
def start_scheduler():
scheduler.add_job(
aggregate_daily_usage_for_all_users,
CronTrigger(hour=1, minute=0),
id="aggregate_daily_usage",
name="Aggregate daily usage for all users",
replace_existing=True
replace_existing=True,
)
scheduler.add_job(
@ -48,10 +52,11 @@ def start_scheduler():
CronTrigger(day=1, hour=2, minute=0),
id="generate_monthly_invoices",
name="Generate monthly invoices",
replace_existing=True
replace_existing=True,
)
scheduler.start()
def stop_scheduler():
scheduler.shutdown()

View File

@ -1,8 +1,8 @@
import stripe
from decimal import Decimal
from typing import Optional, Dict, Any
from typing import Dict
from ..settings import settings
class StripeClient:
@staticmethod
def _ensure_api_key():
@ -11,13 +11,12 @@ class StripeClient:
stripe.api_key = settings.STRIPE_SECRET_KEY
else:
raise ValueError("Stripe API key not configured")
@staticmethod
async def create_customer(email: str, name: str, metadata: Dict = None) -> str:
StripeClient._ensure_api_key()
customer = stripe.Customer.create(
email=email,
name=name,
metadata=metadata or {}
email=email, name=name, metadata=metadata or {}
)
return customer.id
@ -26,7 +25,7 @@ class StripeClient:
amount: int,
currency: str = "usd",
customer_id: str = None,
metadata: Dict = None
metadata: Dict = None,
) -> stripe.PaymentIntent:
StripeClient._ensure_api_key()
return stripe.PaymentIntent.create(
@ -34,32 +33,29 @@ class StripeClient:
currency=currency,
customer=customer_id,
metadata=metadata or {},
automatic_payment_methods={"enabled": True}
automatic_payment_methods={"enabled": True},
)
@staticmethod
async def create_invoice(
customer_id: str,
description: str,
line_items: list,
metadata: Dict = None
customer_id: str, description: str, line_items: list, metadata: Dict = None
) -> stripe.Invoice:
StripeClient._ensure_api_key()
for item in line_items:
stripe.InvoiceItem.create(
customer=customer_id,
amount=int(item['amount'] * 100),
currency=item.get('currency', 'usd'),
description=item['description'],
metadata=item.get('metadata', {})
amount=int(item["amount"] * 100),
currency=item.get("currency", "usd"),
description=item["description"],
metadata=item.get("metadata", {}),
)
invoice = stripe.Invoice.create(
customer=customer_id,
description=description,
auto_advance=True,
collection_method='charge_automatically',
metadata=metadata or {}
collection_method="charge_automatically",
metadata=metadata or {},
)
return invoice
@ -76,18 +72,15 @@ class StripeClient:
@staticmethod
async def attach_payment_method(
payment_method_id: str,
customer_id: str
payment_method_id: str, customer_id: str
) -> stripe.PaymentMethod:
StripeClient._ensure_api_key()
payment_method = stripe.PaymentMethod.attach(
payment_method_id,
customer=customer_id
payment_method_id, customer=customer_id
)
stripe.Customer.modify(
customer_id,
invoice_settings={'default_payment_method': payment_method_id}
customer_id, invoice_settings={"default_payment_method": payment_method_id}
)
return payment_method
@ -95,22 +88,15 @@ class StripeClient:
@staticmethod
async def list_payment_methods(customer_id: str, type: str = "card"):
StripeClient._ensure_api_key()
return stripe.PaymentMethod.list(
customer=customer_id,
type=type
)
return stripe.PaymentMethod.list(customer=customer_id, type=type)
@staticmethod
async def create_subscription(
customer_id: str,
price_id: str,
metadata: Dict = None
customer_id: str, price_id: str, metadata: Dict = None
) -> stripe.Subscription:
StripeClient._ensure_api_key()
return stripe.Subscription.create(
customer=customer_id,
items=[{'price': price_id}],
metadata=metadata or {}
customer=customer_id, items=[{"price": price_id}], metadata=metadata or {}
)
@staticmethod

View File

@ -1,11 +1,10 @@
import uuid
from datetime import datetime, date, timezone
from decimal import Decimal
from typing import Optional
from tortoise.transactions import in_transaction
from .models import UsageRecord, UsageAggregate
from ..models import User
class UsageTracker:
@staticmethod
async def track_storage(
@ -13,7 +12,7 @@ class UsageTracker:
amount_bytes: int,
resource_type: str = None,
resource_id: int = None,
metadata: dict = None
metadata: dict = None,
):
idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
@ -24,7 +23,7 @@ class UsageTracker:
resource_type=resource_type,
resource_id=resource_id,
idempotency_key=idempotency_key,
metadata=metadata
metadata=metadata,
)
@staticmethod
@ -34,7 +33,7 @@ class UsageTracker:
direction: str = "down",
resource_type: str = None,
resource_id: int = None,
metadata: dict = None
metadata: dict = None,
):
record_type = f"bandwidth_{direction}"
idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
@ -46,7 +45,7 @@ class UsageTracker:
resource_type=resource_type,
resource_id=resource_id,
idempotency_key=idempotency_key,
metadata=metadata
metadata=metadata,
)
@staticmethod
@ -61,24 +60,26 @@ class UsageTracker:
user=user,
record_type="storage",
timestamp__gte=start_of_day,
timestamp__lte=end_of_day
timestamp__lte=end_of_day,
).all()
storage_avg = sum(r.amount_bytes for r in storage_records) // max(len(storage_records), 1)
storage_avg = sum(r.amount_bytes for r in storage_records) // max(
len(storage_records), 1
)
storage_peak = max((r.amount_bytes for r in storage_records), default=0)
bandwidth_up = await UsageRecord.filter(
user=user,
record_type="bandwidth_up",
timestamp__gte=start_of_day,
timestamp__lte=end_of_day
timestamp__lte=end_of_day,
).all()
bandwidth_down = await UsageRecord.filter(
user=user,
record_type="bandwidth_down",
timestamp__gte=start_of_day,
timestamp__lte=end_of_day
timestamp__lte=end_of_day,
).all()
total_up = sum(r.amount_bytes for r in bandwidth_up)
@ -92,8 +93,8 @@ class UsageTracker:
"storage_bytes_avg": storage_avg,
"storage_bytes_peak": storage_peak,
"bandwidth_up_bytes": total_up,
"bandwidth_down_bytes": total_down
}
"bandwidth_down_bytes": total_down,
},
)
if not created:
@ -108,6 +109,7 @@ class UsageTracker:
@staticmethod
async def get_current_storage(user: User) -> int:
from ..models import File
files = await File.filter(owner=user, is_deleted=False).all()
return sum(f.size for f in files)
@ -121,9 +123,7 @@ class UsageTracker:
end_date = date(year, month, last_day)
aggregates = await UsageAggregate.filter(
user=user,
date__gte=start_date,
date__lte=end_date
user=user, date__gte=start_date, date__lte=end_date
).all()
if not aggregates:
@ -132,7 +132,7 @@ class UsageTracker:
"storage_gb_peak": 0,
"bandwidth_up_gb": 0,
"bandwidth_down_gb": 0,
"total_bandwidth_gb": 0
"total_bandwidth_gb": 0,
}
storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates)
@ -145,5 +145,5 @@ class UsageTracker:
"storage_gb_peak": round(storage_peak / (1024**3), 4),
"bandwidth_up_gb": round(bandwidth_up / (1024**3), 4),
"bandwidth_down_gb": round(bandwidth_down / (1024**3), 4),
"total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4)
"total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4),
}

View File

@ -7,7 +7,7 @@ for various legal policies required for a European cloud storage provider.
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Dict, Any
from typing import Dict
class LegalDocument(ABC):
@ -17,10 +17,13 @@ class LegalDocument(ABC):
Provides common structure and methods for generating legal content.
"""
def __init__(self, company_name: str = "MyWebdav Technologies",
def __init__(
self,
company_name: str = "MyWebdav Technologies",
last_updated: str = None,
contact_email: str = "legal@mywebdav.eu",
website: str = "https://mywebdav.eu"):
website: str = "https://mywebdav.eu",
):
self.company_name = company_name
self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y")
self.contact_email = contact_email
@ -853,20 +856,21 @@ class ContactComplaintMechanism(LegalDocument):
def get_all_legal_documents() -> Dict[str, LegalDocument]:
"""Return a dictionary of all legal document instances."""
return {
'privacy_policy': PrivacyPolicy(),
'terms_of_service': TermsOfService(),
'security_policy': SecurityPolicy(),
'cookie_policy': CookiePolicy(),
'data_processing_agreement': DataProcessingAgreement(),
'compliance_statement': ComplianceStatement(),
'data_portability_deletion_policy': DataPortabilityDeletionPolicy(),
'contact_complaint_mechanism': ContactComplaintMechanism(),
"privacy_policy": PrivacyPolicy(),
"terms_of_service": TermsOfService(),
"security_policy": SecurityPolicy(),
"cookie_policy": CookiePolicy(),
"data_processing_agreement": DataProcessingAgreement(),
"compliance_statement": ComplianceStatement(),
"data_portability_deletion_policy": DataPortabilityDeletionPolicy(),
"contact_complaint_mechanism": ContactComplaintMechanism(),
}
def generate_legal_documents(output_dir: str = "static/legal"):
"""Generate all legal documents as Markdown and HTML files."""
import os
os.makedirs(output_dir, exist_ok=True)
documents = get_all_legal_documents()
@ -875,13 +879,13 @@ def generate_legal_documents(output_dir: str = "static/legal"):
# Generate Markdown
md_filename = f"{doc_name}.md"
md_path = os.path.join(output_dir, md_filename)
with open(md_path, 'w') as f:
with open(md_path, "w") as f:
f.write(doc.to_markdown())
# Generate HTML
html_filename = f"{doc_name}.html"
html_path = os.path.join(output_dir, html_filename)
with open(html_path, 'w') as f:
with open(html_path, "w") as f:
f.write(doc.to_html())
print(f"Generated {md_filename} and {html_filename}")

View File

@ -2,17 +2,26 @@ import asyncio
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from typing import Optional, Dict, Any
from typing import Optional
from .settings import settings
class EmailTask:
def __init__(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
def __init__(
self,
to_email: str,
subject: str,
body: str,
html: Optional[str] = None,
**kwargs,
):
self.to_email = to_email
self.subject = subject
self.body = body
self.html = html
self.kwargs = kwargs
class EmailService:
def __init__(self):
self.queue = asyncio.Queue()
@ -38,7 +47,14 @@ class EmailService:
except asyncio.CancelledError:
pass
async def send_email(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
async def send_email(
self,
to_email: str,
subject: str,
body: str,
html: Optional[str] = None,
**kwargs,
):
"""Queue an email for sending"""
task = EmailTask(to_email, subject, body, html, **kwargs)
await self.queue.put(task)
@ -60,22 +76,26 @@ class EmailService:
async def _send_email_task(self, task: EmailTask):
"""Send a single email task"""
if not settings.SMTP_HOST or not settings.SMTP_USERNAME or not settings.SMTP_PASSWORD:
if (
not settings.SMTP_HOST
or not settings.SMTP_USERNAME
or not settings.SMTP_PASSWORD
):
print("SMTP not configured, skipping email send")
return
msg = MIMEMultipart('alternative')
msg['From'] = settings.SMTP_SENDER_EMAIL
msg['To'] = task.to_email
msg['Subject'] = task.subject
msg = MIMEMultipart("alternative")
msg["From"] = settings.SMTP_SENDER_EMAIL
msg["To"] = task.to_email
msg["Subject"] = task.subject
# Add text part
text_part = MIMEText(task.body, 'plain')
text_part = MIMEText(task.body, "plain")
msg.attach(text_part)
# Add HTML part if provided
if task.html:
html_part = MIMEText(task.html, 'html')
html_part = MIMEText(task.html, "html")
msg.attach(html_part)
try:
@ -86,7 +106,7 @@ class EmailService:
port=settings.SMTP_PORT,
username=settings.SMTP_USERNAME,
password=settings.SMTP_PASSWORD,
use_tls=True
use_tls=True,
) as smtp:
await smtp.send_message(msg)
print(f"Email sent to {task.to_email}")
@ -99,7 +119,10 @@ class EmailService:
try:
await smtp.starttls()
except Exception as tls_error:
if "already using" in str(tls_error).lower() or "tls" in str(tls_error).lower():
if (
"already using" in str(tls_error).lower()
or "tls" in str(tls_error).lower()
):
# Connection is already using TLS, proceed without starttls
pass
else:
@ -111,14 +134,23 @@ class EmailService:
print(f"Failed to send email to {task.to_email}: {e}")
raise # Re-raise to let caller handle
# Global email service instance
email_service = EmailService()
# Convenience functions
async def send_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
async def send_email(
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
):
"""Send an email asynchronously"""
await email_service.send_email(to_email, subject, body, html, **kwargs)
def queue_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
def queue_email(
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
):
"""Queue an email for sending (fire and forget)"""
asyncio.create_task(email_service.send_email(to_email, subject, body, html, **kwargs))
asyncio.create_task(
email_service.send_email(to_email, subject, body, html, **kwargs)
)

View File

@ -2,18 +2,33 @@ import argparse
import uvicorn
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status, HTTPException
from fastapi import FastAPI, Request, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, JSONResponse
from tortoise.contrib.fastapi import register_tortoise
from .settings import settings
from .routers import auth, users, folders, files, shares, search, admin, starred, billing, admin_billing
from .routers import (
auth,
users,
folders,
files,
shares,
search,
admin,
starred,
billing,
admin_billing,
)
from . import webdav
from .schemas import ErrorResponse
from .middleware.usage_tracking import UsageTrackingMiddleware
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting up...")
@ -21,6 +36,7 @@ async def lifespan(app: FastAPI):
from .billing.scheduler import start_scheduler
from .billing.models import PricingConfig
from .mail import email_service
start_scheduler()
logger.info("Billing scheduler started")
await email_service.start()
@ -28,28 +44,61 @@ async def lifespan(app: FastAPI):
pricing_count = await PricingConfig.all().count()
if pricing_count == 0:
from decimal import Decimal
await PricingConfig.create(config_key='storage_per_gb_month', config_value=Decimal('0.0045'), description='Storage cost per GB per month', unit='per_gb_month')
await PricingConfig.create(config_key='bandwidth_egress_per_gb', config_value=Decimal('0.009'), description='Bandwidth egress cost per GB', unit='per_gb')
await PricingConfig.create(config_key='bandwidth_ingress_per_gb', config_value=Decimal('0.0'), description='Bandwidth ingress cost per GB (free)', unit='per_gb')
await PricingConfig.create(config_key='free_tier_storage_gb', config_value=Decimal('15'), description='Free tier storage in GB', unit='gb')
await PricingConfig.create(config_key='free_tier_bandwidth_gb', config_value=Decimal('15'), description='Free tier bandwidth in GB per month', unit='gb')
await PricingConfig.create(config_key='tax_rate_default', config_value=Decimal('0.0'), description='Default tax rate (0 = no tax)', unit='percentage')
await PricingConfig.create(
config_key="storage_per_gb_month",
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month",
)
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb",
)
await PricingConfig.create(
config_key="bandwidth_ingress_per_gb",
config_value=Decimal("0.0"),
description="Bandwidth ingress cost per GB (free)",
unit="per_gb",
)
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb",
)
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb",
)
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate (0 = no tax)",
unit="percentage",
)
logger.info("Default pricing configuration initialized")
yield
from .billing.scheduler import stop_scheduler
stop_scheduler()
logger.info("Billing scheduler stopped")
await email_service.stop()
logger.info("Email service stopped")
print("Shutting down...")
app = FastAPI(
title="MyWebdav Cloud Storage",
description="A commercial cloud storage web application",
version="0.1.0",
lifespan=lifespan
lifespan=lifespan,
)
app.include_router(auth.router)
@ -64,8 +113,6 @@ app.include_router(billing.router)
app.include_router(admin_billing.router)
app.include_router(webdav.router)
from .middleware.usage_tracking import UsageTrackingMiddleware
app.add_middleware(UsageTrackingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
@ -73,17 +120,17 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
register_tortoise(
app,
db_url=settings.DATABASE_URL,
modules={
"models": ["mywebdav.models"],
"billing": ["mywebdav.billing.models"]
},
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
generate_schemas=True,
add_exception_handlers=True,
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}")
logger.error(
f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}"
)
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(),
@ -95,13 +142,17 @@ async def read_root():
with open("static/index.html", "r") as f:
return f.read()
def main():
parser = argparse.ArgumentParser(description="Run the MyWebdav application.")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address to bind to")
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host address to bind to"
)
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()

View File

@ -2,31 +2,35 @@ from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from ..billing.usage_tracker import UsageTracker
class UsageTrackingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
if hasattr(request.state, 'user') and request.state.user:
if hasattr(request.state, "user") and request.state.user:
user = request.state.user
if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path:
content_length = response.headers.get('content-length')
if (
request.method in ["POST", "PUT"]
and "/files/upload" in request.url.path
):
content_length = response.headers.get("content-length")
if content_length:
await UsageTracker.track_bandwidth(
user=user,
amount_bytes=int(content_length),
direction='up',
metadata={'path': request.url.path}
direction="up",
metadata={"path": request.url.path},
)
elif request.method == 'GET' and '/files/download' in request.url.path:
content_length = response.headers.get('content-length')
elif request.method == "GET" and "/files/download" in request.url.path:
content_length = response.headers.get("content-length")
if content_length:
await UsageTracker.track_bandwidth(
user=user,
amount_bytes=int(content_length),
direction='down',
metadata={'path': request.url.path}
direction="down",
metadata={"path": request.url.path},
)
return response

View File

@ -1,9 +1,9 @@
from tortoise import fields, models
from tortoise.contrib.pydantic import pydantic_model_creator
from datetime import datetime
class User(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
username = fields.CharField(max_length=20, unique=True)
email = fields.CharField(max_length=255, unique=True)
hashed_password = fields.CharField(max_length=255)
@ -11,7 +11,9 @@ class User(models.Model):
is_superuser = fields.BooleanField(default=False)
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
storage_quota_bytes = fields.BigIntField(default=10 * 1024 * 1024 * 1024) # 10 GB default
storage_quota_bytes = fields.BigIntField(
default=10 * 1024 * 1024 * 1024
) # 10 GB default
used_storage_bytes = fields.BigIntField(default=0)
plan_type = fields.CharField(max_length=50, default="free")
two_factor_secret = fields.CharField(max_length=255, null=True)
@ -24,11 +26,16 @@ class User(models.Model):
def __str__(self):
return self.username
class Folder(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255)
parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField("models.Folder", related_name="children", null=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="folders")
parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField(
"models.Folder", related_name="children", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="folders"
)
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False)
@ -36,21 +43,28 @@ class Folder(models.Model):
class Meta:
table = "folders"
unique_together = (("name", "parent", "owner"),) # Ensure unique folder names within a parent for an owner
unique_together = (
("name", "parent", "owner"),
) # Ensure unique folder names within a parent for an owner
def __str__(self):
return self.name
class File(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255)
path = fields.CharField(max_length=1024) # Internal storage path
size = fields.BigIntField()
mime_type = fields.CharField(max_length=255)
file_hash = fields.CharField(max_length=64, null=True) # SHA-256
thumbnail_path = fields.CharField(max_length=1024, null=True)
parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="files", null=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="files")
parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="files", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="files"
)
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
is_deleted = fields.BooleanField(default=False)
@ -60,14 +74,19 @@ class File(models.Model):
class Meta:
table = "files"
unique_together = (("name", "parent", "owner"),) # Ensure unique file names within a parent for an owner
unique_together = (
("name", "parent", "owner"),
) # Ensure unique file names within a parent for an owner
def __str__(self):
return self.name
class FileVersion(models.Model):
id = fields.IntField(pk=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="versions")
id = fields.IntField(primary_key=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
"models.File", related_name="versions"
)
version_path = fields.CharField(max_length=1024)
size = fields.BigIntField()
created_at = fields.DatetimeField(auto_now_add=True)
@ -75,45 +94,67 @@ class FileVersion(models.Model):
class Meta:
table = "file_versions"
class Share(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
token = fields.CharField(max_length=64, unique=True)
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="shares", null=True)
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="shares", null=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="shares")
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
"models.File", related_name="shares", null=True
)
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="shares", null=True
)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="shares"
)
created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True)
password_protected = fields.BooleanField(default=False)
hashed_password = fields.CharField(max_length=255, null=True)
access_count = fields.IntField(default=0)
permission_level = fields.CharField(max_length=50, default="viewer") # viewer, uploader, editor
permission_level = fields.CharField(
max_length=50, default="viewer"
) # viewer, uploader, editor
class Meta:
table = "shares"
class Team(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
name = fields.CharField(max_length=255, unique=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="owned_teams")
members: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User", related_name="teams", through="team_members")
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="owned_teams"
)
members: fields.ManyToManyRelation[User] = fields.ManyToManyField(
"models.User", related_name="teams", through="team_members"
)
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
table = "teams"
class TeamMember(models.Model):
id = fields.IntField(pk=True)
team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField("models.Team", related_name="team_members")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="user_teams")
id = fields.IntField(primary_key=True)
team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField(
"models.Team", related_name="team_members"
)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="user_teams"
)
role = fields.CharField(max_length=50, default="member") # owner, admin, member
class Meta:
table = "team_members"
unique_together = (("team", "user"),)
class Activity(models.Model):
id = fields.IntField(pk=True)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="activities", null=True)
id = fields.IntField(primary_key=True)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="activities", null=True
)
action = fields.CharField(max_length=255)
target_type = fields.CharField(max_length=50) # file, folder, share, user, team
target_id = fields.IntField()
@ -123,13 +164,18 @@ class Activity(models.Model):
class Meta:
table = "activities"
class FileRequest(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
title = fields.CharField(max_length=255)
description = fields.TextField(null=True)
token = fields.CharField(max_length=64, unique=True)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="file_requests")
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="file_requests")
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", related_name="file_requests"
)
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
"models.Folder", related_name="file_requests"
)
created_at = fields.DatetimeField(auto_now_add=True)
expires_at = fields.DatetimeField(null=True)
is_active = fields.BooleanField(default=True)
@ -137,8 +183,9 @@ class FileRequest(models.Model):
class Meta:
table = "file_requests"
class WebDAVProperty(models.Model):
id = fields.IntField(pk=True)
id = fields.IntField(primary_key=True)
resource_type = fields.CharField(max_length=10)
resource_id = fields.IntField()
namespace = fields.CharField(max_length=255)
@ -151,5 +198,8 @@ class WebDAVProperty(models.Model):
table = "webdav_properties"
unique_together = (("resource_type", "resource_id", "namespace", "name"),)
User_Pydantic = pydantic_model_creator(User, name="User_Pydantic")
UserIn_Pydantic = pydantic_model_creator(User, name="UserIn_Pydantic", exclude_readonly=True)
UserIn_Pydantic = pydantic_model_creator(
User, name="UserIn_Pydantic", exclude_readonly=True
)

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
@ -13,18 +13,25 @@ router = APIRouter(
responses={403: {"description": "Not enough permissions"}},
)
@router.get("/users", response_model=List[User_Pydantic])
async def get_all_users():
return await User.all()
@router.get("/users/{user_id}", response_model=User_Pydantic)
async def get_user(user_id: int):
user = await User.get_or_none(id=user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
return user
@router.post("/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED)
@router.post(
"/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED
)
async def create_user_by_admin(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username)
if user:
@ -49,20 +56,28 @@ async def create_user_by_admin(user_in: UserCreate):
)
return await User_Pydantic.from_tortoise_orm(user)
@router.put("/users/{user_id}", response_model=User_Pydantic)
async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
user = await User.get_or_none(id=user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
if user_update.username is not None and user_update.username != user.username:
if await User.get_or_none(username=user_update.username):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
)
user.username = user_update.username
if user_update.email is not None and user_update.email != user.email:
if await User.get_or_none(email=user_update.email):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
user.email = user_update.email
if user_update.password is not None:
@ -89,21 +104,28 @@ async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
await user.save()
return await User_Pydantic.from_tortoise_orm(user)
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_user_by_admin(user_id: int):
user = await User.get_or_none(id=user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
await user.delete()
return {"message": "User deleted successfully"}
@router.post("/test-email")
async def send_test_email(to_email: str, subject: str = "Test Email", body: str = "This is a test email"):
async def send_test_email(
to_email: str, subject: str = "Test Email", body: str = "This is a test email"
):
from ..mail import queue_email
queue_email(
to_email=to_email,
subject=subject,
body=body,
html=f"<h1>{subject}</h1><p>{body}</p>"
html=f"<h1>{subject}</h1><p>{body}</p>",
)
return {"message": "Test email queued"}

View File

@ -1,5 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException
from typing import List
from decimal import Decimal
from pydantic import BaseModel
@ -8,21 +7,25 @@ from ..models import User
from ..billing.models import PricingConfig, Invoice, SubscriptionPlan
from ..billing.invoice_generator import InvoiceGenerator
def require_superuser(current_user: User = Depends(get_current_user)):
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Superuser privileges required")
return current_user
router = APIRouter(
prefix="/api/admin/billing",
tags=["admin", "billing"],
dependencies=[Depends(require_superuser)]
dependencies=[Depends(require_superuser)],
)
class PricingConfigUpdate(BaseModel):
config_key: str
config_value: float
class PlanCreate(BaseModel):
name: str
display_name: str
@ -32,6 +35,7 @@ class PlanCreate(BaseModel):
price_monthly: float
price_yearly: float = None
@router.get("/pricing")
async def get_all_pricing(current_user: User = Depends(require_superuser)):
configs = await PricingConfig.all()
@ -42,16 +46,17 @@ async def get_all_pricing(current_user: User = Depends(require_superuser)):
"config_value": float(c.config_value),
"description": c.description,
"unit": c.unit,
"updated_at": c.updated_at
"updated_at": c.updated_at,
}
for c in configs
]
@router.put("/pricing/{config_id}")
async def update_pricing(
config_id: int,
update: PricingConfigUpdate,
current_user: User = Depends(require_superuser)
current_user: User = Depends(require_superuser),
):
config = await PricingConfig.get_or_none(id=config_id)
if not config:
@ -63,11 +68,10 @@ async def update_pricing(
return {"message": "Pricing updated successfully"}
@router.post("/generate-invoices/{year}/{month}")
async def generate_all_invoices(
year: int,
month: int,
current_user: User = Depends(require_superuser)
year: int, month: int, current_user: User = Depends(require_superuser)
):
users = await User.filter(is_active=True).all()
generated = []
@ -76,35 +80,36 @@ async def generate_all_invoices(
for user in users:
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month)
if invoice:
generated.append({
generated.append(
{
"user_id": user.id,
"invoice_id": invoice.id,
"total": float(invoice.total)
})
"total": float(invoice.total),
}
)
else:
skipped.append(user.id)
return {
"generated": len(generated),
"skipped": len(skipped),
"invoices": generated
}
return {"generated": len(generated), "skipped": len(skipped), "invoices": generated}
@router.post("/plans")
async def create_plan(
plan_data: PlanCreate,
current_user: User = Depends(require_superuser)
plan_data: PlanCreate, current_user: User = Depends(require_superuser)
):
plan = await SubscriptionPlan.create(**plan_data.dict())
return {"id": plan.id, "message": "Plan created successfully"}
@router.get("/stats")
async def get_billing_stats(current_user: User = Depends(require_superuser)):
from tortoise.functions import Sum, Count
from tortoise.functions import Sum
total_revenue = await Invoice.filter(status="paid").annotate(
total_sum=Sum("total")
).values("total_sum")
total_revenue = (
await Invoice.filter(status="paid")
.annotate(total_sum=Sum("total"))
.values("total_sum")
)
invoice_count = await Invoice.all().count()
pending_invoices = await Invoice.filter(status="open").count()
@ -112,5 +117,5 @@ async def get_billing_stats(current_user: User = Depends(require_superuser)):
return {
"total_revenue": float(total_revenue[0]["total_sum"] or 0),
"total_invoices": invoice_count,
"pending_invoices": pending_invoices
"pending_invoices": pending_invoices,
}

View File

@ -4,13 +4,23 @@ from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password
from ..auth import (
authenticate_user,
create_access_token,
get_password_hash,
get_current_user,
get_current_verified_user,
verify_password,
)
from ..models import User
from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA
from ..schemas import Token, UserCreate
from ..two_factor import (
generate_totp_secret, generate_totp_uri, generate_qr_code_base64,
verify_totp_code, generate_recovery_codes, hash_recovery_codes,
verify_recovery_codes
generate_totp_secret,
generate_totp_uri,
generate_qr_code_base64,
verify_totp_code,
generate_recovery_codes,
hash_recovery_codes,
)
router = APIRouter(
@ -18,27 +28,33 @@ router = APIRouter(
tags=["auth"],
)
class LoginRequest(BaseModel):
username: str
password: str
class TwoFactorLogin(BaseModel):
username: str
password: str
two_factor_code: Optional[str] = None
class TwoFactorSetupResponse(BaseModel):
secret: str
qr_code_base64: str
recovery_codes: List[str]
class TwoFactorCode(BaseModel):
two_factor_code: str
class TwoFactorDisable(BaseModel):
password: str
two_factor_code: str
@router.post("/register", response_model=Token)
async def register_user(user_in: UserCreate):
user = await User.get_or_none(username=user_in.username)
@ -63,11 +79,12 @@ async def register_user(user_in: UserCreate):
# Send welcome email
from ..mail import queue_email
queue_email(
to_email=user.email,
subject="Welcome to MyWebdav!",
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>"
html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>",
)
access_token_expires = timedelta(minutes=30) # Use settings
@ -76,9 +93,12 @@ async def register_user(user_in: UserCreate):
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/token", response_model=Token)
async def login_for_access_token(login_data: LoginRequest):
auth_result = await authenticate_user(login_data.username, login_data.password, None)
auth_result = await authenticate_user(
login_data.username, login_data.password, None
)
if not auth_result:
raise HTTPException(
@ -97,16 +117,26 @@ async def login_for_access_token(login_data: LoginRequest):
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True
data={"sub": user.username},
expires_delta=access_token_expires,
two_factor_verified=True,
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)):
async def setup_two_factor_authentication(
current_user: User = Depends(get_current_user),
):
if current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if current_user.two_factor_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="2FA setup already initiated. Verify or disable first.",
)
secret = generate_totp_secret()
current_user.two_factor_secret = secret
@ -120,39 +150,66 @@ async def setup_two_factor_authentication(current_user: User = Depends(get_curre
current_user.recovery_codes = ",".join(hashed_recovery_codes)
await current_user.save()
return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes)
return TwoFactorSetupResponse(
secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes
)
@router.post("/2fa/verify", response_model=Token)
async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)):
async def verify_two_factor_authentication(
two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)
):
if current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
)
if not current_user.two_factor_secret:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated."
)
if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
if not verify_totp_code(
current_user.two_factor_secret, two_factor_code_data.two_factor_code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.is_2fa_enabled = True
await current_user.save()
access_token_expires = timedelta(minutes=30) # Use settings
access_token = create_access_token(
data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True
data={"sub": current_user.username},
expires_delta=access_token_expires,
two_factor_verified=True,
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/2fa/disable", response_model=dict)
async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)):
async def disable_two_factor_authentication(
disable_data: TwoFactorDisable,
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
# Verify password
if not verify_password(disable_data.password, current_user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password."
)
# Verify 2FA code
if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
if not verify_totp_code(
current_user.two_factor_secret, disable_data.two_factor_code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
)
current_user.two_factor_secret = None
current_user.is_2fa_enabled = False
@ -161,10 +218,15 @@ async def disable_two_factor_authentication(disable_data: TwoFactorDisable, curr
return {"message": "2FA disabled successfully."}
@router.get("/2fa/recovery-codes", response_model=List[str])
async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)):
async def get_new_recovery_codes(
current_user: User = Depends(get_current_verified_user),
):
if not current_user.is_2fa_enabled:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
)
recovery_codes = generate_recovery_codes()
hashed_recovery_codes = hash_recovery_codes(recovery_codes)

View File

@ -1,25 +1,25 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from typing import List, Optional
from datetime import datetime, date
from decimal import Decimal
import calendar
from ..auth import get_current_user
from ..models import User
from ..billing.models import (
Invoice, InvoiceLineItem, UserSubscription, PricingConfig,
PaymentMethod, UsageAggregate, SubscriptionPlan
Invoice,
UserSubscription,
PricingConfig,
PaymentMethod,
UsageAggregate,
SubscriptionPlan,
)
from ..billing.usage_tracker import UsageTracker
from ..billing.invoice_generator import InvoiceGenerator
from ..billing.stripe_client import StripeClient
from pydantic import BaseModel
router = APIRouter(
prefix="/api/billing",
tags=["billing"]
)
router = APIRouter(prefix="/api/billing", tags=["billing"])
class UsageResponse(BaseModel):
storage_gb_avg: float
@ -29,6 +29,7 @@ class UsageResponse(BaseModel):
total_bandwidth_gb: float
period: str
class InvoiceResponse(BaseModel):
id: int
invoice_number: str
@ -42,6 +43,7 @@ class InvoiceResponse(BaseModel):
paid_at: Optional[datetime]
line_items: List[dict]
class SubscriptionResponse(BaseModel):
id: int
billing_type: str
@ -50,6 +52,7 @@ class SubscriptionResponse(BaseModel):
current_period_start: Optional[datetime]
current_period_end: Optional[datetime]
@router.get("/usage/current")
async def get_current_usage(current_user: User = Depends(get_current_user)):
try:
@ -61,25 +64,32 @@ async def get_current_usage(current_user: User = Depends(get_current_user)):
if usage_today:
return {
"storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": round(usage_today.bandwidth_down_bytes / (1024**3), 4),
"bandwidth_up_gb_today": round(usage_today.bandwidth_up_bytes / (1024**3), 4),
"as_of": today.isoformat()
"bandwidth_down_gb_today": round(
usage_today.bandwidth_down_bytes / (1024**3), 4
),
"bandwidth_up_gb_today": round(
usage_today.bandwidth_up_bytes / (1024**3), 4
),
"as_of": today.isoformat(),
}
return {
"storage_gb": round(storage_bytes / (1024**3), 4),
"bandwidth_down_gb_today": 0,
"bandwidth_up_gb_today": 0,
"as_of": today.isoformat()
"as_of": today.isoformat(),
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch usage data: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to fetch usage data: {str(e)}"
)
@router.get("/usage/monthly")
async def get_monthly_usage(
year: Optional[int] = None,
month: Optional[int] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
) -> UsageResponse:
try:
if year is None or month is None:
@ -88,39 +98,50 @@ async def get_monthly_usage(
month = now.month
if not (1 <= month <= 12):
raise HTTPException(status_code=400, detail="Month must be between 1 and 12")
raise HTTPException(
status_code=400, detail="Month must be between 1 and 12"
)
if not (2020 <= year <= 2100):
raise HTTPException(status_code=400, detail="Year must be between 2020 and 2100")
raise HTTPException(
status_code=400, detail="Year must be between 2020 and 2100"
)
usage = await UsageTracker.get_monthly_usage(current_user, year, month)
return UsageResponse(
**usage,
period=f"{year}-{month:02d}"
)
return UsageResponse(**usage, period=f"{year}-{month:02d}")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}"
)
@router.get("/invoices")
async def list_invoices(
limit: int = 50,
offset: int = 0,
current_user: User = Depends(get_current_user)
limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user)
) -> List[InvoiceResponse]:
try:
if limit < 1 or limit > 100:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
raise HTTPException(
status_code=400, detail="Limit must be between 1 and 100"
)
if offset < 0:
raise HTTPException(status_code=400, detail="Offset must be non-negative")
invoices = await Invoice.filter(user=current_user).order_by("-created_at").offset(offset).limit(limit).all()
invoices = (
await Invoice.filter(user=current_user)
.order_by("-created_at")
.offset(offset)
.limit(limit)
.all()
)
result = []
for invoice in invoices:
line_items = await invoice.line_items.all()
result.append(InvoiceResponse(
result.append(
InvoiceResponse(
id=invoice.id,
invoice_number=invoice.invoice_number,
period_start=invoice.period_start,
@ -137,22 +158,25 @@ async def list_invoices(
"quantity": float(item.quantity),
"unit_price": float(item.unit_price),
"amount": float(item.amount),
"type": item.item_type
"type": item.item_type,
}
for item in line_items
]
))
],
)
)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch invoices: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to fetch invoices: {str(e)}"
)
@router.get("/invoices/{invoice_id}")
async def get_invoice(
invoice_id: int,
current_user: User = Depends(get_current_user)
invoice_id: int, current_user: User = Depends(get_current_user)
) -> InvoiceResponse:
invoice = await Invoice.get_or_none(id=invoice_id, user=current_user)
if not invoice:
@ -177,21 +201,22 @@ async def get_invoice(
"quantity": float(item.quantity),
"unit_price": float(item.unit_price),
"amount": float(item.amount),
"type": item.item_type
"type": item.item_type,
}
for item in line_items
]
],
)
@router.get("/subscription")
async def get_subscription(current_user: User = Depends(get_current_user)) -> SubscriptionResponse:
async def get_subscription(
current_user: User = Depends(get_current_user),
) -> SubscriptionResponse:
subscription = await UserSubscription.get_or_none(user=current_user)
if not subscription:
subscription = await UserSubscription.create(
user=current_user,
billing_type="pay_as_you_go",
status="active"
user=current_user, billing_type="pay_as_you_go", status="active"
)
plan_name = None
@ -205,15 +230,19 @@ async def get_subscription(current_user: User = Depends(get_current_user)) -> Su
plan_name=plan_name,
status=subscription.status,
current_period_start=subscription.current_period_start,
current_period_end=subscription.current_period_end
current_period_end=subscription.current_period_end,
)
@router.post("/payment-methods/setup-intent")
async def create_setup_intent(current_user: User = Depends(get_current_user)):
try:
from ..settings import settings
if not settings.STRIPE_SECRET_KEY:
raise HTTPException(status_code=503, detail="Payment processing not configured")
raise HTTPException(
status_code=503, detail="Payment processing not configured"
)
subscription = await UserSubscription.get_or_none(user=current_user)
@ -221,7 +250,7 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
customer_id = await StripeClient.create_customer(
email=current_user.email,
name=current_user.username,
metadata={"user_id": str(current_user.id)}
metadata={"user_id": str(current_user.id)},
)
if not subscription:
@ -229,27 +258,30 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
user=current_user,
billing_type="pay_as_you_go",
stripe_customer_id=customer_id,
status="active"
status="active",
)
else:
subscription.stripe_customer_id = customer_id
await subscription.save()
import stripe
StripeClient._ensure_api_key()
setup_intent = stripe.SetupIntent.create(
customer=subscription.stripe_customer_id,
payment_method_types=["card"]
customer=subscription.stripe_customer_id, payment_method_types=["card"]
)
return {
"client_secret": setup_intent.client_secret,
"customer_id": subscription.stripe_customer_id
"customer_id": subscription.stripe_customer_id,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create setup intent: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create setup intent: {str(e)}"
)
@router.get("/payment-methods")
async def list_payment_methods(current_user: User = Depends(get_current_user)):
@ -262,11 +294,12 @@ async def list_payment_methods(current_user: User = Depends(get_current_user)):
"brand": m.brand,
"exp_month": m.exp_month,
"exp_year": m.exp_year,
"is_default": m.is_default
"is_default": m.is_default,
}
for m in methods
]
@router.post("/webhooks/stripe")
async def stripe_webhook(request: Request):
import stripe
@ -298,12 +331,14 @@ async def stripe_webhook(request: Request):
event_type=event["type"],
stripe_event_id=event_id,
data=event["data"],
processed=False
processed=False,
)
if event["type"] == "invoice.payment_succeeded":
invoice_data = event["data"]["object"]
mywebdav_invoice_id = invoice_data.get("metadata", {}).get("mywebdav_invoice_id")
mywebdav_invoice_id = invoice_data.get("metadata", {}).get(
"mywebdav_invoice_id"
)
if mywebdav_invoice_id:
invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id))
@ -317,7 +352,9 @@ async def stripe_webhook(request: Request):
payment_method = event["data"]["object"]
customer_id = payment_method["customer"]
subscription = await UserSubscription.get_or_none(stripe_customer_id=customer_id)
subscription = await UserSubscription.get_or_none(
stripe_customer_id=customer_id
)
if subscription:
await PaymentMethod.create(
user=subscription.user,
@ -327,7 +364,7 @@ async def stripe_webhook(request: Request):
brand=payment_method.get("card", {}).get("brand"),
exp_month=payment_method.get("card", {}).get("exp_month"),
exp_year=payment_method.get("card", {}).get("exp_year"),
is_default=True
is_default=True,
)
await BillingEvent.filter(stripe_event_id=event_id).update(processed=True)
@ -335,7 +372,10 @@ async def stripe_webhook(request: Request):
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Webhook processing failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Webhook processing failed: {str(e)}"
)
@router.get("/pricing")
async def get_pricing():
@ -344,11 +384,12 @@ async def get_pricing():
config.config_key: {
"value": float(config.config_value),
"description": config.description,
"unit": config.unit
"unit": config.unit,
}
for config in configs
}
@router.get("/plans")
async def list_plans():
plans = await SubscriptionPlan.filter(is_active=True).all()
@ -361,14 +402,16 @@ async def list_plans():
"storage_gb": plan.storage_gb,
"bandwidth_gb": plan.bandwidth_gb,
"price_monthly": float(plan.price_monthly),
"price_yearly": float(plan.price_yearly) if plan.price_yearly else None
"price_yearly": float(plan.price_yearly) if plan.price_yearly else None,
}
for plan in plans
]
@router.get("/stripe-key")
async def get_stripe_key():
from ..settings import settings
if not settings.STRIPE_PUBLISHABLE_KEY:
raise HTTPException(status_code=503, detail="Payment processing not configured")
return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY}

View File

@ -1,4 +1,12 @@
from fastapi import APIRouter, Depends, UploadFile, File as FastAPIFile, HTTPException, status, Response, Form
from fastapi import (
APIRouter,
Depends,
UploadFile,
File as FastAPIFile,
HTTPException,
status,
Form,
)
from fastapi.responses import StreamingResponse
from typing import List, Optional
import mimetypes
@ -11,7 +19,6 @@ from ..auth import get_current_user
from ..models import User, File, Folder
from ..schemas import FileOut
from ..storage import storage_manager
from ..settings import settings
from ..activity import log_activity
from ..thumbnails import generate_thumbnail, delete_thumbnail
@ -20,35 +27,46 @@ router = APIRouter(
tags=["files"],
)
class FileMove(BaseModel):
target_folder_id: Optional[int] = None
class FileRename(BaseModel):
new_name: str
class FileCopy(BaseModel):
target_folder_id: Optional[int] = None
class BatchFileOperation(BaseModel):
file_ids: List[int]
operation: str # e.g., "delete", "star", "unstar", "move", "copy"
class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None
class FileContentUpdate(BaseModel):
content: str
@router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED)
async def upload_file(
file: UploadFile = FastAPIFile(...),
folder_id: Optional[int] = Form(None),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
if folder_id:
parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
parent_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not parent_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
else:
parent_folder = None
@ -105,16 +123,20 @@ async def upload_file(
return await FileOut.from_tortoise_orm(db_file)
@router.get("/download/{file_id}")
async def download_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.last_accessed_at = datetime.now()
await db_file.save()
try:
async def file_iterator():
async for chunk in storage_manager.get_file(current_user.id, db_file.path):
yield chunk
@ -122,16 +144,21 @@ async def download_file(file_id: int, current_user: User = Depends(get_current_u
return StreamingResponse(
file_iterator(),
media_type=db_file.mime_type,
headers={"Content-Disposition": f"attachment; filename=\"{db_file.name}\""}
headers={"Content-Disposition": f'attachment; filename="{db_file.name}"'},
)
except FileNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage"
)
@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_deleted = True
db_file.deleted_at = datetime.now()
@ -141,68 +168,110 @@ async def delete_file(file_id: int, current_user: User = Depends(get_current_use
return
@router.post("/{file_id}/move", response_model=FileOut)
async def move_file(file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)):
async def move_file(
file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
target_folder = None
if move_data.target_folder_id:
target_folder = await Folder.get_or_none(id=move_data.target_folder_id, owner=current_user, is_deleted=False)
target_folder = await Folder.get_or_none(
id=move_data.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
)
existing_file = await File.get_or_none(
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
)
if existing_file and existing_file.id != file_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in target folder")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="File with this name already exists in target folder",
)
db_file.parent = target_folder
await db_file.save()
await log_activity(user=current_user, action="file_moved", target_type="file", target_id=file_id)
await log_activity(
user=current_user, action="file_moved", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/rename", response_model=FileOut)
async def rename_file(file_id: int, rename_data: FileRename, current_user: User = Depends(get_current_user)):
async def rename_file(
file_id: int,
rename_data: FileRename,
current_user: User = Depends(get_current_user),
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
existing_file = await File.get_or_none(
name=rename_data.new_name, parent_id=db_file.parent_id, owner=current_user, is_deleted=False
name=rename_data.new_name,
parent_id=db_file.parent_id,
owner=current_user,
is_deleted=False,
)
if existing_file and existing_file.id != file_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in the same folder")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="File with this name already exists in the same folder",
)
db_file.name = rename_data.new_name
await db_file.save()
await log_activity(user=current_user, action="file_renamed", target_type="file", target_id=file_id)
await log_activity(
user=current_user, action="file_renamed", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED)
async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)):
@router.post(
"/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED
)
async def copy_file(
file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
target_folder = None
if copy_data.target_folder_id:
target_folder = await Folder.get_or_none(id=copy_data.target_folder_id, owner=current_user, is_deleted=False)
target_folder = await Folder.get_or_none(
id=copy_data.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
)
base_name = db_file.name
name_parts = os.path.splitext(base_name)
counter = 1
new_name = base_name
while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
while await File.get_or_none(
name=new_name, parent=target_folder, owner=current_user, is_deleted=False
):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1
@ -213,139 +282,205 @@ async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depe
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
owner=current_user,
parent=target_folder
parent=target_folder,
)
await log_activity(user=current_user, action="file_copied", target_type="file", target_id=new_file.id)
await log_activity(
user=current_user,
action="file_copied",
target_type="file",
target_id=new_file.id,
)
return await FileOut.from_tortoise_orm(new_file)
@router.get("/", response_model=List[FileOut])
async def list_files(folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
async def list_files(
folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)
):
if folder_id:
parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
parent_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not parent_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
files = await File.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
files = await File.filter(
parent=parent_folder, owner=current_user, is_deleted=False
).order_by("name")
else:
files = await File.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
files = await File.filter(
parent=None, owner=current_user, is_deleted=False
).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/thumbnail/{file_id}")
async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.last_accessed_at = datetime.now()
await db_file.save()
thumbnail_path = getattr(db_file, 'thumbnail_path', None)
thumbnail_path = getattr(db_file, "thumbnail_path", None)
if not thumbnail_path:
thumbnail_path = await generate_thumbnail(db_file.path, db_file.mime_type, current_user.id)
thumbnail_path = await generate_thumbnail(
db_file.path, db_file.mime_type, current_user.id
)
if thumbnail_path:
db_file.thumbnail_path = thumbnail_path
await db_file.save()
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available"
)
try:
async def thumbnail_iterator():
async for chunk in storage_manager.get_file(current_user.id, thumbnail_path):
async for chunk in storage_manager.get_file(
current_user.id, thumbnail_path
):
yield chunk
return StreamingResponse(
thumbnail_iterator(),
media_type="image/jpeg"
)
return StreamingResponse(thumbnail_iterator(), media_type="image/jpeg")
except FileNotFoundError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not found in storage")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Thumbnail not found in storage",
)
@router.get("/photos", response_model=List[FileOut])
async def list_photos(current_user: User = Depends(get_current_user)):
files = await File.filter(
owner=current_user,
is_deleted=False,
mime_type__istartswith="image/"
owner=current_user, is_deleted=False, mime_type__istartswith="image/"
).order_by("-created_at")
return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/recent", response_model=List[FileOut])
async def list_recent_files(current_user: User = Depends(get_current_user), limit: int = 10):
files = await File.filter(
owner=current_user,
is_deleted=False,
last_accessed_at__isnull=False
).order_by("-last_accessed_at").limit(limit)
async def list_recent_files(
current_user: User = Depends(get_current_user), limit: int = 10
):
files = (
await File.filter(
owner=current_user, is_deleted=False, last_accessed_at__isnull=False
)
.order_by("-last_accessed_at")
.limit(limit)
)
return [await FileOut.from_tortoise_orm(f) for f in files]
@router.post("/{file_id}/star", response_model=FileOut)
async def star_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_starred = True
await db_file.save()
await log_activity(user=current_user, action="file_starred", target_type="file", target_id=file_id)
await log_activity(
user=current_user, action="file_starred", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file)
@router.post("/{file_id}/unstar", response_model=FileOut)
async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
db_file.is_starred = False
await db_file.save()
await log_activity(user=current_user, action="file_unstarred", target_type="file", target_id=file_id)
await log_activity(
user=current_user,
action="file_unstarred",
target_type="file",
target_id=file_id,
)
return await FileOut.from_tortoise_orm(db_file)
@router.get("/deleted", response_model=List[FileOut])
async def list_deleted_files(current_user: User = Depends(get_current_user)):
files = await File.filter(owner=current_user, is_deleted=True).order_by("-deleted_at")
files = await File.filter(owner=current_user, is_deleted=True).order_by(
"-deleted_at"
)
return [await FileOut.from_tortoise_orm(f) for f in files]
@router.post("/{file_id}/restore", response_model=FileOut)
async def restore_file(file_id: int, current_user: User = Depends(get_current_user)):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found"
)
# Check if a file with the same name exists in the parent folder
existing_file = await File.get_or_none(
name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False
)
if existing_file:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.")
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.",
)
db_file.is_deleted = False
db_file.deleted_at = None
await db_file.save()
await log_activity(user=current_user, action="file_restored", target_type="file", target_id=file_id)
await log_activity(
user=current_user, action="file_restored", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file)
class BatchOperationResult(BaseModel):
succeeded: List[FileOut]
failed: List[dict]
@router.post("/batch")
async def batch_file_operations(
batch_operation: BatchFileOperation,
payload: Optional[BatchMoveCopyPayload] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid operation: {batch_operation.operation}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid operation: {batch_operation.operation}",
)
updated_files = []
failed_operations = []
for file_id in batch_operation.file_ids:
try:
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
db_file = await File.get_or_none(
id=file_id, owner=current_user, is_deleted=False
)
if not db_file:
failed_operations.append({"file_id": file_id, "reason": "File not found or not owned by user"})
failed_operations.append(
{
"file_id": file_id,
"reason": "File not found or not owned by user",
}
)
continue
if batch_operation.operation == "delete":
@ -353,47 +488,87 @@ async def batch_file_operations(
db_file.deleted_at = datetime.now()
await db_file.save()
await delete_thumbnail(db_file.id)
await log_activity(user=current_user, action="file_deleted_batch", target_type="file", target_id=file_id)
await log_activity(
user=current_user,
action="file_deleted_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file)
elif batch_operation.operation == "star":
db_file.is_starred = True
await db_file.save()
await log_activity(user=current_user, action="file_starred_batch", target_type="file", target_id=file_id)
await log_activity(
user=current_user,
action="file_starred_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file)
elif batch_operation.operation == "unstar":
db_file.is_starred = False
await db_file.save()
await log_activity(user=current_user, action="file_unstarred_batch", target_type="file", target_id=file_id)
await log_activity(
user=current_user,
action="file_unstarred_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file)
elif batch_operation.operation == "move":
if not payload or payload.target_folder_id is None:
failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
failed_operations.append(
{"file_id": file_id, "reason": "Target folder not specified"}
)
continue
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder:
failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
failed_operations.append(
{"file_id": file_id, "reason": "Target folder not found"}
)
continue
existing_file = await File.get_or_none(
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
name=db_file.name,
parent=target_folder,
owner=current_user,
is_deleted=False,
)
if existing_file and existing_file.id != file_id:
failed_operations.append({"file_id": file_id, "reason": "File with same name exists in target folder"})
failed_operations.append(
{
"file_id": file_id,
"reason": "File with same name exists in target folder",
}
)
continue
db_file.parent = target_folder
await db_file.save()
await log_activity(user=current_user, action="file_moved_batch", target_type="file", target_id=file_id)
await log_activity(
user=current_user,
action="file_moved_batch",
target_type="file",
target_id=file_id,
)
updated_files.append(db_file)
elif batch_operation.operation == "copy":
if not payload or payload.target_folder_id is None:
failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
failed_operations.append(
{"file_id": file_id, "reason": "Target folder not specified"}
)
continue
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder:
failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
failed_operations.append(
{"file_id": file_id, "reason": "Target folder not found"}
)
continue
base_name = db_file.name
@ -401,7 +576,12 @@ async def batch_file_operations(
counter = 1
new_name = base_name
while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
while await File.get_or_none(
name=new_name,
parent=target_folder,
owner=current_user,
is_deleted=False,
):
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
counter += 1
@ -412,48 +592,70 @@ async def batch_file_operations(
mime_type=db_file.mime_type,
file_hash=db_file.file_hash,
owner=current_user,
parent=target_folder
parent=target_folder,
)
await log_activity(
user=current_user,
action="file_copied_batch",
target_type="file",
target_id=new_file.id,
)
await log_activity(user=current_user, action="file_copied_batch", target_type="file", target_id=new_file.id)
updated_files.append(new_file)
except Exception as e:
failed_operations.append({"file_id": file_id, "reason": str(e)})
return {
"succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files],
"failed": failed_operations
"failed": failed_operations,
}
@router.put("/{file_id}/content", response_model=FileOut)
async def update_file_content(
file_id: int,
payload: FileContentUpdate,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
if not db_file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
if not db_file.mime_type or not db_file.mime_type.startswith('text/'):
if not db_file.mime_type or not db_file.mime_type.startswith("text/"):
editableExtensions = [
'txt', 'md', 'log', 'json', 'js', 'py', 'html', 'css',
'xml', 'yaml', 'yml', 'sh', 'bat', 'ini', 'conf', 'cfg'
"txt",
"md",
"log",
"json",
"js",
"py",
"html",
"css",
"xml",
"yaml",
"yml",
"sh",
"bat",
"ini",
"conf",
"cfg",
]
file_extension = os.path.splitext(db_file.name)[1][1:].lower()
if file_extension not in editableExtensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="File type is not editable"
detail="File type is not editable",
)
content_bytes = payload.content.encode('utf-8')
content_bytes = payload.content.encode("utf-8")
new_size = len(content_bytes)
size_diff = new_size - db_file.size
if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes:
raise HTTPException(
status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
detail="Storage quota exceeded"
detail="Storage quota exceeded",
)
new_hash = hashlib.sha256(content_bytes).hexdigest()
@ -465,7 +667,7 @@ async def update_file_content(
if new_storage_path != db_file.path:
try:
await storage_manager.delete_file(current_user.id, db_file.path)
except:
except Exception:
pass
db_file.path = new_storage_path
@ -477,6 +679,8 @@ async def update_file_content(
current_user.used_storage_bytes += size_diff
await current_user.save()
await log_activity(user=current_user, action="file_updated", target_type="file", target_id=file_id)
await log_activity(
user=current_user, action="file_updated", target_type="file", target_id=file_id
)
return await FileOut.from_tortoise_orm(db_file)

View File

@ -3,7 +3,13 @@ from typing import List, Optional
from ..auth import get_current_user
from ..models import User, Folder
from ..schemas import FolderCreate, FolderOut, FolderUpdate, BatchFolderOperation, BatchMoveCopyPayload
from ..schemas import (
FolderCreate,
FolderOut,
FolderUpdate,
BatchFolderOperation,
BatchMoveCopyPayload,
)
from ..activity import log_activity
router = APIRouter(
@ -11,12 +17,17 @@ router = APIRouter(
tags=["folders"],
)
@router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
async def create_folder(folder_in: FolderCreate, current_user: User = Depends(get_current_user)):
async def create_folder(
folder_in: FolderCreate, current_user: User = Depends(get_current_user)
):
# Check if parent folder exists and belongs to the current user
parent_folder = None
if folder_in.parent_id:
parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user)
parent_folder = await Folder.get_or_none(
id=folder_in.parent_id, owner=current_user
)
if not parent_folder:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -39,50 +50,82 @@ async def create_folder(folder_in: FolderCreate, current_user: User = Depends(ge
await log_activity(current_user, "folder_created", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder)
@router.get("/{folder_id}/path", response_model=List[FolderOut])
async def get_folder_path(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
async def get_folder_path(
folder_id: int, current_user: User = Depends(get_current_user)
):
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
path = []
current = folder
while current:
path.insert(0, await FolderOut.from_tortoise_orm(current))
if current.parent_id:
current = await Folder.get_or_none(id=current.parent_id, owner=current_user, is_deleted=False)
current = await Folder.get_or_none(
id=current.parent_id, owner=current_user, is_deleted=False
)
else:
current = None
return path
@router.get("/{folder_id}", response_model=FolderOut)
async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
return await FolderOut.from_tortoise_orm(folder)
@router.get("/", response_model=List[FolderOut])
async def list_folders(parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
async def list_folders(
parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)
):
if parent_id:
parent_folder = await Folder.get_or_none(id=parent_id, owner=current_user, is_deleted=False)
parent_folder = await Folder.get_or_none(
id=parent_id, owner=current_user, is_deleted=False
)
if not parent_folder:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Parent folder not found or does not belong to the current user",
)
folders = await Folder.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
folders = await Folder.filter(
parent=parent_folder, owner=current_user, is_deleted=False
).order_by("name")
else:
# List root folders (folders with no parent)
folders = await Folder.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
folders = await Folder.filter(
parent=None, owner=current_user, is_deleted=False
).order_by("name")
return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
@router.put("/{folder_id}", response_model=FolderOut)
async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
async def update_folder(
folder_id: int,
folder_in: FolderUpdate,
current_user: User = Depends(get_current_user),
):
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
if folder_in.name:
existing_folder = await Folder.get_or_none(
@ -97,11 +140,16 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
if folder_in.parent_id is not None:
if folder_in.parent_id == folder_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot set folder as its own parent")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot set folder as its own parent",
)
new_parent_folder = None
if folder_in.parent_id != 0: # 0 could represent moving to root
new_parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user, is_deleted=False)
new_parent_folder = await Folder.get_or_none(
id=folder_in.parent_id, owner=current_user, is_deleted=False
)
if not new_parent_folder:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -113,46 +161,64 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
await log_activity(current_user, "folder_updated", "folder", folder.id)
return await FolderOut.from_tortoise_orm(folder)
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)):
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
folder.is_deleted = True
await folder.save()
await log_activity(current_user, "folder_deleted", "folder", folder.id)
return
@router.post("/{folder_id}/star", response_model=FolderOut)
async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)):
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
db_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
db_folder.is_starred = True
await db_folder.save()
await log_activity(current_user, "folder_starred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder)
@router.post("/{folder_id}/unstar", response_model=FolderOut)
async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)):
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
db_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
db_folder.is_starred = False
await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id)
return await FolderOut.from_tortoise_orm(db_folder)
@router.post("/batch", response_model=List[FolderOut])
async def batch_folder_operations(
batch_operation: BatchFolderOperation,
payload: Optional[BatchMoveCopyPayload] = None,
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
updated_folders = []
for folder_id in batch_operation.folder_ids:
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
db_folder = await Folder.get_or_none(
id=folder_id, owner=current_user, is_deleted=False
)
if not db_folder:
continue # Skip if folder not found or not owned by user
@ -171,13 +237,22 @@ async def batch_folder_operations(
await db_folder.save()
await log_activity(current_user, "folder_unstarred", "folder", folder_id)
updated_folders.append(db_folder)
elif batch_operation.operation == "move" and payload and payload.target_folder_id is not None:
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
elif (
batch_operation.operation == "move"
and payload
and payload.target_folder_id is not None
):
target_folder = await Folder.get_or_none(
id=payload.target_folder_id, owner=current_user, is_deleted=False
)
if not target_folder:
continue
existing_folder = await Folder.get_or_none(
name=db_folder.name, parent=target_folder, owner=current_user, is_deleted=False
name=db_folder.name,
parent=target_folder,
owner=current_user,
is_deleted=False,
)
if existing_folder and existing_folder.id != folder_id:
continue

View File

@ -11,15 +11,22 @@ router = APIRouter(
tags=["search"],
)
@router.get("/files", response_model=List[FileOut])
async def search_files(
q: str = Query(..., min_length=1, description="Search query"),
file_type: Optional[str] = Query(None, description="Filter by MIME type prefix (e.g., 'image', 'video')"),
file_type: Optional[str] = Query(
None, description="Filter by MIME type prefix (e.g., 'image', 'video')"
),
min_size: Optional[int] = Query(None, description="Minimum file size in bytes"),
max_size: Optional[int] = Query(None, description="Maximum file size in bytes"),
date_from: Optional[datetime] = Query(None, description="Filter files created after this date"),
date_to: Optional[datetime] = Query(None, description="Filter files created before this date"),
current_user: User = Depends(get_current_user)
date_from: Optional[datetime] = Query(
None, description="Filter files created after this date"
),
date_to: Optional[datetime] = Query(
None, description="Filter files created before this date"
),
current_user: User = Depends(get_current_user),
):
query = File.filter(owner=current_user, is_deleted=False, name__icontains=q)
@ -41,15 +48,16 @@ async def search_files(
files = await query.order_by("-created_at").limit(100)
return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/folders", response_model=List[FolderOut])
async def search_folders(
q: str = Query(..., min_length=1, description="Search query"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
folders = await Folder.filter(
owner=current_user,
is_deleted=False,
name__icontains=q
).order_by("-created_at").limit(100)
folders = (
await Folder.filter(owner=current_user, is_deleted=False, name__icontains=q)
.order_by("-created_at")
.limit(100)
)
return [await FolderOut.from_tortoise_orm(folder) for folder in folders]

View File

@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from typing import Optional, List
import secrets
from datetime import datetime, timedelta
from datetime import datetime
from ..auth import get_current_user
from ..models import User, File, Folder, Share
@ -17,25 +17,44 @@ router = APIRouter(
tags=["shares"],
)
@router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED)
async def create_share_link(share_in: ShareCreate, current_user: User = Depends(get_current_user)):
async def create_share_link(
share_in: ShareCreate, current_user: User = Depends(get_current_user)
):
if not share_in.file_id and not share_in.folder_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either file_id or folder_id must be provided")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either file_id or folder_id must be provided",
)
if share_in.file_id and share_in.folder_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot share both a file and a folder simultaneously")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot share both a file and a folder simultaneously",
)
file = None
folder = None
if share_in.file_id:
file = await File.get_or_none(id=share_in.file_id, owner=current_user, is_deleted=False)
file = await File.get_or_none(
id=share_in.file_id, owner=current_user, is_deleted=False
)
if not file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found or does not belong to you")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found or does not belong to you",
)
if share_in.folder_id:
folder = await Folder.get_or_none(id=share_in.folder_id, owner=current_user, is_deleted=False)
folder = await Folder.get_or_none(
id=share_in.folder_id, owner=current_user, is_deleted=False
)
if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found or does not belong to you")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Folder not found or does not belong to you",
)
token = secrets.token_urlsafe(16)
hashed_password = None
@ -60,8 +79,14 @@ async def create_share_link(share_in: ShareCreate, current_user: User = Depends(
item_type = "file" if file else "folder"
item_name = file.name if file else folder.name
expiry_text = f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" if share_in.expires_at else ""
password_text = f"\n\nPassword: {share_in.password}" if share_in.password else ""
expiry_text = (
f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}"
if share_in.expires_at
else ""
)
password_text = (
f"\n\nPassword: {share_in.password}" if share_in.password else ""
)
email_body = f"""Hello,
@ -95,26 +120,32 @@ MyWebdav File Sharing Service"""
to_email=share_in.invite_email,
subject=f"{current_user.username} shared {item_name} with you",
body=email_body,
html=email_html
html=email_html,
)
except Exception as e:
print(f"Failed to send invitation email: {e}")
return await ShareOut.from_tortoise_orm(share)
@router.get("/my", response_model=List[ShareOut])
async def list_my_shares(current_user: User = Depends(get_current_user)):
shares = await Share.filter(owner=current_user).order_by("-created_at")
return [await ShareOut.from_tortoise_orm(share) for share in shares]
@router.get("/{share_token}", response_model=ShareOut)
async def get_share_link_info(share_token: str):
share = await Share.get_or_none(token=share_token)
if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
raise HTTPException(
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
# Increment access count
share.access_count += 1
@ -122,11 +153,17 @@ async def get_share_link_info(share_token: str):
return await ShareOut.from_tortoise_orm(share)
@router.put("/{share_id}", response_model=ShareOut)
async def update_share(share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)):
async def update_share(
share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)
):
share = await Share.get_or_none(id=share_id, owner=current_user)
if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share link not found or does not belong to you",
)
if share_in.expires_at is not None:
share.expires_at = share_in.expires_at
@ -144,31 +181,42 @@ async def update_share(share_id: int, share_in: ShareCreate, current_user: User
await share.save()
return await ShareOut.from_tortoise_orm(share)
@router.post("/{share_token}/access")
async def access_shared_content(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token)
if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
raise HTTPException(
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
if share.password_protected:
if not password or not verify_password(password, share.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
)
result = {"message": "Access granted", "permission_level": share.permission_level}
if share.file_id:
file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
result["file"] = await FileOut.from_tortoise_orm(file)
result["type"] = "file"
elif share.folder_id:
folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False)
if not folder:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
)
result["folder"] = await FolderOut.from_tortoise_orm(folder)
result["type"] = "folder"
@ -179,29 +227,42 @@ async def access_shared_content(share_token: str, password: Optional[str] = None
return result
@router.get("/{share_token}/download")
async def download_shared_file(share_token: str, password: Optional[str] = None):
share = await Share.get_or_none(token=share_token)
if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
)
if share.expires_at and share.expires_at < datetime.utcnow():
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
raise HTTPException(
status_code=status.HTTP_410_GONE, detail="Share link has expired"
)
if share.password_protected:
if not password or not verify_password(password, share.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
)
if not share.file_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This share is not for a file")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This share is not for a file",
)
file = await File.get_or_none(id=share.file_id, is_deleted=False)
if not file:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
)
owner = await User.get(id=file.owner_id)
try:
async def file_iterator():
async for chunk in storage_manager.get_file(owner.id, file.path):
yield chunk
@ -209,18 +270,22 @@ async def download_shared_file(share_token: str, password: Optional[str] = None)
return StreamingResponse(
content=file_iterator(),
media_type=file.mime_type,
headers={
"Content-Disposition": f'attachment; filename="{file.name}"'
}
headers={"Content-Disposition": f'attachment; filename="{file.name}"'},
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage")
@router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_share_link(share_id: int, current_user: User = Depends(get_current_user)):
async def delete_share_link(
share_id: int, current_user: User = Depends(get_current_user)
):
share = await Share.get_or_none(id=share_id, owner=current_user)
if not share:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Share link not found or does not belong to you",
)
await share.delete()
return

View File

@ -11,18 +11,29 @@ router = APIRouter(
tags=["starred"],
)
@router.get("/files", response_model=List[FileOut])
async def list_starred_files(current_user: User = Depends(get_current_user)):
files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
files = await File.filter(
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
return [await FileOut.from_tortoise_orm(f) for f in files]
@router.get("/folders", response_model=List[FolderOut])
async def list_starred_folders(current_user: User = Depends(get_current_user)):
folders = await Folder.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
folders = await Folder.filter(
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
return [await FolderOut.from_tortoise_orm(f) for f in folders]
@router.get("/all", response_model=List[FileOut]) # This will return files and folders as files for now
@router.get(
"/all", response_model=List[FileOut]
) # This will return files and folders as files for now
async def list_all_starred(current_user: User = Depends(get_current_user)):
starred_files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
starred_files = await File.filter(
owner=current_user, is_starred=True, is_deleted=False
).order_by("name")
# For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema.
return [await FileOut.from_tortoise_orm(f) for f in starred_files]

View File

@ -1,17 +1,19 @@
from fastapi import APIRouter, Depends
from ..auth import get_current_user
from ..models import User_Pydantic, User, File, Folder
from typing import List, Dict, Any
from typing import Dict, Any
router = APIRouter(
prefix="/users",
tags=["users"],
)
@router.get("/me", response_model=User_Pydantic)
async def read_users_me(current_user: User = Depends(get_current_user)):
return await User_Pydantic.from_tortoise_orm(current_user)
@router.get("/me/export", response_model=Dict[str, Any])
async def export_my_data(current_user: User = Depends(get_current_user)):
"""
@ -35,6 +37,7 @@ async def export_my_data(current_user: User = Depends(get_current_user)):
# share information, etc., would also be included.
}
@router.delete("/me", status_code=204)
async def delete_my_account(current_user: User = Depends(get_current_user)):
"""
@ -48,5 +51,3 @@ async def delete_my_account(current_user: User = Depends(get_current_user)):
# Finally, delete the user account
await current_user.delete()
return {}

View File

@ -2,19 +2,24 @@ from datetime import datetime
from pydantic import BaseModel, EmailStr, ConfigDict
from typing import Optional, List
from tortoise.contrib.pydantic import pydantic_model_creator
from mywebdav.models import Folder, File, Share, FileVersion
class UserCreate(BaseModel):
username: str
email: EmailStr
password: str
class UserLogin(BaseModel):
username: str
password: str
class UserLoginWith2FA(UserLogin):
two_factor_code: Optional[str] = None
class UserAdminUpdate(BaseModel):
username: Optional[str] = None
email: Optional[EmailStr] = None
@ -25,22 +30,27 @@ class UserAdminUpdate(BaseModel):
plan_type: Optional[str] = None
is_2fa_enabled: Optional[bool] = None
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
two_factor_verified: bool = False
class FolderCreate(BaseModel):
name: str
parent_id: Optional[int] = None
class FolderUpdate(BaseModel):
name: Optional[str] = None
parent_id: Optional[int] = None
class ShareCreate(BaseModel):
file_id: Optional[int] = None
folder_id: Optional[int] = None
@ -49,9 +59,11 @@ class ShareCreate(BaseModel):
permission_level: str = "viewer"
invite_email: Optional[EmailStr] = None
class TeamCreate(BaseModel):
name: str
class TeamOut(BaseModel):
id: int
name: str
@ -60,6 +72,7 @@ class TeamOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
class ActivityOut(BaseModel):
id: int
user_id: Optional[int] = None
@ -71,12 +84,14 @@ class ActivityOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
class FileRequestCreate(BaseModel):
title: str
description: Optional[str] = None
target_folder_id: int
expires_at: Optional[datetime] = None
class FileRequestOut(BaseModel):
id: int
title: str
@ -90,25 +105,28 @@ class FileRequestOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
from mywebdav.models import Folder, File, Share, FileVersion
FolderOut = pydantic_model_creator(Folder, name="FolderOut")
FileOut = pydantic_model_creator(File, name="FileOut")
ShareOut = pydantic_model_creator(Share, name="ShareOut")
FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut")
class ErrorResponse(BaseModel):
code: int
message: str
details: Optional[str] = None
class BatchFileOperation(BaseModel):
file_ids: List[int]
operation: str # e.g., "delete", "move", "copy", "star", "unstar"
class BatchFolderOperation(BaseModel):
folder_ids: List[int]
operation: str # e.g., "delete", "move", "star", "unstar"
class BatchMoveCopyPayload(BaseModel):
target_folder_id: Optional[int] = None

View File

@ -2,8 +2,9 @@ import os
import sys
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', extra='ignore')
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
DATABASE_URL: str = "sqlite:///app/mywebdav.db"
REDIS_URL: str = "redis://redis:6379/0"
@ -30,8 +31,14 @@ class Settings(BaseSettings):
STRIPE_WEBHOOK_SECRET: str = ""
BILLING_ENABLED: bool = False
settings = Settings()
if settings.SECRET_KEY == "super_secret_key" and os.getenv("ENVIRONMENT") == "production":
print("ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable.")
if (
settings.SECRET_KEY == "super_secret_key"
and os.getenv("ENVIRONMENT") == "production"
):
print(
"ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable."
)
sys.exit(1)

View File

@ -5,6 +5,7 @@ from typing import AsyncGenerator
from .settings import settings
class StorageManager:
def __init__(self, base_path: str = settings.STORAGE_PATH):
self.base_path = Path(base_path)
@ -12,7 +13,11 @@ class StorageManager:
async def _get_full_path(self, user_id: int, file_path: str) -> Path:
# Ensure file_path is relative and safe
relative_path = Path(file_path).relative_to('/') if str(file_path).startswith('/') else Path(file_path)
relative_path = (
Path(file_path).relative_to("/")
if str(file_path).startswith("/")
else Path(file_path)
)
full_path = self.base_path / str(user_id) / relative_path
full_path.parent.mkdir(parents=True, exist_ok=True)
return full_path
@ -52,4 +57,5 @@ class StorageManager:
full_path = await self._get_full_path(user_id, file_path)
return full_path.exists()
storage_manager = StorageManager()

View File

@ -1,4 +1,3 @@
import os
import asyncio
from pathlib import Path
from PIL import Image
@ -9,7 +8,10 @@ from .settings import settings
THUMBNAIL_SIZE = (300, 300)
THUMBNAIL_DIR = "thumbnails"
async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Optional[str]:
async def generate_thumbnail(
file_path: str, mime_type: str, user_id: int
) -> Optional[str]:
try:
if mime_type.startswith("image/"):
return await generate_image_thumbnail(file_path, user_id)
@ -20,6 +22,7 @@ async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Op
print(f"Error generating thumbnail for {file_path}: {e}")
return None
async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop()
@ -30,12 +33,16 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
file_name = Path(file_path).name
thumbnail_name = f"thumb_{file_name}"
if not thumbnail_name.lower().endswith(('.jpg', '.jpeg', '.png')):
if not thumbnail_name.lower().endswith((".jpg", ".jpeg", ".png")):
thumbnail_name += ".jpg"
thumbnail_path = thumbnail_dir / thumbnail_name
actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
actual_file_path = (
base_path / str(user_id) / file_path
if not Path(file_path).is_absolute()
else Path(file_path)
)
with Image.open(actual_file_path) as img:
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
@ -44,7 +51,9 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == "P":
img = img.convert("RGBA")
background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
background.paste(
img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None
)
img = background
img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True)
@ -53,6 +62,7 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
return await loop.run_in_executor(None, _generate)
async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]:
loop = asyncio.get_event_loop()
@ -65,17 +75,29 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
thumbnail_name = f"thumb_{file_name}.jpg"
thumbnail_path = thumbnail_dir / thumbnail_name
actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
actual_file_path = (
base_path / str(user_id) / file_path
if not Path(file_path).is_absolute()
else Path(file_path)
)
subprocess.run([
subprocess.run(
[
"ffmpeg",
"-i", str(actual_file_path),
"-ss", "00:00:01",
"-vframes", "1",
"-vf", f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
"-i",
str(actual_file_path),
"-ss",
"00:00:01",
"-vframes",
"1",
"-vf",
f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
"-y",
str(thumbnail_path)
], check=True, capture_output=True)
str(thumbnail_path),
],
check=True,
capture_output=True,
)
return str(thumbnail_path.relative_to(base_path / str(user_id)))
@ -84,6 +106,7 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
except subprocess.CalledProcessError:
return None
async def delete_thumbnail(thumbnail_path: str, user_id: int):
try:
base_path = Path(settings.STORAGE_PATH)

View File

@ -5,15 +5,20 @@ import base64
import secrets
import hashlib
from typing import List, Optional
from typing import List
def generate_totp_secret() -> str:
"""Generates a random base32 TOTP secret."""
return pyotp.random_base32()
def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str:
"""Generates a Google Authenticator-compatible TOTP URI."""
return pyotp.totp.TOTP(secret).provisioning_uri(name=account_name, issuer_name=issuer_name)
return pyotp.totp.TOTP(secret).provisioning_uri(
name=account_name, issuer_name=issuer_name
)
def generate_qr_code_base64(uri: str) -> str:
"""Generates a base64 encoded QR code image for a given URI."""
@ -31,27 +36,33 @@ def generate_qr_code_base64(uri: str) -> str:
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def verify_totp_code(secret: str, code: str) -> bool:
"""Verifies a TOTP code against a secret."""
totp = pyotp.TOTP(secret)
return totp.verify(code)
def generate_recovery_codes(num_codes: int = 10) -> List[str]:
"""Generates a list of random recovery codes."""
return [secrets.token_urlsafe(16) for _ in range(num_codes)]
def hash_recovery_code(code: str) -> str:
"""Hashes a single recovery code using SHA256."""
return hashlib.sha256(code.encode('utf-8')).hexdigest()
return hashlib.sha256(code.encode("utf-8")).hexdigest()
def verify_recovery_code(plain_code: str, hashed_code: str) -> bool:
"""Verifies a plain recovery code against its hashed version."""
return hash_recovery_code(plain_code) == hashed_code
def hash_recovery_codes(codes: List[str]) -> List[str]:
"""Hashes a list of recovery codes."""
return [hash_recovery_code(code) for code in codes]
def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool:
"""Verifies if a plain recovery code matches any of the hashed recovery codes."""
for hashed_code in hashed_codes:

View File

@ -19,6 +19,7 @@ router = APIRouter(
tags=["webdav"],
)
class WebDAVLock:
locks = {}
@ -26,10 +27,10 @@ class WebDAVLock:
def create_lock(cls, path: str, user_id: int, timeout: int = 3600):
lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}"
cls.locks[path] = {
'token': lock_token,
'user_id': user_id,
'created_at': datetime.now(),
'timeout': timeout
"token": lock_token,
"user_id": user_id,
"created_at": datetime.now(),
"timeout": timeout,
}
return lock_token
@ -42,17 +43,18 @@ class WebDAVLock:
if path in cls.locks:
del cls.locks[path]
async def basic_auth(authorization: Optional[str] = Header(None)):
if not authorization:
return None
try:
scheme, credentials = authorization.split()
if scheme.lower() != 'basic':
if scheme.lower() != "basic":
return None
decoded = base64.b64decode(credentials).decode('utf-8')
username, password = decoded.split(':', 1)
decoded = base64.b64decode(credentials).decode("utf-8")
username, password = decoded.split(":", 1)
user = await User.get_or_none(username=username)
if user and verify_password(password, user.hashed_password):
@ -62,6 +64,7 @@ async def basic_auth(authorization: Optional[str] = Header(None)):
return None
async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)):
user = await basic_auth(authorization)
if user:
@ -75,14 +78,15 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={'WWW-Authenticate': 'Basic realm="MyWebdav WebDAV"'}
headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'},
)
async def resolve_path(path_str: str, user: User):
if not path_str or path_str == '/':
if not path_str or path_str == "/":
return None, None, True
parts = [p for p in path_str.split('/') if p]
parts = [p for p in path_str.split("/") if p]
if not parts:
return None, None, True
@ -90,10 +94,7 @@ async def resolve_path(path_str: str, user: User):
current_folder = None
for i, part in enumerate(parts[:-1]):
folder = await Folder.get_or_none(
name=part,
parent=current_folder,
owner=user,
is_deleted=False
name=part, parent=current_folder, owner=user, is_deleted=False
)
if not folder:
return None, None, False
@ -102,39 +103,37 @@ async def resolve_path(path_str: str, user: User):
last_part = parts[-1]
folder = await Folder.get_or_none(
name=last_part,
parent=current_folder,
owner=user,
is_deleted=False
name=last_part, parent=current_folder, owner=user, is_deleted=False
)
if folder:
return folder, current_folder, True
file = await File.get_or_none(
name=last_part,
parent=current_folder,
owner=user,
is_deleted=False
name=last_part, parent=current_folder, owner=user, is_deleted=False
)
if file:
return file, current_folder, True
return None, current_folder, True
def build_href(base_path: str, name: str, is_collection: bool):
path = f"{base_path.rstrip('/')}/{name}"
if is_collection:
path += '/'
path += "/"
return path
async def get_custom_properties(resource_type: str, resource_id: int):
props = await WebDAVProperty.filter(
resource_type=resource_type,
resource_id=resource_id
resource_type=resource_type, resource_id=resource_id
)
return {(prop.namespace, prop.name): prop.value for prop in props}
def create_propstat_element(props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"):
def create_propstat_element(
props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"
):
propstat = ET.Element("D:propstat")
prop = ET.SubElement(propstat, "D:prop")
@ -151,10 +150,10 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
elem.text = value
elif key == "getlastmodified":
elem = ET.SubElement(prop, f"D:{key}")
elem.text = value.strftime('%a, %d %b %Y %H:%M:%S GMT')
elem.text = value.strftime("%a, %d %b %Y %H:%M:%S GMT")
elif key == "creationdate":
elem = ET.SubElement(prop, f"D:{key}")
elem.text = value.isoformat() + 'Z'
elem.text = value.isoformat() + "Z"
elif key == "displayname":
elem = ET.SubElement(prop, f"D:{key}")
elem.text = value
@ -174,6 +173,7 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
return propstat
def parse_propfind_body(body: bytes):
if not body:
return None
@ -193,8 +193,8 @@ def parse_propfind_body(body: bytes):
if prop is not None:
requested_props = []
for child in prop:
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
name = child.tag.split('}')[1] if '}' in child.tag else child.tag
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
name = child.tag.split("}")[1] if "}" in child.tag else child.tag
requested_props.append((ns, name))
return requested_props
except ET.ParseError:
@ -202,6 +202,7 @@ def parse_propfind_body(body: bytes):
return None
@router.api_route("/{full_path:path}", methods=["OPTIONS"])
async def webdav_options(full_path: str):
return Response(
@ -209,14 +210,17 @@ async def webdav_options(full_path: str):
headers={
"DAV": "1, 2",
"Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK",
"MS-Author-Via": "DAV"
}
"MS-Author-Via": "DAV",
},
)
@router.api_route("/{full_path:path}", methods=["PROPFIND"])
async def handle_propfind(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
async def handle_propfind(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
depth = request.headers.get("Depth", "1")
full_path = unquote(full_path).strip('/')
full_path = unquote(full_path).strip("/")
body = await request.body()
requested_props = parse_propfind_body(body)
@ -232,19 +236,23 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
if resource is None:
response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href")
href.text = base_href if base_href.endswith('/') else base_href + '/'
href.text = base_href if base_href.endswith("/") else base_href + "/"
props = {
"resourcetype": "collection",
"displayname": full_path.split('/')[-1] if full_path else "Root",
"displayname": full_path.split("/")[-1] if full_path else "Root",
"creationdate": datetime.now(),
"getlastmodified": datetime.now()
"getlastmodified": datetime.now(),
}
response.append(create_propstat_element(props))
if depth in ["1", "infinity"]:
folders = await Folder.filter(owner=current_user, parent=parent, is_deleted=False)
files = await File.filter(owner=current_user, parent=parent, is_deleted=False)
folders = await Folder.filter(
owner=current_user, parent=parent, is_deleted=False
)
files = await File.filter(
owner=current_user, parent=parent, is_deleted=False
)
for folder in folders:
response = ET.SubElement(multistatus, "D:response")
@ -255,9 +263,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"resourcetype": "collection",
"displayname": folder.name,
"creationdate": folder.created_at,
"getlastmodified": folder.updated_at
"getlastmodified": folder.updated_at,
}
custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
custom_props = (
await get_custom_properties("folder", folder.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props))
for file in files:
@ -272,28 +288,48 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": file.mime_type,
"creationdate": file.created_at,
"getlastmodified": file.updated_at,
"getetag": f'"{file.file_hash}"'
"getetag": f'"{file.file_hash}"',
}
custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
custom_props = (
await get_custom_properties("file", file.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, Folder):
response = ET.SubElement(multistatus, "D:response")
href = ET.SubElement(response, "D:href")
href.text = base_href if base_href.endswith('/') else base_href + '/'
href.text = base_href if base_href.endswith("/") else base_href + "/"
props = {
"resourcetype": "collection",
"displayname": resource.name,
"creationdate": resource.created_at,
"getlastmodified": resource.updated_at
"getlastmodified": resource.updated_at,
}
custom_props = await get_custom_properties("folder", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
custom_props = (
await get_custom_properties("folder", resource.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props))
if depth in ["1", "infinity"]:
folders = await Folder.filter(owner=current_user, parent=resource, is_deleted=False)
files = await File.filter(owner=current_user, parent=resource, is_deleted=False)
folders = await Folder.filter(
owner=current_user, parent=resource, is_deleted=False
)
files = await File.filter(
owner=current_user, parent=resource, is_deleted=False
)
for folder in folders:
response = ET.SubElement(multistatus, "D:response")
@ -304,9 +340,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"resourcetype": "collection",
"displayname": folder.name,
"creationdate": folder.created_at,
"getlastmodified": folder.updated_at
"getlastmodified": folder.updated_at,
}
custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
custom_props = (
await get_custom_properties("folder", folder.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props))
for file in files:
@ -321,9 +365,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": file.mime_type,
"creationdate": file.created_at,
"getlastmodified": file.updated_at,
"getetag": f'"{file.file_hash}"'
"getetag": f'"{file.file_hash}"',
}
custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
custom_props = (
await get_custom_properties("file", file.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props))
elif isinstance(resource, File):
@ -338,17 +390,32 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
"getcontenttype": resource.mime_type,
"creationdate": resource.created_at,
"getlastmodified": resource.updated_at,
"getetag": f'"{resource.file_hash}"'
"getetag": f'"{resource.file_hash}"',
}
custom_props = await get_custom_properties("file", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
custom_props = (
await get_custom_properties("file", resource.id)
if requested_props == "allprop"
or (
isinstance(requested_props, list)
and any(ns != "DAV:" for ns, _ in requested_props)
)
else None
)
response.append(create_propstat_element(props, custom_props))
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
return Response(
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=207,
)
@router.api_route("/{full_path:path}", methods=["GET", "HEAD"])
async def handle_get(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_get(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user)
@ -363,8 +430,10 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
"Content-Length": str(resource.size),
"Content-Type": resource.mime_type,
"ETag": f'"{resource.file_hash}"',
"Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
}
"Last-Modified": resource.updated_at.strftime(
"%a, %d %b %Y %H:%M:%S GMT"
),
},
)
async def file_iterator():
@ -377,23 +446,28 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
headers={
"Content-Disposition": f'attachment; filename="{resource.name}"',
"ETag": f'"{resource.file_hash}"',
"Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
}
"Last-Modified": resource.updated_at.strftime(
"%a, %d %b %Y %H:%M:%S GMT"
),
},
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found in storage")
@router.api_route("/{full_path:path}", methods=["PUT"])
async def handle_put(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_put(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
if not full_path:
raise HTTPException(status_code=400, detail="Cannot PUT to root")
parts = [p for p in full_path.split('/') if p]
parts = [p for p in full_path.split("/") if p]
file_name = parts[-1]
parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
_, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists:
@ -417,10 +491,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
mime_type = "application/octet-stream"
existing_file = await File.get_or_none(
name=file_name,
parent=parent_folder,
owner=current_user,
is_deleted=False
name=file_name, parent=parent_folder, owner=current_user, is_deleted=False
)
if existing_file:
@ -432,7 +503,9 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
existing_file.updated_at = datetime.now()
await existing_file.save()
current_user.used_storage_bytes = current_user.used_storage_bytes - old_size + file_size
current_user.used_storage_bytes = (
current_user.used_storage_bytes - old_size + file_size
)
await current_user.save()
await log_activity(current_user, "file_updated", "file", existing_file.id)
@ -445,7 +518,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
mime_type=mime_type,
file_hash=file_hash,
owner=current_user,
parent=parent_folder
parent=parent_folder,
)
current_user.used_storage_bytes += file_size
@ -454,9 +527,12 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
await log_activity(current_user, "file_created", "file", db_file.id)
return Response(status_code=201)
@router.api_route("/{full_path:path}", methods=["DELETE"])
async def handle_delete(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_delete(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
if not full_path:
raise HTTPException(status_code=400, detail="Cannot DELETE root")
@ -478,44 +554,45 @@ async def handle_delete(request: Request, full_path: str, current_user: User = D
return Response(status_code=204)
@router.api_route("/{full_path:path}", methods=["MKCOL"])
async def handle_mkcol(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_mkcol(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
if not full_path:
raise HTTPException(status_code=400, detail="Cannot MKCOL at root")
parts = [p for p in full_path.split('/') if p]
parts = [p for p in full_path.split("/") if p]
folder_name = parts[-1]
parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
_, parent_folder, exists = await resolve_path(parent_path, current_user)
if not exists:
raise HTTPException(status_code=409, detail="Parent folder does not exist")
existing = await Folder.get_or_none(
name=folder_name,
parent=parent_folder,
owner=current_user,
is_deleted=False
name=folder_name, parent=parent_folder, owner=current_user, is_deleted=False
)
if existing:
raise HTTPException(status_code=405, detail="Folder already exists")
folder = await Folder.create(
name=folder_name,
parent=parent_folder,
owner=current_user
name=folder_name, parent=parent_folder, owner=current_user
)
await log_activity(current_user, "folder_created", "folder", folder.id)
return Response(status_code=201)
@router.api_route("/{full_path:path}", methods=["COPY"])
async def handle_copy(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_copy(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T")
@ -523,7 +600,7 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path)
dest_path = dest_path.replace('/webdav/', '').strip('/')
dest_path = dest_path.replace("/webdav/", "").strip("/")
source_resource, _, exists = await resolve_path(full_path, current_user)
@ -533,9 +610,9 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
if not isinstance(source_resource, File):
raise HTTPException(status_code=501, detail="Only file copy is implemented")
dest_parts = [p for p in dest_path.split('/') if p]
dest_parts = [p for p in dest_path.split("/") if p]
dest_name = dest_parts[-1]
dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@ -543,14 +620,13 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=409, detail="Destination parent does not exist")
existing_dest = await File.get_or_none(
name=dest_name,
parent=dest_parent,
owner=current_user,
is_deleted=False
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
)
if existing_dest and overwrite == "F":
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
raise HTTPException(
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest:
await existing_dest.delete()
@ -562,15 +638,18 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
mime_type=source_resource.mime_type,
file_hash=source_resource.file_hash,
owner=current_user,
parent=dest_parent
parent=dest_parent,
)
await log_activity(current_user, "file_copied", "file", new_file.id)
return Response(status_code=201 if not existing_dest else 204)
@router.api_route("/{full_path:path}", methods=["MOVE"])
async def handle_move(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_move(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
destination = request.headers.get("Destination")
overwrite = request.headers.get("Overwrite", "T")
@ -578,16 +657,16 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
raise HTTPException(status_code=400, detail="Destination header required")
dest_path = unquote(urlparse(destination).path)
dest_path = dest_path.replace('/webdav/', '').strip('/')
dest_path = dest_path.replace("/webdav/", "").strip("/")
source_resource, _, exists = await resolve_path(full_path, current_user)
if not source_resource:
raise HTTPException(status_code=404, detail="Source not found")
dest_parts = [p for p in dest_path.split('/') if p]
dest_parts = [p for p in dest_path.split("/") if p]
dest_name = dest_parts[-1]
dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
@ -596,14 +675,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
if isinstance(source_resource, File):
existing_dest = await File.get_or_none(
name=dest_name,
parent=dest_parent,
owner=current_user,
is_deleted=False
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
)
if existing_dest and overwrite == "F":
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
raise HTTPException(
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest:
await existing_dest.delete()
@ -617,14 +695,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
elif isinstance(source_resource, Folder):
existing_dest = await Folder.get_or_none(
name=dest_name,
parent=dest_parent,
owner=current_user,
is_deleted=False
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
)
if existing_dest and overwrite == "F":
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
raise HTTPException(
status_code=412, detail="Destination exists and overwrite is false"
)
if existing_dest:
existing_dest.is_deleted = True
@ -637,9 +714,12 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
await log_activity(current_user, "folder_moved", "folder", source_resource.id)
return Response(status_code=201 if not existing_dest else 204)
@router.api_route("/{full_path:path}", methods=["LOCK"])
async def handle_lock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_lock(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
timeout_header = request.headers.get("Timeout", "Second-3600")
timeout = 3600
@ -681,32 +761,38 @@ async def handle_lock(request: Request, full_path: str, current_user: User = Dep
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=200,
headers={"Lock-Token": f"<{lock_token}>"}
headers={"Lock-Token": f"<{lock_token}>"},
)
@router.api_route("/{full_path:path}", methods=["UNLOCK"])
async def handle_unlock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_unlock(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
lock_token_header = request.headers.get("Lock-Token")
if not lock_token_header:
raise HTTPException(status_code=400, detail="Lock-Token header required")
lock_token = lock_token_header.strip('<>')
lock_token = lock_token_header.strip("<>")
existing_lock = WebDAVLock.get_lock(full_path)
if not existing_lock or existing_lock['token'] != lock_token:
if not existing_lock or existing_lock["token"] != lock_token:
raise HTTPException(status_code=409, detail="Invalid lock token")
if existing_lock['user_id'] != current_user.id:
if existing_lock["user_id"] != current_user.id:
raise HTTPException(status_code=403, detail="Not lock owner")
WebDAVLock.remove_lock(full_path)
return Response(status_code=204)
@router.api_route("/{full_path:path}", methods=["PROPPATCH"])
async def handle_proppatch(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
full_path = unquote(full_path).strip('/')
async def handle_proppatch(
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
):
full_path = unquote(full_path).strip("/")
resource, parent, exists = await resolve_path(full_path, current_user)
@ -719,7 +805,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
try:
root = ET.fromstring(body)
except:
except Exception:
raise HTTPException(status_code=400, detail="Invalid XML")
resource_type = "file" if isinstance(resource, File) else "folder"
@ -734,13 +820,19 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
prop_element = set_element.find(".//{DAV:}prop")
if prop_element is not None:
for child in prop_element:
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
name = child.tag.split('}')[1] if '}' in child.tag else child.tag
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
name = child.tag.split("}")[1] if "}" in child.tag else child.tag
value = child.text or ""
if ns == "DAV:":
live_props = ["creationdate", "getcontentlength", "getcontenttype",
"getetag", "getlastmodified", "resourcetype"]
live_props = [
"creationdate",
"getcontentlength",
"getcontenttype",
"getetag",
"getlastmodified",
"resourcetype",
]
if name in live_props:
failed_props.append((ns, name, "409 Conflict"))
continue
@ -750,7 +842,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_type=resource_type,
resource_id=resource_id,
namespace=ns,
name=name
name=name,
)
if existing_prop:
@ -762,10 +854,10 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_id=resource_id,
namespace=ns,
name=name,
value=value
value=value,
)
set_props.append((ns, name))
except Exception as e:
except Exception:
failed_props.append((ns, name, "500 Internal Server Error"))
remove_element = root.find(".//{DAV:}remove")
@ -773,8 +865,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
prop_element = remove_element.find(".//{DAV:}prop")
if prop_element is not None:
for child in prop_element:
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
name = child.tag.split('}')[1] if '}' in child.tag else child.tag
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
name = child.tag.split("}")[1] if "}" in child.tag else child.tag
if ns == "DAV:":
failed_props.append((ns, name, "409 Conflict"))
@ -785,7 +877,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
resource_type=resource_type,
resource_id=resource_id,
namespace=ns,
name=name
name=name,
)
if existing_prop:
@ -793,7 +885,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
remove_props.append((ns, name))
else:
failed_props.append((ns, name, "404 Not Found"))
except Exception as e:
except Exception:
failed_props.append((ns, name, "500 Internal Server Error"))
multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"})
@ -837,4 +929,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
await log_activity(current_user, "properties_modified", resource_type, resource_id)
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
return Response(
content=xml_content,
media_type="application/xml; charset=utf-8",
status_code=207,
)

View File

@ -6,9 +6,15 @@ from httpx import AsyncClient, ASGITransport
from fastapi import status
from mywebdav.main import app
from mywebdav.models import User
from mywebdav.billing.models import PricingConfig, Invoice, UsageAggregate, UserSubscription
from mywebdav.billing.models import (
PricingConfig,
Invoice,
UsageAggregate,
UserSubscription,
)
from mywebdav.auth import create_access_token
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
@ -16,11 +22,12 @@ async def test_user():
email="test@example.com",
hashed_password="hashed_password_here",
is_active=True,
is_superuser=False
is_superuser=False,
)
yield user
await user.delete()
@pytest_asyncio.fixture
async def admin_user():
user = await User.create(
@ -28,58 +35,72 @@ async def admin_user():
email="admin@example.com",
hashed_password="hashed_password_here",
is_active=True,
is_superuser=True
is_superuser=True,
)
yield user
await user.delete()
@pytest_asyncio.fixture
async def auth_token(test_user):
token = create_access_token(data={"sub": test_user.username})
return token
@pytest_asyncio.fixture
async def admin_token(admin_user):
token = create_access_token(data={"sub": admin_user.username})
return token
@pytest_asyncio.fixture
async def pricing_config():
configs = []
configs.append(await PricingConfig.create(
configs.append(
await PricingConfig.create(
config_key="storage_per_gb_month",
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month"
))
configs.append(await PricingConfig.create(
unit="per_gb_month",
)
)
configs.append(
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb"
))
configs.append(await PricingConfig.create(
unit="per_gb",
)
)
configs.append(
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb"
))
configs.append(await PricingConfig.create(
unit="gb",
)
)
configs.append(
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb"
))
unit="gb",
)
)
yield configs
for config in configs:
await config.delete()
@pytest.mark.asyncio
async def test_get_current_usage(test_user, auth_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/billing/usage/current",
headers={"Authorization": f"Bearer {auth_token}"}
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@ -87,6 +108,7 @@ async def test_get_current_usage(test_user, auth_token):
assert "storage_gb" in data
assert "bandwidth_down_gb_today" in data
@pytest.mark.asyncio
async def test_get_monthly_usage(test_user, auth_token):
today = date.today()
@ -97,13 +119,15 @@ async def test_get_monthly_usage(test_user, auth_token):
storage_bytes_avg=1024**3 * 10,
storage_bytes_peak=1024**3 * 12,
bandwidth_up_bytes=1024**3 * 2,
bandwidth_down_bytes=1024 ** 3 * 5
bandwidth_down_bytes=1024**3 * 5,
)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
f"/api/billing/usage/monthly?year={today.year}&month={today.month}",
headers={"Authorization": f"Bearer {auth_token}"}
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@ -112,12 +136,15 @@ async def test_get_monthly_usage(test_user, auth_token):
await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_get_subscription(test_user, auth_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/billing/subscription",
headers={"Authorization": f"Bearer {auth_token}"}
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@ -127,6 +154,7 @@ async def test_get_subscription(test_user, auth_token):
await UserSubscription.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_list_invoices(test_user, auth_token):
invoice = await Invoice.create(
@ -137,13 +165,14 @@ async def test_list_invoices(test_user, auth_token):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="open"
status="open",
)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/billing/invoices",
headers={"Authorization": f"Bearer {auth_token}"}
"/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"}
)
assert response.status_code == status.HTTP_200_OK
@ -153,6 +182,7 @@ async def test_list_invoices(test_user, auth_token):
await invoice.delete()
@pytest.mark.asyncio
async def test_get_invoice(test_user, auth_token):
invoice = await Invoice.create(
@ -163,13 +193,15 @@ async def test_get_invoice(test_user, auth_token):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="open"
status="open",
)
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
f"/api/billing/invoices/{invoice.id}",
headers={"Authorization": f"Bearer {auth_token}"}
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_200_OK
@ -178,39 +210,45 @@ async def test_get_invoice(test_user, auth_token):
await invoice.delete()
@pytest.mark.asyncio
async def test_get_pricing():
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/api/billing/pricing")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, dict)
@pytest.mark.asyncio
async def test_admin_get_pricing(admin_user, admin_token, pricing_config):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) > 0
@pytest.mark.asyncio
async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
config_id = pricing_config[0].id
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.put(
f"/api/admin/billing/pricing/{config_id}",
headers={"Authorization": f"Bearer {admin_token}"},
json={
"config_key": "storage_per_gb_month",
"config_value": 0.005
}
json={"config_key": "storage_per_gb_month", "config_value": 0.005},
)
assert response.status_code == status.HTTP_200_OK
@ -218,12 +256,15 @@ async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
updated = await PricingConfig.get(id=config_id)
assert updated.config_value == Decimal("0.005")
@pytest.mark.asyncio
async def test_admin_get_stats(admin_user, admin_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/admin/billing/stats",
headers={"Authorization": f"Bearer {admin_token}"}
headers={"Authorization": f"Bearer {admin_token}"},
)
assert response.status_code == status.HTTP_200_OK
@ -232,12 +273,15 @@ async def test_admin_get_stats(admin_user, admin_token):
assert "total_invoices" in data
assert "pending_invoices" in data
@pytest.mark.asyncio
async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get(
"/api/admin/billing/pricing",
headers={"Authorization": f"Bearer {auth_token}"}
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@ -3,57 +3,76 @@ import pytest_asyncio
from decimal import Decimal
from datetime import date, datetime
from mywebdav.models import User
from mywebdav.billing.models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
from mywebdav.billing.models import (
Invoice,
InvoiceLineItem,
PricingConfig,
UsageAggregate,
UserSubscription,
)
from mywebdav.billing.invoice_generator import InvoiceGenerator
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
is_active=True
is_active=True,
)
yield user
await user.delete()
@pytest_asyncio.fixture
async def pricing_config():
configs = []
configs.append(await PricingConfig.create(
configs.append(
await PricingConfig.create(
config_key="storage_per_gb_month",
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month"
))
configs.append(await PricingConfig.create(
unit="per_gb_month",
)
)
configs.append(
await PricingConfig.create(
config_key="bandwidth_egress_per_gb",
config_value=Decimal("0.009"),
description="Bandwidth egress cost per GB",
unit="per_gb"
))
configs.append(await PricingConfig.create(
unit="per_gb",
)
)
configs.append(
await PricingConfig.create(
config_key="free_tier_storage_gb",
config_value=Decimal("15"),
description="Free tier storage in GB",
unit="gb"
))
configs.append(await PricingConfig.create(
unit="gb",
)
)
configs.append(
await PricingConfig.create(
config_key="free_tier_bandwidth_gb",
config_value=Decimal("15"),
description="Free tier bandwidth in GB per month",
unit="gb"
))
configs.append(await PricingConfig.create(
unit="gb",
)
)
configs.append(
await PricingConfig.create(
config_key="tax_rate_default",
config_value=Decimal("0.0"),
description="Default tax rate",
unit="percentage"
))
unit="percentage",
)
)
yield configs
for config in configs:
await config.delete()
@pytest.mark.asyncio
async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
today = date.today()
@ -64,10 +83,12 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
storage_bytes_avg=1024**3 * 50,
storage_bytes_peak=1024**3 * 55,
bandwidth_up_bytes=1024**3 * 10,
bandwidth_down_bytes=1024 ** 3 * 20
bandwidth_down_bytes=1024**3 * 20,
)
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
invoice = await InvoiceGenerator.generate_monthly_invoice(
test_user, today.year, today.month
)
assert invoice is not None
assert invoice.user_id == test_user.id
@ -80,6 +101,7 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
await invoice.delete()
await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config):
today = date.today()
@ -90,15 +112,18 @@ async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_confi
storage_bytes_avg=1024**3 * 10,
storage_bytes_peak=1024**3 * 12,
bandwidth_up_bytes=1024**3 * 5,
bandwidth_down_bytes=1024 ** 3 * 10
bandwidth_down_bytes=1024**3 * 10,
)
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
invoice = await InvoiceGenerator.generate_monthly_invoice(
test_user, today.year, today.month
)
assert invoice is None
await UsageAggregate.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_finalize_invoice(test_user, pricing_config):
invoice = await Invoice.create(
@ -109,7 +134,7 @@ async def test_finalize_invoice(test_user, pricing_config):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="draft"
status="draft",
)
finalized = await InvoiceGenerator.finalize_invoice(invoice)
@ -118,6 +143,7 @@ async def test_finalize_invoice(test_user, pricing_config):
await finalized.delete()
@pytest.mark.asyncio
async def test_finalize_invoice_already_finalized(test_user, pricing_config):
invoice = await Invoice.create(
@ -128,7 +154,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="open"
status="open",
)
with pytest.raises(ValueError):
@ -136,6 +162,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
await invoice.delete()
@pytest.mark.asyncio
async def test_mark_invoice_paid(test_user, pricing_config):
invoice = await Invoice.create(
@ -146,7 +173,7 @@ async def test_mark_invoice_paid(test_user, pricing_config):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="open"
status="open",
)
paid = await InvoiceGenerator.mark_invoice_paid(invoice)
@ -156,10 +183,13 @@ async def test_mark_invoice_paid(test_user, pricing_config):
await paid.delete()
@pytest.mark.asyncio
async def test_invoice_with_tax(test_user, pricing_config):
# Update tax rate
updated = await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.21"))
updated = await PricingConfig.filter(config_key="tax_rate_default").update(
config_value=Decimal("0.21")
)
assert updated == 1 # Should update 1 row
# Verify the update worked
@ -174,10 +204,12 @@ async def test_invoice_with_tax(test_user, pricing_config):
storage_bytes_avg=1024**3 * 50,
storage_bytes_peak=1024**3 * 55,
bandwidth_up_bytes=1024**3 * 10,
bandwidth_down_bytes=1024 ** 3 * 20
bandwidth_down_bytes=1024**3 * 20,
)
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
invoice = await InvoiceGenerator.generate_monthly_invoice(
test_user, today.year, today.month
)
assert invoice is not None
assert invoice.tax > 0
@ -185,4 +217,6 @@ async def test_invoice_with_tax(test_user, pricing_config):
await invoice.delete()
await UsageAggregate.filter(user=test_user).delete()
await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.0"))
await PricingConfig.filter(config_key="tax_rate_default").update(
config_value=Decimal("0.0")
)

View File

@ -4,21 +4,30 @@ from decimal import Decimal
from datetime import date, datetime
from mywebdav.models import User
from mywebdav.billing.models import (
SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
SubscriptionPlan,
UserSubscription,
UsageRecord,
UsageAggregate,
Invoice,
InvoiceLineItem,
PricingConfig,
PaymentMethod,
BillingEvent,
)
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
is_active=True
is_active=True,
)
yield user
await user.delete()
@pytest.mark.asyncio
async def test_subscription_plan_creation():
plan = await SubscriptionPlan.create(
@ -28,7 +37,7 @@ async def test_subscription_plan_creation():
storage_gb=100,
bandwidth_gb=100,
price_monthly=Decimal("5.00"),
price_yearly=Decimal("50.00")
price_yearly=Decimal("50.00"),
)
assert plan.name == "starter"
@ -36,12 +45,11 @@ async def test_subscription_plan_creation():
assert plan.price_monthly == Decimal("5.00")
await plan.delete()
@pytest.mark.asyncio
async def test_user_subscription_creation(test_user):
subscription = await UserSubscription.create(
user=test_user,
billing_type="pay_as_you_go",
status="active"
user=test_user, billing_type="pay_as_you_go", status="active"
)
assert subscription.user_id == test_user.id
@ -49,6 +57,7 @@ async def test_user_subscription_creation(test_user):
assert subscription.status == "active"
await subscription.delete()
@pytest.mark.asyncio
async def test_usage_record_creation(test_user):
usage = await UsageRecord.create(
@ -57,7 +66,7 @@ async def test_usage_record_creation(test_user):
amount_bytes=1024 * 1024 * 100,
resource_type="file",
resource_id=1,
idempotency_key="test_key_123"
idempotency_key="test_key_123",
)
assert usage.user_id == test_user.id
@ -65,6 +74,7 @@ async def test_usage_record_creation(test_user):
assert usage.amount_bytes == 1024 * 1024 * 100
await usage.delete()
@pytest.mark.asyncio
async def test_usage_aggregate_creation(test_user):
aggregate = await UsageAggregate.create(
@ -73,13 +83,14 @@ async def test_usage_aggregate_creation(test_user):
storage_bytes_avg=1024 * 1024 * 500,
storage_bytes_peak=1024 * 1024 * 600,
bandwidth_up_bytes=1024 * 1024 * 50,
bandwidth_down_bytes=1024 * 1024 * 100
bandwidth_down_bytes=1024 * 1024 * 100,
)
assert aggregate.user_id == test_user.id
assert aggregate.storage_bytes_avg == 1024 * 1024 * 500
await aggregate.delete()
@pytest.mark.asyncio
async def test_invoice_creation(test_user):
invoice = await Invoice.create(
@ -90,7 +101,7 @@ async def test_invoice_creation(test_user):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="draft"
status="draft",
)
assert invoice.user_id == test_user.id
@ -98,6 +109,7 @@ async def test_invoice_creation(test_user):
assert invoice.total == Decimal("10.00")
await invoice.delete()
@pytest.mark.asyncio
async def test_invoice_line_item_creation(test_user):
invoice = await Invoice.create(
@ -108,7 +120,7 @@ async def test_invoice_line_item_creation(test_user):
subtotal=Decimal("10.00"),
tax=Decimal("0.00"),
total=Decimal("10.00"),
status="draft"
status="draft",
)
line_item = await InvoiceLineItem.create(
@ -117,7 +129,7 @@ async def test_invoice_line_item_creation(test_user):
quantity=Decimal("100.000000"),
unit_price=Decimal("0.100000"),
amount=Decimal("10.0000"),
item_type="storage"
item_type="storage",
)
assert line_item.invoice_id == invoice.id
@ -126,6 +138,7 @@ async def test_invoice_line_item_creation(test_user):
await line_item.delete()
await invoice.delete()
@pytest.mark.asyncio
async def test_pricing_config_creation(test_user):
config = await PricingConfig.create(
@ -133,13 +146,14 @@ async def test_pricing_config_creation(test_user):
config_value=Decimal("0.0045"),
description="Storage cost per GB per month",
unit="per_gb_month",
updated_by=test_user
updated_by=test_user,
)
assert config.config_key == "storage_per_gb_month"
assert config.config_value == Decimal("0.0045")
await config.delete()
@pytest.mark.asyncio
async def test_payment_method_creation(test_user):
payment_method = await PaymentMethod.create(
@ -150,7 +164,7 @@ async def test_payment_method_creation(test_user):
last4="4242",
brand="visa",
exp_month=12,
exp_year=2025
exp_year=2025,
)
assert payment_method.user_id == test_user.id
@ -158,6 +172,7 @@ async def test_payment_method_creation(test_user):
assert payment_method.is_default is True
await payment_method.delete()
@pytest.mark.asyncio
async def test_billing_event_creation(test_user):
event = await BillingEvent.create(
@ -165,7 +180,7 @@ async def test_billing_event_creation(test_user):
event_type="invoice_created",
stripe_event_id="evt_test_123",
data={"invoice_id": 1},
processed=False
processed=False,
)
assert event.user_id == test_user.id

View File

@ -1,18 +1,35 @@
import pytest
def test_billing_module_imports():
from mywebdav.billing import models, stripe_client, usage_tracker, invoice_generator, scheduler
from mywebdav.billing import (
models,
stripe_client,
usage_tracker,
invoice_generator,
scheduler,
)
assert models is not None
assert stripe_client is not None
assert usage_tracker is not None
assert invoice_generator is not None
assert scheduler is not None
def test_billing_models_exist():
from mywebdav.billing.models import (
SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
SubscriptionPlan,
UserSubscription,
UsageRecord,
UsageAggregate,
Invoice,
InvoiceLineItem,
PricingConfig,
PaymentMethod,
BillingEvent,
)
assert SubscriptionPlan is not None
assert UserSubscription is not None
assert UsageRecord is not None
@ -23,55 +40,71 @@ def test_billing_models_exist():
assert PaymentMethod is not None
assert BillingEvent is not None
def test_stripe_client_exists():
from mywebdav.billing.stripe_client import StripeClient
assert StripeClient is not None
assert hasattr(StripeClient, 'create_customer')
assert hasattr(StripeClient, 'create_invoice')
assert hasattr(StripeClient, 'finalize_invoice')
assert hasattr(StripeClient, "create_customer")
assert hasattr(StripeClient, "create_invoice")
assert hasattr(StripeClient, "finalize_invoice")
def test_usage_tracker_exists():
from mywebdav.billing.usage_tracker import UsageTracker
assert UsageTracker is not None
assert hasattr(UsageTracker, 'track_storage')
assert hasattr(UsageTracker, 'track_bandwidth')
assert hasattr(UsageTracker, 'aggregate_daily_usage')
assert hasattr(UsageTracker, 'get_current_storage')
assert hasattr(UsageTracker, 'get_monthly_usage')
assert hasattr(UsageTracker, "track_storage")
assert hasattr(UsageTracker, "track_bandwidth")
assert hasattr(UsageTracker, "aggregate_daily_usage")
assert hasattr(UsageTracker, "get_current_storage")
assert hasattr(UsageTracker, "get_monthly_usage")
def test_invoice_generator_exists():
from mywebdav.billing.invoice_generator import InvoiceGenerator
assert InvoiceGenerator is not None
assert hasattr(InvoiceGenerator, 'generate_monthly_invoice')
assert hasattr(InvoiceGenerator, 'finalize_invoice')
assert hasattr(InvoiceGenerator, 'mark_invoice_paid')
assert hasattr(InvoiceGenerator, "generate_monthly_invoice")
assert hasattr(InvoiceGenerator, "finalize_invoice")
assert hasattr(InvoiceGenerator, "mark_invoice_paid")
def test_scheduler_exists():
from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler
assert scheduler is not None
assert callable(start_scheduler)
assert callable(stop_scheduler)
def test_routers_exist():
from mywebdav.routers import billing, admin_billing
assert billing is not None
assert admin_billing is not None
assert hasattr(billing, 'router')
assert hasattr(admin_billing, 'router')
assert hasattr(billing, "router")
assert hasattr(admin_billing, "router")
def test_middleware_exists():
from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware
assert UsageTrackingMiddleware is not None
def test_settings_updated():
from mywebdav.settings import settings
assert hasattr(settings, 'STRIPE_SECRET_KEY')
assert hasattr(settings, 'STRIPE_PUBLISHABLE_KEY')
assert hasattr(settings, 'STRIPE_WEBHOOK_SECRET')
assert hasattr(settings, 'BILLING_ENABLED')
assert hasattr(settings, "STRIPE_SECRET_KEY")
assert hasattr(settings, "STRIPE_PUBLISHABLE_KEY")
assert hasattr(settings, "STRIPE_WEBHOOK_SECRET")
assert hasattr(settings, "BILLING_ENABLED")
def test_main_includes_billing():
from mywebdav.main import app
routes = [route.path for route in app.routes]
billing_routes = [r for r in routes if '/billing' in r]
billing_routes = [r for r in routes if "/billing" in r]
assert len(billing_routes) > 0

View File

@ -6,24 +6,26 @@ from mywebdav.models import User, File, Folder
from mywebdav.billing.models import UsageRecord, UsageAggregate
from mywebdav.billing.usage_tracker import UsageTracker
@pytest_asyncio.fixture
async def test_user():
user = await User.create(
username="testuser",
email="test@example.com",
hashed_password="hashed_password_here",
is_active=True
is_active=True,
)
yield user
await user.delete()
@pytest.mark.asyncio
async def test_track_storage(test_user):
await UsageTracker.track_storage(
user=test_user,
amount_bytes=1024 * 1024 * 100,
resource_type="file",
resource_id=1
resource_id=1,
)
records = await UsageRecord.filter(user=test_user, record_type="storage").all()
@ -32,6 +34,7 @@ async def test_track_storage(test_user):
await records[0].delete()
@pytest.mark.asyncio
async def test_track_bandwidth_upload(test_user):
await UsageTracker.track_bandwidth(
@ -39,7 +42,7 @@ async def test_track_bandwidth_upload(test_user):
amount_bytes=1024 * 1024 * 50,
direction="up",
resource_type="file",
resource_id=1
resource_id=1,
)
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all()
@ -48,6 +51,7 @@ async def test_track_bandwidth_upload(test_user):
await records[0].delete()
@pytest.mark.asyncio
async def test_track_bandwidth_download(test_user):
await UsageTracker.track_bandwidth(
@ -55,32 +59,28 @@ async def test_track_bandwidth_download(test_user):
amount_bytes=1024 * 1024 * 75,
direction="down",
resource_type="file",
resource_id=1
resource_id=1,
)
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_down").all()
records = await UsageRecord.filter(
user=test_user, record_type="bandwidth_down"
).all()
assert len(records) == 1
assert records[0].amount_bytes == 1024 * 1024 * 75
await records[0].delete()
@pytest.mark.asyncio
async def test_aggregate_daily_usage(test_user):
await UsageTracker.track_storage(
user=test_user,
amount_bytes=1024 * 1024 * 100
await UsageTracker.track_storage(user=test_user, amount_bytes=1024 * 1024 * 100)
await UsageTracker.track_bandwidth(
user=test_user, amount_bytes=1024 * 1024 * 50, direction="up"
)
await UsageTracker.track_bandwidth(
user=test_user,
amount_bytes=1024 * 1024 * 50,
direction="up"
)
await UsageTracker.track_bandwidth(
user=test_user,
amount_bytes=1024 * 1024 * 75,
direction="down"
user=test_user, amount_bytes=1024 * 1024 * 75, direction="down"
)
aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today())
@ -93,12 +93,10 @@ async def test_aggregate_daily_usage(test_user):
await aggregate.delete()
await UsageRecord.filter(user=test_user).delete()
@pytest.mark.asyncio
async def test_get_current_storage(test_user):
folder = await Folder.create(
name="Test Folder",
owner=test_user
)
folder = await Folder.create(name="Test Folder", owner=test_user)
file1 = await File.create(
name="test1.txt",
@ -107,7 +105,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain",
owner=test_user,
parent=folder,
is_deleted=False
is_deleted=False,
)
file2 = await File.create(
@ -117,7 +115,7 @@ async def test_get_current_storage(test_user):
mime_type="text/plain",
owner=test_user,
parent=folder,
is_deleted=False
is_deleted=False,
)
storage = await UsageTracker.get_current_storage(test_user)
@ -128,6 +126,7 @@ async def test_get_current_storage(test_user):
await file2.delete()
await folder.delete()
@pytest.mark.asyncio
async def test_get_monthly_usage(test_user):
today = date.today()
@ -138,7 +137,7 @@ async def test_get_monthly_usage(test_user):
storage_bytes_avg=1024 * 1024 * 1024 * 10,
storage_bytes_peak=1024 * 1024 * 1024 * 12,
bandwidth_up_bytes=1024 * 1024 * 1024 * 2,
bandwidth_down_bytes=1024 * 1024 * 1024 * 5
bandwidth_down_bytes=1024 * 1024 * 1024 * 5,
)
usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month)
@ -150,11 +149,14 @@ async def test_get_monthly_usage(test_user):
await aggregate.delete()
@pytest.mark.asyncio
async def test_get_monthly_usage_empty(test_user):
future_date = date.today() + timedelta(days=365)
usage = await UsageTracker.get_monthly_usage(test_user, future_date.year, future_date.month)
usage = await UsageTracker.get_monthly_usage(
test_user, future_date.year, future_date.month
)
assert usage["storage_gb_avg"] == 0
assert usage["storage_gb_peak"] == 0

View File

@ -3,27 +3,21 @@ import pytest_asyncio
import asyncio
from tortoise import Tortoise
# Initialize Tortoise at module level
async def _init_tortoise():
@pytest_asyncio.fixture(scope="session")
async def init_db():
"""Initialize the database for testing."""
await Tortoise.init(
db_url="sqlite://:memory:",
modules={
"models": ["mywebdav.models"],
"billing": ["mywebdav.billing.models"]
}
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
)
await Tortoise.generate_schemas()
# Run initialization
asyncio.run(_init_tortoise())
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def cleanup_db():
yield
await Tortoise.close_connections()
@pytest_asyncio.fixture(autouse=True)
async def setup_db(init_db):
"""Ensure database is initialized before each test."""
# The init_db fixture handles the setup
yield