From 01846bf23f7883007b99a2e100240bf3b35b30f2 Mon Sep 17 00:00:00 2001 From: retoor Date: Sun, 11 May 2025 07:52:22 +0200 Subject: [PATCH] Update. --- src/snek/app.py | 10 +++- src/snek/service/user.py | 25 +++++++++- src/snek/sssh.py | 95 +++++++++++++++++++++++++++++++++++++ src/snek/system/security.py | 11 +++-- 4 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 src/snek/sssh.py diff --git a/src/snek/app.py b/src/snek/app.py index 0a6b018..706f74e 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -20,6 +20,7 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage from app.app import Application as BaseApplication from jinja2 import FileSystemLoader +from snek.sssh import start_ssh_server from snek.docs.app import Application as DocsApplication from snek.mapper import get_mappers from snek.service import get_services @@ -95,15 +96,22 @@ class Application(BaseApplication): self.jinja2_env.add_extension(LinkifyExtension) self.jinja2_env.add_extension(PythonExtension) self.jinja2_env.add_extension(EmojiExtension) - + self.ssh_host = "0.0.0.0" + self.ssh_port = 2042 self.setup_router() + self.ssh_server = None self.executor = None self.cache = Cache(self) self.services = get_services(app=self) self.mappers = get_mappers(app=self) + self.on_startup.append(self.start_ssh_server) self.on_startup.append(self.prepare_asyncio) self.on_startup.append(self.prepare_database) + async def start_ssh_server(self, app): + app.ssh_server = await start_ssh_server(app,app.ssh_host,app.ssh_port) + asyncio.create_task(app.ssh_server.wait_closed()) + async def prepare_asyncio(self, app): # app.loop = asyncio.get_running_loop() app.executor = ThreadPoolExecutor(max_workers=200) diff --git a/src/snek/service/user.py b/src/snek/service/user.py index 76e6d1c..13f660f 100644 --- a/src/snek/service/user.py +++ b/src/snek/service/user.py @@ -32,10 +32,17 @@ class UserService(BaseService): user["color"] = await self.services.util.random_light_hex_color() return await super().save(user) + def authenticate_sync(self,username,password): + user = self.get_by_username_sync(username) + + if not user: + return False + if not security.verify_sync(password, user["password"]): + return False + return True + async def authenticate(self, username, password): - print(username, password, flush=True) success = await self.validate_login(username, password) - print(success, flush=True) if not success: return None @@ -61,6 +68,20 @@ class UserService(BaseService): if not path.exists(): return None return path + + def get_by_username_sync(self, username): + user = self.mapper.db["user"].find_one(username=username, deleted_at=None) + return dict(user) + + def get_home_folder_by_username(self, username): + user = self.get_by_username_sync(username) + folder = pathlib.Path(f"./drive/{user['uid']}") + if not folder.exists(): + try: + folder.mkdir(parents=True, exist_ok=True) + except: + pass + return folder async def get_home_folder(self, user_uid): folder = pathlib.Path(f"./drive/{user_uid}") diff --git a/src/snek/sssh.py b/src/snek/sssh.py new file mode 100644 index 0000000..0106a17 --- /dev/null +++ b/src/snek/sssh.py @@ -0,0 +1,95 @@ +import asyncio +import asyncssh +import logging +import os +from pathlib import Path +import sys +import pty + +global _app + + +def set_app(app): + global _app + _app = app + +def get_app(): + return _app + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler("sftp_server.log"), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +roots = {} + +class MySFTPServer(asyncssh.SFTPServer): + + def __init__(self, chan: asyncssh.SSHServerChannel): + self.root = get_app().services.user.get_home_folder_by_username( + chan.get_extra_info('username') + ) + self.root.mkdir(exist_ok=True) + self.root = str(self.root) + super().__init__(chan, chroot=self.root) + + def map_path(self, path): + + # Map client path to local filesystem path + mapped_path = Path(self.root).joinpath(path.lstrip(b"/").decode()) + #Path(mapped_path).mkdir(exist_ok=True) + print(mapped_path) + logger.debug(f"Mapping client path {path} to {mapped_path}") + return str(mapped_path).encode() + +class MySSHServer(asyncssh.SSHServer): + def password_auth_supported(self): + return True + + def validate_password(self, username, password): + # Simple in-memory user database + + logger.debug(f"Validating credentials for user {username}") + return get_app().services.user.authenticate_sync(username,password) + + +async def start_ssh_server(app,host,port): + set_app(app) + logger.info("Starting SFTP server setup") + # Server host key (generate or load one) + host_key_path = Path("drive") / ".ssh" / "sftp_server_key" + host_key_path.parent.mkdir(exist_ok=True) + try: + if not host_key_path.exists(): + logger.info(f"Generating new host key at {host_key_path}") + key = asyncssh.generate_private_key("ecdsa-sha2-nistp256") + key.write_private_key(host_key_path) + else: + logger.info(f"Loading existing host key from {host_key_path}") + key = asyncssh.read_private_key(host_key_path) + except Exception as e: + logger.error(f"Failed to generate or load host key: {e}") + raise + + # Start the SSH server with SFTP support + logger.info("Starting SFTP server on localhost:8022") + try: + x = await asyncssh.listen( + host=host, + port=port, + #process_factory=handle_client, + server_host_keys=[key], + server_factory=MySSHServer, + sftp_factory=MySFTPServer + ) + return x + except Exception as e: + logger.error(f"Failed to start SFTP server: {e}") + raise + + diff --git a/src/snek/system/security.py b/src/snek/system/security.py index 43b61fe..4ead284 100644 --- a/src/snek/system/security.py +++ b/src/snek/system/security.py @@ -40,7 +40,7 @@ def uid(value: str = None, ns: str = DEFAULT_NS) -> str: return str(uuid.uuid5(UIDNS(ns), value)) -async def hash(data: str, salt: str = DEFAULT_SALT) -> str: +def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str: """Hash the given data with the specified salt using SHA-256. Args: @@ -63,8 +63,10 @@ async def hash(data: str, salt: str = DEFAULT_SALT) -> str: obj = hashlib.sha256(salted) return obj.hexdigest() +async def hash(data: str, salt: str = DEFAULT_SALT) -> str: + return hash_sync(data, salt) -async def verify(string: str, hashed: str) -> bool: +def verify_sync(string: str, hashed: str) -> bool: """Verify if the given string matches the hashed value. Args: @@ -74,4 +76,7 @@ async def verify(string: str, hashed: str) -> bool: Returns: bool: True if the string matches the hashed value, False otherwise. """ - return await hash(string) == hashed + return hash_sync(string) == hashed + +async def verify(string: str, hashed: str) -> bool: + return verify_sync(string, hashed)