This commit is contained in:
retoor 2025-11-09 08:14:14 +01:00
parent f54941dd80
commit 7abced9315
24 changed files with 1439 additions and 657 deletions

View File

@ -5,6 +5,7 @@ bcrypt
python-dotenv python-dotenv
aiosmtplib aiosmtplib
aiojobs aiojobs
aiofiles
pytest pytest
pytest-aiohttp pytest-aiohttp
aiohttp-test-utils aiohttp-test-utils

View File

@ -17,14 +17,12 @@ def login_required(func):
raise web.HTTPFound('/login') raise web.HTTPFound('/login')
user_service: UserService = request.app["user_service"] user_service: UserService = request.app["user_service"]
user = user_service.get_user_by_email(user_email) user = await user_service.get_user_by_email(user_email)
if not user: if not user:
# User not found in service, clear session and redirect to login
session.pop('user_email', None) session.pop('user_email', None)
raise web.HTTPFound('/login') raise web.HTTPFound('/login')
# Ensure the user object is available in the request for views
request["user"] = user request["user"] = user
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs)
return wrapper return wrapper

View File

@ -18,9 +18,9 @@ from .helpers.env_manager import ensure_env_file_exists, get_or_create_session_s
async def setup_services(app: web.Application): async def setup_services(app: web.Application):
base_path = Path(__file__).parent base_path = Path(__file__).parent
data_path = base_path.parent / "data" data_path = base_path.parent / "data"
app["user_service"] = UserService(data_path / "users.json") app["user_service"] = UserService(use_isolated_storage=True)
app["config_service"] = ConfigService(data_path / "config.json") app["config_service"] = ConfigService(data_path / "config.json")
app["file_service"] = FileService(data_path / "user_files", data_path / "users.json") # Instantiate FileService app["file_service"] = FileService(data_path / "user_files", data_path / "users.json")
# Setup aiojobs scheduler # Setup aiojobs scheduler
app["scheduler"] = aiojobs.Scheduler() app["scheduler"] = aiojobs.Scheduler()

View File

@ -12,7 +12,7 @@ async def user_middleware(request, handler):
request["user"] = None request["user"] = None
if "user_email" in session: if "user_email" in session:
user_service = request.app["user_service"] user_service = request.app["user_service"]
request["user"] = user_service.get_user_by_email(session["user_email"]) request["user"] = await user_service.get_user_by_email(session["user_email"])
return await handler(request) return await handler(request)

View File

@ -1,6 +1,7 @@
from .views.auth import LoginView, RegistrationView, LogoutView, ForgotPasswordView, ResetPasswordView from .views.auth import LoginView, RegistrationView, LogoutView, ForgotPasswordView, ResetPasswordView
from .views.site import SiteView, OrderView, FileBrowserView, UserManagementView from .views.site import SiteView, OrderView, FileBrowserView, UserManagementView
from .views.upload import UploadView from .views.upload import UploadView
from .views.migrate import MigrateView
from .views.admin import get_users, add_user, update_user_quota, delete_user, get_user_details, delete_team from .views.admin import get_users, add_user, update_user_quota, delete_user, get_user_details, delete_team

View File

@ -5,6 +5,9 @@ import shutil
import uuid import uuid
import datetime import datetime
import logging import logging
import hashlib
import os
from .storage_service import StorageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -14,121 +17,152 @@ handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
class FileService: class FileService:
def __init__(self, base_dir: Path, users_data_path: Path): def __init__(self, base_dir: Path, user_service):
self.base_dir = base_dir self.base_dir = base_dir
self.users_data_path = users_data_path self.user_service = user_service
self.storage = StorageService()
self.base_dir.mkdir(parents=True, exist_ok=True) self.base_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"FileService initialized with base_dir: {self.base_dir} and users_data_path: {self.users_data_path}") self.drives_dir = self.base_dir / "drives"
self.drives_dir.mkdir(exist_ok=True)
self.drives = ["drive1", "drive2", "drive3"]
for drive in self.drives:
(self.drives_dir / drive).mkdir(exist_ok=True)
self.drive_counter = 0
logger.info(f"FileService initialized with base_dir: {self.base_dir} and user_service")
async def _load_users_data(self): def _choose_drive(self):
"""Loads user data from the JSON file.""" drive = self.drives[self.drive_counter % len(self.drives)]
if not self.users_data_path.exists(): self.drive_counter += 1
logger.warning(f"users_data_path does not exist: {self.users_data_path}") return drive
return []
async with aiofiles.open(self.users_data_path, mode="r") as f:
content = await f.read()
try:
return json.loads(content) if content else []
except json.JSONDecodeError:
logger.error(f"JSONDecodeError when loading users data from {self.users_data_path}")
return []
async def _save_users_data(self, data): async def _load_metadata(self, user_email: str):
"""Saves user data to the JSON file.""" metadata = await self.storage.load(user_email, "file_metadata")
async with aiofiles.open(self.users_data_path, mode="w") as f: return metadata if metadata else {}
await f.write(json.dumps(data, indent=4))
logger.debug(f"Saved users data to {self.users_data_path}")
def _get_user_file_path(self, user_email: str, relative_path: str = "") -> Path: async def _save_metadata(self, user_email: str, metadata: dict):
"""Constructs the absolute path for a user's file or directory.""" await self.storage.save(user_email, "file_metadata", metadata)
user_dir = self.base_dir / user_email
full_path = user_dir / relative_path
logger.debug(f"Constructed path for user '{user_email}', relative_path '{relative_path}': {full_path}")
def get_user_file_system_path(self, user_email: str, item_path: str) -> Path:
"""
Constructs the absolute file system path for a user's item.
This is for internal use to resolve virtual paths to physical paths.
"""
user_base_path = self.storage._get_user_base_path(user_email)
# Ensure item_path is treated as relative to user_base_path
# and prevent directory traversal attacks.
full_path = user_base_path / item_path
if not self.storage._validate_path(full_path, user_base_path):
raise ValueError("Invalid item path: directory traversal detected")
return full_path return full_path
async def list_files(self, user_email: str, path: str = "") -> list: async def list_files(self, user_email: str, path: str = "") -> list:
"""Lists files and directories for a given user within a specified path.""" """Lists files and directories for a given user within a specified path."""
user_path = self._get_user_file_path(user_email, path) metadata = await self._load_metadata(user_email)
if not user_path.is_dir(): # Normalize path
logger.warning(f"list_files: User path is not a directory or does not exist: {user_path}") if path and not path.endswith('/'):
return [] path += '/'
items = []
files_list = [] seen = set()
for item in user_path.iterdir(): for item_path, item_meta in metadata.items():
if item.name.startswith('.'): # Ignore hidden files/directories if item_path.startswith(path):
continue remaining = item_path[len(path):].rstrip('/')
file_info = { if '/' not in remaining and remaining not in seen:
"name": item.name, seen.add(remaining)
"is_dir": item.is_dir(), if remaining: # not the path itself
"path": str(item.relative_to(self._get_user_file_path(user_email))), items.append({
"size": item.stat().st_size if item.is_file() else 0, "name": remaining,
"last_modified": datetime.datetime.fromtimestamp(item.stat().st_mtime).isoformat(), "is_dir": item_meta.get("type") == "dir",
} "path": item_path,
files_list.append(file_info) "size": item_meta.get("size", 0),
logger.debug(f"Listed {len(files_list)} items for user '{user_email}' in path '{path}'") "last_modified": item_meta.get("modified_at", ""),
return sorted(files_list, key=lambda x: (not x["is_dir"], x["name"].lower())) })
logger.debug(f"Listed {len(items)} items for user '{user_email}' in path '{path}'")
return sorted(items, key=lambda x: (not x["is_dir"], x["name"].lower()))
async def create_folder(self, user_email: str, folder_path: str) -> bool: async def create_folder(self, user_email: str, folder_path: str) -> bool:
"""Creates a new folder for the user.""" """Creates a new folder for the user."""
full_path = self._get_user_file_path(user_email, folder_path) metadata = await self._load_metadata(user_email)
if full_path.exists(): if folder_path in metadata:
logger.warning(f"create_folder: Folder already exists: {full_path}") logger.warning(f"create_folder: Folder already exists: {folder_path}")
return False # Folder already exists return False
full_path.mkdir(parents=True, exist_ok=True) metadata[folder_path] = {
logger.info(f"create_folder: Folder created: {full_path}") "type": "dir",
"created_at": datetime.datetime.now().isoformat(),
"modified_at": datetime.datetime.now().isoformat(),
}
await self._save_metadata(user_email, metadata)
logger.info(f"create_folder: Folder created: {folder_path}")
return True return True
async def upload_file(self, user_email: str, file_path: str, content: bytes) -> bool: async def upload_file(self, user_email: str, file_path: str, content: bytes) -> bool:
"""Uploads a file for the user.""" """Uploads a file for the user."""
full_path = self._get_user_file_path(user_email, file_path) hash = hashlib.sha256(content).hexdigest()
full_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent directories exist drive = self._choose_drive()
async with aiofiles.open(full_path, mode="wb") as f: dir1 = hash[:3]
dir2 = hash[3:6]
dir3 = hash[6:9]
blob_path = self.drives_dir / drive / dir1 / dir2 / dir3 / hash
blob_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(blob_path, 'wb') as f:
await f.write(content) await f.write(content)
logger.info(f"upload_file: File uploaded to: {full_path}") metadata = await self._load_metadata(user_email)
metadata[file_path] = {
"type": "file",
"size": len(content),
"hash": hash,
"blob_location": {"drive": drive, "path": f"{dir1}/{dir2}/{dir3}/{hash}"},
"created_at": datetime.datetime.now().isoformat(),
"modified_at": datetime.datetime.now().isoformat(),
}
await self._save_metadata(user_email, metadata)
logger.info(f"upload_file: File uploaded to drive {drive}: {file_path}")
return True return True
async def download_file(self, user_email: str, file_path: str) -> tuple[bytes, str] | None: async def download_file(self, user_email: str, file_path: str) -> tuple[bytes, str] | None:
"""Downloads a file for the user.""" """Downloads a file for the user."""
full_path = self._get_user_file_path(user_email, file_path) metadata = await self._load_metadata(user_email)
logger.debug(f"download_file: Attempting to download file from: {full_path}") if file_path not in metadata or metadata[file_path]["type"] != "file":
if full_path.is_file(): logger.warning(f"download_file: File not found in metadata: {file_path}")
async with aiofiles.open(full_path, mode="rb") as f:
content = await f.read()
logger.info(f"download_file: Successfully read file: {full_path}")
return content, full_path.name
logger.warning(f"download_file: File not found or is not a file: {full_path}")
return None return None
item_meta = metadata[file_path]
blob_loc = item_meta["blob_location"]
blob_path = self.drives_dir / blob_loc["drive"] / blob_loc["path"]
if not blob_path.exists():
logger.warning(f"download_file: Blob not found: {blob_path}")
return None
async with aiofiles.open(blob_path, 'rb') as f:
content = await f.read()
logger.info(f"download_file: Successfully read file: {file_path}")
return content, Path(file_path).name
async def delete_item(self, user_email: str, item_path: str) -> bool: async def delete_item(self, user_email: str, item_path: str) -> bool:
"""Deletes a file or folder for the user.""" """Deletes a file or folder for the user."""
full_path = self._get_user_file_path(user_email, item_path) metadata = await self._load_metadata(user_email)
logger.debug(f"delete_item: Attempting to delete item: {full_path}") if item_path not in metadata:
if not full_path.exists(): logger.warning(f"delete_item: Item not found: {item_path}")
logger.warning(f"delete_item: Item does not exist: {full_path}")
return False return False
# If dir, remove all under it
if full_path.is_file(): to_delete = [p for p in metadata if p == item_path or p.startswith(item_path + '/')]
full_path.unlink() for p in to_delete:
logger.info(f"delete_item: File deleted: {full_path}") del metadata[p]
elif full_path.is_dir(): await self._save_metadata(user_email, metadata)
shutil.rmtree(full_path) logger.info(f"delete_item: Item deleted: {item_path}")
logger.info(f"delete_item: Directory deleted: {full_path}")
return True return True
async def generate_share_link(self, user_email: str, item_path: str) -> str | None: async def generate_share_link(self, user_email: str, item_path: str) -> str | None:
"""Generates a shareable link for a file or folder.""" """Generates a shareable link for a file or folder."""
logger.debug(f"generate_share_link: Generating link for user '{user_email}', item '{item_path}'") logger.debug(f"generate_share_link: Generating link for user '{user_email}', item '{item_path}'")
users_data = await self._load_users_data() metadata = await self._load_metadata(user_email)
user = next((u for u in users_data if u.get("email") == user_email), None) if item_path not in metadata:
logger.warning(f"generate_share_link: Item does not exist: {item_path}")
return None
user = await self.user_service.get_user_by_email(user_email)
if not user: if not user:
logger.warning(f"generate_share_link: User not found: {user_email}") logger.warning(f"generate_share_link: User not found: {user_email}")
return None return None
full_path = self._get_user_file_path(user_email, item_path)
if not full_path.exists():
logger.warning(f"generate_share_link: Item does not exist: {full_path}")
return None
share_id = str(uuid.uuid4()) share_id = str(uuid.uuid4())
if "shared_items" not in user: if "shared_items" not in user:
user["shared_items"] = {} user["shared_items"] = {}
@ -138,17 +172,17 @@ class FileService:
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat(), # 7-day expiry "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat(), # 7-day expiry
} }
await self._save_users_data(users_data) await self.user_service.update_user(user_email, shared_items=user["shared_items"])
logger.info(f"generate_share_link: Share link generated with ID: {share_id} for item: {item_path}") logger.info(f"generate_share_link: Share link generated with ID: {share_id} for item: {item_path}")
return share_id return share_id
async def get_shared_item(self, share_id: str) -> dict | None: async def get_shared_item(self, share_id: str) -> dict | None:
"""Retrieves information about a shared item.""" """Retrieves information about a shared item."""
logger.debug(f"get_shared_item: Retrieving shared item with ID: {share_id}") logger.debug(f"get_shared_item: Retrieving shared item with ID: {share_id}")
users_data = await self._load_users_data() all_users = await self.user_service.get_all_users()
for user_info in users_data: for user in all_users:
if "shared_items" in user_info and share_id in user_info["shared_items"]: if "shared_items" in user and share_id in user["shared_items"]:
shared_item = user_info["shared_items"][share_id] shared_item = user["shared_items"][share_id]
expiry_time = datetime.datetime.fromisoformat(shared_item["expires_at"]) expiry_time = datetime.datetime.fromisoformat(shared_item["expires_at"])
if expiry_time > datetime.datetime.now(datetime.timezone.utc): if expiry_time > datetime.datetime.now(datetime.timezone.utc):
logger.info(f"get_shared_item: Found valid shared item for ID: {share_id}") logger.info(f"get_shared_item: Found valid shared item for ID: {share_id}")
@ -168,27 +202,15 @@ class FileService:
user_email = shared_item["user_email"] user_email = shared_item["user_email"]
item_path = shared_item["item_path"] # This is the path of the originally shared item (file or folder) item_path = shared_item["item_path"] # This is the path of the originally shared item (file or folder)
# Construct the full path to the originally shared item target_path = item_path
full_shared_item_path = self._get_user_file_path(user_email, item_path)
target_file_path = full_shared_item_path
if requested_file_path: if requested_file_path:
# If a specific file within a shared folder is requested target_path = requested_file_path
target_file_path = self._get_user_file_path(user_email, requested_file_path)
# Security check: Ensure the requested file is actually within the shared item's directory # Security check: Ensure the requested file is actually within the shared item's directory
try: if not target_path.startswith(item_path + '/'):
target_file_path.relative_to(full_shared_item_path)
except ValueError:
logger.warning(f"get_shared_file_content: Requested file path '{requested_file_path}' is not within shared item path '{item_path}' for share_id: {share_id}") logger.warning(f"get_shared_file_content: Requested file path '{requested_file_path}' is not within shared item path '{item_path}' for share_id: {share_id}")
return None return None
if target_file_path.is_file(): return await self.download_file(user_email, target_path)
async with aiofiles.open(target_file_path, mode="rb") as f:
content = await f.read()
logger.info(f"get_shared_file_content: Successfully read content for shared file: {target_file_path}")
return content, target_file_path.name
logger.warning(f"get_shared_file_content: Shared item path is not a file or does not exist: {target_file_path}")
return None
async def get_shared_folder_content(self, share_id: str) -> list | None: async def get_shared_folder_content(self, share_id: str) -> list | None:
"""Retrieves the content of a shared folder.""" """Retrieves the content of a shared folder."""
@ -199,10 +221,61 @@ class FileService:
user_email = shared_item["user_email"] user_email = shared_item["user_email"]
item_path = shared_item["item_path"] item_path = shared_item["item_path"]
full_path = self._get_user_file_path(user_email, item_path) metadata = await self._load_metadata(user_email)
if item_path not in metadata or metadata[item_path]["type"] != "dir":
if full_path.is_dir(): logger.warning(f"get_shared_folder_content: Shared item is not a directory: {item_path}")
logger.info(f"get_shared_folder_content: Listing files for shared folder: {full_path}")
return await self.list_files(user_email, item_path)
logger.warning(f"get_shared_folder_content: Shared item path is not a directory or does not exist: {full_path}")
return None return None
logger.info(f"get_shared_folder_content: Listing files for shared folder: {item_path}")
return await self.list_files(user_email, item_path)
async def migrate_old_files(self, user_email: str):
"""Migrate existing files from old file system to virtual system."""
old_user_dir = self.base_dir / user_email
if not old_user_dir.exists():
logger.info(f"No old files to migrate for {user_email}")
return
metadata = await self._load_metadata(user_email)
migrated_count = 0
for root, dirs, files in os.walk(old_user_dir):
rel_root = Path(root).relative_to(old_user_dir)
for dir_name in dirs:
dir_path = str(rel_root / dir_name) if str(rel_root) != "." else dir_name
if dir_path not in metadata:
metadata[dir_path] = {
"type": "dir",
"created_at": datetime.datetime.now().isoformat(),
"modified_at": datetime.datetime.now().isoformat(),
}
migrated_count += 1
for file_name in files:
file_path = rel_root / file_name
str_file_path = str(file_path) if str(rel_root) != "." else file_name
if str_file_path in metadata:
continue # Already migrated
full_file_path = old_user_dir / file_path
with open(full_file_path, 'rb') as f:
content = f.read()
hash = hashlib.sha256(content).hexdigest()
drive = self._choose_drive()
dir1 = hash[:3]
dir2 = hash[3:6]
dir3 = hash[6:9]
blob_path = self.drives_dir / drive / dir1 / dir2 / dir3 / hash
if not blob_path.exists():
blob_path.parent.mkdir(parents=True, exist_ok=True)
with open(blob_path, 'wb') as f:
f.write(content)
metadata[str_file_path] = {
"type": "file",
"size": len(content),
"hash": hash,
"blob_location": {"drive": drive, "path": f"{dir1}/{dir2}/{dir3}/{hash}"},
"created_at": datetime.datetime.fromtimestamp(full_file_path.stat().st_ctime).isoformat(),
"modified_at": datetime.datetime.fromtimestamp(full_file_path.stat().st_mtime).isoformat(),
}
migrated_count += 1
await self._save_metadata(user_email, metadata)
logger.info(f"Migrated {migrated_count} items for {user_email}")
# Optionally remove old dir
# shutil.rmtree(old_user_dir)

