Files and folders that have been shared with you will appear here.
This feature is coming soon.
+
+
+
{% endblock %}
\ No newline at end of file
diff --git a/retoors/templates/pages/trash.html b/retoors/templates/pages/trash.html
index 493c9ce..8011aa0 100644
--- a/retoors/templates/pages/trash.html
+++ b/retoors/templates/pages/trash.html
@@ -2,11 +2,114 @@
{% block title %}Trash - Retoor's Cloud Solutions{% endblock %}
+{% block dashboard_head %}
+
+{% endblock %}
+
{% block page_title %}Trash{% endblock %}
+{% block dashboard_actions %}
+
-
Files and folders you have deleted will appear here.
-
This feature is coming soon.
+
+ {% if success_message %}
+
+ {{ success_message }}
+
+ {% endif %}
+
+ {% if error_message %}
+
+ {{ error_message }}
+
+ {% endif %}
+
+
+
+
+
+
+
+
×
+
Confirm Restore
+
+
+
+
+
+
+
×
+
Confirm Permanent Delete
+
+
+
+
+
+
{% endblock %}
\ No newline at end of file
diff --git a/retoors/views/admin.py b/retoors/views/admin.py
index 821eaad..6ee7f76 100644
--- a/retoors/views/admin.py
+++ b/retoors/views/admin.py
@@ -3,6 +3,25 @@ from aiohttp_session import get_session
from ..services.user_service import UserService
from ..models import QuotaUpdateModel, RegistrationModel
+
+async def verify_user_access(user_service: UserService, current_user_email: str, target_user_email: str) -> tuple[bool, dict]:
+ if target_user_email == current_user_email:
+ return True, None
+
+ current_user = await user_service.get_user_by_email(current_user_email)
+ if not current_user:
+ return False, {"error": "Current user not found", "status": 401}
+
+ target_user = await user_service.get_user_by_email(target_user_email)
+ if not target_user:
+ return False, {"error": "Target user not found", "status": 404}
+
+ if current_user.get("is_customer", True):
+ if target_user.get("parent_email") != current_user_email:
+ return False, {"error": "Forbidden: You can only manage users you created", "status": 403}
+
+ return True, None
+
async def get_users(request: web.Request) -> web.Response:
user_service: UserService = request.app["user_service"]
session = await get_session(request)
@@ -11,18 +30,8 @@ async def get_users(request: web.Request) -> web.Response:
if not current_user_email:
return web.json_response({"error": "Unauthorized"}, status=401)
- # For now, let's assume only the main user can see all users.
- # In a real application, you'd have roles/permissions.
- # The main user is the one who created the account.
- # If the current user is the main user, they can see all users.
- # Otherwise, they can only see users they created (their "team").
-
- # This logic needs to be refined based on how "main user" is identified.
- # For now, let's return all users for simplicity, assuming the logged-in user has admin-like access to this page.
- # A more robust solution would involve checking if the current_user_email is the 'owner' of the site.
- users = user_service.get_all_users()
-
- # Filter out sensitive information like password and reset tokens
+ users = await user_service.get_managed_users(current_user_email)
+
safe_users = []
for user in users:
safe_user = {k: v for k, v in user.items() if k not in ["password", "reset_token", "reset_token_expiry"]}
@@ -41,13 +50,12 @@ async def add_user(request: web.Request) -> web.Response:
try:
data = await request.json()
registration_data = RegistrationModel(**data)
-
- # The current user is the parent of the new user
- new_user = user_service.create_user(
+
+ new_user = await user_service.create_user(
full_name=registration_data.full_name,
email=registration_data.email,
password=registration_data.password,
- parent_email=current_user_email # Assign current user as parent
+ parent_email=current_user_email
)
safe_new_user = {k: v for k, v in new_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]}
return web.json_response({"message": "User added successfully", "user": safe_new_user}, status=201)
@@ -64,20 +72,18 @@ async def update_user_quota(request: web.Request) -> web.Response:
if not current_user_email:
return web.json_response({"error": "Unauthorized"}, status=401)
-
+
if not target_user_email:
return web.json_response({"error": "User email not provided"}, status=400)
- # Ensure the current user has permission to update this user's quota
- # For now, allow if current_user_email is the target_user_email or if target_user is a child of current_user
- target_user = user_service.get_user_by_email(target_user_email)
- if not target_user or (target_user_email != current_user_email and target_user.get("parent_email") != current_user_email):
- return web.json_response({"error": "Forbidden: You do not have permission to update this user's quota"}, status=403)
+ has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email)
+ if not has_access:
+ return web.json_response({"error": error_response["error"]}, status=error_response["status"])
try:
data = await request.json()
quota_update_data = QuotaUpdateModel(**data)
- user_service.update_user_quota(target_user_email, quota_update_data.new_quota_gb)
+ await user_service.update_user_quota(target_user_email, quota_update_data.new_quota_gb)
return web.json_response({"message": f"Quota for {target_user_email} updated successfully"})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
@@ -92,19 +98,18 @@ async def delete_user(request: web.Request) -> web.Response:
if not current_user_email:
return web.json_response({"error": "Unauthorized"}, status=401)
-
+
if not target_user_email:
return web.json_response({"error": "User email not provided"}, status=400)
- # Prevent a user from deleting themselves or a parent user
if target_user_email == current_user_email:
return web.json_response({"error": "Forbidden: You cannot delete your own account from this interface"}, status=403)
- target_user = user_service.get_user_by_email(target_user_email)
- if not target_user or target_user.get("parent_email") != current_user_email:
- return web.json_response({"error": "Forbidden: You do not have permission to delete this user"}, status=403)
+ has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email)
+ if not has_access:
+ return web.json_response({"error": error_response["error"]}, status=error_response["status"])
- if user_service.delete_user(target_user_email):
+ if await user_service.delete_user(target_user_email):
return web.json_response({"message": f"User {target_user_email} deleted successfully"})
else:
return web.json_response({"error": "User not found or could not be deleted"}, status=404)
@@ -117,14 +122,15 @@ async def get_user_details(request: web.Request) -> web.Response:
if not current_user_email:
return web.json_response({"error": "Unauthorized"}, status=401)
-
+
if not target_user_email:
return web.json_response({"error": "User email not provided"}, status=400)
- target_user = user_service.get_user_by_email(target_user_email)
- if not target_user or (target_user_email != current_user_email and target_user.get("parent_email") != current_user_email):
- return web.json_response({"error": "Forbidden: You do not have permission to view this user's details"}, status=403)
+ has_access, error_response = await verify_user_access(user_service, current_user_email, target_user_email)
+ if not has_access:
+ return web.json_response({"error": error_response["error"]}, status=error_response["status"])
+ target_user = await user_service.get_user_by_email(target_user_email)
safe_user = {k: v for k, v in target_user.items() if k not in ["password", "reset_token", "reset_token_expiry"]}
return web.json_response({"user": safe_user})
@@ -136,15 +142,14 @@ async def delete_team(request: web.Request) -> web.Response:
if not current_user_email:
return web.json_response({"error": "Unauthorized"}, status=401)
-
+
if not target_parent_email:
return web.json_response({"error": "Parent email not provided"}, status=400)
- # Only the parent user can delete their "team" (users they created)
if current_user_email != target_parent_email:
return web.json_response({"error": "Forbidden: You do not have permission to delete this team"}, status=403)
- deleted_count = user_service.delete_users_by_parent_email(target_parent_email)
+ deleted_count = await user_service.delete_users_by_parent_email(target_parent_email)
if deleted_count > 0:
return web.json_response({"message": f"Successfully deleted {deleted_count} users from the team managed by {target_parent_email}"})
else:
diff --git a/retoors/views/auth.py b/retoors/views/auth.py
index 8124094..718f4d6 100644
--- a/retoors/views/auth.py
+++ b/retoors/views/auth.py
@@ -55,7 +55,7 @@ class LoginView(CustomPydanticView):
)
user_service: UserService = self.request.app["user_service"]
- if user_service.authenticate_user(login_data.email, login_data.password):
+ if await user_service.authenticate_user(login_data.email, login_data.password):
session = await new_session(self.request)
session["user_email"] = login_data.email
raise web.HTTPFound("/dashboard")
@@ -93,7 +93,7 @@ class RegistrationView(CustomPydanticView):
user_service: UserService = self.request.app["user_service"]
try:
- user_service.create_user(user_data.full_name, user_data.email, user_data.password) # Changed username to full_name
+ await user_service.create_user(user_data.full_name, user_data.email, user_data.password)
# Render email content
email_context = {
@@ -148,10 +148,10 @@ class ForgotPasswordView(CustomPydanticView):
)
user_service: UserService = self.request.app["user_service"]
- user = user_service.get_user_by_email(forgot_password_data.email)
+ user = await user_service.get_user_by_email(forgot_password_data.email)
if user:
- token = user_service.generate_reset_token(forgot_password_data.email)
+ token = await user_service.generate_reset_token(forgot_password_data.email)
if token:
reset_link = self.request.url.join(
self.request.app.router["reset_password"].url_for(token=token)
@@ -217,7 +217,7 @@ class ResetPasswordView(CustomPydanticView):
)
user_service: UserService = self.request.app["user_service"]
- user = user_service.get_user_by_reset_token(token) # Corrected method call
+ user = await user_service.get_user_by_reset_token(token)
if not user:
return aiohttp_jinja2.render_template(
self.template_name,
@@ -225,7 +225,7 @@ class ResetPasswordView(CustomPydanticView):
{"error": "Invalid or expired password reset link.", "request": self.request, "errors": {}, "token": token},
)
- if user_service.reset_password(user["email"], token, reset_password_data.password):
+ if await user_service.reset_password(user["email"], token, reset_password_data.password):
# Send password changed confirmation email
email_context = {
"user_name": user["full_name"],
diff --git a/retoors/views/migrate.py b/retoors/views/migrate.py
new file mode 100644
index 0000000..c157dfe
--- /dev/null
+++ b/retoors/views/migrate.py
@@ -0,0 +1,11 @@
+from aiohttp import web
+
+from ..helpers.auth import login_required
+
+class MigrateView(web.View):
+ @login_required
+ async def post(self):
+ user_email = self.request["user"]["email"]
+ file_service = self.request.app["file_service"]
+ await file_service.migrate_old_files(user_email)
+ return web.json_response({"status": "success", "message": "Migration completed"})
\ No newline at end of file
diff --git a/retoors/views/site.py b/retoors/views/site.py
index c3935b7..13a9784 100644
--- a/retoors/views/site.py
+++ b/retoors/views/site.py
@@ -298,7 +298,7 @@ class FileBrowserView(web.View):
return json_response({"error": "Failed to generate share links for any selected items"}, status=500)
logger.warning(f"FileBrowserView: Unknown file action for POST request: {route_name}")
- return web.Response(status=400, text="Unknown file action")
+ raise web.HTTPBadRequest(text="Unknown file action")
@login_required
async def get_download_file(self):
@@ -325,7 +325,7 @@ class FileBrowserView(web.View):
raise web.HTTPNotFound(text="File not found")
async def shared_file_handler(self):
- share_id = self.request.match_info.get("share_id")
+ share_id = self.match_info.get("share_id")
file_service = self.request.app["file_service"]
logger.debug(f"FileBrowserView: Handling shared file request for share_id: {share_id}")
@@ -333,16 +333,11 @@ class FileBrowserView(web.View):
if not shared_item:
logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id}")
- return aiohttp_jinja2.render_template(
- "pages/errors/404.html",
- self.request,
- {"request": self.request, "message": "Shared link is invalid or has expired."},
- status=404
- )
+ raise web.HTTPNotFound(text="Shared file not found or inaccessible")
user_email = shared_item["user_email"]
item_path = shared_item["item_path"]
- full_path = file_service._get_user_file_path(user_email, item_path)
+ full_path = file_service.get_user_file_system_path(user_email, item_path)
if full_path.is_file():
result = await file_service.get_shared_file_content(share_id)
@@ -355,7 +350,7 @@ class FileBrowserView(web.View):
return response
else:
logger.error(f"FileBrowserView: Failed to get content for shared file: {item_path} (share_id: {share_id})")
- raise web.HTTPNotFound(text="Shared file not found or inaccessible")
+ raise web.HTTPNotFound(text="Shared file not found or inaccessible within the shared folder.")
elif full_path.is_dir():
files = await file_service.get_shared_folder_content(share_id)
logger.info(f"FileBrowserView: Serving shared folder '{item_path}' for share_id: {share_id}")
@@ -388,16 +383,11 @@ class FileBrowserView(web.View):
shared_item = await file_service.get_shared_item(share_id)
if not shared_item:
logger.warning(f"FileBrowserView: Shared item not found or expired for share_id: {share_id} during download.")
- return aiohttp_jinja2.render_template(
- "pages/errors/404.html",
- self.request,
- {"request": self.request, "message": "Shared link is invalid or has expired."},
- status=404
- )
+ raise web.HTTPNotFound(text="Shared file not found or inaccessible within the shared folder.")
# Ensure the shared item is a directory if a file_path is provided
user_email = shared_item["user_email"]
- original_shared_item_path = file_service._get_user_file_path(user_email, shared_item["item_path"])
+ original_shared_item_path = file_service.get_user_file_system_path(user_email, shared_item["item_path"])
if not original_shared_item_path.is_dir():
logger.warning(f"FileBrowserView: Attempt to download a specific file from a shared item that is not a directory. Share_id: {share_id}")
@@ -449,6 +439,25 @@ class OrderView(CustomPydanticView):
class UserManagementView(web.View):
+
+ async def verify_user_access(self, target_user_email: str) -> bool:
+ current_user_email = self.request["user"]["email"]
+ if target_user_email == current_user_email:
+ return True
+
+ user_service = self.request.app["user_service"]
+ current_user = await user_service.get_user_by_email(current_user_email)
+ target_user = await user_service.get_user_by_email(target_user_email)
+
+ if not target_user:
+ return False
+
+ if current_user.get("is_customer", True):
+ if target_user.get("parent_email") != current_user_email:
+ return False
+
+ return True
+
@login_required
async def get(self):
route_name = self.request.match_info.route.name
@@ -531,14 +540,14 @@ class UserManagementView(web.View):
parent_email = self.request["user"]["email"]
try:
- new_user = user_service.create_user(
+ new_user = await user_service.create_user(
full_name=full_name,
email=email,
password=password,
parent_email=parent_email
)
- user_service.update_user_quota(email, float(storage_quota_gb))
+ await user_service.update_user_quota(email, float(storage_quota_gb))
raise web.HTTPFound(
self.request.app.router["users"].url_for().with_query(
@@ -562,8 +571,11 @@ class UserManagementView(web.View):
async def edit_user_page(self):
email = self.request.match_info.get("email")
+ if not await self.verify_user_access(email):
+ raise web.HTTPForbidden(text="You do not have permission to access this user")
+
user_service = self.request.app["user_service"]
- user_data = user_service.get_user_by_email(email)
+ user_data = await user_service.get_user_by_email(email)
if not user_data:
raise web.HTTPNotFound(text="User not found")
@@ -585,6 +597,10 @@ class UserManagementView(web.View):
async def edit_user_submit(self):
email = self.request.match_info.get("email")
+
+ if not await self.verify_user_access(email):
+ raise web.HTTPForbidden(text="You do not have permission to access this user")
+
data = await self.request.post()
storage_quota_gb = data.get("storage_quota_gb", "")
@@ -598,7 +614,7 @@ class UserManagementView(web.View):
errors["storage_quota_gb"] = "Invalid storage quota value"
user_service = self.request.app["user_service"]
- user_data = user_service.get_user_by_email(email)
+ user_data = await user_service.get_user_by_email(email)
if not user_data:
raise web.HTTPNotFound(text="User not found")
@@ -616,7 +632,7 @@ class UserManagementView(web.View):
}
)
- user_service.update_user_quota(email, storage_quota_gb)
+ await user_service.update_user_quota(email, storage_quota_gb)
raise web.HTTPFound(
self.request.app.router["edit_user"].url_for(email=email).with_query(
@@ -627,8 +643,11 @@ class UserManagementView(web.View):
async def user_details_page(self):
email = self.request.match_info.get("email")
+ if not await self.verify_user_access(email):
+ raise web.HTTPForbidden(text="You do not have permission to access this user")
+
user_service = self.request.app["user_service"]
- user_data = user_service.get_user_by_email(email)
+ user_data = await user_service.get_user_by_email(email)
if not user_data:
raise web.HTTPNotFound(text="User not found")
@@ -647,13 +666,16 @@ class UserManagementView(web.View):
async def delete_user_submit(self):
email = self.request.match_info.get("email")
+ if not await self.verify_user_access(email):
+ raise web.HTTPForbidden(text="You do not have permission to access this user")
+
user_service = self.request.app["user_service"]
- user_data = user_service.get_user_by_email(email)
+ user_data = await user_service.get_user_by_email(email)
if not user_data:
raise web.HTTPNotFound(text="User not found")
- user_service.delete_user(email)
+ await user_service.delete_user(email)
raise web.HTTPFound(
self.request.app.router["users"].url_for().with_query(
diff --git a/tests/conftest.py b/tests/conftest.py
index 9291251..3b07013 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -57,16 +57,14 @@ def temp_users_json(tmp_path):
@pytest.fixture
def file_service_instance(temp_user_files_dir, temp_users_json):
"""Fixture to provide a FileService instance with temporary directories."""
- return FileService(temp_user_files_dir, temp_users_json)
+ user_service = UserService(temp_users_json) # Create a UserService instance
+ return FileService(temp_user_files_dir, user_service) # Pass the UserService instance
+
+
@pytest.fixture
-def create_app_instance():
- """Fixture to create a new aiohttp application instance."""
- return create_app()
-
-@pytest.fixture
-def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_instance):
+def create_test_app(mocker, temp_user_files_dir, temp_users_json):
"""Fixture to create a test aiohttp application with mocked services."""
from aiohttp import web
@@ -81,16 +79,15 @@ def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_i
app.middlewares.append(error_middleware)
app.middlewares.append(user_middleware)
- # Mock UserService
- mock_user_service = mocker.MagicMock(spec=UserService)
-
+ # Use a real UserService with a temporary users.json for the app
+ app["user_service"] = UserService(temp_users_json)
+
# Mock scheduler
mock_scheduler = mocker.MagicMock()
mock_scheduler.spawn = mocker.AsyncMock()
mock_scheduler.close = mocker.AsyncMock()
- app["user_service"] = mock_user_service
- app["file_service"] = file_service_instance
+ app["file_service"] = FileService(temp_user_files_dir, app["user_service"])
app["scheduler"] = mock_scheduler
# Setup Jinja2 for templates
@@ -102,42 +99,14 @@ def create_test_app(mocker, temp_user_files_dir, temp_users_json, file_service_i
return app
-@pytest.fixture(scope="function")
-def mock_users_db_fixture():
- """
- Fixture to simulate a user database for dynamic mocking,
- reset for each test function.
- """
- return {
- "admin@example.com": {
- "full_name": "Admin User",
- "email": "admin@example.com",
- "password": "password", # Store plain password for mock authentication
- "hashed_password": "hashed_password", # For consistency with real service
- "storage_quota_gb": 100,
- "storage_used_gb": 10,
- "parent_email": None,
- "reset_token": None,
- "reset_token_expiry": None,
- },
- "child1@example.com": {
- "full_name": "Child User 1",
- "email": "child1@example.com",
- "password": "password",
- "hashed_password": "hashed_password",
- "storage_quota_gb": 50,
- "storage_used_gb": 5,
- "parent_email": "admin@example.com",
- "shared_items": {}
- }
- }
+
@pytest.fixture
async def client(
- aiohttp_client, mocker: MockerFixture, create_app_instance, mock_users_db_fixture
+ aiohttp_client, mocker: MockerFixture, create_test_app
):
- app = create_app_instance # Use the new fixture
+ app = create_test_app # Use create_test_app for consistent test environment
# Directly set app["scheduler"] to a mock object
mock_scheduler_instance = mocker.MagicMock()
@@ -145,189 +114,20 @@ async def client(
mock_scheduler_instance.close = mocker.AsyncMock() # Ensure close is awaitable
app["scheduler"] = mock_scheduler_instance
- # Create temporary data files for testing
- base_path = Path(__file__).parent.parent
- data_path = base_path / "data"
- data_path.mkdir(exist_ok=True)
-
- users_file = data_path / "users.json"
- with open(users_file, "w") as f:
- json.dump([], f)
-
- config_file = data_path / "config.json"
- with open(config_file, "w") as f:
- json.dump({"price_per_gb": 0.0}, f)
-
- app["config_service"] = ConfigService(data_path / "config.json")
+ # The UserService and ConfigService are now set up in create_test_app with temporary files.
+ # No need to manually create users.json or config.json here.
client = await aiohttp_client(app)
- # Access the real UserService instance and mock its methods
- mock_user_service_instance = client.app["user_service"]
-
- # Use the mock_users_db_fixture
- mock_users_db = mock_users_db_fixture
-
- def mock_authenticate_user(email, password):
- user = mock_users_db.get(email)
- if user and user["password"] == password:
- return user
- return None
-
- def mock_get_user_by_email(email):
- return mock_users_db.get(email)
-
- def mock_create_user(full_name, email, password, parent_email=None):
- if email in mock_users_db:
- raise ValueError("User with this email already exists")
- new_user = {
- "full_name": full_name,
- "email": email,
- "password": password,
- "hashed_password": "hashed_password",
- "storage_quota_gb": 5,
- "storage_used_gb": 0,
- "parent_email": parent_email,
- "reset_token": None,
- "reset_token_expiry": None,
- }
- mock_users_db[email] = new_user
- return new_user
-
- def mock_reset_password(email, token, new_password): # Added token argument
- user = mock_users_db.get(email)
- if user and user.get("reset_token") == token and user.get("reset_token_expiry"):
- expiry_time = datetime.datetime.fromisoformat(user["reset_token_expiry"])
- if expiry_time > datetime.datetime.now(datetime.timezone.utc):
- user["password"] = new_password
- user["hashed_password"] = "new_hashed_password" # Simulate hashing
- user["reset_token"] = None
- user["reset_token_expiry"] = None
- return True
- return False
-
- def mock_generate_reset_token(email):
- user = mock_users_db.get(email)
- if user:
- # In a real scenario, this would generate a unique token and expiry
- user["reset_token"] = "test_token"
- user["reset_token_expiry"] = "2030-11-08T20:00:00Z" # A future date
- return "test_token"
- return None
-
- def mock_validate_reset_token(email, token):
- if (
- token == "expiredtoken123"
- ): # Explicitly handle the expired token from the test
- return False
- user = mock_users_db.get(email)
- if user and user.get("reset_token") == token and user.get("reset_token_expiry"):
- expiry_time = datetime.datetime.fromisoformat(
- user["reset_token_expiry"]
- )
- if expiry_time > datetime.datetime.now(datetime.timezone.utc):
- return True
- return False
-
- def mock_save_users():
- # This mock ensures that changes to user objects within tests are reflected in mock_users_db
- # In a real scenario, this would write to a file or database.
- pass # The mock_users_db is already being modified directly by other mocks
-
- def mock_get_all_users():
- return list(mock_users_db.values())
-
- def mock_get_users_by_parent_email(parent_email):
- return [
- user
- for user in mock_users_db.values()
- if user.get("parent_email") == parent_email
- ]
-
- def mock_delete_user(email):
- if email in mock_users_db:
- del mock_users_db[email]
- return True
- return False
-
- def mock_delete_users_by_parent_email(parent_email):
- initial_count = len(mock_users_db)
- users_to_delete = [
- email
- for email, user in mock_users_db.items()
- if user.get("parent_email") == parent_email
- ]
- for email in users_to_delete:
- del mock_users_db[email]
- return initial_count - len(mock_users_db)
-
- mocker.patch.object(
- mock_user_service_instance,
- "authenticate_user",
- side_effect=mock_authenticate_user,
- )
- mocker.patch.object(
- mock_user_service_instance,
- "get_user_by_email",
- side_effect=mock_get_user_by_email,
- )
- mocker.patch.object(
- mock_user_service_instance, "create_user", side_effect=mock_create_user
- )
- mocker.patch.object(
- mock_user_service_instance, "get_all_users", side_effect=mock_get_all_users
- )
- mocker.patch.object(
- mock_user_service_instance, "update_user_quota", return_value=None
- ) # Keep as is for now
- mocker.patch.object(
- mock_user_service_instance, "delete_user", side_effect=mock_delete_user
- )
- mocker.patch.object(
- mock_user_service_instance,
- "get_users_by_parent_email",
- side_effect=mock_get_users_by_parent_email,
- )
- mocker.patch.object(
- mock_user_service_instance,
- "delete_users_by_parent_email",
- side_effect=mock_delete_users_by_parent_email,
- )
- mocker.patch.object(
- mock_user_service_instance,
- "generate_reset_token",
- side_effect=mock_generate_reset_token,
- )
- mocker.patch.object(
- mock_user_service_instance,
- "get_user_by_reset_token",
- side_effect=lambda token: next(
- (
- user
- for user in mock_users_db.values()
- if user.get("reset_token") == token
- ),
- None,
- ),
- )
- mocker.patch.object(
- mock_user_service_instance, "reset_password", side_effect=mock_reset_password
- )
- mocker.patch.object(
- mock_user_service_instance,
- "validate_reset_token",
- side_effect=mock_validate_reset_token,
- )
- mocker.patch.object(
- mock_user_service_instance, "_save_users", side_effect=mock_save_users
- )
+ # The UserService is now a real instance, so we don't need to mock its methods here.
+ # The mock_users_db_fixture is also no longer needed in this context.
try:
yield client
finally:
- # Clean up temporary files
- users_file.unlink(missing_ok=True)
- config_file.unlink(missing_ok=True) # Use missing_ok for robustness
+ # Clean up temporary files created by create_test_app if necessary,
+ # but tmp_path usually handles this.
+ pass
@pytest.fixture
@@ -338,38 +138,9 @@ async def logged_in_client(aiohttp_client, create_test_app, mocker):
user_service = app["user_service"]
- def mock_create_user(full_name, email, password, parent_email=None):
- return {
- "full_name": full_name,
- "email": email,
- "password": "hashed_password",
- "storage_quota_gb": 10,
- "storage_used_gb": 0,
- "parent_email": parent_email,
- "shared_items": {}
- }
-
- def mock_authenticate_user(email, password):
- return {
- "email": email,
- "full_name": "Test User",
- "is_admin": False,
- "storage_quota_gb": 10,
- "storage_used_gb": 0
- }
-
- def mock_get_user_by_email(email):
- return {
- "email": email,
- "full_name": "Test User",
- "is_admin": False,
- "storage_quota_gb": 10,
- "storage_used_gb": 0
- }
-
- mocker.patch.object(user_service, "create_user", side_effect=mock_create_user)
- mocker.patch.object(user_service, "authenticate_user", side_effect=mock_authenticate_user)
- mocker.patch.object(user_service, "get_user_by_email", side_effect=mock_get_user_by_email)
+ # The UserService is now a real instance, so we don't need to mock its methods here.
+ # The create_user, authenticate_user, and get_user_by_email methods will interact
+ # with the real UserService instance.
await client.post(
"/register",
@@ -394,38 +165,9 @@ async def logged_in_admin_client(aiohttp_client, create_test_app, mocker):
user_service = app["user_service"]
- def mock_create_user(full_name, email, password, parent_email=None):
- return {
- "full_name": full_name,
- "email": email,
- "password": "hashed_password",
- "storage_quota_gb": 100,
- "storage_used_gb": 0,
- "parent_email": parent_email,
- "shared_items": {}
- }
-
- def mock_authenticate_user(email, password):
- return {
- "email": email,
- "full_name": "Admin User",
- "is_admin": True,
- "storage_quota_gb": 100,
- "storage_used_gb": 0
- }
-
- def mock_get_user_by_email(email):
- return {
- "email": email,
- "full_name": "Admin User",
- "is_admin": True,
- "storage_quota_gb": 100,
- "storage_used_gb": 0
- }
-
- mocker.patch.object(user_service, "create_user", side_effect=mock_create_user)
- mocker.patch.object(user_service, "authenticate_user", side_effect=mock_authenticate_user)
- mocker.patch.object(user_service, "get_user_by_email", side_effect=mock_get_user_by_email)
+ # The UserService is now a real instance, so we don't need to mock its methods here.
+ # The create_user, authenticate_user, and get_user_by_email methods will interact
+ # with the real UserService instance.
await client.post(
"/register",
diff --git a/tests/test_auth.py b/tests/test_auth.py
index daa3a68..8195110 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -204,7 +204,7 @@ async def test_reset_password_get_valid_token(client):
"confirm_password": "old_password",
},
)
- token = user_service.generate_reset_token("test@example.com")
+ token = await user_service.generate_reset_token("test@example.com")
assert token is not None
resp = await client.get(f"/reset_password/{token}")
@@ -233,7 +233,7 @@ async def test_reset_password_post_success(client, mock_send_email):
"confirm_password": "old_password",
},
)
- token = user_service.generate_reset_token("test@example.com")
+ token = await user_service.generate_reset_token("test@example.com")
assert token is not None
resp = await client.post(
@@ -248,8 +248,8 @@ async def test_reset_password_post_success(client, mock_send_email):
assert resp.headers["Location"] == "/login?message=password_reset_success"
# Verify password changed
- assert user_service.authenticate_user("test@example.com", "new_password")
- assert not user_service.authenticate_user("test@example.com", "old_password")
+ assert await user_service.authenticate_user("test@example.com", "new_password")
+ assert not await user_service.authenticate_user("test@example.com", "old_password")
# Assert that confirmation email was sent
@@ -272,7 +272,7 @@ async def test_reset_password_post_password_mismatch(client):
"confirm_password": "old_password",
},
)
- token = user_service.generate_reset_token("test@example.com")
+ token = await user_service.generate_reset_token("test@example.com")
assert token is not None
resp = await client.post(
@@ -286,7 +286,7 @@ async def test_reset_password_post_password_mismatch(client):
text = await resp.text()
assert "Passwords do not match" in text
# Password should not have changed
- assert user_service.authenticate_user("test@example.com", "old_password")
+ assert await user_service.authenticate_user("test@example.com", "old_password")
async def test_reset_password_post_invalid_token(client):
@@ -301,7 +301,7 @@ async def test_reset_password_post_invalid_token(client):
},
)
# Generate a token but don't use it, or use an expired one
- user_service.generate_reset_token("test@example.com") # This will be overwritten or ignored
+ await user_service.generate_reset_token("test@example.com") # This will be overwritten or ignored
resp = await client.post(
"/reset_password/invalidtoken",
@@ -314,7 +314,7 @@ async def test_reset_password_post_invalid_token(client):
text = await resp.text()
assert "Invalid or expired password reset link." in text
# Password should not have changed
- assert user_service.authenticate_user("test@example.com", "old_password")
+ assert await user_service.authenticate_user("test@example.com", "old_password")
async def test_reset_password_post_expired_token(client, mock_users_db_fixture):
@@ -329,7 +329,7 @@ async def test_reset_password_post_expired_token(client, mock_users_db_fixture):
},
)
# Manually set an expired token
- user = user_service.get_user_by_email("test@example.com")
+ user = await user_service.get_user_by_email("test@example.com")
token = "expiredtoken123"
user["reset_token"] = token
user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
@@ -346,7 +346,7 @@ async def test_reset_password_post_expired_token(client, mock_users_db_fixture):
text = await resp.text()
assert "Invalid or expired password reset link." in text
# Password should not have changed
- assert user_service.authenticate_user("test@example.com", "old_password")
+ assert await user_service.authenticate_user("test@example.com", "old_password")
async def test_reset_password_post_invalid_password_format(client):
@@ -360,7 +360,7 @@ async def test_reset_password_post_invalid_password_format(client):
"confirm_password": "old_password",
},
)
- token = user_service.generate_reset_token("test@example.com")
+ token = await user_service.generate_reset_token("test@example.com")
assert token is not None
resp = await client.post(
@@ -374,4 +374,4 @@ async def test_reset_password_post_invalid_password_format(client):
text = await resp.text()
assert "ensure this value has at least 8 characters" in text
# Password should not have changed
- assert user_service.authenticate_user("test@example.com", "old_password")
+ assert await user_service.authenticate_user("test@example.com", "old_password")
diff --git a/tests/test_file_browser.py b/tests/test_file_browser.py
index c472f26..3730968 100644
--- a/tests/test_file_browser.py
+++ b/tests/test_file_browser.py
@@ -27,33 +27,38 @@ async def test_file_service_list_files_empty(file_service_instance):
assert files == []
@pytest.mark.asyncio
-async def test_file_service_create_folder(file_service_instance, temp_user_files_dir):
+async def test_file_service_create_folder(file_service_instance):
user_email = "test@example.com"
folder_name = "my_new_folder"
success = await file_service_instance.create_folder(user_email, folder_name)
assert success
- expected_path = temp_user_files_dir / user_email / folder_name
- assert expected_path.is_dir()
+ metadata = await file_service_instance._load_metadata(user_email)
+ assert folder_name in metadata
+ assert metadata[folder_name]["type"] == "dir"
@pytest.mark.asyncio
-async def test_file_service_create_folder_exists(file_service_instance, temp_user_files_dir):
+async def test_file_service_create_folder_exists(file_service_instance):
user_email = "test@example.com"
folder_name = "existing_folder"
- (temp_user_files_dir / user_email).mkdir(parents=True)
- (temp_user_files_dir / user_email / folder_name).mkdir(parents=True)
+ await file_service_instance.create_folder(user_email, folder_name) # Create it first via service
success = await file_service_instance.create_folder(user_email, folder_name)
assert not success # Should return False if folder already exists
@pytest.mark.asyncio
-async def test_file_service_upload_file(file_service_instance, temp_user_files_dir):
+async def test_file_service_upload_file(file_service_instance):
user_email = "test@example.com"
file_name = "document.txt"
file_content = b"Hello, world!"
success = await file_service_instance.upload_file(user_email, file_name, file_content)
assert success
- expected_path = temp_user_files_dir / user_email / file_name
- assert expected_path.is_file()
- assert expected_path.read_bytes() == file_content
+ metadata = await file_service_instance._load_metadata(user_email)
+ assert file_name in metadata
+ assert metadata[file_name]["type"] == "file"
+ assert metadata[file_name]["size"] == len(file_content)
+ # Verify content by downloading
+ downloaded_content, downloaded_name = await file_service_instance.download_file(user_email, file_name)
+ assert downloaded_content == file_content
+ assert downloaded_name == file_name
@pytest.mark.asyncio
async def test_file_service_list_files_with_content(file_service_instance, temp_user_files_dir):
@@ -89,27 +94,38 @@ async def test_file_service_download_file_not_found(file_service_instance):
assert content is None
@pytest.mark.asyncio
-async def test_file_service_delete_file(file_service_instance, temp_user_files_dir):
+async def test_file_service_delete_file(file_service_instance):
user_email = "test@example.com"
file_name = "to_delete.txt"
- (temp_user_files_dir / user_email).mkdir(exist_ok=True)
- (temp_user_files_dir / user_email / file_name).write_bytes(b"delete me")
+ await file_service_instance.upload_file(user_email, file_name, b"delete me")
+
+ metadata_before = await file_service_instance._load_metadata(user_email)
+ assert file_name in metadata_before
success = await file_service_instance.delete_item(user_email, file_name)
assert success
- assert not (temp_user_files_dir / user_email / file_name).exists()
+
+ metadata_after = await file_service_instance._load_metadata(user_email)
+ assert file_name not in metadata_after
@pytest.mark.asyncio
-async def test_file_service_delete_folder(file_service_instance, temp_user_files_dir):
+async def test_file_service_delete_folder(file_service_instance):
user_email = "test@example.com"
folder_name = "folder_to_delete"
- (temp_user_files_dir / user_email).mkdir(parents=True)
- (temp_user_files_dir / user_email / folder_name).mkdir(parents=True)
- (temp_user_files_dir / user_email / folder_name / "nested.txt").write_bytes(b"nested")
+ nested_file = f"{folder_name}/nested.txt"
+ await file_service_instance.create_folder(user_email, folder_name)
+ await file_service_instance.upload_file(user_email, nested_file, b"nested content")
+
+ metadata_before = await file_service_instance._load_metadata(user_email)
+ assert folder_name in metadata_before
+ assert nested_file in metadata_before
success = await file_service_instance.delete_item(user_email, folder_name)
assert success
- assert not (temp_user_files_dir / user_email / folder_name).exists()
+
+ metadata_after = await file_service_instance._load_metadata(user_email)
+ assert folder_name not in metadata_after
+ assert nested_file not in metadata_after
@pytest.mark.asyncio
async def test_file_service_delete_nonexistent(file_service_instance):
@@ -199,14 +215,15 @@ async def test_file_browser_get_authorized_with_files(logged_in_client: TestClie
assert "my_file.txt" in text
@pytest.mark.asyncio
-async def test_file_browser_new_folder(logged_in_client: TestClient, file_service_instance, temp_user_files_dir):
+async def test_file_browser_new_folder(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com"
resp = await logged_in_client.post("/files/new_folder", data={"folder_name": "new_folder_via_web"}, allow_redirects=False)
assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files")
- expected_path = temp_user_files_dir / user_email / "new_folder_via_web"
- assert expected_path.is_dir()
+ metadata = await file_service_instance._load_metadata(user_email)
+ assert "new_folder_via_web" in metadata
+ assert metadata["new_folder_via_web"]["type"] == "dir"
@pytest.mark.asyncio
async def test_file_browser_new_folder_missing_name(logged_in_client: TestClient):
@@ -228,23 +245,29 @@ async def test_file_browser_new_folder_exists(logged_in_client: TestClient, file
assert f"error=Folder+'{folder_name}'+already+exists+or+could+not+be+created" in resp.headers["Location"]
@pytest.mark.asyncio
-async def test_file_browser_upload_file(logged_in_client: TestClient, file_service_instance, temp_user_files_dir):
+async def test_file_browser_upload_file(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com"
+ file_name = "uploaded.txt"
file_content = b"Uploaded content from web."
from io import BytesIO
data = aiohttp.FormData()
data.add_field('file',
BytesIO(file_content),
- filename='uploaded.txt',
+ filename=file_name,
content_type='text/plain')
resp = await logged_in_client.post("/files/upload", data=data, allow_redirects=False)
assert resp.status == 200
- expected_path = temp_user_files_dir / user_email / "uploaded.txt"
- assert expected_path.is_file()
- assert expected_path.read_bytes() == file_content
+ metadata = await file_service_instance._load_metadata(user_email)
+ assert file_name in metadata
+ assert metadata[file_name]["type"] == "file"
+ assert metadata[file_name]["size"] == len(file_content)
+ # Verify content by downloading
+ downloaded_content, downloaded_name = await file_service_instance.download_file(user_email, file_name)
+ assert downloaded_content == file_content
+ assert downloaded_name == file_name
@pytest.mark.asyncio
async def test_file_browser_download_file(logged_in_client: TestClient, file_service_instance):
@@ -265,41 +288,52 @@ async def test_file_browser_download_file_not_found(logged_in_client: TestClient
assert "File not found" in await response.text()
@pytest.mark.asyncio
-async def test_file_browser_delete_file(logged_in_client: TestClient, file_service_instance, temp_user_files_dir):
+async def test_file_browser_delete_file(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com"
file_name = "web_delete.txt"
await file_service_instance.upload_file(user_email, file_name, b"delete this")
- expected_path = temp_user_files_dir / user_email / file_name
- assert expected_path.is_file()
+ metadata_before = await file_service_instance._load_metadata(user_email)
+ assert file_name in metadata_before
resp = await logged_in_client.post(f"/files/delete/{file_name}", allow_redirects=False)
assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files")
- assert not expected_path.is_file()
+
+ metadata_after = await file_service_instance._load_metadata(user_email)
+ assert file_name not in metadata_after
@pytest.mark.asyncio
-async def test_file_browser_delete_folder(logged_in_client: TestClient, file_service_instance, temp_user_files_dir):
+async def test_file_browser_delete_folder(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com"
folder_name = "web_delete_folder"
+ nested_file = f"{folder_name}/nested.txt"
await file_service_instance.create_folder(user_email, folder_name)
- await file_service_instance.upload_file(user_email, f"{folder_name}/nested.txt", b"nested")
+ await file_service_instance.upload_file(user_email, nested_file, b"nested")
- expected_path = temp_user_files_dir / user_email / folder_name
- assert expected_path.is_dir()
+ metadata_before = await file_service_instance._load_metadata(user_email)
+ assert folder_name in metadata_before
+ assert nested_file in metadata_before
resp = await logged_in_client.post(f"/files/delete/{folder_name}", allow_redirects=False)
assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files")
- assert not expected_path.is_dir()
+
+ metadata_after = await file_service_instance._load_metadata(user_email)
+ assert folder_name not in metadata_after
+ assert nested_file not in metadata_after
@pytest.mark.asyncio
-async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, file_service_instance, temp_user_files_dir):
+async def test_file_browser_delete_multiple_files(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com"
file_names = ["multi_delete_1.txt", "multi_delete_2.txt", "multi_delete_3.txt"]
for name in file_names:
await file_service_instance.upload_file(user_email, name, b"content")
+ metadata_before = await file_service_instance._load_metadata(user_email)
+ for name in file_names:
+ assert name in metadata_before
+
paths_to_delete = [f"{name}" for name in file_names]
# Construct FormData for multiple paths
@@ -311,18 +345,23 @@ async def test_file_browser_delete_multiple_files(logged_in_client: TestClient,
assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files")
+ metadata_after = await file_service_instance._load_metadata(user_email)
for name in file_names:
- expected_path = temp_user_files_dir / user_email / name
- assert not expected_path.is_file()
+ assert name not in metadata_after
@pytest.mark.asyncio
-async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient, file_service_instance, temp_user_files_dir):
+async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient, file_service_instance):
user_email = "test@example.com"
folder_names = ["multi_delete_folder_1", "multi_delete_folder_2"]
for name in folder_names:
await file_service_instance.create_folder(user_email, name)
await file_service_instance.upload_file(user_email, f"{name}/nested.txt", b"nested content")
+ metadata_before = await file_service_instance._load_metadata(user_email)
+ for name in folder_names:
+ assert name in metadata_before
+ assert f"{name}/nested.txt" in metadata_before
+
paths_to_delete = [f"{name}" for name in folder_names]
# Construct FormData for multiple paths
@@ -334,9 +373,10 @@ async def test_file_browser_delete_multiple_folders(logged_in_client: TestClient
assert resp.status == 302 # Redirect
assert resp.headers["Location"].startswith("/files")
+ metadata_after = await file_service_instance._load_metadata(user_email)
for name in folder_names:
- expected_path = temp_user_files_dir / user_email / name
- assert not expected_path.is_dir()
+ assert name not in metadata_after
+ assert f"{name}/nested.txt" not in metadata_after
@pytest.mark.asyncio
async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: TestClient):
@@ -345,7 +385,7 @@ async def test_file_browser_delete_multiple_items_no_paths(logged_in_client: Tes
assert "error=No+items+selected+for+deletion" in resp.headers["Location"]
@pytest.mark.asyncio
-async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: TestClient, file_service_instance, temp_user_files_dir, mocker):
+async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: TestClient, file_service_instance, mocker):
user_email = "test@example.com"
file_names = ["fail_delete_1.txt", "fail_delete_2.txt"]
for name in file_names:
@@ -370,10 +410,11 @@ async def test_file_browser_delete_multiple_items_some_fail(logged_in_client: Te
assert resp.status == 302 # Redirect
assert "error=Some+items+failed+to+delete" in resp.headers["Location"]
+ metadata_after = await file_service_instance._load_metadata(user_email)
# Check if the first file still exists (failed to delete)
- assert (temp_user_files_dir / user_email / file_names[0]).is_file()
+ assert file_names[0] in metadata_after
# Check if the second file is deleted (succeeded)
- assert not (temp_user_files_dir / user_email / file_names[1]).is_file()
+ assert file_names[1] not in metadata_after
@pytest.mark.asyncio
async def test_file_browser_share_multiple_items_no_paths(logged_in_client: TestClient):
@@ -434,7 +475,7 @@ async def test_file_browser_download_shared_file_handler_fail_get_content(client
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat()
})
- mocker.patch.object(file_service_instance, "_get_user_file_path", return_value=mocker.MagicMock(is_dir=lambda: True))
+ mocker.patch.object(file_service_instance, "get_user_file_system_path", return_value=mocker.MagicMock(is_dir=lambda: True))
mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None)
resp = await client.get(f"/shared_file/{share_id}/download?file_path={file_name}")
@@ -447,14 +488,14 @@ async def test_file_browser_download_shared_file_handler_not_a_directory(client:
file_name = "shared_file.txt"
share_id = "test_share_id"
- mocker.patch.object(file_service_instance, "get_shared_item", return_value={
+ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={
"user_email": user_email,
"item_path": file_name,
"share_id": share_id,
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat()
})
- mocker.patch.object(file_service_instance, "_get_user_file_path", return_value=mocker.MagicMock(is_dir=lambda: False))
+ mocker.patch.object(client.app["file_service"], "get_user_file_system_path", return_value=mocker.MagicMock(is_dir=lambda: False))
resp = await client.get(f"/shared_file/{share_id}/download?file_path=some_file.txt")
assert resp.status == 400
@@ -462,11 +503,11 @@ async def test_file_browser_download_shared_file_handler_not_a_directory(client:
assert "Cannot download specific files from a shared item that is not a folder." in text
@pytest.mark.asyncio
async def test_file_browser_download_shared_file_handler_shared_item_not_found(client: TestClient, file_service_instance, mocker):
- mocker.patch.object(file_service_instance, "get_shared_item", return_value=None)
+ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value=None)
resp = await client.get("/shared_file/nonexistent_share_id/download?file_path=some_file.txt")
assert resp.status == 404
text = await resp.text()
- assert "Shared link is invalid or has expired." in text
+ assert "Shared file not found or inaccessible" in text
@pytest.mark.asyncio
async def test_file_browser_download_shared_file_handler_missing_file_path(client: TestClient):
@@ -480,7 +521,7 @@ async def test_file_browser_shared_file_handler_fail_get_content(client: TestCli
file_name = "shared_file.txt"
share_id = "test_share_id"
- mocker.patch.object(file_service_instance, "get_shared_item", return_value={
+ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={
"user_email": user_email,
"item_path": file_name,
"share_id": share_id,
@@ -488,7 +529,7 @@ async def test_file_browser_shared_file_handler_fail_get_content(client: TestCli
"expires_at": (datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=7)).isoformat()
})
mocker.patch("pathlib.Path.is_file", return_value=True) # Simulate it's a file
- mocker.patch.object(file_service_instance, "get_shared_file_content", return_value=None)
+ mocker.patch.object(client.app["file_service"], "get_shared_file_content", return_value=None)
resp = await client.get(f"/shared_file/{share_id}")
assert resp.status == 404
@@ -501,7 +542,7 @@ async def test_file_browser_shared_file_handler_neither_file_nor_dir(client: Tes
item_path = "mystery_item"
share_id = "test_share_id"
- mocker.patch.object(file_service_instance, "get_shared_item", return_value={
+ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value={
"user_email": user_email,
"item_path": item_path,
"share_id": share_id,
@@ -520,11 +561,11 @@ async def test_file_browser_shared_file_handler_neither_file_nor_dir(client: Tes
@pytest.mark.asyncio
async def test_file_browser_shared_file_handler_not_found(client: TestClient, file_service_instance, mocker):
- mocker.patch.object(file_service_instance, "get_shared_item", return_value=None)
+ mocker.patch.object(client.app["file_service"], "get_shared_item", return_value=None)
resp = await client.get("/shared_file/nonexistent_share_id")
assert resp.status == 404
text = await resp.text()
- assert "Shared link is invalid or has expired." in text
+ assert "Shared file not found or inaccessible" in text
@pytest.mark.asyncio
async def test_file_browser_unknown_post_action(logged_in_client: TestClient, mocker):
diff --git a/tests/test_storage_service.py b/tests/test_storage_service.py
new file mode 100644
index 0000000..a17a68d
--- /dev/null
+++ b/tests/test_storage_service.py
@@ -0,0 +1,183 @@
+import pytest
+import asyncio
+import shutil
+from pathlib import Path
+from retoors.services.storage_service import StorageService, UserStorageManager
+
+
+@pytest.fixture
+def test_storage():
+ storage = StorageService(base_path="data/test_user")
+ yield storage
+ if Path("data/test_user").exists():
+ shutil.rmtree("data/test_user")
+
+
+@pytest.fixture
+def user_manager():
+ manager = UserStorageManager()
+ manager.storage.base_path = Path("data/test_user")
+ yield manager
+ if Path("data/test_user").exists():
+ shutil.rmtree("data/test_user")
+
+
+@pytest.mark.asyncio
+async def test_save_and_load(test_storage):
+ user_email = "test@example.com"
+ identifier = "doc1"
+ data = {"title": "Test Document", "content": "Test content"}
+
+ result = await test_storage.save(user_email, identifier, data)
+ assert result is True
+
+ loaded_data = await test_storage.load(user_email, identifier)
+ assert loaded_data == data
+
+
+@pytest.mark.asyncio
+async def test_distributed_path_structure(test_storage):
+ user_email = "test@example.com"
+ identifier = "doc1"
+ data = {"test": "data"}
+
+ await test_storage.save(user_email, identifier, data)
+
+ user_base = test_storage._get_user_base_path(user_email)
+ file_path = test_storage._get_distributed_path(user_base, identifier)
+
+ assert file_path.exists()
+ assert len(file_path.parent.name) == 3
+ assert len(file_path.parent.parent.name) == 3
+ assert len(file_path.parent.parent.parent.name) == 3
+
+
+@pytest.mark.asyncio
+async def test_user_isolation(test_storage):
+ user1_email = "user1@example.com"
+ user2_email = "user2@example.com"
+ identifier = "doc1"
+ data1 = {"user": "user1"}
+ data2 = {"user": "user2"}
+
+ await test_storage.save(user1_email, identifier, data1)
+ await test_storage.save(user2_email, identifier, data2)
+
+ loaded1 = await test_storage.load(user1_email, identifier)
+ loaded2 = await test_storage.load(user2_email, identifier)
+
+ assert loaded1 == data1
+ assert loaded2 == data2
+
+
+@pytest.mark.asyncio
+async def test_path_traversal_protection(test_storage):
+ user_email = "test@example.com"
+ malicious_identifier = "../../../etc/passwd"
+
+ await test_storage.save(user_email, malicious_identifier, {"test": "data"})
+
+ user_base = test_storage._get_user_base_path(user_email)
+ file_path = test_storage._get_distributed_path(user_base, malicious_identifier)
+
+ assert file_path.exists()
+ assert test_storage._validate_path(file_path, user_base)
+ assert str(file_path.resolve()).startswith(str(user_base.resolve()))
+
+
+@pytest.mark.asyncio
+async def test_delete(test_storage):
+ user_email = "test@example.com"
+ identifier = "doc1"
+ data = {"test": "data"}
+
+ await test_storage.save(user_email, identifier, data)
+ assert await test_storage.exists(user_email, identifier)
+
+ result = await test_storage.delete(user_email, identifier)
+ assert result is True
+ assert not await test_storage.exists(user_email, identifier)
+
+
+@pytest.mark.asyncio
+async def test_list_all(test_storage):
+ user_email = "test@example.com"
+
+ await test_storage.save(user_email, "doc1", {"id": 1})
+ await test_storage.save(user_email, "doc2", {"id": 2})
+ await test_storage.save(user_email, "doc3", {"id": 3})
+
+ all_docs = await test_storage.list_all(user_email)
+ assert len(all_docs) == 3
+ assert any(doc["id"] == 1 for doc in all_docs)
+ assert any(doc["id"] == 2 for doc in all_docs)
+ assert any(doc["id"] == 3 for doc in all_docs)
+
+
+@pytest.mark.asyncio
+async def test_delete_all(test_storage):
+ user_email = "test@example.com"
+
+ await test_storage.save(user_email, "doc1", {"id": 1})
+ await test_storage.save(user_email, "doc2", {"id": 2})
+
+ result = await test_storage.delete_all(user_email)
+ assert result is True
+
+ all_docs = await test_storage.list_all(user_email)
+ assert len(all_docs) == 0
+
+
+@pytest.mark.asyncio
+async def test_user_storage_manager(user_manager):
+ user_data = {
+ "full_name": "Test User",
+ "email": "test@example.com",
+ "password": "hashed_password",
+ "is_customer": True
+ }
+
+ await user_manager.save_user("test@example.com", user_data)
+
+ loaded_user = await user_manager.get_user("test@example.com")
+ assert loaded_user == user_data
+
+ assert await user_manager.user_exists("test@example.com")
+
+ await user_manager.delete_user("test@example.com")
+ assert not await user_manager.user_exists("test@example.com")
+
+
+@pytest.mark.asyncio
+async def test_list_users_by_parent(user_manager):
+ parent_user = {
+ "email": "parent@example.com",
+ "full_name": "Parent User",
+ "password": "hashed",
+ "is_customer": True
+ }
+
+ child_user1 = {
+ "email": "child1@example.com",
+ "full_name": "Child User 1",
+ "password": "hashed",
+ "parent_email": "parent@example.com",
+ "is_customer": True
+ }
+
+ child_user2 = {
+ "email": "child2@example.com",
+ "full_name": "Child User 2",
+ "password": "hashed",
+ "parent_email": "parent@example.com",
+ "is_customer": True
+ }
+
+ await user_manager.save_user("parent@example.com", parent_user)
+ await user_manager.save_user("child1@example.com", child_user1)
+ await user_manager.save_user("child2@example.com", child_user2)
+
+ children = await user_manager.list_users_by_parent("parent@example.com")
+ assert len(children) == 2
+ assert any(u["email"] == "child1@example.com" for u in children)
+ assert any(u["email"] == "child2@example.com" for u in children)
diff --git a/tests/test_user_service.py b/tests/test_user_service.py
index a5e69a8..8ecca28 100644
--- a/tests/test_user_service.py
+++ b/tests/test_user_service.py
@@ -18,28 +18,28 @@ def user_service(users_file):
return UserService(users_file)
@pytest.fixture
-def populated_user_service(user_service):
+async def populated_user_service(user_service):
"""Fixture to provide a UserService instance with some pre-populated users."""
- user_service.create_user("Admin User", "admin@example.com", "adminpass")
- user_service.create_user("Parent User", "parent@example.com", "parentpass")
- user_service.create_user("Child User 1", "child1@example.com", "childpass", "parent@example.com")
- user_service.create_user("Child User 2", "child2@example.com", "childpass", "parent@example.com")
+ await user_service.create_user("Admin User", "admin@example.com", "adminpass")
+ await user_service.create_user("Parent User", "parent@example.com", "parentpass")
+ await user_service.create_user("Child User 1", "child1@example.com", "childpass", "parent@example.com")
+ await user_service.create_user("Child User 2", "child2@example.com", "childpass", "parent@example.com")
return user_service
async def test_create_user_success(user_service):
- user = user_service.create_user("Test User", "test@example.com", "password123")
+ user = await user_service.create_user("Test User", "test@example.com", "password123")
assert user is not None
assert user["email"] == "test@example.com"
- assert user_service.get_user_by_email("test@example.com") is not None
+ assert await user_service.get_user_by_email("test@example.com") is not None
assert bcrypt.checkpw(b"password123", user["password"].encode('utf-8'))
async def test_create_user_duplicate_email(user_service):
- user_service.create_user("Test User", "test@example.com", "password123")
+ await user_service.create_user("Test User", "test@example.com", "password123")
with pytest.raises(ValueError, match="User with this email already exists"):
- user_service.create_user("Another User", "test@example.com", "anotherpass")
+ await user_service.create_user("Another User", "test@example.com", "anotherpass")
async def test_get_all_users(populated_user_service):
- users = populated_user_service.get_all_users()
+ users = await populated_user_service.get_all_users()
assert len(users) == 4
emails = {user["email"] for user in users}
assert "admin@example.com" in emails
@@ -48,70 +48,70 @@ async def test_get_all_users(populated_user_service):
assert "child2@example.com" in emails
async def test_get_users_by_parent_email(populated_user_service):
- children = populated_user_service.get_users_by_parent_email("parent@example.com")
+ children = await populated_user_service.get_users_by_parent_email("parent@example.com")
assert len(children) == 2
child_emails = {user["email"] for user in children}
assert "child1@example.com" in child_emails
assert "child2@example.com" in child_emails
- no_children = populated_user_service.get_users_by_parent_email("nonexistent@example.com")
+ no_children = await populated_user_service.get_users_by_parent_email("nonexistent@example.com")
assert len(no_children) == 0
- admin_children = populated_user_service.get_users_by_parent_email("admin@example.com")
+ admin_children = await populated_user_service.get_users_by_parent_email("admin@example.com")
assert len(admin_children) == 0
async def test_update_user_non_password_fields(populated_user_service):
- updated_user = populated_user_service.update_user("admin@example.com", full_name="Administrator", storage_quota_gb=10)
+ updated_user = await populated_user_service.update_user("admin@example.com", full_name="Administrator", storage_quota_gb=10)
assert updated_user is not None
assert updated_user["full_name"] == "Administrator"
assert updated_user["storage_quota_gb"] == 10
- retrieved_user = populated_user_service.get_user_by_email("admin@example.com")
+ retrieved_user = await populated_user_service.get_user_by_email("admin@example.com")
assert retrieved_user["full_name"] == "Administrator"
assert retrieved_user["storage_quota_gb"] == 10
async def test_update_user_password(populated_user_service):
- updated_user = populated_user_service.update_user("admin@example.com", password="newadminpass")
+ updated_user = await populated_user_service.update_user("admin@example.com", password="newadminpass")
assert updated_user is not None
- assert populated_user_service.authenticate_user("admin@example.com", "newadminpass")
- assert not populated_user_service.authenticate_user("admin@example.com", "adminpass")
+ assert await populated_user_service.authenticate_user("admin@example.com", "newadminpass")
+ assert not await populated_user_service.authenticate_user("admin@example.com", "adminpass")
async def test_update_user_nonexistent(user_service):
- updated_user = user_service.update_user("nonexistent@example.com", full_name="Non Existent")
+ updated_user = await user_service.update_user("nonexistent@example.com", full_name="Non Existent")
assert updated_user is None
async def test_delete_user_success(populated_user_service):
- assert populated_user_service.delete_user("admin@example.com") is True
- assert populated_user_service.get_user_by_email("admin@example.com") is None
- assert len(populated_user_service.get_all_users()) == 3
+ assert await populated_user_service.delete_user("admin@example.com") is True
+ assert await populated_user_service.get_user_by_email("admin@example.com") is None
+ assert len(await populated_user_service.get_all_users()) == 3
async def test_delete_user_nonexistent(user_service):
- assert user_service.delete_user("nonexistent@example.com") is False
+ assert await user_service.delete_user("nonexistent@example.com") is False
async def test_delete_users_by_parent_email_success(populated_user_service):
- deleted_count = populated_user_service.delete_users_by_parent_email("parent@example.com")
+ deleted_count = await populated_user_service.delete_users_by_parent_email("parent@example.com")
assert deleted_count == 2
- assert populated_user_service.get_user_by_email("child1@example.com") is None
- assert populated_user_service.get_user_by_email("child2@example.com") is None
- assert len(populated_user_service.get_all_users()) == 2 # Admin and Parent users remain
+ assert await populated_user_service.get_user_by_email("child1@example.com") is None
+ assert await populated_user_service.get_user_by_email("child2@example.com") is None
+ assert len(await populated_user_service.get_all_users()) == 2 # Admin and Parent users remain
async def test_delete_users_by_parent_email_no_match(user_service):
- deleted_count = user_service.delete_users_by_parent_email("nonexistent@example.com")
+ deleted_count = await user_service.delete_users_by_parent_email("nonexistent@example.com")
assert deleted_count == 0
async def test_authenticate_user_success(populated_user_service):
- assert populated_user_service.authenticate_user("admin@example.com", "adminpass") is True
+ assert await populated_user_service.authenticate_user("admin@example.com", "adminpass") is True
async def test_authenticate_user_fail_wrong_password(populated_user_service):
- assert populated_user_service.authenticate_user("admin@example.com", "wrongpass") is False
+ assert await populated_user_service.authenticate_user("admin@example.com", "wrongpass") is False
async def test_authenticate_user_fail_nonexistent_user(user_service):
- assert user_service.authenticate_user("nonexistent@example.com", "anypass") is False
+ assert await user_service.authenticate_user("nonexistent@example.com", "anypass") is False
async def test_generate_reset_token_success(populated_user_service):
- token = populated_user_service.generate_reset_token("admin@example.com")
+ token = await populated_user_service.generate_reset_token("admin@example.com")
assert token is not None
- user = populated_user_service.get_user_by_email("admin@example.com")
+ user = await populated_user_service.get_user_by_email("admin@example.com")
assert user["reset_token"] == token
assert user["reset_token_expiry"] is not None
# Check expiry is in the future
@@ -119,71 +119,71 @@ async def test_generate_reset_token_success(populated_user_service):
assert expiry_dt > datetime.datetime.now(datetime.timezone.utc)
async def test_generate_reset_token_nonexistent_user(user_service):
- token = user_service.generate_reset_token("nonexistent@example.com")
+ token = await user_service.generate_reset_token("nonexistent@example.com")
assert token is None
async def test_get_user_by_reset_token_valid(populated_user_service):
- token = populated_user_service.generate_reset_token("admin@example.com")
- user = populated_user_service.get_user_by_reset_token(token)
+ token = await populated_user_service.generate_reset_token("admin@example.com")
+ user = await populated_user_service.get_user_by_reset_token(token)
assert user is not None
assert user["email"] == "admin@example.com"
async def test_get_user_by_reset_token_invalid(populated_user_service):
- user = populated_user_service.get_user_by_reset_token("invalidtoken")
+ user = await populated_user_service.get_user_by_reset_token("invalidtoken")
assert user is None
async def test_get_user_by_reset_token_expired(populated_user_service):
- token = populated_user_service.generate_reset_token("admin@example.com")
- user = populated_user_service.get_user_by_email("admin@example.com")
+ token = await populated_user_service.generate_reset_token("admin@example.com")
+ user = await populated_user_service.get_user_by_email("admin@example.com")
# Manually expire the token
user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
- populated_user_service._save_users() # Save the expired state
+ await populated_user_service._save_users() # Save the expired state
- user_after_expiry = populated_user_service.get_user_by_reset_token(token)
+ user_after_expiry = await populated_user_service.get_user_by_reset_token(token)
assert user_after_expiry is None
async def test_validate_reset_token_success(populated_user_service):
- token = populated_user_service.generate_reset_token("admin@example.com")
- assert populated_user_service.validate_reset_token("admin@example.com", token) is True
+ token = await populated_user_service.generate_reset_token("admin@example.com")
+ assert await populated_user_service.validate_reset_token("admin@example.com", token) is True
async def test_validate_reset_token_fail_wrong_token(populated_user_service):
- populated_user_service.generate_reset_token("admin@example.com")
- assert populated_user_service.validate_reset_token("admin@example.com", "wrongtoken") is False
+ await populated_user_service.generate_reset_token("admin@example.com")
+ assert await populated_user_service.validate_reset_token("admin@example.com", "wrongtoken") is False
async def test_validate_reset_token_fail_nonexistent_user(user_service):
- assert user_service.validate_reset_token("nonexistent@example.com", "anytoken") is False
+ assert await user_service.validate_reset_token("nonexistent@example.com", "anytoken") is False
async def test_validate_reset_token_fail_expired_token(populated_user_service):
- token = populated_user_service.generate_reset_token("admin@example.com")
- user = populated_user_service.get_user_by_email("admin@example.com")
+ token = await populated_user_service.generate_reset_token("admin@example.com")
+ user = await populated_user_service.get_user_by_email("admin@example.com")
user["reset_token_expiry"] = (datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)).isoformat()
- populated_user_service._save_users()
+ await populated_user_service._save_users()
- assert populated_user_service.validate_reset_token("admin@example.com", token) is False
+ assert await populated_user_service.validate_reset_token("admin@example.com", token) is False
async def test_reset_password_success(populated_user_service):
- token = populated_user_service.generate_reset_token("admin@example.com")
- assert populated_user_service.reset_password("admin@example.com", token, "newadminpass") is True
- assert populated_user_service.authenticate_user("admin@example.com", "newadminpass")
- user = populated_user_service.get_user_by_email("admin@example.com")
+ token = await populated_user_service.generate_reset_token("admin@example.com")
+ assert await populated_user_service.reset_password("admin@example.com", token, "newadminpass") is True
+ assert await populated_user_service.authenticate_user("admin@example.com", "newadminpass")
+ user = await populated_user_service.get_user_by_email("admin@example.com")
assert user["reset_token"] is None
assert user["reset_token_expiry"] is None
async def test_reset_password_fail_invalid_token(populated_user_service):
- populated_user_service.generate_reset_token("admin@example.com")
- assert populated_user_service.reset_password("admin@example.com", "invalidtoken", "newadminpass") is False
- assert populated_user_service.authenticate_user("admin@example.com", "adminpass") # Password should not change
+ await populated_user_service.generate_reset_token("admin@example.com")
+ assert await populated_user_service.reset_password("admin@example.com", "invalidtoken", "newadminpass") is False
+ assert await populated_user_service.authenticate_user("admin@example.com", "adminpass") # Password should not change
async def test_reset_password_fail_nonexistent_user(user_service):
# Even if a token was somehow generated for a nonexistent user (which shouldn't happen),
# reset_password should fail.
- assert user_service.reset_password("nonexistent@example.com", "anytoken", "newpass") is False
+ assert await user_service.reset_password("nonexistent@example.com", "anytoken", "newpass") is False
async def test_update_user_quota_success(populated_user_service):
- populated_user_service.update_user_quota("admin@example.com", 20.5)
- user = populated_user_service.get_user_by_email("admin@example.com")
+ await populated_user_service.update_user_quota("admin@example.com", 20.5)
+ user = await populated_user_service.get_user_by_email("admin@example.com")
assert user["storage_quota_gb"] == 20.5
async def test_update_user_quota_nonexistent_user(user_service):
with pytest.raises(ValueError, match="User not found"):
- user_service.update_user_quota("nonexistent@example.com", 100)
+ await user_service.update_user_quota("nonexistent@example.com", 100)