Update.
This commit is contained in:
parent
aff8dfca08
commit
d350ab6807
@ -1,17 +1,18 @@
|
||||
from typing import Optional
|
||||
from .models import Activity, User
|
||||
|
||||
|
||||
async def log_activity(
|
||||
user: Optional[User],
|
||||
action: str,
|
||||
target_type: str,
|
||||
target_id: int,
|
||||
ip_address: Optional[str] = None
|
||||
ip_address: Optional[str] = None,
|
||||
):
|
||||
await Activity.create(
|
||||
user=user,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
ip_address=ip_address
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
@ -9,20 +9,29 @@ import bcrypt
|
||||
from .schemas import TokenData
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
from .two_factor import verify_totp_code # Import verify_totp_code
|
||||
from .two_factor import verify_totp_code # Import verify_totp_code
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
password_bytes = plain_password[:72].encode('utf-8')
|
||||
hashed_bytes = hashed_password.encode('utf-8') if isinstance(hashed_password, str) else hashed_password
|
||||
password_bytes = plain_password[:72].encode("utf-8")
|
||||
hashed_bytes = (
|
||||
hashed_password.encode("utf-8")
|
||||
if isinstance(hashed_password, str)
|
||||
else hashed_password
|
||||
)
|
||||
return bcrypt.checkpw(password_bytes, hashed_bytes)
|
||||
|
||||
def get_password_hash(password):
|
||||
password_bytes = password[:72].encode('utf-8')
|
||||
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode('utf-8')
|
||||
|
||||
async def authenticate_user(username: str, password: str, two_factor_code: Optional[str] = None):
|
||||
def get_password_hash(password):
|
||||
password_bytes = password[:72].encode("utf-8")
|
||||
return bcrypt.hashpw(password_bytes, bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
async def authenticate_user(
|
||||
username: str, password: str, two_factor_code: Optional[str] = None
|
||||
):
|
||||
user = await User.get_or_none(username=username)
|
||||
if not user:
|
||||
return None
|
||||
@ -33,19 +42,29 @@ async def authenticate_user(username: str, password: str, two_factor_code: Optio
|
||||
if not two_factor_code:
|
||||
return {"user": user, "2fa_required": True}
|
||||
if not verify_totp_code(user.two_factor_secret, two_factor_code):
|
||||
return None # 2FA code is incorrect
|
||||
return None # 2FA code is incorrect
|
||||
return {"user": user, "2fa_required": False}
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None, two_factor_verified: bool = False):
|
||||
|
||||
def create_access_token(
|
||||
data: dict,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
two_factor_verified: bool = False,
|
||||
):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
to_encode.update({"exp": expire, "2fa_verified": two_factor_verified})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -53,31 +72,45 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
two_factor_verified: bool = payload.get("2fa_verified", False)
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username, two_factor_verified=two_factor_verified)
|
||||
token_data = TokenData(
|
||||
username=username, two_factor_verified=two_factor_verified
|
||||
)
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
user = await User.get_or_none(username=token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
user.token_data = token_data # Attach token_data to user for easy access
|
||||
user.token_data = token_data # Attach token_data to user for easy access
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(current_user: User = Depends(get_current_user)):
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_verified_user(current_user: User = Depends(get_current_user)):
|
||||
if current_user.is_2fa_enabled and not current_user.token_data.two_factor_verified:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="2FA required and not verified")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="2FA required and not verified",
|
||||
)
|
||||
return current_user
|
||||
|
||||
async def get_current_admin_user(current_user: User = Depends(get_current_verified_user)):
|
||||
|
||||
async def get_current_admin_user(
|
||||
current_user: User = Depends(get_current_verified_user),
|
||||
):
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
@ -2,14 +2,17 @@ from datetime import datetime, date, timedelta, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from calendar import monthrange
|
||||
from .models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
|
||||
from .models import Invoice, InvoiceLineItem, PricingConfig, UserSubscription
|
||||
from .usage_tracker import UsageTracker
|
||||
from .stripe_client import StripeClient
|
||||
from ..models import User
|
||||
|
||||
|
||||
class InvoiceGenerator:
|
||||
@staticmethod
|
||||
async def generate_monthly_invoice(user: User, year: int, month: int) -> Optional[Invoice]:
|
||||
async def generate_monthly_invoice(
|
||||
user: User, year: int, month: int
|
||||
) -> Optional[Invoice]:
|
||||
period_start = date(year, month, 1)
|
||||
days_in_month = monthrange(year, month)[1]
|
||||
period_end = date(year, month, days_in_month)
|
||||
@ -19,19 +22,24 @@ class InvoiceGenerator:
|
||||
pricing = await PricingConfig.all()
|
||||
pricing_dict = {p.config_key: p.config_value for p in pricing}
|
||||
|
||||
storage_price_per_gb = pricing_dict.get('storage_per_gb_month', Decimal('0.0045'))
|
||||
bandwidth_price_per_gb = pricing_dict.get('bandwidth_egress_per_gb', Decimal('0.009'))
|
||||
free_storage_gb = pricing_dict.get('free_tier_storage_gb', Decimal('15'))
|
||||
free_bandwidth_gb = pricing_dict.get('free_tier_bandwidth_gb', Decimal('15'))
|
||||
tax_rate = pricing_dict.get('tax_rate_default', Decimal('0'))
|
||||
storage_price_per_gb = pricing_dict.get(
|
||||
"storage_per_gb_month", Decimal("0.0045")
|
||||
)
|
||||
bandwidth_price_per_gb = pricing_dict.get(
|
||||
"bandwidth_egress_per_gb", Decimal("0.009")
|
||||
)
|
||||
free_storage_gb = pricing_dict.get("free_tier_storage_gb", Decimal("15"))
|
||||
free_bandwidth_gb = pricing_dict.get("free_tier_bandwidth_gb", Decimal("15"))
|
||||
tax_rate = pricing_dict.get("tax_rate_default", Decimal("0"))
|
||||
|
||||
storage_gb = Decimal(str(usage['storage_gb_avg']))
|
||||
bandwidth_gb = Decimal(str(usage['bandwidth_down_gb']))
|
||||
storage_gb = Decimal(str(usage["storage_gb_avg"]))
|
||||
bandwidth_gb = Decimal(str(usage["bandwidth_down_gb"]))
|
||||
|
||||
billable_storage = max(Decimal('0'), storage_gb - free_storage_gb)
|
||||
billable_bandwidth = max(Decimal('0'), bandwidth_gb - free_bandwidth_gb)
|
||||
billable_storage = max(Decimal("0"), storage_gb - free_storage_gb)
|
||||
billable_bandwidth = max(Decimal("0"), bandwidth_gb - free_bandwidth_gb)
|
||||
|
||||
import math
|
||||
|
||||
billable_storage_rounded = Decimal(math.ceil(float(billable_storage)))
|
||||
billable_bandwidth_rounded = Decimal(math.ceil(float(billable_bandwidth)))
|
||||
|
||||
@ -65,9 +73,9 @@ class InvoiceGenerator:
|
||||
"usage": usage,
|
||||
"pricing": {
|
||||
"storage_per_gb": float(storage_price_per_gb),
|
||||
"bandwidth_per_gb": float(bandwidth_price_per_gb)
|
||||
}
|
||||
}
|
||||
"bandwidth_per_gb": float(bandwidth_price_per_gb),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if billable_storage_rounded > 0:
|
||||
@ -78,7 +86,10 @@ class InvoiceGenerator:
|
||||
unit_price=storage_price_per_gb,
|
||||
amount=storage_cost,
|
||||
item_type="storage",
|
||||
metadata={"avg_gb": float(storage_gb), "free_gb": float(free_storage_gb)}
|
||||
metadata={
|
||||
"avg_gb": float(storage_gb),
|
||||
"free_gb": float(free_storage_gb),
|
||||
},
|
||||
)
|
||||
|
||||
if billable_bandwidth_rounded > 0:
|
||||
@ -89,7 +100,10 @@ class InvoiceGenerator:
|
||||
unit_price=bandwidth_price_per_gb,
|
||||
amount=bandwidth_cost,
|
||||
item_type="bandwidth",
|
||||
metadata={"total_gb": float(bandwidth_gb), "free_gb": float(free_bandwidth_gb)}
|
||||
metadata={
|
||||
"total_gb": float(bandwidth_gb),
|
||||
"free_gb": float(free_bandwidth_gb),
|
||||
},
|
||||
)
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
@ -100,7 +114,7 @@ class InvoiceGenerator:
|
||||
"amount": item.amount,
|
||||
"currency": "usd",
|
||||
"description": item.description,
|
||||
"metadata": item.metadata or {}
|
||||
"metadata": item.metadata or {},
|
||||
}
|
||||
for item in line_items
|
||||
]
|
||||
@ -109,7 +123,7 @@ class InvoiceGenerator:
|
||||
customer_id=subscription.stripe_customer_id,
|
||||
description=f"MyWebdav Usage Invoice for {period_start.strftime('%B %Y')}",
|
||||
line_items=stripe_line_items,
|
||||
metadata={"mywebdav_invoice_id": str(invoice.id)}
|
||||
metadata={"mywebdav_invoice_id": str(invoice.id)},
|
||||
)
|
||||
|
||||
invoice.stripe_invoice_id = stripe_invoice.id
|
||||
@ -135,8 +149,11 @@ class InvoiceGenerator:
|
||||
|
||||
# Send invoice email
|
||||
from ..mail import queue_email
|
||||
|
||||
line_items = await invoice.line_items.all()
|
||||
items_text = "\n".join([f"- {item.description}: ${item.amount}" for item in line_items])
|
||||
items_text = "\n".join(
|
||||
[f"- {item.description}: ${item.amount}" for item in line_items]
|
||||
)
|
||||
body = f"""Dear {invoice.user.username},
|
||||
|
||||
Your invoice {invoice.invoice_number} for the period {invoice.period_start} to {invoice.period_end} is now available.
|
||||
@ -174,7 +191,7 @@ The MyWebdav Team
|
||||
to_email=invoice.user.email,
|
||||
subject=f"Your MyWebdav Invoice {invoice.invoice_number}",
|
||||
body=body,
|
||||
html=html
|
||||
html=html,
|
||||
)
|
||||
|
||||
return invoice
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from tortoise import fields, models
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class SubscriptionPlan(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
name = fields.CharField(max_length=100, unique=True)
|
||||
display_name = fields.CharField(max_length=255)
|
||||
description = fields.TextField(null=True)
|
||||
@ -18,10 +18,13 @@ class SubscriptionPlan(models.Model):
|
||||
class Meta:
|
||||
table = "subscription_plans"
|
||||
|
||||
|
||||
class UserSubscription(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="subscription")
|
||||
plan = fields.ForeignKeyField("billing.SubscriptionPlan", related_name="subscriptions", null=True)
|
||||
plan = fields.ForeignKeyField(
|
||||
"billing.SubscriptionPlan", related_name="subscriptions", null=True
|
||||
)
|
||||
billing_type = fields.CharField(max_length=20, default="pay_as_you_go")
|
||||
stripe_customer_id = fields.CharField(max_length=255, unique=True, null=True)
|
||||
stripe_subscription_id = fields.CharField(max_length=255, unique=True, null=True)
|
||||
@ -35,14 +38,15 @@ class UserSubscription(models.Model):
|
||||
class Meta:
|
||||
table = "user_subscriptions"
|
||||
|
||||
|
||||
class UsageRecord(models.Model):
|
||||
id = fields.BigIntField(pk=True)
|
||||
id = fields.BigIntField(primary_key=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="usage_records")
|
||||
record_type = fields.CharField(max_length=50, index=True)
|
||||
record_type = fields.CharField(max_length=50, db_index=True)
|
||||
amount_bytes = fields.BigIntField()
|
||||
resource_type = fields.CharField(max_length=50, null=True)
|
||||
resource_id = fields.IntField(null=True)
|
||||
timestamp = fields.DatetimeField(auto_now_add=True, index=True)
|
||||
timestamp = fields.DatetimeField(auto_now_add=True, db_index=True)
|
||||
idempotency_key = fields.CharField(max_length=255, unique=True, null=True)
|
||||
metadata = fields.JSONField(null=True)
|
||||
|
||||
@ -50,8 +54,9 @@ class UsageRecord(models.Model):
|
||||
table = "usage_records"
|
||||
indexes = [("user_id", "record_type", "timestamp")]
|
||||
|
||||
|
||||
class UsageAggregate(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="usage_aggregates")
|
||||
date = fields.DateField()
|
||||
storage_bytes_avg = fields.BigIntField(default=0)
|
||||
@ -64,8 +69,9 @@ class UsageAggregate(models.Model):
|
||||
table = "usage_aggregates"
|
||||
unique_together = (("user", "date"),)
|
||||
|
||||
|
||||
class Invoice(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="invoices")
|
||||
invoice_number = fields.CharField(max_length=50, unique=True)
|
||||
stripe_invoice_id = fields.CharField(max_length=255, unique=True, null=True)
|
||||
@ -75,10 +81,10 @@ class Invoice(models.Model):
|
||||
tax = fields.DecimalField(max_digits=10, decimal_places=4, default=0)
|
||||
total = fields.DecimalField(max_digits=10, decimal_places=4)
|
||||
currency = fields.CharField(max_length=3, default="USD")
|
||||
status = fields.CharField(max_length=50, default="draft", index=True)
|
||||
status = fields.CharField(max_length=50, default="draft", db_index=True)
|
||||
due_date = fields.DateField(null=True)
|
||||
paid_at = fields.DatetimeField(null=True)
|
||||
created_at = fields.DatetimeField(auto_now_add=True, index=True)
|
||||
created_at = fields.DatetimeField(auto_now_add=True, db_index=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
metadata = fields.JSONField(null=True)
|
||||
|
||||
@ -86,8 +92,9 @@ class Invoice(models.Model):
|
||||
table = "invoices"
|
||||
indexes = [("user_id", "status", "created_at")]
|
||||
|
||||
|
||||
class InvoiceLineItem(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
invoice = fields.ForeignKeyField("billing.Invoice", related_name="line_items")
|
||||
description = fields.TextField()
|
||||
quantity = fields.DecimalField(max_digits=15, decimal_places=6)
|
||||
@ -100,20 +107,24 @@ class InvoiceLineItem(models.Model):
|
||||
class Meta:
|
||||
table = "invoice_line_items"
|
||||
|
||||
|
||||
class PricingConfig(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
config_key = fields.CharField(max_length=100, unique=True)
|
||||
config_value = fields.DecimalField(max_digits=10, decimal_places=6)
|
||||
description = fields.TextField(null=True)
|
||||
unit = fields.CharField(max_length=50, null=True)
|
||||
updated_by = fields.ForeignKeyField("models.User", related_name="pricing_updates", null=True)
|
||||
updated_by = fields.ForeignKeyField(
|
||||
"models.User", related_name="pricing_updates", null=True
|
||||
)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
table = "pricing_config"
|
||||
|
||||
|
||||
class PaymentMethod(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="payment_methods")
|
||||
stripe_payment_method_id = fields.CharField(max_length=255)
|
||||
type = fields.CharField(max_length=50)
|
||||
@ -128,9 +139,12 @@ class PaymentMethod(models.Model):
|
||||
class Meta:
|
||||
table = "payment_methods"
|
||||
|
||||
|
||||
class BillingEvent(models.Model):
|
||||
id = fields.BigIntField(pk=True)
|
||||
user = fields.ForeignKeyField("models.User", related_name="billing_events", null=True)
|
||||
id = fields.BigIntField(primary_key=True)
|
||||
user = fields.ForeignKeyField(
|
||||
"models.User", related_name="billing_events", null=True
|
||||
)
|
||||
event_type = fields.CharField(max_length=100)
|
||||
stripe_event_id = fields.CharField(max_length=255, unique=True, null=True)
|
||||
data = fields.JSONField(null=True)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from datetime import datetime, date, timedelta
|
||||
import asyncio
|
||||
|
||||
from .usage_tracker import UsageTracker
|
||||
from .invoice_generator import InvoiceGenerator
|
||||
@ -9,6 +8,7 @@ from ..models import User
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
|
||||
async def aggregate_daily_usage_for_all_users():
|
||||
users = await User.filter(is_active=True).all()
|
||||
yesterday = date.today() - timedelta(days=1)
|
||||
@ -19,6 +19,7 @@ async def aggregate_daily_usage_for_all_users():
|
||||
except Exception as e:
|
||||
print(f"Failed to aggregate usage for user {user.id}: {e}")
|
||||
|
||||
|
||||
async def generate_monthly_invoices():
|
||||
now = datetime.now()
|
||||
last_month = now.month - 1 if now.month > 1 else 12
|
||||
@ -28,19 +29,22 @@ async def generate_monthly_invoices():
|
||||
|
||||
for user in users:
|
||||
try:
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, last_month)
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(
|
||||
user, year, last_month
|
||||
)
|
||||
if invoice:
|
||||
await InvoiceGenerator.finalize_invoice(invoice)
|
||||
except Exception as e:
|
||||
print(f"Failed to generate invoice for user {user.id}: {e}")
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
scheduler.add_job(
|
||||
aggregate_daily_usage_for_all_users,
|
||||
CronTrigger(hour=1, minute=0),
|
||||
id="aggregate_daily_usage",
|
||||
name="Aggregate daily usage for all users",
|
||||
replace_existing=True
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.add_job(
|
||||
@ -48,10 +52,11 @@ def start_scheduler():
|
||||
CronTrigger(day=1, hour=2, minute=0),
|
||||
id="generate_monthly_invoices",
|
||||
name="Generate monthly invoices",
|
||||
replace_existing=True
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
scheduler.shutdown()
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import stripe
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Dict
|
||||
from ..settings import settings
|
||||
|
||||
|
||||
class StripeClient:
|
||||
@staticmethod
|
||||
def _ensure_api_key():
|
||||
@ -11,13 +11,12 @@ class StripeClient:
|
||||
stripe.api_key = settings.STRIPE_SECRET_KEY
|
||||
else:
|
||||
raise ValueError("Stripe API key not configured")
|
||||
|
||||
@staticmethod
|
||||
async def create_customer(email: str, name: str, metadata: Dict = None) -> str:
|
||||
StripeClient._ensure_api_key()
|
||||
customer = stripe.Customer.create(
|
||||
email=email,
|
||||
name=name,
|
||||
metadata=metadata or {}
|
||||
email=email, name=name, metadata=metadata or {}
|
||||
)
|
||||
return customer.id
|
||||
|
||||
@ -26,7 +25,7 @@ class StripeClient:
|
||||
amount: int,
|
||||
currency: str = "usd",
|
||||
customer_id: str = None,
|
||||
metadata: Dict = None
|
||||
metadata: Dict = None,
|
||||
) -> stripe.PaymentIntent:
|
||||
StripeClient._ensure_api_key()
|
||||
return stripe.PaymentIntent.create(
|
||||
@ -34,32 +33,29 @@ class StripeClient:
|
||||
currency=currency,
|
||||
customer=customer_id,
|
||||
metadata=metadata or {},
|
||||
automatic_payment_methods={"enabled": True}
|
||||
automatic_payment_methods={"enabled": True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_invoice(
|
||||
customer_id: str,
|
||||
description: str,
|
||||
line_items: list,
|
||||
metadata: Dict = None
|
||||
customer_id: str, description: str, line_items: list, metadata: Dict = None
|
||||
) -> stripe.Invoice:
|
||||
StripeClient._ensure_api_key()
|
||||
for item in line_items:
|
||||
stripe.InvoiceItem.create(
|
||||
customer=customer_id,
|
||||
amount=int(item['amount'] * 100),
|
||||
currency=item.get('currency', 'usd'),
|
||||
description=item['description'],
|
||||
metadata=item.get('metadata', {})
|
||||
amount=int(item["amount"] * 100),
|
||||
currency=item.get("currency", "usd"),
|
||||
description=item["description"],
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
|
||||
invoice = stripe.Invoice.create(
|
||||
customer=customer_id,
|
||||
description=description,
|
||||
auto_advance=True,
|
||||
collection_method='charge_automatically',
|
||||
metadata=metadata or {}
|
||||
collection_method="charge_automatically",
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
return invoice
|
||||
@ -76,18 +72,15 @@ class StripeClient:
|
||||
|
||||
@staticmethod
|
||||
async def attach_payment_method(
|
||||
payment_method_id: str,
|
||||
customer_id: str
|
||||
payment_method_id: str, customer_id: str
|
||||
) -> stripe.PaymentMethod:
|
||||
StripeClient._ensure_api_key()
|
||||
payment_method = stripe.PaymentMethod.attach(
|
||||
payment_method_id,
|
||||
customer=customer_id
|
||||
payment_method_id, customer=customer_id
|
||||
)
|
||||
|
||||
stripe.Customer.modify(
|
||||
customer_id,
|
||||
invoice_settings={'default_payment_method': payment_method_id}
|
||||
customer_id, invoice_settings={"default_payment_method": payment_method_id}
|
||||
)
|
||||
|
||||
return payment_method
|
||||
@ -95,22 +88,15 @@ class StripeClient:
|
||||
@staticmethod
|
||||
async def list_payment_methods(customer_id: str, type: str = "card"):
|
||||
StripeClient._ensure_api_key()
|
||||
return stripe.PaymentMethod.list(
|
||||
customer=customer_id,
|
||||
type=type
|
||||
)
|
||||
return stripe.PaymentMethod.list(customer=customer_id, type=type)
|
||||
|
||||
@staticmethod
|
||||
async def create_subscription(
|
||||
customer_id: str,
|
||||
price_id: str,
|
||||
metadata: Dict = None
|
||||
customer_id: str, price_id: str, metadata: Dict = None
|
||||
) -> stripe.Subscription:
|
||||
StripeClient._ensure_api_key()
|
||||
return stripe.Subscription.create(
|
||||
customer=customer_id,
|
||||
items=[{'price': price_id}],
|
||||
metadata=metadata or {}
|
||||
customer=customer_id, items=[{"price": price_id}], metadata=metadata or {}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import uuid
|
||||
from datetime import datetime, date, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from tortoise.transactions import in_transaction
|
||||
from .models import UsageRecord, UsageAggregate
|
||||
from ..models import User
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
@staticmethod
|
||||
async def track_storage(
|
||||
@ -13,7 +12,7 @@ class UsageTracker:
|
||||
amount_bytes: int,
|
||||
resource_type: str = None,
|
||||
resource_id: int = None,
|
||||
metadata: dict = None
|
||||
metadata: dict = None,
|
||||
):
|
||||
idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@ -24,7 +23,7 @@ class UsageTracker:
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
idempotency_key=idempotency_key,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -34,7 +33,7 @@ class UsageTracker:
|
||||
direction: str = "down",
|
||||
resource_type: str = None,
|
||||
resource_id: int = None,
|
||||
metadata: dict = None
|
||||
metadata: dict = None,
|
||||
):
|
||||
record_type = f"bandwidth_{direction}"
|
||||
idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}"
|
||||
@ -46,7 +45,7 @@ class UsageTracker:
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
idempotency_key=idempotency_key,
|
||||
metadata=metadata
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -61,24 +60,26 @@ class UsageTracker:
|
||||
user=user,
|
||||
record_type="storage",
|
||||
timestamp__gte=start_of_day,
|
||||
timestamp__lte=end_of_day
|
||||
timestamp__lte=end_of_day,
|
||||
).all()
|
||||
|
||||
storage_avg = sum(r.amount_bytes for r in storage_records) // max(len(storage_records), 1)
|
||||
storage_avg = sum(r.amount_bytes for r in storage_records) // max(
|
||||
len(storage_records), 1
|
||||
)
|
||||
storage_peak = max((r.amount_bytes for r in storage_records), default=0)
|
||||
|
||||
bandwidth_up = await UsageRecord.filter(
|
||||
user=user,
|
||||
record_type="bandwidth_up",
|
||||
timestamp__gte=start_of_day,
|
||||
timestamp__lte=end_of_day
|
||||
timestamp__lte=end_of_day,
|
||||
).all()
|
||||
|
||||
bandwidth_down = await UsageRecord.filter(
|
||||
user=user,
|
||||
record_type="bandwidth_down",
|
||||
timestamp__gte=start_of_day,
|
||||
timestamp__lte=end_of_day
|
||||
timestamp__lte=end_of_day,
|
||||
).all()
|
||||
|
||||
total_up = sum(r.amount_bytes for r in bandwidth_up)
|
||||
@ -92,8 +93,8 @@ class UsageTracker:
|
||||
"storage_bytes_avg": storage_avg,
|
||||
"storage_bytes_peak": storage_peak,
|
||||
"bandwidth_up_bytes": total_up,
|
||||
"bandwidth_down_bytes": total_down
|
||||
}
|
||||
"bandwidth_down_bytes": total_down,
|
||||
},
|
||||
)
|
||||
|
||||
if not created:
|
||||
@ -108,6 +109,7 @@ class UsageTracker:
|
||||
@staticmethod
|
||||
async def get_current_storage(user: User) -> int:
|
||||
from ..models import File
|
||||
|
||||
files = await File.filter(owner=user, is_deleted=False).all()
|
||||
return sum(f.size for f in files)
|
||||
|
||||
@ -121,9 +123,7 @@ class UsageTracker:
|
||||
end_date = date(year, month, last_day)
|
||||
|
||||
aggregates = await UsageAggregate.filter(
|
||||
user=user,
|
||||
date__gte=start_date,
|
||||
date__lte=end_date
|
||||
user=user, date__gte=start_date, date__lte=end_date
|
||||
).all()
|
||||
|
||||
if not aggregates:
|
||||
@ -132,7 +132,7 @@ class UsageTracker:
|
||||
"storage_gb_peak": 0,
|
||||
"bandwidth_up_gb": 0,
|
||||
"bandwidth_down_gb": 0,
|
||||
"total_bandwidth_gb": 0
|
||||
"total_bandwidth_gb": 0,
|
||||
}
|
||||
|
||||
storage_avg = sum(a.storage_bytes_avg for a in aggregates) / len(aggregates)
|
||||
@ -145,5 +145,5 @@ class UsageTracker:
|
||||
"storage_gb_peak": round(storage_peak / (1024**3), 4),
|
||||
"bandwidth_up_gb": round(bandwidth_up / (1024**3), 4),
|
||||
"bandwidth_down_gb": round(bandwidth_down / (1024**3), 4),
|
||||
"total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4)
|
||||
"total_bandwidth_gb": round((bandwidth_up + bandwidth_down) / (1024**3), 4),
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ for various legal policies required for a European cloud storage provider.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class LegalDocument(ABC):
|
||||
@ -17,10 +17,13 @@ class LegalDocument(ABC):
|
||||
Provides common structure and methods for generating legal content.
|
||||
"""
|
||||
|
||||
def __init__(self, company_name: str = "MyWebdav Technologies",
|
||||
last_updated: str = None,
|
||||
contact_email: str = "legal@mywebdav.eu",
|
||||
website: str = "https://mywebdav.eu"):
|
||||
def __init__(
|
||||
self,
|
||||
company_name: str = "MyWebdav Technologies",
|
||||
last_updated: str = None,
|
||||
contact_email: str = "legal@mywebdav.eu",
|
||||
website: str = "https://mywebdav.eu",
|
||||
):
|
||||
self.company_name = company_name
|
||||
self.last_updated = last_updated or datetime.now().strftime("%B %d, %Y")
|
||||
self.contact_email = contact_email
|
||||
@ -853,20 +856,21 @@ class ContactComplaintMechanism(LegalDocument):
|
||||
def get_all_legal_documents() -> Dict[str, LegalDocument]:
|
||||
"""Return a dictionary of all legal document instances."""
|
||||
return {
|
||||
'privacy_policy': PrivacyPolicy(),
|
||||
'terms_of_service': TermsOfService(),
|
||||
'security_policy': SecurityPolicy(),
|
||||
'cookie_policy': CookiePolicy(),
|
||||
'data_processing_agreement': DataProcessingAgreement(),
|
||||
'compliance_statement': ComplianceStatement(),
|
||||
'data_portability_deletion_policy': DataPortabilityDeletionPolicy(),
|
||||
'contact_complaint_mechanism': ContactComplaintMechanism(),
|
||||
"privacy_policy": PrivacyPolicy(),
|
||||
"terms_of_service": TermsOfService(),
|
||||
"security_policy": SecurityPolicy(),
|
||||
"cookie_policy": CookiePolicy(),
|
||||
"data_processing_agreement": DataProcessingAgreement(),
|
||||
"compliance_statement": ComplianceStatement(),
|
||||
"data_portability_deletion_policy": DataPortabilityDeletionPolicy(),
|
||||
"contact_complaint_mechanism": ContactComplaintMechanism(),
|
||||
}
|
||||
|
||||
|
||||
def generate_legal_documents(output_dir: str = "static/legal"):
|
||||
"""Generate all legal documents as Markdown and HTML files."""
|
||||
import os
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
documents = get_all_legal_documents()
|
||||
@ -875,13 +879,13 @@ def generate_legal_documents(output_dir: str = "static/legal"):
|
||||
# Generate Markdown
|
||||
md_filename = f"{doc_name}.md"
|
||||
md_path = os.path.join(output_dir, md_filename)
|
||||
with open(md_path, 'w') as f:
|
||||
with open(md_path, "w") as f:
|
||||
f.write(doc.to_markdown())
|
||||
|
||||
# Generate HTML
|
||||
html_filename = f"{doc_name}.html"
|
||||
html_path = os.path.join(output_dir, html_filename)
|
||||
with open(html_path, 'w') as f:
|
||||
with open(html_path, "w") as f:
|
||||
f.write(doc.to_html())
|
||||
|
||||
print(f"Generated {md_filename} and {html_filename}")
|
||||
|
||||
@ -2,17 +2,26 @@ import asyncio
|
||||
import aiosmtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional
|
||||
from .settings import settings
|
||||
|
||||
|
||||
class EmailTask:
|
||||
def __init__(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
html: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.to_email = to_email
|
||||
self.subject = subject
|
||||
self.body = body
|
||||
self.html = html
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class EmailService:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
@ -38,7 +47,14 @@ class EmailService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def send_email(self, to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
|
||||
async def send_email(
|
||||
self,
|
||||
to_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
html: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Queue an email for sending"""
|
||||
task = EmailTask(to_email, subject, body, html, **kwargs)
|
||||
await self.queue.put(task)
|
||||
@ -60,22 +76,26 @@ class EmailService:
|
||||
|
||||
async def _send_email_task(self, task: EmailTask):
|
||||
"""Send a single email task"""
|
||||
if not settings.SMTP_HOST or not settings.SMTP_USERNAME or not settings.SMTP_PASSWORD:
|
||||
if (
|
||||
not settings.SMTP_HOST
|
||||
or not settings.SMTP_USERNAME
|
||||
or not settings.SMTP_PASSWORD
|
||||
):
|
||||
print("SMTP not configured, skipping email send")
|
||||
return
|
||||
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['From'] = settings.SMTP_SENDER_EMAIL
|
||||
msg['To'] = task.to_email
|
||||
msg['Subject'] = task.subject
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = settings.SMTP_SENDER_EMAIL
|
||||
msg["To"] = task.to_email
|
||||
msg["Subject"] = task.subject
|
||||
|
||||
# Add text part
|
||||
text_part = MIMEText(task.body, 'plain')
|
||||
text_part = MIMEText(task.body, "plain")
|
||||
msg.attach(text_part)
|
||||
|
||||
# Add HTML part if provided
|
||||
if task.html:
|
||||
html_part = MIMEText(task.html, 'html')
|
||||
html_part = MIMEText(task.html, "html")
|
||||
msg.attach(html_part)
|
||||
|
||||
try:
|
||||
@ -86,7 +106,7 @@ class EmailService:
|
||||
port=settings.SMTP_PORT,
|
||||
username=settings.SMTP_USERNAME,
|
||||
password=settings.SMTP_PASSWORD,
|
||||
use_tls=True
|
||||
use_tls=True,
|
||||
) as smtp:
|
||||
await smtp.send_message(msg)
|
||||
print(f"Email sent to {task.to_email}")
|
||||
@ -99,7 +119,10 @@ class EmailService:
|
||||
try:
|
||||
await smtp.starttls()
|
||||
except Exception as tls_error:
|
||||
if "already using" in str(tls_error).lower() or "tls" in str(tls_error).lower():
|
||||
if (
|
||||
"already using" in str(tls_error).lower()
|
||||
or "tls" in str(tls_error).lower()
|
||||
):
|
||||
# Connection is already using TLS, proceed without starttls
|
||||
pass
|
||||
else:
|
||||
@ -111,14 +134,23 @@ class EmailService:
|
||||
print(f"Failed to send email to {task.to_email}: {e}")
|
||||
raise # Re-raise to let caller handle
|
||||
|
||||
|
||||
# Global email service instance
|
||||
email_service = EmailService()
|
||||
|
||||
|
||||
# Convenience functions
|
||||
async def send_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
|
||||
async def send_email(
|
||||
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
|
||||
):
|
||||
"""Send an email asynchronously"""
|
||||
await email_service.send_email(to_email, subject, body, html, **kwargs)
|
||||
|
||||
def queue_email(to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs):
|
||||
|
||||
def queue_email(
|
||||
to_email: str, subject: str, body: str, html: Optional[str] = None, **kwargs
|
||||
):
|
||||
"""Queue an email for sending (fire and forget)"""
|
||||
asyncio.create_task(email_service.send_email(to_email, subject, body, html, **kwargs))
|
||||
asyncio.create_task(
|
||||
email_service.send_email(to_email, subject, body, html, **kwargs)
|
||||
)
|
||||
|
||||
@ -2,18 +2,33 @@ import argparse
|
||||
import uvicorn
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request, status, HTTPException
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from tortoise.contrib.fastapi import register_tortoise
|
||||
from .settings import settings
|
||||
from .routers import auth, users, folders, files, shares, search, admin, starred, billing, admin_billing
|
||||
from .routers import (
|
||||
auth,
|
||||
users,
|
||||
folders,
|
||||
files,
|
||||
shares,
|
||||
search,
|
||||
admin,
|
||||
starred,
|
||||
billing,
|
||||
admin_billing,
|
||||
)
|
||||
from . import webdav
|
||||
from .schemas import ErrorResponse
|
||||
from .middleware.usage_tracking import UsageTrackingMiddleware
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting up...")
|
||||
@ -21,6 +36,7 @@ async def lifespan(app: FastAPI):
|
||||
from .billing.scheduler import start_scheduler
|
||||
from .billing.models import PricingConfig
|
||||
from .mail import email_service
|
||||
|
||||
start_scheduler()
|
||||
logger.info("Billing scheduler started")
|
||||
await email_service.start()
|
||||
@ -28,28 +44,61 @@ async def lifespan(app: FastAPI):
|
||||
pricing_count = await PricingConfig.all().count()
|
||||
if pricing_count == 0:
|
||||
from decimal import Decimal
|
||||
await PricingConfig.create(config_key='storage_per_gb_month', config_value=Decimal('0.0045'), description='Storage cost per GB per month', unit='per_gb_month')
|
||||
await PricingConfig.create(config_key='bandwidth_egress_per_gb', config_value=Decimal('0.009'), description='Bandwidth egress cost per GB', unit='per_gb')
|
||||
await PricingConfig.create(config_key='bandwidth_ingress_per_gb', config_value=Decimal('0.0'), description='Bandwidth ingress cost per GB (free)', unit='per_gb')
|
||||
await PricingConfig.create(config_key='free_tier_storage_gb', config_value=Decimal('15'), description='Free tier storage in GB', unit='gb')
|
||||
await PricingConfig.create(config_key='free_tier_bandwidth_gb', config_value=Decimal('15'), description='Free tier bandwidth in GB per month', unit='gb')
|
||||
await PricingConfig.create(config_key='tax_rate_default', config_value=Decimal('0.0'), description='Default tax rate (0 = no tax)', unit='percentage')
|
||||
|
||||
await PricingConfig.create(
|
||||
config_key="storage_per_gb_month",
|
||||
config_value=Decimal("0.0045"),
|
||||
description="Storage cost per GB per month",
|
||||
unit="per_gb_month",
|
||||
)
|
||||
await PricingConfig.create(
|
||||
config_key="bandwidth_egress_per_gb",
|
||||
config_value=Decimal("0.009"),
|
||||
description="Bandwidth egress cost per GB",
|
||||
unit="per_gb",
|
||||
)
|
||||
await PricingConfig.create(
|
||||
config_key="bandwidth_ingress_per_gb",
|
||||
config_value=Decimal("0.0"),
|
||||
description="Bandwidth ingress cost per GB (free)",
|
||||
unit="per_gb",
|
||||
)
|
||||
await PricingConfig.create(
|
||||
config_key="free_tier_storage_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier storage in GB",
|
||||
unit="gb",
|
||||
)
|
||||
await PricingConfig.create(
|
||||
config_key="free_tier_bandwidth_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier bandwidth in GB per month",
|
||||
unit="gb",
|
||||
)
|
||||
await PricingConfig.create(
|
||||
config_key="tax_rate_default",
|
||||
config_value=Decimal("0.0"),
|
||||
description="Default tax rate (0 = no tax)",
|
||||
unit="percentage",
|
||||
)
|
||||
logger.info("Default pricing configuration initialized")
|
||||
|
||||
yield
|
||||
|
||||
from .billing.scheduler import stop_scheduler
|
||||
|
||||
stop_scheduler()
|
||||
logger.info("Billing scheduler stopped")
|
||||
await email_service.stop()
|
||||
logger.info("Email service stopped")
|
||||
print("Shutting down...")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="MyWebdav Cloud Storage",
|
||||
description="A commercial cloud storage web application",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.include_router(auth.router)
|
||||
@ -64,8 +113,6 @@ app.include_router(billing.router)
|
||||
app.include_router(admin_billing.router)
|
||||
app.include_router(webdav.router)
|
||||
|
||||
from .middleware.usage_tracking import UsageTrackingMiddleware
|
||||
|
||||
app.add_middleware(UsageTrackingMiddleware)
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
@ -73,35 +120,39 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
register_tortoise(
|
||||
app,
|
||||
db_url=settings.DATABASE_URL,
|
||||
modules={
|
||||
"models": ["mywebdav.models"],
|
||||
"billing": ["mywebdav.billing.models"]
|
||||
},
|
||||
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
|
||||
generate_schemas=True,
|
||||
add_exception_handlers=True,
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
logger.error(f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}")
|
||||
logger.error(
|
||||
f"HTTPException: {exc.status_code} - {exc.detail} for URL: {request.url}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(code=exc.status_code, message=exc.detail).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse
|
||||
@app.get("/", response_class=HTMLResponse) # Change response_class to HTMLResponse
|
||||
async def read_root():
|
||||
with open("static/index.html", "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run the MyWebdav application.")
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host address to bind to")
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="0.0.0.0", help="Host address to bind to"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to listen on")
|
||||
args = parser.parse_args()
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -2,31 +2,35 @@ from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from ..billing.usage_tracker import UsageTracker
|
||||
|
||||
|
||||
class UsageTrackingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
|
||||
if hasattr(request.state, 'user') and request.state.user:
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
user = request.state.user
|
||||
|
||||
if request.method in ['POST', 'PUT'] and '/files/upload' in request.url.path:
|
||||
content_length = response.headers.get('content-length')
|
||||
if (
|
||||
request.method in ["POST", "PUT"]
|
||||
and "/files/upload" in request.url.path
|
||||
):
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length:
|
||||
await UsageTracker.track_bandwidth(
|
||||
user=user,
|
||||
amount_bytes=int(content_length),
|
||||
direction='up',
|
||||
metadata={'path': request.url.path}
|
||||
direction="up",
|
||||
metadata={"path": request.url.path},
|
||||
)
|
||||
|
||||
elif request.method == 'GET' and '/files/download' in request.url.path:
|
||||
content_length = response.headers.get('content-length')
|
||||
elif request.method == "GET" and "/files/download" in request.url.path:
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length:
|
||||
await UsageTracker.track_bandwidth(
|
||||
user=user,
|
||||
amount_bytes=int(content_length),
|
||||
direction='down',
|
||||
metadata={'path': request.url.path}
|
||||
direction="down",
|
||||
metadata={"path": request.url.path},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from tortoise import fields, models
|
||||
from tortoise.contrib.pydantic import pydantic_model_creator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class User(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
username = fields.CharField(max_length=20, unique=True)
|
||||
email = fields.CharField(max_length=255, unique=True)
|
||||
hashed_password = fields.CharField(max_length=255)
|
||||
@ -11,7 +11,9 @@ class User(models.Model):
|
||||
is_superuser = fields.BooleanField(default=False)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
storage_quota_bytes = fields.BigIntField(default=10 * 1024 * 1024 * 1024) # 10 GB default
|
||||
storage_quota_bytes = fields.BigIntField(
|
||||
default=10 * 1024 * 1024 * 1024
|
||||
) # 10 GB default
|
||||
used_storage_bytes = fields.BigIntField(default=0)
|
||||
plan_type = fields.CharField(max_length=50, default="free")
|
||||
two_factor_secret = fields.CharField(max_length=255, null=True)
|
||||
@ -24,11 +26,16 @@ class User(models.Model):
|
||||
def __str__(self):
|
||||
return self.username
|
||||
|
||||
|
||||
class Folder(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
name = fields.CharField(max_length=255)
|
||||
parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField("models.Folder", related_name="children", null=True)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="folders")
|
||||
parent: fields.ForeignKeyRelation["Folder"] = fields.ForeignKeyField(
|
||||
"models.Folder", related_name="children", null=True
|
||||
)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="folders"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
is_deleted = fields.BooleanField(default=False)
|
||||
@ -36,21 +43,28 @@ class Folder(models.Model):
|
||||
|
||||
class Meta:
|
||||
table = "folders"
|
||||
unique_together = (("name", "parent", "owner"),) # Ensure unique folder names within a parent for an owner
|
||||
unique_together = (
|
||||
("name", "parent", "owner"),
|
||||
) # Ensure unique folder names within a parent for an owner
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class File(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
name = fields.CharField(max_length=255)
|
||||
path = fields.CharField(max_length=1024) # Internal storage path
|
||||
path = fields.CharField(max_length=1024) # Internal storage path
|
||||
size = fields.BigIntField()
|
||||
mime_type = fields.CharField(max_length=255)
|
||||
file_hash = fields.CharField(max_length=64, null=True) # SHA-256
|
||||
file_hash = fields.CharField(max_length=64, null=True) # SHA-256
|
||||
thumbnail_path = fields.CharField(max_length=1024, null=True)
|
||||
parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="files", null=True)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="files")
|
||||
parent: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
|
||||
"models.Folder", related_name="files", null=True
|
||||
)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="files"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
is_deleted = fields.BooleanField(default=False)
|
||||
@ -60,14 +74,19 @@ class File(models.Model):
|
||||
|
||||
class Meta:
|
||||
table = "files"
|
||||
unique_together = (("name", "parent", "owner"),) # Ensure unique file names within a parent for an owner
|
||||
unique_together = (
|
||||
("name", "parent", "owner"),
|
||||
) # Ensure unique file names within a parent for an owner
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class FileVersion(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="versions")
|
||||
id = fields.IntField(primary_key=True)
|
||||
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
|
||||
"models.File", related_name="versions"
|
||||
)
|
||||
version_path = fields.CharField(max_length=1024)
|
||||
size = fields.BigIntField()
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
@ -75,47 +94,69 @@ class FileVersion(models.Model):
|
||||
class Meta:
|
||||
table = "file_versions"
|
||||
|
||||
|
||||
class Share(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
token = fields.CharField(max_length=64, unique=True)
|
||||
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField("models.File", related_name="shares", null=True)
|
||||
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="shares", null=True)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="shares")
|
||||
file: fields.ForeignKeyRelation[File] = fields.ForeignKeyField(
|
||||
"models.File", related_name="shares", null=True
|
||||
)
|
||||
folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
|
||||
"models.Folder", related_name="shares", null=True
|
||||
)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="shares"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
expires_at = fields.DatetimeField(null=True)
|
||||
password_protected = fields.BooleanField(default=False)
|
||||
hashed_password = fields.CharField(max_length=255, null=True)
|
||||
access_count = fields.IntField(default=0)
|
||||
permission_level = fields.CharField(max_length=50, default="viewer") # viewer, uploader, editor
|
||||
permission_level = fields.CharField(
|
||||
max_length=50, default="viewer"
|
||||
) # viewer, uploader, editor
|
||||
|
||||
class Meta:
|
||||
table = "shares"
|
||||
|
||||
|
||||
class Team(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
name = fields.CharField(max_length=255, unique=True)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="owned_teams")
|
||||
members: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User", related_name="teams", through="team_members")
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="owned_teams"
|
||||
)
|
||||
members: fields.ManyToManyRelation[User] = fields.ManyToManyField(
|
||||
"models.User", related_name="teams", through="team_members"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
table = "teams"
|
||||
|
||||
|
||||
class TeamMember(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField("models.Team", related_name="team_members")
|
||||
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="user_teams")
|
||||
role = fields.CharField(max_length=50, default="member") # owner, admin, member
|
||||
id = fields.IntField(primary_key=True)
|
||||
team: fields.ForeignKeyRelation[Team] = fields.ForeignKeyField(
|
||||
"models.Team", related_name="team_members"
|
||||
)
|
||||
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="user_teams"
|
||||
)
|
||||
role = fields.CharField(max_length=50, default="member") # owner, admin, member
|
||||
|
||||
class Meta:
|
||||
table = "team_members"
|
||||
unique_together = (("team", "user"),)
|
||||
|
||||
|
||||
class Activity(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="activities", null=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="activities", null=True
|
||||
)
|
||||
action = fields.CharField(max_length=255)
|
||||
target_type = fields.CharField(max_length=50) # file, folder, share, user, team
|
||||
target_type = fields.CharField(max_length=50) # file, folder, share, user, team
|
||||
target_id = fields.IntField()
|
||||
ip_address = fields.CharField(max_length=45, null=True)
|
||||
timestamp = fields.DatetimeField(auto_now_add=True)
|
||||
@ -123,13 +164,18 @@ class Activity(models.Model):
|
||||
class Meta:
|
||||
table = "activities"
|
||||
|
||||
|
||||
class FileRequest(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
title = fields.CharField(max_length=255)
|
||||
description = fields.TextField(null=True)
|
||||
token = fields.CharField(max_length=64, unique=True)
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField("models.User", related_name="file_requests")
|
||||
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField("models.Folder", related_name="file_requests")
|
||||
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
|
||||
"models.User", related_name="file_requests"
|
||||
)
|
||||
target_folder: fields.ForeignKeyRelation[Folder] = fields.ForeignKeyField(
|
||||
"models.Folder", related_name="file_requests"
|
||||
)
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
expires_at = fields.DatetimeField(null=True)
|
||||
is_active = fields.BooleanField(default=True)
|
||||
@ -137,8 +183,9 @@ class FileRequest(models.Model):
|
||||
class Meta:
|
||||
table = "file_requests"
|
||||
|
||||
|
||||
class WebDAVProperty(models.Model):
|
||||
id = fields.IntField(pk=True)
|
||||
id = fields.IntField(primary_key=True)
|
||||
resource_type = fields.CharField(max_length=10)
|
||||
resource_id = fields.IntField()
|
||||
namespace = fields.CharField(max_length=255)
|
||||
@ -151,5 +198,8 @@ class WebDAVProperty(models.Model):
|
||||
table = "webdav_properties"
|
||||
unique_together = (("resource_type", "resource_id", "namespace", "name"),)
|
||||
|
||||
|
||||
User_Pydantic = pydantic_model_creator(User, name="User_Pydantic")
|
||||
UserIn_Pydantic = pydantic_model_creator(User, name="UserIn_Pydantic", exclude_readonly=True)
|
||||
UserIn_Pydantic = pydantic_model_creator(
|
||||
User, name="UserIn_Pydantic", exclude_readonly=True
|
||||
)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
@ -13,18 +13,25 @@ router = APIRouter(
|
||||
responses={403: {"description": "Not enough permissions"}},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[User_Pydantic])
|
||||
async def get_all_users():
|
||||
return await User.all()
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=User_Pydantic)
|
||||
async def get_user(user_id: int):
|
||||
user = await User.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
return user
|
||||
|
||||
@router.post("/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/users", response_model=User_Pydantic, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_user_by_admin(user_in: UserCreate):
|
||||
user = await User.get_or_none(username=user_in.username)
|
||||
if user:
|
||||
@ -44,25 +51,33 @@ async def create_user_by_admin(user_in: UserCreate):
|
||||
username=user_in.username,
|
||||
email=user_in.email,
|
||||
hashed_password=hashed_password,
|
||||
is_superuser=False, # Admin creates regular users by default
|
||||
is_superuser=False, # Admin creates regular users by default
|
||||
is_active=True,
|
||||
)
|
||||
return await User_Pydantic.from_tortoise_orm(user)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", response_model=User_Pydantic)
|
||||
async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
|
||||
user = await User.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if user_update.username is not None and user_update.username != user.username:
|
||||
if await User.get_or_none(username=user_update.username):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Username already taken"
|
||||
)
|
||||
user.username = user_update.username
|
||||
|
||||
if user_update.email is not None and user_update.email != user.email:
|
||||
if await User.get_or_none(email=user_update.email):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email already registered",
|
||||
)
|
||||
user.email = user_update.email
|
||||
|
||||
if user_update.password is not None:
|
||||
@ -89,21 +104,28 @@ async def update_user_by_admin(user_id: int, user_update: UserAdminUpdate):
|
||||
await user.save()
|
||||
return await User_Pydantic.from_tortoise_orm(user)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user_by_admin(user_id: int):
|
||||
user = await User.get_or_none(id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
await user.delete()
|
||||
return {"message": "User deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/test-email")
|
||||
async def send_test_email(to_email: str, subject: str = "Test Email", body: str = "This is a test email"):
|
||||
async def send_test_email(
|
||||
to_email: str, subject: str = "Test Email", body: str = "This is a test email"
|
||||
):
|
||||
from ..mail import queue_email
|
||||
|
||||
queue_email(
|
||||
to_email=to_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
html=f"<h1>{subject}</h1><p>{body}</p>"
|
||||
html=f"<h1>{subject}</h1><p>{body}</p>",
|
||||
)
|
||||
return {"message": "Test email queued"}
|
||||
@ -1,5 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import List
|
||||
from decimal import Decimal
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -8,21 +7,25 @@ from ..models import User
|
||||
from ..billing.models import PricingConfig, Invoice, SubscriptionPlan
|
||||
from ..billing.invoice_generator import InvoiceGenerator
|
||||
|
||||
|
||||
def require_superuser(current_user: User = Depends(get_current_user)):
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(status_code=403, detail="Superuser privileges required")
|
||||
return current_user
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/admin/billing",
|
||||
tags=["admin", "billing"],
|
||||
dependencies=[Depends(require_superuser)]
|
||||
dependencies=[Depends(require_superuser)],
|
||||
)
|
||||
|
||||
|
||||
class PricingConfigUpdate(BaseModel):
|
||||
config_key: str
|
||||
config_value: float
|
||||
|
||||
|
||||
class PlanCreate(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
@ -32,6 +35,7 @@ class PlanCreate(BaseModel):
|
||||
price_monthly: float
|
||||
price_yearly: float = None
|
||||
|
||||
|
||||
@router.get("/pricing")
|
||||
async def get_all_pricing(current_user: User = Depends(require_superuser)):
|
||||
configs = await PricingConfig.all()
|
||||
@ -42,16 +46,17 @@ async def get_all_pricing(current_user: User = Depends(require_superuser)):
|
||||
"config_value": float(c.config_value),
|
||||
"description": c.description,
|
||||
"unit": c.unit,
|
||||
"updated_at": c.updated_at
|
||||
"updated_at": c.updated_at,
|
||||
}
|
||||
for c in configs
|
||||
]
|
||||
|
||||
|
||||
@router.put("/pricing/{config_id}")
|
||||
async def update_pricing(
|
||||
config_id: int,
|
||||
update: PricingConfigUpdate,
|
||||
current_user: User = Depends(require_superuser)
|
||||
current_user: User = Depends(require_superuser),
|
||||
):
|
||||
config = await PricingConfig.get_or_none(id=config_id)
|
||||
if not config:
|
||||
@ -63,11 +68,10 @@ async def update_pricing(
|
||||
|
||||
return {"message": "Pricing updated successfully"}
|
||||
|
||||
|
||||
@router.post("/generate-invoices/{year}/{month}")
|
||||
async def generate_all_invoices(
|
||||
year: int,
|
||||
month: int,
|
||||
current_user: User = Depends(require_superuser)
|
||||
year: int, month: int, current_user: User = Depends(require_superuser)
|
||||
):
|
||||
users = await User.filter(is_active=True).all()
|
||||
generated = []
|
||||
@ -76,35 +80,36 @@ async def generate_all_invoices(
|
||||
for user in users:
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(user, year, month)
|
||||
if invoice:
|
||||
generated.append({
|
||||
"user_id": user.id,
|
||||
"invoice_id": invoice.id,
|
||||
"total": float(invoice.total)
|
||||
})
|
||||
generated.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"invoice_id": invoice.id,
|
||||
"total": float(invoice.total),
|
||||
}
|
||||
)
|
||||
else:
|
||||
skipped.append(user.id)
|
||||
|
||||
return {
|
||||
"generated": len(generated),
|
||||
"skipped": len(skipped),
|
||||
"invoices": generated
|
||||
}
|
||||
return {"generated": len(generated), "skipped": len(skipped), "invoices": generated}
|
||||
|
||||
|
||||
@router.post("/plans")
|
||||
async def create_plan(
|
||||
plan_data: PlanCreate,
|
||||
current_user: User = Depends(require_superuser)
|
||||
plan_data: PlanCreate, current_user: User = Depends(require_superuser)
|
||||
):
|
||||
plan = await SubscriptionPlan.create(**plan_data.dict())
|
||||
return {"id": plan.id, "message": "Plan created successfully"}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_billing_stats(current_user: User = Depends(require_superuser)):
|
||||
from tortoise.functions import Sum, Count
|
||||
from tortoise.functions import Sum
|
||||
|
||||
total_revenue = await Invoice.filter(status="paid").annotate(
|
||||
total_sum=Sum("total")
|
||||
).values("total_sum")
|
||||
total_revenue = (
|
||||
await Invoice.filter(status="paid")
|
||||
.annotate(total_sum=Sum("total"))
|
||||
.values("total_sum")
|
||||
)
|
||||
|
||||
invoice_count = await Invoice.all().count()
|
||||
pending_invoices = await Invoice.filter(status="open").count()
|
||||
@ -112,5 +117,5 @@ async def get_billing_stats(current_user: User = Depends(require_superuser)):
|
||||
return {
|
||||
"total_revenue": float(total_revenue[0]["total_sum"] or 0),
|
||||
"total_invoices": invoice_count,
|
||||
"pending_invoices": pending_invoices
|
||||
"pending_invoices": pending_invoices,
|
||||
}
|
||||
|
||||
@ -4,13 +4,23 @@ from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..auth import authenticate_user, create_access_token, get_password_hash, get_current_user, get_current_verified_user, verify_password
|
||||
from ..auth import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
get_password_hash,
|
||||
get_current_user,
|
||||
get_current_verified_user,
|
||||
verify_password,
|
||||
)
|
||||
from ..models import User
|
||||
from ..schemas import Token, UserCreate, TokenData, UserLoginWith2FA
|
||||
from ..schemas import Token, UserCreate
|
||||
from ..two_factor import (
|
||||
generate_totp_secret, generate_totp_uri, generate_qr_code_base64,
|
||||
verify_totp_code, generate_recovery_codes, hash_recovery_codes,
|
||||
verify_recovery_codes
|
||||
generate_totp_secret,
|
||||
generate_totp_uri,
|
||||
generate_qr_code_base64,
|
||||
verify_totp_code,
|
||||
generate_recovery_codes,
|
||||
hash_recovery_codes,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
@ -18,27 +28,33 @@ router = APIRouter(
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class TwoFactorLogin(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
two_factor_code: Optional[str] = None
|
||||
|
||||
|
||||
class TwoFactorSetupResponse(BaseModel):
|
||||
secret: str
|
||||
qr_code_base64: str
|
||||
recovery_codes: List[str]
|
||||
|
||||
|
||||
class TwoFactorCode(BaseModel):
|
||||
two_factor_code: str
|
||||
|
||||
|
||||
class TwoFactorDisable(BaseModel):
|
||||
password: str
|
||||
two_factor_code: str
|
||||
|
||||
|
||||
@router.post("/register", response_model=Token)
|
||||
async def register_user(user_in: UserCreate):
|
||||
user = await User.get_or_none(username=user_in.username)
|
||||
@ -63,22 +79,26 @@ async def register_user(user_in: UserCreate):
|
||||
|
||||
# Send welcome email
|
||||
from ..mail import queue_email
|
||||
|
||||
queue_email(
|
||||
to_email=user.email,
|
||||
subject="Welcome to MyWebdav!",
|
||||
body=f"Hi {user.username},\n\nWelcome to MyWebdav! Your account has been created successfully.\n\nBest regards,\nThe MyWebdav Team",
|
||||
html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>"
|
||||
html=f"<h1>Welcome to MyWebdav!</h1><p>Hi {user.username},</p><p>Welcome to MyWebdav! Your account has been created successfully.</p><p>Best regards,<br>The MyWebdav Team</p>",
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=30) # Use settings
|
||||
access_token_expires = timedelta(minutes=30) # Use settings
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/token", response_model=Token)
|
||||
async def login_for_access_token(login_data: LoginRequest):
|
||||
auth_result = await authenticate_user(login_data.username, login_data.password, None)
|
||||
auth_result = await authenticate_user(
|
||||
login_data.username, login_data.password, None
|
||||
)
|
||||
|
||||
if not auth_result:
|
||||
raise HTTPException(
|
||||
@ -97,16 +117,26 @@ async def login_for_access_token(login_data: LoginRequest):
|
||||
|
||||
access_token_expires = timedelta(minutes=30)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires, two_factor_verified=True
|
||||
data={"sub": user.username},
|
||||
expires_delta=access_token_expires,
|
||||
two_factor_verified=True,
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/2fa/setup", response_model=TwoFactorSetupResponse)
|
||||
async def setup_two_factor_authentication(current_user: User = Depends(get_current_user)):
|
||||
async def setup_two_factor_authentication(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if current_user.is_2fa_enabled:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
|
||||
)
|
||||
if current_user.two_factor_secret:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup already initiated. Verify or disable first.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="2FA setup already initiated. Verify or disable first.",
|
||||
)
|
||||
|
||||
secret = generate_totp_secret()
|
||||
current_user.two_factor_secret = secret
|
||||
@ -120,39 +150,66 @@ async def setup_two_factor_authentication(current_user: User = Depends(get_curre
|
||||
current_user.recovery_codes = ",".join(hashed_recovery_codes)
|
||||
await current_user.save()
|
||||
|
||||
return TwoFactorSetupResponse(secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes)
|
||||
return TwoFactorSetupResponse(
|
||||
secret=secret, qr_code_base64=qr_code_base64, recovery_codes=recovery_codes
|
||||
)
|
||||
|
||||
|
||||
@router.post("/2fa/verify", response_model=Token)
|
||||
async def verify_two_factor_authentication(two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)):
|
||||
async def verify_two_factor_authentication(
|
||||
two_factor_code_data: TwoFactorCode, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
if current_user.is_2fa_enabled:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is already enabled."
|
||||
)
|
||||
if not current_user.two_factor_secret:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA setup not initiated."
|
||||
)
|
||||
|
||||
if not verify_totp_code(current_user.two_factor_secret, two_factor_code_data.two_factor_code):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
|
||||
if not verify_totp_code(
|
||||
current_user.two_factor_secret, two_factor_code_data.two_factor_code
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
|
||||
)
|
||||
|
||||
current_user.is_2fa_enabled = True
|
||||
await current_user.save()
|
||||
|
||||
access_token_expires = timedelta(minutes=30) # Use settings
|
||||
access_token_expires = timedelta(minutes=30) # Use settings
|
||||
access_token = create_access_token(
|
||||
data={"sub": current_user.username}, expires_delta=access_token_expires, two_factor_verified=True
|
||||
data={"sub": current_user.username},
|
||||
expires_delta=access_token_expires,
|
||||
two_factor_verified=True,
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/2fa/disable", response_model=dict)
|
||||
async def disable_two_factor_authentication(disable_data: TwoFactorDisable, current_user: User = Depends(get_current_verified_user)):
|
||||
async def disable_two_factor_authentication(
|
||||
disable_data: TwoFactorDisable,
|
||||
current_user: User = Depends(get_current_verified_user),
|
||||
):
|
||||
if not current_user.is_2fa_enabled:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
|
||||
)
|
||||
|
||||
# Verify password
|
||||
if not verify_password(disable_data.password, current_user.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password."
|
||||
)
|
||||
|
||||
# Verify 2FA code
|
||||
if not verify_totp_code(current_user.two_factor_secret, disable_data.two_factor_code):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code.")
|
||||
if not verify_totp_code(
|
||||
current_user.two_factor_secret, disable_data.two_factor_code
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid 2FA code."
|
||||
)
|
||||
|
||||
current_user.two_factor_secret = None
|
||||
current_user.is_2fa_enabled = False
|
||||
@ -161,10 +218,15 @@ async def disable_two_factor_authentication(disable_data: TwoFactorDisable, curr
|
||||
|
||||
return {"message": "2FA disabled successfully."}
|
||||
|
||||
|
||||
@router.get("/2fa/recovery-codes", response_model=List[str])
|
||||
async def get_new_recovery_codes(current_user: User = Depends(get_current_verified_user)):
|
||||
async def get_new_recovery_codes(
|
||||
current_user: User = Depends(get_current_verified_user),
|
||||
):
|
||||
if not current_user.is_2fa_enabled:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="2FA is not enabled."
|
||||
)
|
||||
|
||||
recovery_codes = generate_recovery_codes()
|
||||
hashed_recovery_codes = hash_recovery_codes(recovery_codes)
|
||||
|
||||
@ -1,25 +1,25 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
import calendar
|
||||
|
||||
from ..auth import get_current_user
|
||||
from ..models import User
|
||||
from ..billing.models import (
|
||||
Invoice, InvoiceLineItem, UserSubscription, PricingConfig,
|
||||
PaymentMethod, UsageAggregate, SubscriptionPlan
|
||||
Invoice,
|
||||
UserSubscription,
|
||||
PricingConfig,
|
||||
PaymentMethod,
|
||||
UsageAggregate,
|
||||
SubscriptionPlan,
|
||||
)
|
||||
from ..billing.usage_tracker import UsageTracker
|
||||
from ..billing.invoice_generator import InvoiceGenerator
|
||||
from ..billing.stripe_client import StripeClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/billing",
|
||||
tags=["billing"]
|
||||
)
|
||||
router = APIRouter(prefix="/api/billing", tags=["billing"])
|
||||
|
||||
|
||||
class UsageResponse(BaseModel):
|
||||
storage_gb_avg: float
|
||||
@ -29,6 +29,7 @@ class UsageResponse(BaseModel):
|
||||
total_bandwidth_gb: float
|
||||
period: str
|
||||
|
||||
|
||||
class InvoiceResponse(BaseModel):
|
||||
id: int
|
||||
invoice_number: str
|
||||
@ -42,6 +43,7 @@ class InvoiceResponse(BaseModel):
|
||||
paid_at: Optional[datetime]
|
||||
line_items: List[dict]
|
||||
|
||||
|
||||
class SubscriptionResponse(BaseModel):
|
||||
id: int
|
||||
billing_type: str
|
||||
@ -50,6 +52,7 @@ class SubscriptionResponse(BaseModel):
|
||||
current_period_start: Optional[datetime]
|
||||
current_period_end: Optional[datetime]
|
||||
|
||||
|
||||
@router.get("/usage/current")
|
||||
async def get_current_usage(current_user: User = Depends(get_current_user)):
|
||||
try:
|
||||
@ -61,25 +64,32 @@ async def get_current_usage(current_user: User = Depends(get_current_user)):
|
||||
if usage_today:
|
||||
return {
|
||||
"storage_gb": round(storage_bytes / (1024**3), 4),
|
||||
"bandwidth_down_gb_today": round(usage_today.bandwidth_down_bytes / (1024**3), 4),
|
||||
"bandwidth_up_gb_today": round(usage_today.bandwidth_up_bytes / (1024**3), 4),
|
||||
"as_of": today.isoformat()
|
||||
"bandwidth_down_gb_today": round(
|
||||
usage_today.bandwidth_down_bytes / (1024**3), 4
|
||||
),
|
||||
"bandwidth_up_gb_today": round(
|
||||
usage_today.bandwidth_up_bytes / (1024**3), 4
|
||||
),
|
||||
"as_of": today.isoformat(),
|
||||
}
|
||||
|
||||
return {
|
||||
"storage_gb": round(storage_bytes / (1024**3), 4),
|
||||
"bandwidth_down_gb_today": 0,
|
||||
"bandwidth_up_gb_today": 0,
|
||||
"as_of": today.isoformat()
|
||||
"as_of": today.isoformat(),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch usage data: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch usage data: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage/monthly")
|
||||
async def get_monthly_usage(
|
||||
year: Optional[int] = None,
|
||||
month: Optional[int] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> UsageResponse:
|
||||
try:
|
||||
if year is None or month is None:
|
||||
@ -88,71 +98,85 @@ async def get_monthly_usage(
|
||||
month = now.month
|
||||
|
||||
if not (1 <= month <= 12):
|
||||
raise HTTPException(status_code=400, detail="Month must be between 1 and 12")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Month must be between 1 and 12"
|
||||
)
|
||||
if not (2020 <= year <= 2100):
|
||||
raise HTTPException(status_code=400, detail="Year must be between 2020 and 2100")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Year must be between 2020 and 2100"
|
||||
)
|
||||
|
||||
usage = await UsageTracker.get_monthly_usage(current_user, year, month)
|
||||
|
||||
return UsageResponse(
|
||||
**usage,
|
||||
period=f"{year}-{month:02d}"
|
||||
)
|
||||
return UsageResponse(**usage, period=f"{year}-{month:02d}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch monthly usage: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/invoices")
|
||||
async def list_invoices(
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
current_user: User = Depends(get_current_user)
|
||||
limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user)
|
||||
) -> List[InvoiceResponse]:
|
||||
try:
|
||||
if limit < 1 or limit > 100:
|
||||
raise HTTPException(status_code=400, detail="Limit must be between 1 and 100")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Limit must be between 1 and 100"
|
||||
)
|
||||
if offset < 0:
|
||||
raise HTTPException(status_code=400, detail="Offset must be non-negative")
|
||||
|
||||
invoices = await Invoice.filter(user=current_user).order_by("-created_at").offset(offset).limit(limit).all()
|
||||
invoices = (
|
||||
await Invoice.filter(user=current_user)
|
||||
.order_by("-created_at")
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
result = []
|
||||
for invoice in invoices:
|
||||
line_items = await invoice.line_items.all()
|
||||
result.append(InvoiceResponse(
|
||||
id=invoice.id,
|
||||
invoice_number=invoice.invoice_number,
|
||||
period_start=invoice.period_start,
|
||||
period_end=invoice.period_end,
|
||||
subtotal=float(invoice.subtotal),
|
||||
tax=float(invoice.tax),
|
||||
total=float(invoice.total),
|
||||
status=invoice.status,
|
||||
due_date=invoice.due_date,
|
||||
paid_at=invoice.paid_at,
|
||||
line_items=[
|
||||
{
|
||||
"description": item.description,
|
||||
"quantity": float(item.quantity),
|
||||
"unit_price": float(item.unit_price),
|
||||
"amount": float(item.amount),
|
||||
"type": item.item_type
|
||||
}
|
||||
for item in line_items
|
||||
]
|
||||
))
|
||||
result.append(
|
||||
InvoiceResponse(
|
||||
id=invoice.id,
|
||||
invoice_number=invoice.invoice_number,
|
||||
period_start=invoice.period_start,
|
||||
period_end=invoice.period_end,
|
||||
subtotal=float(invoice.subtotal),
|
||||
tax=float(invoice.tax),
|
||||
total=float(invoice.total),
|
||||
status=invoice.status,
|
||||
due_date=invoice.due_date,
|
||||
paid_at=invoice.paid_at,
|
||||
line_items=[
|
||||
{
|
||||
"description": item.description,
|
||||
"quantity": float(item.quantity),
|
||||
"unit_price": float(item.unit_price),
|
||||
"amount": float(item.amount),
|
||||
"type": item.item_type,
|
||||
}
|
||||
for item in line_items
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to fetch invoices: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch invoices: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/invoices/{invoice_id}")
|
||||
async def get_invoice(
|
||||
invoice_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
invoice_id: int, current_user: User = Depends(get_current_user)
|
||||
) -> InvoiceResponse:
|
||||
invoice = await Invoice.get_or_none(id=invoice_id, user=current_user)
|
||||
if not invoice:
|
||||
@ -177,21 +201,22 @@ async def get_invoice(
|
||||
"quantity": float(item.quantity),
|
||||
"unit_price": float(item.unit_price),
|
||||
"amount": float(item.amount),
|
||||
"type": item.item_type
|
||||
"type": item.item_type,
|
||||
}
|
||||
for item in line_items
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/subscription")
|
||||
async def get_subscription(current_user: User = Depends(get_current_user)) -> SubscriptionResponse:
|
||||
async def get_subscription(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> SubscriptionResponse:
|
||||
subscription = await UserSubscription.get_or_none(user=current_user)
|
||||
|
||||
if not subscription:
|
||||
subscription = await UserSubscription.create(
|
||||
user=current_user,
|
||||
billing_type="pay_as_you_go",
|
||||
status="active"
|
||||
user=current_user, billing_type="pay_as_you_go", status="active"
|
||||
)
|
||||
|
||||
plan_name = None
|
||||
@ -205,15 +230,19 @@ async def get_subscription(current_user: User = Depends(get_current_user)) -> Su
|
||||
plan_name=plan_name,
|
||||
status=subscription.status,
|
||||
current_period_start=subscription.current_period_start,
|
||||
current_period_end=subscription.current_period_end
|
||||
current_period_end=subscription.current_period_end,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/payment-methods/setup-intent")
|
||||
async def create_setup_intent(current_user: User = Depends(get_current_user)):
|
||||
try:
|
||||
from ..settings import settings
|
||||
|
||||
if not settings.STRIPE_SECRET_KEY:
|
||||
raise HTTPException(status_code=503, detail="Payment processing not configured")
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Payment processing not configured"
|
||||
)
|
||||
|
||||
subscription = await UserSubscription.get_or_none(user=current_user)
|
||||
|
||||
@ -221,7 +250,7 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
|
||||
customer_id = await StripeClient.create_customer(
|
||||
email=current_user.email,
|
||||
name=current_user.username,
|
||||
metadata={"user_id": str(current_user.id)}
|
||||
metadata={"user_id": str(current_user.id)},
|
||||
)
|
||||
|
||||
if not subscription:
|
||||
@ -229,27 +258,30 @@ async def create_setup_intent(current_user: User = Depends(get_current_user)):
|
||||
user=current_user,
|
||||
billing_type="pay_as_you_go",
|
||||
stripe_customer_id=customer_id,
|
||||
status="active"
|
||||
status="active",
|
||||
)
|
||||
else:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
await subscription.save()
|
||||
|
||||
import stripe
|
||||
|
||||
StripeClient._ensure_api_key()
|
||||
setup_intent = stripe.SetupIntent.create(
|
||||
customer=subscription.stripe_customer_id,
|
||||
payment_method_types=["card"]
|
||||
customer=subscription.stripe_customer_id, payment_method_types=["card"]
|
||||
)
|
||||
|
||||
return {
|
||||
"client_secret": setup_intent.client_secret,
|
||||
"customer_id": subscription.stripe_customer_id
|
||||
"customer_id": subscription.stripe_customer_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create setup intent: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create setup intent: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/payment-methods")
|
||||
async def list_payment_methods(current_user: User = Depends(get_current_user)):
|
||||
@ -262,11 +294,12 @@ async def list_payment_methods(current_user: User = Depends(get_current_user)):
|
||||
"brand": m.brand,
|
||||
"exp_month": m.exp_month,
|
||||
"exp_year": m.exp_year,
|
||||
"is_default": m.is_default
|
||||
"is_default": m.is_default,
|
||||
}
|
||||
for m in methods
|
||||
]
|
||||
|
||||
|
||||
@router.post("/webhooks/stripe")
|
||||
async def stripe_webhook(request: Request):
|
||||
import stripe
|
||||
@ -298,12 +331,14 @@ async def stripe_webhook(request: Request):
|
||||
event_type=event["type"],
|
||||
stripe_event_id=event_id,
|
||||
data=event["data"],
|
||||
processed=False
|
||||
processed=False,
|
||||
)
|
||||
|
||||
if event["type"] == "invoice.payment_succeeded":
|
||||
invoice_data = event["data"]["object"]
|
||||
mywebdav_invoice_id = invoice_data.get("metadata", {}).get("mywebdav_invoice_id")
|
||||
mywebdav_invoice_id = invoice_data.get("metadata", {}).get(
|
||||
"mywebdav_invoice_id"
|
||||
)
|
||||
|
||||
if mywebdav_invoice_id:
|
||||
invoice = await Invoice.get_or_none(id=int(mywebdav_invoice_id))
|
||||
@ -317,7 +352,9 @@ async def stripe_webhook(request: Request):
|
||||
payment_method = event["data"]["object"]
|
||||
customer_id = payment_method["customer"]
|
||||
|
||||
subscription = await UserSubscription.get_or_none(stripe_customer_id=customer_id)
|
||||
subscription = await UserSubscription.get_or_none(
|
||||
stripe_customer_id=customer_id
|
||||
)
|
||||
if subscription:
|
||||
await PaymentMethod.create(
|
||||
user=subscription.user,
|
||||
@ -327,7 +364,7 @@ async def stripe_webhook(request: Request):
|
||||
brand=payment_method.get("card", {}).get("brand"),
|
||||
exp_month=payment_method.get("card", {}).get("exp_month"),
|
||||
exp_year=payment_method.get("card", {}).get("exp_year"),
|
||||
is_default=True
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
await BillingEvent.filter(stripe_event_id=event_id).update(processed=True)
|
||||
@ -335,7 +372,10 @@ async def stripe_webhook(request: Request):
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Webhook processing failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Webhook processing failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/pricing")
|
||||
async def get_pricing():
|
||||
@ -344,11 +384,12 @@ async def get_pricing():
|
||||
config.config_key: {
|
||||
"value": float(config.config_value),
|
||||
"description": config.description,
|
||||
"unit": config.unit
|
||||
"unit": config.unit,
|
||||
}
|
||||
for config in configs
|
||||
}
|
||||
|
||||
|
||||
@router.get("/plans")
|
||||
async def list_plans():
|
||||
plans = await SubscriptionPlan.filter(is_active=True).all()
|
||||
@ -361,14 +402,16 @@ async def list_plans():
|
||||
"storage_gb": plan.storage_gb,
|
||||
"bandwidth_gb": plan.bandwidth_gb,
|
||||
"price_monthly": float(plan.price_monthly),
|
||||
"price_yearly": float(plan.price_yearly) if plan.price_yearly else None
|
||||
"price_yearly": float(plan.price_yearly) if plan.price_yearly else None,
|
||||
}
|
||||
for plan in plans
|
||||
]
|
||||
|
||||
|
||||
@router.get("/stripe-key")
|
||||
async def get_stripe_key():
|
||||
from ..settings import settings
|
||||
|
||||
if not settings.STRIPE_PUBLISHABLE_KEY:
|
||||
raise HTTPException(status_code=503, detail="Payment processing not configured")
|
||||
return {"publishable_key": settings.STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
from fastapi import APIRouter, Depends, UploadFile, File as FastAPIFile, HTTPException, status, Response, Form
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
UploadFile,
|
||||
File as FastAPIFile,
|
||||
HTTPException,
|
||||
status,
|
||||
Form,
|
||||
)
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Optional
|
||||
import mimetypes
|
||||
@ -11,7 +19,6 @@ from ..auth import get_current_user
|
||||
from ..models import User, File, Folder
|
||||
from ..schemas import FileOut
|
||||
from ..storage import storage_manager
|
||||
from ..settings import settings
|
||||
from ..activity import log_activity
|
||||
from ..thumbnails import generate_thumbnail, delete_thumbnail
|
||||
|
||||
@ -20,35 +27,46 @@ router = APIRouter(
|
||||
tags=["files"],
|
||||
)
|
||||
|
||||
|
||||
class FileMove(BaseModel):
|
||||
target_folder_id: Optional[int] = None
|
||||
|
||||
|
||||
class FileRename(BaseModel):
|
||||
new_name: str
|
||||
|
||||
|
||||
class FileCopy(BaseModel):
|
||||
target_folder_id: Optional[int] = None
|
||||
|
||||
|
||||
class BatchFileOperation(BaseModel):
|
||||
file_ids: List[int]
|
||||
operation: str # e.g., "delete", "star", "unstar", "move", "copy"
|
||||
operation: str # e.g., "delete", "star", "unstar", "move", "copy"
|
||||
|
||||
|
||||
class BatchMoveCopyPayload(BaseModel):
|
||||
target_folder_id: Optional[int] = None
|
||||
|
||||
|
||||
class FileContentUpdate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/upload", response_model=FileOut, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_file(
|
||||
file: UploadFile = FastAPIFile(...),
|
||||
folder_id: Optional[int] = Form(None),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if folder_id:
|
||||
parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
parent_folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not parent_folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
else:
|
||||
parent_folder = None
|
||||
|
||||
@ -73,7 +91,7 @@ async def upload_file(
|
||||
|
||||
# Generate a unique path for storage
|
||||
file_extension = os.path.splitext(file.filename)[1]
|
||||
unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename
|
||||
unique_filename = f"{file_hash}{file_extension}" # Use hash for unique filename
|
||||
storage_path = unique_filename
|
||||
|
||||
# Save file to storage
|
||||
@ -105,16 +123,20 @@ async def upload_file(
|
||||
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
|
||||
@router.get("/download/{file_id}")
|
||||
async def download_file(file_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
db_file.last_accessed_at = datetime.now()
|
||||
await db_file.save()
|
||||
|
||||
try:
|
||||
|
||||
async def file_iterator():
|
||||
async for chunk in storage_manager.get_file(current_user.id, db_file.path):
|
||||
yield chunk
|
||||
@ -122,16 +144,21 @@ async def download_file(file_id: int, current_user: User = Depends(get_current_u
|
||||
return StreamingResponse(
|
||||
file_iterator(),
|
||||
media_type=db_file.mime_type,
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{db_file.name}\""}
|
||||
headers={"Content-Disposition": f'attachment; filename="{db_file.name}"'},
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{file_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_file(file_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
db_file.is_deleted = True
|
||||
db_file.deleted_at = datetime.now()
|
||||
@ -141,68 +168,110 @@ async def delete_file(file_id: int, current_user: User = Depends(get_current_use
|
||||
|
||||
return
|
||||
|
||||
|
||||
@router.post("/{file_id}/move", response_model=FileOut)
|
||||
async def move_file(file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)):
|
||||
async def move_file(
|
||||
file_id: int, move_data: FileMove, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
target_folder = None
|
||||
if move_data.target_folder_id:
|
||||
target_folder = await Folder.get_or_none(id=move_data.target_folder_id, owner=current_user, is_deleted=False)
|
||||
target_folder = await Folder.get_or_none(
|
||||
id=move_data.target_folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not target_folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
|
||||
)
|
||||
|
||||
existing_file = await File.get_or_none(
|
||||
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
|
||||
)
|
||||
if existing_file and existing_file.id != file_id:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in target folder")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="File with this name already exists in target folder",
|
||||
)
|
||||
|
||||
db_file.parent = target_folder
|
||||
await db_file.save()
|
||||
|
||||
await log_activity(user=current_user, action="file_moved", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user, action="file_moved", target_type="file", target_id=file_id
|
||||
)
|
||||
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
|
||||
@router.post("/{file_id}/rename", response_model=FileOut)
|
||||
async def rename_file(file_id: int, rename_data: FileRename, current_user: User = Depends(get_current_user)):
|
||||
async def rename_file(
|
||||
file_id: int,
|
||||
rename_data: FileRename,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
existing_file = await File.get_or_none(
|
||||
name=rename_data.new_name, parent_id=db_file.parent_id, owner=current_user, is_deleted=False
|
||||
name=rename_data.new_name,
|
||||
parent_id=db_file.parent_id,
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
)
|
||||
if existing_file and existing_file.id != file_id:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="File with this name already exists in the same folder")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="File with this name already exists in the same folder",
|
||||
)
|
||||
|
||||
db_file.name = rename_data.new_name
|
||||
await db_file.save()
|
||||
|
||||
await log_activity(user=current_user, action="file_renamed", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user, action="file_renamed", target_type="file", target_id=file_id
|
||||
)
|
||||
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
@router.post("/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED)
|
||||
async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)):
|
||||
|
||||
@router.post(
|
||||
"/{file_id}/copy", response_model=FileOut, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def copy_file(
|
||||
file_id: int, copy_data: FileCopy, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
target_folder = None
|
||||
if copy_data.target_folder_id:
|
||||
target_folder = await Folder.get_or_none(id=copy_data.target_folder_id, owner=current_user, is_deleted=False)
|
||||
target_folder = await Folder.get_or_none(
|
||||
id=copy_data.target_folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not target_folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Target folder not found"
|
||||
)
|
||||
|
||||
base_name = db_file.name
|
||||
name_parts = os.path.splitext(base_name)
|
||||
counter = 1
|
||||
new_name = base_name
|
||||
|
||||
while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
|
||||
while await File.get_or_none(
|
||||
name=new_name, parent=target_folder, owner=current_user, is_deleted=False
|
||||
):
|
||||
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
|
||||
counter += 1
|
||||
|
||||
@ -213,139 +282,205 @@ async def copy_file(file_id: int, copy_data: FileCopy, current_user: User = Depe
|
||||
mime_type=db_file.mime_type,
|
||||
file_hash=db_file.file_hash,
|
||||
owner=current_user,
|
||||
parent=target_folder
|
||||
parent=target_folder,
|
||||
)
|
||||
|
||||
await log_activity(user=current_user, action="file_copied", target_type="file", target_id=new_file.id)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_copied",
|
||||
target_type="file",
|
||||
target_id=new_file.id,
|
||||
)
|
||||
|
||||
return await FileOut.from_tortoise_orm(new_file)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[FileOut])
|
||||
async def list_files(folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
|
||||
async def list_files(
|
||||
folder_id: Optional[int] = None, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
if folder_id:
|
||||
parent_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
parent_folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not parent_folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
files = await File.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
files = await File.filter(
|
||||
parent=parent_folder, owner=current_user, is_deleted=False
|
||||
).order_by("name")
|
||||
else:
|
||||
files = await File.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
|
||||
files = await File.filter(
|
||||
parent=None, owner=current_user, is_deleted=False
|
||||
).order_by("name")
|
||||
return [await FileOut.from_tortoise_orm(f) for f in files]
|
||||
|
||||
|
||||
@router.get("/thumbnail/{file_id}")
|
||||
async def get_thumbnail(file_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
db_file.last_accessed_at = datetime.now()
|
||||
await db_file.save()
|
||||
|
||||
thumbnail_path = getattr(db_file, 'thumbnail_path', None)
|
||||
thumbnail_path = getattr(db_file, "thumbnail_path", None)
|
||||
|
||||
if not thumbnail_path:
|
||||
thumbnail_path = await generate_thumbnail(db_file.path, db_file.mime_type, current_user.id)
|
||||
thumbnail_path = await generate_thumbnail(
|
||||
db_file.path, db_file.mime_type, current_user.id
|
||||
)
|
||||
|
||||
if thumbnail_path:
|
||||
db_file.thumbnail_path = thumbnail_path
|
||||
await db_file.save()
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not available"
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
async def thumbnail_iterator():
|
||||
async for chunk in storage_manager.get_file(current_user.id, thumbnail_path):
|
||||
async for chunk in storage_manager.get_file(
|
||||
current_user.id, thumbnail_path
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
thumbnail_iterator(),
|
||||
media_type="image/jpeg"
|
||||
)
|
||||
return StreamingResponse(thumbnail_iterator(), media_type="image/jpeg")
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Thumbnail not found in storage")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Thumbnail not found in storage",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/photos", response_model=List[FileOut])
|
||||
async def list_photos(current_user: User = Depends(get_current_user)):
|
||||
files = await File.filter(
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
mime_type__istartswith="image/"
|
||||
owner=current_user, is_deleted=False, mime_type__istartswith="image/"
|
||||
).order_by("-created_at")
|
||||
return [await FileOut.from_tortoise_orm(f) for f in files]
|
||||
|
||||
|
||||
@router.get("/recent", response_model=List[FileOut])
|
||||
async def list_recent_files(current_user: User = Depends(get_current_user), limit: int = 10):
|
||||
files = await File.filter(
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
last_accessed_at__isnull=False
|
||||
).order_by("-last_accessed_at").limit(limit)
|
||||
async def list_recent_files(
|
||||
current_user: User = Depends(get_current_user), limit: int = 10
|
||||
):
|
||||
files = (
|
||||
await File.filter(
|
||||
owner=current_user, is_deleted=False, last_accessed_at__isnull=False
|
||||
)
|
||||
.order_by("-last_accessed_at")
|
||||
.limit(limit)
|
||||
)
|
||||
return [await FileOut.from_tortoise_orm(f) for f in files]
|
||||
|
||||
|
||||
@router.post("/{file_id}/star", response_model=FileOut)
|
||||
async def star_file(file_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
db_file.is_starred = True
|
||||
await db_file.save()
|
||||
await log_activity(user=current_user, action="file_starred", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user, action="file_starred", target_type="file", target_id=file_id
|
||||
)
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
|
||||
@router.post("/{file_id}/unstar", response_model=FileOut)
|
||||
async def unstar_file(file_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
db_file.is_starred = False
|
||||
await db_file.save()
|
||||
await log_activity(user=current_user, action="file_unstarred", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_unstarred",
|
||||
target_type="file",
|
||||
target_id=file_id,
|
||||
)
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
|
||||
@router.get("/deleted", response_model=List[FileOut])
|
||||
async def list_deleted_files(current_user: User = Depends(get_current_user)):
|
||||
files = await File.filter(owner=current_user, is_deleted=True).order_by("-deleted_at")
|
||||
files = await File.filter(owner=current_user, is_deleted=True).order_by(
|
||||
"-deleted_at"
|
||||
)
|
||||
return [await FileOut.from_tortoise_orm(f) for f in files]
|
||||
|
||||
|
||||
@router.post("/{file_id}/restore", response_model=FileOut)
|
||||
async def restore_file(file_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=True)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Deleted file not found"
|
||||
)
|
||||
|
||||
# Check if a file with the same name exists in the parent folder
|
||||
existing_file = await File.get_or_none(
|
||||
name=db_file.name, parent=db_file.parent, owner=current_user, is_deleted=False
|
||||
)
|
||||
if existing_file:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="A file with the same name already exists in this location. Please rename the existing file or restore to a different location.",
|
||||
)
|
||||
|
||||
db_file.is_deleted = False
|
||||
db_file.deleted_at = None
|
||||
await db_file.save()
|
||||
await log_activity(user=current_user, action="file_restored", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user, action="file_restored", target_type="file", target_id=file_id
|
||||
)
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
|
||||
class BatchOperationResult(BaseModel):
|
||||
succeeded: List[FileOut]
|
||||
failed: List[dict]
|
||||
|
||||
|
||||
@router.post("/batch")
|
||||
async def batch_file_operations(
|
||||
batch_operation: BatchFileOperation,
|
||||
payload: Optional[BatchMoveCopyPayload] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if batch_operation.operation not in ["delete", "star", "unstar", "move", "copy"]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid operation: {batch_operation.operation}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid operation: {batch_operation.operation}",
|
||||
)
|
||||
|
||||
updated_files = []
|
||||
failed_operations = []
|
||||
|
||||
for file_id in batch_operation.file_ids:
|
||||
try:
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
db_file = await File.get_or_none(
|
||||
id=file_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not db_file:
|
||||
failed_operations.append({"file_id": file_id, "reason": "File not found or not owned by user"})
|
||||
failed_operations.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"reason": "File not found or not owned by user",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if batch_operation.operation == "delete":
|
||||
@ -353,47 +488,87 @@ async def batch_file_operations(
|
||||
db_file.deleted_at = datetime.now()
|
||||
await db_file.save()
|
||||
await delete_thumbnail(db_file.id)
|
||||
await log_activity(user=current_user, action="file_deleted_batch", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_deleted_batch",
|
||||
target_type="file",
|
||||
target_id=file_id,
|
||||
)
|
||||
updated_files.append(db_file)
|
||||
elif batch_operation.operation == "star":
|
||||
db_file.is_starred = True
|
||||
await db_file.save()
|
||||
await log_activity(user=current_user, action="file_starred_batch", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_starred_batch",
|
||||
target_type="file",
|
||||
target_id=file_id,
|
||||
)
|
||||
updated_files.append(db_file)
|
||||
elif batch_operation.operation == "unstar":
|
||||
db_file.is_starred = False
|
||||
await db_file.save()
|
||||
await log_activity(user=current_user, action="file_unstarred_batch", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_unstarred_batch",
|
||||
target_type="file",
|
||||
target_id=file_id,
|
||||
)
|
||||
updated_files.append(db_file)
|
||||
elif batch_operation.operation == "move":
|
||||
if not payload or payload.target_folder_id is None:
|
||||
failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
|
||||
failed_operations.append(
|
||||
{"file_id": file_id, "reason": "Target folder not specified"}
|
||||
)
|
||||
continue
|
||||
|
||||
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
|
||||
target_folder = await Folder.get_or_none(
|
||||
id=payload.target_folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not target_folder:
|
||||
failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
|
||||
failed_operations.append(
|
||||
{"file_id": file_id, "reason": "Target folder not found"}
|
||||
)
|
||||
continue
|
||||
|
||||
existing_file = await File.get_or_none(
|
||||
name=db_file.name, parent=target_folder, owner=current_user, is_deleted=False
|
||||
name=db_file.name,
|
||||
parent=target_folder,
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
)
|
||||
if existing_file and existing_file.id != file_id:
|
||||
failed_operations.append({"file_id": file_id, "reason": "File with same name exists in target folder"})
|
||||
failed_operations.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"reason": "File with same name exists in target folder",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
db_file.parent = target_folder
|
||||
await db_file.save()
|
||||
await log_activity(user=current_user, action="file_moved_batch", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_moved_batch",
|
||||
target_type="file",
|
||||
target_id=file_id,
|
||||
)
|
||||
updated_files.append(db_file)
|
||||
elif batch_operation.operation == "copy":
|
||||
if not payload or payload.target_folder_id is None:
|
||||
failed_operations.append({"file_id": file_id, "reason": "Target folder not specified"})
|
||||
failed_operations.append(
|
||||
{"file_id": file_id, "reason": "Target folder not specified"}
|
||||
)
|
||||
continue
|
||||
|
||||
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
|
||||
target_folder = await Folder.get_or_none(
|
||||
id=payload.target_folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not target_folder:
|
||||
failed_operations.append({"file_id": file_id, "reason": "Target folder not found"})
|
||||
failed_operations.append(
|
||||
{"file_id": file_id, "reason": "Target folder not found"}
|
||||
)
|
||||
continue
|
||||
|
||||
base_name = db_file.name
|
||||
@ -401,7 +576,12 @@ async def batch_file_operations(
|
||||
counter = 1
|
||||
new_name = base_name
|
||||
|
||||
while await File.get_or_none(name=new_name, parent=target_folder, owner=current_user, is_deleted=False):
|
||||
while await File.get_or_none(
|
||||
name=new_name,
|
||||
parent=target_folder,
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
):
|
||||
new_name = f"{name_parts[0]} (copy {counter}){name_parts[1]}"
|
||||
counter += 1
|
||||
|
||||
@ -412,48 +592,70 @@ async def batch_file_operations(
|
||||
mime_type=db_file.mime_type,
|
||||
file_hash=db_file.file_hash,
|
||||
owner=current_user,
|
||||
parent=target_folder
|
||||
parent=target_folder,
|
||||
)
|
||||
await log_activity(
|
||||
user=current_user,
|
||||
action="file_copied_batch",
|
||||
target_type="file",
|
||||
target_id=new_file.id,
|
||||
)
|
||||
await log_activity(user=current_user, action="file_copied_batch", target_type="file", target_id=new_file.id)
|
||||
updated_files.append(new_file)
|
||||
except Exception as e:
|
||||
failed_operations.append({"file_id": file_id, "reason": str(e)})
|
||||
|
||||
return {
|
||||
"succeeded": [await FileOut.from_tortoise_orm(f) for f in updated_files],
|
||||
"failed": failed_operations
|
||||
"failed": failed_operations,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{file_id}/content", response_model=FileOut)
|
||||
async def update_file_content(
|
||||
file_id: int,
|
||||
payload: FileContentUpdate,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
db_file = await File.get_or_none(id=file_id, owner=current_user, is_deleted=False)
|
||||
if not db_file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
if not db_file.mime_type or not db_file.mime_type.startswith('text/'):
|
||||
if not db_file.mime_type or not db_file.mime_type.startswith("text/"):
|
||||
editableExtensions = [
|
||||
'txt', 'md', 'log', 'json', 'js', 'py', 'html', 'css',
|
||||
'xml', 'yaml', 'yml', 'sh', 'bat', 'ini', 'conf', 'cfg'
|
||||
"txt",
|
||||
"md",
|
||||
"log",
|
||||
"json",
|
||||
"js",
|
||||
"py",
|
||||
"html",
|
||||
"css",
|
||||
"xml",
|
||||
"yaml",
|
||||
"yml",
|
||||
"sh",
|
||||
"bat",
|
||||
"ini",
|
||||
"conf",
|
||||
"cfg",
|
||||
]
|
||||
file_extension = os.path.splitext(db_file.name)[1][1:].lower()
|
||||
if file_extension not in editableExtensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="File type is not editable"
|
||||
detail="File type is not editable",
|
||||
)
|
||||
|
||||
content_bytes = payload.content.encode('utf-8')
|
||||
content_bytes = payload.content.encode("utf-8")
|
||||
new_size = len(content_bytes)
|
||||
size_diff = new_size - db_file.size
|
||||
|
||||
if current_user.used_storage_bytes + size_diff > current_user.storage_quota_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_507_INSUFFICIENT_STORAGE,
|
||||
detail="Storage quota exceeded"
|
||||
detail="Storage quota exceeded",
|
||||
)
|
||||
|
||||
new_hash = hashlib.sha256(content_bytes).hexdigest()
|
||||
@ -465,7 +667,7 @@ async def update_file_content(
|
||||
if new_storage_path != db_file.path:
|
||||
try:
|
||||
await storage_manager.delete_file(current_user.id, db_file.path)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db_file.path = new_storage_path
|
||||
@ -477,6 +679,8 @@ async def update_file_content(
|
||||
current_user.used_storage_bytes += size_diff
|
||||
await current_user.save()
|
||||
|
||||
await log_activity(user=current_user, action="file_updated", target_type="file", target_id=file_id)
|
||||
await log_activity(
|
||||
user=current_user, action="file_updated", target_type="file", target_id=file_id
|
||||
)
|
||||
|
||||
return await FileOut.from_tortoise_orm(db_file)
|
||||
|
||||
@ -3,7 +3,13 @@ from typing import List, Optional
|
||||
|
||||
from ..auth import get_current_user
|
||||
from ..models import User, Folder
|
||||
from ..schemas import FolderCreate, FolderOut, FolderUpdate, BatchFolderOperation, BatchMoveCopyPayload
|
||||
from ..schemas import (
|
||||
FolderCreate,
|
||||
FolderOut,
|
||||
FolderUpdate,
|
||||
BatchFolderOperation,
|
||||
BatchMoveCopyPayload,
|
||||
)
|
||||
from ..activity import log_activity
|
||||
|
||||
router = APIRouter(
|
||||
@ -11,12 +17,17 @@ router = APIRouter(
|
||||
tags=["folders"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=FolderOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_folder(folder_in: FolderCreate, current_user: User = Depends(get_current_user)):
|
||||
async def create_folder(
|
||||
folder_in: FolderCreate, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# Check if parent folder exists and belongs to the current user
|
||||
parent_folder = None
|
||||
if folder_in.parent_id:
|
||||
parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user)
|
||||
parent_folder = await Folder.get_or_none(
|
||||
id=folder_in.parent_id, owner=current_user
|
||||
)
|
||||
if not parent_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@ -39,50 +50,82 @@ async def create_folder(folder_in: FolderCreate, current_user: User = Depends(ge
|
||||
await log_activity(current_user, "folder_created", "folder", folder.id)
|
||||
return await FolderOut.from_tortoise_orm(folder)
|
||||
|
||||
|
||||
@router.get("/{folder_id}/path", response_model=List[FolderOut])
|
||||
async def get_folder_path(folder_id: int, current_user: User = Depends(get_current_user)):
|
||||
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
async def get_folder_path(
|
||||
folder_id: int, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
|
||||
path = []
|
||||
current = folder
|
||||
while current:
|
||||
path.insert(0, await FolderOut.from_tortoise_orm(current))
|
||||
if current.parent_id:
|
||||
current = await Folder.get_or_none(id=current.parent_id, owner=current_user, is_deleted=False)
|
||||
current = await Folder.get_or_none(
|
||||
id=current.parent_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
else:
|
||||
current = None
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@router.get("/{folder_id}", response_model=FolderOut)
|
||||
async def get_folder(folder_id: int, current_user: User = Depends(get_current_user)):
|
||||
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
return await FolderOut.from_tortoise_orm(folder)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[FolderOut])
|
||||
async def list_folders(parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)):
|
||||
async def list_folders(
|
||||
parent_id: Optional[int] = None, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
if parent_id:
|
||||
parent_folder = await Folder.get_or_none(id=parent_id, owner=current_user, is_deleted=False)
|
||||
parent_folder = await Folder.get_or_none(
|
||||
id=parent_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not parent_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Parent folder not found or does not belong to the current user",
|
||||
)
|
||||
folders = await Folder.filter(parent=parent_folder, owner=current_user, is_deleted=False).order_by("name")
|
||||
folders = await Folder.filter(
|
||||
parent=parent_folder, owner=current_user, is_deleted=False
|
||||
).order_by("name")
|
||||
else:
|
||||
# List root folders (folders with no parent)
|
||||
folders = await Folder.filter(parent=None, owner=current_user, is_deleted=False).order_by("name")
|
||||
folders = await Folder.filter(
|
||||
parent=None, owner=current_user, is_deleted=False
|
||||
).order_by("name")
|
||||
return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
|
||||
|
||||
|
||||
@router.put("/{folder_id}", response_model=FolderOut)
|
||||
async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: User = Depends(get_current_user)):
|
||||
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
async def update_folder(
|
||||
folder_id: int,
|
||||
folder_in: FolderUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
|
||||
if folder_in.name:
|
||||
existing_folder = await Folder.get_or_none(
|
||||
@ -97,11 +140,16 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
|
||||
|
||||
if folder_in.parent_id is not None:
|
||||
if folder_in.parent_id == folder_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot set folder as its own parent")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot set folder as its own parent",
|
||||
)
|
||||
|
||||
new_parent_folder = None
|
||||
if folder_in.parent_id != 0: # 0 could represent moving to root
|
||||
new_parent_folder = await Folder.get_or_none(id=folder_in.parent_id, owner=current_user, is_deleted=False)
|
||||
if folder_in.parent_id != 0: # 0 could represent moving to root
|
||||
new_parent_folder = await Folder.get_or_none(
|
||||
id=folder_in.parent_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not new_parent_folder:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@ -113,48 +161,66 @@ async def update_folder(folder_id: int, folder_in: FolderUpdate, current_user: U
|
||||
await log_activity(current_user, "folder_updated", "folder", folder.id)
|
||||
return await FolderOut.from_tortoise_orm(folder)
|
||||
|
||||
|
||||
@router.delete("/{folder_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_folder(folder_id: int, current_user: User = Depends(get_current_user)):
|
||||
folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
|
||||
folder.is_deleted = True
|
||||
await folder.save()
|
||||
await log_activity(current_user, "folder_deleted", "folder", folder.id)
|
||||
return
|
||||
|
||||
|
||||
@router.post("/{folder_id}/star", response_model=FolderOut)
|
||||
async def star_folder(folder_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
db_folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not db_folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
db_folder.is_starred = True
|
||||
await db_folder.save()
|
||||
await log_activity(current_user, "folder_starred", "folder", folder_id)
|
||||
return await FolderOut.from_tortoise_orm(db_folder)
|
||||
|
||||
|
||||
@router.post("/{folder_id}/unstar", response_model=FolderOut)
|
||||
async def unstar_folder(folder_id: int, current_user: User = Depends(get_current_user)):
|
||||
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
db_folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not db_folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
db_folder.is_starred = False
|
||||
await db_folder.save()
|
||||
await log_activity(current_user, "folder_unstarred", "folder", folder_id)
|
||||
return await FolderOut.from_tortoise_orm(db_folder)
|
||||
|
||||
|
||||
@router.post("/batch", response_model=List[FolderOut])
|
||||
async def batch_folder_operations(
|
||||
batch_operation: BatchFolderOperation,
|
||||
payload: Optional[BatchMoveCopyPayload] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
updated_folders = []
|
||||
for folder_id in batch_operation.folder_ids:
|
||||
db_folder = await Folder.get_or_none(id=folder_id, owner=current_user, is_deleted=False)
|
||||
db_folder = await Folder.get_or_none(
|
||||
id=folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not db_folder:
|
||||
continue # Skip if folder not found or not owned by user
|
||||
continue # Skip if folder not found or not owned by user
|
||||
|
||||
if batch_operation.operation == "delete":
|
||||
db_folder.is_deleted = True
|
||||
@ -171,13 +237,22 @@ async def batch_folder_operations(
|
||||
await db_folder.save()
|
||||
await log_activity(current_user, "folder_unstarred", "folder", folder_id)
|
||||
updated_folders.append(db_folder)
|
||||
elif batch_operation.operation == "move" and payload and payload.target_folder_id is not None:
|
||||
target_folder = await Folder.get_or_none(id=payload.target_folder_id, owner=current_user, is_deleted=False)
|
||||
elif (
|
||||
batch_operation.operation == "move"
|
||||
and payload
|
||||
and payload.target_folder_id is not None
|
||||
):
|
||||
target_folder = await Folder.get_or_none(
|
||||
id=payload.target_folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not target_folder:
|
||||
continue
|
||||
|
||||
existing_folder = await Folder.get_or_none(
|
||||
name=db_folder.name, parent=target_folder, owner=current_user, is_deleted=False
|
||||
name=db_folder.name,
|
||||
parent=target_folder,
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
)
|
||||
if existing_folder and existing_folder.id != folder_id:
|
||||
continue
|
||||
|
||||
@ -11,15 +11,22 @@ router = APIRouter(
|
||||
tags=["search"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files", response_model=List[FileOut])
|
||||
async def search_files(
|
||||
q: str = Query(..., min_length=1, description="Search query"),
|
||||
file_type: Optional[str] = Query(None, description="Filter by MIME type prefix (e.g., 'image', 'video')"),
|
||||
file_type: Optional[str] = Query(
|
||||
None, description="Filter by MIME type prefix (e.g., 'image', 'video')"
|
||||
),
|
||||
min_size: Optional[int] = Query(None, description="Minimum file size in bytes"),
|
||||
max_size: Optional[int] = Query(None, description="Maximum file size in bytes"),
|
||||
date_from: Optional[datetime] = Query(None, description="Filter files created after this date"),
|
||||
date_to: Optional[datetime] = Query(None, description="Filter files created before this date"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
date_from: Optional[datetime] = Query(
|
||||
None, description="Filter files created after this date"
|
||||
),
|
||||
date_to: Optional[datetime] = Query(
|
||||
None, description="Filter files created before this date"
|
||||
),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = File.filter(owner=current_user, is_deleted=False, name__icontains=q)
|
||||
|
||||
@ -41,15 +48,16 @@ async def search_files(
|
||||
files = await query.order_by("-created_at").limit(100)
|
||||
return [await FileOut.from_tortoise_orm(f) for f in files]
|
||||
|
||||
|
||||
@router.get("/folders", response_model=List[FolderOut])
|
||||
async def search_folders(
|
||||
q: str = Query(..., min_length=1, description="Search query"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
folders = await Folder.filter(
|
||||
owner=current_user,
|
||||
is_deleted=False,
|
||||
name__icontains=q
|
||||
).order_by("-created_at").limit(100)
|
||||
folders = (
|
||||
await Folder.filter(owner=current_user, is_deleted=False, name__icontains=q)
|
||||
.order_by("-created_at")
|
||||
.limit(100)
|
||||
)
|
||||
|
||||
return [await FolderOut.from_tortoise_orm(folder) for folder in folders]
|
||||
|
||||
@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional, List
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
from ..auth import get_current_user
|
||||
from ..models import User, File, Folder, Share
|
||||
@ -17,25 +17,44 @@ router = APIRouter(
|
||||
tags=["shares"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=ShareOut, status_code=status.HTTP_201_CREATED)
|
||||
async def create_share_link(share_in: ShareCreate, current_user: User = Depends(get_current_user)):
|
||||
async def create_share_link(
|
||||
share_in: ShareCreate, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
if not share_in.file_id and not share_in.folder_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either file_id or folder_id must be provided")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Either file_id or folder_id must be provided",
|
||||
)
|
||||
if share_in.file_id and share_in.folder_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot share both a file and a folder simultaneously")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot share both a file and a folder simultaneously",
|
||||
)
|
||||
|
||||
file = None
|
||||
folder = None
|
||||
|
||||
if share_in.file_id:
|
||||
file = await File.get_or_none(id=share_in.file_id, owner=current_user, is_deleted=False)
|
||||
file = await File.get_or_none(
|
||||
id=share_in.file_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found or does not belong to you")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found or does not belong to you",
|
||||
)
|
||||
|
||||
if share_in.folder_id:
|
||||
folder = await Folder.get_or_none(id=share_in.folder_id, owner=current_user, is_deleted=False)
|
||||
folder = await Folder.get_or_none(
|
||||
id=share_in.folder_id, owner=current_user, is_deleted=False
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found or does not belong to you")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Folder not found or does not belong to you",
|
||||
)
|
||||
|
||||
token = secrets.token_urlsafe(16)
|
||||
hashed_password = None
|
||||
@ -60,8 +79,14 @@ async def create_share_link(share_in: ShareCreate, current_user: User = Depends(
|
||||
item_type = "file" if file else "folder"
|
||||
item_name = file.name if file else folder.name
|
||||
|
||||
expiry_text = f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}" if share_in.expires_at else ""
|
||||
password_text = f"\n\nPassword: {share_in.password}" if share_in.password else ""
|
||||
expiry_text = (
|
||||
f" until {share_in.expires_at.strftime('%Y-%m-%d %H:%M')}"
|
||||
if share_in.expires_at
|
||||
else ""
|
||||
)
|
||||
password_text = (
|
||||
f"\n\nPassword: {share_in.password}" if share_in.password else ""
|
||||
)
|
||||
|
||||
email_body = f"""Hello,
|
||||
|
||||
@ -95,26 +120,32 @@ MyWebdav File Sharing Service"""
|
||||
to_email=share_in.invite_email,
|
||||
subject=f"{current_user.username} shared {item_name} with you",
|
||||
body=email_body,
|
||||
html=email_html
|
||||
html=email_html,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to send invitation email: {e}")
|
||||
|
||||
return await ShareOut.from_tortoise_orm(share)
|
||||
|
||||
|
||||
@router.get("/my", response_model=List[ShareOut])
|
||||
async def list_my_shares(current_user: User = Depends(get_current_user)):
|
||||
shares = await Share.filter(owner=current_user).order_by("-created_at")
|
||||
return [await ShareOut.from_tortoise_orm(share) for share in shares]
|
||||
|
||||
|
||||
@router.get("/{share_token}", response_model=ShareOut)
|
||||
async def get_share_link_info(share_token: str):
|
||||
share = await Share.get_or_none(token=share_token)
|
||||
if not share:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
|
||||
)
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.utcnow():
|
||||
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_410_GONE, detail="Share link has expired"
|
||||
)
|
||||
|
||||
# Increment access count
|
||||
share.access_count += 1
|
||||
@ -122,11 +153,17 @@ async def get_share_link_info(share_token: str):
|
||||
|
||||
return await ShareOut.from_tortoise_orm(share)
|
||||
|
||||
|
||||
@router.put("/{share_id}", response_model=ShareOut)
|
||||
async def update_share(share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)):
|
||||
async def update_share(
|
||||
share_id: int, share_in: ShareCreate, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
share = await Share.get_or_none(id=share_id, owner=current_user)
|
||||
if not share:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share link not found or does not belong to you",
|
||||
)
|
||||
|
||||
if share_in.expires_at is not None:
|
||||
share.expires_at = share_in.expires_at
|
||||
@ -134,7 +171,7 @@ async def update_share(share_id: int, share_in: ShareCreate, current_user: User
|
||||
if share_in.password is not None:
|
||||
share.hashed_password = get_password_hash(share_in.password)
|
||||
share.password_protected = True
|
||||
elif share_in.password == "": # Allow clearing password
|
||||
elif share_in.password == "": # Allow clearing password
|
||||
share.hashed_password = None
|
||||
share.password_protected = False
|
||||
|
||||
@ -144,31 +181,42 @@ async def update_share(share_id: int, share_in: ShareCreate, current_user: User
|
||||
await share.save()
|
||||
return await ShareOut.from_tortoise_orm(share)
|
||||
|
||||
|
||||
@router.post("/{share_token}/access")
|
||||
async def access_shared_content(share_token: str, password: Optional[str] = None):
|
||||
share = await Share.get_or_none(token=share_token)
|
||||
if not share:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
|
||||
)
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.utcnow():
|
||||
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_410_GONE, detail="Share link has expired"
|
||||
)
|
||||
|
||||
if share.password_protected:
|
||||
if not password or not verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
|
||||
)
|
||||
|
||||
result = {"message": "Access granted", "permission_level": share.permission_level}
|
||||
|
||||
if share.file_id:
|
||||
file = await File.get_or_none(id=share.file_id, is_deleted=False)
|
||||
if not file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
result["file"] = await FileOut.from_tortoise_orm(file)
|
||||
result["type"] = "file"
|
||||
elif share.folder_id:
|
||||
folder = await Folder.get_or_none(id=share.folder_id, is_deleted=False)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Folder not found"
|
||||
)
|
||||
result["folder"] = await FolderOut.from_tortoise_orm(folder)
|
||||
result["type"] = "folder"
|
||||
|
||||
@ -179,29 +227,42 @@ async def access_shared_content(share_token: str, password: Optional[str] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{share_token}/download")
|
||||
async def download_shared_file(share_token: str, password: Optional[str] = None):
|
||||
share = await Share.get_or_none(token=share_token)
|
||||
if not share:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found"
|
||||
)
|
||||
|
||||
if share.expires_at and share.expires_at < datetime.utcnow():
|
||||
raise HTTPException(status_code=status.HTTP_410_GONE, detail="Share link has expired")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_410_GONE, detail="Share link has expired"
|
||||
)
|
||||
|
||||
if share.password_protected:
|
||||
if not password or not verify_password(password, share.hashed_password):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect password"
|
||||
)
|
||||
|
||||
if not share.file_id:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This share is not for a file")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="This share is not for a file",
|
||||
)
|
||||
|
||||
file = await File.get_or_none(id=share.file_id, is_deleted=False)
|
||||
if not file:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
|
||||
)
|
||||
|
||||
owner = await User.get(id=file.owner_id)
|
||||
|
||||
try:
|
||||
|
||||
async def file_iterator():
|
||||
async for chunk in storage_manager.get_file(owner.id, file.path):
|
||||
yield chunk
|
||||
@ -209,18 +270,22 @@ async def download_shared_file(share_token: str, password: Optional[str] = None)
|
||||
return StreamingResponse(
|
||||
content=file_iterator(),
|
||||
media_type=file.mime_type,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{file.name}"'
|
||||
}
|
||||
headers={"Content-Disposition": f'attachment; filename="{file.name}"'},
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found in storage")
|
||||
|
||||
|
||||
@router.delete("/{share_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_share_link(share_id: int, current_user: User = Depends(get_current_user)):
|
||||
async def delete_share_link(
|
||||
share_id: int, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
share = await Share.get_or_none(id=share_id, owner=current_user)
|
||||
if not share:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Share link not found or does not belong to you")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Share link not found or does not belong to you",
|
||||
)
|
||||
|
||||
await share.delete()
|
||||
return
|
||||
|
||||
@ -11,18 +11,29 @@ router = APIRouter(
|
||||
tags=["starred"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files", response_model=List[FileOut])
|
||||
async def list_starred_files(current_user: User = Depends(get_current_user)):
|
||||
files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
|
||||
files = await File.filter(
|
||||
owner=current_user, is_starred=True, is_deleted=False
|
||||
).order_by("name")
|
||||
return [await FileOut.from_tortoise_orm(f) for f in files]
|
||||
|
||||
|
||||
@router.get("/folders", response_model=List[FolderOut])
|
||||
async def list_starred_folders(current_user: User = Depends(get_current_user)):
|
||||
folders = await Folder.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
|
||||
folders = await Folder.filter(
|
||||
owner=current_user, is_starred=True, is_deleted=False
|
||||
).order_by("name")
|
||||
return [await FolderOut.from_tortoise_orm(f) for f in folders]
|
||||
|
||||
@router.get("/all", response_model=List[FileOut]) # This will return files and folders as files for now
|
||||
|
||||
@router.get(
|
||||
"/all", response_model=List[FileOut]
|
||||
) # This will return files and folders as files for now
|
||||
async def list_all_starred(current_user: User = Depends(get_current_user)):
|
||||
starred_files = await File.filter(owner=current_user, is_starred=True, is_deleted=False).order_by("name")
|
||||
starred_files = await File.filter(
|
||||
owner=current_user, is_starred=True, is_deleted=False
|
||||
).order_by("name")
|
||||
# For simplicity, we'll return files only for now. A more complex solution would involve a union or a custom schema.
|
||||
return [await FileOut.from_tortoise_orm(f) for f in starred_files]
|
||||
@ -1,17 +1,19 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from ..auth import get_current_user
|
||||
from ..models import User_Pydantic, User, File, Folder
|
||||
from typing import List, Dict, Any
|
||||
from typing import Dict, Any
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=User_Pydantic)
|
||||
async def read_users_me(current_user: User = Depends(get_current_user)):
|
||||
return await User_Pydantic.from_tortoise_orm(current_user)
|
||||
|
||||
|
||||
@router.get("/me/export", response_model=Dict[str, Any])
|
||||
async def export_my_data(current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
@ -35,6 +37,7 @@ async def export_my_data(current_user: User = Depends(get_current_user)):
|
||||
# share information, etc., would also be included.
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/me", status_code=204)
|
||||
async def delete_my_account(current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
@ -48,5 +51,3 @@ async def delete_my_account(current_user: User = Depends(get_current_user)):
|
||||
# Finally, delete the user account
|
||||
await current_user.delete()
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@ -2,19 +2,24 @@ from datetime import datetime
|
||||
from pydantic import BaseModel, EmailStr, ConfigDict
|
||||
from typing import Optional, List
|
||||
from tortoise.contrib.pydantic import pydantic_model_creator
|
||||
from mywebdav.models import Folder, File, Share, FileVersion
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserLoginWith2FA(UserLogin):
|
||||
two_factor_code: Optional[str] = None
|
||||
|
||||
|
||||
class UserAdminUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
@ -25,22 +30,27 @@ class UserAdminUpdate(BaseModel):
|
||||
plan_type: Optional[str] = None
|
||||
is_2fa_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
two_factor_verified: bool = False
|
||||
|
||||
|
||||
class FolderCreate(BaseModel):
|
||||
name: str
|
||||
parent_id: Optional[int] = None
|
||||
|
||||
|
||||
class FolderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
parent_id: Optional[int] = None
|
||||
|
||||
|
||||
class ShareCreate(BaseModel):
|
||||
file_id: Optional[int] = None
|
||||
folder_id: Optional[int] = None
|
||||
@ -49,9 +59,11 @@ class ShareCreate(BaseModel):
|
||||
permission_level: str = "viewer"
|
||||
invite_email: Optional[EmailStr] = None
|
||||
|
||||
|
||||
class TeamCreate(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TeamOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@ -60,6 +72,7 @@ class TeamOut(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ActivityOut(BaseModel):
|
||||
id: int
|
||||
user_id: Optional[int] = None
|
||||
@ -71,12 +84,14 @@ class ActivityOut(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class FileRequestCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
target_folder_id: int
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class FileRequestOut(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
@ -90,25 +105,28 @@ class FileRequestOut(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
from mywebdav.models import Folder, File, Share, FileVersion
|
||||
|
||||
FolderOut = pydantic_model_creator(Folder, name="FolderOut")
|
||||
FileOut = pydantic_model_creator(File, name="FileOut")
|
||||
ShareOut = pydantic_model_creator(Share, name="ShareOut")
|
||||
FileVersionOut = pydantic_model_creator(FileVersion, name="FileVersionOut")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
details: Optional[str] = None
|
||||
|
||||
|
||||
class BatchFileOperation(BaseModel):
|
||||
file_ids: List[int]
|
||||
operation: str # e.g., "delete", "move", "copy", "star", "unstar"
|
||||
operation: str # e.g., "delete", "move", "copy", "star", "unstar"
|
||||
|
||||
|
||||
class BatchFolderOperation(BaseModel):
|
||||
folder_ids: List[int]
|
||||
operation: str # e.g., "delete", "move", "star", "unstar"
|
||||
operation: str # e.g., "delete", "move", "star", "unstar"
|
||||
|
||||
|
||||
class BatchMoveCopyPayload(BaseModel):
|
||||
target_folder_id: Optional[int] = None
|
||||
|
||||
@ -2,8 +2,9 @@ import os
|
||||
import sys
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file='.env', extra='ignore')
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
DATABASE_URL: str = "sqlite:///app/mywebdav.db"
|
||||
REDIS_URL: str = "redis://redis:6379/0"
|
||||
@ -30,8 +31,14 @@ class Settings(BaseSettings):
|
||||
STRIPE_WEBHOOK_SECRET: str = ""
|
||||
BILLING_ENABLED: bool = False
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
if settings.SECRET_KEY == "super_secret_key" and os.getenv("ENVIRONMENT") == "production":
|
||||
print("ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable.")
|
||||
if (
|
||||
settings.SECRET_KEY == "super_secret_key"
|
||||
and os.getenv("ENVIRONMENT") == "production"
|
||||
):
|
||||
print(
|
||||
"ERROR: Secret key must be changed in production. Set SECRET_KEY environment variable."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import AsyncGenerator
|
||||
|
||||
from .settings import settings
|
||||
|
||||
|
||||
class StorageManager:
|
||||
def __init__(self, base_path: str = settings.STORAGE_PATH):
|
||||
self.base_path = Path(base_path)
|
||||
@ -12,7 +13,11 @@ class StorageManager:
|
||||
|
||||
async def _get_full_path(self, user_id: int, file_path: str) -> Path:
|
||||
# Ensure file_path is relative and safe
|
||||
relative_path = Path(file_path).relative_to('/') if str(file_path).startswith('/') else Path(file_path)
|
||||
relative_path = (
|
||||
Path(file_path).relative_to("/")
|
||||
if str(file_path).startswith("/")
|
||||
else Path(file_path)
|
||||
)
|
||||
full_path = self.base_path / str(user_id) / relative_path
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return full_path
|
||||
@ -52,4 +57,5 @@ class StorageManager:
|
||||
full_path = await self._get_full_path(user_id, file_path)
|
||||
return full_path.exists()
|
||||
|
||||
|
||||
storage_manager = StorageManager()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
@ -9,7 +8,10 @@ from .settings import settings
|
||||
THUMBNAIL_SIZE = (300, 300)
|
||||
THUMBNAIL_DIR = "thumbnails"
|
||||
|
||||
async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Optional[str]:
|
||||
|
||||
async def generate_thumbnail(
|
||||
file_path: str, mime_type: str, user_id: int
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
if mime_type.startswith("image/"):
|
||||
return await generate_image_thumbnail(file_path, user_id)
|
||||
@ -20,6 +22,7 @@ async def generate_thumbnail(file_path: str, mime_type: str, user_id: int) -> Op
|
||||
print(f"Error generating thumbnail for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@ -30,12 +33,16 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
|
||||
|
||||
file_name = Path(file_path).name
|
||||
thumbnail_name = f"thumb_{file_name}"
|
||||
if not thumbnail_name.lower().endswith(('.jpg', '.jpeg', '.png')):
|
||||
if not thumbnail_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
||||
thumbnail_name += ".jpg"
|
||||
|
||||
thumbnail_path = thumbnail_dir / thumbnail_name
|
||||
|
||||
actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
|
||||
actual_file_path = (
|
||||
base_path / str(user_id) / file_path
|
||||
if not Path(file_path).is_absolute()
|
||||
else Path(file_path)
|
||||
)
|
||||
|
||||
with Image.open(actual_file_path) as img:
|
||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||
@ -44,7 +51,9 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
|
||||
background = Image.new("RGB", img.size, (255, 255, 255))
|
||||
if img.mode == "P":
|
||||
img = img.convert("RGBA")
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None)
|
||||
background.paste(
|
||||
img, mask=img.split()[-1] if img.mode in ("RGBA", "LA") else None
|
||||
)
|
||||
img = background
|
||||
|
||||
img.save(str(thumbnail_path), "JPEG", quality=85, optimize=True)
|
||||
@ -53,6 +62,7 @@ async def generate_image_thumbnail(file_path: str, user_id: int) -> Optional[str
|
||||
|
||||
return await loop.run_in_executor(None, _generate)
|
||||
|
||||
|
||||
async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str]:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@ -65,17 +75,29 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
|
||||
thumbnail_name = f"thumb_{file_name}.jpg"
|
||||
thumbnail_path = thumbnail_dir / thumbnail_name
|
||||
|
||||
actual_file_path = base_path / str(user_id) / file_path if not Path(file_path).is_absolute() else Path(file_path)
|
||||
actual_file_path = (
|
||||
base_path / str(user_id) / file_path
|
||||
if not Path(file_path).is_absolute()
|
||||
else Path(file_path)
|
||||
)
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg",
|
||||
"-i", str(actual_file_path),
|
||||
"-ss", "00:00:01",
|
||||
"-vframes", "1",
|
||||
"-vf", f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
|
||||
"-y",
|
||||
str(thumbnail_path)
|
||||
], check=True, capture_output=True)
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
str(actual_file_path),
|
||||
"-ss",
|
||||
"00:00:01",
|
||||
"-vframes",
|
||||
"1",
|
||||
"-vf",
|
||||
f"scale={THUMBNAIL_SIZE[0]}:{THUMBNAIL_SIZE[1]}:force_original_aspect_ratio=decrease",
|
||||
"-y",
|
||||
str(thumbnail_path),
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
return str(thumbnail_path.relative_to(base_path / str(user_id)))
|
||||
|
||||
@ -84,6 +106,7 @@ async def generate_video_thumbnail(file_path: str, user_id: int) -> Optional[str
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
async def delete_thumbnail(thumbnail_path: str, user_id: int):
|
||||
try:
|
||||
base_path = Path(settings.STORAGE_PATH)
|
||||
|
||||
@ -5,15 +5,20 @@ import base64
|
||||
import secrets
|
||||
import hashlib
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
|
||||
def generate_totp_secret() -> str:
|
||||
"""Generates a random base32 TOTP secret."""
|
||||
return pyotp.random_base32()
|
||||
|
||||
|
||||
def generate_totp_uri(secret: str, account_name: str, issuer_name: str) -> str:
|
||||
"""Generates a Google Authenticator-compatible TOTP URI."""
|
||||
return pyotp.totp.TOTP(secret).provisioning_uri(name=account_name, issuer_name=issuer_name)
|
||||
return pyotp.totp.TOTP(secret).provisioning_uri(
|
||||
name=account_name, issuer_name=issuer_name
|
||||
)
|
||||
|
||||
|
||||
def generate_qr_code_base64(uri: str) -> str:
|
||||
"""Generates a base64 encoded QR code image for a given URI."""
|
||||
@ -31,27 +36,33 @@ def generate_qr_code_base64(uri: str) -> str:
|
||||
img.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_totp_code(secret: str, code: str) -> bool:
|
||||
"""Verifies a TOTP code against a secret."""
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(code)
|
||||
|
||||
|
||||
def generate_recovery_codes(num_codes: int = 10) -> List[str]:
|
||||
"""Generates a list of random recovery codes."""
|
||||
return [secrets.token_urlsafe(16) for _ in range(num_codes)]
|
||||
|
||||
|
||||
def hash_recovery_code(code: str) -> str:
|
||||
"""Hashes a single recovery code using SHA256."""
|
||||
return hashlib.sha256(code.encode('utf-8')).hexdigest()
|
||||
return hashlib.sha256(code.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def verify_recovery_code(plain_code: str, hashed_code: str) -> bool:
|
||||
"""Verifies a plain recovery code against its hashed version."""
|
||||
return hash_recovery_code(plain_code) == hashed_code
|
||||
|
||||
|
||||
def hash_recovery_codes(codes: List[str]) -> List[str]:
|
||||
"""Hashes a list of recovery codes."""
|
||||
return [hash_recovery_code(code) for code in codes]
|
||||
|
||||
|
||||
def verify_recovery_codes(plain_code: str, hashed_codes: List[str]) -> bool:
|
||||
"""Verifies if a plain recovery code matches any of the hashed recovery codes."""
|
||||
for hashed_code in hashed_codes:
|
||||
|
||||
@ -19,6 +19,7 @@ router = APIRouter(
|
||||
tags=["webdav"],
|
||||
)
|
||||
|
||||
|
||||
class WebDAVLock:
|
||||
locks = {}
|
||||
|
||||
@ -26,10 +27,10 @@ class WebDAVLock:
|
||||
def create_lock(cls, path: str, user_id: int, timeout: int = 3600):
|
||||
lock_token = f"opaquelocktoken:{hashlib.md5(f'{path}{user_id}{datetime.now()}'.encode()).hexdigest()}"
|
||||
cls.locks[path] = {
|
||||
'token': lock_token,
|
||||
'user_id': user_id,
|
||||
'created_at': datetime.now(),
|
||||
'timeout': timeout
|
||||
"token": lock_token,
|
||||
"user_id": user_id,
|
||||
"created_at": datetime.now(),
|
||||
"timeout": timeout,
|
||||
}
|
||||
return lock_token
|
||||
|
||||
@ -42,17 +43,18 @@ class WebDAVLock:
|
||||
if path in cls.locks:
|
||||
del cls.locks[path]
|
||||
|
||||
|
||||
async def basic_auth(authorization: Optional[str] = Header(None)):
|
||||
if not authorization:
|
||||
return None
|
||||
|
||||
try:
|
||||
scheme, credentials = authorization.split()
|
||||
if scheme.lower() != 'basic':
|
||||
if scheme.lower() != "basic":
|
||||
return None
|
||||
|
||||
decoded = base64.b64decode(credentials).decode('utf-8')
|
||||
username, password = decoded.split(':', 1)
|
||||
decoded = base64.b64decode(credentials).decode("utf-8")
|
||||
username, password = decoded.split(":", 1)
|
||||
|
||||
user = await User.get_or_none(username=username)
|
||||
if user and verify_password(password, user.hashed_password):
|
||||
@ -62,6 +64,7 @@ async def basic_auth(authorization: Optional[str] = Header(None)):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)):
|
||||
user = await basic_auth(authorization)
|
||||
if user:
|
||||
@ -75,14 +78,15 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
headers={'WWW-Authenticate': 'Basic realm="MyWebdav WebDAV"'}
|
||||
headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'},
|
||||
)
|
||||
|
||||
|
||||
async def resolve_path(path_str: str, user: User):
|
||||
if not path_str or path_str == '/':
|
||||
if not path_str or path_str == "/":
|
||||
return None, None, True
|
||||
|
||||
parts = [p for p in path_str.split('/') if p]
|
||||
parts = [p for p in path_str.split("/") if p]
|
||||
|
||||
if not parts:
|
||||
return None, None, True
|
||||
@ -90,10 +94,7 @@ async def resolve_path(path_str: str, user: User):
|
||||
current_folder = None
|
||||
for i, part in enumerate(parts[:-1]):
|
||||
folder = await Folder.get_or_none(
|
||||
name=part,
|
||||
parent=current_folder,
|
||||
owner=user,
|
||||
is_deleted=False
|
||||
name=part, parent=current_folder, owner=user, is_deleted=False
|
||||
)
|
||||
if not folder:
|
||||
return None, None, False
|
||||
@ -102,39 +103,37 @@ async def resolve_path(path_str: str, user: User):
|
||||
last_part = parts[-1]
|
||||
|
||||
folder = await Folder.get_or_none(
|
||||
name=last_part,
|
||||
parent=current_folder,
|
||||
owner=user,
|
||||
is_deleted=False
|
||||
name=last_part, parent=current_folder, owner=user, is_deleted=False
|
||||
)
|
||||
if folder:
|
||||
return folder, current_folder, True
|
||||
|
||||
file = await File.get_or_none(
|
||||
name=last_part,
|
||||
parent=current_folder,
|
||||
owner=user,
|
||||
is_deleted=False
|
||||
name=last_part, parent=current_folder, owner=user, is_deleted=False
|
||||
)
|
||||
if file:
|
||||
return file, current_folder, True
|
||||
|
||||
return None, current_folder, True
|
||||
|
||||
|
||||
def build_href(base_path: str, name: str, is_collection: bool):
|
||||
path = f"{base_path.rstrip('/')}/{name}"
|
||||
if is_collection:
|
||||
path += '/'
|
||||
path += "/"
|
||||
return path
|
||||
|
||||
|
||||
async def get_custom_properties(resource_type: str, resource_id: int):
|
||||
props = await WebDAVProperty.filter(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id
|
||||
resource_type=resource_type, resource_id=resource_id
|
||||
)
|
||||
return {(prop.namespace, prop.name): prop.value for prop in props}
|
||||
|
||||
def create_propstat_element(props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"):
|
||||
|
||||
def create_propstat_element(
|
||||
props: dict, custom_props: dict = None, status: str = "HTTP/1.1 200 OK"
|
||||
):
|
||||
propstat = ET.Element("D:propstat")
|
||||
prop = ET.SubElement(propstat, "D:prop")
|
||||
|
||||
@ -151,10 +150,10 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
|
||||
elem.text = value
|
||||
elif key == "getlastmodified":
|
||||
elem = ET.SubElement(prop, f"D:{key}")
|
||||
elem.text = value.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
elem.text = value.strftime("%a, %d %b %Y %H:%M:%S GMT")
|
||||
elif key == "creationdate":
|
||||
elem = ET.SubElement(prop, f"D:{key}")
|
||||
elem.text = value.isoformat() + 'Z'
|
||||
elem.text = value.isoformat() + "Z"
|
||||
elif key == "displayname":
|
||||
elem = ET.SubElement(prop, f"D:{key}")
|
||||
elem.text = value
|
||||
@ -174,6 +173,7 @@ def create_propstat_element(props: dict, custom_props: dict = None, status: str
|
||||
|
||||
return propstat
|
||||
|
||||
|
||||
def parse_propfind_body(body: bytes):
|
||||
if not body:
|
||||
return None
|
||||
@ -193,8 +193,8 @@ def parse_propfind_body(body: bytes):
|
||||
if prop is not None:
|
||||
requested_props = []
|
||||
for child in prop:
|
||||
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
|
||||
name = child.tag.split('}')[1] if '}' in child.tag else child.tag
|
||||
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
|
||||
name = child.tag.split("}")[1] if "}" in child.tag else child.tag
|
||||
requested_props.append((ns, name))
|
||||
return requested_props
|
||||
except ET.ParseError:
|
||||
@ -202,6 +202,7 @@ def parse_propfind_body(body: bytes):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["OPTIONS"])
|
||||
async def webdav_options(full_path: str):
|
||||
return Response(
|
||||
@ -209,14 +210,17 @@ async def webdav_options(full_path: str):
|
||||
headers={
|
||||
"DAV": "1, 2",
|
||||
"Allow": "OPTIONS, GET, HEAD, POST, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK",
|
||||
"MS-Author-Via": "DAV"
|
||||
}
|
||||
"MS-Author-Via": "DAV",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["PROPFIND"])
|
||||
async def handle_propfind(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
async def handle_propfind(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
depth = request.headers.get("Depth", "1")
|
||||
full_path = unquote(full_path).strip('/')
|
||||
full_path = unquote(full_path).strip("/")
|
||||
body = await request.body()
|
||||
requested_props = parse_propfind_body(body)
|
||||
|
||||
@ -232,19 +236,23 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
|
||||
if resource is None:
|
||||
response = ET.SubElement(multistatus, "D:response")
|
||||
href = ET.SubElement(response, "D:href")
|
||||
href.text = base_href if base_href.endswith('/') else base_href + '/'
|
||||
href.text = base_href if base_href.endswith("/") else base_href + "/"
|
||||
|
||||
props = {
|
||||
"resourcetype": "collection",
|
||||
"displayname": full_path.split('/')[-1] if full_path else "Root",
|
||||
"displayname": full_path.split("/")[-1] if full_path else "Root",
|
||||
"creationdate": datetime.now(),
|
||||
"getlastmodified": datetime.now()
|
||||
"getlastmodified": datetime.now(),
|
||||
}
|
||||
response.append(create_propstat_element(props))
|
||||
|
||||
if depth in ["1", "infinity"]:
|
||||
folders = await Folder.filter(owner=current_user, parent=parent, is_deleted=False)
|
||||
files = await File.filter(owner=current_user, parent=parent, is_deleted=False)
|
||||
folders = await Folder.filter(
|
||||
owner=current_user, parent=parent, is_deleted=False
|
||||
)
|
||||
files = await File.filter(
|
||||
owner=current_user, parent=parent, is_deleted=False
|
||||
)
|
||||
|
||||
for folder in folders:
|
||||
response = ET.SubElement(multistatus, "D:response")
|
||||
@ -255,9 +263,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
|
||||
"resourcetype": "collection",
|
||||
"displayname": folder.name,
|
||||
"creationdate": folder.created_at,
|
||||
"getlastmodified": folder.updated_at
|
||||
"getlastmodified": folder.updated_at,
|
||||
}
|
||||
custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
|
||||
custom_props = (
|
||||
await get_custom_properties("folder", folder.id)
|
||||
if requested_props == "allprop"
|
||||
or (
|
||||
isinstance(requested_props, list)
|
||||
and any(ns != "DAV:" for ns, _ in requested_props)
|
||||
)
|
||||
else None
|
||||
)
|
||||
response.append(create_propstat_element(props, custom_props))
|
||||
|
||||
for file in files:
|
||||
@ -272,28 +288,48 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
|
||||
"getcontenttype": file.mime_type,
|
||||
"creationdate": file.created_at,
|
||||
"getlastmodified": file.updated_at,
|
||||
"getetag": f'"{file.file_hash}"'
|
||||
"getetag": f'"{file.file_hash}"',
|
||||
}
|
||||
custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
|
||||
custom_props = (
|
||||
await get_custom_properties("file", file.id)
|
||||
if requested_props == "allprop"
|
||||
or (
|
||||
isinstance(requested_props, list)
|
||||
and any(ns != "DAV:" for ns, _ in requested_props)
|
||||
)
|
||||
else None
|
||||
)
|
||||
response.append(create_propstat_element(props, custom_props))
|
||||
|
||||
elif isinstance(resource, Folder):
|
||||
response = ET.SubElement(multistatus, "D:response")
|
||||
href = ET.SubElement(response, "D:href")
|
||||
href.text = base_href if base_href.endswith('/') else base_href + '/'
|
||||
href.text = base_href if base_href.endswith("/") else base_href + "/"
|
||||
|
||||
props = {
|
||||
"resourcetype": "collection",
|
||||
"displayname": resource.name,
|
||||
"creationdate": resource.created_at,
|
||||
"getlastmodified": resource.updated_at
|
||||
"getlastmodified": resource.updated_at,
|
||||
}
|
||||
custom_props = await get_custom_properties("folder", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
|
||||
custom_props = (
|
||||
await get_custom_properties("folder", resource.id)
|
||||
if requested_props == "allprop"
|
||||
or (
|
||||
isinstance(requested_props, list)
|
||||
and any(ns != "DAV:" for ns, _ in requested_props)
|
||||
)
|
||||
else None
|
||||
)
|
||||
response.append(create_propstat_element(props, custom_props))
|
||||
|
||||
if depth in ["1", "infinity"]:
|
||||
folders = await Folder.filter(owner=current_user, parent=resource, is_deleted=False)
|
||||
files = await File.filter(owner=current_user, parent=resource, is_deleted=False)
|
||||
folders = await Folder.filter(
|
||||
owner=current_user, parent=resource, is_deleted=False
|
||||
)
|
||||
files = await File.filter(
|
||||
owner=current_user, parent=resource, is_deleted=False
|
||||
)
|
||||
|
||||
for folder in folders:
|
||||
response = ET.SubElement(multistatus, "D:response")
|
||||
@ -304,9 +340,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
|
||||
"resourcetype": "collection",
|
||||
"displayname": folder.name,
|
||||
"creationdate": folder.created_at,
|
||||
"getlastmodified": folder.updated_at
|
||||
"getlastmodified": folder.updated_at,
|
||||
}
|
||||
custom_props = await get_custom_properties("folder", folder.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
|
||||
custom_props = (
|
||||
await get_custom_properties("folder", folder.id)
|
||||
if requested_props == "allprop"
|
||||
or (
|
||||
isinstance(requested_props, list)
|
||||
and any(ns != "DAV:" for ns, _ in requested_props)
|
||||
)
|
||||
else None
|
||||
)
|
||||
response.append(create_propstat_element(props, custom_props))
|
||||
|
||||
for file in files:
|
||||
@ -321,9 +365,17 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
|
||||
"getcontenttype": file.mime_type,
|
||||
"creationdate": file.created_at,
|
||||
"getlastmodified": file.updated_at,
|
||||
"getetag": f'"{file.file_hash}"'
|
||||
"getetag": f'"{file.file_hash}"',
|
||||
}
|
||||
custom_props = await get_custom_properties("file", file.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
|
||||
custom_props = (
|
||||
await get_custom_properties("file", file.id)
|
||||
if requested_props == "allprop"
|
||||
or (
|
||||
isinstance(requested_props, list)
|
||||
and any(ns != "DAV:" for ns, _ in requested_props)
|
||||
)
|
||||
else None
|
||||
)
|
||||
response.append(create_propstat_element(props, custom_props))
|
||||
|
||||
elif isinstance(resource, File):
|
||||
@ -338,17 +390,32 @@ async def handle_propfind(request: Request, full_path: str, current_user: User =
|
||||
"getcontenttype": resource.mime_type,
|
||||
"creationdate": resource.created_at,
|
||||
"getlastmodified": resource.updated_at,
|
||||
"getetag": f'"{resource.file_hash}"'
|
||||
"getetag": f'"{resource.file_hash}"',
|
||||
}
|
||||
custom_props = await get_custom_properties("file", resource.id) if requested_props == "allprop" or (isinstance(requested_props, list) and any(ns != "DAV:" for ns, _ in requested_props)) else None
|
||||
custom_props = (
|
||||
await get_custom_properties("file", resource.id)
|
||||
if requested_props == "allprop"
|
||||
or (
|
||||
isinstance(requested_props, list)
|
||||
and any(ns != "DAV:" for ns, _ in requested_props)
|
||||
)
|
||||
else None
|
||||
)
|
||||
response.append(create_propstat_element(props, custom_props))
|
||||
|
||||
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
|
||||
return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
|
||||
return Response(
|
||||
content=xml_content,
|
||||
media_type="application/xml; charset=utf-8",
|
||||
status_code=207,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["GET", "HEAD"])
|
||||
async def handle_get(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_get(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
|
||||
resource, parent, exists = await resolve_path(full_path, current_user)
|
||||
|
||||
@ -363,8 +430,10 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
|
||||
"Content-Length": str(resource.size),
|
||||
"Content-Type": resource.mime_type,
|
||||
"ETag": f'"{resource.file_hash}"',
|
||||
"Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
}
|
||||
"Last-Modified": resource.updated_at.strftime(
|
||||
"%a, %d %b %Y %H:%M:%S GMT"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
async def file_iterator():
|
||||
@ -377,23 +446,28 @@ async def handle_get(request: Request, full_path: str, current_user: User = Depe
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{resource.name}"',
|
||||
"ETag": f'"{resource.file_hash}"',
|
||||
"Last-Modified": resource.updated_at.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
}
|
||||
"Last-Modified": resource.updated_at.strftime(
|
||||
"%a, %d %b %Y %H:%M:%S GMT"
|
||||
),
|
||||
},
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="File not found in storage")
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["PUT"])
|
||||
async def handle_put(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_put(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
|
||||
if not full_path:
|
||||
raise HTTPException(status_code=400, detail="Cannot PUT to root")
|
||||
|
||||
parts = [p for p in full_path.split('/') if p]
|
||||
parts = [p for p in full_path.split("/") if p]
|
||||
file_name = parts[-1]
|
||||
|
||||
parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
|
||||
parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
|
||||
_, parent_folder, exists = await resolve_path(parent_path, current_user)
|
||||
|
||||
if not exists:
|
||||
@ -417,10 +491,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
existing_file = await File.get_or_none(
|
||||
name=file_name,
|
||||
parent=parent_folder,
|
||||
owner=current_user,
|
||||
is_deleted=False
|
||||
name=file_name, parent=parent_folder, owner=current_user, is_deleted=False
|
||||
)
|
||||
|
||||
if existing_file:
|
||||
@ -432,7 +503,9 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
|
||||
existing_file.updated_at = datetime.now()
|
||||
await existing_file.save()
|
||||
|
||||
current_user.used_storage_bytes = current_user.used_storage_bytes - old_size + file_size
|
||||
current_user.used_storage_bytes = (
|
||||
current_user.used_storage_bytes - old_size + file_size
|
||||
)
|
||||
await current_user.save()
|
||||
|
||||
await log_activity(current_user, "file_updated", "file", existing_file.id)
|
||||
@ -445,7 +518,7 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
|
||||
mime_type=mime_type,
|
||||
file_hash=file_hash,
|
||||
owner=current_user,
|
||||
parent=parent_folder
|
||||
parent=parent_folder,
|
||||
)
|
||||
|
||||
current_user.used_storage_bytes += file_size
|
||||
@ -454,9 +527,12 @@ async def handle_put(request: Request, full_path: str, current_user: User = Depe
|
||||
await log_activity(current_user, "file_created", "file", db_file.id)
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["DELETE"])
|
||||
async def handle_delete(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_delete(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
|
||||
if not full_path:
|
||||
raise HTTPException(status_code=400, detail="Cannot DELETE root")
|
||||
@ -478,44 +554,45 @@ async def handle_delete(request: Request, full_path: str, current_user: User = D
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["MKCOL"])
|
||||
async def handle_mkcol(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_mkcol(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
|
||||
if not full_path:
|
||||
raise HTTPException(status_code=400, detail="Cannot MKCOL at root")
|
||||
|
||||
parts = [p for p in full_path.split('/') if p]
|
||||
parts = [p for p in full_path.split("/") if p]
|
||||
folder_name = parts[-1]
|
||||
|
||||
parent_path = '/'.join(parts[:-1]) if len(parts) > 1 else ''
|
||||
parent_path = "/".join(parts[:-1]) if len(parts) > 1 else ""
|
||||
_, parent_folder, exists = await resolve_path(parent_path, current_user)
|
||||
|
||||
if not exists:
|
||||
raise HTTPException(status_code=409, detail="Parent folder does not exist")
|
||||
|
||||
existing = await Folder.get_or_none(
|
||||
name=folder_name,
|
||||
parent=parent_folder,
|
||||
owner=current_user,
|
||||
is_deleted=False
|
||||
name=folder_name, parent=parent_folder, owner=current_user, is_deleted=False
|
||||
)
|
||||
|
||||
if existing:
|
||||
raise HTTPException(status_code=405, detail="Folder already exists")
|
||||
|
||||
folder = await Folder.create(
|
||||
name=folder_name,
|
||||
parent=parent_folder,
|
||||
owner=current_user
|
||||
name=folder_name, parent=parent_folder, owner=current_user
|
||||
)
|
||||
|
||||
await log_activity(current_user, "folder_created", "folder", folder.id)
|
||||
return Response(status_code=201)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["COPY"])
|
||||
async def handle_copy(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_copy(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
destination = request.headers.get("Destination")
|
||||
overwrite = request.headers.get("Overwrite", "T")
|
||||
|
||||
@ -523,7 +600,7 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
|
||||
raise HTTPException(status_code=400, detail="Destination header required")
|
||||
|
||||
dest_path = unquote(urlparse(destination).path)
|
||||
dest_path = dest_path.replace('/webdav/', '').strip('/')
|
||||
dest_path = dest_path.replace("/webdav/", "").strip("/")
|
||||
|
||||
source_resource, _, exists = await resolve_path(full_path, current_user)
|
||||
|
||||
@ -533,9 +610,9 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
|
||||
if not isinstance(source_resource, File):
|
||||
raise HTTPException(status_code=501, detail="Only file copy is implemented")
|
||||
|
||||
dest_parts = [p for p in dest_path.split('/') if p]
|
||||
dest_parts = [p for p in dest_path.split("/") if p]
|
||||
dest_name = dest_parts[-1]
|
||||
dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
|
||||
dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
|
||||
|
||||
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
|
||||
|
||||
@ -543,14 +620,13 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
|
||||
raise HTTPException(status_code=409, detail="Destination parent does not exist")
|
||||
|
||||
existing_dest = await File.get_or_none(
|
||||
name=dest_name,
|
||||
parent=dest_parent,
|
||||
owner=current_user,
|
||||
is_deleted=False
|
||||
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
|
||||
)
|
||||
|
||||
if existing_dest and overwrite == "F":
|
||||
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
|
||||
raise HTTPException(
|
||||
status_code=412, detail="Destination exists and overwrite is false"
|
||||
)
|
||||
|
||||
if existing_dest:
|
||||
await existing_dest.delete()
|
||||
@ -562,15 +638,18 @@ async def handle_copy(request: Request, full_path: str, current_user: User = Dep
|
||||
mime_type=source_resource.mime_type,
|
||||
file_hash=source_resource.file_hash,
|
||||
owner=current_user,
|
||||
parent=dest_parent
|
||||
parent=dest_parent,
|
||||
)
|
||||
|
||||
await log_activity(current_user, "file_copied", "file", new_file.id)
|
||||
return Response(status_code=201 if not existing_dest else 204)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["MOVE"])
|
||||
async def handle_move(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_move(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
destination = request.headers.get("Destination")
|
||||
overwrite = request.headers.get("Overwrite", "T")
|
||||
|
||||
@ -578,16 +657,16 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
|
||||
raise HTTPException(status_code=400, detail="Destination header required")
|
||||
|
||||
dest_path = unquote(urlparse(destination).path)
|
||||
dest_path = dest_path.replace('/webdav/', '').strip('/')
|
||||
dest_path = dest_path.replace("/webdav/", "").strip("/")
|
||||
|
||||
source_resource, _, exists = await resolve_path(full_path, current_user)
|
||||
|
||||
if not source_resource:
|
||||
raise HTTPException(status_code=404, detail="Source not found")
|
||||
|
||||
dest_parts = [p for p in dest_path.split('/') if p]
|
||||
dest_parts = [p for p in dest_path.split("/") if p]
|
||||
dest_name = dest_parts[-1]
|
||||
dest_parent_path = '/'.join(dest_parts[:-1]) if len(dest_parts) > 1 else ''
|
||||
dest_parent_path = "/".join(dest_parts[:-1]) if len(dest_parts) > 1 else ""
|
||||
|
||||
_, dest_parent, exists = await resolve_path(dest_parent_path, current_user)
|
||||
|
||||
@ -596,14 +675,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
|
||||
|
||||
if isinstance(source_resource, File):
|
||||
existing_dest = await File.get_or_none(
|
||||
name=dest_name,
|
||||
parent=dest_parent,
|
||||
owner=current_user,
|
||||
is_deleted=False
|
||||
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
|
||||
)
|
||||
|
||||
if existing_dest and overwrite == "F":
|
||||
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
|
||||
raise HTTPException(
|
||||
status_code=412, detail="Destination exists and overwrite is false"
|
||||
)
|
||||
|
||||
if existing_dest:
|
||||
await existing_dest.delete()
|
||||
@ -617,14 +695,13 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
|
||||
|
||||
elif isinstance(source_resource, Folder):
|
||||
existing_dest = await Folder.get_or_none(
|
||||
name=dest_name,
|
||||
parent=dest_parent,
|
||||
owner=current_user,
|
||||
is_deleted=False
|
||||
name=dest_name, parent=dest_parent, owner=current_user, is_deleted=False
|
||||
)
|
||||
|
||||
if existing_dest and overwrite == "F":
|
||||
raise HTTPException(status_code=412, detail="Destination exists and overwrite is false")
|
||||
raise HTTPException(
|
||||
status_code=412, detail="Destination exists and overwrite is false"
|
||||
)
|
||||
|
||||
if existing_dest:
|
||||
existing_dest.is_deleted = True
|
||||
@ -637,9 +714,12 @@ async def handle_move(request: Request, full_path: str, current_user: User = Dep
|
||||
await log_activity(current_user, "folder_moved", "folder", source_resource.id)
|
||||
return Response(status_code=201 if not existing_dest else 204)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["LOCK"])
|
||||
async def handle_lock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_lock(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
|
||||
timeout_header = request.headers.get("Timeout", "Second-3600")
|
||||
timeout = 3600
|
||||
@ -681,32 +761,38 @@ async def handle_lock(request: Request, full_path: str, current_user: User = Dep
|
||||
content=xml_content,
|
||||
media_type="application/xml; charset=utf-8",
|
||||
status_code=200,
|
||||
headers={"Lock-Token": f"<{lock_token}>"}
|
||||
headers={"Lock-Token": f"<{lock_token}>"},
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["UNLOCK"])
|
||||
async def handle_unlock(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_unlock(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
lock_token_header = request.headers.get("Lock-Token")
|
||||
|
||||
if not lock_token_header:
|
||||
raise HTTPException(status_code=400, detail="Lock-Token header required")
|
||||
|
||||
lock_token = lock_token_header.strip('<>')
|
||||
lock_token = lock_token_header.strip("<>")
|
||||
existing_lock = WebDAVLock.get_lock(full_path)
|
||||
|
||||
if not existing_lock or existing_lock['token'] != lock_token:
|
||||
if not existing_lock or existing_lock["token"] != lock_token:
|
||||
raise HTTPException(status_code=409, detail="Invalid lock token")
|
||||
|
||||
if existing_lock['user_id'] != current_user.id:
|
||||
if existing_lock["user_id"] != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not lock owner")
|
||||
|
||||
WebDAVLock.remove_lock(full_path)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.api_route("/{full_path:path}", methods=["PROPPATCH"])
|
||||
async def handle_proppatch(request: Request, full_path: str, current_user: User = Depends(webdav_auth)):
|
||||
full_path = unquote(full_path).strip('/')
|
||||
async def handle_proppatch(
|
||||
request: Request, full_path: str, current_user: User = Depends(webdav_auth)
|
||||
):
|
||||
full_path = unquote(full_path).strip("/")
|
||||
|
||||
resource, parent, exists = await resolve_path(full_path, current_user)
|
||||
|
||||
@ -719,7 +805,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
|
||||
try:
|
||||
root = ET.fromstring(body)
|
||||
except:
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid XML")
|
||||
|
||||
resource_type = "file" if isinstance(resource, File) else "folder"
|
||||
@ -734,13 +820,19 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
prop_element = set_element.find(".//{DAV:}prop")
|
||||
if prop_element is not None:
|
||||
for child in prop_element:
|
||||
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
|
||||
name = child.tag.split('}')[1] if '}' in child.tag else child.tag
|
||||
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
|
||||
name = child.tag.split("}")[1] if "}" in child.tag else child.tag
|
||||
value = child.text or ""
|
||||
|
||||
if ns == "DAV:":
|
||||
live_props = ["creationdate", "getcontentlength", "getcontenttype",
|
||||
"getetag", "getlastmodified", "resourcetype"]
|
||||
live_props = [
|
||||
"creationdate",
|
||||
"getcontentlength",
|
||||
"getcontenttype",
|
||||
"getetag",
|
||||
"getlastmodified",
|
||||
"resourcetype",
|
||||
]
|
||||
if name in live_props:
|
||||
failed_props.append((ns, name, "409 Conflict"))
|
||||
continue
|
||||
@ -750,7 +842,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
namespace=ns,
|
||||
name=name
|
||||
name=name,
|
||||
)
|
||||
|
||||
if existing_prop:
|
||||
@ -762,10 +854,10 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
resource_id=resource_id,
|
||||
namespace=ns,
|
||||
name=name,
|
||||
value=value
|
||||
value=value,
|
||||
)
|
||||
set_props.append((ns, name))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
failed_props.append((ns, name, "500 Internal Server Error"))
|
||||
|
||||
remove_element = root.find(".//{DAV:}remove")
|
||||
@ -773,8 +865,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
prop_element = remove_element.find(".//{DAV:}prop")
|
||||
if prop_element is not None:
|
||||
for child in prop_element:
|
||||
ns = child.tag.split('}')[0][1:] if '}' in child.tag else "DAV:"
|
||||
name = child.tag.split('}')[1] if '}' in child.tag else child.tag
|
||||
ns = child.tag.split("}")[0][1:] if "}" in child.tag else "DAV:"
|
||||
name = child.tag.split("}")[1] if "}" in child.tag else child.tag
|
||||
|
||||
if ns == "DAV:":
|
||||
failed_props.append((ns, name, "409 Conflict"))
|
||||
@ -785,7 +877,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
namespace=ns,
|
||||
name=name
|
||||
name=name,
|
||||
)
|
||||
|
||||
if existing_prop:
|
||||
@ -793,7 +885,7 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
remove_props.append((ns, name))
|
||||
else:
|
||||
failed_props.append((ns, name, "404 Not Found"))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
failed_props.append((ns, name, "500 Internal Server Error"))
|
||||
|
||||
multistatus = ET.Element("D:multistatus", {"xmlns:D": "DAV:"})
|
||||
@ -837,4 +929,8 @@ async def handle_proppatch(request: Request, full_path: str, current_user: User
|
||||
await log_activity(current_user, "properties_modified", resource_type, resource_id)
|
||||
|
||||
xml_content = ET.tostring(multistatus, encoding="utf-8", xml_declaration=True)
|
||||
return Response(content=xml_content, media_type="application/xml; charset=utf-8", status_code=207)
|
||||
return Response(
|
||||
content=xml_content,
|
||||
media_type="application/xml; charset=utf-8",
|
||||
status_code=207,
|
||||
)
|
||||
|
||||
@ -6,9 +6,15 @@ from httpx import AsyncClient, ASGITransport
|
||||
from fastapi import status
|
||||
from mywebdav.main import app
|
||||
from mywebdav.models import User
|
||||
from mywebdav.billing.models import PricingConfig, Invoice, UsageAggregate, UserSubscription
|
||||
from mywebdav.billing.models import (
|
||||
PricingConfig,
|
||||
Invoice,
|
||||
UsageAggregate,
|
||||
UserSubscription,
|
||||
)
|
||||
from mywebdav.auth import create_access_token
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user():
|
||||
user = await User.create(
|
||||
@ -16,11 +22,12 @@ async def test_user():
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password_here",
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
is_superuser=False,
|
||||
)
|
||||
yield user
|
||||
await user.delete()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_user():
|
||||
user = await User.create(
|
||||
@ -28,58 +35,72 @@ async def admin_user():
|
||||
email="admin@example.com",
|
||||
hashed_password="hashed_password_here",
|
||||
is_active=True,
|
||||
is_superuser=True
|
||||
is_superuser=True,
|
||||
)
|
||||
yield user
|
||||
await user.delete()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def auth_token(test_user):
|
||||
token = create_access_token(data={"sub": test_user.username})
|
||||
return token
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_token(admin_user):
|
||||
token = create_access_token(data={"sub": admin_user.username})
|
||||
return token
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pricing_config():
|
||||
configs = []
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="storage_per_gb_month",
|
||||
config_value=Decimal("0.0045"),
|
||||
description="Storage cost per GB per month",
|
||||
unit="per_gb_month"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="bandwidth_egress_per_gb",
|
||||
config_value=Decimal("0.009"),
|
||||
description="Bandwidth egress cost per GB",
|
||||
unit="per_gb"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="free_tier_storage_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier storage in GB",
|
||||
unit="gb"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="free_tier_bandwidth_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier bandwidth in GB per month",
|
||||
unit="gb"
|
||||
))
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="storage_per_gb_month",
|
||||
config_value=Decimal("0.0045"),
|
||||
description="Storage cost per GB per month",
|
||||
unit="per_gb_month",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="bandwidth_egress_per_gb",
|
||||
config_value=Decimal("0.009"),
|
||||
description="Bandwidth egress cost per GB",
|
||||
unit="per_gb",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="free_tier_storage_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier storage in GB",
|
||||
unit="gb",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="free_tier_bandwidth_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier bandwidth in GB per month",
|
||||
unit="gb",
|
||||
)
|
||||
)
|
||||
yield configs
|
||||
for config in configs:
|
||||
await config.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_usage(test_user, auth_token):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"/api/billing/usage/current",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -87,6 +108,7 @@ async def test_get_current_usage(test_user, auth_token):
|
||||
assert "storage_gb" in data
|
||||
assert "bandwidth_down_gb_today" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_monthly_usage(test_user, auth_token):
|
||||
today = date.today()
|
||||
@ -94,16 +116,18 @@ async def test_get_monthly_usage(test_user, auth_token):
|
||||
await UsageAggregate.create(
|
||||
user=test_user,
|
||||
date=today,
|
||||
storage_bytes_avg=1024 ** 3 * 10,
|
||||
storage_bytes_peak=1024 ** 3 * 12,
|
||||
bandwidth_up_bytes=1024 ** 3 * 2,
|
||||
bandwidth_down_bytes=1024 ** 3 * 5
|
||||
storage_bytes_avg=1024**3 * 10,
|
||||
storage_bytes_peak=1024**3 * 12,
|
||||
bandwidth_up_bytes=1024**3 * 2,
|
||||
bandwidth_down_bytes=1024**3 * 5,
|
||||
)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
f"/api/billing/usage/monthly?year={today.year}&month={today.month}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -112,12 +136,15 @@ async def test_get_monthly_usage(test_user, auth_token):
|
||||
|
||||
await UsageAggregate.filter(user=test_user).delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_subscription(test_user, auth_token):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"/api/billing/subscription",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -127,6 +154,7 @@ async def test_get_subscription(test_user, auth_token):
|
||||
|
||||
await UserSubscription.filter(user=test_user).delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_invoices(test_user, auth_token):
|
||||
invoice = await Invoice.create(
|
||||
@ -137,13 +165,14 @@ async def test_list_invoices(test_user, auth_token):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="open"
|
||||
status="open",
|
||||
)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"/api/billing/invoices",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
"/api/billing/invoices", headers={"Authorization": f"Bearer {auth_token}"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -153,6 +182,7 @@ async def test_list_invoices(test_user, auth_token):
|
||||
|
||||
await invoice.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invoice(test_user, auth_token):
|
||||
invoice = await Invoice.create(
|
||||
@ -163,13 +193,15 @@ async def test_get_invoice(test_user, auth_token):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="open"
|
||||
status="open",
|
||||
)
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
f"/api/billing/invoices/{invoice.id}",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -178,39 +210,45 @@ async def test_get_invoice(test_user, auth_token):
|
||||
|
||||
await invoice.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pricing():
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/api/billing/pricing")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_get_pricing(admin_user, admin_token, pricing_config):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"/api/admin/billing/pricing",
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
|
||||
config_id = pricing_config[0].id
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.put(
|
||||
f"/api/admin/billing/pricing/{config_id}",
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
json={
|
||||
"config_key": "storage_per_gb_month",
|
||||
"config_value": 0.005
|
||||
}
|
||||
json={"config_key": "storage_per_gb_month", "config_value": 0.005},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -218,12 +256,15 @@ async def test_admin_update_pricing(admin_user, admin_token, pricing_config):
|
||||
updated = await PricingConfig.get(id=config_id)
|
||||
assert updated.config_value == Decimal("0.005")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_get_stats(admin_user, admin_token):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"/api/admin/billing/stats",
|
||||
headers={"Authorization": f"Bearer {admin_token}"}
|
||||
headers={"Authorization": f"Bearer {admin_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -232,12 +273,15 @@ async def test_admin_get_stats(admin_user, admin_token):
|
||||
assert "total_invoices" in data
|
||||
assert "pending_invoices" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_admin_cannot_access_admin_endpoints(test_user, auth_token):
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get(
|
||||
"/api/admin/billing/pricing",
|
||||
headers={"Authorization": f"Bearer {auth_token}"}
|
||||
headers={"Authorization": f"Bearer {auth_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@ -3,57 +3,76 @@ import pytest_asyncio
|
||||
from decimal import Decimal
|
||||
from datetime import date, datetime
|
||||
from mywebdav.models import User
|
||||
from mywebdav.billing.models import Invoice, InvoiceLineItem, PricingConfig, UsageAggregate, UserSubscription
|
||||
from mywebdav.billing.models import (
|
||||
Invoice,
|
||||
InvoiceLineItem,
|
||||
PricingConfig,
|
||||
UsageAggregate,
|
||||
UserSubscription,
|
||||
)
|
||||
from mywebdav.billing.invoice_generator import InvoiceGenerator
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user():
|
||||
user = await User.create(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password_here",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
yield user
|
||||
await user.delete()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pricing_config():
|
||||
configs = []
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="storage_per_gb_month",
|
||||
config_value=Decimal("0.0045"),
|
||||
description="Storage cost per GB per month",
|
||||
unit="per_gb_month"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="bandwidth_egress_per_gb",
|
||||
config_value=Decimal("0.009"),
|
||||
description="Bandwidth egress cost per GB",
|
||||
unit="per_gb"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="free_tier_storage_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier storage in GB",
|
||||
unit="gb"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="free_tier_bandwidth_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier bandwidth in GB per month",
|
||||
unit="gb"
|
||||
))
|
||||
configs.append(await PricingConfig.create(
|
||||
config_key="tax_rate_default",
|
||||
config_value=Decimal("0.0"),
|
||||
description="Default tax rate",
|
||||
unit="percentage"
|
||||
))
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="storage_per_gb_month",
|
||||
config_value=Decimal("0.0045"),
|
||||
description="Storage cost per GB per month",
|
||||
unit="per_gb_month",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="bandwidth_egress_per_gb",
|
||||
config_value=Decimal("0.009"),
|
||||
description="Bandwidth egress cost per GB",
|
||||
unit="per_gb",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="free_tier_storage_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier storage in GB",
|
||||
unit="gb",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="free_tier_bandwidth_gb",
|
||||
config_value=Decimal("15"),
|
||||
description="Free tier bandwidth in GB per month",
|
||||
unit="gb",
|
||||
)
|
||||
)
|
||||
configs.append(
|
||||
await PricingConfig.create(
|
||||
config_key="tax_rate_default",
|
||||
config_value=Decimal("0.0"),
|
||||
description="Default tax rate",
|
||||
unit="percentage",
|
||||
)
|
||||
)
|
||||
yield configs
|
||||
for config in configs:
|
||||
await config.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
|
||||
today = date.today()
|
||||
@ -61,13 +80,15 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
|
||||
await UsageAggregate.create(
|
||||
user=test_user,
|
||||
date=today,
|
||||
storage_bytes_avg=1024 ** 3 * 50,
|
||||
storage_bytes_peak=1024 ** 3 * 55,
|
||||
bandwidth_up_bytes=1024 ** 3 * 10,
|
||||
bandwidth_down_bytes=1024 ** 3 * 20
|
||||
storage_bytes_avg=1024**3 * 50,
|
||||
storage_bytes_peak=1024**3 * 55,
|
||||
bandwidth_up_bytes=1024**3 * 10,
|
||||
bandwidth_down_bytes=1024**3 * 20,
|
||||
)
|
||||
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(
|
||||
test_user, today.year, today.month
|
||||
)
|
||||
|
||||
assert invoice is not None
|
||||
assert invoice.user_id == test_user.id
|
||||
@ -80,6 +101,7 @@ async def test_generate_monthly_invoice_with_usage(test_user, pricing_config):
|
||||
await invoice.delete()
|
||||
await UsageAggregate.filter(user=test_user).delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_config):
|
||||
today = date.today()
|
||||
@ -87,18 +109,21 @@ async def test_generate_monthly_invoice_below_free_tier(test_user, pricing_confi
|
||||
await UsageAggregate.create(
|
||||
user=test_user,
|
||||
date=today,
|
||||
storage_bytes_avg=1024 ** 3 * 10,
|
||||
storage_bytes_peak=1024 ** 3 * 12,
|
||||
bandwidth_up_bytes=1024 ** 3 * 5,
|
||||
bandwidth_down_bytes=1024 ** 3 * 10
|
||||
storage_bytes_avg=1024**3 * 10,
|
||||
storage_bytes_peak=1024**3 * 12,
|
||||
bandwidth_up_bytes=1024**3 * 5,
|
||||
bandwidth_down_bytes=1024**3 * 10,
|
||||
)
|
||||
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(
|
||||
test_user, today.year, today.month
|
||||
)
|
||||
|
||||
assert invoice is None
|
||||
|
||||
await UsageAggregate.filter(user=test_user).delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalize_invoice(test_user, pricing_config):
|
||||
invoice = await Invoice.create(
|
||||
@ -109,7 +134,7 @@ async def test_finalize_invoice(test_user, pricing_config):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="draft"
|
||||
status="draft",
|
||||
)
|
||||
|
||||
finalized = await InvoiceGenerator.finalize_invoice(invoice)
|
||||
@ -118,6 +143,7 @@ async def test_finalize_invoice(test_user, pricing_config):
|
||||
|
||||
await finalized.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalize_invoice_already_finalized(test_user, pricing_config):
|
||||
invoice = await Invoice.create(
|
||||
@ -128,7 +154,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="open"
|
||||
status="open",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
@ -136,6 +162,7 @@ async def test_finalize_invoice_already_finalized(test_user, pricing_config):
|
||||
|
||||
await invoice.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_invoice_paid(test_user, pricing_config):
|
||||
invoice = await Invoice.create(
|
||||
@ -146,7 +173,7 @@ async def test_mark_invoice_paid(test_user, pricing_config):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="open"
|
||||
status="open",
|
||||
)
|
||||
|
||||
paid = await InvoiceGenerator.mark_invoice_paid(invoice)
|
||||
@ -156,10 +183,13 @@ async def test_mark_invoice_paid(test_user, pricing_config):
|
||||
|
||||
await paid.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoice_with_tax(test_user, pricing_config):
|
||||
# Update tax rate
|
||||
updated = await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.21"))
|
||||
updated = await PricingConfig.filter(config_key="tax_rate_default").update(
|
||||
config_value=Decimal("0.21")
|
||||
)
|
||||
assert updated == 1 # Should update 1 row
|
||||
|
||||
# Verify the update worked
|
||||
@ -171,13 +201,15 @@ async def test_invoice_with_tax(test_user, pricing_config):
|
||||
await UsageAggregate.create(
|
||||
user=test_user,
|
||||
date=today,
|
||||
storage_bytes_avg=1024 ** 3 * 50,
|
||||
storage_bytes_peak=1024 ** 3 * 55,
|
||||
bandwidth_up_bytes=1024 ** 3 * 10,
|
||||
bandwidth_down_bytes=1024 ** 3 * 20
|
||||
storage_bytes_avg=1024**3 * 50,
|
||||
storage_bytes_peak=1024**3 * 55,
|
||||
bandwidth_up_bytes=1024**3 * 10,
|
||||
bandwidth_down_bytes=1024**3 * 20,
|
||||
)
|
||||
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(test_user, today.year, today.month)
|
||||
invoice = await InvoiceGenerator.generate_monthly_invoice(
|
||||
test_user, today.year, today.month
|
||||
)
|
||||
|
||||
assert invoice is not None
|
||||
assert invoice.tax > 0
|
||||
@ -185,4 +217,6 @@ async def test_invoice_with_tax(test_user, pricing_config):
|
||||
|
||||
await invoice.delete()
|
||||
await UsageAggregate.filter(user=test_user).delete()
|
||||
await PricingConfig.filter(config_key="tax_rate_default").update(config_value=Decimal("0.0"))
|
||||
await PricingConfig.filter(config_key="tax_rate_default").update(
|
||||
config_value=Decimal("0.0")
|
||||
)
|
||||
|
||||
@ -4,21 +4,30 @@ from decimal import Decimal
|
||||
from datetime import date, datetime
|
||||
from mywebdav.models import User
|
||||
from mywebdav.billing.models import (
|
||||
SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
|
||||
Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
|
||||
SubscriptionPlan,
|
||||
UserSubscription,
|
||||
UsageRecord,
|
||||
UsageAggregate,
|
||||
Invoice,
|
||||
InvoiceLineItem,
|
||||
PricingConfig,
|
||||
PaymentMethod,
|
||||
BillingEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user():
|
||||
user = await User.create(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password_here",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
yield user
|
||||
await user.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscription_plan_creation():
|
||||
plan = await SubscriptionPlan.create(
|
||||
@ -28,7 +37,7 @@ async def test_subscription_plan_creation():
|
||||
storage_gb=100,
|
||||
bandwidth_gb=100,
|
||||
price_monthly=Decimal("5.00"),
|
||||
price_yearly=Decimal("50.00")
|
||||
price_yearly=Decimal("50.00"),
|
||||
)
|
||||
|
||||
assert plan.name == "starter"
|
||||
@ -36,12 +45,11 @@ async def test_subscription_plan_creation():
|
||||
assert plan.price_monthly == Decimal("5.00")
|
||||
await plan.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_subscription_creation(test_user):
|
||||
subscription = await UserSubscription.create(
|
||||
user=test_user,
|
||||
billing_type="pay_as_you_go",
|
||||
status="active"
|
||||
user=test_user, billing_type="pay_as_you_go", status="active"
|
||||
)
|
||||
|
||||
assert subscription.user_id == test_user.id
|
||||
@ -49,6 +57,7 @@ async def test_user_subscription_creation(test_user):
|
||||
assert subscription.status == "active"
|
||||
await subscription.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_record_creation(test_user):
|
||||
usage = await UsageRecord.create(
|
||||
@ -57,7 +66,7 @@ async def test_usage_record_creation(test_user):
|
||||
amount_bytes=1024 * 1024 * 100,
|
||||
resource_type="file",
|
||||
resource_id=1,
|
||||
idempotency_key="test_key_123"
|
||||
idempotency_key="test_key_123",
|
||||
)
|
||||
|
||||
assert usage.user_id == test_user.id
|
||||
@ -65,6 +74,7 @@ async def test_usage_record_creation(test_user):
|
||||
assert usage.amount_bytes == 1024 * 1024 * 100
|
||||
await usage.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_aggregate_creation(test_user):
|
||||
aggregate = await UsageAggregate.create(
|
||||
@ -73,13 +83,14 @@ async def test_usage_aggregate_creation(test_user):
|
||||
storage_bytes_avg=1024 * 1024 * 500,
|
||||
storage_bytes_peak=1024 * 1024 * 600,
|
||||
bandwidth_up_bytes=1024 * 1024 * 50,
|
||||
bandwidth_down_bytes=1024 * 1024 * 100
|
||||
bandwidth_down_bytes=1024 * 1024 * 100,
|
||||
)
|
||||
|
||||
assert aggregate.user_id == test_user.id
|
||||
assert aggregate.storage_bytes_avg == 1024 * 1024 * 500
|
||||
await aggregate.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoice_creation(test_user):
|
||||
invoice = await Invoice.create(
|
||||
@ -90,7 +101,7 @@ async def test_invoice_creation(test_user):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="draft"
|
||||
status="draft",
|
||||
)
|
||||
|
||||
assert invoice.user_id == test_user.id
|
||||
@ -98,6 +109,7 @@ async def test_invoice_creation(test_user):
|
||||
assert invoice.total == Decimal("10.00")
|
||||
await invoice.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoice_line_item_creation(test_user):
|
||||
invoice = await Invoice.create(
|
||||
@ -108,7 +120,7 @@ async def test_invoice_line_item_creation(test_user):
|
||||
subtotal=Decimal("10.00"),
|
||||
tax=Decimal("0.00"),
|
||||
total=Decimal("10.00"),
|
||||
status="draft"
|
||||
status="draft",
|
||||
)
|
||||
|
||||
line_item = await InvoiceLineItem.create(
|
||||
@ -117,7 +129,7 @@ async def test_invoice_line_item_creation(test_user):
|
||||
quantity=Decimal("100.000000"),
|
||||
unit_price=Decimal("0.100000"),
|
||||
amount=Decimal("10.0000"),
|
||||
item_type="storage"
|
||||
item_type="storage",
|
||||
)
|
||||
|
||||
assert line_item.invoice_id == invoice.id
|
||||
@ -126,6 +138,7 @@ async def test_invoice_line_item_creation(test_user):
|
||||
await line_item.delete()
|
||||
await invoice.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pricing_config_creation(test_user):
|
||||
config = await PricingConfig.create(
|
||||
@ -133,13 +146,14 @@ async def test_pricing_config_creation(test_user):
|
||||
config_value=Decimal("0.0045"),
|
||||
description="Storage cost per GB per month",
|
||||
unit="per_gb_month",
|
||||
updated_by=test_user
|
||||
updated_by=test_user,
|
||||
)
|
||||
|
||||
assert config.config_key == "storage_per_gb_month"
|
||||
assert config.config_value == Decimal("0.0045")
|
||||
await config.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_payment_method_creation(test_user):
|
||||
payment_method = await PaymentMethod.create(
|
||||
@ -150,7 +164,7 @@ async def test_payment_method_creation(test_user):
|
||||
last4="4242",
|
||||
brand="visa",
|
||||
exp_month=12,
|
||||
exp_year=2025
|
||||
exp_year=2025,
|
||||
)
|
||||
|
||||
assert payment_method.user_id == test_user.id
|
||||
@ -158,6 +172,7 @@ async def test_payment_method_creation(test_user):
|
||||
assert payment_method.is_default is True
|
||||
await payment_method.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_billing_event_creation(test_user):
|
||||
event = await BillingEvent.create(
|
||||
@ -165,7 +180,7 @@ async def test_billing_event_creation(test_user):
|
||||
event_type="invoice_created",
|
||||
stripe_event_id="evt_test_123",
|
||||
data={"invoice_id": 1},
|
||||
processed=False
|
||||
processed=False,
|
||||
)
|
||||
|
||||
assert event.user_id == test_user.id
|
||||
|
||||
@ -1,18 +1,35 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def test_billing_module_imports():
|
||||
from mywebdav.billing import models, stripe_client, usage_tracker, invoice_generator, scheduler
|
||||
from mywebdav.billing import (
|
||||
models,
|
||||
stripe_client,
|
||||
usage_tracker,
|
||||
invoice_generator,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
assert models is not None
|
||||
assert stripe_client is not None
|
||||
assert usage_tracker is not None
|
||||
assert invoice_generator is not None
|
||||
assert scheduler is not None
|
||||
|
||||
|
||||
def test_billing_models_exist():
|
||||
from mywebdav.billing.models import (
|
||||
SubscriptionPlan, UserSubscription, UsageRecord, UsageAggregate,
|
||||
Invoice, InvoiceLineItem, PricingConfig, PaymentMethod, BillingEvent
|
||||
SubscriptionPlan,
|
||||
UserSubscription,
|
||||
UsageRecord,
|
||||
UsageAggregate,
|
||||
Invoice,
|
||||
InvoiceLineItem,
|
||||
PricingConfig,
|
||||
PaymentMethod,
|
||||
BillingEvent,
|
||||
)
|
||||
|
||||
assert SubscriptionPlan is not None
|
||||
assert UserSubscription is not None
|
||||
assert UsageRecord is not None
|
||||
@ -23,55 +40,71 @@ def test_billing_models_exist():
|
||||
assert PaymentMethod is not None
|
||||
assert BillingEvent is not None
|
||||
|
||||
|
||||
def test_stripe_client_exists():
|
||||
from mywebdav.billing.stripe_client import StripeClient
|
||||
|
||||
assert StripeClient is not None
|
||||
assert hasattr(StripeClient, 'create_customer')
|
||||
assert hasattr(StripeClient, 'create_invoice')
|
||||
assert hasattr(StripeClient, 'finalize_invoice')
|
||||
assert hasattr(StripeClient, "create_customer")
|
||||
assert hasattr(StripeClient, "create_invoice")
|
||||
assert hasattr(StripeClient, "finalize_invoice")
|
||||
|
||||
|
||||
def test_usage_tracker_exists():
|
||||
from mywebdav.billing.usage_tracker import UsageTracker
|
||||
|
||||
assert UsageTracker is not None
|
||||
assert hasattr(UsageTracker, 'track_storage')
|
||||
assert hasattr(UsageTracker, 'track_bandwidth')
|
||||
assert hasattr(UsageTracker, 'aggregate_daily_usage')
|
||||
assert hasattr(UsageTracker, 'get_current_storage')
|
||||
assert hasattr(UsageTracker, 'get_monthly_usage')
|
||||
assert hasattr(UsageTracker, "track_storage")
|
||||
assert hasattr(UsageTracker, "track_bandwidth")
|
||||
assert hasattr(UsageTracker, "aggregate_daily_usage")
|
||||
assert hasattr(UsageTracker, "get_current_storage")
|
||||
assert hasattr(UsageTracker, "get_monthly_usage")
|
||||
|
||||
|
||||
def test_invoice_generator_exists():
|
||||
from mywebdav.billing.invoice_generator import InvoiceGenerator
|
||||
|
||||
assert InvoiceGenerator is not None
|
||||
assert hasattr(InvoiceGenerator, 'generate_monthly_invoice')
|
||||
assert hasattr(InvoiceGenerator, 'finalize_invoice')
|
||||
assert hasattr(InvoiceGenerator, 'mark_invoice_paid')
|
||||
assert hasattr(InvoiceGenerator, "generate_monthly_invoice")
|
||||
assert hasattr(InvoiceGenerator, "finalize_invoice")
|
||||
assert hasattr(InvoiceGenerator, "mark_invoice_paid")
|
||||
|
||||
|
||||
def test_scheduler_exists():
|
||||
from mywebdav.billing.scheduler import scheduler, start_scheduler, stop_scheduler
|
||||
|
||||
assert scheduler is not None
|
||||
assert callable(start_scheduler)
|
||||
assert callable(stop_scheduler)
|
||||
|
||||
|
||||
def test_routers_exist():
|
||||
from mywebdav.routers import billing, admin_billing
|
||||
|
||||
assert billing is not None
|
||||
assert admin_billing is not None
|
||||
assert hasattr(billing, 'router')
|
||||
assert hasattr(admin_billing, 'router')
|
||||
assert hasattr(billing, "router")
|
||||
assert hasattr(admin_billing, "router")
|
||||
|
||||
|
||||
def test_middleware_exists():
|
||||
from mywebdav.middleware.usage_tracking import UsageTrackingMiddleware
|
||||
|
||||
assert UsageTrackingMiddleware is not None
|
||||
|
||||
|
||||
def test_settings_updated():
|
||||
from mywebdav.settings import settings
|
||||
assert hasattr(settings, 'STRIPE_SECRET_KEY')
|
||||
assert hasattr(settings, 'STRIPE_PUBLISHABLE_KEY')
|
||||
assert hasattr(settings, 'STRIPE_WEBHOOK_SECRET')
|
||||
assert hasattr(settings, 'BILLING_ENABLED')
|
||||
|
||||
assert hasattr(settings, "STRIPE_SECRET_KEY")
|
||||
assert hasattr(settings, "STRIPE_PUBLISHABLE_KEY")
|
||||
assert hasattr(settings, "STRIPE_WEBHOOK_SECRET")
|
||||
assert hasattr(settings, "BILLING_ENABLED")
|
||||
|
||||
|
||||
def test_main_includes_billing():
|
||||
from mywebdav.main import app
|
||||
|
||||
routes = [route.path for route in app.routes]
|
||||
billing_routes = [r for r in routes if '/billing' in r]
|
||||
billing_routes = [r for r in routes if "/billing" in r]
|
||||
assert len(billing_routes) > 0
|
||||
|
||||
@ -6,24 +6,26 @@ from mywebdav.models import User, File, Folder
|
||||
from mywebdav.billing.models import UsageRecord, UsageAggregate
|
||||
from mywebdav.billing.usage_tracker import UsageTracker
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user():
|
||||
user = await User.create(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password="hashed_password_here",
|
||||
is_active=True
|
||||
is_active=True,
|
||||
)
|
||||
yield user
|
||||
await user.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_storage(test_user):
|
||||
await UsageTracker.track_storage(
|
||||
user=test_user,
|
||||
amount_bytes=1024 * 1024 * 100,
|
||||
resource_type="file",
|
||||
resource_id=1
|
||||
resource_id=1,
|
||||
)
|
||||
|
||||
records = await UsageRecord.filter(user=test_user, record_type="storage").all()
|
||||
@ -32,6 +34,7 @@ async def test_track_storage(test_user):
|
||||
|
||||
await records[0].delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_bandwidth_upload(test_user):
|
||||
await UsageTracker.track_bandwidth(
|
||||
@ -39,7 +42,7 @@ async def test_track_bandwidth_upload(test_user):
|
||||
amount_bytes=1024 * 1024 * 50,
|
||||
direction="up",
|
||||
resource_type="file",
|
||||
resource_id=1
|
||||
resource_id=1,
|
||||
)
|
||||
|
||||
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_up").all()
|
||||
@ -48,6 +51,7 @@ async def test_track_bandwidth_upload(test_user):
|
||||
|
||||
await records[0].delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_track_bandwidth_download(test_user):
|
||||
await UsageTracker.track_bandwidth(
|
||||
@ -55,32 +59,28 @@ async def test_track_bandwidth_download(test_user):
|
||||
amount_bytes=1024 * 1024 * 75,
|
||||
direction="down",
|
||||
resource_type="file",
|
||||
resource_id=1
|
||||
resource_id=1,
|
||||
)
|
||||
|
||||
records = await UsageRecord.filter(user=test_user, record_type="bandwidth_down").all()
|
||||
records = await UsageRecord.filter(
|
||||
user=test_user, record_type="bandwidth_down"
|
||||
).all()
|
||||
assert len(records) == 1
|
||||
assert records[0].amount_bytes == 1024 * 1024 * 75
|
||||
|
||||
await records[0].delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregate_daily_usage(test_user):
|
||||
await UsageTracker.track_storage(
|
||||
user=test_user,
|
||||
amount_bytes=1024 * 1024 * 100
|
||||
await UsageTracker.track_storage(user=test_user, amount_bytes=1024 * 1024 * 100)
|
||||
|
||||
await UsageTracker.track_bandwidth(
|
||||
user=test_user, amount_bytes=1024 * 1024 * 50, direction="up"
|
||||
)
|
||||
|
||||
await UsageTracker.track_bandwidth(
|
||||
user=test_user,
|
||||
amount_bytes=1024 * 1024 * 50,
|
||||
direction="up"
|
||||
)
|
||||
|
||||
await UsageTracker.track_bandwidth(
|
||||
user=test_user,
|
||||
amount_bytes=1024 * 1024 * 75,
|
||||
direction="down"
|
||||
user=test_user, amount_bytes=1024 * 1024 * 75, direction="down"
|
||||
)
|
||||
|
||||
aggregate = await UsageTracker.aggregate_daily_usage(test_user, date.today())
|
||||
@ -93,12 +93,10 @@ async def test_aggregate_daily_usage(test_user):
|
||||
await aggregate.delete()
|
||||
await UsageRecord.filter(user=test_user).delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_storage(test_user):
|
||||
folder = await Folder.create(
|
||||
name="Test Folder",
|
||||
owner=test_user
|
||||
)
|
||||
folder = await Folder.create(name="Test Folder", owner=test_user)
|
||||
|
||||
file1 = await File.create(
|
||||
name="test1.txt",
|
||||
@ -107,7 +105,7 @@ async def test_get_current_storage(test_user):
|
||||
mime_type="text/plain",
|
||||
owner=test_user,
|
||||
parent=folder,
|
||||
is_deleted=False
|
||||
is_deleted=False,
|
||||
)
|
||||
|
||||
file2 = await File.create(
|
||||
@ -117,7 +115,7 @@ async def test_get_current_storage(test_user):
|
||||
mime_type="text/plain",
|
||||
owner=test_user,
|
||||
parent=folder,
|
||||
is_deleted=False
|
||||
is_deleted=False,
|
||||
)
|
||||
|
||||
storage = await UsageTracker.get_current_storage(test_user)
|
||||
@ -128,6 +126,7 @@ async def test_get_current_storage(test_user):
|
||||
await file2.delete()
|
||||
await folder.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_monthly_usage(test_user):
|
||||
today = date.today()
|
||||
@ -138,7 +137,7 @@ async def test_get_monthly_usage(test_user):
|
||||
storage_bytes_avg=1024 * 1024 * 1024 * 10,
|
||||
storage_bytes_peak=1024 * 1024 * 1024 * 12,
|
||||
bandwidth_up_bytes=1024 * 1024 * 1024 * 2,
|
||||
bandwidth_down_bytes=1024 * 1024 * 1024 * 5
|
||||
bandwidth_down_bytes=1024 * 1024 * 1024 * 5,
|
||||
)
|
||||
|
||||
usage = await UsageTracker.get_monthly_usage(test_user, today.year, today.month)
|
||||
@ -150,11 +149,14 @@ async def test_get_monthly_usage(test_user):
|
||||
|
||||
await aggregate.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_monthly_usage_empty(test_user):
|
||||
future_date = date.today() + timedelta(days=365)
|
||||
|
||||
usage = await UsageTracker.get_monthly_usage(test_user, future_date.year, future_date.month)
|
||||
usage = await UsageTracker.get_monthly_usage(
|
||||
test_user, future_date.year, future_date.month
|
||||
)
|
||||
|
||||
assert usage["storage_gb_avg"] == 0
|
||||
assert usage["storage_gb_peak"] == 0
|
||||
|
||||
@ -3,27 +3,21 @@ import pytest_asyncio
|
||||
import asyncio
|
||||
from tortoise import Tortoise
|
||||
|
||||
# Initialize Tortoise at module level
|
||||
async def _init_tortoise():
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def init_db():
|
||||
"""Initialize the database for testing."""
|
||||
await Tortoise.init(
|
||||
db_url="sqlite://:memory:",
|
||||
modules={
|
||||
"models": ["mywebdav.models"],
|
||||
"billing": ["mywebdav.billing.models"]
|
||||
}
|
||||
modules={"models": ["mywebdav.models"], "billing": ["mywebdav.billing.models"]},
|
||||
)
|
||||
await Tortoise.generate_schemas()
|
||||
|
||||
# Run initialization
|
||||
asyncio.run(_init_tortoise())
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def cleanup_db():
|
||||
yield
|
||||
await Tortoise.close_connections()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def setup_db(init_db):
|
||||
"""Ensure database is initialized before each test."""
|
||||
# The init_db fixture handles the setup
|
||||
yield
|
||||
|
||||
Loading…
Reference in New Issue
Block a user