View File

@ -0,0 +1,172 @@
import os
import json
import hashlib
import aiofiles
from pathlib import Path
from typing import Any, Dict, List, Optional
class StorageService:
def __init__(self, base_path: str = "data/user"):
self.base_path = Path(base_path)
self.base_path.mkdir(parents=True, exist_ok=True)
def _hash(self, value: str) -> str:
return hashlib.sha256(value.encode()).hexdigest()
def _get_user_base_path(self, user_email: str) -> Path:
user_hash = self._hash(user_email)
return self.base_path / user_hash
def _get_distributed_path(self, base_path: Path, identifier: str) -> Path:
obj_hash = self._hash(identifier)
dir1 = obj_hash[:3]
dir2 = obj_hash[3:6]
dir3 = obj_hash[6:9]
return base_path / dir1 / dir2 / dir3 / f"{obj_hash}.json"
def _validate_path(self, path: Path, expected_base: Path) -> bool:
try:
resolved_path = path.resolve()
resolved_base = expected_base.resolve()
return str(resolved_path).startswith(str(resolved_base))
except (ValueError, OSError):
return False
async def save(self, user_email: str, identifier: str, data: Dict[str, Any]) -> bool:
user_base = self._get_user_base_path(user_email)
file_path = self._get_distributed_path(user_base, identifier)
if not self._validate_path(file_path, user_base):
raise ValueError("Invalid path: directory traversal detected")
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, 'w') as f:
await f.write(json.dumps(data, indent=2))
return True
async def load(self, user_email: str, identifier: str) -> Optional[Dict[str, Any]]:
user_base = self._get_user_base_path(user_email)
file_path = self._get_distributed_path(user_base, identifier)
if not self._validate_path(file_path, user_base):
raise ValueError("Invalid path: directory traversal detected")
if not file_path.exists():
return None
async with aiofiles.open(file_path, 'r') as f:
content = await f.read()
if not content:
return {}
try:
return json.loads(content)
except json.JSONDecodeError:
return {}
async def delete(self, user_email: str, identifier: str) -> bool:
user_base = self._get_user_base_path(user_email)
file_path = self._get_distributed_path(user_base, identifier)
if not self._validate_path(file_path, user_base):
raise ValueError("Invalid path: directory traversal detected")
if file_path.exists():
file_path.unlink()
return True
return False
async def exists(self, user_email: str, identifier: str) -> bool:
user_base = self._get_user_base_path(user_email)
file_path = self._get_distributed_path(user_base, identifier)
if not self._validate_path(file_path, user_base):
raise ValueError("Invalid path: directory traversal detected")
return file_path.exists()
async def list_all(self, user_email: str) -> List[Dict[str, Any]]:
user_base = self._get_user_base_path(user_email)
if not user_base.exists():
return []
results = []
for json_file in user_base.rglob("*.json"):
if self._validate_path(json_file, user_base):
async with aiofiles.open(json_file, 'r') as f:
content = await f.read()
results.append(json.loads(content))
return results
async def delete_all(self, user_email: str) -> bool:
user_base = self._get_user_base_path(user_email)
if not user_base.exists():
return False
import shutil
shutil.rmtree(user_base)
return True
def get_user_storage_path(self, user_email: str) -> str:
return str(self._get_user_base_path(user_email))
class UserStorageManager:
def __init__(self):
self.storage = StorageService()
async def save_user(self, user_email: str, user_data: Dict[str, Any]) -> bool:
return await self.storage.save(user_email, user_email, user_data)
async def get_user(self, user_email: str) -> Optional[Dict[str, Any]]:
return await self.storage.load(user_email, user_email)
async def delete_user(self, user_email: str) -> bool:
return await self.storage.delete_all(user_email)
async def user_exists(self, user_email: str) -> bool:
return await self.storage.exists(user_email, user_email)
async def list_users_by_parent(self, parent_email: str) -> List[Dict[str, Any]]:
all_users = []
base_path = Path(self.storage.base_path)
if not base_path.exists():
return []
for user_dir in base_path.iterdir():
if user_dir.is_dir():
user_files = list(user_dir.rglob("*.json"))
for user_file in user_files:
async with aiofiles.open(user_file, 'r') as f:
content = await f.read()
user_data = json.loads(content)
if user_data.get('parent_email') == parent_email:
all_users.append(user_data)
return all_users
async def get_all_users(self) -> List[Dict[str, Any]]:
all_users = []
base_path = Path(self.storage.base_path)
if not base_path.exists():
return []
for user_dir in base_path.iterdir():
if user_dir.is_dir():
user_files = list(user_dir.rglob("*.json"))
for user_file in user_files:
async with aiofiles.open(user_file, 'r') as f:
content = await f.read()
all_users.append(json.loads(content))
return all_users

View File

@ -1,13 +1,20 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
import bcrypt # Import bcrypt import bcrypt
import secrets # For generating secure tokens import secrets
import datetime # For token expiry import datetime
from .storage_service import UserStorageManager
class UserService: class UserService:
def __init__(self, users_path: Path): def __init__(self, users_path: Path = None, use_isolated_storage: bool = True):
self.use_isolated_storage = use_isolated_storage
if use_isolated_storage:
self._storage_manager = UserStorageManager()
self._users = []
else:
self._users_path = users_path self._users_path = users_path
self._users = self._load_users() self._users = self._load_users()
@ -21,40 +28,64 @@ class UserService:
return [] return []
def _save_users(self): def _save_users(self):
if self.use_isolated_storage:
return
with open(self._users_path, "w") as f: with open(self._users_path, "w") as f:
json.dump(self._users, f, indent=4) json.dump(self._users, f, indent=4)
def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
if self.use_isolated_storage:
return await self._storage_manager.get_user(email)
return next((user for user in self._users if user["email"] == email), None) return next((user for user in self._users if user["email"] == email), None)
def get_all_users(self) -> List[Dict[str, Any]]: async def get_all_users(self) -> List[Dict[str, Any]]:
if self.use_isolated_storage:
return await self._storage_manager.get_all_users()
return self._users return self._users
def get_users_by_parent_email(self, parent_email: str) -> List[Dict[str, Any]]: async def get_users_by_parent_email(self, parent_email: str) -> List[Dict[str, Any]]:
if self.use_isolated_storage:
return await self._storage_manager.list_users_by_parent(parent_email)
return [user for user in self._users if user.get("parent_email") == parent_email] return [user for user in self._users if user.get("parent_email") == parent_email]
def create_user(self, full_name: str, email: str, password: str, parent_email: Optional[str] = None) -> Dict[str, Any]: async def get_managed_users(self, user_email: str) -> List[Dict[str, Any]]:
if self.get_user_by_email(email): user = await self.get_user_by_email(user_email)
if not user:
return []
if not user.get("is_customer", True):
return await self.get_all_users()
return await self.get_users_by_parent_email(user_email)
async def create_user(self, full_name: str, email: str, password: str, parent_email: Optional[str] = None, is_customer: bool = True) -> Dict[str, Any]:
existing_user = await self.get_user_by_email(email)
if existing_user:
raise ValueError("User with this email already exists") raise ValueError("User with this email already exists")
# Hash password with bcrypt
hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
user = { user = {
"full_name": full_name, "full_name": full_name,
"email": email, "email": email,
"password": hashed_password, "password": hashed_password,
"storage_quota_gb": 5, # Default quota "storage_quota_gb": 5,
"storage_used_gb": 0, "storage_used_gb": 0,
"reset_token": None, "reset_token": None,
"reset_token_expiry": None, "reset_token_expiry": None,
"parent_email": parent_email, # New field for hierarchical user management "parent_email": parent_email,
"is_customer": is_customer,
} }
if self.use_isolated_storage:
await self._storage_manager.save_user(email, user)
else:
self._users.append(user) self._users.append(user)
self._save_users() self._save_users()
return user return user
def update_user(self, email: str, **kwargs) -> Optional[Dict[str, Any]]: async def update_user(self, email: str, **kwargs) -> Optional[Dict[str, Any]]:
user = self.get_user_by_email(email) user = await self.get_user_by_email(email)
if not user: if not user:
return None return None
@ -63,10 +94,18 @@ class UserService:
user[key] = bcrypt.hashpw(value.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') user[key] = bcrypt.hashpw(value.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
else: else:
user[key] = value user[key] = value
if self.use_isolated_storage:
await self._storage_manager.save_user(email, user)
else:
self._save_users() self._save_users()
return user return user
def delete_user(self, email: str) -> bool: async def delete_user(self, email: str) -> bool:
if self.use_isolated_storage:
return await self._storage_manager.delete_user(email)
initial_len = len(self._users) initial_len = len(self._users)
self._users = [user for user in self._users if user["email"] != email] self._users = [user for user in self._users if user["email"] != email]
if len(self._users) < initial_len: if len(self._users) < initial_len:
@ -74,23 +113,32 @@ class UserService:
return True return True
return False return False
def delete_users_by_parent_email(self, parent_email: str) -> int: async def delete_users_by_parent_email(self, parent_email: str) -> int:
users_to_delete = await self.get_users_by_parent_email(parent_email)
deleted_count = 0
if self.use_isolated_storage:
for user in users_to_delete:
if await self._storage_manager.delete_user(user["email"]):
deleted_count += 1
else:
initial_len = len(self._users) initial_len = len(self._users)
self._users = [user for user in self._users if user.get("parent_email") != parent_email] self._users = [user for user in self._users if user.get("parent_email") != parent_email]
deleted_count = initial_len - len(self._users) deleted_count = initial_len - len(self._users)
if deleted_count > 0: if deleted_count > 0:
self._save_users() self._save_users()
return deleted_count return deleted_count
def authenticate_user(self, email: str, password: str) -> bool: async def authenticate_user(self, email: str, password: str) -> bool:
user = self.get_user_by_email(email) user = await self.get_user_by_email(email)
if not user: if not user:
return False return False
# Verify password with bcrypt
return bcrypt.checkpw(password.encode('utf-8'), user["password"].encode('utf-8')) return bcrypt.checkpw(password.encode('utf-8'), user["password"].encode('utf-8'))
def get_user_by_reset_token(self, token: str) -> Optional[Dict[str, Any]]: async def get_user_by_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
for user in self._users: all_users = await self.get_all_users()
for user in all_users:
if user.get("reset_token") == token: if user.get("reset_token") == token:
expiry_str = user.get("reset_token_expiry") expiry_str = user.get("reset_token_expiry")
if expiry_str: if expiry_str:
@ -99,20 +147,25 @@ class UserService:
return user return user
return None return None
def generate_reset_token(self, email: str) -> Optional[str]: async def generate_reset_token(self, email: str) -> Optional[str]:
user = self.get_user_by_email(email) user = await self.get_user_by_email(email)
if not user: if not user:
return None return None
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) # Token valid for 1 hour expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
user["reset_token"] = token user["reset_token"] = token
user["reset_token_expiry"] = expiry.isoformat() user["reset_token_expiry"] = expiry.isoformat()
if self.use_isolated_storage:
await self._storage_manager.save_user(email, user)
else:
self._save_users() self._save_users()
return token return token
def validate_reset_token(self, email: str, token: str) -> bool: async def validate_reset_token(self, email: str, token: str) -> bool:
user = self.get_user_by_email(email) user = await self.get_user_by_email(email)
if not user or user.get("reset_token") != token: if not user or user.get("reset_token") != token:
return False return False
@ -123,25 +176,34 @@ class UserService:
expiry = datetime.datetime.fromisoformat(expiry_str) expiry = datetime.datetime.fromisoformat(expiry_str)
return expiry > datetime.datetime.now(datetime.timezone.utc) return expiry > datetime.datetime.now(datetime.timezone.utc)
def reset_password(self, email: str, token: str, new_password: str) -> bool: async def reset_password(self, email: str, token: str, new_password: str) -> bool:
if not self.validate_reset_token(email, token): if not await self.validate_reset_token(email, token):
return False return False
user = self.get_user_by_email(email) user = await self.get_user_by_email(email)
if not user: # Should not happen if validate_reset_token passed, but for type safety if not user:
return False return False
hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
user["password"] = hashed_password user["password"] = hashed_password
user["reset_token"] = None user["reset_token"] = None
user["reset_token_expiry"] = None user["reset_token_expiry"] = None
if self.use_isolated_storage:
await self._storage_manager.save_user(email, user)
else:
self._save_users() self._save_users()
return True return True
def update_user_quota(self, email: str, new_quota_gb: float): async def update_user_quota(self, email: str, new_quota_gb: float):
user = self.get_user_by_email(email) user = await self.get_user_by_email(email)
if not user: if not user:
raise ValueError("User not found") raise ValueError("User not found")
user["storage_quota_gb"] = new_quota_gb user["storage_quota_gb"] = new_quota_gb
if self.use_isolated_storage:
await self._storage_manager.save_user(email, user)
else:
self._save_users() self._save_users()

