From 389d417c1b4d46344475978b45125d757dcc673f Mon Sep 17 00:00:00 2001 From: retoor Date: Sat, 7 Jun 2025 05:47:54 +0200 Subject: [PATCH] Update. --- src/snek/__main__.py | 73 +- src/snek/app.py | 102 +- src/snek/balancer.py | 36 +- src/snek/mapper/__init__.py | 9 +- src/snek/mapper/container.py | 3 +- src/snek/model/__init__.py | 8 +- src/snek/model/channel_attachment.py | 7 +- src/snek/model/channel_message.py | 10 +- src/snek/model/container.py | 1 + src/snek/model/drive_item.py | 4 +- src/snek/model/repository.py | 6 +- src/snek/model/user.py | 2 +- src/snek/research/serptest.py | 53 +- src/snek/service/__init__.py | 10 +- src/snek/service/channel.py | 13 +- src/snek/service/channel_attachment.py | 20 +- src/snek/service/channel_message.py | 4 +- src/snek/service/chat.py | 10 +- src/snek/service/container.py | 41 +- src/snek/service/db.py | 23 +- src/snek/service/notification.py | 5 +- src/snek/service/push.py | 22 +- src/snek/service/repository.py | 21 +- src/snek/service/socket.py | 21 +- src/snek/service/user.py | 10 +- src/snek/sgit.py | 422 ++++--- src/snek/snode.py | 65 +- src/snek/sssh.py | 35 +- src/snek/system/docker.py | 54 +- src/snek/system/mapper.py | 17 +- src/snek/system/markdown.py | 27 +- src/snek/system/middleware.py | 10 +- src/snek/system/security.py | 3 + src/snek/system/template.py | 91 +- src/snek/system/websocket.py | 59 +- src/snek/view/avatar.py | 15 +- src/snek/view/avatar_animal.py | 1434 +++++++++++++++++------- src/snek/view/avatar_animal_view.py | 10 +- src/snek/view/channel.py | 59 +- src/snek/view/drive.py | 171 +-- src/snek/view/push.py | 2 +- src/snek/view/repository.py | 161 +-- src/snek/view/rpc.py | 148 ++- src/snek/view/settings/containers.py | 52 +- src/snek/view/settings/repositories.py | 44 +- 45 files changed, 2158 insertions(+), 1235 deletions(-) diff --git a/src/snek/__main__.py b/src/snek/__main__.py index 360bf5a..a2e433a 100644 --- a/src/snek/__main__.py +++ b/src/snek/__main__.py @@ -1,27 +1,31 @@ -import click -import uvloop -from aiohttp import web -import asyncio -from snek.app import Application -from IPython import start_ipython -import sqlite3 import pathlib -import shutil +import shutil +import sqlite3 + +import click +from aiohttp import web +from IPython import start_ipython + +from snek.app import Application + @click.group() def cli(): pass + @cli.command() -@click.option('--db_path',default="snek.db", help='Database to initialize if not exists.') -@click.option('--source',default=None, help='Database to initialize if not exists.') -def init(db_path,source): +@click.option( + "--db_path", default="snek.db", help="Database to initialize if not exists." +) +@click.option("--source", default=None, help="Database to initialize if not exists.") +def init(db_path, source): if source and pathlib.Path(source).exists(): print(f"Copying {source} to {db_path}") - shutil.copy2(source,db_path) + shutil.copy2(source, db_path) print("Database initialized.") return - + if pathlib.Path(db_path).exists(): return print(f"Initializing database at {db_path}") @@ -33,25 +37,44 @@ def init(db_path,source): db.close() print("Database initialized.") -@cli.command() -@click.option('--port', default=8081, show_default=True, help='Port to run the application on') -@click.option('--host', default='0.0.0.0', show_default=True, help='Host to run the application on') -@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application') -def serve(port, host, db_path): - #init(db_path) - #asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - web.run_app( - Application(db_path=f"sqlite:///{db_path}"), port=port, host=host - ) @cli.command() -@click.option('--db_path', default='snek.db', show_default=True, help='Database path for the application') +@click.option( + "--port", default=8081, show_default=True, help="Port to run the application on" +) +@click.option( + "--host", + default="0.0.0.0", + show_default=True, + help="Host to run the application on", +) +@click.option( + "--db_path", + default="snek.db", + show_default=True, + help="Database path for the application", +) +def serve(port, host, db_path): + # init(db_path) + # asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + web.run_app(Application(db_path=f"sqlite:///{db_path}"), port=port, host=host) + + +@cli.command() +@click.option( + "--db_path", + default="snek.db", + show_default=True, + help="Database path for the application", +) def shell(db_path): app = Application(db_path=f"sqlite:///{db_path}") - start_ipython(argv=[], user_ns={'app': app}) + start_ipython(argv=[], user_ns={"app": app}) + def main(): cli() + if __name__ == "__main__": main() diff --git a/src/snek/app.py b/src/snek/app.py index 03ef491..765e4fb 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -4,13 +4,15 @@ import pathlib import ssl import uuid from datetime import datetime + from snek import snode from snek.view.threads import ThreadsView logging.basicConfig(level=logging.DEBUG) -from ipaddress import ip_address from concurrent.futures import ThreadPoolExecutor +from ipaddress import ip_address +import IP2Location from aiohttp import web from aiohttp_session import ( get_session as session_get, @@ -20,54 +22,56 @@ from aiohttp_session import ( from aiohttp_session.cookie_storage import EncryptedCookieStorage from app.app import Application as BaseApplication from jinja2 import FileSystemLoader -import IP2Location -from snek.sssh import start_ssh_server -from snek.docs.app import Application as DocsApplication from snek.mapper import get_mappers from snek.service import get_services +from snek.sgit import GitApplication +from snek.sssh import start_ssh_server from snek.system import http from snek.system.cache import Cache from snek.system.markdown import MarkdownExtension from snek.system.middleware import auth_middleware, cors_middleware, csp_middleware from snek.system.profiler import profiler_handler -from snek.system.template import EmojiExtension, LinkifyExtension, PythonExtension +from snek.system.template import ( + EmojiExtension, + LinkifyExtension, + PythonExtension, + sanitize_html, +) from snek.view.about import AboutHTMLView, AboutMDView from snek.view.avatar import AvatarView +from snek.view.channel import ChannelAttachmentView, ChannelView from snek.view.docs import DocsHTMLView, DocsMDView -from snek.view.drive import DriveView -from snek.view.drive import DriveApiView +from snek.view.drive import DriveApiView, DriveView from snek.view.index import IndexView from snek.view.login import LoginView -from snek.view.push import PushView from snek.view.logout import LogoutView +from snek.view.push import PushView from snek.view.register import RegisterView -from snek.view.rpc import RPCView from snek.view.repository import RepositoryView +from snek.view.rpc import RPCView from snek.view.search_user import SearchUserView -from snek.view.settings.repositories import RepositoriesIndexView -from snek.view.settings.repositories import RepositoriesCreateView -from snek.view.settings.repositories import RepositoriesUpdateView -from snek.view.settings.repositories import RepositoriesDeleteView +from snek.view.settings.containers import ( + ContainersCreateView, + ContainersDeleteView, + ContainersIndexView, + ContainersUpdateView, +) from snek.view.settings.index import SettingsIndexView from snek.view.settings.profile import SettingsProfileView +from snek.view.settings.repositories import ( + RepositoriesCreateView, + RepositoriesDeleteView, + RepositoriesIndexView, + RepositoriesUpdateView, +) from snek.view.stats import StatsView from snek.view.status import StatusView from snek.view.terminal import TerminalSocketView, TerminalView from snek.view.upload import UploadView from snek.view.user import UserView from snek.view.web import WebView -from snek.view.channel import ChannelAttachmentView -from snek.view.channel import ChannelView -from snek.view.settings.containers import ( - ContainersIndexView, - ContainersCreateView, - ContainersUpdateView, - ContainersDeleteView, -) from snek.webdav import WebdavApplication -from snek.system.template import sanitize_html -from snek.sgit import GitApplication SESSION_KEY = b"c79a0c5fda4b424189c427d28c9f7c34" from snek.system.template import whitelist_attributes @@ -94,7 +98,7 @@ async def ip2location_middleware(request, handler): if not user: return response location = request.app.ip2location.get(ip) - original_city = user["city"] + user["city"] if user["city"] != location.city: user["country_long"] = location.country user["country_short"] = locaion.country_short @@ -121,7 +125,7 @@ class Application(BaseApplication): cors_middleware, web.normalize_path_middleware(merge_slashes=True), ip2location_middleware, - csp_middleware + csp_middleware, ] self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.static_path = pathlib.Path(__file__).parent.joinpath("static") @@ -139,7 +143,7 @@ class Application(BaseApplication): self.jinja2_env.add_extension(LinkifyExtension) self.jinja2_env.add_extension(PythonExtension) self.jinja2_env.add_extension(EmojiExtension) - self.jinja2_env.filters['sanitize'] = sanitize_html + self.jinja2_env.filters["sanitize"] = sanitize_html self.time_start = datetime.now() self.ssh_host = "0.0.0.0" self.ssh_port = 2242 @@ -272,8 +276,12 @@ class Application(BaseApplication): self.router.add_get("/http-photo", self.handle_http_photo) self.router.add_get("/rpc.ws", RPCView) self.router.add_get("/c/{channel:.*}", ChannelView) - self.router.add_view("/channel/{channel_uid}/attachment.bin",ChannelAttachmentView) - self.router.add_view("/channel/attachment/{relative_url:.*}",ChannelAttachmentView) + self.router.add_view( + "/channel/{channel_uid}/attachment.bin", ChannelAttachmentView + ) + self.router.add_view( + "/channel/attachment/{relative_url:.*}", ChannelAttachmentView + ) self.router.add_view("/channel/{channel}.html", WebView) self.router.add_view("/threads.html", ThreadsView) self.router.add_view("/terminal.ws", TerminalSocketView) @@ -284,15 +292,29 @@ class Application(BaseApplication): self.router.add_view("/stats.json", StatsView) self.router.add_view("/user/{user}.html", UserView) self.router.add_view("/repository/{username}/{repository}", RepositoryView) - self.router.add_view("/repository/{username}/{repository}/{path:.*}", RepositoryView) + self.router.add_view( + "/repository/{username}/{repository}/{path:.*}", RepositoryView + ) self.router.add_view("/settings/repositories/index.html", RepositoriesIndexView) - self.router.add_view("/settings/repositories/create.html", RepositoriesCreateView) - self.router.add_view("/settings/repositories/repository/{name}/update.html", RepositoriesUpdateView) - self.router.add_view("/settings/repositories/repository/{name}/delete.html", RepositoriesDeleteView) + self.router.add_view( + "/settings/repositories/create.html", RepositoriesCreateView + ) + self.router.add_view( + "/settings/repositories/repository/{name}/update.html", + RepositoriesUpdateView, + ) + self.router.add_view( + "/settings/repositories/repository/{name}/delete.html", + RepositoriesDeleteView, + ) self.router.add_view("/settings/containers/index.html", ContainersIndexView) self.router.add_view("/settings/containers/create.html", ContainersCreateView) - self.router.add_view("/settings/containers/container/{uid}/update.html", ContainersUpdateView) - self.router.add_view("/settings/containers/container/{uid}/delete.html", ContainersDeleteView) + self.router.add_view( + "/settings/containers/container/{uid}/update.html", ContainersUpdateView + ) + self.router.add_view( + "/settings/containers/container/{uid}/delete.html", ContainersDeleteView + ) self.webdav = WebdavApplication(self) self.git = GitApplication(self) self.add_subapp("/webdav", self.webdav) @@ -302,9 +324,9 @@ class Application(BaseApplication): async def handle_test(self, request): - return await whitelist_attributes(self.render_template( - "test.html", request, context={"name": "retoor"} - )) + return await whitelist_attributes( + self.render_template("test.html", request, context={"name": "retoor"}) + ) async def handle_http_get(self, request: web.Request): url = request.query.get("url") @@ -375,8 +397,8 @@ class Application(BaseApplication): self.jinja2_env.loader = self.original_loader - #rendered.text = whitelist_attributes(rendered.text) - #rendered.headers['Content-Lenght'] = len(rendered.text) + # rendered.text = whitelist_attributes(rendered.text) + # rendered.headers['Content-Lenght'] = len(rendered.text) return rendered async def static_handler(self, request): @@ -423,7 +445,7 @@ app = Application(db_path="sqlite:///snek.db") async def main(): ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_context.load_cert_chain('cert.pem', 'key.pem') + ssl_context.load_cert_chain("cert.pem", "key.pem") await web._run_app(app, port=8081, host="0.0.0.0", ssl_context=ssl_context) diff --git a/src/snek/balancer.py b/src/snek/balancer.py index 09d0b5a..44db9a0 100644 --- a/src/snek/balancer.py +++ b/src/snek/balancer.py @@ -1,6 +1,7 @@ import asyncio import sys + class LoadBalancer: def __init__(self, backend_ports): self.backend_ports = backend_ports @@ -8,27 +9,29 @@ class LoadBalancer: self.client_counts = [0] * len(backend_ports) self.lock = asyncio.Lock() - async def start_backend_servers(self,port,workers): + async def start_backend_servers(self, port, workers): for x in range(workers): port += 1 process = await asyncio.create_subprocess_exec( sys.executable, sys.argv[0], - 'backend', + "backend", str(port), stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) port += 1 self.backend_processes.append(process) - print(f"Started backend server on port {(port-1)/port} with PID {process.pid}") + print( + f"Started backend server on port {(port-1)/port} with PID {process.pid}" + ) async def handle_client(self, reader, writer): async with self.lock: min_clients = min(self.client_counts) server_index = self.client_counts.index(min_clients) self.client_counts[server_index] += 1 - backend = ('127.0.0.1', self.backend_ports[server_index]) + backend = ("127.0.0.1", self.backend_ports[server_index]) try: backend_reader, backend_writer = await asyncio.open_connection(*backend) @@ -62,10 +65,10 @@ class LoadBalancer: for i, count in enumerate(self.client_counts): print(f"Server {self.backend_ports[i]}: {count} clients") - async def start(self, host='0.0.0.0', port=8081,workers=5): - await self.start_backend_servers(port,workers) + async def start(self, host="0.0.0.0", port=8081, workers=5): + await self.start_backend_servers(port, workers) server = await asyncio.start_server(self.handle_client, host, port) - monitor_task = asyncio.create_task(self.monitor()) + asyncio.create_task(self.monitor()) # Handle shutdown gracefully try: @@ -80,6 +83,7 @@ class LoadBalancer: await asyncio.gather(*(p.wait() for p in self.backend_processes)) print("Backend processes terminated.") + async def backend_echo_server(port): async def handle_echo(reader, writer): try: @@ -94,10 +98,11 @@ async def backend_echo_server(port): finally: writer.close() - server = await asyncio.start_server(handle_echo, '127.0.0.1', port) + server = await asyncio.start_server(handle_echo, "127.0.0.1", port) print(f"Backend echo server running on port {port}") await server.serve_forever() + async def main(): backend_ports = [8001, 8003, 8005, 8006] # Launch backend echo servers @@ -105,19 +110,20 @@ async def main(): lb = LoadBalancer(backend_ports) await lb.start() + if __name__ == "__main__": if len(sys.argv) > 1: - if sys.argv[1] == 'backend': + if sys.argv[1] == "backend": port = int(sys.argv[2]) from snek.app import Application + snek = Application(port=port) - web.run_app(snek, port=port, host='127.0.0.1') - elif sys.argv[1] == 'sync': - from snek.sync import app - web.run_app(snek, port=port, host='127.0.0.1') + web.run_app(snek, port=port, host="127.0.0.1") + elif sys.argv[1] == "sync": + + web.run_app(snek, port=port, host="127.0.0.1") else: try: asyncio.run(main()) except KeyboardInterrupt: print("Shutting down...") - diff --git a/src/snek/mapper/__init__.py b/src/snek/mapper/__init__.py index 48a8a0e..062553a 100644 --- a/src/snek/mapper/__init__.py +++ b/src/snek/mapper/__init__.py @@ -1,22 +1,21 @@ import functools from snek.mapper.channel import ChannelMapper +from snek.mapper.channel_attachment import ChannelAttachmentMapper from snek.mapper.channel_member import ChannelMemberMapper from snek.mapper.channel_message import ChannelMessageMapper +from snek.mapper.container import ContainerMapper from snek.mapper.drive import DriveMapper from snek.mapper.drive_item import DriveItemMapper from snek.mapper.notification import NotificationMapper +from snek.mapper.push import PushMapper +from snek.mapper.repository import RepositoryMapper from snek.mapper.user import UserMapper from snek.mapper.user_property import UserPropertyMapper -from snek.mapper.repository import RepositoryMapper -from snek.mapper.push import PushMapper -from snek.mapper.channel_attachment import ChannelAttachmentMapper -from snek.mapper.container import ContainerMapper from snek.system.object import Object @functools.cache - def get_mappers(app=None): return Object( **{ diff --git a/src/snek/mapper/container.py b/src/snek/mapper/container.py index f88f9f0..c2a089f 100644 --- a/src/snek/mapper/container.py +++ b/src/snek/mapper/container.py @@ -1,6 +1,7 @@ from snek.model.container import Container from snek.system.mapper import BaseMapper + class ContainerMapper(BaseMapper): model_class = Container - table_name = "container" \ No newline at end of file + table_name = "container" diff --git a/src/snek/model/__init__.py b/src/snek/model/__init__.py index e5c6054..187ea4b 100644 --- a/src/snek/model/__init__.py +++ b/src/snek/model/__init__.py @@ -1,19 +1,19 @@ import functools from snek.model.channel import ChannelModel +from snek.model.channel_attachment import ChannelAttachmentModel from snek.model.channel_member import ChannelMemberModel # from snek.model.channel_message import ChannelMessageModel from snek.model.channel_message import ChannelMessageModel +from snek.model.container import Container from snek.model.drive import DriveModel from snek.model.drive_item import DriveItemModel from snek.model.notification import NotificationModel +from snek.model.push_registration import PushRegistrationModel +from snek.model.repository import RepositoryModel from snek.model.user import UserModel from snek.model.user_property import UserPropertyModel -from snek.model.repository import RepositoryModel -from snek.model.channel_attachment import ChannelAttachmentModel -from snek.model.push_registration import PushRegistrationModel -from snek.model.container import Container from snek.system.object import Object diff --git a/src/snek/model/channel_attachment.py b/src/snek/model/channel_attachment.py index 9add1b8..0c0fba2 100644 --- a/src/snek/model/channel_attachment.py +++ b/src/snek/model/channel_attachment.py @@ -1,4 +1,3 @@ -from snek.system.model import BaseModel from snek.system.model import BaseModel, ModelField @@ -11,6 +10,6 @@ class ChannelAttachmentModel(BaseModel): user_uid = ModelField(name="user_uid", required=True, kind=str) mime_type = ModelField(name="type", required=True, kind=str) relative_url = ModelField(name="relative_url", required=True, kind=str) - resource_type = ModelField(name="resource_type", required=True, kind=str,value="file") - - + resource_type = ModelField( + name="resource_type", required=True, kind=str, value="file" + ) diff --git a/src/snek/model/channel_message.py b/src/snek/model/channel_message.py index 505a40d..e3933e3 100644 --- a/src/snek/model/channel_message.py +++ b/src/snek/model/channel_message.py @@ -1,6 +1,8 @@ +from datetime import datetime, timezone + from snek.model.user import UserModel from snek.system.model import BaseModel, ModelField -from datetime import datetime,timezone + class ChannelMessageModel(BaseModel): channel_uid = ModelField(name="channel_uid", required=True, kind=str) @@ -10,7 +12,11 @@ class ChannelMessageModel(BaseModel): is_final = ModelField(name="is_final", required=True, kind=bool, value=True) def get_seconds_since_last_update(self): - return int((datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"])).total_seconds()) + return int( + ( + datetime.now(timezone.utc) - datetime.fromisoformat(self["updated_at"]) + ).total_seconds() + ) async def get_user(self) -> UserModel: return await self.app.services.user.get(uid=self["user_uid"]) diff --git a/src/snek/model/container.py b/src/snek/model/container.py index 3947a8d..ee64144 100644 --- a/src/snek/model/container.py +++ b/src/snek/model/container.py @@ -1,5 +1,6 @@ from snek.system.model import BaseModel, ModelField + class Container(BaseModel): id = ModelField(name="id", required=True, kind=str) name = ModelField(name="name", required=True, kind=str) diff --git a/src/snek/model/drive_item.py b/src/snek/model/drive_item.py index e2b55b4..82567a6 100644 --- a/src/snek/model/drive_item.py +++ b/src/snek/model/drive_item.py @@ -9,7 +9,9 @@ class DriveItemModel(BaseModel): path = ModelField(name="path", required=True, kind=str) file_type = ModelField(name="file_type", required=True, kind=str) file_size = ModelField(name="file_size", required=True, kind=int) - is_available = ModelField(name="is_available", required=True, kind=bool, initial_value=True) + is_available = ModelField( + name="is_available", required=True, kind=bool, initial_value=True + ) @property def extension(self): diff --git a/src/snek/model/repository.py b/src/snek/model/repository.py index 598cbb2..521e3ff 100644 --- a/src/snek/model/repository.py +++ b/src/snek/model/repository.py @@ -1,14 +1,10 @@ -from snek.model.user import UserModel from snek.system.model import BaseModel, ModelField class RepositoryModel(BaseModel): user_uid = ModelField(name="user_uid", required=True, kind=str) - + name = ModelField(name="name", required=True, kind=str) is_private = ModelField(name="is_private", required=False, kind=bool) - - - diff --git a/src/snek/model/user.py b/src/snek/model/user.py index afa4b53..93e83af 100644 --- a/src/snek/model/user.py +++ b/src/snek/model/user.py @@ -30,7 +30,7 @@ class UserModel(BaseModel): last_ping = ModelField(name="last_ping", required=False, kind=str) is_admin = ModelField(name="is_admin", required=False, kind=bool) - + country_short = ModelField(name="country_short", required=False, kind=str) country_long = ModelField(name="country_long", required=False, kind=str) city = ModelField(name="city", required=False, kind=str) diff --git a/src/snek/research/serptest.py b/src/snek/research/serptest.py index 94e22be..891cccc 100644 --- a/src/snek/research/serptest.py +++ b/src/snek/research/serptest.py @@ -1,51 +1,54 @@ -import snek.serpentarium - -import time - +import time from concurrent.futures import ProcessPoolExecutor +import snek.serpentarium + durations = [] + def task1(): - global durations + global durations client = snek.serpentarium.DatasetWrapper() - start=time.time() + start = time.time() for x in range(1500): - - client['a'].delete() - client['a'].insert({"foo": x}) - client['a'].find(foo=x) - client['a'].find_one(foo=x) - client['a'].count() - #print(client['a'].find(foo=x) ) - #print(client['a'].find_one(foo=x) ) - #print(client['a'].count()) + + client["a"].delete() + client["a"].insert({"foo": x}) + client["a"].find(foo=x) + client["a"].find_one(foo=x) + client["a"].count() + # print(client['a'].find(foo=x) ) + # print(client['a'].find_one(foo=x) ) + # print(client['a'].count()) client.close() duration1 = f"{time.time()-start}" durations.append(duration1) print(durations) + with ProcessPoolExecutor(max_workers=4) as executor: - tasks = [executor.submit(task1), + tasks = [ + executor.submit(task1), + executor.submit(task1), executor.submit(task1), executor.submit(task1), - executor.submit(task1) ] for task in tasks: task.result() -import dataset +import dataset + client = dataset.connect("sqlite:///snek.db") -start=time.time() +start = time.time() for x in range(1500): - client['a'].delete() - client['a'].insert({"foo": x}) - print([dict(row) for row in client['a'].find(foo=x)]) - print(dict(client['a'].find_one(foo=x) )) - print(client['a'].count()) + client["a"].delete() + client["a"].insert({"foo": x}) + print([dict(row) for row in client["a"].find(foo=x)]) + print(dict(client["a"].find_one(foo=x))) + print(client["a"].count()) duration2 = f"{time.time()-start}" -print(duration1,duration2) +print(duration1, duration2) diff --git a/src/snek/service/__init__.py b/src/snek/service/__init__.py index f1654fb..77702c2 100644 --- a/src/snek/service/__init__.py +++ b/src/snek/service/__init__.py @@ -1,22 +1,22 @@ import functools from snek.service.channel import ChannelService +from snek.service.channel_attachment import ChannelAttachmentService from snek.service.channel_member import ChannelMemberService from snek.service.channel_message import ChannelMessageService from snek.service.chat import ChatService +from snek.service.container import ContainerService +from snek.service.db import DBService from snek.service.drive import DriveService from snek.service.drive_item import DriveItemService from snek.service.notification import NotificationService +from snek.service.push import PushService +from snek.service.repository import RepositoryService from snek.service.socket import SocketService from snek.service.user import UserService -from snek.service.push import PushService from snek.service.user_property import UserPropertyService from snek.service.util import UtilService -from snek.service.repository import RepositoryService -from snek.service.channel_attachment import ChannelAttachmentService -from snek.service.container import ContainerService from snek.system.object import Object -from snek.service.db import DBService @functools.cache diff --git a/src/snek/service/channel.py b/src/snek/service/channel.py index a5c166d..0af0a4b 100644 --- a/src/snek/service/channel.py +++ b/src/snek/service/channel.py @@ -1,9 +1,9 @@ +import pathlib from datetime import datetime from snek.system.model import now from snek.system.service import BaseService -import pathlib class ChannelService(BaseService): mapper_name = "channel" @@ -17,12 +17,10 @@ class ChannelService(BaseService): pass return folder - async def get_attachment_folder(self, channel_uid,ensure=False): + async def get_attachment_folder(self, channel_uid, ensure=False): path = pathlib.Path(f"./drive/{channel_uid}/attachments") if ensure: - path.mkdir( - parents=True, exist_ok=True - ) + path.mkdir(parents=True, exist_ok=True) return path async def get(self, uid=None, **kwargs): @@ -78,7 +76,10 @@ class ChannelService(BaseService): return channel async def get_recent_users(self, channel_uid): - async for user in self.query("SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30", {"channel_uid": channel_uid}): + async for user in self.query( + "SELECT user.uid, user.username,user.color,user.last_ping,user.nick FROM channel_member INNER JOIN user ON user.uid = channel_member.user_uid WHERE channel_uid=:channel_uid AND user.last_ping >= datetime('now', '-3 minutes') ORDER BY last_ping DESC LIMIT 30", + {"channel_uid": channel_uid}, + ): yield user async def get_users(self, channel_uid): diff --git a/src/snek/service/channel_attachment.py b/src/snek/service/channel_attachment.py index 7c8ded6..755aaa0 100644 --- a/src/snek/service/channel_attachment.py +++ b/src/snek/service/channel_attachment.py @@ -1,25 +1,25 @@ -from snek.system.service import BaseService -import urllib.parse -import pathlib import mimetypes -import uuid + +from snek.system.service import BaseService + class ChannelAttachmentService(BaseService): - mapper_name="channel_attachment" + mapper_name = "channel_attachment" async def create_file(self, channel_uid, user_uid, name): attachment = await self.new() attachment["channel_uid"] = channel_uid - attachment['user_uid'] = user_uid + attachment["user_uid"] = user_uid attachment["name"] = name attachment["mime_type"] = mimetypes.guess_type(name)[0] - attachment['resource_type'] = "file" + attachment["resource_type"] = "file" real_file_name = f"{attachment['uid']}-{name}" - attachment["relative_url"] = (f"{attachment['uid']}-{name}") - attachment_folder = await self.services.channel.get_attachment_folder(channel_uid) + attachment["relative_url"] = f"{attachment['uid']}-{name}" + attachment_folder = await self.services.channel.get_attachment_folder( + channel_uid + ) attachment_path = attachment_folder.joinpath(real_file_name) attachment["path"] = str(attachment_path) if await self.save(attachment): return attachment raise Exception(f"Failed to create channel attachment: {attachment.errors}.") - diff --git a/src/snek/service/channel_message.py b/src/snek/service/channel_message.py index eefff96..6d70da7 100644 --- a/src/snek/service/channel_message.py +++ b/src/snek/service/channel_message.py @@ -11,7 +11,7 @@ class ChannelMessageService(BaseService): model["channel_uid"] = channel_uid model["user_uid"] = user_uid model["message"] = message - model['is_final'] = is_final + model["is_final"] = is_final context = {} @@ -56,7 +56,7 @@ class ChannelMessageService(BaseService): async def save(self, model): context = {} context.update(model.record) - user = await self.app.services.user.get(model['user_uid']) + user = await self.app.services.user.get(model["user_uid"]) context.update( { "user_uid": user["uid"], diff --git a/src/snek/service/chat.py b/src/snek/service/chat.py index f37f6c7..9130292 100644 --- a/src/snek/service/chat.py +++ b/src/snek/service/chat.py @@ -13,13 +13,13 @@ class ChatService(BaseService): channel["last_message_on"] = now() await self.services.channel.save(channel) await self.services.socket.broadcast( - channel['uid'], + channel["uid"], { "message": channel_message["message"], "html": channel_message["html"], - "user_uid": user['uid'], + "user_uid": user["uid"], "color": user["color"], - "channel_uid": channel['uid'], + "channel_uid": channel["uid"], "created_at": channel_message["created_at"], "updated_at": channel_message["updated_at"], "username": user["username"], @@ -32,14 +32,12 @@ class ChatService(BaseService): self.services.notification.create_channel_message(message_uid) ) - - async def send(self, user_uid, channel_uid, message, is_final=True): channel = await self.services.channel.get(uid=channel_uid) if not channel: raise Exception("Channel not found.") channel_message = await self.services.channel_message.create( - channel_uid, user_uid, message,is_final + channel_uid, user_uid, message, is_final ) channel_message_uid = channel_message["uid"] diff --git a/src/snek/service/container.py b/src/snek/service/container.py index 03c63aa..9d2706c 100644 --- a/src/snek/service/container.py +++ b/src/snek/service/container.py @@ -1,36 +1,51 @@ +from snek.system.docker import ComposeFileManager from snek.system.service import BaseService -from snek.system.docker import ComposeFileManager + class ContainerService(BaseService): mapper_name = "container" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + self.compose_path = "snek-container-compose.yml" self.compose = ComposeFileManager(self.compose_path) - - async def get_instances(self): return list(self.compose.list_instances()) async def get_container_name(self, channel_uid): return f"channel-{channel_uid}" - async def create(self, channel_uid, image="ubuntu:latest", command=None, cpus=1, memory='1024m', ports=None, volumes=None): + async def create( + self, + channel_uid, + image="ubuntu:latest", + command=None, + cpus=1, + memory="1024m", + ports=None, + volumes=None, + ): name = await self.get_container_name(channel_uid) self.compose.create_instance( - name, - image, - command, - cpus, - memory, - ports, - ["./"+str((await self.services.channel.get_home_folder(channel_uid))) +":"+ "/home/ubuntu"] + name, + image, + command, + cpus, + memory, + ports, + [ + "./" + + str(await self.services.channel.get_home_folder(channel_uid)) + + ":" + + "/home/ubuntu" + ], ) - async def create2(self, id, name, status, resources=None, user_uid=None, path=None, readonly=False): + async def create2( + self, id, name, status, resources=None, user_uid=None, path=None, readonly=False + ): model = await self.new() model["id"] = id model["name"] = name diff --git a/src/snek/service/db.py b/src/snek/service/db.py index fe9fb8a..8e583e9 100644 --- a/src/snek/service/db.py +++ b/src/snek/service/db.py @@ -1,23 +1,21 @@ -from snek.system.service import BaseService -import dataset -import uuid +import dataset + +from snek.system.service import BaseService -from datetime import datetime class DBService(BaseService): - + async def get_db(self, user_uid): - + home_folder = await self.app.services.user.get_home_folder(user_uid) home_folder.mkdir(parents=True, exist_ok=True) db_path = home_folder.joinpath("snek/user.db") db_path.parent.mkdir(parents=True, exist_ok=True) - return dataset.connect("sqlite:///" + str(db_path)) - + return dataset.connect("sqlite:///" + str(db_path)) + async def insert(self, user_uid, table_name, values): db = await self.get_db(user_uid) return db[table_name].insert(values) - async def update(self, user_uid, table_name, values, filters): db = await self.get_db(user_uid) @@ -33,7 +31,7 @@ class DBService(BaseService): async def find(self, user_uid, table_name, kwargs): db = await self.get_db(user_uid) - kwargs['_limit'] = kwargs.get('_limit', 30) + kwargs["_limit"] = kwargs.get("_limit", 30) return [dict(row) for row in db[table_name].find(**kwargs)] async def get(self, user_uid, table_name, filters): @@ -45,14 +43,13 @@ class DBService(BaseService): except ValueError: return None - async def delete(self, user_uid, table_name, filters): db = await self.get_db(user_uid) if not filters: filters = {} return db[table_name].delete(**filters) - async def query(self, sql,values): + async def query(self, sql, values): db = await self.app.db return [dict(row) for row in db.query(sql, values or {})] @@ -62,8 +59,6 @@ class DBService(BaseService): filters = {} return bool(db[table_name].find_one(**filters)) - - async def count(self, user_uid, table_name, filters): db = await self.get_db(user_uid) if not filters: diff --git a/src/snek/service/notification.py b/src/snek/service/notification.py index 8aaa1a4..fa3cdc3 100644 --- a/src/snek/service/notification.py +++ b/src/snek/service/notification.py @@ -1,6 +1,7 @@ +from snek.system.markdown import strip_markdown from snek.system.model import now from snek.system.service import BaseService -from snek.system.markdown import strip_markdown + class NotificationService(BaseService): mapper_name = "notification" @@ -34,7 +35,7 @@ class NotificationService(BaseService): uid=channel_message_uid ) if not channel_message["is_final"]: - return + return user = await self.services.user.get(uid=channel_message["user_uid"]) self.app.db.begin() async for channel_member in self.services.channel_member.find( diff --git a/src/snek/service/push.py b/src/snek/service/push.py index ccba4af..b4b72c9 100644 --- a/src/snek/service/push.py +++ b/src/snek/service/push.py @@ -1,25 +1,23 @@ +import base64 import json - -import aiohttp -from snek.system.service import BaseService +import os.path import random import time -import base64 import uuid -from functools import cache from pathlib import Path - -import jwt -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.backends import default_backend from urllib.parse import urlparse -import os.path +import aiohttp +import jwt +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.hashes import SHA256 from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from snek.system.service import BaseService + # The only reason to persist the keys is to be able to use them in the web push PRIVATE_KEY_FILE = Path("./notification-private.pem") @@ -268,4 +266,4 @@ class PushService(BaseService): raise Exception( f"Failed to register push subscription for user {user_uid} with endpoint {endpoint}" - ) \ No newline at end of file + ) diff --git a/src/snek/service/repository.py b/src/snek/service/repository.py index 120c232..def3257 100644 --- a/src/snek/service/repository.py +++ b/src/snek/service/repository.py @@ -1,13 +1,17 @@ -from snek.system.service import BaseService -import asyncio +import asyncio import shutil +from snek.system.service import BaseService + + class RepositoryService(BaseService): mapper_name = "repository" async def delete(self, user_uid, name): loop = asyncio.get_event_loop() - repository_path = (await self.services.user.get_repository_path(user_uid)).joinpath(name) + repository_path = ( + await self.services.user.get_repository_path(user_uid) + ).joinpath(name) try: await loop.run_in_executor(None, shutil.rmtree, repository_path) except Exception as ex: @@ -15,7 +19,6 @@ class RepositoryService(BaseService): await super().delete(user_uid=user_uid, name=name) - async def exists(self, user_uid, name, **kwargs): kwargs["user_uid"] = user_uid kwargs["name"] = name @@ -29,18 +32,16 @@ class RepositoryService(BaseService): repository_path = str(repository_path) if not repository_path.endswith(".git"): repository_path += ".git" - command = ['git', 'init', '--bare', repository_path] + command = ["git", "init", "--bare", repository_path] process = await asyncio.subprocess.create_subprocess_exec( - *command, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() return process.returncode == 0 - async def create(self, user_uid, name,is_private=False): + async def create(self, user_uid, name, is_private=False): if await self.exists(user_uid=user_uid, name=name): - return False + return False if not await self.init(user_uid=user_uid, name=name): return False diff --git a/src/snek/service/socket.py b/src/snek/service/socket.py index d0ff024..ea002dc 100644 --- a/src/snek/service/socket.py +++ b/src/snek/service/socket.py @@ -1,12 +1,14 @@ +import asyncio +import logging +from datetime import datetime + from snek.model.user import UserModel from snek.system.service import BaseService -from datetime import datetime -import json -import asyncio -import logging + logger = logging.getLogger(__name__) from snek.system.model import now + class SocketService(BaseService): class Socket: @@ -15,7 +17,6 @@ class SocketService(BaseService): self.is_connected = True self.user = user - async def send_json(self, data): if not self.is_connected: return False @@ -40,7 +41,6 @@ class SocketService(BaseService): self.users = {} self.subscriptions = {} self.last_update = str(datetime.now()) - async def user_availability_service(self): logger.info("User availability update service started.") @@ -50,14 +50,15 @@ class SocketService(BaseService): for s in self.sockets: if not s.is_connected: continue - if not s.user in users_updated: + if s.user not in users_updated: s.user["last_ping"] = now() await self.app.services.user.save(s.user) users_updated.append(s.user) - logger.info(f"Updated user availability for {len(users_updated)} online users.") + logger.info( + f"Updated user availability for {len(users_updated)} online users." + ) await asyncio.sleep(60) - async def add(self, ws, user_uid): s = self.Socket(ws, await self.app.services.user.get(uid=user_uid)) self.sockets.add(s) @@ -81,7 +82,6 @@ class SocketService(BaseService): count += 1 return count - async def broadcast(self, channel_uid, message): await self._broadcast(channel_uid, message) @@ -102,4 +102,3 @@ class SocketService(BaseService): await s.close() logger.info(f"Removed socket for user {s.user['username']}") self.sockets.remove(s) - diff --git a/src/snek/service/user.py b/src/snek/service/user.py index 13f660f..34dc468 100644 --- a/src/snek/service/user.py +++ b/src/snek/service/user.py @@ -32,14 +32,14 @@ class UserService(BaseService): user["color"] = await self.services.util.random_light_hex_color() return await super().save(user) - def authenticate_sync(self,username,password): + def authenticate_sync(self, username, password): user = self.get_by_username_sync(username) - + if not user: return False if not security.verify_sync(password, user["password"]): return False - return True + return True async def authenticate(self, username, password): success = await self.validate_login(username, password) @@ -61,14 +61,12 @@ class UserService(BaseService): return None return path - - async def get_template_path(self, user_uid): path = pathlib.Path(f"./drive/{user_uid}/snek/templates") if not path.exists(): return None return path - + def get_by_username_sync(self, username): user = self.mapper.db["user"].find_one(username=username, deleted_at=None) return dict(user) diff --git a/src/snek/sgit.py b/src/snek/sgit.py index 3cbdcf6..4c2b6ac 100644 --- a/src/snek/sgit.py +++ b/src/snek/sgit.py @@ -1,48 +1,51 @@ +import asyncio +import base64 +import json +import logging import os -import aiohttp +import shutil +import tempfile + from aiohttp import web -import shutil -import json -import tempfile -import asyncio -import logging -import base64 -import pathlib -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger('git_server') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("git_server") + class GitApplication(web.Application): def __init__(self, parent=None): - #import git - #globals()['git'] = git + # import git + # globals()['git'] = git self.parent = parent - super().__init__(client_max_size=1024*1024*1024*5) - self.add_routes([ - web.post('/create/{repo_name}', self.create_repository), - web.delete('/delete/{repo_name}', self.delete_repository), - web.get('/clone/{repo_name}', self.clone_repository), - web.post('/push/{repo_name}', self.push_repository), - web.post('/pull/{repo_name}', self.pull_repository), - web.get('/status/{repo_name}', self.status_repository), - #web.get('/list', self.list_repositories), - web.get('/branches/{repo_name}', self.list_branches), - web.post('/branches/{repo_name}', self.create_branch), - web.get('/log/{repo_name}', self.commit_log), - web.get('/file/{repo_name}/{file_path:.*}', self.file_content), - web.get('/{path:.+}/info/refs', self.git_smart_http), - web.post('/{path:.+}/git-upload-pack', self.git_smart_http), - web.post('/{path:.+}/git-receive-pack', self.git_smart_http), - web.get('/{repo_name}.git/info/refs', self.git_smart_http), - web.post('/{repo_name}.git/git-upload-pack', self.git_smart_http), - web.post('/{repo_name}.git/git-receive-pack', self.git_smart_http), - ]) - + super().__init__(client_max_size=1024 * 1024 * 1024 * 5) + self.add_routes( + [ + web.post("/create/{repo_name}", self.create_repository), + web.delete("/delete/{repo_name}", self.delete_repository), + web.get("/clone/{repo_name}", self.clone_repository), + web.post("/push/{repo_name}", self.push_repository), + web.post("/pull/{repo_name}", self.pull_repository), + web.get("/status/{repo_name}", self.status_repository), + # web.get('/list', self.list_repositories), + web.get("/branches/{repo_name}", self.list_branches), + web.post("/branches/{repo_name}", self.create_branch), + web.get("/log/{repo_name}", self.commit_log), + web.get("/file/{repo_name}/{file_path:.*}", self.file_content), + web.get("/{path:.+}/info/refs", self.git_smart_http), + web.post("/{path:.+}/git-upload-pack", self.git_smart_http), + web.post("/{path:.+}/git-receive-pack", self.git_smart_http), + web.get("/{repo_name}.git/info/refs", self.git_smart_http), + web.post("/{repo_name}.git/git-upload-pack", self.git_smart_http), + web.post("/{repo_name}.git/git-receive-pack", self.git_smart_http), + ] + ) async def check_basic_auth(self, request): auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Basic "): - return None,None + return None, None encoded_creds = auth_header.split("Basic ")[1] decoded_creds = base64.b64decode(encoded_creds).decode() username, password = decoded_creds.split(":", 1) @@ -50,27 +53,31 @@ class GitApplication(web.Application): username=username, password=password ) if not request["user"]: - return None,None - request["repository_path"] = await self.parent.services.user.get_repository_path( - request["user"]["uid"] + return None, None + request["repository_path"] = ( + await self.parent.services.user.get_repository_path(request["user"]["uid"]) ) - return request["user"]['username'],request["repository_path"] - + return request["user"]["username"], request["repository_path"] @staticmethod def require_auth(handler): async def wrapped(self, request, *args, **kwargs): username, repository_path = await self.check_basic_auth(request) if not username or not repository_path: - return web.Response(status=401, headers={'WWW-Authenticate': 'Basic'}, text='Authentication required') - request['username'] = username - request['repository_path'] = repository_path + return web.Response( + status=401, + headers={"WWW-Authenticate": "Basic"}, + text="Authentication required", + ) + request["username"] = username + request["repository_path"] = repository_path return await handler(self, request, *args, **kwargs) + return wrapped def repo_path(self, repository_path, repo_name): - return repository_path.joinpath(repo_name + '.git') + return repository_path.joinpath(repo_name + ".git") def check_repo_exists(self, repository_path, repo_name): repo_dir = self.repo_path(repository_path, repo_name) @@ -80,10 +87,10 @@ class GitApplication(web.Application): @require_auth async def create_repository(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] - if not repo_name or '/' in repo_name or '..' in repo_name: + username = request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] + if not repo_name or "/" in repo_name or ".." in repo_name: return web.Response(text="Invalid repository name", status=400) repo_dir = self.repo_path(repository_path, repo_name) if os.path.exists(repo_dir): @@ -98,9 +105,9 @@ class GitApplication(web.Application): @require_auth async def delete_repository(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + username = request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response @@ -115,9 +122,9 @@ class GitApplication(web.Application): @require_auth async def clone_repository(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response @@ -126,15 +133,15 @@ class GitApplication(web.Application): response_data = { "repository": repo_name, "clone_command": f"git clone {clone_url}", - "clone_url": clone_url + "clone_url": clone_url, } return web.json_response(response_data) @require_auth async def push_repository(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + username = request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response @@ -142,34 +149,40 @@ class GitApplication(web.Application): data = await request.json() except json.JSONDecodeError: return web.Response(text="Invalid JSON data", status=400) - commit_message = data.get('commit_message', 'Update from server') - branch = data.get('branch', 'main') - changes = data.get('changes', []) + commit_message = data.get("commit_message", "Update from server") + branch = data.get("branch", "main") + changes = data.get("changes", []) if not changes: return web.Response(text="No changes provided", status=400) with tempfile.TemporaryDirectory() as temp_dir: - temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + temp_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) for change in changes: - file_path = os.path.join(temp_dir, change.get('file', '')) - content = change.get('content', '') + file_path = os.path.join(temp_dir, change.get("file", "")) + content = change.get("content", "") os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, 'w') as f: + with open(file_path, "w") as f: f.write(content) temp_repo.git.add(A=True) - if not temp_repo.config_reader().has_section('user'): - temp_repo.config_writer().set_value("user", "name", "Git Server").release() - temp_repo.config_writer().set_value("user", "email", "git@server.local").release() + if not temp_repo.config_reader().has_section("user"): + temp_repo.config_writer().set_value( + "user", "name", "Git Server" + ).release() + temp_repo.config_writer().set_value( + "user", "email", "git@server.local" + ).release() temp_repo.index.commit(commit_message) - origin = temp_repo.remote('origin') + origin = temp_repo.remote("origin") origin.push(refspec=f"{branch}:{branch}") logger.info(f"Pushed to repository: {repo_name} for user {username}") return web.Response(text=f"Successfully pushed changes to {repo_name}") @require_auth async def pull_repository(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + username = request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response @@ -177,13 +190,15 @@ class GitApplication(web.Application): data = await request.json() except json.JSONDecodeError: data = {} - remote_url = data.get('remote_url') - branch = data.get('branch', 'main') + remote_url = data.get("remote_url") + branch = data.get("branch", "main") if not remote_url: return web.Response(text="Remote URL is required", status=400) with tempfile.TemporaryDirectory() as temp_dir: try: - local_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + local_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) remote_name = "pull_source" try: remote = local_repo.create_remote(remote_name, remote_url) @@ -192,38 +207,46 @@ class GitApplication(web.Application): remote.set_url(remote_url) remote.fetch() local_repo.git.merge(f"{remote_name}/{branch}") - origin = local_repo.remote('origin') + origin = local_repo.remote("origin") origin.push() - logger.info(f"Pulled to repository {repo_name} from {remote_url} for user {username}") - return web.Response(text=f"Successfully pulled changes from {remote_url} to {repo_name}") + logger.info( + f"Pulled to repository {repo_name} from {remote_url} for user {username}" + ) + return web.Response( + text=f"Successfully pulled changes from {remote_url} to {repo_name}" + ) except Exception as e: logger.error(f"Error pulling to {repo_name}: {str(e)}") return web.Response(text=f"Error pulling changes: {str(e)}", status=500) @require_auth async def status_repository(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response with tempfile.TemporaryDirectory() as temp_dir: try: - temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + temp_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) branches = [b.name for b in temp_repo.branches] active_branch = temp_repo.active_branch.name commits = [] for commit in list(temp_repo.iter_commits(max_count=5)): - commits.append({ - "id": commit.hexsha, - "author": f"{commit.author.name} <{commit.author.email}>", - "date": commit.committed_datetime.isoformat(), - "message": commit.message - }) + commits.append( + { + "id": commit.hexsha, + "author": f"{commit.author.name} <{commit.author.email}>", + "date": commit.committed_datetime.isoformat(), + "message": commit.message, + } + ) files = [] for root, dirs, filenames in os.walk(temp_dir): - if '.git' in root: + if ".git" in root: continue for filename in filenames: full_path = os.path.join(root, filename) @@ -234,50 +257,58 @@ class GitApplication(web.Application): "branches": branches, "active_branch": active_branch, "recent_commits": commits, - "files": files + "files": files, } return web.json_response(status_info) except Exception as e: logger.error(f"Error getting status for {repo_name}: {str(e)}") - return web.Response(text=f"Error getting repository status: {str(e)}", status=500) + return web.Response( + text=f"Error getting repository status: {str(e)}", status=500 + ) @require_auth async def list_repositories(self, request): - username = request['username'] + request["username"] try: repos = [] user_dir = self.REPO_DIR if os.path.exists(user_dir): for item in os.listdir(user_dir): item_path = os.path.join(user_dir, item) - if os.path.isdir(item_path) and item.endswith('.git'): + if os.path.isdir(item_path) and item.endswith(".git"): repos.append(item[:-4]) - if request.query.get('format') == 'json': + if request.query.get("format") == "json": return web.json_response({"repositories": repos}) else: - return web.Response(text="\n".join(repos) if repos else "No repositories found") + return web.Response( + text="\n".join(repos) if repos else "No repositories found" + ) except Exception as e: logger.error(f"Error listing repositories: {str(e)}") - return web.Response(text=f"Error listing repositories: {str(e)}", status=500) + return web.Response( + text=f"Error listing repositories: {str(e)}", status=500 + ) @require_auth async def list_branches(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response with tempfile.TemporaryDirectory() as temp_dir: - temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + temp_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) branches = [b.name for b in temp_repo.branches] return web.json_response({"branches": branches}) @require_auth async def create_branch(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + username = request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response @@ -285,145 +316,168 @@ class GitApplication(web.Application): data = await request.json() except json.JSONDecodeError: return web.Response(text="Invalid JSON data", status=400) - branch_name = data.get('branch_name') - start_point = data.get('start_point', 'HEAD') + branch_name = data.get("branch_name") + start_point = data.get("start_point", "HEAD") if not branch_name: return web.Response(text="Branch name is required", status=400) with tempfile.TemporaryDirectory() as temp_dir: try: - temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + temp_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) temp_repo.git.branch(branch_name, start_point) - temp_repo.git.push('origin', branch_name) - logger.info(f"Created branch {branch_name} in repository {repo_name} for user {username}") + temp_repo.git.push("origin", branch_name) + logger.info( + f"Created branch {branch_name} in repository {repo_name} for user {username}" + ) return web.Response(text=f"Created branch {branch_name}") except Exception as e: - logger.error(f"Error creating branch {branch_name} in {repo_name}: {str(e)}") + logger.error( + f"Error creating branch {branch_name} in {repo_name}: {str(e)}" + ) return web.Response(text=f"Error creating branch: {str(e)}", status=500) @require_auth async def commit_log(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - repository_path = request['repository_path'] + request["username"] + repo_name = request.match_info["repo_name"] + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response try: - limit = int(request.query.get('limit', 10)) - branch = request.query.get('branch', 'main') + limit = int(request.query.get("limit", 10)) + branch = request.query.get("branch", "main") except ValueError: return web.Response(text="Invalid limit parameter", status=400) with tempfile.TemporaryDirectory() as temp_dir: try: - temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + temp_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) commits = [] try: for commit in list(temp_repo.iter_commits(branch, max_count=limit)): - commits.append({ - "id": commit.hexsha, - "short_id": commit.hexsha[:7], - "author": f"{commit.author.name} <{commit.author.email}>", - "date": commit.committed_datetime.isoformat(), - "message": commit.message.strip() - }) + commits.append( + { + "id": commit.hexsha, + "short_id": commit.hexsha[:7], + "author": f"{commit.author.name} <{commit.author.email}>", + "date": commit.committed_datetime.isoformat(), + "message": commit.message.strip(), + } + ) except git.GitCommandError as e: if "unknown revision or path" in str(e): commits = [] else: raise - return web.json_response({ - "repository": repo_name, - "branch": branch, - "commits": commits - }) + return web.json_response( + {"repository": repo_name, "branch": branch, "commits": commits} + ) except Exception as e: logger.error(f"Error getting commit log for {repo_name}: {str(e)}") - return web.Response(text=f"Error getting commit log: {str(e)}", status=500) + return web.Response( + text=f"Error getting commit log: {str(e)}", status=500 + ) @require_auth async def file_content(self, request): - username = request['username'] - repo_name = request.match_info['repo_name'] - file_path = request.match_info.get('file_path', '') - branch = request.query.get('branch', 'main') - repository_path = request['repository_path'] + request["username"] + repo_name = request.match_info["repo_name"] + file_path = request.match_info.get("file_path", "") + branch = request.query.get("branch", "main") + repository_path = request["repository_path"] error_response = self.check_repo_exists(repository_path, repo_name) if error_response: return error_response with tempfile.TemporaryDirectory() as temp_dir: try: - temp_repo = git.Repo.clone_from(self.repo_path(repository_path, repo_name), temp_dir) + temp_repo = git.Repo.clone_from( + self.repo_path(repository_path, repo_name), temp_dir + ) try: temp_repo.git.checkout(branch) except git.GitCommandError: return web.Response(text=f"Branch '{branch}' not found", status=404) file_full_path = os.path.join(temp_dir, file_path) if not os.path.exists(file_full_path): - return web.Response(text=f"File '{file_path}' not found", status=404) + return web.Response( + text=f"File '{file_path}' not found", status=404 + ) if os.path.isdir(file_full_path): files = os.listdir(file_full_path) - return web.json_response({ - "repository": repo_name, - "path": file_path, - "type": "directory", - "contents": files - }) + return web.json_response( + { + "repository": repo_name, + "path": file_path, + "type": "directory", + "contents": files, + } + ) else: try: - with open(file_full_path, 'r') as f: + with open(file_full_path) as f: content = f.read() return web.Response(text=content) except UnicodeDecodeError: - return web.Response(text=f"Cannot display binary file content for '{file_path}'", status=400) + return web.Response( + text=f"Cannot display binary file content for '{file_path}'", + status=400, + ) except Exception as e: logger.error(f"Error getting file content from {repo_name}: {str(e)}") - return web.Response(text=f"Error getting file content: {str(e)}", status=500) + return web.Response( + text=f"Error getting file content: {str(e)}", status=500 + ) @require_auth async def git_smart_http(self, request): - username = request['username'] - repository_path = request['repository_path'] + request["username"] + repository_path = request["repository_path"] path = request.path + async def get_repository_path(): - req_path = path.lstrip('/') - if req_path.endswith('/info/refs'): - repo_name = req_path[:-len('/info/refs')] - elif req_path.endswith('/git-upload-pack'): - repo_name = req_path[:-len('/git-upload-pack')] - elif req_path.endswith('/git-receive-pack'): - repo_name = req_path[:-len('/git-receive-pack')] + req_path = path.lstrip("/") + if req_path.endswith("/info/refs"): + repo_name = req_path[: -len("/info/refs")] + elif req_path.endswith("/git-upload-pack"): + repo_name = req_path[: -len("/git-upload-pack")] + elif req_path.endswith("/git-receive-pack"): + repo_name = req_path[: -len("/git-receive-pack")] else: repo_name = req_path - if repo_name.endswith('.git'): + if repo_name.endswith(".git"): repo_name = repo_name[:-4] repo_name = repo_name[4:] repo_dir = repository_path.joinpath(repo_name + ".git") logger.info(f"Resolved repo path: {repo_dir}") - return repo_dir + return repo_dir + async def handle_info_refs(service): repo_path = await get_repository_path() - + logger.info(f"handle_info_refs: {repo_path}") if not os.path.exists(repo_path): return web.Response(text="Repository not found", status=404) - cmd = [service, '--stateless-rpc', '--advertise-refs', str(repo_path)] + cmd = [service, "--stateless-rpc", "--advertise-refs", str(repo_path)] try: process = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() if process.returncode != 0: logger.error(f"Git command failed: {stderr.decode()}") - return web.Response(text=f"Git error: {stderr.decode()}", status=500) + return web.Response( + text=f"Git error: {stderr.decode()}", status=500 + ) response = web.StreamResponse( status=200, - reason='OK', + reason="OK", headers={ - 'Content-Type': f'application/x-{service}-advertisement', - 'Cache-Control': 'no-cache' - } + "Content-Type": f"application/x-{service}-advertisement", + "Cache-Control": "no-cache", + }, ) await response.prepare(request) packet = f"# service={service}\n" @@ -435,48 +489,58 @@ class GitApplication(web.Application): except Exception as e: logger.error(f"Error handling info/refs: {str(e)}") return web.Response(text=f"Server error: {str(e)}", status=500) + async def handle_service_rpc(service): repo_path = await get_repository_path() logger.info(f"handle_service_rpc: {repo_path}") if not os.path.exists(repo_path): return web.Response(text="Repository not found", status=404) - if not request.headers.get('Content-Type') == f'application/x-{service}-request': + if ( + not request.headers.get("Content-Type") + == f"application/x-{service}-request" + ): return web.Response(text="Invalid Content-Type", status=403) body = await request.read() - cmd = [service, '--stateless-rpc', str(repo_path)] + cmd = [service, "--stateless-rpc", str(repo_path)] try: process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate(input=body) if process.returncode != 0: logger.error(f"Git command failed: {stderr.decode()}") - return web.Response(text=f"Git error: {stderr.decode()}", status=500) + return web.Response( + text=f"Git error: {stderr.decode()}", status=500 + ) return web.Response( - body=stdout, - content_type=f'application/x-{service}-result' + body=stdout, content_type=f"application/x-{service}-result" ) except Exception as e: logger.error(f"Error handling service RPC: {str(e)}") return web.Response(text=f"Server error: {str(e)}", status=500) - if request.method == 'GET' and path.endswith('/info/refs'): - service = request.query.get('service') - if service in ('git-upload-pack', 'git-receive-pack'): + + if request.method == "GET" and path.endswith("/info/refs"): + service = request.query.get("service") + if service in ("git-upload-pack", "git-receive-pack"): return await handle_info_refs(service) else: - return web.Response(text="Smart HTTP requires service parameter", status=400) - elif request.method == 'POST' and '/git-upload-pack' in path: - return await handle_service_rpc('git-upload-pack') - elif request.method == 'POST' and '/git-receive-pack' in path: - return await handle_service_rpc('git-receive-pack') + return web.Response( + text="Smart HTTP requires service parameter", status=400 + ) + elif request.method == "POST" and "/git-upload-pack" in path: + return await handle_service_rpc("git-upload-pack") + elif request.method == "POST" and "/git-receive-pack" in path: + return await handle_service_rpc("git-receive-pack") return web.Response(text="Not found", status=404) -if __name__ == '__main__': + +if __name__ == "__main__": try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) logger.info("Using uvloop for improved performance") except ImportError: diff --git a/src/snek/snode.py b/src/snek/snode.py index ae42a1f..517d52a 100644 --- a/src/snek/snode.py +++ b/src/snek/snode.py @@ -2,28 +2,28 @@ import aiohttp ENABLED = False -import aiohttp import asyncio -from aiohttp import web - +import json import sqlite3 -import dataset +import aiohttp +from aiohttp import web from sqlalchemy import event from sqlalchemy.engine import Engine -import json - queue = asyncio.Queue() + class State: - do_not_sync = False + do_not_sync = False + async def sync_service(app): if not ENABLED: - return + return session = aiohttp.ClientSession() - async with session.ws_connect('http://localhost:3131/ws') as ws: + async with session.ws_connect("http://localhost:3131/ws") as ws: + async def receive(): queries_synced = 0 @@ -31,7 +31,7 @@ async def sync_service(app): if msg.type == aiohttp.WSMsgType.TEXT: try: data = json.loads(msg.data) - State.do_not_sync = True + State.do_not_sync = True app.db.execute(*data) app.db.commit() State.do_not_sync = False @@ -41,21 +41,25 @@ async def sync_service(app): await app.services.socket.broadcast_event() except Exception as e: print(e) - pass - #print(f"Received: {msg.data}") + pass + # print(f"Received: {msg.data}") elif msg.type == aiohttp.WSMsgType.ERROR: break + async def write(): while True: msg = await queue.get() - await ws.send_str(json.dumps(msg,default=str)) + await ws.send_str(json.dumps(msg, default=str)) queue.task_done() await asyncio.gather(receive(), write()) await session.close() + queries_queued = 0 + + # Attach a listener to log all executed statements @event.listens_for(Engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): @@ -63,7 +67,7 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem return global queries_queued if State.do_not_sync: - print(statement,parameters) + print(statement, parameters) return if statement.startswith("SELECT"): return @@ -71,17 +75,18 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem queries_queued += 1 print("Queries queued: " + str(queries_queued)) + async def websocket_handler(request): - queries_broadcasted = 0 + queries_broadcasted = 0 ws = web.WebSocketResponse() await ws.prepare(request) - request.app['websockets'].append(ws) + request.app["websockets"].append(ws) async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: - for client in request.app['websockets']: + for client in request.app["websockets"]: if client != ws: await client.send_str(msg.data) - cursor = request.app['db'].cursor() + cursor = request.app["db"].cursor() data = json.loads(msg.data) queries_broadcasted += 1 @@ -89,28 +94,32 @@ async def websocket_handler(request): cursor.close() print("Queries broadcasted: " + str(queries_broadcasted)) elif msg.type == aiohttp.WSMsgType.ERROR: - print(f'WebSocket connection closed with exception {ws.exception()}') + print(f"WebSocket connection closed with exception {ws.exception()}") - request.app['websockets'].remove(ws) + request.app["websockets"].remove(ws) return ws -app = web.Application() -app['websockets'] = [] -app.router.add_get('/ws', websocket_handler) +app = web.Application() +app["websockets"] = [] + +app.router.add_get("/ws", websocket_handler) + async def on_startup(app): - app['db'] = sqlite3.connect('snek.db') + app["db"] = sqlite3.connect("snek.db") print("Server starting...") + async def on_cleanup(app): - for ws in app['websockets']: + for ws in app["websockets"]: await ws.close() - app['db'].close() + app["db"].close() + app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) -if __name__ == '__main__': - web.run_app(app, host='127.0.0.1', port=3131) +if __name__ == "__main__": + web.run_app(app, host="127.0.0.1", port=3131) diff --git a/src/snek/sssh.py b/src/snek/sssh.py index 6b179b6..bb83fa9 100644 --- a/src/snek/sssh.py +++ b/src/snek/sssh.py @@ -1,25 +1,30 @@ -import asyncssh import logging from pathlib import Path -global _app +import asyncssh + +global _app + def set_app(app): global _app - _app = app + _app = app + def get_app(): return _app + logger = logging.getLogger(__name__) roots = {} + class SFTPServer(asyncssh.SFTPServer): - + def __init__(self, chan: asyncssh.SSHServerChannel): self.root = get_app().services.user.get_home_folder_by_username( - chan.get_extra_info('username') + chan.get_extra_info("username") ) self.root.mkdir(exist_ok=True) self.root = str(self.root) @@ -30,20 +35,22 @@ class SFTPServer(asyncssh.SFTPServer): logger.debug(f"Mapping client path {path} to {mapped_path}") return str(mapped_path).encode() + class SSHServer(asyncssh.SSHServer): def password_auth_supported(self): return True def validate_password(self, username, password): logger.debug(f"Validating credentials for user {username}") - result = get_app().services.user.authenticate_sync(username,password) + result = get_app().services.user.authenticate_sync(username, password) logger.info(f"Validating credentials for user {username}: {result}") return result -async def start_ssh_server(app,host,port): - set_app(app) + +async def start_ssh_server(app, host, port): + set_app(app) logger.info("Starting SFTP server setup") - + host_key_path = Path("drive") / ".ssh" / "sftp_server_key" host_key_path.parent.mkdir(exist_ok=True, parents=True) try: @@ -63,14 +70,12 @@ async def start_ssh_server(app,host,port): x = await asyncssh.listen( host=host, port=port, - #process_factory=handle_client, + # process_factory=handle_client, server_host_keys=[key], server_factory=SSHServer, - sftp_factory=SFTPServer + sftp_factory=SFTPServer, ) return x - except Exception as e: + except Exception: logger.warning(f"Failed to start SFTP server. Already running.") - pass - - + pass diff --git a/src/snek/system/docker.py b/src/snek/system/docker.py index d651950..add91e4 100644 --- a/src/snek/system/docker.py +++ b/src/snek/system/docker.py @@ -1,58 +1,71 @@ -import yaml import copy +import yaml + + class ComposeFileManager: - def __init__(self, compose_path='docker-compose.yml'): + def __init__(self, compose_path="docker-compose.yml"): self.compose_path = compose_path self._load() def _load(self): try: - with open(self.compose_path, 'r') as f: + with open(self.compose_path) as f: self.compose = yaml.safe_load(f) or {} except FileNotFoundError: - self.compose = {'version': '3', 'services': {}} + self.compose = {"version": "3", "services": {}} def _save(self): - with open(self.compose_path, 'w') as f: + with open(self.compose_path, "w") as f: yaml.dump(self.compose, f, default_flow_style=False) def list_instances(self): - return list(self.compose.get('services', {}).keys()) + return list(self.compose.get("services", {}).keys()) - def create_instance(self, name, image, command=None, cpus=None, memory=None, ports=None, volumes=None): + def create_instance( + self, + name, + image, + command=None, + cpus=None, + memory=None, + ports=None, + volumes=None, + ): service = { - 'image': image, + "image": image, } if command: - service['command'] = command + service["command"] = command if cpus or memory: - service['deploy'] = {'resources': {'limits': {}}} + service["deploy"] = {"resources": {"limits": {}}} if cpus: - service['deploy']['resources']['limits']['cpus'] = str(cpus) + service["deploy"]["resources"]["limits"]["cpus"] = str(cpus) if memory: - service['deploy']['resources']['limits']['memory'] = str(memory) + service["deploy"]["resources"]["limits"]["memory"] = str(memory) if ports: - service['ports'] = [f"{host}:{container}" for container, host in ports.items()] + service["ports"] = [ + f"{host}:{container}" for container, host in ports.items() + ] if volumes: - service['volumes'] = volumes + service["volumes"] = volumes - self.compose.setdefault('services', {})[name] = service + self.compose.setdefault("services", {})[name] = service self._save() def remove_instance(self, name): - if name in self.compose.get('services', {}): - del self.compose['services'][name] + if name in self.compose.get("services", {}): + del self.compose["services"][name] self._save() def get_instance(self, name): - return self.compose.get('services', {}).get(name) + return self.compose.get("services", {}).get(name) def duplicate_instance(self, name, new_name): orig = self.get_instance(name) if not orig: raise ValueError(f"No such instance: {name}") - self.compose['services'][new_name] = copy.deepcopy(orig) + self.compose["services"][new_name] = copy.deepcopy(orig) self._save() def update_instance(self, name, **kwargs): @@ -62,11 +75,12 @@ class ComposeFileManager: for k, v in kwargs.items(): if v is not None: service[k] = v - self.compose['services'][name] = service + self.compose["services"][name] = service self._save() # Storage size is not tracked in compose files; would need Docker API for that. + # Example usage: # mgr = ComposeFileManager() # mgr.create_instance('web', 'nginx:latest', cpus=1, memory='512m', ports={80:8080}, volumes=['./data:/data']) diff --git a/src/snek/system/mapper.py b/src/snek/system/mapper.py index a0e37d2..fef7784 100644 --- a/src/snek/system/mapper.py +++ b/src/snek/system/mapper.py @@ -1,6 +1,7 @@ DEFAULT_LIMIT = 30 +import asyncio import typing -import asyncio + from snek.system.model import BaseModel @@ -12,21 +13,21 @@ class BaseMapper: def __init__(self, app): self.app = app - self.semaphore = asyncio.Semaphore(1) + self.semaphore = asyncio.Semaphore(1) self.default_limit = self.__class__.default_limit @property def db(self): return self.app.db - @property + @property def loop(self): return asyncio.get_event_loop() async def run_in_executor(self, func, *args, **kwargs): async with self.semaphore: return func(*args, **kwargs) - #return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs)) + # return await self.loop.run_in_executor(None, lambda: func(*args, **kwargs)) async def new(self): return self.model_class(mapper=self, app=self.app) @@ -38,8 +39,8 @@ class BaseMapper: async def get(self, uid: str = None, **kwargs) -> BaseModel: if uid: kwargs["uid"] = uid - - record = await self.run_in_executor(self.table.find_one,**kwargs) + + record = await self.run_in_executor(self.table.find_one, **kwargs) if not record: return None record = dict(record) @@ -50,7 +51,7 @@ class BaseMapper: return await self.model_class.from_record(mapper=self, record=record) async def exists(self, **kwargs): - return await self.run_in_executor(self.table.exists,**kwargs) + return await self.run_in_executor(self.table.exists, **kwargs) async def count(self, **kwargs) -> int: return await self.run_in_executor(self.table.count, **kwargs) @@ -71,7 +72,7 @@ class BaseMapper: yield model async def query(self, sql, *args): - for record in await self.run_in_executor(self.db.query,sql, *args): + for record in await self.run_in_executor(self.db.query, sql, *args): yield dict(record) async def update(self, model): diff --git a/src/snek/system/markdown.py b/src/snek/system/markdown.py index a823b5f..e83ff38 100644 --- a/src/snek/system/markdown.py +++ b/src/snek/system/markdown.py @@ -1,5 +1,5 @@ # Original source: https://brandonjay.dev/posts/2021/render-markdown-html-in-python-with-jinja2 -import re +import re from types import SimpleNamespace from app.cache import time_cache_async @@ -13,19 +13,20 @@ from pygments.lexers import get_lexer_by_name def strip_markdown(md_text): - # Remove code blocks ( - md_text = re.sub(r'[\s\S]?```', '', md_text) - md_text = re.sub(r'^\s{4,}.$', '', md_text, flags=re.MULTILINE) - md_text = re.sub(r'^\s{0,3}#{1,6}\s+', '', md_text, flags=re.MULTILINE) - md_text = re.sub(r'!\[.?\]\(.?\)', '', md_text) - md_text = re.sub(r'\[([^\]]+)\]\(.?\)', r'\1', md_text) - md_text = re.sub(r'(\*|_){1,3}(.+?)\1{1,3}', r'\2', md_text) - md_text = re.sub(r'^\s{0,3}>+\s?', '', md_text, flags=re.MULTILINE) - md_text = re.sub(r'^(\s)(\-{3,}|_{3,}|\{3,})\s$', '', md_text, flags=re.MULTILINE) - md_text = re.sub(r'[`~>#+\-=]', '', md_text) - md_text = re.sub(r'\s+', ' ', md_text) + # Remove code blocks ( + md_text = re.sub(r"[\s\S]?```", "", md_text) + md_text = re.sub(r"^\s{4,}.$", "", md_text, flags=re.MULTILINE) + md_text = re.sub(r"^\s{0,3}#{1,6}\s+", "", md_text, flags=re.MULTILINE) + md_text = re.sub(r"!\[.?\]\(.?\)", "", md_text) + md_text = re.sub(r"\[([^\]]+)\]\(.?\)", r"\1", md_text) + md_text = re.sub(r"(\*|_){1,3}(.+?)\1{1,3}", r"\2", md_text) + md_text = re.sub(r"^\s{0,3}>+\s?", "", md_text, flags=re.MULTILINE) + md_text = re.sub(r"^(\s)(\-{3,}|_{3,}|\{3,})\s$", "", md_text, flags=re.MULTILINE) + md_text = re.sub(r"[`~>#+\-=]", "", md_text) + md_text = re.sub(r"\s+", " ", md_text) return md_text.strip() + class MarkdownRenderer(HTMLRenderer): _allow_harmful_protocols = False @@ -40,7 +41,7 @@ class MarkdownRenderer(HTMLRenderer): formatter = html.HtmlFormatter() self.env.globals["highlight_styles"] = formatter.get_style_defs() - #def _escape(self, str): + # def _escape(self, str): # return str ##escape(str) def get_lexer(self, lang, default="bash"): diff --git a/src/snek/system/middleware.py b/src/snek/system/middleware.py index 368e3ae..93e5eaf 100644 --- a/src/snek/system/middleware.py +++ b/src/snek/system/middleware.py @@ -6,9 +6,9 @@ # MIT License: This code is distributed under the MIT License. -from aiohttp import web import secrets +from aiohttp import web csp_policy = ( "default-src 'self'; " @@ -22,14 +22,16 @@ csp_policy = ( def generate_nonce(): return secrets.token_hex(16) + @web.middleware async def csp_middleware(request, handler): response = await handler(request) - return response - nonce = generate_nonce() - response.headers['Content-Security-Policy'] = csp_policy.format(nonce=nonce) return response + nonce = generate_nonce() + response.headers["Content-Security-Policy"] = csp_policy.format(nonce=nonce) + return response + @web.middleware async def no_cors_middleware(request, handler): diff --git a/src/snek/system/security.py b/src/snek/system/security.py index 4ead284..a5b9302 100644 --- a/src/snek/system/security.py +++ b/src/snek/system/security.py @@ -63,9 +63,11 @@ def hash_sync(data: str, salt: str = DEFAULT_SALT) -> str: obj = hashlib.sha256(salted) return obj.hexdigest() + async def hash(data: str, salt: str = DEFAULT_SALT) -> str: return hash_sync(data, salt) + def verify_sync(string: str, hashed: str) -> bool: """Verify if the given string matches the hashed value. @@ -78,5 +80,6 @@ def verify_sync(string: str, hashed: str) -> bool: """ return hash_sync(string) == hashed + async def verify(string: str, hashed: str) -> bool: return verify_sync(string, hashed) diff --git a/src/snek/system/template.py b/src/snek/system/template.py index 3dd4357..81306e9 100644 --- a/src/snek/system/template.py +++ b/src/snek/system/template.py @@ -1,17 +1,16 @@ import mimetypes import re -from functools import lru_cache from types import SimpleNamespace -from urllib.parse import urlparse, parse_qs -from app.cache import time_cache +from urllib.parse import parse_qs, urlparse + +import bleach import emoji import requests +from app.cache import time_cache from bs4 import BeautifulSoup from jinja2 import TemplateSyntaxError, nodes from jinja2.ext import Extension from jinja2.nodes import Const -import bleach - emoji.EMOJI_DATA[''] = { "en": ":snek1:", @@ -81,13 +80,29 @@ emoji.EMOJI_DATA[ ALLOWED_TAGS = list(bleach.sanitizer.ALLOWED_TAGS) + [ - "img", "video", "audio", "source", "iframe", "picture", "span" + "img", + "video", + "audio", + "source", + "iframe", + "picture", + "span", ] ALLOWED_ATTRIBUTES = { **bleach.sanitizer.ALLOWED_ATTRIBUTES, "img": ["src", "alt", "title", "width", "height"], "a": ["href", "title", "target", "rel", "referrerpolicy", "class"], - "iframe": ["src", "width", "height", "frameborder", "allow", "allowfullscreen", "title", "referrerpolicy", "style"], + "iframe": [ + "src", + "width", + "height", + "frameborder", + "allow", + "allowfullscreen", + "title", + "referrerpolicy", + "style", + ], "video": ["src", "controls", "width", "height"], "audio": ["src", "controls"], "source": ["src", "type"], @@ -96,9 +111,6 @@ ALLOWED_ATTRIBUTES = { } - - - def sanitize_html(value): return bleach.clean( value, @@ -109,7 +121,6 @@ def sanitize_html(value): ) - def set_link_target_blank(text): soup = BeautifulSoup(text, "html.parser") @@ -118,26 +129,43 @@ def set_link_target_blank(text): element.attrs["rel"] = "noopener noreferrer" element.attrs["referrerpolicy"] = "no-referrer" element.attrs["href"] = element.attrs["href"].strip(".").strip(",") - + return str(soup) + SAFE_ATTRIBUTES = { - 'href', 'src', 'alt', 'title', 'width', 'height', 'style', 'id', 'class', - 'rel', 'type', 'name', 'value', 'placeholder', 'aria-hidden', 'aria-label', 'srcset' + "href", + "src", + "alt", + "title", + "width", + "height", + "style", + "id", + "class", + "rel", + "type", + "name", + "value", + "placeholder", + "aria-hidden", + "aria-label", + "srcset", } + def whitelist_attributes(html): - soup = BeautifulSoup(html, 'html.parser') + soup = BeautifulSoup(html, "html.parser") for tag in soup.find_all(): - if hasattr(tag, 'attrs'): - if tag.name in ['script','form','input']: - tag.replace_with('') + if hasattr(tag, "attrs"): + if tag.name in ["script", "form", "input"]: + tag.replace_with("") continue attrs = dict(tag.attrs) for attr in list(attrs): # Check if attribute is in the safe list or is a data-* attribute - if not (attr in SAFE_ATTRIBUTES or attr.startswith('data-')): + if not (attr in SAFE_ATTRIBUTES or attr.startswith("data-")): del tag.attrs[attr] return str(soup) @@ -236,12 +264,12 @@ def enrich_image_rendering(text): for element in soup.find_all("img"): if element.attrs["src"].startswith("/"): element.attrs["src"] += "?width=240&height=240" - picture_template = f''' + picture_template = f""" {element.attrs[ - ''' + """ element.replace_with(BeautifulSoup(picture_template, "html.parser")) return str(soup) @@ -285,7 +313,7 @@ def linkify_https(text): return set_link_target_blank(str(soup)) -@time_cache(timeout=60*60) +@time_cache(timeout=60 * 60) def get_url_content(url): try: response = requests.get(url, timeout=5) @@ -368,12 +396,11 @@ def embed_url(text): page_video = get_element_options(None, "video", "video", "video") page_audio = get_element_options(None, "audio", "audio", "audio") - preview_size = ( + ( get_element_options(None, None, None, "card") or "summary_large_image" ) - attachment_base = BeautifulSoup(str(element), "html.parser") attachments[original_link_name] = attachment_base @@ -412,20 +439,28 @@ def embed_url(text): ) description_element.append( - BeautifulSoup(f'{page_name}', "html.parser") + BeautifulSoup( + f'{page_name}', + "html.parser", + ) ) description_element.append( - BeautifulSoup(f"

