From 7abced9315e5ae8cccf626ad27bc9fc515b5f8eb Mon Sep 17 00:00:00 2001 From: retoor Date: Sun, 9 Nov 2025 08:14:14 +0100 Subject: [PATCH] Update. --- requirements.txt | 1 + retoors/helpers/auth.py | 6 +- retoors/main.py | 4 +- retoors/middlewares.py | 2 +- retoors/routes.py | 1 + retoors/services/file_service.py | 289 +++++++++++++--------- retoors/services/storage_service.py | 172 +++++++++++++ retoors/services/user_service.py | 162 +++++++++---- retoors/templates/pages/add_user.html | 2 +- retoors/templates/pages/edit_user.html | 2 +- retoors/templates/pages/errors/404.html | 4 + retoors/templates/pages/favorites.html | 116 ++++++++- retoors/templates/pages/recent.html | 116 ++++++++- retoors/templates/pages/shared.html | 150 +++++++++++- retoors/templates/pages/trash.html | 109 ++++++++- retoors/views/admin.py | 77 +++--- retoors/views/auth.py | 12 +- retoors/views/migrate.py | 11 + retoors/views/site.py | 72 ++++-- tests/conftest.py | 308 ++---------------------- tests/test_auth.py | 24 +- tests/test_file_browser.py | 151 +++++++----- tests/test_storage_service.py | 183 ++++++++++++++ tests/test_user_service.py | 122 +++++----- 24 files changed, 1439 insertions(+), 657 deletions(-) create mode 100644 retoors/services/storage_service.py create mode 100644 retoors/views/migrate.py create mode 100644 tests/test_storage_service.py diff --git a/requirements.txt b/requirements.txt index f057cfc..7214ae4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ bcrypt python-dotenv aiosmtplib aiojobs +aiofiles pytest pytest-aiohttp aiohttp-test-utils diff --git a/retoors/helpers/auth.py b/retoors/helpers/auth.py index 730ba1c..6313dfe 100644 --- a/retoors/helpers/auth.py +++ b/retoors/helpers/auth.py @@ -17,14 +17,12 @@ def login_required(func): raise web.HTTPFound('/login') user_service: UserService = request.app["user_service"] - user = user_service.get_user_by_email(user_email) + user = await user_service.get_user_by_email(user_email) if not user: - # User not found in service, clear session and redirect to login session.pop('user_email', None) raise web.HTTPFound('/login') - - # Ensure the user object is available in the request for views + request["user"] = user return await func(self, *args, **kwargs) return wrapper diff --git a/retoors/main.py b/retoors/main.py index 62b297f..fd6a487 100644 --- a/retoors/main.py +++ b/retoors/main.py @@ -18,9 +18,9 @@ from .helpers.env_manager import ensure_env_file_exists, get_or_create_session_s async def setup_services(app: web.Application): base_path = Path(__file__).parent data_path = base_path.parent / "data" - app["user_service"] = UserService(data_path / "users.json") + app["user_service"] = UserService(use_isolated_storage=True) app["config_service"] = ConfigService(data_path / "config.json") - app["file_service"] = FileService(data_path / "user_files", data_path / "users.json") # Instantiate FileService + app["file_service"] = FileService(data_path / "user_files", data_path / "users.json") # Setup aiojobs scheduler app["scheduler"] = aiojobs.Scheduler() diff --git a/retoors/middlewares.py b/retoors/middlewares.py index 227a85f..eb270fd 100644 --- a/retoors/middlewares.py +++ b/retoors/middlewares.py @@ -12,7 +12,7 @@ async def user_middleware(request, handler): request["user"] = None if "user_email" in session: user_service = request.app["user_service"] - request["user"] = user_service.get_user_by_email(session["user_email"]) + request["user"] = await user_service.get_user_by_email(session["user_email"]) return await handler(request) diff --git a/retoors/routes.py b/retoors/routes.py index d939a02..0bccfff 100644 --- a/retoors/routes.py +++ b/retoors/routes.py @@ -1,6 +1,7 @@ from .views.auth import LoginView, RegistrationView, LogoutView, ForgotPasswordView, ResetPasswordView from .views.site import SiteView, OrderView, FileBrowserView, UserManagementView from .views.upload import UploadView +from .views.migrate import MigrateView from .views.admin import get_users, add_user, update_user_quota, delete_user, get_user_details, delete_team diff --git a/retoors/services/file_service.py b/retoors/services/file_service.py index d3f3988..0ea6729 100644 --- a/retoors/services/file_service.py +++ b/retoors/services/file_service.py @@ -5,6 +5,9 @@ import shutil import uuid import datetime import logging +import hashlib +import os +from .storage_service import StorageService logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -14,121 +17,152 @@ handler.setFormatter(formatter) logger.addHandler(handler) class FileService: - def __init__(self, base_dir: Path, users_data_path: Path): + def __init__(self, base_dir: Path, user_service): self.base_dir = base_dir - self.users_data_path = users_data_path + self.user_service = user_service + self.storage = StorageService() self.base_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"FileService initialized with base_dir: {self.base_dir} and users_data_path: {self.users_data_path}") + self.drives_dir = self.base_dir / "drives" + self.drives_dir.mkdir(exist_ok=True) + self.drives = ["drive1", "drive2", "drive3"] + for drive in self.drives: + (self.drives_dir / drive).mkdir(exist_ok=True) + self.drive_counter = 0 + logger.info(f"FileService initialized with base_dir: {self.base_dir} and user_service") - async def _load_users_data(self): - """Loads user data from the JSON file.""" - if not self.users_data_path.exists(): - logger.warning(f"users_data_path does not exist: {self.users_data_path}") - return [] - async with aiofiles.open(self.users_data_path, mode="r") as f: - content = await f.read() - try: - return json.loads(content) if content else [] - except json.JSONDecodeError: - logger.error(f"JSONDecodeError when loading users data from {self.users_data_path}") - return [] + def _choose_drive(self): + drive = self.drives[self.drive_counter % len(self.drives)] + self.drive_counter += 1 + return drive - async def _save_users_data(self, data): - """Saves user data to the JSON file.""" - async with aiofiles.open(self.users_data_path, mode="w") as f: - await f.write(json.dumps(data, indent=4)) - logger.debug(f"Saved users data to {self.users_data_path}") + async def _load_metadata(self, user_email: str): + metadata = await self.storage.load(user_email, "file_metadata") + return metadata if metadata else {} - def _get_user_file_path(self, user_email: str, relative_path: str = "") -> Path: - """Constructs the absolute path for a user's file or directory.""" - user_dir = self.base_dir / user_email - full_path = user_dir / relative_path - logger.debug(f"Constructed path for user '{user_email}', relative_path '{relative_path}': {full_path}") + async def _save_metadata(self, user_email: str, metadata: dict): + await self.storage.save(user_email, "file_metadata", metadata) + + + + def get_user_file_system_path(self, user_email: str, item_path: str) -> Path: + """ + Constructs the absolute file system path for a user's item. + This is for internal use to resolve virtual paths to physical paths. + """ + user_base_path = self.storage._get_user_base_path(user_email) + # Ensure item_path is treated as relative to user_base_path + # and prevent directory traversal attacks. + full_path = user_base_path / item_path + if not self.storage._validate_path(full_path, user_base_path): + raise ValueError("Invalid item path: directory traversal detected") return full_path async def list_files(self, user_email: str, path: str = "") -> list: """Lists files and directories for a given user within a specified path.""" - user_path = self._get_user_file_path(user_email, path) - if not user_path.is_dir(): - logger.warning(f"list_files: User path is not a directory or does not exist: {user_path}") - return [] - - files_list = [] - for item in user_path.iterdir(): - if item.name.startswith('.'): # Ignore hidden files/directories - continue - file_info = { - "name": item.name, - "is_dir": item.is_dir(), - "path": str(item.relative_to(self._get_user_file_path(user_email))), - "size": item.stat().st_size if item.is_file() else 0, - "last_modified": datetime.datetime.fromtimestamp(item.stat().st_mtime).isoformat(), - } - files_list.append(file_info) - logger.debug(f"Listed {len(files_list)} items for user '{user_email}' in path '{path}'") - return sorted(files_list, key=lambda x: (not x["is_dir"], x["name"].lower())) + metadata = await self._load_metadata(user_email) + # Normalize path + if path and not path.endswith('/'): + path += '/' + items = [] + seen = set() + for item_path, item_meta in metadata.items(): + if item_path.startswith(path): + remaining = item_path[len(path):].rstrip('/') + if '/' not in remaining and remaining not in seen: + seen.add(remaining) + if remaining: # not the path itself + items.append({ + "name": remaining, + "is_dir": item_meta.get("type") == "dir", + "path": item_path, + "size": item_meta.get("size", 0), + "last_modified": item_meta.get("modified_at", ""), + }) + logger.debug(f"Listed {len(items)} items for user '{user_email}' in path '{path}'") + return sorted(items, key=lambda x: (not x["is_dir"], x["name"].lower())) async def create_folder(self, user_email: str, folder_path: str) -> bool: """Creates a new folder for the user.""" - full_path = self._get_user_file_path(user_email, folder_path) - if full_path.exists(): - logger.warning(f"create_folder: Folder already exists: {full_path}") - return False # Folder already exists - full_path.mkdir(parents=True, exist_ok=True) - logger.info(f"create_folder: Folder created: {full_path}") + metadata = await self._load_metadata(user_email) + if folder_path in metadata: + logger.warning(f"create_folder: Folder already exists: {folder_path}") + return False + metadata[folder_path] = { + "type": "dir", + "created_at": datetime.datetime.now().isoformat(), + "modified_at": datetime.datetime.now().isoformat(), + } + await self._save_metadata(user_email, metadata) + logger.info(f"create_folder: Folder created: {folder_path}") return True async def upload_file(self, user_email: str, file_path: str, content: bytes) -> bool: """Uploads a file for the user.""" - full_path = self._get_user_file_path(user_email, file_path) - full_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent directories exist - async with aiofiles.open(full_path, mode="wb") as f: + hash = hashlib.sha256(content).hexdigest() + drive = self._choose_drive() + dir1 = hash[:3] + dir2 = hash[3:6] + dir3 = hash[6:9] + blob_path = self.drives_dir / drive / dir1 / dir2 / dir3 / hash + blob_path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(blob_path, 'wb') as f: await f.write(content) - logger.info(f"upload_file: File uploaded to: {full_path}") + metadata = await self._load_metadata(user_email) + metadata[file_path] = { + "type": "file", + "size": len(content), + "hash": hash, + "blob_location": {"drive": drive, "path": f"{dir1}/{dir2}/{dir3}/{hash}"}, + "created_at": datetime.datetime.now().isoformat(), + "modified_at": datetime.datetime.now().isoformat(), + } + await self._save_metadata(user_email, metadata) + logger.info(f"upload_file: File uploaded to drive {drive}: {file_path}") return True async def download_file(self, user_email: str, file_path: str) -> tuple[bytes, str] | None: """Downloads a file for the user.""" - full_path = self._get_user_file_path(user_email, file_path) - logger.debug(f"download_file: Attempting to download file from: {full_path}") - if full_path.is_file(): - async with aiofiles.open(full_path, mode="rb") as f: - content = await f.read() - logger.info(f"download_file: Successfully read file: {full_path}") - return content, full_path.name - logger.warning(f"download_file: File not found or is not a file: {full_path}") - return None + metadata = await self._load_metadata(user_email) + if file_path not in metadata or metadata[file_path]["type"] != "file": + logger.warning(f"download_file: File not found in metadata: {file_path}") + return None + item_meta = metadata[file_path] + blob_loc = item_meta["blob_location"] + blob_path = self.drives_dir / blob_loc["drive"] / blob_loc["path"] + if not blob_path.exists(): + logger.warning(f"download_file: Blob not found: {blob_path}") + return None + async with aiofiles.open(blob_path, 'rb') as f: + content = await f.read() + logger.info(f"download_file: Successfully read file: {file_path}") + return content, Path(file_path).name async def delete_item(self, user_email: str, item_path: str) -> bool: """Deletes a file or folder for the user.""" - full_path = self._get_user_file_path(user_email, item_path) - logger.debug(f"delete_item: Attempting to delete item: {full_path}") - if not full_path.exists(): - logger.warning(f"delete_item: Item does not exist: {full_path}") + metadata = await self._load_metadata(user_email) + if item_path not in metadata: + logger.warning(f"delete_item: Item not found: {item_path}") return False - - if full_path.is_file(): - full_path.unlink() - logger.info(f"delete_item: File deleted: {full_path}") - elif full_path.is_dir(): - shutil.rmtree(full_path) - logger.info(f"delete_item: Directory deleted: {full_path}") + # If dir, remove all under it + to_delete = [p for p in metadata if p == item_path or p.startswith(item_path + '/')] + for p in to_delete: + del metadata[p] + await self._save_metadata(user_email, metadata) + logger.info(f"delete_item: Item deleted: {item_path}") return True async def generate_share_link(self, user_email: str, item_path: str) -> str | None: """Generates a shareable link for a file or folder.""" logger.debug(f"generate_share_link: Generating link for user '{user_email}', item '{item_path}'") - users_data = await self._load_users_data() - user = next((u for u in users_data if u.get("email") == user_email), None) + metadata = await self._load_metadata(user_email) + if item_path not in metadata: + logger.warning(f"generate_share_link: Item does not exist: {item_path}") + return None + user = await self.user_service.get_user_by_email(user_email) if not user: logger.warning(f"generate_share_link: User not found: {user_email}") return None - full_path = self._get_user_file_path(user_email, item_path) - if not full_path.exists(): - logger.warning(f"generate_share_link: Item does not exist: {full_path}") - return None - share_id = str(uuid.uuid4()) if "shared_items" not in user: user["shared_items"] = {} @@ -138,17 +172,17 @@ class FileService: "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat(), # 7-day expiry } - await self._save_users_data(users_data) + await self.user_service.update_user(user_email, shared_items=user["shared_items"]) logger.info(f"generate_share_link: Share link generated with ID: {share_id} for item: {item_path}") return share_id async def get_shared_item(self, share_id: str) -> dict | None: """Retrieves information about a shared item.""" logger.debug(f"get_shared_item: Retrieving shared item with ID: {share_id}") - users_data = await self._load_users_data() - for user_info in users_data: - if "shared_items" in user_info and share_id in user_info["shared_items"]: - shared_item = user_info["shared_items"][share_id] + all_users = await self.user_service.get_all_users() + for user in all_users: + if "shared_items" in user and share_id in user["shared_items"]: + shared_item = user["shared_items"][share_id] expiry_time = datetime.datetime.fromisoformat(shared_item["expires_at"]) if expiry_time > datetime.datetime.now(datetime.timezone.utc): logger.info(f"get_shared_item: Found valid shared item for ID: {share_id}") @@ -168,27 +202,15 @@ class FileService: user_email = shared_item["user_email"] item_path = shared_item["item_path"] # This is the path of the originally shared item (file or folder) - # Construct the full path to the originally shared item - full_shared_item_path = self._get_user_file_path(user_email, item_path) - - target_file_path = full_shared_item_path + target_path = item_path if requested_file_path: - # If a specific file within a shared folder is requested - target_file_path = self._get_user_file_path(user_email, requested_file_path) + target_path = requested_file_path # Security check: Ensure the requested file is actually within the shared item's directory - try: - target_file_path.relative_to(full_shared_item_path) - except ValueError: + if not target_path.startswith(item_path + '/'): logger.warning(f"get_shared_file_content: Requested file path '{requested_file_path}' is not within shared item path '{item_path}' for share_id: {share_id}") return None - if target_file_path.is_file(): - async with aiofiles.open(target_file_path, mode="rb") as f: - content = await f.read() - logger.info(f"get_shared_file_content: Successfully read content for shared file: {target_file_path}") - return content, target_file_path.name - logger.warning(f"get_shared_file_content: Shared item path is not a file or does not exist: {target_file_path}") - return None + return await self.download_file(user_email, target_path) async def get_shared_folder_content(self, share_id: str) -> list | None: """Retrieves the content of a shared folder.""" @@ -199,10 +221,61 @@ class FileService: user_email = shared_item["user_email"] item_path = shared_item["item_path"] - full_path = self._get_user_file_path(user_email, item_path) + metadata = await self._load_metadata(user_email) + if item_path not in metadata or metadata[item_path]["type"] != "dir": + logger.warning(f"get_shared_folder_content: Shared item is not a directory: {item_path}") + return None - if full_path.is_dir(): - logger.info(f"get_shared_folder_content: Listing files for shared folder: {full_path}") - return await self.list_files(user_email, item_path) - logger.warning(f"get_shared_folder_content: Shared item path is not a directory or does not exist: {full_path}") - return None + logger.info(f"get_shared_folder_content: Listing files for shared folder: {item_path}") + return await self.list_files(user_email, item_path) + + async def migrate_old_files(self, user_email: str): + """Migrate existing files from old file system to virtual system.""" + old_user_dir = self.base_dir / user_email + if not old_user_dir.exists(): + logger.info(f"No old files to migrate for {user_email}") + return + metadata = await self._load_metadata(user_email) + migrated_count = 0 + for root, dirs, files in os.walk(old_user_dir): + rel_root = Path(root).relative_to(old_user_dir) + for dir_name in dirs: + dir_path = str(rel_root / dir_name) if str(rel_root) != "." else dir_name + if dir_path not in metadata: + metadata[dir_path] = { + "type": "dir", + "created_at": datetime.datetime.now().isoformat(), + "modified_at": datetime.datetime.now().isoformat(), + } + migrated_count += 1 + for file_name in files: + file_path = rel_root / file_name + str_file_path = str(file_path) if str(rel_root) != "." else file_name + if str_file_path in metadata: + continue # Already migrated + full_file_path = old_user_dir / file_path + with open(full_file_path, 'rb') as f: + content = f.read() + hash = hashlib.sha256(content).hexdigest() + drive = self._choose_drive() + dir1 = hash[:3] + dir2 = hash[3:6] + dir3 = hash[6:9] + blob_path = self.drives_dir / drive / dir1 / dir2 / dir3 / hash + if not blob_path.exists(): + blob_path.parent.mkdir(parents=True, exist_ok=True) + with open(blob_path, 'wb') as f: + f.write(content) + metadata[str_file_path] = { + "type": "file", + "size": len(content), + "hash": hash, + "blob_location": {"drive": drive, "path": f"{dir1}/{dir2}/{dir3}/{hash}"}, + "created_at": datetime.datetime.fromtimestamp(full_file_path.stat().st_ctime).isoformat(), + "modified_at": datetime.datetime.fromtimestamp(full_file_path.stat().st_mtime).isoformat(), + } + migrated_count += 1 + await self._save_metadata(user_email, metadata) + logger.info(f"Migrated {migrated_count} items for {user_email}") + # Optionally remove old dir + # shutil.rmtree(old_user_dir) diff --git a/retoors/services/storage_service.py b/retoors/services/storage_service.py new file mode 100644 index 0000000..78916bb --- /dev/null +++ b/retoors/services/storage_service.py @@ -0,0 +1,172 @@ +import os +import json +import hashlib +import aiofiles +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class StorageService: + + def __init__(self, base_path: str = "data/user"): + self.base_path = Path(base_path) + self.base_path.mkdir(parents=True, exist_ok=True) + + def _hash(self, value: str) -> str: + return hashlib.sha256(value.encode()).hexdigest() + + def _get_user_base_path(self, user_email: str) -> Path: + user_hash = self._hash(user_email) + return self.base_path / user_hash + + def _get_distributed_path(self, base_path: Path, identifier: str) -> Path: + obj_hash = self._hash(identifier) + dir1 = obj_hash[:3] + dir2 = obj_hash[3:6] + dir3 = obj_hash[6:9] + return base_path / dir1 / dir2 / dir3 / f"{obj_hash}.json" + + def _validate_path(self, path: Path, expected_base: Path) -> bool: + try: + resolved_path = path.resolve() + resolved_base = expected_base.resolve() + return str(resolved_path).startswith(str(resolved_base)) + except (ValueError, OSError): + return False + + async def save(self, user_email: str, identifier: str, data: Dict[str, Any]) -> bool: + user_base = self._get_user_base_path(user_email) + file_path = self._get_distributed_path(user_base, identifier) + + if not self._validate_path(file_path, user_base): + raise ValueError("Invalid path: directory traversal detected") + + file_path.parent.mkdir(parents=True, exist_ok=True) + + async with aiofiles.open(file_path, 'w') as f: + await f.write(json.dumps(data, indent=2)) + + return True + + async def load(self, user_email: str, identifier: str) -> Optional[Dict[str, Any]]: + user_base = self._get_user_base_path(user_email) + file_path = self._get_distributed_path(user_base, identifier) + + if not self._validate_path(file_path, user_base): + raise ValueError("Invalid path: directory traversal detected") + + if not file_path.exists(): + return None + + async with aiofiles.open(file_path, 'r') as f: + content = await f.read() + if not content: + return {} + try: + return json.loads(content) + except json.JSONDecodeError: + return {} + + async def delete(self, user_email: str, identifier: str) -> bool: + user_base = self._get_user_base_path(user_email) + file_path = self._get_distributed_path(user_base, identifier) + + if not self._validate_path(file_path, user_base): + raise ValueError("Invalid path: directory traversal detected") + + if file_path.exists(): + file_path.unlink() + return True + + return False + + async def exists(self, user_email: str, identifier: str) -> bool: + user_base = self._get_user_base_path(user_email) + file_path = self._get_distributed_path(user_base, identifier) + + if not self._validate_path(file_path, user_base): + raise ValueError("Invalid path: directory traversal detected") + + return file_path.exists() + + async def list_all(self, user_email: str) -> List[Dict[str, Any]]: + user_base = self._get_user_base_path(user_email) + + if not user_base.exists(): + return [] + + results = [] + for json_file in user_base.rglob("*.json"): + if self._validate_path(json_file, user_base): + async with aiofiles.open(json_file, 'r') as f: + content = await f.read() + results.append(json.loads(content)) + + return results + + async def delete_all(self, user_email: str) -> bool: + user_base = self._get_user_base_path(user_email) + + if not user_base.exists(): + return False + + import shutil + shutil.rmtree(user_base) + return True + + def get_user_storage_path(self, user_email: str) -> str: + return str(self._get_user_base_path(user_email)) + + +class UserStorageManager: + + def __init__(self): + self.storage = StorageService() + + async def save_user(self, user_email: str, user_data: Dict[str, Any]) -> bool: + return await self.storage.save(user_email, user_email, user_data) + + async def get_user(self, user_email: str) -> Optional[Dict[str, Any]]: + return await self.storage.load(user_email, user_email) + + async def delete_user(self, user_email: str) -> bool: + return await self.storage.delete_all(user_email) + + async def user_exists(self, user_email: str) -> bool: + return await self.storage.exists(user_email, user_email) + + async def list_users_by_parent(self, parent_email: str) -> List[Dict[str, Any]]: + all_users = [] + + base_path = Path(self.storage.base_path) + if not base_path.exists(): + return [] + + for user_dir in base_path.iterdir(): + if user_dir.is_dir(): + user_files = list(user_dir.rglob("*.json")) + for user_file in user_files: + async with aiofiles.open(user_file, 'r') as f: + content = await f.read() + user_data = json.loads(content) + if user_data.get('parent_email') == parent_email: + all_users.append(user_data) + + return all_users + + async def get_all_users(self) -> List[Dict[str, Any]]: + all_users = [] + + base_path = Path(self.storage.base_path) + if not base_path.exists(): + return [] + + for user_dir in base_path.iterdir(): + if user_dir.is_dir(): + user_files = list(user_dir.rglob("*.json")) + for user_file in user_files: + async with aiofiles.open(user_file, 'r') as f: + content = await f.read() + all_users.append(json.loads(content)) + + return all_users diff --git a/retoors/services/user_service.py b/retoors/services/user_service.py index 9924b8f..e63e682 100644 --- a/retoors/services/user_service.py +++ b/retoors/services/user_service.py @@ -1,15 +1,22 @@ import json from pathlib import Path from typing import List, Dict, Optional, Any -import bcrypt # Import bcrypt -import secrets # For generating secure tokens -import datetime # For token expiry +import bcrypt +import secrets +import datetime +from .storage_service import UserStorageManager class UserService: - def __init__(self, users_path: Path): - self._users_path = users_path - self._users = self._load_users() + def __init__(self, users_path: Path = None, use_isolated_storage: bool = True): + self.use_isolated_storage = use_isolated_storage + + if use_isolated_storage: + self._storage_manager = UserStorageManager() + self._users = [] + else: + self._users_path = users_path + self._users = self._load_users() def _load_users(self) -> List[Dict[str, Any]]: if not self._users_path.exists(): @@ -21,52 +28,84 @@ class UserService: return [] def _save_users(self): + if self.use_isolated_storage: + return with open(self._users_path, "w") as f: json.dump(self._users, f, indent=4) - def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: + async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: + if self.use_isolated_storage: + return await self._storage_manager.get_user(email) return next((user for user in self._users if user["email"] == email), None) - def get_all_users(self) -> List[Dict[str, Any]]: + async def get_all_users(self) -> List[Dict[str, Any]]: + if self.use_isolated_storage: + return await self._storage_manager.get_all_users() return self._users - def get_users_by_parent_email(self, parent_email: str) -> List[Dict[str, Any]]: + async def get_users_by_parent_email(self, parent_email: str) -> List[Dict[str, Any]]: + if self.use_isolated_storage: + return await self._storage_manager.list_users_by_parent(parent_email) return [user for user in self._users if user.get("parent_email") == parent_email] - def create_user(self, full_name: str, email: str, password: str, parent_email: Optional[str] = None) -> Dict[str, Any]: - if self.get_user_by_email(email): + async def get_managed_users(self, user_email: str) -> List[Dict[str, Any]]: + user = await self.get_user_by_email(user_email) + if not user: + return [] + + if not user.get("is_customer", True): + return await self.get_all_users() + + return await self.get_users_by_parent_email(user_email) + + async def create_user(self, full_name: str, email: str, password: str, parent_email: Optional[str] = None, is_customer: bool = True) -> Dict[str, Any]: + existing_user = await self.get_user_by_email(email) + if existing_user: raise ValueError("User with this email already exists") - # Hash password with bcrypt hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') user = { "full_name": full_name, "email": email, "password": hashed_password, - "storage_quota_gb": 5, # Default quota + "storage_quota_gb": 5, "storage_used_gb": 0, "reset_token": None, "reset_token_expiry": None, - "parent_email": parent_email, # New field for hierarchical user management + "parent_email": parent_email, + "is_customer": is_customer, } - self._users.append(user) - self._save_users() + + if self.use_isolated_storage: + await self._storage_manager.save_user(email, user) + else: + self._users.append(user) + self._save_users() + return user - def update_user(self, email: str, **kwargs) -> Optional[Dict[str, Any]]: - user = self.get_user_by_email(email) + async def update_user(self, email: str, **kwargs) -> Optional[Dict[str, Any]]: + user = await self.get_user_by_email(email) if not user: return None - + for key, value in kwargs.items(): if key == "password": user[key] = bcrypt.hashpw(value.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') else: user[key] = value - self._save_users() + + if self.use_isolated_storage: + await self._storage_manager.save_user(email, user) + else: + self._save_users() + return user - def delete_user(self, email: str) -> bool: + async def delete_user(self, email: str) -> bool: + if self.use_isolated_storage: + return await self._storage_manager.delete_user(email) + initial_len = len(self._users) self._users = [user for user in self._users if user["email"] != email] if len(self._users) < initial_len: @@ -74,23 +113,32 @@ class UserService: return True return False - def delete_users_by_parent_email(self, parent_email: str) -> int: - initial_len = len(self._users) - self._users = [user for user in self._users if user.get("parent_email") != parent_email] - deleted_count = initial_len - len(self._users) - if deleted_count > 0: - self._save_users() + async def delete_users_by_parent_email(self, parent_email: str) -> int: + users_to_delete = await self.get_users_by_parent_email(parent_email) + deleted_count = 0 + + if self.use_isolated_storage: + for user in users_to_delete: + if await self._storage_manager.delete_user(user["email"]): + deleted_count += 1 + else: + initial_len = len(self._users) + self._users = [user for user in self._users if user.get("parent_email") != parent_email] + deleted_count = initial_len - len(self._users) + if deleted_count > 0: + self._save_users() + return deleted_count - def authenticate_user(self, email: str, password: str) -> bool: - user = self.get_user_by_email(email) + async def authenticate_user(self, email: str, password: str) -> bool: + user = await self.get_user_by_email(email) if not user: return False - # Verify password with bcrypt return bcrypt.checkpw(password.encode('utf-8'), user["password"].encode('utf-8')) - def get_user_by_reset_token(self, token: str) -> Optional[Dict[str, Any]]: - for user in self._users: + async def get_user_by_reset_token(self, token: str) -> Optional[Dict[str, Any]]: + all_users = await self.get_all_users() + for user in all_users: if user.get("reset_token") == token: expiry_str = user.get("reset_token_expiry") if expiry_str: @@ -99,49 +147,63 @@ class UserService: return user return None - def generate_reset_token(self, email: str) -> Optional[str]: - user = self.get_user_by_email(email) + async def generate_reset_token(self, email: str) -> Optional[str]: + user = await self.get_user_by_email(email) if not user: return None token = secrets.token_urlsafe(32) - expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) # Token valid for 1 hour + expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) user["reset_token"] = token user["reset_token_expiry"] = expiry.isoformat() - self._save_users() + + if self.use_isolated_storage: + await self._storage_manager.save_user(email, user) + else: + self._save_users() + return token - def validate_reset_token(self, email: str, token: str) -> bool: - user = self.get_user_by_email(email) + async def validate_reset_token(self, email: str, token: str) -> bool: + user = await self.get_user_by_email(email) if not user or user.get("reset_token") != token: return False - + expiry_str = user.get("reset_token_expiry") if not expiry_str: return False - + expiry = datetime.datetime.fromisoformat(expiry_str) return expiry > datetime.datetime.now(datetime.timezone.utc) - def reset_password(self, email: str, token: str, new_password: str) -> bool: - if not self.validate_reset_token(email, token): + async def reset_password(self, email: str, token: str, new_password: str) -> bool: + if not await self.validate_reset_token(email, token): return False - - user = self.get_user_by_email(email) - if not user: # Should not happen if validate_reset_token passed, but for type safety + + user = await self.get_user_by_email(email) + if not user: return False hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') user["password"] = hashed_password user["reset_token"] = None user["reset_token_expiry"] = None - self._save_users() + + if self.use_isolated_storage: + await self._storage_manager.save_user(email, user) + else: + self._save_users() + return True - def update_user_quota(self, email: str, new_quota_gb: float): - user = self.get_user_by_email(email) + async def update_user_quota(self, email: str, new_quota_gb: float): + user = await self.get_user_by_email(email) if not user: raise ValueError("User not found") - + user["storage_quota_gb"] = new_quota_gb - self._save_users() + + if self.use_isolated_storage: + await self._storage_manager.save_user(email, user) + else: + self._save_users() diff --git a/retoors/templates/pages/add_user.html b/retoors/templates/pages/add_user.html index 8ea4d36..9c4ba32 100644 --- a/retoors/templates/pages/add_user.html +++ b/retoors/templates/pages/add_user.html @@ -43,7 +43,7 @@
- +
diff --git a/retoors/templates/pages/edit_user.html b/retoors/templates/pages/edit_user.html index f66c85e..fc760d8 100644 --- a/retoors/templates/pages/edit_user.html +++ b/retoors/templates/pages/edit_user.html @@ -49,7 +49,7 @@
- +
diff --git a/retoors/templates/pages/errors/404.html b/retoors/templates/pages/errors/404.html index 6440dce..9bc4489 100644 --- a/retoors/templates/pages/errors/404.html +++ b/retoors/templates/pages/errors/404.html @@ -5,7 +5,11 @@ {% block content %}

