Update.
This commit is contained in:
		
							parent
							
								
									2c90044185
								
							
						
					
					
						commit
						01846bf23f
					
				| @ -20,6 +20,7 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage | |||||||
| from app.app import Application as BaseApplication | from app.app import Application as BaseApplication | ||||||
| from jinja2 import FileSystemLoader | from jinja2 import FileSystemLoader | ||||||
| 
 | 
 | ||||||
|  | from snek.sssh import start_ssh_server | ||||||
| from snek.docs.app import Application as DocsApplication | from snek.docs.app import Application as DocsApplication | ||||||
| from snek.mapper import get_mappers | from snek.mapper import get_mappers | ||||||
| from snek.service import get_services | from snek.service import get_services | ||||||
| @ -95,15 +96,22 @@ class Application(BaseApplication): | |||||||
|         self.jinja2_env.add_extension(LinkifyExtension) |         self.jinja2_env.add_extension(LinkifyExtension) | ||||||
|         self.jinja2_env.add_extension(PythonExtension) |         self.jinja2_env.add_extension(PythonExtension) | ||||||
|         self.jinja2_env.add_extension(EmojiExtension) |         self.jinja2_env.add_extension(EmojiExtension) | ||||||
| 
 |         self.ssh_host = "0.0.0.0" | ||||||
|  |         self.ssh_port = 2042 | ||||||
|         self.setup_router() |         self.setup_router() | ||||||
|  |         self.ssh_server = None  | ||||||
|         self.executor = None |         self.executor = None | ||||||
|         self.cache = Cache(self) |         self.cache = Cache(self) | ||||||
|         self.services = get_services(app=self) |         self.services = get_services(app=self) | ||||||
|         self.mappers = get_mappers(app=self) |         self.mappers = get_mappers(app=self) | ||||||
|  |         self.on_startup.append(self.start_ssh_server) | ||||||
|         self.on_startup.append(self.prepare_asyncio) |         self.on_startup.append(self.prepare_asyncio) | ||||||
|         self.on_startup.append(self.prepare_database) |         self.on_startup.append(self.prepare_database) | ||||||
| 
 | 
 | ||||||
|  |     async def start_ssh_server(self, app): | ||||||
|  |         app.ssh_server = await start_ssh_server(app,app.ssh_host,app.ssh_port) | ||||||
|  |         asyncio.create_task(app.ssh_server.wait_closed()) | ||||||
|  | 
 | ||||||
|     async def prepare_asyncio(self, app): |     async def prepare_asyncio(self, app): | ||||||
|         # app.loop = asyncio.get_running_loop() |         # app.loop = asyncio.get_running_loop() | ||||||
|         app.executor = ThreadPoolExecutor(max_workers=200) |         app.executor = ThreadPoolExecutor(max_workers=200) | ||||||
|  | |||||||
| @ -32,10 +32,17 @@ class UserService(BaseService): | |||||||
|             user["color"] = await self.services.util.random_light_hex_color() |             user["color"] = await self.services.util.random_light_hex_color() | ||||||
|         return await super().save(user) |         return await super().save(user) | ||||||
| 
 | 
 | ||||||
|  |     def authenticate_sync(self,username,password): | ||||||
|  |         user = self.get_by_username_sync(username) | ||||||
|  |            | ||||||
|  |         if not user: | ||||||
|  |             return False | ||||||
|  |         if not security.verify_sync(password, user["password"]): | ||||||
|  |             return False | ||||||
|  |         return True  | ||||||
|  | 
 | ||||||
|     async def authenticate(self, username, password): |     async def authenticate(self, username, password): | ||||||
|         print(username, password, flush=True) |  | ||||||
|         success = await self.validate_login(username, password) |         success = await self.validate_login(username, password) | ||||||
|         print(success, flush=True) |  | ||||||
|         if not success: |         if not success: | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
| @ -61,6 +68,20 @@ class UserService(BaseService): | |||||||
|         if not path.exists(): |         if not path.exists(): | ||||||
|             return None |             return None | ||||||
|         return path |         return path | ||||||
|  |      | ||||||
|  |     def get_by_username_sync(self, username): | ||||||
|  |         user = self.mapper.db["user"].find_one(username=username, deleted_at=None) | ||||||
|  |         return dict(user) | ||||||
|  | 
 | ||||||
|  |     def get_home_folder_by_username(self, username): | ||||||
|  |         user = self.get_by_username_sync(username) | ||||||
|  |         folder = pathlib.Path(f"./drive/{user['uid']}") | ||||||
|  |         if not folder.exists(): | ||||||
|  |             try: | ||||||
|  |                 folder.mkdir(parents=True, exist_ok=True) | ||||||
|  |             except: | ||||||
|  |                 pass | ||||||
|  |         return folder | ||||||
| 
 | 
 | ||||||
|     async def get_home_folder(self, user_uid): |     async def get_home_folder(self, user_uid): | ||||||
|         folder = pathlib.Path(f"./drive/{user_uid}") |         folder = pathlib.Path(f"./drive/{user_uid}") | ||||||
|  | |||||||
							
								
								
									
										95
									
								
								src/snek/sssh.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								src/snek/sssh.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,95 @@ | |||||||
|  | import asyncio | ||||||
|  | import asyncssh | ||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | from pathlib import Path | ||||||
|  | import sys  | ||||||
|  | import pty  | ||||||
|  | 
 | ||||||
|  | global _app  | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def set_app(app): | ||||||
|  |     global _app | ||||||
|  |     _app = app  | ||||||
|  | 
 | ||||||
