Compare commits
	
		
			3 Commits
		
	
	
		
			main
			...
			bugfix/ret
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					51b4898078 | ||
| 
						 | 
					40f9646251 | ||
| 5e99e894e9 | 
@ -40,7 +40,8 @@ dependencies = [
 | 
				
			|||||||
    "pillow-heif",
 | 
					    "pillow-heif",
 | 
				
			||||||
    "IP2Location",
 | 
					    "IP2Location",
 | 
				
			||||||
    "bleach",
 | 
					    "bleach",
 | 
				
			||||||
    "sentry-sdk"
 | 
					    "sentry-sdk",
 | 
				
			||||||
 | 
					    "aiosqlite"
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.setuptools.packages.find]
 | 
					[tool.setuptools.packages.find]
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,3 @@
 | 
				
			|||||||
import logging 
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logging.basicConfig(level=logging.INFO)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
import sqlite3
 | 
					import sqlite3
 | 
				
			||||||
 | 
				
			|||||||
@ -9,13 +9,10 @@ from contextlib import asynccontextmanager
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from snek import snode
 | 
					from snek import snode
 | 
				
			||||||
from snek.view.threads import ThreadsView
 | 
					from snek.view.threads import ThreadsView
 | 
				
			||||||
 | 
					from snek.system.ads import AsyncDataSet
 | 
				
			||||||
logging.basicConfig(level=logging.DEBUG)
 | 
					logging.basicConfig(level=logging.DEBUG)
 | 
				
			||||||
from concurrent.futures import ThreadPoolExecutor
 | 
					from concurrent.futures import ThreadPoolExecutor
 | 
				
			||||||
from ipaddress import ip_address
 | 
					from ipaddress import ip_address
 | 
				
			||||||
import time
 | 
					 | 
				
			||||||
import uuid
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import IP2Location
 | 
					import IP2Location
 | 
				
			||||||
from aiohttp import web
 | 
					from aiohttp import web
 | 
				
			||||||
@ -34,6 +31,7 @@ from snek.sgit import GitApplication
 | 
				
			|||||||
from snek.sssh import start_ssh_server
 | 
					from snek.sssh import start_ssh_server
 | 
				
			||||||
from snek.system import http
 | 
					from snek.system import http
 | 
				
			||||||
from snek.system.cache import Cache
 | 
					from snek.system.cache import Cache
 | 
				
			||||||
 | 
					from snek.system.stats import middleware as stats_middleware, create_stats_structure, stats_handler
 | 
				
			||||||
from snek.system.markdown import MarkdownExtension
 | 
					from snek.system.markdown import MarkdownExtension
 | 
				
			||||||
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
 | 
					from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
 | 
				
			||||||
from snek.system.profiler import profiler_handler
 | 
					from snek.system.profiler import profiler_handler
 | 
				
			||||||