View File

@ -43,7 +43,7 @@
</div> </div>
<div class="form-group"> <div class="form-group">
<label for="storage_quota_gb">Storage Quota (GB)</label> <label for="storage_quota_gb">Storage Quota (GB)</label>
<input type="number" id="storage_quota_gb" name="storage_quota_gb" min="1" value="10" required class="form-input"> <custom-slider min="1" max="1000" value="{{ form_data.storage_quota_gb if form_data and form_data.storage_quota_gb else 10 }}" step="1" name="storage_quota_gb"></custom-slider>
</div> </div>
<div style="display: flex; gap: 10px; margin-top: 20px;"> <div style="display: flex; gap: 10px; margin-top: 20px;">
<button type="submit" class="btn-primary">Add User</button> <button type="submit" class="btn-primary">Add User</button>

View File

@ -49,7 +49,7 @@
<form action="/users/{{ user_data.email }}/edit" method="post" class="order-form"> <form action="/users/{{ user_data.email }}/edit" method="post" class="order-form">
<div class="form-group"> <div class="form-group">
<label for="storage_quota_gb">Storage Quota (GB)</label> <label for="storage_quota_gb">Storage Quota (GB)</label>
<input type="number" id="storage_quota_gb" name="storage_quota_gb" min="1" value="{{ user_data.storage_quota_gb }}" required class="form-input"> <custom-slider min="1" max="1000" value="{{ user_data.storage_quota_gb }}" step="1" name="storage_quota_gb"></custom-slider>
</div> </div>
<div style="display: flex; gap: 10px; margin-top: 20px;"> <div style="display: flex; gap: 10px; margin-top: 20px;">
<button type="submit" class="btn-primary">Update Quota</button> <button type="submit" class="btn-primary">Update Quota</button>

View File

@ -5,7 +5,11 @@
{% block content %} {% block content %}
<div class="error-page-container"> <div class="error-page-container">
<h1>404 - Page Not Found</h1> <h1>404 - Page Not Found</h1>
{% if message %}
<p>{{ message }}</p>
{% else %}
<p>Oops! The page you're looking for doesn't exist or has been moved.</p> <p>Oops! The page you're looking for doesn't exist or has been moved.</p>
{% endif %}
<a href="/" class="btn-primary">Go to Homepage</a> <a href="/" class="btn-primary">Go to Homepage</a>
</div> </div>
{% endblock %} {% endblock %}

View File

@ -2,11 +2,121 @@
{% block title %}Favorites - Retoor's Cloud Solutions{% endblock %} {% block title %}Favorites - Retoor's Cloud Solutions{% endblock %}
{% block dashboard_head %}
<link rel="stylesheet" href="/static/css/components/file_browser.css">
{% endblock %}
{% block page_title %}Favorites{% endblock %} {% block page_title %}Favorites{% endblock %}
{% block dashboard_content %} {% block dashboard_actions %}
<div class="content-section"> <button class="btn-outline" id="download-selected-btn" disabled>⬇️</button>
<p>Your favorite files and folders will appear here.</p> <button class="btn-outline" id="share-selected-btn" disabled>🔗</button>
<p>This feature is coming soon.</p> <button class="btn-outline" id="remove-favorite-selected-btn" disabled>Remove from Favorites</button>
</div> {% endblock %}
{% block dashboard_content %}
{% if success_message %}
<div class="alert alert-success">
{{ success_message }}
</div>
{% endif %}
{% if error_message %}
<div class="alert alert-error">
{{ error_message }}
</div>
{% endif %}
<input type="text" class="file-search-bar" placeholder="Search favorites..." id="search-bar">
<div class="file-list-table">
<table>
<thead>
<tr>
<th><input type="checkbox" id="select-all"></th>
<th>Name</th>
<th>Owner</th>
<th>Favorited on</th>
<th>Size</th>
<th>Actions</th>
</tr>
</thead>
<tbody id="file-list-body">
{% if favorite_files %}
{% for item in favorite_files %}
<tr data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}">
<td><input type="checkbox" class="file-checkbox" data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}"></td>
<td>
{% if item.is_dir %}
<img src="/static/images/icon-families.svg" alt="Folder Icon" class="file-icon">
<a href="/files?path={{ item.path }}">{{ item.name }}</a>
{% else %}
<img src="/static/images/icon-professionals.svg" alt="File Icon" class="file-icon">
{{ item.name }}
{% endif %}
</td>
<td>{{ user.email }}</td>
<td>{{ item.favorited_on[:10] }}</td>
<td>
{% if item.is_dir %}
--
{% else %}
{{ (item.size / 1024 / 1024)|round(2) }} MB
{% endif %}
</td>
<td>
<div class="action-buttons">
{% if not item.is_dir %}
<button class="btn-small download-file-btn" data-path="{{ item.path }}">⬇️</button>
{% endif %}
<button class="btn-small share-file-btn" data-path="{{ item.path }}" data-name="{{ item.name }}">🔗</button>
<button class="btn-small btn-danger remove-favorite-btn" data-path="{{ item.path }}" data-name="{{ item.name }}">Remove</button>
</div>
</td>
</tr>
{% endfor %}
{% else %}
<tr>
<td colspan="6" style="text-align: center; padding: 40px;">
<p>No favorite files.</p>
</td>
</tr>
{% endif %}
</tbody>
</table>
</div>
<div id="share-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('share-modal')">&times;</span>
<h3>Share File</h3>
<p id="share-file-name"></p>
<div id="share-link-container" style="display: none;">
<input type="text" id="share-link-input" readonly class="form-input">
<button class="btn-primary" id="copy-share-link-btn">Copy Link</button>
<div id="share-links-list" class="share-links-list"></div>
</div>
<div id="share-loading">Generating share link...</div>
<div class="modal-actions">
<button type="button" class="btn-outline" onclick="closeModal('share-modal')">Close</button>
</div>
</div>
</div>
<div id="remove-favorite-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('remove-favorite-modal')">&times;</span>
<h3>Remove from Favorites</h3>
<p id="remove-favorite-message"></p>
<form id="remove-favorite-form" method="post">
<div class="modal-actions">
<button type="submit" class="btn-danger">Remove</button>
<button type="button" class="btn-outline" onclick="closeModal('remove-favorite-modal')">Cancel</button>
</div>
</form>
</div>
</div>
<script type="module" src="/static/js/main.js"></script>
{% endblock %} {% endblock %}

View File

@ -2,11 +2,121 @@
{% block title %}Recent Files - Retoor's Cloud Solutions{% endblock %} {% block title %}Recent Files - Retoor's Cloud Solutions{% endblock %}
{% block dashboard_head %}
<link rel="stylesheet" href="/static/css/components/file_browser.css">
{% endblock %}
{% block page_title %}Recent Files{% endblock %} {% block page_title %}Recent Files{% endblock %}
{% block dashboard_content %} {% block dashboard_actions %}
<div class="content-section"> <button class="btn-outline" id="download-selected-btn" disabled>⬇️</button>
<p>Your recently accessed files will appear here.</p> <button class="btn-outline" id="share-selected-btn" disabled>🔗</button>
<p>This feature is coming soon.</p> <button class="btn-outline" id="add-favorite-selected-btn" disabled>Add to Favorites</button>
</div> {% endblock %}
{% block dashboard_content %}
{% if success_message %}
<div class="alert alert-success">
{{ success_message }}
</div>
{% endif %}
{% if error_message %}
<div class="alert alert-error">
{{ error_message }}
</div>
{% endif %}
<input type="text" class="file-search-bar" placeholder="Search recent files..." id="search-bar">
<div class="file-list-table">
<table>
<thead>
<tr>
<th><input type="checkbox" id="select-all"></th>
<th>Name</th>
<th>Owner</th>
<th>Last Accessed</th>
<th>Size</th>
<th>Actions</th>
</tr>
</thead>
<tbody id="file-list-body">
{% if recent_files %}
{% for item in recent_files %}
<tr data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}">
<td><input type="checkbox" class="file-checkbox" data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}"></td>
<td>
{% if item.is_dir %}
<img src="/static/images/icon-families.svg" alt="Folder Icon" class="file-icon">
<a href="/files?path={{ item.path }}">{{ item.name }}</a>
{% else %}
<img src="/static/images/icon-professionals.svg" alt="File Icon" class="file-icon">
{{ item.name }}
{% endif %}
</td>
<td>{{ user.email }}</td>
<td>{{ item.last_accessed[:10] }}</td>
<td>
{% if item.is_dir %}
--
{% else %}
{{ (item.size / 1024 / 1024)|round(2) }} MB
{% endif %}
</td>
<td>
<div class="action-buttons">
{% if not item.is_dir %}
<button class="btn-small download-file-btn" data-path="{{ item.path }}">⬇️</button>
{% endif %}
<button class="btn-small share-file-btn" data-path="{{ item.path }}" data-name="{{ item.name }}">🔗</button>
<button class="btn-small add-favorite-btn" data-path="{{ item.path }}" data-name="{{ item.name }}"></button>
</div>
</td>
</tr>
{% endfor %}
{% else %}
<tr>
<td colspan="6" style="text-align: center; padding: 40px;">
<p>No recent files.</p>
</td>
</tr>
{% endif %}
</tbody>
</table>
</div>
<div id="share-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('share-modal')">&times;</span>
<h3>Share File</h3>
<p id="share-file-name"></p>
<div id="share-link-container" style="display: none;">
<input type="text" id="share-link-input" readonly class="form-input">
<button class="btn-primary" id="copy-share-link-btn">Copy Link</button>
<div id="share-links-list" class="share-links-list"></div>
</div>
<div id="share-loading">Generating share link...</div>
<div class="modal-actions">
<button type="button" class="btn-outline" onclick="closeModal('share-modal')">Close</button>
</div>
</div>
</div>
<div id="add-favorite-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('add-favorite-modal')">&times;</span>
<h3>Add to Favorites</h3>
<p id="add-favorite-message"></p>
<form id="add-favorite-form" method="post">
<div class="modal-actions">
<button type="submit" class="btn-primary">Add</button>
<button type="button" class="btn-outline" onclick="closeModal('add-favorite-modal')">Cancel</button>
</div>
</form>
</div>
</div>
<script type="module" src="/static/js/main.js"></script>
{% endblock %} {% endblock %}

View File

@ -2,11 +2,155 @@
{% block title %}Shared with me - Retoor's Cloud Solutions{% endblock %} {% block title %}Shared with me - Retoor's Cloud Solutions{% endblock %}
{% block dashboard_head %}
<link rel="stylesheet" href="/static/css/components/file_browser.css">
{% endblock %}
{% block page_title %}Shared with me{% endblock %} {% block page_title %}Shared with me{% endblock %}
{% block dashboard_content %} {% block dashboard_actions %}
<div class="content-section"> <button class="btn-primary" id="new-folder-btn">+ New</button>
<p>Files and folders that have been shared with you will appear here.</p> <button class="btn-outline" id="upload-btn">Upload</button>
<p>This feature is coming soon.</p> <button class="btn-outline" id="download-selected-btn" disabled>⬇️</button>
</div> <button class="btn-outline" id="share-selected-btn" disabled>🔗</button>
<button class="btn-outline" id="delete-selected-btn" disabled>Delete</button>
{% endblock %}
{% block dashboard_content %}
{% if success_message %}
<div class="alert alert-success">
{{ success_message }}
</div>
{% endif %}
{% if error_message %}
<div class="alert alert-error">
{{ error_message }}
</div>
{% endif %}
<input type="text" class="file-search-bar" placeholder="Search shared files..." id="search-bar">
<div class="file-list-table">
<table>
<thead>
<tr>
<th><input type="checkbox" id="select-all"></th>
<th>Name</th>
<th>Shared by</th>
<th>Last Modified</th>
<th>Size</th>
<th>Actions</th>
</tr>
</thead>
<tbody id="file-list-body">
{% if shared_files %}
{% for item in shared_files %}
<tr data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}">
<td><input type="checkbox" class="file-checkbox" data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}"></td>
<td>
{% if item.is_dir %}
<img src="/static/images/icon-families.svg" alt="Folder Icon" class="file-icon">
<a href="/shared?path={{ item.path }}">{{ item.name }}</a>
{% else %}
<img src="/static/images/icon-professionals.svg" alt="File Icon" class="file-icon">
{{ item.name }}
{% endif %}
</td>
<td>{{ item.shared_by }}</td>
<td>{{ item.last_modified[:10] }}</td>
<td>
{% if item.is_dir %}
--
{% else %}
{{ (item.size / 1024 / 1024)|round(2) }} MB
{% endif %}
</td>
<td>
<div class="action-buttons">
{% if not item.is_dir %}
<button class="btn-small download-file-btn" data-path="{{ item.path }}">⬇️</button>
{% endif %}
<button class="btn-small share-file-btn" data-path="{{ item.path }}" data-name="{{ item.name }}">🔗</button>
<button class="btn-small btn-danger delete-file-btn" data-path="{{ item.path }}" data-name="{{ item.name }}">Delete</button>
</div>
</td>
</tr>
{% endfor %}
{% else %}
<tr>
<td colspan="6" style="text-align: center; padding: 40px;">
<p>No shared files found.</p>
</td>
</tr>
{% endif %}
</tbody>
</table>
</div>
<div id="new-folder-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('new-folder-modal')">&times;</span>
<h3>Create New Folder</h3>
<form action="/shared/new_folder" method="post">
<input type="text" name="folder_name" placeholder="Folder name" required class="form-input">
<div class="modal-actions">
<button type="submit" class="btn-primary">Create</button>
<button type="button" class="btn-outline" onclick="closeModal('new-folder-modal')">Cancel</button>
</div>
</form>
</div>
</div>
<div id="upload-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('upload-modal')">&times;</span>
<h3>Upload Files</h3>
<div class="upload-area">
<input type="file" name="file" multiple class="form-input" id="file-input-multiple" style="display: none;">
<label for="file-input-multiple" class="btn-outline upload-button">Select Files</label>
<div id="selected-files-preview" class="selected-files-preview"></div>
<div id="upload-progress-container" class="upload-progress-container"></div>
</div>
<div class="modal-actions">
<button type="button" class="btn-primary" id="start-upload-btn" disabled>Upload</button>
<button type="button" class="btn-outline" onclick="closeModal('upload-modal')">Cancel</button>
</div>
</div>
</div>
<div id="share-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('share-modal')">&times;</span>
<h3>Share File</h3>
<p id="share-file-name"></p>
<div id="share-link-container" style="display: none;">
<input type="text" id="share-link-input" readonly class="form-input">
<button class="btn-primary" id="copy-share-link-btn">Copy Link</button>
<div id="share-links-list" class="share-links-list"></div>
</div>
<div id="share-loading">Generating share link...</div>
<div class="modal-actions">
<button type="button" class="btn-outline" onclick="closeModal('share-modal')">Close</button>
</div>
</div>
</div>
<div id="delete-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('delete-modal')">&times;</span>
<h3>Confirm Delete</h3>
<p id="delete-message"></p>
<form id="delete-form" method="post">
<div class="modal-actions">
<button type="submit" class="btn-danger">Delete</button>
<button type="button" class="btn-outline" onclick="closeModal('delete-modal')">Cancel</button>
</div>
</form>
</div>
</div>
<script type="module" src="/static/js/components/upload.js"></script>
<script type="module" src="/static/js/main.js"></script>
{% endblock %} {% endblock %}