|  | def get_app(): | ||||||
|  |     return _app | ||||||
|  | 
 | ||||||
|  | logging.basicConfig( | ||||||
|  |     level=logging.DEBUG, | ||||||
|  |     format="%(asctime)s - %(levelname)s - %(message)s", | ||||||
|  |     handlers=[ | ||||||
|  |         logging.FileHandler("sftp_server.log"), | ||||||
|  |         logging.StreamHandler() | ||||||
|  |     ] | ||||||
|  | ) | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | roots = {} | ||||||
|  | 
 | ||||||
|  | class MySFTPServer(asyncssh.SFTPServer): | ||||||
|  |     | ||||||
|  |     def __init__(self, chan: asyncssh.SSHServerChannel): | ||||||
|  |         self.root = get_app().services.user.get_home_folder_by_username( | ||||||
|  |             chan.get_extra_info('username') | ||||||
|  |         ) | ||||||
|  |         self.root.mkdir(exist_ok=True) | ||||||
|  |         self.root = str(self.root) | ||||||
|  |         super().__init__(chan, chroot=self.root) | ||||||
|  | 
 | ||||||
|  |     def map_path(self, path): | ||||||
|  | 
 | ||||||
|  |         # Map client path to local filesystem path | ||||||
|  |         mapped_path = Path(self.root).joinpath(path.lstrip(b"/").decode()) | ||||||
|  |         #Path(mapped_path).mkdir(exist_ok=True) | ||||||
|  |         print(mapped_path) | ||||||
|  |         logger.debug(f"Mapping client path {path} to {mapped_path}") | ||||||
|  |         return str(mapped_path).encode() | ||||||
|  | 
 | ||||||
|  | class MySSHServer(asyncssh.SSHServer): | ||||||
|  |     def password_auth_supported(self): | ||||||
|  |         return True | ||||||
|  | 
 | ||||||
|  |     def validate_password(self, username, password): | ||||||
|  |         # Simple in-memory user database | ||||||
|  |          | ||||||
|  |         logger.debug(f"Validating credentials for user {username}") | ||||||
|  |         return get_app().services.user.authenticate_sync(username,password) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def start_ssh_server(app,host,port): | ||||||
|  |     set_app(app)  | ||||||
|  |     logger.info("Starting SFTP server setup") | ||||||
|  |     # Server host key (generate or load one) | ||||||
|  |     host_key_path = Path("drive") / ".ssh" / "sftp_server_key" | ||||||
|  |     host_key_path.parent.mkdir(exist_ok=True) | ||||||
|  |     try: | ||||||
|  |         if not host_key_path.exists(): | ||||||
|  |             logger.info(f"Generating new host key at {host_key_path}") | ||||||
|  |             key = asyncssh.generate_private_key("ecdsa-sha2-nistp256") | ||||||
|  |             key.write_private_key(host_key_path) | ||||||
|  |         else: | ||||||
|  |             logger.info(f"Loading existing host key from {host_key_path}") | ||||||
|  |             key = asyncssh.read_private_key(host_key_path) | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"Failed to generate or load host key: {e}") | ||||||
|  |         raise | ||||||
|  | 
 | ||||||
|  |     # Start the SSH server with SFTP support | ||||||
|  |     logger.info("Starting SFTP server on localhost:8022") | ||||||
|  |     try: | ||||||
|  |         x = await asyncssh.listen( | ||||||
|  |             host=host, | ||||||
|  |             port=port, | ||||||
|  |             #process_factory=handle_client, | ||||||
|  |             server_host_keys=[key], | ||||||
|  |             server_factory=MySSHServer, | ||||||
|  |             sftp_factory=MySFTPServer | ||||||
|  |         ) | ||||||
|  |         return x | ||||||
|  |     except Exception as e: | ||||||
|  |         logger.error(f"Failed to start SFTP server: {e}") | ||||||
|  |         raise | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @ -40,7 +40,7 @@ def uid(value: str = None, ns: str = DEFAULT_NS) -> str: | |||||||
|     return str(uuid.uuid5(UIDNS(ns), value)) |     return str(uuid.uuid5(UIDNS(ns), value)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def hash(data: str, salt: str = DEFAULT_SALT) -> str: | def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str: | ||||||
|     """Hash the given data with the specified salt using SHA-256. |     """Hash the given data with the specified salt using SHA-256. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
| @ -63,8 +63,10 @@ async def hash(data: str, salt: str = DEFAULT_SALT) -> str: | |||||||
|     obj = hashlib.sha256(salted) |     obj = hashlib.sha256(salted) | ||||||
|     return obj.hexdigest() |     return obj.hexdigest() | ||||||
| 
 | 
 | ||||||
|  | async def hash(data: str, salt: str = DEFAULT_SALT) -> str: | ||||||
|  |     return hash_sync(data, salt) | ||||||
| 
 | 
 | ||||||
| async def verify(string: str, hashed: str) -> bool: | def verify_sync(string: str, hashed: str) -> bool: | ||||||
|     """Verify if the given string matches the hashed value. |     """Verify if the given string matches the hashed value. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
| @ -74,4 +76,7 @@ async def verify(string: str, hashed: str) -> bool: | |||||||
|     Returns: |     Returns: | ||||||
|         bool: True if the string matches the hashed value, False otherwise. |         bool: True if the string matches the hashed value, False otherwise. | ||||||
|     """ |     """ | ||||||
|     return await hash(string) == hashed |     return hash_sync(string) == hashed | ||||||
|  | 
 | ||||||
|  | async def verify(string: str, hashed: str) -> bool: | ||||||
|  |     return verify_sync(string, hashed) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user