404 - Page Not Found

+ {% if message %} +

{{ message }}

+ {% else %}

Oops! The page you're looking for doesn't exist or has been moved.

+ {% endif %} Go to Homepage
{% endblock %} diff --git a/retoors/templates/pages/favorites.html b/retoors/templates/pages/favorites.html index 44a3a46..6f58c53 100644 --- a/retoors/templates/pages/favorites.html +++ b/retoors/templates/pages/favorites.html @@ -2,11 +2,121 @@ {% block title %}Favorites - Retoor's Cloud Solutions{% endblock %} +{% block dashboard_head %} + +{% endblock %} + {% block page_title %}Favorites{% endblock %} +{% block dashboard_actions %} + + + +{% endblock %} + {% block dashboard_content %} -
-

Your favorite files and folders will appear here.

-

This feature is coming soon.

+ + {% if success_message %} +
+ {{ success_message }} +
+ {% endif %} + + {% if error_message %} +
+ {{ error_message }} +
+ {% endif %} + + + +
+ + + + + + + + + + + + + {% if favorite_files %} + {% for item in favorite_files %} + + + + + + + + + {% endfor %} + {% else %} + + + + {% endif %} + +
NameOwnerFavorited onSizeActions
+ {% if item.is_dir %} + Folder Icon + {{ item.name }} + {% else %} + File Icon + {{ item.name }} + {% endif %} + {{ user.email }}{{ item.favorited_on[:10] }} + {% if item.is_dir %} + -- + {% else %} + {{ (item.size / 1024 / 1024)|round(2) }} MB + {% endif %} + +
+ {% if not item.is_dir %} + + {% endif %} + + +
+
+

