Update.
This commit is contained in:
parent
2c90044185
commit
01846bf23f
@ -20,6 +20,7 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage
|
||||
from app.app import Application as BaseApplication
|
||||
from jinja2 import FileSystemLoader
|
||||
|
||||
from snek.sssh import start_ssh_server
|
||||
from snek.docs.app import Application as DocsApplication
|
||||
from snek.mapper import get_mappers
|
||||
from snek.service import get_services
|
||||
@ -95,15 +96,22 @@ class Application(BaseApplication):
|
||||
self.jinja2_env.add_extension(LinkifyExtension)
|
||||
self.jinja2_env.add_extension(PythonExtension)
|
||||
self.jinja2_env.add_extension(EmojiExtension)
|
||||
|
||||
self.ssh_host = "0.0.0.0"
|
||||
self.ssh_port = 2042
|
||||
self.setup_router()
|
||||
self.ssh_server = None
|
||||
self.executor = None
|
||||
self.cache = Cache(self)
|
||||
self.services = get_services(app=self)
|
||||
self.mappers = get_mappers(app=self)
|
||||
self.on_startup.append(self.start_ssh_server)
|
||||
self.on_startup.append(self.prepare_asyncio)
|
||||
self.on_startup.append(self.prepare_database)
|
||||
|
||||
async def start_ssh_server(self, app):
|
||||
app.ssh_server = await start_ssh_server(app,app.ssh_host,app.ssh_port)
|
||||
asyncio.create_task(app.ssh_server.wait_closed())
|
||||
|
||||
async def prepare_asyncio(self, app):
|
||||
# app.loop = asyncio.get_running_loop()
|
||||
app.executor = ThreadPoolExecutor(max_workers=200)
|
||||
|
@ -32,10 +32,17 @@ class UserService(BaseService):
|
||||
user["color"] = await self.services.util.random_light_hex_color()
|
||||
return await super().save(user)
|
||||
|
||||
def authenticate_sync(self,username,password):
|
||||
user = self.get_by_username_sync(username)
|
||||
|
||||
if not user:
|
||||
return False
|
||||
if not security.verify_sync(password, user["password"]):
|
||||
return False
|
||||
return True
|
||||
|
||||
async def authenticate(self, username, password):
|
||||
print(username, password, flush=True)
|
||||
success = await self.validate_login(username, password)
|
||||
print(success, flush=True)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
@ -61,6 +68,20 @@ class UserService(BaseService):
|
||||
if not path.exists():
|
||||
return None
|
||||
return path
|
||||
|
||||
def get_by_username_sync(self, username):
|
||||
user = self.mapper.db["user"].find_one(username=username, deleted_at=None)
|
||||
return dict(user)
|
||||
|
||||
def get_home_folder_by_username(self, username):
|
||||
user = self.get_by_username_sync(username)
|
||||
folder = pathlib.Path(f"./drive/{user['uid']}")
|
||||
if not folder.exists():
|
||||
try:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
except:
|
||||
pass
|
||||
return folder
|
||||
|
||||
async def get_home_folder(self, user_uid):
|
||||
folder = pathlib.Path(f"./drive/{user_uid}")
|
||||
|
95
src/snek/sssh.py
Normal file
95
src/snek/sssh.py
Normal file
@ -0,0 +1,95 @@
|
||||
import asyncio
|
||||
import asyncssh
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import pty
|
||||
|
||||
global _app
|
||||
|
||||
|
||||
def set_app(app):
|
||||
global _app
|
||||
_app = app
|
||||
|
||||
def get_app():
|
||||
return _app
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler("sftp_server.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
roots = {}
|
||||
|
||||
class MySFTPServer(asyncssh.SFTPServer):
|
||||
|
||||
def __init__(self, chan: asyncssh.SSHServerChannel):
|
||||
self.root = get_app().services.user.get_home_folder_by_username(
|
||||
chan.get_extra_info('username')
|
||||
)
|
||||
self.root.mkdir(exist_ok=True)
|
||||
self.root = str(self.root)
|
||||
super().__init__(chan, chroot=self.root)
|
||||
|
||||
def map_path(self, path):
|
||||
|
||||
# Map client path to local filesystem path
|
||||
mapped_path = Path(self.root).joinpath(path.lstrip(b"/").decode())
|
||||
#Path(mapped_path).mkdir(exist_ok=True)
|
||||
print(mapped_path)
|
||||
logger.debug(f"Mapping client path {path} to {mapped_path}")
|
||||
return str(mapped_path).encode()
|
||||
|
||||
class MySSHServer(asyncssh.SSHServer):
|
||||
def password_auth_supported(self):
|
||||
return True
|
||||
|
||||
def validate_password(self, username, password):
|
||||
# Simple in-memory user database
|
||||
|
||||
logger.debug(f"Validating credentials for user {username}")
|
||||
return get_app().services.user.authenticate_sync(username,password)
|
||||
|
||||
|
||||
async def start_ssh_server(app,host,port):
|
||||
set_app(app)
|
||||
logger.info("Starting SFTP server setup")
|
||||
# Server host key (generate or load one)
|
||||
host_key_path = Path("drive") / ".ssh" / "sftp_server_key"
|
||||
host_key_path.parent.mkdir(exist_ok=True)
|
||||
try:
|
||||
if not host_key_path.exists():
|
||||
logger.info(f"Generating new host key at {host_key_path}")
|
||||
key = asyncssh.generate_private_key("ecdsa-sha2-nistp256")
|
||||
key.write_private_key(host_key_path)
|
||||
else:
|
||||
logger.info(f"Loading existing host key from {host_key_path}")
|
||||
key = asyncssh.read_private_key(host_key_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate or load host key: {e}")
|
||||
raise
|
||||
|
||||
# Start the SSH server with SFTP support
|
||||
logger.info("Starting SFTP server on localhost:8022")
|
||||
try:
|
||||
x = await asyncssh.listen(
|
||||
host=host,
|
||||
port=port,
|
||||
#process_factory=handle_client,
|
||||
server_host_keys=[key],
|
||||
server_factory=MySSHServer,
|
||||
sftp_factory=MySFTPServer
|
||||
)
|
||||
return x
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start SFTP server: {e}")
|
||||
raise
|
||||
|
||||
|
@ -40,7 +40,7 @@ def uid(value: str = None, ns: str = DEFAULT_NS) -> str:
|
||||
return str(uuid.uuid5(UIDNS(ns), value))
|
||||
|
||||
|
||||
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||
def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||
"""Hash the given data with the specified salt using SHA-256.
|
||||
|
||||
Args:
|
||||
@ -63,8 +63,10 @@ async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||
obj = hashlib.sha256(salted)
|
||||
return obj.hexdigest()
|
||||
|
||||
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||
return hash_sync(data, salt)
|
||||
|
||||
async def verify(string: str, hashed: str) -> bool:
|
||||
def verify_sync(string: str, hashed: str) -> bool:
|
||||
"""Verify if the given string matches the hashed value.
|
||||
|
||||
Args:
|
||||
@ -74,4 +76,7 @@ async def verify(string: str, hashed: str) -> bool:
|
||||
Returns:
|
||||
bool: True if the string matches the hashed value, False otherwise.
|
||||
"""
|
||||
return await hash(string) == hashed
|
||||
return hash_sync(string) == hashed
|
||||
|
||||
async def verify(string: str, hashed: str) -> bool:
|
||||
return verify_sync(string, hashed)
|
||||
|
Loading…
Reference in New Issue
Block a user