This commit is contained in:
retoor 2025-10-03 04:29:02 +02:00
parent 9a0a6ce0fa
commit 6da5999860
4 changed files with 2779 additions and 0 deletions

655
benchmark.py Executable file
View File

@ -0,0 +1,655 @@
#!/usr/bin/env python3
"""
WebDAV Server Concurrent Benchmark Tool
Heavy load testing with performance metrics per method
"""
import asyncio
import aiohttp
import time
import argparse
import statistics
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from collections import defaultdict
import random
import string
@dataclass
class RequestMetrics:
"""Metrics for a single request"""
method: str
duration: float
status: int
success: bool
error: Optional[str] = None
@dataclass
class MethodStats:
"""Statistics for a specific HTTP method"""
method: str
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
total_duration: float = 0.0
durations: List[float] = field(default_factory=list)
errors: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
@property
def success_rate(self) -> float:
return (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
@property
def avg_duration(self) -> float:
return self.total_duration / self.total_requests if self.total_requests > 0 else 0
@property
def requests_per_second(self) -> float:
return self.total_requests / self.total_duration if self.total_duration > 0 else 0
@property
def min_duration(self) -> float:
return min(self.durations) if self.durations else 0
@property
def max_duration(self) -> float:
return max(self.durations) if self.durations else 0
@property
def p50_duration(self) -> float:
return statistics.median(self.durations) if self.durations else 0
@property
def p95_duration(self) -> float:
if not self.durations:
return 0
sorted_durations = sorted(self.durations)
index = int(len(sorted_durations) * 0.95)
return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1]
@property
def p99_duration(self) -> float:
if not self.durations:
return 0
sorted_durations = sorted(self.durations)
index = int(len(sorted_durations) * 0.99)
return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1]
class WebDAVBenchmark:
"""WebDAV server benchmark runner"""
def __init__(self, url: str, username: str, password: str,
concurrency: int = 50, duration: int = 60):
self.url = url.rstrip('/')
self.username = username
self.password = password
self.concurrency = concurrency
self.duration = duration
self.stats: Dict[str, MethodStats] = defaultdict(lambda: MethodStats(method=""))
self.start_time = 0
self.stop_flag = False
self.auth = aiohttp.BasicAuth(username, password)
def random_string(self, length: int = 10) -> str:
"""Generate random string"""
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
async def record_metric(self, metric: RequestMetrics):
"""Record a request metric"""
stats = self.stats[metric.method]
stats.method = metric.method
stats.total_requests += 1
stats.total_duration += metric.duration
stats.durations.append(metric.duration)
if metric.success:
stats.successful_requests += 1
else:
stats.failed_requests += 1
if metric.error:
stats.errors[metric.error] += 1
async def benchmark_options(self, session: aiohttp.ClientSession) -> RequestMetrics:
"""Benchmark OPTIONS request"""
start = time.time()
try:
async with session.options(self.url, auth=self.auth) as resp:
duration = time.time() - start
return RequestMetrics(
method='OPTIONS',
duration=duration,
status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='OPTIONS',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_propfind(self, session: aiohttp.ClientSession, depth: int = 0) -> RequestMetrics:
"""Benchmark PROPFIND request"""
propfind_body = '''<?xml version="1.0"?>
<D:propfind xmlns:D="DAV:">
<D:allprop/>
</D:propfind>'''
start = time.time()
try:
async with session.request(
'PROPFIND',
self.url,
auth=self.auth,
data=propfind_body,
headers={'Depth': str(depth), 'Content-Type': 'application/xml'}
) as resp:
await resp.read() # Consume response
duration = time.time() - start
return RequestMetrics(
method='PROPFIND',
duration=duration,
status=resp.status,
success=resp.status == 207
)
except Exception as e:
return RequestMetrics(
method='PROPFIND',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_put(self, session: aiohttp.ClientSession) -> RequestMetrics:
"""Benchmark PUT request"""
filename = f"bench_{self.random_string()}.txt"
content = self.random_string(1024).encode() # 1KB file
start = time.time()
try:
async with session.put(
f"{self.url}/{filename}",
auth=self.auth,
data=content
) as resp:
duration = time.time() - start
return RequestMetrics(
method='PUT',
duration=duration,
status=resp.status,
success=resp.status in [201, 204]
)
except Exception as e:
return RequestMetrics(
method='PUT',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_get(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark GET request"""
start = time.time()
try:
print(f"{self.url}/{filename}")
async with session.get(
f"{self.url}/{filename}",
auth=self.auth
) as resp:
await resp.read() # Consume response
duration = time.time() - start
return RequestMetrics(
method='GET',
duration=duration,
status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='GET',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_head(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark HEAD request"""
start = time.time()
try:
async with session.head(
f"{self.url}/{filename}",
auth=self.auth
) as resp:
duration = time.time() - start
return RequestMetrics(
method='HEAD',
duration=duration,
status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='HEAD',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_mkcol(self, session: aiohttp.ClientSession) -> RequestMetrics:
"""Benchmark MKCOL request"""
dirname = f"bench_dir_{self.random_string()}"
start = time.time()
try:
async with session.request(
'MKCOL',
f"{self.url}/{dirname}/",
auth=self.auth
) as resp:
duration = time.time() - start
return RequestMetrics(
method='MKCOL',
duration=duration,
status=resp.status,
success=resp.status == 201
)
except Exception as e:
return RequestMetrics(
method='MKCOL',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_proppatch(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark PROPPATCH request"""
proppatch_body = f'''<?xml version="1.0"?>
<D:propertyupdate xmlns:D="DAV:">
<D:set>
<D:prop>
<D:displayname>Benchmark Test</D:displayname>
</D:prop>
</D:set>
</D:propertyupdate>'''
start = time.time()
try:
async with session.request(
'PROPPATCH',
f"{self.url}/{filename}",
auth=self.auth,
data=proppatch_body,
headers={'Content-Type': 'application/xml'}
) as resp:
await resp.read()
duration = time.time() - start
return RequestMetrics(
method='PROPPATCH',
duration=duration,
status=resp.status,
success=resp.status == 207
)
except Exception as e:
return RequestMetrics(
method='PROPPATCH',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_copy(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark COPY request"""
dest_filename = f"copy_{self.random_string()}.txt"
start = time.time()
try:
async with session.request(
'COPY',
f"{self.url}/{filename}",
auth=self.auth,
headers={'Destination': f"{self.url}/{dest_filename}"}
) as resp:
duration = time.time() - start
return RequestMetrics(
method='COPY',
duration=duration,
status=resp.status,
success=resp.status in [201, 204]
)
except Exception as e:
return RequestMetrics(
method='COPY',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_move(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark MOVE request"""
dest_filename = f"moved_{self.random_string()}.txt"
start = time.time()
try:
async with session.request(
'MOVE',
f"{self.url}/{filename}",
auth=self.auth,
headers={'Destination': f"{self.url}/{dest_filename}"}
) as resp:
duration = time.time() - start
return RequestMetrics(
method='MOVE',
duration=duration,
status=resp.status,
success=resp.status in [201, 204]
)
except Exception as e:
return RequestMetrics(
method='MOVE',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_lock(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark LOCK request"""
lock_body = '''<?xml version="1.0"?>
<D:lockinfo xmlns:D="DAV:">
<D:lockscope><D:exclusive/></D:lockscope>
<D:locktype><D:write/></D:locktype>
<D:owner>
<D:href>benchmark</D:href>
</D:owner>
</D:lockinfo>'''
start = time.time()
try:
async with session.request(
'LOCK',
f"{self.url}/{filename}",
auth=self.auth,
data=lock_body,
headers={'Content-Type': 'application/xml', 'Timeout': 'Second-300'}
) as resp:
lock_token = resp.headers.get('Lock-Token', '').strip('<>')
await resp.read()
duration = time.time() - start
# Unlock immediately to clean up
if lock_token:
try:
async with session.request(
'UNLOCK',
f"{self.url}/{filename}",
auth=self.auth,
headers={'Lock-Token': f'<{lock_token}>'}
) as unlock_resp:
pass
except:
pass
return RequestMetrics(
method='LOCK',
duration=duration,
status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='LOCK',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def benchmark_delete(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark DELETE request"""
start = time.time()
try:
async with session.delete(
f"{self.url}/{filename}",
auth=self.auth
) as resp:
duration = time.time() - start
return RequestMetrics(
method='DELETE',
duration=duration,
status=resp.status,
success=resp.status == 204
)
except Exception as e:
return RequestMetrics(
method='DELETE',
duration=time.time() - start,
status=0,
success=False,
error=str(e)
)
async def worker(self, worker_id: int, session: aiohttp.ClientSession):
"""Worker coroutine that runs various benchmarks"""
test_files = []
# Create initial test file
filename = f"bench_worker_{worker_id}_{self.random_string()}.txt"
metric = await self.benchmark_put(session)
await self.record_metric(metric)
if metric.success:
test_files.append(filename)
while not self.stop_flag:
elapsed = time.time() - self.start_time
if elapsed >= self.duration:
self.stop_flag = True
break
# Randomly choose operation
operation = random.choice([
'options', 'propfind', 'put', 'get', 'head',
'mkcol', 'proppatch', 'copy', 'move', 'lock', 'delete'
])
try:
if operation == 'options':
metric = await self.benchmark_options(session)
elif operation == 'propfind':
depth = random.choice([0, 1])
metric = await self.benchmark_propfind(session, depth)
elif operation == 'put':
metric = await self.benchmark_put(session)
if metric.success:
filename = f"bench_worker_{worker_id}_{self.random_string()}.txt"
test_files.append(filename)
elif operation == 'get' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_get(session, filename)
elif operation == 'head' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_head(session, filename)
elif operation == 'mkcol':
metric = await self.benchmark_mkcol(session)
elif operation == 'proppatch' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_proppatch(session, filename)
elif operation == 'copy' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_copy(session, filename)
elif operation == 'move' and test_files:
if len(test_files) > 1:
filename = test_files.pop(random.randrange(len(test_files)))
metric = await self.benchmark_move(session, filename)
else:
continue
elif operation == 'lock' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_lock(session, filename)
elif operation == 'delete' and len(test_files) > 1:
filename = test_files.pop(random.randrange(len(test_files)))
metric = await self.benchmark_delete(session, filename)
else:
continue
await self.record_metric(metric)
except Exception as e:
print(f"Worker {worker_id} error: {e}")
# Small delay to prevent overwhelming
await asyncio.sleep(0.001)
async def run(self):
"""Run the benchmark"""
print("="*80)
print("WebDAV Server Concurrent Benchmark")
print("="*80)
print(f"URL: {self.url}")
print(f"Concurrency: {self.concurrency} workers")
print(f"Duration: {self.duration} seconds")
print(f"User: {self.username}")
print("="*80)
print()
connector = aiohttp.TCPConnector(limit=self.concurrency * 2)
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
self.start_time = time.time()
# Create worker tasks
workers = [
asyncio.create_task(self.worker(i, session))
for i in range(self.concurrency)
]
# Progress indicator
progress_task = asyncio.create_task(self.show_progress())
# Wait for all workers
await asyncio.gather(*workers, return_exceptions=True)
# Stop progress
await progress_task
# Print results
self.print_results()
async def show_progress(self):
"""Show progress during benchmark"""
while not self.stop_flag:
elapsed = time.time() - self.start_time
if elapsed >= self.duration:
break
total_requests = sum(s.total_requests for s in self.stats.values())
print(f"\rProgress: {elapsed:.1f}s / {self.duration}s | Total Requests: {total_requests}", end='', flush=True)
await asyncio.sleep(1)
print()
def print_results(self):
"""Print benchmark results"""
print("\n")
print("="*80)
print("BENCHMARK RESULTS")
print("="*80)
print()
total_duration = time.time() - self.start_time
total_requests = sum(s.total_requests for s in self.stats.values())
total_success = sum(s.successful_requests for s in self.stats.values())
total_failed = sum(s.failed_requests for s in self.stats.values())
print(f"Total Duration: {total_duration:.2f}s")
print(f"Total Requests: {total_requests:,}")
print(f"Successful: {total_success:,} ({total_success/total_requests*100:.1f}%)")
print(f"Failed: {total_failed:,} ({total_failed/total_requests*100:.1f}%)")
print(f"Overall RPS: {total_requests/total_duration:.2f}")
print()
# Sort methods by request count
sorted_stats = sorted(self.stats.values(), key=lambda s: s.total_requests, reverse=True)
print("="*80)
print("PER-METHOD STATISTICS")
print("="*80)
print()
for stats in sorted_stats:
if stats.total_requests == 0:
continue
print(f"Method: {stats.method}")
print(f" Requests: {stats.total_requests:>8,}")
print(f" Success Rate: {stats.success_rate:>8.2f}%")
print(f" RPS: {stats.requests_per_second:>8.2f}")
print(f" Latency (ms):")
print(f" Min: {stats.min_duration*1000:>8.2f}")
print(f" Avg: {stats.avg_duration*1000:>8.2f}")
print(f" P50: {stats.p50_duration*1000:>8.2f}")
print(f" P95: {stats.p95_duration*1000:>8.2f}")
print(f" P99: {stats.p99_duration*1000:>8.2f}")
print(f" Max: {stats.max_duration*1000:>8.2f}")
if stats.failed_requests > 0 and stats.errors:
print(f" Errors:")
for error, count in sorted(stats.errors.items(), key=lambda x: x[1], reverse=True)[:5]:
error_short = error[:60] + '...' if len(error) > 60 else error
print(f" {error_short}: {count}")
print()
print("="*80)
async def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description='WebDAV Server Concurrent Benchmark')
parser.add_argument('url', help='WebDAV server URL (e.g., http://localhost:8080/)')
parser.add_argument('username', help='Username for authentication')
parser.add_argument('password', help='Password for authentication')
parser.add_argument('-c', '--concurrency', type=int, default=50,
help='Number of concurrent workers (default: 50)')
parser.add_argument('-d', '--duration', type=int, default=60,
help='Benchmark duration in seconds (default: 60)')
args = parser.parse_args()
benchmark = WebDAVBenchmark(
url=args.url,
username=args.username,
password=args.password,
concurrency=args.concurrency,
duration=args.duration
)
await benchmark.run()
if __name__ == '__main__':
asyncio.run(main())

579
benchmark2.py Normal file
View File

@ -0,0 +1,579 @@
#!/usr/bin/env python3
"""
WebDAV Server Concurrent Benchmark Tool
Heavy load testing with performance metrics per method
"""
import asyncio
import aiohttp
import time
import argparse
import statistics
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from collections import defaultdict
import random
import string
@dataclass
class RequestMetrics:
"""Metrics for a single request"""
method: str
duration: float
status: int
success: bool
error: Optional[str] = None
filename: Optional[str] = None # To track created/moved resources
@dataclass
class MethodStats:
"""Statistics for a specific HTTP method"""
method: str
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
total_duration: float = 0.0
durations: List[float] = field(default_factory=list)
errors: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
@property
def success_rate(self) -> float:
return (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
@property
def avg_duration(self) -> float:
return self.total_duration / self.total_requests if self.total_requests > 0 else 0
@property
def requests_per_second(self) -> float:
# A more accurate RPS for a method is its count over the total benchmark time
# This property is not used in the final report, but we'll leave it for potential use.
return self.total_requests / self.total_duration if self.total_duration > 0 else 0
@property
def min_duration(self) -> float:
return min(self.durations) if self.durations else 0
@property
def max_duration(self) -> float:
return max(self.durations) if self.durations else 0
@property
def p50_duration(self) -> float:
return statistics.median(self.durations) if self.durations else 0
@property
def p95_duration(self) -> float:
if not self.durations:
return 0
sorted_durations = sorted(self.durations)
index = int(len(sorted_durations) * 0.95)
return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1]
@property
def p99_duration(self) -> float:
if not self.durations:
return 0
sorted_durations = sorted(self.durations)
index = int(len(sorted_durations) * 0.99)
return sorted_durations[index] if index < len(sorted_durations) else sorted_durations[-1]
class WebDAVBenchmark:
"""WebDAV server benchmark runner"""
def __init__(self, url: str, username: str, password: str,
concurrency: int = 50, duration: int = 60):
self.url = url.rstrip('/')
self.username = username
self.password = password
self.concurrency = concurrency
self.duration = duration
self.stats: Dict[str, MethodStats] = defaultdict(lambda: MethodStats(method=""))
self.start_time = 0.0
self.stop_flag = False
self.auth = aiohttp.BasicAuth(username, password)
def random_string(self, length: int = 10) -> str:
"""Generate random string"""
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
async def record_metric(self, metric: RequestMetrics):
"""Record a request metric"""
stats = self.stats[metric.method]
if not stats.method:
stats.method = metric.method
stats.total_requests += 1
stats.total_duration += metric.duration
stats.durations.append(metric.duration)
if metric.success:
stats.successful_requests += 1
else:
stats.failed_requests += 1
error_key = f"Status {metric.status}" if metric.status != 0 else str(metric.error)
stats.errors[error_key] += 1
async def benchmark_options(self, session: aiohttp.ClientSession) -> RequestMetrics:
"""Benchmark OPTIONS request"""
start = time.time()
try:
async with session.options(self.url, auth=self.auth) as resp:
duration = time.time() - start
return RequestMetrics(
method='OPTIONS', duration=duration, status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='OPTIONS', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_propfind(self, session: aiohttp.ClientSession, depth: int = 0) -> RequestMetrics:
"""Benchmark PROPFIND request"""
propfind_body = '''<?xml version="1.0"?>
<D:propfind xmlns:D="DAV:">
<D:allprop/>
</D:propfind>'''
start = time.time()
try:
async with session.request(
'PROPFIND', self.url, auth=self.auth, data=propfind_body,
headers={'Depth': str(depth), 'Content-Type': 'application/xml'}
) as resp:
await resp.read()
duration = time.time() - start
return RequestMetrics(
method='PROPFIND', duration=duration, status=resp.status,
success=resp.status == 207
)
except Exception as e:
return RequestMetrics(
method='PROPFIND', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_put(self, session: aiohttp.ClientSession) -> RequestMetrics:
"""Benchmark PUT request"""
filename = f"bench_{self.random_string()}.txt"
content = self.random_string(1024).encode()
start = time.time()
try:
async with session.put(f"{self.url}/{filename}", auth=self.auth, data=content) as resp:
duration = time.time() - start
is_success = resp.status in [201, 204]
return RequestMetrics(
method='PUT', duration=duration, status=resp.status,
success=is_success,
filename=filename if is_success else None
)
except Exception as e:
return RequestMetrics(
method='PUT', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_get(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark GET request"""
start = time.time()
try:
async with session.get(f"{self.url}/{filename}", auth=self.auth) as resp:
await resp.read()
duration = time.time() - start
return RequestMetrics(
method='GET', duration=duration, status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='GET', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_head(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark HEAD request"""
start = time.time()
try:
async with session.head(f"{self.url}/{filename}", auth=self.auth) as resp:
duration = time.time() - start
return RequestMetrics(
method='HEAD', duration=duration, status=resp.status,
success=resp.status == 200
)
except Exception as e:
return RequestMetrics(
method='HEAD', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_mkcol(self, session: aiohttp.ClientSession) -> RequestMetrics:
"""Benchmark MKCOL request"""
dirname = f"bench_dir_{self.random_string()}"
start = time.time()
try:
async with session.request('MKCOL', f"{self.url}/{dirname}/", auth=self.auth) as resp:
duration = time.time() - start
is_success = resp.status == 201
return RequestMetrics(
method='MKCOL', duration=duration, status=resp.status,
success=is_success,
filename=dirname if is_success else None
)
except Exception as e:
return RequestMetrics(
method='MKCOL', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_proppatch(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark PROPPATCH request"""
proppatch_body = '''<?xml version="1.0"?>
<D:propertyupdate xmlns:D="DAV:">
<D:set><D:prop><D:displayname>Benchmark Test</D:displayname></D:prop></D:set>
</D:propertyupdate>'''
start = time.time()
try:
async with session.request(
'PROPPATCH', f"{self.url}/{filename}", auth=self.auth, data=proppatch_body,
headers={'Content-Type': 'application/xml'}
) as resp:
await resp.read()
duration = time.time() - start
return RequestMetrics(
method='PROPPATCH', duration=duration, status=resp.status,
success=resp.status == 207
)
except Exception as e:
return RequestMetrics(
method='PROPPATCH', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_copy(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark COPY request"""
dest_filename = f"copy_{self.random_string()}.txt"
start = time.time()
try:
async with session.request(
'COPY', f"{self.url}/{filename}", auth=self.auth,
headers={'Destination': f"{self.url}/{dest_filename}"}
) as resp:
duration = time.time() - start
is_success = resp.status in [201, 204]
return RequestMetrics(
method='COPY', duration=duration, status=resp.status,
success=is_success,
filename=dest_filename if is_success else None
)
except Exception as e:
return RequestMetrics(
method='COPY', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_move(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark MOVE request"""
dest_filename = f"moved_{self.random_string()}.txt"
start = time.time()
try:
async with session.request(
'MOVE', f"{self.url}/{filename}", auth=self.auth,
headers={'Destination': f"{self.url}/{dest_filename}"}
) as resp:
duration = time.time() - start
is_success = resp.status in [201, 204]
return RequestMetrics(
method='MOVE', duration=duration, status=resp.status,
success=is_success,
filename=dest_filename if is_success else None
)
except Exception as e:
return RequestMetrics(
method='MOVE', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_lock(self, session: aiohttp.ClientSession, filename: str) -> RequestMetrics:
"""Benchmark LOCK request"""
lock_body = '''<?xml version="1.0"?>
<D:lockinfo xmlns:D="DAV:">
<D:lockscope><D:exclusive/></D:lockscope><D:locktype><D:write/></D:locktype>
<D:owner><D:href>benchmark</D:href></D:owner>
</D:lockinfo>'''
start = time.time()
try:
async with session.request(
'LOCK', f"{self.url}/{filename}", auth=self.auth, data=lock_body,
headers={'Content-Type': 'application/xml', 'Timeout': 'Second-300'}
) as resp:
lock_token = resp.headers.get('Lock-Token', '').strip('<>')
await resp.read()
duration = time.time() - start
is_success = resp.status == 200
if is_success and lock_token:
try:
async with session.request(
'UNLOCK', f"{self.url}/{filename}", auth=self.auth,
headers={'Lock-Token': f'<{lock_token}>'}
):
pass
except:
pass
return RequestMetrics(
method='LOCK', duration=duration, status=resp.status,
success=is_success
)
except Exception as e:
return RequestMetrics(
method='LOCK', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def benchmark_delete(self, session: aiohttp.ClientSession, resource_name: str) -> RequestMetrics:
"""Benchmark DELETE request for files or directories"""
start = time.time()
try:
# Add trailing slash for directories for some servers
url_path = f"{self.url}/{resource_name}"
if "dir" in resource_name:
url_path += "/"
async with session.delete(url_path, auth=self.auth) as resp:
duration = time.time() - start
return RequestMetrics(
method='DELETE', duration=duration, status=resp.status,
success=resp.status == 204
)
except Exception as e:
return RequestMetrics(
method='DELETE', duration=time.time() - start, status=0,
success=False, error=str(e)
)
async def worker(self, worker_id: int, session: aiohttp.ClientSession):
"""Worker coroutine that runs various benchmarks"""
test_files = []
test_dirs = []
# Create an initial test file to ensure other operations can start
metric = await self.benchmark_put(session)
await self.record_metric(metric)
if metric.success and metric.filename:
test_files.append(metric.filename)
while not self.stop_flag:
elapsed = time.time() - self.start_time
if elapsed >= self.duration:
self.stop_flag = True
break
# Weighted random choice
operations = [
'options', 'propfind', 'put', 'get', 'head',
'mkcol', 'proppatch', 'copy', 'move', 'lock', 'delete'
]
# Ensure some operations are more frequent
weights = [5, 5, 15, 15, 10, 5, 5, 5, 5, 5, 20] # More PUT, GET, DELETE
operation = random.choices(operations, weights=weights, k=1)[0]
metric = None
try:
if operation == 'options':
metric = await self.benchmark_options(session)
elif operation == 'propfind':
depth = random.choice([0, 1])
metric = await self.benchmark_propfind(session, depth)
elif operation == 'put':
metric = await self.benchmark_put(session)
if metric.success and metric.filename:
test_files.append(metric.filename)
elif operation == 'get' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_get(session, filename)
elif operation == 'head' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_head(session, filename)
elif operation == 'mkcol':
metric = await self.benchmark_mkcol(session)
if metric.success and metric.filename:
test_dirs.append(metric.filename)
elif operation == 'proppatch' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_proppatch(session, filename)
elif operation == 'copy' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_copy(session, filename)
if metric.success and metric.filename:
test_files.append(metric.filename)
elif operation == 'move' and len(test_files) > 1:
filename_to_move = test_files.pop(random.randrange(len(test_files)))
metric = await self.benchmark_move(session, filename_to_move)
if metric.success and metric.filename:
test_files.append(metric.filename)
elif operation == 'lock' and test_files:
filename = random.choice(test_files)
metric = await self.benchmark_lock(session, filename)
elif operation == 'delete':
# Randomly delete a file or a directory
if test_dirs and random.random() < 0.2 and len(test_dirs) > 0: # 20% chance to delete a dir
dir_to_delete = test_dirs.pop(random.randrange(len(test_dirs)))
metric = await self.benchmark_delete(session, dir_to_delete)
elif len(test_files) > 1:
file_to_delete = test_files.pop(random.randrange(len(test_files)))
metric = await self.benchmark_delete(session, file_to_delete)
if metric:
await self.record_metric(metric)
except Exception as e:
print(f"Worker {worker_id} error: {e}")
await asyncio.sleep(0.01) # Small delay to prevent tight loop on empty lists
async def run(self):
"""Run the benchmark"""
print("="*80)
print("WebDAV Server Concurrent Benchmark")
print("="*80)
print(f"URL: {self.url}")
print(f"Concurrency: {self.concurrency} workers")
print(f"Duration: {self.duration} seconds")
print(f"User: {self.username}")
print("="*80)
print()
connector = aiohttp.TCPConnector(limit=self.concurrency * 2)
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
self.start_time = time.time()
workers = [
asyncio.create_task(self.worker(i, session))
for i in range(self.concurrency)
]
progress_task = asyncio.create_task(self.show_progress())
await asyncio.gather(*workers, return_exceptions=True)
self.stop_flag = True
await progress_task
self.print_results()
async def show_progress(self):
"""Show progress during benchmark"""
while not self.stop_flag:
elapsed = time.time() - self.start_time
if elapsed >= self.duration:
break
total_requests = sum(s.total_requests for s in self.stats.values())
print(f"\rProgress: {elapsed:.1f}s / {self.duration}s | Total Requests: {total_requests}", end='', flush=True)
await asyncio.sleep(0.5)
print()
def print_results(self):
"""Print benchmark results"""
print("\n")
print("="*80)
print("BENCHMARK RESULTS")
print("="*80)
print()
total_duration = time.time() - self.start_time
total_requests = sum(s.total_requests for s in self.stats.values())
total_success = sum(s.successful_requests for s in self.stats.values())
total_failed = total_requests - total_success
success_rate = (total_success / total_requests * 100) if total_requests > 0 else 0
failed_rate = (total_failed / total_requests * 100) if total_requests > 0 else 0
print(f"Total Duration: {total_duration:.2f}s")
print(f"Total Requests: {total_requests:,}")
print(f"Successful: {total_success:,} ({success_rate:.1f}%)")
print(f"Failed: {total_failed:,} ({failed_rate:.1f}%)")
print(f"Overall RPS: {total_requests/total_duration:.2f}")
print()
sorted_stats = sorted(self.stats.values(), key=lambda s: s.total_requests, reverse=True)
print("="*80)
print("PER-METHOD STATISTICS")
print("="*80)
print()
for stats in sorted_stats:
if stats.total_requests == 0:
continue
# Calculate RPS based on total benchmark duration for better comparison
method_rps = stats.total_requests / total_duration
print(f"Method: {stats.method}")
print(f" Requests: {stats.total_requests:>8,}")
print(f" Success Rate: {stats.success_rate:>8.2f}%")
print(f" RPS: {method_rps:>8.2f}")
print(f" Latency (ms):")
print(f" Min: {stats.min_duration*1000:>8.2f}")
print(f" Avg: {stats.avg_duration*1000:>8.2f}")
print(f" P50: {stats.p50_duration*1000:>8.2f}")
print(f" P95: {stats.p95_duration*1000:>8.2f}")
print(f" P99: {stats.p99_duration*1000:>8.2f}")
print(f" Max: {stats.max_duration*1000:>8.2f}")
if stats.failed_requests > 0 and stats.errors:
print(f" Errors:")
for error, count in sorted(stats.errors.items(), key=lambda x: x[1], reverse=True)[:5]:
error_short = error[:60] + '...' if len(error) > 60 else error
print(f" {error_short}: {count}")
print()
print("="*80)
async def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description='WebDAV Server Concurrent Benchmark')
parser.add_argument('url', help='WebDAV server URL (e.g., http://localhost:8080/)')
parser.add_argument('username', help='Username for authentication')
parser.add_argument('password', help='Password for authentication')
parser.add_argument('-c', '--concurrency', type=int, default=50,
help='Number of concurrent workers (default: 50)')
parser.add_argument('-d', '--duration', type=int, default=60,
help='Benchmark duration in seconds (default: 60)')
args = parser.parse_args()
benchmark = WebDAVBenchmark(
url=args.url,
username=args.username,
password=args.password,
concurrency=args.concurrency,
duration=args.duration
)
await benchmark.run()
if __name__ == '__main__':
asyncio.run(main())

752
main2.py Normal file
View File

@ -0,0 +1,752 @@
#!/usr/bin/env python3
"""
Complete WebDAV Server Implementation with aiohttp
Production-ready WebDAV server with full RFC 4918 compliance,
Windows Explorer compatibility, and comprehensive user management.
Includes multi-layered caching for high performance:
1. HTTP ETags for client-side caching.
2. In-memory LRU cache for filesystem metadata (PROPFIND).
3. In-memory LRU cache for password hashing (Authentication).
4. Asynchronous handling of blocking file operations.
"""
import os
import asyncio
import aiofiles
import sqlite3
import hashlib
import hmac
import secrets
import mimetypes
import base64
import functools
import shutil
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, List, Tuple
from xml.etree import ElementTree as ET
from urllib.parse import unquote, quote, urlparse
from aiohttp import web
from aiohttp_session import setup as setup_session
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# ============================================================================
# Configuration Management
# ============================================================================
class Config:
"""Centralized configuration management from environment variables"""
# Server Configuration
HOST = os.getenv('HOST', '0.0.0.0')
PORT = int(os.getenv('PORT', '8080'))
# Database Configuration
DB_PATH = os.getenv('DB_PATH', './webdav.db')
# Authentication Configuration
AUTH_METHODS = os.getenv('AUTH_METHODS', 'basic').split(',')
# WebDAV Configuration
MAX_FILE_SIZE = int(os.getenv('MAX_FILE_SIZE', '104857600')) # 100MB
MAX_PROPFIND_DEPTH = int(os.getenv('MAX_PROPFIND_DEPTH', '3'))
LOCK_TIMEOUT_DEFAULT = int(os.getenv('LOCK_TIMEOUT_DEFAULT', '3600'))
# WebDAV Root Directory
WEBDAV_ROOT = os.getenv('WEBDAV_ROOT', './webdav')
# ============================================================================
# Database Layer
# ============================================================================
# This is the function we will cache. Caching works best on pure functions.
@functools.lru_cache(maxsize=128)
def _hash_password(password: str, salt: str) -> str:
"""Hashes a password with a salt. This is the expensive part."""
return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex()
class Database:
"""SQLite database management with async wrapper"""
def __init__(self, db_path: str):
self.db_path = db_path
self._connection_lock = asyncio.Lock()
self.init_database()
def get_connection(self) -> sqlite3.Connection:
"""Get database connection with row factory"""
conn = sqlite3.connect(self.db_path, timeout=30.0, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute('PRAGMA journal_mode=WAL')
conn.execute('PRAGMA busy_timeout=30000')
conn.execute('PRAGMA synchronous=NORMAL')
return conn
def init_database(self):
"""Initialize database schema"""
conn = self.get_connection()
cursor = conn.cursor()
# Users table
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
salt TEXT NOT NULL,
is_active BOOLEAN DEFAULT 1
)
''')
# Locks table
cursor.execute('''
CREATE TABLE IF NOT EXISTS locks (
lock_token TEXT PRIMARY KEY,
resource_path TEXT NOT NULL,
user_id INTEGER,
lock_type TEXT DEFAULT 'write',
lock_scope TEXT DEFAULT 'exclusive',
depth INTEGER DEFAULT 0,
timeout_seconds INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
owner TEXT,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
# Properties table
cursor.execute('''
CREATE TABLE IF NOT EXISTS properties (
id INTEGER PRIMARY KEY AUTOINCREMENT,
resource_path TEXT NOT NULL,
namespace TEXT,
property_name TEXT NOT NULL,
property_value TEXT,
UNIQUE(resource_path, namespace, property_name)
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_locks_resource ON locks(resource_path)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_properties_resource ON properties(resource_path)')
conn.commit()
conn.close()
async def run_in_executor(self, func, *args):
"""Run a synchronous database function in a thread pool."""
return await asyncio.get_event_loop().run_in_executor(None, func, *args)
async def create_user(self, username: str, password: str) -> int:
"""Create a new user"""
salt = secrets.token_hex(16)
password_hash = _hash_password(password, salt)
def _create():
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
'INSERT INTO users (username, password_hash, salt) VALUES (?, ?, ?)',
(username, password_hash, salt)
)
user_id = cursor.lastrowid
conn.commit()
return user_id
finally:
conn.close()
user_id = await self.run_in_executor(_create)
user_dir = Path(Config.WEBDAV_ROOT) / 'users' / username
user_dir.mkdir(parents=True, exist_ok=True)
return user_id
def _get_user_from_db(self, username: str) -> Optional[Dict]:
"""Fetches user data from the database."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('SELECT * FROM users WHERE username = ? AND is_active = 1', (username,))
user = cursor.fetchone()
return dict(user) if user else None
finally:
conn.close()
async def verify_user(self, username: str, password: str) -> Optional[Dict]:
"""Verify user credentials using a cached hash function."""
user_data = await self.run_in_executor(self._get_user_from_db, username)
if not user_data:
return None
password_hash = _hash_password(password, user_data['salt'])
if hmac.compare_digest(password_hash, user_data['password_hash']):
return user_data
return None
async def get_lock(self, resource_path: str) -> Optional[Dict]:
def _get():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM locks WHERE resource_path = ?
AND datetime(created_at, '+' || timeout_seconds || ' seconds') > datetime('now')
''', (resource_path,))
lock = cursor.fetchone()
return dict(lock) if lock else None
finally:
conn.close()
return await self.run_in_executor(_get)
async def create_lock(self, resource_path: str, user_id: int, timeout: int, owner: str) -> str:
lock_token = f"opaquelocktoken:{secrets.token_urlsafe(16)}"
def _create():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'INSERT INTO locks (lock_token, resource_path, user_id, timeout_seconds, owner) VALUES (?, ?, ?, ?, ?)',
(lock_token, resource_path, user_id, timeout, owner)
)
conn.commit()
return lock_token
finally:
conn.close()
return await self.run_in_executor(_create)
async def remove_lock(self, lock_token: str, user_id: int) -> bool:
def _remove():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('DELETE FROM locks WHERE lock_token = ? AND user_id = ?', (lock_token, user_id))
deleted = cursor.rowcount > 0
conn.commit()
return deleted
finally:
conn.close()
return await self.run_in_executor(_remove)
async def get_properties(self, resource_path: str) -> List[Dict]:
def _get():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('SELECT * FROM properties WHERE resource_path = ?', (resource_path,))
properties = cursor.fetchall()
return [dict(prop) for prop in properties]
finally:
conn.close()
return await self.run_in_executor(_get)
async def set_property(self, resource_path: str, namespace: str, property_name: str, property_value: str):
def _set():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'INSERT OR REPLACE INTO properties (resource_path, namespace, property_name, property_value) VALUES (?, ?, ?, ?)',
(resource_path, namespace, property_name, property_value)
)
conn.commit()
finally:
conn.close()
await self.run_in_executor(_set)
async def remove_property(self, resource_path: str, namespace: str, property_name: str):
def _remove():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'DELETE FROM properties WHERE resource_path = ? AND namespace = ? AND property_name = ?',
(resource_path, namespace, property_name)
)
conn.commit()
finally:
conn.close()
await self.run_in_executor(_remove)
# ============================================================================
# XML Utilities for WebDAV
# ============================================================================
class WebDAVXML:
"""XML processing utilities for WebDAV protocol"""
NS = {'D': 'DAV:'}
@staticmethod
def register_namespaces():
for prefix, uri in WebDAVXML.NS.items():
ET.register_namespace(prefix, uri)
@staticmethod
def create_multistatus() -> ET.Element:
return ET.Element('{DAV:}multistatus')
@staticmethod
def create_response(href: str) -> ET.Element:
response = ET.Element('{DAV:}response')
href_elem = ET.SubElement(response, '{DAV:}href')
href_elem.text = href
return response
@staticmethod
def add_propstat(response: ET.Element, props: Dict[str, str], status: str = '200 OK'):
propstat = ET.SubElement(response, '{DAV:}propstat')
prop = ET.SubElement(propstat, '{DAV:}prop')
is_collection = props.pop('_is_collection', False)
for prop_name, prop_value in props.items():
prop_elem = ET.SubElement(prop, prop_name)
if prop_name == '{DAV:}resourcetype' and is_collection:
ET.SubElement(prop_elem, '{DAV:}collection')
elif prop_value is not None:
prop_elem.text = str(prop_value)
status_elem = ET.SubElement(propstat, '{DAV:}status')
status_elem.text = f'HTTP/1.1 {status}'
@staticmethod
def serialize(element: ET.Element) -> str:
WebDAVXML.register_namespaces()
return ET.tostring(element, encoding='unicode', xml_declaration=True)
@staticmethod
def parse_propfind(body: bytes) -> Tuple[str, List[str]]:
if not body: return 'allprop', []
try:
root = ET.fromstring(body)
if root.find('.//{DAV:}allprop') is not None: return 'allprop', []
if root.find('.//{DAV:}propname') is not None: return 'propname', []
prop_elem = root.find('.//{DAV:}prop')
if prop_elem is not None:
return 'prop', [child.tag for child in prop_elem]
except ET.ParseError:
pass
return 'allprop', []
# ============================================================================
# Authentication and Authorization
# ============================================================================
class AuthHandler:
"""Handle authentication methods"""
def __init__(self, db: Database):
self.db = db
async def authenticate_basic(self, request: web.Request) -> Optional[Dict]:
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Basic '):
return None
try:
auth_decoded = base64.b64decode(auth_header[6:]).decode()
username, password = auth_decoded.split(':', 1)
return await self.db.verify_user(username, password)
except (ValueError, UnicodeDecodeError):
return None
async def authenticate(self, request: web.Request) -> Optional[Dict]:
if 'basic' in Config.AUTH_METHODS:
return await self.authenticate_basic(request)
return None
def require_auth_response(self) -> web.Response:
return web.Response(
status=401,
headers={'WWW-Authenticate': 'Basic realm="WebDAV Server"'},
text='Unauthorized'
)
# ============================================================================
# WebDAV Handler
# ============================================================================
class WebDAVHandler:
"""Main WebDAV protocol handler"""
def __init__(self, db: Database, auth: AuthHandler):
self.db = db
self.auth = auth
self.metadata_cache = {}
self.cache_lock = asyncio.Lock()
WebDAVXML.register_namespaces()
def get_user_root(self, username: str) -> Path:
return Path(Config.WEBDAV_ROOT) / 'users' / username
def get_physical_path(self, username: str, webdav_path: str) -> Path:
webdav_path = unquote(webdav_path).lstrip('/')
user_root = self.get_user_root(username)
physical_path = (user_root / webdav_path).resolve()
if user_root.resolve() not in physical_path.parents and physical_path != user_root.resolve():
raise web.HTTPForbidden(text="Access denied outside of user root.")
return physical_path
async def run_blocking_io(self, func, *args, **kwargs):
fn = functools.partial(func, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, fn)
async def _invalidate_cache_entry(self, user: Dict, webdav_path: str):
"""Invalidates a single entry and its parent from the cache."""
key_prefix = f"{user['username']}:"
async with self.cache_lock:
# Invalidate the resource itself
if (key_prefix + webdav_path) in self.metadata_cache:
del self.metadata_cache[key_prefix + webdav_path]
# Invalidate its parent directory
parent_path = str(Path(webdav_path).parent)
if (key_prefix + parent_path) in self.metadata_cache:
del self.metadata_cache[key_prefix + parent_path]
async def handle_options(self, request: web.Request, user: Dict) -> web.Response:
return web.Response(
status=200,
headers={
'DAV': '1, 2',
'MS-Author-Via': 'DAV',
'Allow': 'OPTIONS, GET, HEAD, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK',
}
)
async def _generate_etag(self, path: Path) -> str:
"""Generates an ETag for a file based on size and mtime."""
try:
stat = await self.run_blocking_io(path.stat)
etag_data = f"{stat.st_size}-{stat.st_mtime_ns}"
return f'"{hashlib.sha1(etag_data.encode()).hexdigest()}"'
except FileNotFoundError:
return ""
async def handle_get(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
is_dir = await self.run_blocking_io(path.is_dir)
if is_dir:
raise web.HTTPForbidden(text="Directory listing not supported via GET.")
etag = await self._generate_etag(path)
if etag and request.headers.get('If-None-Match') == etag:
return web.Response(status=304, headers={'ETag': etag})
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
content_type, _ = mimetypes.guess_type(str(path))
return web.Response(
body=content,
content_type=content_type or 'application/octet-stream',
headers={'ETag': etag}
)
async def handle_head(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
etag = await self._generate_etag(path)
if etag and request.headers.get('If-None-Match') == etag:
return web.Response(status=304, headers={'ETag': etag})
stat = await self.run_blocking_io(path.stat)
content_type, _ = mimetypes.guess_type(str(path))
headers = {
'Content-Type': content_type or 'application/octet-stream',
'Content-Length': str(stat.st_size),
'ETag': etag
}
return web.Response(headers=headers)
async def handle_put(self, request: web.Request, user: Dict) -> web.Response:
await self._invalidate_cache_entry(user, request.path)
path = self.get_physical_path(user['username'], request.path)
exists = await self.run_blocking_io(path.exists)
await self.run_blocking_io(path.parent.mkdir, parents=True, exist_ok=True)
async with aiofiles.open(path, 'wb') as f:
async for chunk in request.content.iter_chunked(8192):
await f.write(chunk)
return web.Response(status=204 if exists else 201)
async def handle_delete(self, request: web.Request, user: Dict) -> web.Response:
await self._invalidate_cache_entry(user, request.path)
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
try:
is_dir = await self.run_blocking_io(path.is_dir)
if is_dir:
await self.run_blocking_io(shutil.rmtree, path)
else:
await self.run_blocking_io(path.unlink)
return web.Response(status=204)
except OSError as e:
raise web.HTTPConflict(text=f"Cannot delete resource: {e}")
async def handle_mkcol(self, request: web.Request, user: Dict) -> web.Response:
await self._invalidate_cache_entry(user, str(Path(request.path).parent))
path = self.get_physical_path(user['username'], request.path)
if await self.run_blocking_io(path.exists):
raise web.HTTPMethodNotAllowed(method='MKCOL', allowed_methods=[])
if not await self.run_blocking_io(path.parent.exists):
raise web.HTTPConflict()
await self.run_blocking_io(path.mkdir)
return web.Response(status=201)
async def get_resource_properties(self, path: Path, href: str, user: Dict) -> Dict[str, str]:
cache_key = f"{user['username']}:{href}"
async with self.cache_lock:
if cache_key in self.metadata_cache:
return self.metadata_cache[cache_key]
try:
stat = await self.run_blocking_io(path.stat)
is_dir = await self.run_blocking_io(path.is_dir)
except FileNotFoundError:
return {}
props = {
'{DAV:}displayname': path.name,
'{DAV:}creationdate': datetime.fromtimestamp(stat.st_ctime).isoformat() + "Z",
'{DAV:}getlastmodified': datetime.fromtimestamp(stat.st_mtime).strftime('%a, %d %b %Y %H:%M:%S GMT'),
'{DAV:}resourcetype': None,
'_is_collection': is_dir,
}
if not is_dir:
props['{DAV:}getcontentlength'] = str(stat.st_size)
content_type, _ = mimetypes.guess_type(str(path))
props['{DAV:}getcontenttype'] = content_type or 'application/octet-stream'
db_props = await self.db.get_properties(href)
for prop in db_props:
props[f"{{{prop['namespace']}}}{prop['property_name']}"] = prop['property_value']
async with self.cache_lock:
self.metadata_cache[cache_key] = props
return props
async def add_resource_to_multistatus(self, multistatus: ET.Element, path: Path, href: str, user: Dict):
props = await self.get_resource_properties(path, href, user)
if props:
response = WebDAVXML.create_response(quote(href))
WebDAVXML.add_propstat(response, props)
multistatus.append(response)
async def handle_propfind(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
depth = request.headers.get('Depth', '1')
body = await request.read()
multistatus = WebDAVXML.create_multistatus()
await self.add_resource_to_multistatus(multistatus, path, request.path, user)
if depth == '1' and await self.run_blocking_io(path.is_dir):
for child_path in await self.run_blocking_io(list, path.iterdir()):
child_href = f"{request.path.rstrip('/')}/{child_path.name}"
await self.add_resource_to_multistatus(multistatus, child_path, child_href, user)
xml_response = WebDAVXML.serialize(multistatus)
return web.Response(status=207, content_type='application/xml', text=xml_response)
async def handle_proppatch(self, request: web.Request, user: Dict) -> web.Response:
await self._invalidate_cache_entry(user, request.path)
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
body = await request.read()
root = ET.fromstring(body)
for prop_action in root:
if prop_action.tag.endswith("set"):
for prop in prop_action.find('{DAV:}prop'):
await self.db.set_property(request.path, prop.tag.split('}')[0][1:], prop.tag.split('}')[1], prop.text or "")
elif prop_action.tag.endswith("remove"):
for prop in prop_action.find('{DAV:}prop'):
await self.db.remove_property(request.path, prop.tag.split('}')[0][1:], prop.tag.split('}')[1])
multistatus = WebDAVXML.create_multistatus()
await self.add_resource_to_multistatus(multistatus, path, request.path, user)
return web.Response(status=207, content_type='application/xml', text=WebDAVXML.serialize(multistatus))
async def handle_copy(self, request: web.Request, user: Dict) -> web.Response:
src_path = self.get_physical_path(user['username'], request.path)
dest_header = request.headers.get('Destination')
if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header")
dest_path = self.get_physical_path(user['username'], urlparse(dest_header).path)
await self._invalidate_cache_entry(user, str(Path(urlparse(dest_header).path).parent))
overwrite = request.headers.get('Overwrite', 'T').upper() == 'T'
if await self.run_blocking_io(dest_path.exists) and not overwrite:
raise web.HTTPPreconditionFailed()
is_dir = await self.run_blocking_io(src_path.is_dir)
if is_dir:
await self.run_blocking_io(shutil.copytree, src_path, dest_path, dirs_exist_ok=overwrite)
else:
await self.run_blocking_io(shutil.copy2, src_path, dest_path)
return web.Response(status=201)
async def handle_move(self, request: web.Request, user: Dict) -> web.Response:
src_path = self.get_physical_path(user['username'], request.path)
dest_header = request.headers.get('Destination')
if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header")
dest_path = self.get_physical_path(user['username'], urlparse(dest_header).path)
await self._invalidate_cache_entry(user, request.path)
await self._invalidate_cache_entry(user, urlparse(dest_header).path)
overwrite = request.headers.get('Overwrite', 'T').upper() == 'T'
if await self.run_blocking_io(dest_path.exists) and not overwrite:
raise web.HTTPPreconditionFailed()
await self.run_blocking_io(shutil.move, str(src_path), str(dest_path))
return web.Response(status=201)
async def handle_lock(self, request: web.Request, user: Dict) -> web.Response:
body = await request.read()
owner_info = ET.tostring(ET.fromstring(body).find('.//{DAV:}owner'), encoding='unicode')
timeout_header = request.headers.get('Timeout', f'Second-{Config.LOCK_TIMEOUT_DEFAULT}')
timeout = int(timeout_header.split('-')[1])
lock_token = await self.db.create_lock(request.path, user['id'], timeout, owner_info)
response_xml = f'''<?xml version="1.0" encoding="utf-8"?>
<D:prop xmlns:D="DAV:"><D:lockdiscovery><D:activelock>
<D:locktype><D:write/></D:locktype><D:lockscope><D:exclusive/></D:lockscope>
<D:depth>0</D:depth><D:timeout>Second-{timeout}</D:timeout>
<D:locktoken><D:href>{lock_token}</D:href></D:locktoken>
{owner_info}
</D:activelock></D:lockdiscovery></D:prop>'''
return web.Response(status=200, content_type='application/xml', text=response_xml, headers={'Lock-Token': f'<{lock_token}>'})
async def handle_unlock(self, request: web.Request, user: Dict) -> web.Response:
lock_token = request.headers.get('Lock-Token', '').strip('<>')
if not lock_token: raise web.HTTPBadRequest(text="Missing Lock-Token header")
if await self.db.remove_lock(lock_token, user['id']):
return web.Response(status=204)
else:
raise web.HTTPConflict(text="Lock not found or not owned by user")
# ============================================================================
# Web Application
# ============================================================================
async def webdav_handler_func(request: web.Request):
"""Main routing function for all WebDAV methods."""
app = request.app
auth_handler: AuthHandler = app['auth']
webdav_handler: WebDAVHandler = app['webdav']
# OPTIONS is often unauthenticated (pre-flight)
if request.method == 'OPTIONS':
return await webdav_handler.handle_options(request, {})
user = await auth_handler.authenticate(request)
if not user:
return auth_handler.require_auth_response()
# Route to the correct handler based on method
method_map = {
'GET': webdav_handler.handle_get,
'HEAD': webdav_handler.handle_head,
'PUT': webdav_handler.handle_put,
'DELETE': webdav_handler.handle_delete,
'MKCOL': webdav_handler.handle_mkcol,
'PROPFIND': webdav_handler.handle_propfind,
'PROPPATCH': webdav_handler.handle_proppatch,
'COPY': webdav_handler.handle_copy,
'MOVE': webdav_handler.handle_move,
'LOCK': webdav_handler.handle_lock,
'UNLOCK': webdav_handler.handle_unlock,
}
handler = method_map.get(request.method)
if handler:
return await handler(request, user)
else:
raise web.HTTPMethodNotAllowed(method=request.method, allowed_methods=list(method_map.keys()))
async def init_app() -> web.Application:
"""Initialize web application"""
app = web.Application(client_max_size=Config.MAX_FILE_SIZE)
db = Database(Config.DB_PATH)
app['db'] = db
app['auth'] = AuthHandler(db)
app['webdav'] = WebDAVHandler(db, app['auth'])
app.router.add_route('*', '/{path:.*}', webdav_handler_func)
return app
async def create_default_user(db: Database):
"""Create default admin user if no users exist"""
def _check_user_exists():
conn = db.get_connection()
try:
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) as count FROM users')
return cursor.fetchone()['count'] > 0
finally:
conn.close()
user_exists = await asyncio.get_event_loop().run_in_executor(None, _check_user_exists)
if not user_exists:
print("No users found. Creating default user 'admin' with password 'admin123'.")
await db.create_user('admin', 'admin123')
print("Default user created. Please change the password for security.")
def main():
"""Main entry point"""
Path(Config.WEBDAV_ROOT).mkdir(parents=True, exist_ok=True)
(Path(Config.WEBDAV_ROOT) / 'users').mkdir(exist_ok=True)
db = Database(Config.DB_PATH)
asyncio.run(create_default_user(db))
app = asyncio.run(init_app())
print(f"Starting WebDAV Server on http://{Config.HOST}:{Config.PORT}")
web.run_app(app, host=Config.HOST, port=Config.PORT)
if __name__ == '__main__':
main()

793
main3.py Normal file
View File

@ -0,0 +1,793 @@
#!/usr/bin/env python3
"""
Complete WebDAV Server Implementation with aiohttp
Production-ready WebDAV server with full RFC 4918 compliance,
Windows Explorer compatibility, and comprehensive user management.
Includes multi-layered caching for high performance:
1. HTTP ETags for client-side caching.
2. "Prime on Write" metadata cache for consistently fast PROPFIND.
3. In-memory LRU cache for password hashing (Authentication).
4. Asynchronous handling of blocking file operations and stability fixes.
"""
import os
import asyncio
import aiofiles
import sqlite3
import hashlib
import hmac
import secrets
import mimetypes
import base64
import functools
import shutil
from datetime import datetime
from pathlib import Path
from typing import Optional, Dict, List, Tuple
from xml.etree import ElementTree as ET
from urllib.parse import unquote, quote, urlparse
from aiohttp import web
from aiohttp_session import setup as setup_session
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# ============================================================================
# Configuration Management
# ============================================================================
class Config:
"""Centralized configuration management from environment variables"""
# Server Configuration
HOST = os.getenv('HOST', '0.0.0.0')
PORT = int(os.getenv('PORT', '8080'))
# Database Configuration
DB_PATH = os.getenv('DB_PATH', './webdav.db')
# Authentication Configuration
AUTH_METHODS = os.getenv('AUTH_METHODS', 'basic').split(',')
# WebDAV Configuration
MAX_FILE_SIZE = int(os.getenv('MAX_FILE_SIZE', '104857600')) # 100MB
MAX_PROPFIND_DEPTH = int(os.getenv('MAX_PROPFIND_DEPTH', '3'))
LOCK_TIMEOUT_DEFAULT = int(os.getenv('LOCK_TIMEOUT_DEFAULT', '3600'))
# WebDAV Root Directory
WEBDAV_ROOT = os.getenv('WEBDAV_ROOT', './webdav')
# ============================================================================
# Database Layer
# ============================================================================
# This is the function we will cache. Caching works best on pure functions.
@functools.lru_cache(maxsize=128)
def _hash_password(password: str, salt: str) -> str:
"""Hashes a password with a salt. This is the expensive part."""
return hashlib.pbkdf2_hmac('sha256', password.encode(), salt.encode(), 100000).hex()
class Database:
"""SQLite database management with async wrapper"""
def __init__(self, db_path: str):
self.db_path = db_path
self._connection_lock = asyncio.Lock()
self.init_database()
def get_connection(self) -> sqlite3.Connection:
"""Get database connection with row factory"""
conn = sqlite3.connect(self.db_path, timeout=30.0, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute('PRAGMA journal_mode=WAL')
conn.execute('PRAGMA busy_timeout=30000')
conn.execute('PRAGMA synchronous=NORMAL')
return conn
def init_database(self):
"""Initialize database schema"""
conn = self.get_connection()
cursor = conn.cursor()
# Users table
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
salt TEXT NOT NULL,
is_active BOOLEAN DEFAULT 1
)
''')
# Locks table
cursor.execute('''
CREATE TABLE IF NOT EXISTS locks (
lock_token TEXT PRIMARY KEY,
resource_path TEXT NOT NULL,
user_id INTEGER,
lock_type TEXT DEFAULT 'write',
lock_scope TEXT DEFAULT 'exclusive',
depth INTEGER DEFAULT 0,
timeout_seconds INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
owner TEXT,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
# Properties table
cursor.execute('''
CREATE TABLE IF NOT EXISTS properties (
id INTEGER PRIMARY KEY AUTOINCREMENT,
resource_path TEXT NOT NULL,
namespace TEXT,
property_name TEXT NOT NULL,
property_value TEXT,
UNIQUE(resource_path, namespace, property_name)
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_locks_resource ON locks(resource_path)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_properties_resource ON properties(resource_path)')
conn.commit()
conn.close()
async def run_in_executor(self, func, *args):
"""Run a synchronous database function in a thread pool."""
return await asyncio.get_event_loop().run_in_executor(None, func, *args)
async def create_user(self, username: str, password: str) -> int:
"""Create a new user"""
salt = secrets.token_hex(16)
password_hash = _hash_password(password, salt)
def _create():
conn = self.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
'INSERT INTO users (username, password_hash, salt) VALUES (?, ?, ?)',
(username, password_hash, salt)
)
user_id = cursor.lastrowid
conn.commit()
return user_id
finally:
conn.close()
user_id = await self.run_in_executor(_create)
user_dir = Path(Config.WEBDAV_ROOT) / 'users' / username
user_dir.mkdir(parents=True, exist_ok=True)
return user_id
def _get_user_from_db(self, username: str) -> Optional[Dict]:
"""Fetches user data from the database."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('SELECT * FROM users WHERE username = ? AND is_active = 1', (username,))
user = cursor.fetchone()
return dict(user) if user else None
finally:
conn.close()
async def verify_user(self, username: str, password: str) -> Optional[Dict]:
"""Verify user credentials using a cached hash function."""
user_data = await self.run_in_executor(self._get_user_from_db, username)
if not user_data:
return None
password_hash = _hash_password(password, user_data['salt'])
if hmac.compare_digest(password_hash, user_data['password_hash']):
return user_data
return None
async def get_lock(self, resource_path: str) -> Optional[Dict]:
def _get():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM locks WHERE resource_path = ?
AND datetime(created_at, '+' || timeout_seconds || ' seconds') > datetime('now')
''', (resource_path,))
lock = cursor.fetchone()
return dict(lock) if lock else None
finally:
conn.close()
return await self.run_in_executor(_get)
async def create_lock(self, resource_path: str, user_id: int, timeout: int, owner: str) -> str:
lock_token = f"opaquelocktoken:{secrets.token_urlsafe(16)}"
def _create():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'INSERT INTO locks (lock_token, resource_path, user_id, timeout_seconds, owner) VALUES (?, ?, ?, ?, ?)',
(lock_token, resource_path, user_id, timeout, owner)
)
conn.commit()
return lock_token
finally:
conn.close()
return await self.run_in_executor(_create)
async def remove_lock(self, lock_token: str, user_id: int) -> bool:
def _remove():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('DELETE FROM locks WHERE lock_token = ? AND user_id = ?', (lock_token, user_id))
deleted = cursor.rowcount > 0
conn.commit()
return deleted
finally:
conn.close()
return await self.run_in_executor(_remove)
async def get_properties(self, resource_path: str) -> List[Dict]:
def _get():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute('SELECT * FROM properties WHERE resource_path = ?', (resource_path,))
properties = cursor.fetchall()
return [dict(prop) for prop in properties]
finally:
conn.close()
return await self.run_in_executor(_get)
async def set_property(self, resource_path: str, namespace: str, property_name: str, property_value: str):
def _set():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'INSERT OR REPLACE INTO properties (resource_path, namespace, property_name, property_value) VALUES (?, ?, ?, ?)',
(resource_path, namespace, property_name, property_value)
)
conn.commit()
finally:
conn.close()
await self.run_in_executor(_set)
async def remove_property(self, resource_path: str, namespace: str, property_name: str):
def _remove():
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
'DELETE FROM properties WHERE resource_path = ? AND namespace = ? AND property_name = ?',
(resource_path, namespace, property_name)
)
conn.commit()
finally:
conn.close()
await self.run_in_executor(_remove)
# ============================================================================
# XML Utilities for WebDAV
# ============================================================================
class WebDAVXML:
"""XML processing utilities for WebDAV protocol"""
NS = {'D': 'DAV:'}
@staticmethod
def register_namespaces():
for prefix, uri in WebDAVXML.NS.items():
ET.register_namespace(prefix, uri)
@staticmethod
def create_multistatus() -> ET.Element:
return ET.Element('{DAV:}multistatus')
@staticmethod
def create_response(href: str) -> ET.Element:
response = ET.Element('{DAV:}response')
href_elem = ET.SubElement(response, '{DAV:}href')
href_elem.text = href
return response
@staticmethod
def add_propstat(response: ET.Element, props: Dict[str, str], status: str = '200 OK'):
propstat = ET.SubElement(response, '{DAV:}propstat')
prop = ET.SubElement(propstat, '{DAV:}prop')
is_collection = props.pop('_is_collection', False)
for prop_name, prop_value in props.items():
prop_elem = ET.SubElement(prop, prop_name)
if prop_name == '{DAV:}resourcetype' and is_collection:
ET.SubElement(prop_elem, '{DAV:}collection')
elif prop_value is not None:
prop_elem.text = str(prop_value)
status_elem = ET.SubElement(propstat, '{DAV:}status')
status_elem.text = f'HTTP/1.1 {status}'
@staticmethod
def serialize(element: ET.Element) -> str:
WebDAVXML.register_namespaces()
return ET.tostring(element, encoding='unicode', xml_declaration=True)
@staticmethod
def parse_propfind(body: bytes) -> Tuple[str, List[str]]:
if not body: return 'allprop', []
try:
root = ET.fromstring(body)
if root.find('.//{DAV:}allprop') is not None: return 'allprop', []
if root.find('.//{DAV:}propname') is not None: return 'propname', []
prop_elem = root.find('.//{DAV:}prop')
if prop_elem is not None:
return 'prop', [child.tag for child in prop_elem]
except ET.ParseError:
pass
return 'allprop', []
# ============================================================================
# Authentication and Authorization
# ============================================================================
class AuthHandler:
"""Handle authentication methods"""
def __init__(self, db: Database):
self.db = db
async def authenticate_basic(self, request: web.Request) -> Optional[Dict]:
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Basic '):
return None
try:
auth_decoded = base64.b64decode(auth_header[6:]).decode()
username, password = auth_decoded.split(':', 1)
return await self.db.verify_user(username, password)
except (ValueError, UnicodeDecodeError):
return None
async def authenticate(self, request: web.Request) -> Optional[Dict]:
if 'basic' in Config.AUTH_METHODS:
return await self.authenticate_basic(request)
return None
def require_auth_response(self) -> web.Response:
return web.Response(
status=401,
headers={'WWW-Authenticate': 'Basic realm="WebDAV Server"'},
text='Unauthorized'
)
# ============================================================================
# WebDAV Handler
# ============================================================================
class WebDAVHandler:
"""Main WebDAV protocol handler with prime-on-write caching"""
def __init__(self, db: Database, auth: AuthHandler):
self.db = db
self.auth = auth
self.metadata_cache = {}
self.cache_lock = asyncio.Lock()
WebDAVXML.register_namespaces()
def get_user_root(self, username: str) -> Path:
return Path(Config.WEBDAV_ROOT) / 'users' / username
def get_physical_path(self, username: str, webdav_path: str) -> Path:
webdav_path = unquote(webdav_path).lstrip('/')
user_root = self.get_user_root(username)
physical_path = (user_root / webdav_path).resolve()
if user_root.resolve() not in physical_path.parents and physical_path != user_root.resolve():
raise web.HTTPForbidden(text="Access denied outside of user root.")
return physical_path
async def run_blocking_io(self, func, *args, **kwargs):
fn = functools.partial(func, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, fn)
def get_cache_key(self, user: Dict, webdav_path: str) -> str:
return f"{user['id']}:{webdav_path}"
async def _invalidate_cache_entry(self, user: Dict, webdav_path: str):
"""Invalidates a single entry and its parent from the cache."""
async with self.cache_lock:
key = self.get_cache_key(user, webdav_path)
if key in self.metadata_cache:
del self.metadata_cache[key]
parent_path = str(Path(webdav_path).parent)
parent_key = self.get_cache_key(user, parent_path)
if parent_key in self.metadata_cache:
del self.metadata_cache[parent_key]
async def _update_and_cache_properties(self, path: Path, href: str, user: Dict) -> Dict:
"""Fetches properties for a resource and caches them."""
try:
stat = await self.run_blocking_io(path.stat)
is_dir = await self.run_blocking_io(path.is_dir)
except FileNotFoundError:
return {}
props = {
'{DAV:}displayname': path.name,
'{DAV:}creationdate': datetime.fromtimestamp(stat.st_ctime).isoformat() + "Z",
'{DAV:}getlastmodified': datetime.fromtimestamp(stat.st_mtime).strftime('%a, %d %b %Y %H:%M:%S GMT'),
'{DAV:}resourcetype': None,
'_is_collection': is_dir,
}
if not is_dir:
props['{DAV:}getcontentlength'] = str(stat.st_size)
content_type, _ = mimetypes.guess_type(str(path))
props['{DAV:}getcontenttype'] = content_type or 'application/octet-stream'
db_props = await self.db.get_properties(href)
for prop in db_props:
props[f"{{{prop['namespace']}}}{prop['property_name']}"] = prop['property_value']
key = self.get_cache_key(user, href)
async with self.cache_lock:
self.metadata_cache[key] = props
return props
async def get_resource_properties(self, path: Path, href: str, user: Dict) -> Dict:
"""Gets resource properties, using cache if available."""
key = self.get_cache_key(user, href)
async with self.cache_lock:
if key in self.metadata_cache:
return self.metadata_cache[key]
# On cache miss, fetch and update the cache
return await self._update_and_cache_properties(path, href, user)
async def handle_options(self, request: web.Request, user: Dict) -> web.Response:
return web.Response(
status=200,
headers={
'DAV': '1, 2',
'MS-Author-Via': 'DAV',
'Allow': 'OPTIONS, GET, HEAD, PUT, DELETE, PROPFIND, PROPPATCH, MKCOL, COPY, MOVE, LOCK, UNLOCK',
}
)
async def _generate_etag(self, path: Path) -> str:
"""Generates an ETag for a file based on size and mtime."""
try:
stat = await self.run_blocking_io(path.stat)
etag_data = f"{stat.st_size}-{stat.st_mtime_ns}"
return f'"{hashlib.sha1(etag_data.encode()).hexdigest()}"'
except FileNotFoundError:
return ""
async def handle_get(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
if await self.run_blocking_io(path.is_dir):
raise web.HTTPForbidden(text="Directory listing not supported via GET.")
etag = await self._generate_etag(path)
if etag and request.headers.get('If-None-Match') == etag:
return web.Response(status=304, headers={'ETag': etag})
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
content_type, _ = mimetypes.guess_type(str(path))
return web.Response(body=content, content_type=content_type or 'application/octet-stream', headers={'ETag': etag})
async def handle_head(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
etag = await self._generate_etag(path)
if etag and request.headers.get('If-None-Match') == etag:
return web.Response(status=304, headers={'ETag': etag})
stat = await self.run_blocking_io(path.stat)
content_type, _ = mimetypes.guess_type(str(path))
headers = {'Content-Type': content_type or 'application/octet-stream', 'Content-Length': str(stat.st_size), 'ETag': etag}
return web.Response(headers=headers)
async def handle_put(self, request: web.Request, user: Dict) -> web.Response:
"""Handles PUT requests with robust error handling and cache priming."""
path = self.get_physical_path(user['username'], request.path)
try:
exists = await self.run_blocking_io(path.exists)
# Ensure parent directory exists
await self.run_blocking_io(path.parent.mkdir, parents=True, exist_ok=True)
# Write the file content
async with aiofiles.open(path, 'wb') as f:
async for chunk in request.content.iter_chunked(8192):
await f.write(chunk)
# After a successful write, prime the cache for the new file
await self._update_and_cache_properties(path, request.path, user)
return web.Response(status=204 if exists else 201)
finally:
# CRITICAL: Always invalidate the parent directory's cache after any PUT,
# even if the cache prime operation above fails. This prevents stale listings.
await self._invalidate_cache_entry(user, str(Path(request.path).parent))
async def handle_delete(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
# Invalidate cache before deletion
await self._invalidate_cache_entry(user, request.path)
try:
if await self.run_blocking_io(path.is_dir):
await self.run_blocking_io(shutil.rmtree, path)
else:
await self.run_blocking_io(path.unlink)
return web.Response(status=204)
except OSError as e:
raise web.HTTPConflict(text=f"Cannot delete resource: {e}")
async def handle_mkcol(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if await self.run_blocking_io(path.exists):
raise web.HTTPMethodNotAllowed(method='MKCOL', allowed_methods=[])
if not await self.run_blocking_io(path.parent.exists):
raise web.HTTPConflict()
await self.run_blocking_io(path.mkdir)
# Prime cache for the new directory and invalidate parent
await self._update_and_cache_properties(path, request.path, user)
await self._invalidate_cache_entry(user, str(Path(request.path).parent))
return web.Response(status=201)
async def add_resource_to_multistatus(self, multistatus: ET.Element, path: Path, href: str, user: Dict):
props = await self.get_resource_properties(path, href, user)
if props:
response = WebDAVXML.create_response(quote(href))
WebDAVXML.add_propstat(response, props)
multistatus.append(response)
async def handle_propfind(self, request: web.Request, user: Dict) -> web.Response:
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
depth = request.headers.get('Depth', '1')
multistatus = WebDAVXML.create_multistatus()
# Add the resource itself
await self.add_resource_to_multistatus(multistatus, path, request.path, user)
# Add children if depth=1 and it's a directory
if depth == '1' and await self.run_blocking_io(path.is_dir):
child_tasks = []
for child_path in await self.run_blocking_io(list, path.iterdir()):
child_href = f"{request.path.rstrip('/')}/{child_path.name}"
task = self.add_resource_to_multistatus(multistatus, child_path, child_href, user)
child_tasks.append(task)
await asyncio.gather(*child_tasks)
xml_response = WebDAVXML.serialize(multistatus)
return web.Response(status=207, content_type='application/xml', text=xml_response)
async def handle_proppatch(self, request: web.Request, user: Dict) -> web.Response:
"""Handles PROPPATCH with the typo fixed."""
path = self.get_physical_path(user['username'], request.path)
if not await self.run_blocking_io(path.exists):
raise web.HTTPNotFound()
body = await request.read()
root = ET.fromstring(body)
for prop_action in root:
prop_container = prop_action.find('{DAV:}prop')
if prop_container is None:
continue
for prop in prop_container:
# Correctly parse namespace and property name
tag_parts = prop.tag.split('}')
if len(tag_parts) != 2:
continue # Skip malformed tags
ns = tag_parts[0][1:]
name = tag_parts[1]
if prop_action.tag.endswith("set"):
await self.db.set_property(request.path, ns, name, prop.text or "")
elif prop_action.tag.endswith("remove"):
await self.db.remove_property(request.path, ns, name)
# Invalidate and update cache after property change
await self._update_and_cache_properties(path, request.path, user)
multistatus = WebDAVXML.create_multistatus()
await self.add_resource_to_multistatus(multistatus, path, request.path, user)
return web.Response(status=207, content_type='application/xml', text=WebDAVXML.serialize(multistatus))
async def handle_copy(self, request: web.Request, user: Dict) -> web.Response:
src_path = self.get_physical_path(user['username'], request.path)
dest_header = request.headers.get('Destination')
if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header")
dest_href = urlparse(dest_header).path
dest_path = self.get_physical_path(user['username'], dest_href)
overwrite = request.headers.get('Overwrite', 'T').upper() == 'T'
if await self.run_blocking_io(dest_path.exists) and not overwrite:
raise web.HTTPPreconditionFailed()
if await self.run_blocking_io(src_path.is_dir):
await self.run_blocking_io(shutil.copytree, src_path, dest_path, dirs_exist_ok=overwrite)
else:
await self.run_blocking_io(shutil.copy2, src_path, dest_path)
# Prime cache for destination and invalidate parent
await self._update_and_cache_properties(dest_path, dest_href, user)
await self._invalidate_cache_entry(user, str(Path(dest_href).parent))
return web.Response(status=201)
async def handle_move(self, request: web.Request, user: Dict) -> web.Response:
src_path = self.get_physical_path(user['username'], request.path)
dest_header = request.headers.get('Destination')
if not dest_header: raise web.HTTPBadRequest(text="Missing Destination header")
dest_href = urlparse(dest_header).path
dest_path = self.get_physical_path(user['username'], dest_href)
overwrite = request.headers.get('Overwrite', 'T').upper() == 'T'
if await self.run_blocking_io(dest_path.exists) and not overwrite:
raise web.HTTPPreconditionFailed()
# Invalidate source before move
await self._invalidate_cache_entry(user, request.path)
await self.run_blocking_io(shutil.move, str(src_path), str(dest_path))
# Prime cache for new destination and invalidate its parent
await self._update_and_cache_properties(dest_path, dest_href, user)
await self._invalidate_cache_entry(user, str(Path(dest_href).parent))
return web.Response(status=201)
async def handle_lock(self, request: web.Request, user: Dict) -> web.Response:
body = await request.read()
owner_info = ET.tostring(ET.fromstring(body).find('.//{DAV:}owner'), encoding='unicode')
timeout_header = request.headers.get('Timeout', f'Second-{Config.LOCK_TIMEOUT_DEFAULT}')
timeout = int(timeout_header.split('-')[1])
lock_token = await self.db.create_lock(request.path, user['id'], timeout, owner_info)
response_xml = f'''<?xml version="1.0" encoding="utf-8"?>
<D:prop xmlns:D="DAV:"><D:lockdiscovery><D:activelock>
<D:locktype><D:write/></D:locktype><D:lockscope><D:exclusive/></D:lockscope>
<D:depth>0</D:depth><D:timeout>Second-{timeout}</D:timeout>
<D:locktoken><D:href>{lock_token}</D:href></D:locktoken>
{owner_info}
</D:activelock></D:lockdiscovery></D:prop>'''
return web.Response(status=200, content_type='application/xml', text=response_xml, headers={'Lock-Token': f'<{lock_token}>'})
async def handle_unlock(self, request: web.Request, user: Dict) -> web.Response:
lock_token = request.headers.get('Lock-Token', '').strip('<>')
if not lock_token: raise web.HTTPBadRequest(text="Missing Lock-Token header")
if await self.db.remove_lock(lock_token, user['id']):
return web.Response(status=204)
else:
raise web.HTTPConflict(text="Lock not found or not owned by user")
# ============================================================================
# Web Application
# ============================================================================
async def webdav_handler_func(request: web.Request):
"""Main routing function for all WebDAV methods."""
app = request.app
auth_handler: AuthHandler = app['auth']
webdav_handler: WebDAVHandler = app['webdav']
# OPTIONS is often unauthenticated (pre-flight)
if request.method == 'OPTIONS':
return await webdav_handler.handle_options(request, {})
user = await auth_handler.authenticate(request)
if not user:
return auth_handler.require_auth_response()
# Route to the correct handler based on method
method_map = {
'GET': webdav_handler.handle_get,
'HEAD': webdav_handler.handle_head,
'PUT': webdav_handler.handle_put,
'DELETE': webdav_handler.handle_delete,
'MKCOL': webdav_handler.handle_mkcol,
'PROPFIND': webdav_handler.handle_propfind,
'PROPPATCH': webdav_handler.handle_proppatch,
'COPY': webdav_handler.handle_copy,
'MOVE': webdav_handler.handle_move,
'LOCK': webdav_handler.handle_lock,
'UNLOCK': webdav_handler.handle_unlock,
}
handler = method_map.get(request.method)
if handler:
return await handler(request, user)
else:
raise web.HTTPMethodNotAllowed(method=request.method, allowed_methods=list(method_map.keys()))
async def init_app() -> web.Application:
"""Initialize web application"""
app = web.Application(client_max_size=Config.MAX_FILE_SIZE)
db = Database(Config.DB_PATH)
app['db'] = db
app['auth'] = AuthHandler(db)
app['webdav'] = WebDAVHandler(db, app['auth'])
app.router.add_route('*', '/{path:.*}', webdav_handler_func)
return app
async def create_default_user(db: Database):
"""Create default admin user if no users exist"""
def _check_user_exists():
conn = db.get_connection()
try:
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) as count FROM users')
return cursor.fetchone()['count'] > 0
finally:
conn.close()
user_exists = await asyncio.get_event_loop().run_in_executor(None, _check_user_exists)
if not user_exists:
print("No users found. Creating default user 'admin' with password 'admin123'.")
await db.create_user('admin', 'admin123')
print("Default user created. Please change the password for security.")
def main():
"""Main entry point"""
Path(Config.WEBDAV_ROOT).mkdir(parents=True, exist_ok=True)
(Path(Config.WEBDAV_ROOT) / 'users').mkdir(exist_ok=True)
db = Database(Config.DB_PATH)
asyncio.run(create_default_user(db))
app = asyncio.run(init_app())
print(f"Starting WebDAV Server on http://{Config.HOST}:{Config.PORT}")
web.run_app(app, host=Config.HOST, port=Config.PORT)
if __name__ == '__main__':
main()