No favorite files.

+
+
+ + + + + + {% endblock %} \ No newline at end of file diff --git a/retoors/templates/pages/recent.html b/retoors/templates/pages/recent.html index 3ab6dbb..1ee754c 100644 --- a/retoors/templates/pages/recent.html +++ b/retoors/templates/pages/recent.html @@ -2,11 +2,121 @@ {% block title %}Recent Files - Retoor's Cloud Solutions{% endblock %} +{% block dashboard_head %} + +{% endblock %} + {% block page_title %}Recent Files{% endblock %} +{% block dashboard_actions %} + + + +{% endblock %} + {% block dashboard_content %} -
-

Your recently accessed files will appear here.

-

This feature is coming soon.

+ + {% if success_message %} +
+ {{ success_message }} +
+ {% endif %} + + {% if error_message %} +
+ {{ error_message }} +
+ {% endif %} + + + +
+ + + + + + + + + + + + + {% if recent_files %} + {% for item in recent_files %} + + + + + + + + + {% endfor %} + {% else %} + + + + {% endif %} + +
NameOwnerLast AccessedSizeActions
+ {% if item.is_dir %} + Folder Icon + {{ item.name }} + {% else %} + File Icon + {{ item.name }} + {% endif %} + {{ user.email }}{{ item.last_accessed[:10] }} + {% if item.is_dir %} + -- + {% else %} + {{ (item.size / 1024 / 1024)|round(2) }} MB + {% endif %} + +
+ {% if not item.is_dir %} + + {% endif %} + + +
+
+

No recent files.

+
+
+ + + + + + {% endblock %} \ No newline at end of file diff --git a/retoors/templates/pages/shared.html b/retoors/templates/pages/shared.html index 1eaf8dd..5d4b3b8 100644 --- a/retoors/templates/pages/shared.html +++ b/retoors/templates/pages/shared.html @@ -2,11 +2,155 @@ {% block title %}Shared with me - Retoor's Cloud Solutions{% endblock %} +{% block dashboard_head %} + +{% endblock %} + {% block page_title %}Shared with me{% endblock %} +{% block dashboard_actions %} + + + + + +{% endblock %} + {% block dashboard_content %} -
-

Files and folders that have been shared with you will appear here.

-

This feature is coming soon.

+ + {% if success_message %} +
+ {{ success_message }} +
+ {% endif %} + + {% if error_message %} +
+ {{ error_message }} +
+ {% endif %} + + + +
+ + + + + + + + + + + + + {% if shared_files %} + {% for item in shared_files %} + + + + + + + + + {% endfor %} + {% else %} + + + + {% endif %} + +
NameShared byLast ModifiedSizeActions
+ {% if item.is_dir %} + Folder Icon + {{ item.name }} + {% else %} + File Icon + {{ item.name }} + {% endif %} + {{ item.shared_by }}{{ item.last_modified[:10] }} + {% if item.is_dir %} + -- + {% else %} + {{ (item.size / 1024 / 1024)|round(2) }} MB + {% endif %} + +
+ {% if not item.is_dir %} + + {% endif %} + + +
+
+

No shared files found.

+
+
+ + + + + + + + + + + {% endblock %} \ No newline at end of file diff --git a/retoors/templates/pages/trash.html b/retoors/templates/pages/trash.html index 493c9ce..8011aa0 100644 --- a/retoors/templates/pages/trash.html +++ b/retoors/templates/pages/trash.html @@ -2,11 +2,114 @@ {% block title %}Trash - Retoor's Cloud Solutions{% endblock %} +{% block dashboard_head %} + +{% endblock %} + {% block page_title %}Trash{% endblock %} +{% block dashboard_actions %} + + +{% endblock %} + {% block dashboard_content %} -
-

Files and folders you have deleted will appear here.

-