@ -43,7 +41,6 @@ from snek.system.template import (
 | 
				
			|||||||
    PythonExtension,
 | 
					    PythonExtension,
 | 
				
			||||||
    sanitize_html,
 | 
					    sanitize_html,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from snek.view.new import NewView
 | 
					 | 
				
			||||||
from snek.view.about import AboutHTMLView, AboutMDView
 | 
					from snek.view.about import AboutHTMLView, AboutMDView
 | 
				
			||||||
from snek.view.avatar import AvatarView
 | 
					from snek.view.avatar import AvatarView
 | 
				
			||||||
from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
 | 
					from snek.view.channel import ChannelAttachmentView,ChannelAttachmentUploadView, ChannelView
 | 
				
			||||||
@ -110,7 +107,7 @@ async def ip2location_middleware(request, handler):
 | 
				
			|||||||
    user["city"]
 | 
					    user["city"]
 | 
				
			||||||
    if user["city"] != location.city:
 | 
					    if user["city"] != location.city:
 | 
				
			||||||
        user["country_long"] = location.country
 | 
					        user["country_long"] = location.country
 | 
				
			||||||
        user["country_short"] = locaion.country_short
 | 
					        user["country_short"] = location.country_short
 | 
				
			||||||
        user["city"] = location.city
 | 
					        user["city"] = location.city
 | 
				
			||||||
        user["region"] = location.region
 | 
					        user["region"] = location.region
 | 
				
			||||||
        user["latitude"] = location.latitude
 | 
					        user["latitude"] = location.latitude
 | 
				
			||||||
@ -131,7 +128,10 @@ async def trailing_slash_middleware(request, handler):
 | 
				
			|||||||
class Application(BaseApplication):
 | 
					class Application(BaseApplication):
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        middlewares = [
 | 
					        middlewares = [
 | 
				
			||||||
 | 
					            stats_middleware,
 | 
				
			||||||
            cors_middleware,
 | 
					            cors_middleware,
 | 
				
			||||||
 | 
					            web.normalize_path_middleware(merge_slashes=True),
 | 
				
			||||||
 | 
					            ip2location_middleware,
 | 
				
			||||||
            csp_middleware,
 | 
					            csp_middleware,
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
 | 
					        self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
 | 
				
			||||||
@ -142,6 +142,7 @@ class Application(BaseApplication):
 | 
				
			|||||||
            client_max_size=1024 * 1024 * 1024 * 5 * args,
 | 
					            client_max_size=1024 * 1024 * 1024 * 5 * args,
 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        self.db = AsyncDataSet(kwargs["db_path"].replace("sqlite:///", ""))
 | 
				
			||||||
        session_setup(self, EncryptedCookieStorage(SESSION_KEY))
 | 
					        session_setup(self, EncryptedCookieStorage(SESSION_KEY))
 | 
				
			||||||
        self.tasks = asyncio.Queue()
 | 
					        self.tasks = asyncio.Queue()
 | 
				
			||||||
        self._middlewares.append(session_middleware)
 | 
					        self._middlewares.append(session_middleware)
 | 
				
			||||||
@ -164,17 +165,22 @@ class Application(BaseApplication):
 | 
				
			|||||||
        self.mappers = get_mappers(app=self)
 | 
					        self.mappers = get_mappers(app=self)
 | 
				
			||||||
        self.broadcast_service = None
 | 
					        self.broadcast_service = None
 | 
				
			||||||
        self.user_availability_service_task = None
 | 
					        self.user_availability_service_task = None
 | 
				
			||||||
        
 | 
					
 | 
				
			||||||
        self.setup_router()
 | 
					        self.setup_router()
 | 
				
			||||||
        base_path = pathlib.Path(__file__).parent
 | 
					        base_path = pathlib.Path(__file__).parent
 | 
				
			||||||
        self.ip2location = IP2Location.IP2Location(
 | 
					        self.ip2location = IP2Location.IP2Location(
 | 
				
			||||||
            base_path.joinpath("IP2LOCATION-LITE-DB11.BIN")
 | 
					            base_path.joinpath("IP2LOCATION-LITE-DB11.BIN")
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        self.on_startup.append(self.prepare_stats)
 | 
				
			||||||
        self.on_startup.append(self.prepare_asyncio)
 | 
					        self.on_startup.append(self.prepare_asyncio)
 | 
				
			||||||
        self.on_startup.append(self.start_user_availability_service)
 | 
					        self.on_startup.append(self.start_user_availability_service)
 | 
				
			||||||
        self.on_startup.append(self.start_ssh_server)
 | 
					        self.on_startup.append(self.start_ssh_server)
 | 
				
			||||||
        self.on_startup.append(self.prepare_database)
 | 
					        #self.on_startup.append(self.prepare_database)
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
 | 
					    async def prepare_stats(self, app):
 | 
				
			||||||
 | 
					        app['stats'] = create_stats_structure()
 | 
				
			||||||
 | 
					        print("Stats prepared", flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def uptime_seconds(self):
 | 
					    def uptime_seconds(self):
 | 
				
			||||||
@ -237,21 +243,11 @@ class Application(BaseApplication):
 | 
				
			|||||||
            except Exception as ex:
 | 
					            except Exception as ex:
 | 
				
			||||||
                print(ex)
 | 
					                print(ex)
 | 
				
			||||||
            self.db.commit()
 | 
					            self.db.commit()
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def prepare_database(self, app):
 | 
					    async def prepare_database(self, app):
 | 
				
			||||||
        self.db.query("PRAGMA journal_mode=WAL")
 | 
					        await self.db.query_raw("PRAGMA journal_mode=WAL")
 | 
				
			||||||
        self.db.query("PRAGMA syncnorm=off")
 | 
					        await self.db.query_raw("PRAGMA syncnorm=off")
 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            if not self.db["user"].has_index("username"):
 | 
					 | 
				
			||||||
                self.db["user"].create_index("username", unique=True)
 | 
					 | 
				
			||||||
            if not self.db["channel_member"].has_index(["channel_uid", "user_uid"]):
 | 
					 | 
				
			||||||
                self.db["channel_member"].create_index(["channel_uid", "user_uid"])
 | 
					 | 
				
			||||||
            if not self.db["channel_message"].has_index(["channel_uid", "user_uid"]):
 | 
					 | 
				
			||||||
                self.db["channel_message"].create_index(["channel_uid", "user_uid"])
 | 
					 | 
				
			||||||
        except:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await self.services.drive.prepare_all()
 | 
					        await self.services.drive.prepare_all()
 | 
				
			||||||
        self.loop.create_task(self.task_runner())
 | 
					        self.loop.create_task(self.task_runner())
 | 
				
			||||||
@ -264,8 +260,6 @@ class Application(BaseApplication):
 | 
				
			|||||||
            name="static",
 | 
					            name="static",
 | 
				
			||||||
            show_index=True,
 | 
					            show_index=True,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.router.add_view("/new.html", NewView)
 | 
					 | 
				
			||||||
        self.router.add_view("/profiler.html", profiler_handler)
 | 
					        self.router.add_view("/profiler.html", profiler_handler)
 | 
				
			||||||
        self.router.add_view("/container/sock/{channel_uid}.json", ContainerView)
 | 
					        self.router.add_view("/container/sock/{channel_uid}.json", ContainerView)
 | 
				
			||||||
        self.router.add_view("/about.html", AboutHTMLView)
 | 
					        self.router.add_view("/about.html", AboutHTMLView)
 | 
				
			||||||
@ -284,9 +278,9 @@ class Application(BaseApplication):
 | 
				
			|||||||
        self.router.add_view("/login.json", LoginView)
 | 
					        self.router.add_view("/login.json", LoginView)
 | 
				
			||||||
        self.router.add_view("/register.html", RegisterView)
 | 
					        self.router.add_view("/register.html", RegisterView)
 | 
				
			||||||
        self.router.add_view("/register.json", RegisterView)
 | 
					        self.router.add_view("/register.json", RegisterView)
 | 
				
			||||||
       # self.router.add_view("/drive/{rel_path:.*}", DriveView)
 | 
					        self.router.add_view("/drive/{rel_path:.*}", DriveView)
 | 
				
			||||||
       ## self.router.add_view("/drive.bin", UploadView)
 | 
					        self.router.add_view("/drive.bin", UploadView)
 | 
				
			||||||
       # self.router.add_view("/drive.bin/{uid}.{ext}", UploadView)
 | 
					        self.router.add_view("/drive.bin/{uid}.{ext}", UploadView)
 | 
				
			||||||
        self.router.add_view("/search-user.html", SearchUserView)
 | 
					        self.router.add_view("/search-user.html", SearchUserView)
 | 
				
			||||||
        self.router.add_view("/search-user.json", SearchUserView)
 | 
					        self.router.add_view("/search-user.json", SearchUserView)
 | 
				
			||||||
        self.router.add_view("/avatar/{uid}.svg", AvatarView)
 | 
					        self.router.add_view("/avatar/{uid}.svg", AvatarView)
 | 
				
			||||||
@ -294,25 +288,26 @@ class Application(BaseApplication):
 | 
				
			|||||||
        self.router.add_get("/http-photo", self.handle_http_photo)
 | 
					        self.router.add_get("/http-photo", self.handle_http_photo)
 | 
				
			||||||
        self.router.add_get("/rpc.ws", RPCView)
 | 
					        self.router.add_get("/rpc.ws", RPCView)
 | 
				
			||||||
        self.router.add_get("/c/{channel:.*}", ChannelView)
 | 
					        self.router.add_get("/c/{channel:.*}", ChannelView)
 | 
				
			||||||
        #self.router.add_view(
 | 
					        self.router.add_view(
 | 
				
			||||||
        #    "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
 | 
					            "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
 | 
				
			||||||
        #)
 | 
					        )
 | 
				
			||||||
        #self.router.add_view(
 | 
					        self.router.add_view(
 | 
				
			||||||
        #    "/channel/{channel_uid}/drive.json", ChannelDriveApiView
 | 
					            "/channel/{channel_uid}/drive.json", ChannelDriveApiView
 | 
				
			||||||
        #)
 | 
					        )
 | 
				
			||||||
        self.router.add_view(
 | 
					        self.router.add_view(
 | 
				
			||||||
            "/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
 | 
					            "/channel/{channel_uid}/attachment.sock", ChannelAttachmentUploadView
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.router.add_view(
 | 
					        self.router.add_view(
 | 
				
			||||||
            "/channel/attachment/{relative_url:.*}", ChannelAttachmentView
 | 
					            "/channel/attachment/{relative_url:.*}", ChannelAttachmentView
 | 
				
			||||||
        )#
 | 
					        )
 | 
				
			||||||
        self.router.add_view("/channel/{channel}.html", WebView)
 | 
					        self.router.add_view("/channel/{channel}.html", WebView)
 | 
				
			||||||
        self.router.add_view("/threads.html", ThreadsView)
 | 
					        self.router.add_view("/threads.html", ThreadsView)
 | 
				
			||||||
        self.router.add_view("/terminal.ws", TerminalSocketView)
 | 
					        self.router.add_view("/terminal.ws", TerminalSocketView)
 | 
				
			||||||
        self.router.add_view("/terminal.html", TerminalView)
 | 
					        self.router.add_view("/terminal.html", TerminalView)
 | 
				
			||||||
        #self.router.add_view("/drive.json", DriveApiView)
 | 
					        self.router.add_view("/drive.json", DriveApiView)
 | 
				
			||||||
        #self.router.add_view("/drive.html", DriveView)
 | 
					        self.router.add_view("/drive.html", DriveView)
 | 
				
			||||||
        #self.router.add_view("/drive/{drive}.json", DriveView)
 | 
					        self.router.add_view("/drive/{drive}.json", DriveView)
 | 
				
			||||||
 | 
					        self.router.add_get("/stats.html", stats_handler)
 | 
				
			||||||
        self.router.add_view("/stats.json", StatsView)
 | 
					        self.router.add_view("/stats.json", StatsView)
 | 
				
			||||||
        self.router.add_view("/user/{user}.html", UserView)
 | 
					        self.router.add_view("/user/{user}.html", UserView)
 | 
				
			||||||
        self.router.add_view("/repository/{username}/{repository}", RepositoryView)
 | 
					        self.router.add_view("/repository/{username}/{repository}", RepositoryView)
 | 
				
			||||||
@ -367,8 +362,6 @@ class Application(BaseApplication):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # @time_cache_async(60)
 | 
					    # @time_cache_async(60)
 | 
				
			||||||
    async def render_template(self, template, request, context=None):
 | 
					    async def render_template(self, template, request, context=None):
 | 
				
			||||||
        start_time = time.perf_counter()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        channels = []
 | 
					        channels = []
 | 
				
			||||||
        if not context:
 | 
					        if not context:
 | 
				
			||||||
            context = {}
 | 
					            context = {}
 | 
				
			||||||
@ -419,22 +412,20 @@ class Application(BaseApplication):
 | 
				
			|||||||
        self.jinja2_env.loader = await self.get_user_template_loader(
 | 
					        self.jinja2_env.loader = await self.get_user_template_loader(
 | 
				
			||||||
            request.session.get("uid")
 | 
					            request.session.get("uid")
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            context["nonce"] = request['csp_nonce']
 | 
					            context["nonce"] = request['csp_nonce']
 | 
				
			||||||
        except:
 | 
					        except:
 | 
				
			||||||
            context['nonce'] = '?' 
 | 
					            context['nonce'] = '?'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        rendered = await super().render_template(template, request, context)
 | 
					        rendered = await super().render_template(template, request, context)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.jinja2_env.loader = self.original_loader
 | 
					        self.jinja2_env.loader = self.original_loader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        end_time = time.perf_counter()
 | 
					 | 
				
			||||||
        print(f"render_template took {end_time - start_time:.4f} seconds")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # rendered.text = whitelist_attributes(rendered.text)
 | 
					        # rendered.text = whitelist_attributes(rendered.text)
 | 
				
			||||||
        # rendered.headers['Content-Lenght'] = len(rendered.text)
 | 
					        # rendered.headers['Content-Lenght'] = len(rendered.text)
 | 
				
			||||||
        return rendered
 | 
					        return rendered
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def static_handler(self, request):
 | 
					    async def static_handler(self, request):
 | 
				
			||||||
        file_name = request.match_info.get("filename", "")
 | 
					        file_name = request.match_info.get("filename", "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -460,7 +451,7 @@ class Application(BaseApplication):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    async def get_user_template_loader(self, uid=None):
 | 
					    async def get_user_template_loader(self, uid=None):
 | 
				
			||||||
        template_paths = []
 | 
					        template_paths = []
 | 
				
			||||||
        for admin_uid in self.services.user.get_admin_uids():
 | 
					        for admin_uid in await self.services.user.get_admin_uids():
 | 
				
			||||||
            user_template_path = await self.services.user.get_template_path(admin_uid)
 | 
					            user_template_path = await self.services.user.get_template_path(admin_uid)
 | 
				
			||||||
            if user_template_path:
 | 
					            if user_template_path:
 | 
				
			||||||
                template_paths.append(user_template_path)
 | 
					                template_paths.append(user_template_path)
 | 
				
			||||||
@ -472,7 +463,7 @@ class Application(BaseApplication):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        template_paths.append(self.template_path)
 | 
					        template_paths.append(self.template_path)
 | 
				
			||||||
        return FileSystemLoader(template_paths)
 | 
					        return FileSystemLoader(template_paths)
 | 
				
			||||||
    
 | 
					
 | 
				
			||||||
    @asynccontextmanager
 | 
					    @asynccontextmanager
 | 
				
			||||||
    async def no_save(self):
 | 
					    async def no_save(self):
 | 
				
			||||||
        stats = {
 | 
					        stats = {
 | 
				
			||||||
@ -487,7 +478,7 @@ class Application(BaseApplication):
 | 
				
			|||||||
        self.services.channel_message.mapper.save = patched_save
 | 
					        self.services.channel_message.mapper.save = patched_save
 | 
				
			||||||
        raised_exception = None
 | 
					        raised_exception = None
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            yield 
 | 
					            yield
 | 
				
			||||||
        except Exception as ex:
 | 
					        except Exception as ex:
 | 
				
			||||||
            raised_exception = ex
 | 
					            raised_exception = ex
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
 | 
				
			|||||||
@ -6,11 +6,11 @@ class UserMapper(BaseMapper):
 | 
				
			|||||||
    table_name = "user"
 | 
					    table_name = "user"
 | 
				
			||||||
    model_class = UserModel
 | 
					    model_class = UserModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_admin_uids(self):
 | 
					    async def get_admin_uids(self):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return [
 | 
					            return [
 | 
				
			||||||
                user["uid"]
 | 
					                user["uid"]
 | 
				
			||||||
                for user in self.db.query(
 | 
					                for user in await self.db.query(
 | 
				
			||||||
                    "SELECT uid FROM user WHERE is_admin = :is_admin",
 | 
					                    "SELECT uid FROM user WHERE is_admin = :is_admin",
 | 
				
			||||||
                    {"is_admin": True},
 | 
					                    {"is_admin": True},
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
				
			|||||||
@ -23,7 +23,7 @@ class ChannelModel(BaseModel):
 | 
				
			|||||||
            history_start_filter = f" AND created_at > '{self['history_start']}' "
 | 
					            history_start_filter = f" AND created_at > '{self['history_start']}' "
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            async for model in self.app.services.channel_message.query(
 | 
					            async for model in self.app.services.channel_message.query(
 | 
				
			||||||
                "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid" + history_start_filter + " ORDER BY id DESC LIMIT 1",
 | 
					                "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid" + history_start_filter + " ORDER BY created_at DESC LIMIT 1",
 | 
				
			||||||
                {"channel_uid": self["uid"]},
 | 
					                {"channel_uid": self["uid"]},
 | 
				
			||||||
            ):
 | 
					            ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -39,20 +39,19 @@ CREATE TABLE IF NOT EXISTS channel (
 | 
				
			|||||||
);
 | 
					);
 | 
				
			||||||
CREATE INDEX IF NOT EXISTS ix_channel_e2577dd78b54fe28 ON channel (uid);
 | 
					CREATE INDEX IF NOT EXISTS ix_channel_e2577dd78b54fe28 ON channel (uid);
 | 
				
			||||||
CREATE TABLE IF NOT EXISTS channel_member (
 | 
					CREATE TABLE IF NOT EXISTS channel_member (
 | 
				
			||||||
	id INTEGER NOT NULL,
 | 
						id INTEGER NOT NULL, 
 | 
				
			||||||
	channel_uid TEXT,
 | 
						channel_uid TEXT, 
 | 
				
			||||||
	created_at TEXT,
 | 
						created_at TEXT, 
 | 
				
			||||||
	deleted_at TEXT,
 | 
						deleted_at TEXT, 
 | 
				
			||||||
	is_banned BOOLEAN,
 | 
						is_banned BOOLEAN, 
 | 
				
			||||||
	is_moderator BOOLEAN,
 | 
						is_moderator BOOLEAN, 
 | 
				
			||||||
	is_muted BOOLEAN,
 | 
						is_muted BOOLEAN, 
 | 
				
			||||||
	is_read_only BOOLEAN,
 | 
						is_read_only BOOLEAN, 
 | 
				
			||||||
	label TEXT,
 | 
						label TEXT, 
 | 
				
			||||||
	new_count BIGINT,
 | 
						new_count BIGINT, 
 | 
				
			||||||
	last_read_at TEXT,
 | 
						uid TEXT, 
 | 
				
			||||||
	uid TEXT,
 | 
						updated_at TEXT, 
 | 
				
			||||||
	updated_at TEXT,
 | 
						user_uid TEXT, 
 | 
				
			||||||
	user_uid TEXT,
 | 
					 | 
				
			||||||
	PRIMARY KEY (id)
 | 
						PRIMARY KEY (id)
 | 
				
			||||||
);
 | 
					);
 | 
				
			||||||
CREATE INDEX IF NOT EXISTS ix_channel_member_e2577dd78b54fe28 ON channel_member (uid);
 | 
					CREATE INDEX IF NOT EXISTS ix_channel_member_e2577dd78b54fe28 ON channel_member (uid);
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,4 @@
 | 
				
			|||||||
from snek.system.service import BaseService
 | 
					from snek.system.service import BaseService
 | 
				
			||||||
from snek.system.model import now
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ChannelMemberService(BaseService):
 | 
					class ChannelMemberService(BaseService):
 | 
				
			||||||
@ -9,7 +8,6 @@ class ChannelMemberService(BaseService):
 | 
				
			|||||||
    async def mark_as_read(self, channel_uid, user_uid):
 | 
					    async def mark_as_read(self, channel_uid, user_uid):
 | 
				
			||||||
        channel_member = await self.get(channel_uid=channel_uid, user_uid=user_uid)
 | 
					        channel_member = await self.get(channel_uid=channel_uid, user_uid=user_uid)
 | 
				
			||||||
        channel_member["new_count"] = 0
 | 
					        channel_member["new_count"] = 0
 | 
				
			||||||
        channel_member["last_read_at"] = now()
 | 
					 | 
				
			||||||
        return await self.save(channel_member)
 | 
					        return await self.save(channel_member)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_user_uids(self, channel_uid):
 | 
					    async def get_user_uids(self, channel_uid):
 | 
				
			||||||
 | 
				
			|||||||
@ -1,21 +1,6 @@
 | 
				
			|||||||
from snek.system.service import BaseService
 | 
					from snek.system.service import BaseService
 | 
				
			||||||
from snek.system.template import sanitize_html
 | 
					from snek.system.template import sanitize_html
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
from concurrent.futures import ProcessPoolExecutor
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from jinja2 import Environment, FileSystemLoader
 | 
					 | 
				
			||||||
global jinja2_env
 | 
					 | 
				
			||||||
import pathlib 
 | 
					 | 
				
			||||||
template_path = pathlib.Path(__file__).parent.parent.joinpath("templates")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def render(context):
 | 
					 | 
				
			||||||
    template =jinja2_env.get_template("message.html")
 | 
					 | 
				
			||||||
    return sanitize_html(template.render(**context))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ChannelMessageService(BaseService):
 | 
					class ChannelMessageService(BaseService):
 | 
				
			||||||
    mapper_name = "channel_message"
 | 
					    mapper_name = "channel_message"
 | 
				
			||||||
@ -23,26 +8,9 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
        self._configured_indexes = False
 | 
					        self._configured_indexes = False
 | 
				
			||||||
        self._executor_pools = {}
 | 
					 | 
				
			||||||
        global jinja2_env
 | 
					 | 
				
			||||||
        jinja2_env = self.app.jinja2_env
 | 
					 | 
				
			||||||
        self._max_workers = 1
 | 
					 | 
				
			||||||
    def get_or_create_executor(self, uid):
 | 
					 | 
				
			||||||
        if not uid in self._executor_pools:
 | 
					 | 
				
			||||||
            self._executor_pools[uid] = ProcessPoolExecutor(max_workers=self._max_workers)
 | 
					 | 
				
			||||||
        print("Executors available", len(self._executor_pools))
 | 
					 | 
				
			||||||
        return self._executor_pools[uid]
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    def delete_executor(self, uid):
 | 
					 | 
				
			||||||
        if uid in self._executor_pools:
 | 
					 | 
				
			||||||
            self._executor_pools[uid].shutdown()
 | 
					 | 
				
			||||||
            del self._executor_pools[uid]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def maintenance(self):
 | 
					    async def maintenance(self):
 | 
				
			||||||
        args = {}
 | 
					        args = {}
 | 
				
			||||||
        
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return 
 | 
					 | 
				
			||||||
        for message in self.mapper.db["channel_message"].find():
 | 
					        for message in self.mapper.db["channel_message"].find():
 | 
				
			||||||
            print(message)
 | 
					            print(message)
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
@ -61,12 +29,12 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
                )
 | 
					                )
 | 
				
			||||||
                if html != message["html"]:
 | 
					                if html != message["html"]:
 | 
				
			||||||
                    print("Reredefined message", message["uid"])
 | 
					                    print("Reredefined message", message["uid"])
 | 
				
			||||||
                    
 | 
					
 | 
				
			||||||
            except Exception as ex:
 | 
					            except Exception as ex:
 | 
				
			||||||
                time.sleep(0.1)
 | 
					                time.sleep(0.1)
 | 
				
			||||||
                print(ex, flush=True)
 | 
					                print(ex, flush=True)
 | 
				
			||||||
                
 | 
					
 | 
				
			||||||
                
 | 
					
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            changed = 0
 | 
					            changed = 0
 | 
				
			||||||
            async for message in self.find(is_final=False):
 | 
					            async for message in self.find(is_final=False):
 | 
				
			||||||
@ -101,14 +69,10 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
                "color": user["color"],
 | 
					                "color": user["color"],
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        loop = asyncio.get_event_loop()
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            template = self.app.jinja2_env.get_template("message.html")
 | 
				
			||||||
            context = json.loads(json.dumps(context, default=str)) 
 | 
					            model["html"] = template.render(**context)
 | 
				
			||||||
 | 
					            model['html'] = sanitize_html(model['html'])
 | 
				
			||||||
 | 
					 | 
				
			||||||
            model["html"] = await loop.run_in_executor(self.get_or_create_executor(model["uid"]), render,context)
 | 
					 | 
				
			||||||
            #model['html'] = await loop.run_in_executor(self.get_or_create_executor(user["uid"]), sanitize_html,model['html'])
 | 
					 | 
				
			||||||
        except Exception as ex:
 | 
					        except Exception as ex:
 | 
				
			||||||
            print(ex, flush=True)
 | 
					            print(ex, flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -127,8 +91,6 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
                        ["deleted_at"], unique=False
 | 
					                        ["deleted_at"], unique=False
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
            self._configured_indexes = True
 | 
					            self._configured_indexes = True
 | 
				
			||||||
            if model['is_final']:
 | 
					 | 
				
			||||||
                self.delete_executor(model['uid'])
 | 
					 | 
				
			||||||
            return model
 | 
					            return model
 | 
				
			||||||
        raise Exception(f"Failed to create channel message: {model.errors}.")
 | 
					        raise Exception(f"Failed to create channel message: {model.errors}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -140,7 +102,7 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
        #if not message["html"].startswith("<chat-message"):
 | 
					        #if not message["html"].startswith("<chat-message"):
 | 
				
			||||||
            #message = await self.get(uid=message["uid"])
 | 
					            #message = await self.get(uid=message["uid"])
 | 
				
			||||||
            #await self.save(message)
 | 
					            #await self.save(message)
 | 
				
			||||||
        
 | 
					
 | 
				
			||||||
        return {
 | 
					        return {
 | 
				
			||||||
            "uid": message["uid"],
 | 
					            "uid": message["uid"],
 | 
				
			||||||
            "color": user["color"],
 | 
					            "color": user["color"],
 | 
				
			||||||
@ -165,15 +127,10 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
                "color": user["color"],
 | 
					                "color": user["color"],
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        context = json.loads(json.dumps(context, default=str)) 
 | 
					        template = self.app.jinja2_env.get_template("message.html")
 | 
				
			||||||
        loop = asyncio.get_event_loop()
 | 
					        model["html"] = template.render(**context)
 | 
				
			||||||
        model["html"] = await loop.run_in_executor(self.get_or_create_executor(model["uid"]), render, context)
 | 
					        model['html'] = sanitize_html(model['html'])
 | 
				
			||||||
        #model['html'] = await loop.run_in_executor(self.get_or_create_executor(user["uid"]), sanitize_html,model['html'])
 | 
					        return await super().save(model)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        result = await super().save(model)
 | 
					 | 
				
			||||||
        if model['is_final']:
 | 
					 | 
				
			||||||
            self.delete_executor(model['uid'])
 | 
					 | 
				
			||||||
        return result 
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def offset(self, channel_uid, page=0, timestamp=None, page_size=30):
 | 
					    async def offset(self, channel_uid, page=0, timestamp=None, page_size=30):
 | 
				
			||||||
        channel = await self.services.channel.get(uid=channel_uid)
 | 
					        channel = await self.services.channel.get(uid=channel_uid)
 | 
				
			||||||
@ -199,22 +156,22 @@ class ChannelMessageService(BaseService):
 | 
				
			|||||||
            elif page > 0:
 | 
					            elif page > 0:
 | 
				
			||||||
                async for model in self.query(
 | 
					                async for model in self.query(
 | 
				
			||||||
                    f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
 | 
					                    f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid WHERE created_at < :timestamp {history_start_filter} ORDER BY created_at DESC LIMIT :page_size",
 | 
				
			||||||
                    {
 | 
					                    *{
 | 
				
			||||||
                        "channel_uid": channel_uid,
 | 
					                        "channel_uid": channel_uid,
 | 
				
			||||||
                        "page_size": page_size,
 | 
					                        "page_size": page_size,
 | 
				
			||||||
                        "offset": offset,
 | 
					                        "offset": offset,
 | 
				
			||||||
                        "timestamp": timestamp,
 | 
					                        "timestamp": timestamp,
 | 
				
			||||||
                    },
 | 
					                    }.values(),
 | 
				
			||||||
                ):
 | 
					                ):
 | 
				
			||||||
                    results.append(model)
 | 
					                    results.append(model)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                async for model in self.query(
 | 
					                async for model in self.query(
 | 
				
			||||||
                    f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
 | 
					                    f"SELECT * FROM channel_message WHERE channel_uid=:channel_uid {history_start_filter} ORDER BY created_at DESC LIMIT :page_size OFFSET :offset",
 | 
				
			||||||
                    {
 | 
					                    *{
 | 
				
			||||||
                        "channel_uid": channel_uid,
 | 
					                        "channel_uid": channel_uid,
 | 
				
			||||||
                        "page_size": page_size,
 | 
					                        "page_size": page_size,
 | 
				
			||||||
                        "offset": offset,
 | 
					                        "offset": offset,
 | 
				
			||||||
                    },
 | 
					                    }.values(),
 | 
				
			||||||
                ):
 | 
					                ):
 | 
				
			||||||
                    results.append(model)
 | 
					                    results.append(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -44,29 +44,19 @@ class SocketService(BaseService):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    async def user_availability_service(self):
 | 
					    async def user_availability_service(self):
 | 
				
			||||||
        logger.info("User availability update service started.")
 | 
					        logger.info("User availability update service started.")
 | 
				
			||||||
        logger.debug("Entering the main loop.")
 | 
					 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            logger.info("Updating user availability...")
 | 
					            logger.info("Updating user availability...")
 | 
				
			||||||
            logger.debug("Initializing users_updated list.")
 | 
					 | 
				
			||||||
            users_updated = []
 | 
					            users_updated = []
 | 
				
			||||||
            logger.debug("Iterating over sockets.")
 | 
					 | 
				
			||||||
            for s in self.sockets:
 | 
					            for s in self.sockets:
 | 
				
			||||||
                logger.debug(f"Checking connection status for socket: {s}.")
 | 
					 | 
				
			||||||
                if not s.is_connected:
 | 
					                if not s.is_connected:
 | 
				
			||||||
                    logger.debug("Socket is not connected, continuing to next socket.")
 | 
					 | 
				
			||||||
                    continue
 | 
					                    continue
 | 
				
			||||||
                logger.debug(f"Checking if user {s.user} is already updated.")
 | 
					 | 
				
			||||||
                if s.user not in users_updated:
 | 
					                if s.user not in users_updated:
 | 
				
			||||||
                    logger.debug(f"Updating last_ping for user: {s.user}.")
 | 
					 | 
				
			||||||
                    s.user["last_ping"] = now()
 | 
					                    s.user["last_ping"] = now()
 | 
				
			||||||
                    logger.debug(f"Saving user {s.user} to the database.")
 | 
					 | 
				
			||||||
                    await self.app.services.user.save(s.user)
 | 
					                    await self.app.services.user.save(s.user)
 | 
				
			||||||
                    logger.debug(f"Adding user {s.user} to users_updated list.")
 | 
					 | 
				
			||||||
                    users_updated.append(s.user)
 | 
					                    users_updated.append(s.user)
 | 
				
			||||||
            logger.info(
 | 
					            logger.info(
 | 
				
			||||||
                f"Updated user availability for {len(users_updated)} online users."
 | 
					                f"Updated user availability for {len(users_updated)} online users."
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            logger.debug("Sleeping for 60 seconds before the next update.")
 | 
					 | 
				
			||||||
            await asyncio.sleep(60)
 | 
					            await asyncio.sleep(60)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def add(self, ws, user_uid):
 | 
					    async def add(self, ws, user_uid):
 | 
				
			||||||
 | 
				
			|||||||
@ -101,6 +101,8 @@ class UserService(BaseService):
 | 
				
			|||||||
        model.username.value = username
 | 
					        model.username.value = username
 | 
				
			||||||
        model.password.value = await security.hash(password)
 | 
					        model.password.value = await security.hash(password)
 | 
				
			||||||
        if await self.save(model):
 | 
					        if await self.save(model):
 | 
				
			||||||
 | 
					            for x in range(10):
 | 
				
			||||||
 | 
					                print("Jazeker!!!")
 | 
				
			||||||
            if model:
 | 
					            if model:
 | 
				
			||||||
                channel = await self.services.channel.ensure_public_channel(
 | 
					                channel = await self.services.channel.ensure_public_channel(
 | 
				
			||||||
                    model["uid"]
 | 
					                    model["uid"]
 | 
				
			||||||
 | 
				
			|||||||
@ -7,7 +7,8 @@ class UserPropertyService(BaseService):
 | 
				
			|||||||
    mapper_name = "user_property"
 | 
					    mapper_name = "user_property"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def set(self, user_uid, name, value):
 | 
					    async def set(self, user_uid, name, value):
 | 
				
			||||||
        self.mapper.db["user_property"].upsert(
 | 
					        self.mapper.db.upsert(
 | 
				
			||||||
 | 
					                "user_property",
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "user_uid": user_uid,
 | 
					                "user_uid": user_uid,
 | 
				
			||||||
                "name": name,
 | 
					                "name": name,
 | 
				
			||||||
 | 
				
			|||||||
@ -65,7 +65,7 @@ export class Chat extends EventHandler {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    return new Promise((resolve) => {
 | 
					    return new Promise((resolve) => {
 | 
				
			||||||
      this._waitConnect = resolve;
 | 
					      this._waitConnect = resolve;
 | 
				
			||||||
      //console.debug("Connecting..");
 | 
					      console.debug("Connecting..");
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
      try {
 | 
					      try {
 | 
				
			||||||
        this._socket = new WebSocket(this._url);
 | 
					        this._socket = new WebSocket(this._url);
 | 
				
			||||||
@ -142,7 +142,7 @@ export class NotificationAudio {
 | 
				
			|||||||
      new Audio(this.sounds[soundIndex])
 | 
					      new Audio(this.sounds[soundIndex])
 | 
				
			||||||
        .play()
 | 
					        .play()
 | 
				
			||||||
        .then(() => {
 | 
					        .then(() => {
 | 
				
			||||||
          //console.debug("Gave sound notification");
 | 
					          console.debug("Gave sound notification");
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
        .catch((error) => {
 | 
					        .catch((error) => {
 | 
				
			||||||
          console.error("Notification failed:", error);
 | 
					          console.error("Notification failed:", error);
 | 
				
			||||||
 | 
				
			|||||||
@ -407,48 +407,34 @@ a {
 | 
				
			|||||||
  width: 250px;
 | 
					  width: 250px;
 | 
				
			||||||
  padding-left: 20px;
 | 
					  padding-left: 20px;
 | 
				
			||||||
  padding-right: 20px;
 | 
					  padding-right: 20px;
 | 
				
			||||||
  padding-top: 20px;
 | 
					  padding-top: 10px;
 | 
				
			||||||
  overflow-y: auto;
 | 
					  overflow-y: auto;
 | 
				
			||||||
  grid-area: sidebar;
 | 
					  grid-area: sidebar;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.sidebar h2 {
 | 
					.sidebar h2 {
 | 
				
			||||||
  color: #f05a28;
 | 
					  color: #f05a28;
 | 
				
			||||||
  font-size: 0.75em;
 | 
					  font-size: 1.2em;
 | 
				
			||||||
  font-weight: 600;
 | 
					  margin-bottom: 20px;
 | 
				
			||||||
  text-transform: uppercase;
 | 
					 | 
				
			||||||
  letter-spacing: 0.5px;
 | 
					 | 
				
			||||||
  margin-bottom: 10px;
 | 
					 | 
				
			||||||
  margin-top: 20px;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
.sidebar h2:first-child {
 | 
					 | 
				
			||||||
  margin-top: 0;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.sidebar ul {
 | 
					.sidebar ul {
 | 
				
			||||||
  list-style: none;
 | 
					  list-style: none;
 | 
				
			||||||
  margin-bottom: 15px;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.sidebar ul li {
 | 
					.sidebar ul li {
 | 
				
			||||||
  margin-bottom: 4px;
 | 
					  margin-bottom: 15px;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.sidebar ul li a {
 | 
					.sidebar ul li a {
 | 
				
			||||||
  color: #ccc;
 | 
					  color: #ccc;
 | 
				
			||||||
  text-decoration: none;
 | 
					  text-decoration: none;
 | 
				
			||||||
  font-size: 0.9em;
 | 
					  font-size: 1em;
 | 
				
			||||||
  transition: color 0.3s;
 | 
					  transition: color 0.3s;
 | 
				
			||||||
  display: block;
 | 
					 | 
				
			||||||
  padding: 4px 8px;
 | 
					 | 
				
			||||||
  border-radius: 4px;
 | 
					 | 
				
			||||||
  transition: background-color 0.2s, color 0.2s;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.sidebar ul li a:hover {
 | 
					.sidebar ul li a:hover {
 | 
				
			||||||
  color: #fff;
 | 
					  color: #fff;
 | 
				
			||||||
  background-color: rgba(255, 255, 255, 0.05);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@keyframes glow {
 | 
					@keyframes glow {
 | 
				
			||||||
 | 
				
			|||||||
@ -81,152 +81,7 @@ class ChatInputComponent extends NjetComponent {
 | 
				
			|||||||
    return Array.from(text.matchAll(/@([a-zA-Z0-9_-]+)/g), m => m[1]);
 | 
					    return Array.from(text.matchAll(/@([a-zA-Z0-9_-]+)/g), m => m[1]);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  matchMentionsToAuthors(mentions, authors) {
 | 
					  matchMentionsToAuthors(mentions, authors) {
 | 
				
			||||||
    return mentions.map((mention) => {
 | 
					 | 
				
			||||||
      const lowerMention = mention.toLowerCase();
 | 
					 | 
				
			||||||
      let bestMatch = null;
 | 
					 | 
				
			||||||
      let bestScore = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      for (const author of authors) {
 | 
					 | 
				
			||||||
        const lowerAuthor = author.toLowerCase();
 | 
					 | 
				
			||||||
        let score = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (lowerMention === lowerAuthor) {
 | 
					 | 
				
			||||||
          score = 100;
 | 
					 | 
				
			||||||
        } else if (lowerAuthor.startsWith(lowerMention)) {
 | 
					 | 
				
			||||||
          score = 90 + (5 * (lowerMention.length / lowerAuthor.length));
 | 
					 | 
				
			||||||
        } else if (lowerAuthor.includes(lowerMention) && lowerMention.length >= 2) {
 | 
					 | 
				
			||||||
          const position = lowerAuthor.indexOf(lowerMention);
 | 
					 | 
				
			||||||
          score = 80 - (10 * (position / lowerAuthor.length));
 | 
					 | 
				
			||||||
        } else if (this.isFuzzyMatch(lowerMention, lowerAuthor)) {
 | 
					 | 
				
			||||||
          const ratio = lowerMention.length / lowerAuthor.length;
 | 
					 | 
				
			||||||
          score = 40 + (20 * ratio);
 | 
					 | 
				
			||||||
        } else if (this.isCloseMatch(lowerMention, lowerAuthor)) {
 | 
					 | 
				
			||||||
          score = 30 + (10 * (lowerMention.length / lowerAuthor.length));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (score > bestScore) {
 | 
					 | 
				
			||||||
          bestScore = score;
 | 
					 | 
				
			||||||
          bestMatch = author;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      const minScore = 40;
 | 
					 | 
				
			||||||
      return {
 | 
					 | 
				
			||||||
        mention,
 | 
					 | 
				
			||||||
        closestAuthor: bestScore >= minScore ? bestMatch : null,
 | 
					 | 
				
			||||||
        score: bestScore,
 | 
					 | 
				
			||||||
      };
 | 
					 | 
				
			||||||
    });
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  isFuzzyMatch(needle, haystack) {
 | 
					 | 
				
			||||||
    if (needle.length < 2) return false;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    let needleIndex = 0;
 | 
					 | 
				
			||||||
    for (let i = 0; i < haystack.length && needleIndex < needle.length; i++) {
 | 
					 | 
				
			||||||
      if (haystack[i] === needle[needleIndex]) {
 | 
					 | 
				
			||||||
        needleIndex++;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return needleIndex === needle.length;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  isCloseMatch(str1, str2) {
 | 
					 | 
				
			||||||
    if (Math.abs(str1.length - str2.length) > 2) return false;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const shorter = str1.length <= str2.length ? str1 : str2;
 | 
					 | 
				
			||||||
    const longer = str1.length > str2.length ? str1 : str2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    let differences = 0;
 | 
					 | 
				
			||||||
    let j = 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (let i = 0; i < shorter.length && j < longer.length; i++) {
 | 
					 | 
				
			||||||
      if (shorter[i] !== longer[j]) {
 | 
					 | 
				
			||||||
        differences++;
 | 
					 | 
				
			||||||
        if (j + 1 < longer.length && shorter[i] === longer[j + 1]) {
 | 
					 | 
				
			||||||
          j++;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      j++;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    differences += Math.abs(longer.length - j);
 | 
					 | 
				
			||||||
    return differences <= 2;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  matchMentions4ToAuthors(mentions, authors) {
 | 
					 | 
				
			||||||
    return mentions.map((mention) => {
 | 
					 | 
				
			||||||
      let closestAuthor = null;
 | 
					 | 
				
			||||||
      let minDistance = Infinity;
 | 
					 | 
				
			||||||
      const lowerMention = mention.toLowerCase();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      authors.forEach((author) => {
 | 
					 | 
				
			||||||
        const lowerAuthor = author.toLowerCase();
 | 
					 | 
				
			||||||
        let distance = this.levenshteinDistance(lowerMention, lowerAuthor);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (!this.isSubsequence(lowerMention, lowerAuthor)) {
 | 
					 | 
				
			||||||
          distance += 10;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (distance < minDistance) {
 | 
					 | 
				
			||||||
          minDistance = distance;
 | 
					 | 
				
			||||||
          closestAuthor = author;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      if (minDistance < 5) {
 | 
					 | 
				
			||||||
        closestAuthor = 0;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
      return { mention, closestAuthor, distance: minDistance };
 | 
					 | 
				
			||||||
    });
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  levenshteinDistance(a, b) {
 | 
					 | 
				
			||||||
    const matrix = [];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (let i = 0; i <= b.length; i++) {
 | 
					 | 
				
			||||||
      matrix[i] = [i];
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    for (let j = 0; j <= a.length; j++) {
 | 
					 | 
				
			||||||
      matrix[0][j] = j;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (let i = 1; i <= b.length; i++) {
 | 
					 | 
				
			||||||
      for (let j = 1; j <= a.length; j++) {
 | 
					 | 
				
			||||||
        if (b.charAt(i - 1) === a.charAt(j - 1)) {
 | 
					 | 
				
			||||||
          matrix[i][j] = matrix[i - 1][j - 1];
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
          matrix[i][j] = Math.min(
 | 
					 | 
				
			||||||
            matrix[i - 1][j] + 1,
 | 
					 | 
				
			||||||
            matrix[i][j - 1] + 1,
 | 
					 | 
				
			||||||
            matrix[i - 1][j - 1] + 1
 | 
					 | 
				
			||||||
          );
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return matrix[b.length][a.length];
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  replaceMentionsWithAuthors(text) {
 | 
					 | 
				
			||||||
    const authors = this.getAuthors();
 | 
					 | 
				
			||||||
    const mentions = this.extractMentions(text);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const matches = this.matchMentionsToAuthors(mentions, authors);
 | 
					 | 
				
			||||||
    let updatedText = text;
 | 
					 | 
				
			||||||
    matches.forEach(({ mention, closestAuthor }) => {
 | 
					 | 
				
			||||||
      if(closestAuthor){
 | 
					 | 
				
			||||||
        const mentionRegex = new RegExp(`@${mention}`, 'g');
 | 
					 | 
				
			||||||
        updatedText = updatedText.replace(mentionRegex, `@${closestAuthor}`);
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return updatedText;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  matchMentions2ToAuthors(mentions, authors) {
 | 
					 | 
				
			||||||
    return mentions.map(mention => {
 | 
					    return mentions.map(mention => {
 | 
				
			||||||
      let closestAuthor = null;
 | 
					      let closestAuthor = null;
 | 
				
			||||||
      let minDistance = Infinity;
 | 
					      let minDistance = Infinity;
 | 
				
			||||||
@ -245,14 +100,53 @@ class ChatInputComponent extends NjetComponent {
 | 
				
			|||||||
          closestAuthor = author;
 | 
					          closestAuthor = author;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
    if (minDistance < 5){
 | 
					
 | 
				
			||||||
        closestAuthor = 0;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
      return { mention, closestAuthor, distance: minDistance };
 | 
					      return { mention, closestAuthor, distance: minDistance };
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  levenshteinDistance(a, b) {
 | 
				
			||||||
 | 
					    const matrix = [];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Initialize the first row and column
 | 
				
			||||||
 | 
					    for (let i = 0; i <= b.length; i++) {
 | 
				
			||||||
 | 
					      matrix[i] = [i];
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    for (let j = 0; j <= a.length; j++) {
 | 
				
			||||||
 | 
					      matrix[0][j] = j;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Fill in the matrix
 | 
				
			||||||
 | 
					    for (let i = 1; i <= b.length; i++) {
 | 
				
			||||||
 | 
					      for (let j = 1; j <= a.length; j++) {
 | 
				
			||||||
 | 
					        if (b.charAt(i - 1) === a.charAt(j - 1)) {
 | 
				
			||||||
 | 
					          matrix[i][j] = matrix[i - 1][j - 1];
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          matrix[i][j] = Math.min(
 | 
				
			||||||
 | 
					            matrix[i - 1][j] + 1,    // Deletion
 | 
				
			||||||
 | 
					            matrix[i][j - 1] + 1,    // Insertion
 | 
				
			||||||
 | 
					            matrix[i - 1][j - 1] + 1 // Substitution
 | 
				
			||||||
 | 
					          );
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matrix[b.length][a.length];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  replaceMentionsWithAuthors(text) {
 | 
				
			||||||
 | 
					    const authors = this.getAuthors();
 | 
				
			||||||
 | 
					    const mentions = this.extractMentions(text);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const matches = this.matchMentionsToAuthors(mentions, authors);
 | 
				
			||||||
 | 
					    let updatedText = text;
 | 
				
			||||||
 | 
					    matches.forEach(({ mention, closestAuthor }) => {
 | 
				
			||||||
 | 
					      const mentionRegex = new RegExp(`@${mention}`, 'g');
 | 
				
			||||||
 | 
					      updatedText = updatedText.replace(mentionRegex, `@${closestAuthor}`);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return updatedText;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  textToLeet(text) {
 | 
					  textToLeet(text) {
 | 
				
			||||||
  // L33t speak character mapping
 | 
					  // L33t speak character mapping
 | 
				
			||||||
  const leetMap = {
 | 
					  const leetMap = {
 | 
				
			||||||
@ -498,7 +392,7 @@ textToLeetAdvanced(text) {
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
      j++;
 | 
					      j++;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    return i === s.length && s.length > 1;
 | 
					    return i === s.length;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  flagTyping() {
 | 
					  flagTyping() {
 | 
				
			||||||
 | 
				
			|||||||
@ -34,12 +34,6 @@ export class ReplyEvent extends Event {
 | 
				
			|||||||
      img.replaceWith(document.createTextNode(src));
 | 
					      img.replaceWith(document.createTextNode(src));
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
     // Replace <video> with just their src
 | 
					 | 
				
			||||||
    newMessage.querySelectorAll('video').forEach(vid => {
 | 
					 | 
				
			||||||
      const src = vid.src || vid.currentSrc || vid.querySelector('source').src;
 | 
					 | 
				
			||||||
      vid.replaceWith(document.createTextNode(src));
 | 
					 | 
				
			||||||
    });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Replace <iframe> with their src
 | 
					    // Replace <iframe> with their src
 | 
				
			||||||
    newMessage.querySelectorAll('iframe').forEach(iframe => {
 | 
					    newMessage.querySelectorAll('iframe').forEach(iframe => {
 | 
				
			||||||
      const src = iframe.src || iframe.currentSrc;
 | 
					      const src = iframe.src || iframe.currentSrc;
 | 
				
			||||||
@ -262,9 +256,7 @@ class MessageList extends HTMLElement {
 | 
				
			|||||||
    this.querySelectorAll('.avatar').forEach((el) => {
 | 
					    this.querySelectorAll('.avatar').forEach((el) => {
 | 
				
			||||||
      const anchor = el.closest('a');
 | 
					      const anchor = el.closest('a');
 | 
				
			||||||
      if (anchor && typeof anchor.href === 'string' && anchor.href.includes(uid)) {
 | 
					      if (anchor && typeof anchor.href === 'string' && anchor.href.includes(uid)) {
 | 
				
			||||||
        if(!lastElement)
 | 
					        lastElement = el;
 | 
				
			||||||
          lastElement = el;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
    if (lastElement) {
 | 
					    if (lastElement) {
 | 
				
			||||||
@ -286,14 +278,10 @@ class MessageList extends HTMLElement {
 | 
				
			|||||||
  upsertMessage(data) {
 | 
					  upsertMessage(data) {
 | 
				
			||||||
    let message = this.messageMap.get(data.uid);
 | 
					    let message = this.messageMap.get(data.uid);
 | 
				
			||||||
    if (message && (data.is_final || !data.message)) {
 | 
					    if (message && (data.is_final || !data.message)) {
 | 
				
			||||||
      //message.parentElement?.removeChild(message);
 | 
					 | 
				
			||||||
      // TO force insert 
 | 
					 | 
				
			||||||
      //message = null;
 | 
					 | 
				
			||||||
      
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    if(message && !data.message){
 | 
					 | 
				
			||||||
      message.parentElement?.removeChild(message);
 | 
					      message.parentElement?.removeChild(message);
 | 
				
			||||||
 | 
					      // TO force insert 
 | 
				
			||||||
      message = null;
 | 
					      message = null;
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if (!data.message) return;
 | 
					    if (!data.message) return;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -52,7 +52,7 @@ self.addEventListener("push", async (event) => {
 | 
				
			|||||||
        data,
 | 
					        data,
 | 
				
			||||||
    }).then(e => console.log("Showing notification", e)).catch(console.error);
 | 
					    }).then(e => console.log("Showing notification", e)).catch(console.error);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//    event.waitUntil(reg);
 | 
					    event.waitUntil(reg);
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -86,7 +86,6 @@ export class Socket extends EventHandler {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    this.emit("data", data.data);
 | 
					    this.emit("data", data.data);
 | 
				
			||||||
    if (data["event"]) {
 | 
					    if (data["event"]) {
 | 
				
			||||||
      console.info([data.event,data.data])
 | 
					 | 
				
			||||||
      this.emit(data.event, data.data);
 | 
					      this.emit(data.event, data.data);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
@ -100,7 +99,7 @@ export class Socket extends EventHandler {
 | 
				
			|||||||
        console.log("Reconnecting");
 | 
					        console.log("Reconnecting");
 | 
				
			||||||
        this.emit("reconnecting");
 | 
					        this.emit("reconnecting");
 | 
				
			||||||
          return this.connect();
 | 
					          return this.connect();
 | 
				
			||||||
      }, 4000);
 | 
					      }, 0);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  _camelToSnake(str) {
 | 
					  _camelToSnake(str) {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										1212
									
								
								src/snek/system/ads.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1212
									
								
								src/snek/system/ads.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,137 +1,75 @@
 | 
				
			|||||||
import asyncio
 | 
					 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from collections import OrderedDict
 | 
					
 | 
				
			||||||
from snek.system import security
 | 
					from snek.system import security
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cache = functools.cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CACHE_MAX_ITEMS_DEFAULT = 5000
 | 
					CACHE_MAX_ITEMS_DEFAULT = 5000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Cache:
 | 
					class Cache:
 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    An asynchronous, thread-safe, in-memory LRU (Least Recently Used) cache.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This implementation uses an OrderedDict for efficient O(1) time complexity
 | 
					 | 
				
			||||||
    for its core get, set, and delete operations.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
 | 
					    def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
 | 
				
			||||||
        self.app = app
 | 
					        self.app = app
 | 
				
			||||||
        # OrderedDict is the core of the LRU logic. It remembers the order
 | 
					        self.cache = {}
 | 
				
			||||||
        # in which items were inserted.
 | 
					 | 
				
			||||||
        self.cache: OrderedDict = OrderedDict()
 | 
					 | 
				
			||||||
        self.max_items = max_items
 | 
					        self.max_items = max_items
 | 
				
			||||||
        self.stats = {}
 | 
					        self.stats = {}
 | 
				
			||||||
        self.enabled = True
 | 
					        self.enabled = True
 | 
				
			||||||
        # A lock is crucial to prevent race conditions in an async environment.
 | 
					        self.lru = []
 | 
				
			||||||
        self._lock = asyncio.Lock()
 | 
					        self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
 | 
				
			||||||
        self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4 
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get(self, key):
 | 
					    async def get(self, args):
 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Retrieves an item from the cache. If found, it's marked as recently used.
 | 
					 | 
				
			||||||
        Returns None if the item is not found or the cache is disabled.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not self.enabled:
 | 
					        if not self.enabled:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					        await self.update_stat(args, "get")
 | 
				
			||||||
        #async with self._lock:
 | 
					        try:
 | 
				
			||||||
        if key not in self.cache:
 | 
					            self.lru.pop(self.lru.index(args))
 | 
				
			||||||
            await self.update_stat(key, "get")
 | 
					        except:
 | 
				
			||||||
 | 
					            # print("Cache miss!", args, flush=True)
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					        self.lru.insert(0, args)
 | 
				
			||||||
        # Mark as recently used by moving it to the end of the OrderedDict.
 | 
					        while len(self.lru) > self.max_items:
 | 
				
			||||||
        # This is an O(1) operation.
 | 
					            self.cache.pop(self.lru[-1])
 | 
				
			||||||
        self.cache.move_to_end(key)
 | 
					            self.lru.pop()
 | 
				
			||||||
        await self.update_stat(key, "get")
 | 
					        # print("Cache hit!", args, flush=True)
 | 
				
			||||||
        return self.cache[key]
 | 
					        return self.cache[args]
 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def set(self, key, value):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Adds or updates an item in the cache and marks it as recently used.
 | 
					 | 
				
			||||||
        If the cache exceeds its maximum size, the least recently used item is evicted.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if not self.enabled:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        # comment
 | 
					 | 
				
			||||||
        #async with self._lock:
 | 
					 | 
				
			||||||
        is_new = key not in self.cache
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Add or update the item. If it exists, it's moved to the end.
 | 
					 | 
				
			||||||
        self.cache[key] = value
 | 
					 | 
				
			||||||
        self.cache.move_to_end(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        await self.update_stat(key, "set")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Evict the least recently used item if the cache is full.
 | 
					 | 
				
			||||||
        # This is an O(1) operation.
 | 
					 | 
				
			||||||
        if len(self.cache) > self.max_items:
 | 
					 | 
				
			||||||
            # popitem(last=False) removes and returns the first (oldest) item.
 | 
					 | 
				
			||||||
            evicted_key, _ = self.cache.popitem(last=False)
 | 
					 | 
				
			||||||
            # Optionally, you could log the evicted key here.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if is_new:
 | 
					 | 
				
			||||||
            self.version += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def delete(self, key):
 | 
					 | 
				
			||||||
        """Removes an item from the cache if it exists."""
 | 
					 | 
				
			||||||
        if not self.enabled:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async with self._lock:
 | 
					 | 
				
			||||||
            if key in self.cache:
 | 
					 | 
				
			||||||
                await self.update_stat(key, "delete")
 | 
					 | 
				
			||||||
                # Deleting from OrderedDict is an O(1) operation on average.
 | 
					 | 
				
			||||||
                del self.cache[key]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get_stats(self):
 | 
					    async def get_stats(self):
 | 
				
			||||||
        """Returns statistics for all items currently in the cache."""
 | 
					        all_ = []
 | 
				
			||||||
        async with self._lock:
 | 
					        for key in self.lru:
 | 
				
			||||||
            stats_list = []
 | 
					            all_.append(
 | 
				
			||||||
            # Items are iterated from oldest to newest. We reverse to show
 | 
					                {
 | 
				
			||||||
            # most recently used items first.
 | 
					 | 
				
			||||||
            for key in reversed(self.cache):
 | 
					 | 
				
			||||||
                stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
 | 
					 | 
				
			||||||
                value = self.cache[key]
 | 
					 | 
				
			||||||
                value_record = value.record if hasattr(value, 'record') else value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                stats_list.append({
 | 
					 | 
				
			||||||
                    "key": key,
 | 
					                    "key": key,
 | 
				
			||||||
                    "set": stat_data.get("set", 0),
 | 
					                    "set": self.stats[key]["set"],
 | 
				
			||||||
                    "get": stat_data.get("get", 0),
 | 
					                    "get": self.stats[key]["get"],
 | 
				
			||||||
                    "delete": stat_data.get("delete", 0),
 | 
					                    "delete": self.stats[key]["delete"],
 | 
				
			||||||
                    "value": str(self.serialize(value_record)),
 | 
					                    "value": str(self.serialize(self.cache[key].record)),
 | 
				
			||||||
                })
 | 
					                }
 | 
				
			||||||
            return stats_list
 | 
					            )
 | 
				
			||||||
 | 
					        return all_
 | 
				
			||||||
    async def update_stat(self, key, action):
 | 
					 | 
				
			||||||
        """Updates hit/miss/set counts for a given cache key."""
 | 
					 | 
				
			||||||
        # This method is already called within a locked context,
 | 
					 | 
				
			||||||
        # but the lock makes it safe if ever called directly.
 | 
					 | 
				
			||||||
        async with self._lock:
 | 
					 | 
				
			||||||
            if key not in self.stats:
 | 
					 | 
				
			||||||
                self.stats[key] = {"set": 0, "get": 0, "delete": 0}
 | 
					 | 
				
			||||||
            self.stats[key][action] += 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def serialize(self, obj):
 | 
					    def serialize(self, obj):
 | 
				
			||||||
        """A synchronous helper to create a serializable representation of an object."""
 | 
					 | 
				
			||||||
        if not isinstance(obj, dict):
 | 
					 | 
				
			||||||
            return obj
 | 
					 | 
				
			||||||
        cpy = obj.copy()
 | 
					        cpy = obj.copy()
 | 
				
			||||||
        for key_to_remove in ["created_at", "deleted_at", "email", "password"]:
 | 
					        cpy.pop("created_at", None)
 | 
				
			||||||
            cpy.pop(key_to_remove, None)
 | 
					        cpy.pop("deleted_at", None)
 | 
				
			||||||
 | 
					        cpy.pop("email", None)
 | 
				
			||||||
 | 
					        cpy.pop("password", None)
 | 
				
			||||||
        return cpy
 | 
					        return cpy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def update_stat(self, key, action):
 | 
				
			||||||
 | 
					        if key not in self.stats:
 | 
				
			||||||
 | 
					            self.stats[key] = {"set": 0, "get": 0, "delete": 0}
 | 
				
			||||||
 | 
					        self.stats[key][action] = self.stats[key][action] + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def json_default(self, value):
 | 
					    def json_default(self, value):
 | 
				
			||||||
        """JSON serializer fallback for objects that are not directly serializable."""
 | 
					        # if hasattr(value, "to_json"):
 | 
				
			||||||
 | 
					        #    return value.to_json()
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return json.dumps(value.__dict__, default=str)
 | 
					            return json.dumps(value.__dict__, default=str)
 | 
				
			||||||
        except:
 | 
					        except:
 | 
				
			||||||
            return str(value)
 | 
					            return str(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def create_cache_key(self, args, kwargs):
 | 
					    async def create_cache_key(self, args, kwargs):
 | 
				
			||||||
        """Creates a consistent, hashable cache key from function arguments."""
 | 
					 | 
				
			||||||
        # security.hash is async, so this method remains async.
 | 
					 | 
				
			||||||
        return await security.hash(
 | 
					        return await security.hash(
 | 
				
			||||||
            json.dumps(
 | 
					            json.dumps(
 | 
				
			||||||
                {"args": args, "kwargs": kwargs},
 | 
					                {"args": args, "kwargs": kwargs},
 | 
				
			||||||
@ -140,8 +78,38 @@ class Cache:
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def set(self, args, result):
 | 
				
			||||||
 | 
					        if not self.enabled:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        is_new = args not in self.cache
 | 
				
			||||||
 | 
					        self.cache[args] = result
 | 
				
			||||||
 | 
					        await self.update_stat(args, "set")
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            self.lru.pop(self.lru.index(args))
 | 
				
			||||||
 | 
					        except (ValueError, IndexError):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        self.lru.insert(0, args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        while len(self.lru) > self.max_items:
 | 
				
			||||||
 | 
					            self.cache.pop(self.lru[-1])
 | 
				
			||||||
 | 
					            self.lru.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_new:
 | 
				
			||||||
 | 
					            self.version += 1
 | 
				
			||||||
 | 
					            # print(f"Cache store! {len(self.lru)} items. New version:", self.version, flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def delete(self, args):
 | 
				
			||||||
 | 
					        if not self.enabled:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        await self.update_stat(args, "delete")
 | 
				
			||||||
 | 
					        if args in self.cache:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                self.lru.pop(self.lru.index(args))
 | 
				
			||||||
 | 
					            except IndexError:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					            del self.cache[args]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def async_cache(self, func):
 | 
					    def async_cache(self, func):
 | 
				
			||||||
        """Decorator to cache the results of an async function."""
 | 
					 | 
				
			||||||
        @functools.wraps(func)
 | 
					        @functools.wraps(func)
 | 
				
			||||||
        async def wrapper(*args, **kwargs):
 | 
					        async def wrapper(*args, **kwargs):
 | 
				
			||||||
            cache_key = await self.create_cache_key(args, kwargs)
 | 
					            cache_key = await self.create_cache_key(args, kwargs)
 | 
				
			||||||
@ -151,14 +119,33 @@ class Cache:
 | 
				
			|||||||
            result = await func(*args, **kwargs)
 | 
					            result = await func(*args, **kwargs)
 | 
				
			||||||
            await self.set(cache_key, result)
 | 
					            await self.set(cache_key, result)
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return wrapper
 | 
					        return wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def async_delete_cache(self, func):
 | 
					    def async_delete_cache(self, func):
 | 
				
			||||||
        """Decorator to invalidate a cache entry before running an async function."""
 | 
					 | 
				
			||||||
        @functools.wraps(func)
 | 
					        @functools.wraps(func)
 | 
				
			||||||
        async def wrapper(*args, **kwargs):
 | 
					        async def wrapper(*args, **kwargs):
 | 
				
			||||||
            cache_key = await self.create_cache_key(args, kwargs)
 | 
					            cache_key = await self.create_cache_key(args, kwargs)
 | 
				
			||||||
            await self.delete(cache_key)
 | 
					            if cache_key in self.cache:
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    self.lru.pop(self.lru.index(cache_key))
 | 
				
			||||||
 | 
					                except IndexError:
 | 
				
			||||||
 | 
					                    pass
 | 
				
			||||||
 | 
					                del self.cache[cache_key]
 | 
				
			||||||
            return await func(*args, **kwargs)
 | 
					            return await func(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return wrapper
 | 
					        return wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def async_cache(func):
 | 
				
			||||||
 | 
					    cache = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @functools.wraps(func)
 | 
				
			||||||
 | 
					    async def wrapper(*args):
 | 
				
			||||||
 | 
					        if args in cache:
 | 
				
			||||||
 | 
					            return cache[args]
 | 
				
			||||||
 | 
					        result = await func(*args)
 | 
				
			||||||
 | 
					        cache[args] = result
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return wrapper
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,7 @@
 | 
				
			|||||||
DEFAULT_LIMIT = 30
 | 
					DEFAULT_LIMIT = 30
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import typing
 | 
					import typing
 | 
				
			||||||
import time
 | 
					import traceback
 | 
				
			||||||
from snek.system.model import BaseModel
 | 
					from snek.system.model import BaseModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -15,6 +15,8 @@ class BaseMapper:
 | 
				
			|||||||
    def __init__(self, app):
 | 
					    def __init__(self, app):
 | 
				
			||||||
        self.app = app
 | 
					        self.app = app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.default_limit = self.__class__.default_limit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def db(self):
 | 
					    def db(self):
 | 
				
			||||||
        return self.app.db
 | 
					        return self.app.db
 | 
				
			||||||
@ -25,7 +27,6 @@ class BaseMapper:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    async def run_in_executor(self, func, *args, **kwargs):
 | 
					    async def run_in_executor(self, func, *args, **kwargs):
 | 
				
			||||||
        use_semaphore = kwargs.pop("use_semaphore", False)
 | 
					        use_semaphore = kwargs.pop("use_semaphore", False)
 | 
				
			||||||
        start_time = time.time()
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        def _execute(): 
 | 
					        def _execute(): 
 | 
				
			||||||
            result = func(*args, **kwargs)
 | 
					            result = func(*args, **kwargs)
 | 
				
			||||||
@ -33,8 +34,10 @@ class BaseMapper:
 | 
				
			|||||||
                self.db.commit()
 | 
					                self.db.commit()
 | 
				
			||||||
            return result 
 | 
					            return result 
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        async with self.semaphore:
 | 
					        return _execute()
 | 
				
			||||||
            return await asyncio.to_thread(_execute)
 | 
					        #async with self.semaphore:
 | 
				
			||||||
 | 
					        #    return await self.loop.run_in_executor(None, _execute)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def new(self):
 | 
					    async def new(self):
 | 
				
			||||||
        return self.model_class(mapper=self, app=self.app)
 | 
					        return self.model_class(mapper=self, app=self.app)
 | 
				
			||||||
@ -48,7 +51,9 @@ class BaseMapper:
 | 
				
			|||||||
            kwargs["uid"] = uid
 | 
					            kwargs["uid"] = uid
 | 
				
			||||||
        if not kwargs.get("deleted_at"):
 | 
					        if not kwargs.get("deleted_at"):
 | 
				
			||||||
            kwargs["deleted_at"] = None
 | 
					            kwargs["deleted_at"] = None
 | 
				
			||||||
        record = await self.run_in_executor(self.table.find_one, **kwargs)
 | 
					        #traceback.print_exc()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        record = await self.db.get(self.table_name, kwargs)
 | 
				
			||||||
        if not record:
 | 
					        if not record:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
        record = dict(record)
 | 
					        record = dict(record)
 | 
				
			||||||
@ -58,23 +63,29 @@ class BaseMapper:
 | 
				
			|||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def exists(self, **kwargs):
 | 
					    async def exists(self, **kwargs):
 | 
				
			||||||
        return await self.run_in_executor(self.table.exists, **kwargs)
 | 
					        return await self.db.count(self.table_name, kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #return await self.run_in_executor(self.table.exists, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def count(self, **kwargs) -> int:
 | 
					    async def count(self, **kwargs) -> int:
 | 
				
			||||||
        return await self.run_in_executor(self.table.count, **kwargs)
 | 
					        return await self.db.count(self.table_name,kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def save(self, model: BaseModel) -> bool:
 | 
					    async def save(self, model: BaseModel) -> bool:
 | 
				
			||||||
        if not model.record.get("uid"):
 | 
					        if not model.record.get("uid"):
 | 
				
			||||||
            raise Exception(f"Attempt to save without uid: {model.record}.")
 | 
					            raise Exception(f"Attempt to save without uid: {model.record}.")
 | 
				
			||||||
        model.updated_at.update()
 | 
					        model.updated_at.update() 
 | 
				
			||||||
        return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
 | 
					        await self.upsert(model)
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					        #return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def find(self, **kwargs) -> typing.AsyncGenerator:
 | 
					    async def find(self, **kwargs) -> typing.AsyncGenerator:
 | 
				
			||||||
        if not kwargs.get("_limit"):
 | 
					        if not kwargs.get("_limit"):
 | 
				
			||||||
            kwargs["_limit"] = self.default_limit
 | 
					            kwargs["_limit"] = self.default_limit
 | 
				
			||||||
        if not kwargs.get("deleted_at"):
 | 
					        if not kwargs.get("deleted_at"):
 | 
				
			||||||
            kwargs["deleted_at"] = None
 | 
					            kwargs["deleted_at"] = None
 | 
				
			||||||
        for record in await self.run_in_executor(self.table.find, **kwargs):
 | 
					        for record in await self.db.find(self.table_name, kwargs):
 | 
				
			||||||
            model = await self.new()
 | 
					            model = await self.new()
 | 
				
			||||||
            for key, value in record.items():
 | 
					            for key, value in record.items():
 | 
				
			||||||
                model[key] = value
 | 
					                model[key] = value
 | 
				
			||||||
@ -85,21 +96,21 @@ class BaseMapper:
 | 
				
			|||||||
        return "insert" in sql or "update" in sql or "delete" in sql
 | 
					        return "insert" in sql or "update" in sql or "delete" in sql
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def query(self, sql, *args):
 | 
					    async def query(self, sql, *args):
 | 
				
			||||||
        for record in await self.run_in_executor(self.db.query, sql, *args, use_semaphore=await self._use_semaphore(sql)):
 | 
					        for record in await self.db.query(sql, *args):
 | 
				
			||||||
            yield dict(record)
 | 
					            yield dict(record)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def update(self, model):
 | 
					    async def update(self, model):
 | 
				
			||||||
        if not model["deleted_at"] is None:
 | 
					        if not model["deleted_at"] is None:
 | 
				
			||||||
            raise Exception("Can't update deleted record.")
 | 
					            raise Exception("Can't update deleted record.")
 | 
				
			||||||
        model.updated_at.update()
 | 
					        model.updated_at.update()
 | 
				
			||||||
        return await self.run_in_executor(self.table.update, model.record, ["uid"],use_semaphore=True)
 | 
					        return await self.db.update(self.table_name, model.record, {"uid": model["uid"]})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def upsert(self, model):
 | 
					    async def upsert(self, model):
 | 
				
			||||||
        model.updated_at.update()
 | 
					        model.updated_at.update()
 | 
				
			||||||
        return await self.run_in_executor(self.table.upsert, model.record, ["uid"],use_semaphore=True)
 | 
					        await self.db.upsert(self.table_name, model.record, {"uid": model["uid"]})
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def delete(self, **kwargs) -> int:
 | 
					    async def delete(self, **kwargs) -> int:
 | 
				
			||||||
        if not kwargs or not isinstance(kwargs, dict):
 | 
					        if not kwargs or not isinstance(kwargs, dict):
 | 
				
			||||||
            raise Exception("Can't execute delete with no filter.")
 | 
					            raise Exception("Can't execute delete with no filter.")
 | 
				
			||||||
        kwargs["use_semaphore"] = True
 | 
					        return await self.db.delete(self.table_name, kwargs)
 | 
				
			||||||
        return await self.run_in_executor(self.table.delete, **kwargs)
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -53,7 +53,7 @@ async def auth_middleware(request, handler):
 | 
				
			|||||||
    request["user"] = None
 | 
					    request["user"] = None
 | 
				
			||||||
    if request.session.get("uid") and request.session.get("logged_in"):
 | 
					    if request.session.get("uid") and request.session.get("logged_in"):
 | 
				
			||||||
        request["user"] = await request.app.services.user.get(
 | 
					        request["user"] = await request.app.services.user.get(
 | 
				
			||||||
            uid=request.app.session.get("uid")
 | 
					            uid=request.session.get("uid")
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    return await handler(request)
 | 
					    return await handler(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -69,5 +69,5 @@ async def cors_middleware(request, handler):
 | 
				
			|||||||
    response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
 | 
					    response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
 | 
				
			||||||
    response.headers["Access-Control-Allow-Headers"] = "*"
 | 
					    response.headers["Access-Control-Allow-Headers"] = "*"
 | 
				
			||||||
    response.headers["Access-Control-Allow-Credentials"] = "true"
 | 
					    response.headers["Access-Control-Allow-Credentials"] = "true"
 | 
				
			||||||
 
 | 
					
 | 
				
			||||||
    return response
 | 
					    return response
 | 
				
			||||||
 | 
				
			|||||||
@ -36,12 +36,12 @@ class BaseService:
 | 
				
			|||||||
        return await self.mapper.new()
 | 
					        return await self.mapper.new()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def query(self, sql, *args):
 | 
					    async def query(self, sql, *args):
 | 
				
			||||||
        for record in self.app.db.query(sql, *args):
 | 
					        for record in await self.app.db.query(sql, *args):
 | 
				
			||||||
            yield record
 | 
					            yield record
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def get(self, *args, **kwargs):
 | 
					    async def get(self, *args, **kwargs):
 | 
				
			||||||
        if not "deleted_at" in kwargs:
 | 
					        if not "deleted_at" in kwargs:
 | 
				
			||||||
            kwargs["deleted_at"] = None 
 | 
					            kwargs["deleted_at"] = None
 | 
				
			||||||
        uid = kwargs.get("uid")
 | 
					        uid = kwargs.get("uid")
 | 
				
			||||||
        if args:
 | 
					        if args:
 | 
				
			||||||
            uid = args[0]
 | 
					            uid = args[0]
 | 
				
			||||||
@ -50,7 +50,7 @@ class BaseService:
 | 
				
			|||||||
            if result and result.__class__ == self.mapper.model_class:
 | 
					            if result and result.__class__ == self.mapper.model_class:
 | 
				
			||||||
                return result
 | 
					                return result
 | 
				
			||||||
            kwargs["uid"] = uid
 | 
					            kwargs["uid"] = uid
 | 
				
			||||||
        
 | 
					        print(kwargs,"ZZZZZZZ")
 | 
				
			||||||
        result = await self.mapper.get(**kwargs)
 | 
					        result = await self.mapper.get(**kwargs)
 | 
				
			||||||
        if result:
 | 
					        if result:
 | 
				
			||||||
            await self.cache.set(result["uid"], result)
 | 
					            await self.cache.set(result["uid"], result)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										129
									
								
								src/snek/system/stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								src/snek/system/stats.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,129 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from aiohttp import web, WSMsgType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from datetime import datetime, timedelta, timezone
 | 
				
			||||||
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					import html
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def create_stats_structure():
 | 
				
			||||||
 | 
					    """Creates the nested dictionary structure for storing statistics."""
 | 
				
			||||||
 | 
					    def nested_dd():
 | 
				
			||||||
 | 
					        return defaultdict(lambda: defaultdict(int))
 | 
				
			||||||
 | 
					    return defaultdict(nested_dd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_time_keys(dt: datetime):
 | 
				
			||||||
 | 
					    """Generates dictionary keys for different time granularities."""
 | 
				
			||||||
 | 
					    return {
 | 
				
			||||||
 | 
					        "hour": dt.strftime('%Y-%m-%d-%H'),
 | 
				
			||||||
 | 
					        "day": dt.strftime('%Y-%m-%d'),
 | 
				
			||||||
 | 
					        "week": dt.strftime('%Y-%W'), # Week number, Monday is first day
 | 
				
			||||||
 | 
					        "month": dt.strftime('%Y-%m'),
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def update_stats_counters(stats_dict: defaultdict, now: datetime):
 | 
				
			||||||
 | 
					    """Increments the appropriate time-based counters in a stats dictionary."""
 | 
				
			||||||
 | 
					    keys = get_time_keys(now)
 | 
				
			||||||
 | 
					    stats_dict['by_hour'][keys['hour']] += 1
 | 
				
			||||||
 | 
					    stats_dict['by_day'][keys['day']] += 1
 | 
				
			||||||
 | 
					    stats_dict['by_week'][keys['week']] += 1
 | 
				
			||||||
 | 
					    stats_dict['by_month'][keys['month']] += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def generate_time_series_svg(title: str, data: list[tuple[str, int]], y_label: str) -> str:
 | 
				
			||||||
 | 
					    """Generates a responsive SVG bar chart for time-series data."""
 | 
				
			||||||
 | 
					    if not data:
 | 
				
			||||||
 | 
					        return f"<h3>{html.escape(title)}</h3><p>No data yet.</p>"
 | 
				
			||||||
 | 
					    max_val = max(item[1] for item in data) if data else 1
 | 
				
			||||||
 | 
					    svg_height, svg_width = 250, 600
 | 
				
			||||||
 | 
					    bar_padding = 5
 | 
				
			||||||
 | 
					    bar_width = (svg_width - 50) / len(data) - bar_padding
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    bars = ""
 | 
				
			||||||
 | 
					    labels = ""
 | 
				
			||||||
 | 
					    for i, (key, val) in enumerate(data):
 | 
				
			||||||
 | 
					        bar_height = (val / max_val) * (svg_height - 50) if max_val > 0 else 0
 | 
				
			||||||
 | 
					        x = i * (bar_width + bar_padding) + 40
 | 
				
			||||||
 | 
					        y = svg_height - bar_height - 30
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        bars += f'<rect x="{x}" y="{y}" width="{bar_width}" height="{bar_height}" fill="#007BFF"><title>{html.escape(key)}: {val}</title></rect>'
 | 
				
			||||||
 | 
					        labels += f'<text x="{x + bar_width / 2}" y="{svg_height - 15}" font-size="11" text-anchor="middle">{html.escape(key)}</text>'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return f"""
 | 
				
			||||||
 | 
					    <h3>{html.escape(title)}</h3>
 | 
				
			||||||
 | 
					    <div style="border:1px solid #ccc; padding: 10px; border-radius: 5px;">
 | 
				
			||||||
 | 
					        <svg viewBox="0 0 {svg_width} {svg_height}" style="width:100%; height:auto;">
 | 
				
			||||||
 | 
					            <g>{bars}</g>
 | 
				
			||||||
 | 
					            <g>{labels}</g>
 | 
				
			||||||
 | 
					            <line x1="35" y1="10" x2="35" y2="{svg_height - 30}" stroke="#aaa" stroke-width="1" />
 | 
				
			||||||
 | 
					            <line x1="35" y1="{svg_height - 30}" x2="{svg_width - 10}" y2="{svg_height - 30}" stroke="#aaa" stroke-width="1" />
 | 
				
			||||||
 | 
					            <text x="5" y="{svg_height - 30}" font-size="12">0</text>
 | 
				
			||||||
 | 
					            <text x="5" y="20" font-size="12">{max_val}</text>
 | 
				
			||||||
 | 
					        </svg>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@web.middleware
 | 
				
			||||||
 | 
					async def middleware(request, handler):
 | 
				
			||||||
 | 
					    """Middleware to count all incoming HTTP requests."""
 | 
				
			||||||
 | 
					    # Avoid counting requests to the stats page itself
 | 
				
			||||||
 | 
					    if request.path.startswith('/stats.html'):
 | 
				
			||||||
 | 
					        return await handler(request)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    update_stats_counters(request.app['stats']['http_requests'], datetime.now(timezone.utc))
 | 
				
			||||||
 | 
					    return await handler(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def update_websocket_stats(app):
 | 
				
			||||||
 | 
					    update_stats_counters(app['stats']['websocket_requests'], datetime.now(timezone.utc))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def pipe_and_count_websocket(ws_from, ws_to, stats_dict):
 | 
				
			||||||
 | 
					    """This function proxies WebSocket messages AND counts them."""
 | 
				
			||||||
 | 
					    async for msg in ws_from:
 | 
				
			||||||
 | 
					        # This is the key part for monitoring WebSockets
 | 
				
			||||||
 | 
					        update_stats_counters(stats_dict, datetime.now(timezone.utc))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if msg.type == WSMsgType.TEXT:
 | 
				
			||||||
 | 
					            await ws_to.send_str(msg.data)
 | 
				
			||||||
 | 
					        elif msg.type == WSMsgType.BINARY:
 | 
				
			||||||
 | 
					            await ws_to.send_bytes(msg.data)
 | 
				
			||||||
 | 
					        elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
 | 
				
			||||||
 | 
					            await ws_to.close(code=ws_from.close_code)
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def stats_handler(request: web.Request):
 | 
				
			||||||
 | 
					    """Handler to display the statistics dashboard."""
 | 
				
			||||||
 | 
					    stats = request.app['stats']
 | 
				
			||||||
 | 
					    now = datetime.now(timezone.utc)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Helper to prepare data for charts
 | 
				
			||||||
 | 
					    def get_data(source, period, count):
 | 
				
			||||||
 | 
					        data = []
 | 
				
			||||||
 | 
					        for i in range(count - 1, -1, -1):
 | 
				
			||||||
 | 
					            if period == 'hour':
 | 
				
			||||||
 | 
					                dt = now - timedelta(hours=i)
 | 
				
			||||||
 | 
					                key, label = dt.strftime('%Y-%m-%d-%H'), dt.strftime('%H:00')
 | 
				
			||||||
 | 
					                data.append((label, source['by_hour'].get(key, 0)))
 | 
				
			||||||
 | 
					            elif period == 'day':
 | 
				
			||||||
 | 
					                dt = now - timedelta(days=i)
 | 
				
			||||||
 | 
					                key, label = dt.strftime('%Y-%m-%d'), dt.strftime('%a')
 | 
				
			||||||
 | 
					                data.append((label, source['by_day'].get(key, 0)))
 | 
				
			||||||
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http_hourly = get_data(stats['http_requests'], 'hour', 24)
 | 
				
			||||||
 | 
					    ws_hourly = get_data(stats['ws_messages'], 'hour', 24)
 | 
				
			||||||
 | 
					    http_daily = get_data(stats['http_requests'], 'day', 7)
 | 
				
			||||||
 | 
					    ws_daily = get_data(stats['ws_messages'], 'day', 7)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    body = f"""
 | 
				
			||||||
 | 
					    <html><head><title>App Stats</title><meta http-equiv="refresh" content="30"></head>
 | 
				
			||||||
 | 
					    <body>
 | 
				
			||||||
 | 
					        <h2>Application Dashboard</h2>
 | 
				
			||||||
 | 
					        <h3>Last 24 Hours</h3>
 | 
				
			||||||
 | 
					        {generate_time_series_svg("HTTP Requests", http_hourly, "Reqs/Hour")}
 | 
				
			||||||
 | 
					        {generate_time_series_svg("WebSocket Messages", ws_hourly, "Msgs/Hour")}
 | 
				
			||||||
 | 
					        <h3>Last 7 Days</h3>
 | 
				
			||||||
 | 
					        {generate_time_series_svg("HTTP Requests", http_daily, "Reqs/Day")}
 | 
				
			||||||
 | 
					        {generate_time_series_svg("WebSocket Messages", ws_daily, "Msgs/Day")}
 | 
				
			||||||
 | 
					    </body></html>
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return web.Response(text=body, content_type='text/html')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -32,7 +32,7 @@
 | 
				
			|||||||
  <link rel="stylesheet" href="/base.css">
 | 
					  <link rel="stylesheet" href="/base.css">
 | 
				
			||||||
  <link rel="icon" type="image/png" href="/image/snek_logo_32x32.png" sizes="32x32">
 | 
					  <link rel="icon" type="image/png" href="/image/snek_logo_32x32.png" sizes="32x32">
 | 
				
			||||||
  <link rel="icon" type="image/png" href="/image/snek_logo_64x64.png" sizes="64x64">
 | 
					  <link rel="icon" type="image/png" href="/image/snek_logo_64x64.png" sizes="64x64">
 | 
				
			||||||
  <script nonce="{{nonce}}" defer src="https://umami.molodetz.nl/script.js" data-website-id="d127c3e4-dc70-4041-a1c8-bcc32c2492ea" defer></script>
 | 
					  <script nonce="{{nonce}}" defer src="https://umami.molodetz.nl/script.js" data-website-id="d127c3e4-dc70-4041-a1c8-bcc32c2492ea"></script>
 | 
				
			||||||
</head>
 | 
					</head>
 | 
				
			||||||
<body>
 | 
					<body>
 | 
				
			||||||
  <header>
 | 
					  <header>
 | 
				
			||||||
 | 
				
			|||||||
@ -1,97 +0,0 @@
 | 
				
			|||||||
<html>
 | 
					 | 
				
			||||||
    <head>
 | 
					 | 
				
			||||||
  <meta name="viewport" content="width=device-width, initial-scale=1.0, interactive-widget=resizes-content">
 | 
					 | 
				
			||||||
    <link rel="manifest" href="/manifest.json" />
 | 
					 | 
				
			||||||
        <style>
 | 
					 | 
				
			||||||
            body{
 | 
					 | 
				
			||||||
            background-color: black;
 | 
					 | 
				
			||||||
            color: white;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            </style>
 | 
					 | 
				
			||||||
    </head>
 | 
					 | 
				
			||||||
    <body>
 | 
					 | 
				
			||||||
<script type="module">
 | 
					 | 
				
			||||||
import { Socket } from "./socket.js";
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ChatWindow extends HTMLElement {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    constructor() {
 | 
					 | 
				
			||||||
      super();
 | 
					 | 
				
			||||||
      this.component = document.createElement("div");
 | 
					 | 
				
			||||||
    this.message_list = document.createElement("div")
 | 
					 | 
				
			||||||
        this.component.appendChild(this.message_list);
 | 
					 | 
				
			||||||
        this.chat_input = document.createElement("div")
 | 
					 | 
				
			||||||
        this.component.appendChild(this.chat_input);
 | 
					 | 
				
			||||||
        this.channelUid = null
 | 
					 | 
				
			||||||
        this.channelUid = this.getAttribute("channel")  
 | 
					 | 
				
			||||||
        this.inputText = document.createElement("textarea")
 | 
					 | 
				
			||||||
        this.inputText.addEventListener("keyup",(e)=>{
 | 
					 | 
				
			||||||
            
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            this.rpc.sendMessage(this.channelUid, e.target.value,false)
 | 
					 | 
				
			||||||
            if(e.key == "Enter" && !e.shiftKey){
 | 
					 | 
				
			||||||
                
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                this.rpc.sendMessage(this.channelUid, e.target.value,true)
 | 
					 | 
				
			||||||
                
 | 
					 | 
				
			||||||
                e.target.value = ""
 | 
					 | 
				
			||||||
            }else{
 | 
					 | 
				
			||||||
                //this.rpc.sendMessage(this.channelUid, e.target.value, false)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        })
 | 
					 | 
				
			||||||
        this.component.appendChild(this.inputText)
 | 
					 | 
				
			||||||
        this.ws = new Socket();
 | 
					 | 
				
			||||||
        this.ws.addEventListener("channel-message", this.handleMessage.bind(this))
 | 
					 | 
				
			||||||
        this.rpc = this.ws.client
 | 
					 | 
				
			||||||
        this.ws.addEventListener("update_message_text",this.handleMessage.bind(this))
 | 
					 | 
				
			||||||
        window.chat = this
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async handleMessage(data,data2) {
 | 
					 | 
				
			||||||
    if(data2 && data2.event)
 | 
					 | 
				
			||||||
        data = data.data
 | 
					 | 
				
			||||||
        console.info(["update-messagettt",data])
 | 
					 | 
				
			||||||
        console.warn(data.uid)
 | 
					 | 
				
			||||||
        if(!data.html)
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
        let div = this.message_list.querySelector('[data-uid="' + data.uid + '"]');
 | 
					 | 
				
			||||||
        console.info(div)
 | 
					 | 
				
			||||||
        if(!div){
 | 
					 | 
				
			||||||
            let temp = document.createElement("chat-message");
 | 
					 | 
				
			||||||
            temp.innerHTML = data.html
 | 
					 | 
				
			||||||
            this.message_list.appendChild(temp)
 | 
					 | 
				
			||||||
            //this.message_list.replace(div,temp)
 | 
					 | 
				
			||||||
            //div.innerHTML = data.html 
 | 
					 | 
				
			||||||
            //this.message_list.appendChild(div);
 | 
					 | 
				
			||||||
        }else{
 | 
					 | 
				
			||||||
        //    alert("HIERR")
 | 
					 | 
				
			||||||
            let temp = document.createElement("chat-message");
 | 
					 | 
				
			||||||
            temp.innerHTML = data.html;
 | 
					 | 
				
			||||||
            div.innerHTML = temp.innerHTML
 | 
					 | 
				
			||||||
            console.info("REPLACE")
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async connectedCallback() {
 | 
					 | 
				
			||||||
        await this.rpc.ping(this.channel)
 | 
					 | 
				
			||||||
        console.info(this.channelUid)
 | 
					 | 
				
			||||||
        this.messages = await this.rpc.getMessages(this.channelUid, 0, 0);
 | 
					 | 
				
			||||||
        this.messages.forEach((msg) => {
 | 
					 | 
				
			||||||
            const temp = document.createElement("div");
 | 
					 | 
				
			||||||
            temp.innerHTML = msg.html;
 | 
					 | 
				
			||||||
            this.message_list.appendChild(temp.firstChild);
 | 
					 | 
				
			||||||
        })
 | 
					 | 
				
			||||||
this.appendChild(this.component);
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
customElements.define("chat-window", ChatWindow);
 | 
					 | 
				
			||||||
</script>
 | 
					 | 
				
			||||||
        <chat-window channel="df3e1259-7d1a-4184-b75c-3befd5bf08e1"></chat-window>
 | 
					 | 
				
			||||||
    </body>
 | 
					 | 
				
			||||||
</html>
 | 
					 | 
				
			||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
{% extends "app.html" %}
 | 
					{% extends "app.html" %}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
{% block header_text %}<h2 style="color:#fff">{{ name }}</h2>{% endblock %}
 | 
					{% block header_text %}<h2 style="color:#fff">{{ name }}</h2>{% endblock %} 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
{% block main %}
 | 
					{% block main %}
 | 
				
			||||||
{% include "channel.html" %}
 | 
					{% include "channel.html" %}
 | 
				
			||||||
@ -8,14 +8,11 @@
 | 
				
			|||||||
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/xterm/css/xterm.css">
 | 
					    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/xterm/css/xterm.css">
 | 
				
			||||||
    <script src="https://cdn.jsdelivr.net/npm/xterm/lib/xterm.js"></script>
 | 
					    <script src="https://cdn.jsdelivr.net/npm/xterm/lib/xterm.js"></script>
 | 
				
			||||||
    <script src="https://cdn.jsdelivr.net/npm/xterm-addon-fit/lib/xterm-addon-fit.js"></script>
 | 
					    <script src="https://cdn.jsdelivr.net/npm/xterm-addon-fit/lib/xterm-addon-fit.js"></script>
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    <div id="terminal" class="hidden"></div>
 | 
					    <div id="terminal" class="hidden"></div>
 | 
				
			||||||
        <button id="jump-to-unread-btn" style="display: none; position: absolute; top: 10px; right: 10px; z-index: 1000; padding: 8px 16px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; box-shadow: 0 2px 8px rgba(0,0,0,0.2);">
 | 
					 | 
				
			||||||
            Jump to First Unread
 | 
					 | 
				
			||||||
        </button>
 | 
					 | 
				
			||||||
        <message-list class="chat-messages">
 | 
					        <message-list class="chat-messages">
 | 
				
			||||||
        {% if not messages %}
 | 
					        {% if not messages %}
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        <div>
 | 
					        <div>
 | 
				
			||||||
          <h1>Welcome to your new channel!</h1>
 | 
					          <h1>Welcome to your new channel!</h1>
 | 
				
			||||||
          <p>This is the start of something great. Use the commands below to get started:</p>
 | 
					          <p>This is the start of something great. Use the commands below to get started:</p>
 | 
				
			||||||
@ -40,9 +37,10 @@
 | 
				
			|||||||
{% include "dialog_help.html" %}
 | 
					{% include "dialog_help.html" %}
 | 
				
			||||||
{% include "dialog_online.html" %}
 | 
					{% include "dialog_online.html" %}
 | 
				
			||||||
<script type="module">
 | 
					<script type="module">
 | 
				
			||||||
    import {app} from "/app.js";
 | 
					    import { app } from "/app.js";
 | 
				
			||||||
 | 
					import { Schedule } from "/schedule.js";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // --- Cache selectors ---
 | 
					// --- Cache selectors ---
 | 
				
			||||||
const chatInputField = document.querySelector("chat-input");
 | 
					const chatInputField = document.querySelector("chat-input");
 | 
				
			||||||
const messagesContainer = document.querySelector(".chat-messages");
 | 
					const messagesContainer = document.querySelector(".chat-messages");
 | 
				
			||||||
const chatArea = document.querySelector(".chat-area");
 | 
					const chatArea = document.querySelector(".chat-area");
 | 
				
			||||||
@ -54,7 +52,7 @@ chatInputField.autoCompletions = {
 | 
				
			|||||||
    "/clear": () => { messagesContainer.innerHTML = ''; },
 | 
					    "/clear": () => { messagesContainer.innerHTML = ''; },
 | 
				
			||||||
    "/live": () => { chatInputField.liveType = !chatInputField.liveType; },
 | 
					    "/live": () => { chatInputField.liveType = !chatInputField.liveType; },
 | 
				
			||||||
    "/help": showHelp,
 | 
					    "/help": showHelp,
 | 
				
			||||||
    "/container": async() =>{
 | 
					    "/container": async() =>{ 
 | 
				
			||||||
        containerDialog.openWithStatus()
 | 
					        containerDialog.openWithStatus()
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
@ -101,61 +99,21 @@ setInterval(() => requestIdleCallback(updateTimes), 30000);
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// --- Paste & drag/drop uploads ---
 | 
					// --- Paste & drag/drop uploads ---
 | 
				
			||||||
const textBox = chatInputField.textarea;
 | 
					const textBox = chatInputField.textarea;
 | 
				
			||||||
 | 
					 | 
				
			||||||
function uploadDataTransfer(dt) {
 | 
					 | 
				
			||||||
    if (dt.items.length > 0) {
 | 
					 | 
				
			||||||
        const uploadButton = chatInputField.fileUploadGrid;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for (const item of dt.items) {
 | 
					 | 
				
			||||||
            if (item.kind === "file") {
 | 
					 | 
				
			||||||
                const file = item.getAsFile();
 | 
					 | 
				
			||||||
                if (file) {
 | 
					 | 
				
			||||||
                   uploadButton.uploadsStarted++
 | 
					 | 
				
			||||||
                   uploadButton.createTile(file)
 | 
					 | 
				
			||||||
                } else {
 | 
					 | 
				
			||||||
                    console.error("Failed to get file from DataTransferItem");
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
textBox.addEventListener("paste", async (e) => {
 | 
					textBox.addEventListener("paste", async (e) => {
 | 
				
			||||||
    try {
 | 
					    try {
 | 
				
			||||||
        console.log("Pasted data:", e.clipboardData);
 | 
					        const clipboardItems = await navigator.clipboard.read();
 | 
				
			||||||
        if (e.clipboardData.types.every(v => v.startsWith("text/"))) {
 | 
					        const dt = new DataTransfer();
 | 
				
			||||||
            const codeType = e.clipboardData.types.find(t => !t.startsWith('text/plain') && !t.startsWith('text/html') && !t.startsWith('text/rtf'));
 | 
					        for (const item of clipboardItems) {
 | 
				
			||||||
            const probablyCode = codeType ||e.clipboardData.types.some(t => !t.startsWith('text/plain'));
 | 
					            for (const type of item.types.filter(t => !t.startsWith('text/'))) {
 | 
				
			||||||
            for (const item of e.clipboardData.items) {
 | 
					                const blob = await item.getType(type);
 | 
				
			||||||
                if (item.kind === "string" && item.type === "text/plain") {
 | 
					                dt.items.add(new File([blob], "image.png", { type }));
 | 
				
			||||||
                    e.preventDefault();
 | 
					 | 
				
			||||||
                    item.getAsString(text => {
 | 
					 | 
				
			||||||
                        const value = chatInputField.value;
 | 
					 | 
				
			||||||
                        if (probablyCode) {
 | 
					 | 
				
			||||||
                            let code = text;
 | 
					 | 
				
			||||||
                            const minIndentDepth = code.split('\n').reduce((acc, line) => {
 | 
					 | 
				
			||||||
                                if (!line.trim()) return acc;
 | 
					 | 
				
			||||||
                                const match = line.match(/^(\s*)/);
 | 
					 | 
				
			||||||
                                return match ? Math.min(acc, match[1].length) : acc;
 | 
					 | 
				
			||||||
                            }, 9000);
 | 
					 | 
				
			||||||
                            code = code.split('\n').map(line => line.slice(minIndentDepth)).join('\n')
 | 
					 | 
				
			||||||
                            text = `\`\`\`${codeType?.split('/')?.[1] ?? ''}\n${code}\n\`\`\`\n`;
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                        const area = chatInputField.textarea
 | 
					 | 
				
			||||||
                        if(area){
 | 
					 | 
				
			||||||
                            const start = area.selectionStart
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            if ("\n" !== value[start - 1]) {
 | 
					 | 
				
			||||||
                                text = `\n${text}`;
 | 
					 | 
				
			||||||
                            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            document.execCommand('insertText', false, text);
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    });
 | 
					 | 
				
			||||||
                    break;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        } else {
 | 
					        }
 | 
				
			||||||
            uploadDataTransfer(e.clipboardData);
 | 
					        if (dt.items.length > 0) {
 | 
				
			||||||
 | 
					            const uploadButton = chatInputField.uploadButton;
 | 
				
			||||||
 | 
					            const input = uploadButton.shadowRoot.querySelector('.file-input');
 | 
				
			||||||
 | 
					            input.files = dt.files;
 | 
				
			||||||
 | 
					            await uploadButton.uploadFiles();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    } catch (error) {
 | 
					    } catch (error) {
 | 
				
			||||||
        console.error("Failed to read clipboard contents: ", error);
 | 
					        console.error("Failed to read clipboard contents: ", error);
 | 
				
			||||||
@ -163,7 +121,13 @@ textBox.addEventListener("paste", async (e) => {
 | 
				
			|||||||
});
 | 
					});
 | 
				
			||||||
chatArea.addEventListener("drop", async (e) => {
 | 
					chatArea.addEventListener("drop", async (e) => {
 | 
				
			||||||
    e.preventDefault();
 | 
					    e.preventDefault();
 | 
				
			||||||
    uploadDataTransfer(e.dataTransfer);
 | 
					    const dt = e.dataTransfer;
 | 
				
			||||||
 | 
					    if (dt.items.length > 0) {
 | 
				
			||||||
 | 
					        const uploadButton = chatInputField.uploadButton;
 | 
				
			||||||
 | 
					        const input = uploadButton.shadowRoot.querySelector('.file-input');
 | 
				
			||||||
 | 
					        input.files = dt.files;
 | 
				
			||||||
 | 
					        await uploadButton.uploadFiles();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
chatArea.addEventListener("dragover", e => {
 | 
					chatArea.addEventListener("dragover", e => {
 | 
				
			||||||
    e.preventDefault();
 | 
					    e.preventDefault();
 | 
				
			||||||
@ -216,7 +180,7 @@ app.ws.addEventListener("starfield.render_word", (data) => {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// --- Channel message event ---
 | 
					// --- Channel message event ---
 | 
				
			||||||
app.addEventListener("channel-message", (data) => {
 | 
					app.addEventListener("channel-message", (data) => {
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    let display = data.text && data.text.trim() ? 'block' : 'none';
 | 
					    let display = data.text && data.text.trim() ? 'block' : 'none';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (data.channel_uid !== channelUid) {
 | 
					    if (data.channel_uid !== channelUid) {
 | 
				
			||||||
@ -257,7 +221,7 @@ document.addEventListener('keydown', function(event) {
 | 
				
			|||||||
        clearTimeout(keyTimeout);
 | 
					        clearTimeout(keyTimeout);
 | 
				
			||||||
        keyTimeout = setTimeout(() => { gPressCount = 0; }, 300);
 | 
					        keyTimeout = setTimeout(() => { gPressCount = 0; }, 300);
 | 
				
			||||||
        if (gPressCount === 2) {
 | 
					        if (gPressCount === 2) {
 | 
				
			||||||
            gPressCount = 0;
 | 
					            gPressCount = 0; 
 | 
				
			||||||
            messagesContainer.lastElementChild?.scrollIntoView({ block: "end", inline: "nearest" });
 | 
					            messagesContainer.lastElementChild?.scrollIntoView({ block: "end", inline: "nearest" });
 | 
				
			||||||
            loadExtra();
 | 
					            loadExtra();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@ -303,76 +267,7 @@ function isScrolledPastHalf() {
 | 
				
			|||||||
// --- Initial layout update ---
 | 
					// --- Initial layout update ---
 | 
				
			||||||
updateLayout(true);
 | 
					updateLayout(true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// --- Jump to unread functionality ---
 | 
					 | 
				
			||||||
const jumpToUnreadBtn = document.getElementById('jump-to-unread-btn');
 | 
					 | 
				
			||||||
let firstUnreadMessageUid = null;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
async function checkForUnreadMessages() {
 | 
					 | 
				
			||||||
    try {
 | 
					 | 
				
			||||||
        const uid = await app.rpc.getFirstUnreadMessageUid(channelUid);
 | 
					 | 
				
			||||||
        if (uid) {
 | 
					 | 
				
			||||||
            firstUnreadMessageUid = uid;
 | 
					 | 
				
			||||||
            const messageElement = messagesContainer.querySelector(`[data-uid="${uid}"]`);
 | 
					 | 
				
			||||||
            if (messageElement && !messagesContainer.isElementVisible(messageElement)) {
 | 
					 | 
				
			||||||
                jumpToUnreadBtn.style.display = 'block';
 | 
					 | 
				
			||||||
            } else {
 | 
					 | 
				
			||||||
                jumpToUnreadBtn.style.display = 'none';
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            jumpToUnreadBtn.style.display = 'none';
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    } catch (error) {
 | 
					 | 
				
			||||||
        console.error('Error checking for unread messages:', error);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async function jumpToUnread() {
 | 
					 | 
				
			||||||
    if (!firstUnreadMessageUid) {
 | 
					 | 
				
			||||||
        await checkForUnreadMessages();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (firstUnreadMessageUid) {
 | 
					 | 
				
			||||||
        let messageElement = messagesContainer.querySelector(`[data-uid="${firstUnreadMessageUid}"]`);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (!messageElement) {
 | 
					 | 
				
			||||||
            const messages = await app.rpc.getMessages(channelUid, 0, null);
 | 
					 | 
				
			||||||
            const targetMessage = messages.find(m => m.uid === firstUnreadMessageUid);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (targetMessage) {
 | 
					 | 
				
			||||||
                const temp = document.createElement("div");
 | 
					 | 
				
			||||||
                temp.innerHTML = targetMessage.html;
 | 
					 | 
				
			||||||
                const newMessageElement = temp.firstChild;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                messagesContainer.endOfMessages.after(newMessageElement);
 | 
					 | 
				
			||||||
                messagesContainer.messageMap.set(targetMessage.uid, newMessageElement);
 | 
					 | 
				
			||||||
                messagesContainer._observer.observe(newMessageElement);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                messageElement = newMessageElement;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (messageElement) {
 | 
					 | 
				
			||||||
            messageElement.scrollIntoView({ behavior: 'smooth', block: 'center' });
 | 
					 | 
				
			||||||
            messageElement.style.animation = 'highlight-fade 2s';
 | 
					 | 
				
			||||||
            setTimeout(() => {
 | 
					 | 
				
			||||||
                messageElement.style.animation = '';
 | 
					 | 
				
			||||||
            }, 2000);
 | 
					 | 
				
			||||||
            jumpToUnreadBtn.style.display = 'none';
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
jumpToUnreadBtn.addEventListener('click', jumpToUnread);
 | 
					 | 
				
			||||||
checkForUnreadMessages();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const style = document.createElement('style');
 | 
					 | 
				
			||||||
style.textContent = `
 | 
					 | 
				
			||||||
    @keyframes highlight-fade {
 | 
					 | 
				
			||||||
        0% { background-color: rgba(255, 255, 0, 0.3); }
 | 
					 | 
				
			||||||
        100% { background-color: transparent; }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
`;
 | 
					 | 
				
			||||||
document.head.appendChild(style);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
</script>
 | 
					</script>
 | 
				
			||||||
{% endblock %}
 | 
					{% endblock %}
 | 
				
			||||||
 | 
				
			|||||||
@ -71,10 +71,7 @@ class AvatarView(BaseView):
 | 
				
			|||||||
        uid = self.request.match_info.get("uid")
 | 
					        uid = self.request.match_info.get("uid")
 | 
				
			||||||
        if uid == "unique":
 | 
					        if uid == "unique":
 | 
				
			||||||
            uid = str(uuid.uuid4())
 | 
					            uid = str(uuid.uuid4())
 | 
				
			||||||
        avatar = await self.app.get(uid)
 | 
					        avatar = multiavatar.multiavatar(uid, True, None)
 | 
				
			||||||
        if not avatar:
 | 
					 | 
				
			||||||
            avatar = multiavatar.multiavatar(uid, True, None)
 | 
					 | 
				
			||||||
            await self.app.set(uid, avatar)
 | 
					 | 
				
			||||||
        response = web.Response(text=avatar, content_type="image/svg+xml")
 | 
					        response = web.Response(text=avatar, content_type="image/svg+xml")
 | 
				
			||||||
        response.headers["Cache-Control"] = f"public, max-age={1337*42}"
 | 
					        response.headers["Cache-Control"] = f"public, max-age={1337*42}"
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +0,0 @@
 | 
				
			|||||||
from snek.system.view import BaseView
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class NewView(BaseView):
 | 
					 | 
				
			||||||
    login_required = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def get(self):
 | 
					 | 
				
			||||||
        return await self.render_template("new.html")
 | 
					 | 
				
			||||||
@ -6,6 +6,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
 | 
					# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from snek.system.stats import update_websocket_stats 
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
@ -175,26 +176,6 @@ class RPCView(BaseView):
 | 
				
			|||||||
                messages.append(extended_dict)
 | 
					                messages.append(extended_dict)
 | 
				
			||||||
            return messages
 | 
					            return messages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async def get_first_unread_message_uid(self, channel_uid):
 | 
					 | 
				
			||||||
            self._require_login()
 | 
					 | 
				
			||||||
            channel_member = await self.services.channel_member.get(
 | 
					 | 
				
			||||||
                channel_uid=channel_uid, user_uid=self.user_uid
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            if not channel_member:
 | 
					 | 
				
			||||||
                return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            last_read_at = channel_member.get("last_read_at")
 | 
					 | 
				
			||||||
            if not last_read_at:
 | 
					 | 
				
			||||||
                return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            async for message in self.services.channel_message.query(
 | 
					 | 
				
			||||||
                "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid AND created_at > :last_read_at AND deleted_at IS NULL ORDER BY created_at ASC LIMIT 1",
 | 
					 | 
				
			||||||
                {"channel_uid": channel_uid, "last_read_at": last_read_at}
 | 
					 | 
				
			||||||
            ):
 | 
					 | 
				
			||||||
                return message["uid"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        async def get_channels(self):
 | 
					        async def get_channels(self):
 | 
				
			||||||
            self._require_login()
 | 
					            self._require_login()
 | 
				
			||||||
            channels = []
 | 
					            channels = []
 | 
				
			||||||
@ -526,7 +507,9 @@ class RPCView(BaseView):
 | 
				
			|||||||
                    raise Exception("Method not found")
 | 
					                    raise Exception("Method not found")
 | 
				
			||||||
                success = True
 | 
					                success = True
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
 | 
					                    update_websocket_stats(self.app)
 | 
				
			||||||
                    result = await method(*args)
 | 
					                    result = await method(*args)
 | 
				
			||||||
 | 
					                    update_websocket_stats(self.app)
 | 
				
			||||||
                except Exception as ex:
 | 
					                except Exception as ex:
 | 
				
			||||||
                    result = {"exception": str(ex), "traceback": traceback.format_exc()}
 | 
					                    result = {"exception": str(ex), "traceback": traceback.format_exc()}
 | 
				
			||||||
                    success = False
 | 
					                    success = False
 | 
				
			||||||
@ -546,8 +529,8 @@ class RPCView(BaseView):
 | 
				
			|||||||
            try:
 | 
					            try:
 | 
				
			||||||
                await self.ws.send_str(json.dumps(obj, default=str))
 | 
					                await self.ws.send_str(json.dumps(obj, default=str))
 | 
				
			||||||
            except Exception as ex:
 | 
					            except Exception as ex:
 | 
				
			||||||
                print("THIS IS THE DeAL>",str(ex), flush=True)
 | 
					 | 
				
			||||||
                await self.services.socket.delete(self.ws)
 | 
					                await self.services.socket.delete(self.ws)
 | 
				
			||||||
 | 
					                await self.ws.close()
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        async def get_online_users(self, channel_uid):
 | 
					        async def get_online_users(self, channel_uid):
 | 
				
			||||||
            self._require_login()
 | 
					            self._require_login()
 | 
				
			||||||
@ -655,7 +638,7 @@ class RPCView(BaseView):
 | 
				
			|||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    await rpc(msg.json())
 | 
					                    await rpc(msg.json())
 | 
				
			||||||
                except Exception as ex:
 | 
					                except Exception as ex:
 | 
				
			||||||
                    print("XXXXXXXXXX Deleting socket", ex, flush=True)
 | 
					                    print("Deleting socket", ex, flush=True)
 | 
				
			||||||
                    logger.exception(ex)
 | 
					                    logger.exception(ex)
 | 
				
			||||||
                    await self.services.socket.delete(ws)
 | 
					                    await self.services.socket.delete(ws)
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
 | 
				
			|||||||
@ -38,6 +38,10 @@ class WebView(BaseView):
 | 
				
			|||||||
        channel = await self.services.channel.get(
 | 
					        channel = await self.services.channel.get(
 | 
				
			||||||
            uid=self.request.match_info.get("channel")
 | 
					            uid=self.request.match_info.get("channel")
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        print(self.session.get("uid"),"ZZZZZZZZZZ")
 | 
				
			||||||
 | 
					        qq = await self.services.user.get(uid=self.session.get("uid"))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        print("GGGGGGGGGG",qq)
 | 
				
			||||||
        if not channel:
 | 
					        if not channel:
 | 
				
			||||||
            user = await self.services.user.get(
 | 
					            user = await self.services.user.get(
 | 
				
			||||||
                uid=self.request.match_info.get("channel")
 | 
					                uid=self.request.match_info.get("channel")
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										78
									
								
								src/snekssh/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								src/snekssh/app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncssh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					asyncssh.set_debug_level(2)
 | 
				
			||||||
 | 
					logging.basicConfig(level=logging.DEBUG)
 | 
				
			||||||
 | 
					# Configuration for SFTP server
 | 
				
			||||||
 | 
					SFTP_ROOT = "."  # Directory to serve
 | 
				
			||||||
 | 
					USERNAME = "test"
 | 
				
			||||||
 | 
					PASSWORD = "woeii"
 | 
				
			||||||
 | 
					HOST = "localhost"
 | 
				
			||||||
 | 
					PORT = 2225
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MySFTPServer(asyncssh.SFTPServer):
 | 
				
			||||||
 | 
					    def __init__(self, chan):
 | 
				
			||||||
 | 
					        super().__init__(chan)
 | 
				
			||||||
 | 
					        self.root = os.path.abspath(SFTP_ROOT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def stat(self, path):
 | 
				
			||||||
 | 
					        """Handles 'stat' command from SFTP client"""
 | 
				
			||||||
 | 
					        full_path = os.path.join(self.root, path.lstrip("/"))
 | 
				
			||||||
 | 
					        return await super().stat(full_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def open(self, path, flags, attrs):
 | 
				
			||||||
 | 
					        """Handles file open requests"""
 | 
				
			||||||
 | 
					        full_path = os.path.join(self.root, path.lstrip("/"))
 | 
				
			||||||
 | 
					        return await super().open(full_path, flags, attrs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def listdir(self, path):
 | 
				
			||||||
 | 
					        """Handles directory listing"""
 | 
				
			||||||
 | 
					        full_path = os.path.join(self.root, path.lstrip("/"))
 | 
				
			||||||
 | 
					        return await super().listdir(full_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MySSHServer(asyncssh.SSHServer):
 | 
				
			||||||
 | 
					    """Custom SSH server to handle authentication"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connection_made(self, conn):
 | 
				
			||||||
 | 
					        print(f"New connection from {conn.get_extra_info('peername')}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connection_lost(self, exc):
 | 
				
			||||||
 | 
					        print("Client disconnected")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def begin_auth(self, username):
 | 
				
			||||||
 | 
					        return True  # No additional authentication steps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def password_auth_supported(self):
 | 
				
			||||||
 | 
					        return True  # Support password authentication
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate_password(self, username, password):
 | 
				
			||||||
 | 
					        print(username, password)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					        return username == USERNAME and password == PASSWORD
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def start_sftp_server():
 | 
				
			||||||
 | 
					    os.makedirs(SFTP_ROOT, exist_ok=True)  # Ensure the root directory exists
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await asyncssh.create_server(
 | 
				
			||||||
 | 
					        lambda: MySSHServer(),
 | 
				
			||||||
 | 
					        host=HOST,
 | 
				
			||||||
 | 
					        port=PORT,
 | 
				
			||||||
 | 
					        server_host_keys=["ssh_host_key"],
 | 
				
			||||||
 | 
					        process_factory=MySFTPServer,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    print(f"SFTP server running on {HOST}:{PORT}")
 | 
				
			||||||
 | 
					    await asyncio.Future()  # Keep running forever
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        asyncio.run(start_sftp_server())
 | 
				
			||||||
 | 
					    except (OSError, asyncssh.Error) as e:
 | 
				
			||||||
 | 
					        print(f"Error starting SFTP server: {e}")
 | 
				
			||||||
							
								
								
									
										77
									
								
								src/snekssh/app2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								src/snekssh/app2.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncssh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# SSH Server Configuration
 | 
				
			||||||
 | 
					HOST = "0.0.0.0"
 | 
				
			||||||
 | 
					PORT = 2225
 | 
				
			||||||
 | 
					USERNAME = "user"
 | 
				
			||||||
 | 
					PASSWORD = "password"
 | 
				
			||||||
 | 
					SHELL = "/bin/sh"  # Change to another shell if needed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CustomSSHServer(asyncssh.SSHServer):
 | 
				
			||||||
 | 
					    def connection_made(self, conn):
 | 
				
			||||||
 | 
					        print(f"New connection from {conn.get_extra_info('peername')}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connection_lost(self, exc):
 | 
				
			||||||
 | 
					        print("Client disconnected")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def password_auth_supported(self):
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate_password(self, username, password):
 | 
				
			||||||
 | 
					        return username == USERNAME and password == PASSWORD
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def custom_bash_process(process):
 | 
				
			||||||
 | 
					    """Spawns a custom bash shell process"""
 | 
				
			||||||
 | 
					    env = os.environ.copy()
 | 
				
			||||||
 | 
					    env["TERM"] = "xterm-256color"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Start the Bash shell
 | 
				
			||||||
 | 
					    bash_proc = await asyncio.create_subprocess_exec(
 | 
				
			||||||
 | 
					        SHELL,
 | 
				
			||||||
 | 
					        "-i",
 | 
				
			||||||
 | 
					        stdin=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        stdout=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        stderr=asyncio.subprocess.PIPE,
 | 
				
			||||||
 | 
					        env=env,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def read_output():
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            data = await bash_proc.stdout.read(1)
 | 
				
			||||||
 | 
					            if not data:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            process.stdout.write(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def read_input():
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            data = await process.stdin.read(1)
 | 
				
			||||||
 | 
					            if not data:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            bash_proc.stdin.write(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await asyncio.gather(read_output(), read_input())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def start_ssh_server():
 | 
				
			||||||
 | 
					    """Starts the AsyncSSH server with Bash"""
 | 
				
			||||||
 | 
					    await asyncssh.create_server(
 | 
				
			||||||
 | 
					        lambda: CustomSSHServer(),
 | 
				
			||||||
 | 
					        host=HOST,
 | 
				
			||||||
 | 
					        port=PORT,
 | 
				
			||||||
 | 
					        server_host_keys=["ssh_host_key"],
 | 
				
			||||||
 | 
					        process_factory=custom_bash_process,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    print(f"SSH server running on {HOST}:{PORT}")
 | 
				
			||||||
 | 
					    await asyncio.Future()  # Keep running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        asyncio.run(start_ssh_server())
 | 
				
			||||||
 | 
					    except (OSError, asyncssh.Error) as e:
 | 
				
			||||||
 | 
					        print(f"Error starting SSH server: {e}")
 | 
				
			||||||
							
								
								
									
										74
									
								
								src/snekssh/app3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								src/snekssh/app3.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,74 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3.7
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This program and the accompanying materials are made available under
 | 
				
			||||||
 | 
					# the terms of the Eclipse Public License v2.0 which accompanies this
 | 
				
			||||||
 | 
					# distribution and is available at:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.eclipse.org/legal/epl-2.0/
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This program may also be made available under the following secondary
 | 
				
			||||||
 | 
					# licenses when the conditions for such availability set forth in the
 | 
				
			||||||
 | 
					# Eclipse Public License v2.0 are satisfied:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    GNU General Public License, Version 2.0, or any later versions of
 | 
				
			||||||
 | 
					#    that license
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Contributors:
 | 
				
			||||||
 | 
					#     Ron Frederick - initial implementation, API, and documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# To run this program, the file ``ssh_host_key`` must exist with an SSH
 | 
				
			||||||
 | 
					# private key in it to use as a server host key. An SSH host certificate
 | 
				
			||||||
 | 
					# can optionally be provided in the file ``ssh_host_key-cert.pub``.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The file ``ssh_user_ca`` must exist with a cert-authority entry of
 | 
				
			||||||
 | 
					# the certificate authority which can sign valid client certificates.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncssh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def handle_client(process: asyncssh.SSHServerProcess) -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    width, height, pixwidth, pixheight = process.term_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    process.stdout.write(
 | 
				
			||||||
 | 
					        f"Terminal type: {process.term_type}, " f"size: {width}x{height}"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if pixwidth and pixheight:
 | 
				
			||||||
 | 
					        process.stdout.write(f" ({pixwidth}x{pixheight} pixels)")
 | 
				
			||||||
 | 
					    process.stdout.write("\nTry resizing your window!\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    while not process.stdin.at_eof():
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            await process.stdin.read()
 | 
				
			||||||
 | 
					        except asyncssh.TerminalSizeChanged as exc:
 | 
				
			||||||
 | 
					            process.stdout.write(f"New window size: {exc.width}x{exc.height}")
 | 
				
			||||||
 | 
					            if exc.pixwidth and exc.pixheight:
 | 
				
			||||||
 | 
					                process.stdout.write(f" ({exc.pixwidth}" f"x{exc.pixheight} pixels)")
 | 
				
			||||||
 | 
					            process.stdout.write("\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def start_server() -> None:
 | 
				
			||||||
 | 
					    await asyncssh.listen(
 | 
				
			||||||
 | 
					        "",
 | 
				
			||||||
 | 
					        2230,
 | 
				
			||||||
 | 
					        server_host_keys=["ssh_host_key"],
 | 
				
			||||||
 | 
					        # authorized_client_keys='ssh_user_ca',
 | 
				
			||||||
 | 
					        process_factory=handle_client,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loop = asyncio.new_event_loop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    loop.run_until_complete(start_server())
 | 
				
			||||||
 | 
					except (OSError, asyncssh.Error) as exc:
 | 
				
			||||||
 | 
					    sys.exit("Error starting server: " + str(exc))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loop.run_forever()
 | 
				
			||||||
							
								
								
									
										90
									
								
								src/snekssh/app4.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								src/snekssh/app4.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3.7
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This program and the accompanying materials are made available under
 | 
				
			||||||
 | 
					# the terms of the Eclipse Public License v2.0 which accompanies this
 | 
				
			||||||
 | 
					# distribution and is available at:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.eclipse.org/legal/epl-2.0/
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This program may also be made available under the following secondary
 | 
				
			||||||
 | 
					# licenses when the conditions for such availability set forth in the
 | 
				
			||||||
 | 
					# Eclipse Public License v2.0 are satisfied:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    GNU General Public License, Version 2.0, or any later versions of
 | 
				
			||||||
 | 
					#    that license
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Contributors:
 | 
				
			||||||
 | 
					#     Ron Frederick - initial implementation, API, and documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# To run this program, the file ``ssh_host_key`` must exist with an SSH
 | 
				
			||||||
 | 
					# private key in it to use as a server host key. An SSH host certificate
 | 
				
			||||||
 | 
					# can optionally be provided in the file ``ssh_host_key-cert.pub``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncssh
 | 
				
			||||||
 | 
					import bcrypt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					passwords = {
 | 
				
			||||||
 | 
					    "guest": b"",  # guest account with no password
 | 
				
			||||||
 | 
					    "user": bcrypt.hashpw(b"user", bcrypt.gensalt()),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def handle_client(process: asyncssh.SSHServerProcess) -> None:
 | 
				
			||||||
 | 
					    username = process.get_extra_info("username")
 | 
				
			||||||
 | 
					    process.stdout.write(f"Welcome to my SSH server, {username}!\n")
 | 
				
			||||||
 | 
					    # process.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MySSHServer(asyncssh.SSHServer):
 | 
				
			||||||
 | 
					    def connection_made(self, conn: asyncssh.SSHServerConnection) -> None:
 | 
				
			||||||
 | 
					        peername = conn.get_extra_info("peername")[0]
 | 
				
			||||||
 | 
					        print(f"SSH connection received from {peername}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def connection_lost(self, exc: Optional[Exception]) -> None:
 | 
				
			||||||
 | 
					        if exc:
 | 
				
			||||||
 | 
					            print("SSH connection error: " + str(exc), file=sys.stderr)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            print("SSH connection closed.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def begin_auth(self, username: str) -> bool:
 | 
				
			||||||
 | 
					        # If the user's password is the empty string, no auth is required
 | 
				
			||||||
 | 
					        return passwords.get(username) != b""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def password_auth_supported(self) -> bool:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate_password(self, username: str, password: str) -> bool:
 | 
				
			||||||
 | 
					        if username not in passwords:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        pw = passwords[username]
 | 
				
			||||||
 | 
					        if not password and not pw:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					        return bcrypt.checkpw(password.encode("utf-8"), pw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def start_server() -> None:
 | 
				
			||||||
 | 
					    await asyncssh.create_server(
 | 
				
			||||||
 | 
					        MySSHServer,
 | 
				
			||||||
 | 
					        "",
 | 
				
			||||||
 | 
					        2231,
 | 
				
			||||||
 | 
					        server_host_keys=["ssh_host_key"],
 | 
				
			||||||
 | 
					        process_factory=handle_client,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loop = asyncio.new_event_loop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    loop.run_until_complete(start_server())
 | 
				
			||||||
 | 
					except (OSError, asyncssh.Error) as exc:
 | 
				
			||||||
 | 
					    sys.exit("Error starting server: " + str(exc))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loop.run_forever()
 | 
				
			||||||
							
								
								
									
										112
									
								
								src/snekssh/app5.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								src/snekssh/app5.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,112 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3.7
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2016-2024 by Ron Frederick <ronf@timeheart.net> and others.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This program and the accompanying materials are made available under
 | 
				
			||||||
 | 
					# the terms of the Eclipse Public License v2.0 which accompanies this
 | 
				
			||||||
 | 
					# distribution and is available at:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.eclipse.org/legal/epl-2.0/
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This program may also be made available under the following secondary
 | 
				
			||||||
 | 
					# licenses when the conditions for such availability set forth in the
 | 
				
			||||||
 | 
					# Eclipse Public License v2.0 are satisfied:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#    GNU General Public License, Version 2.0, or any later versions of
 | 
				
			||||||
 | 
					#    that license
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Contributors:
 | 
				
			||||||
 | 
					#     Ron Frederick - initial implementation, API, and documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# To run this program, the file ``ssh_host_key`` must exist with an SSH
 | 
				
			||||||
 | 
					# private key in it to use as a server host key. An SSH host certificate
 | 
				
			||||||
 | 
					# can optionally be provided in the file ``ssh_host_key-cert.pub``.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The file ``ssh_user_ca`` must exist with a cert-authority entry of
 | 
				
			||||||
 | 
					# the certificate authority which can sign valid client certificates.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from typing import List, cast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncssh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChatClient:
 | 
				
			||||||
 | 
					    _clients: List["ChatClient"] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, process: asyncssh.SSHServerProcess):
 | 
				
			||||||
 | 
					        self._process = process
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    async def handle_client(cls, process: asyncssh.SSHServerProcess):
 | 
				
			||||||
 | 
					        await cls(process).run()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def readline(self) -> str:
 | 
				
			||||||
 | 
					        return cast(str, self._process.stdin.readline())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def write(self, msg: str) -> None:
 | 
				
			||||||
 | 
					        self._process.stdout.write(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def broadcast(self, msg: str) -> None:
 | 
				
			||||||
 | 
					        for client in self._clients:
 | 
				
			||||||
 | 
					            if client != self:
 | 
				
			||||||
 | 
					                client.write(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def begin_auth(self, username: str) -> bool:
 | 
				
			||||||
 | 
					        # If the user's password is the empty string, no auth is required
 | 
				
			||||||
 | 
					        # return False
 | 
				
			||||||
 | 
					        return True  # passwords.get(username) != b''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def password_auth_supported(self) -> bool:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validate_password(self, username: str, password: str) -> bool:
 | 
				
			||||||
 | 
					        # if username not in passwords:
 | 
				
			||||||
 | 
					        #    return False
 | 
				
			||||||
 | 
					        # pw = passwords[username]
 | 
				
			||||||
 | 
					        # if not password and not pw:
 | 
				
			||||||
 | 
					        #    return True
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					        # return bcrypt.checkpw(password.encode('utf-8'), pw)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def run(self) -> None:
 | 
				
			||||||
 | 
					        self.write("Welcome to chat!\n\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.write("Enter your name: ")
 | 
				
			||||||
 | 
					        name = (await self.readline()).rstrip("\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.write(f"\n{len(self._clients)} other users are connected.\n\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._clients.append(self)
 | 
				
			||||||
 | 
					        self.broadcast(f"*** {name} has entered chat ***\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            async for line in self._process.stdin:
 | 
				
			||||||
 | 
					                self.broadcast(f"{name}: {line}")
 | 
				
			||||||
 | 
					        except asyncssh.BreakReceived:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.broadcast(f"*** {name} has left chat ***\n")
 | 
				
			||||||
 | 
					        self._clients.remove(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def start_server() -> None:
 | 
				
			||||||
 | 
					    await asyncssh.listen(
 | 
				
			||||||
 | 
					        "",
 | 
				
			||||||
 | 
					        2235,
 | 
				
			||||||
 | 
					        server_host_keys=["ssh_host_key"],
 | 
				
			||||||
 | 
					        process_factory=ChatClient.handle_client,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loop = asyncio.new_event_loop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    loop.run_until_complete(start_server())
 | 
				
			||||||
 | 
					except (OSError, asyncssh.Error) as exc:
 | 
				
			||||||
 | 
					    sys.exit("Error starting server: " + str(exc))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					loop.run_forever()
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user