View File

@ -2,11 +2,114 @@
{% block title %}Trash - Retoor's Cloud Solutions{% endblock %} {% block title %}Trash - Retoor's Cloud Solutions{% endblock %}
{% block dashboard_head %}
<link rel="stylesheet" href="/static/css/components/file_browser.css">
{% endblock %}
{% block page_title %}Trash{% endblock %} {% block page_title %}Trash{% endblock %}
{% block dashboard_content %} {% block dashboard_actions %}
<div class="content-section"> <button class="btn-outline" id="restore-selected-btn" disabled>Restore</button>
<p>Files and folders you have deleted will appear here.</p> <button class="btn-outline" id="delete-selected-btn" disabled>Delete Permanently</button>
<p>This feature is coming soon.</p> {% endblock %}
</div>
{% block dashboard_content %}
{% if success_message %}
<div class="alert alert-success">
{{ success_message }}
</div>
{% endif %}
{% if error_message %}
<div class="alert alert-error">
{{ error_message }}
</div>
{% endif %}
<input type="text" class="file-search-bar" placeholder="Search trash..." id="search-bar">
<div class="file-list-table">
<table>
<thead>
<tr>
<th><input type="checkbox" id="select-all"></th>
<th>Name</th>
<th>Owner</th>
<th>Deleted on</th>
<th>Size</th>
<th>Actions</th>
</tr>
</thead>
<tbody id="file-list-body">
{% if trash_files %}
{% for item in trash_files %}
<tr data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}">
<td><input type="checkbox" class="file-checkbox" data-path="{{ item.path }}" data-is-dir="{{ item.is_dir }}"></td>
<td>
{% if item.is_dir %}
<img src="/static/images/icon-families.svg" alt="Folder Icon" class="file-icon">
{{ item.name }}
{% else %}
<img src="/static/images/icon-professionals.svg" alt="File Icon" class="file-icon">
{{ item.name }}
{% endif %}
</td>
<td>{{ user.email }}</td>
<td>{{ item.deleted_on[:10] }}</td>
<td>
{% if item.is_dir %}
--
{% else %}
{{ (item.size / 1024 / 1024)|round(2) }} MB
{% endif %}
</td>
<td>
<div class="action-buttons">
<button class="btn-small restore-file-btn" data-path="{{ item.path }}">Restore</button>
<button class="btn-small btn-danger delete-file-btn" data-path="{{ item.path }}" data-name="{{ item.name }}">Delete Permanently</button>
</div>
</td>
</tr>
{% endfor %}
{% else %}
<tr>
<td colspan="6" style="text-align: center; padding: 40px;">
<p>No files in trash.</p>
</td>
</tr>
{% endif %}
</tbody>
</table>
</div>
<div id="restore-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('restore-modal')">&times;</span>
<h3>Confirm Restore</h3>
<p id="restore-message"></p>
<form id="restore-form" method="post">
<div class="modal-actions">
<button type="submit" class="btn-primary">Restore</button>
<button type="button" class="btn-outline" onclick="closeModal('restore-modal')">Cancel</button>
</div>
</form>
</div>
</div>
<div id="delete-modal" class="modal">
<div class="modal-content">
<span class="close" onclick="closeModal('delete-modal')">&times;</span>
<h3>Confirm Permanent Delete</h3>
<p id="delete-message"></p>
<form id="delete-form" method="post">
<div class="modal-actions">
<button type="submit" class="btn-danger">Delete Permanently</button>
<button type="button" class="btn-outline" onclick="closeModal('delete-modal')">Cancel</button>
</div>
</form>
</div>
</div>
<script type="module" src="/static/js/main.js"></script>
{% endblock %} {% endblock %}

View File

@ -3,6 +3,25 @@ from aiohttp_session import get_session
from ..services.user_service import UserService from ..services.user_service import UserService
from ..models import QuotaUpdateModel, RegistrationModel from ..models import QuotaUpdateModel, RegistrationModel
async def verify_user_access(user_service: UserService, current_user_email: str, target_user_email: str) -> tuple[bool, dict]:
if target_user_email == current_user_email:
return True, None
current_user = await user_service.get_user_by_email(current_user_email)
if not current_user:
return False, {"error": "Current user not found", "status": 401}
target_user = await user_service.get_user_by_email(target_user_email)
if not target_user:
return False, {"error": "Target user not found", "status": 404}
if current_user.get("is_customer", True):
if target_user.get("parent_email") != current_user_email:
return False, {"error": "Forbidden: You can only manage users you created", "status": 403}
return True, None
async def get_users(request: web.Request) -> web.Response: async def get_users(request: web.Request) -> web.Response:
user_service: UserService = request.app["user_service"] user_service: UserService = request.app["user_service"]
session = await get_session(request) session = await get_session(request)
@ -11,18 +30,8 @@ async def get_users(request: web.Request) -> web.Response:
if not current_user_email: if not current_user_email:
return web.json_response({"error": "Unauthorized"}, status=401) return web.json_response({"error": "Unauthorized"}, status=401)
# For now, let's assume only the main user can see all users. users = await user_service.get_managed_users(current_user_email)
# In a real application, you'd have roles/permissions.
# The main user is the one who created the account.
# If the current user is the main user, they can see all users.
# Otherwise, they can only see users they created (their "team").
# This logic needs to be refined based on how "main user" is identified.
# For now, let's return all users for simplicity, assuming the logged-in user has admin-like access to this page.
# A more robust solution would involve checking if the current_user_email is the 'owner' of the site.
users = user_service.get_all_users()
# Filter out sensitive information like password and reset tokens
safe_users = [] safe_users = []
for user in users: for user in users:
safe_user = {k: v for k, v in user.items() if k not in ["password", "reset_token", "reset_token_expiry"]} safe_user = {k: v for k, v in user.items() if k not in ["password", "reset_token", "reset_token_expiry"]}
@ -42,12 +51,11 @@ async def add_user(request: web.Request) -> web.Response:
data = await request.json() data = await request.json()
registration_data = RegistrationModel(**data) registration_data = RegistrationModel(**data)
# The current user is the parent of the new user new_user = await user_service.create_user(
new_user = user_service.create_user(
full_name=registration_data.full_name, full_name=registration_data.full_name,
email=registration_data.email, email=registration_data.email,
password=registration_data.password, password=registration_data.password,
parent_email=current_user_email # Assign current user as parent parent_email=current_user_email
) )
safe_new_user = {k: v for k, v in new_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]} safe_new_user = {k: v for k, v in new_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]}
return web.json_response({"message": "User added successfully", "user": safe_new_user}, status=201) return web.json_response({"message": "User added successfully", "user": safe_new_user}, status=201)
@ -68,16 +76,14 @@ async def update_user_quota(request: web.Request) -> web.Response:
if not target_user_email: if not target_user_email:
return web.json_response({"error": "User email not provided"}, status=400) return web.json_response({"error": "User email not provided"}, status=400)
# Ensure the current user has permission to update this user's quota has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email)
# For now, allow if current_user_email is the target_user_email or if target_user is a child of current_user if not has_access:
target_user = user_service.get_user_by_email(target_user_email) return web.json_response({"error": error_response["error"]}, status=error_response["status"])
if not target_user or (target_user_email != current_user_email and target_user.get("parent_email") != current_user_email):
return web.json_response({"error": "Forbidden: You do not have permission to update this user's quota"}, status=403)
try: try:
data = await request.json() data = await request.json()
quota_update_data = QuotaUpdateModel(**data) quota_update_data = QuotaUpdateModel(**data)
user_service.update_user_quota(target_user_email, quota_update_data.new_quota_gb) await user_service.update_user_quota(target_user_email, quota_update_data.new_quota_gb)
return web.json_response({"message": f"Quota for {target_user_email} updated successfully"}) return web.json_response({"message": f"Quota for {target_user_email} updated successfully"})
except ValueError as e: except ValueError as e:
return web.json_response({"error": str(e)}, status=400) return web.json_response({"error": str(e)}, status=400)
@ -96,15 +102,14 @@ async def delete_user(request: web.Request) -> web.Response:
if not target_user_email: if not target_user_email:
return web.json_response({"error": "User email not provided"}, status=400) return web.json_response({"error": "User email not provided"}, status=400)
# Prevent a user from deleting themselves or a parent user
if target_user_email == current_user_email: if target_user_email == current_user_email:
return web.json_response({"error": "Forbidden: You cannot delete your own account from this interface"}, status=403) return web.json_response({"error": "Forbidden: You cannot delete your own account from this interface"}, status=403)
target_user = user_service.get_user_by_email(target_user_email) has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email)
if not target_user or target_user.get("parent_email") != current_user_email: if not has_access:
return web.json_response({"error": "Forbidden: You do not have permission to delete this user"}, status=403) return web.json_response({"error": error_response["error"]}, status=error_response["status"])
if user_service.delete_user(target_user_email): if await user_service.delete_user(target_user_email):
return web.json_response({"message": f"User {target_user_email} deleted successfully"}) return web.json_response({"message": f"User {target_user_email} deleted successfully"})
else: else:
return web.json_response({"error": "User not found or could not be deleted"}, status=404) return web.json_response({"error": "User not found or could not be deleted"}, status=404)
@ -121,10 +126,11 @@ async def get_user_details(request: web.Request) -> web.Response:
if not target_user_email: if not target_user_email:
return web.json_response({"error": "User email not provided"}, status=400) return web.json_response({"error": "User email not provided"}, status=400)
target_user = user_service.get_user_by_email(target_user_email) has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email)
if not target_user or (target_user_email != current_user_email and target_user.get("parent_email") != current_user_email): if not has_access:
return web.json_response({"error": "Forbidden: You do not have permission to view this user's details"}, status=403) return web.json_response({"error": error_response["error"]}, status=error_response["status"])
target_user = await user_service.get_user_by_email(target_user_email)
safe_user = {k: v for k, v in target_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]} safe_user = {k: v for k, v in target_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]}
return web.json_response({"user": safe_user}) return web.json_response({"user": safe_user})
@ -140,11 +146,10 @@ async def delete_team(request: web.Request) -> web.Response:
if not target_parent_email: if not target_parent_email:
return web.json_response({"error": "Parent email not provided"}, status=400) return web.json_response({"error": "Parent email not provided"}, status=400)
# Only the parent user can delete their "team" (users they created)
if current_user_email != target_parent_email: if current_user_email != target_parent_email:
return web.json_response({"error": "Forbidden: You do not have permission to delete this team"}, status=403) return web.json_response({"error": "Forbidden: You do not have permission to delete this team"}, status=403)
deleted_count = user_service.delete_users_by_parent_email(target_parent_email) deleted_count = await user_service.delete_users_by_parent_email(target_parent_email)
if deleted_count > 0: if deleted_count > 0:
return web.json_response({"message": f"Successfully deleted {deleted_count} users from the team managed by {target_parent_email}"}) return web.json_response({"message": f"Successfully deleted {deleted_count} users from the team managed by {target_parent_email}"})
else: else:

View File