This feature is coming soon.

+ + {% if success_message %} +
+ {{ success_message }} +
+ {% endif %} + + {% if error_message %} +
+ {{ error_message }} +
+ {% endif %} + + + +
+ + + + + + + + + + + + + {% if trash_files %} + {% for item in trash_files %} + + + + + + + + + {% endfor %} + {% else %} + + + + {% endif %} + +
NameOwnerDeleted onSizeActions
+ {% if item.is_dir %} + Folder Icon + {{ item.name }} + {% else %} + File Icon + {{ item.name }} + {% endif %} + {{ user.email }}{{ item.deleted_on[:10] }} + {% if item.is_dir %} + -- + {% else %} + {{ (item.size / 1024 / 1024)|round(2) }} MB + {% endif %} + +
+ + +
+
+

No files in trash.

+
+
+ + + + + + {% endblock %} \ No newline at end of file diff --git a/retoors/views/admin.py b/retoors/views/admin.py index 821eaad..6ee7f76 100644 --- a/retoors/views/admin.py +++ b/retoors/views/admin.py @@ -3,6 +3,25 @@ from aiohttp_session import get_session from ..services.user_service import UserService from ..models import QuotaUpdateModel, RegistrationModel + +async def verify_user_access(user_service: UserService, current_user_email: str, target_user_email: str) -> tuple[bool, dict]: + if target_user_email == current_user_email: + return True, None + + current_user = await user_service.get_user_by_email(current_user_email) + if not current_user: + return False, {"error": "Current user not found", "status": 401} + + target_user = await user_service.get_user_by_email(target_user_email) + if not target_user: + return False, {"error": "Target user not found", "status": 404} + + if current_user.get("is_customer", True): + if target_user.get("parent_email") != current_user_email: + return False, {"error": "Forbidden: You can only manage users you created", "status": 403} + + return True, None + async def get_users(request: web.Request) -> web.Response: user_service: UserService = request.app["user_service"] session = await get_session(request) @@ -11,18 +30,8 @@ async def get_users(request: web.Request) -> web.Response: if not current_user_email: return web.json_response({"error": "Unauthorized"}, status=401) - # For now, let's assume only the main user can see all users. - # In a real application, you'd have roles/permissions. - # The main user is the one who created the account. - # If the current user is the main user, they can see all users. - # Otherwise, they can only see users they created (their "team"). - - # This logic needs to be refined based on how "main user" is identified. - # For now, let's return all users for simplicity, assuming the logged-in user has admin-like access to this page. - # A more robust solution would involve checking if the current_user_email is the 'owner' of the site. - users = user_service.get_all_users() - - # Filter out sensitive information like password and reset tokens + users = await user_service.get_managed_users(current_user_email) + safe_users = [] for user in users: safe_user = {k: v for k, v in user.items() if k not in ["password", "reset_token", "reset_token_expiry"]} @@ -41,13 +50,12 @@ async def add_user(request: web.Request) -> web.Response: try: data = await request.json() registration_data = RegistrationModel(**data) - - # The current user is the parent of the new user - new_user = user_service.create_user( + + new_user = await user_service.create_user( full_name=registration_data.full_name, email=registration_data.email, password=registration_data.password, - parent_email=current_user_email # Assign current user as parent + parent_email=current_user_email ) safe_new_user = {k: v for k, v in new_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]} return web.json_response({"message": "User added successfully", "user": safe_new_user}, status=201) @@ -64,20 +72,18 @@ async def update_user_quota(request: web.Request) -> web.Response: if not current_user_email: return web.json_response({"error": "Unauthorized"}, status=401) - + if not target_user_email: return web.json_response({"error": "User email not provided"}, status=400) - # Ensure the current user has permission to update this user's quota - # For now, allow if current_user_email is the target_user_email or if target_user is a child of current_user - target_user = user_service.get_user_by_email(target_user_email) - if not target_user or (target_user_email != current_user_email and target_user.get("parent_email") != current_user_email): - return web.json_response({"error": "Forbidden: You do not have permission to update this user's quota"}, status=403) + has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email) + if not has_access: + return web.json_response({"error": error_response["error"]}, status=error_response["status"]) try: data = await request.json() quota_update_data = QuotaUpdateModel(**data) - user_service.update_user_quota(target_user_email, quota_update_data.new_quota_gb) + await user_service.update_user_quota(target_user_email, quota_update_data.new_quota_gb) return web.json_response({"message": f"Quota for {target_user_email} updated successfully"}) except ValueError as e: return web.json_response({"error": str(e)}, status=400) @@ -92,19 +98,18 @@ async def delete_user(request: web.Request) -> web.Response: if not current_user_email: return web.json_response({"error": "Unauthorized"}, status=401) - + if not target_user_email: return web.json_response({"error": "User email not provided"}, status=400) - # Prevent a user from deleting themselves or a parent user if target_user_email == current_user_email: return web.json_response({"error": "Forbidden: You cannot delete your own account from this interface"}, status=403) - target_user = user_service.get_user_by_email(target_user_email) - if not target_user or target_user.get("parent_email") != current_user_email: - return web.json_response({"error": "Forbidden: You do not have permission to delete this user"}, status=403) + has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email) + if not has_access: + return web.json_response({"error": error_response["error"]}, status=error_response["status"]) - if user_service.delete_user(target_user_email): + if await user_service.delete_user(target_user_email): return web.json_response({"message": f"User {target_user_email} deleted successfully"}) else: return web.json_response({"error": "User not found or could not be deleted"}, status=404) @@ -117,14 +122,15 @@ async def get_user_details(request: web.Request) -> web.Response: if not current_user_email: return web.json_response({"error": "Unauthorized"}, status=401) - + if not target_user_email: return web.json_response({"error": "User email not provided"}, status=400) - target_user = user_service.get_user_by_email(target_user_email) - if not target_user or (target_user_email != current_user_email and target_user.get("parent_email") != current_user_email): - return web.json_response({"error": "Forbidden: You do not have permission to view this user's details"}, status=403) + has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email) + if not has_access: + return web.json_response({"error": error_response["error"]}, status=error_response["status"]) + target_user = await user_service.get_user_by_email(target_user_email) safe_user = {k: v for k, v in target_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]} return web.json_response({"user": safe_user}) @@ -136,15 +142,14 @@ async def delete_team(request: web.Request) -> web.Response: if not current_user_email: return web.json_response({"error": "Unauthorized"}, status=401) - + if not target_parent_email: return web.json_response({"error": "Parent email not provided"}, status=400) - # Only the parent user can delete their "team" (users they created) if current_user_email != target_parent_email: return web.json_response({"error": "Forbidden: You do not have permission to delete this team"}, status=403) - deleted_count = user_service.delete_users_by_parent_email(target_parent_email) + deleted_count = await user_service.delete_users_by_parent_email(target_parent_email) if deleted_count > 0: return web.json_response({"message": f"Successfully deleted {deleted_count} users from the team managed by {target_parent_email}"}) else: diff --git a/retoors/views/auth.py b/retoors/views/auth.py index 8124094..718f4d6 100644 --- a/retoors/views/auth.py +++ b/retoors/views/auth.py @@ -55,7 +55,7 @@ class LoginView(CustomPydanticView): ) user_service: UserService = self.request.app["user_service"] - if user_service.authenticate_user(login_data.email, login_data.password): + if await user_service.authenticate_user(login_data.email, login_data.password): session = await new_session(self.request) session["user_email"] = login_data.email raise web.HTTPFound("/dashboard") @@ -93,7 +93,7 @@ class RegistrationView(CustomPydanticView): user_service: UserService = self.request.app["user_service"] try: - user_service.create_user(user_data.full_name, user_data.email, user_data.password) # Changed username to full_name + await user_service.create_user(user_data.full_name, user_data.email, user_data.password) # Render email content email_context = { @@ -148,10 +148,10 @@ class ForgotPasswordView(CustomPydanticView): ) user_service: UserService = self.request.app["user_service"] - user = user_service.get_user_by_email(forgot_password_data.email) + user = await user_service.get_user_by_email(forgot_password_data.email) if user: - token = user_service.generate_reset_token(forgot_password_data.email) + token = await user_service.generate_reset_token(forgot_password_data.email) if token: reset_link = self.request.url.join( self.request.app.router["reset_password"].url_for(token=token) @@ -217,7 +217,7 @@ class ResetPasswordView(CustomPydanticView): ) user_service: UserService = self.request.app["user_service"] - user = user_service.get_user_by_reset_token(token) # Corrected method call + user = await user_service.get_user_by_reset_token(token) if not user: return aiohttp_jinja2.render_template( self.template_name, @@ -225,7 +225,7 @@ class ResetPasswordView(CustomPydanticView): {"error": "Invalid or expired password reset link.", "request": self.request, "errors": {}, "token": token}, ) - if user_service.reset_password(user["email"], token, reset_password_data.password): + if await user_service.reset_password(user["email"], token, reset_password_data.password): # Send password changed confirmation email email_context = { "user_name": user["full_name"], diff --git a/retoors/views/migrate.py b/retoors/views/migrate.py new file mode 100644 index 0000000..c157dfe --- /dev/null +++ b/retoors/views/migrate.py @@ -0,0 +1,11 @@ +from aiohttp import web + +from ..helpers.auth import login_required + +class MigrateView(web.View): + @login_required + async def post(self): + user_email = self.request["user"]["email"] + file_service = self.request.app["file_service"] + await file_service.migrate_old_files(user_email) + return web.json_response({"status": "success", "message": "Migration completed"}) \ No newline at end of file diff --git a/retoors/views/site.py b/retoors/views/site.py index c3935b7..13a9784 100644 --- a/retoors/views/site.py +++ b/retoors/views/site.py @@ -298,7 +298,7 @@ class FileBrowserView(web.View): return json_response({"error": "Failed to generate share links for any selected items"}, status=500) logger.warning(f"FileBrowserView: Unknown file action for POST request: {route_name}") - return web.Response(status=400, text="Unknown file action") + raise web.HTTPBadRequest(text="Unknown file action") @login_required async def get_download_file(self): @@ -325,7 +325,7 @@ class FileBrowserView(web.View): raise web.HTTPNotFound(text="File not found") async def shared_file_handler(self): - share_id = self.request.match_info.get("share_id") + share_id = self.match_info.get("share_id") file_service = self.request.app["file_service"] logger.debug(f"FileBrowserView: Handling shared file request for share_id: {share_id}") @@ -333,16 +333,11 @@ class FileBrowserView(web.View): if not shared_item: logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id}") - return aiohttp_jinja2.render_template( - "pages/errors/404.html", - self.request, - {"request": self.request, "message": "Shared link is invalid or has expired."}, - status=404 - ) + raise web.HTTPNotFound(text="Shared file not found or inaccessible") user_email = shared_item["user_email"] item_path = shared_item["item_path"] - full_path = file_service._get_user_file_path(user_email, item_path) + full_path = file_service.get_user_file_system_path(user_email, item_path) if full_path.is_file(): result = await file_service.get_shared_file_content(share_id) @@ -355,7 +350,7 @@ class FileBrowserView(web.View): return response else: logger.error(f"FileBrowserView: Failed to get content for shared file: {item_path} (share_id: {share_id})") - raise web.HTTPNotFound(text="Shared file not found or inaccessible") + raise web.HTTPNotFound(text="Shared file not found or inaccessible within the shared folder.") elif full_path.is_dir(): files = await file_service.get_shared_folder_content(share_id) logger.info(f"FileBrowserView: Serving shared folder '{item_path}' for share_id: {share_id}") @@ -388,16 +383,11 @@ class FileBrowserView(web.View): shared_item = await file_service.get_shared_item(share_id) if not shared_item: logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id} during download.") - return aiohttp_jinja2.render_template( - "pages/errors/404.html", - self.request, - {"request": self.request, "message": "Shared link is invalid or has expired."}, - status=404 - ) + raise web.HTTPNotFound(text="Shared file not found or inaccessible within the shared folder.") # Ensure the shared item is a directory if a file_path is provided user_email = shared_item["user_email"] - original_shared_item_path = file_service._get_user_file_path(user_email, shared_item["item_path"]) + original_shared_item_path = file_service.get_user_file_system_path(user_email, shared_item["item_path"]) if not original_shared_item_path.is_dir(): logger.warning(f"FileBrowserView: Attempt to download a specific file from a shared item that is not a directory. Share_id: {share_id}") @@ -449,6 +439,25 @@ class OrderView(CustomPydanticView): class UserManagementView(web.View): + + async def verify_user_access(self, target_user_email: str) -> bool: + current_user_email = self.request["user"]["email"] + if target_user_email == current_user_email: + return True + + user_service = self.request.app["user_service"] + current_user = await user_service.get_user_by_email(current_user_email) + target_user = await user_service.get_user_by_email(target_user_email) + + if not target_user: + return False + + if current_user.get("is_customer", True): + if target_user.get("parent_email") != current_user_email: + return False + + return True + @login_required async def get(self): route_name = self.request.match_info.route.name @@ -531,14 +540,14 @@ class UserManagementView(web.View): parent_email = self.request["user"]["email"] try: - new_user = user_service.create_user( + new_user = await user_service.create_user( full_name=full_name, email=email, password=password, parent_email=parent_email ) - user_service.update_user_quota(email, float(storage_quota_gb)) + await user_service.update_user_quota(email, float(storage_quota_gb)) raise web.HTTPFound( self.request.app.router["users"].url_for().with_query( @@ -562,8 +571,11 @@ class UserManagementView(web.View): async def edit_user_page(self): email = self.request.match_info.get("email") + if not await self.verify_user_access(email): + raise web.HTTPForbidden(text="You do not have permission to access this user") + user_service = self.request.app["user_service"] - user_data = user_service.get_user_by_email(email) + user_data = await user_service.get_user_by_email(email) if not user_data: raise web.HTTPNotFound(text="User not found") @@ -585,6 +597,10 @@ class UserManagementView(web.View): async def edit_user_submit(self): email = self.request.match_info.get("email") + + if not await self.verify_user_access(email): + raise web.HTTPForbidden(text="You do not have permission to access this user") + data = await self.request.post() storage_quota_gb = data.get("storage_quota_gb", "") @@ -598,7 +614,7 @@ class UserManagementView(web.View): errors["storage_quota_gb"] = "Invalid storage quota value" user_service = self.request.app["user_service"] - user_data = user_service.get_user_by_email(email) + user_data = await user_service.get_user_by_email(email) if not user_data: raise web.HTTPNotFound(text="User not found") @@ -616,7 +632,7 @@ class UserManagementView(web.View): } ) - user_service.update_user_quota(email, storage_quota_gb) + await user_service.update_user_quota(email, storage_quota_gb) raise web.HTTPFound( self.request.app.router["edit_user"].url_for(email=email).with_query( @@ -627,8 +643,11 @@ class UserManagementView(web.View): async def user_details_page(self): email = self.request.match_info.get("email") + if not await self.verify_user_access(email): + raise web.HTTPForbidden(text="You do not have permission to access this user") + user_service = self.request.app["user_service"] - user_data = user_service.get_user_by_email(email) + user_data = await user_service.get_user_by_email(email) if not user_data: raise web.HTTPNotFound(text="User not found") @@ -647,13 +666,16 @@ class UserManagementView(web.View): async def delete_user_submit(self): email = self.request.match_info.get("email") + if not await self.verify_user_access(email): + raise web.HTTPForbidden(text="You do not have permission to access this user") + user_service = self.request.app["user_service"] - user_data = user_service.get_user_by_email(email) + user_data = await user_service.get_user_by_email(email) if not user_data: raise web.HTTPNotFound(text="User not found") - user_service.delete_user(email) + await user_service.delete_user(email) raise web.HTTPFound( self.request.app.router["users"].url_for().with_query( diff --git a/tests/conftest.py b/tests/conftest.py index 9291251..3b07013 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,16 +57,14 @@ def temp_users_json(tmp_path): @pytest.fixture def file_service_instance(temp_user_files_dir, temp_users_json): """Fixture to provide a FileService instance with temporary directories.""" - return FileService(temp_user_files_dir, temp_users_json) + user_service = UserService(temp_users_json) # Create a UserService instance + return FileService(temp_user_files_dir, user_service) # Pass the UserService instance + + @pytest.fixture -def create_app_instance(): - """Fixture to create a new aiohttp application instance.""" - return create_app() - -@pytest.fixture -def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_instance): +def create_test_app(mocker, temp_user_files_dir, temp_users_json): """Fixture to create a test aiohttp application with mocked services.""" from aiohttp import web @@ -81,16 +79,15 @@ def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_i app.middlewares.append(error_middleware) app.middlewares.append(user_middleware) - # Mock UserService - mock_user_service = mocker.MagicMock(spec=UserService) - + # Use a real UserService with a temporary users.json for the app + app["user_service"] = UserService(temp_users_json) + # Mock scheduler mock_scheduler = mocker.MagicMock() mock_scheduler.spawn = mocker.AsyncMock() mock_scheduler.close = mocker.AsyncMock() - app["user_service"] = mock_user_service - app["file_service"] = file_service_instance + app["file_service"] = FileService(temp_user_files_dir, app["user_service"]) app["scheduler"] = mock_scheduler # Setup Jinja2 for templates @@ -102,42 +99,14 @@ def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_i return app -@pytest.fixture(scope="function") -def mock_users_db_fixture(): - """ - Fixture to simulate a user database for dynamic mocking, - reset for each test function. - """ - return { - "admin@example.com": { - "full_name": "Admin User", - "email": "admin@example.com", - "password": "password", # Store plain password for mock authentication - "hashed_password": "hashed_password", # For consistency with real service - "storage_quota_gb": 100, - "storage_used_gb": 10, - "parent_email": None, - "reset_token": None, - "reset_token_expiry": None, - }, - "child1@example.com": { - "full_name": "Child User 1", - "email": "child1@example.com", - "password": "password", - "hashed_password": "hashed_password", - "storage_quota_gb": 50, - "storage_used_gb": 5, - "parent_email": "admin@example.com", - "shared_items": {} - } - } + @pytest.fixture async def client( - aiohttp_client, mocker: MockerFixture, create_app_instance, mock_users_db_fixture + aiohttp_client, mocker: MockerFixture, create_test_app ): - app = create_app_instance # Use the new fixture + app = create_test_app # Use create_test_app for consistent test environment # Directly set app["scheduler"] to a mock object mock_scheduler_instance = mocker.MagicMock() @@ -145,189 +114,20 @@ async def client( mock_scheduler_instance.close = mocker.AsyncMock() # Ensure close is awaitable app["scheduler"] = mock_scheduler_instance - # Create temporary data files for testing - base_path = Path(__file__).parent.parent - data_path = base_path / "data" - data_path.mkdir(exist_ok=True) - - users_file = data_path / "users.json" - with open(users_file, "w") as f: - json.dump([], f) - - config_file = data_path / "config.json" - with open(config_file, "w") as f: - json.dump({"price_per_gb": 0.0}, f) - - app["config_service"] = ConfigService(data_path / "config.json") + # The UserService and ConfigService are now set up in create_test_app with temporary files. + # No need to manually create users.json or config.json here. client = await aiohttp_client(app) - # Access the real UserService instance and mock its methods - mock_user_service_instance = client.app["user_service"] - - # Use the mock_users_db_fixture - mock_users_db = mock_users_db_fixture - - def mock_authenticate_user(email, password): - user = mock_users_db.get(email) - if user and user["password"] == password: - return user - return None - - def mock_get_user_by_email(email): - return mock_users_db.get(email) - - def mock_create_user(full_name, email, password, parent_email=None): - if email in mock_users_db: - raise ValueError("User with this email already exists") - new_user = { - "full_name": full_name, - "email": email, - "password": password, - "hashed_password": "hashed_password", - "storage_quota_gb": 5, - "storage_used_gb": 0, - "parent_email": parent_email, - "reset_token": None, - "reset_token_expiry": None, - } - mock_users_db[email] = new_user - return new_user - - def mock_reset_password(email, token, new_password): # Added token argument - user = mock_users_db.get(email) - if user and user.get("reset_token") == token and user.get("reset_token_expiry"): - expiry_time = datetime.datetime.fromisoformat(user["reset_token_expiry"]) - if expiry_time > datetime.datetime.now(datetime.timezone.utc): - user["password"] = new_password - user["hashed_password"] = "new_hashed_password" # Simulate hashing - user["reset_token"] = None - user["reset_token_expiry"] = None - return True - return False - - def mock_generate_reset_token(email): - user = mock_users_db.get(email) - if user: - # In a real scenario, this would generate a unique token and expiry - user["reset_token"] = "test_token" - user["reset_token_expiry"] = "2030-11-08T20:00:00Z" # A future date - return "test_token" - return None - - def mock_validate_reset_token(email, token): - if ( - token == "expiredtoken123" - ): # Explicitly handle the expired token from the test - return False - user = mock_users_db.get(email) - if user and user.get("reset_token") == token and user.get("reset_token_expiry"): - expiry_time = datetime.datetime.fromisoformat( - user["reset_token_expiry"] - ) - if expiry_time > datetime.datetime.now(datetime.timezone.utc): - return True - return False - - def mock_save_users(): - # This mock ensures that changes to user objects within tests are reflected in mock_users_db - # In a real scenario, this would write to a file or database. - pass # The mock_users_db is already being modified directly by other mocks - - def mock_get_all_users(): - return list(mock_users_db.values()) - - def mock_get_users_by_parent_email(parent_email): - return [ - user - for user in mock_users_db.values() - if user.get("parent_email") == parent_email - ] - - def mock_delete_user(email): - if email in mock_users_db: - del mock_users_db[email] - return True - return False - - def mock_delete_users_by_parent_email(parent_email): - initial_count = len(mock_users_db) - users_to_delete = [ - email - for email, user in mock_users_db.items() - if user.get("parent_email") == parent_email - ] - for email in users_to_delete: - del mock_users_db[email] - return initial_count - len(mock_users_db) - - mocker.patch.object( - mock_user_service_instance, - "authenticate_user", - side_effect=mock_authenticate_user, - ) - mocker.patch.object( - mock_user_service_instance, - "get_user_by_email", - side_effect=mock_get_user_by_email, - ) - mocker.patch.object( - mock_user_service_instance, "create_user", side_effect=mock_create_user - ) - mocker.patch.object( - mock_user_service_instance, "get_all_users", side_effect=mock_get_all_users - ) - mocker.patch.object( - mock_user_service_instance, "update_user_quota", return_value=None - ) # Keep as is for now - mocker.patch.object( - mock_user_service_instance, "delete_user", side_effect=mock_delete_user - ) - mocker.patch.object( - mock_user_service_instance, - "get_users_by_parent_email", - side_effect=mock_get_users_by_parent_email, - ) - mocker.patch.object( - mock_user_service_instance, - "delete_users_by_parent_email", - side_effect=mock_delete_users_by_parent_email, - ) - mocker.patch.object( - mock_user_service_instance, - "generate_reset_token", - side_effect=mock_generate_reset_token, - ) - mocker.patch.object( - mock_user_service_instance, - "get_user_by_reset_token", - side_effect=lambda token: next( - ( - user - for user in mock_users_db.values() - if user.get("reset_token") == token - ), - None, - ), - ) - mocker.patch.object( - mock_user_service_instance, "reset_password", side_effect=mock_reset_password - ) - mocker.patch.object( - mock_user_service_instance, - "validate_reset_token", - side_effect=mock_validate_reset_token, - ) - mocker.patch.object( - mock_user_service_instance, "_save_users", side_effect=mock_save_users - ) + # The UserService is now a real instance, so we don't need to mock its methods here. + # The mock_users_db_fixture is also no longer needed in this context. try: yield client finally: - # Clean up temporary files - users_file.unlink(missing_ok=True) - config_file.unlink(missing_ok=True) # Use missing_ok for robustness + # Clean up temporary files created by create_test_app if necessary, + # but tmp_path usually handles this. + pass @pytest.fixture @@ -338,38 +138,9 @@ async def logged_in_client(aiohttp_client, create_test_app, mocker): user_service = app["user_service"] - def mock_create_user(full_name, email, password, parent_email=None): - return { - "full_name": full_name, - "email": email, - "password": "hashed_password", - "storage_quota_gb": 10, - "storage_used_gb": 0, - "parent_email": parent_email, - "shared_items": {} - } - - def mock_authenticate_user(email, password): - return { - "email": email, - "full_name": "Test User", - "is_admin": False, - "storage_quota_gb": 10, - "storage_used_gb": 0 - } - - def mock_get_user_by_email(email): - return { - "email": email, - "full_name": "Test User", - "is_admin": False, - "storage_quota_gb": 10, - "storage_used_gb": 0 - } - - mocker.patch.object(user_service, "create_user", side_effect=mock_create_user) - mocker.patch.object(user_service, "authenticate_user", side_effect=mock_authenticate_user) - mocker.patch.object(user_service, "get_user_by_email", side_effect=mock_get_user_by_email) + # The UserService is now a real instance, so we don't need to mock its methods here. + # The create_user, authenticate_user, and get_user_by_email methods will interact + # with the real UserService instance. await client.post( "/register", @@ -394,38 +165,9 @@ async def logged_in_admin_client(aiohttp_client, create_test_app, mocker): user_service = app["user_service"] - def mock_create_user(full_name, email, password, parent_email=None): - return { - "full_name": full_name, - "email": email, - "password": "hashed_password", - "storage_quota_gb": 100, - "storage_used_gb": 0, - "parent_email": parent_email, - "shared_items": {} - } - - def mock_authenticate_user(email, password): - return { - "email": email, - "full_name": "Admin User", - "is_admin": True, - "storage_quota_gb": 100, - "storage_used_gb": 0 - } - - def mock_get_user_by_email(email): - return { - "email": email, - "full_name": "Admin User", - "is_admin": True, - "storage_quota_gb": 100, - "storage_used_gb": 0 - } - - mocker.patch.object(user_service, "create_user", side_effect=mock_create_user) - mocker.patch.object(user_service, "authenticate_user", side_effect=mock_authenticate_user) - mocker.patch.object(user_service, "get_user_by_email", side_effect=mock_get_user_by_email) + # The UserService is now a real instance, so we don't need to mock its methods here. + # The create_user, authenticate_user, and get_user_by_email methods will interact + # with the real UserService instance. await client.post( "/register", diff --git a/tests/test_auth.py b/tests/test_auth.py index daa3a68..8195110 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -204,7 +204,7 @@ async def test_reset_password_get_valid_token(client): "confirm_password": "old_password", }, ) - token = user_service.generate_reset_token("test@example.com") + token = await user_service.generate_reset_token("test@example.com") assert token is not None resp = await client.get(f"/reset_password/{token}") @@ -233,7 +233,7 @@ async def test_reset_password_post_success(client, mock_send_email): "confirm_password": "old_password", }, ) - token = user_service.generate_reset_token("test@example.com") + token = await user_service.generate_reset_token("test@example.com") assert token is not None resp = await client.post( @@ -248,8 +248,8 @@ async def test_reset_password_post_success(client, mock_send_email): assert resp.headers["Location"] == "/login?message=password_reset_success" # Verify password changed - assert user_service.authenticate_user("test@example.com", "new_password") - assert not user_service.authenticate_user("test@example.com", "old_password") + assert await user_service.authenticate_user("test@example.com", "new_password") + assert not await user_service.authenticate_user("test@example.com", "old_password") # Assert that confirmation email was sent @@ -272,7 +272,7 @@ async def test_reset_password_post_password_mismatch(client): "confirm_password": "old_password", }, ) - token = user_service.generate_reset_token("test@example.com") + token = await user_service.generate_reset_token("test@example.com") assert token is not None resp = await client.post( @@ -286,7 +286,7 @@ async def test_reset_password_post_password_mismatch(client): text = await resp.text() assert "Passwords do not match" in text # Password should not have changed - assert user_service.authenticate_user("test@example.com", "old_password") + assert await user_service.authenticate_user("test@example.com", "old_password") async def test_reset_password_post_invalid_token(client): @@ -301,7 +301,7 @@ async def test_reset_password_post_invalid_token(client): }, ) # Generate a token but don't use it, or use an expired one - user_service.generate_reset_token("test@example.com") # This will be overwritten or ignored + await user_service.generate_reset_token("test@example.com") # This will be overwritten or ignored resp = await client.post( "/reset_password/invalidtoken", @@ -314,7 +314,7 @@ async def test_reset_password_post_invalid_token(client): text = await resp.text() assert "Invalid or expired password reset link." in text # Password should not have changed - assert user_service.authenticate_user("test@example.com", "old_password") + assert await user_service.authenticate_user("test@example.com", "old_password") async def test_reset_password_post_expired_token(client, mock_users_db_fixture): @@ -329,7 +329,7 @@ async def test_reset_password_post_expired_token(client, mock_users_db_fixture): }, ) # Manually set an expired token - user = user_service.get_user_by_email("test@example.com") + user = await user_service.get_user_by_email("test@example.com") token = "expiredtoken123" user["reset_token"] = token user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat() @@ -346,7 +346,7 @@ async def test_reset_password_post_expired_token(client, mock_users_db_fixture): text = await resp.text() assert "Invalid or expired password reset link." in text # Password should not have changed - assert user_service.authenticate_user("test@example.com", "old_password") + assert await user_service.authenticate_user("test@example.com", "old_password") async def test_reset_password_post_invalid_password_format(client): @@ -360,7 +360,7 @@ async def test_reset_password_post_invalid_password_format(client): "confirm_password": "old_password", }, ) - token = user_service.generate_reset_token("test@example.com") + token = await user_service.generate_reset_token("test@example.com") assert token is not None resp = await client.post( @@ -374,4 +374,4 @@ async def test_reset_password_post_invalid_password_format(client): text = await resp.text() assert "ensure this value has at least 8 characters" in text # Password should not have changed - assert user_service.authenticate_user("test@example.com", "old_password") + assert await user_service.authenticate_user("test@example.com", "old_password") diff --git a/tests/test_file_browser.py b/tests/test_file_browser.py index c472f26..3730968 100644 --- a/tests/test_file_browser.py +++ b/tests/test_file_browser.py @@ -27,33 +27,38 @@ async def test_file_service_list_files_empty(file_service_instance): assert files == [] @pytest.mark.asyncio -async def test_file_service_create_folder(file_service_instance, temp_user_files_dir): +async def test_file_service_create_folder(file_service_instance): user_email = "test@example.com" folder_name = "my_new_folder" success = await file_service_instance.create_folder(user_email, folder_name) assert success - expected_path = temp_user_files_dir / user_email / folder_name - assert expected_path.is_dir() + metadata = await file_service_instance._load_metadata(user_email) + assert folder_name in metadata + assert metadata[folder_name]["type"] == "dir" @pytest.mark.asyncio -async def test_file_service_create_folder_exists(file_service_instance, temp_user_files_dir): +async def test_file_service_create_folder_exists(file_service_instance): user_email = "test@example.com" folder_name = "existing_folder" - (temp_user_files_dir / user_email).mkdir(parents=True) - (temp_user_files_dir / user_email / folder_name).mkdir(parents=True) + await file_service_instance.create_folder(user_email, folder_name) # Create it first via service success = await file_service_instance.create_folder(user_email, folder_name) assert not success # Should return False if folder already exists @pytest.mark.asyncio -async def test_file_service_upload_file(file_service_instance, temp_user_files_dir): +async def test_file_service_upload_file(file_service_instance): user_email = "test@example.com" file_name = "document.txt" file_content = b"Hello, world!" success = await file_service_instance.upload_file(user_email, file_name, file_content) assert success - expected_path = temp_user_files_dir / user_email / file_name - assert expected_path.is_file() - assert expected_path.read_bytes() == file_content + metadata = await file_service_instance._load_metadata(user_email) + assert file_name in metadata + assert metadata[file_name]["type"] == "file" + assert metadata[file_name]["size"] == len(file_content) + # Verify content by downloading + downloaded_content, downloaded_name = await file_service_instance.download_file(user_email, file_name) + assert downloaded_content == file_content + assert downloaded_name == file_name @pytest.mark.asyncio async def test_file_service_list_files_with_content(file_service_instance, temp_user_files_dir): @@ -89,27 +94,38 @@ async def test_file_service_download_file_not_found(file_service_instance): assert content is None @pytest.mark.asyncio -async def test_file_service_delete_file(file_service_instance, temp_user_files_dir): +async def test_file_service_delete_file(file_service_instance): user_email = "test@example.com" file_name = "to_delete.txt" - (temp_user_files_dir / user_email).mkdir(exist_ok=True) - (temp_user_files_dir / user_email / file_name).write_bytes(b"delete me") + await file_service_instance.upload_file(user_email, file_name, b"delete me") + + metadata_before = await file_service_instance._load_metadata(user_email) + assert file_name in metadata_before success = await file_service_instance.delete_item(user_email, file_name) assert success - assert not (temp_user_files_dir / user_email / file_name).exists() + + metadata_after = await file_service_instance._load_metadata(user_email) + assert file_name not in metadata_after @pytest.mark.asyncio -async def test_file_service_delete_folder(file_service_instance, temp_user_files_dir): +async def test_file_service_delete_folder(file_service_instance): user_email = "test@example.com" folder_name = "folder_to_delete" - (temp_user_files_dir / user_email).mkdir(parents=True) - (temp_user_files_dir / user_email / folder_name).mkdir(parents=True) - (temp_user_files_dir / user_email / folder_name / "nested.txt").write_bytes(b"nested") + nested_file = f"{folder_name}/nested.txt" + await file_service_instance.create_folder(user_email, folder_name) + await file_service_instance.upload_file(user_email, nested_file, b"nested content") + + metadata_before = await file_service_instance._load_metadata(user_email) + assert folder_name in metadata_before + assert nested_file in metadata_before success = await file_service_instance.delete_item(user_email, folder_name) assert success - assert not (temp_user_files_dir / user_email / folder_name).exists() + + metadata_after = await file_service_instance._load_metadata(user_email) + assert folder_name not in metadata_after + assert nested_file not in metadata_after @pytest.mark.asyncio async def test_file_service_delete_nonexistent(file_service_instance): @@ -199,14 +215,15 @@ async def test_file_browser_get_authorized_with_files(logged_in_client: TestClie assert "my_file.txt" in text @pytest.mark.asyncio -async def test_file_browser_new_folder(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): +async def test_file_browser_new_folder(logged_in_client: TestClient, file_service_instance): user_email = "test@example.com" resp = await logged_in_client.post("/files/new_folder", data={"folder_name": "new_folder_via_web"}, allow_redirects=False) assert resp.status == 302 # Redirect assert resp.headers["Location"].startswith("/files") - expected_path = temp_user_files_dir / user_email / "new_folder_via_web" - assert expected_path.is_dir() + metadata = await file_service_instance._load_metadata(user_email) + assert "new_folder_via_web" in metadata + assert metadata["new_folder_via_web"]["type"] == "dir" @pytest.mark.asyncio async def test_file_browser_new_folder_missing_name(logged_in_client: TestClient): @@ -228,23 +245,29 @@ async def test_file_browser_new_folder_exists(logged_in_client: TestClient, file assert f"error=Folder+'{folder_name}'+already+exists+or+could+not+be+created" in resp.headers["Location"] @pytest.mark.asyncio -async def test_file_browser_upload_file(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): +async def test_file_browser_upload_file(logged_in_client: TestClient, file_service_instance): user_email = "test@example.com" + file_name = "uploaded.txt" file_content = b"Uploaded content from web." from io import BytesIO data = aiohttp.FormData() data.add_field('file', BytesIO(file_content), - filename='uploaded.txt', + filename=file_name, content_type='text/plain') resp = await logged_in_client.post("/files/upload", data=data, allow_redirects=False) assert resp.status == 200 - expected_path = temp_user_files_dir / user_email / "uploaded.txt" - assert expected_path.is_file() - assert expected_path.read_bytes() == file_content + metadata = await file_service_instance._load_metadata(user_email) + assert file_name in metadata + assert metadata[file_name]["type"] == "file" + assert metadata[file_name]["size"] == len(file_content) + # Verify content by downloading + downloaded_content, downloaded_name = await file_service_instance.download_file(user_email, file_name) + assert downloaded_content == file_content + assert downloaded_name == file_name @pytest.mark.asyncio async def test_file_browser_download_file(logged_in_client: TestClient, file_service_instance): @@ -265,41 +288,52 @@ async def test_file_browser_download_file_not_found(logged_in_client: TestClient assert "File not found" in await response.text() @pytest.mark.asyncio -async def test_file_browser_delete_file(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): +async def test_file_browser_delete_file(logged_in_client: TestClient, file_service_instance): user_email = "test@example.com" file_name = "web_delete.txt" await file_service_instance.upload_file(user_email, file_name, b"delete this") - expected_path = temp_user_files_dir / user_email / file_name - assert expected_path.is_file() + metadata_before = await file_service_instance._load_metadata(user_email) + assert file_name in metadata_before resp = await logged_in_client.post(f"/files/delete/{file_name}", allow_redirects=False) assert resp.status == 302 # Redirect assert resp.headers["Location"].startswith("/files") - assert not expected_path.is_file() + + metadata_after = await file_service_instance._load_metadata(user_email) + assert file_name not in metadata_after @pytest.mark.asyncio -async def test_file_browser_delete_folder(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): +async def test_file_browser_delete_folder(logged_in_client: TestClient, file_service_instance): user_email = "test@example.com" folder_name = "web_delete_folder" + nested_file = f"{folder_name}/nested.txt" await file_service_instance.create_folder(user_email, folder_name) - await file_service_instance.upload_file(user_email, f"{folder_name}/nested.txt", b"nested") + await file_service_instance.upload_file(user_email, nested_file, b"nested") - expected_path = temp_user_files_dir / user_email / folder_name - assert expected_path.is_dir() + metadata_before = await file_service_instance._load_metadata(user_email) + assert folder_name in metadata_before + assert nested_file in metadata_before resp = await logged_in_client.post(f"/files/delete/{folder_name}", allow_redirects=False) assert resp.status == 302 # Redirect assert resp.headers["Location"].startswith("/files") - assert not expected_path.is_dir() + + metadata_after = await file_service_instance._load_metadata(user_email) + assert folder_name not in metadata_after + assert nested_file not in metadata_after @pytest.mark.asyncio -async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): +async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, file_service_instance): user_email = "test@example.com" file_names = ["multi_delete_1.txt", "multi_delete_2.txt", "multi_delete_3.txt"] for name in file_names: await file_service_instance.upload_file(user_email, name, b"content") + metadata_before = await file_service_instance._load_metadata(user_email) + for name in file_names: + assert name in metadata_before + paths_to_delete = [f"{name}" for name in file_names] # Construct FormData for multiple paths @@ -311,18 +345,23 @@ async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, assert resp.status == 302 # Redirect assert resp.headers["Location"].startswith("/files") + metadata_after = await file_service_instance._load_metadata(user_email) for name in file_names: - expected_path = temp_user_files_dir / user_email / name - assert not expected_path.is_file() + assert name not in metadata_after @pytest.mark.asyncio -async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): +async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient, file_service_instance): user_email = "test@example.com" folder_names = ["multi_delete_folder_1", "multi_delete_folder_2"] for name in folder_names: await file_service_instance.create_folder(user_email, name) await file_service_instance.upload_file(user_email, f"{name}/nested.txt", b"nested content") + metadata_before = await file_service_instance._load_metadata(user_email) + for name in folder_names: + assert name in metadata_before + assert f"{name}/nested.txt" in metadata_before + paths_to_delete = [f"{name}" for name in folder_names] # Construct FormData for multiple paths @@ -334,9 +373,10 @@ async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient assert resp.status == 302 # Redirect assert resp.headers["Location"].startswith("/files") + metadata_after = await file_service_instance._load_metadata(user_email) for name in folder_names: - expected_path = temp_user_files_dir / user_email / name - assert not expected_path.is_dir() + assert name not in metadata_after + assert f"{name}/nested.txt" not in metadata_after @pytest.mark.asyncio async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: TestClient): @@ -345,7 +385,7 @@ async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: Tes assert "error=No+items+selected+for+deletion" in resp.headers["Location"] @pytest.mark.asyncio -async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: TestClient, file_service_instance, temp_user_files_dir, mocker): +async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: TestClient, file_service_instance, mocker): user_email = "test@example.com" file_names = ["fail_delete_1.txt", "fail_delete_2.txt"] for name in file_names: @@ -370,10 +410,11 @@ async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: Te assert resp.status == 302 # Redirect assert "error=Some+items+failed+to+delete" in resp.headers["Location"] + metadata_after = await file_service_instance._load_metadata(user_email) # Check if the first file still exists (failed to delete) - assert (temp_user_files_dir / user_email / file_names[0]).is_file() + assert file_names[0] in metadata_after # Check if the second file is deleted (succeeded) - assert not (temp_user_files_dir / user_email / file_names[1]).is_file() + assert file_names[1] not in metadata_after @pytest.mark.asyncio async def test_file_browser_share_multiple_items_no_paths(logged_in_client: TestClient): @@ -434,7 +475,7 @@ async def test_file_browser_download_shared_file_handler_fail_get_content(client "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat() }) - mocker.patch.object(file_service_instance, "_get_user_file_path", return_value=mocker.MagicMock(is_dir=lambda: True)) + mocker.patch.object(file_service_instance, "get_user_file_system_path", return_value=mocker.MagicMock(is_dir=lambda: True)) mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None) resp = await client.get(f"/shared_file/{share_id}/download?file_path={file_name}") @@ -447,14 +488,14 @@ async def test_file_browser_download_shared_file_handler_not_a_directory(client: file_name = "shared_file.txt" share_id = "test_share_id" - mocker.patch.object(file_service_instance, "get_shared_item", return_value={ + mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={ "user_email": user_email, "item_path": file_name, "share_id": share_id, "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat() }) - mocker.patch.object(file_service_instance, "_get_user_file_path", return_value=mocker.MagicMock(is_dir=lambda: False)) + mocker.patch.object(client.app["file_service"], "get_user_file_system_path", return_value=mocker.MagicMock(is_dir=lambda: False)) resp = await client.get(f"/shared_file/{share_id}/download?file_path=some_file.txt") assert resp.status == 400 @@ -462,11 +503,11 @@ async def test_file_browser_download_shared_file_handler_not_a_directory(client: assert "Cannot download specific files from a shared item that is not a folder." in text @pytest.mark.asyncio async def test_file_browser_download_shared_file_handler_shared_item_not_found(client: TestClient, file_service_instance, mocker): - mocker.patch.object(file_service_instance, "get_shared_item", return_value=None) + mocker.patch.object(client.app["file_service"], "get_shared_item", return_value=None) resp = await client.get("/shared_file/nonexistent_share_id/download?file_path=some_file.txt") assert resp.status == 404 text = await resp.text() - assert "Shared link is invalid or has expired." in text + assert "Shared file not found or inaccessible" in text @pytest.mark.asyncio async def test_file_browser_download_shared_file_handler_missing_file_path(client: TestClient): @@ -480,7 +521,7 @@ async def test_file_browser_shared_file_handler_fail_get_content(client: TestCli file_name = "shared_file.txt" share_id = "test_share_id" - mocker.patch.object(file_service_instance, "get_shared_item", return_value={ + mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={ "user_email": user_email, "item_path": file_name, "share_id": share_id, @@ -488,7 +529,7 @@ async def test_file_browser_shared_file_handler_fail_get_content(client: TestCli "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat() }) mocker.patch("pathlib.Path.is_file", return_value=True) # Simulate it's a file - mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None) + mocker.patch.object(client.app["file_service"], "get_shared_file_content", return_value=None) resp = await client.get(f"/shared_file/{share_id}") assert resp.status == 404 @@ -501,7 +542,7 @@ async def test_file_browser_shared_file_handler_neither_file_nor_dir(client: Tes item_path = "mystery_item" share_id = "test_share_id" - mocker.patch.object(file_service_instance, "get_shared_item", return_value={ + mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={ "user_email": user_email, "item_path": item_path, "share_id": share_id, @@ -520,11 +561,11 @@ async def test_file_browser_shared_file_handler_neither_file_nor_dir(client: Tes @pytest.mark.asyncio async def test_file_browser_shared_file_handler_not_found(client: TestClient, file_service_instance, mocker): - mocker.patch.object(file_service_instance, "get_shared_item", return_value=None) + mocker.patch.object(client.app["file_service"], "get_shared_item", return_value=None) resp = await client.get("/shared_file/nonexistent_share_id") assert resp.status == 404 text = await resp.text() - assert "Shared link is invalid or has expired." in text + assert "Shared file not found or inaccessible" in text @pytest.mark.asyncio async def test_file_browser_unknown_post_action(logged_in_client: TestClient, mocker): diff --git a/tests/test_storage_service.py b/tests/test_storage_service.py new file mode 100644 index 0000000..a17a68d --- /dev/null +++ b/tests/test_storage_service.py @@ -0,0 +1,183 @@ +import pytest +import asyncio +import shutil +from pathlib import Path +from retoors.services.storage_service import StorageService, UserStorageManager + + +@pytest.fixture +def test_storage(): + storage = StorageService(base_path="data/test_user") + yield storage + if Path("data/test_user").exists(): + shutil.rmtree("data/test_user") + + +@pytest.fixture +def user_manager(): + manager = UserStorageManager() + manager.storage.base_path = Path("data/test_user") + yield manager + if Path("data/test_user").exists(): + shutil.rmtree("data/test_user") + + +@pytest.mark.asyncio +async def test_save_and_load(test_storage): + user_email = "test@example.com" + identifier = "doc1" + data = {"title": "Test Document", "content": "Test content"} + + result = await test_storage.save(user_email, identifier, data) + assert result is True + + loaded_data = await test_storage.load(user_email, identifier) + assert loaded_data == data + + +@pytest.mark.asyncio +async def test_distributed_path_structure(test_storage): + user_email = "test@example.com" + identifier = "doc1" + data = {"test": "data"} + + await test_storage.save(user_email, identifier, data) + + user_base = test_storage._get_user_base_path(user_email) + file_path = test_storage._get_distributed_path(user_base, identifier) + + assert file_path.exists() + assert len(file_path.parent.name) == 3 + assert len(file_path.parent.parent.name) == 3 + assert len(file_path.parent.parent.parent.name) == 3 + + +@pytest.mark.asyncio +async def test_user_isolation(test_storage): + user1_email = "user1@example.com" + user2_email = "user2@example.com" + identifier = "doc1" + data1 = {"user": "user1"} + data2 = {"user": "user2"} + + await test_storage.save(user1_email, identifier, data1) + await test_storage.save(user2_email, identifier, data2) + + loaded1 = await test_storage.load(user1_email, identifier) + loaded2 = await test_storage.load(user2_email, identifier) + + assert loaded1 == data1 + assert loaded2 == data2 + + +@pytest.mark.asyncio +async def test_path_traversal_protection(test_storage): + user_email = "test@example.com" + malicious_identifier = "../../../etc/passwd" + + await test_storage.save(user_email, malicious_identifier, {"test": "data"}) + + user_base = test_storage._get_user_base_path(user_email) + file_path = test_storage._get_distributed_path(user_base, malicious_identifier) + + assert file_path.exists() + assert test_storage._validate_path(file_path, user_base) + assert str(file_path.resolve()).startswith(str(user_base.resolve())) + + +@pytest.mark.asyncio +async def test_delete(test_storage): + user_email = "test@example.com" + identifier = "doc1" + data = {"test": "data"} + + await test_storage.save(user_email, identifier, data) + assert await test_storage.exists(user_email, identifier) + + result = await test_storage.delete(user_email, identifier) + assert result is True + assert not await test_storage.exists(user_email, identifier) + + +@pytest.mark.asyncio +async def test_list_all(test_storage): + user_email = "test@example.com" + + await test_storage.save(user_email, "doc1", {"id": 1}) + await test_storage.save(user_email, "doc2", {"id": 2}) + await test_storage.save(user_email, "doc3", {"id": 3}) + + all_docs = await test_storage.list_all(user_email) + assert len(all_docs) == 3 + assert any(doc["id"] == 1 for doc in all_docs) + assert any(doc["id"] == 2 for doc in all_docs) + assert any(doc["id"] == 3 for doc in all_docs) + + +@pytest.mark.asyncio +async def test_delete_all(test_storage): + user_email = "test@example.com" + + await test_storage.save(user_email, "doc1", {"id": 1}) + await test_storage.save(user_email, "doc2", {"id": 2}) + + result = await test_storage.delete_all(user_email) + assert result is True + + all_docs = await test_storage.list_all(user_email) + assert len(all_docs) == 0 + + +@pytest.mark.asyncio +async def test_user_storage_manager(user_manager): + user_data = { + "full_name": "Test User", + "email": "test@example.com", + "password": "hashed_password", + "is_customer": True + } + + await user_manager.save_user("test@example.com", user_data) + + loaded_user = await user_manager.get_user("test@example.com") + assert loaded_user == user_data + + assert await user_manager.user_exists("test@example.com") + + await user_manager.delete_user("test@example.com") + assert not await user_manager.user_exists("test@example.com") + + +@pytest.mark.asyncio +async def test_list_users_by_parent(user_manager): + parent_user = { + "email": "parent@example.com", + "full_name": "Parent User", + "password": "hashed", + "is_customer": True + } + + child_user1 = { + "email": "child1@example.com", + "full_name": "Child User 1", + "password": "hashed", + "parent_email": "parent@example.com", + "is_customer": True + } + + child_user2 = { + "email": "child2@example.com", + "full_name": "Child User 2", + "password": "hashed", + "parent_email": "parent@example.com", + "is_customer": True + } + + await user_manager.save_user("parent@example.com", parent_user) + await user_manager.save_user("child1@example.com", child_user1) + await user_manager.save_user("child2@example.com", child_user2) + + children = await user_manager.list_users_by_parent("parent@example.com") + assert len(children) == 2 + assert any(u["email"] == "child1@example.com" for u in children) + assert any(u["email"] == "child2@example.com" for u in children) diff --git a/tests/test_user_service.py b/tests/test_user_service.py index a5e69a8..8ecca28 100644 --- a/tests/test_user_service.py +++ b/tests/test_user_service.py @@ -18,28 +18,28 @@ def user_service(users_file): return UserService(users_file) @pytest.fixture -def populated_user_service(user_service): +async def populated_user_service(user_service): """Fixture to provide a UserService instance with some pre-populated users.""" - user_service.create_user("Admin User", "admin@example.com", "adminpass") - user_service.create_user("Parent User", "parent@example.com", "parentpass") - user_service.create_user("Child User 1", "child1@example.com", "childpass", "parent@example.com") - user_service.create_user("Child User 2", "child2@example.com", "childpass", "parent@example.com") + await user_service.create_user("Admin User", "admin@example.com", "adminpass") + await user_service.create_user("Parent User", "parent@example.com", "parentpass") + await user_service.create_user("Child User 1", "child1@example.com", "childpass", "parent@example.com") + await user_service.create_user("Child User 2", "child2@example.com", "childpass", "parent@example.com") return user_service async def test_create_user_success(user_service): - user = user_service.create_user("Test User", "test@example.com", "password123") + user = await user_service.create_user("Test User", "test@example.com", "password123") assert user is not None assert user["email"] == "test@example.com" - assert user_service.get_user_by_email("test@example.com") is not None + assert await user_service.get_user_by_email("test@example.com") is not None assert bcrypt.checkpw(b"password123", user["password"].encode('utf-8')) async def test_create_user_duplicate_email(user_service): - user_service.create_user("Test User", "test@example.com", "password123") + await user_service.create_user("Test User", "test@example.com", "password123") with pytest.raises(ValueError, match="User with this email already exists"): - user_service.create_user("Another User", "test@example.com", "anotherpass") + await user_service.create_user("Another User", "test@example.com", "anotherpass") async def test_get_all_users(populated_user_service): - users = populated_user_service.get_all_users() + users = await populated_user_service.get_all_users() assert len(users) == 4 emails = {user["email"] for user in users} assert "admin@example.com" in emails @@ -48,70 +48,70 @@ async def test_get_all_users(populated_user_service): assert "child2@example.com" in emails async def test_get_users_by_parent_email(populated_user_service): - children = populated_user_service.get_users_by_parent_email("parent@example.com") + children = await populated_user_service.get_users_by_parent_email("parent@example.com") assert len(children) == 2 child_emails = {user["email"] for user in children} assert "child1@example.com" in child_emails assert "child2@example.com" in child_emails - no_children = populated_user_service.get_users_by_parent_email("nonexistent@example.com") + no_children = await populated_user_service.get_users_by_parent_email("nonexistent@example.com") assert len(no_children) == 0 - admin_children = populated_user_service.get_users_by_parent_email("admin@example.com") + admin_children = await populated_user_service.get_users_by_parent_email("admin@example.com") assert len(admin_children) == 0 async def test_update_user_non_password_fields(populated_user_service): - updated_user = populated_user_service.update_user("admin@example.com", full_name="Administrator", storage_quota_gb=10) + updated_user = await populated_user_service.update_user("admin@example.com", full_name="Administrator", storage_quota_gb=10) assert updated_user is not None assert updated_user["full_name"] == "Administrator" assert updated_user["storage_quota_gb"] == 10 - retrieved_user = populated_user_service.get_user_by_email("admin@example.com") + retrieved_user = await populated_user_service.get_user_by_email("admin@example.com") assert retrieved_user["full_name"] == "Administrator" assert retrieved_user["storage_quota_gb"] == 10 async def test_update_user_password(populated_user_service): - updated_user = populated_user_service.update_user("admin@example.com", password="newadminpass") + updated_user = await populated_user_service.update_user("admin@example.com", password="newadminpass") assert updated_user is not None - assert populated_user_service.authenticate_user("admin@example.com", "newadminpass") - assert not populated_user_service.authenticate_user("admin@example.com", "adminpass") + assert await populated_user_service.authenticate_user("admin@example.com", "newadminpass") + assert not await populated_user_service.authenticate_user("admin@example.com", "adminpass") async def test_update_user_nonexistent(user_service): - updated_user = user_service.update_user("nonexistent@example.com", full_name="Non Existent") + updated_user = await user_service.update_user("nonexistent@example.com", full_name="Non Existent") assert updated_user is None async def test_delete_user_success(populated_user_service): - assert populated_user_service.delete_user("admin@example.com") is True - assert populated_user_service.get_user_by_email("admin@example.com") is None - assert len(populated_user_service.get_all_users()) == 3 + assert await populated_user_service.delete_user("admin@example.com") is True + assert await populated_user_service.get_user_by_email("admin@example.com") is None + assert len(await populated_user_service.get_all_users()) == 3 async def test_delete_user_nonexistent(user_service): - assert user_service.delete_user("nonexistent@example.com") is False + assert await user_service.delete_user("nonexistent@example.com") is False async def test_delete_users_by_parent_email_success(populated_user_service): - deleted_count = populated_user_service.delete_users_by_parent_email("parent@example.com") + deleted_count = await populated_user_service.delete_users_by_parent_email("parent@example.com") assert deleted_count == 2 - assert populated_user_service.get_user_by_email("child1@example.com") is None - assert populated_user_service.get_user_by_email("child2@example.com") is None - assert len(populated_user_service.get_all_users()) == 2 # Admin and Parent users remain + assert await populated_user_service.get_user_by_email("child1@example.com") is None + assert await populated_user_service.get_user_by_email("child2@example.com") is None + assert len(await populated_user_service.get_all_users()) == 2 # Admin and Parent users remain async def test_delete_users_by_parent_email_no_match(user_service): - deleted_count = user_service.delete_users_by_parent_email("nonexistent@example.com") + deleted_count = await user_service.delete_users_by_parent_email("nonexistent@example.com") assert deleted_count == 0 async def test_authenticate_user_success(populated_user_service): - assert populated_user_service.authenticate_user("admin@example.com", "adminpass") is True + assert await populated_user_service.authenticate_user("admin@example.com", "adminpass") is True async def test_authenticate_user_fail_wrong_password(populated_user_service): - assert populated_user_service.authenticate_user("admin@example.com", "wrongpass") is False + assert await populated_user_service.authenticate_user("admin@example.com", "wrongpass") is False async def test_authenticate_user_fail_nonexistent_user(user_service): - assert user_service.authenticate_user("nonexistent@example.com", "anypass") is False + assert await user_service.authenticate_user("nonexistent@example.com", "anypass") is False async def test_generate_reset_token_success(populated_user_service): - token = populated_user_service.generate_reset_token("admin@example.com") + token = await populated_user_service.generate_reset_token("admin@example.com") assert token is not None - user = populated_user_service.get_user_by_email("admin@example.com") + user = await populated_user_service.get_user_by_email("admin@example.com") assert user["reset_token"] == token assert user["reset_token_expiry"] is not None # Check expiry is in the future @@ -119,71 +119,71 @@ async def test_generate_reset_token_success(populated_user_service): assert expiry_dt > datetime.datetime.now(datetime.timezone.utc) async def test_generate_reset_token_nonexistent_user(user_service): - token = user_service.generate_reset_token("nonexistent@example.com") + token = await user_service.generate_reset_token("nonexistent@example.com") assert token is None async def test_get_user_by_reset_token_valid(populated_user_service): - token = populated_user_service.generate_reset_token("admin@example.com") - user = populated_user_service.get_user_by_reset_token(token) + token = await populated_user_service.generate_reset_token("admin@example.com") + user = await populated_user_service.get_user_by_reset_token(token) assert user is not None assert user["email"] == "admin@example.com" async def test_get_user_by_reset_token_invalid(populated_user_service): - user = populated_user_service.get_user_by_reset_token("invalidtoken") + user = await populated_user_service.get_user_by_reset_token("invalidtoken") assert user is None async def test_get_user_by_reset_token_expired(populated_user_service): - token = populated_user_service.generate_reset_token("admin@example.com") - user = populated_user_service.get_user_by_email("admin@example.com") + token = await populated_user_service.generate_reset_token("admin@example.com") + user = await populated_user_service.get_user_by_email("admin@example.com") # Manually expire the token user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat() - populated_user_service._save_users() # Save the expired state + await populated_user_service._save_users() # Save the expired state - user_after_expiry = populated_user_service.get_user_by_reset_token(token) + user_after_expiry = await populated_user_service.get_user_by_reset_token(token) assert user_after_expiry is None async def test_validate_reset_token_success(populated_user_service): - token = populated_user_service.generate_reset_token("admin@example.com") - assert populated_user_service.validate_reset_token("admin@example.com", token) is True + token = await populated_user_service.generate_reset_token("admin@example.com") + assert await populated_user_service.validate_reset_token("admin@example.com", token) is True async def test_validate_reset_token_fail_wrong_token(populated_user_service): - populated_user_service.generate_reset_token("admin@example.com") - assert populated_user_service.validate_reset_token("admin@example.com", "wrongtoken") is False + await populated_user_service.generate_reset_token("admin@example.com") + assert await populated_user_service.validate_reset_token("admin@example.com", "wrongtoken") is False async def test_validate_reset_token_fail_nonexistent_user(user_service): - assert user_service.validate_reset_token("nonexistent@example.com", "anytoken") is False + assert await user_service.validate_reset_token("nonexistent@example.com", "anytoken") is False async def test_validate_reset_token_fail_expired_token(populated_user_service): - token = populated_user_service.generate_reset_token("admin@example.com") - user = populated_user_service.get_user_by_email("admin@example.com") + token = await populated_user_service.generate_reset_token("admin@example.com") + user = await populated_user_service.get_user_by_email("admin@example.com") user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat() - populated_user_service._save_users() + await populated_user_service._save_users() - assert populated_user_service.validate_reset_token("admin@example.com", token) is False + assert await populated_user_service.validate_reset_token("admin@example.com", token) is False async def test_reset_password_success(populated_user_service): - token = populated_user_service.generate_reset_token("admin@example.com") - assert populated_user_service.reset_password("admin@example.com", token, "newadminpass") is True - assert populated_user_service.authenticate_user("admin@example.com", "newadminpass") - user = populated_user_service.get_user_by_email("admin@example.com") + token = await populated_user_service.generate_reset_token("admin@example.com") + assert await populated_user_service.reset_password("admin@example.com", token, "newadminpass") is True + assert await populated_user_service.authenticate_user("admin@example.com", "newadminpass") + user = await populated_user_service.get_user_by_email("admin@example.com") assert user["reset_token"] is None assert user["reset_token_expiry"] is None async def test_reset_password_fail_invalid_token(populated_user_service): - populated_user_service.generate_reset_token("admin@example.com") - assert populated_user_service.reset_password("admin@example.com", "invalidtoken", "newadminpass") is False - assert populated_user_service.authenticate_user("admin@example.com", "adminpass") # Password should not change + await populated_user_service.generate_reset_token("admin@example.com") + assert await populated_user_service.reset_password("admin@example.com", "invalidtoken", "newadminpass") is False + assert await populated_user_service.authenticate_user("admin@example.com", "adminpass") # Password should not change async def test_reset_password_fail_nonexistent_user(user_service): # Even if a token was somehow generated for a nonexistent user (which shouldn't happen), # reset_password should fail. - assert user_service.reset_password("nonexistent@example.com", "anytoken", "newpass") is False + assert await user_service.reset_password("nonexistent@example.com", "anytoken", "newpass") is False async def test_update_user_quota_success(populated_user_service): - populated_user_service.update_user_quota("admin@example.com", 20.5) - user = populated_user_service.get_user_by_email("admin@example.com") + await populated_user_service.update_user_quota("admin@example.com", 20.5) + user = await populated_user_service.get_user_by_email("admin@example.com") assert user["storage_quota_gb"] == 20.5 async def test_update_user_quota_nonexistent_user(user_service): with pytest.raises(ValueError, match="User not found"): - user_service.update_user_quota("nonexistent@example.com", 100) + await user_service.update_user_quota("nonexistent@example.com", 100)