Update.
This commit is contained in:
		
							parent
							
								
									2c90044185
								
							
						
					
					
						commit
						01846bf23f
					
				| @ -20,6 +20,7 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage | ||||
| from app.app import Application as BaseApplication | ||||
| from jinja2 import FileSystemLoader | ||||
| 
 | ||||
| from snek.sssh import start_ssh_server | ||||
| from snek.docs.app import Application as DocsApplication | ||||
| from snek.mapper import get_mappers | ||||
| from snek.service import get_services | ||||
| @ -95,15 +96,22 @@ class Application(BaseApplication): | ||||
|         self.jinja2_env.add_extension(LinkifyExtension) | ||||
|         self.jinja2_env.add_extension(PythonExtension) | ||||
|         self.jinja2_env.add_extension(EmojiExtension) | ||||
| 
 | ||||
|         self.ssh_host = "0.0.0.0" | ||||
|         self.ssh_port = 2042 | ||||
|         self.setup_router() | ||||
|         self.ssh_server = None  | ||||
|         self.executor = None | ||||
|         self.cache = Cache(self) | ||||
|         self.services = get_services(app=self) | ||||
|         self.mappers = get_mappers(app=self) | ||||
|         self.on_startup.append(self.start_ssh_server) | ||||
|         self.on_startup.append(self.prepare_asyncio) | ||||
|         self.on_startup.append(self.prepare_database) | ||||
| 
 | ||||
|     async def start_ssh_server(self, app): | ||||
|         app.ssh_server = await start_ssh_server(app,app.ssh_host,app.ssh_port) | ||||
|         asyncio.create_task(app.ssh_server.wait_closed()) | ||||
| 
 | ||||
|     async def prepare_asyncio(self, app): | ||||
|         # app.loop = asyncio.get_running_loop() | ||||
|         app.executor = ThreadPoolExecutor(max_workers=200) | ||||
|  | ||||
| @ -32,10 +32,17 @@ class UserService(BaseService): | ||||
|             user["color"] = await self.services.util.random_light_hex_color() | ||||
|         return await super().save(user) | ||||
| 
 | ||||
|     def authenticate_sync(self,username,password): | ||||
|         user = self.get_by_username_sync(username) | ||||
|            | ||||
|         if not user: | ||||
|             return False | ||||
|         if not security.verify_sync(password, user["password"]): | ||||
|             return False | ||||
|         return True  | ||||
| 
 | ||||
|     async def authenticate(self, username, password): | ||||
|         print(username, password, flush=True) | ||||
|         success = await self.validate_login(username, password) | ||||
|         print(success, flush=True) | ||||
|         if not success: | ||||
|             return None | ||||
| 
 | ||||
| @ -62,6 +69,20 @@ class UserService(BaseService): | ||||
|             return None | ||||
|         return path | ||||
|      | ||||
|     def get_by_username_sync(self, username): | ||||
|         user = self.mapper.db["user"].find_one(username=username, deleted_at=None) | ||||
|         return dict(user) | ||||
| 
 | ||||
|     def get_home_folder_by_username(self, username): | ||||
|         user = self.get_by_username_sync(username) | ||||
|         folder = pathlib.Path(f"./drive/{user['uid']}") | ||||
|         if not folder.exists(): | ||||
|             try: | ||||
|                 folder.mkdir(parents=True, exist_ok=True) | ||||
|             except: | ||||
|                 pass | ||||
|         return folder | ||||
| 
 | ||||
|     async def get_home_folder(self, user_uid): | ||||
|         folder = pathlib.Path(f"./drive/{user_uid}") | ||||
|         if not folder.exists(): | ||||
|  | ||||
							
								
								
									
										95
									
								
								src/snek/sssh.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								src/snek/sssh.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,95 @@ | ||||
| import asyncio | ||||
| import asyncssh | ||||
| import logging | ||||
| import os | ||||
| from pathlib import Path | ||||
| import sys  | ||||
| import pty  | ||||
| 
 | ||||
| global _app  | ||||
| 
 | ||||
| 
 | ||||
| def set_app(app): | ||||
|     global _app | ||||
|     _app = app  | ||||
| 
 | ||||
| def get_app(): | ||||
|     return _app | ||||
| 
 | ||||
| logging.basicConfig( | ||||
|     level=logging.DEBUG, | ||||
|     format="%(asctime)s - %(levelname)s - %(message)s", | ||||
|     handlers=[ | ||||
|         logging.FileHandler("sftp_server.log"), | ||||
|         logging.StreamHandler() | ||||
|     ] | ||||
| ) | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| roots = {} | ||||
| 
 | ||||
