diff --git a/src/snek/__main__.py b/src/snek/__main__.py
index 360bf5a..a2e433a 100644
--- a/src/snek/__main__.py
+++ b/src/snek/__main__.py
@@ -1,27 +1,31 @@
-import click
-import uvloop
-from aiohttp import web
-import asyncio
-from snek.app import Application
-from IPython import start_ipython
-import sqlite3
import pathlib
-import shutil
+import shutil
+import sqlite3
+
+import click
+from aiohttp import web
+from IPython import start_ipython
+
+from snek.app import Application
+
@click.group()
def cli():
pass
+
@cli.command()
-@click.option('--db_path',default="snek.db", help='Database to initialize if not exists.')
-@click.option('--source',default=None, help='Database to initialize if not exists.')
-def init(db_path,source):
+@click.option(
+ "--db_path", default="snek.db", help="Database to initialize if not exists."
+)
+@click.option("--source", default=None, help="Database to initialize if not exists.")
+def init(db_path, source):
if source and pathlib.Path(source).exists():
print(f"Copying {source} to {db_path}")
- shutil.copy2(source,db_path)
+ shutil.copy2(source, db_path)
print("Database initialized.")
return
-
+
if pathlib.Path(db_path).exists():
return
print(f"Initializing database at {db_path}")
@@ -33,25 +37,44 @@ def init(db_path,source):
db.close()
print("Database initialized.")
-@cli.command()
-@click.option('--port', default=8081, show_default=True, help='Port to run the application on')
-@click.option('--host', default='0.0.0.0', show_default=True, help='Host to run the application on')
-@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
-def serve(port, host, db_path):
- #init(db_path)
- #asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
- web.run_app(
- Application(db_path=f"sqlite:///{db_path}"), port=port, host=host
- )
@cli.command()
-@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application')
+@click.option(
+ "--port", default=8081, show_default=True, help="Port to run the application on"
+)
+@click.option(
+ "--host",
+ default="0.0.0.0",
+ show_default=True,
+ help="Host to run the application on",
+)
+@click.option(
+ "--db_path",
+ default="snek.db",
+ show_default=True,
+ help="Database path for the application",
+)
+def serve(port, host, db_path):
+ # init(db_path)
+ # asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+ web.run_app(Application(db_path=f"sqlite:///{db_path}"), port=port, host=host)
+
+
+@cli.command()
+@click.option(
+ "--db_path",
+ default="snek.db",
+ show_default=True,
+ help="Database path for the application",
+)
def shell(db_path):
app = Application(db_path=f"sqlite:///{db_path}")
- start_ipython(argv=[], user_ns={'app': app})
+ start_ipython(argv=[], user_ns={"app": app})
+
def main():
cli()
+
if __name__ == "__main__":
main()
diff --git a/src/snek/app.py b/src/snek/app.py
index 03ef491..765e4fb 100644
--- a/src/snek/app.py
+++ b/src/snek/app.py
@@ -4,13 +4,15 @@ import pathlib
import ssl
import uuid
from datetime import datetime
+
from snek import snode
from snek.view.threads import ThreadsView
logging.basicConfig(level=logging.DEBUG)
-from ipaddress import ip_address
from concurrent.futures import ThreadPoolExecutor
+from ipaddress import ip_address
+import IP2Location
from aiohttp import web
from aiohttp_session import (
get_session as session_get,
@@ -20,54 +22,56 @@ from aiohttp_session import (
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from app.app import Application as BaseApplication
from jinja2 import FileSystemLoader
-import IP2Location
-from snek.sssh import start_ssh_server
-from snek.docs.app import Application as DocsApplication
from snek.mapper import get_mappers
from snek.service import get_services
+from snek.sgit import GitApplication
+from snek.sssh import start_ssh_server
from snek.system import http
from snek.system.cache import Cache
from snek.system.markdown import MarkdownExtension
from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware
from snek.system.profiler import profiler_handler
-from snek.system.template import EmojiExtension, LinkifyExtension, PythonExtension
+from snek.system.template import (
+ EmojiExtension,
+ LinkifyExtension,
+ PythonExtension,
+ sanitize_html,
+)
from snek.view.about import AboutHTMLView, AboutMDView
from snek.view.avatar import AvatarView
+from snek.view.channel import ChannelAttachmentView, ChannelView
from snek.view.docs import DocsHTMLView, DocsMDView
-from snek.view.drive import DriveView
-from snek.view.drive import DriveApiView
+from snek.view.drive import DriveApiView, DriveView
from snek.view.index import IndexView
from snek.view.login import LoginView
-from snek.view.push import PushView
from snek.view.logout import LogoutView
+from snek.view.push import PushView
from snek.view.register import RegisterView
-from snek.view.rpc import RPCView
from snek.view.repository import RepositoryView
+from snek.view.rpc import RPCView
from snek.view.search_user import SearchUserView
-from snek.view.settings.repositories import RepositoriesIndexView
-from snek.view.settings.repositories import RepositoriesCreateView
-from snek.view.settings.repositories import RepositoriesUpdateView
-from snek.view.settings.repositories import RepositoriesDeleteView
+from snek.view.settings.containers import (
+ ContainersCreateView,
+ ContainersDeleteView,
+ ContainersIndexView,
+ ContainersUpdateView,
+)
from snek.view.settings.index import SettingsIndexView
from snek.view.settings.profile import SettingsProfileView
+from snek.view.settings.repositories import (
+ RepositoriesCreateView,
+ RepositoriesDeleteView,
+ RepositoriesIndexView,
+ RepositoriesUpdateView,
+)
from snek.view.stats import StatsView
from snek.view.status import StatusView
from snek.view.terminal import TerminalSocketView, TerminalView
from snek.view.upload import UploadView
from snek.view.user import UserView
from snek.view.web import WebView
-from snek.view.channel import ChannelAttachmentView
-from snek.view.channel import ChannelView
-from snek.view.settings.containers import (
- ContainersIndexView,
- ContainersCreateView,
- ContainersUpdateView,
- ContainersDeleteView,
-)
from snek.webdav import WebdavApplication
-from snek.system.template import sanitize_html
-from snek.sgit import GitApplication
SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34"
from snek.system.template import whitelist_attributes
@@ -94,7 +98,7 @@ async def ip2location_middleware(request, handler):
if not user:
return response
location = request.app.ip2location.get(ip)
- original_city = user["city"]
+ user["city"]
if user["city"] != location.city:
user["country_long"] = location.country
user["country_short"] = locaion.country_short
@@ -121,7 +125,7 @@ class Application(BaseApplication):
cors_middleware,
web.normalize_path_middleware(merge_slashes=True),
ip2location_middleware,
- csp_middleware
+ csp_middleware,
]
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
@@ -139,7 +143,7 @@ class Application(BaseApplication):
self.jinja2_env.add_extension(LinkifyExtension)
self.jinja2_env.add_extension(PythonExtension)
self.jinja2_env.add_extension(EmojiExtension)
- self.jinja2_env.filters['sanitize'] = sanitize_html
+ self.jinja2_env.filters["sanitize"] = sanitize_html
self.time_start = datetime.now()
self.ssh_host = "0.0.0.0"
self.ssh_port = 2242
@@ -272,8 +276,12 @@ class Application(BaseApplication):
self.router.add_get("/http-photo", self.handle_http_photo)
self.router.add_get("/rpc.ws", RPCView)
self.router.add_get("/c/{channel:.*}", ChannelView)
- self.router.add_view("/channel/{channel_uid}/attachment.bin",ChannelAttachmentView)
- self.router.add_view("/channel/attachment/{relative_url:.*}",ChannelAttachmentView)
+ self.router.add_view(
+ "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView
+ )
+ self.router.add_view(
+ "/channel/attachment/{relative_url:.*}", ChannelAttachmentView
+ )
self.router.add_view("/channel/{channel}.html", WebView)
self.router.add_view("/threads.html", ThreadsView)
self.router.add_view("/terminal.ws", TerminalSocketView)
@@ -284,15 +292,29 @@ class Application(BaseApplication):
self.router.add_view("/stats.json", StatsView)
self.router.add_view("/user/{user}.html", UserView)
self.router.add_view("/repository/{username}/{repository}", RepositoryView)
- self.router.add_view("/repository/{username}/{repository}/{path:.*}", RepositoryView)
+ self.router.add_view(
+ "/repository/{username}/{repository}/{path:.*}", RepositoryView
+ )
self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView)
- self.router.add_view("/settings/repositories/create.html", RepositoriesCreateView)
- self.router.add_view("/settings/repositories/repository/{name}/update.html", RepositoriesUpdateView)
- self.router.add_view("/settings/repositories/repository/{name}/delete.html", RepositoriesDeleteView)
+ self.router.add_view(
+ "/settings/repositories/create.html", RepositoriesCreateView
+ )
+ self.router.add_view(
+ "/settings/repositories/repository/{name}/update.html",
+ RepositoriesUpdateView,
+ )
+ self.router.add_view(
+ "/settings/repositories/repository/{name}/delete.html",
+ RepositoriesDeleteView,
+ )
self.router.add_view("/settings/containers/index.html", ContainersIndexView)
self.router.add_view("/settings/containers/create.html", ContainersCreateView)
- self.router.add_view("/settings/containers/container/{uid}/update.html", ContainersUpdateView)
- self.router.add_view("/settings/containers/container/{uid}/delete.html", ContainersDeleteView)
+ self.router.add_view(
+ "/settings/containers/container/{uid}/update.html", ContainersUpdateView
+ )
+ self.router.add_view(
+ "/settings/containers/container/{uid}/delete.html", ContainersDeleteView
+ )
self.webdav = WebdavApplication(self)
self.git = GitApplication(self)
self.add_subapp("/webdav", self.webdav)
@@ -302,9 +324,9 @@ class Application(BaseApplication):
async def handle_test(self, request):
- return await whitelist_attributes(self.render_template(
- "test.html", request, context={"name": "retoor"}
- ))
+ return await whitelist_attributes(
+ self.render_template("test.html", request, context={"name": "retoor"})
+ )
async def handle_http_get(self, request: web.Request):
url = request.query.get("url")
@@ -375,8 +397,8 @@ class Application(BaseApplication):
self.jinja2_env.loader = self.original_loader
- #rendered.text = whitelist_attributes(rendered.text)
- #rendered.headers['Content-Lenght'] = len(rendered.text)
+ # rendered.text = whitelist_attributes(rendered.text)
+ # rendered.headers['Content-Lenght'] = len(rendered.text)
return rendered
async def static_handler(self, request):
@@ -423,7 +445,7 @@ app = Application(db_path="sqlite:///snek.db")
async def main():
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
- ssl_context.load_cert_chain('cert.pem', 'key.pem')
+ ssl_context.load_cert_chain("cert.pem", "key.pem")
await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context)
diff --git a/src/snek/balancer.py b/src/snek/balancer.py
index 09d0b5a..44db9a0 100644
--- a/src/snek/balancer.py
+++ b/src/snek/balancer.py
@@ -1,6 +1,7 @@
import asyncio
import sys
+
class LoadBalancer:
def __init__(self, backend_ports):
self.backend_ports = backend_ports
@@ -8,27 +9,29 @@ class LoadBalancer:
self.client_counts = [0] * len(backend_ports)
self.lock = asyncio.Lock()
- async def start_backend_servers(self,port,workers):
+ async def start_backend_servers(self, port, workers):
for x in range(workers):
port += 1
process = await asyncio.create_subprocess_exec(
sys.executable,
sys.argv[0],
- 'backend',
+ "backend",
str(port),
stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE
+ stderr=asyncio.subprocess.PIPE,
)
port += 1
self.backend_processes.append(process)
- print(f"Started backend server on port {(port-1)/port} with PID {process.pid}")
+ print(
+ f"Started backend server on port {(port-1)/port} with PID {process.pid}"
+ )
async def handle_client(self, reader, writer):
async with self.lock:
min_clients = min(self.client_counts)
server_index = self.client_counts.index(min_clients)
self.client_counts[server_index] += 1
- backend = ('127.0.0.1', self.backend_ports[server_index])
+ backend = ("127.0.0.1", self.backend_ports[server_index])
try:
backend_reader, backend_writer = await asyncio.open_connection(*backend)
@@ -62,10 +65,10 @@ class LoadBalancer:
for i, count in enumerate(self.client_counts):
print(f"Server {self.backend_ports[i]}: {count} clients")
- async def start(self, host='0.0.0.0', port=8081,workers=5):
- await self.start_backend_servers(port,workers)
+ async def start(self, host="0.0.0.0", port=8081, workers=5):
+ await self.start_backend_servers(port, workers)
server = await asyncio.start_server(self.handle_client, host, port)
- monitor_task = asyncio.create_task(self.monitor())
+ asyncio.create_task(self.monitor())
# Handle shutdown gracefully
try:
@@ -80,6 +83,7 @@ class LoadBalancer:
await asyncio.gather(*(p.wait() for p in self.backend_processes))
print("Backend processes terminated.")
+
async def backend_echo_server(port):
async def handle_echo(reader, writer):
try:
@@ -94,10 +98,11 @@ async def backend_echo_server(port):
finally:
writer.close()
- server = await asyncio.start_server(handle_echo, '127.0.0.1', port)
+ server = await asyncio.start_server(handle_echo, "127.0.0.1", port)
print(f"Backend echo server running on port {port}")
await server.serve_forever()
+
async def main():
backend_ports = [8001, 8003, 8005, 8006]
# Launch backend echo servers
@@ -105,19 +110,20 @@ async def main():
lb = LoadBalancer(backend_ports)
await lb.start()
+
if __name__ == "__main__":
if len(sys.argv) > 1:
- if sys.argv[1] == 'backend':
+ if sys.argv[1] == "backend":
port = int(sys.argv[2])
from snek.app import Application
+
snek = Application(port=port)
- web.run_app(snek, port=port, host='127.0.0.1')
- elif sys.argv[1] == 'sync':
- from snek.sync import app
- web.run_app(snek, port=port, host='127.0.0.1')
+ web.run_app(snek, port=port, host="127.0.0.1")
+ elif sys.argv[1] == "sync":
+
+ web.run_app(snek, port=port, host="127.0.0.1")
else:
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Shutting down...")
-
diff --git a/src/snek/mapper/__init__.py b/src/snek/mapper/__init__.py
index 48a8a0e..062553a 100644
--- a/src/snek/mapper/__init__.py
+++ b/src/snek/mapper/__init__.py
@@ -1,22 +1,21 @@
import functools
from snek.mapper.channel import ChannelMapper
+from snek.mapper.channel_attachment import ChannelAttachmentMapper
from snek.mapper.channel_member import ChannelMemberMapper
from snek.mapper.channel_message import ChannelMessageMapper
+from snek.mapper.container import ContainerMapper
from snek.mapper.drive import DriveMapper
from snek.mapper.drive_item import DriveItemMapper
from snek.mapper.notification import NotificationMapper
+from snek.mapper.push import PushMapper
+from snek.mapper.repository import RepositoryMapper
from snek.mapper.user import UserMapper
from snek.mapper.user_property import UserPropertyMapper
-from snek.mapper.repository import RepositoryMapper
-from snek.mapper.push import PushMapper
-from snek.mapper.channel_attachment import ChannelAttachmentMapper
-from snek.mapper.container import ContainerMapper
from snek.system.object import Object
@functools.cache
-
def get_mappers(app=None):
return Object(
**{
diff --git a/src/snek/mapper/container.py b/src/snek/mapper/container.py
index f88f9f0..c2a089f 100644
--- a/src/snek/mapper/container.py
+++ b/src/snek/mapper/container.py
@@ -1,6 +1,7 @@
from snek.model.container import Container
from snek.system.mapper import BaseMapper
+
class ContainerMapper(BaseMapper):
model_class = Container
- table_name = "container"
\ No newline at end of file
+ table_name = "container"
diff --git a/src/snek/model/__init__.py b/src/snek/model/__init__.py
index e5c6054..187ea4b 100644
--- a/src/snek/model/__init__.py
+++ b/src/snek/model/__init__.py
@@ -1,19 +1,19 @@
import functools
from snek.model.channel import ChannelModel
+from snek.model.channel_attachment import ChannelAttachmentModel
from snek.model.channel_member import ChannelMemberModel
# from snek.model.channel_message import ChannelMessageModel
from snek.model.channel_message import ChannelMessageModel
+from snek.model.container import Container
from snek.model.drive import DriveModel
from snek.model.drive_item import DriveItemModel
from snek.model.notification import NotificationModel
+from snek.model.push_registration import PushRegistrationModel
+from snek.model.repository import RepositoryModel
from snek.model.user import UserModel
from snek.model.user_property import UserPropertyModel
-from snek.model.repository import RepositoryModel
-from snek.model.channel_attachment import ChannelAttachmentModel
-from snek.model.push_registration import PushRegistrationModel
-from snek.model.container import Container
from snek.system.object import Object
diff --git a/src/snek/model/channel_attachment.py b/src/snek/model/channel_attachment.py
index 9add1b8..0c0fba2 100644
--- a/src/snek/model/channel_attachment.py
+++ b/src/snek/model/channel_attachment.py
@@ -1,4 +1,3 @@
-from snek.system.model import BaseModel
from snek.system.model import BaseModel, ModelField
@@ -11,6 +10,6 @@ class ChannelAttachmentModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True, kind=str)
mime_type = ModelField(name="type", required=True, kind=str)
relative_url = ModelField(name="relative_url", required=True, kind=str)
- resource_type = ModelField(name="resource_type", required=True, kind=str,value="file")
-
-
+ resource_type = ModelField(
+ name="resource_type", required=True, kind=str, value="file"
+ )
diff --git a/src/snek/model/channel_message.py b/src/snek/model/channel_message.py
index 505a40d..e3933e3 100644
--- a/src/snek/model/channel_message.py
+++ b/src/snek/model/channel_message.py
@@ -1,6 +1,8 @@
+from datetime import datetime, timezone
+
from snek.model.user import UserModel
from snek.system.model import BaseModel, ModelField
-from datetime import datetime,timezone
+
class ChannelMessageModel(BaseModel):
channel_uid = ModelField(name="channel_uid", required=True, kind=str)
@@ -10,7 +12,11 @@ class ChannelMessageModel(BaseModel):
is_final = ModelField(name="is_final", required=True, kind=bool, value=True)
def get_seconds_since_last_update(self):
- return int((datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])).total_seconds())
+ return int(
+ (
+ datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])
+ ).total_seconds()
+ )
async def get_user(self) -> UserModel:
return await self.app.services.user.get(uid=self["user_uid"])
diff --git a/src/snek/model/container.py b/src/snek/model/container.py
index 3947a8d..ee64144 100644
--- a/src/snek/model/container.py
+++ b/src/snek/model/container.py
@@ -1,5 +1,6 @@
from snek.system.model import BaseModel, ModelField
+
class Container(BaseModel):
id = ModelField(name="id", required=True, kind=str)
name = ModelField(name="name", required=True, kind=str)
diff --git a/src/snek/model/drive_item.py b/src/snek/model/drive_item.py
index e2b55b4..82567a6 100644
--- a/src/snek/model/drive_item.py
+++ b/src/snek/model/drive_item.py
@@ -9,7 +9,9 @@ class DriveItemModel(BaseModel):
path = ModelField(name="path", required=True, kind=str)
file_type = ModelField(name="file_type", required=True, kind=str)
file_size = ModelField(name="file_size", required=True, kind=int)
- is_available = ModelField(name="is_available", required=True, kind=bool, initial_value=True)
+ is_available = ModelField(
+ name="is_available", required=True, kind=bool, initial_value=True
+ )
@property
def extension(self):
diff --git a/src/snek/model/repository.py b/src/snek/model/repository.py
index 598cbb2..521e3ff 100644
--- a/src/snek/model/repository.py
+++ b/src/snek/model/repository.py
@@ -1,14 +1,10 @@
-from snek.model.user import UserModel
from snek.system.model import BaseModel, ModelField
class RepositoryModel(BaseModel):
user_uid = ModelField(name="user_uid", required=True, kind=str)
-
+
name = ModelField(name="name", required=True, kind=str)
is_private = ModelField(name="is_private", required=False, kind=bool)
-
-
-
diff --git a/src/snek/model/user.py b/src/snek/model/user.py
index afa4b53..93e83af 100644
--- a/src/snek/model/user.py
+++ b/src/snek/model/user.py
@@ -30,7 +30,7 @@ class UserModel(BaseModel):
last_ping = ModelField(name="last_ping", required=False, kind=str)
is_admin = ModelField(name="is_admin", required=False, kind=bool)
-
+
country_short = ModelField(name="country_short", required=False, kind=str)
country_long = ModelField(name="country_long", required=False, kind=str)
city = ModelField(name="city", required=False, kind=str)
diff --git a/src/snek/research/serptest.py b/src/snek/research/serptest.py
index 94e22be..891cccc 100644
--- a/src/snek/research/serptest.py
+++ b/src/snek/research/serptest.py
@@ -1,51 +1,54 @@
-import snek.serpentarium
-
-import time
-
+import time
from concurrent.futures import ProcessPoolExecutor
+import snek.serpentarium
+
durations = []
+
def task1():
- global durations
+ global durations
client = snek.serpentarium.DatasetWrapper()
- start=time.time()
+ start = time.time()
for x in range(1500):
-
- client['a'].delete()
- client['a'].insert({"foo": x})
- client['a'].find(foo=x)
- client['a'].find_one(foo=x)
- client['a'].count()
- #print(client['a'].find(foo=x) )
- #print(client['a'].find_one(foo=x) )
- #print(client['a'].count())
+
+ client["a"].delete()
+ client["a"].insert({"foo": x})
+ client["a"].find(foo=x)
+ client["a"].find_one(foo=x)
+ client["a"].count()
+ # print(client['a'].find(foo=x) )
+ # print(client['a'].find_one(foo=x) )
+ # print(client['a'].count())
client.close()
duration1 = f"{time.time()-start}"
durations.append(duration1)
print(durations)
+
with ProcessPoolExecutor(max_workers=4) as executor:
- tasks = [executor.submit(task1),
+ tasks = [
+ executor.submit(task1),
+ executor.submit(task1),
executor.submit(task1),
executor.submit(task1),
- executor.submit(task1)
]
for task in tasks:
task.result()
-import dataset
+import dataset
+
client = dataset.connect("sqlite:///snek.db")
-start=time.time()
+start = time.time()
for x in range(1500):
- client['a'].delete()
- client['a'].insert({"foo": x})
- print([dict(row) for row in client['a'].find(foo=x)])
- print(dict(client['a'].find_one(foo=x) ))
- print(client['a'].count())
+ client["a"].delete()
+ client["a"].insert({"foo": x})
+ print([dict(row) for row in client["a"].find(foo=x)])
+ print(dict(client["a"].find_one(foo=x)))
+ print(client["a"].count())
duration2 = f"{time.time()-start}"
-print(duration1,duration2)
+print(duration1, duration2)
diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py
index f1654fb..77702c2 100644
--- a/src/snek/service/__init__.py
+++ b/src/snek/service/__init__.py
@@ -1,22 +1,22 @@
import functools
from snek.service.channel import ChannelService
+from snek.service.channel_attachment import ChannelAttachmentService
from snek.service.channel_member import ChannelMemberService
from snek.service.channel_message import ChannelMessageService
from snek.service.chat import ChatService
+from snek.service.container import ContainerService
+from snek.service.db import DBService
from snek.service.drive import DriveService
from snek.service.drive_item import DriveItemService
from snek.service.notification import NotificationService
+from snek.service.push import PushService
+from snek.service.repository import RepositoryService
from snek.service.socket import SocketService
from snek.service.user import UserService
-from snek.service.push import PushService
from snek.service.user_property import UserPropertyService
from snek.service.util import UtilService
-from snek.service.repository import RepositoryService
-from snek.service.channel_attachment import ChannelAttachmentService
-from snek.service.container import ContainerService
from snek.system.object import Object
-from snek.service.db import DBService
@functools.cache
diff --git a/src/snek/service/channel.py b/src/snek/service/channel.py
index a5c166d..0af0a4b 100644
--- a/src/snek/service/channel.py
+++ b/src/snek/service/channel.py
@@ -1,9 +1,9 @@
+import pathlib
from datetime import datetime
from snek.system.model import now
from snek.system.service import BaseService
-import pathlib
class ChannelService(BaseService):
mapper_name = "channel"
@@ -17,12 +17,10 @@ class ChannelService(BaseService):
pass
return folder
- async def get_attachment_folder(self, channel_uid,ensure=False):
+ async def get_attachment_folder(self, channel_uid, ensure=False):
path = pathlib.Path(f"./drive/{channel_uid}/attachments")
if ensure:
- path.mkdir(
- parents=True, exist_ok=True
- )
+ path.mkdir(parents=True, exist_ok=True)
return path
async def get(self, uid=None, **kwargs):
@@ -78,7 +76,10 @@ class ChannelService(BaseService):
return channel
async def get_recent_users(self, channel_uid):
- async for user in self.query("SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30", {"channel_uid": channel_uid}):
+ async for user in self.query(
+ "SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30",
+ {"channel_uid": channel_uid},
+ ):
yield user
async def get_users(self, channel_uid):
diff --git a/src/snek/service/channel_attachment.py b/src/snek/service/channel_attachment.py
index 7c8ded6..755aaa0 100644
--- a/src/snek/service/channel_attachment.py
+++ b/src/snek/service/channel_attachment.py
@@ -1,25 +1,25 @@
-from snek.system.service import BaseService
-import urllib.parse
-import pathlib
import mimetypes
-import uuid
+
+from snek.system.service import BaseService
+
class ChannelAttachmentService(BaseService):
- mapper_name="channel_attachment"
+ mapper_name = "channel_attachment"
async def create_file(self, channel_uid, user_uid, name):
attachment = await self.new()
attachment["channel_uid"] = channel_uid
- attachment['user_uid'] = user_uid
+ attachment["user_uid"] = user_uid
attachment["name"] = name
attachment["mime_type"] = mimetypes.guess_type(name)[0]
- attachment['resource_type'] = "file"
+ attachment["resource_type"] = "file"
real_file_name = f"{attachment['uid']}-{name}"
- attachment["relative_url"] = (f"{attachment['uid']}-{name}")
- attachment_folder = await self.services.channel.get_attachment_folder(channel_uid)
+ attachment["relative_url"] = f"{attachment['uid']}-{name}"
+ attachment_folder = await self.services.channel.get_attachment_folder(
+ channel_uid
+ )
attachment_path = attachment_folder.joinpath(real_file_name)
attachment["path"] = str(attachment_path)
if await self.save(attachment):
return attachment
raise Exception(f"Failed to create channel attachment: {attachment.errors}.")
-
diff --git a/src/snek/service/channel_message.py b/src/snek/service/channel_message.py
index eefff96..6d70da7 100644
--- a/src/snek/service/channel_message.py
+++ b/src/snek/service/channel_message.py
@@ -11,7 +11,7 @@ class ChannelMessageService(BaseService):
model["channel_uid"] = channel_uid
model["user_uid"] = user_uid
model["message"] = message
- model['is_final'] = is_final
+ model["is_final"] = is_final
context = {}
@@ -56,7 +56,7 @@ class ChannelMessageService(BaseService):
async def save(self, model):
context = {}
context.update(model.record)
- user = await self.app.services.user.get(model['user_uid'])
+ user = await self.app.services.user.get(model["user_uid"])
context.update(
{
"user_uid": user["uid"],
diff --git a/src/snek/service/chat.py b/src/snek/service/chat.py
index f37f6c7..9130292 100644
--- a/src/snek/service/chat.py
+++ b/src/snek/service/chat.py
@@ -13,13 +13,13 @@ class ChatService(BaseService):
channel["last_message_on"] = now()
await self.services.channel.save(channel)
await self.services.socket.broadcast(
- channel['uid'],
+ channel["uid"],
{
"message": channel_message["message"],
"html": channel_message["html"],
- "user_uid": user['uid'],
+ "user_uid": user["uid"],
"color": user["color"],
- "channel_uid": channel['uid'],
+ "channel_uid": channel["uid"],
"created_at": channel_message["created_at"],
"updated_at": channel_message["updated_at"],
"username": user["username"],
@@ -32,14 +32,12 @@ class ChatService(BaseService):
self.services.notification.create_channel_message(message_uid)
)
-
-
async def send(self, user_uid, channel_uid, message, is_final=True):
channel = await self.services.channel.get(uid=channel_uid)
if not channel:
raise Exception("Channel not found.")
channel_message = await self.services.channel_message.create(
- channel_uid, user_uid, message,is_final
+ channel_uid, user_uid, message, is_final
)
channel_message_uid = channel_message["uid"]
diff --git a/src/snek/service/container.py b/src/snek/service/container.py
index 03c63aa..9d2706c 100644
--- a/src/snek/service/container.py
+++ b/src/snek/service/container.py
@@ -1,36 +1,51 @@
+from snek.system.docker import ComposeFileManager
from snek.system.service import BaseService
-from snek.system.docker import ComposeFileManager
+
class ContainerService(BaseService):
mapper_name = "container"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
-
+
self.compose_path = "snek-container-compose.yml"
self.compose = ComposeFileManager(self.compose_path)
-
-
async def get_instances(self):
return list(self.compose.list_instances())
async def get_container_name(self, channel_uid):
return f"channel-{channel_uid}"
- async def create(self, channel_uid, image="ubuntu:latest", command=None, cpus=1, memory='1024m', ports=None, volumes=None):
+ async def create(
+ self,
+ channel_uid,
+ image="ubuntu:latest",
+ command=None,
+ cpus=1,
+ memory="1024m",
+ ports=None,
+ volumes=None,
+ ):
name = await self.get_container_name(channel_uid)
self.compose.create_instance(
- name,
- image,
- command,
- cpus,
- memory,
- ports,
- ["./"+str((await self.services.channel.get_home_folder(channel_uid))) +":"+ "/home/ubuntu"]
+ name,
+ image,
+ command,
+ cpus,
+ memory,
+ ports,
+ [
+ "./"
+ + str(await self.services.channel.get_home_folder(channel_uid))
+ + ":"
+ + "/home/ubuntu"
+ ],
)
- async def create2(self, id, name, status, resources=None, user_uid=None, path=None, readonly=False):
+ async def create2(
+ self, id, name, status, resources=None, user_uid=None, path=None, readonly=False
+ ):
model = await self.new()
model["id"] = id
model["name"] = name
diff --git a/src/snek/service/db.py b/src/snek/service/db.py
index fe9fb8a..8e583e9 100644
--- a/src/snek/service/db.py
+++ b/src/snek/service/db.py
@@ -1,23 +1,21 @@
-from snek.system.service import BaseService
-import dataset
-import uuid
+import dataset
+
+from snek.system.service import BaseService
-from datetime import datetime
class DBService(BaseService):
-
+
async def get_db(self, user_uid):
-
+
home_folder = await self.app.services.user.get_home_folder(user_uid)
home_folder.mkdir(parents=True, exist_ok=True)
db_path = home_folder.joinpath("snek/user.db")
db_path.parent.mkdir(parents=True, exist_ok=True)
- return dataset.connect("sqlite:///" + str(db_path))
-
+ return dataset.connect("sqlite:///" + str(db_path))
+
async def insert(self, user_uid, table_name, values):
db = await self.get_db(user_uid)
return db[table_name].insert(values)
-
async def update(self, user_uid, table_name, values, filters):
db = await self.get_db(user_uid)
@@ -33,7 +31,7 @@ class DBService(BaseService):
async def find(self, user_uid, table_name, kwargs):
db = await self.get_db(user_uid)
- kwargs['_limit'] = kwargs.get('_limit', 30)
+ kwargs["_limit"] = kwargs.get("_limit", 30)
return [dict(row) for row in db[table_name].find(**kwargs)]
async def get(self, user_uid, table_name, filters):
@@ -45,14 +43,13 @@ class DBService(BaseService):
except ValueError:
return None
-
async def delete(self, user_uid, table_name, filters):
db = await self.get_db(user_uid)
if not filters:
filters = {}
return db[table_name].delete(**filters)
- async def query(self, sql,values):
+ async def query(self, sql, values):
db = await self.app.db
return [dict(row) for row in db.query(sql, values or {})]
@@ -62,8 +59,6 @@ class DBService(BaseService):
filters = {}
return bool(db[table_name].find_one(**filters))
-
-
async def count(self, user_uid, table_name, filters):
db = await self.get_db(user_uid)
if not filters:
diff --git a/src/snek/service/notification.py b/src/snek/service/notification.py
index 8aaa1a4..fa3cdc3 100644
--- a/src/snek/service/notification.py
+++ b/src/snek/service/notification.py
@@ -1,6 +1,7 @@
+from snek.system.markdown import strip_markdown
from snek.system.model import now
from snek.system.service import BaseService
-from snek.system.markdown import strip_markdown
+
class NotificationService(BaseService):
mapper_name = "notification"
@@ -34,7 +35,7 @@ class NotificationService(BaseService):
uid=channel_message_uid
)
if not channel_message["is_final"]:
- return
+ return
user = await self.services.user.get(uid=channel_message["user_uid"])
self.app.db.begin()
async for channel_member in self.services.channel_member.find(
diff --git a/src/snek/service/push.py b/src/snek/service/push.py
index ccba4af..b4b72c9 100644
--- a/src/snek/service/push.py
+++ b/src/snek/service/push.py
@@ -1,25 +1,23 @@
+import base64
import json
-
-import aiohttp
-from snek.system.service import BaseService
+import os.path
import random
import time
-import base64
import uuid
-from functools import cache
from pathlib import Path
-
-import jwt
-from cryptography.hazmat.primitives.asymmetric import ec
-from cryptography.hazmat.primitives import serialization
-from cryptography.hazmat.backends import default_backend
from urllib.parse import urlparse
-import os.path
+import aiohttp
+import jwt
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+from snek.system.service import BaseService
+
# The only reason to persist the keys is to be able to use them in the web push
PRIVATE_KEY_FILE = Path("./notification-private.pem")
@@ -268,4 +266,4 @@ class PushService(BaseService):
raise Exception(
f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}"
- )
\ No newline at end of file
+ )
diff --git a/src/snek/service/repository.py b/src/snek/service/repository.py
index 120c232..def3257 100644
--- a/src/snek/service/repository.py
+++ b/src/snek/service/repository.py
@@ -1,13 +1,17 @@
-from snek.system.service import BaseService
-import asyncio
+import asyncio
import shutil
+from snek.system.service import BaseService
+
+
class RepositoryService(BaseService):
mapper_name = "repository"
async def delete(self, user_uid, name):
loop = asyncio.get_event_loop()
- repository_path = (await self.services.user.get_repository_path(user_uid)).joinpath(name)
+ repository_path = (
+ await self.services.user.get_repository_path(user_uid)
+ ).joinpath(name)
try:
await loop.run_in_executor(None, shutil.rmtree, repository_path)
except Exception as ex:
@@ -15,7 +19,6 @@ class RepositoryService(BaseService):
await super().delete(user_uid=user_uid, name=name)
-
async def exists(self, user_uid, name, **kwargs):
kwargs["user_uid"] = user_uid
kwargs["name"] = name
@@ -29,18 +32,16 @@ class RepositoryService(BaseService):
repository_path = str(repository_path)
if not repository_path.endswith(".git"):
repository_path += ".git"
- command = ['git', 'init', '--bare', repository_path]
+ command = ["git", "init", "--bare", repository_path]
process = await asyncio.subprocess.create_subprocess_exec(
- *command,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE
+ *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
return process.returncode == 0
- async def create(self, user_uid, name,is_private=False):
+ async def create(self, user_uid, name, is_private=False):
if await self.exists(user_uid=user_uid, name=name):
- return False
+ return False
if not await self.init(user_uid=user_uid, name=name):
return False
diff --git a/src/snek/service/socket.py b/src/snek/service/socket.py
index d0ff024..ea002dc 100644
--- a/src/snek/service/socket.py
+++ b/src/snek/service/socket.py
@@ -1,12 +1,14 @@
+import asyncio
+import logging
+from datetime import datetime
+
from snek.model.user import UserModel
from snek.system.service import BaseService
-from datetime import datetime
-import json
-import asyncio
-import logging
+
logger = logging.getLogger(__name__)
from snek.system.model import now
+
class SocketService(BaseService):
class Socket:
@@ -15,7 +17,6 @@ class SocketService(BaseService):
self.is_connected = True
self.user = user
-
async def send_json(self, data):
if not self.is_connected:
return False
@@ -40,7 +41,6 @@ class SocketService(BaseService):
self.users = {}
self.subscriptions = {}
self.last_update = str(datetime.now())
-
async def user_availability_service(self):
logger.info("User availability update service started.")
@@ -50,14 +50,15 @@ class SocketService(BaseService):
for s in self.sockets:
if not s.is_connected:
continue
- if not s.user in users_updated:
+ if s.user not in users_updated:
s.user["last_ping"] = now()
await self.app.services.user.save(s.user)
users_updated.append(s.user)
- logger.info(f"Updated user availability for {len(users_updated)} online users.")
+ logger.info(
+ f"Updated user availability for {len(users_updated)} online users."
+ )
await asyncio.sleep(60)
-
async def add(self, ws, user_uid):
s = self.Socket(ws, await self.app.services.user.get(uid=user_uid))
self.sockets.add(s)
@@ -81,7 +82,6 @@ class SocketService(BaseService):
count += 1
return count
-
async def broadcast(self, channel_uid, message):
await self._broadcast(channel_uid, message)
@@ -102,4 +102,3 @@ class SocketService(BaseService):
await s.close()
logger.info(f"Removed socket for user {s.user['username']}")
self.sockets.remove(s)
-
diff --git a/src/snek/service/user.py b/src/snek/service/user.py
index 13f660f..34dc468 100644
--- a/src/snek/service/user.py
+++ b/src/snek/service/user.py
@@ -32,14 +32,14 @@ class UserService(BaseService):
user["color"] = await self.services.util.random_light_hex_color()
return await super().save(user)
- def authenticate_sync(self,username,password):
+ def authenticate_sync(self, username, password):
user = self.get_by_username_sync(username)
-
+
if not user:
return False
if not security.verify_sync(password, user["password"]):
return False
- return True
+ return True
async def authenticate(self, username, password):
success = await self.validate_login(username, password)
@@ -61,14 +61,12 @@ class UserService(BaseService):
return None
return path
-
-
async def get_template_path(self, user_uid):
path = pathlib.Path(f"./drive/{user_uid}/snek/templates")
if not path.exists():
return None
return path
-
+
def get_by_username_sync(self, username):
user = self.mapper.db["user"].find_one(username=username, deleted_at=None)
return dict(user)
diff --git a/src/snek/sgit.py b/src/snek/sgit.py
index 3cbdcf6..4c2b6ac 100644
--- a/src/snek/sgit.py
+++ b/src/snek/sgit.py
@@ -1,48 +1,51 @@
+import asyncio
+import base64
+import json
+import logging
import os
-import aiohttp
+import shutil
+import tempfile
+
from aiohttp import web
-import shutil
-import json
-import tempfile
-import asyncio
-import logging
-import base64
-import pathlib
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-logger = logging.getLogger('git_server')
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger("git_server")
+
class GitApplication(web.Application):
def __init__(self, parent=None):
- #import git
- #globals()['git'] = git
+ # import git
+ # globals()['git'] = git
self.parent = parent
- super().__init__(client_max_size=1024*1024*1024*5)
- self.add_routes([
- web.post('/create/{repo_name}', self.create_repository),
- web.delete('/delete/{repo_name}', self.delete_repository),
- web.get('/clone/{repo_name}', self.clone_repository),
- web.post('/push/{repo_name}', self.push_repository),
- web.post('/pull/{repo_name}', self.pull_repository),
- web.get('/status/{repo_name}', self.status_repository),
- #web.get('/list', self.list_repositories),
- web.get('/branches/{repo_name}', self.list_branches),
- web.post('/branches/{repo_name}', self.create_branch),
- web.get('/log/{repo_name}', self.commit_log),
- web.get('/file/{repo_name}/{file_path:.*}', self.file_content),
- web.get('/{path:.+}/info/refs', self.git_smart_http),
- web.post('/{path:.+}/git-upload-pack', self.git_smart_http),
- web.post('/{path:.+}/git-receive-pack', self.git_smart_http),
- web.get('/{repo_name}.git/info/refs', self.git_smart_http),
- web.post('/{repo_name}.git/git-upload-pack', self.git_smart_http),
- web.post('/{repo_name}.git/git-receive-pack', self.git_smart_http),
- ])
-
+ super().__init__(client_max_size=1024 * 1024 * 1024 * 5)
+ self.add_routes(
+ [
+ web.post("/create/{repo_name}", self.create_repository),
+ web.delete("/delete/{repo_name}", self.delete_repository),
+ web.get("/clone/{repo_name}", self.clone_repository),
+ web.post("/push/{repo_name}", self.push_repository),
+ web.post("/pull/{repo_name}", self.pull_repository),
+ web.get("/status/{repo_name}", self.status_repository),
+ # web.get('/list', self.list_repositories),
+ web.get("/branches/{repo_name}", self.list_branches),
+ web.post("/branches/{repo_name}", self.create_branch),
+ web.get("/log/{repo_name}", self.commit_log),
+ web.get("/file/{repo_name}/{file_path:.*}", self.file_content),
+ web.get("/{path:.+}/info/refs", self.git_smart_http),
+ web.post("/{path:.+}/git-upload-pack", self.git_smart_http),
+ web.post("/{path:.+}/git-receive-pack", self.git_smart_http),
+ web.get("/{repo_name}.git/info/refs", self.git_smart_http),
+ web.post("/{repo_name}.git/git-upload-pack", self.git_smart_http),
+ web.post("/{repo_name}.git/git-receive-pack", self.git_smart_http),
+ ]
+ )
async def check_basic_auth(self, request):
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Basic "):
- return None,None
+ return None, None
encoded_creds = auth_header.split("Basic ")[1]
decoded_creds = base64.b64decode(encoded_creds).decode()
username, password = decoded_creds.split(":", 1)
@@ -50,27 +53,31 @@ class GitApplication(web.Application):
username=username, password=password
)
if not request["user"]:
- return None,None
- request["repository_path"] = await self.parent.services.user.get_repository_path(
- request["user"]["uid"]
+ return None, None
+ request["repository_path"] = (
+ await self.parent.services.user.get_repository_path(request["user"]["uid"])
)
- return request["user"]['username'],request["repository_path"]
-
+ return request["user"]["username"], request["repository_path"]
@staticmethod
def require_auth(handler):
async def wrapped(self, request, *args, **kwargs):
username, repository_path = await self.check_basic_auth(request)
if not username or not repository_path:
- return web.Response(status=401, headers={'WWW-Authenticate': 'Basic'}, text='Authentication required')
- request['username'] = username
- request['repository_path'] = repository_path
+ return web.Response(
+ status=401,
+ headers={"WWW-Authenticate": "Basic"},
+ text="Authentication required",
+ )
+ request["username"] = username
+ request["repository_path"] = repository_path
return await handler(self, request, *args, **kwargs)
+
return wrapped
def repo_path(self, repository_path, repo_name):
- return repository_path.joinpath(repo_name + '.git')
+ return repository_path.joinpath(repo_name + ".git")
def check_repo_exists(self, repository_path, repo_name):
repo_dir = self.repo_path(repository_path, repo_name)
@@ -80,10 +87,10 @@ class GitApplication(web.Application):
@require_auth
async def create_repository(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
- if not repo_name or '/' in repo_name or '..' in repo_name:
+ username = request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
+ if not repo_name or "/" in repo_name or ".." in repo_name:
return web.Response(text="Invalid repository name", status=400)
repo_dir = self.repo_path(repository_path, repo_name)
if os.path.exists(repo_dir):
@@ -98,9 +105,9 @@ class GitApplication(web.Application):
@require_auth
async def delete_repository(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ username = request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@@ -115,9 +122,9 @@ class GitApplication(web.Application):
@require_auth
async def clone_repository(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@@ -126,15 +133,15 @@ class GitApplication(web.Application):
response_data = {
"repository": repo_name,
"clone_command": f"git clone {clone_url}",
- "clone_url": clone_url
+ "clone_url": clone_url,
}
return web.json_response(response_data)
@require_auth
async def push_repository(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ username = request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@@ -142,34 +149,40 @@ class GitApplication(web.Application):
data = await request.json()
except json.JSONDecodeError:
return web.Response(text="Invalid JSON data", status=400)
- commit_message = data.get('commit_message', 'Update from server')
- branch = data.get('branch', 'main')
- changes = data.get('changes', [])
+ commit_message = data.get("commit_message", "Update from server")
+ branch = data.get("branch", "main")
+ changes = data.get("changes", [])
if not changes:
return web.Response(text="No changes provided", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
- temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ temp_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
for change in changes:
- file_path = os.path.join(temp_dir, change.get('file', ''))
- content = change.get('content', '')
+ file_path = os.path.join(temp_dir, change.get("file", ""))
+ content = change.get("content", "")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
- with open(file_path, 'w') as f:
+ with open(file_path, "w") as f:
f.write(content)
temp_repo.git.add(A=True)
- if not temp_repo.config_reader().has_section('user'):
- temp_repo.config_writer().set_value("user", "name", "Git Server").release()
- temp_repo.config_writer().set_value("user", "email", "git@server.local").release()
+ if not temp_repo.config_reader().has_section("user"):
+ temp_repo.config_writer().set_value(
+ "user", "name", "Git Server"
+ ).release()
+ temp_repo.config_writer().set_value(
+ "user", "email", "git@server.local"
+ ).release()
temp_repo.index.commit(commit_message)
- origin = temp_repo.remote('origin')
+ origin = temp_repo.remote("origin")
origin.push(refspec=f"{branch}:{branch}")
logger.info(f"Pushed to repository: {repo_name} for user {username}")
return web.Response(text=f"Successfully pushed changes to {repo_name}")
@require_auth
async def pull_repository(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ username = request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@@ -177,13 +190,15 @@ class GitApplication(web.Application):
data = await request.json()
except json.JSONDecodeError:
data = {}
- remote_url = data.get('remote_url')
- branch = data.get('branch', 'main')
+ remote_url = data.get("remote_url")
+ branch = data.get("branch", "main")
if not remote_url:
return web.Response(text="Remote URL is required", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
try:
- local_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ local_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
remote_name = "pull_source"
try:
remote = local_repo.create_remote(remote_name, remote_url)
@@ -192,38 +207,46 @@ class GitApplication(web.Application):
remote.set_url(remote_url)
remote.fetch()
local_repo.git.merge(f"{remote_name}/{branch}")
- origin = local_repo.remote('origin')
+ origin = local_repo.remote("origin")
origin.push()
- logger.info(f"Pulled to repository {repo_name} from {remote_url} for user {username}")
- return web.Response(text=f"Successfully pulled changes from {remote_url} to {repo_name}")
+ logger.info(
+ f"Pulled to repository {repo_name} from {remote_url} for user {username}"
+ )
+ return web.Response(
+ text=f"Successfully pulled changes from {remote_url} to {repo_name}"
+ )
except Exception as e:
logger.error(f"Error pulling to {repo_name}: {str(e)}")
return web.Response(text=f"Error pulling changes: {str(e)}", status=500)
@require_auth
async def status_repository(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
with tempfile.TemporaryDirectory() as temp_dir:
try:
- temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ temp_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
branches = [b.name for b in temp_repo.branches]
active_branch = temp_repo.active_branch.name
commits = []
for commit in list(temp_repo.iter_commits(max_count=5)):
- commits.append({
- "id": commit.hexsha,
- "author": f"{commit.author.name} <{commit.author.email}>",
- "date": commit.committed_datetime.isoformat(),
- "message": commit.message
- })
+ commits.append(
+ {
+ "id": commit.hexsha,
+ "author": f"{commit.author.name} <{commit.author.email}>",
+ "date": commit.committed_datetime.isoformat(),
+ "message": commit.message,
+ }
+ )
files = []
for root, dirs, filenames in os.walk(temp_dir):
- if '.git' in root:
+ if ".git" in root:
continue
for filename in filenames:
full_path = os.path.join(root, filename)
@@ -234,50 +257,58 @@ class GitApplication(web.Application):
"branches": branches,
"active_branch": active_branch,
"recent_commits": commits,
- "files": files
+ "files": files,
}
return web.json_response(status_info)
except Exception as e:
logger.error(f"Error getting status for {repo_name}: {str(e)}")
- return web.Response(text=f"Error getting repository status: {str(e)}", status=500)
+ return web.Response(
+ text=f"Error getting repository status: {str(e)}", status=500
+ )
@require_auth
async def list_repositories(self, request):
- username = request['username']
+ request["username"]
try:
repos = []
user_dir = self.REPO_DIR
if os.path.exists(user_dir):
for item in os.listdir(user_dir):
item_path = os.path.join(user_dir, item)
- if os.path.isdir(item_path) and item.endswith('.git'):
+ if os.path.isdir(item_path) and item.endswith(".git"):
repos.append(item[:-4])
- if request.query.get('format') == 'json':
+ if request.query.get("format") == "json":
return web.json_response({"repositories": repos})
else:
- return web.Response(text="\n".join(repos) if repos else "No repositories found")
+ return web.Response(
+ text="\n".join(repos) if repos else "No repositories found"
+ )
except Exception as e:
logger.error(f"Error listing repositories: {str(e)}")
- return web.Response(text=f"Error listing repositories: {str(e)}", status=500)
+ return web.Response(
+ text=f"Error listing repositories: {str(e)}", status=500
+ )
@require_auth
async def list_branches(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
with tempfile.TemporaryDirectory() as temp_dir:
- temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ temp_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
branches = [b.name for b in temp_repo.branches]
return web.json_response({"branches": branches})
@require_auth
async def create_branch(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ username = request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
@@ -285,145 +316,168 @@ class GitApplication(web.Application):
data = await request.json()
except json.JSONDecodeError:
return web.Response(text="Invalid JSON data", status=400)
- branch_name = data.get('branch_name')
- start_point = data.get('start_point', 'HEAD')
+ branch_name = data.get("branch_name")
+ start_point = data.get("start_point", "HEAD")
if not branch_name:
return web.Response(text="Branch name is required", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
try:
- temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ temp_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
temp_repo.git.branch(branch_name, start_point)
- temp_repo.git.push('origin', branch_name)
- logger.info(f"Created branch {branch_name} in repository {repo_name} for user {username}")
+ temp_repo.git.push("origin", branch_name)
+ logger.info(
+ f"Created branch {branch_name} in repository {repo_name} for user {username}"
+ )
return web.Response(text=f"Created branch {branch_name}")
except Exception as e:
- logger.error(f"Error creating branch {branch_name} in {repo_name}: {str(e)}")
+ logger.error(
+ f"Error creating branch {branch_name} in {repo_name}: {str(e)}"
+ )
return web.Response(text=f"Error creating branch: {str(e)}", status=500)
@require_auth
async def commit_log(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- repository_path = request['repository_path']
+ request["username"]
+ repo_name = request.match_info["repo_name"]
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
try:
- limit = int(request.query.get('limit', 10))
- branch = request.query.get('branch', 'main')
+ limit = int(request.query.get("limit", 10))
+ branch = request.query.get("branch", "main")
except ValueError:
return web.Response(text="Invalid limit parameter", status=400)
with tempfile.TemporaryDirectory() as temp_dir:
try:
- temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ temp_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
commits = []
try:
for commit in list(temp_repo.iter_commits(branch, max_count=limit)):
- commits.append({
- "id": commit.hexsha,
- "short_id": commit.hexsha[:7],
- "author": f"{commit.author.name} <{commit.author.email}>",
- "date": commit.committed_datetime.isoformat(),
- "message": commit.message.strip()
- })
+ commits.append(
+ {
+ "id": commit.hexsha,
+ "short_id": commit.hexsha[:7],
+ "author": f"{commit.author.name} <{commit.author.email}>",
+ "date": commit.committed_datetime.isoformat(),
+ "message": commit.message.strip(),
+ }
+ )
except git.GitCommandError as e:
if "unknown revision or path" in str(e):
commits = []
else:
raise
- return web.json_response({
- "repository": repo_name,
- "branch": branch,
- "commits": commits
- })
+ return web.json_response(
+ {"repository": repo_name, "branch": branch, "commits": commits}
+ )
except Exception as e:
logger.error(f"Error getting commit log for {repo_name}: {str(e)}")
- return web.Response(text=f"Error getting commit log: {str(e)}", status=500)
+ return web.Response(
+ text=f"Error getting commit log: {str(e)}", status=500
+ )
@require_auth
async def file_content(self, request):
- username = request['username']
- repo_name = request.match_info['repo_name']
- file_path = request.match_info.get('file_path', '')
- branch = request.query.get('branch', 'main')
- repository_path = request['repository_path']
+ request["username"]
+ repo_name = request.match_info["repo_name"]
+ file_path = request.match_info.get("file_path", "")
+ branch = request.query.get("branch", "main")
+ repository_path = request["repository_path"]
error_response = self.check_repo_exists(repository_path, repo_name)
if error_response:
return error_response
with tempfile.TemporaryDirectory() as temp_dir:
try:
- temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir)
+ temp_repo = git.Repo.clone_from(
+ self.repo_path(repository_path, repo_name), temp_dir
+ )
try:
temp_repo.git.checkout(branch)
except git.GitCommandError:
return web.Response(text=f"Branch '{branch}' not found", status=404)
file_full_path = os.path.join(temp_dir, file_path)
if not os.path.exists(file_full_path):
- return web.Response(text=f"File '{file_path}' not found", status=404)
+ return web.Response(
+ text=f"File '{file_path}' not found", status=404
+ )
if os.path.isdir(file_full_path):
files = os.listdir(file_full_path)
- return web.json_response({
- "repository": repo_name,
- "path": file_path,
- "type": "directory",
- "contents": files
- })
+ return web.json_response(
+ {
+ "repository": repo_name,
+ "path": file_path,
+ "type": "directory",
+ "contents": files,
+ }
+ )
else:
try:
- with open(file_full_path, 'r') as f:
+ with open(file_full_path) as f:
content = f.read()
return web.Response(text=content)
except UnicodeDecodeError:
- return web.Response(text=f"Cannot display binary file content for '{file_path}'", status=400)
+ return web.Response(
+ text=f"Cannot display binary file content for '{file_path}'",
+ status=400,
+ )
except Exception as e:
logger.error(f"Error getting file content from {repo_name}: {str(e)}")
- return web.Response(text=f"Error getting file content: {str(e)}", status=500)
+ return web.Response(
+ text=f"Error getting file content: {str(e)}", status=500
+ )
@require_auth
async def git_smart_http(self, request):
- username = request['username']
- repository_path = request['repository_path']
+ request["username"]
+ repository_path = request["repository_path"]
path = request.path
+
async def get_repository_path():
- req_path = path.lstrip('/')
- if req_path.endswith('/info/refs'):
- repo_name = req_path[:-len('/info/refs')]
- elif req_path.endswith('/git-upload-pack'):
- repo_name = req_path[:-len('/git-upload-pack')]
- elif req_path.endswith('/git-receive-pack'):
- repo_name = req_path[:-len('/git-receive-pack')]
+ req_path = path.lstrip("/")
+ if req_path.endswith("/info/refs"):
+ repo_name = req_path[: -len("/info/refs")]
+ elif req_path.endswith("/git-upload-pack"):
+ repo_name = req_path[: -len("/git-upload-pack")]
+ elif req_path.endswith("/git-receive-pack"):
+ repo_name = req_path[: -len("/git-receive-pack")]
else:
repo_name = req_path
- if repo_name.endswith('.git'):
+ if repo_name.endswith(".git"):
repo_name = repo_name[:-4]
repo_name = repo_name[4:]
repo_dir = repository_path.joinpath(repo_name + ".git")
logger.info(f"Resolved repo path: {repo_dir}")
- return repo_dir
+ return repo_dir
+
async def handle_info_refs(service):
repo_path = await get_repository_path()
-
+
logger.info(f"handle_info_refs: {repo_path}")
if not os.path.exists(repo_path):
return web.Response(text="Repository not found", status=404)
- cmd = [service, '--stateless-rpc', '--advertise-refs', str(repo_path)]
+ cmd = [service, "--stateless-rpc", "--advertise-refs", str(repo_path)]
try:
process = await asyncio.create_subprocess_exec(
- *cmd,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE
+ *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
logger.error(f"Git command failed: {stderr.decode()}")
- return web.Response(text=f"Git error: {stderr.decode()}", status=500)
+ return web.Response(
+ text=f"Git error: {stderr.decode()}", status=500
+ )
response = web.StreamResponse(
status=200,
- reason='OK',
+ reason="OK",
headers={
- 'Content-Type': f'application/x-{service}-advertisement',
- 'Cache-Control': 'no-cache'
- }
+ "Content-Type": f"application/x-{service}-advertisement",
+ "Cache-Control": "no-cache",
+ },
)
await response.prepare(request)
packet = f"# service={service}\n"
@@ -435,48 +489,58 @@ class GitApplication(web.Application):
except Exception as e:
logger.error(f"Error handling info/refs: {str(e)}")
return web.Response(text=f"Server error: {str(e)}", status=500)
+
async def handle_service_rpc(service):
repo_path = await get_repository_path()
logger.info(f"handle_service_rpc: {repo_path}")
if not os.path.exists(repo_path):
return web.Response(text="Repository not found", status=404)
- if not request.headers.get('Content-Type') == f'application/x-{service}-request':
+ if (
+ not request.headers.get("Content-Type")
+ == f"application/x-{service}-request"
+ ):
return web.Response(text="Invalid Content-Type", status=403)
body = await request.read()
- cmd = [service, '--stateless-rpc', str(repo_path)]
+ cmd = [service, "--stateless-rpc", str(repo_path)]
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE
+ stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate(input=body)
if process.returncode != 0:
logger.error(f"Git command failed: {stderr.decode()}")
- return web.Response(text=f"Git error: {stderr.decode()}", status=500)
+ return web.Response(
+ text=f"Git error: {stderr.decode()}", status=500
+ )
return web.Response(
- body=stdout,
- content_type=f'application/x-{service}-result'
+ body=stdout, content_type=f"application/x-{service}-result"
)
except Exception as e:
logger.error(f"Error handling service RPC: {str(e)}")
return web.Response(text=f"Server error: {str(e)}", status=500)
- if request.method == 'GET' and path.endswith('/info/refs'):
- service = request.query.get('service')
- if service in ('git-upload-pack', 'git-receive-pack'):
+
+ if request.method == "GET" and path.endswith("/info/refs"):
+ service = request.query.get("service")
+ if service in ("git-upload-pack", "git-receive-pack"):
return await handle_info_refs(service)
else:
- return web.Response(text="Smart HTTP requires service parameter", status=400)
- elif request.method == 'POST' and '/git-upload-pack' in path:
- return await handle_service_rpc('git-upload-pack')
- elif request.method == 'POST' and '/git-receive-pack' in path:
- return await handle_service_rpc('git-receive-pack')
+ return web.Response(
+ text="Smart HTTP requires service parameter", status=400
+ )
+ elif request.method == "POST" and "/git-upload-pack" in path:
+ return await handle_service_rpc("git-upload-pack")
+ elif request.method == "POST" and "/git-receive-pack" in path:
+ return await handle_service_rpc("git-receive-pack")
return web.Response(text="Not found", status=404)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
try:
import uvloop
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.info("Using uvloop for improved performance")
except ImportError:
diff --git a/src/snek/snode.py b/src/snek/snode.py
index ae42a1f..517d52a 100644
--- a/src/snek/snode.py
+++ b/src/snek/snode.py
@@ -2,28 +2,28 @@ import aiohttp
ENABLED = False
-import aiohttp
import asyncio
-from aiohttp import web
-
+import json
import sqlite3
-import dataset
+import aiohttp
+from aiohttp import web
from sqlalchemy import event
from sqlalchemy.engine import Engine
-import json
-
queue = asyncio.Queue()
+
class State:
- do_not_sync = False
+ do_not_sync = False
+
async def sync_service(app):
if not ENABLED:
- return
+ return
session = aiohttp.ClientSession()
- async with session.ws_connect('http://localhost:3131/ws') as ws:
+ async with session.ws_connect("http://localhost:3131/ws") as ws:
+
async def receive():
queries_synced = 0
@@ -31,7 +31,7 @@ async def sync_service(app):
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
- State.do_not_sync = True
+ State.do_not_sync = True
app.db.execute(*data)
app.db.commit()
State.do_not_sync = False
@@ -41,21 +41,25 @@ async def sync_service(app):
await app.services.socket.broadcast_event()
except Exception as e:
print(e)
- pass
- #print(f"Received: {msg.data}")
+ pass
+ # print(f"Received: {msg.data}")
elif msg.type == aiohttp.WSMsgType.ERROR:
break
+
async def write():
while True:
msg = await queue.get()
- await ws.send_str(json.dumps(msg,default=str))
+ await ws.send_str(json.dumps(msg, default=str))
queue.task_done()
await asyncio.gather(receive(), write())
await session.close()
+
queries_queued = 0
+
+
# Attach a listener to log all executed statements
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
@@ -63,7 +67,7 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
return
global queries_queued
if State.do_not_sync:
- print(statement,parameters)
+ print(statement, parameters)
return
if statement.startswith("SELECT"):
return
@@ -71,17 +75,18 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
queries_queued += 1
print("Queries queued: " + str(queries_queued))
+
async def websocket_handler(request):
- queries_broadcasted = 0
+ queries_broadcasted = 0
ws = web.WebSocketResponse()
await ws.prepare(request)
- request.app['websockets'].append(ws)
+ request.app["websockets"].append(ws)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
- for client in request.app['websockets']:
+ for client in request.app["websockets"]:
if client != ws:
await client.send_str(msg.data)
- cursor = request.app['db'].cursor()
+ cursor = request.app["db"].cursor()
data = json.loads(msg.data)
queries_broadcasted += 1
@@ -89,28 +94,32 @@ async def websocket_handler(request):
cursor.close()
print("Queries broadcasted: " + str(queries_broadcasted))
elif msg.type == aiohttp.WSMsgType.ERROR:
- print(f'WebSocket connection closed with exception {ws.exception()}')
+ print(f"WebSocket connection closed with exception {ws.exception()}")
- request.app['websockets'].remove(ws)
+ request.app["websockets"].remove(ws)
return ws
-app = web.Application()
-app['websockets'] = []
-app.router.add_get('/ws', websocket_handler)
+app = web.Application()
+app["websockets"] = []
+
+app.router.add_get("/ws", websocket_handler)
+
async def on_startup(app):
- app['db'] = sqlite3.connect('snek.db')
+ app["db"] = sqlite3.connect("snek.db")
print("Server starting...")
+
async def on_cleanup(app):
- for ws in app['websockets']:
+ for ws in app["websockets"]:
await ws.close()
- app['db'].close()
+ app["db"].close()
+
app.on_startup.append(on_startup)
app.on_cleanup.append(on_cleanup)
-if __name__ == '__main__':
- web.run_app(app, host='127.0.0.1', port=3131)
+if __name__ == "__main__":
+ web.run_app(app, host="127.0.0.1", port=3131)
diff --git a/src/snek/sssh.py b/src/snek/sssh.py
index 6b179b6..bb83fa9 100644
--- a/src/snek/sssh.py
+++ b/src/snek/sssh.py
@@ -1,25 +1,30 @@
-import asyncssh
import logging
from pathlib import Path
-global _app
+import asyncssh
+
+global _app
+
def set_app(app):
global _app
- _app = app
+ _app = app
+
def get_app():
return _app
+
logger = logging.getLogger(__name__)
roots = {}
+
class SFTPServer(asyncssh.SFTPServer):
-
+
def __init__(self, chan: asyncssh.SSHServerChannel):
self.root = get_app().services.user.get_home_folder_by_username(
- chan.get_extra_info('username')
+ chan.get_extra_info("username")
)
self.root.mkdir(exist_ok=True)
self.root = str(self.root)
@@ -30,20 +35,22 @@ class SFTPServer(asyncssh.SFTPServer):
logger.debug(f"Mapping client path {path} to {mapped_path}")
return str(mapped_path).encode()
+
class SSHServer(asyncssh.SSHServer):
def password_auth_supported(self):
return True
def validate_password(self, username, password):
logger.debug(f"Validating credentials for user {username}")
- result = get_app().services.user.authenticate_sync(username,password)
+ result = get_app().services.user.authenticate_sync(username, password)
logger.info(f"Validating credentials for user {username}: {result}")
return result
-async def start_ssh_server(app,host,port):
- set_app(app)
+
+async def start_ssh_server(app, host, port):
+ set_app(app)
logger.info("Starting SFTP server setup")
-
+
host_key_path = Path("drive") / ".ssh" / "sftp_server_key"
host_key_path.parent.mkdir(exist_ok=True, parents=True)
try:
@@ -63,14 +70,12 @@ async def start_ssh_server(app,host,port):
x = await asyncssh.listen(
host=host,
port=port,
- #process_factory=handle_client,
+ # process_factory=handle_client,
server_host_keys=[key],
server_factory=SSHServer,
- sftp_factory=SFTPServer
+ sftp_factory=SFTPServer,
)
return x
- except Exception as e:
+ except Exception:
logger.warning(f"Failed to start SFTP server. Already running.")
- pass
-
-
+ pass
diff --git a/src/snek/system/docker.py b/src/snek/system/docker.py
index d651950..add91e4 100644
--- a/src/snek/system/docker.py
+++ b/src/snek/system/docker.py
@@ -1,58 +1,71 @@
-import yaml
import copy
+import yaml
+
+
class ComposeFileManager:
- def __init__(self, compose_path='docker-compose.yml'):
+ def __init__(self, compose_path="docker-compose.yml"):
self.compose_path = compose_path
self._load()
def _load(self):
try:
- with open(self.compose_path, 'r') as f:
+ with open(self.compose_path) as f:
self.compose = yaml.safe_load(f) or {}
except FileNotFoundError:
- self.compose = {'version': '3', 'services': {}}
+ self.compose = {"version": "3", "services": {}}
def _save(self):
- with open(self.compose_path, 'w') as f:
+ with open(self.compose_path, "w") as f:
yaml.dump(self.compose, f, default_flow_style=False)
def list_instances(self):
- return list(self.compose.get('services', {}).keys())
+ return list(self.compose.get("services", {}).keys())
- def create_instance(self, name, image, command=None, cpus=None, memory=None, ports=None, volumes=None):
+ def create_instance(
+ self,
+ name,
+ image,
+ command=None,
+ cpus=None,
+ memory=None,
+ ports=None,
+ volumes=None,
+ ):
service = {
- 'image': image,
+ "image": image,
}
if command:
- service['command'] = command
+ service["command"] = command
if cpus or memory:
- service['deploy'] = {'resources': {'limits': {}}}
+ service["deploy"] = {"resources": {"limits": {}}}
if cpus:
- service['deploy']['resources']['limits']['cpus'] = str(cpus)
+ service["deploy"]["resources"]["limits"]["cpus"] = str(cpus)
if memory:
- service['deploy']['resources']['limits']['memory'] = str(memory)
+ service["deploy"]["resources"]["limits"]["memory"] = str(memory)
if ports:
- service['ports'] = [f"{host}:{container}" for container, host in ports.items()]
+ service["ports"] = [
+ f"{host}:{container}" for container, host in ports.items()
+ ]
if volumes:
- service['volumes'] = volumes
+ service["volumes"] = volumes
- self.compose.setdefault('services', {})[name] = service
+ self.compose.setdefault("services", {})[name] = service
self._save()
def remove_instance(self, name):
- if name in self.compose.get('services', {}):
- del self.compose['services'][name]
+ if name in self.compose.get("services", {}):
+ del self.compose["services"][name]
self._save()
def get_instance(self, name):
- return self.compose.get('services', {}).get(name)
+ return self.compose.get("services", {}).get(name)
def duplicate_instance(self, name, new_name):
orig = self.get_instance(name)
if not orig:
raise ValueError(f"No such instance: {name}")
- self.compose['services'][new_name] = copy.deepcopy(orig)
+ self.compose["services"][new_name] = copy.deepcopy(orig)
self._save()
def update_instance(self, name, **kwargs):
@@ -62,11 +75,12 @@ class ComposeFileManager:
for k, v in kwargs.items():
if v is not None:
service[k] = v
- self.compose['services'][name] = service
+ self.compose["services"][name] = service
self._save()
# Storage size is not tracked in compose files; would need Docker API for that.
+
# Example usage:
# mgr = ComposeFileManager()
# mgr.create_instance('web', 'nginx:latest', cpus=1, memory='512m', ports={80:8080}, volumes=['./data:/data'])
diff --git a/src/snek/system/mapper.py b/src/snek/system/mapper.py
index a0e37d2..fef7784 100644
--- a/src/snek/system/mapper.py
+++ b/src/snek/system/mapper.py
@@ -1,6 +1,7 @@
DEFAULT_LIMIT = 30
+import asyncio
import typing
-import asyncio
+
from snek.system.model import BaseModel
@@ -12,21 +13,21 @@ class BaseMapper:
def __init__(self, app):
self.app = app
- self.semaphore = asyncio.Semaphore(1)
+ self.semaphore = asyncio.Semaphore(1)
self.default_limit = self.__class__.default_limit
@property
def db(self):
return self.app.db
- @property
+ @property
def loop(self):
return asyncio.get_event_loop()
async def run_in_executor(self, func, *args, **kwargs):
async with self.semaphore:
return func(*args, **kwargs)
- #return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
+ # return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs))
async def new(self):
return self.model_class(mapper=self, app=self.app)
@@ -38,8 +39,8 @@ class BaseMapper:
async def get(self, uid: str = None, **kwargs) -> BaseModel:
if uid:
kwargs["uid"] = uid
-
- record = await self.run_in_executor(self.table.find_one,**kwargs)
+
+ record = await self.run_in_executor(self.table.find_one, **kwargs)
if not record:
return None
record = dict(record)
@@ -50,7 +51,7 @@ class BaseMapper:
return await self.model_class.from_record(mapper=self, record=record)
async def exists(self, **kwargs):
- return await self.run_in_executor(self.table.exists,**kwargs)
+ return await self.run_in_executor(self.table.exists, **kwargs)
async def count(self, **kwargs) -> int:
return await self.run_in_executor(self.table.count, **kwargs)
@@ -71,7 +72,7 @@ class BaseMapper:
yield model
async def query(self, sql, *args):
- for record in await self.run_in_executor(self.db.query,sql, *args):
+ for record in await self.run_in_executor(self.db.query, sql, *args):
yield dict(record)
async def update(self, model):
diff --git a/src/snek/system/markdown.py b/src/snek/system/markdown.py
index a823b5f..e83ff38 100644
--- a/src/snek/system/markdown.py
+++ b/src/snek/system/markdown.py
@@ -1,5 +1,5 @@
# Original source: https://brandonjay.dev/posts/2021/render-markdown-html-in-python-with-jinja2
-import re
+import re
from types import SimpleNamespace
from app.cache import time_cache_async
@@ -13,19 +13,20 @@ from pygments.lexers import get_lexer_by_name
def strip_markdown(md_text):
- # Remove code blocks (
- md_text = re.sub(r'[\s\S]?```', '', md_text)
- md_text = re.sub(r'^\s{4,}.$', '', md_text, flags=re.MULTILINE)
- md_text = re.sub(r'^\s{0,3}#{1,6}\s+', '', md_text, flags=re.MULTILINE)
- md_text = re.sub(r'!\[.?\]\(.?\)', '', md_text)
- md_text = re.sub(r'\[([^\]]+)\]\(.?\)', r'\1', md_text)
- md_text = re.sub(r'(\*|_){1,3}(.+?)\1{1,3}', r'\2', md_text)
- md_text = re.sub(r'^\s{0,3}>+\s?', '', md_text, flags=re.MULTILINE)
- md_text = re.sub(r'^(\s)(\-{3,}|_{3,}|\{3,})\s$', '', md_text, flags=re.MULTILINE)
- md_text = re.sub(r'[`~>#+\-=]', '', md_text)
- md_text = re.sub(r'\s+', ' ', md_text)
+ # Remove code blocks (
+ md_text = re.sub(r"[\s\S]?```", "", md_text)
+ md_text = re.sub(r"^\s{4,}.$", "", md_text, flags=re.MULTILINE)
+ md_text = re.sub(r"^\s{0,3}#{1,6}\s+", "", md_text, flags=re.MULTILINE)
+ md_text = re.sub(r"!\[.?\]\(.?\)", "", md_text)
+ md_text = re.sub(r"\[([^\]]+)\]\(.?\)", r"\1", md_text)
+ md_text = re.sub(r"(\*|_){1,3}(.+?)\1{1,3}", r"\2", md_text)
+ md_text = re.sub(r"^\s{0,3}>+\s?", "", md_text, flags=re.MULTILINE)
+ md_text = re.sub(r"^(\s)(\-{3,}|_{3,}|\{3,})\s$", "", md_text, flags=re.MULTILINE)
+ md_text = re.sub(r"[`~>#+\-=]", "", md_text)
+ md_text = re.sub(r"\s+", " ", md_text)
return md_text.strip()
+
class MarkdownRenderer(HTMLRenderer):
_allow_harmful_protocols = False
@@ -40,7 +41,7 @@ class MarkdownRenderer(HTMLRenderer):
formatter = html.HtmlFormatter()
self.env.globals["highlight_styles"] = formatter.get_style_defs()
- #def _escape(self, str):
+ # def _escape(self, str):
# return str ##escape(str)
def get_lexer(self, lang, default="bash"):
diff --git a/src/snek/system/middleware.py b/src/snek/system/middleware.py
index 368e3ae..93e5eaf 100644
--- a/src/snek/system/middleware.py
+++ b/src/snek/system/middleware.py
@@ -6,9 +6,9 @@
# MIT License: This code is distributed under the MIT License.
-from aiohttp import web
import secrets
+from aiohttp import web
csp_policy = (
"default-src 'self'; "
@@ -22,14 +22,16 @@ csp_policy = (
def generate_nonce():
return secrets.token_hex(16)
+
@web.middleware
async def csp_middleware(request, handler):
response = await handler(request)
- return response
- nonce = generate_nonce()
- response.headers['Content-Security-Policy'] = csp_policy.format(nonce=nonce)
return response
+ nonce = generate_nonce()
+ response.headers["Content-Security-Policy"] = csp_policy.format(nonce=nonce)
+ return response
+
@web.middleware
async def no_cors_middleware(request, handler):
diff --git a/src/snek/system/security.py b/src/snek/system/security.py
index 4ead284..a5b9302 100644
--- a/src/snek/system/security.py
+++ b/src/snek/system/security.py
@@ -63,9 +63,11 @@ def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str:
obj = hashlib.sha256(salted)
return obj.hexdigest()
+
async def hash(data: str, salt: str = DEFAULT_SALT) -> str:
return hash_sync(data, salt)
+
def verify_sync(string: str, hashed: str) -> bool:
"""Verify if the given string matches the hashed value.
@@ -78,5 +80,6 @@ def verify_sync(string: str, hashed: str) -> bool:
"""
return hash_sync(string) == hashed
+
async def verify(string: str, hashed: str) -> bool:
return verify_sync(string, hashed)
diff --git a/src/snek/system/template.py b/src/snek/system/template.py
index 3dd4357..81306e9 100644
--- a/src/snek/system/template.py
+++ b/src/snek/system/template.py
@@ -1,17 +1,16 @@
import mimetypes
import re
-from functools import lru_cache
from types import SimpleNamespace
-from urllib.parse import urlparse, parse_qs
-from app.cache import time_cache
+from urllib.parse import parse_qs, urlparse
+
+import bleach
import emoji
import requests
+from app.cache import time_cache
from bs4 import BeautifulSoup
from jinja2 import TemplateSyntaxError, nodes
from jinja2.ext import Extension
from jinja2.nodes import Const
-import bleach
-
emoji.EMOJI_DATA[''] = {
"en": ":snek1:",
@@ -81,13 +80,29 @@ emoji.EMOJI_DATA[
ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + [
- "img", "video", "audio", "source", "iframe", "picture", "span"
+ "img",
+ "video",
+ "audio",
+ "source",
+ "iframe",
+ "picture",
+ "span",
]
ALLOWED_ATTRIBUTES = {
**bleach.sanitizer.ALLOWED_ATTRIBUTES,
"img": ["src", "alt", "title", "width", "height"],
"a": ["href", "title", "target", "rel", "referrerpolicy", "class"],
- "iframe": ["src", "width", "height", "frameborder", "allow", "allowfullscreen", "title", "referrerpolicy", "style"],
+ "iframe": [
+ "src",
+ "width",
+ "height",
+ "frameborder",
+ "allow",
+ "allowfullscreen",
+ "title",
+ "referrerpolicy",
+ "style",
+ ],
"video": ["src", "controls", "width", "height"],
"audio": ["src", "controls"],
"source": ["src", "type"],
@@ -96,9 +111,6 @@ ALLOWED_ATTRIBUTES = {
}
-
-
-
def sanitize_html(value):
return bleach.clean(
value,
@@ -109,7 +121,6 @@ def sanitize_html(value):
)
-
def set_link_target_blank(text):
soup = BeautifulSoup(text, "html.parser")
@@ -118,26 +129,43 @@ def set_link_target_blank(text):
element.attrs["rel"] = "noopener noreferrer"
element.attrs["referrerpolicy"] = "no-referrer"
element.attrs["href"] = element.attrs["href"].strip(".").strip(",")
-
+
return str(soup)
+
SAFE_ATTRIBUTES = {
- 'href', 'src', 'alt', 'title', 'width', 'height', 'style', 'id', 'class',
- 'rel', 'type', 'name', 'value', 'placeholder', 'aria-hidden', 'aria-label', 'srcset'
+ "href",
+ "src",
+ "alt",
+ "title",
+ "width",
+ "height",
+ "style",
+ "id",
+ "class",
+ "rel",
+ "type",
+ "name",
+ "value",
+ "placeholder",
+ "aria-hidden",
+ "aria-label",
+ "srcset",
}
+
def whitelist_attributes(html):
- soup = BeautifulSoup(html, 'html.parser')
+ soup = BeautifulSoup(html, "html.parser")
for tag in soup.find_all():
- if hasattr(tag, 'attrs'):
- if tag.name in ['script','form','input']:
- tag.replace_with('')
+ if hasattr(tag, "attrs"):
+ if tag.name in ["script", "form", "input"]:
+ tag.replace_with("")
continue
attrs = dict(tag.attrs)
for attr in list(attrs):
# Check if attribute is in the safe list or is a data-* attribute
- if not (attr in SAFE_ATTRIBUTES or attr.startswith('data-')):
+ if not (attr in SAFE_ATTRIBUTES or attr.startswith("data-")):
del tag.attrs[attr]
return str(soup)
@@ -236,12 +264,12 @@ def enrich_image_rendering(text):
for element in soup.find_all("img"):
if element.attrs["src"].startswith("/"):
element.attrs["src"] += "?width=240&height=240"
- picture_template = f'''
+ picture_template = f"""
-
{page_description or "No description available."}
", "html.parser") + BeautifulSoup( + f"{page_description or "No description available."}
", + "html.parser", + ) ) description_element.append( - BeautifulSoup(f"{original_link_name}
", "html.parser") + BeautifulSoup( + f"{original_link_name}
", + "html.parser", + ) ) render_element.append(description_element_base) - for attachment in attachments.values(): soup.append(attachment) diff --git a/src/snek/system/websocket.py b/src/snek/system/websocket.py index c54b094..14d40fd 100644 --- a/src/snek/system/websocket.py +++ b/src/snek/system/websocket.py @@ -1,22 +1,24 @@ class WebSocketClient: def __init__(self, hostname, port): - self.buffer = b'' + self.buffer = b"" self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.hostname = hostname + self.hostname = hostname self.port = port - self.connect() - + self.connect() + def __getattr__(self, method, *args, **kwargs): if method in self.__dict__.keys(): return self.__dict__[method] + def call(*args, **kwargs): - self.write(json.dumps({'method': method, 'args': args, 'kwargs': kwargs})) + self.write(json.dumps({"method": method, "args": args, "kwargs": kwargs})) return json.loads(self.read()) + return call def connect(self): self.socket.connect((self.hostname, self.port)) - key = base64.b64encode(b'1234123412341234').decode('utf-8') + key = base64.b64encode(b"1234123412341234").decode("utf-8") handshake = ( f"GET /db HTTP/1.1\r\n" f"Host: localhost:3131\r\n" @@ -25,57 +27,58 @@ class WebSocketClient: f"Sec-WebSocket-Key: {key}\r\n" f"Sec-WebSocket-Version: 13\r\n\r\n" ) - self.socket.sendall(handshake.encode('utf-8')) - response = self.read_until(b'\r\n\r\n') - if b'101 Switching Protocols' not in response: + self.socket.sendall(handshake.encode("utf-8")) + response = self.read_until(b"\r\n\r\n") + if b"101 Switching Protocols" not in response: raise Exception("Failed to connect to WebSocket") def write(self, message): - message_bytes = message.encode('utf-8') + message_bytes = message.encode("utf-8") length = len(message_bytes) if length <= 125: - self.socket.sendall(b'\x81' + bytes([length]) + message_bytes) + self.socket.sendall(b"\x81" + bytes([length]) + message_bytes) elif length >= 126 and length <= 65535: - self.socket.sendall(b'\x81' + bytes([126]) + length.to_bytes(2, 'big') + message_bytes) + self.socket.sendall( + b"\x81" + bytes([126]) + length.to_bytes(2, "big") + message_bytes + ) else: - self.socket.sendall(b'\x81' + bytes([127]) + length.to_bytes(8, 'big') + message_bytes) - + self.socket.sendall( + b"\x81" + bytes([127]) + length.to_bytes(8, "big") + message_bytes + ) - def read_until(self, delimiter): + def read_until(self, delimiter): while True: find_pos = self.buffer.find(delimiter) if find_pos != -1: - data = self.buffer[:find_pos+4] - self.buffer = self.buffer[find_pos+4:] - return data - + data = self.buffer[: find_pos + 4] + self.buffer = self.buffer[find_pos + 4 :] + return data + chunk = self.socket.recv(1024) if not chunk: return None self.buffer += chunk - + def read_exactly(self, length): while len(self.buffer) < length: chunk = self.socket.recv(length - len(self.buffer)) if not chunk: return None - self.buffer += chunk - response = self.buffer[: length] + self.buffer += chunk + response = self.buffer[:length] self.buffer = self.buffer[length:] return response def read(self): - frame = None + frame = None frame = self.read_exactly(2) length = frame[1] & 127 if length == 126: - length = int.from_bytes(self.read_exactly(2), 'big') + length = int.from_bytes(self.read_exactly(2), "big") elif length == 127: - length = int.from_bytes(self.read_exactly(8), 'big') + length = int.from_bytes(self.read_exactly(8), "big") message = self.read_exactly(length) return message - + def close(self): self.socket.close() - - diff --git a/src/snek/view/avatar.py b/src/snek/view/avatar.py index d129f5e..c984ec1 100644 --- a/src/snek/view/avatar.py +++ b/src/snek/view/avatar.py @@ -31,21 +31,17 @@ from multiavatar import multiavatar from snek.system.view import BaseView from snek.view.avatar_animal import generate_avatar_with_options -import functools class AvatarView(BaseView): login_required = False - - def __init__(self, *args,**kwargs): - super().__init__(*args,**kwargs) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.avatars = {} async def get(self): type_ = self.request.query.get("type") - type_match = { - "animal": self.get_animal, - "default": self.get_default - } + type_match = {"animal": self.get_animal, "default": self.get_default} handler = type_match.get(type_, self.get_default) return await handler() @@ -64,10 +60,9 @@ class AvatarView(BaseView): avatar = generate_avatar_with_options(self.request.query) await self.app.set(key, avatar) return web.Response(text=avatar, content_type="image/svg+xml") - except Exception as e: + except Exception: pass - def _get(self, uid): avatar = generate_avatar_with_options(self.request.query) return avatar diff --git a/src/snek/view/avatar_animal.py b/src/snek/view/avatar_animal.py index 3a7f653..42c47d3 100644 --- a/src/snek/view/avatar_animal.py +++ b/src/snek/view/avatar_animal.py @@ -1,23 +1,61 @@ -import random -import math import argparse import json -from typing import Dict, List, Tuple, Optional, Union +import math +import random +from typing import Dict, List, Optional + class AnimalAvatarGenerator: """A generator for animal-themed avatar SVGs.""" - + # Constants ANIMALS = [ - "cat", "dog", "fox", "wolf", "bear", "panda", "koala", "tiger", "lion", - "rabbit", "monkey", "elephant", "giraffe", "zebra", "penguin", "owl", - "deer", "raccoon", "squirrel", "hedgehog", "otter", "frog" + "cat", + "dog", + "fox", + "wolf", + "bear", + "panda", + "koala", + "tiger", + "lion", + "rabbit", + "monkey", + "elephant", + "giraffe", + "zebra", + "penguin", + "owl", + "deer", + "raccoon", + "squirrel", + "hedgehog", + "otter", + "frog", ] - + COLOR_PALETTES = { "natural": { - "cat": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#808080", "#FFFFFF", "#FFA500"], - "dog": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#808080", "#FFFFFF", "#FFA500"], + "cat": [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#000000", + "#808080", + "#FFFFFF", + "#FFA500", + ], + "dog": [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#000000", + "#808080", + "#FFFFFF", + "#FFA500", + ], "fox": ["#FF6600", "#FF7F00", "#FF8C00", "#FFA500", "#FFFFFF", "#000000"], "wolf": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000", "#696969"], "bear": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#FFFFFF"], @@ -25,35 +63,68 @@ class AnimalAvatarGenerator: "koala": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000"], "tiger": ["#FF8C00", "#FF7F00", "#FFFFFF", "#000000"], "lion": ["#DAA520", "#B8860B", "#CD853F", "#D2B48C", "#FFFFFF", "#000000"], - "rabbit": ["#FFFFFF", "#F5F5F5", "#D3D3D3", "#A9A9A9", "#FFC0CB", "#000000"], + "rabbit": [ + "#FFFFFF", + "#F5F5F5", + "#D3D3D3", + "#A9A9A9", + "#FFC0CB", + "#000000", + ], "monkey": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], "elephant": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000"], "giraffe": ["#DAA520", "#B8860B", "#F5DEB3", "#FFFFFF", "#000000"], "zebra": ["#000000", "#FFFFFF"], "penguin": ["#000000", "#FFFFFF", "#FFA500"], - "owl": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#FFFFFF", "#FFC0CB"], + "owl": [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#000000", + "#FFFFFF", + "#FFC0CB", + ], "deer": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#FFFFFF", "#000000"], "raccoon": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000"], "squirrel": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], "hedgehog": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], "otter": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], - "frog": ["#008000", "#00FF00", "#ADFF2F", "#7FFF00", "#000000", "#FFFFFF"] + "frog": ["#008000", "#00FF00", "#ADFF2F", "#7FFF00", "#000000", "#FFFFFF"], }, "pastel": { - "all": ["#FFB6C1", "#FFD700", "#FFDAB9", "#98FB98", "#ADD8E6", "#DDA0DD", "#F0E68C", "#FFFFE0"] + "all": [ + "#FFB6C1", + "#FFD700", + "#FFDAB9", + "#98FB98", + "#ADD8E6", + "#DDA0DD", + "#F0E68C", + "#FFFFE0", + ] }, "vibrant": { - "all": ["#FF0000", "#00FF00", "#0000FF", "#FFFF00", "#FF00FF", "#00FFFF", "#FFA500", "#FF1493"] + "all": [ + "#FF0000", + "#00FF00", + "#0000FF", + "#FFFF00", + "#FF00FF", + "#00FFFF", + "#FFA500", + "#FF1493", + ] }, "mono": { "all": ["#000000", "#333333", "#666666", "#999999", "#CCCCCC", "#FFFFFF"] - } + }, } - + EYE_STYLES = ["round", "oval", "almond", "wide", "narrow", "cute"] - + FACE_SHAPES = ["round", "oval", "square", "heart", "triangular", "diamond"] - + EAR_STYLES = { "cat": ["pointed", "folded", "round", "large", "small"], "dog": ["floppy", "pointed", "round", "large", "small"], @@ -76,11 +147,11 @@ class AnimalAvatarGenerator: "squirrel": ["pointed", "small"], "hedgehog": ["round", "small"], "otter": ["round", "small"], - "frog": ["none"] + "frog": ["none"], } - + NOSE_STYLES = ["round", "triangular", "small", "large", "heart", "button"] - + SPECIAL_FEATURES = { "cat": ["whiskers", "stripes", "spots"], "dog": ["spots", "patch", "whiskers"], @@ -103,16 +174,16 @@ class AnimalAvatarGenerator: "squirrel": ["bushy_tail", "none"], "hedgehog": ["spikes", "none"], "otter": ["whiskers", "none"], - "frog": ["spots", "none"] + "frog": ["spots", "none"], } - + EXPRESSIONS = ["happy", "serious", "surprised", "sleepy", "wink"] - + def __init__(self, seed: Optional[int] = None): """Initialize the avatar generator with an optional seed for reproducibility.""" if seed is not None: random.seed(seed) - + def _get_colors(self, animal: str, color_palette: str) -> List[str]: """Get colors for the given animal and palette.""" if color_palette in self.COLOR_PALETTES: @@ -120,382 +191,790 @@ class AnimalAvatarGenerator: return self.COLOR_PALETTES[color_palette][animal] elif "all" in self.COLOR_PALETTES[color_palette]: return self.COLOR_PALETTES[color_palette]["all"] - + # Default to natural palette for the animal or general natural colors if animal in self.COLOR_PALETTES["natural"]: return self.COLOR_PALETTES["natural"][animal] - + # If no specific colors found, use a mix of browns and grays - return ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#808080", "#A9A9A9", "#FFFFFF"] - + return [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#808080", + "#A9A9A9", + "#FFFFFF", + ] + def _get_ear_style(self, animal: str) -> str: """Get a random ear style appropriate for the animal.""" if animal in self.EAR_STYLES: return random.choice(self.EAR_STYLES[animal]) return "none" # Default for animals not in the list - + def _get_special_feature(self, animal: str) -> str: """Get a random special feature appropriate for the animal.""" if animal in self.SPECIAL_FEATURES: return random.choice(self.SPECIAL_FEATURES[animal]) return "none" # Default for animals not in the list - - def _draw_circle(self, cx: float, cy: float, r: float, fill: str, - stroke: str = "none", stroke_width: float = 1.0) -> str: + + def _draw_circle( + self, + cx: float, + cy: float, + r: float, + fill: str, + stroke: str = "none", + stroke_width: float = 1.0, + ) -> str: """Generate SVG for a circle.""" return f'{content}" except Exception as e: return f"
{e}" def get_icon(self, file): - if file.is_dir(): return "📁" - mime = mimetypes.guess_type(file.name)[0] or '' - if mime.startswith("image"): return "🖼️" - if mime.startswith("text"): return "📄" - if mime.startswith("audio"): return "🎵" - if mime.startswith("video"): return "🎬" - if file.name.endswith(".py"): return "🐍" + if file.is_dir(): + return "📁" + mime = mimetypes.guess_type(file.name)[0] or "" + if mime.startswith("image"): + return "🖼️" + if mime.startswith("text"): + return "📄" + if mime.startswith("audio"): + return "🎵" + if mime.startswith("video"): + return "🎬" + if file.name.endswith(".py"): + return "🐍" return "📦" - diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index 7227e27..a31c45b 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -7,19 +7,20 @@ # MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions. -import json -import traceback import asyncio +import json +import logging +import traceback + from aiohttp import web from snek.system.model import now from snek.system.profiler import Profiler from snek.system.view import BaseView -import logging - logger = logging.getLogger(__name__) + class RPCView(BaseView): class RPCApi: def __init__(self, view, ws): @@ -32,7 +33,7 @@ class RPCView(BaseView): async def _session_ensure(self): uid = await self.view.session_get("uid") - if not uid in self.user_session: + if uid not in self.user_session: self.user_session[uid] = { "said_hello": False, } @@ -50,45 +51,54 @@ class RPCView(BaseView): self._require_login() return await self.services.db.insert(self.user_uid, table_name, record) + async def db_update(self, table_name, record): self._require_login() return await self.services.db.update(self.user_uid, table_name, record) - async def set_typing(self,channel_uid,color=None): + + async def set_typing(self, channel_uid, color=None): self._require_login() user = await self.services.user.get(self.user_uid) if not color: color = user["color"] - return await self.services.socket.broadcast(channel_uid, { - "channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2", - "event": "set_typing", - "data": { - "event":"set_typing", - "user_uid": user['uid'], - "username": user["username"], - "nick": user["nick"], - "channel_uid": channel_uid, - "color": color - } - }) + return await self.services.socket.broadcast( + channel_uid, + { + "channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2", + "event": "set_typing", + "data": { + "event": "set_typing", + "user_uid": user["uid"], + "username": user["username"], + "nick": user["nick"], + "channel_uid": channel_uid, + "color": color, + }, + }, + ) async def db_delete(self, table_name, record): self._require_login() return await self.services.db.delete(self.user_uid, table_name, record) + async def db_get(self, table_name, record): self._require_login() return await self.services.db.get(self.user_uid, table_name, record) + async def db_find(self, table_name, record): self._require_login() return await self.services.db.find(self.user_uid, table_name, record) - async def db_upsert(self, table_name, record,keys): + + async def db_upsert(self, table_name, record, keys): self._require_login() - return await self.services.db.upsert(self.user_uid, table_name, record,keys) + return await self.services.db.upsert( + self.user_uid, table_name, record, keys + ) async def db_query(self, table_name, args): self._require_login() return await self.services.db.query(self.user_uid, table_name, sql, args) - @property def user_uid(self): return self.view.session.get("uid") @@ -195,52 +205,60 @@ class RPCView(BaseView): async def finalize_message(self, message_uid): self._require_login() - message = await self.services.channel_message.get(message_uid) + message = await self.services.channel_message.get(message_uid) if not message: - return False + return False if message["user_uid"] != self.user_uid: raise Exception("Not allowed") - - - if not message['is_final']: - await self.services.chat.finalize(message['uid']) + + if not message["is_final"]: + await self.services.chat.finalize(message["uid"]) return True - async def update_message_text(self,message_uid, text): + async def update_message_text(self, message_uid, text): self._require_login() - message = await self.services.channel_message.get(message_uid) + message = await self.services.channel_message.get(message_uid) if message["user_uid"] != self.user_uid: raise Exception("Not allowed") - + if message.get_seconds_since_last_update() > 3: await self.finalize_message(message["uid"]) - return {"error": "Message too old","seconds_since_last_update": message.get_seconds_since_last_update(),"success": False} + return { + "error": "Message too old", + "seconds_since_last_update": message.get_seconds_since_last_update(), + "success": False, + } - message['message'] = text + message["message"] = text if not text: - message['deleted_at'] = now() + message["deleted_at"] = now() else: - message['deleted_at'] = None + message["deleted_at"] = None await self.services.channel_message.save(message) - data = message.record - data['text'] = message["message"] - data['message_uid'] = message_uid + data = message.record + data["text"] = message["message"] + data["message_uid"] = message_uid + + await self.services.socket.broadcast( + message["channel_uid"], + { + "channel_uid": message["channel_uid"], + "event": "update_message_text", + "data": message.record, + }, + ) - await self.services.socket.broadcast(message["channel_uid"], { - "channel_uid": message["channel_uid"], - "event": "update_message_text", - "data": message.record - }) - return {"success": True} - async def send_message(self, channel_uid, message,is_final=True): + async def send_message(self, channel_uid, message, is_final=True): self._require_login() - message = await self.services.chat.send(self.user_uid, channel_uid, message,is_final) - + message = await self.services.chat.send( + self.user_uid, channel_uid, message, is_final + ) + return message["uid"] async def echo(self, *args): @@ -330,7 +348,8 @@ class RPCView(BaseView): self._require_login() results = [ - record async for record in self.services.channel.get_recent_users(channel_uid) + record + async for record in self.services.channel.get_recent_users(channel_uid) ] results = sorted(results, key=lambda x: x["nick"]) return results @@ -371,7 +390,6 @@ class RPCView(BaseView): await self.services.socket.send_to_user(self.user_uid, call) self._scheduled.remove(call) - async def ping(self, callId, *args): if self.user_uid: user = await self.services.user.get(uid=self.user_uid) @@ -379,17 +397,23 @@ class RPCView(BaseView): await self.services.user.save(user) return {"pong": args} - async def stars_render(self, channel_uid, message): - + for user in await self.get_online_users(channel_uid): try: - await self.services.socket.send_to_user(user['uid'], dict(event="stars_render", data={"channel_uid": channel_uid, "message":message})) + await self.services.socket.send_to_user( + user["uid"], + { + "event": "stars_render", + "data": {"channel_uid": channel_uid, "message": message}, + }, + ) except Exception as ex: print(ex) - + async def get(self): scheduled = [] + async def schedule(uid, seconds, call): scheduled.append(call) await asyncio.sleep(seconds) @@ -408,16 +432,18 @@ class RPCView(BaseView): await self.services.socket.subscribe( ws, subscription["channel_uid"], self.request.session.get("uid") ) - if not scheduled and self.request.app.uptime_seconds < 5: - await schedule(self.request.session.get("uid"),0,{"event":"refresh", "data": { - "message": "Finishing deployment"} - } - ) - await schedule(self.request.session.get("uid"),15,{"event": "deployed", "data": { - "uptime": self.request.app.uptime} - } + if not scheduled and self.request.app.uptime_seconds < 5: + await schedule( + self.request.session.get("uid"), + 0, + {"event": "refresh", "data": {"message": "Finishing deployment"}}, ) - + await schedule( + self.request.session.get("uid"), + 15, + {"event": "deployed", "data": {"uptime": self.request.app.uptime}}, + ) + rpc = RPCView.RPCApi(self, ws) async for msg in ws: if msg.type == web.WSMsgType.TEXT: diff --git a/src/snek/view/settings/containers.py b/src/snek/view/settings/containers.py index dae0237..b6c4bfc 100644 --- a/src/snek/view/settings/containers.py +++ b/src/snek/view/settings/containers.py @@ -1,45 +1,48 @@ -import asyncio from aiohttp import web from snek.system.view import BaseFormView -import pathlib + class ContainersIndexView(BaseFormView): login_required = True async def get(self): - + user_uid = self.session.get("uid") - + containers = [] async for container in self.services.container.find(user_uid=user_uid): containers.append(container.record) - + user = await self.services.user.get(uid=self.session.get("uid")) - return await self.render_template("settings/containers/index.html", {"containers": containers, "user": user}) + return await self.render_template( + "settings/containers/index.html", {"containers": containers, "user": user} + ) + class ContainersCreateView(BaseFormView): login_required = True async def get(self): - + return await self.render_template("settings/containers/create.html") async def post(self): data = await self.request.post() - container = await self.services.container.create( + await self.services.container.create( user_uid=self.session.get("uid"), - name=data['name'], - status=data['status'], - resources=data.get('resources', ''), - path=data.get('path', ''), - readonly=bool(data.get('readonly', False)) + name=data["name"], + status=data["status"], + resources=data.get("resources", ""), + path=data.get("path", ""), + readonly=bool(data.get("readonly", False)), ) return web.HTTPFound("/settings/containers/index.html") + class ContainersUpdateView(BaseFormView): login_required = True @@ -51,40 +54,43 @@ class ContainersUpdateView(BaseFormView): ) if not container: return web.HTTPNotFound() - return await self.render_template("settings/containers/update.html", {"container": container.record}) + return await self.render_template( + "settings/containers/update.html", {"container": container.record} + ) async def post(self): data = await self.request.post() container = await self.services.container.get( user_uid=self.session.get("uid"), uid=self.request.match_info["uid"] ) - container['status'] = data['status'] - container['resources'] = data.get('resources', '') - container['path'] = data.get('path', '') - container['readonly'] = bool(data.get('readonly', False)) + container["status"] = data["status"] + container["resources"] = data.get("resources", "") + container["path"] = data.get("path", "") + container["readonly"] = bool(data.get("readonly", False)) await self.services.container.save(container) return web.HTTPFound("/settings/containers/index.html") + class ContainersDeleteView(BaseFormView): login_required = True async def get(self): - + container = await self.services.container.get( user_uid=self.session.get("uid"), uid=self.request.match_info["uid"] ) if not container: return web.HTTPNotFound() - return await self.render_template("settings/containers/delete.html", {"container": container.record}) + return await self.render_template( + "settings/containers/delete.html", {"container": container.record} + ) async def post(self): user_uid = self.session.get("uid") uid = self.request.match_info["uid"] - container = await self.services.container.get( - user_uid=user_uid, uid=uid - ) + container = await self.services.container.get(user_uid=user_uid, uid=uid) if not container: return web.HTTPNotFound() await self.services.container.delete(user_uid=user_uid, uid=uid) diff --git a/src/snek/view/settings/repositories.py b/src/snek/view/settings/repositories.py index 093d229..dcc945b 100644 --- a/src/snek/view/settings/repositories.py +++ b/src/snek/view/settings/repositories.py @@ -1,26 +1,26 @@ -import asyncio from aiohttp import web from snek.system.view import BaseFormView -import pathlib + class RepositoriesIndexView(BaseFormView): login_required = True async def get(self): - + user_uid = self.session.get("uid") - + repositories = [] async for repository in self.services.repository.find(user_uid=user_uid): repositories.append(repository.record) - + user = await self.services.user.get(uid=self.session.get("uid")) - return await self.render_template("settings/repositories/index.html", {"repositories": repositories, "user": user}) - - + return await self.render_template( + "settings/repositories/index.html", + {"repositories": repositories, "user": user}, + ) class RepositoriesCreateView(BaseFormView): @@ -28,14 +28,19 @@ class RepositoriesCreateView(BaseFormView): login_required = True async def get(self): - + return await self.render_template("settings/repositories/create.html") async def post(self): data = await self.request.post() - repository = await self.services.repository.create(user_uid=self.session.get("uid"), name=data['name'], is_private=int(data.get('is_private',0))) + await self.services.repository.create( + user_uid=self.session.get("uid"), + name=data["name"], + is_private=int(data.get("is_private", 0)), + ) return web.HTTPFound("/settings/repositories/index.html") + class RepositoriesUpdateView(BaseFormView): login_required = True @@ -47,40 +52,41 @@ class RepositoriesUpdateView(BaseFormView): ) if not repository: return web.HTTPNotFound() - return await self.render_template("settings/repositories/update.html", {"repository": repository.record}) + return await self.render_template( + "settings/repositories/update.html", {"repository": repository.record} + ) async def post(self): data = await self.request.post() repository = await self.services.repository.get( user_uid=self.session.get("uid"), name=self.request.match_info["name"] ) - repository['is_private'] = int(data.get('is_private',0)) + repository["is_private"] = int(data.get("is_private", 0)) await self.services.repository.save(repository) return web.HTTPFound("/settings/repositories/index.html") + class RepositoriesDeleteView(BaseFormView): login_required = True async def get(self): - + repository = await self.services.repository.get( user_uid=self.session.get("uid"), name=self.request.match_info["name"] ) if not repository: return web.HTTPNotFound() - return await self.render_template("settings/repositories/delete.html", {"repository": repository.record}) + return await self.render_template( + "settings/repositories/delete.html", {"repository": repository.record} + ) async def post(self): user_uid = self.session.get("uid") name = self.request.match_info["name"] - repository = await self.services.repository.get( - user_uid=user_uid, name=name - ) + repository = await self.services.repository.get(user_uid=user_uid, name=name) if not repository: return web.HTTPNotFound() await self.services.repository.delete(user_uid=user_uid, name=name) return web.HTTPFound("/settings/repositories/index.html") - -