diff --git a/benchmark.py b/benchmark.py new file mode 100755 index 0000000..83d7089 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python3 +""" +WebDAV Server Concurrent Benchmark Tool +Heavy load testing with performance metrics per method +""" + +import asyncio +import aiohttp +import time +import argparse +import statistics +from dataclasses import dataclass, field +from typing import List, Dict, Optional +from collections import defaultdict +import random +import string + + +@dataclass +class RequestMetrics: + """Metrics for a single request""" + method: str + duration: float + status: int + success: bool + error: Optional[str] = None + + +@dataclass +class MethodStats: + """Statistics for a specific HTTP method""" + method: str + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + total_duration: float = 0.0 + durations: List[float] = field(default_factory=list) + errors: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + + @property + def success_rate(self) -> float: + return (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0 + + @property + def avg_duration(self) -> float: + return self.total_duration / self.total_requests if self.total_requests > 0 else 0 + + @property + def requests_per_second(self) -> float: + return self.total_requests / self.total_duration if self.total_duration > 0 else 0 + + @property + def min_duration(self) -> float: + return min(self.durations) if self.durations else 0 + + @property + def max_duration(self) -> float: + return max(self.durations) if self.durations else 0 + + @property + def p50_duration(self) -> float: + return statistics.median(self.durations) if self.durations else 0 + + @property + def p95_duration(self) -> float: + if not self.durations: + return 0 + sorted_durations = sorted(self.durations) + index = int(len(sorted_durations) * 0.95) + return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1] + + @property + def p99_duration(self) -> float: + if not self.durations: + return 0 + sorted_durations = sorted(self.durations) + index = int(len(sorted_durations) * 0.99) + return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1] + + +class WebDAVBenchmark: + """WebDAV server benchmark runner""" + + def __init__(self, url: str, username: str, password: str, + concurrency: int = 50, duration: int = 60): + self.url = url.rstrip('/') + self.username = username + self.password = password + self.concurrency = concurrency + self.duration = duration + self.stats: Dict[str, MethodStats] = defaultdict(lambda: MethodStats(method="")) + self.start_time = 0 + self.stop_flag = False + self.auth = aiohttp.BasicAuth(username, password) + + def random_string(self, length: int = 10) -> str: + """Generate random string""" + return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + + async def record_metric(self, metric: RequestMetrics): + """Record a request metric""" + stats = self.stats[metric.method] + stats.method = metric.method + stats.total_requests += 1 + stats.total_duration += metric.duration + stats.durations.append(metric.duration) + + if metric.success: + stats.successful_requests += 1 + else: + stats.failed_requests += 1 + if metric.error: + stats.errors[metric.error] += 1 + + async def benchmark_options(self, session: aiohttp.ClientSession) -> RequestMetrics: + """Benchmark OPTIONS request""" + start = time.time() + try: + async with session.options(self.url, auth=self.auth) as resp: + duration = time.time() - start + return RequestMetrics( + method='OPTIONS', + duration=duration, + status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='OPTIONS', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_propfind(self, session: aiohttp.ClientSession, depth: int = 0) -> RequestMetrics: + """Benchmark PROPFIND request""" + propfind_body = ''' + + +''' + + start = time.time() + try: + async with session.request( + 'PROPFIND', + self.url, + auth=self.auth, + data=propfind_body, + headers={'Depth': str(depth), 'Content-Type': 'application/xml'} + ) as resp: + await resp.read() # Consume response + duration = time.time() - start + return RequestMetrics( + method='PROPFIND', + duration=duration, + status=resp.status, + success=resp.status == 207 + ) + except Exception as e: + return RequestMetrics( + method='PROPFIND', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_put(self, session: aiohttp.ClientSession) -> RequestMetrics: + """Benchmark PUT request""" + filename = f"bench_{self.random_string()}.txt" + content = self.random_string(1024).encode() # 1KB file + + start = time.time() + try: + async with session.put( + f"{self.url}/{filename}", + auth=self.auth, + data=content + ) as resp: + duration = time.time() - start + return RequestMetrics( + method='PUT', + duration=duration, + status=resp.status, + success=resp.status in [201, 204] + ) + except Exception as e: + return RequestMetrics( + method='PUT', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_get(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark GET request""" + start = time.time() + try: + + print(f"{self.url}/{filename}") + async with session.get( + f"{self.url}/{filename}", + auth=self.auth + ) as resp: + await resp.read() # Consume response + duration = time.time() - start + return RequestMetrics( + method='GET', + duration=duration, + status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='GET', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_head(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark HEAD request""" + start = time.time() + try: + async with session.head( + f"{self.url}/{filename}", + auth=self.auth + ) as resp: + duration = time.time() - start + return RequestMetrics( + method='HEAD', + duration=duration, + status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='HEAD', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_mkcol(self, session: aiohttp.ClientSession) -> RequestMetrics: + """Benchmark MKCOL request""" + dirname = f"bench_dir_{self.random_string()}" + + start = time.time() + try: + async with session.request( + 'MKCOL', + f"{self.url}/{dirname}/", + auth=self.auth + ) as resp: + duration = time.time() - start + return RequestMetrics( + method='MKCOL', + duration=duration, + status=resp.status, + success=resp.status == 201 + ) + except Exception as e: + return RequestMetrics( + method='MKCOL', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_proppatch(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark PROPPATCH request""" + proppatch_body = f''' + + + + Benchmark Test + + +''' + + start = time.time() + try: + async with session.request( + 'PROPPATCH', + f"{self.url}/{filename}", + auth=self.auth, + data=proppatch_body, + headers={'Content-Type': 'application/xml'} + ) as resp: + await resp.read() + duration = time.time() - start + return RequestMetrics( + method='PROPPATCH', + duration=duration, + status=resp.status, + success=resp.status == 207 + ) + except Exception as e: + return RequestMetrics( + method='PROPPATCH', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_copy(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark COPY request""" + dest_filename = f"copy_{self.random_string()}.txt" + + start = time.time() + try: + async with session.request( + 'COPY', + f"{self.url}/{filename}", + auth=self.auth, + headers={'Destination': f"{self.url}/{dest_filename}"} + ) as resp: + duration = time.time() - start + return RequestMetrics( + method='COPY', + duration=duration, + status=resp.status, + success=resp.status in [201, 204] + ) + except Exception as e: + return RequestMetrics( + method='COPY', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_move(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark MOVE request""" + dest_filename = f"moved_{self.random_string()}.txt" + + start = time.time() + try: + async with session.request( + 'MOVE', + f"{self.url}/{filename}", + auth=self.auth, + headers={'Destination': f"{self.url}/{dest_filename}"} + ) as resp: + duration = time.time() - start + return RequestMetrics( + method='MOVE', + duration=duration, + status=resp.status, + success=resp.status in [201, 204] + ) + except Exception as e: + return RequestMetrics( + method='MOVE', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_lock(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark LOCK request""" + lock_body = ''' + + + + + benchmark + +''' + + start = time.time() + try: + async with session.request( + 'LOCK', + f"{self.url}/{filename}", + auth=self.auth, + data=lock_body, + headers={'Content-Type': 'application/xml', 'Timeout': 'Second-300'} + ) as resp: + lock_token = resp.headers.get('Lock-Token', '').strip('<>') + await resp.read() + duration = time.time() - start + + # Unlock immediately to clean up + if lock_token: + try: + async with session.request( + 'UNLOCK', + f"{self.url}/{filename}", + auth=self.auth, + headers={'Lock-Token': f'<{lock_token}>'} + ) as unlock_resp: + pass + except: + pass + + return RequestMetrics( + method='LOCK', + duration=duration, + status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='LOCK', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def benchmark_delete(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark DELETE request""" + start = time.time() + try: + async with session.delete( + f"{self.url}/{filename}", + auth=self.auth + ) as resp: + duration = time.time() - start + return RequestMetrics( + method='DELETE', + duration=duration, + status=resp.status, + success=resp.status == 204 + ) + except Exception as e: + return RequestMetrics( + method='DELETE', + duration=time.time() - start, + status=0, + success=False, + error=str(e) + ) + + async def worker(self, worker_id: int, session: aiohttp.ClientSession): + """Worker coroutine that runs various benchmarks""" + test_files = [] + + # Create initial test file + filename = f"bench_worker_{worker_id}_{self.random_string()}.txt" + metric = await self.benchmark_put(session) + await self.record_metric(metric) + if metric.success: + test_files.append(filename) + + while not self.stop_flag: + elapsed = time.time() - self.start_time + if elapsed >= self.duration: + self.stop_flag = True + break + + # Randomly choose operation + operation = random.choice([ + 'options', 'propfind', 'put', 'get', 'head', + 'mkcol', 'proppatch', 'copy', 'move', 'lock', 'delete' + ]) + + try: + if operation == 'options': + metric = await self.benchmark_options(session) + + elif operation == 'propfind': + depth = random.choice([0, 1]) + metric = await self.benchmark_propfind(session, depth) + + elif operation == 'put': + metric = await self.benchmark_put(session) + if metric.success: + filename = f"bench_worker_{worker_id}_{self.random_string()}.txt" + test_files.append(filename) + + elif operation == 'get' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_get(session, filename) + + elif operation == 'head' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_head(session, filename) + + elif operation == 'mkcol': + metric = await self.benchmark_mkcol(session) + + elif operation == 'proppatch' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_proppatch(session, filename) + + elif operation == 'copy' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_copy(session, filename) + + elif operation == 'move' and test_files: + if len(test_files) > 1: + filename = test_files.pop(random.randrange(len(test_files))) + metric = await self.benchmark_move(session, filename) + else: + continue + + elif operation == 'lock' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_lock(session, filename) + + elif operation == 'delete' and len(test_files) > 1: + filename = test_files.pop(random.randrange(len(test_files))) + metric = await self.benchmark_delete(session, filename) + + else: + continue + + await self.record_metric(metric) + + except Exception as e: + print(f"Worker {worker_id} error: {e}") + + # Small delay to prevent overwhelming + await asyncio.sleep(0.001) + + async def run(self): + """Run the benchmark""" + print("="*80) + print("WebDAV Server Concurrent Benchmark") + print("="*80) + print(f"URL: {self.url}") + print(f"Concurrency: {self.concurrency} workers") + print(f"Duration: {self.duration} seconds") + print(f"User: {self.username}") + print("="*80) + print() + + connector = aiohttp.TCPConnector(limit=self.concurrency * 2) + timeout = aiohttp.ClientTimeout(total=30) + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + self.start_time = time.time() + + # Create worker tasks + workers = [ + asyncio.create_task(self.worker(i, session)) + for i in range(self.concurrency) + ] + + # Progress indicator + progress_task = asyncio.create_task(self.show_progress()) + + # Wait for all workers + await asyncio.gather(*workers, return_exceptions=True) + + # Stop progress + await progress_task + + # Print results + self.print_results() + + async def show_progress(self): + """Show progress during benchmark""" + while not self.stop_flag: + elapsed = time.time() - self.start_time + if elapsed >= self.duration: + break + + total_requests = sum(s.total_requests for s in self.stats.values()) + print(f"\rProgress: {elapsed:.1f}s / {self.duration}s | Total Requests: {total_requests}", end='', flush=True) + await asyncio.sleep(1) + + print() + + def print_results(self): + """Print benchmark results""" + print("\n") + print("="*80) + print("BENCHMARK RESULTS") + print("="*80) + print() + + total_duration = time.time() - self.start_time + total_requests = sum(s.total_requests for s in self.stats.values()) + total_success = sum(s.successful_requests for s in self.stats.values()) + total_failed = sum(s.failed_requests for s in self.stats.values()) + + print(f"Total Duration: {total_duration:.2f}s") + print(f"Total Requests: {total_requests:,}") + print(f"Successful: {total_success:,} ({total_success/total_requests*100:.1f}%)") + print(f"Failed: {total_failed:,} ({total_failed/total_requests*100:.1f}%)") + print(f"Overall RPS: {total_requests/total_duration:.2f}") + print() + + # Sort methods by request count + sorted_stats = sorted(self.stats.values(), key=lambda s: s.total_requests, reverse=True) + + print("="*80) + print("PER-METHOD STATISTICS") + print("="*80) + print() + + for stats in sorted_stats: + if stats.total_requests == 0: + continue + + print(f"Method: {stats.method}") + print(f" Requests: {stats.total_requests:>8,}") + print(f" Success Rate: {stats.success_rate:>8.2f}%") + print(f" RPS: {stats.requests_per_second:>8.2f}") + print(f" Latency (ms):") + print(f" Min: {stats.min_duration*1000:>8.2f}") + print(f" Avg: {stats.avg_duration*1000:>8.2f}") + print(f" P50: {stats.p50_duration*1000:>8.2f}") + print(f" P95: {stats.p95_duration*1000:>8.2f}") + print(f" P99: {stats.p99_duration*1000:>8.2f}") + print(f" Max: {stats.max_duration*1000:>8.2f}") + + if stats.failed_requests > 0 and stats.errors: + print(f" Errors:") + for error, count in sorted(stats.errors.items(), key=lambda x: x[1], reverse=True)[:5]: + error_short = error[:60] + '...' if len(error) > 60 else error + print(f" {error_short}: {count}") + + print() + + print("="*80) + + +async def main(): + """Main entry point""" + parser = argparse.ArgumentParser(description='WebDAV Server Concurrent Benchmark') + parser.add_argument('url', help='WebDAV server URL (e.g., http://localhost:8080/)') + parser.add_argument('username', help='Username for authentication') + parser.add_argument('password', help='Password for authentication') + parser.add_argument('-c', '--concurrency', type=int, default=50, + help='Number of concurrent workers (default: 50)') + parser.add_argument('-d', '--duration', type=int, default=60, + help='Benchmark duration in seconds (default: 60)') + + args = parser.parse_args() + + benchmark = WebDAVBenchmark( + url=args.url, + username=args.username, + password=args.password, + concurrency=args.concurrency, + duration=args.duration + ) + + await benchmark.run() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/benchmark2.py b/benchmark2.py new file mode 100644 index 0000000..e5d67cf --- /dev/null +++ b/benchmark2.py @@ -0,0 +1,579 @@ +#!/usr/bin/env python3 +""" +WebDAV Server Concurrent Benchmark Tool +Heavy load testing with performance metrics per method +""" + +import asyncio +import aiohttp +import time +import argparse +import statistics +from dataclasses import dataclass, field +from typing import List, Dict, Optional +from collections import defaultdict +import random +import string + + +@dataclass +class RequestMetrics: + """Metrics for a single request""" + method: str + duration: float + status: int + success: bool + error: Optional[str] = None + filename: Optional[str] = None # To track created/moved resources + + +@dataclass +class MethodStats: + """Statistics for a specific HTTP method""" + method: str + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + total_duration: float = 0.0 + durations: List[float] = field(default_factory=list) + errors: Dict[str, int] = field(default_factory=lambda: defaultdict(int)) + + @property + def success_rate(self) -> float: + return (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0 + + @property + def avg_duration(self) -> float: + return self.total_duration / self.total_requests if self.total_requests > 0 else 0 + + @property + def requests_per_second(self) -> float: + # A more accurate RPS for a method is its count over the total benchmark time + # This property is not used in the final report, but we'll leave it for potential use. + return self.total_requests / self.total_duration if self.total_duration > 0 else 0 + + @property + def min_duration(self) -> float: + return min(self.durations) if self.durations else 0 + + @property + def max_duration(self) -> float: + return max(self.durations) if self.durations else 0 + + @property + def p50_duration(self) -> float: + return statistics.median(self.durations) if self.durations else 0 + + @property + def p95_duration(self) -> float: + if not self.durations: + return 0 + sorted_durations = sorted(self.durations) + index = int(len(sorted_durations) * 0.95) + return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1] + + @property + def p99_duration(self) -> float: + if not self.durations: + return 0 + sorted_durations = sorted(self.durations) + index = int(len(sorted_durations) * 0.99) + return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1] + + +class WebDAVBenchmark: + """WebDAV server benchmark runner""" + + def __init__(self, url: str, username: str, password: str, + concurrency: int = 50, duration: int = 60): + self.url = url.rstrip('/') + self.username = username + self.password = password + self.concurrency = concurrency + self.duration = duration + self.stats: Dict[str, MethodStats] = defaultdict(lambda: MethodStats(method="")) + self.start_time = 0.0 + self.stop_flag = False + self.auth = aiohttp.BasicAuth(username, password) + + def random_string(self, length: int = 10) -> str: + """Generate random string""" + return ''.join(random.choices(string.ascii_letters + string.digits, k=length)) + + async def record_metric(self, metric: RequestMetrics): + """Record a request metric""" + stats = self.stats[metric.method] + if not stats.method: + stats.method = metric.method + + stats.total_requests += 1 + stats.total_duration += metric.duration + stats.durations.append(metric.duration) + + if metric.success: + stats.successful_requests += 1 + else: + stats.failed_requests += 1 + error_key = f"Status {metric.status}" if metric.status != 0 else str(metric.error) + stats.errors[error_key] += 1 + + async def benchmark_options(self, session: aiohttp.ClientSession) -> RequestMetrics: + """Benchmark OPTIONS request""" + start = time.time() + try: + async with session.options(self.url, auth=self.auth) as resp: + duration = time.time() - start + return RequestMetrics( + method='OPTIONS', duration=duration, status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='OPTIONS', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_propfind(self, session: aiohttp.ClientSession, depth: int = 0) -> RequestMetrics: + """Benchmark PROPFIND request""" + propfind_body = ''' + + +''' + start = time.time() + try: + async with session.request( + 'PROPFIND', self.url, auth=self.auth, data=propfind_body, + headers={'Depth': str(depth), 'Content-Type': 'application/xml'} + ) as resp: + await resp.read() + duration = time.time() - start + return RequestMetrics( + method='PROPFIND', duration=duration, status=resp.status, + success=resp.status == 207 + ) + except Exception as e: + return RequestMetrics( + method='PROPFIND', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_put(self, session: aiohttp.ClientSession) -> RequestMetrics: + """Benchmark PUT request""" + filename = f"bench_{self.random_string()}.txt" + content = self.random_string(1024).encode() + start = time.time() + try: + async with session.put(f"{self.url}/{filename}", auth=self.auth, data=content) as resp: + duration = time.time() - start + is_success = resp.status in [201, 204] + return RequestMetrics( + method='PUT', duration=duration, status=resp.status, + success=is_success, + filename=filename if is_success else None + ) + except Exception as e: + return RequestMetrics( + method='PUT', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_get(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark GET request""" + start = time.time() + try: + async with session.get(f"{self.url}/{filename}", auth=self.auth) as resp: + await resp.read() + duration = time.time() - start + return RequestMetrics( + method='GET', duration=duration, status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='GET', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_head(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark HEAD request""" + start = time.time() + try: + async with session.head(f"{self.url}/{filename}", auth=self.auth) as resp: + duration = time.time() - start + return RequestMetrics( + method='HEAD', duration=duration, status=resp.status, + success=resp.status == 200 + ) + except Exception as e: + return RequestMetrics( + method='HEAD', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_mkcol(self, session: aiohttp.ClientSession) -> RequestMetrics: + """Benchmark MKCOL request""" + dirname = f"bench_dir_{self.random_string()}" + start = time.time() + try: + async with session.request('MKCOL', f"{self.url}/{dirname}/", auth=self.auth) as resp: + duration = time.time() - start + is_success = resp.status == 201 + return RequestMetrics( + method='MKCOL', duration=duration, status=resp.status, + success=is_success, + filename=dirname if is_success else None + ) + except Exception as e: + return RequestMetrics( + method='MKCOL', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_proppatch(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark PROPPATCH request""" + proppatch_body = ''' + + Benchmark Test +''' + start = time.time() + try: + async with session.request( + 'PROPPATCH', f"{self.url}/{filename}", auth=self.auth, data=proppatch_body, + headers={'Content-Type': 'application/xml'} + ) as resp: + await resp.read() + duration = time.time() - start + return RequestMetrics( + method='PROPPATCH', duration=duration, status=resp.status, + success=resp.status == 207 + ) + except Exception as e: + return RequestMetrics( + method='PROPPATCH', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_copy(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark COPY request""" + dest_filename = f"copy_{self.random_string()}.txt" + start = time.time() + try: + async with session.request( + 'COPY', f"{self.url}/{filename}", auth=self.auth, + headers={'Destination': f"{self.url}/{dest_filename}"} + ) as resp: + duration = time.time() - start + is_success = resp.status in [201, 204] + return RequestMetrics( + method='COPY', duration=duration, status=resp.status, + success=is_success, + filename=dest_filename if is_success else None + ) + except Exception as e: + return RequestMetrics( + method='COPY', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_move(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark MOVE request""" + dest_filename = f"moved_{self.random_string()}.txt" + start = time.time() + try: + async with session.request( + 'MOVE', f"{self.url}/{filename}", auth=self.auth, + headers={'Destination': f"{self.url}/{dest_filename}"} + ) as resp: + duration = time.time() - start + is_success = resp.status in [201, 204] + return RequestMetrics( + method='MOVE', duration=duration, status=resp.status, + success=is_success, + filename=dest_filename if is_success else None + ) + except Exception as e: + return RequestMetrics( + method='MOVE', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_lock(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics: + """Benchmark LOCK request""" + lock_body = ''' + + + benchmark +''' + start = time.time() + try: + async with session.request( + 'LOCK', f"{self.url}/{filename}", auth=self.auth, data=lock_body, + headers={'Content-Type': 'application/xml', 'Timeout': 'Second-300'} + ) as resp: + lock_token = resp.headers.get('Lock-Token', '').strip('<>') + await resp.read() + duration = time.time() - start + is_success = resp.status == 200 + + if is_success and lock_token: + try: + async with session.request( + 'UNLOCK', f"{self.url}/{filename}", auth=self.auth, + headers={'Lock-Token': f'<{lock_token}>'} + ): + pass + except: + pass + + return RequestMetrics( + method='LOCK', duration=duration, status=resp.status, + success=is_success + ) + except Exception as e: + return RequestMetrics( + method='LOCK', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def benchmark_delete(self, session: aiohttp.ClientSession, resource_name: str) -> RequestMetrics: + """Benchmark DELETE request for files or directories""" + start = time.time() + try: + # Add trailing slash for directories for some servers + url_path = f"{self.url}/{resource_name}" + if "dir" in resource_name: + url_path += "/" + + async with session.delete(url_path, auth=self.auth) as resp: + duration = time.time() - start + return RequestMetrics( + method='DELETE', duration=duration, status=resp.status, + success=resp.status == 204 + ) + except Exception as e: + return RequestMetrics( + method='DELETE', duration=time.time() - start, status=0, + success=False, error=str(e) + ) + + async def worker(self, worker_id: int, session: aiohttp.ClientSession): + """Worker coroutine that runs various benchmarks""" + test_files = [] + test_dirs = [] + + # Create an initial test file to ensure other operations can start + metric = await self.benchmark_put(session) + await self.record_metric(metric) + if metric.success and metric.filename: + test_files.append(metric.filename) + + while not self.stop_flag: + elapsed = time.time() - self.start_time + if elapsed >= self.duration: + self.stop_flag = True + break + + # Weighted random choice + operations = [ + 'options', 'propfind', 'put', 'get', 'head', + 'mkcol', 'proppatch', 'copy', 'move', 'lock', 'delete' + ] + + # Ensure some operations are more frequent + weights = [5, 5, 15, 15, 10, 5, 5, 5, 5, 5, 20] # More PUT, GET, DELETE + operation = random.choices(operations, weights=weights, k=1)[0] + + metric = None + try: + if operation == 'options': + metric = await self.benchmark_options(session) + + elif operation == 'propfind': + depth = random.choice([0, 1]) + metric = await self.benchmark_propfind(session, depth) + + elif operation == 'put': + metric = await self.benchmark_put(session) + if metric.success and metric.filename: + test_files.append(metric.filename) + + elif operation == 'get' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_get(session, filename) + + elif operation == 'head' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_head(session, filename) + + elif operation == 'mkcol': + metric = await self.benchmark_mkcol(session) + if metric.success and metric.filename: + test_dirs.append(metric.filename) + + elif operation == 'proppatch' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_proppatch(session, filename) + + elif operation == 'copy' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_copy(session, filename) + if metric.success and metric.filename: + test_files.append(metric.filename) + + elif operation == 'move' and len(test_files) > 1: + filename_to_move = test_files.pop(random.randrange(len(test_files))) + metric = await self.benchmark_move(session, filename_to_move) + if metric.success and metric.filename: + test_files.append(metric.filename) + + elif operation == 'lock' and test_files: + filename = random.choice(test_files) + metric = await self.benchmark_lock(session, filename) + + elif operation == 'delete': + # Randomly delete a file or a directory + if test_dirs and random.random() < 0.2 and len(test_dirs) > 0: # 20% chance to delete a dir + dir_to_delete = test_dirs.pop(random.randrange(len(test_dirs))) + metric = await self.benchmark_delete(session, dir_to_delete) + elif len(test_files) > 1: + file_to_delete = test_files.pop(random.randrange(len(test_files))) + metric = await self.benchmark_delete(session, file_to_delete) + + if metric: + await self.record_metric(metric) + + except Exception as e: + print(f"Worker {worker_id} error: {e}") + + await asyncio.sleep(0.01) # Small delay to prevent tight loop on empty lists + + async def run(self): + """Run the benchmark""" + print("="*80) + print("WebDAV Server Concurrent Benchmark") + print("="*80) + print(f"URL: {self.url}") + print(f"Concurrency: {self.concurrency} workers") + print(f"Duration: {self.duration} seconds") + print(f"User: {self.username}") + print("="*80) + print() + + connector = aiohttp.TCPConnector(limit=self.concurrency * 2) + timeout = aiohttp.ClientTimeout(total=30) + + async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: + self.start_time = time.time() + + workers = [ + asyncio.create_task(self.worker(i, session)) + for i in range(self.concurrency) + ] + + progress_task = asyncio.create_task(self.show_progress()) + + await asyncio.gather(*workers, return_exceptions=True) + + self.stop_flag = True + await progress_task + + self.print_results() + + async def show_progress(self): + """Show progress during benchmark""" + while not self.stop_flag: + elapsed = time.time() - self.start_time + if elapsed >= self.duration: + break + + total_requests = sum(s.total_requests for s in self.stats.values()) + print(f"\rProgress: {elapsed:.1f}s / {self.duration}s | Total Requests: {total_requests}", end='', flush=True) + await asyncio.sleep(0.5) + print() + + def print_results(self): + """Print benchmark results""" + print("\n") + print("="*80) + print("BENCHMARK RESULTS") + print("="*80) + print() + + total_duration = time.time() - self.start_time + total_requests = sum(s.total_requests for s in self.stats.values()) + total_success = sum(s.successful_requests for s in self.stats.values()) + total_failed = total_requests - total_success + + success_rate = (total_success / total_requests * 100) if total_requests > 0 else 0 + failed_rate = (total_failed / total_requests * 100) if total_requests > 0 else 0 + + print(f"Total Duration: {total_duration:.2f}s") + print(f"Total Requests: {total_requests:,}") + print(f"Successful: {total_success:,} ({success_rate:.1f}%)") + print(f"Failed: {total_failed:,} ({failed_rate:.1f}%)") + print(f"Overall RPS: {total_requests/total_duration:.2f}") + print() + + sorted_stats = sorted(self.stats.values(), key=lambda s: s.total_requests, reverse=True) + + print("="*80) + print("PER-METHOD STATISTICS") + print("="*80) + print() + + for stats in sorted_stats: + if stats.total_requests == 0: + continue + + # Calculate RPS based on total benchmark duration for better comparison + method_rps = stats.total_requests / total_duration + + print(f"Method: {stats.method}") + print(f" Requests: {stats.total_requests:>8,}") + print(f" Success Rate: {stats.success_rate:>8.2f}%") + print(f" RPS: {method_rps:>8.2f}") + print(f" Latency (ms):") + print(f" Min: {stats.min_duration*1000:>8.2f}") + print(f" Avg: {stats.avg_duration*1000:>8.2f}") + print(f" P50: {stats.p50_duration*1000:>8.2f}") + print(f" P95: {stats.p95_duration*1000:>8.2f}") + print(f" P99: {stats.p99_duration*1000:>8.2f}") + print(f" Max: {stats.max_duration*1000:>8.2f}") + + if stats.failed_requests > 0 and stats.errors: + print(f" Errors:") + for error, count in sorted(stats.errors.items(), key=lambda x: x[1], reverse=True)[:5]: + error_short = error[:60] + '...' if len(error) > 60 else error + print(f" {error_short}: {count}") + + print() + + print("="*80) + + +async def main(): + """Main entry point""" + parser = argparse.ArgumentParser(description='WebDAV Server Concurrent Benchmark') + parser.add_argument('url', help='WebDAV server URL (e.g., http://localhost:8080/)') + parser.add_argument('username', help='Username for authentication') + parser.add_argument('password', help='Password for authentication') + parser.add_argument('-c', '--concurrency', type=int, default=50, + help='Number of concurrent workers (default: 50)') + parser.add_argument('-d', '--duration', type=int, default=60, + help='Benchmark duration in seconds (default: 60)') + + args = parser.parse_args() + + benchmark = WebDAVBenchmark( + url=args.url, + username=args.username, + password=args.password, + concurrency=args.concurrency, + duration=args.duration + ) + + await benchmark.run() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/main2.py b/main2.py new file mode 100644 index 0000000..5b684cb --- /dev/null +++ b/main2.py @@ -0,0 +1,752 @@ +#!/usr/bin/env python3 +""" +Complete WebDAV Server Implementation with aiohttp +Production-ready WebDAV server with full RFC 4918 compliance, +Windows Explorer compatibility, and comprehensive user management. + +Includes multi-layered caching for high performance: +1. HTTP ETags for client-side caching. +2. In-memory LRU cache for filesystem metadata (PROPFIND). +3. In-memory LRU cache for password hashing (Authentication). +4. Asynchronous handling of blocking file operations. +""" + +import os +import asyncio +import aiofiles +import sqlite3 +import hashlib +import hmac +import secrets +import mimetypes +import base64 +import functools +import shutil +from datetime import datetime +from pathlib import Path +from typing import Optional, Dict, List, Tuple +from xml.etree import ElementTree as ET +from urllib.parse import unquote, quote, urlparse + +from aiohttp import web +from aiohttp_session import setup as setup_session +from aiohttp_session.cookie_storage import EncryptedCookieStorage +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# ============================================================================ +# Configuration Management +# ============================================================================ + +class Config: + """Centralized configuration management from environment variables""" + + # Server Configuration + HOST = os.getenv('HOST', '0.0.0.0') + PORT = int(os.getenv('PORT', '8080')) + + # Database Configuration + DB_PATH = os.getenv('DB_PATH', './webdav.db') + + # Authentication Configuration + AUTH_METHODS = os.getenv('AUTH_METHODS', 'basic').split(',') + + # WebDAV Configuration + MAX_FILE_SIZE = int(os.getenv('MAX_FILE_SIZE', '104857600')) # 100MB + MAX_PROPFIND_DEPTH = int(os.getenv('MAX_PROPFIND_DEPTH', '3')) + LOCK_TIMEOUT_DEFAULT = int(os.getenv('LOCK_TIMEOUT_DEFAULT', '3600')) + + # WebDAV Root Directory + WEBDAV_ROOT = os.getenv('WEBDAV_ROOT', './webdav') + + +# ============================================================================ +# Database Layer +# ============================================================================ + +# This is the function we will cache. Caching works best on pure functions. +@functools.lru_cache(maxsize=128) +def _hash_password(password: str, salt: str) -> str: + """Hashes a password with a salt. This is the expensive part.""" + return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex() + + +class Database: + """SQLite database management with async wrapper""" + + def __init__(self, db_path: str): + self.db_path = db_path + self._connection_lock = asyncio.Lock() + self.init_database() + + def get_connection(self) -> sqlite3.Connection: + """Get database connection with row factory""" + conn = sqlite3.connect(self.db_path, timeout=30.0, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute('PRAGMA journal_mode=WAL') + conn.execute('PRAGMA busy_timeout=30000') + conn.execute('PRAGMA synchronous=NORMAL') + return conn + + def init_database(self): + """Initialize database schema""" + conn = self.get_connection() + cursor = conn.cursor() + + # Users table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + salt TEXT NOT NULL, + is_active BOOLEAN DEFAULT 1 + ) + ''') + + # Locks table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS locks ( + lock_token TEXT PRIMARY KEY, + resource_path TEXT NOT NULL, + user_id INTEGER, + lock_type TEXT DEFAULT 'write', + lock_scope TEXT DEFAULT 'exclusive', + depth INTEGER DEFAULT 0, + timeout_seconds INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + owner TEXT, + FOREIGN KEY (user_id) REFERENCES users (id) + ) + ''') + + # Properties table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS properties ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + resource_path TEXT NOT NULL, + namespace TEXT, + property_name TEXT NOT NULL, + property_value TEXT, + UNIQUE(resource_path, namespace, property_name) + ) + ''') + + cursor.execute('CREATE INDEX IF NOT EXISTS idx_locks_resource ON locks(resource_path)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_properties_resource ON properties(resource_path)') + + conn.commit() + conn.close() + + async def run_in_executor(self, func, *args): + """Run a synchronous database function in a thread pool.""" + return await asyncio.get_event_loop().run_in_executor(None, func, *args) + + async def create_user(self, username: str, password: str) -> int: + """Create a new user""" + salt = secrets.token_hex(16) + password_hash = _hash_password(password, salt) + + def _create(): + conn = self.get_connection() + cursor = conn.cursor() + try: + cursor.execute( + 'INSERT INTO users (username, password_hash, salt) VALUES (?, ?, ?)', + (username, password_hash, salt) + ) + user_id = cursor.lastrowid + conn.commit() + return user_id + finally: + conn.close() + + user_id = await self.run_in_executor(_create) + + user_dir = Path(Config.WEBDAV_ROOT) / 'users' / username + user_dir.mkdir(parents=True, exist_ok=True) + + return user_id + + def _get_user_from_db(self, username: str) -> Optional[Dict]: + """Fetches user data from the database.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute('SELECT * FROM users WHERE username = ? AND is_active = 1', (username,)) + user = cursor.fetchone() + return dict(user) if user else None + finally: + conn.close() + + async def verify_user(self, username: str, password: str) -> Optional[Dict]: + """Verify user credentials using a cached hash function.""" + user_data = await self.run_in_executor(self._get_user_from_db, username) + if not user_data: + return None + + password_hash = _hash_password(password, user_data['salt']) + if hmac.compare_digest(password_hash, user_data['password_hash']): + return user_data + return None + + async def get_lock(self, resource_path: str) -> Optional[Dict]: + def _get(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute(''' + SELECT * FROM locks WHERE resource_path = ? + AND datetime(created_at, '+' || timeout_seconds || ' seconds') > datetime('now') + ''', (resource_path,)) + lock = cursor.fetchone() + return dict(lock) if lock else None + finally: + conn.close() + return await self.run_in_executor(_get) + + async def create_lock(self, resource_path: str, user_id: int, timeout: int, owner: str) -> str: + lock_token = f"opaquelocktoken:{secrets.token_urlsafe(16)}" + def _create(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + 'INSERT INTO locks (lock_token, resource_path, user_id, timeout_seconds, owner) VALUES (?, ?, ?, ?, ?)', + (lock_token, resource_path, user_id, timeout, owner) + ) + conn.commit() + return lock_token + finally: + conn.close() + return await self.run_in_executor(_create) + + async def remove_lock(self, lock_token: str, user_id: int) -> bool: + def _remove(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute('DELETE FROM locks WHERE lock_token = ? AND user_id = ?', (lock_token, user_id)) + deleted = cursor.rowcount > 0 + conn.commit() + return deleted + finally: + conn.close() + return await self.run_in_executor(_remove) + + async def get_properties(self, resource_path: str) -> List[Dict]: + def _get(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute('SELECT * FROM properties WHERE resource_path = ?', (resource_path,)) + properties = cursor.fetchall() + return [dict(prop) for prop in properties] + finally: + conn.close() + return await self.run_in_executor(_get) + + async def set_property(self, resource_path: str, namespace: str, property_name: str, property_value: str): + def _set(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + 'INSERT OR REPLACE INTO properties (resource_path, namespace, property_name, property_value) VALUES (?, ?, ?, ?)', + (resource_path, namespace, property_name, property_value) + ) + conn.commit() + finally: + conn.close() + await self.run_in_executor(_set) + + async def remove_property(self, resource_path: str, namespace: str, property_name: str): + def _remove(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + 'DELETE FROM properties WHERE resource_path = ? AND namespace = ? AND property_name = ?', + (resource_path, namespace, property_name) + ) + conn.commit() + finally: + conn.close() + await self.run_in_executor(_remove) + +# ============================================================================ +# XML Utilities for WebDAV +# ============================================================================ + +class WebDAVXML: + """XML processing utilities for WebDAV protocol""" + NS = {'D': 'DAV:'} + + @staticmethod + def register_namespaces(): + for prefix, uri in WebDAVXML.NS.items(): + ET.register_namespace(prefix, uri) + + @staticmethod + def create_multistatus() -> ET.Element: + return ET.Element('{DAV:}multistatus') + + @staticmethod + def create_response(href: str) -> ET.Element: + response = ET.Element('{DAV:}response') + href_elem = ET.SubElement(response, '{DAV:}href') + href_elem.text = href + return response + + @staticmethod + def add_propstat(response: ET.Element, props: Dict[str, str], status: str = '200 OK'): + propstat = ET.SubElement(response, '{DAV:}propstat') + prop = ET.SubElement(propstat, '{DAV:}prop') + + is_collection = props.pop('_is_collection', False) + + for prop_name, prop_value in props.items(): + prop_elem = ET.SubElement(prop, prop_name) + if prop_name == '{DAV:}resourcetype' and is_collection: + ET.SubElement(prop_elem, '{DAV:}collection') + elif prop_value is not None: + prop_elem.text = str(prop_value) + + status_elem = ET.SubElement(propstat, '{DAV:}status') + status_elem.text = f'HTTP/1.1 {status}' + + @staticmethod + def serialize(element: ET.Element) -> str: + WebDAVXML.register_namespaces() + return ET.tostring(element, encoding='unicode', xml_declaration=True) + + @staticmethod + def parse_propfind(body: bytes) -> Tuple[str, List[str]]: + if not body: return 'allprop', [] + try: + root = ET.fromstring(body) + if root.find('.//{DAV:}allprop') is not None: return 'allprop', [] + if root.find('.//{DAV:}propname') is not None: return 'propname', [] + prop_elem = root.find('.//{DAV:}prop') + if prop_elem is not None: + return 'prop', [child.tag for child in prop_elem] + except ET.ParseError: + pass + return 'allprop', [] + + +# ============================================================================ +# Authentication and Authorization +# ============================================================================ + +class AuthHandler: + """Handle authentication methods""" + + def __init__(self, db: Database): + self.db = db + + async def authenticate_basic(self, request: web.Request) -> Optional[Dict]: + auth_header = request.headers.get('Authorization') + if not auth_header or not auth_header.startswith('Basic '): + return None + try: + auth_decoded = base64.b64decode(auth_header[6:]).decode() + username, password = auth_decoded.split(':', 1) + return await self.db.verify_user(username, password) + except (ValueError, UnicodeDecodeError): + return None + + async def authenticate(self, request: web.Request) -> Optional[Dict]: + if 'basic' in Config.AUTH_METHODS: + return await self.authenticate_basic(request) + return None + + def require_auth_response(self) -> web.Response: + return web.Response( + status=401, + headers={'WWW-Authenticate': 'Basic realm="WebDAV Server"'}, + text='Unauthorized' + ) + + +# ============================================================================ +# WebDAV Handler +# ============================================================================ + +class WebDAVHandler: + """Main WebDAV protocol handler""" + + def __init__(self, db: Database, auth: AuthHandler): + self.db = db + self.auth = auth + self.metadata_cache = {} + self.cache_lock = asyncio.Lock() + WebDAVXML.register_namespaces() + + def get_user_root(self, username: str) -> Path: + return Path(Config.WEBDAV_ROOT) / 'users' / username + + def get_physical_path(self, username: str, webdav_path: str) -> Path: + webdav_path = unquote(webdav_path).lstrip('/') + user_root = self.get_user_root(username) + physical_path = (user_root / webdav_path).resolve() + + if user_root.resolve() not in physical_path.parents and physical_path != user_root.resolve(): + raise web.HTTPForbidden(text="Access denied outside of user root.") + + return physical_path + + async def run_blocking_io(self, func, *args, **kwargs): + fn = functools.partial(func, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, fn) + + async def _invalidate_cache_entry(self, user: Dict, webdav_path: str): + """Invalidates a single entry and its parent from the cache.""" + key_prefix = f"{user['username']}:" + async with self.cache_lock: + # Invalidate the resource itself + if (key_prefix + webdav_path) in self.metadata_cache: + del self.metadata_cache[key_prefix + webdav_path] + + # Invalidate its parent directory + parent_path = str(Path(webdav_path).parent) + if (key_prefix + parent_path) in self.metadata_cache: + del self.metadata_cache[key_prefix + parent_path] + + async def handle_options(self, request: web.Request, user: Dict) -> web.Response: + return web.Response( + status=200, + headers={ + 'DAV': '1, 2', + 'MS-Author-Via': 'DAV', + 'Allow': 'OPTIONS, GET, HEAD, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK', + } + ) + + async def _generate_etag(self, path: Path) -> str: + """Generates an ETag for a file based on size and mtime.""" + try: + stat = await self.run_blocking_io(path.stat) + etag_data = f"{stat.st_size}-{stat.st_mtime_ns}" + return f'"{hashlib.sha1(etag_data.encode()).hexdigest()}"' + except FileNotFoundError: + return "" + + async def handle_get(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + is_dir = await self.run_blocking_io(path.is_dir) + if is_dir: + raise web.HTTPForbidden(text="Directory listing not supported via GET.") + + etag = await self._generate_etag(path) + if etag and request.headers.get('If-None-Match') == etag: + return web.Response(status=304, headers={'ETag': etag}) + + async with aiofiles.open(path, 'rb') as f: + content = await f.read() + + content_type, _ = mimetypes.guess_type(str(path)) + return web.Response( + body=content, + content_type=content_type or 'application/octet-stream', + headers={'ETag': etag} + ) + + async def handle_head(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + etag = await self._generate_etag(path) + if etag and request.headers.get('If-None-Match') == etag: + return web.Response(status=304, headers={'ETag': etag}) + + stat = await self.run_blocking_io(path.stat) + content_type, _ = mimetypes.guess_type(str(path)) + headers = { + 'Content-Type': content_type or 'application/octet-stream', + 'Content-Length': str(stat.st_size), + 'ETag': etag + } + return web.Response(headers=headers) + + async def handle_put(self, request: web.Request, user: Dict) -> web.Response: + await self._invalidate_cache_entry(user, request.path) + path = self.get_physical_path(user['username'], request.path) + + exists = await self.run_blocking_io(path.exists) + + await self.run_blocking_io(path.parent.mkdir, parents=True, exist_ok=True) + + async with aiofiles.open(path, 'wb') as f: + async for chunk in request.content.iter_chunked(8192): + await f.write(chunk) + + return web.Response(status=204 if exists else 201) + + async def handle_delete(self, request: web.Request, user: Dict) -> web.Response: + await self._invalidate_cache_entry(user, request.path) + path = self.get_physical_path(user['username'], request.path) + + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + try: + is_dir = await self.run_blocking_io(path.is_dir) + if is_dir: + await self.run_blocking_io(shutil.rmtree, path) + else: + await self.run_blocking_io(path.unlink) + return web.Response(status=204) + except OSError as e: + raise web.HTTPConflict(text=f"Cannot delete resource: {e}") + + async def handle_mkcol(self, request: web.Request, user: Dict) -> web.Response: + await self._invalidate_cache_entry(user, str(Path(request.path).parent)) + path = self.get_physical_path(user['username'], request.path) + + if await self.run_blocking_io(path.exists): + raise web.HTTPMethodNotAllowed(method='MKCOL', allowed_methods=[]) + + if not await self.run_blocking_io(path.parent.exists): + raise web.HTTPConflict() + + await self.run_blocking_io(path.mkdir) + return web.Response(status=201) + + async def get_resource_properties(self, path: Path, href: str, user: Dict) -> Dict[str, str]: + cache_key = f"{user['username']}:{href}" + async with self.cache_lock: + if cache_key in self.metadata_cache: + return self.metadata_cache[cache_key] + + try: + stat = await self.run_blocking_io(path.stat) + is_dir = await self.run_blocking_io(path.is_dir) + except FileNotFoundError: + return {} + + props = { + '{DAV:}displayname': path.name, + '{DAV:}creationdate': datetime.fromtimestamp(stat.st_ctime).isoformat() + "Z", + '{DAV:}getlastmodified': datetime.fromtimestamp(stat.st_mtime).strftime('%a, %d %b %Y %H:%M:%S GMT'), + '{DAV:}resourcetype': None, + '_is_collection': is_dir, + } + if not is_dir: + props['{DAV:}getcontentlength'] = str(stat.st_size) + content_type, _ = mimetypes.guess_type(str(path)) + props['{DAV:}getcontenttype'] = content_type or 'application/octet-stream' + + db_props = await self.db.get_properties(href) + for prop in db_props: + props[f"{{{prop['namespace']}}}{prop['property_name']}"] = prop['property_value'] + + async with self.cache_lock: + self.metadata_cache[cache_key] = props + return props + + async def add_resource_to_multistatus(self, multistatus: ET.Element, path: Path, href: str, user: Dict): + props = await self.get_resource_properties(path, href, user) + if props: + response = WebDAVXML.create_response(quote(href)) + WebDAVXML.add_propstat(response, props) + multistatus.append(response) + + async def handle_propfind(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + depth = request.headers.get('Depth', '1') + body = await request.read() + + multistatus = WebDAVXML.create_multistatus() + await self.add_resource_to_multistatus(multistatus, path, request.path, user) + + if depth == '1' and await self.run_blocking_io(path.is_dir): + for child_path in await self.run_blocking_io(list, path.iterdir()): + child_href = f"{request.path.rstrip('/')}/{child_path.name}" + await self.add_resource_to_multistatus(multistatus, child_path, child_href, user) + + xml_response = WebDAVXML.serialize(multistatus) + return web.Response(status=207, content_type='application/xml', text=xml_response) + + async def handle_proppatch(self, request: web.Request, user: Dict) -> web.Response: + await self._invalidate_cache_entry(user, request.path) + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + body = await request.read() + root = ET.fromstring(body) + + for prop_action in root: + if prop_action.tag.endswith("set"): + for prop in prop_action.find('{DAV:}prop'): + await self.db.set_property(request.path, prop.tag.split('}')[0][1:], prop.tag.split('}')[1], prop.text or "") + elif prop_action.tag.endswith("remove"): + for prop in prop_action.find('{DAV:}prop'): + await self.db.remove_property(request.path, prop.tag.split('}')[0][1:], prop.tag.split('}')[1]) + + multistatus = WebDAVXML.create_multistatus() + await self.add_resource_to_multistatus(multistatus, path, request.path, user) + return web.Response(status=207, content_type='application/xml', text=WebDAVXML.serialize(multistatus)) + + async def handle_copy(self, request: web.Request, user: Dict) -> web.Response: + src_path = self.get_physical_path(user['username'], request.path) + dest_header = request.headers.get('Destination') + if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header") + + dest_path = self.get_physical_path(user['username'], urlparse(dest_header).path) + await self._invalidate_cache_entry(user, str(Path(urlparse(dest_header).path).parent)) + + overwrite = request.headers.get('Overwrite', 'T').upper() == 'T' + if await self.run_blocking_io(dest_path.exists) and not overwrite: + raise web.HTTPPreconditionFailed() + + is_dir = await self.run_blocking_io(src_path.is_dir) + if is_dir: + await self.run_blocking_io(shutil.copytree, src_path, dest_path, dirs_exist_ok=overwrite) + else: + await self.run_blocking_io(shutil.copy2, src_path, dest_path) + + return web.Response(status=201) + + async def handle_move(self, request: web.Request, user: Dict) -> web.Response: + src_path = self.get_physical_path(user['username'], request.path) + dest_header = request.headers.get('Destination') + if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header") + + dest_path = self.get_physical_path(user['username'], urlparse(dest_header).path) + await self._invalidate_cache_entry(user, request.path) + await self._invalidate_cache_entry(user, urlparse(dest_header).path) + + overwrite = request.headers.get('Overwrite', 'T').upper() == 'T' + if await self.run_blocking_io(dest_path.exists) and not overwrite: + raise web.HTTPPreconditionFailed() + + await self.run_blocking_io(shutil.move, str(src_path), str(dest_path)) + return web.Response(status=201) + + async def handle_lock(self, request: web.Request, user: Dict) -> web.Response: + body = await request.read() + owner_info = ET.tostring(ET.fromstring(body).find('.//{DAV:}owner'), encoding='unicode') + timeout_header = request.headers.get('Timeout', f'Second-{Config.LOCK_TIMEOUT_DEFAULT}') + timeout = int(timeout_header.split('-')[1]) + + lock_token = await self.db.create_lock(request.path, user['id'], timeout, owner_info) + + response_xml = f''' + + +0Second-{timeout} +{lock_token} +{owner_info} +''' + return web.Response(status=200, content_type='application/xml', text=response_xml, headers={'Lock-Token': f'<{lock_token}>'}) + + async def handle_unlock(self, request: web.Request, user: Dict) -> web.Response: + lock_token = request.headers.get('Lock-Token', '').strip('<>') + if not lock_token: raise web.HTTPBadRequest(text="Missing Lock-Token header") + + if await self.db.remove_lock(lock_token, user['id']): + return web.Response(status=204) + else: + raise web.HTTPConflict(text="Lock not found or not owned by user") + + +# ============================================================================ +# Web Application +# ============================================================================ + +async def webdav_handler_func(request: web.Request): + """Main routing function for all WebDAV methods.""" + app = request.app + auth_handler: AuthHandler = app['auth'] + webdav_handler: WebDAVHandler = app['webdav'] + + # OPTIONS is often unauthenticated (pre-flight) + if request.method == 'OPTIONS': + return await webdav_handler.handle_options(request, {}) + + user = await auth_handler.authenticate(request) + if not user: + return auth_handler.require_auth_response() + + # Route to the correct handler based on method + method_map = { + 'GET': webdav_handler.handle_get, + 'HEAD': webdav_handler.handle_head, + 'PUT': webdav_handler.handle_put, + 'DELETE': webdav_handler.handle_delete, + 'MKCOL': webdav_handler.handle_mkcol, + 'PROPFIND': webdav_handler.handle_propfind, + 'PROPPATCH': webdav_handler.handle_proppatch, + 'COPY': webdav_handler.handle_copy, + 'MOVE': webdav_handler.handle_move, + 'LOCK': webdav_handler.handle_lock, + 'UNLOCK': webdav_handler.handle_unlock, + } + + handler = method_map.get(request.method) + if handler: + return await handler(request, user) + else: + raise web.HTTPMethodNotAllowed(method=request.method, allowed_methods=list(method_map.keys())) + + +async def init_app() -> web.Application: + """Initialize web application""" + app = web.Application(client_max_size=Config.MAX_FILE_SIZE) + + db = Database(Config.DB_PATH) + app['db'] = db + app['auth'] = AuthHandler(db) + app['webdav'] = WebDAVHandler(db, app['auth']) + + app.router.add_route('*', '/{path:.*}', webdav_handler_func) + return app + + +async def create_default_user(db: Database): + """Create default admin user if no users exist""" + def _check_user_exists(): + conn = db.get_connection() + try: + cursor = conn.cursor() + cursor.execute('SELECT COUNT(*) as count FROM users') + return cursor.fetchone()['count'] > 0 + finally: + conn.close() + + user_exists = await asyncio.get_event_loop().run_in_executor(None, _check_user_exists) + if not user_exists: + print("No users found. Creating default user 'admin' with password 'admin123'.") + await db.create_user('admin', 'admin123') + print("Default user created. Please change the password for security.") + + +def main(): + """Main entry point""" + Path(Config.WEBDAV_ROOT).mkdir(parents=True, exist_ok=True) + (Path(Config.WEBDAV_ROOT) / 'users').mkdir(exist_ok=True) + + db = Database(Config.DB_PATH) + asyncio.run(create_default_user(db)) + + app = asyncio.run(init_app()) + + print(f"Starting WebDAV Server on http://{Config.HOST}:{Config.PORT}") + web.run_app(app, host=Config.HOST, port=Config.PORT) + + +if __name__ == '__main__': + main() diff --git a/main3.py b/main3.py new file mode 100644 index 0000000..a3b1d00 --- /dev/null +++ b/main3.py @@ -0,0 +1,793 @@ +#!/usr/bin/env python3 +""" +Complete WebDAV Server Implementation with aiohttp +Production-ready WebDAV server with full RFC 4918 compliance, +Windows Explorer compatibility, and comprehensive user management. + +Includes multi-layered caching for high performance: +1. HTTP ETags for client-side caching. +2. "Prime on Write" metadata cache for consistently fast PROPFIND. +3. In-memory LRU cache for password hashing (Authentication). +4. Asynchronous handling of blocking file operations and stability fixes. +""" + +import os +import asyncio +import aiofiles +import sqlite3 +import hashlib +import hmac +import secrets +import mimetypes +import base64 +import functools +import shutil +from datetime import datetime +from pathlib import Path +from typing import Optional, Dict, List, Tuple +from xml.etree import ElementTree as ET +from urllib.parse import unquote, quote, urlparse + +from aiohttp import web +from aiohttp_session import setup as setup_session +from aiohttp_session.cookie_storage import EncryptedCookieStorage +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# ============================================================================ +# Configuration Management +# ============================================================================ + +class Config: + """Centralized configuration management from environment variables""" + + # Server Configuration + HOST = os.getenv('HOST', '0.0.0.0') + PORT = int(os.getenv('PORT', '8080')) + + # Database Configuration + DB_PATH = os.getenv('DB_PATH', './webdav.db') + + # Authentication Configuration + AUTH_METHODS = os.getenv('AUTH_METHODS', 'basic').split(',') + + # WebDAV Configuration + MAX_FILE_SIZE = int(os.getenv('MAX_FILE_SIZE', '104857600')) # 100MB + MAX_PROPFIND_DEPTH = int(os.getenv('MAX_PROPFIND_DEPTH', '3')) + LOCK_TIMEOUT_DEFAULT = int(os.getenv('LOCK_TIMEOUT_DEFAULT', '3600')) + + # WebDAV Root Directory + WEBDAV_ROOT = os.getenv('WEBDAV_ROOT', './webdav') + + +# ============================================================================ +# Database Layer +# ============================================================================ + +# This is the function we will cache. Caching works best on pure functions. +@functools.lru_cache(maxsize=128) +def _hash_password(password: str, salt: str) -> str: + """Hashes a password with a salt. This is the expensive part.""" + return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex() + + +class Database: + """SQLite database management with async wrapper""" + + def __init__(self, db_path: str): + self.db_path = db_path + self._connection_lock = asyncio.Lock() + self.init_database() + + def get_connection(self) -> sqlite3.Connection: + """Get database connection with row factory""" + conn = sqlite3.connect(self.db_path, timeout=30.0, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute('PRAGMA journal_mode=WAL') + conn.execute('PRAGMA busy_timeout=30000') + conn.execute('PRAGMA synchronous=NORMAL') + return conn + + def init_database(self): + """Initialize database schema""" + conn = self.get_connection() + cursor = conn.cursor() + + # Users table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + salt TEXT NOT NULL, + is_active BOOLEAN DEFAULT 1 + ) + ''') + + # Locks table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS locks ( + lock_token TEXT PRIMARY KEY, + resource_path TEXT NOT NULL, + user_id INTEGER, + lock_type TEXT DEFAULT 'write', + lock_scope TEXT DEFAULT 'exclusive', + depth INTEGER DEFAULT 0, + timeout_seconds INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + owner TEXT, + FOREIGN KEY (user_id) REFERENCES users (id) + ) + ''') + + # Properties table + cursor.execute(''' + CREATE TABLE IF NOT EXISTS properties ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + resource_path TEXT NOT NULL, + namespace TEXT, + property_name TEXT NOT NULL, + property_value TEXT, + UNIQUE(resource_path, namespace, property_name) + ) + ''') + + cursor.execute('CREATE INDEX IF NOT EXISTS idx_locks_resource ON locks(resource_path)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_properties_resource ON properties(resource_path)') + + conn.commit() + conn.close() + + async def run_in_executor(self, func, *args): + """Run a synchronous database function in a thread pool.""" + return await asyncio.get_event_loop().run_in_executor(None, func, *args) + + async def create_user(self, username: str, password: str) -> int: + """Create a new user""" + salt = secrets.token_hex(16) + password_hash = _hash_password(password, salt) + + def _create(): + conn = self.get_connection() + cursor = conn.cursor() + try: + cursor.execute( + 'INSERT INTO users (username, password_hash, salt) VALUES (?, ?, ?)', + (username, password_hash, salt) + ) + user_id = cursor.lastrowid + conn.commit() + return user_id + finally: + conn.close() + + user_id = await self.run_in_executor(_create) + + user_dir = Path(Config.WEBDAV_ROOT) / 'users' / username + user_dir.mkdir(parents=True, exist_ok=True) + + return user_id + + def _get_user_from_db(self, username: str) -> Optional[Dict]: + """Fetches user data from the database.""" + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute('SELECT * FROM users WHERE username = ? AND is_active = 1', (username,)) + user = cursor.fetchone() + return dict(user) if user else None + finally: + conn.close() + + async def verify_user(self, username: str, password: str) -> Optional[Dict]: + """Verify user credentials using a cached hash function.""" + user_data = await self.run_in_executor(self._get_user_from_db, username) + if not user_data: + return None + + password_hash = _hash_password(password, user_data['salt']) + if hmac.compare_digest(password_hash, user_data['password_hash']): + return user_data + return None + + async def get_lock(self, resource_path: str) -> Optional[Dict]: + def _get(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute(''' + SELECT * FROM locks WHERE resource_path = ? + AND datetime(created_at, '+' || timeout_seconds || ' seconds') > datetime('now') + ''', (resource_path,)) + lock = cursor.fetchone() + return dict(lock) if lock else None + finally: + conn.close() + return await self.run_in_executor(_get) + + async def create_lock(self, resource_path: str, user_id: int, timeout: int, owner: str) -> str: + lock_token = f"opaquelocktoken:{secrets.token_urlsafe(16)}" + def _create(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + 'INSERT INTO locks (lock_token, resource_path, user_id, timeout_seconds, owner) VALUES (?, ?, ?, ?, ?)', + (lock_token, resource_path, user_id, timeout, owner) + ) + conn.commit() + return lock_token + finally: + conn.close() + return await self.run_in_executor(_create) + + async def remove_lock(self, lock_token: str, user_id: int) -> bool: + def _remove(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute('DELETE FROM locks WHERE lock_token = ? AND user_id = ?', (lock_token, user_id)) + deleted = cursor.rowcount > 0 + conn.commit() + return deleted + finally: + conn.close() + return await self.run_in_executor(_remove) + + async def get_properties(self, resource_path: str) -> List[Dict]: + def _get(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute('SELECT * FROM properties WHERE resource_path = ?', (resource_path,)) + properties = cursor.fetchall() + return [dict(prop) for prop in properties] + finally: + conn.close() + return await self.run_in_executor(_get) + + async def set_property(self, resource_path: str, namespace: str, property_name: str, property_value: str): + def _set(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + 'INSERT OR REPLACE INTO properties (resource_path, namespace, property_name, property_value) VALUES (?, ?, ?, ?)', + (resource_path, namespace, property_name, property_value) + ) + conn.commit() + finally: + conn.close() + await self.run_in_executor(_set) + + async def remove_property(self, resource_path: str, namespace: str, property_name: str): + def _remove(): + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute( + 'DELETE FROM properties WHERE resource_path = ? AND namespace = ? AND property_name = ?', + (resource_path, namespace, property_name) + ) + conn.commit() + finally: + conn.close() + await self.run_in_executor(_remove) + +# ============================================================================ +# XML Utilities for WebDAV +# ============================================================================ + +class WebDAVXML: + """XML processing utilities for WebDAV protocol""" + NS = {'D': 'DAV:'} + + @staticmethod + def register_namespaces(): + for prefix, uri in WebDAVXML.NS.items(): + ET.register_namespace(prefix, uri) + + @staticmethod + def create_multistatus() -> ET.Element: + return ET.Element('{DAV:}multistatus') + + @staticmethod + def create_response(href: str) -> ET.Element: + response = ET.Element('{DAV:}response') + href_elem = ET.SubElement(response, '{DAV:}href') + href_elem.text = href + return response + + @staticmethod + def add_propstat(response: ET.Element, props: Dict[str, str], status: str = '200 OK'): + propstat = ET.SubElement(response, '{DAV:}propstat') + prop = ET.SubElement(propstat, '{DAV:}prop') + + is_collection = props.pop('_is_collection', False) + + for prop_name, prop_value in props.items(): + prop_elem = ET.SubElement(prop, prop_name) + if prop_name == '{DAV:}resourcetype' and is_collection: + ET.SubElement(prop_elem, '{DAV:}collection') + elif prop_value is not None: + prop_elem.text = str(prop_value) + + status_elem = ET.SubElement(propstat, '{DAV:}status') + status_elem.text = f'HTTP/1.1 {status}' + + @staticmethod + def serialize(element: ET.Element) -> str: + WebDAVXML.register_namespaces() + return ET.tostring(element, encoding='unicode', xml_declaration=True) + + @staticmethod + def parse_propfind(body: bytes) -> Tuple[str, List[str]]: + if not body: return 'allprop', [] + try: + root = ET.fromstring(body) + if root.find('.//{DAV:}allprop') is not None: return 'allprop', [] + if root.find('.//{DAV:}propname') is not None: return 'propname', [] + prop_elem = root.find('.//{DAV:}prop') + if prop_elem is not None: + return 'prop', [child.tag for child in prop_elem] + except ET.ParseError: + pass + return 'allprop', [] + + +# ============================================================================ +# Authentication and Authorization +# ============================================================================ + +class AuthHandler: + """Handle authentication methods""" + + def __init__(self, db: Database): + self.db = db + + async def authenticate_basic(self, request: web.Request) -> Optional[Dict]: + auth_header = request.headers.get('Authorization') + if not auth_header or not auth_header.startswith('Basic '): + return None + try: + auth_decoded = base64.b64decode(auth_header[6:]).decode() + username, password = auth_decoded.split(':', 1) + return await self.db.verify_user(username, password) + except (ValueError, UnicodeDecodeError): + return None + + async def authenticate(self, request: web.Request) -> Optional[Dict]: + if 'basic' in Config.AUTH_METHODS: + return await self.authenticate_basic(request) + return None + + def require_auth_response(self) -> web.Response: + return web.Response( + status=401, + headers={'WWW-Authenticate': 'Basic realm="WebDAV Server"'}, + text='Unauthorized' + ) + + +# ============================================================================ +# WebDAV Handler +# ============================================================================ + +class WebDAVHandler: + """Main WebDAV protocol handler with prime-on-write caching""" + + def __init__(self, db: Database, auth: AuthHandler): + self.db = db + self.auth = auth + self.metadata_cache = {} + self.cache_lock = asyncio.Lock() + WebDAVXML.register_namespaces() + + def get_user_root(self, username: str) -> Path: + return Path(Config.WEBDAV_ROOT) / 'users' / username + + def get_physical_path(self, username: str, webdav_path: str) -> Path: + webdav_path = unquote(webdav_path).lstrip('/') + user_root = self.get_user_root(username) + physical_path = (user_root / webdav_path).resolve() + + if user_root.resolve() not in physical_path.parents and physical_path != user_root.resolve(): + raise web.HTTPForbidden(text="Access denied outside of user root.") + + return physical_path + + async def run_blocking_io(self, func, *args, **kwargs): + fn = functools.partial(func, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor(None, fn) + + def get_cache_key(self, user: Dict, webdav_path: str) -> str: + return f"{user['id']}:{webdav_path}" + + async def _invalidate_cache_entry(self, user: Dict, webdav_path: str): + """Invalidates a single entry and its parent from the cache.""" + async with self.cache_lock: + key = self.get_cache_key(user, webdav_path) + if key in self.metadata_cache: + del self.metadata_cache[key] + + parent_path = str(Path(webdav_path).parent) + parent_key = self.get_cache_key(user, parent_path) + if parent_key in self.metadata_cache: + del self.metadata_cache[parent_key] + + async def _update_and_cache_properties(self, path: Path, href: str, user: Dict) -> Dict: + """Fetches properties for a resource and caches them.""" + try: + stat = await self.run_blocking_io(path.stat) + is_dir = await self.run_blocking_io(path.is_dir) + except FileNotFoundError: + return {} + + props = { + '{DAV:}displayname': path.name, + '{DAV:}creationdate': datetime.fromtimestamp(stat.st_ctime).isoformat() + "Z", + '{DAV:}getlastmodified': datetime.fromtimestamp(stat.st_mtime).strftime('%a, %d %b %Y %H:%M:%S GMT'), + '{DAV:}resourcetype': None, + '_is_collection': is_dir, + } + if not is_dir: + props['{DAV:}getcontentlength'] = str(stat.st_size) + content_type, _ = mimetypes.guess_type(str(path)) + props['{DAV:}getcontenttype'] = content_type or 'application/octet-stream' + + db_props = await self.db.get_properties(href) + for prop in db_props: + props[f"{{{prop['namespace']}}}{prop['property_name']}"] = prop['property_value'] + + key = self.get_cache_key(user, href) + async with self.cache_lock: + self.metadata_cache[key] = props + return props + + async def get_resource_properties(self, path: Path, href: str, user: Dict) -> Dict: + """Gets resource properties, using cache if available.""" + key = self.get_cache_key(user, href) + async with self.cache_lock: + if key in self.metadata_cache: + return self.metadata_cache[key] + + # On cache miss, fetch and update the cache + return await self._update_and_cache_properties(path, href, user) + + async def handle_options(self, request: web.Request, user: Dict) -> web.Response: + return web.Response( + status=200, + headers={ + 'DAV': '1, 2', + 'MS-Author-Via': 'DAV', + 'Allow': 'OPTIONS, GET, HEAD, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK', + } + ) + + async def _generate_etag(self, path: Path) -> str: + """Generates an ETag for a file based on size and mtime.""" + try: + stat = await self.run_blocking_io(path.stat) + etag_data = f"{stat.st_size}-{stat.st_mtime_ns}" + return f'"{hashlib.sha1(etag_data.encode()).hexdigest()}"' + except FileNotFoundError: + return "" + + async def handle_get(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + if await self.run_blocking_io(path.is_dir): + raise web.HTTPForbidden(text="Directory listing not supported via GET.") + + etag = await self._generate_etag(path) + if etag and request.headers.get('If-None-Match') == etag: + return web.Response(status=304, headers={'ETag': etag}) + + async with aiofiles.open(path, 'rb') as f: + content = await f.read() + + content_type, _ = mimetypes.guess_type(str(path)) + return web.Response(body=content, content_type=content_type or 'application/octet-stream', headers={'ETag': etag}) + + async def handle_head(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + etag = await self._generate_etag(path) + if etag and request.headers.get('If-None-Match') == etag: + return web.Response(status=304, headers={'ETag': etag}) + + stat = await self.run_blocking_io(path.stat) + content_type, _ = mimetypes.guess_type(str(path)) + headers = {'Content-Type': content_type or 'application/octet-stream', 'Content-Length': str(stat.st_size), 'ETag': etag} + return web.Response(headers=headers) + + async def handle_put(self, request: web.Request, user: Dict) -> web.Response: + """Handles PUT requests with robust error handling and cache priming.""" + path = self.get_physical_path(user['username'], request.path) + + try: + exists = await self.run_blocking_io(path.exists) + + # Ensure parent directory exists + await self.run_blocking_io(path.parent.mkdir, parents=True, exist_ok=True) + + # Write the file content + async with aiofiles.open(path, 'wb') as f: + async for chunk in request.content.iter_chunked(8192): + await f.write(chunk) + + # After a successful write, prime the cache for the new file + await self._update_and_cache_properties(path, request.path, user) + + return web.Response(status=204 if exists else 201) + + finally: + # CRITICAL: Always invalidate the parent directory's cache after any PUT, + # even if the cache prime operation above fails. This prevents stale listings. + await self._invalidate_cache_entry(user, str(Path(request.path).parent)) + + async def handle_delete(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + # Invalidate cache before deletion + await self._invalidate_cache_entry(user, request.path) + + try: + if await self.run_blocking_io(path.is_dir): + await self.run_blocking_io(shutil.rmtree, path) + else: + await self.run_blocking_io(path.unlink) + return web.Response(status=204) + except OSError as e: + raise web.HTTPConflict(text=f"Cannot delete resource: {e}") + + async def handle_mkcol(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + if await self.run_blocking_io(path.exists): + raise web.HTTPMethodNotAllowed(method='MKCOL', allowed_methods=[]) + + if not await self.run_blocking_io(path.parent.exists): + raise web.HTTPConflict() + + await self.run_blocking_io(path.mkdir) + + # Prime cache for the new directory and invalidate parent + await self._update_and_cache_properties(path, request.path, user) + await self._invalidate_cache_entry(user, str(Path(request.path).parent)) + + return web.Response(status=201) + + async def add_resource_to_multistatus(self, multistatus: ET.Element, path: Path, href: str, user: Dict): + props = await self.get_resource_properties(path, href, user) + if props: + response = WebDAVXML.create_response(quote(href)) + WebDAVXML.add_propstat(response, props) + multistatus.append(response) + + async def handle_propfind(self, request: web.Request, user: Dict) -> web.Response: + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + depth = request.headers.get('Depth', '1') + multistatus = WebDAVXML.create_multistatus() + + # Add the resource itself + await self.add_resource_to_multistatus(multistatus, path, request.path, user) + + # Add children if depth=1 and it's a directory + if depth == '1' and await self.run_blocking_io(path.is_dir): + child_tasks = [] + for child_path in await self.run_blocking_io(list, path.iterdir()): + child_href = f"{request.path.rstrip('/')}/{child_path.name}" + task = self.add_resource_to_multistatus(multistatus, child_path, child_href, user) + child_tasks.append(task) + await asyncio.gather(*child_tasks) + + xml_response = WebDAVXML.serialize(multistatus) + return web.Response(status=207, content_type='application/xml', text=xml_response) + + async def handle_proppatch(self, request: web.Request, user: Dict) -> web.Response: + """Handles PROPPATCH with the typo fixed.""" + path = self.get_physical_path(user['username'], request.path) + if not await self.run_blocking_io(path.exists): + raise web.HTTPNotFound() + + body = await request.read() + root = ET.fromstring(body) + + for prop_action in root: + prop_container = prop_action.find('{DAV:}prop') + if prop_container is None: + continue + + for prop in prop_container: + # Correctly parse namespace and property name + tag_parts = prop.tag.split('}') + if len(tag_parts) != 2: + continue # Skip malformed tags + + ns = tag_parts[0][1:] + name = tag_parts[1] + + if prop_action.tag.endswith("set"): + await self.db.set_property(request.path, ns, name, prop.text or "") + elif prop_action.tag.endswith("remove"): + await self.db.remove_property(request.path, ns, name) + + # Invalidate and update cache after property change + await self._update_and_cache_properties(path, request.path, user) + + multistatus = WebDAVXML.create_multistatus() + await self.add_resource_to_multistatus(multistatus, path, request.path, user) + return web.Response(status=207, content_type='application/xml', text=WebDAVXML.serialize(multistatus)) + + + async def handle_copy(self, request: web.Request, user: Dict) -> web.Response: + src_path = self.get_physical_path(user['username'], request.path) + dest_header = request.headers.get('Destination') + if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header") + + dest_href = urlparse(dest_header).path + dest_path = self.get_physical_path(user['username'], dest_href) + + overwrite = request.headers.get('Overwrite', 'T').upper() == 'T' + if await self.run_blocking_io(dest_path.exists) and not overwrite: + raise web.HTTPPreconditionFailed() + + if await self.run_blocking_io(src_path.is_dir): + await self.run_blocking_io(shutil.copytree, src_path, dest_path, dirs_exist_ok=overwrite) + else: + await self.run_blocking_io(shutil.copy2, src_path, dest_path) + + # Prime cache for destination and invalidate parent + await self._update_and_cache_properties(dest_path, dest_href, user) + await self._invalidate_cache_entry(user, str(Path(dest_href).parent)) + + return web.Response(status=201) + + async def handle_move(self, request: web.Request, user: Dict) -> web.Response: + src_path = self.get_physical_path(user['username'], request.path) + dest_header = request.headers.get('Destination') + if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header") + + dest_href = urlparse(dest_header).path + dest_path = self.get_physical_path(user['username'], dest_href) + + overwrite = request.headers.get('Overwrite', 'T').upper() == 'T' + if await self.run_blocking_io(dest_path.exists) and not overwrite: + raise web.HTTPPreconditionFailed() + + # Invalidate source before move + await self._invalidate_cache_entry(user, request.path) + + await self.run_blocking_io(shutil.move, str(src_path), str(dest_path)) + + # Prime cache for new destination and invalidate its parent + await self._update_and_cache_properties(dest_path, dest_href, user) + await self._invalidate_cache_entry(user, str(Path(dest_href).parent)) + + return web.Response(status=201) + + async def handle_lock(self, request: web.Request, user: Dict) -> web.Response: + body = await request.read() + owner_info = ET.tostring(ET.fromstring(body).find('.//{DAV:}owner'), encoding='unicode') + timeout_header = request.headers.get('Timeout', f'Second-{Config.LOCK_TIMEOUT_DEFAULT}') + timeout = int(timeout_header.split('-')[1]) + + lock_token = await self.db.create_lock(request.path, user['id'], timeout, owner_info) + + response_xml = f''' + + +0Second-{timeout} +{lock_token} +{owner_info} +''' + return web.Response(status=200, content_type='application/xml', text=response_xml, headers={'Lock-Token': f'<{lock_token}>'}) + + async def handle_unlock(self, request: web.Request, user: Dict) -> web.Response: + lock_token = request.headers.get('Lock-Token', '').strip('<>') + if not lock_token: raise web.HTTPBadRequest(text="Missing Lock-Token header") + + if await self.db.remove_lock(lock_token, user['id']): + return web.Response(status=204) + else: + raise web.HTTPConflict(text="Lock not found or not owned by user") + + +# ============================================================================ +# Web Application +# ============================================================================ + +async def webdav_handler_func(request: web.Request): + """Main routing function for all WebDAV methods.""" + app = request.app + auth_handler: AuthHandler = app['auth'] + webdav_handler: WebDAVHandler = app['webdav'] + + # OPTIONS is often unauthenticated (pre-flight) + if request.method == 'OPTIONS': + return await webdav_handler.handle_options(request, {}) + + user = await auth_handler.authenticate(request) + if not user: + return auth_handler.require_auth_response() + + # Route to the correct handler based on method + method_map = { + 'GET': webdav_handler.handle_get, + 'HEAD': webdav_handler.handle_head, + 'PUT': webdav_handler.handle_put, + 'DELETE': webdav_handler.handle_delete, + 'MKCOL': webdav_handler.handle_mkcol, + 'PROPFIND': webdav_handler.handle_propfind, + 'PROPPATCH': webdav_handler.handle_proppatch, + 'COPY': webdav_handler.handle_copy, + 'MOVE': webdav_handler.handle_move, + 'LOCK': webdav_handler.handle_lock, + 'UNLOCK': webdav_handler.handle_unlock, + } + + handler = method_map.get(request.method) + if handler: + return await handler(request, user) + else: + raise web.HTTPMethodNotAllowed(method=request.method, allowed_methods=list(method_map.keys())) + + +async def init_app() -> web.Application: + """Initialize web application""" + app = web.Application(client_max_size=Config.MAX_FILE_SIZE) + + db = Database(Config.DB_PATH) + app['db'] = db + app['auth'] = AuthHandler(db) + app['webdav'] = WebDAVHandler(db, app['auth']) + + app.router.add_route('*', '/{path:.*}', webdav_handler_func) + return app + + +async def create_default_user(db: Database): + """Create default admin user if no users exist""" + def _check_user_exists(): + conn = db.get_connection() + try: + cursor = conn.cursor() + cursor.execute('SELECT COUNT(*) as count FROM users') + return cursor.fetchone()['count'] > 0 + finally: + conn.close() + + user_exists = await asyncio.get_event_loop().run_in_executor(None, _check_user_exists) + if not user_exists: + print("No users found. Creating default user 'admin' with password 'admin123'.") + await db.create_user('admin', 'admin123') + print("Default user created. Please change the password for security.") + + +def main(): + """Main entry point""" + Path(Config.WEBDAV_ROOT).mkdir(parents=True, exist_ok=True) + (Path(Config.WEBDAV_ROOT) / 'users').mkdir(exist_ok=True) + + db = Database(Config.DB_PATH) + asyncio.run(create_default_user(db)) + + app = asyncio.run(init_app()) + + print(f"Starting WebDAV Server on http://{Config.HOST}:{Config.PORT}") + web.run_app(app, host=Config.HOST, port=Config.PORT) + + +if __name__ == '__main__': + main()