{page_description or "No description available."}

", "html.parser") + BeautifulSoup( + f"

{page_description or "No description available."}

", + "html.parser", + ) ) description_element.append( - BeautifulSoup(f"", "html.parser") + BeautifulSoup( + f"", + "html.parser", + ) ) render_element.append(description_element_base) - for attachment in attachments.values(): soup.append(attachment) diff --git a/src/snek/system/websocket.py b/src/snek/system/websocket.py index c54b094..14d40fd 100644 --- a/src/snek/system/websocket.py +++ b/src/snek/system/websocket.py @@ -1,22 +1,24 @@ class WebSocketClient: def __init__(self, hostname, port): - self.buffer = b'' + self.buffer = b"" self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.hostname = hostname + self.hostname = hostname self.port = port - self.connect() - + self.connect() + def __getattr__(self, method, *args, **kwargs): if method in self.__dict__.keys(): return self.__dict__[method] + def call(*args, **kwargs): - self.write(json.dumps({'method': method, 'args': args, 'kwargs': kwargs})) + self.write(json.dumps({"method": method, "args": args, "kwargs": kwargs})) return json.loads(self.read()) + return call def connect(self): self.socket.connect((self.hostname, self.port)) - key = base64.b64encode(b'1234123412341234').decode('utf-8') + key = base64.b64encode(b"1234123412341234").decode("utf-8") handshake = ( f"GET /db HTTP/1.1\r\n" f"Host: localhost:3131\r\n" @@ -25,57 +27,58 @@ class WebSocketClient: f"Sec-WebSocket-Key: {key}\r\n" f"Sec-WebSocket-Version: 13\r\n\r\n" ) - self.socket.sendall(handshake.encode('utf-8')) - response = self.read_until(b'\r\n\r\n') - if b'101 Switching Protocols' not in response: + self.socket.sendall(handshake.encode("utf-8")) + response = self.read_until(b"\r\n\r\n") + if b"101 Switching Protocols" not in response: raise Exception("Failed to connect to WebSocket") def write(self, message): - message_bytes = message.encode('utf-8') + message_bytes = message.encode("utf-8") length = len(message_bytes) if length <= 125: - self.socket.sendall(b'\x81' + bytes([length]) + message_bytes) + self.socket.sendall(b"\x81" + bytes([length]) + message_bytes) elif length >= 126 and length <= 65535: - self.socket.sendall(b'\x81' + bytes([126]) + length.to_bytes(2, 'big') + message_bytes) + self.socket.sendall( + b"\x81" + bytes([126]) + length.to_bytes(2, "big") + message_bytes + ) else: - self.socket.sendall(b'\x81' + bytes([127]) + length.to_bytes(8, 'big') + message_bytes) - + self.socket.sendall( + b"\x81" + bytes([127]) + length.to_bytes(8, "big") + message_bytes + ) - def read_until(self, delimiter): + def read_until(self, delimiter): while True: find_pos = self.buffer.find(delimiter) if find_pos != -1: - data = self.buffer[:find_pos+4] - self.buffer = self.buffer[find_pos+4:] - return data - + data = self.buffer[: find_pos + 4] + self.buffer = self.buffer[find_pos + 4 :] + return data + chunk = self.socket.recv(1024) if not chunk: return None self.buffer += chunk - + def read_exactly(self, length): while len(self.buffer) < length: chunk = self.socket.recv(length - len(self.buffer)) if not chunk: return None - self.buffer += chunk - response = self.buffer[: length] + self.buffer += chunk + response = self.buffer[:length] self.buffer = self.buffer[length:] return response def read(self): - frame = None + frame = None frame = self.read_exactly(2) length = frame[1] & 127 if length == 126: - length = int.from_bytes(self.read_exactly(2), 'big') + length = int.from_bytes(self.read_exactly(2), "big") elif length == 127: - length = int.from_bytes(self.read_exactly(8), 'big') + length = int.from_bytes(self.read_exactly(8), "big") message = self.read_exactly(length) return message - + def close(self): self.socket.close() - - diff --git a/src/snek/view/avatar.py b/src/snek/view/avatar.py index d129f5e..c984ec1 100644 --- a/src/snek/view/avatar.py +++ b/src/snek/view/avatar.py @@ -31,21 +31,17 @@ from multiavatar import multiavatar from snek.system.view import BaseView from snek.view.avatar_animal import generate_avatar_with_options -import functools class AvatarView(BaseView): login_required = False - - def __init__(self, *args,**kwargs): - super().__init__(*args,**kwargs) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.avatars = {} async def get(self): type_ = self.request.query.get("type") - type_match = { - "animal": self.get_animal, - "default": self.get_default - } + type_match = {"animal": self.get_animal, "default": self.get_default} handler = type_match.get(type_, self.get_default) return await handler() @@ -64,10 +60,9 @@ class AvatarView(BaseView): avatar = generate_avatar_with_options(self.request.query) await self.app.set(key, avatar) return web.Response(text=avatar, content_type="image/svg+xml") - except Exception as e: + except Exception: pass - def _get(self, uid): avatar = generate_avatar_with_options(self.request.query) return avatar diff --git a/src/snek/view/avatar_animal.py b/src/snek/view/avatar_animal.py index 3a7f653..42c47d3 100644 --- a/src/snek/view/avatar_animal.py +++ b/src/snek/view/avatar_animal.py @@ -1,23 +1,61 @@ -import random -import math import argparse import json -from typing import Dict, List, Tuple, Optional, Union +import math +import random +from typing import Dict, List, Optional + class AnimalAvatarGenerator: """A generator for animal-themed avatar SVGs.""" - + # Constants ANIMALS = [ - "cat", "dog", "fox", "wolf", "bear", "panda", "koala", "tiger", "lion", - "rabbit", "monkey", "elephant", "giraffe", "zebra", "penguin", "owl", - "deer", "raccoon", "squirrel", "hedgehog", "otter", "frog" + "cat", + "dog", + "fox", + "wolf", + "bear", + "panda", + "koala", + "tiger", + "lion", + "rabbit", + "monkey", + "elephant", + "giraffe", + "zebra", + "penguin", + "owl", + "deer", + "raccoon", + "squirrel", + "hedgehog", + "otter", + "frog", ] - + COLOR_PALETTES = { "natural": { - "cat": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#808080", "#FFFFFF", "#FFA500"], - "dog": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#808080", "#FFFFFF", "#FFA500"], + "cat": [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#000000", + "#808080", + "#FFFFFF", + "#FFA500", + ], + "dog": [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#000000", + "#808080", + "#FFFFFF", + "#FFA500", + ], "fox": ["#FF6600", "#FF7F00", "#FF8C00", "#FFA500", "#FFFFFF", "#000000"], "wolf": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000", "#696969"], "bear": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#FFFFFF"], @@ -25,35 +63,68 @@ class AnimalAvatarGenerator: "koala": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000"], "tiger": ["#FF8C00", "#FF7F00", "#FFFFFF", "#000000"], "lion": ["#DAA520", "#B8860B", "#CD853F", "#D2B48C", "#FFFFFF", "#000000"], - "rabbit": ["#FFFFFF", "#F5F5F5", "#D3D3D3", "#A9A9A9", "#FFC0CB", "#000000"], + "rabbit": [ + "#FFFFFF", + "#F5F5F5", + "#D3D3D3", + "#A9A9A9", + "#FFC0CB", + "#000000", + ], "monkey": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], "elephant": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000"], "giraffe": ["#DAA520", "#B8860B", "#F5DEB3", "#FFFFFF", "#000000"], "zebra": ["#000000", "#FFFFFF"], "penguin": ["#000000", "#FFFFFF", "#FFA500"], - "owl": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000", "#FFFFFF", "#FFC0CB"], + "owl": [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#000000", + "#FFFFFF", + "#FFC0CB", + ], "deer": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#FFFFFF", "#000000"], "raccoon": ["#808080", "#A9A9A9", "#D3D3D3", "#FFFFFF", "#000000"], "squirrel": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], "hedgehog": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], "otter": ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#000000"], - "frog": ["#008000", "#00FF00", "#ADFF2F", "#7FFF00", "#000000", "#FFFFFF"] + "frog": ["#008000", "#00FF00", "#ADFF2F", "#7FFF00", "#000000", "#FFFFFF"], }, "pastel": { - "all": ["#FFB6C1", "#FFD700", "#FFDAB9", "#98FB98", "#ADD8E6", "#DDA0DD", "#F0E68C", "#FFFFE0"] + "all": [ + "#FFB6C1", + "#FFD700", + "#FFDAB9", + "#98FB98", + "#ADD8E6", + "#DDA0DD", + "#F0E68C", + "#FFFFE0", + ] }, "vibrant": { - "all": ["#FF0000", "#00FF00", "#0000FF", "#FFFF00", "#FF00FF", "#00FFFF", "#FFA500", "#FF1493"] + "all": [ + "#FF0000", + "#00FF00", + "#0000FF", + "#FFFF00", + "#FF00FF", + "#00FFFF", + "#FFA500", + "#FF1493", + ] }, "mono": { "all": ["#000000", "#333333", "#666666", "#999999", "#CCCCCC", "#FFFFFF"] - } + }, } - + EYE_STYLES = ["round", "oval", "almond", "wide", "narrow", "cute"] - + FACE_SHAPES = ["round", "oval", "square", "heart", "triangular", "diamond"] - + EAR_STYLES = { "cat": ["pointed", "folded", "round", "large", "small"], "dog": ["floppy", "pointed", "round", "large", "small"], @@ -76,11 +147,11 @@ class AnimalAvatarGenerator: "squirrel": ["pointed", "small"], "hedgehog": ["round", "small"], "otter": ["round", "small"], - "frog": ["none"] + "frog": ["none"], } - + NOSE_STYLES = ["round", "triangular", "small", "large", "heart", "button"] - + SPECIAL_FEATURES = { "cat": ["whiskers", "stripes", "spots"], "dog": ["spots", "patch", "whiskers"], @@ -103,16 +174,16 @@ class AnimalAvatarGenerator: "squirrel": ["bushy_tail", "none"], "hedgehog": ["spikes", "none"], "otter": ["whiskers", "none"], - "frog": ["spots", "none"] + "frog": ["spots", "none"], } - + EXPRESSIONS = ["happy", "serious", "surprised", "sleepy", "wink"] - + def __init__(self, seed: Optional[int] = None): """Initialize the avatar generator with an optional seed for reproducibility.""" if seed is not None: random.seed(seed) - + def _get_colors(self, animal: str, color_palette: str) -> List[str]: """Get colors for the given animal and palette.""" if color_palette in self.COLOR_PALETTES: @@ -120,382 +191,790 @@ class AnimalAvatarGenerator: return self.COLOR_PALETTES[color_palette][animal] elif "all" in self.COLOR_PALETTES[color_palette]: return self.COLOR_PALETTES[color_palette]["all"] - + # Default to natural palette for the animal or general natural colors if animal in self.COLOR_PALETTES["natural"]: return self.COLOR_PALETTES["natural"][animal] - + # If no specific colors found, use a mix of browns and grays - return ["#8B4513", "#A0522D", "#D2B48C", "#F5DEB3", "#808080", "#A9A9A9", "#FFFFFF"] - + return [ + "#8B4513", + "#A0522D", + "#D2B48C", + "#F5DEB3", + "#808080", + "#A9A9A9", + "#FFFFFF", + ] + def _get_ear_style(self, animal: str) -> str: """Get a random ear style appropriate for the animal.""" if animal in self.EAR_STYLES: return random.choice(self.EAR_STYLES[animal]) return "none" # Default for animals not in the list - + def _get_special_feature(self, animal: str) -> str: """Get a random special feature appropriate for the animal.""" if animal in self.SPECIAL_FEATURES: return random.choice(self.SPECIAL_FEATURES[animal]) return "none" # Default for animals not in the list - - def _draw_circle(self, cx: float, cy: float, r: float, fill: str, - stroke: str = "none", stroke_width: float = 1.0) -> str: + + def _draw_circle( + self, + cx: float, + cy: float, + r: float, + fill: str, + stroke: str = "none", + stroke_width: float = 1.0, + ) -> str: """Generate SVG for a circle.""" return f'' - - def _draw_ellipse(self, cx: float, cy: float, rx: float, ry: float, - fill: str, stroke: str = "none", stroke_width: float = 1.0) -> str: + + def _draw_ellipse( + self, + cx: float, + cy: float, + rx: float, + ry: float, + fill: str, + stroke: str = "none", + stroke_width: float = 1.0, + ) -> str: """Generate SVG for an ellipse.""" return f'' - - def _draw_path(self, d: str, fill: str, stroke: str = "none", - stroke_width: float = 1.0, stroke_linecap: str = "round") -> str: + + def _draw_path( + self, + d: str, + fill: str, + stroke: str = "none", + stroke_width: float = 1.0, + stroke_linecap: str = "round", + ) -> str: """Generate SVG for a path.""" return f'' - - def _draw_polygon(self, points: str, fill: str, stroke: str = "none", - stroke_width: float = 1.0) -> str: + + def _draw_polygon( + self, points: str, fill: str, stroke: str = "none", stroke_width: float = 1.0 + ) -> str: """Generate SVG for a polygon.""" return f'' - - def _draw_face(self, animal: str, face_shape: str, face_color: str, - x: float, y: float, size: float) -> str: + + def _draw_face( + self, + animal: str, + face_shape: str, + face_color: str, + x: float, + y: float, + size: float, + ) -> str: """Draw the animal's face based on face shape.""" elements = [] - + if face_shape == "round": elements.append(self._draw_circle(x, y, size * 0.4, face_color)) elif face_shape == "oval": - elements.append(self._draw_ellipse(x, y, size * 0.35, size * 0.45, face_color)) + elements.append( + self._draw_ellipse(x, y, size * 0.35, size * 0.45, face_color) + ) elif face_shape == "square": - points = (f"{x-size*0.35},{y-size*0.35} {x+size*0.35},{y-size*0.35} " - f"{x+size*0.35},{y+size*0.35} {x-size*0.35},{y+size*0.35}") + points = ( + f"{x-size*0.35},{y-size*0.35} {x+size*0.35},{y-size*0.35} " + f"{x+size*0.35},{y+size*0.35} {x-size*0.35},{y+size*0.35}" + ) elements.append(self._draw_polygon(points, face_color)) elif face_shape == "heart": # Create a heart shape using paths cx, cy = x, y + size * 0.05 r = size * 0.2 - path = (f"M {cx} {cy-r*0.4} " - f"C {cx-r*1.5} {cy-r*1.5}, {cx-r*2} {cy+r*0.5}, {cx} {cy+r} " - f"C {cx+r*2} {cy+r*0.5}, {cx+r*1.5} {cy-r*1.5}, {cx} {cy-r*0.4} Z") + path = ( + f"M {cx} {cy-r*0.4} " + f"C {cx-r*1.5} {cy-r*1.5}, {cx-r*2} {cy+r*0.5}, {cx} {cy+r} " + f"C {cx+r*2} {cy+r*0.5}, {cx+r*1.5} {cy-r*1.5}, {cx} {cy-r*0.4} Z" + ) elements.append(self._draw_path(path, face_color)) elif face_shape == "triangular": - points = f"{x},{y-size*0.4} {x+size*0.4},{y+size*0.3} {x-size*0.4},{y+size*0.3}" + points = ( + f"{x},{y-size*0.4} {x+size*0.4},{y+size*0.3} {x-size*0.4},{y+size*0.3}" + ) elements.append(self._draw_polygon(points, face_color)) elif face_shape == "diamond": - points = f"{x},{y-size*0.4} {x+size*0.35},{y} {x},{y+size*0.4} {x-size*0.35},{y}" + points = ( + f"{x},{y-size*0.4} {x+size*0.35},{y} {x},{y+size*0.4} {x-size*0.35},{y}" + ) elements.append(self._draw_polygon(points, face_color)) - + return "\n".join(elements) - - def _draw_ears(self, animal: str, ear_style: str, face_color: str, - inner_color: str, x: float, y: float, size: float) -> str: + + def _draw_ears( + self, + animal: str, + ear_style: str, + face_color: str, + inner_color: str, + x: float, + y: float, + size: float, + ) -> str: """Draw the animal's ears based on ear style.""" elements = [] - + if ear_style == "none": return "" - + if ear_style == "pointed": # Left ear points_left = f"{x-size*0.2},{y-size*0.1} {x-size*0.35},{y-size*0.45} {x-size*0.05},{y-size*0.15}" elements.append(self._draw_polygon(points_left, face_color)) - + # Right ear points_right = f"{x+size*0.2},{y-size*0.1} {x+size*0.35},{y-size*0.45} {x+size*0.05},{y-size*0.15}" elements.append(self._draw_polygon(points_right, face_color)) - + # Inner ears points_inner_left = f"{x-size*0.2},{y-size*0.13} {x-size*0.3},{y-size*0.38} {x-size*0.1},{y-size*0.17}" elements.append(self._draw_polygon(points_inner_left, inner_color)) - + points_inner_right = f"{x+size*0.2},{y-size*0.13} {x+size*0.3},{y-size*0.38} {x+size*0.1},{y-size*0.17}" elements.append(self._draw_polygon(points_inner_right, inner_color)) - + elif ear_style == "round": # Left ear - elements.append(self._draw_circle(x-size*0.25, y-size*0.25, size*0.15, face_color)) - + elements.append( + self._draw_circle( + x - size * 0.25, y - size * 0.25, size * 0.15, face_color + ) + ) + # Right ear - elements.append(self._draw_circle(x+size*0.25, y-size*0.25, size*0.15, face_color)) - + elements.append( + self._draw_circle( + x + size * 0.25, y - size * 0.25, size * 0.15, face_color + ) + ) + # Inner ears - elements.append(self._draw_circle(x-size*0.25, y-size*0.25, size*0.08, inner_color)) - elements.append(self._draw_circle(x+size*0.25, y-size*0.25, size*0.08, inner_color)) - + elements.append( + self._draw_circle( + x - size * 0.25, y - size * 0.25, size * 0.08, inner_color + ) + ) + elements.append( + self._draw_circle( + x + size * 0.25, y - size * 0.25, size * 0.08, inner_color + ) + ) + elif ear_style == "folded" or ear_style == "floppy": # Left ear - path_left = (f"M {x-size*0.15} {y-size*0.15} " - f"C {x-size*0.3} {y-size*0.4}, {x-size*0.4} {y-size*0.2}, {x-size*0.35} {y}") - elements.append(self._draw_path(path_left, face_color, stroke="none", stroke_width=size*0.08)) - + path_left = ( + f"M {x-size*0.15} {y-size*0.15} " + f"C {x-size*0.3} {y-size*0.4}, {x-size*0.4} {y-size*0.2}, {x-size*0.35} {y}" + ) + elements.append( + self._draw_path( + path_left, face_color, stroke="none", stroke_width=size * 0.08 + ) + ) + # Right ear - path_right = (f"M {x+size*0.15} {y-size*0.15} " - f"C {x+size*0.3} {y-size*0.4}, {x+size*0.4} {y-size*0.2}, {x+size*0.35} {y}") - elements.append(self._draw_path(path_right, face_color, stroke="none", stroke_width=size*0.08)) - + path_right = ( + f"M {x+size*0.15} {y-size*0.15} " + f"C {x+size*0.3} {y-size*0.4}, {x+size*0.4} {y-size*0.2}, {x+size*0.35} {y}" + ) + elements.append( + self._draw_path( + path_right, face_color, stroke="none", stroke_width=size * 0.08 + ) + ) + elif ear_style == "long" or ear_style == "standing": # Left ear points_left = f"{x-size*0.2},{y-size*0.15} {x-size*0.3},{y-size*0.6} {x-size*0.05},{y-size*0.15}" elements.append(self._draw_polygon(points_left, face_color)) - + # Right ear points_right = f"{x+size*0.2},{y-size*0.15} {x+size*0.3},{y-size*0.6} {x+size*0.05},{y-size*0.15}" elements.append(self._draw_polygon(points_right, face_color)) - + # Inner ears points_inner_left = f"{x-size*0.18},{y-size*0.18} {x-size*0.25},{y-size*0.5} {x-size*0.1},{y-size*0.18}" elements.append(self._draw_polygon(points_inner_left, inner_color)) - + points_inner_right = f"{x+size*0.18},{y-size*0.18} {x+size*0.25},{y-size*0.5} {x+size*0.1},{y-size*0.18}" elements.append(self._draw_polygon(points_inner_right, inner_color)) - + elif ear_style == "large": # Left ear - elements.append(self._draw_ellipse(x-size*0.25, y-size*0.3, size*0.2, size*0.25, face_color)) - + elements.append( + self._draw_ellipse( + x - size * 0.25, y - size * 0.3, size * 0.2, size * 0.25, face_color + ) + ) + # Right ear - elements.append(self._draw_ellipse(x+size*0.25, y-size*0.3, size*0.2, size*0.25, face_color)) - + elements.append( + self._draw_ellipse( + x + size * 0.25, y - size * 0.3, size * 0.2, size * 0.25, face_color + ) + ) + # Inner ears - elements.append(self._draw_ellipse(x-size*0.25, y-size*0.3, size*0.1, size*0.15, inner_color)) - elements.append(self._draw_ellipse(x+size*0.25, y-size*0.3, size*0.1, size*0.15, inner_color)) - + elements.append( + self._draw_ellipse( + x - size * 0.25, + y - size * 0.3, + size * 0.1, + size * 0.15, + inner_color, + ) + ) + elements.append( + self._draw_ellipse( + x + size * 0.25, + y - size * 0.3, + size * 0.1, + size * 0.15, + inner_color, + ) + ) + elif ear_style == "small": # Left ear - elements.append(self._draw_ellipse(x-size*0.22, y-size*0.22, size*0.1, size*0.12, face_color)) - + elements.append( + self._draw_ellipse( + x - size * 0.22, + y - size * 0.22, + size * 0.1, + size * 0.12, + face_color, + ) + ) + # Right ear - elements.append(self._draw_ellipse(x+size*0.22, y-size*0.22, size*0.1, size*0.12, face_color)) - + elements.append( + self._draw_ellipse( + x + size * 0.22, + y - size * 0.22, + size * 0.1, + size * 0.12, + face_color, + ) + ) + # Inner ears - elements.append(self._draw_ellipse(x-size*0.22, y-size*0.22, size*0.05, size*0.06, inner_color)) - elements.append(self._draw_ellipse(x+size*0.22, y-size*0.22, size*0.05, size*0.06, inner_color)) - + elements.append( + self._draw_ellipse( + x - size * 0.22, + y - size * 0.22, + size * 0.05, + size * 0.06, + inner_color, + ) + ) + elements.append( + self._draw_ellipse( + x + size * 0.22, + y - size * 0.22, + size * 0.05, + size * 0.06, + inner_color, + ) + ) + elif ear_style == "tufted" and animal == "owl": # Left ear tuft points_left = f"{x-size*0.2},{y-size*0.15} {x-size*0.35},{y-size*0.45} {x-size*0.05},{y-size*0.15}" elements.append(self._draw_polygon(points_left, face_color)) - + # Right ear tuft points_right = f"{x+size*0.2},{y-size*0.15} {x+size*0.35},{y-size*0.45} {x+size*0.05},{y-size*0.15}" elements.append(self._draw_polygon(points_right, face_color)) - + elif ear_style == "wide" and animal == "elephant": # Left ear - path_left = (f"M {x-size*0.15} {y-size*0.1} " - f"C {x-size*0.5} {y-size*0.2}, {x-size*0.6} {y+size*0.2}, {x-size*0.15} {y+size*0.2}") - elements.append(self._draw_path(path_left, face_color, stroke=face_color, stroke_width=size*0.04)) - + path_left = ( + f"M {x-size*0.15} {y-size*0.1} " + f"C {x-size*0.5} {y-size*0.2}, {x-size*0.6} {y+size*0.2}, {x-size*0.15} {y+size*0.2}" + ) + elements.append( + self._draw_path( + path_left, face_color, stroke=face_color, stroke_width=size * 0.04 + ) + ) + # Right ear - path_right = (f"M {x+size*0.15} {y-size*0.1} " - f"C {x+size*0.5} {y-size*0.2}, {x+size*0.6} {y+size*0.2}, {x+size*0.15} {y+size*0.2}") - elements.append(self._draw_path(path_right, face_color, stroke=face_color, stroke_width=size*0.04)) - + path_right = ( + f"M {x+size*0.15} {y-size*0.1} " + f"C {x+size*0.5} {y-size*0.2}, {x+size*0.6} {y+size*0.2}, {x+size*0.15} {y+size*0.2}" + ) + elements.append( + self._draw_path( + path_right, face_color, stroke=face_color, stroke_width=size * 0.04 + ) + ) + return "\n".join(elements) - - def _draw_eyes(self, eye_style: str, expression: str, eye_color: str, - x: float, y: float, size: float) -> str: + + def _draw_eyes( + self, + eye_style: str, + expression: str, + eye_color: str, + x: float, + y: float, + size: float, + ) -> str: """Draw the animal's eyes based on eye style and expression.""" elements = [] eye_spacing = size * 0.2 - + if eye_style == "round": eye_size = size * 0.08 # Left eye - elements.append(self._draw_circle(x - eye_spacing, y - size * 0.05, eye_size, "#FFFFFF")) - elements.append(self._draw_circle(x - eye_spacing, y - size * 0.05, eye_size * 0.6, eye_color)) - + elements.append( + self._draw_circle(x - eye_spacing, y - size * 0.05, eye_size, "#FFFFFF") + ) + elements.append( + self._draw_circle( + x - eye_spacing, y - size * 0.05, eye_size * 0.6, eye_color + ) + ) + # Right eye - elements.append(self._draw_circle(x + eye_spacing, y - size * 0.05, eye_size, "#FFFFFF")) - elements.append(self._draw_circle(x + eye_spacing, y - size * 0.05, eye_size * 0.6, eye_color)) - + elements.append( + self._draw_circle(x + eye_spacing, y - size * 0.05, eye_size, "#FFFFFF") + ) + elements.append( + self._draw_circle( + x + eye_spacing, y - size * 0.05, eye_size * 0.6, eye_color + ) + ) + elif eye_style == "oval": eye_width = size * 0.1 eye_height = size * 0.07 # Left eye - elements.append(self._draw_ellipse(x - eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF")) - elements.append(self._draw_ellipse(x - eye_spacing, y - size * 0.05, eye_width * 0.6, eye_height * 0.6, eye_color)) - + elements.append( + self._draw_ellipse( + x - eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF" + ) + ) + elements.append( + self._draw_ellipse( + x - eye_spacing, + y - size * 0.05, + eye_width * 0.6, + eye_height * 0.6, + eye_color, + ) + ) + # Right eye - elements.append(self._draw_ellipse(x + eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF")) - elements.append(self._draw_ellipse(x + eye_spacing, y - size * 0.05, eye_width * 0.6, eye_height * 0.6, eye_color)) - + elements.append( + self._draw_ellipse( + x + eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF" + ) + ) + elements.append( + self._draw_ellipse( + x + eye_spacing, + y - size * 0.05, + eye_width * 0.6, + eye_height * 0.6, + eye_color, + ) + ) + elif eye_style == "almond": # Left eye - almond shape - path_left = (f"M {x-eye_spacing-size*0.1} {y-size*0.05} " - f"C {x-eye_spacing} {y-size*0.12}, {x-eye_spacing} {y+size*0.02}, {x-eye_spacing+size*0.1} {y-size*0.05} " - f"C {x-eye_spacing} {y+size*0.02}, {x-eye_spacing} {y-size*0.12}, {x-eye_spacing-size*0.1} {y-size*0.05} Z") + path_left = ( + f"M {x-eye_spacing-size*0.1} {y-size*0.05} " + f"C {x-eye_spacing} {y-size*0.12}, {x-eye_spacing} {y+size*0.02}, {x-eye_spacing+size*0.1} {y-size*0.05} " + f"C {x-eye_spacing} {y+size*0.02}, {x-eye_spacing} {y-size*0.12}, {x-eye_spacing-size*0.1} {y-size*0.05} Z" + ) elements.append(self._draw_path(path_left, "#FFFFFF")) - elements.append(self._draw_circle(x - eye_spacing, y - size * 0.05, size * 0.05, eye_color)) - + elements.append( + self._draw_circle( + x - eye_spacing, y - size * 0.05, size * 0.05, eye_color + ) + ) + # Right eye - almond shape - path_right = (f"M {x+eye_spacing-size*0.1} {y-size*0.05} " - f"C {x+eye_spacing} {y-size*0.12}, {x+eye_spacing} {y+size*0.02}, {x+eye_spacing+size*0.1} {y-size*0.05} " - f"C {x+eye_spacing} {y+size*0.02}, {x+eye_spacing} {y-size*0.12}, {x+eye_spacing-size*0.1} {y-size*0.05} Z") + path_right = ( + f"M {x+eye_spacing-size*0.1} {y-size*0.05} " + f"C {x+eye_spacing} {y-size*0.12}, {x+eye_spacing} {y+size*0.02}, {x+eye_spacing+size*0.1} {y-size*0.05} " + f"C {x+eye_spacing} {y+size*0.02}, {x+eye_spacing} {y-size*0.12}, {x+eye_spacing-size*0.1} {y-size*0.05} Z" + ) elements.append(self._draw_path(path_right, "#FFFFFF")) - elements.append(self._draw_circle(x + eye_spacing, y - size * 0.05, size * 0.05, eye_color)) - + elements.append( + self._draw_circle( + x + eye_spacing, y - size * 0.05, size * 0.05, eye_color + ) + ) + elif eye_style == "wide": eye_width = size * 0.12 eye_height = size * 0.08 # Left eye - elements.append(self._draw_ellipse(x - eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF")) - elements.append(self._draw_ellipse(x - eye_spacing, y - size * 0.05, eye_width * 0.5, eye_height * 0.6, eye_color)) - + elements.append( + self._draw_ellipse( + x - eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF" + ) + ) + elements.append( + self._draw_ellipse( + x - eye_spacing, + y - size * 0.05, + eye_width * 0.5, + eye_height * 0.6, + eye_color, + ) + ) + # Right eye - elements.append(self._draw_ellipse(x + eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF")) - elements.append(self._draw_ellipse(x + eye_spacing, y - size * 0.05, eye_width * 0.5, eye_height * 0.6, eye_color)) - + elements.append( + self._draw_ellipse( + x + eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF" + ) + ) + elements.append( + self._draw_ellipse( + x + eye_spacing, + y - size * 0.05, + eye_width * 0.5, + eye_height * 0.6, + eye_color, + ) + ) + elif eye_style == "narrow": eye_width = size * 0.12 eye_height = size * 0.04 # Left eye - elements.append(self._draw_ellipse(x - eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF")) - elements.append(self._draw_ellipse(x - eye_spacing, y - size * 0.05, eye_width * 0.4, eye_height * 0.7, eye_color)) - + elements.append( + self._draw_ellipse( + x - eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF" + ) + ) + elements.append( + self._draw_ellipse( + x - eye_spacing, + y - size * 0.05, + eye_width * 0.4, + eye_height * 0.7, + eye_color, + ) + ) + # Right eye - elements.append(self._draw_ellipse(x + eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF")) - elements.append(self._draw_ellipse(x + eye_spacing, y - size * 0.05, eye_width * 0.4, eye_height * 0.7, eye_color)) - + elements.append( + self._draw_ellipse( + x + eye_spacing, y - size * 0.05, eye_width, eye_height, "#FFFFFF" + ) + ) + elements.append( + self._draw_ellipse( + x + eye_spacing, + y - size * 0.05, + eye_width * 0.4, + eye_height * 0.7, + eye_color, + ) + ) + elif eye_style == "cute": eye_size = size * 0.1 # Left eye - elements.append(self._draw_circle(x - eye_spacing, y - size * 0.05, eye_size, "#FFFFFF")) - elements.append(self._draw_circle(x - eye_spacing - size * 0.02, y - size * 0.07, eye_size * 0.5, eye_color)) - elements.append(self._draw_circle(x - eye_spacing - size * 0.03, y - size * 0.08, eye_size * 0.15, "#FFFFFF")) - + elements.append( + self._draw_circle(x - eye_spacing, y - size * 0.05, eye_size, "#FFFFFF") + ) + elements.append( + self._draw_circle( + x - eye_spacing - size * 0.02, + y - size * 0.07, + eye_size * 0.5, + eye_color, + ) + ) + elements.append( + self._draw_circle( + x - eye_spacing - size * 0.03, + y - size * 0.08, + eye_size * 0.15, + "#FFFFFF", + ) + ) + # Right eye - elements.append(self._draw_circle(x + eye_spacing, y - size * 0.05, eye_size, "#FFFFFF")) - elements.append(self._draw_circle(x + eye_spacing - size * 0.02, y - size * 0.07, eye_size * 0.5, eye_color)) - elements.append(self._draw_circle(x + eye_spacing - size * 0.03, y - size * 0.08, eye_size * 0.15, "#FFFFFF")) - + elements.append( + self._draw_circle(x + eye_spacing, y - size * 0.05, eye_size, "#FFFFFF") + ) + elements.append( + self._draw_circle( + x + eye_spacing - size * 0.02, + y - size * 0.07, + eye_size * 0.5, + eye_color, + ) + ) + elements.append( + self._draw_circle( + x + eye_spacing - size * 0.03, + y - size * 0.08, + eye_size * 0.15, + "#FFFFFF", + ) + ) + # Apply expression if expression == "happy": # Close bottom half of eyes slightly - path_left = (f"M {x-eye_spacing-size*0.1} {y-size*0.03} " - f"Q {x-eye_spacing} {y+size*0.02}, {x-eye_spacing+size*0.1} {y-size*0.03}") - elements.append(self._draw_path(path_left, face_color, stroke="#000000", stroke_width=size*0.01)) - - path_right = (f"M {x+eye_spacing-size*0.1} {y-size*0.03} " - f"Q {x+eye_spacing} {y+size*0.02}, {x+eye_spacing+size*0.1} {y-size*0.03}") - elements.append(self._draw_path(path_right, face_color, stroke="#000000", stroke_width=size*0.01)) - + path_left = ( + f"M {x-eye_spacing-size*0.1} {y-size*0.03} " + f"Q {x-eye_spacing} {y+size*0.02}, {x-eye_spacing+size*0.1} {y-size*0.03}" + ) + elements.append( + self._draw_path( + path_left, face_color, stroke="#000000", stroke_width=size * 0.01 + ) + ) + + path_right = ( + f"M {x+eye_spacing-size*0.1} {y-size*0.03} " + f"Q {x+eye_spacing} {y+size*0.02}, {x+eye_spacing+size*0.1} {y-size*0.03}" + ) + elements.append( + self._draw_path( + path_right, face_color, stroke="#000000", stroke_width=size * 0.01 + ) + ) + elif expression == "serious": # Serious eyebrows - path_left = (f"M {x-eye_spacing-size*0.08} {y-size*0.12} " - f"L {x-eye_spacing+size*0.08} {y-size*0.15}") - elements.append(self._draw_path(path_left, "none", stroke="#000000", stroke_width=size*0.02)) - - path_right = (f"M {x+eye_spacing-size*0.08} {y-size*0.15} " - f"L {x+eye_spacing+size*0.08} {y-size*0.12}") - elements.append(self._draw_path(path_right, "none", stroke="#000000", stroke_width=size*0.02)) - + path_left = ( + f"M {x-eye_spacing-size*0.08} {y-size*0.12} " + f"L {x-eye_spacing+size*0.08} {y-size*0.15}" + ) + elements.append( + self._draw_path( + path_left, "none", stroke="#000000", stroke_width=size * 0.02 + ) + ) + + path_right = ( + f"M {x+eye_spacing-size*0.08} {y-size*0.15} " + f"L {x+eye_spacing+size*0.08} {y-size*0.12}" + ) + elements.append( + self._draw_path( + path_right, "none", stroke="#000000", stroke_width=size * 0.02 + ) + ) + elif expression == "surprised": # Raise eyebrows - path_left = (f"M {x-eye_spacing-size*0.08} {y-size*0.15} " - f"Q {x-eye_spacing} {y-size*0.18}, {x-eye_spacing+size*0.08} {y-size*0.15}") - elements.append(self._draw_path(path_left, "none", stroke="#000000", stroke_width=size*0.015)) - - path_right = (f"M {x+eye_spacing-size*0.08} {y-size*0.15} " - f"Q {x+eye_spacing} {y-size*0.18}, {x+eye_spacing+size*0.08} {y-size*0.15}") - elements.append(self._draw_path(path_right, "none", stroke="#000000", stroke_width=size*0.015)) - + path_left = ( + f"M {x-eye_spacing-size*0.08} {y-size*0.15} " + f"Q {x-eye_spacing} {y-size*0.18}, {x-eye_spacing+size*0.08} {y-size*0.15}" + ) + elements.append( + self._draw_path( + path_left, "none", stroke="#000000", stroke_width=size * 0.015 + ) + ) + + path_right = ( + f"M {x+eye_spacing-size*0.08} {y-size*0.15} " + f"Q {x+eye_spacing} {y-size*0.18}, {x+eye_spacing+size*0.08} {y-size*0.15}" + ) + elements.append( + self._draw_path( + path_right, "none", stroke="#000000", stroke_width=size * 0.015 + ) + ) + elif expression == "sleepy": # Half-closed eyes - path_left = (f"M {x-eye_spacing-size*0.1} {y-size*0.08} " - f"Q {x-eye_spacing} {y-size*0.01}, {x-eye_spacing+size*0.1} {y-size*0.08}") - elements.append(self._draw_path(path_left, "none", stroke="#000000", stroke_width=size*0.015)) - - path_right = (f"M {x+eye_spacing-size*0.1} {y-size*0.08} " - f"Q {x+eye_spacing} {y-size*0.01}, {x+eye_spacing+size*0.1} {y-size*0.08}") - elements.append(self._draw_path(path_right, "none", stroke="#000000", stroke_width=size*0.015)) - + path_left = ( + f"M {x-eye_spacing-size*0.1} {y-size*0.08} " + f"Q {x-eye_spacing} {y-size*0.01}, {x-eye_spacing+size*0.1} {y-size*0.08}" + ) + elements.append( + self._draw_path( + path_left, "none", stroke="#000000", stroke_width=size * 0.015 + ) + ) + + path_right = ( + f"M {x+eye_spacing-size*0.1} {y-size*0.08} " + f"Q {x+eye_spacing} {y-size*0.01}, {x+eye_spacing+size*0.1} {y-size*0.08}" + ) + elements.append( + self._draw_path( + path_right, "none", stroke="#000000", stroke_width=size * 0.015 + ) + ) + elif expression == "wink": # Right eye normal # Left eye winking - path_left = (f"M {x-eye_spacing-size*0.1} {y-size*0.05} " - f"Q {x-eye_spacing} {y-size*0.1}, {x-eye_spacing+size*0.1} {y-size*0.05}") - elements.append(self._draw_path(path_left, "none", stroke="#000000", stroke_width=size*0.02)) - + path_left = ( + f"M {x-eye_spacing-size*0.1} {y-size*0.05} " + f"Q {x-eye_spacing} {y-size*0.1}, {x-eye_spacing+size*0.1} {y-size*0.05}" + ) + elements.append( + self._draw_path( + path_left, "none", stroke="#000000", stroke_width=size * 0.02 + ) + ) + return "\n".join(elements) - - def _draw_nose(self, animal: str, nose_style: str, nose_color: str, - x: float, y: float, size: float) -> str: + + def _draw_nose( + self, + animal: str, + nose_style: str, + nose_color: str, + x: float, + y: float, + size: float, + ) -> str: """Draw the animal's nose based on nose style.""" elements = [] - + if nose_style == "round": - elements.append(self._draw_circle(x, y + size * 0.05, size * 0.08, nose_color)) - + elements.append( + self._draw_circle(x, y + size * 0.05, size * 0.08, nose_color) + ) + elif nose_style == "triangular": points = f"{x},{y+size*0.02} {x-size*0.08},{y+size*0.12} {x+size*0.08},{y+size*0.12}" elements.append(self._draw_polygon(points, nose_color)) - + elif nose_style == "small": - elements.append(self._draw_circle(x, y + size * 0.05, size * 0.05, nose_color)) - + elements.append( + self._draw_circle(x, y + size * 0.05, size * 0.05, nose_color) + ) + elif nose_style == "large": if animal in ["dog", "bear", "panda", "koala"]: - elements.append(self._draw_ellipse(x, y + size * 0.05, size * 0.1, size * 0.08, nose_color)) + elements.append( + self._draw_ellipse( + x, y + size * 0.05, size * 0.1, size * 0.08, nose_color + ) + ) else: - elements.append(self._draw_circle(x, y + size * 0.05, size * 0.1, nose_color)) - + elements.append( + self._draw_circle(x, y + size * 0.05, size * 0.1, nose_color) + ) + elif nose_style == "heart": # Heart-shaped nose cx, cy = x, y + size * 0.05 r = size * 0.06 - path = (f"M {cx} {cy-r*0.2} " - f"C {cx-r*1.5} {cy-r*1.2}, {cx-r*1.8} {cy+r*0.6}, {cx} {cy+r*0.8} " - f"C {cx+r*1.8} {cy+r*0.6}, {cx+r*1.5} {cy-r*1.2}, {cx} {cy-r*0.2} Z") + path = ( + f"M {cx} {cy-r*0.2} " + f"C {cx-r*1.5} {cy-r*1.2}, {cx-r*1.8} {cy+r*0.6}, {cx} {cy+r*0.8} " + f"C {cx+r*1.8} {cy+r*0.6}, {cx+r*1.5} {cy-r*1.2}, {cx} {cy-r*0.2} Z" + ) elements.append(self._draw_path(path, nose_color)) - + elif nose_style == "button": - elements.append(self._draw_circle(x, y + size * 0.05, size * 0.06, nose_color)) - elements.append(self._draw_circle(x, y + size * 0.05, size * 0.04, nose_color, stroke="#000000", stroke_width=size*0.01)) - + elements.append( + self._draw_circle(x, y + size * 0.05, size * 0.06, nose_color) + ) + elements.append( + self._draw_circle( + x, + y + size * 0.05, + size * 0.04, + nose_color, + stroke="#000000", + stroke_width=size * 0.01, + ) + ) + return "\n".join(elements) - + def _draw_mouth(self, expression: str, x: float, y: float, size: float) -> str: """Draw the animal's mouth based on expression.""" elements = [] - + if expression == "happy": - path = (f"M {x-size*0.15} {y+size*0.12} " - f"Q {x} {y+size*0.25}, {x+size*0.15} {y+size*0.12}") - elements.append(self._draw_path(path, "none", stroke="#000000", stroke_width=size*0.02)) - + path = ( + f"M {x-size*0.15} {y+size*0.12} " + f"Q {x} {y+size*0.25}, {x+size*0.15} {y+size*0.12}" + ) + elements.append( + self._draw_path( + path, "none", stroke="#000000", stroke_width=size * 0.02 + ) + ) + elif expression == "serious": - path = (f"M {x-size*0.12} {y+size*0.15} " - f"L {x+size*0.12} {y+size*0.15}") - elements.append(self._draw_path(path, "none", stroke="#000000", stroke_width=size*0.02)) - + path = f"M {x-size*0.12} {y+size*0.15} " f"L {x+size*0.12} {y+size*0.15}" + elements.append( + self._draw_path( + path, "none", stroke="#000000", stroke_width=size * 0.02 + ) + ) + elif expression == "surprised": - elements.append(self._draw_ellipse(x, y + size * 0.15, size * 0.06, size * 0.08, "#FFFFFF", - stroke="#000000", stroke_width=size*0.01)) - + elements.append( + self._draw_ellipse( + x, + y + size * 0.15, + size * 0.06, + size * 0.08, + "#FFFFFF", + stroke="#000000", + stroke_width=size * 0.01, + ) + ) + elif expression == "sleepy": - path = (f"M {x-size*0.08} {y+size*0.15} " - f"Q {x} {y+size*0.12}, {x+size*0.08} {y+size*0.15}") - elements.append(self._draw_path(path, "none", stroke="#000000", stroke_width=size*0.015)) - + path = ( + f"M {x-size*0.08} {y+size*0.15} " + f"Q {x} {y+size*0.12}, {x+size*0.08} {y+size*0.15}" + ) + elements.append( + self._draw_path( + path, "none", stroke="#000000", stroke_width=size * 0.015 + ) + ) + elif expression == "wink": - path = (f"M {x-size*0.15} {y+size*0.12} " - f"Q {x} {y+size*0.22}, {x+size*0.15} {y+size*0.12}") - elements.append(self._draw_path(path, "none", stroke="#000000", stroke_width=size*0.02)) - + path = ( + f"M {x-size*0.15} {y+size*0.12} " + f"Q {x} {y+size*0.22}, {x+size*0.15} {y+size*0.12}" + ) + elements.append( + self._draw_path( + path, "none", stroke="#000000", stroke_width=size * 0.02 + ) + ) + return "\n".join(elements) - - def _draw_special_features(self, animal: str, special_feature: str, face_color: str, - accent_color: str, x: float, y: float, size: float) -> str: + + def _draw_special_features( + self, + animal: str, + special_feature: str, + face_color: str, + accent_color: str, + x: float, + y: float, + size: float, + ) -> str: """Draw special features based on animal and feature type.""" elements = [] - + if special_feature == "none": return "" - + elif special_feature == "whiskers": # Left whiskers for i in range(3): @@ -503,36 +982,58 @@ class AnimalAvatarGenerator: length = size * 0.25 end_x = x - size * 0.15 + length * math.cos(math.radians(angle)) end_y = y + size * 0.05 + length * math.sin(math.radians(angle)) - elements.append(self._draw_path( - f"M {x-size*0.15} {y+size*0.05} L {end_x} {end_y}", - "none", stroke="#000000", stroke_width=size*0.01)) - + elements.append( + self._draw_path( + f"M {x-size*0.15} {y+size*0.05} L {end_x} {end_y}", + "none", + stroke="#000000", + stroke_width=size * 0.01, + ) + ) + # Right whiskers for i in range(3): angle = -150 + i * 30 length = size * 0.25 end_x = x + size * 0.15 + length * math.cos(math.radians(angle)) end_y = y + size * 0.05 + length * math.sin(math.radians(angle)) - elements.append(self._draw_path( - f"M {x+size*0.15} {y+size*0.05} L {end_x} {end_y}", - "none", stroke="#000000", stroke_width=size*0.01)) - + elements.append( + self._draw_path( + f"M {x+size*0.15} {y+size*0.05} L {end_x} {end_y}", + "none", + stroke="#000000", + stroke_width=size * 0.01, + ) + ) + elif special_feature == "stripes": # Vertical stripes if animal == "tiger": for i in range(3): offset = -size * 0.2 + i * size * 0.2 - path = (f"M {x+offset} {y-size*0.3} " - f"Q {x+offset+size*0.1} {y}, {x+offset} {y+size*0.3}") - elements.append(self._draw_path(path, "none", stroke=accent_color, stroke_width=size*0.04)) + path = ( + f"M {x+offset} {y-size*0.3} " + f"Q {x+offset+size*0.1} {y}, {x+offset} {y+size*0.3}" + ) + elements.append( + self._draw_path( + path, "none", stroke=accent_color, stroke_width=size * 0.04 + ) + ) # Horizontal stripes for zebra elif animal == "zebra": for i in range(3): offset = -size * 0.2 + i * size * 0.2 - path = (f"M {x-size*0.3} {y+offset} " - f"Q {x} {y+offset+size*0.1}, {x+size*0.3} {y+offset}") - elements.append(self._draw_path(path, "none", stroke=accent_color, stroke_width=size*0.04)) - + path = ( + f"M {x-size*0.3} {y+offset} " + f"Q {x} {y+offset+size*0.1}, {x+size*0.3} {y+offset}" + ) + elements.append( + self._draw_path( + path, "none", stroke=accent_color, stroke_width=size * 0.04 + ) + ) + elif special_feature == "spots": # Random spots num_spots = random.randint(3, 6) @@ -540,41 +1041,75 @@ class AnimalAvatarGenerator: spot_x = x + random.uniform(-size * 0.3, size * 0.3) spot_y = y + random.uniform(-size * 0.3, size * 0.3) spot_size = random.uniform(size * 0.03, size * 0.08) - elements.append(self._draw_circle(spot_x, spot_y, spot_size, accent_color)) - + elements.append( + self._draw_circle(spot_x, spot_y, spot_size, accent_color) + ) + elif special_feature == "patch": # Eye patch or face patch if animal == "dog": - elements.append(self._draw_ellipse(x - size * 0.2, y, size * 0.2, size * 0.25, accent_color)) + elements.append( + self._draw_ellipse( + x - size * 0.2, y, size * 0.2, size * 0.25, accent_color + ) + ) else: # Generic face patch - elements.append(self._draw_ellipse(x, y + size * 0.2, size * 0.2, size * 0.15, accent_color)) - + elements.append( + self._draw_ellipse( + x, y + size * 0.2, size * 0.2, size * 0.15, accent_color + ) + ) + elif special_feature == "mask": if animal == "raccoon": # Raccoon mask - path = (f"M {x-size*0.3} {y-size*0.1} " - f"Q {x} {y-size*0.3}, {x+size*0.3} {y-size*0.1} " - f"Q {x+size*0.2} {y+size*0.1}, {x} {y+size*0.15} " - f"Q {x-size*0.2} {y+size*0.1}, {x-size*0.3} {y-size*0.1} Z") + path = ( + f"M {x-size*0.3} {y-size*0.1} " + f"Q {x} {y-size*0.3}, {x+size*0.3} {y-size*0.1} " + f"Q {x+size*0.2} {y+size*0.1}, {x} {y+size*0.15} " + f"Q {x-size*0.2} {y+size*0.1}, {x-size*0.3} {y-size*0.1} Z" + ) elements.append(self._draw_path(path, accent_color)) elif animal in ["fox", "wolf"]: # Fox/wolf mask - path = (f"M {x-size*0.3} {y-size*0.1} " - f"L {x} {y+size*0.1} " - f"L {x+size*0.3} {y-size*0.1} " - f"Q {x+size*0.15} {y-size*0.05}, {x} {y-size*0.1} " - f"Q {x-size*0.15} {y-size*0.05}, {x-size*0.3} {y-size*0.1} Z") + path = ( + f"M {x-size*0.3} {y-size*0.1} " + f"L {x} {y+size*0.1} " + f"L {x+size*0.3} {y-size*0.1} " + f"Q {x+size*0.15} {y-size*0.05}, {x} {y-size*0.1} " + f"Q {x-size*0.15} {y-size*0.05}, {x-size*0.3} {y-size*0.1} Z" + ) elements.append(self._draw_path(path, accent_color)) - + elif special_feature == "eye_patches" and animal == "panda": # Panda eye patches - elements.append(self._draw_ellipse(x - size * 0.2, y - size * 0.05, size * 0.15, size * 0.2, accent_color)) - elements.append(self._draw_ellipse(x + size * 0.2, y - size * 0.05, size * 0.15, size * 0.2, accent_color)) - + elements.append( + self._draw_ellipse( + x - size * 0.2, + y - size * 0.05, + size * 0.15, + size * 0.2, + accent_color, + ) + ) + elements.append( + self._draw_ellipse( + x + size * 0.2, + y - size * 0.05, + size * 0.15, + size * 0.2, + accent_color, + ) + ) + elif special_feature == "nose_patch": - elements.append(self._draw_ellipse(x, y + size * 0.05, size * 0.12, size * 0.1, accent_color)) - + elements.append( + self._draw_ellipse( + x, y + size * 0.05, size * 0.12, size * 0.1, accent_color + ) + ) + elif special_feature == "mane" and animal == "lion": # Lion mane for i in range(12): @@ -583,83 +1118,135 @@ class AnimalAvatarGenerator: outer_y = y + size * 0.5 * math.sin(math.radians(angle)) inner_x = x + size * 0.3 * math.cos(math.radians(angle)) inner_y = y + size * 0.3 * math.sin(math.radians(angle)) - + # Draw mane sections - path = (f"M {inner_x} {inner_y} " - f"L {outer_x} {outer_y} " - f"A {size*0.5} {size*0.5} 0 0 1 " - f"{x + size * 0.5 * math.cos(math.radians(angle + 30))} " - f"{y + size * 0.5 * math.sin(math.radians(angle + 30))} " - f"L {x + size * 0.3 * math.cos(math.radians(angle + 30))} " - f"{y + size * 0.3 * math.sin(math.radians(angle + 30))} Z") + path = ( + f"M {inner_x} {inner_y} " + f"L {outer_x} {outer_y} " + f"A {size*0.5} {size*0.5} 0 0 1 " + f"{x + size * 0.5 * math.cos(math.radians(angle + 30))} " + f"{y + size * 0.5 * math.sin(math.radians(angle + 30))} " + f"L {x + size * 0.3 * math.cos(math.radians(angle + 30))} " + f"{y + size * 0.3 * math.sin(math.radians(angle + 30))} Z" + ) elements.append(self._draw_path(path, accent_color)) - + elif special_feature == "tusks" and animal == "elephant": # Elephant tusks - path_left = (f"M {x-size*0.15} {y+size*0.1} " - f"Q {x-size*0.3} {y+size*0.3}, {x-size*0.35} {y+size*0.5}") - elements.append(self._draw_path(path_left, "#FFFFF0", stroke="#F5F5DC", stroke_width=size*0.04)) - - path_right = (f"M {x+size*0.15} {y+size*0.1} " - f"Q {x+size*0.3} {y+size*0.3}, {x+size*0.35} {y+size*0.5}") - elements.append(self._draw_path(path_right, "#FFFFF0", stroke="#F5F5DC", stroke_width=size*0.04)) - + path_left = ( + f"M {x-size*0.15} {y+size*0.1} " + f"Q {x-size*0.3} {y+size*0.3}, {x-size*0.35} {y+size*0.5}" + ) + elements.append( + self._draw_path( + path_left, "#FFFFF0", stroke="#F5F5DC", stroke_width=size * 0.04 + ) + ) + + path_right = ( + f"M {x+size*0.15} {y+size*0.1} " + f"Q {x+size*0.3} {y+size*0.3}, {x+size*0.35} {y+size*0.5}" + ) + elements.append( + self._draw_path( + path_right, "#FFFFF0", stroke="#F5F5DC", stroke_width=size * 0.04 + ) + ) + elif special_feature == "antlers" and animal == "deer": # Deer antlers # Left antler - path_left = (f"M {x-size*0.15} {y-size*0.2} " - f"L {x-size*0.3} {y-size*0.45} " - f"L {x-size*0.4} {y-size*0.4} " - f"M {x-size*0.3} {y-size*0.45} " - f"L {x-size*0.2} {y-size*0.5}") - elements.append(self._draw_path(path_left, "none", stroke=accent_color, stroke_width=size*0.03)) - + path_left = ( + f"M {x-size*0.15} {y-size*0.2} " + f"L {x-size*0.3} {y-size*0.45} " + f"L {x-size*0.4} {y-size*0.4} " + f"M {x-size*0.3} {y-size*0.45} " + f"L {x-size*0.2} {y-size*0.5}" + ) + elements.append( + self._draw_path( + path_left, "none", stroke=accent_color, stroke_width=size * 0.03 + ) + ) + # Right antler - path_right = (f"M {x+size*0.15} {y-size*0.2} " - f"L {x+size*0.3} {y-size*0.45} " - f"L {x+size*0.4} {y-size*0.4} " - f"M {x+size*0.3} {y-size*0.45} " - f"L {x+size*0.2} {y-size*0.5}") - elements.append(self._draw_path(path_right, "none", stroke=accent_color, stroke_width=size*0.03)) - + path_right = ( + f"M {x+size*0.15} {y-size*0.2} " + f"L {x+size*0.3} {y-size*0.45} " + f"L {x+size*0.4} {y-size*0.4} " + f"M {x+size*0.3} {y-size*0.45} " + f"L {x+size*0.2} {y-size*0.5}" + ) + elements.append( + self._draw_path( + path_right, "none", stroke=accent_color, stroke_width=size * 0.03 + ) + ) + elif special_feature == "bushy_tail" and animal == "squirrel": # Squirrel bushy tail - path = (f"M {x+size*0.1} {y+size*0.2} " - f"Q {x+size*0.5} {y}, {x+size*0.3} {y-size*0.3} " - f"Q {x+size*0.4} {y-size*0.4}, {x+size*0.5} {y-size*0.35} " - f"Q {x+size*0.45} {y-size*0.25}, {x+size*0.6} {y-size*0.3} " - f"Q {x+size*0.55} {y-size*0.1}, {x+size*0.4} {y+size*0.1} " - f"Q {x+size*0.3} {y+size*0.3}, {x+size*0.1} {y+size*0.2} Z") + path = ( + f"M {x+size*0.1} {y+size*0.2} " + f"Q {x+size*0.5} {y}, {x+size*0.3} {y-size*0.3} " + f"Q {x+size*0.4} {y-size*0.4}, {x+size*0.5} {y-size*0.35} " + f"Q {x+size*0.45} {y-size*0.25}, {x+size*0.6} {y-size*0.3} " + f"Q {x+size*0.55} {y-size*0.1}, {x+size*0.4} {y+size*0.1} " + f"Q {x+size*0.3} {y+size*0.3}, {x+size*0.1} {y+size*0.2} Z" + ) elements.append(self._draw_path(path, accent_color)) - + elif special_feature == "brush_tail" and animal in ["fox", "wolf"]: # Fox/wolf brush tail - path = (f"M {x+size*0.1} {y+size*0.2} " - f"Q {x+size*0.4} {y+size*0.1}, {x+size*0.5} {y-size*0.1} " - f"Q {x+size*0.6} {y-size*0.2}, {x+size*0.7} {y-size*0.1} " - f"Q {x+size*0.65} {y}, {x+size*0.6} {y+size*0.1} " - f"Q {x+size*0.5} {y+size*0.2}, {x+size*0.3} {y+size*0.3} Z") + path = ( + f"M {x+size*0.1} {y+size*0.2} " + f"Q {x+size*0.4} {y+size*0.1}, {x+size*0.5} {y-size*0.1} " + f"Q {x+size*0.6} {y-size*0.2}, {x+size*0.7} {y-size*0.1} " + f"Q {x+size*0.65} {y}, {x+size*0.6} {y+size*0.1} " + f"Q {x+size*0.5} {y+size*0.2}, {x+size*0.3} {y+size*0.3} Z" + ) elements.append(self._draw_path(path, face_color)) # Tail tip - elements.append(self._draw_ellipse(x + size * 0.6, y - size * 0.05, size * 0.12, size * 0.08, accent_color)) - + elements.append( + self._draw_ellipse( + x + size * 0.6, + y - size * 0.05, + size * 0.12, + size * 0.08, + accent_color, + ) + ) + elif special_feature == "bib" and animal == "penguin": # Penguin bib/chest - path = (f"M {x-size*0.2} {y} " - f"Q {x} {y+size*0.4}, {x+size*0.2} {y} " - f"Q {x} {y+size*0.1}, {x-size*0.2} {y} Z") + path = ( + f"M {x-size*0.2} {y} " + f"Q {x} {y+size*0.4}, {x+size*0.2} {y} " + f"Q {x} {y+size*0.1}, {x-size*0.2} {y} Z" + ) elements.append(self._draw_path(path, "#FFFFFF")) - + elif special_feature == "feather_tufts" and animal == "owl": # Owl feather tufts - path_left = (f"M {x-size*0.1} {y-size*0.3} " - f"Q {x-size*0.15} {y-size*0.45}, {x-size*0.05} {y-size*0.5}") - elements.append(self._draw_path(path_left, face_color, stroke="#000000", stroke_width=size*0.01)) - - path_right = (f"M {x+size*0.1} {y-size*0.3} " - f"Q {x+size*0.15} {y-size*0.45}, {x+size*0.05} {y-size*0.5}") - elements.append(self._draw_path(path_right, face_color, stroke="#000000", stroke_width=size*0.01)) - + path_left = ( + f"M {x-size*0.1} {y-size*0.3} " + f"Q {x-size*0.15} {y-size*0.45}, {x-size*0.05} {y-size*0.5}" + ) + elements.append( + self._draw_path( + path_left, face_color, stroke="#000000", stroke_width=size * 0.01 + ) + ) + + path_right = ( + f"M {x+size*0.1} {y-size*0.3} " + f"Q {x+size*0.15} {y-size*0.45}, {x+size*0.05} {y-size*0.5}" + ) + elements.append( + self._draw_path( + path_right, face_color, stroke="#000000", stroke_width=size * 0.01 + ) + ) + elif special_feature == "spikes" and animal == "hedgehog": # Hedgehog spikes for i in range(12): @@ -668,25 +1255,44 @@ class AnimalAvatarGenerator: inner_y = y + size * 0.3 * math.sin(math.radians(angle)) outer_x = x + size * 0.5 * math.cos(math.radians(angle)) outer_y = y + size * 0.5 * math.sin(math.radians(angle)) - + path = f"M {inner_x} {inner_y} L {outer_x} {outer_y}" - elements.append(self._draw_path(path, "none", stroke=accent_color, stroke_width=size*0.03)) - + elements.append( + self._draw_path( + path, "none", stroke=accent_color, stroke_width=size * 0.03 + ) + ) + elif special_feature == "cheek_patches" and animal == "monkey": # Monkey cheek patches - elements.append(self._draw_circle(x - size * 0.25, y + size * 0.1, size * 0.12, accent_color)) - elements.append(self._draw_circle(x + size * 0.25, y + size * 0.1, size * 0.12, accent_color)) - + elements.append( + self._draw_circle( + x - size * 0.25, y + size * 0.1, size * 0.12, accent_color + ) + ) + elements.append( + self._draw_circle( + x + size * 0.25, y + size * 0.1, size * 0.12, accent_color + ) + ) + return "\n".join(elements) - - def generate_avatar(self, animal: Optional[str] = None, color_palette: str = "natural", - face_shape: Optional[str] = None, eye_style: Optional[str] = None, - ear_style: Optional[str] = None, nose_style: Optional[str] = None, - expression: Optional[str] = None, special_feature: Optional[str] = None, - size: int = 500) -> str: + + def generate_avatar( + self, + animal: Optional[str] = None, + color_palette: str = "natural", + face_shape: Optional[str] = None, + eye_style: Optional[str] = None, + ear_style: Optional[str] = None, + nose_style: Optional[str] = None, + expression: Optional[str] = None, + special_feature: Optional[str] = None, + size: int = 500, + ) -> str: """ Generate an animal avatar with the specified parameters. - + Args: animal: Animal type (e.g., "cat", "dog"). If None, a random animal is selected. color_palette: Color palette to use (e.g., "natural", "pastel", "vibrant", "mono"). @@ -697,95 +1303,115 @@ class AnimalAvatarGenerator: expression: Facial expression. If None, a random expression is selected. special_feature: Special feature to add. If None, a random feature is selected for the animal. size: Size of the avatar in pixels. - + Returns: SVG string representation of the generated avatar. """ # Select random animal if not specified if animal is None or animal not in self.ANIMALS: animal = random.choice(self.ANIMALS) - + # Select random options if not specified if face_shape is None or face_shape not in self.FACE_SHAPES: face_shape = random.choice(self.FACE_SHAPES) - + if eye_style is None or eye_style not in self.EYE_STYLES: eye_style = random.choice(self.EYE_STYLES) - + if ear_style is None: ear_style = self._get_ear_style(animal) - + if nose_style is None or nose_style not in self.NOSE_STYLES: nose_style = random.choice(self.NOSE_STYLES) - + if expression is None or expression not in self.EXPRESSIONS: expression = random.choice(self.EXPRESSIONS) - + if special_feature is None: special_feature = self._get_special_feature(animal) - + # Get colors colors = self._get_colors(animal, color_palette) face_color = random.choice(colors) - + # Make sure accent color is different from face color remaining_colors = [c for c in colors if c != face_color] if not remaining_colors: remaining_colors = ["#000000", "#FFFFFF"] accent_color = random.choice(remaining_colors) - + # Ensure inner ear color is different from face color inner_ear_color = random.choice(remaining_colors) - + # Eye color options eye_colors = ["#000000", "#331800", "#0000FF", "#008000", "#FFA500", "#800080"] eye_color = random.choice(eye_colors) - + # Nose color options based on animal - if animal in ["dog", "cat", "fox", "wolf", "bear", "panda", "koala", "tiger", "lion"]: + if animal in [ + "dog", + "cat", + "fox", + "wolf", + "bear", + "panda", + "koala", + "tiger", + "lion", + ]: nose_color = "#000000" else: nose_color = accent_color - + # Center coordinates x, y = size / 2, size / 2 - + # Generate SVG elements elements = [] - + # Draw the face first elements.append(self._draw_face(animal, face_shape, face_color, x, y, size)) - + # Draw special features behind the face if needed if special_feature in ["mane", "spikes"]: - elements.append(self._draw_special_features(animal, special_feature, face_color, - accent_color, x, y, size)) - + elements.append( + self._draw_special_features( + animal, special_feature, face_color, accent_color, x, y, size + ) + ) + # Draw ears - elements.append(self._draw_ears(animal, ear_style, face_color, inner_ear_color, x, y, size)) - + elements.append( + self._draw_ears(animal, ear_style, face_color, inner_ear_color, x, y, size) + ) + # Draw eyes elements.append(self._draw_eyes(eye_style, expression, eye_color, x, y, size)) - + # Draw nose elements.append(self._draw_nose(animal, nose_style, nose_color, x, y, size)) - + # Draw mouth elements.append(self._draw_mouth(expression, x, y, size)) - + # Draw special features that should be in front if special_feature not in ["mane", "spikes"]: - elements.append(self._draw_special_features(animal, special_feature, face_color, - accent_color, x, y, size)) - + elements.append( + self._draw_special_features( + animal, special_feature, face_color, accent_color, x, y, size + ) + ) + # Assemble SVG - svg_content = '\n'.join(elements) - svg = (f'\n' - f'{svg_content}\n' - f'') - + svg_content = "\n".join(elements) + svg = ( + f'\n' + f"{svg_content}\n" + f"" + ) + return svg - + def get_avatar_options(self) -> Dict[str, List[str]]: """Return all available avatar options.""" return { @@ -793,20 +1419,21 @@ class AnimalAvatarGenerator: "color_palettes": list(self.COLOR_PALETTES.keys()), "face_shapes": self.FACE_SHAPES, "eye_styles": self.EYE_STYLES, - "ear_styles": {animal: styles for animal, styles in self.EAR_STYLES.items()}, + "ear_styles": dict(self.EAR_STYLES.items()), "nose_styles": self.NOSE_STYLES, "expressions": self.EXPRESSIONS, - "special_features": {animal: features for animal, features in self.SPECIAL_FEATURES.items()} + "special_features": dict(self.SPECIAL_FEATURES.items()), } - + def generate_random_avatar(self, size: int = 500) -> str: """Generate a completely random avatar.""" return self.generate_avatar(size=size) + def generate_avatar_with_options(options: Dict) -> str: """Generate an avatar with the given options.""" generator = AnimalAvatarGenerator(seed=options.get("seed")) - + return generator.generate_avatar( animal=options.get("animal"), color_palette=options.get("color_palette", "natural"), @@ -816,40 +1443,58 @@ def generate_avatar_with_options(options: Dict) -> str: nose_style=options.get("nose_style"), expression=options.get("expression"), special_feature=options.get("special_feature"), - size=options.get("size", 500) + size=options.get("size", 500), ) + def list_avatar_options() -> Dict[str, List[str]]: """Return all available avatar options.""" generator = AnimalAvatarGenerator() return generator.get_avatar_options() + def create_avatar_app(): """Command-line interface for the avatar generator.""" parser = argparse.ArgumentParser(description="Generate animal avatars") - parser.add_argument("--animal", help="Animal type", choices=AnimalAvatarGenerator.ANIMALS) - parser.add_argument("--color-palette", help="Color palette", default="natural", - choices=["natural", "pastel", "vibrant", "mono"]) - parser.add_argument("--face-shape", help="Face shape", choices=AnimalAvatarGenerator.FACE_SHAPES) - parser.add_argument("--eye-style", help="Eye style", choices=AnimalAvatarGenerator.EYE_STYLES) + parser.add_argument( + "--animal", help="Animal type", choices=AnimalAvatarGenerator.ANIMALS + ) + parser.add_argument( + "--color-palette", + help="Color palette", + default="natural", + choices=["natural", "pastel", "vibrant", "mono"], + ) + parser.add_argument( + "--face-shape", help="Face shape", choices=AnimalAvatarGenerator.FACE_SHAPES + ) + parser.add_argument( + "--eye-style", help="Eye style", choices=AnimalAvatarGenerator.EYE_STYLES + ) parser.add_argument("--ear-style", help="Ear style") - parser.add_argument("--nose-style", help="Nose style", choices=AnimalAvatarGenerator.NOSE_STYLES) - parser.add_argument("--expression", help="Expression", choices=AnimalAvatarGenerator.EXPRESSIONS) + parser.add_argument( + "--nose-style", help="Nose style", choices=AnimalAvatarGenerator.NOSE_STYLES + ) + parser.add_argument( + "--expression", help="Expression", choices=AnimalAvatarGenerator.EXPRESSIONS + ) parser.add_argument("--special-feature", help="Special feature") parser.add_argument("--size", help="Size in pixels", type=int, default=500) parser.add_argument("--seed", help="Random seed for reproducibility", type=int) parser.add_argument("--output", help="Output file path", default="avatar.svg") - parser.add_argument("--list-options", help="List all available options", action="store_true") - + parser.add_argument( + "--list-options", help="List all available options", action="store_true" + ) + args = parser.parse_args() - + if args.list_options: options = list_avatar_options() print(json.dumps(options, indent=2)) return - + generator = AnimalAvatarGenerator(seed=args.seed) - + svg = generator.generate_avatar( animal=args.animal, color_palette=args.color_palette, @@ -859,13 +1504,14 @@ def create_avatar_app(): nose_style=args.nose_style, expression=args.expression, special_feature=args.special_feature, - size=args.size + size=args.size, ) - + with open(args.output, "w") as f: f.write(svg) - + print(f"Avatar saved to {args.output}") + if __name__ == "__main__": create_avatar_app() diff --git a/src/snek/view/avatar_animal_view.py b/src/snek/view/avatar_animal_view.py index ad81301..ebcca95 100644 --- a/src/snek/view/avatar_animal_view.py +++ b/src/snek/view/avatar_animal_view.py @@ -31,13 +31,12 @@ from multiavatar import multiavatar from snek.system.view import BaseView from snek.view.avatar_animal import generate_avatar_with_options -import functools class AvatarView(BaseView): login_required = False - - def __init__(self, *args,**kwargs): - super().__init__(*args,**kwargs) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.avatars = {} async def get(self): @@ -45,10 +44,9 @@ class AvatarView(BaseView): while True: try: return web.Response(text=self._get(uid), content_type="image/svg+xml") - except Exception as e: + except Exception: pass - def _get(self, uid): if uid in self.avatars: return self.avatars[uid] diff --git a/src/snek/view/channel.py b/src/snek/view/channel.py index 6aba26a..c3b8587 100644 --- a/src/snek/view/channel.py +++ b/src/snek/view/channel.py @@ -1,15 +1,15 @@ import asyncio import mimetypes +import pathlib +import urllib.parse from os.path import isfile -from PIL import Image -import pillow_heif.HeifImagePlugin - -from snek.system.view import BaseView import aiofiles from aiohttp import web -import pathlib -import urllib.parse +from PIL import Image + +from snek.system.view import BaseView + class ChannelAttachmentView(BaseView): async def get(self): @@ -73,18 +73,21 @@ class ChannelAttachmentView(BaseView): # response.write is async but image.save is not naughty_steal = response.write tasks = [] + def sync_writer(*args, **kwargs): tasks.append(naughty_steal(*args, **kwargs)) return True - + setattr(response, "write", sync_writer) - image.save(response, format=format_, quality=100, optimize=True, save_all=True) + image.save( + response, format=format_, quality=100, optimize=True, save_all=True + ) setattr(response, "write", naughty_steal) - + await asyncio.gather(*tasks) - + return response else: response = web.FileResponse(channel_attachment["path"]) @@ -127,7 +130,9 @@ class ChannelAttachmentView(BaseView): attachment_records = [] for attachment in attachments: attachment_record = attachment.record - attachment_record['relative_url'] = urllib.parse.quote(attachment_record['relative_url']) + attachment_record["relative_url"] = urllib.parse.quote( + attachment_record["relative_url"] + ) attachment_records.append(attachment_record) return web.json_response( @@ -138,23 +143,37 @@ class ChannelAttachmentView(BaseView): } ) + class ChannelView(BaseView): async def get(self): channel_name = self.request.match_info.get("channel") - if(channel_name is None): + if channel_name is None: return web.HTTPNotFound() channel = await self.services.channel.get(label="#" + channel_name) - if(channel is None): + if channel is None: channel = await self.services.channel.get(label=channel_name) channel = await self.services.channel.get(channel_name) - if(channel is None): + if channel is None: channel = await self.services.channel.get(label=channel_name) - if(channel is None): - user = await self.services.user.get(uid=self.session.get("uid")) + if channel is None: + user = await self.services.user.get(uid=self.session.get("uid")) is_listed = self.request.query.get("listed", False) == "true" is_private = self.request.query.get("private", False) == "true" - channel = await self.services.channel.create(label=channel_name,created_by_uid=user['uid'],description="No description provided.",tag="user",is_private=is_private,is_listed=is_listed) - channel_member = await self.services.channel_member.create(channel_uid=channel['uid'],user_uid=user['uid'],is_moderator=True,is_read_only=False,is_muted=False,is_banned=False) - - return web.HTTPFound("/channel/{}.html".format(channel["uid"])) + channel = await self.services.channel.create( + label=channel_name, + created_by_uid=user["uid"], + description="No description provided.", + tag="user", + is_private=is_private, + is_listed=is_listed, + ) + await self.services.channel_member.create( + channel_uid=channel["uid"], + user_uid=user["uid"], + is_moderator=True, + is_read_only=False, + is_muted=False, + is_banned=False, + ) + return web.HTTPFound("/channel/{}.html".format(channel["uid"])) diff --git a/src/snek/view/drive.py b/src/snek/view/drive.py index d6de50f..e972539 100644 --- a/src/snek/view/drive.py +++ b/src/snek/view/drive.py @@ -1,20 +1,15 @@ +import mimetypes +import os +import urllib.parse +from datetime import datetime +from pathlib import Path +from urllib.parse import quote, unquote + from aiohttp import web from snek.system.view import BaseView -import os -import mimetypes -from aiohttp import web -from urllib.parse import unquote, quote -from datetime import datetime - - - -from aiohttp import web -from pathlib import Path -import mimetypes, urllib.parse - class DriveView(BaseView): async def get(self): @@ -27,7 +22,7 @@ class DriveView(BaseView): return web.HTTPNotFound(reason="Path not found") if target.is_dir(): - return await self.render_template("drive.html",{"path": rel_path}) + return await self.render_template("drive.html", {"path": rel_path}) if target.is_file(): return web.FileResponse(target) @@ -37,7 +32,7 @@ class DriveApiView(BaseView): target = await self.services.user.get_home_folder(self.session.get("uid")) rel = self.request.query.get("path", "") offset = int(self.request.query.get("offset", 0)) - limit = int(self.request.query.get("limit", 20)) + limit = int(self.request.query.get("limit", 20)) if rel: target = target.joinpath(rel) @@ -47,58 +42,54 @@ class DriveApiView(BaseView): if target.is_dir(): entries = [] - for p in sorted(target.iterdir(), key=lambda p: (p.is_file(), p.name.lower())): + for p in sorted( + target.iterdir(), key=lambda p: (p.is_file(), p.name.lower()) + ): item_path = (Path(rel) / p.name).as_posix() - mime = mimetypes.guess_type(p.name)[0] if p.is_file() else "inode/directory" - url = (self.request.url.with_path(f"/drive/{urllib.parse.quote(item_path)}") - if p.is_file() else None) - entries.append({ - "name": p.name, - "type": "directory" if p.is_dir() else "file", - "mimetype": mime, - "size": p.stat().st_size if p.is_file() else None, - "path": item_path, - "url": url, - }) - import json + mime = ( + mimetypes.guess_type(p.name)[0] + if p.is_file() + else "inode/directory" + ) + url = ( + self.request.url.with_path( + f"/drive/{urllib.parse.quote(item_path)}" + ) + if p.is_file() + else None + ) + entries.append( + { + "name": p.name, + "type": "directory" if p.is_dir() else "file", + "mimetype": mime, + "size": p.stat().st_size if p.is_file() else None, + "path": item_path, + "url": url, + } + ) + import json + total = len(entries) - items = entries[offset:offset+limit] - return web.json_response({ - "items": json.loads(json.dumps(items,default=str)), - "pagination": {"offset": offset, "limit": limit, "total": total} - }) - + items = entries[offset : offset + limit] + return web.json_response( + { + "items": json.loads(json.dumps(items, default=str)), + "pagination": {"offset": offset, "limit": limit, "total": total}, + } + ) + url = self.request.url.with_path(f"/drive/{urllib.parse.quote(rel)}") - return web.json_response({ - "name": target.name, - "type": "file", - "mimetype": mimetypes.guess_type(target.name)[0], - "size": target.stat().st_size, - "path": rel, - "url": str(url), - }) - - - - - - - - - - - - - - - - - - - - - - + return web.json_response( + { + "name": target.name, + "type": "file", + "mimetype": mimetypes.guess_type(target.name)[0], + "size": target.stat().st_size, + "path": rel, + "url": str(url), + } + ) class DriveView222(BaseView): @@ -124,7 +115,11 @@ class DriveView222(BaseView): entry_path = os.path.join(dir_path, entry) stat = os.stat(entry_path) is_dir = os.path.isdir(entry_path) - mimetype = None if is_dir else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream") + mimetype = ( + None + if is_dir + else (mimetypes.guess_type(entry_path)[0] or "application/octet-stream") + ) size = stat.st_size if not is_dir else None created_at = datetime.fromtimestamp(stat.st_ctime).isoformat() updated_at = datetime.fromtimestamp(stat.st_mtime).isoformat() @@ -155,15 +150,20 @@ class DriveView222(BaseView): start = (page - 1) * page_size end = start + page_size paged_entries = entries[start:end] - details = [await self.entry_details(full_path, entry, rel_path) for entry in paged_entries] - return web.json_response({ - "path": rel_path, - "absolute_url": abs_url, - "entries": details, - "total": len(entries), - "page": page, - "page_size": page_size, - }) + details = [ + await self.entry_details(full_path, entry, rel_path) + for entry in paged_entries + ] + return web.json_response( + { + "path": rel_path, + "absolute_url": abs_url, + "entries": details, + "total": len(entries), + "page": page, + "page_size": page_size, + } + ) else: with open(full_path, "rb") as f: content = f.read() @@ -180,14 +180,18 @@ class DriveView222(BaseView): data = await self.request.post() if data.get("type") == "dir": os.makedirs(full_path) - return web.json_response({"status": "created", "type": "dir", "absolute_url": abs_url}) + return web.json_response( + {"status": "created", "type": "dir", "absolute_url": abs_url} + ) else: file_field = data.get("file") if not file_field: raise web.HTTPBadRequest(reason="No file uploaded") with open(full_path, "wb") as f: f.write(file_field.file.read()) - return web.json_response({"status": "created", "type": "file", "absolute_url": abs_url}) + return web.json_response( + {"status": "created", "type": "file", "absolute_url": abs_url} + ) async def put(self): rel_path = self.request.match_info.get("rel_path", "") @@ -210,10 +214,14 @@ class DriveView222(BaseView): raise web.HTTPNotFound(reason="Path not found") if os.path.isdir(full_path): os.rmdir(full_path) - return web.json_response({"status": "deleted", "type": "dir", "absolute_url": abs_url}) + return web.json_response( + {"status": "deleted", "type": "dir", "absolute_url": abs_url} + ) else: os.remove(full_path) - return web.json_response({"status": "deleted", "type": "file", "absolute_url": abs_url}) + return web.json_response( + {"status": "deleted", "type": "file", "absolute_url": abs_url} + ) class DriveViewi2(BaseView): @@ -223,20 +231,17 @@ class DriveViewi2(BaseView): async def get(self): drive_uid = self.request.match_info.get("drive") - before = self.request.query.get("before") - filters = {} + filters = {} if before: filters["created_at__lt"] = before if drive_uid: - filters['drive_uid'] = drive_uid + filters["drive_uid"] = drive_uid drive = await self.services.drive.get(uid=drive_uid) drive_items = [] - - - + async for item in self.services.drive_item.find(**filters): record = item.record record["url"] = "/drive.bin/" + record["uid"] + "." + item.extension diff --git a/src/snek/view/push.py b/src/snek/view/push.py index c2ba644..169ef63 100644 --- a/src/snek/view/push.py +++ b/src/snek/view/push.py @@ -82,4 +82,4 @@ class PushView(BaseFormView): print(post_notification.text) print(post_notification.headers) - return await self.json_response({ "registered": True }) \ No newline at end of file + return await self.json_response({"registered": True}) diff --git a/src/snek/view/repository.py b/src/snek/view/repository.py index 60fa1bb..4c78e4d 100644 --- a/src/snek/view/repository.py +++ b/src/snek/view/repository.py @@ -1,14 +1,13 @@ -import os +import asyncio import mimetypes +import os import urllib.parse from pathlib import Path + import humanize from aiohttp import web + from snek.system.view import BaseView -import asyncio - - - class BareRepoNavigator: @@ -25,18 +24,18 @@ class BareRepoNavigator: except Exception as e: print(f"Error opening repository: {str(e)}") sys.exit(1) - + self.repo_path = repo_path self.branches = list(self.repo.branches) self.current_branch = None self.current_commit = None self.current_path = "" self.history = [] - + def get_branches(self): """Return a list of branch names in the repository.""" return [branch.name for branch in self.branches] - + def set_branch(self, branch_name): """Set the current branch.""" try: @@ -47,23 +46,27 @@ class BareRepoNavigator: return True except IndexError: return False - + def get_commits(self, count=10): """Get the latest commits on the current branch.""" if not self.current_branch: return [] - + commits = [] for commit in self.repo.iter_commits(self.current_branch, max_count=count): - commits.append({ - 'hash': commit.hexsha, - 'short_hash': commit.hexsha[:7], - 'message': commit.message.strip(), - 'author': commit.author.name, - 'date': datetime.fromtimestamp(commit.committed_date).strftime('%Y-%m-%d %H:%M:%S') - }) + commits.append( + { + "hash": commit.hexsha, + "short_hash": commit.hexsha[:7], + "message": commit.message.strip(), + "author": commit.author.name, + "date": datetime.fromtimestamp(commit.committed_date).strftime( + "%Y-%m-%d %H:%M:%S" + ), + } + ) return commits - + def set_commit(self, commit_hash): """Set the current commit by hash.""" try: @@ -73,48 +76,48 @@ class BareRepoNavigator: return True except ValueError: return False - + def list_directory(self, path=""): """List the contents of a directory in the current commit.""" if not self.current_commit: - return {'dirs': [], 'files': []} - + return {"dirs": [], "files": []} + dirs = [] files = [] - + try: # Get the tree at the current path if path: tree = self.current_commit.tree[path] - if not hasattr(tree, 'trees'): # It's a blob, not a tree - return {'dirs': [], 'files': [path]} + if not hasattr(tree, "trees"): # It's a blob, not a tree + return {"dirs": [], "files": [path]} else: tree = self.current_commit.tree - + # List directories and files for item in tree: - if item.type == 'tree': + if item.type == "tree": item_path = os.path.join(path, item.name) if path else item.name dirs.append(item_path) - elif item.type == 'blob': + elif item.type == "blob": item_path = os.path.join(path, item.name) if path else item.name files.append(item_path) - + dirs.sort() files.sort() - return {'dirs': dirs, 'files': files} - + return {"dirs": dirs, "files": files} + except KeyError: - return {'dirs': [], 'files': []} - + return {"dirs": [], "files": []} + def get_file_content(self, file_path): """Get the content of a file in the current commit.""" if not self.current_commit: return None - + try: blob = self.current_commit.tree[file_path] - return blob.data_stream.read().decode('utf-8', errors='replace') + return blob.data_stream.read().decode("utf-8", errors="replace") except (KeyError, UnicodeDecodeError): try: # Try to get as binary if text decoding fails @@ -122,12 +125,12 @@ class BareRepoNavigator: return blob.data_stream.read() except: return None - + def navigate_to(self, path): """Navigate to a specific path, updating the current path.""" if not self.current_commit: return False - + try: if path: self.current_commit.tree[path] # Check if path exists @@ -136,7 +139,7 @@ class BareRepoNavigator: return True except KeyError: return False - + def navigate_back(self): """Navigate back to the previous path.""" if self.history: @@ -145,12 +148,13 @@ class BareRepoNavigator: return False - class RepositoryView(BaseView): login_required = True - def checkout_bare_repo(self, bare_repo_path: Path, target_path: Path, ref: str = 'HEAD'): + def checkout_bare_repo( + self, bare_repo_path: Path, target_path: Path, ref: str = "HEAD" + ): repo = Repo(bare_repo_path) assert repo.bare, "Repository is not bare." @@ -159,23 +163,22 @@ class RepositoryView(BaseView): for blob in tree.traverse(): target_file = target_path / blob.path - + target_file.parent.mkdir(parents=True, exist_ok=True) print(blob.path) - with open(target_file, 'wb') as f: + with open(target_file, "wb") as f: f.write(blob.data_stream.read()) - async def get(self): - base_repo_path = Path("drive/repositories") + base_repo_path = Path("drive/repositories") authenticated_user_id = self.session.get("uid") - username = self.request.match_info.get('username') - repo_name = self.request.match_info.get('repository') - rel_path = self.request.match_info.get('path', '') + username = self.request.match_info.get("username") + repo_name = self.request.match_info.get("repository") + rel_path = self.request.match_info.get("path", "") user = None if not username.count("-") == 4: user = await self.app.services.user.get(username=username) @@ -185,22 +188,24 @@ class RepositoryView(BaseView): else: user = await self.app.services.user.get(uid=username) - repo = await self.app.services.repository.get(name=repo_name, user_uid=user["uid"]) + repo = await self.app.services.repository.get( + name=repo_name, user_uid=user["uid"] + ) if not repo: return web.Response(text="404 Not Found", status=404) - if repo['is_private'] and authenticated_user_id != repo['uid']: - return web.Response(text="404 Not Found", status=404) + if repo["is_private"] and authenticated_user_id != repo["uid"]: + return web.Response(text="404 Not Found", status=404) - repo_root_base = (base_repo_path / user['uid'] / (repo_name + ".git")).resolve() - repo_root = (base_repo_path / user['uid'] / repo_name).resolve() + repo_root_base = (base_repo_path / user["uid"] / (repo_name + ".git")).resolve() + repo_root = (base_repo_path / user["uid"] / repo_name).resolve() try: loop = asyncio.get_event_loop() - await loop.run_in_executor(None, - self.checkout_bare_repo, repo_root_base, repo_root + await loop.run_in_executor( + None, self.checkout_bare_repo, repo_root_base, repo_root ) except: pass - + if not repo_root.exists() or not repo_root.is_dir(): return web.Response(text="404 Not Found", status=404) @@ -211,24 +216,35 @@ class RepositoryView(BaseView): return web.Response(text="404 Not Found", status=404) if abs_path.is_dir(): - return web.Response(text=self.render_directory(abs_path, username, repo_name, safe_rel_path), content_type='text/html') + return web.Response( + text=self.render_directory( + abs_path, username, repo_name, safe_rel_path + ), + content_type="text/html", + ) else: - return web.Response(text=self.render_file(abs_path), content_type='text/html') + return web.Response( + text=self.render_file(abs_path), content_type="text/html" + ) def render_directory(self, abs_path, username, repo_name, safe_rel_path): - entries = sorted(abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())) + entries = sorted( + abs_path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()) + ) items = [] if safe_rel_path: parent_path = Path(safe_rel_path).parent - parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip('/') + parent_link = f"/repository/{username}/{repo_name}/{parent_path}".rstrip( + "/" + ) items.append(f'
  • ⬅️ ..
  • ') for entry in entries: link_path = urllib.parse.quote(str(Path(safe_rel_path) / entry.name)) - link = f"/repository/{username}/{repo_name}/{link_path}".rstrip('/') - display = entry.name + ('/' if entry.is_dir() else '') - size = '' if entry.is_dir() else humanize.naturalsize(entry.stat().st_size) + link = f"/repository/{username}/{repo_name}/{link_path}".rstrip("/") + display = entry.name + ("/" if entry.is_dir() else "") + size = "" if entry.is_dir() else humanize.naturalsize(entry.stat().st_size) icon = self.get_icon(entry) items.append(f'
  • {icon} {display} {size}
  • ') @@ -247,19 +263,24 @@ class RepositoryView(BaseView): def render_file(self, abs_path): try: - with open(abs_path, 'r', encoding='utf-8', errors='ignore') as f: + with open(abs_path, encoding="utf-8", errors="ignore") as f: content = f.read() return f"
    {content}
    " except Exception as e: return f"

    Error

    {e}
    " def get_icon(self, file): - if file.is_dir(): return "📁" - mime = mimetypes.guess_type(file.name)[0] or '' - if mime.startswith("image"): return "🖼️" - if mime.startswith("text"): return "📄" - if mime.startswith("audio"): return "🎵" - if mime.startswith("video"): return "🎬" - if file.name.endswith(".py"): return "🐍" + if file.is_dir(): + return "📁" + mime = mimetypes.guess_type(file.name)[0] or "" + if mime.startswith("image"): + return "🖼️" + if mime.startswith("text"): + return "📄" + if mime.startswith("audio"): + return "🎵" + if mime.startswith("video"): + return "🎬" + if file.name.endswith(".py"): + return "🐍" return "📦" - diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index 7227e27..a31c45b 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -7,19 +7,20 @@ # MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions. -import json -import traceback import asyncio +import json +import logging +import traceback + from aiohttp import web from snek.system.model import now from snek.system.profiler import Profiler from snek.system.view import BaseView -import logging - logger = logging.getLogger(__name__) + class RPCView(BaseView): class RPCApi: def __init__(self, view, ws): @@ -32,7 +33,7 @@ class RPCView(BaseView): async def _session_ensure(self): uid = await self.view.session_get("uid") - if not uid in self.user_session: + if uid not in self.user_session: self.user_session[uid] = { "said_hello": False, } @@ -50,45 +51,54 @@ class RPCView(BaseView): self._require_login() return await self.services.db.insert(self.user_uid, table_name, record) + async def db_update(self, table_name, record): self._require_login() return await self.services.db.update(self.user_uid, table_name, record) - async def set_typing(self,channel_uid,color=None): + + async def set_typing(self, channel_uid, color=None): self._require_login() user = await self.services.user.get(self.user_uid) if not color: color = user["color"] - return await self.services.socket.broadcast(channel_uid, { - "channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2", - "event": "set_typing", - "data": { - "event":"set_typing", - "user_uid": user['uid'], - "username": user["username"], - "nick": user["nick"], - "channel_uid": channel_uid, - "color": color - } - }) + return await self.services.socket.broadcast( + channel_uid, + { + "channel_uid": "293ecf12-08c9-494b-b423-48ba1a2d12c2", + "event": "set_typing", + "data": { + "event": "set_typing", + "user_uid": user["uid"], + "username": user["username"], + "nick": user["nick"], + "channel_uid": channel_uid, + "color": color, + }, + }, + ) async def db_delete(self, table_name, record): self._require_login() return await self.services.db.delete(self.user_uid, table_name, record) + async def db_get(self, table_name, record): self._require_login() return await self.services.db.get(self.user_uid, table_name, record) + async def db_find(self, table_name, record): self._require_login() return await self.services.db.find(self.user_uid, table_name, record) - async def db_upsert(self, table_name, record,keys): + + async def db_upsert(self, table_name, record, keys): self._require_login() - return await self.services.db.upsert(self.user_uid, table_name, record,keys) + return await self.services.db.upsert( + self.user_uid, table_name, record, keys + ) async def db_query(self, table_name, args): self._require_login() return await self.services.db.query(self.user_uid, table_name, sql, args) - @property def user_uid(self): return self.view.session.get("uid") @@ -195,52 +205,60 @@ class RPCView(BaseView): async def finalize_message(self, message_uid): self._require_login() - message = await self.services.channel_message.get(message_uid) + message = await self.services.channel_message.get(message_uid) if not message: - return False + return False if message["user_uid"] != self.user_uid: raise Exception("Not allowed") - - - if not message['is_final']: - await self.services.chat.finalize(message['uid']) + + if not message["is_final"]: + await self.services.chat.finalize(message["uid"]) return True - async def update_message_text(self,message_uid, text): + async def update_message_text(self, message_uid, text): self._require_login() - message = await self.services.channel_message.get(message_uid) + message = await self.services.channel_message.get(message_uid) if message["user_uid"] != self.user_uid: raise Exception("Not allowed") - + if message.get_seconds_since_last_update() > 3: await self.finalize_message(message["uid"]) - return {"error": "Message too old","seconds_since_last_update": message.get_seconds_since_last_update(),"success": False} + return { + "error": "Message too old", + "seconds_since_last_update": message.get_seconds_since_last_update(), + "success": False, + } - message['message'] = text + message["message"] = text if not text: - message['deleted_at'] = now() + message["deleted_at"] = now() else: - message['deleted_at'] = None + message["deleted_at"] = None await self.services.channel_message.save(message) - data = message.record - data['text'] = message["message"] - data['message_uid'] = message_uid + data = message.record + data["text"] = message["message"] + data["message_uid"] = message_uid + + await self.services.socket.broadcast( + message["channel_uid"], + { + "channel_uid": message["channel_uid"], + "event": "update_message_text", + "data": message.record, + }, + ) - await self.services.socket.broadcast(message["channel_uid"], { - "channel_uid": message["channel_uid"], - "event": "update_message_text", - "data": message.record - }) - return {"success": True} - async def send_message(self, channel_uid, message,is_final=True): + async def send_message(self, channel_uid, message, is_final=True): self._require_login() - message = await self.services.chat.send(self.user_uid, channel_uid, message,is_final) - + message = await self.services.chat.send( + self.user_uid, channel_uid, message, is_final + ) + return message["uid"] async def echo(self, *args): @@ -330,7 +348,8 @@ class RPCView(BaseView): self._require_login() results = [ - record async for record in self.services.channel.get_recent_users(channel_uid) + record + async for record in self.services.channel.get_recent_users(channel_uid) ] results = sorted(results, key=lambda x: x["nick"]) return results @@ -371,7 +390,6 @@ class RPCView(BaseView): await self.services.socket.send_to_user(self.user_uid, call) self._scheduled.remove(call) - async def ping(self, callId, *args): if self.user_uid: user = await self.services.user.get(uid=self.user_uid) @@ -379,17 +397,23 @@ class RPCView(BaseView): await self.services.user.save(user) return {"pong": args} - async def stars_render(self, channel_uid, message): - + for user in await self.get_online_users(channel_uid): try: - await self.services.socket.send_to_user(user['uid'], dict(event="stars_render", data={"channel_uid": channel_uid, "message":message})) + await self.services.socket.send_to_user( + user["uid"], + { + "event": "stars_render", + "data": {"channel_uid": channel_uid, "message": message}, + }, + ) except Exception as ex: print(ex) - + async def get(self): scheduled = [] + async def schedule(uid, seconds, call): scheduled.append(call) await asyncio.sleep(seconds) @@ -408,16 +432,18 @@ class RPCView(BaseView): await self.services.socket.subscribe( ws, subscription["channel_uid"], self.request.session.get("uid") ) - if not scheduled and self.request.app.uptime_seconds < 5: - await schedule(self.request.session.get("uid"),0,{"event":"refresh", "data": { - "message": "Finishing deployment"} - } - ) - await schedule(self.request.session.get("uid"),15,{"event": "deployed", "data": { - "uptime": self.request.app.uptime} - } + if not scheduled and self.request.app.uptime_seconds < 5: + await schedule( + self.request.session.get("uid"), + 0, + {"event": "refresh", "data": {"message": "Finishing deployment"}}, ) - + await schedule( + self.request.session.get("uid"), + 15, + {"event": "deployed", "data": {"uptime": self.request.app.uptime}}, + ) + rpc = RPCView.RPCApi(self, ws) async for msg in ws: if msg.type == web.WSMsgType.TEXT: diff --git a/src/snek/view/settings/containers.py b/src/snek/view/settings/containers.py index dae0237..b6c4bfc 100644 --- a/src/snek/view/settings/containers.py +++ b/src/snek/view/settings/containers.py @@ -1,45 +1,48 @@ -import asyncio from aiohttp import web from snek.system.view import BaseFormView -import pathlib + class ContainersIndexView(BaseFormView): login_required = True async def get(self): - + user_uid = self.session.get("uid") - + containers = [] async for container in self.services.container.find(user_uid=user_uid): containers.append(container.record) - + user = await self.services.user.get(uid=self.session.get("uid")) - return await self.render_template("settings/containers/index.html", {"containers": containers, "user": user}) + return await self.render_template( + "settings/containers/index.html", {"containers": containers, "user": user} + ) + class ContainersCreateView(BaseFormView): login_required = True async def get(self): - + return await self.render_template("settings/containers/create.html") async def post(self): data = await self.request.post() - container = await self.services.container.create( + await self.services.container.create( user_uid=self.session.get("uid"), - name=data['name'], - status=data['status'], - resources=data.get('resources', ''), - path=data.get('path', ''), - readonly=bool(data.get('readonly', False)) + name=data["name"], + status=data["status"], + resources=data.get("resources", ""), + path=data.get("path", ""), + readonly=bool(data.get("readonly", False)), ) return web.HTTPFound("/settings/containers/index.html") + class ContainersUpdateView(BaseFormView): login_required = True @@ -51,40 +54,43 @@ class ContainersUpdateView(BaseFormView): ) if not container: return web.HTTPNotFound() - return await self.render_template("settings/containers/update.html", {"container": container.record}) + return await self.render_template( + "settings/containers/update.html", {"container": container.record} + ) async def post(self): data = await self.request.post() container = await self.services.container.get( user_uid=self.session.get("uid"), uid=self.request.match_info["uid"] ) - container['status'] = data['status'] - container['resources'] = data.get('resources', '') - container['path'] = data.get('path', '') - container['readonly'] = bool(data.get('readonly', False)) + container["status"] = data["status"] + container["resources"] = data.get("resources", "") + container["path"] = data.get("path", "") + container["readonly"] = bool(data.get("readonly", False)) await self.services.container.save(container) return web.HTTPFound("/settings/containers/index.html") + class ContainersDeleteView(BaseFormView): login_required = True async def get(self): - + container = await self.services.container.get( user_uid=self.session.get("uid"), uid=self.request.match_info["uid"] ) if not container: return web.HTTPNotFound() - return await self.render_template("settings/containers/delete.html", {"container": container.record}) + return await self.render_template( + "settings/containers/delete.html", {"container": container.record} + ) async def post(self): user_uid = self.session.get("uid") uid = self.request.match_info["uid"] - container = await self.services.container.get( - user_uid=user_uid, uid=uid - ) + container = await self.services.container.get(user_uid=user_uid, uid=uid) if not container: return web.HTTPNotFound() await self.services.container.delete(user_uid=user_uid, uid=uid) diff --git a/src/snek/view/settings/repositories.py b/src/snek/view/settings/repositories.py index 093d229..dcc945b 100644 --- a/src/snek/view/settings/repositories.py +++ b/src/snek/view/settings/repositories.py @@ -1,26 +1,26 @@ -import asyncio from aiohttp import web from snek.system.view import BaseFormView -import pathlib + class RepositoriesIndexView(BaseFormView): login_required = True async def get(self): - + user_uid = self.session.get("uid") - + repositories = [] async for repository in self.services.repository.find(user_uid=user_uid): repositories.append(repository.record) - + user = await self.services.user.get(uid=self.session.get("uid")) - return await self.render_template("settings/repositories/index.html", {"repositories": repositories, "user": user}) - - + return await self.render_template( + "settings/repositories/index.html", + {"repositories": repositories, "user": user}, + ) class RepositoriesCreateView(BaseFormView): @@ -28,14 +28,19 @@ class RepositoriesCreateView(BaseFormView): login_required = True async def get(self): - + return await self.render_template("settings/repositories/create.html") async def post(self): data = await self.request.post() - repository = await self.services.repository.create(user_uid=self.session.get("uid"), name=data['name'], is_private=int(data.get('is_private',0))) + await self.services.repository.create( + user_uid=self.session.get("uid"), + name=data["name"], + is_private=int(data.get("is_private", 0)), + ) return web.HTTPFound("/settings/repositories/index.html") + class RepositoriesUpdateView(BaseFormView): login_required = True @@ -47,40 +52,41 @@ class RepositoriesUpdateView(BaseFormView): ) if not repository: return web.HTTPNotFound() - return await self.render_template("settings/repositories/update.html", {"repository": repository.record}) + return await self.render_template( + "settings/repositories/update.html", {"repository": repository.record} + ) async def post(self): data = await self.request.post() repository = await self.services.repository.get( user_uid=self.session.get("uid"), name=self.request.match_info["name"] ) - repository['is_private'] = int(data.get('is_private',0)) + repository["is_private"] = int(data.get("is_private", 0)) await self.services.repository.save(repository) return web.HTTPFound("/settings/repositories/index.html") + class RepositoriesDeleteView(BaseFormView): login_required = True async def get(self): - + repository = await self.services.repository.get( user_uid=self.session.get("uid"), name=self.request.match_info["name"] ) if not repository: return web.HTTPNotFound() - return await self.render_template("settings/repositories/delete.html", {"repository": repository.record}) + return await self.render_template( + "settings/repositories/delete.html", {"repository": repository.record} + ) async def post(self): user_uid = self.session.get("uid") name = self.request.match_info["name"] - repository = await self.services.repository.get( - user_uid=user_uid, name=name - ) + repository = await self.services.repository.get(user_uid=user_uid, name=name) if not repository: return web.HTTPNotFound() await self.services.repository.delete(user_uid=user_uid, name=name) return web.HTTPFound("/settings/repositories/index.html") - -