Update.
This commit is contained in:
parent
2c90044185
commit
01846bf23f
@ -20,6 +20,7 @@ from aiohttp_session.cookie_storage import EncryptedCookieStorage
|
|||||||
from app.app import Application as BaseApplication
|
from app.app import Application as BaseApplication
|
||||||
from jinja2 import FileSystemLoader
|
from jinja2 import FileSystemLoader
|
||||||
|
|
||||||
|
from snek.sssh import start_ssh_server
|
||||||
from snek.docs.app import Application as DocsApplication
|
from snek.docs.app import Application as DocsApplication
|
||||||
from snek.mapper import get_mappers
|
from snek.mapper import get_mappers
|
||||||
from snek.service import get_services
|
from snek.service import get_services
|
||||||
@ -95,15 +96,22 @@ class Application(BaseApplication):
|
|||||||
self.jinja2_env.add_extension(LinkifyExtension)
|
self.jinja2_env.add_extension(LinkifyExtension)
|
||||||
self.jinja2_env.add_extension(PythonExtension)
|
self.jinja2_env.add_extension(PythonExtension)
|
||||||
self.jinja2_env.add_extension(EmojiExtension)
|
self.jinja2_env.add_extension(EmojiExtension)
|
||||||
|
self.ssh_host = "0.0.0.0"
|
||||||
|
self.ssh_port = 2042
|
||||||
self.setup_router()
|
self.setup_router()
|
||||||
|
self.ssh_server = None
|
||||||
self.executor = None
|
self.executor = None
|
||||||
self.cache = Cache(self)
|
self.cache = Cache(self)
|
||||||
self.services = get_services(app=self)
|
self.services = get_services(app=self)
|
||||||
self.mappers = get_mappers(app=self)
|
self.mappers = get_mappers(app=self)
|
||||||
|
self.on_startup.append(self.start_ssh_server)
|
||||||
self.on_startup.append(self.prepare_asyncio)
|
self.on_startup.append(self.prepare_asyncio)
|
||||||
self.on_startup.append(self.prepare_database)
|
self.on_startup.append(self.prepare_database)
|
||||||
|
|
||||||
|
async def start_ssh_server(self, app):
|
||||||
|
app.ssh_server = await start_ssh_server(app,app.ssh_host,app.ssh_port)
|
||||||
|
asyncio.create_task(app.ssh_server.wait_closed())
|
||||||
|
|
||||||
async def prepare_asyncio(self, app):
|
async def prepare_asyncio(self, app):
|
||||||
# app.loop = asyncio.get_running_loop()
|
# app.loop = asyncio.get_running_loop()
|
||||||
app.executor = ThreadPoolExecutor(max_workers=200)
|
app.executor = ThreadPoolExecutor(max_workers=200)
|
||||||
|
@ -32,10 +32,17 @@ class UserService(BaseService):
|
|||||||
user["color"] = await self.services.util.random_light_hex_color()
|
user["color"] = await self.services.util.random_light_hex_color()
|
||||||
return await super().save(user)
|
return await super().save(user)
|
||||||
|
|
||||||
|
def authenticate_sync(self,username,password):
|
||||||
|
user = self.get_by_username_sync(username)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
if not security.verify_sync(password, user["password"]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
async def authenticate(self, username, password):
|
async def authenticate(self, username, password):
|
||||||
print(username, password, flush=True)
|
|
||||||
success = await self.validate_login(username, password)
|
success = await self.validate_login(username, password)
|
||||||
print(success, flush=True)
|
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -61,6 +68,20 @@ class UserService(BaseService):
|
|||||||
if not path.exists():
|
if not path.exists():
|
||||||
return None
|
return None
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
def get_by_username_sync(self, username):
|
||||||
|
user = self.mapper.db["user"].find_one(username=username, deleted_at=None)
|
||||||
|
return dict(user)
|
||||||
|
|
||||||
|
def get_home_folder_by_username(self, username):
|
||||||
|
user = self.get_by_username_sync(username)
|
||||||
|
folder = pathlib.Path(f"./drive/{user['uid']}")
|
||||||
|
if not folder.exists():
|
||||||
|
try:
|
||||||
|
folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return folder
|
||||||
|
|
||||||
async def get_home_folder(self, user_uid):
|
async def get_home_folder(self, user_uid):
|
||||||
folder = pathlib.Path(f"./drive/{user_uid}")
|
folder = pathlib.Path(f"./drive/{user_uid}")
|
||||||
|
95
src/snek/sssh.py
Normal file
95
src/snek/sssh.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import asyncio
|
||||||
|
import asyncssh
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import pty
|
||||||
|
|
||||||
|
global _app
|
||||||
|
|
||||||
|
|
||||||
|
def set_app(app):
|
||||||
|
global _app
|
||||||
|
_app = app
|
||||||
|
|
||||||
|
def get_app():
|
||||||
|
return _app
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("sftp_server.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
roots = {}
|
||||||
|
|
||||||
|
class MySFTPServer(asyncssh.SFTPServer):
|
||||||
|
|
||||||
|
def __init__(self, chan: asyncssh.SSHServerChannel):
|
||||||
|
self.root = get_app().services.user.get_home_folder_by_username(
|
||||||
|
chan.get_extra_info('username')
|
||||||
|
)
|
||||||
|
self.root.mkdir(exist_ok=True)
|
||||||
|
self.root = str(self.root)
|
||||||
|
super().__init__(chan, chroot=self.root)
|
||||||
|
|
||||||
|
def map_path(self, path):
|
||||||
|
|
||||||
|
# Map client path to local filesystem path
|
||||||
|
mapped_path = Path(self.root).joinpath(path.lstrip(b"/").decode())
|
||||||
|
#Path(mapped_path).mkdir(exist_ok=True)
|
||||||
|
print(mapped_path)
|
||||||
|
logger.debug(f"Mapping client path {path} to {mapped_path}")
|
||||||
|
return str(mapped_path).encode()
|
||||||
|
|
||||||
|
class MySSHServer(asyncssh.SSHServer):
|
||||||
|
def password_auth_supported(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_password(self, username, password):
|
||||||
|
# Simple in-memory user database
|
||||||
|
|
||||||
|
logger.debug(f"Validating credentials for user {username}")
|
||||||
|
return get_app().services.user.authenticate_sync(username,password)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_ssh_server(app,host,port):
|
||||||
|
set_app(app)
|
||||||
|
logger.info("Starting SFTP server setup")
|
||||||
|
# Server host key (generate or load one)
|
||||||
|
host_key_path = Path("drive") / ".ssh" / "sftp_server_key"
|
||||||
|
host_key_path.parent.mkdir(exist_ok=True)
|
||||||
|
try:
|
||||||
|
if not host_key_path.exists():
|
||||||
|
logger.info(f"Generating new host key at {host_key_path}")
|
||||||
|
key = asyncssh.generate_private_key("ecdsa-sha2-nistp256")
|
||||||
|
key.write_private_key(host_key_path)
|
||||||
|
else:
|
||||||
|
logger.info(f"Loading existing host key from {host_key_path}")
|
||||||
|
key = asyncssh.read_private_key(host_key_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate or load host key: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Start the SSH server with SFTP support
|
||||||
|
logger.info("Starting SFTP server on localhost:8022")
|
||||||
|
try:
|
||||||
|
x = await asyncssh.listen(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
#process_factory=handle_client,
|
||||||
|
server_host_keys=[key],
|
||||||
|
server_factory=MySSHServer,
|
||||||
|
sftp_factory=MySFTPServer
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start SFTP server: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ def uid(value: str = None, ns: str = DEFAULT_NS) -> str:
|
|||||||
return str(uuid.uuid5(UIDNS(ns), value))
|
return str(uuid.uuid5(UIDNS(ns), value))
|
||||||
|
|
||||||
|
|
||||||
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||||
"""Hash the given data with the specified salt using SHA-256.
|
"""Hash the given data with the specified salt using SHA-256.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -63,8 +63,10 @@ async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
|||||||
obj = hashlib.sha256(salted)
|
obj = hashlib.sha256(salted)
|
||||||
return obj.hexdigest()
|
return obj.hexdigest()
|
||||||
|
|
||||||
|
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
|
||||||
|
return hash_sync(data, salt)
|
||||||
|
|
||||||
async def verify(string: str, hashed: str) -> bool:
|
def verify_sync(string: str, hashed: str) -> bool:
|
||||||
"""Verify if the given string matches the hashed value.
|
"""Verify if the given string matches the hashed value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -74,4 +76,7 @@ async def verify(string: str, hashed: str) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the string matches the hashed value, False otherwise.
|
bool: True if the string matches the hashed value, False otherwise.
|
||||||
"""
|
"""
|
||||||
return await hash(string) == hashed
|
return hash_sync(string) == hashed
|
||||||
|
|
||||||
|
async def verify(string: str, hashed: str) -> bool:
|
||||||
|
return verify_sync(string, hashed)
|
||||||
|
Loading…
Reference in New Issue
Block a user