| class MySFTPServer(asyncssh.SFTPServer): | ||||
|     | ||||
|     def __init__(self, chan: asyncssh.SSHServerChannel): | ||||
|         self.root = get_app().services.user.get_home_folder_by_username( | ||||
|             chan.get_extra_info('username') | ||||
|         ) | ||||
|         self.root.mkdir(exist_ok=True) | ||||
|         self.root = str(self.root) | ||||
|         super().__init__(chan, chroot=self.root) | ||||
| 
 | ||||
|     def map_path(self, path): | ||||
| 
 | ||||
|         # Map client path to local filesystem path | ||||
|         mapped_path = Path(self.root).joinpath(path.lstrip(b"/").decode()) | ||||
|         #Path(mapped_path).mkdir(exist_ok=True) | ||||
|         print(mapped_path) | ||||
|         logger.debug(f"Mapping client path {path} to {mapped_path}") | ||||
|         return str(mapped_path).encode() | ||||
| 
 | ||||
| class MySSHServer(asyncssh.SSHServer): | ||||
|     def password_auth_supported(self): | ||||
|         return True | ||||
| 
 | ||||
|     def validate_password(self, username, password): | ||||
|         # Simple in-memory user database | ||||
|          | ||||
|         logger.debug(f"Validating credentials for user {username}") | ||||
|         return get_app().services.user.authenticate_sync(username,password) | ||||
| 
 | ||||
| 
 | ||||
| async def start_ssh_server(app,host,port): | ||||
|     set_app(app)  | ||||
|     logger.info("Starting SFTP server setup") | ||||
|     # Server host key (generate or load one) | ||||
|     host_key_path = Path("drive") / ".ssh" / "sftp_server_key" | ||||
|     host_key_path.parent.mkdir(exist_ok=True) | ||||
|     try: | ||||
|         if not host_key_path.exists(): | ||||
|             logger.info(f"Generating new host key at {host_key_path}") | ||||
|             key = asyncssh.generate_private_key("ecdsa-sha2-nistp256") | ||||
|             key.write_private_key(host_key_path) | ||||
|         else: | ||||
|             logger.info(f"Loading existing host key from {host_key_path}") | ||||
|             key = asyncssh.read_private_key(host_key_path) | ||||
|     except Exception as e: | ||||
|         logger.error(f"Failed to generate or load host key: {e}") | ||||
|         raise | ||||
| 
 | ||||
|     # Start the SSH server with SFTP support | ||||
|     logger.info("Starting SFTP server on localhost:8022") | ||||
|     try: | ||||
|         x = await asyncssh.listen( | ||||
|             host=host, | ||||
|             port=port, | ||||
|             #process_factory=handle_client, | ||||
|             server_host_keys=[key], | ||||
|             server_factory=MySSHServer, | ||||
|             sftp_factory=MySFTPServer | ||||
|         ) | ||||
|         return x | ||||
|     except Exception as e: | ||||
|         logger.error(f"Failed to start SFTP server: {e}") | ||||
|         raise | ||||
| 
 | ||||
| 
 | ||||
| @ -40,7 +40,7 @@ def uid(value: str = None, ns: str = DEFAULT_NS) -> str: | ||||
|     return str(uuid.uuid5(UIDNS(ns), value)) | ||||
| 
 | ||||
| 
 | ||||
| async def hash(data: str, salt: str = DEFAULT_SALT) -> str: | ||||
| def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str: | ||||
|     """Hash the given data with the specified salt using SHA-256. | ||||
| 
 | ||||
|     Args: | ||||
| @ -63,8 +63,10 @@ async def hash(data: str, salt: str = DEFAULT_SALT) -> str: | ||||
|     obj = hashlib.sha256(salted) | ||||
|     return obj.hexdigest() | ||||
| 
 | ||||
| async def hash(data: str, salt: str = DEFAULT_SALT) -> str: | ||||
|     return hash_sync(data, salt) | ||||
| 
 | ||||
| async def verify(string: str, hashed: str) -> bool: | ||||
| def verify_sync(string: str, hashed: str) -> bool: | ||||
|     """Verify if the given string matches the hashed value. | ||||
| 
 | ||||
|     Args: | ||||
| @ -74,4 +76,7 @@ async def verify(string: str, hashed: str) -> bool: | ||||
|     Returns: | ||||
|         bool: True if the string matches the hashed value, False otherwise. | ||||
|     """ | ||||
|     return await hash(string) == hashed | ||||
|     return hash_sync(string) == hashed | ||||
| 
 | ||||
| async def verify(string: str, hashed: str) -> bool: | ||||
|     return verify_sync(string, hashed) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user