From 25201421167304ce949d875b0022a6a31c9ba83e Mon Sep 17 00:00:00 2001 From: retoor Date: Sat, 29 Nov 2025 12:18:53 +0100 Subject: [PATCH] Update --- mywebdav/activity.py | 51 ++++- mywebdav/billing/usage_tracker.py | 91 ++++++-- mywebdav/enterprise/__init__.py | 13 ++ mywebdav/enterprise/dal.py | 335 ++++++++++++++++++++++++++++ mywebdav/enterprise/write_buffer.py | 178 +++++++++++++++ mywebdav/main.py | 25 +++ mywebdav/webdav.py | 121 +++++++--- 7 files changed, 753 insertions(+), 61 deletions(-) create mode 100644 mywebdav/enterprise/__init__.py create mode 100644 mywebdav/enterprise/dal.py create mode 100644 mywebdav/enterprise/write_buffer.py diff --git a/mywebdav/activity.py b/mywebdav/activity.py index 8537e56..771403a 100644 --- a/mywebdav/activity.py +++ b/mywebdav/activity.py @@ -1,5 +1,26 @@ -from typing import Optional +from typing import Optional, List, Dict from .models import Activity, User +import logging + +logger = logging.getLogger(__name__) + + +async def _flush_activities(records: List[Dict]): + from tortoise.transactions import in_transaction + try: + async with in_transaction(): + for record in records: + await Activity.create( + user_id=record.get("user_id"), + action=record["action"], + target_type=record["target_type"], + target_id=record["target_id"], + ip_address=record.get("ip_address"), + ) + logger.debug(f"Flushed {len(records)} activity records") + except Exception as e: + logger.error(f"Failed to flush activity records: {e}") + raise async def log_activity( @@ -9,10 +30,24 @@ async def log_activity( target_id: int, ip_address: Optional[str] = None, ): - await Activity.create( - user=user, - action=action, - target_type=target_type, - target_id=target_id, - ip_address=ip_address, - ) + try: + from .enterprise.write_buffer import get_write_buffer, WriteType + buffer = get_write_buffer() + await buffer.buffer( + WriteType.ACTIVITY, + { + "user_id": user.id if user else None, + "action": action, + "target_type": target_type, + "target_id": target_id, + "ip_address": ip_address, + }, + ) + except RuntimeError: + await Activity.create( + user=user, + action=action, + target_type=target_type, + target_id=target_id, + ip_address=ip_address, + ) diff --git a/mywebdav/billing/usage_tracker.py b/mywebdav/billing/usage_tracker.py index 9433a1d..c4723ac 100644 --- a/mywebdav/billing/usage_tracker.py +++ b/mywebdav/billing/usage_tracker.py @@ -1,11 +1,34 @@ import uuid +import logging from datetime import datetime, date, timezone, timedelta +from typing import List, Dict from tortoise.transactions import in_transaction from .models import UsageRecord, UsageAggregate from ..models import User +logger = logging.getLogger(__name__) + + +async def _flush_usage_records(records: List[Dict]): + try: + async with in_transaction(): + for record in records: + await UsageRecord.create( + user_id=record["user_id"], + record_type=record["record_type"], + amount_bytes=record["amount_bytes"], + resource_type=record.get("resource_type"), + resource_id=record.get("resource_id"), + idempotency_key=record["idempotency_key"], + metadata=record.get("metadata"), + ) + logger.debug(f"Flushed {len(records)} usage records") + except Exception as e: + logger.error(f"Failed to flush usage records: {e}") + raise + class UsageTracker: @staticmethod @@ -18,15 +41,31 @@ class UsageTracker: ): idempotency_key = f"storage_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" - await UsageRecord.create( - user=user, - record_type="storage", - amount_bytes=amount_bytes, - resource_type=resource_type, - resource_id=resource_id, - idempotency_key=idempotency_key, - metadata=metadata, - ) + try: + from ..enterprise.write_buffer import get_write_buffer, WriteType + buffer = get_write_buffer() + await buffer.buffer( + WriteType.USAGE_RECORD, + { + "user_id": user.id, + "record_type": "storage", + "amount_bytes": amount_bytes, + "resource_type": resource_type, + "resource_id": resource_id, + "idempotency_key": idempotency_key, + "metadata": metadata, + }, + ) + except RuntimeError: + await UsageRecord.create( + user=user, + record_type="storage", + amount_bytes=amount_bytes, + resource_type=resource_type, + resource_id=resource_id, + idempotency_key=idempotency_key, + metadata=metadata, + ) @staticmethod async def track_bandwidth( @@ -40,15 +79,31 @@ class UsageTracker: record_type = f"bandwidth_{direction}" idempotency_key = f"{record_type}_{user.id}_{datetime.now(timezone.utc).timestamp()}_{uuid.uuid4().hex[:8]}" - await UsageRecord.create( - user=user, - record_type=record_type, - amount_bytes=amount_bytes, - resource_type=resource_type, - resource_id=resource_id, - idempotency_key=idempotency_key, - metadata=metadata, - ) + try: + from ..enterprise.write_buffer import get_write_buffer, WriteType + buffer = get_write_buffer() + await buffer.buffer( + WriteType.USAGE_RECORD, + { + "user_id": user.id, + "record_type": record_type, + "amount_bytes": amount_bytes, + "resource_type": resource_type, + "resource_id": resource_id, + "idempotency_key": idempotency_key, + "metadata": metadata, + }, + ) + except RuntimeError: + await UsageRecord.create( + user=user, + record_type=record_type, + amount_bytes=amount_bytes, + resource_type=resource_type, + resource_id=resource_id, + idempotency_key=idempotency_key, + metadata=metadata, + ) @staticmethod async def aggregate_daily_usage(user: User, target_date: date = None): diff --git a/mywebdav/enterprise/__init__.py b/mywebdav/enterprise/__init__.py new file mode 100644 index 0000000..db8d597 --- /dev/null +++ b/mywebdav/enterprise/__init__.py @@ -0,0 +1,13 @@ +from .dal import DataAccessLayer, get_dal, init_dal, shutdown_dal +from .write_buffer import WriteBuffer, get_write_buffer, init_write_buffer, shutdown_write_buffer + +__all__ = [ + "DataAccessLayer", + "get_dal", + "init_dal", + "shutdown_dal", + "WriteBuffer", + "get_write_buffer", + "init_write_buffer", + "shutdown_write_buffer", +] diff --git a/mywebdav/enterprise/dal.py b/mywebdav/enterprise/dal.py new file mode 100644 index 0000000..3982bae --- /dev/null +++ b/mywebdav/enterprise/dal.py @@ -0,0 +1,335 @@ +import asyncio +import hashlib +import time +from typing import Dict, List, Any, Optional, Tuple, Union +from dataclasses import dataclass, field +from collections import OrderedDict +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + value: Any + created_at: float = field(default_factory=time.time) + ttl: float = 300.0 + access_count: int = 0 + + @property + def is_expired(self) -> bool: + return time.time() - self.created_at > self.ttl + + +class DALCache: + def __init__(self, maxsize: int = 50000): + self.maxsize = maxsize + self._data: OrderedDict[str, CacheEntry] = OrderedDict() + self._lock = asyncio.Lock() + self._stats = {"hits": 0, "misses": 0, "evictions": 0} + + async def get(self, key: str) -> Optional[Any]: + async with self._lock: + if key in self._data: + entry = self._data[key] + if entry.is_expired: + del self._data[key] + self._stats["misses"] += 1 + return None + entry.access_count += 1 + self._data.move_to_end(key) + self._stats["hits"] += 1 + return entry.value + self._stats["misses"] += 1 + return None + + async def set(self, key: str, value: Any, ttl: float = 300.0): + async with self._lock: + if len(self._data) >= self.maxsize: + oldest_key = next(iter(self._data)) + del self._data[oldest_key] + self._stats["evictions"] += 1 + self._data[key] = CacheEntry(value=value, ttl=ttl) + + async def delete(self, key: str): + async with self._lock: + self._data.pop(key, None) + + async def invalidate_prefix(self, prefix: str) -> int: + async with self._lock: + keys_to_delete = [k for k in self._data.keys() if k.startswith(prefix)] + for k in keys_to_delete: + del self._data[k] + return len(keys_to_delete) + + async def clear(self): + async with self._lock: + self._data.clear() + + def get_stats(self) -> Dict[str, Any]: + total = self._stats["hits"] + self._stats["misses"] + hit_rate = (self._stats["hits"] / total * 100) if total > 0 else 0 + return { + **self._stats, + "size": len(self._data), + "hit_rate": round(hit_rate, 2), + } + + +class DataAccessLayer: + TTL_USER = 600.0 + TTL_FOLDER = 120.0 + TTL_FILE = 120.0 + TTL_PATH = 60.0 + TTL_CONTENTS = 30.0 + + def __init__(self, cache_size: int = 50000): + self._cache = DALCache(maxsize=cache_size) + self._path_cache: Dict[str, Tuple[Any, Any, bool]] = {} + self._path_lock = asyncio.Lock() + + async def start(self): + logger.info("DataAccessLayer started") + + async def stop(self): + await self._cache.clear() + async with self._path_lock: + self._path_cache.clear() + logger.info("DataAccessLayer stopped") + + def _user_key(self, user_id: int) -> str: + return f"user:{user_id}" + + def _user_by_name_key(self, username: str) -> str: + return f"user:name:{username}" + + def _folder_key(self, user_id: int, folder_id: Optional[int]) -> str: + return f"u:{user_id}:f:{folder_id or 'root'}" + + def _file_key(self, user_id: int, file_id: int) -> str: + return f"u:{user_id}:file:{file_id}" + + def _folder_contents_key(self, user_id: int, folder_id: Optional[int]) -> str: + return f"u:{user_id}:contents:{folder_id or 'root'}" + + def _path_key(self, user_id: int, path: str) -> str: + path_hash = hashlib.md5(path.encode()).hexdigest()[:12] + return f"u:{user_id}:path:{path_hash}" + + async def get_user_by_id(self, user_id: int) -> Optional[Any]: + key = self._user_key(user_id) + cached = await self._cache.get(key) + if cached is not None: + return cached + + from ..models import User + user = await User.get_or_none(id=user_id) + if user: + await self._cache.set(key, user, ttl=self.TTL_USER) + await self._cache.set(self._user_by_name_key(user.username), user, ttl=self.TTL_USER) + return user + + async def get_user_by_username(self, username: str) -> Optional[Any]: + key = self._user_by_name_key(username) + cached = await self._cache.get(key) + if cached is not None: + return cached + + from ..models import User + user = await User.get_or_none(username=username) + if user: + await self._cache.set(key, user, ttl=self.TTL_USER) + await self._cache.set(self._user_key(user.id), user, ttl=self.TTL_USER) + return user + + async def invalidate_user(self, user_id: int, username: Optional[str] = None): + await self._cache.delete(self._user_key(user_id)) + if username: + await self._cache.delete(self._user_by_name_key(username)) + await self._cache.invalidate_prefix(f"u:{user_id}:") + + async def get_folder( + self, + user_id: int, + folder_id: Optional[int] = None, + name: Optional[str] = None, + parent_id: Optional[int] = None, + ) -> Optional[Any]: + if folder_id: + key = self._folder_key(user_id, folder_id) + cached = await self._cache.get(key) + if cached is not None: + return cached + + from ..models import Folder + + if folder_id: + folder = await Folder.get_or_none(id=folder_id, owner_id=user_id, is_deleted=False) + elif name is not None: + folder = await Folder.get_or_none( + name=name, parent_id=parent_id, owner_id=user_id, is_deleted=False + ) + else: + return None + + if folder: + await self._cache.set(self._folder_key(user_id, folder.id), folder, ttl=self.TTL_FOLDER) + return folder + + async def get_file( + self, + user_id: int, + file_id: Optional[int] = None, + name: Optional[str] = None, + parent_id: Optional[int] = None, + ) -> Optional[Any]: + if file_id: + key = self._file_key(user_id, file_id) + cached = await self._cache.get(key) + if cached is not None: + return cached + + from ..models import File + + if file_id: + file = await File.get_or_none(id=file_id, owner_id=user_id, is_deleted=False) + elif name is not None: + file = await File.get_or_none( + name=name, parent_id=parent_id, owner_id=user_id, is_deleted=False + ) + else: + return None + + if file: + await self._cache.set(self._file_key(user_id, file.id), file, ttl=self.TTL_FILE) + return file + + async def get_folder_contents( + self, + user_id: int, + folder_id: Optional[int] = None, + ) -> Tuple[List[Any], List[Any]]: + key = self._folder_contents_key(user_id, folder_id) + cached = await self._cache.get(key) + if cached is not None: + return cached + + from ..models import Folder, File + + folders = await Folder.filter( + owner_id=user_id, parent_id=folder_id, is_deleted=False + ).all() + files = await File.filter( + owner_id=user_id, parent_id=folder_id, is_deleted=False + ).all() + + result = (folders, files) + await self._cache.set(key, result, ttl=self.TTL_CONTENTS) + return result + + async def resolve_path( + self, + user_id: int, + path_str: str, + ) -> Tuple[Optional[Any], Optional[Any], bool]: + path_str = path_str.strip("/") + if not path_str: + return None, None, True + + key = self._path_key(user_id, path_str) + cached = await self._cache.get(key) + if cached is not None: + return cached + + from ..models import Folder, File + + parts = [p for p in path_str.split("/") if p] + current_folder = None + current_folder_id = None + + for i, part in enumerate(parts[:-1]): + folder = await Folder.get_or_none( + name=part, parent_id=current_folder_id, owner_id=user_id, is_deleted=False + ) + if not folder: + result = (None, None, False) + await self._cache.set(key, result, ttl=self.TTL_PATH) + return result + current_folder = folder + current_folder_id = folder.id + await self._cache.set( + self._folder_key(user_id, folder.id), folder, ttl=self.TTL_FOLDER + ) + + last_part = parts[-1] + + folder = await Folder.get_or_none( + name=last_part, parent_id=current_folder_id, owner_id=user_id, is_deleted=False + ) + if folder: + await self._cache.set( + self._folder_key(user_id, folder.id), folder, ttl=self.TTL_FOLDER + ) + result = (folder, current_folder, True) + await self._cache.set(key, result, ttl=self.TTL_PATH) + return result + + file = await File.get_or_none( + name=last_part, parent_id=current_folder_id, owner_id=user_id, is_deleted=False + ) + if file: + await self._cache.set( + self._file_key(user_id, file.id), file, ttl=self.TTL_FILE + ) + result = (file, current_folder, True) + await self._cache.set(key, result, ttl=self.TTL_PATH) + return result + + result = (None, current_folder, False) + await self._cache.set(key, result, ttl=self.TTL_PATH) + return result + + async def invalidate_folder(self, user_id: int, folder_id: Optional[int] = None): + await self._cache.delete(self._folder_key(user_id, folder_id)) + await self._cache.delete(self._folder_contents_key(user_id, folder_id)) + await self._cache.invalidate_prefix(f"u:{user_id}:path:") + + async def invalidate_file(self, user_id: int, file_id: int, parent_id: Optional[int] = None): + await self._cache.delete(self._file_key(user_id, file_id)) + if parent_id is not None: + await self._cache.delete(self._folder_contents_key(user_id, parent_id)) + await self._cache.invalidate_prefix(f"u:{user_id}:path:") + + async def invalidate_path(self, user_id: int, path: str): + key = self._path_key(user_id, path) + await self._cache.delete(key) + + async def invalidate_user_paths(self, user_id: int): + await self._cache.invalidate_prefix(f"u:{user_id}:path:") + await self._cache.invalidate_prefix(f"u:{user_id}:contents:") + + def get_stats(self) -> Dict[str, Any]: + return self._cache.get_stats() + + +_dal: Optional[DataAccessLayer] = None + + +async def init_dal(cache_size: int = 50000) -> DataAccessLayer: + global _dal + _dal = DataAccessLayer(cache_size=cache_size) + await _dal.start() + return _dal + + +async def shutdown_dal(): + global _dal + if _dal: + await _dal.stop() + _dal = None + + +def get_dal() -> DataAccessLayer: + if not _dal: + raise RuntimeError("DataAccessLayer not initialized") + return _dal diff --git a/mywebdav/enterprise/write_buffer.py b/mywebdav/enterprise/write_buffer.py new file mode 100644 index 0000000..7109aeb --- /dev/null +++ b/mywebdav/enterprise/write_buffer.py @@ -0,0 +1,178 @@ +import asyncio +import time +import uuid +from typing import Dict, List, Any, Optional, Callable, Awaitable +from dataclasses import dataclass, field +from enum import Enum +from collections import defaultdict +import logging + +logger = logging.getLogger(__name__) + + +class WriteType(Enum): + ACTIVITY = "activity" + USAGE_RECORD = "usage_record" + FILE_ACCESS = "file_access" + WEBDAV_PROPERTY = "webdav_property" + + +@dataclass +class BufferedWrite: + write_type: WriteType + data: Dict[str, Any] + created_at: float = field(default_factory=time.time) + priority: int = 0 + + +class WriteBuffer: + def __init__( + self, + flush_interval: float = 60.0, + max_buffer_size: int = 1000, + immediate_types: Optional[set] = None, + ): + self.flush_interval = flush_interval + self.max_buffer_size = max_buffer_size + self.immediate_types = immediate_types or set() + self._buffers: Dict[WriteType, List[BufferedWrite]] = defaultdict(list) + self._lock = asyncio.Lock() + self._flush_task: Optional[asyncio.Task] = None + self._running = False + self._handlers: Dict[WriteType, Callable[[List[Dict]], Awaitable[None]]] = {} + self._stats = { + "buffered": 0, + "flushed": 0, + "immediate": 0, + "errors": 0, + } + + async def start(self): + self._running = True + self._flush_task = asyncio.create_task(self._background_flusher()) + logger.info(f"WriteBuffer started (interval={self.flush_interval}s)") + + async def stop(self): + self._running = False + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + await self.flush_all() + logger.info("WriteBuffer stopped") + + def register_handler( + self, + write_type: WriteType, + handler: Callable[[List[Dict]], Awaitable[None]], + ): + self._handlers[write_type] = handler + logger.debug(f"Registered write handler for {write_type.value}") + + async def buffer( + self, + write_type: WriteType, + data: Dict[str, Any], + priority: int = 0, + ): + if write_type in self.immediate_types: + await self._execute_immediate(write_type, data) + return + + async with self._lock: + self._buffers[write_type].append( + BufferedWrite(write_type=write_type, data=data, priority=priority) + ) + self._stats["buffered"] += 1 + + if len(self._buffers[write_type]) >= self.max_buffer_size: + await self._flush_type(write_type) + + async def _execute_immediate(self, write_type: WriteType, data: Dict[str, Any]): + handler = self._handlers.get(write_type) + if handler: + try: + await handler([data]) + self._stats["immediate"] += 1 + except Exception as e: + logger.error(f"Immediate write failed for {write_type.value}: {e}") + self._stats["errors"] += 1 + else: + logger.warning(f"No handler for immediate write type: {write_type.value}") + + async def _flush_type(self, write_type: WriteType): + if not self._buffers[write_type]: + return + + writes = self._buffers[write_type] + self._buffers[write_type] = [] + + handler = self._handlers.get(write_type) + if handler: + try: + data_list = [w.data for w in writes] + await handler(data_list) + self._stats["flushed"] += len(writes) + logger.debug(f"Flushed {len(writes)} {write_type.value} writes") + except Exception as e: + logger.error(f"Flush failed for {write_type.value}: {e}") + self._stats["errors"] += len(writes) + async with self._lock: + self._buffers[write_type].extend(writes) + + async def flush_all(self): + async with self._lock: + types_to_flush = list(self._buffers.keys()) + + for write_type in types_to_flush: + async with self._lock: + await self._flush_type(write_type) + + async def _background_flusher(self): + while self._running: + try: + await asyncio.sleep(self.flush_interval) + await self.flush_all() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in write buffer flusher: {e}") + + def get_stats(self) -> Dict[str, Any]: + buffer_sizes = {t.value: len(b) for t, b in self._buffers.items()} + return { + **self._stats, + "buffer_sizes": buffer_sizes, + "total_buffered": sum(buffer_sizes.values()), + } + + +_write_buffer: Optional[WriteBuffer] = None + + +async def init_write_buffer( + flush_interval: float = 60.0, + max_buffer_size: int = 1000, +) -> WriteBuffer: + global _write_buffer + _write_buffer = WriteBuffer( + flush_interval=flush_interval, + max_buffer_size=max_buffer_size, + ) + await _write_buffer.start() + return _write_buffer + + +async def shutdown_write_buffer(): + global _write_buffer + if _write_buffer: + await _write_buffer.stop() + _write_buffer = None + + +def get_write_buffer() -> WriteBuffer: + if not _write_buffer: + raise RuntimeError("WriteBuffer not initialized") + return _write_buffer diff --git a/mywebdav/main.py b/mywebdav/main.py index 7010d78..25c7ba5 100644 --- a/mywebdav/main.py +++ b/mywebdav/main.py @@ -38,6 +38,8 @@ enterprise_components = { "rate_limiter": None, "webdav_locks": None, "task_queue": None, + "dal": None, + "write_buffer": None, } @@ -50,6 +52,8 @@ async def init_enterprise_components(): from .auth_tokens import init_token_manager from .middleware.rate_limit import init_rate_limiter from .workers.queue import init_task_queue + from .enterprise.dal import init_dal + from .enterprise.write_buffer import init_write_buffer, WriteType enterprise_components["db_manager"] = await init_db_manager("data") logger.info("Enterprise: Per-user database manager initialized") @@ -57,6 +61,19 @@ async def init_enterprise_components(): enterprise_components["cache"] = await init_cache(maxsize=10000, flush_interval=30) logger.info("Enterprise: Cache layer initialized") + enterprise_components["dal"] = await init_dal(cache_size=50000) + logger.info("Enterprise: Data Access Layer initialized") + + enterprise_components["write_buffer"] = await init_write_buffer( + flush_interval=60.0, + max_buffer_size=1000, + ) + from .activity import _flush_activities + from .billing.usage_tracker import _flush_usage_records + enterprise_components["write_buffer"].register_handler(WriteType.ACTIVITY, _flush_activities) + enterprise_components["write_buffer"].register_handler(WriteType.USAGE_RECORD, _flush_usage_records) + logger.info("Enterprise: Write buffer initialized (60s flush interval)") + enterprise_components["lock_manager"] = await init_lock_manager(default_timeout=30.0) logger.info("Enterprise: Lock manager initialized") @@ -88,6 +105,8 @@ async def shutdown_enterprise_components(): from .auth_tokens import shutdown_token_manager from .middleware.rate_limit import shutdown_rate_limiter from .workers.queue import shutdown_task_queue + from .enterprise.dal import shutdown_dal + from .enterprise.write_buffer import shutdown_write_buffer await shutdown_task_queue() logger.info("Enterprise: Task queue stopped") @@ -104,6 +123,12 @@ async def shutdown_enterprise_components(): await shutdown_lock_manager() logger.info("Enterprise: Lock manager stopped") + await shutdown_write_buffer() + logger.info("Enterprise: Write buffer stopped (flushed pending writes)") + + await shutdown_dal() + logger.info("Enterprise: Data Access Layer stopped") + await shutdown_cache() logger.info("Enterprise: Cache layer stopped") diff --git a/mywebdav/webdav.py b/mywebdav/webdav.py index 22082ae..d71b4a7 100644 --- a/mywebdav/webdav.py +++ b/mywebdav/webdav.py @@ -75,7 +75,13 @@ async def basic_auth(authorization: Optional[str] = Header(None)): decoded = base64.b64decode(credentials).decode("utf-8") username, password = decoded.split(":", 1) - user = await User.get_or_none(username=username) + try: + from .enterprise.dal import get_dal + dal = get_dal() + user = await dal.get_user_by_username(username) + except RuntimeError: + user = await User.get_or_none(username=username) + if user and verify_password(password, user.hashed_password): return user except (ValueError, UnicodeDecodeError, base64.binascii.Error): @@ -85,12 +91,10 @@ async def basic_auth(authorization: Optional[str] = Header(None)): async def webdav_auth(request: Request, authorization: Optional[str] = Header(None)): - # First, try Basic Auth, which is common for WebDAV clients user = await basic_auth(authorization) if user: return user - # If Basic Auth fails or is not provided, try to authenticate using the session cookie token = request.cookies.get("access_token") if token: try: @@ -99,14 +103,17 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No ) username: str = payload.get("sub") if username: - user = await User.get_or_none(username=username) + try: + from .enterprise.dal import get_dal + dal = get_dal() + user = await dal.get_user_by_username(username) + except RuntimeError: + user = await User.get_or_none(username=username) if user: return user except JWTError: - # Token is invalid, fall through to the final exception pass - # If all authentication methods fail, raise 401 raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": 'Basic realm="MyWebdav WebDAV"'}, @@ -114,34 +121,29 @@ async def webdav_auth(request: Request, authorization: Optional[str] = Header(No async def resolve_path(path_str: str, user: User): - """ - Resolves a path string to a resource. - Returns a tuple: (resource, parent_folder, exists) - - resource: The File or Folder object at the path, or None if not found. - - parent_folder: The parent Folder object, or None if root. - - exists: Boolean indicating if the resource at the given path exists. - """ + try: + from .enterprise.dal import get_dal + dal = get_dal() + return await dal.resolve_path(user.id, path_str) + except RuntimeError: + pass + path_str = path_str.strip("/") - if not path_str: # Root directory - # The root exists conceptually, but has no specific resource object. - # It contains top-level files and folders. + if not path_str: return None, None, True parts = [p for p in path_str.split("/") if p] current_folder = None - # Traverse the path to find the parent of the target resource for i, part in enumerate(parts[:-1]): folder = await Folder.get_or_none( name=part, parent=current_folder, owner=user, is_deleted=False ) if not folder: - # A component in the middle of the path does not exist, so the full path cannot exist. return None, None, False current_folder = folder last_part = parts[-1] - # Check for the target resource itself (can be a folder or a file) folder = await Folder.get_or_none( name=last_part, parent=current_folder, owner=user, is_deleted=False ) @@ -154,7 +156,6 @@ async def resolve_path(path_str: str, user: User): if file: return file, current_folder, True - # The resource itself was not found, but the path to its parent is valid. return None, current_folder, False @@ -165,6 +166,38 @@ def build_href(base_path: str, name: str, is_collection: bool): return path +async def invalidate_cache_for_path(user_id: int, path_str: str, parent_id: Optional[int] = None): + try: + from .enterprise.dal import get_dal + dal = get_dal() + await dal.invalidate_path(user_id, path_str) + if parent_id is not None: + await dal.invalidate_folder(user_id, parent_id) + else: + await dal.invalidate_folder(user_id, None) + except RuntimeError: + pass + + +async def invalidate_cache_for_file(user_id: int, file_id: int, parent_id: Optional[int] = None): + try: + from .enterprise.dal import get_dal + dal = get_dal() + await dal.invalidate_file(user_id, file_id, parent_id) + except RuntimeError: + pass + + +async def invalidate_cache_for_folder(user_id: int, folder_id: Optional[int] = None): + try: + from .enterprise.dal import get_dal + dal = get_dal() + await dal.invalidate_folder(user_id, folder_id) + await dal.invalidate_user_paths(user_id) + except RuntimeError: + pass + + async def get_custom_properties(resource_type: str, resource_id: int): props = await WebDAVProperty.filter( resource_type=resource_type, resource_id=resource_id @@ -314,12 +347,17 @@ async def handle_propfind( res_href = base_href if isinstance(resource, File) else (base_href if base_href.endswith('/') else base_href + '/') await add_resource_to_response(resource, res_href) - # If depth is 1 or infinity, add children if depth in ["1", "infinity"]: target_folder = resource if isinstance(resource, Folder) else (None if full_path_str == "" else parent_folder) - - folders = await Folder.filter(owner=current_user, parent=target_folder, is_deleted=False) - files = await File.filter(owner=current_user, parent=target_folder, is_deleted=False) + target_folder_id = target_folder.id if target_folder else None + + try: + from .enterprise.dal import get_dal + dal = get_dal() + folders, files = await dal.get_folder_contents(current_user.id, target_folder_id) + except RuntimeError: + folders = await Folder.filter(owner=current_user, parent=target_folder, is_deleted=False) + files = await File.filter(owner=current_user, parent=target_folder, is_deleted=False) for folder in folders: child_href = build_href(base_href, folder.name, True) @@ -410,6 +448,7 @@ async def handle_put( name=file_name, parent=parent_folder, owner=current_user, is_deleted=False ) + parent_id = parent_folder.id if parent_folder else None if existing_file: old_size = existing_file.size existing_file.path, existing_file.size, existing_file.mime_type, existing_file.file_hash = \ @@ -418,6 +457,7 @@ async def handle_put( await existing_file.save() current_user.used_storage_bytes += (file_size - old_size) await current_user.save() + await invalidate_cache_for_file(current_user.id, existing_file.id, parent_id) await log_activity(current_user, "file_updated", "file", existing_file.id) return Response(status_code=204) else: @@ -427,6 +467,7 @@ async def handle_put( ) current_user.used_storage_bytes += file_size await current_user.save() + await invalidate_cache_for_path(current_user.id, full_path, parent_id) await log_activity(current_user, "file_created", "file", db_file.id) return Response(status_code=201) @@ -445,19 +486,23 @@ async def handle_delete( raise HTTPException(status_code=404, detail="Resource not found") if isinstance(resource, File): + parent_id = resource.parent_id resource.is_deleted = True resource.deleted_at = datetime.now() await resource.save() + await invalidate_cache_for_file(current_user.id, resource.id, parent_id) await log_activity(current_user, "file_deleted", "file", resource.id) elif isinstance(resource, Folder): child_files = await File.filter(parent=resource, is_deleted=False).count() child_folders = await Folder.filter(parent=resource, is_deleted=False).count() if child_files > 0 or child_folders > 0: raise HTTPException(status_code=409, detail="Folder is not empty") - + + parent_id = resource.parent_id resource.is_deleted = True resource.deleted_at = datetime.now() await resource.save() + await invalidate_cache_for_folder(current_user.id, parent_id) await log_activity(current_user, "folder_deleted", "folder", resource.id) return Response(status_code=204) @@ -485,8 +530,10 @@ async def handle_mkcol( raise HTTPException(status_code=409, detail="Parent collection does not exist") parent_folder = parent_resource if isinstance(parent_resource, Folder) else None - + parent_id = parent_folder.id if parent_folder else None + folder = await Folder.create(name=folder_name, parent=parent_folder, owner=current_user) + await invalidate_cache_for_folder(current_user.id, parent_id) await log_activity(current_user, "folder_created", "folder", folder.id) return Response(status_code=201) @@ -525,14 +572,15 @@ async def handle_copy( if existing_dest_exists and overwrite == "F": raise HTTPException(status_code=412, detail="Destination exists and Overwrite is 'F'") + dest_parent_id = dest_parent_folder.id if dest_parent_folder else None if existing_dest_exists: - # Update existing_dest instead of deleting and recreating existing_dest.path = source_resource.path existing_dest.size = source_resource.size existing_dest.mime_type = source_resource.mime_type existing_dest.file_hash = source_resource.file_hash existing_dest.updated_at = datetime.now() await existing_dest.save() + await invalidate_cache_for_file(current_user.id, existing_dest.id, dest_parent_id) await log_activity(current_user, "file_copied", "file", existing_dest.id) return Response(status_code=204) else: @@ -541,6 +589,7 @@ async def handle_copy( mime_type=source_resource.mime_type, file_hash=source_resource.file_hash, owner=current_user, parent=dest_parent_folder ) + await invalidate_cache_for_path(current_user.id, dest_path, dest_parent_id) await log_activity(current_user, "file_copied", "file", new_file.id) return Response(status_code=201) @@ -577,10 +626,11 @@ async def handle_move( if existing_dest_exists and overwrite == "F": raise HTTPException(status_code=412, detail="Destination exists and Overwrite is 'F'") - # Handle overwrite scenario + dest_parent_id = dest_parent_folder.id if dest_parent_folder else None + source_parent_id = source_resource.parent_id if hasattr(source_resource, 'parent_id') else None + if existing_dest_exists: if isinstance(source_resource, File): - # Update existing_dest with source_resource's properties existing_dest.name = dest_name existing_dest.path = source_resource.path existing_dest.size = source_resource.size @@ -589,19 +639,19 @@ async def handle_move( existing_dest.parent = dest_parent_folder existing_dest.updated_at = datetime.now() await existing_dest.save() + await invalidate_cache_for_file(current_user.id, existing_dest.id, dest_parent_id) await log_activity(current_user, "file_moved_overwrite", "file", existing_dest.id) elif isinstance(source_resource, Folder): raise HTTPException(status_code=501, detail="Folder move overwrite not implemented") - - # Mark source as deleted + source_resource.is_deleted = True source_resource.deleted_at = datetime.now() await source_resource.save() + await invalidate_cache_for_file(current_user.id, source_resource.id, source_parent_id) await log_activity(current_user, "file_deleted_after_move", "file", source_resource.id) - return Response(status_code=204) # Overwrite means 204 No Content + return Response(status_code=204) - # Handle non-overwrite scenario (create new resource at destination) else: if isinstance(source_resource, File): new_file = await File.create( @@ -609,17 +659,18 @@ async def handle_move( mime_type=source_resource.mime_type, file_hash=source_resource.file_hash, owner=current_user, parent=dest_parent_folder ) + await invalidate_cache_for_path(current_user.id, dest_path, dest_parent_id) await log_activity(current_user, "file_moved_created", "file", new_file.id) elif isinstance(source_resource, Folder): raise HTTPException(status_code=501, detail="Folder move not implemented") - # Mark source as deleted source_resource.is_deleted = True source_resource.deleted_at = datetime.now() await source_resource.save() + await invalidate_cache_for_file(current_user.id, source_resource.id, source_parent_id) await log_activity(current_user, "file_deleted_after_move", "file", source_resource.id) - return Response(status_code=201) # New resource created means 201 Created + return Response(status_code=201) @router.api_route("/{full_path:path}", methods=["LOCK"])