@ -55,7 +55,7 @@ class LoginView(CustomPydanticView):
) )
user_service: UserService = self.request.app["user_service"] user_service: UserService = self.request.app["user_service"]
if user_service.authenticate_user(login_data.email, login_data.password): if await user_service.authenticate_user(login_data.email, login_data.password):
session = await new_session(self.request) session = await new_session(self.request)
session["user_email"] = login_data.email session["user_email"] = login_data.email
raise web.HTTPFound("/dashboard") raise web.HTTPFound("/dashboard")
@ -93,7 +93,7 @@ class RegistrationView(CustomPydanticView):
user_service: UserService = self.request.app["user_service"] user_service: UserService = self.request.app["user_service"]
try: try:
user_service.create_user(user_data.full_name, user_data.email, user_data.password) # Changed username to full_name await user_service.create_user(user_data.full_name, user_data.email, user_data.password)
# Render email content # Render email content
email_context = { email_context = {
@ -148,10 +148,10 @@ class ForgotPasswordView(CustomPydanticView):
) )
user_service: UserService = self.request.app["user_service"] user_service: UserService = self.request.app["user_service"]
user = user_service.get_user_by_email(forgot_password_data.email) user = await user_service.get_user_by_email(forgot_password_data.email)
if user: if user:
token = user_service.generate_reset_token(forgot_password_data.email) token = await user_service.generate_reset_token(forgot_password_data.email)
if token: if token:
reset_link = self.request.url.join( reset_link = self.request.url.join(
self.request.app.router["reset_password"].url_for(token=token) self.request.app.router["reset_password"].url_for(token=token)
@ -217,7 +217,7 @@ class ResetPasswordView(CustomPydanticView):
) )
user_service: UserService = self.request.app["user_service"] user_service: UserService = self.request.app["user_service"]
user = user_service.get_user_by_reset_token(token) # Corrected method call user = await user_service.get_user_by_reset_token(token)
if not user: if not user:
return aiohttp_jinja2.render_template( return aiohttp_jinja2.render_template(
self.template_name, self.template_name,
@ -225,7 +225,7 @@ class ResetPasswordView(CustomPydanticView):
{"error": "Invalid or expired password reset link.", "request": self.request, "errors": {}, "token": token}, {"error": "Invalid or expired password reset link.", "request": self.request, "errors": {}, "token": token},
) )
if user_service.reset_password(user["email"], token, reset_password_data.password): if await user_service.reset_password(user["email"], token, reset_password_data.password):
# Send password changed confirmation email # Send password changed confirmation email
email_context = { email_context = {
"user_name": user["full_name"], "user_name": user["full_name"],

11
retoors/views/migrate.py Normal file
View File

@ -0,0 +1,11 @@
from aiohttp import web
from ..helpers.auth import login_required
class MigrateView(web.View):
@login_required
async def post(self):
user_email = self.request["user"]["email"]
file_service = self.request.app["file_service"]
await file_service.migrate_old_files(user_email)
return web.json_response({"status": "success", "message": "Migration completed"})

View File

@ -298,7 +298,7 @@ class FileBrowserView(web.View):
return json_response({"error": "Failed to generate share links for any selected items"}, status=500) return json_response({"error": "Failed to generate share links for any selected items"}, status=500)
logger.warning(f"FileBrowserView: Unknown file action for POST request: {route_name}") logger.warning(f"FileBrowserView: Unknown file action for POST request: {route_name}")
return web.Response(status=400, text="Unknown file action") raise web.HTTPBadRequest(text="Unknown file action")
@login_required @login_required
async def get_download_file(self): async def get_download_file(self):
@ -325,7 +325,7 @@ class FileBrowserView(web.View):
raise web.HTTPNotFound(text="File not found") raise web.HTTPNotFound(text="File not found")
async def shared_file_handler(self): async def shared_file_handler(self):
share_id = self.request.match_info.get("share_id") share_id = self.match_info.get("share_id")
file_service = self.request.app["file_service"] file_service = self.request.app["file_service"]
logger.debug(f"FileBrowserView: Handling shared file request for share_id: {share_id}") logger.debug(f"FileBrowserView: Handling shared file request for share_id: {share_id}")
@ -333,16 +333,11 @@ class FileBrowserView(web.View):
if not shared_item: if not shared_item:
logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id}") logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id}")
return aiohttp_jinja2.render_template( raise web.HTTPNotFound(text="Shared file not found or inaccessible")
"pages/errors/404.html",
self.request,
{"request": self.request, "message": "Shared link is invalid or has expired."},
status=404
)
user_email = shared_item["user_email"] user_email = shared_item["user_email"]
item_path = shared_item["item_path"] item_path = shared_item["item_path"]
full_path = file_service._get_user_file_path(user_email, item_path) full_path = file_service.get_user_file_system_path(user_email, item_path)
if full_path.is_file(): if full_path.is_file():
result = await file_service.get_shared_file_content(share_id) result = await file_service.get_shared_file_content(share_id)
@ -355,7 +350,7 @@ class FileBrowserView(web.View):
return response return response
else: else:
logger.error(f"FileBrowserView: Failed to get content for shared file: {item_path} (share_id: {share_id})") logger.error(f"FileBrowserView: Failed to get content for shared file: {item_path} (share_id: {share_id})")
raise web.HTTPNotFound(text="Shared file not found or inaccessible") raise web.HTTPNotFound(text="Shared file not found or inaccessible within the shared folder.")
elif full_path.is_dir(): elif full_path.is_dir():
files = await file_service.get_shared_folder_content(share_id) files = await file_service.get_shared_folder_content(share_id)
logger.info(f"FileBrowserView: Serving shared folder '{item_path}' for share_id: {share_id}") logger.info(f"FileBrowserView: Serving shared folder '{item_path}' for share_id: {share_id}")
@ -388,16 +383,11 @@ class FileBrowserView(web.View):
shared_item = await file_service.get_shared_item(share_id) shared_item = await file_service.get_shared_item(share_id)
if not shared_item: if not shared_item:
logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id} during download.") logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id} during download.")
return aiohttp_jinja2.render_template( raise web.HTTPNotFound(text="Shared file not found or inaccessible within the shared folder.")
"pages/errors/404.html",
self.request,
{"request": self.request, "message": "Shared link is invalid or has expired."},
status=404
)
# Ensure the shared item is a directory if a file_path is provided # Ensure the shared item is a directory if a file_path is provided
user_email = shared_item["user_email"] user_email = shared_item["user_email"]
original_shared_item_path = file_service._get_user_file_path(user_email, shared_item["item_path"]) original_shared_item_path = file_service.get_user_file_system_path(user_email, shared_item["item_path"])
if not original_shared_item_path.is_dir(): if not original_shared_item_path.is_dir():
logger.warning(f"FileBrowserView: Attempt to download a specific file from a shared item that is not a directory. Share_id: {share_id}") logger.warning(f"FileBrowserView: Attempt to download a specific file from a shared item that is not a directory. Share_id: {share_id}")
@ -449,6 +439,25 @@ class OrderView(CustomPydanticView):
class UserManagementView(web.View): class UserManagementView(web.View):
async def verify_user_access(self, target_user_email: str) -> bool:
current_user_email = self.request["user"]["email"]
if target_user_email == current_user_email:
return True
user_service = self.request.app["user_service"]
current_user = await user_service.get_user_by_email(current_user_email)
target_user = await user_service.get_user_by_email(target_user_email)
if not target_user:
return False
if current_user.get("is_customer", True):
if target_user.get("parent_email") != current_user_email:
return False
return True
@login_required @login_required
async def get(self): async def get(self):
route_name = self.request.match_info.route.name route_name = self.request.match_info.route.name
@ -531,14 +540,14 @@ class UserManagementView(web.View):
parent_email = self.request["user"]["email"] parent_email = self.request["user"]["email"]
try: try:
new_user = user_service.create_user( new_user = await user_service.create_user(
full_name=full_name, full_name=full_name,
email=email, email=email,
password=password, password=password,
parent_email=parent_email parent_email=parent_email
) )
user_service.update_user_quota(email, float(storage_quota_gb)) await user_service.update_user_quota(email, float(storage_quota_gb))
raise web.HTTPFound( raise web.HTTPFound(
self.request.app.router["users"].url_for().with_query( self.request.app.router["users"].url_for().with_query(
@ -562,8 +571,11 @@ class UserManagementView(web.View):
async def edit_user_page(self): async def edit_user_page(self):
email = self.request.match_info.get("email") email = self.request.match_info.get("email")
if not await self.verify_user_access(email):
raise web.HTTPForbidden(text="You do not have permission to access this user")
user_service = self.request.app["user_service"] user_service = self.request.app["user_service"]
user_data = user_service.get_user_by_email(email) user_data = await user_service.get_user_by_email(email)
if not user_data: if not user_data:
raise web.HTTPNotFound(text="User not found") raise web.HTTPNotFound(text="User not found")
@ -585,6 +597,10 @@ class UserManagementView(web.View):
async def edit_user_submit(self): async def edit_user_submit(self):
email = self.request.match_info.get("email") email = self.request.match_info.get("email")
if not await self.verify_user_access(email):
raise web.HTTPForbidden(text="You do not have permission to access this user")
data = await self.request.post() data = await self.request.post()
storage_quota_gb = data.get("storage_quota_gb", "") storage_quota_gb = data.get("storage_quota_gb", "")
@ -598,7 +614,7 @@ class UserManagementView(web.View):
errors["storage_quota_gb"] = "Invalid storage quota value" errors["storage_quota_gb"] = "Invalid storage quota value"
user_service = self.request.app["user_service"] user_service = self.request.app["user_service"]
user_data = user_service.get_user_by_email(email) user_data = await user_service.get_user_by_email(email)
if not user_data: if not user_data:
raise web.HTTPNotFound(text="User not found") raise web.HTTPNotFound(text="User not found")
@ -616,7 +632,7 @@ class UserManagementView(web.View):
} }
) )
user_service.update_user_quota(email, storage_quota_gb) await user_service.update_user_quota(email, storage_quota_gb)
raise web.HTTPFound( raise web.HTTPFound(
self.request.app.router["edit_user"].url_for(email=email).with_query( self.request.app.router["edit_user"].url_for(email=email).with_query(
@ -627,8 +643,11 @@ class UserManagementView(web.View):
async def user_details_page(self): async def user_details_page(self):
email = self.request.match_info.get("email") email = self.request.match_info.get("email")
if not await self.verify_user_access(email):
raise web.HTTPForbidden(text="You do not have permission to access this user")
user_service = self.request.app["user_service"] user_service = self.request.app["user_service"]
user_data = user_service.get_user_by_email(email) user_data = await user_service.get_user_by_email(email)
if not user_data: if not user_data:
raise web.HTTPNotFound(text="User not found") raise web.HTTPNotFound(text="User not found")
@ -647,13 +666,16 @@ class UserManagementView(web.View):
async def delete_user_submit(self): async def delete_user_submit(self):
email = self.request.match_info.get("email") email = self.request.match_info.get("email")
if not await self.verify_user_access(email):
raise web.HTTPForbidden(text="You do not have permission to access this user")
user_service = self.request.app["user_service"] user_service = self.request.app["user_service"]
user_data = user_service.get_user_by_email(email) user_data = await user_service.get_user_by_email(email)
if not user_data: if not user_data:
raise web.HTTPNotFound(text="User not found") raise web.HTTPNotFound(text="User not found")
user_service.delete_user(email) await user_service.delete_user(email)
raise web.HTTPFound( raise web.HTTPFound(
self.request.app.router["users"].url_for().with_query( self.request.app.router["users"].url_for().with_query(

View File

@ -57,16 +57,14 @@ def temp_users_json(tmp_path):
@pytest.fixture @pytest.fixture
def file_service_instance(temp_user_files_dir, temp_users_json): def file_service_instance(temp_user_files_dir, temp_users_json):
"""Fixture to provide a FileService instance with temporary directories.""" """Fixture to provide a FileService instance with temporary directories."""
return FileService(temp_user_files_dir, temp_users_json) user_service = UserService(temp_users_json) # Create a UserService instance
return FileService(temp_user_files_dir, user_service) # Pass the UserService instance
@pytest.fixture @pytest.fixture
def create_app_instance(): def create_test_app(mocker, temp_user_files_dir, temp_users_json):
"""Fixture to create a new aiohttp application instance."""
return create_app()
@pytest.fixture
def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_instance):
"""Fixture to create a test aiohttp application with mocked services.""" """Fixture to create a test aiohttp application with mocked services."""
from aiohttp import web from aiohttp import web
@ -81,16 +79,15 @@ def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_i
app.middlewares.append(error_middleware) app.middlewares.append(error_middleware)
app.middlewares.append(user_middleware) app.middlewares.append(user_middleware)
# Mock UserService # Use a real UserService with a temporary users.json for the app
mock_user_service = mocker.MagicMock(spec=UserService) app["user_service"] = UserService(temp_users_json)
# Mock scheduler # Mock scheduler
mock_scheduler = mocker.MagicMock() mock_scheduler = mocker.MagicMock()
mock_scheduler.spawn = mocker.AsyncMock() mock_scheduler.spawn = mocker.AsyncMock()
mock_scheduler.close = mocker.AsyncMock() mock_scheduler.close = mocker.AsyncMock()
app["user_service"] = mock_user_service app["file_service"] = FileService(temp_user_files_dir, app["user_service"])
app["file_service"] = file_service_instance
app["scheduler"] = mock_scheduler app["scheduler"] = mock_scheduler
# Setup Jinja2 for templates # Setup Jinja2 for templates
@ -102,42 +99,14 @@ def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_i
return app return app
@pytest.fixture(scope="function")
def mock_users_db_fixture():
"""
Fixture to simulate a user database for dynamic mocking,
reset for each test function.
"""
return {
"admin@example.com": {
"full_name": "Admin User",
"email": "admin@example.com",
"password": "password", # Store plain password for mock authentication
"hashed_password": "hashed_password", # For consistency with real service
"storage_quota_gb": 100,
"storage_used_gb": 10,
"parent_email": None,
"reset_token": None,
"reset_token_expiry": None,
},
"child1@example.com": {
"full_name": "Child User 1",
"email": "child1@example.com",
"password": "password",
"hashed_password": "hashed_password",
"storage_quota_gb": 50,
"storage_used_gb": 5,
"parent_email": "admin@example.com",
"shared_items": {}
}
}
@pytest.fixture @pytest.fixture
async def client( async def client(
aiohttp_client, mocker: MockerFixture, create_app_instance, mock_users_db_fixture aiohttp_client, mocker: MockerFixture, create_test_app
): ):
app = create_app_instance # Use the new fixture app = create_test_app # Use create_test_app for consistent test environment
# Directly set app["scheduler"] to a mock object # Directly set app["scheduler"] to a mock object
mock_scheduler_instance = mocker.MagicMock() mock_scheduler_instance = mocker.MagicMock()
@ -145,189 +114,20 @@ async def client(
mock_scheduler_instance.close = mocker.AsyncMock() # Ensure close is awaitable mock_scheduler_instance.close = mocker.AsyncMock() # Ensure close is awaitable
app["scheduler"] = mock_scheduler_instance app["scheduler"] = mock_scheduler_instance
# Create temporary data files for testing # The UserService and ConfigService are now set up in create_test_app with temporary files.
base_path = Path(__file__).parent.parent # No need to manually create users.json or config.json here.
data_path = base_path / "data"
data_path.mkdir(exist_ok=True)
users_file = data_path / "users.json"
with open(users_file, "w") as f:
json.dump([], f)
config_file = data_path / "config.json"
with open(config_file, "w") as f:
json.dump({"price_per_gb": 0.0}, f)
app["config_service"] = ConfigService(data_path / "config.json")
client = await aiohttp_client(app) client = await aiohttp_client(app)
# Access the real UserService instance and mock its methods # The UserService is now a real instance, so we don't need to mock its methods here.
mock_user_service_instance = client.app["user_service"] # The mock_users_db_fixture is also no longer needed in this context.
# Use the mock_users_db_fixture
mock_users_db = mock_users_db_fixture
def mock_authenticate_user(email, password):
user = mock_users_db.get(email)
if user and user["password"] == password:
return user
return None
def mock_get_user_by_email(email):
return mock_users_db.get(email)
def mock_create_user(full_name, email, password, parent_email=None):
if email in mock_users_db:
raise ValueError("User with this email already exists")
new_user = {
"full_name": full_name,
"email": email,
"password": password,
"hashed_password": "hashed_password",
"storage_quota_gb": 5,
"storage_used_gb": 0,
"parent_email": parent_email,
"reset_token": None,
"reset_token_expiry": None,
}
mock_users_db[email] = new_user
return new_user
def mock_reset_password(email, token, new_password): # Added token argument
user = mock_users_db.get(email)
if user and user.get("reset_token") == token and user.get("reset_token_expiry"):
expiry_time = datetime.datetime.fromisoformat(user["reset_token_expiry"])
if expiry_time > datetime.datetime.now(datetime.timezone.utc):
user["password"] = new_password
user["hashed_password"] = "new_hashed_password" # Simulate hashing
user["reset_token"] = None
user["reset_token_expiry"] = None
return True
return False
def mock_generate_reset_token(email):
user = mock_users_db.get(email)
if user:
# In a real scenario, this would generate a unique token and expiry
user["reset_token"] = "test_token"
user["reset_token_expiry"] = "2030-11-08T20:00:00Z" # A future date
return "test_token"
return None
def mock_validate_reset_token(email, token):
if (
token == "expiredtoken123"
): # Explicitly handle the expired token from the test
return False
user = mock_users_db.get(email)
if user and user.get("reset_token") == token and user.get("reset_token_expiry"):
expiry_time = datetime.datetime.fromisoformat(
user["reset_token_expiry"]
)
if expiry_time > datetime.datetime.now(datetime.timezone.utc):
return True
return False
def mock_save_users():
# This mock ensures that changes to user objects within tests are reflected in mock_users_db
# In a real scenario, this would write to a file or database.
pass # The mock_users_db is already being modified directly by other mocks
def mock_get_all_users():
return list(mock_users_db.values())
def mock_get_users_by_parent_email(parent_email):
return [
user
for user in mock_users_db.values()
if user.get("parent_email") == parent_email
]
def mock_delete_user(email):
if email in mock_users_db:
del mock_users_db[email]
return True
return False
def mock_delete_users_by_parent_email(parent_email):
initial_count = len(mock_users_db)
users_to_delete = [
email
for email, user in mock_users_db.items()
if user.get("parent_email") == parent_email
]
for email in users_to_delete:
del mock_users_db[email]
return initial_count - len(mock_users_db)
mocker.patch.object(
mock_user_service_instance,
"authenticate_user",
side_effect=mock_authenticate_user,
)
mocker.patch.object(
mock_user_service_instance,
"get_user_by_email",
side_effect=mock_get_user_by_email,
)
mocker.patch.object(
mock_user_service_instance, "create_user", side_effect=mock_create_user
)
mocker.patch.object(
mock_user_service_instance, "get_all_users", side_effect=mock_get_all_users
)
mocker.patch.object(
mock_user_service_instance, "update_user_quota", return_value=None
) # Keep as is for now
mocker.patch.object(
mock_user_service_instance, "delete_user", side_effect=mock_delete_user
)
mocker.patch.object(
mock_user_service_instance,
"get_users_by_parent_email",
side_effect=mock_get_users_by_parent_email,
)
mocker.patch.object(
mock_user_service_instance,
"delete_users_by_parent_email",
side_effect=mock_delete_users_by_parent_email,
)
mocker.patch.object(
mock_user_service_instance,
"generate_reset_token",
side_effect=mock_generate_reset_token,
)
mocker.patch.object(
mock_user_service_instance,
"get_user_by_reset_token",
side_effect=lambda token: next(
(
user
for user in mock_users_db.values()
if user.get("reset_token") == token
),
None,
),
)
mocker.patch.object(
mock_user_service_instance, "reset_password", side_effect=mock_reset_password
)
mocker.patch.object(
mock_user_service_instance,
"validate_reset_token",
side_effect=mock_validate_reset_token,
)
mocker.patch.object(
mock_user_service_instance, "_save_users", side_effect=mock_save_users
)
try: try:
yield client yield client
finally: finally:
# Clean up temporary files # Clean up temporary files created by create_test_app if necessary,
users_file.unlink(missing_ok=True) # but tmp_path usually handles this.
config_file.unlink(missing_ok=True) # Use missing_ok for robustness pass
@pytest.fixture @pytest.fixture
@ -338,38 +138,9 @@ async def logged_in_client(aiohttp_client, create_test_app, mocker):
user_service = app["user_service"] user_service = app["user_service"]
def mock_create_user(full_name, email, password, parent_email=None): # The UserService is now a real instance, so we don't need to mock its methods here.
return { # The create_user, authenticate_user, and get_user_by_email methods will interact
"full_name": full_name, # with the real UserService instance.
"email": email,
"password": "hashed_password",
"storage_quota_gb": 10,
"storage_used_gb": 0,
"parent_email": parent_email,
"shared_items": {}
}
def mock_authenticate_user(email, password):
return {
"email": email,
"full_name": "Test User",
"is_admin": False,
"storage_quota_gb": 10,
"storage_used_gb": 0
}
def mock_get_user_by_email(email):
return {
"email": email,
"full_name": "Test User",
"is_admin": False,
"storage_quota_gb": 10,
"storage_used_gb": 0
}
mocker.patch.object(user_service, "create_user", side_effect=mock_create_user)
mocker.patch.object(user_service, "authenticate_user", side_effect=mock_authenticate_user)
mocker.patch.object(user_service, "get_user_by_email", side_effect=mock_get_user_by_email)
await client.post( await client.post(
"/register", "/register",
@ -394,38 +165,9 @@ async def logged_in_admin_client(aiohttp_client, create_test_app, mocker):
user_service = app["user_service"] user_service = app["user_service"]
def mock_create_user(full_name, email, password, parent_email=None): # The UserService is now a real instance, so we don't need to mock its methods here.
return { # The create_user, authenticate_user, and get_user_by_email methods will interact
"full_name": full_name, # with the real UserService instance.
"email": email,
"password": "hashed_password",
"storage_quota_gb": 100,
"storage_used_gb": 0,
"parent_email": parent_email,
"shared_items": {}
}
def mock_authenticate_user(email, password):
return {
"email": email,
"full_name": "Admin User",
"is_admin": True,
"storage_quota_gb": 100,
"storage_used_gb": 0
}
def mock_get_user_by_email(email):
return {
"email": email,
"full_name": "Admin User",
"is_admin": True,
"storage_quota_gb": 100,
"storage_used_gb": 0
}
mocker.patch.object(user_service, "create_user", side_effect=mock_create_user)
mocker.patch.object(user_service, "authenticate_user", side_effect=mock_authenticate_user)
mocker.patch.object(user_service, "get_user_by_email", side_effect=mock_get_user_by_email)
await client.post( await client.post(
"/register", "/register",

View File

@ -204,7 +204,7 @@ async def test_reset_password_get_valid_token(client):
"confirm_password": "old_password", "confirm_password": "old_password",
}, },
) )
token = user_service.generate_reset_token("test@example.com") token = await user_service.generate_reset_token("test@example.com")
assert token is not None assert token is not None
resp = await client.get(f"/reset_password/{token}") resp = await client.get(f"/reset_password/{token}")
@ -233,7 +233,7 @@ async def test_reset_password_post_success(client, mock_send_email):
"confirm_password": "old_password", "confirm_password": "old_password",
}, },
) )
token = user_service.generate_reset_token("test@example.com") token = await user_service.generate_reset_token("test@example.com")
assert token is not None assert token is not None
resp = await client.post( resp = await client.post(
@ -248,8 +248,8 @@ async def test_reset_password_post_success(client, mock_send_email):
assert resp.headers["Location"] == "/login?message=password_reset_success" assert resp.headers["Location"] == "/login?message=password_reset_success"
# Verify password changed # Verify password changed
assert user_service.authenticate_user("test@example.com", "new_password") assert await user_service.authenticate_user("test@example.com", "new_password")
assert not user_service.authenticate_user("test@example.com", "old_password") assert not await user_service.authenticate_user("test@example.com", "old_password")
# Assert that confirmation email was sent # Assert that confirmation email was sent
@ -272,7 +272,7 @@ async def test_reset_password_post_password_mismatch(client):
"confirm_password": "old_password", "confirm_password": "old_password",
}, },
) )
token = user_service.generate_reset_token("test@example.com") token = await user_service.generate_reset_token("test@example.com")
assert token is not None assert token is not None
resp = await client.post( resp = await client.post(
@ -286,7 +286,7 @@ async def test_reset_password_post_password_mismatch(client):
text = await resp.text() text = await resp.text()
assert "Passwords do not match" in text assert "Passwords do not match" in text
# Password should not have changed # Password should not have changed
assert user_service.authenticate_user("test@example.com", "old_password") assert await user_service.authenticate_user("test@example.com", "old_password")
async def test_reset_password_post_invalid_token(client): async def test_reset_password_post_invalid_token(client):
@ -301,7 +301,7 @@ async def test_reset_password_post_invalid_token(client):
}, },
) )
# Generate a token but don't use it, or use an expired one # Generate a token but don't use it, or use an expired one
user_service.generate_reset_token("test@example.com") # This will be overwritten or ignored await user_service.generate_reset_token("test@example.com") # This will be overwritten or ignored
resp = await client.post( resp = await client.post(
"/reset_password/invalidtoken", "/reset_password/invalidtoken",
@ -314,7 +314,7 @@ async def test_reset_password_post_invalid_token(client):
text = await resp.text() text = await resp.text()
assert "Invalid or expired password reset link." in text assert "Invalid or expired password reset link." in text
# Password should not have changed # Password should not have changed
assert user_service.authenticate_user("test@example.com", "old_password") assert await user_service.authenticate_user("test@example.com", "old_password")
async def test_reset_password_post_expired_token(client, mock_users_db_fixture): async def test_reset_password_post_expired_token(client, mock_users_db_fixture):
@ -329,7 +329,7 @@ async def test_reset_password_post_expired_token(client, mock_users_db_fixture):
}, },
) )
# Manually set an expired token # Manually set an expired token
user = user_service.get_user_by_email("test@example.com") user = await user_service.get_user_by_email("test@example.com")
token = "expiredtoken123" token = "expiredtoken123"
user["reset_token"] = token user["reset_token"] = token
user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat() user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
@ -346,7 +346,7 @@ async def test_reset_password_post_expired_token(client, mock_users_db_fixture):
text = await resp.text() text = await resp.text()
assert "Invalid or expired password reset link." in text assert "Invalid or expired password reset link." in text
# Password should not have changed # Password should not have changed
assert user_service.authenticate_user("test@example.com", "old_password") assert await user_service.authenticate_user("test@example.com", "old_password")
async def test_reset_password_post_invalid_password_format(client): async def test_reset_password_post_invalid_password_format(client):
@ -360,7 +360,7 @@ async def test_reset_password_post_invalid_password_format(client):
"confirm_password": "old_password", "confirm_password": "old_password",
}, },
) )
token = user_service.generate_reset_token("test@example.com") token = await user_service.generate_reset_token("test@example.com")
assert token is not None assert token is not None
resp = await client.post( resp = await client.post(
@ -374,4 +374,4 @@ async def test_reset_password_post_invalid_password_format(client):
text = await resp.text() text = await resp.text()
assert "ensure this value has at least 8 characters" in text assert "ensure this value has at least 8 characters" in text
# Password should not have changed # Password should not have changed
assert user_service.authenticate_user("test@example.com", "old_password") assert await user_service.authenticate_user("test@example.com", "old_password")

View File

@ -27,33 +27,38 @@ async def test_file_service_list_files_empty(file_service_instance):
assert files == [] assert files == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_create_folder(file_service_instance, temp_user_files_dir): async def test_file_service_create_folder(file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
folder_name = "my_new_folder" folder_name = "my_new_folder"
success = await file_service_instance.create_folder(user_email, folder_name) success = await file_service_instance.create_folder(user_email, folder_name)
assert success assert success
expected_path = temp_user_files_dir / user_email / folder_name metadata = await file_service_instance._load_metadata(user_email)
assert expected_path.is_dir() assert folder_name in metadata
assert metadata[folder_name]["type"] == "dir"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_create_folder_exists(file_service_instance, temp_user_files_dir): async def test_file_service_create_folder_exists(file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
folder_name = "existing_folder" folder_name = "existing_folder"
(temp_user_files_dir / user_email).mkdir(parents=True) await file_service_instance.create_folder(user_email, folder_name) # Create it first via service
(temp_user_files_dir / user_email / folder_name).mkdir(parents=True)
success = await file_service_instance.create_folder(user_email, folder_name) success = await file_service_instance.create_folder(user_email, folder_name)
assert not success # Should return False if folder already exists assert not success # Should return False if folder already exists
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_upload_file(file_service_instance, temp_user_files_dir): async def test_file_service_upload_file(file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
file_name = "document.txt" file_name = "document.txt"
file_content = b"Hello, world!" file_content = b"Hello, world!"
success = await file_service_instance.upload_file(user_email, file_name, file_content) success = await file_service_instance.upload_file(user_email, file_name, file_content)
assert success assert success
expected_path = temp_user_files_dir / user_email / file_name metadata = await file_service_instance._load_metadata(user_email)
assert expected_path.is_file() assert file_name in metadata
assert expected_path.read_bytes() == file_content assert metadata[file_name]["type"] == "file"
assert metadata[file_name]["size"] == len(file_content)
# Verify content by downloading
downloaded_content, downloaded_name = await file_service_instance.download_file(user_email, file_name)
assert downloaded_content == file_content
assert downloaded_name == file_name
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_list_files_with_content(file_service_instance, temp_user_files_dir): async def test_file_service_list_files_with_content(file_service_instance, temp_user_files_dir):
@ -89,27 +94,38 @@ async def test_file_service_download_file_not_found(file_service_instance):
assert content is None assert content is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_delete_file(file_service_instance, temp_user_files_dir): async def test_file_service_delete_file(file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
file_name = "to_delete.txt" file_name = "to_delete.txt"
(temp_user_files_dir / user_email).mkdir(exist_ok=True) await file_service_instance.upload_file(user_email, file_name, b"delete me")
(temp_user_files_dir / user_email / file_name).write_bytes(b"delete me")
metadata_before = await file_service_instance._load_metadata(user_email)
assert file_name in metadata_before
success = await file_service_instance.delete_item(user_email, file_name) success = await file_service_instance.delete_item(user_email, file_name)
assert success assert success
assert not (temp_user_files_dir / user_email / file_name).exists()
metadata_after = await file_service_instance._load_metadata(user_email)
assert file_name not in metadata_after
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_delete_folder(file_service_instance, temp_user_files_dir): async def test_file_service_delete_folder(file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
folder_name = "folder_to_delete" folder_name = "folder_to_delete"
(temp_user_files_dir / user_email).mkdir(parents=True) nested_file = f"{folder_name}/nested.txt"
(temp_user_files_dir / user_email / folder_name).mkdir(parents=True) await file_service_instance.create_folder(user_email, folder_name)
(temp_user_files_dir / user_email / folder_name / "nested.txt").write_bytes(b"nested") await file_service_instance.upload_file(user_email, nested_file, b"nested content")
metadata_before = await file_service_instance._load_metadata(user_email)
assert folder_name in metadata_before
assert nested_file in metadata_before
success = await file_service_instance.delete_item(user_email, folder_name) success = await file_service_instance.delete_item(user_email, folder_name)
assert success assert success
assert not (temp_user_files_dir / user_email / folder_name).exists()
metadata_after = await file_service_instance._load_metadata(user_email)
assert folder_name not in metadata_after
assert nested_file not in metadata_after
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_service_delete_nonexistent(file_service_instance): async def test_file_service_delete_nonexistent(file_service_instance):
@ -199,14 +215,15 @@ async def test_file_browser_get_authorized_with_files(logged_in_client: TestClie
assert "my_file.txt" in text assert "my_file.txt" in text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_new_folder(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): async def test_file_browser_new_folder(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
resp = await logged_in_client.post("/files/new_folder", data={"folder_name": "new_folder_via_web"}, allow_redirects=False) resp = await logged_in_client.post("/files/new_folder", data={"folder_name": "new_folder_via_web"}, allow_redirects=False)
assert resp.status == 302 # Redirect assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files") assert resp.headers["Location"].startswith("/files")
expected_path = temp_user_files_dir / user_email / "new_folder_via_web" metadata = await file_service_instance._load_metadata(user_email)
assert expected_path.is_dir() assert "new_folder_via_web" in metadata
assert metadata["new_folder_via_web"]["type"] == "dir"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_new_folder_missing_name(logged_in_client: TestClient): async def test_file_browser_new_folder_missing_name(logged_in_client: TestClient):
@ -228,23 +245,29 @@ async def test_file_browser_new_folder_exists(logged_in_client: TestClient, file
assert f"error=Folder+'{folder_name}'+already+exists+or+could+not+be+created" in resp.headers["Location"] assert f"error=Folder+'{folder_name}'+already+exists+or+could+not+be+created" in resp.headers["Location"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_upload_file(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): async def test_file_browser_upload_file(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
file_name = "uploaded.txt"
file_content = b"Uploaded content from web." file_content = b"Uploaded content from web."
from io import BytesIO from io import BytesIO
data = aiohttp.FormData() data = aiohttp.FormData()
data.add_field('file', data.add_field('file',
BytesIO(file_content), BytesIO(file_content),
filename='uploaded.txt', filename=file_name,
content_type='text/plain') content_type='text/plain')
resp = await logged_in_client.post("/files/upload", data=data, allow_redirects=False) resp = await logged_in_client.post("/files/upload", data=data, allow_redirects=False)
assert resp.status == 200 assert resp.status == 200
expected_path = temp_user_files_dir / user_email / "uploaded.txt" metadata = await file_service_instance._load_metadata(user_email)
assert expected_path.is_file() assert file_name in metadata
assert expected_path.read_bytes() == file_content assert metadata[file_name]["type"] == "file"
assert metadata[file_name]["size"] == len(file_content)
# Verify content by downloading
downloaded_content, downloaded_name = await file_service_instance.download_file(user_email, file_name)
assert downloaded_content == file_content
assert downloaded_name == file_name
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_download_file(logged_in_client: TestClient, file_service_instance): async def test_file_browser_download_file(logged_in_client: TestClient, file_service_instance):
@ -265,41 +288,52 @@ async def test_file_browser_download_file_not_found(logged_in_client: TestClient
assert "File not found" in await response.text() assert "File not found" in await response.text()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_delete_file(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): async def test_file_browser_delete_file(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
file_name = "web_delete.txt" file_name = "web_delete.txt"
await file_service_instance.upload_file(user_email, file_name, b"delete this") await file_service_instance.upload_file(user_email, file_name, b"delete this")
expected_path = temp_user_files_dir / user_email / file_name metadata_before = await file_service_instance._load_metadata(user_email)
assert expected_path.is_file() assert file_name in metadata_before
resp = await logged_in_client.post(f"/files/delete/{file_name}", allow_redirects=False) resp = await logged_in_client.post(f"/files/delete/{file_name}", allow_redirects=False)
assert resp.status == 302 # Redirect assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files") assert resp.headers["Location"].startswith("/files")
assert not expected_path.is_file()
metadata_after = await file_service_instance._load_metadata(user_email)
assert file_name not in metadata_after
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_delete_folder(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): async def test_file_browser_delete_folder(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
folder_name = "web_delete_folder" folder_name = "web_delete_folder"
nested_file = f"{folder_name}/nested.txt"
await file_service_instance.create_folder(user_email, folder_name) await file_service_instance.create_folder(user_email, folder_name)
await file_service_instance.upload_file(user_email, f"{folder_name}/nested.txt", b"nested") await file_service_instance.upload_file(user_email, nested_file, b"nested")
expected_path = temp_user_files_dir / user_email / folder_name metadata_before = await file_service_instance._load_metadata(user_email)
assert expected_path.is_dir() assert folder_name in metadata_before
assert nested_file in metadata_before
resp = await logged_in_client.post(f"/files/delete/{folder_name}", allow_redirects=False) resp = await logged_in_client.post(f"/files/delete/{folder_name}", allow_redirects=False)
assert resp.status == 302 # Redirect assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files") assert resp.headers["Location"].startswith("/files")
assert not expected_path.is_dir()
metadata_after = await file_service_instance._load_metadata(user_email)
assert folder_name not in metadata_after
assert nested_file not in metadata_after
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
file_names = ["multi_delete_1.txt", "multi_delete_2.txt", "multi_delete_3.txt"] file_names = ["multi_delete_1.txt", "multi_delete_2.txt", "multi_delete_3.txt"]
for name in file_names: for name in file_names:
await file_service_instance.upload_file(user_email, name, b"content") await file_service_instance.upload_file(user_email, name, b"content")
metadata_before = await file_service_instance._load_metadata(user_email)
for name in file_names:
assert name in metadata_before
paths_to_delete = [f"{name}" for name in file_names] paths_to_delete = [f"{name}" for name in file_names]
# Construct FormData for multiple paths # Construct FormData for multiple paths
@ -311,18 +345,23 @@ async def test_file_browser_delete_multiple_files(logged_in_client: TestClient,
assert resp.status == 302 # Redirect assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files") assert resp.headers["Location"].startswith("/files")
metadata_after = await file_service_instance._load_metadata(user_email)
for name in file_names: for name in file_names:
expected_path = temp_user_files_dir / user_email / name assert name not in metadata_after
assert not expected_path.is_file()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient, file_service_instance, temp_user_files_dir): async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com" user_email = "test@example.com"
folder_names = ["multi_delete_folder_1", "multi_delete_folder_2"] folder_names = ["multi_delete_folder_1", "multi_delete_folder_2"]
for name in folder_names: for name in folder_names:
await file_service_instance.create_folder(user_email, name) await file_service_instance.create_folder(user_email, name)
await file_service_instance.upload_file(user_email, f"{name}/nested.txt", b"nested content") await file_service_instance.upload_file(user_email, f"{name}/nested.txt", b"nested content")
metadata_before = await file_service_instance._load_metadata(user_email)
for name in folder_names:
assert name in metadata_before
assert f"{name}/nested.txt" in metadata_before
paths_to_delete = [f"{name}" for name in folder_names] paths_to_delete = [f"{name}" for name in folder_names]
# Construct FormData for multiple paths # Construct FormData for multiple paths
@ -334,9 +373,10 @@ async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient
assert resp.status == 302 # Redirect assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files") assert resp.headers["Location"].startswith("/files")
metadata_after = await file_service_instance._load_metadata(user_email)
for name in folder_names: for name in folder_names:
expected_path = temp_user_files_dir / user_email / name assert name not in metadata_after
assert not expected_path.is_dir() assert f"{name}/nested.txt" not in metadata_after
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: TestClient): async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: TestClient):
@ -345,7 +385,7 @@ async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: Tes
assert "error=No+items+selected+for+deletion" in resp.headers["Location"] assert "error=No+items+selected+for+deletion" in resp.headers["Location"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: TestClient, file_service_instance, temp_user_files_dir, mocker): async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: TestClient, file_service_instance, mocker):
user_email = "test@example.com" user_email = "test@example.com"
file_names = ["fail_delete_1.txt", "fail_delete_2.txt"] file_names = ["fail_delete_1.txt", "fail_delete_2.txt"]
for name in file_names: for name in file_names:
@ -370,10 +410,11 @@ async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: Te
assert resp.status == 302 # Redirect assert resp.status == 302 # Redirect
assert "error=Some+items+failed+to+delete" in resp.headers["Location"] assert "error=Some+items+failed+to+delete" in resp.headers["Location"]
metadata_after = await file_service_instance._load_metadata(user_email)
# Check if the first file still exists (failed to delete) # Check if the first file still exists (failed to delete)
assert (temp_user_files_dir / user_email / file_names[0]).is_file() assert file_names[0] in metadata_after
# Check if the second file is deleted (succeeded) # Check if the second file is deleted (succeeded)
assert not (temp_user_files_dir / user_email / file_names[1]).is_file() assert file_names[1] not in metadata_after
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_share_multiple_items_no_paths(logged_in_client: TestClient): async def test_file_browser_share_multiple_items_no_paths(logged_in_client: TestClient):
@ -434,7 +475,7 @@ async def test_file_browser_download_shared_file_handler_fail_get_content(client
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat() "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat()
}) })
mocker.patch.object(file_service_instance, "_get_user_file_path", return_value=mocker.MagicMock(is_dir=lambda: True)) mocker.patch.object(file_service_instance, "get_user_file_system_path", return_value=mocker.MagicMock(is_dir=lambda: True))
mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None) mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None)
resp = await client.get(f"/shared_file/{share_id}/download?file_path={file_name}") resp = await client.get(f"/shared_file/{share_id}/download?file_path={file_name}")
@ -447,14 +488,14 @@ async def test_file_browser_download_shared_file_handler_not_a_directory(client:
file_name = "shared_file.txt" file_name = "shared_file.txt"
share_id = "test_share_id" share_id = "test_share_id"
mocker.patch.object(file_service_instance, "get_shared_item", return_value={ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={
"user_email": user_email, "user_email": user_email,
"item_path": file_name, "item_path": file_name,
"share_id": share_id, "share_id": share_id,
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat() "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat()
}) })
mocker.patch.object(file_service_instance, "_get_user_file_path", return_value=mocker.MagicMock(is_dir=lambda: False)) mocker.patch.object(client.app["file_service"], "get_user_file_system_path", return_value=mocker.MagicMock(is_dir=lambda: False))
resp = await client.get(f"/shared_file/{share_id}/download?file_path=some_file.txt") resp = await client.get(f"/shared_file/{share_id}/download?file_path=some_file.txt")
assert resp.status == 400 assert resp.status == 400
@ -462,11 +503,11 @@ async def test_file_browser_download_shared_file_handler_not_a_directory(client:
assert "Cannot download specific files from a shared item that is not a folder." in text assert "Cannot download specific files from a shared item that is not a folder." in text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_download_shared_file_handler_shared_item_not_found(client: TestClient, file_service_instance, mocker): async def test_file_browser_download_shared_file_handler_shared_item_not_found(client: TestClient, file_service_instance, mocker):
mocker.patch.object(file_service_instance, "get_shared_item", return_value=None) mocker.patch.object(client.app["file_service"], "get_shared_item", return_value=None)
resp = await client.get("/shared_file/nonexistent_share_id/download?file_path=some_file.txt") resp = await client.get("/shared_file/nonexistent_share_id/download?file_path=some_file.txt")
assert resp.status == 404 assert resp.status == 404
text = await resp.text() text = await resp.text()
assert "Shared link is invalid or has expired." in text assert "Shared file not found or inaccessible" in text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_download_shared_file_handler_missing_file_path(client: TestClient): async def test_file_browser_download_shared_file_handler_missing_file_path(client: TestClient):
@ -480,7 +521,7 @@ async def test_file_browser_shared_file_handler_fail_get_content(client: TestCli
file_name = "shared_file.txt" file_name = "shared_file.txt"
share_id = "test_share_id" share_id = "test_share_id"
mocker.patch.object(file_service_instance, "get_shared_item", return_value={ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={
"user_email": user_email, "user_email": user_email,
"item_path": file_name, "item_path": file_name,
"share_id": share_id, "share_id": share_id,
@ -488,7 +529,7 @@ async def test_file_browser_shared_file_handler_fail_get_content(client: TestCli
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat() "expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat()
}) })
mocker.patch("pathlib.Path.is_file", return_value=True) # Simulate it's a file mocker.patch("pathlib.Path.is_file", return_value=True) # Simulate it's a file
mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None) mocker.patch.object(client.app["file_service"], "get_shared_file_content", return_value=None)
resp = await client.get(f"/shared_file/{share_id}") resp = await client.get(f"/shared_file/{share_id}")
assert resp.status == 404 assert resp.status == 404
@ -501,7 +542,7 @@ async def test_file_browser_shared_file_handler_neither_file_nor_dir(client: Tes
item_path = "mystery_item" item_path = "mystery_item"
share_id = "test_share_id" share_id = "test_share_id"
mocker.patch.object(file_service_instance, "get_shared_item", return_value={ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={
"user_email": user_email, "user_email": user_email,
"item_path": item_path, "item_path": item_path,
"share_id": share_id, "share_id": share_id,
@ -520,11 +561,11 @@ async def test_file_browser_shared_file_handler_neither_file_nor_dir(client: Tes
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_shared_file_handler_not_found(client: TestClient, file_service_instance, mocker): async def test_file_browser_shared_file_handler_not_found(client: TestClient, file_service_instance, mocker):
mocker.patch.object(file_service_instance, "get_shared_item", return_value=None) mocker.patch.object(client.app["file_service"], "get_shared_item", return_value=None)
resp = await client.get("/shared_file/nonexistent_share_id") resp = await client.get("/shared_file/nonexistent_share_id")
assert resp.status == 404 assert resp.status == 404
text = await resp.text() text = await resp.text()
assert "Shared link is invalid or has expired." in text assert "Shared file not found or inaccessible" in text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_file_browser_unknown_post_action(logged_in_client: TestClient, mocker): async def test_file_browser_unknown_post_action(logged_in_client: TestClient, mocker):

View File

@ -0,0 +1,183 @@
import pytest
import asyncio
import shutil
from pathlib import Path
from retoors.services.storage_service import StorageService, UserStorageManager
@pytest.fixture
def test_storage():
storage = StorageService(base_path="data/test_user")
yield storage
if Path("data/test_user").exists():
shutil.rmtree("data/test_user")
@pytest.fixture
def user_manager():
manager = UserStorageManager()
manager.storage.base_path = Path("data/test_user")
yield manager
if Path("data/test_user").exists():
shutil.rmtree("data/test_user")
@pytest.mark.asyncio
async def test_save_and_load(test_storage):
user_email = "test@example.com"
identifier = "doc1"
data = {"title": "Test Document", "content": "Test content"}
result = await test_storage.save(user_email, identifier, data)
assert result is True
loaded_data = await test_storage.load(user_email, identifier)
assert loaded_data == data
@pytest.mark.asyncio
async def test_distributed_path_structure(test_storage):
user_email = "test@example.com"
identifier = "doc1"
data = {"test": "data"}
await test_storage.save(user_email, identifier, data)
user_base = test_storage._get_user_base_path(user_email)
file_path = test_storage._get_distributed_path(user_base, identifier)
assert file_path.exists()
assert len(file_path.parent.name) == 3
assert len(file_path.parent.parent.name) == 3
assert len(file_path.parent.parent.parent.name) == 3
@pytest.mark.asyncio
async def test_user_isolation(test_storage):
user1_email = "user1@example.com"
user2_email = "user2@example.com"
identifier = "doc1"
data1 = {"user": "user1"}
data2 = {"user": "user2"}
await test_storage.save(user1_email, identifier, data1)
await test_storage.save(user2_email, identifier, data2)
loaded1 = await test_storage.load(user1_email, identifier)
loaded2 = await test_storage.load(user2_email, identifier)
assert loaded1 == data1
assert loaded2 == data2
@pytest.mark.asyncio
async def test_path_traversal_protection(test_storage):
user_email = "test@example.com"
malicious_identifier = "../../../etc/passwd"
await test_storage.save(user_email, malicious_identifier, {"test": "data"})
user_base = test_storage._get_user_base_path(user_email)
file_path = test_storage._get_distributed_path(user_base, malicious_identifier)
assert file_path.exists()
assert test_storage._validate_path(file_path, user_base)
assert str(file_path.resolve()).startswith(str(user_base.resolve()))
@pytest.mark.asyncio
async def test_delete(test_storage):
user_email = "test@example.com"
identifier = "doc1"
data = {"test": "data"}
await test_storage.save(user_email, identifier, data)
assert await test_storage.exists(user_email, identifier)
result = await test_storage.delete(user_email, identifier)
assert result is True
assert not await test_storage.exists(user_email, identifier)
@pytest.mark.asyncio
async def test_list_all(test_storage):
user_email = "test@example.com"
await test_storage.save(user_email, "doc1", {"id": 1})
await test_storage.save(user_email, "doc2", {"id": 2})
await test_storage.save(user_email, "doc3", {"id": 3})
all_docs = await test_storage.list_all(user_email)
assert len(all_docs) == 3
assert any(doc["id"] == 1 for doc in all_docs)
assert any(doc["id"] == 2 for doc in all_docs)
assert any(doc["id"] == 3 for doc in all_docs)
@pytest.mark.asyncio
async def test_delete_all(test_storage):
user_email = "test@example.com"
await test_storage.save(user_email, "doc1", {"id": 1})
await test_storage.save(user_email, "doc2", {"id": 2})
result = await test_storage.delete_all(user_email)
assert result is True
all_docs = await test_storage.list_all(user_email)
assert len(all_docs) == 0
@pytest.mark.asyncio
async def test_user_storage_manager(user_manager):
user_data = {
"full_name": "Test User",
"email": "test@example.com",
"password": "hashed_password",
"is_customer": True
}
await user_manager.save_user("test@example.com", user_data)
loaded_user = await user_manager.get_user("test@example.com")
assert loaded_user == user_data
assert await user_manager.user_exists("test@example.com")
await user_manager.delete_user("test@example.com")
assert not await user_manager.user_exists("test@example.com")
@pytest.mark.asyncio
async def test_list_users_by_parent(user_manager):
parent_user = {
"email": "parent@example.com",
"full_name": "Parent User",
"password": "hashed",
"is_customer": True
}
child_user1 = {
"email": "child1@example.com",
"full_name": "Child User 1",
"password": "hashed",
"parent_email": "parent@example.com",
"is_customer": True
}
child_user2 = {
"email": "child2@example.com",
"full_name": "Child User 2",
"password": "hashed",
"parent_email": "parent@example.com",
"is_customer": True
}
await user_manager.save_user("parent@example.com", parent_user)
await user_manager.save_user("child1@example.com", child_user1)
await user_manager.save_user("child2@example.com", child_user2)
children = await user_manager.list_users_by_parent("parent@example.com")
assert len(children) == 2
assert any(u["email"] == "child1@example.com" for u in children)
assert any(u["email"] == "child2@example.com" for u in children)

View File

@ -18,28 +18,28 @@ def user_service(users_file):
return UserService(users_file) return UserService(users_file)
@pytest.fixture @pytest.fixture
def populated_user_service(user_service): async def populated_user_service(user_service):
"""Fixture to provide a UserService instance with some pre-populated users.""" """Fixture to provide a UserService instance with some pre-populated users."""
user_service.create_user("Admin User", "admin@example.com", "adminpass") await user_service.create_user("Admin User", "admin@example.com", "adminpass")
user_service.create_user("Parent User", "parent@example.com", "parentpass") await user_service.create_user("Parent User", "parent@example.com", "parentpass")
user_service.create_user("Child User 1", "child1@example.com", "childpass", "parent@example.com") await user_service.create_user("Child User 1", "child1@example.com", "childpass", "parent@example.com")
user_service.create_user("Child User 2", "child2@example.com", "childpass", "parent@example.com") await user_service.create_user("Child User 2", "child2@example.com", "childpass", "parent@example.com")
return user_service return user_service
async def test_create_user_success(user_service): async def test_create_user_success(user_service):
user = user_service.create_user("Test User", "test@example.com", "password123") user = await user_service.create_user("Test User", "test@example.com", "password123")
assert user is not None assert user is not None
assert user["email"] == "test@example.com" assert user["email"] == "test@example.com"
assert user_service.get_user_by_email("test@example.com") is not None assert await user_service.get_user_by_email("test@example.com") is not None
assert bcrypt.checkpw(b"password123", user["password"].encode('utf-8')) assert bcrypt.checkpw(b"password123", user["password"].encode('utf-8'))
async def test_create_user_duplicate_email(user_service): async def test_create_user_duplicate_email(user_service):
user_service.create_user("Test User", "test@example.com", "password123") await user_service.create_user("Test User", "test@example.com", "password123")
with pytest.raises(ValueError, match="User with this email already exists"): with pytest.raises(ValueError, match="User with this email already exists"):
user_service.create_user("Another User", "test@example.com", "anotherpass") await user_service.create_user("Another User", "test@example.com", "anotherpass")
async def test_get_all_users(populated_user_service): async def test_get_all_users(populated_user_service):
users = populated_user_service.get_all_users() users = await populated_user_service.get_all_users()
assert len(users) == 4 assert len(users) == 4
emails = {user["email"] for user in users} emails = {user["email"] for user in users}
assert "admin@example.com" in emails assert "admin@example.com" in emails
@ -48,70 +48,70 @@ async def test_get_all_users(populated_user_service):
assert "child2@example.com" in emails assert "child2@example.com" in emails
async def test_get_users_by_parent_email(populated_user_service): async def test_get_users_by_parent_email(populated_user_service):
children = populated_user_service.get_users_by_parent_email("parent@example.com") children = await populated_user_service.get_users_by_parent_email("parent@example.com")
assert len(children) == 2 assert len(children) == 2
child_emails = {user["email"] for user in children} child_emails = {user["email"] for user in children}
assert "child1@example.com" in child_emails assert "child1@example.com" in child_emails
assert "child2@example.com" in child_emails assert "child2@example.com" in child_emails
no_children = populated_user_service.get_users_by_parent_email("nonexistent@example.com") no_children = await populated_user_service.get_users_by_parent_email("nonexistent@example.com")
assert len(no_children) == 0 assert len(no_children) == 0
admin_children = populated_user_service.get_users_by_parent_email("admin@example.com") admin_children = await populated_user_service.get_users_by_parent_email("admin@example.com")
assert len(admin_children) == 0 assert len(admin_children) == 0
async def test_update_user_non_password_fields(populated_user_service): async def test_update_user_non_password_fields(populated_user_service):
updated_user = populated_user_service.update_user("admin@example.com", full_name="Administrator", storage_quota_gb=10) updated_user = await populated_user_service.update_user("admin@example.com", full_name="Administrator", storage_quota_gb=10)
assert updated_user is not None assert updated_user is not None
assert updated_user["full_name"] == "Administrator" assert updated_user["full_name"] == "Administrator"
assert updated_user["storage_quota_gb"] == 10 assert updated_user["storage_quota_gb"] == 10
retrieved_user = populated_user_service.get_user_by_email("admin@example.com") retrieved_user = await populated_user_service.get_user_by_email("admin@example.com")
assert retrieved_user["full_name"] == "Administrator" assert retrieved_user["full_name"] == "Administrator"
assert retrieved_user["storage_quota_gb"] == 10 assert retrieved_user["storage_quota_gb"] == 10
async def test_update_user_password(populated_user_service): async def test_update_user_password(populated_user_service):
updated_user = populated_user_service.update_user("admin@example.com", password="newadminpass") updated_user = await populated_user_service.update_user("admin@example.com", password="newadminpass")
assert updated_user is not None assert updated_user is not None
assert populated_user_service.authenticate_user("admin@example.com", "newadminpass") assert await populated_user_service.authenticate_user("admin@example.com", "newadminpass")
assert not populated_user_service.authenticate_user("admin@example.com", "adminpass") assert not await populated_user_service.authenticate_user("admin@example.com", "adminpass")
async def test_update_user_nonexistent(user_service): async def test_update_user_nonexistent(user_service):
updated_user = user_service.update_user("nonexistent@example.com", full_name="Non Existent") updated_user = await user_service.update_user("nonexistent@example.com", full_name="Non Existent")
assert updated_user is None assert updated_user is None
async def test_delete_user_success(populated_user_service): async def test_delete_user_success(populated_user_service):
assert populated_user_service.delete_user("admin@example.com") is True assert await populated_user_service.delete_user("admin@example.com") is True
assert populated_user_service.get_user_by_email("admin@example.com") is None assert await populated_user_service.get_user_by_email("admin@example.com") is None
assert len(populated_user_service.get_all_users()) == 3 assert len(await populated_user_service.get_all_users()) == 3
async def test_delete_user_nonexistent(user_service): async def test_delete_user_nonexistent(user_service):
assert user_service.delete_user("nonexistent@example.com") is False assert await user_service.delete_user("nonexistent@example.com") is False
async def test_delete_users_by_parent_email_success(populated_user_service): async def test_delete_users_by_parent_email_success(populated_user_service):
deleted_count = populated_user_service.delete_users_by_parent_email("parent@example.com") deleted_count = await populated_user_service.delete_users_by_parent_email("parent@example.com")
assert deleted_count == 2 assert deleted_count == 2
assert populated_user_service.get_user_by_email("child1@example.com") is None assert await populated_user_service.get_user_by_email("child1@example.com") is None
assert populated_user_service.get_user_by_email("child2@example.com") is None assert await populated_user_service.get_user_by_email("child2@example.com") is None
assert len(populated_user_service.get_all_users()) == 2 # Admin and Parent users remain assert len(await populated_user_service.get_all_users()) == 2 # Admin and Parent users remain
async def test_delete_users_by_parent_email_no_match(user_service): async def test_delete_users_by_parent_email_no_match(user_service):
deleted_count = user_service.delete_users_by_parent_email("nonexistent@example.com") deleted_count = await user_service.delete_users_by_parent_email("nonexistent@example.com")
assert deleted_count == 0 assert deleted_count == 0
async def test_authenticate_user_success(populated_user_service): async def test_authenticate_user_success(populated_user_service):
assert populated_user_service.authenticate_user("admin@example.com", "adminpass") is True assert await populated_user_service.authenticate_user("admin@example.com", "adminpass") is True
async def test_authenticate_user_fail_wrong_password(populated_user_service): async def test_authenticate_user_fail_wrong_password(populated_user_service):
assert populated_user_service.authenticate_user("admin@example.com", "wrongpass") is False assert await populated_user_service.authenticate_user("admin@example.com", "wrongpass") is False
async def test_authenticate_user_fail_nonexistent_user(user_service): async def test_authenticate_user_fail_nonexistent_user(user_service):
assert user_service.authenticate_user("nonexistent@example.com", "anypass") is False assert await user_service.authenticate_user("nonexistent@example.com", "anypass") is False
async def test_generate_reset_token_success(populated_user_service): async def test_generate_reset_token_success(populated_user_service):
token = populated_user_service.generate_reset_token("admin@example.com") token = await populated_user_service.generate_reset_token("admin@example.com")
assert token is not None assert token is not None
user = populated_user_service.get_user_by_email("admin@example.com") user = await populated_user_service.get_user_by_email("admin@example.com")
assert user["reset_token"] == token assert user["reset_token"] == token
assert user["reset_token_expiry"] is not None assert user["reset_token_expiry"] is not None
# Check expiry is in the future # Check expiry is in the future
@ -119,71 +119,71 @@ async def test_generate_reset_token_success(populated_user_service):
assert expiry_dt > datetime.datetime.now(datetime.timezone.utc) assert expiry_dt > datetime.datetime.now(datetime.timezone.utc)
async def test_generate_reset_token_nonexistent_user(user_service): async def test_generate_reset_token_nonexistent_user(user_service):
token = user_service.generate_reset_token("nonexistent@example.com") token = await user_service.generate_reset_token("nonexistent@example.com")
assert token is None assert token is None
async def test_get_user_by_reset_token_valid(populated_user_service): async def test_get_user_by_reset_token_valid(populated_user_service):
token = populated_user_service.generate_reset_token("admin@example.com") token = await populated_user_service.generate_reset_token("admin@example.com")
user = populated_user_service.get_user_by_reset_token(token) user = await populated_user_service.get_user_by_reset_token(token)
assert user is not None assert user is not None
assert user["email"] == "admin@example.com" assert user["email"] == "admin@example.com"
async def test_get_user_by_reset_token_invalid(populated_user_service): async def test_get_user_by_reset_token_invalid(populated_user_service):
user = populated_user_service.get_user_by_reset_token("invalidtoken") user = await populated_user_service.get_user_by_reset_token("invalidtoken")
assert user is None assert user is None
async def test_get_user_by_reset_token_expired(populated_user_service): async def test_get_user_by_reset_token_expired(populated_user_service):
token = populated_user_service.generate_reset_token("admin@example.com") token = await populated_user_service.generate_reset_token("admin@example.com")
user = populated_user_service.get_user_by_email("admin@example.com") user = await populated_user_service.get_user_by_email("admin@example.com")
# Manually expire the token # Manually expire the token
user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat() user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
populated_user_service._save_users() # Save the expired state await populated_user_service._save_users() # Save the expired state
user_after_expiry = populated_user_service.get_user_by_reset_token(token) user_after_expiry = await populated_user_service.get_user_by_reset_token(token)
assert user_after_expiry is None assert user_after_expiry is None
async def test_validate_reset_token_success(populated_user_service): async def test_validate_reset_token_success(populated_user_service):
token = populated_user_service.generate_reset_token("admin@example.com") token = await populated_user_service.generate_reset_token("admin@example.com")
assert populated_user_service.validate_reset_token("admin@example.com", token) is True assert await populated_user_service.validate_reset_token("admin@example.com", token) is True
async def test_validate_reset_token_fail_wrong_token(populated_user_service): async def test_validate_reset_token_fail_wrong_token(populated_user_service):
populated_user_service.generate_reset_token("admin@example.com") await populated_user_service.generate_reset_token("admin@example.com")
assert populated_user_service.validate_reset_token("admin@example.com", "wrongtoken") is False assert await populated_user_service.validate_reset_token("admin@example.com", "wrongtoken") is False
async def test_validate_reset_token_fail_nonexistent_user(user_service): async def test_validate_reset_token_fail_nonexistent_user(user_service):
assert user_service.validate_reset_token("nonexistent@example.com", "anytoken") is False assert await user_service.validate_reset_token("nonexistent@example.com", "anytoken") is False
async def test_validate_reset_token_fail_expired_token(populated_user_service): async def test_validate_reset_token_fail_expired_token(populated_user_service):
token = populated_user_service.generate_reset_token("admin@example.com") token = await populated_user_service.generate_reset_token("admin@example.com")
user = populated_user_service.get_user_by_email("admin@example.com") user = await populated_user_service.get_user_by_email("admin@example.com")
user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat() user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
populated_user_service._save_users() await populated_user_service._save_users()
assert populated_user_service.validate_reset_token("admin@example.com", token) is False assert await populated_user_service.validate_reset_token("admin@example.com", token) is False
async def test_reset_password_success(populated_user_service): async def test_reset_password_success(populated_user_service):
token = populated_user_service.generate_reset_token("admin@example.com") token = await populated_user_service.generate_reset_token("admin@example.com")
assert populated_user_service.reset_password("admin@example.com", token, "newadminpass") is True assert await populated_user_service.reset_password("admin@example.com", token, "newadminpass") is True
assert populated_user_service.authenticate_user("admin@example.com", "newadminpass") assert await populated_user_service.authenticate_user("admin@example.com", "newadminpass")
user = populated_user_service.get_user_by_email("admin@example.com") user = await populated_user_service.get_user_by_email("admin@example.com")
assert user["reset_token"] is None assert user["reset_token"] is None
assert user["reset_token_expiry"] is None assert user["reset_token_expiry"] is None
async def test_reset_password_fail_invalid_token(populated_user_service): async def test_reset_password_fail_invalid_token(populated_user_service):
populated_user_service.generate_reset_token("admin@example.com") await populated_user_service.generate_reset_token("admin@example.com")
assert populated_user_service.reset_password("admin@example.com", "invalidtoken", "newadminpass") is False assert await populated_user_service.reset_password("admin@example.com", "invalidtoken", "newadminpass") is False
assert populated_user_service.authenticate_user("admin@example.com", "adminpass") # Password should not change assert await populated_user_service.authenticate_user("admin@example.com", "adminpass") # Password should not change
async def test_reset_password_fail_nonexistent_user(user_service): async def test_reset_password_fail_nonexistent_user(user_service):
# Even if a token was somehow generated for a nonexistent user (which shouldn't happen), # Even if a token was somehow generated for a nonexistent user (which shouldn't happen),
# reset_password should fail. # reset_password should fail.
assert user_service.reset_password("nonexistent@example.com", "anytoken", "newpass") is False assert await user_service.reset_password("nonexistent@example.com", "anytoken", "newpass") is False
async def test_update_user_quota_success(populated_user_service): async def test_update_user_quota_success(populated_user_service):
populated_user_service.update_user_quota("admin@example.com", 20.5) await populated_user_service.update_user_quota("admin@example.com", 20.5)
user = populated_user_service.get_user_by_email("admin@example.com") user = await populated_user_service.get_user_by_email("admin@example.com")
assert user["storage_quota_gb"] == 20.5 assert user["storage_quota_gb"] == 20.5
async def test_update_user_quota_nonexistent_user(user_service): async def test_update_user_quota_nonexistent_user(user_service):
with pytest.raises(ValueError, match="User not found"): with pytest.raises(ValueError, match="User not found"):
user_service.update_user_quota("nonexistent@example.com", 100) await user_service.update_user_quota("nonexistent@example.com", 100)